use std::net::SocketAddr; use super::{FromUrl, TunnelInfo}; use crate::tunnel::common::bind; use async_trait::async_trait; use futures::stream::FuturesUnordered; use tokio::net::{TcpListener, TcpSocket, TcpStream}; use super::{ IpVersion, Tunnel, TunnelError, TunnelListener, common::{FramedReader, FramedWriter, TunnelWrapper, wait_for_connect_futures}, }; const TCP_MTU_BYTES: usize = 2000; #[derive(Debug)] pub struct TcpTunnelListener { addr: url::Url, listener: Option, } impl TcpTunnelListener { pub fn new(addr: url::Url) -> Self { TcpTunnelListener { addr, listener: None, } } async fn do_accept(&self) -> Result, std::io::Error> { let listener = self.listener.as_ref().unwrap(); let (stream, _) = listener.accept().await?; if let Err(e) = stream.set_nodelay(true) { tracing::warn!(?e, "set_nodelay fail in accept"); } let info = TunnelInfo { tunnel_type: "tcp".to_owned(), local_addr: Some(self.local_url().into()), remote_addr: Some( super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(), ), resolved_remote_addr: Some( super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(), ), }; let (r, w) = stream.into_split(); Ok(Box::new(TunnelWrapper::new( FramedReader::new(r, TCP_MTU_BYTES), FramedWriter::new(w), Some(info), ))) } } #[async_trait] impl TunnelListener for TcpTunnelListener { async fn listen(&mut self) -> Result<(), TunnelError> { self.listener = None; let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; let listener = bind::().addr(addr).only_v6(true).call()?; self.addr .set_port(Some(listener.local_addr()?.port())) .unwrap(); self.listener = Some(listener); Ok(()) } async fn accept(&mut self) -> Result, super::TunnelError> { loop { match self.do_accept().await { Ok(ret) => return Ok(ret), Err(e) => { use std::io::ErrorKind::*; if matches!( e.kind(), NotConnected | ConnectionAborted | ConnectionRefused | ConnectionReset ) { tracing::warn!(?e, "accept fail with retryable error: {:?}", e); continue; } tracing::warn!(?e, "accept fail"); return Err(e.into()); } } } } fn local_url(&self) -> url::Url { self.addr.clone() } } fn get_tunnel_with_tcp_stream( stream: TcpStream, remote_url: url::Url, ) -> Result, super::TunnelError> { if let Err(e) = stream.set_nodelay(true) { tracing::warn!(?e, "set_nodelay fail in get_tunnel_with_tcp_stream"); } let info = TunnelInfo { tunnel_type: "tcp".to_owned(), local_addr: Some( super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp").into(), ), remote_addr: Some(remote_url.into()), resolved_remote_addr: Some( super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(), ), }; let (r, w) = stream.into_split(); Ok(Box::new(TunnelWrapper::new( FramedReader::new(r, TCP_MTU_BYTES), FramedWriter::new(w), Some(info), ))) } #[derive(Debug)] pub struct TcpTunnelConnector { addr: url::Url, bind_addrs: Vec, ip_version: IpVersion, resolved_addr: Option, } impl TcpTunnelConnector { pub fn new(addr: url::Url) -> Self { TcpTunnelConnector { addr, bind_addrs: vec![], ip_version: IpVersion::Both, resolved_addr: None, } } async fn connect_with_default_bind( &self, addr: SocketAddr, ) -> Result, super::TunnelError> { tracing::info!(url = ?self.addr, ?addr, "connect tcp start, bind addrs: {:?}", self.bind_addrs); let stream = TcpStream::connect(addr).await?; tracing::info!(url = ?self.addr, ?addr, "connect tcp succ"); get_tunnel_with_tcp_stream(stream, self.addr.clone()) } async fn connect_with_custom_bind( &self, addr: SocketAddr, ) -> Result, super::TunnelError> { let futures = FuturesUnordered::new(); for bind_addr in self.bind_addrs.iter() { tracing::info!(?bind_addr, ?addr, "bind addr"); match bind::().addr(*bind_addr).only_v6(true).call() { Ok(socket) => futures.push(socket.connect(addr)), Err(error) => { tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail"); continue; } } } let ret = wait_for_connect_futures(futures).await; get_tunnel_with_tcp_stream(ret?, self.addr.clone()) } } #[async_trait] impl super::TunnelConnector for TcpTunnelConnector { async fn connect(&mut self) -> Result, TunnelError> { let addr = match self.resolved_addr { Some(addr) => addr, None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?, }; if self.bind_addrs.is_empty() { self.connect_with_default_bind(addr).await } else { self.connect_with_custom_bind(addr).await } } fn remote_url(&self) -> url::Url { self.addr.clone() } fn set_bind_addrs(&mut self, addrs: Vec) { self.bind_addrs = addrs; } fn set_ip_version(&mut self, ip_version: IpVersion) { self.ip_version = ip_version; } fn set_resolved_addr(&mut self, addr: SocketAddr) { self.resolved_addr = Some(addr); } } #[cfg(test)] mod tests { use crate::tunnel::{ TunnelConnector, common::tests::{_tunnel_bench, _tunnel_pingpong}, }; use super::*; #[tokio::test] async fn tcp_pingpong() { let listener = TcpTunnelListener::new("tcp://0.0.0.0:31011".parse().unwrap()); let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31011".parse().unwrap()); _tunnel_pingpong(listener, connector).await } #[tokio::test] async fn tcp_bench() { let listener = TcpTunnelListener::new("tcp://0.0.0.0:31012".parse().unwrap()); let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31012".parse().unwrap()); _tunnel_bench(listener, connector).await } #[tokio::test] async fn tcp_bench_with_bind() { let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap()); let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap()); connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); _tunnel_pingpong(listener, connector).await } #[tokio::test] #[should_panic] async fn tcp_bench_with_bind_fail() { let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap()); let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap()); connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); _tunnel_pingpong(listener, connector).await } #[tokio::test] async fn bind_same_port() { let mut listener = TcpTunnelListener::new("tcp://[::]:31014".parse().unwrap()); let mut listener2 = TcpTunnelListener::new("tcp://0.0.0.0:31014".parse().unwrap()); listener.listen().await.unwrap(); listener2.listen().await.unwrap(); } #[tokio::test] async fn ipv6_pingpong() { let listener = TcpTunnelListener::new("tcp://[::1]:31015".parse().unwrap()); let connector = TcpTunnelConnector::new("tcp://[::1]:31015".parse().unwrap()); _tunnel_pingpong(listener, connector).await } #[tokio::test] async fn ipv6_domain_pingpong() { let listener = TcpTunnelListener::new("tcp://[::1]:31015".parse().unwrap()); let mut connector = TcpTunnelConnector::new("tcp://test.easytier.top:31015".parse().unwrap()); connector.set_ip_version(IpVersion::V6); _tunnel_pingpong(listener, connector).await; let listener = TcpTunnelListener::new("tcp://127.0.0.1:31015".parse().unwrap()); let mut connector = TcpTunnelConnector::new("tcp://test.easytier.top:31015".parse().unwrap()); connector.set_ip_version(IpVersion::V4); _tunnel_pingpong(listener, connector).await; } #[tokio::test] async fn connector_keeps_source_addr_and_reports_resolved_addr() { let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:0".parse().unwrap()); listener.listen().await.unwrap(); let port = listener.local_url().port().unwrap(); let source_url: url::Url = format!("tcp://localhost:{port}").parse().unwrap(); let mut connector = TcpTunnelConnector::new(source_url.clone()); connector.set_ip_version(IpVersion::V4); let accept_task = tokio::spawn(async move { listener.accept().await.unwrap() }); let tunnel = connector.connect().await.unwrap(); let accepted_tunnel = accept_task.await.unwrap(); let info = tunnel.info().unwrap(); assert_eq!(info.remote_addr.unwrap().url, source_url.to_string()); let resolved_remote_addr: url::Url = info.resolved_remote_addr.unwrap().into(); assert_eq!(resolved_remote_addr.host_str(), Some("127.0.0.1")); assert_eq!(resolved_remote_addr.port(), Some(port)); let accepted_info = accepted_tunnel.info().unwrap(); assert_eq!( accepted_info.remote_addr, accepted_info.resolved_remote_addr, ); } #[tokio::test] async fn connector_uses_pre_resolved_addr_without_resolving_url() { let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:0".parse().unwrap()); listener.listen().await.unwrap(); let port = listener.local_url().port().unwrap(); let source_url: url::Url = format!("tcp://unresolvable.invalid:{port}") .parse() .unwrap(); let resolved_addr: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap(); let mut connector = TcpTunnelConnector::new(source_url.clone()); connector.set_resolved_addr(resolved_addr); let accept_task = tokio::spawn(async move { listener.accept().await.unwrap() }); let tunnel = connector.connect().await.unwrap(); let _accepted_tunnel = accept_task.await.unwrap(); let info = tunnel.info().unwrap(); assert_eq!(info.remote_addr.unwrap().url, source_url.to_string()); let resolved_remote_addr: url::Url = info.resolved_remote_addr.unwrap().into(); assert_eq!(resolved_remote_addr.host_str(), Some("127.0.0.1")); assert_eq!(resolved_remote_addr.port(), Some(port)); } #[tokio::test] async fn test_alloc_port() { // v4 let mut listener = TcpTunnelListener::new("tcp://0.0.0.0:0".parse().unwrap()); listener.listen().await.unwrap(); let port = listener.local_url().port().unwrap(); assert!(port > 0); // v6 let mut listener = TcpTunnelListener::new("tcp://[::]:0".parse().unwrap()); listener.listen().await.unwrap(); let port = listener.local_url().port().unwrap(); assert!(port > 0); } }