diff --git a/Cargo.lock b/Cargo.lock index 36c870d7..6db0df14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2291,6 +2291,7 @@ dependencies = [ "machine-uid", "maplit", "mimalloc", + "moka", "multimap", "natpmp", "netlink-packet-core", @@ -5103,9 +5104,12 @@ version = "0.12.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" dependencies = [ + "async-lock", "crossbeam-channel", "crossbeam-epoch", "crossbeam-utils", + "event-listener", + "futures-util", "loom", "parking_lot", "portable-atomic", diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 1bbb2b8a..4b645546 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -70,6 +70,7 @@ async-stream = "0.3.5" async-trait = "0.1.74" dashmap = "6.0" +moka = { version = "0.12", features = ["future"] } timedmap = "=1.0.1" # for full-path zero-copy diff --git a/easytier/src/gateway/quic_proxy.rs b/easytier/src/gateway/quic_proxy.rs index 5d9d90ab..0019aaf6 100644 --- a/easytier/src/gateway/quic_proxy.rs +++ b/easytier/src/gateway/quic_proxy.rs @@ -25,6 +25,7 @@ use dashmap::DashMap; use derivative::Derivative; use derive_more::{Constructor, Deref, DerefMut, From, Into}; use guarden::defer; +use moka::future::Cache; use prost::Message; use quinn::udp::{EcnCodepoint, RecvMeta, Transmit}; use quinn::{ @@ -43,8 +44,8 @@ use tokio::io::{AsyncReadExt, Join, join}; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::{Receiver, Sender, channel}; use tokio::task::JoinSet; -use tokio::time::{Instant, timeout}; -use tokio::{join, pin, select}; +use tokio::time::timeout; +use tokio::{join, select}; use tokio_util::sync::PollSender; use tracing::{debug, error, info, instrument, trace, warn}; @@ -279,6 +280,7 @@ impl From<(SendStream, RecvStream)> for QuicStream { pub struct NatDstQuicConnector { pub(crate) endpoint: Endpoint, pub(crate) peer_mgr: Weak, + pub(crate) conn_map: Cache, } #[async_trait::async_trait] @@ -302,7 +304,6 @@ impl NatDstConnector for NatDstQuicConnector { }; trace!("quic nat dst: {:?}, dst peers: {:?}", nat_dst, dst_peer_id); - let addr = QuicAddr::new(dst_peer_id, PacketType::QuicSrc).into(); let header = { let conn_data = QuicConnData { @@ -323,50 +324,65 @@ impl NatDstConnector for NatDstQuicConnector { buf.freeze() }; - let mut connect_tasks = JoinSet::>::new(); - let connect = |tasks: &mut JoinSet<_>| { + for attempt in 0..2 { let endpoint = self.endpoint.clone(); - let header = header.clone(); - tasks.spawn(async move { - let connection = endpoint.connect(addr, "")?.await?; - let mut stream: QuicStream = connection.open_bi().await?.into(); - stream.writer_mut().write_chunk(header).await?; - Ok(stream) - }); - }; - - connect(&mut connect_tasks); - - let timer = tokio::time::sleep(Duration::from_millis(200)); - pin!(timer); - - let mut retry_remain = 5; - loop { - select! { - Some(result) = connect_tasks.join_next() => { - match result { - Ok(Ok(stream)) => return Ok(stream.into()), - _ => { - if connect_tasks.is_empty() { - if retry_remain == 0 { - return Err(anyhow!("failed to connect to nat dst: {:?}", nat_dst).into()) - } - - retry_remain -= 1; - connect(&mut connect_tasks); - timer.as_mut().reset(Instant::now() + Duration::from_millis(200)) - } - } + let connection = match self + .conn_map + .try_get_with(dst_peer_id, async move { + endpoint + .connect(addr, "") + .map_err(|e| anyhow!("quic connect: {:#}", e))? + .await + .map_err(|e| anyhow!("quic connection: {:#}", e)) + }) + .await + { + Ok(conn) => conn, + Err(e) => { + if attempt == 0 { + debug!("quic connect failed, retrying: {:#}", e); + tokio::time::sleep(Duration::from_millis(300)).await; + continue; } + return Err(anyhow!("{:#}", e).into()); } - _ = &mut timer, if retry_remain > 0 => { - retry_remain -= 1; - connect(&mut connect_tasks); - timer.as_mut().reset(Instant::now() + Duration::from_millis(200)); + }; + + let stream: Result = async { + let mut stream: QuicStream = connection + .open_bi() + .await + .map_err(|e| anyhow!("open bi: {:#}", e))? + .into(); + stream.writer_mut().write_chunk(header.clone()).await?; + Ok(stream.into()) + } + .await; + + match stream { + Ok(stream) => return Ok(stream), + Err(error) => { + debug!( + ?dst_peer_id, + attempt, + ?error, + "quic connect: stream setup failed" + ); } } + + // Evict stale connection; + self.conn_map.invalidate(&dst_peer_id).await; } + + Err(anyhow!( + "quic connect: failed after {} attempts, dst_peer_id={}, nat_dst={}", + 2, + dst_peer_id, + nat_dst + ) + .into()) } #[inline] @@ -595,10 +611,17 @@ impl QuicStreamReceiver { } }; - match Self::establish_stream(stream, ctx.clone()).await { - Ok(stream) => drop(tasks.spawn(stream)), - Err(e) => warn!("failed to establish quic stream from {:?}: {:?}", connection.remote_address(), e), - } + let ctx = ctx.clone(); + tasks.spawn(async move { + match Self::establish_stream(stream, ctx).await { + Ok(transfer_fut) => { + if let Err(e) = transfer_fut.await { + warn!("quic stream transfer error: {:?}", e); + } + } + Err(e) => warn!("failed to establish quic stream: {:?}", e), + } + }); } res = tasks.join_next(), if !tasks.is_empty() => { @@ -840,11 +863,26 @@ impl QuicProxy { return; } + let conn_map = Cache::builder() + .max_capacity(u8::MAX.into()) // same with max_concurrent_bidi_streams, can be increased + .time_to_idle(Duration::from_secs(600)) + .build(); + + let conn_map_bg = conn_map.clone(); + self.tasks.spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + conn_map_bg.run_pending_tasks().await; + } + }); + let tcp_proxy = TcpProxyForQuicSrc(TcpProxy::new( peer_mgr.clone(), NatDstQuicConnector { endpoint: endpoint.clone(), peer_mgr: Arc::downgrade(&peer_mgr), + conn_map, }, ));