use std::sync::{Arc, Mutex, atomic::AtomicBool}; use futures::{SinkExt as _, StreamExt}; use guarden::defer; use tokio::{task::JoinSet, time::timeout}; use crate::{ proto::rpc_types::error::Error, tunnel::{Tunnel, packet_def::PacketType, ring::create_ring_tunnel_pair}, }; use super::{client::Client, server::Server, service_registry::ServiceRegistry}; use crate::common::stats_manager::StatsManager; pub struct BidirectRpcManager { rpc_client: Client, rpc_server: Server, rx_timeout: Option, error: Arc>>, tunnel: Mutex>>, running: Arc, tasks: Mutex>>, } impl Default for BidirectRpcManager { fn default() -> Self { Self::new() } } impl BidirectRpcManager { pub fn new() -> Self { Self { rpc_client: Client::new(), rpc_server: Server::new(), rx_timeout: None, error: Arc::new(Mutex::new(None)), tunnel: Mutex::new(None), running: Arc::new(AtomicBool::new(false)), tasks: Mutex::new(None), } } pub fn new_with_stats_manager(stats_manager: Arc) -> Self { Self { rpc_client: Client::new_with_stats_manager(stats_manager.clone()), rpc_server: Server::new_with_registry_and_stats_manager( Arc::new(ServiceRegistry::new()), stats_manager, ), rx_timeout: None, error: Arc::new(Mutex::new(None)), tunnel: Mutex::new(None), running: Arc::new(AtomicBool::new(false)), tasks: Mutex::new(None), } } pub fn set_rx_timeout(mut self, timeout: Option) -> Self { self.rx_timeout = timeout; self } pub fn run_and_create_tunnel(&self) -> Box { let (ret, inner) = create_ring_tunnel_pair(); self.run_with_tunnel(inner); ret } pub fn run_with_tunnel(&self, inner: Box) { let mut tasks = JoinSet::new(); self.rpc_client.run(); self.rpc_server.run(); self.running .store(true, std::sync::atomic::Ordering::Relaxed); let (server_tx, mut server_rx) = ( self.rpc_server.get_transport_sink(), self.rpc_server.get_transport_stream(), ); let (client_tx, mut client_rx) = ( self.rpc_client.get_transport_sink(), self.rpc_client.get_transport_stream(), ); let (mut inner_rx, mut inner_tx) = inner.split(); self.tunnel.lock().unwrap().replace(inner); let e_clone = self.error.clone(); let r_clone = self.running.clone(); tasks.spawn(async move { defer! { r_clone.store(false, std::sync::atomic::Ordering::Relaxed); } loop { let packet = tokio::select! { Some(Ok(packet)) = server_rx.next() => { tracing::trace!(?packet, "recv rpc packet from server"); packet } Some(Ok(packet)) = client_rx.next() => { tracing::trace!(?packet, "recv rpc packet from client"); packet } else => { tracing::warn!("rpc transport read aborted, exiting"); break; } }; if let Err(e) = inner_tx.send(packet).await { tracing::error!(error = ?e, "send to peer failed"); e_clone.lock().unwrap().replace(Error::from(e)); } } }); let recv_timeout = self.rx_timeout; let e_clone = self.error.clone(); let r_clone = self.running.clone(); tasks.spawn(async move { defer! { r_clone.store(false, std::sync::atomic::Ordering::Relaxed); } loop { let ret = if let Some(recv_timeout) = recv_timeout { match timeout(recv_timeout, inner_rx.next()).await { Ok(ret) => ret, Err(e) => { e_clone.lock().unwrap().replace(e.into()); break; } } } else { inner_rx.next().await }; let o = match ret { Some(Ok(o)) => o, Some(Err(e)) => { tracing::error!(error = ?e, "recv from peer failed"); e_clone.lock().unwrap().replace(Error::from(e)); break; } None => { tracing::warn!("peer rpc transport read aborted, exiting"); e_clone.lock().unwrap().replace(Error::Shutdown); break; } }; let Some(peer_manager_header) = o.peer_manager_header() else { tracing::error!("peer manager header not found"); continue; }; if peer_manager_header.packet_type == PacketType::RpcReq as u8 { server_tx.send(o).await.unwrap(); continue; } else if peer_manager_header.packet_type == PacketType::RpcResp as u8 { client_tx.send(o).await.unwrap(); continue; } } }); self.tasks.lock().unwrap().replace(tasks); } pub fn rpc_client(&self) -> &Client { &self.rpc_client } pub fn rpc_server(&self) -> &Server { &self.rpc_server } pub async fn stop(&self) { let Some(mut tasks) = self.tasks.lock().unwrap().take() else { return; }; tasks.abort_all(); while tasks.join_next().await.is_some() {} } pub fn take_error(&self) -> Option { self.error.lock().unwrap().take() } pub async fn wait(&self) { let Some(mut tasks) = self.tasks.lock().unwrap().take() else { return; }; while tasks.join_next().await.is_some() { // when any task is done, abort all tasks tasks.abort_all(); } } pub fn is_running(&self) -> bool { self.running.load(std::sync::atomic::Ordering::Relaxed) } }