Files
Easytier/easytier/src/tunnel/tcp.rs
T
Luna Yao 6f3e708679 tunnel(bind): gather all bind logic to a single function (#2070)
* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
2026-04-12 22:16:58 +08:00

312 lines
10 KiB
Rust

use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::bind;
use async_trait::async_trait;
use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use super::{
IpVersion, Tunnel, TunnelError, TunnelListener,
common::{FramedReader, FramedWriter, TunnelWrapper, wait_for_connect_futures},
};
const TCP_MTU_BYTES: usize = 2000;
#[derive(Debug)]
pub struct TcpTunnelListener {
addr: url::Url,
listener: Option<TcpListener>,
}
impl TcpTunnelListener {
pub fn new(addr: url::Url) -> Self {
TcpTunnelListener {
addr,
listener: None,
}
}
async fn do_accept(&self) -> Result<Box<dyn Tunnel>, std::io::Error> {
let listener = self.listener.as_ref().unwrap();
let (stream, _) = listener.accept().await?;
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in accept");
}
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: Some(self.local_url().into()),
remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
};
let (r, w) = stream.into_split();
Ok(Box::new(TunnelWrapper::new(
FramedReader::new(r, TCP_MTU_BYTES),
FramedWriter::new(w),
Some(info),
)))
}
}
#[async_trait]
impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let listener = bind::<TcpListener>().addr(addr).only_v6(true).call()?;
self.addr
.set_port(Some(listener.local_addr()?.port()))
.unwrap();
self.listener = Some(listener);
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
loop {
match self.do_accept().await {
Ok(ret) => return Ok(ret),
Err(e) => {
use std::io::ErrorKind::*;
if matches!(
e.kind(),
NotConnected | ConnectionAborted | ConnectionRefused | ConnectionReset
) {
tracing::warn!(?e, "accept fail with retryable error: {:?}", e);
continue;
}
tracing::warn!(?e, "accept fail");
return Err(e.into());
}
}
}
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
}
fn get_tunnel_with_tcp_stream(
stream: TcpStream,
remote_url: url::Url,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in get_tunnel_with_tcp_stream");
}
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: Some(
super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp").into(),
),
remote_addr: Some(remote_url.into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
};
let (r, w) = stream.into_split();
Ok(Box::new(TunnelWrapper::new(
FramedReader::new(r, TCP_MTU_BYTES),
FramedWriter::new(w),
Some(info),
)))
}
#[derive(Debug)]
pub struct TcpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
ip_version: IpVersion,
}
impl TcpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
TcpTunnelConnector {
addr,
bind_addrs: vec![],
ip_version: IpVersion::Both,
}
}
async fn connect_with_default_bind(
&self,
addr: SocketAddr,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
tracing::info!(url = ?self.addr, ?addr, "connect tcp start, bind addrs: {:?}", self.bind_addrs);
let stream = TcpStream::connect(addr).await?;
tracing::info!(url = ?self.addr, ?addr, "connect tcp succ");
get_tunnel_with_tcp_stream(stream, self.addr.clone())
}
async fn connect_with_custom_bind(
&self,
addr: SocketAddr,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
tracing::info!(?bind_addr, ?addr, "bind addr");
match bind::<TcpSocket>().addr(*bind_addr).only_v6(true).call() {
Ok(socket) => futures.push(socket.connect(addr)),
Err(error) => {
tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
continue;
}
}
}
let ret = wait_for_connect_futures(futures).await;
get_tunnel_with_tcp_stream(ret?, self.addr.clone())
}
}
#[async_trait]
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() {
self.connect_with_default_bind(addr).await
} else {
self.connect_with_custom_bind(addr).await
}
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
fn set_ip_version(&mut self, ip_version: IpVersion) {
self.ip_version = ip_version;
}
}
#[cfg(test)]
mod tests {
use crate::tunnel::{
TunnelConnector,
common::tests::{_tunnel_bench, _tunnel_pingpong},
};
use super::*;
#[tokio::test]
async fn tcp_pingpong() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:31011".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31011".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn tcp_bench() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:31012".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31012".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn tcp_bench_with_bind() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn tcp_bench_with_bind_fail() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn bind_same_port() {
let mut listener = TcpTunnelListener::new("tcp://[::]:31014".parse().unwrap());
let mut listener2 = TcpTunnelListener::new("tcp://0.0.0.0:31014".parse().unwrap());
listener.listen().await.unwrap();
listener2.listen().await.unwrap();
}
#[tokio::test]
async fn ipv6_pingpong() {
let listener = TcpTunnelListener::new("tcp://[::1]:31015".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://[::1]:31015".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ipv6_domain_pingpong() {
let listener = TcpTunnelListener::new("tcp://[::1]:31015".parse().unwrap());
let mut connector =
TcpTunnelConnector::new("tcp://test.easytier.top:31015".parse().unwrap());
connector.set_ip_version(IpVersion::V6);
_tunnel_pingpong(listener, connector).await;
let listener = TcpTunnelListener::new("tcp://127.0.0.1:31015".parse().unwrap());
let mut connector =
TcpTunnelConnector::new("tcp://test.easytier.top:31015".parse().unwrap());
connector.set_ip_version(IpVersion::V4);
_tunnel_pingpong(listener, connector).await;
}
#[tokio::test]
async fn connector_keeps_source_addr_and_reports_resolved_addr() {
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
let source_url: url::Url = format!("tcp://localhost:{port}").parse().unwrap();
let mut connector = TcpTunnelConnector::new(source_url.clone());
connector.set_ip_version(IpVersion::V4);
let accept_task = tokio::spawn(async move { listener.accept().await.unwrap() });
let tunnel = connector.connect().await.unwrap();
let accepted_tunnel = accept_task.await.unwrap();
let info = tunnel.info().unwrap();
assert_eq!(info.remote_addr.unwrap().url, source_url.to_string());
let resolved_remote_addr: url::Url = info.resolved_remote_addr.unwrap().into();
assert_eq!(resolved_remote_addr.host_str(), Some("127.0.0.1"));
assert_eq!(resolved_remote_addr.port(), Some(port));
let accepted_info = accepted_tunnel.info().unwrap();
assert_eq!(
accepted_info.remote_addr,
accepted_info.resolved_remote_addr,
);
}
#[tokio::test]
async fn test_alloc_port() {
// v4
let mut listener = TcpTunnelListener::new("tcp://0.0.0.0:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
// v6
let mut listener = TcpTunnelListener::new("tcp://[::]:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
}
}