mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 10:14:35 +00:00
fix(quic): prune stopped endpoints from pool (#2195)
* remove wss port 0 compatibility code * fix(quic): prune stopped endpoints from pool
This commit is contained in:
+202
-16
@@ -14,8 +14,8 @@ use derivative::Derivative;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use parking_lot::RwLock;
|
||||
use quinn::{
|
||||
ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig,
|
||||
congestion::BbrConfig, default_runtime,
|
||||
ClientConfig, ConnectError, Connection, Endpoint, EndpointConfig, ServerConfig,
|
||||
TransportConfig, congestion::BbrConfig, default_runtime,
|
||||
};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::OnceLock;
|
||||
@@ -135,6 +135,12 @@ impl<Item> RwPool<Item> {
|
||||
self.resize();
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
let persistent_len = self.persistent.read().len();
|
||||
let ephemeral_len = self.ephemeral.read().len();
|
||||
persistent_len + ephemeral_len
|
||||
}
|
||||
|
||||
/// try to push an item to the ephemeral pool, return the item if full
|
||||
fn try_push(&self, item: Item) -> Option<Item> {
|
||||
let mut pool = self.ephemeral.write();
|
||||
@@ -168,6 +174,49 @@ impl<Item> RwPool<Item> {
|
||||
f(&mut persistent.iter().chain(ephemeral.iter()))
|
||||
}
|
||||
}
|
||||
|
||||
impl RwPool<Endpoint> {
|
||||
fn retain_endpoints<F>(&self, mut keep: F) -> usize
|
||||
where
|
||||
F: FnMut(&Endpoint) -> bool,
|
||||
{
|
||||
let persistent_removed = {
|
||||
let mut persistent = self.persistent.write();
|
||||
let before = persistent.len();
|
||||
persistent.retain(|endpoint| keep(endpoint));
|
||||
before - persistent.len()
|
||||
};
|
||||
|
||||
let ephemeral_removed = {
|
||||
let mut ephemeral = self.ephemeral.write();
|
||||
let before = ephemeral.len();
|
||||
ephemeral.retain(|endpoint| keep(endpoint));
|
||||
before - ephemeral.len()
|
||||
};
|
||||
|
||||
let removed = persistent_removed + ephemeral_removed;
|
||||
if removed > 0 {
|
||||
self.resize();
|
||||
}
|
||||
removed
|
||||
}
|
||||
|
||||
fn remove_by_local_addr(&self, local_addr: SocketAddr) -> usize {
|
||||
self.retain_endpoints(|endpoint| endpoint.local_addr().ok() != Some(local_addr))
|
||||
}
|
||||
|
||||
fn contains_local_addr(&self, local_addr: SocketAddr) -> bool {
|
||||
self.persistent
|
||||
.read()
|
||||
.iter()
|
||||
.any(|endpoint| endpoint.local_addr().ok() == Some(local_addr))
|
||||
|| self
|
||||
.ephemeral
|
||||
.read()
|
||||
.iter()
|
||||
.any(|endpoint| endpoint.local_addr().ok() == Some(local_addr))
|
||||
}
|
||||
}
|
||||
//endregion
|
||||
|
||||
//region endpoint manager
|
||||
@@ -262,6 +311,20 @@ impl QuicEndpointManager {
|
||||
QUIC_ENDPOINT_MANAGER.get().unwrap()
|
||||
}
|
||||
|
||||
fn client_pool(&self, ip_version: IpVersion) -> &RwPool<Endpoint> {
|
||||
let dual_stack = self.both.is_enabled();
|
||||
match ip_version {
|
||||
IpVersion::V4 if !dual_stack => &self.ipv4,
|
||||
_ => {
|
||||
if dual_stack {
|
||||
&self.both
|
||||
} else {
|
||||
&self.ipv6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a QUIC endpoint to be used as a server
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -288,14 +351,8 @@ impl QuicEndpointManager {
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
/// Get a quic endpoint to be used as a client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `ip_version`: the IP version of the remote address
|
||||
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
|
||||
let mgr = Self::load(global_ctx);
|
||||
|
||||
let (pool, endpoint) = mgr.create(|mgr| {
|
||||
fn client_endpoint(&self, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
|
||||
let (pool, endpoint) = self.create(|mgr| {
|
||||
let dual_stack = mgr.both.is_enabled();
|
||||
let (pool, addr) = match ip_version {
|
||||
IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()),
|
||||
@@ -318,6 +375,26 @@ impl QuicEndpointManager {
|
||||
Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone()))
|
||||
}
|
||||
|
||||
fn remove_endpoint(&self, endpoint: &Endpoint) -> usize {
|
||||
let Ok(local_addr) = endpoint.local_addr() else {
|
||||
return 0;
|
||||
};
|
||||
self.remove_endpoint_by_local_addr(local_addr)
|
||||
}
|
||||
|
||||
fn remove_endpoint_by_local_addr(&self, local_addr: SocketAddr) -> usize {
|
||||
[&self.ipv4, &self.ipv6, &self.both]
|
||||
.into_iter()
|
||||
.map(|pool| pool.remove_by_local_addr(local_addr))
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn contains_local_addr(&self, local_addr: SocketAddr) -> bool {
|
||||
[&self.ipv4, &self.ipv6, &self.both]
|
||||
.into_iter()
|
||||
.any(|pool| pool.contains_local_addr(local_addr))
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
addr: SocketAddr,
|
||||
@@ -327,14 +404,52 @@ impl QuicEndpointManager {
|
||||
} else {
|
||||
IpVersion::V6
|
||||
};
|
||||
let endpoint = Self::client(global_ctx, ip_version)?;
|
||||
let connection = endpoint
|
||||
.connect(addr, "localhost")
|
||||
.with_context(|| format!("failed to create connection to {}", addr))?
|
||||
Self::load(global_ctx)
|
||||
.connect_with_ip_version(addr, ip_version)
|
||||
.await
|
||||
.with_context(|| format!("failed to connect to {}", addr))?;
|
||||
}
|
||||
|
||||
Ok((endpoint, connection))
|
||||
async fn connect_with_ip_version(
|
||||
&self,
|
||||
addr: SocketAddr,
|
||||
ip_version: IpVersion,
|
||||
) -> Result<(Endpoint, Connection), TunnelError> {
|
||||
let max_endpoint_stopping_retries = self.client_pool(ip_version).len().saturating_add(1);
|
||||
let mut endpoint_stopping_retries = 0;
|
||||
|
||||
loop {
|
||||
let endpoint = self.client_endpoint(ip_version)?;
|
||||
let connecting = match endpoint.connect(addr, "localhost") {
|
||||
Ok(connecting) => connecting,
|
||||
Err(ConnectError::EndpointStopping) => {
|
||||
let local_addr = endpoint.local_addr().ok();
|
||||
let removed = self.remove_endpoint(&endpoint);
|
||||
endpoint_stopping_retries += 1;
|
||||
tracing::warn!(
|
||||
?addr,
|
||||
?local_addr,
|
||||
removed,
|
||||
"removed stopped quic endpoint and retry connect"
|
||||
);
|
||||
if endpoint_stopping_retries > max_endpoint_stopping_retries {
|
||||
return Err(anyhow::Error::new(ConnectError::EndpointStopping)
|
||||
.context(format!("failed to create connection to {}", addr))
|
||||
.into());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow::Error::new(e)
|
||||
.context(format!("failed to create connection to {}", addr))
|
||||
.into());
|
||||
}
|
||||
};
|
||||
let connection = connecting
|
||||
.await
|
||||
.with_context(|| format!("failed to connect to {}", addr))?;
|
||||
|
||||
return Ok((endpoint, connection));
|
||||
}
|
||||
}
|
||||
}
|
||||
//endregion
|
||||
@@ -398,6 +513,18 @@ impl QuicTunnelListener {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for QuicTunnelListener {
|
||||
fn drop(&mut self) {
|
||||
let Some(endpoint) = &self.endpoint else {
|
||||
return;
|
||||
};
|
||||
let Ok(local_addr) = endpoint.local_addr() else {
|
||||
return;
|
||||
};
|
||||
QuicEndpointManager::load(&self.global_ctx).remove_endpoint_by_local_addr(local_addr);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelListener for QuicTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
@@ -516,6 +643,20 @@ mod tests {
|
||||
get_mock_global_ctx_with_network(Some(identity))
|
||||
}
|
||||
|
||||
fn stopped_client_endpoint() -> (Endpoint, SocketAddr) {
|
||||
let rt = Builder::new_current_thread().enable_all().build().unwrap();
|
||||
let endpoint = rt.block_on(async {
|
||||
QuicEndpointManager::try_create((Ipv4Addr::UNSPECIFIED, 0).into(), false).unwrap()
|
||||
});
|
||||
let local_addr = endpoint.local_addr().unwrap();
|
||||
drop(rt);
|
||||
assert!(matches!(
|
||||
endpoint.connect("127.0.0.1:1".parse().unwrap(), "localhost"),
|
||||
Err(ConnectError::EndpointStopping)
|
||||
));
|
||||
(endpoint, local_addr)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quic_pingpong() {
|
||||
RUNTIME.block_on(quic_pingpong_impl())
|
||||
@@ -591,6 +732,51 @@ mod tests {
|
||||
assert!(port > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn listener_drop_removes_persistent_endpoint() {
|
||||
RUNTIME.block_on(listener_drop_removes_persistent_endpoint_impl())
|
||||
}
|
||||
async fn listener_drop_removes_persistent_endpoint_impl() {
|
||||
let global_ctx = global_ctx();
|
||||
let endpoint_addr = {
|
||||
let mut listener =
|
||||
QuicTunnelListener::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx.clone());
|
||||
listener.listen().await.unwrap();
|
||||
let endpoint_addr = listener.endpoint.as_ref().unwrap().local_addr().unwrap();
|
||||
assert!(QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr));
|
||||
endpoint_addr
|
||||
};
|
||||
|
||||
assert!(!QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_removes_stopped_endpoints_and_retries() {
|
||||
let (stopped_endpoint_a, stopped_addr_a) = stopped_client_endpoint();
|
||||
let (stopped_endpoint_b, stopped_addr_b) = stopped_client_endpoint();
|
||||
|
||||
RUNTIME.block_on(async move {
|
||||
let mgr = QuicEndpointManager::new(2);
|
||||
mgr.both.push(stopped_endpoint_a);
|
||||
mgr.both.push(stopped_endpoint_b);
|
||||
assert!(mgr.contains_local_addr(stopped_addr_a));
|
||||
assert!(mgr.contains_local_addr(stopped_addr_b));
|
||||
|
||||
let err = mgr
|
||||
.connect_with_ip_version("127.0.0.1:0".parse().unwrap(), IpVersion::V4)
|
||||
.await
|
||||
.unwrap_err();
|
||||
let err = format!("{:?}", err);
|
||||
assert!(
|
||||
err.contains("invalid remote address"),
|
||||
"unexpected error: {}",
|
||||
err
|
||||
);
|
||||
assert!(!mgr.contains_local_addr(stopped_addr_a));
|
||||
assert!(!mgr.contains_local_addr(stopped_addr_b));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_peer_addr() {
|
||||
RUNTIME.block_on(invalid_peer_addr_impl())
|
||||
|
||||
Reference in New Issue
Block a user