tunnel(bind): gather all bind logic to a single function (#2070)

* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
Luna Yao
2026-04-12 16:16:58 +02:00
committed by GitHub
parent 869e1b89f5
commit 6f3e708679
10 changed files with 370 additions and 5846 deletions
+138 -43
View File
@@ -1,3 +1,7 @@
use bon::builder;
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
use network_interface::NetworkInterfaceConfig as _;
use pin_project_lite::pin_project;
use std::{
any::Any,
net::{IpAddr, SocketAddr},
@@ -5,26 +9,21 @@ use std::{
sync::{Arc, Mutex},
task::{Poll, ready},
};
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
use network_interface::NetworkInterfaceConfig as _;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _;
use super::TunnelInfo;
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
use super::{
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
buf::BufList,
packet_def::{TCP_TUNNEL_HEADER_SIZE, TCPTunnelHeader, ZCPacketType},
};
use crate::common::netns::NetNS;
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _;
pub struct TunnelWrapper<R, W> {
reader: Arc<Mutex<Option<R>>>,
@@ -344,7 +343,70 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
None
}
pub(crate) fn setup_socket2_ext(
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
mut futures: FuturesUnordered<Fut>,
) -> Result<Ret, TunnelError>
where
Fut: Future<Output = Result<Ret, E>> + Send,
E: std::error::Error + Into<TunnelError> + Send + 'static,
{
// return last error
let mut last_err = None;
while let Some(ret) = futures.next().await {
if let Err(e) = ret {
last_err = Some(e.into());
} else {
return ret.map_err(|e| e.into());
}
}
Err(last_err.unwrap_or(TunnelError::Shutdown))
}
// region bind
pub trait Bindable: Sized {
const TYPE: socket2::Type;
const PROTOCOL: Option<socket2::Protocol>;
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError>;
}
impl Bindable for TcpSocket {
const TYPE: socket2::Type = socket2::Type::STREAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
let socket = TcpSocket::from_std_stream(socket.into());
if let Err(error) = socket.set_nodelay(true) {
tracing::warn!(?error, "set_nodelay failed for tcp socket");
}
Ok(socket)
}
}
impl Bindable for TcpListener {
const TYPE: socket2::Type = socket2::Type::STREAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
Ok(TcpSocket::finalize(socket)?.listen(1024)?)
}
}
impl Bindable for UdpSocket {
const TYPE: socket2::Type = socket2::Type::DGRAM;
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::UDP);
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
Ok(UdpSocket::from_std(socket.into())?)
}
}
fn setup_socket2_ext(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
#[allow(unused_variables)] bind_dev: Option<String>,
@@ -408,38 +470,69 @@ pub(crate) fn setup_socket2_ext(
Ok(())
}
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
mut futures: FuturesUnordered<Fut>,
) -> Result<Ret, TunnelError>
where
Fut: Future<Output = Result<Ret, E>> + Send,
E: std::error::Error + Into<TunnelError> + Send + 'static,
{
// return last error
let mut last_err = None;
while let Some(ret) = futures.next().await {
if let Err(e) = ret {
last_err = Some(e.into());
} else {
return ret.map_err(|e| e.into());
}
}
Err(last_err.unwrap_or(TunnelError::Shutdown))
#[derive(Debug, Default, Clone)]
pub enum BindDev {
#[default]
Auto,
Disabled,
Custom(String),
}
pub(crate) fn setup_socket2(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
only_v6: bool,
) -> Result<(), TunnelError> {
setup_socket2_ext(
socket2_socket,
bind_addr,
super::common::get_interface_name_by_ip(&bind_addr.ip()),
only_v6,
)
impl From<String> for BindDev {
fn from(value: String) -> Self {
if value.is_empty() {
Self::Disabled
} else {
Self::Custom(value)
}
}
}
impl From<&str> for BindDev {
fn from(value: &str) -> Self {
value.to_string().into()
}
}
/// Binds a socket to a specific address and optionally a network interface.
///
/// This function creates a new socket, applies specific configurations (such as
/// binding to a device or setting IPv6-only flags), and finalizes it into the
/// requested [`Bindable`] type.
///
/// # Arguments
///
/// * `addr` - The `SocketAddr` to bind the socket to.
/// * `dev` - The name of the network interface to bind to:
/// * **(default) `BindDev::Auto`**: Enables **auto-discovery**. The function will attempt to automatically
/// resolve the interface name associated with the provided `addr.ip()`.
/// * **empty string or `BindDev::Disabled`**: **Disables** auto-discovery and
/// explicitly chooses **not** to bind to any specific device. The routing will be
/// left entirely to the OS.
/// * **non-empty string or `BindDev::Custom(..)`**: Skips auto-discovery and explicitly binds to
/// the specified interface.
/// * `net_ns` - An optional network namespace to switch into before creating the socket.
/// * `only_v6` - If `true`, sets the `IPV6_V6ONLY` flag on the socket.
///
/// # Errors
///
/// Returns a [`TunnelError`] if socket creation, configuration, or finalization fails.
#[builder]
pub fn bind<B: Bindable>(
addr: SocketAddr,
#[builder(default, into)] dev: BindDev,
net_ns: Option<NetNS>,
#[builder(default)] only_v6: bool,
) -> Result<B, TunnelError> {
let _g = net_ns.map(|n| n.guard());
let dev = match dev {
BindDev::Auto => get_interface_name_by_ip(&addr.ip()),
BindDev::Disabled => None,
BindDev::Custom(s) => Some(s),
};
let socket = socket2::Socket::new(socket2::Domain::for_address(addr), B::TYPE, B::PROTOCOL)?;
setup_socket2_ext(&socket, &addr, dev, only_v6)?;
B::finalize(socket)
}
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
@@ -448,6 +541,8 @@ pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
}
}
// endregion
pub mod tests {
use atomic_shim::AtomicU64;
use std::{sync::Arc, time::Instant};