mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 02:09:06 +00:00
tunnel(bind): gather all bind logic to a single function (#2070)
* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
+138
-43
@@ -1,3 +1,7 @@
|
||||
use bon::builder;
|
||||
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
|
||||
use network_interface::NetworkInterfaceConfig as _;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
any::Any,
|
||||
net::{IpAddr, SocketAddr},
|
||||
@@ -5,26 +9,21 @@ use std::{
|
||||
sync::{Arc, Mutex},
|
||||
task::{Poll, ready},
|
||||
};
|
||||
|
||||
use futures::{Future, Sink, Stream, stream::FuturesUnordered};
|
||||
use network_interface::NetworkInterfaceConfig as _;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::io::poll_write_buf;
|
||||
use zerocopy::FromBytes as _;
|
||||
|
||||
use super::TunnelInfo;
|
||||
|
||||
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
|
||||
|
||||
use super::{
|
||||
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
|
||||
buf::BufList,
|
||||
packet_def::{TCP_TUNNEL_HEADER_SIZE, TCPTunnelHeader, ZCPacketType},
|
||||
};
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::tunnel::packet_def::{PEER_MANAGER_HEADER_SIZE, ZCPacket};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::io::poll_write_buf;
|
||||
use zerocopy::FromBytes as _;
|
||||
|
||||
pub struct TunnelWrapper<R, W> {
|
||||
reader: Arc<Mutex<Option<R>>>,
|
||||
@@ -344,7 +343,70 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn setup_socket2_ext(
|
||||
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||
mut futures: FuturesUnordered<Fut>,
|
||||
) -> Result<Ret, TunnelError>
|
||||
where
|
||||
Fut: Future<Output = Result<Ret, E>> + Send,
|
||||
E: std::error::Error + Into<TunnelError> + Send + 'static,
|
||||
{
|
||||
// return last error
|
||||
let mut last_err = None;
|
||||
|
||||
while let Some(ret) = futures.next().await {
|
||||
if let Err(e) = ret {
|
||||
last_err = Some(e.into());
|
||||
} else {
|
||||
return ret.map_err(|e| e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||
}
|
||||
|
||||
// region bind
|
||||
|
||||
pub trait Bindable: Sized {
|
||||
const TYPE: socket2::Type;
|
||||
const PROTOCOL: Option<socket2::Protocol>;
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError>;
|
||||
}
|
||||
|
||||
impl Bindable for TcpSocket {
|
||||
const TYPE: socket2::Type = socket2::Type::STREAM;
|
||||
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
|
||||
let socket = TcpSocket::from_std_stream(socket.into());
|
||||
|
||||
if let Err(error) = socket.set_nodelay(true) {
|
||||
tracing::warn!(?error, "set_nodelay failed for tcp socket");
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bindable for TcpListener {
|
||||
const TYPE: socket2::Type = socket2::Type::STREAM;
|
||||
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::TCP);
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
|
||||
Ok(TcpSocket::finalize(socket)?.listen(1024)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bindable for UdpSocket {
|
||||
const TYPE: socket2::Type = socket2::Type::DGRAM;
|
||||
const PROTOCOL: Option<socket2::Protocol> = Some(socket2::Protocol::UDP);
|
||||
|
||||
fn finalize(socket: socket2::Socket) -> Result<Self, TunnelError> {
|
||||
Ok(UdpSocket::from_std(socket.into())?)
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_socket2_ext(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
#[allow(unused_variables)] bind_dev: Option<String>,
|
||||
@@ -408,38 +470,69 @@ pub(crate) fn setup_socket2_ext(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||
mut futures: FuturesUnordered<Fut>,
|
||||
) -> Result<Ret, TunnelError>
|
||||
where
|
||||
Fut: Future<Output = Result<Ret, E>> + Send,
|
||||
E: std::error::Error + Into<TunnelError> + Send + 'static,
|
||||
{
|
||||
// return last error
|
||||
let mut last_err = None;
|
||||
|
||||
while let Some(ret) = futures.next().await {
|
||||
if let Err(e) = ret {
|
||||
last_err = Some(e.into());
|
||||
} else {
|
||||
return ret.map_err(|e| e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub enum BindDev {
|
||||
#[default]
|
||||
Auto,
|
||||
Disabled,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
pub(crate) fn setup_socket2(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
only_v6: bool,
|
||||
) -> Result<(), TunnelError> {
|
||||
setup_socket2_ext(
|
||||
socket2_socket,
|
||||
bind_addr,
|
||||
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||
only_v6,
|
||||
)
|
||||
impl From<String> for BindDev {
|
||||
fn from(value: String) -> Self {
|
||||
if value.is_empty() {
|
||||
Self::Disabled
|
||||
} else {
|
||||
Self::Custom(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for BindDev {
|
||||
fn from(value: &str) -> Self {
|
||||
value.to_string().into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Binds a socket to a specific address and optionally a network interface.
|
||||
///
|
||||
/// This function creates a new socket, applies specific configurations (such as
|
||||
/// binding to a device or setting IPv6-only flags), and finalizes it into the
|
||||
/// requested [`Bindable`] type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `addr` - The `SocketAddr` to bind the socket to.
|
||||
/// * `dev` - The name of the network interface to bind to:
|
||||
/// * **(default) `BindDev::Auto`**: Enables **auto-discovery**. The function will attempt to automatically
|
||||
/// resolve the interface name associated with the provided `addr.ip()`.
|
||||
/// * **empty string or `BindDev::Disabled`**: **Disables** auto-discovery and
|
||||
/// explicitly chooses **not** to bind to any specific device. The routing will be
|
||||
/// left entirely to the OS.
|
||||
/// * **non-empty string or `BindDev::Custom(..)`**: Skips auto-discovery and explicitly binds to
|
||||
/// the specified interface.
|
||||
/// * `net_ns` - An optional network namespace to switch into before creating the socket.
|
||||
/// * `only_v6` - If `true`, sets the `IPV6_V6ONLY` flag on the socket.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns a [`TunnelError`] if socket creation, configuration, or finalization fails.
|
||||
#[builder]
|
||||
pub fn bind<B: Bindable>(
|
||||
addr: SocketAddr,
|
||||
#[builder(default, into)] dev: BindDev,
|
||||
net_ns: Option<NetNS>,
|
||||
#[builder(default)] only_v6: bool,
|
||||
) -> Result<B, TunnelError> {
|
||||
let _g = net_ns.map(|n| n.guard());
|
||||
let dev = match dev {
|
||||
BindDev::Auto => get_interface_name_by_ip(&addr.ip()),
|
||||
BindDev::Disabled => None,
|
||||
BindDev::Custom(s) => Some(s),
|
||||
};
|
||||
let socket = socket2::Socket::new(socket2::Domain::for_address(addr), B::TYPE, B::PROTOCOL)?;
|
||||
setup_socket2_ext(&socket, &addr, dev, only_v6)?;
|
||||
B::finalize(socket)
|
||||
}
|
||||
|
||||
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
||||
@@ -448,6 +541,8 @@ pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
||||
}
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
pub mod tests {
|
||||
use atomic_shim::AtomicU64;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
Reference in New Issue
Block a user