refactor: remove NoGroAsyncUdpSocket (#1867)

This commit is contained in:
Luna Yao
2026-04-10 17:22:08 +02:00
committed by GitHub
parent 19c80c7b9c
commit 8311b11713
12 changed files with 401 additions and 172 deletions
+3 -1
View File
@@ -105,7 +105,9 @@ pub async fn create_connector_by_url(
IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(), IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(),
IpScheme::Udp => UdpTunnelConnector::new(url).boxed(), IpScheme::Udp => UdpTunnelConnector::new(url).boxed(),
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelConnector::new(url).boxed(), IpScheme::Quic => {
tunnel::quic::QuicTunnelConnector::new(url, global_ctx.clone()).boxed()
}
#[cfg(feature = "wireguard")] #[cfg(feature = "wireguard")]
IpScheme::Wg => { IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector}; use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
+6 -4
View File
@@ -26,7 +26,9 @@ use derivative::Derivative;
use derive_more::{Constructor, Deref, DerefMut, From, Into}; use derive_more::{Constructor, Deref, DerefMut, From, Into};
use prost::Message; use prost::Message;
use quinn::udp::{EcnCodepoint, RecvMeta, Transmit}; use quinn::udp::{EcnCodepoint, RecvMeta, Transmit};
use quinn::{AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, TokioRuntime, UdpPoller}; use quinn::{
AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, UdpPoller, default_runtime,
};
use std::cmp::min; use std::cmp::min;
use std::future::Future; use std::future::Future;
use std::io::IoSliceMut; use std::io::IoSliceMut;
@@ -806,7 +808,7 @@ impl QuicProxy {
endpoint_config(), endpoint_config(),
Some(server_config()), Some(server_config()),
Arc::new(socket), Arc::new(socket),
Arc::new(TokioRuntime), default_runtime().unwrap(),
) )
.unwrap(); .unwrap();
endpoint.set_default_client_config(client_config()); endpoint.set_default_client_config(client_config());
@@ -1020,7 +1022,7 @@ mod tests {
endpoint_config.clone(), endpoint_config.clone(),
Some(server_config.clone()), Some(server_config.clone()),
socket_client.clone(), socket_client.clone(),
Arc::new(TokioRuntime), default_runtime().unwrap(),
) )
.unwrap(); .unwrap();
client_endpoint.set_default_client_config(client_config.clone()); client_endpoint.set_default_client_config(client_config.clone());
@@ -1030,7 +1032,7 @@ mod tests {
endpoint_config.clone(), endpoint_config.clone(),
Some(server_config.clone()), Some(server_config.clone()),
socket_server.clone(), socket_server.clone(),
Arc::new(TokioRuntime), default_runtime().unwrap(),
) )
.unwrap(); .unwrap();
server_endpoint.set_default_client_config(client_config.clone()); server_endpoint.set_default_client_config(client_config.clone());
+3 -3
View File
@@ -31,7 +31,7 @@ use crate::{
tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device}, tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
}, },
tunnel::{ tunnel::{
common::setup_sokcet2, common::setup_socket2,
packet_def::{PacketType, ZCPacket}, packet_def::{PacketType, ZCPacket},
}, },
}; };
@@ -336,7 +336,7 @@ fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<TcpListener, Error
Some(socket2::Protocol::TCP), Some(socket2::Protocol::TCP),
)?; )?;
setup_sokcet2(&socket2_socket, &addr)?; setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into()); let socket = TcpSocket::from_std_stream(socket2_socket.into());
@@ -355,7 +355,7 @@ fn bind_udp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<UdpSocket, Error>
Some(socket2::Protocol::UDP), Some(socket2::Protocol::UDP),
)?; )?;
setup_sokcet2(&socket2_socket, &addr)?; setup_socket2(&socket2_socket, &addr, true)?;
Ok(UdpSocket::from_std(socket2_socket.into())?) Ok(UdpSocket::from_std(socket2_socket.into())?)
} }
+2 -2
View File
@@ -29,7 +29,7 @@ use crate::{
gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet}, gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
peers::{PeerPacketFilter, peer_manager::PeerManager}, peers::{PeerPacketFilter, peer_manager::PeerManager},
tunnel::{ tunnel::{
common::{reserve_buf, setup_sokcet2}, common::{reserve_buf, setup_socket2},
packet_def::{PacketType, ZCPacket}, packet_def::{PacketType, ZCPacket},
}, },
}; };
@@ -72,7 +72,7 @@ impl UdpNatEntry {
Some(socket2::Protocol::UDP), Some(socket2::Protocol::UDP),
)?; )?;
let dst_socket_addr = "0.0.0.0:0".parse().unwrap(); let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
setup_sokcet2(&socket2_socket, &dst_socket_addr)?; setup_socket2(&socket2_socket, &dst_socket_addr, true)?;
Some(UdpSocket::from_std(socket2_socket.into())?) Some(UdpSocket::from_std(socket2_socket.into())?)
}; };
+5 -3
View File
@@ -25,7 +25,7 @@ use crate::{
pub fn create_listener_by_url( pub fn create_listener_by_url(
l: &url::Url, l: &url::Url,
#[allow(unused_variables)] ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
) -> Result<Box<dyn TunnelListener>, Error> { ) -> Result<Box<dyn TunnelListener>, Error> {
Ok(match l.try_into()? { Ok(match l.try_into()? {
TunnelScheme::Ip(scheme) => match scheme { TunnelScheme::Ip(scheme) => match scheme {
@@ -34,7 +34,7 @@ pub fn create_listener_by_url(
#[cfg(feature = "wireguard")] #[cfg(feature = "wireguard")]
IpScheme::Wg => { IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelListener}; use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
let nid = ctx.get_network_identity(); let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity( let wg_config = WgConfig::new_from_network_identity(
&nid.network_name, &nid.network_name,
&nid.network_secret.unwrap_or_default(), &nid.network_secret.unwrap_or_default(),
@@ -42,7 +42,9 @@ pub fn create_listener_by_url(
WgTunnelListener::new(l.clone(), wg_config).boxed() WgTunnelListener::new(l.clone(), wg_config).boxed()
} }
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelListener::new(l.clone()).boxed(), IpScheme::Quic => {
tunnel::quic::QuicTunnelListener::new(l.clone(), global_ctx.clone()).boxed()
}
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
IpScheme::Ws | IpScheme::Wss => { IpScheme::Ws | IpScheme::Wss => {
tunnel::websocket::WsTunnelListener::new(l.clone()).boxed() tunnel::websocket::WsTunnelListener::new(l.clone()).boxed()
+7 -4
View File
@@ -344,10 +344,11 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
None None
} }
pub(crate) fn setup_sokcet2_ext( pub(crate) fn setup_socket2_ext(
socket2_socket: &socket2::Socket, socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr, bind_addr: &SocketAddr,
#[allow(unused_variables)] bind_dev: Option<String>, #[allow(unused_variables)] bind_dev: Option<String>,
only_v6: bool,
) -> Result<(), TunnelError> { ) -> Result<(), TunnelError> {
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
{ {
@@ -356,7 +357,7 @@ pub(crate) fn setup_sokcet2_ext(
} }
if bind_addr.is_ipv6() { if bind_addr.is_ipv6() {
socket2_socket.set_only_v6(true)?; socket2_socket.set_only_v6(only_v6)?;
} }
socket2_socket.set_nonblocking(true)?; socket2_socket.set_nonblocking(true)?;
@@ -428,14 +429,16 @@ where
Err(last_err.unwrap_or(TunnelError::Shutdown)) Err(last_err.unwrap_or(TunnelError::Shutdown))
} }
pub(crate) fn setup_sokcet2( pub(crate) fn setup_socket2(
socket2_socket: &socket2::Socket, socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr, bind_addr: &SocketAddr,
only_v6: bool,
) -> Result<(), TunnelError> { ) -> Result<(), TunnelError> {
setup_sokcet2_ext( setup_socket2_ext(
socket2_socket, socket2_socket,
bind_addr, bind_addr,
super::common::get_interface_name_by_ip(&bind_addr.ip()), super::common::get_interface_name_by_ip(&bind_addr.ip()),
only_v6,
) )
} }
+1 -1
View File
@@ -46,7 +46,7 @@ pub mod unix;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum TunnelError { pub enum TunnelError {
#[error("io error")] #[error("io error: {0}")]
IOError(#[from] std::io::Error), IOError(#[from] std::io::Error),
#[error("invalid packet. msg: {0}")] #[error("invalid packet. msg: {0}")]
InvalidPacket(String), InvalidPacket(String),
+358 -138
View File
@@ -2,22 +2,25 @@
//! //!
//! Checkout the `README.md` for guidance. //! Checkout the `README.md` for guidance.
use std::{ use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
error::Error, io::IoSliceMut, net::SocketAddr, pin::Pin, sync::Arc, task::Poll, time::Duration, use crate::common::global_ctx::ArcGlobalCtx;
};
use crate::tunnel::{ use crate::tunnel::{
FromUrl, TunnelInfo, TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_sokcet2}, common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
}; };
use anyhow::Context; use anyhow::Context;
use derivative::Derivative;
use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener}; use derive_more::{Deref, DerefMut};
use parking_lot::RwLock;
use quinn::{ use quinn::{
AsyncUdpSocket, ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig,
TransportConfig, UdpPoller, congestion::BbrConfig, udp::RecvMeta, congestion::BbrConfig, default_runtime,
}; };
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock;
use std::{net::SocketAddr, sync::Arc, time::Duration};
// region config
pub fn transport_config() -> Arc<TransportConfig> { pub fn transport_config() -> Arc<TransportConfig> {
let mut config = TransportConfig::default(); let mut config = TransportConfig::default();
@@ -50,86 +53,287 @@ pub fn endpoint_config() -> EndpointConfig {
config.max_udp_payload_size(65527).unwrap(); config.max_udp_payload_size(65527).unwrap();
config config
} }
//endregion
#[derive(Clone, Debug)] //region rw pool
struct NoGroAsyncUdpSocket { #[derive(Derivative)]
inner: Arc<dyn AsyncUdpSocket>, #[derivative(Default(bound = ""))]
#[derive(Debug, Deref, DerefMut)]
struct RwPoolInner<Item> {
#[deref]
#[deref_mut]
pool: Vec<Item>,
enabled: bool,
} }
impl AsyncUdpSocket for NoGroAsyncUdpSocket { #[derive(Debug)]
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> { struct RwPool<Item> {
self.inner.clone().create_io_poller() ephemeral: RwLock<RwPoolInner<Item>>,
persistent: RwLock<RwPoolInner<Item>>,
capacity: usize,
}
impl<Item> RwPool<Item> {
fn new(capacity: usize) -> Self {
Self {
ephemeral: RwLock::new(RwPoolInner::default()),
persistent: RwLock::new(RwPoolInner::default()),
capacity,
}
} }
fn try_send(&self, transmit: &quinn::udp::Transmit) -> std::io::Result<()> { /// return the capacity of the ephemeral pool;
self.inner.try_send(transmit) /// if `ephemeral` or `persistent` is None, read lock `self`'s pool
} fn capacity(
/// Receive UDP datagrams, or register to be woken if receiving may succeed in the future
fn poll_recv(
&self, &self,
cx: &mut std::task::Context, ephemeral: Option<&RwPoolInner<Item>>,
bufs: &mut [IoSliceMut<'_>], persistent: Option<&RwPoolInner<Item>>,
meta: &mut [RecvMeta], ) -> usize {
) -> Poll<std::io::Result<usize>> { let guard;
self.inner.poll_recv(cx, bufs, meta) let ephemeral = if let Some(ephemeral) = ephemeral {
ephemeral
} else {
guard = self.ephemeral.read();
&guard
};
let guard;
let persistent = if let Some(persistent) = persistent {
persistent
} else {
guard = self.persistent.read();
&guard
};
(self.capacity * ephemeral.enabled as usize).saturating_sub(persistent.len())
} }
/// Look up the local IP address and port used by this socket fn is_full(&self) -> bool {
fn local_addr(&self) -> std::io::Result<SocketAddr> { let pool = self.ephemeral.read();
self.inner.local_addr() pool.len() >= self.capacity(Some(&pool), None)
} }
fn may_fragment(&self) -> bool { fn is_enabled(&self) -> bool {
self.inner.may_fragment() self.ephemeral.read().enabled
} }
fn max_transmit_segments(&self) -> usize { fn enable(&self) {
self.inner.max_transmit_segments() self.ephemeral.write().enabled = true;
self.resize();
} }
fn max_receive_segments(&self) -> usize { fn disable(&self) {
1 self.ephemeral.write().enabled = false;
self.resize();
}
/// push an item to the persistent pool
fn push(&self, item: Item) {
self.persistent.write().push(item);
self.resize();
}
/// try to push an item to the ephemeral pool, return the item if full
fn try_push(&self, item: Item) -> Option<Item> {
let mut pool = self.ephemeral.write();
if pool.len() < self.capacity(Some(&pool), None) {
pool.push(item);
return None;
}
Some(item)
}
fn resize(&self) {
let resize = {
let pool = self.ephemeral.read();
pool.capacity() != self.capacity(Some(&pool), None)
};
if resize {
let mut pool = self.ephemeral.write();
let capacity = self.capacity(Some(&pool), None);
pool.reserve_exact(capacity);
pool.truncate(capacity);
pool.shrink_to(capacity);
}
}
fn with_iter<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut dyn Iterator<Item = &Item>) -> R,
{
let ephemeral = self.ephemeral.read();
let persistent = self.persistent.read();
f(&mut persistent.iter().chain(ephemeral.iter()))
}
}
//endregion
//region endpoint manager
#[derive(Debug)]
pub struct QuicEndpointManager {
ipv4: RwPool<Endpoint>,
ipv6: RwPool<Endpoint>,
both: RwPool<Endpoint>,
}
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
impl QuicEndpointManager {
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
let socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
.map_err(std::io::Error::other)?;
let socket = std::net::UdpSocket::from(socket);
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config(),
None,
runtime.wrap_udp_socket(socket)?,
runtime,
)?;
endpoint.set_default_client_config(client_config());
Ok(endpoint)
}
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
where
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
{
loop {
let (pool, r) = selector(self);
let Some((addr, dual_stack)) = r else {
return Ok((pool, None));
};
let endpoint = Self::try_create(addr, dual_stack);
if let Err(e) = endpoint.as_ref()
&& dual_stack
{
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
self.both.disable();
self.ipv4.enable();
self.ipv6.enable();
continue;
}
return Ok((pool, Some(endpoint?)));
}
} }
} }
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address impl QuicEndpointManager {
/// and port. fn new(capacity: usize) -> Self {
/// let ipv4 = RwPool::new(capacity.div_ceil(2));
/// ## Returns let ipv6 = RwPool::new(capacity.div_ceil(2));
/// let both = RwPool::new(capacity);
/// - an [`Endpoint`] configured to accept incoming QUIC connections both.enable();
#[allow(unused)] Self { ipv4, ipv6, both }
pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<Endpoint, Box<dyn Error>> { }
let server_config = server_config();
let client_config = client_config();
let endpoint_config = endpoint_config();
let socket2_socket = socket2::Socket::new( fn load(global_ctx: &ArcGlobalCtx) -> &Self {
socket2::Domain::for_address(bind_addr), let capacity = global_ctx
socket2::Type::DGRAM, .config
Some(socket2::Protocol::UDP), .get_flags()
)?; .multi_thread
setup_sokcet2(&socket2_socket, &bind_addr)?; .then(std::thread::available_parallelism)
let socket = std::net::UdpSocket::from(socket2_socket); .and_then(|r| r.ok())
.map(|n| n.get())
.unwrap_or(1);
let runtime = let mgr = QUIC_ENDPOINT_MANAGER.get();
quinn::default_runtime().ok_or_else(|| std::io::Error::other("no async runtime found"))?; match mgr {
let socket: NoGroAsyncUdpSocket = NoGroAsyncUdpSocket { Some(mgr) => {
inner: runtime.wrap_udp_socket(socket)?, for pool in [&mgr.ipv4, &mgr.ipv6, &mgr.both] {
}; pool.resize();
let mut endpoint = Endpoint::new_with_abstract_socket( }
endpoint_config, }
Some(server_config), None => {
Arc::new(socket), let _ = QUIC_ENDPOINT_MANAGER.set(Self::new(capacity));
runtime, }
)?; }
endpoint.set_default_client_config(client_config);
Ok(endpoint) QUIC_ENDPOINT_MANAGER.get().unwrap()
}
/// Get a QUIC endpoint to be used as a server
///
/// # Arguments
/// * `addr`: listen address
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
let dual_stack = addr.ip() == Ipv6Addr::UNSPECIFIED && mgr.both.is_enabled();
let pool = if addr.is_ipv4() {
&mgr.ipv4
} else if dual_stack {
&mgr.both
} else {
&mgr.ipv6
};
(pool, Some((addr, dual_stack)))
})?;
let endpoint = endpoint.expect("server endpoint creation should not return None");
endpoint.set_server_config(Some(server_config()));
pool.push(endpoint.clone());
Ok(endpoint)
}
/// Get a quic endpoint to be used as a client
///
/// # Arguments
/// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
let dual_stack = mgr.both.is_enabled();
let (pool, addr) = match ip_version {
IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()),
_ => {
let pool = if dual_stack { &mgr.both } else { &mgr.ipv6 };
(pool, (Ipv6Addr::UNSPECIFIED, 0).into())
}
};
if pool.is_full() {
(pool, None)
} else {
(pool, Some((addr, dual_stack)))
}
})?;
if let Some(endpoint) = endpoint {
pool.try_push(endpoint);
}
Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone()))
}
async fn connect(
global_ctx: &ArcGlobalCtx,
addr: SocketAddr,
) -> std::io::Result<(Endpoint, Connection)> {
let ip_version = if addr.ip().is_ipv4() {
IpVersion::V4
} else {
IpVersion::V6
};
let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint
.connect(addr, "localhost")
.map_err(std::io::Error::other)?
.await?;
Ok((endpoint, connection))
}
} }
//endregion
#[allow(unused)]
pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"];
struct ConnWrapper { struct ConnWrapper {
conn: Connection, conn: Connection,
@@ -143,13 +347,15 @@ impl Drop for ConnWrapper {
pub struct QuicTunnelListener { pub struct QuicTunnelListener {
addr: url::Url, addr: url::Url,
global_ctx: ArcGlobalCtx,
endpoint: Option<Endpoint>, endpoint: Option<Endpoint>,
} }
impl QuicTunnelListener { impl QuicTunnelListener {
pub fn new(addr: url::Url) -> Self { pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
QuicTunnelListener { QuicTunnelListener {
addr, addr,
global_ctx,
endpoint: None, endpoint: None,
} }
} }
@@ -192,13 +398,11 @@ impl QuicTunnelListener {
impl TunnelListener for QuicTunnelListener { impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let endpoint = make_server_endpoint(addr) let endpoint = QuicEndpointManager::server(&self.global_ctx, addr)?;
.map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?;
self.endpoint = Some(endpoint);
self.addr self.addr
.set_port(Some(self.endpoint.as_ref().unwrap().local_addr()?.port())) .set_port(Some(endpoint.local_addr()?.port()))
.unwrap(); .unwrap();
self.endpoint = Some(endpoint);
Ok(()) Ok(())
} }
@@ -222,15 +426,15 @@ impl TunnelListener for QuicTunnelListener {
pub struct QuicTunnelConnector { pub struct QuicTunnelConnector {
addr: url::Url, addr: url::Url,
endpoint: Option<Endpoint>, global_ctx: ArcGlobalCtx,
ip_version: IpVersion, ip_version: IpVersion,
} }
impl QuicTunnelConnector { impl QuicTunnelConnector {
pub fn new(addr: url::Url) -> Self { pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
QuicTunnelConnector { QuicTunnelConnector {
addr, addr,
endpoint: None, global_ctx,
ip_version: IpVersion::Both, ip_version: IpVersion::Both,
} }
} }
@@ -240,38 +444,10 @@ impl QuicTunnelConnector {
impl TunnelConnector for QuicTunnelConnector { impl TunnelConnector for QuicTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if addr.port() == 0 { let (endpoint, connection) = QuicEndpointManager::connect(&self.global_ctx, addr).await?;
return Err(TunnelError::InvalidAddr(format!(
"invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)",
self.addr
)));
}
let local_addr = if addr.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
};
let mut endpoint = Endpoint::client(local_addr.parse().unwrap())?;
endpoint.set_default_client_config(client_config());
// connect to server
let connection = endpoint
.connect(addr, "localhost")
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to create QUIC connection, url: {}, error: {}",
self.addr, e
))
})?
.await
.with_context(|| "connect failed")?;
tracing::info!("[client] connected: addr={}", connection.remote_address());
let local_addr = endpoint.local_addr()?; let local_addr = endpoint.local_addr()?;
self.endpoint = Some(endpoint);
let (w, r) = connection let (w, r) = connection
.open_bi() .open_bi()
.await .await
@@ -308,68 +484,112 @@ impl TunnelConnector for QuicTunnelConnector {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::common::global_ctx::tests::get_mock_global_ctx_with_network;
use crate::tunnel::{ use crate::tunnel::{
IpVersion, TunnelConnector, TunnelConnector,
common::tests::{_tunnel_bench, _tunnel_pingpong}, common::tests::{_tunnel_bench, _tunnel_pingpong},
}; };
use std::sync::LazyLock;
use tokio::runtime::{Builder, Runtime};
use super::*; use super::*;
#[tokio::test] // Shared runtime for all tests to avoid endpoint invalidation across runtimes
async fn quic_pingpong() { static RUNTIME: LazyLock<Runtime> =
let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap()); LazyLock::new(|| Builder::new_multi_thread().enable_all().build().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
fn global_ctx() -> ArcGlobalCtx {
let identity = crate::common::config::NetworkIdentity::default();
get_mock_global_ctx_with_network(Some(identity))
}
#[test]
fn quic_pingpong() {
RUNTIME.block_on(quic_pingpong_impl())
}
async fn quic_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::]:21011".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap(), global_ctx());
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
#[tokio::test] #[test]
async fn quic_bench() { fn quic_bench() {
let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap()); RUNTIME.block_on(quic_bench_impl())
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap()); }
async fn quic_bench_impl() {
let listener = QuicTunnelListener::new("quic://[::]:21012".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap(), global_ctx());
_tunnel_bench(listener, connector).await _tunnel_bench(listener, connector).await
} }
#[tokio::test] #[test]
async fn ipv6_pingpong() { fn ipv6_pingpong() {
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap()); RUNTIME.block_on(ipv6_pingpong_impl())
let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap()); }
async fn ipv6_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
let connector =
QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap(), global_ctx());
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
#[tokio::test] #[test]
async fn ipv6_domain_pingpong() { fn ipv6_domain_pingpong() {
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap()); RUNTIME.block_on(ipv6_domain_pingpong_impl())
let mut connector = }
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap()); async fn ipv6_domain_pingpong_impl() {
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap(), global_ctx());
let mut connector = QuicTunnelConnector::new(
"quic://test.easytier.top:31016".parse().unwrap(),
global_ctx(),
);
connector.set_ip_version(IpVersion::V6); connector.set_ip_version(IpVersion::V6);
_tunnel_pingpong(listener, connector).await; _tunnel_pingpong(listener, connector).await;
let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap()); let listener =
let mut connector = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap(), global_ctx());
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap()); let mut connector = QuicTunnelConnector::new(
"quic://test.easytier.top:31016".parse().unwrap(),
global_ctx(),
);
connector.set_ip_version(IpVersion::V4); connector.set_ip_version(IpVersion::V4);
_tunnel_pingpong(listener, connector).await; _tunnel_pingpong(listener, connector).await;
} }
#[tokio::test] #[test]
async fn test_alloc_port() { fn alloc_port() {
RUNTIME.block_on(alloc_port_impl())
}
async fn alloc_port_impl() {
// v4 // v4
let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap()); let mut listener =
QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap(), global_ctx());
listener.listen().await.unwrap(); listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap(); let port = listener.local_url().port().unwrap();
assert!(port > 0); assert!(port > 0);
// v6 // v6
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap()); let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap(), global_ctx());
listener.listen().await.unwrap(); listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap(); let port = listener.local_url().port().unwrap();
assert!(port > 0); assert!(port > 0);
} }
#[tokio::test] #[test]
async fn quic_connector_reject_port_zero() { fn invalid_peer_addr() {
let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap()); RUNTIME.block_on(invalid_peer_addr_impl())
let err = connector.connect().await.unwrap_err().to_string(); }
assert!(err.contains("port 0"), "unexpected error: {}", err); async fn invalid_peer_addr_impl() {
let mut connector =
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
let err = connector.connect().await.unwrap_err();
assert!(
err.to_string().contains("invalid remote address"),
"unexpected error: {:?}",
err
);
} }
} }
+3 -3
View File
@@ -1,7 +1,7 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo}; use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::setup_sokcet2; use crate::tunnel::common::setup_socket2;
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::FuturesUnordered; use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
@@ -66,7 +66,7 @@ impl TunnelListener for TcpTunnelListener {
socket2::Type::STREAM, socket2::Type::STREAM,
Some(socket2::Protocol::TCP), Some(socket2::Protocol::TCP),
)?; )?;
setup_sokcet2(&socket2_socket, &addr)?; setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into()); let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) { if let Err(e) = socket.set_nodelay(true) {
@@ -175,7 +175,7 @@ impl TcpTunnelConnector {
Some(socket2::Protocol::TCP), Some(socket2::Protocol::TCP),
)?; )?;
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) { if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue; continue;
} }
+5 -5
View File
@@ -24,7 +24,7 @@ use tracing::{Instrument, instrument};
use super::{ use super::{
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener, FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl, TunnelUrl,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket}, packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket},
ring::{RingSink, RingStream}, ring::{RingSink, RingStream},
}; };
@@ -545,9 +545,9 @@ impl TunnelListener for UdpTunnelListener {
let tunnel_url: TunnelUrl = self.addr.clone().into(); let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() { if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
} else { } else {
setup_sokcet2(&socket2_socket, &addr)?; setup_socket2(&socket2_socket, &addr, true)?;
} }
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
@@ -838,7 +838,7 @@ impl UdpTunnelConnector {
socket2::Type::DGRAM, socket2::Type::DGRAM,
Some(socket2::Protocol::UDP), Some(socket2::Protocol::UDP),
)?; )?;
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) { if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue; continue;
} }
@@ -1040,7 +1040,7 @@ mod tests {
Some(socket2::Protocol::UDP), Some(socket2::Protocol::UDP),
) )
.unwrap(); .unwrap();
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap(); setup_socket2_ext(&socket2_socket, &addr, bind_dev.clone(), true).unwrap();
} }
} }
+3 -3
View File
@@ -1,6 +1,6 @@
use super::{ use super::{
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener, FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
common::{TunnelWrapper, setup_sokcet2, wait_for_connect_futures}, common::{TunnelWrapper, setup_socket2, wait_for_connect_futures},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider}, insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType}, packet_def::{ZCPacket, ZCPacketType},
}; };
@@ -166,7 +166,7 @@ impl TunnelListener for WsTunnelListener {
socket2::Type::STREAM, socket2::Type::STREAM,
Some(socket2::Protocol::TCP), Some(socket2::Protocol::TCP),
)?; )?;
setup_sokcet2(&socket2_socket, &addr)?; setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into()); let socket = TcpSocket::from_std_stream(socket2_socket.into());
self.addr self.addr
@@ -291,7 +291,7 @@ impl WsTunnelConnector {
Some(socket2::Protocol::TCP), Some(socket2::Protocol::TCP),
)?; )?;
if let Err(e) = setup_sokcet2(&socket2_socket, bind_addr) { if let Err(e) = setup_socket2(&socket2_socket, bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue; continue;
} }
+5 -5
View File
@@ -23,7 +23,7 @@ use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use super::{ use super::{
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink, FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream, ZCPacketStream,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, common::{setup_socket2, setup_socket2_ext, wait_for_connect_futures},
generate_digest_from_str, generate_digest_from_str,
packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType}, packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacketType},
ring::create_ring_tunnel_pair, ring::create_ring_tunnel_pair,
@@ -563,9 +563,9 @@ impl TunnelListener for WgTunnelListener {
let tunnel_url: TunnelUrl = self.addr.clone().into(); let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() { if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; setup_socket2_ext(&socket2_socket, &addr, Some(bind_dev), true)?;
} else { } else {
setup_sokcet2(&socket2_socket, &addr)?; setup_socket2(&socket2_socket, &addr, true)?;
} }
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
@@ -700,7 +700,7 @@ impl WgTunnelConnector {
socket2::Type::DGRAM, socket2::Type::DGRAM,
Some(socket2::Protocol::UDP), Some(socket2::Protocol::UDP),
)?; )?;
setup_sokcet2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None)?; setup_socket2_ext(&socket2_socket, &"[::]:0".parse().unwrap(), None, true)?;
let socket = UdpSocket::from_std(socket2_socket.into())?; let socket = UdpSocket::from_std(socket2_socket.into())?;
Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await Self::connect_with_socket(self.addr.clone(), self.config.clone(), socket, addr).await
} }
@@ -728,7 +728,7 @@ impl super::TunnelConnector for WgTunnelConnector {
socket2::Type::DGRAM, socket2::Type::DGRAM,
Some(socket2::Protocol::UDP), Some(socket2::Protocol::UDP),
)?; )?;
if let Err(e) = setup_sokcet2(&socket2_socket, &bind_addr) { if let Err(e) = setup_socket2(&socket2_socket, &bind_addr, true) {
tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e); tracing::error!(bind_addr = ?bind_addr, ?addr, "bind addr fail: {:?}", e);
continue; continue;
} }