fix(quic): prune stopped endpoints from pool (#2195)

* remove wss port 0 compatibility code
* fix(quic): prune stopped endpoints from pool
This commit is contained in:
KKRainbow
2026-05-01 18:51:39 +08:00
committed by GitHub
parent 852d1c9e14
commit 3542e944cb
3 changed files with 260 additions and 34 deletions
+202 -16
View File
@@ -14,8 +14,8 @@ use derivative::Derivative;
use derive_more::{Deref, DerefMut};
use parking_lot::RwLock;
use quinn::{
ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig,
congestion::BbrConfig, default_runtime,
ClientConfig, ConnectError, Connection, Endpoint, EndpointConfig, ServerConfig,
TransportConfig, congestion::BbrConfig, default_runtime,
};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock;
@@ -135,6 +135,12 @@ impl<Item> RwPool<Item> {
self.resize();
}
fn len(&self) -> usize {
let persistent_len = self.persistent.read().len();
let ephemeral_len = self.ephemeral.read().len();
persistent_len + ephemeral_len
}
/// try to push an item to the ephemeral pool, return the item if full
fn try_push(&self, item: Item) -> Option<Item> {
let mut pool = self.ephemeral.write();
@@ -168,6 +174,49 @@ impl<Item> RwPool<Item> {
f(&mut persistent.iter().chain(ephemeral.iter()))
}
}
impl RwPool<Endpoint> {
fn retain_endpoints<F>(&self, mut keep: F) -> usize
where
F: FnMut(&Endpoint) -> bool,
{
let persistent_removed = {
let mut persistent = self.persistent.write();
let before = persistent.len();
persistent.retain(|endpoint| keep(endpoint));
before - persistent.len()
};
let ephemeral_removed = {
let mut ephemeral = self.ephemeral.write();
let before = ephemeral.len();
ephemeral.retain(|endpoint| keep(endpoint));
before - ephemeral.len()
};
let removed = persistent_removed + ephemeral_removed;
if removed > 0 {
self.resize();
}
removed
}
fn remove_by_local_addr(&self, local_addr: SocketAddr) -> usize {
self.retain_endpoints(|endpoint| endpoint.local_addr().ok() != Some(local_addr))
}
fn contains_local_addr(&self, local_addr: SocketAddr) -> bool {
self.persistent
.read()
.iter()
.any(|endpoint| endpoint.local_addr().ok() == Some(local_addr))
|| self
.ephemeral
.read()
.iter()
.any(|endpoint| endpoint.local_addr().ok() == Some(local_addr))
}
}
//endregion
//region endpoint manager
@@ -262,6 +311,20 @@ impl QuicEndpointManager {
QUIC_ENDPOINT_MANAGER.get().unwrap()
}
fn client_pool(&self, ip_version: IpVersion) -> &RwPool<Endpoint> {
let dual_stack = self.both.is_enabled();
match ip_version {
IpVersion::V4 if !dual_stack => &self.ipv4,
_ => {
if dual_stack {
&self.both
} else {
&self.ipv6
}
}
}
}
/// Get a QUIC endpoint to be used as a server
///
/// # Arguments
@@ -288,14 +351,8 @@ impl QuicEndpointManager {
Ok(endpoint)
}
/// Get a quic endpoint to be used as a client
///
/// # Arguments
/// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
fn client_endpoint(&self, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
let (pool, endpoint) = self.create(|mgr| {
let dual_stack = mgr.both.is_enabled();
let (pool, addr) = match ip_version {
IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()),
@@ -318,6 +375,26 @@ impl QuicEndpointManager {
Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone()))
}
fn remove_endpoint(&self, endpoint: &Endpoint) -> usize {
let Ok(local_addr) = endpoint.local_addr() else {
return 0;
};
self.remove_endpoint_by_local_addr(local_addr)
}
fn remove_endpoint_by_local_addr(&self, local_addr: SocketAddr) -> usize {
[&self.ipv4, &self.ipv6, &self.both]
.into_iter()
.map(|pool| pool.remove_by_local_addr(local_addr))
.sum()
}
fn contains_local_addr(&self, local_addr: SocketAddr) -> bool {
[&self.ipv4, &self.ipv6, &self.both]
.into_iter()
.any(|pool| pool.contains_local_addr(local_addr))
}
async fn connect(
global_ctx: &ArcGlobalCtx,
addr: SocketAddr,
@@ -327,14 +404,52 @@ impl QuicEndpointManager {
} else {
IpVersion::V6
};
let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint
.connect(addr, "localhost")
.with_context(|| format!("failed to create connection to {}", addr))?
Self::load(global_ctx)
.connect_with_ip_version(addr, ip_version)
.await
.with_context(|| format!("failed to connect to {}", addr))?;
}
Ok((endpoint, connection))
async fn connect_with_ip_version(
&self,
addr: SocketAddr,
ip_version: IpVersion,
) -> Result<(Endpoint, Connection), TunnelError> {
let max_endpoint_stopping_retries = self.client_pool(ip_version).len().saturating_add(1);
let mut endpoint_stopping_retries = 0;
loop {
let endpoint = self.client_endpoint(ip_version)?;
let connecting = match endpoint.connect(addr, "localhost") {
Ok(connecting) => connecting,
Err(ConnectError::EndpointStopping) => {
let local_addr = endpoint.local_addr().ok();
let removed = self.remove_endpoint(&endpoint);
endpoint_stopping_retries += 1;
tracing::warn!(
?addr,
?local_addr,
removed,
"removed stopped quic endpoint and retry connect"
);
if endpoint_stopping_retries > max_endpoint_stopping_retries {
return Err(anyhow::Error::new(ConnectError::EndpointStopping)
.context(format!("failed to create connection to {}", addr))
.into());
}
continue;
}
Err(e) => {
return Err(anyhow::Error::new(e)
.context(format!("failed to create connection to {}", addr))
.into());
}
};
let connection = connecting
.await
.with_context(|| format!("failed to connect to {}", addr))?;
return Ok((endpoint, connection));
}
}
}
//endregion
@@ -398,6 +513,18 @@ impl QuicTunnelListener {
}
}
impl Drop for QuicTunnelListener {
fn drop(&mut self) {
let Some(endpoint) = &self.endpoint else {
return;
};
let Ok(local_addr) = endpoint.local_addr() else {
return;
};
QuicEndpointManager::load(&self.global_ctx).remove_endpoint_by_local_addr(local_addr);
}
}
#[async_trait::async_trait]
impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
@@ -516,6 +643,20 @@ mod tests {
get_mock_global_ctx_with_network(Some(identity))
}
fn stopped_client_endpoint() -> (Endpoint, SocketAddr) {
let rt = Builder::new_current_thread().enable_all().build().unwrap();
let endpoint = rt.block_on(async {
QuicEndpointManager::try_create((Ipv4Addr::UNSPECIFIED, 0).into(), false).unwrap()
});
let local_addr = endpoint.local_addr().unwrap();
drop(rt);
assert!(matches!(
endpoint.connect("127.0.0.1:1".parse().unwrap(), "localhost"),
Err(ConnectError::EndpointStopping)
));
(endpoint, local_addr)
}
#[test]
fn quic_pingpong() {
RUNTIME.block_on(quic_pingpong_impl())
@@ -591,6 +732,51 @@ mod tests {
assert!(port > 0);
}
#[test]
fn listener_drop_removes_persistent_endpoint() {
RUNTIME.block_on(listener_drop_removes_persistent_endpoint_impl())
}
async fn listener_drop_removes_persistent_endpoint_impl() {
let global_ctx = global_ctx();
let endpoint_addr = {
let mut listener =
QuicTunnelListener::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx.clone());
listener.listen().await.unwrap();
let endpoint_addr = listener.endpoint.as_ref().unwrap().local_addr().unwrap();
assert!(QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr));
endpoint_addr
};
assert!(!QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr));
}
#[test]
fn connect_removes_stopped_endpoints_and_retries() {
let (stopped_endpoint_a, stopped_addr_a) = stopped_client_endpoint();
let (stopped_endpoint_b, stopped_addr_b) = stopped_client_endpoint();
RUNTIME.block_on(async move {
let mgr = QuicEndpointManager::new(2);
mgr.both.push(stopped_endpoint_a);
mgr.both.push(stopped_endpoint_b);
assert!(mgr.contains_local_addr(stopped_addr_a));
assert!(mgr.contains_local_addr(stopped_addr_b));
let err = mgr
.connect_with_ip_version("127.0.0.1:0".parse().unwrap(), IpVersion::V4)
.await
.unwrap_err();
let err = format!("{:?}", err);
assert!(
err.contains("invalid remote address"),
"unexpected error: {}",
err
);
assert!(!mgr.contains_local_addr(stopped_addr_a));
assert!(!mgr.contains_local_addr(stopped_addr_b));
});
}
#[test]
fn invalid_peer_addr() {
RUNTIME.block_on(invalid_peer_addr_impl())