tunnel(bind): gather all bind logic to a single function (#2070)

* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
Luna Yao
2026-04-12 16:16:58 +02:00
committed by GitHub
parent 869e1b89f5
commit 6f3e708679
10 changed files with 370 additions and 5846 deletions
+27 -23
View File
@@ -4,9 +4,10 @@
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use crate::common::global_ctx::ArcGlobalCtx;
use crate::tunnel::common::bind;
use crate::tunnel::{
TunnelInfo,
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
common::{FramedReader, FramedWriter, TunnelWrapper},
};
use anyhow::Context;
use derivative::Derivative;
@@ -19,6 +20,7 @@ use quinn::{
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::net::UdpSocket;
// region config
pub fn transport_config() -> Arc<TransportConfig> {
@@ -179,27 +181,28 @@ pub struct QuicEndpointManager {
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
impl QuicEndpointManager {
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
let socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
.map_err(std::io::Error::other)?;
let socket = std::net::UdpSocket::from(socket);
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
fn try_create(addr: SocketAddr, dual_stack: bool) -> Result<Endpoint, TunnelError> {
let socket = bind::<UdpSocket>()
.addr(addr)
.only_v6(addr.is_ipv6() && !dual_stack)
.call()?;
let runtime = default_runtime().ok_or(TunnelError::InternalError(
"no async runtime found".to_owned(),
))?;
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config(),
None,
runtime.wrap_udp_socket(socket)?,
runtime.wrap_udp_socket(socket.into_std()?)?,
runtime,
)?;
endpoint.set_default_client_config(client_config());
Ok(endpoint)
}
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
fn create<F>(
&self,
mut selector: F,
) -> Result<(&RwPool<Endpoint>, Option<Endpoint>), TunnelError>
where
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
{
@@ -210,10 +213,10 @@ impl QuicEndpointManager {
};
let endpoint = Self::try_create(addr, dual_stack);
if let Err(e) = endpoint.as_ref()
if let Err(error) = endpoint.as_ref()
&& dual_stack
{
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
tracing::warn!(?error, "create dual stack quic endpoint failed");
self.both.disable();
self.ipv4.enable();
self.ipv6.enable();
@@ -263,7 +266,7 @@ impl QuicEndpointManager {
///
/// # Arguments
/// * `addr`: listen address
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
@@ -289,7 +292,7 @@ impl QuicEndpointManager {
///
/// # Arguments
/// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
@@ -318,7 +321,7 @@ impl QuicEndpointManager {
async fn connect(
global_ctx: &ArcGlobalCtx,
addr: SocketAddr,
) -> std::io::Result<(Endpoint, Connection)> {
) -> Result<(Endpoint, Connection), TunnelError> {
let ip_version = if addr.ip().is_ipv4() {
IpVersion::V4
} else {
@@ -327,8 +330,9 @@ impl QuicEndpointManager {
let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint
.connect(addr, "localhost")
.map_err(std::io::Error::other)?
.await?;
.with_context(|| format!("failed to create connection to {}", addr))?
.await
.with_context(|| format!("failed to connect to {}", addr))?;
Ok((endpoint, connection))
}
@@ -585,10 +589,10 @@ mod tests {
async fn invalid_peer_addr_impl() {
let mut connector =
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
let err = connector.connect().await.unwrap_err();
let err = format!("{:?}", connector.connect().await.unwrap_err());
assert!(
err.to_string().contains("invalid remote address"),
"unexpected error: {:?}",
err.contains("invalid remote address"),
"unexpected error: {}",
err
);
}