tunnel(bind): gather all bind logic to a single function (#2070)

* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
Luna Yao
2026-04-12 16:16:58 +02:00
committed by GitHub
parent 869e1b89f5
commit 6f3e708679
10 changed files with 370 additions and 5846 deletions
+20 -47
View File
@@ -18,7 +18,7 @@ use crate::gateway::kcp_proxy::NatDstKcpConnector;
use crate::{
common::{
config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background,
netns::NetNS, scoped_task::ScopedTask,
scoped_task::ScopedTask,
},
gateway::{
fast_socks5::{
@@ -30,10 +30,7 @@ use crate::{
ip_reassembler::IpReassembler,
tokio_smoltcp::{BufferSize, Net, NetConfig, channel_device},
},
tunnel::{
common::setup_socket2,
packet_def::{PacketType, ZCPacket},
},
tunnel::packet_def::{PacketType, ZCPacket},
};
use anyhow::Context;
use dashmap::DashMap;
@@ -42,21 +39,21 @@ use pnet::packet::{
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpSocket, UdpSocket},
net::{TcpListener, UdpSocket},
select,
sync::{Mutex, Notify, mpsc},
task::JoinSet,
time::timeout,
};
#[cfg(feature = "kcp")]
use super::tcp_proxy::NatDstConnector as _;
use crate::tunnel::common::bind;
use crate::{
common::{error::Error, global_ctx::GlobalCtx},
peers::{PeerPacketFilter, peer_manager::PeerManager},
};
#[cfg(feature = "kcp")]
use super::tcp_proxy::NatDstConnector as _;
enum SocksUdpSocket {
UdpSocket(Arc<tokio::net::UdpSocket>),
SmolUdpSocket(super::tokio_smoltcp::UdpSocket),
@@ -328,38 +325,6 @@ impl AsyncTcpConnector for Socks5AutoConnector {
}
}
fn bind_tcp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<TcpListener, Error> {
let _g = net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in listen");
}
Ok(socket.listen(1024)?)
}
fn bind_udp_socket(addr: SocketAddr, net_ns: NetNS) -> Result<UdpSocket, Error> {
let _g = net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_socket2(&socket2_socket, &addr, true)?;
Ok(UdpSocket::from_std(socket2_socket.into())?)
}
struct Socks5ServerNet {
ipv4_addr: cidr::Ipv4Inet,
auth: Option<SimpleUserPassword>,
@@ -702,10 +667,10 @@ impl Socks5Server {
proxy_url.port().unwrap()
);
let listener = bind_tcp_socket(
bind_addr.parse::<SocketAddr>().unwrap(),
self.global_ctx.net_ns.clone(),
)?;
let listener = bind::<TcpListener>()
.addr(bind_addr.parse::<SocketAddr>().unwrap())
.net_ns(self.global_ctx.net_ns.clone())
.call()?;
let entries = self.entries.clone();
let entry_count = self.entry_count.clone();
@@ -838,7 +803,10 @@ impl Socks5Server {
pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?;
let listener = bind::<TcpListener>()
.addr(bind_addr)
.net_ns(self.global_ctx.net_ns.clone())
.call()?;
let net = self.net.clone();
let entries = self.entries.clone();
@@ -906,7 +874,12 @@ impl Socks5Server {
#[tracing::instrument(name = "add_udp_port_forward", skip(self))]
pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?);
let socket = Arc::new(
bind::<UdpSocket>()
.addr(bind_addr)
.net_ns(self.global_ctx.net_ns.clone())
.call()?,
);
let entries = self.entries.clone();
let entry_count = self.entry_count.clone();