refactor: remove NoGroAsyncUdpSocket (#1867)

This commit is contained in:
Luna Yao
2026-04-10 17:22:08 +02:00
committed by GitHub
parent 19c80c7b9c
commit 8311b11713
12 changed files with 401 additions and 172 deletions
+358 -138
View File
@@ -2,22 +2,25 @@
//!
//! Checkout the `README.md` for guidance.
use std::{
error::Error, io::IoSliceMut, net::SocketAddr, pin::Pin, sync::Arc, task::Poll, time::Duration,
};
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use crate::common::global_ctx::ArcGlobalCtx;
use crate::tunnel::{
FromUrl, TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_sokcet2},
TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
};
use anyhow::Context;
use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use derivative::Derivative;
use derive_more::{Deref, DerefMut};
use parking_lot::RwLock;
use quinn::{
AsyncUdpSocket, ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig,
TransportConfig, UdpPoller, congestion::BbrConfig, udp::RecvMeta,
ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig,
congestion::BbrConfig, default_runtime,
};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock;
use std::{net::SocketAddr, sync::Arc, time::Duration};
// region config
pub fn transport_config() -> Arc<TransportConfig> {
let mut config = TransportConfig::default();
@@ -50,86 +53,287 @@ pub fn endpoint_config() -> EndpointConfig {
config.max_udp_payload_size(65527).unwrap();
config
}
//endregion
#[derive(Clone, Debug)]
struct NoGroAsyncUdpSocket {
inner: Arc<dyn AsyncUdpSocket>,
//region rw pool
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
#[derive(Debug, Deref, DerefMut)]
struct RwPoolInner<Item> {
#[deref]
#[deref_mut]
pool: Vec<Item>,
enabled: bool,
}
impl AsyncUdpSocket for NoGroAsyncUdpSocket {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
self.inner.clone().create_io_poller()
#[derive(Debug)]
struct RwPool<Item> {
ephemeral: RwLock<RwPoolInner<Item>>,
persistent: RwLock<RwPoolInner<Item>>,
capacity: usize,
}
impl<Item> RwPool<Item> {
fn new(capacity: usize) -> Self {
Self {
ephemeral: RwLock::new(RwPoolInner::default()),
persistent: RwLock::new(RwPoolInner::default()),
capacity,
}
}
fn try_send(&self, transmit: &quinn::udp::Transmit) -> std::io::Result<()> {
self.inner.try_send(transmit)
}
/// Receive UDP datagrams, or register to be woken if receiving may succeed in the future
fn poll_recv(
/// return the capacity of the ephemeral pool;
/// if `ephemeral` or `persistent` is None, read lock `self`'s pool
fn capacity(
&self,
cx: &mut std::task::Context,
bufs: &mut [IoSliceMut<'_>],
meta: &mut [RecvMeta],
) -> Poll<std::io::Result<usize>> {
self.inner.poll_recv(cx, bufs, meta)
ephemeral: Option<&RwPoolInner<Item>>,
persistent: Option<&RwPoolInner<Item>>,
) -> usize {
let guard;
let ephemeral = if let Some(ephemeral) = ephemeral {
ephemeral
} else {
guard = self.ephemeral.read();
&guard
};
let guard;
let persistent = if let Some(persistent) = persistent {
persistent
} else {
guard = self.persistent.read();
&guard
};
(self.capacity * ephemeral.enabled as usize).saturating_sub(persistent.len())
}
/// Look up the local IP address and port used by this socket
fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.inner.local_addr()
fn is_full(&self) -> bool {
let pool = self.ephemeral.read();
pool.len() >= self.capacity(Some(&pool), None)
}
fn may_fragment(&self) -> bool {
self.inner.may_fragment()
fn is_enabled(&self) -> bool {
self.ephemeral.read().enabled
}
fn max_transmit_segments(&self) -> usize {
self.inner.max_transmit_segments()
fn enable(&self) {
self.ephemeral.write().enabled = true;
self.resize();
}
fn max_receive_segments(&self) -> usize {
1
fn disable(&self) {
self.ephemeral.write().enabled = false;
self.resize();
}
/// push an item to the persistent pool
fn push(&self, item: Item) {
self.persistent.write().push(item);
self.resize();
}
/// try to push an item to the ephemeral pool, return the item if full
fn try_push(&self, item: Item) -> Option<Item> {
let mut pool = self.ephemeral.write();
if pool.len() < self.capacity(Some(&pool), None) {
pool.push(item);
return None;
}
Some(item)
}
fn resize(&self) {
let resize = {
let pool = self.ephemeral.read();
pool.capacity() != self.capacity(Some(&pool), None)
};
if resize {
let mut pool = self.ephemeral.write();
let capacity = self.capacity(Some(&pool), None);
pool.reserve_exact(capacity);
pool.truncate(capacity);
pool.shrink_to(capacity);
}
}
fn with_iter<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut dyn Iterator<Item = &Item>) -> R,
{
let ephemeral = self.ephemeral.read();
let persistent = self.persistent.read();
f(&mut persistent.iter().chain(ephemeral.iter()))
}
}
//endregion
//region endpoint manager
#[derive(Debug)]
pub struct QuicEndpointManager {
ipv4: RwPool<Endpoint>,
ipv6: RwPool<Endpoint>,
both: RwPool<Endpoint>,
}
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
impl QuicEndpointManager {
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
let socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
.map_err(std::io::Error::other)?;
let socket = std::net::UdpSocket::from(socket);
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config(),
None,
runtime.wrap_udp_socket(socket)?,
runtime,
)?;
endpoint.set_default_client_config(client_config());
Ok(endpoint)
}
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
where
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
{
loop {
let (pool, r) = selector(self);
let Some((addr, dual_stack)) = r else {
return Ok((pool, None));
};
let endpoint = Self::try_create(addr, dual_stack);
if let Err(e) = endpoint.as_ref()
&& dual_stack
{
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
self.both.disable();
self.ipv4.enable();
self.ipv6.enable();
continue;
}
return Ok((pool, Some(endpoint?)));
}
}
}
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address
/// and port.
///
/// ## Returns
///
/// - an [`Endpoint`] configured to accept incoming QUIC connections
#[allow(unused)]
pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<Endpoint, Box<dyn Error>> {
let server_config = server_config();
let client_config = client_config();
let endpoint_config = endpoint_config();
impl QuicEndpointManager {
fn new(capacity: usize) -> Self {
let ipv4 = RwPool::new(capacity.div_ceil(2));
let ipv6 = RwPool::new(capacity.div_ceil(2));
let both = RwPool::new(capacity);
both.enable();
Self { ipv4, ipv6, both }
}
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &bind_addr)?;
let socket = std::net::UdpSocket::from(socket2_socket);
fn load(global_ctx: &ArcGlobalCtx) -> &Self {
let capacity = global_ctx
.config
.get_flags()
.multi_thread
.then(std::thread::available_parallelism)
.and_then(|r| r.ok())
.map(|n| n.get())
.unwrap_or(1);
let runtime =
quinn::default_runtime().ok_or_else(|| std::io::Error::other("no async runtime found"))?;
let socket: NoGroAsyncUdpSocket = NoGroAsyncUdpSocket {
inner: runtime.wrap_udp_socket(socket)?,
};
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config,
Some(server_config),
Arc::new(socket),
runtime,
)?;
endpoint.set_default_client_config(client_config);
Ok(endpoint)
let mgr = QUIC_ENDPOINT_MANAGER.get();
match mgr {
Some(mgr) => {
for pool in [&mgr.ipv4, &mgr.ipv6, &mgr.both] {
pool.resize();
}
}
None => {
let _ = QUIC_ENDPOINT_MANAGER.set(Self::new(capacity));
}
}
QUIC_ENDPOINT_MANAGER.get().unwrap()
}
/// Get a QUIC endpoint to be used as a server
///
/// # Arguments
/// * `addr`: listen address
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
let dual_stack = addr.ip() == Ipv6Addr::UNSPECIFIED && mgr.both.is_enabled();
let pool = if addr.is_ipv4() {
&mgr.ipv4
} else if dual_stack {
&mgr.both
} else {
&mgr.ipv6
};
(pool, Some((addr, dual_stack)))
})?;
let endpoint = endpoint.expect("server endpoint creation should not return None");
endpoint.set_server_config(Some(server_config()));
pool.push(endpoint.clone());
Ok(endpoint)
}
/// Get a quic endpoint to be used as a client
///
/// # Arguments
/// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
let dual_stack = mgr.both.is_enabled();
let (pool, addr) = match ip_version {
IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()),
_ => {
let pool = if dual_stack { &mgr.both } else { &mgr.ipv6 };
(pool, (Ipv6Addr::UNSPECIFIED, 0).into())
}
};
if pool.is_full() {
(pool, None)
} else {
(pool, Some((addr, dual_stack)))
}
})?;
if let Some(endpoint) = endpoint {
pool.try_push(endpoint);
}
Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone()))
}
async fn connect(
global_ctx: &ArcGlobalCtx,
addr: SocketAddr,
) -> std::io::Result<(Endpoint, Connection)> {
let ip_version = if addr.ip().is_ipv4() {
IpVersion::V4
} else {
IpVersion::V6
};
let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint
.connect(addr, "localhost")
.map_err(std::io::Error::other)?
.await?;
Ok((endpoint, connection))
}
}
#[allow(unused)]
pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"];
//endregion
struct ConnWrapper {
conn: Connection,
@@ -143,13 +347,15 @@ impl Drop for ConnWrapper {
pub struct QuicTunnelListener {
addr: url::Url,
global_ctx: ArcGlobalCtx,
endpoint: Option<Endpoint>,
}
impl QuicTunnelListener {
pub fn new(addr: url::Url) -> Self {
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
QuicTunnelListener {
addr,
global_ctx,
endpoint: None,
}
}
@@ -192,13 +398,11 @@ impl QuicTunnelListener {
impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let endpoint = make_server_endpoint(addr)
.map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?;
self.endpoint = Some(endpoint);
let endpoint = QuicEndpointManager::server(&self.global_ctx, addr)?;
self.addr
.set_port(Some(self.endpoint.as_ref().unwrap().local_addr()?.port()))
.set_port(Some(endpoint.local_addr()?.port()))
.unwrap();
self.endpoint = Some(endpoint);
Ok(())
}
@@ -222,15 +426,15 @@ impl TunnelListener for QuicTunnelListener {
pub struct QuicTunnelConnector {
addr: url::Url,
endpoint: Option<Endpoint>,
global_ctx: ArcGlobalCtx,
ip_version: IpVersion,
}
impl QuicTunnelConnector {
pub fn new(addr: url::Url) -> Self {
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
QuicTunnelConnector {
addr,
endpoint: None,
global_ctx,
ip_version: IpVersion::Both,
}
}
@@ -240,38 +444,10 @@ impl QuicTunnelConnector {
impl TunnelConnector for QuicTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if addr.port() == 0 {
return Err(TunnelError::InvalidAddr(format!(
"invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)",
self.addr
)));
}
let local_addr = if addr.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
};
let mut endpoint = Endpoint::client(local_addr.parse().unwrap())?;
endpoint.set_default_client_config(client_config());
// connect to server
let connection = endpoint
.connect(addr, "localhost")
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to create QUIC connection, url: {}, error: {}",
self.addr, e
))
})?
.await
.with_context(|| "connect failed")?;
tracing::info!("[client] connected: addr={}", connection.remote_address());
let (endpoint, connection) = QuicEndpointManager::connect(&self.global_ctx, addr).await?;
let local_addr = endpoint.local_addr()?;
self.endpoint = Some(endpoint);
let (w, r) = connection
.open_bi()
.await
@@ -308,68 +484,112 @@ impl TunnelConnector for QuicTunnelConnector {
#[cfg(test)]
mod tests {
use crate::common::global_ctx::tests::get_mock_global_ctx_with_network;
use crate::tunnel::{
IpVersion, TunnelConnector,
TunnelConnector,
common::tests::{_tunnel_bench, _tunnel_pingpong},
};
use std::sync::LazyLock;
use tokio::runtime::{Builder, Runtime};
use super::*;
#[tokio::test]
async fn quic_pingpong() {
let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
// Shared runtime for all tests to avoid endpoint invalidation across runtimes
static RUNTIME: LazyLock<Runtime> =
LazyLock::new(|| Builder::new_multi_thread().enable_all().build().unwrap());
fn global_ctx() -> ArcGlobalCtx {
let identity = crate::common::config::NetworkIdentity::default();
get_mock_global_ctx_with_network(Some(identity))
}
#[test]
fn quic_pingpong() {
RUNTIME.block_on(quic_pingpong_impl())
}
async fn quic_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::]:21011".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap(), global_ctx());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn quic_bench() {
let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
#[test]
fn quic_bench() {
RUNTIME.block_on(quic_bench_impl())
}
async fn quic_bench_impl() {
let listener = QuicTunnelListener::new("quic://[::]:21012".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap(), global_ctx());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn ipv6_pingpong() {
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
#[test]
fn ipv6_pingpong() {
RUNTIME.block_on(ipv6_pingpong_impl())
}
async fn ipv6_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ipv6_domain_pingpong() {
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap());
let mut connector =
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
#[test]
fn ipv6_domain_pingpong() {
RUNTIME.block_on(ipv6_domain_pingpong_impl())
}
async fn ipv6_domain_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap(), global_ctx());
let mut connector = QuicTunnelConnector::new(
"quic://test.easytier.top:31016".parse().unwrap(),
global_ctx(),
);
connector.set_ip_version(IpVersion::V6);
_tunnel_pingpong(listener, connector).await;
let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
let mut connector =
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
let listener =
QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap(), global_ctx());
let mut connector = QuicTunnelConnector::new(
"quic://test.easytier.top:31016".parse().unwrap(),
global_ctx(),
);
connector.set_ip_version(IpVersion::V4);
_tunnel_pingpong(listener, connector).await;
}
#[tokio::test]
async fn test_alloc_port() {
#[test]
fn alloc_port() {
RUNTIME.block_on(alloc_port_impl())
}
async fn alloc_port_impl() {
// v4
let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
let mut listener =
QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap(), global_ctx());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
// v6
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap());
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap(), global_ctx());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
}
#[tokio::test]
async fn quic_connector_reject_port_zero() {
let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
let err = connector.connect().await.unwrap_err().to_string();
assert!(err.contains("port 0"), "unexpected error: {}", err);
#[test]
fn invalid_peer_addr() {
RUNTIME.block_on(invalid_peer_addr_impl())
}
async fn invalid_peer_addr_impl() {
let mut connector =
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
let err = connector.connect().await.unwrap_err();
assert!(
err.to_string().contains("invalid remote address"),
"unexpected error: {:?}",
err
);
}
}