From 3542e944cb4554ae7df550d1f4769a4244f516ef Mon Sep 17 00:00:00 2001 From: KKRainbow <443152178@qq.com> Date: Fri, 1 May 2026 18:51:39 +0800 Subject: [PATCH] fix(quic): prune stopped endpoints from pool (#2195) * remove wss port 0 compatibility code * fix(quic): prune stopped endpoints from pool --- easytier/src/common/dns.rs | 29 ++-- easytier/src/tests/credential_tests.rs | 47 +++++- easytier/src/tunnel/quic.rs | 218 +++++++++++++++++++++++-- 3 files changed, 260 insertions(+), 34 deletions(-) diff --git a/easytier/src/common/dns.rs b/easytier/src/common/dns.rs index 20922526..5c047a46 100644 --- a/easytier/src/common/dns.rs +++ b/easytier/src/common/dns.rs @@ -73,16 +73,6 @@ pub async fn socket_addrs( .port() .or_else(default_port_number) .ok_or(Error::InvalidUrl(url.to_string()))?; - // See https://github.com/EasyTier/EasyTier/pull/947 - // here is for compatibility with old version - let port = match port { - 0 => match url.scheme() { - "ws" => 80, - "wss" => 443, - _ => port, - }, - _ => port, - }; // if host is an ip address, return it directly match host { @@ -139,4 +129,23 @@ mod tests { assert_eq!(2, addrs.len(), "addrs: {:?}", addrs); println!("addrs2: {:?}", addrs); } + + #[tokio::test] + async fn socket_addrs_preserves_explicit_zero_port() { + let cases = [ + ("ws://127.0.0.1:0", 80, 0), + ("wss://127.0.0.1:0", 443, 0), + ("ws://127.0.0.1", 80, 80), + ("wss://127.0.0.1", 443, 443), + ]; + + for (raw_url, default_port, expected_port) in cases { + let url = url::Url::parse(raw_url).unwrap(); + let addrs = socket_addrs(&url, || Some(default_port)).await.unwrap(); + assert_eq!( + addrs, + vec![SocketAddr::from(([127, 0, 0, 1], expected_port))] + ); + } + } } diff --git a/easytier/src/tests/credential_tests.rs b/easytier/src/tests/credential_tests.rs index e3652da0..0907a322 100644 --- a/easytier/src/tests/credential_tests.rs +++ b/easytier/src/tests/credential_tests.rs @@ -247,12 +247,14 @@ fn create_public_server_config() -> TomlConfigLoader { config } -fn create_need_p2p_admin_config() -> TomlConfigLoader { +fn create_need_p2p_admin_config(listener_scheme: &str) -> TomlConfigLoader { let config = TomlConfigLoader::default(); config.set_inst_name(NEED_P2P_ADMIN_NETWORK_NAME.to_string()); config.set_hostname(Some("need-p2p-admin".to_string())); config.set_netns(Some("ns_c3".to_string())); - config.set_listeners(vec!["tcp://0.0.0.0:11020".parse().unwrap()]); + config.set_listeners(vec![ + format!("{listener_scheme}://0.0.0.0:0").parse().unwrap(), + ]); config.set_network_identity(NetworkIdentity::new( NEED_P2P_ADMIN_NETWORK_NAME.to_string(), PUBLIC_SERVER_SHARED_SECRET.to_string(), @@ -326,6 +328,21 @@ async fn wait_direct_peer(inst: &Instance, peer_id: u32, timeout: Duration, labe .await; } +async fn wait_running_listener(inst: &Instance, scheme: &str, timeout: Duration, label: &str) { + wait_for_condition( + || async { + let listeners = inst.get_global_ctx().get_running_listeners(); + let matched = listeners.iter().any(|listener| { + listener.scheme() == scheme && listener.port().is_some_and(|p| p != 0) + }); + println!("{label}: running listeners={:?}", listeners); + matched + }, + timeout, + ) + .await; +} + async fn wait_route_cost(inst: &Instance, peer_id: u32, cost: i32, timeout: Duration, label: &str) { wait_for_condition( || async { @@ -370,18 +387,32 @@ async fn wait_foreign_network_count(inst: &Instance, expected: usize, timeout: D /// Public server <- admin peer (need_p2p) <- two credential peers. /// /// Credential peers set `disable_p2p=true`, while the admin peer advertises `need_p2p=true`. -/// The credential peers should still proactively build direct TCP peers with the admin peer -/// through peer RPC forwarded by the public server. +/// The credential peers should still proactively build direct peers with the admin peer through +/// peer RPC forwarded by the public server, even when the admin listener binds an ephemeral port. +#[rstest] +#[case("quic")] +#[case("wss")] +#[case("tcp")] +#[case("udp")] #[tokio::test] #[serial_test::serial] -async fn credential_peers_p2p_to_need_p2p_admin_through_public_server() { +async fn credential_peers_p2p_to_need_p2p_admin_through_public_server( + #[case] admin_listener_scheme: &str, +) { prepare_credential_network(); let mut public_server_inst = Instance::new(create_public_server_config()); public_server_inst.run().await.unwrap(); - let mut admin_inst = Instance::new(create_need_p2p_admin_config()); + let mut admin_inst = Instance::new(create_need_p2p_admin_config(admin_listener_scheme)); admin_inst.run().await.unwrap(); + wait_running_listener( + &admin_inst, + admin_listener_scheme, + Duration::from_secs(10), + "admin ephemeral listener", + ) + .await; admin_inst .get_conn_manager() .add_connector(UdpTunnelConnector::new( @@ -458,8 +489,8 @@ async fn credential_peers_p2p_to_need_p2p_admin_through_public_server() { let credential_a_peer_id = credential_a_inst.peer_id(); let credential_b_peer_id = credential_b_inst.peer_id(); println!( - "admin={}, credential_a={}, credential_b={}", - admin_peer_id, credential_a_peer_id, credential_b_peer_id + "admin={}, credential_a={}, credential_b={}, admin_listener_scheme={}", + admin_peer_id, credential_a_peer_id, credential_b_peer_id, admin_listener_scheme ); wait_direct_peer( diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs index 2a30396b..c4304dae 100644 --- a/easytier/src/tunnel/quic.rs +++ b/easytier/src/tunnel/quic.rs @@ -14,8 +14,8 @@ use derivative::Derivative; use derive_more::{Deref, DerefMut}; use parking_lot::RwLock; use quinn::{ - ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig, - congestion::BbrConfig, default_runtime, + ClientConfig, ConnectError, Connection, Endpoint, EndpointConfig, ServerConfig, + TransportConfig, congestion::BbrConfig, default_runtime, }; use std::net::{Ipv4Addr, Ipv6Addr}; use std::sync::OnceLock; @@ -135,6 +135,12 @@ impl RwPool { self.resize(); } + fn len(&self) -> usize { + let persistent_len = self.persistent.read().len(); + let ephemeral_len = self.ephemeral.read().len(); + persistent_len + ephemeral_len + } + /// try to push an item to the ephemeral pool, return the item if full fn try_push(&self, item: Item) -> Option { let mut pool = self.ephemeral.write(); @@ -168,6 +174,49 @@ impl RwPool { f(&mut persistent.iter().chain(ephemeral.iter())) } } + +impl RwPool { + fn retain_endpoints(&self, mut keep: F) -> usize + where + F: FnMut(&Endpoint) -> bool, + { + let persistent_removed = { + let mut persistent = self.persistent.write(); + let before = persistent.len(); + persistent.retain(|endpoint| keep(endpoint)); + before - persistent.len() + }; + + let ephemeral_removed = { + let mut ephemeral = self.ephemeral.write(); + let before = ephemeral.len(); + ephemeral.retain(|endpoint| keep(endpoint)); + before - ephemeral.len() + }; + + let removed = persistent_removed + ephemeral_removed; + if removed > 0 { + self.resize(); + } + removed + } + + fn remove_by_local_addr(&self, local_addr: SocketAddr) -> usize { + self.retain_endpoints(|endpoint| endpoint.local_addr().ok() != Some(local_addr)) + } + + fn contains_local_addr(&self, local_addr: SocketAddr) -> bool { + self.persistent + .read() + .iter() + .any(|endpoint| endpoint.local_addr().ok() == Some(local_addr)) + || self + .ephemeral + .read() + .iter() + .any(|endpoint| endpoint.local_addr().ok() == Some(local_addr)) + } +} //endregion //region endpoint manager @@ -262,6 +311,20 @@ impl QuicEndpointManager { QUIC_ENDPOINT_MANAGER.get().unwrap() } + fn client_pool(&self, ip_version: IpVersion) -> &RwPool { + let dual_stack = self.both.is_enabled(); + match ip_version { + IpVersion::V4 if !dual_stack => &self.ipv4, + _ => { + if dual_stack { + &self.both + } else { + &self.ipv6 + } + } + } + } + /// Get a QUIC endpoint to be used as a server /// /// # Arguments @@ -288,14 +351,8 @@ impl QuicEndpointManager { Ok(endpoint) } - /// Get a quic endpoint to be used as a client - /// - /// # Arguments - /// * `ip_version`: the IP version of the remote address - fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result { - let mgr = Self::load(global_ctx); - - let (pool, endpoint) = mgr.create(|mgr| { + fn client_endpoint(&self, ip_version: IpVersion) -> Result { + let (pool, endpoint) = self.create(|mgr| { let dual_stack = mgr.both.is_enabled(); let (pool, addr) = match ip_version { IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()), @@ -318,6 +375,26 @@ impl QuicEndpointManager { Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone())) } + fn remove_endpoint(&self, endpoint: &Endpoint) -> usize { + let Ok(local_addr) = endpoint.local_addr() else { + return 0; + }; + self.remove_endpoint_by_local_addr(local_addr) + } + + fn remove_endpoint_by_local_addr(&self, local_addr: SocketAddr) -> usize { + [&self.ipv4, &self.ipv6, &self.both] + .into_iter() + .map(|pool| pool.remove_by_local_addr(local_addr)) + .sum() + } + + fn contains_local_addr(&self, local_addr: SocketAddr) -> bool { + [&self.ipv4, &self.ipv6, &self.both] + .into_iter() + .any(|pool| pool.contains_local_addr(local_addr)) + } + async fn connect( global_ctx: &ArcGlobalCtx, addr: SocketAddr, @@ -327,14 +404,52 @@ impl QuicEndpointManager { } else { IpVersion::V6 }; - let endpoint = Self::client(global_ctx, ip_version)?; - let connection = endpoint - .connect(addr, "localhost") - .with_context(|| format!("failed to create connection to {}", addr))? + Self::load(global_ctx) + .connect_with_ip_version(addr, ip_version) .await - .with_context(|| format!("failed to connect to {}", addr))?; + } - Ok((endpoint, connection)) + async fn connect_with_ip_version( + &self, + addr: SocketAddr, + ip_version: IpVersion, + ) -> Result<(Endpoint, Connection), TunnelError> { + let max_endpoint_stopping_retries = self.client_pool(ip_version).len().saturating_add(1); + let mut endpoint_stopping_retries = 0; + + loop { + let endpoint = self.client_endpoint(ip_version)?; + let connecting = match endpoint.connect(addr, "localhost") { + Ok(connecting) => connecting, + Err(ConnectError::EndpointStopping) => { + let local_addr = endpoint.local_addr().ok(); + let removed = self.remove_endpoint(&endpoint); + endpoint_stopping_retries += 1; + tracing::warn!( + ?addr, + ?local_addr, + removed, + "removed stopped quic endpoint and retry connect" + ); + if endpoint_stopping_retries > max_endpoint_stopping_retries { + return Err(anyhow::Error::new(ConnectError::EndpointStopping) + .context(format!("failed to create connection to {}", addr)) + .into()); + } + continue; + } + Err(e) => { + return Err(anyhow::Error::new(e) + .context(format!("failed to create connection to {}", addr)) + .into()); + } + }; + let connection = connecting + .await + .with_context(|| format!("failed to connect to {}", addr))?; + + return Ok((endpoint, connection)); + } } } //endregion @@ -398,6 +513,18 @@ impl QuicTunnelListener { } } +impl Drop for QuicTunnelListener { + fn drop(&mut self) { + let Some(endpoint) = &self.endpoint else { + return; + }; + let Ok(local_addr) = endpoint.local_addr() else { + return; + }; + QuicEndpointManager::load(&self.global_ctx).remove_endpoint_by_local_addr(local_addr); + } +} + #[async_trait::async_trait] impl TunnelListener for QuicTunnelListener { async fn listen(&mut self) -> Result<(), TunnelError> { @@ -516,6 +643,20 @@ mod tests { get_mock_global_ctx_with_network(Some(identity)) } + fn stopped_client_endpoint() -> (Endpoint, SocketAddr) { + let rt = Builder::new_current_thread().enable_all().build().unwrap(); + let endpoint = rt.block_on(async { + QuicEndpointManager::try_create((Ipv4Addr::UNSPECIFIED, 0).into(), false).unwrap() + }); + let local_addr = endpoint.local_addr().unwrap(); + drop(rt); + assert!(matches!( + endpoint.connect("127.0.0.1:1".parse().unwrap(), "localhost"), + Err(ConnectError::EndpointStopping) + )); + (endpoint, local_addr) + } + #[test] fn quic_pingpong() { RUNTIME.block_on(quic_pingpong_impl()) @@ -591,6 +732,51 @@ mod tests { assert!(port > 0); } + #[test] + fn listener_drop_removes_persistent_endpoint() { + RUNTIME.block_on(listener_drop_removes_persistent_endpoint_impl()) + } + async fn listener_drop_removes_persistent_endpoint_impl() { + let global_ctx = global_ctx(); + let endpoint_addr = { + let mut listener = + QuicTunnelListener::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx.clone()); + listener.listen().await.unwrap(); + let endpoint_addr = listener.endpoint.as_ref().unwrap().local_addr().unwrap(); + assert!(QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr)); + endpoint_addr + }; + + assert!(!QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr)); + } + + #[test] + fn connect_removes_stopped_endpoints_and_retries() { + let (stopped_endpoint_a, stopped_addr_a) = stopped_client_endpoint(); + let (stopped_endpoint_b, stopped_addr_b) = stopped_client_endpoint(); + + RUNTIME.block_on(async move { + let mgr = QuicEndpointManager::new(2); + mgr.both.push(stopped_endpoint_a); + mgr.both.push(stopped_endpoint_b); + assert!(mgr.contains_local_addr(stopped_addr_a)); + assert!(mgr.contains_local_addr(stopped_addr_b)); + + let err = mgr + .connect_with_ip_version("127.0.0.1:0".parse().unwrap(), IpVersion::V4) + .await + .unwrap_err(); + let err = format!("{:?}", err); + assert!( + err.contains("invalid remote address"), + "unexpected error: {}", + err + ); + assert!(!mgr.contains_local_addr(stopped_addr_a)); + assert!(!mgr.contains_local_addr(stopped_addr_b)); + }); + } + #[test] fn invalid_peer_addr() { RUNTIME.block_on(invalid_peer_addr_impl())