use std::{ sync::{atomic::AtomicU32, Arc, Mutex}, time::Duration, }; use anyhow::Context as _; use futures::{SinkExt as _, StreamExt}; use tokio::task::JoinSet; use crate::{ common::join_joinset_background, proto::rpc_types::{__rt::RpcClientFactory, error::Error}, tunnel::{Tunnel, TunnelConnector, TunnelListener}, }; use super::{client::Client, server::Server, service_registry::ServiceRegistry}; struct StandAloneServerOneTunnel { tunnel: Box, rpc_server: Server, } impl StandAloneServerOneTunnel { pub fn new(tunnel: Box, registry: Arc) -> Self { let rpc_server = Server::new_with_registry(registry); StandAloneServerOneTunnel { tunnel, rpc_server } } pub async fn run(self) { use tokio_stream::StreamExt as _; let (tunnel_rx, tunnel_tx) = self.tunnel.split(); let (rpc_rx, rpc_tx) = ( self.rpc_server.get_transport_stream(), self.rpc_server.get_transport_sink(), ); let mut tasks = JoinSet::new(); tasks.spawn(async move { let ret = tunnel_rx.timeout(Duration::from_secs(60)); tokio::pin!(ret); while let Ok(Some(Ok(p))) = ret.try_next().await { if let Err(e) = rpc_tx.send(p).await { tracing::error!("tunnel_rx send to rpc_tx error: {:?}", e); break; } } tracing::info!("forward tunnel_rx to rpc_tx done"); }); tasks.spawn(async move { let ret = rpc_rx.forward(tunnel_tx).await; tracing::info!("rpc_rx forward tunnel_tx done: {:?}", ret); }); self.rpc_server.run(); while let Some(ret) = tasks.join_next().await { self.rpc_server.close(); tracing::info!("task done: {:?}", ret); } tracing::info!("all tasks done"); } } pub struct StandAloneServer { registry: Arc, listener: Option, inflight_server: Arc, tasks: Arc>>, } impl StandAloneServer { pub fn new(listener: L) -> Self { StandAloneServer { registry: Arc::new(ServiceRegistry::new()), listener: Some(listener), inflight_server: Arc::new(AtomicU32::new(0)), tasks: Arc::new(Mutex::new(JoinSet::new())), } } pub fn registry(&self) -> &ServiceRegistry { &self.registry } pub async fn serve(&mut self) -> Result<(), Error> { let tasks = self.tasks.clone(); let mut listener = self.listener.take().unwrap(); let registry = self.registry.clone(); join_joinset_background(tasks.clone(), "standalone server tasks".to_string()); listener .listen() .await .with_context(|| "failed to listen")?; let inflight_server = self.inflight_server.clone(); self.tasks.lock().unwrap().spawn(async move { while let Ok(tunnel) = listener.accept().await { let server = StandAloneServerOneTunnel::new(tunnel, registry.clone()); let inflight_server = inflight_server.clone(); inflight_server.fetch_add(1, std::sync::atomic::Ordering::Relaxed); tasks.lock().unwrap().spawn(async move { server.run().await; inflight_server.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); }); } panic!("standalone server listener exit"); }); Ok(()) } pub fn inflight_server(&self) -> u32 { self.inflight_server .load(std::sync::atomic::Ordering::Relaxed) } } struct StandAloneClientOneTunnel { rpc_client: Client, tasks: Arc>>, error: Arc>>, } impl StandAloneClientOneTunnel { pub fn new(tunnel: Box) -> Self { let rpc_client = Client::new(); let (mut rpc_rx, rpc_tx) = ( rpc_client.get_transport_stream(), rpc_client.get_transport_sink(), ); let tasks = Arc::new(Mutex::new(JoinSet::new())); let (mut tunnel_rx, mut tunnel_tx) = tunnel.split(); let error_store = Arc::new(Mutex::new(None)); let error = error_store.clone(); tasks.lock().unwrap().spawn(async move { while let Some(p) = rpc_rx.next().await { match p { Ok(p) => { if let Err(e) = tunnel_tx .send(p) .await .with_context(|| "failed to send packet") { *error.lock().unwrap() = Some(e.into()); } } Err(e) => { *error.lock().unwrap() = Some(anyhow::Error::from(e).into()); } } } *error.lock().unwrap() = Some(anyhow::anyhow!("rpc_rx next exit").into()); }); let error = error_store.clone(); tasks.lock().unwrap().spawn(async move { while let Some(p) = tunnel_rx.next().await { match p { Ok(p) => { if let Err(e) = rpc_tx .send(p) .await .with_context(|| "failed to send packet") { *error.lock().unwrap() = Some(e.into()); } } Err(e) => { *error.lock().unwrap() = Some(anyhow::Error::from(e).into()); } } } *error.lock().unwrap() = Some(anyhow::anyhow!("tunnel_rx next exit").into()); }); rpc_client.run(); StandAloneClientOneTunnel { rpc_client, tasks, error: error_store, } } pub fn take_error(&self) -> Option { self.error.lock().unwrap().take() } } pub struct StandAloneClient { connector: C, client: Option, } impl StandAloneClient { pub fn new(connector: C) -> Self { StandAloneClient { connector, client: None, } } async fn connect(&mut self) -> Result, Error> { Ok(self.connector.connect().await.with_context(|| { format!( "failed to connect to server: {:?}", self.connector.remote_url() ) })?) } pub async fn scoped_client( &mut self, domain_name: String, ) -> Result { let mut c = self.client.take(); let error = c.as_ref().and_then(|c| c.take_error()); if c.is_none() || error.is_some() { tracing::info!("reconnect due to error: {:?}", error); let tunnel = self.connect().await?; c = Some(StandAloneClientOneTunnel::new(tunnel)); } self.client = c; Ok(self .client .as_ref() .unwrap() .rpc_client .scoped_client::(1, 1, domain_name)) } }