refactor: listener/connector protocol abstraction (#2026)

* fix listener protocol detection
* replace IpProtocol with IpNextHeaderProtocol
* use an enum to gather all listener schemes
* rename ListenerScheme to TunnelScheme; replace IpNextHeaderProtocols with socket2::Protocol
* move TunnelScheme to tunnel
* add IpScheme, simplify connector creation
* format; fix some typos; remove check_scheme_...;
* remove PROTO_PORT_OFFSET
* rename WSTunnel.. -> WsTunnel.., DNSTunnel.. -> DnsTunnel..
This commit is contained in:
Luna Yao
2026-04-04 04:55:58 +02:00
committed by GitHub
parent 9cc617ae4c
commit e91a0da70a
18 changed files with 481 additions and 526 deletions
+21 -30
View File
@@ -2,20 +2,28 @@ mod netfilter;
mod packet;
mod stack;
use std::net::{IpAddr, Ipv4Addr, UdpSocket};
use std::sync::Arc;
use std::{net::SocketAddr, pin::Pin};
use bytes::BytesMut;
use futures::{Sink, Stream};
use network_interface::NetworkInterfaceConfig;
use pnet::util::MacAddr;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
use crate::common::scoped_task::ScopedTask;
use crate::tunnel::fake_tcp::netfilter::create_tun;
use crate::tunnel::{common::TunnelWrapper, Tunnel, TunnelError, TunnelInfo, TunnelListener};
use crate::{
common::scoped_task::ScopedTask,
tunnel::{
common::TunnelWrapper,
fake_tcp::netfilter::create_tun,
packet_def::{ZCPacket, ZCPacketType, PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE},
FromUrl, IpVersion, SinkError, SinkItem, StreamItem, Tunnel, TunnelConnector, TunnelError,
TunnelInfo, TunnelListener,
},
};
use futures::Future;
@@ -207,12 +215,7 @@ struct AcceptResult {
impl TunnelListener for FakeTcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let port = self.addr.port().unwrap_or(0);
let bind_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let bind_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let os_listener = tokio::net::TcpListener::bind(bind_addr).await?;
tracing::info!(port, "FakeTcpTunnelListener listening");
self.os_listener = Some(os_listener);
@@ -306,14 +309,9 @@ fn get_local_ip_for_destination(destination: IpAddr) -> Option<IpAddr> {
}
#[async_trait::async_trait]
impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
impl TunnelConnector for FakeTcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let remote_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let remote_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let local_ip = get_local_ip_for_destination(remote_addr.ip())
.ok_or(TunnelError::InternalError("Failed to get local ip".into()))?;
@@ -387,13 +385,6 @@ impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
}
}
use crate::tunnel::packet_def::{
ZCPacket, ZCPacketType, PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE,
};
use crate::tunnel::{SinkError, SinkItem, StreamItem};
use futures::{Sink, Stream};
use std::task::{Context as TaskContext, Poll};
type RecvFut = Pin<Box<dyn Future<Output = Option<(BytesMut, usize)>> + Send + Sync>>;
enum FakeTcpStreamState {
+140 -52
View File
@@ -1,16 +1,19 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
use std::{net::SocketAddr, pin::Pin, sync::Arc};
use std::{
collections::hash_map::DefaultHasher, hash::Hasher, net::SocketAddr, pin::Pin, sync::Arc,
};
use crate::{
common::{dns::socket_addrs, error::Error},
proto::common::TunnelInfo,
};
use async_trait::async_trait;
use derive_more::{From, TryInto};
use futures::{Sink, Stream};
use socket2::Protocol;
use std::fmt::Debug;
use strum::{Display, EnumString, VariantArray};
use tokio::time::error::Elapsed;
use crate::common::dns::socket_addrs;
use crate::proto::common::TunnelInfo;
use self::packet_def::ZCPacket;
pub mod buf;
@@ -23,15 +26,6 @@ pub mod stats;
pub mod tcp;
pub mod udp;
pub const PROTO_PORT_OFFSET: &[(&str, u16)] = &[
("tcp", 0),
("udp", 0),
("wg", 1),
("ws", 1),
("wss", 2),
("faketcp", 3),
];
#[cfg(feature = "faketcp")]
pub mod fake_tcp;
@@ -193,45 +187,23 @@ pub(crate) trait FromUrl {
Self: Sized;
}
pub(crate) async fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
ip_version: IpVersion,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
T::from_url(url.clone(), ip_version).await
}
fn default_port(scheme: &str) -> Option<u16> {
match scheme {
"tcp" => Some(11010),
"udp" => Some(11010),
"ws" => Some(80),
"wss" => Some(443),
"faketcp" => Some(11013),
"quic" => Some(11012),
"wg" => Some(11011),
_ => None,
}
}
#[async_trait::async_trait]
impl FromUrl for SocketAddr {
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
let addrs = socket_addrs(&url, || default_port(url.scheme()))
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
let addrs = socket_addrs(&url, || {
(&url)
.try_into()
.ok()
.and_then(|s: TunnelScheme| s.try_into().ok())
.map(IpScheme::default_port)
})
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs");
let addrs = addrs
.into_iter()
@@ -305,3 +277,119 @@ pub fn generate_digest_from_str(str1: &str, str2: &str, digest: &mut [u8]) {
hasher.write(&digest[..(i + 1) * 8]);
}
}
#[derive(Debug, Clone, Copy)]
struct IpSchemeAttributes {
protocol: Protocol,
port_offset: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, VariantArray)]
#[strum(serialize_all = "lowercase")]
pub enum IpScheme {
Tcp,
Udp,
#[cfg(feature = "wireguard")]
Wg,
#[cfg(feature = "quic")]
Quic,
#[cfg(feature = "websocket")]
Ws,
#[cfg(feature = "websocket")]
Wss,
#[cfg(feature = "faketcp")]
FakeTcp,
}
impl IpScheme {
const fn attributes(self) -> IpSchemeAttributes {
let (protocol, port_offset) = match self {
Self::Tcp => (Protocol::TCP, 0),
Self::Udp => (Protocol::UDP, 0),
#[cfg(feature = "wireguard")]
Self::Wg => (Protocol::UDP, 1),
#[cfg(feature = "quic")]
Self::Quic => (Protocol::UDP, 2),
#[cfg(feature = "websocket")]
Self::Ws => (Protocol::TCP, 1),
#[cfg(feature = "websocket")]
Self::Wss => (Protocol::TCP, 2),
#[cfg(feature = "faketcp")]
Self::FakeTcp => (Protocol::TCP, 3),
};
IpSchemeAttributes {
protocol,
port_offset,
}
}
pub const fn protocol(self) -> Protocol {
self.attributes().protocol
}
pub const fn port_offset(self) -> u16 {
self.attributes().port_offset
}
pub const fn default_port(self) -> u16 {
match self {
#[cfg(feature = "websocket")]
Self::Ws => 80,
#[cfg(feature = "websocket")]
Self::Wss => 443,
_ => 11010 + self.port_offset(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, EnumString, From, TryInto)]
#[strum(serialize_all = "lowercase")]
pub enum TunnelScheme {
#[strum(disabled)]
Ip(IpScheme),
#[cfg(unix)]
Unix,
// Only for connector
Http,
Https,
Ring,
Txt,
Srv,
}
impl TryFrom<&url::Url> for TunnelScheme {
type Error = Error;
fn try_from(value: &url::Url) -> Result<Self, Self::Error> {
let scheme = value.scheme();
scheme.parse().or_else(|_| {
Ok(TunnelScheme::Ip(
scheme
.parse()
.map_err(|_| Error::InvalidUrl(value.to_string()))?,
))
})
}
}
macro_rules! __matches_scheme__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::TunnelScheme::try_from(($url).as_ref()), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_scheme__ as matches_scheme;
pub fn get_protocol_by_url(l: &url::Url) -> Result<Protocol, Error> {
let TunnelScheme::Ip(scheme) = l.try_into()? else {
return Err(Error::InvalidUrl(l.to_string()));
};
Ok(scheme.protocol())
}
macro_rules! __matches_protocol__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::get_protocol_by_url($url), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_protocol__ as matches_protocol;
+26 -34
View File
@@ -8,20 +8,16 @@ use std::{
use crate::tunnel::{
common::{setup_sokcet2, FramedReader, FramedWriter, TunnelWrapper},
TunnelInfo,
FromUrl, TunnelInfo,
};
use anyhow::Context;
use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use quinn::{
congestion::BbrConfig, udp::RecvMeta, AsyncUdpSocket, ClientConfig, Connection, Endpoint,
EndpointConfig, ServerConfig, TransportConfig, UdpPoller,
};
use super::{
check_scheme_and_get_socket_addr, IpVersion, Tunnel, TunnelConnector, TunnelError,
TunnelListener,
};
pub fn transport_config() -> Arc<TransportConfig> {
let mut config = TransportConfig::default();
@@ -145,14 +141,14 @@ impl Drop for ConnWrapper {
}
}
pub struct QUICTunnelListener {
pub struct QuicTunnelListener {
addr: url::Url,
endpoint: Option<Endpoint>,
}
impl QUICTunnelListener {
impl QuicTunnelListener {
pub fn new(addr: url::Url) -> Self {
QUICTunnelListener {
QuicTunnelListener {
addr,
endpoint: None,
}
@@ -190,11 +186,9 @@ impl QUICTunnelListener {
}
#[async_trait::async_trait]
impl TunnelListener for QUICTunnelListener {
impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", IpVersion::Both)
.await?;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let endpoint = make_server_endpoint(addr)
.map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?;
self.endpoint = Some(endpoint);
@@ -223,15 +217,15 @@ impl TunnelListener for QUICTunnelListener {
}
}
pub struct QUICTunnelConnector {
pub struct QuicTunnelConnector {
addr: url::Url,
endpoint: Option<Endpoint>,
ip_version: IpVersion,
}
impl QUICTunnelConnector {
impl QuicTunnelConnector {
pub fn new(addr: url::Url) -> Self {
QUICTunnelConnector {
QuicTunnelConnector {
addr,
endpoint: None,
ip_version: IpVersion::Both,
@@ -240,11 +234,9 @@ impl QUICTunnelConnector {
}
#[async_trait::async_trait]
impl TunnelConnector for QUICTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", self.ip_version)
.await?;
impl TunnelConnector for QuicTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if addr.port() == 0 {
return Err(TunnelError::InvalidAddr(format!(
"invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)",
@@ -318,36 +310,36 @@ mod tests {
#[tokio::test]
async fn quic_pingpong() {
let listener = QUICTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn quic_bench() {
let listener = QUICTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn ipv6_pingpong() {
let listener = QUICTunnelListener::new("quic://[::1]:31015".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ipv6_domain_pingpong() {
let listener = QUICTunnelListener::new("quic://[::1]:31016".parse().unwrap());
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap());
let mut connector =
QUICTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
connector.set_ip_version(IpVersion::V6);
_tunnel_pingpong(listener, connector).await;
let listener = QUICTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
let mut connector =
QUICTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
connector.set_ip_version(IpVersion::V4);
_tunnel_pingpong(listener, connector).await;
}
@@ -355,13 +347,13 @@ mod tests {
#[tokio::test]
async fn test_alloc_port() {
// v4
let mut listener = QUICTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
// v6
let mut listener = QUICTunnelListener::new("quic://[::]:0".parse().unwrap());
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
@@ -369,7 +361,7 @@ mod tests {
#[tokio::test]
async fn quic_connector_reject_port_zero() {
let mut connector = QUICTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
let err = connector.connect().await.unwrap_err().to_string();
assert!(err.contains("port 0"), "unexpected error: {}", err);
}
+17 -28
View File
@@ -1,3 +1,5 @@
use async_ringbuf::{traits::*, AsyncHeapCons, AsyncHeapProd, AsyncHeapRb};
use crossbeam::atomic::AtomicCell;
use std::{
collections::HashMap,
fmt::Debug,
@@ -5,9 +7,6 @@ use std::{
task::{ready, Poll},
};
use async_ringbuf::{traits::*, AsyncHeapCons, AsyncHeapProd, AsyncHeapRb};
use crossbeam::atomic::AtomicCell;
use async_trait::async_trait;
use futures::{Sink, SinkExt, Stream, StreamExt};
use once_cell::sync::Lazy;
@@ -16,15 +15,15 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use uuid::Uuid;
use crate::tunnel::{SinkError, SinkItem};
use crate::tunnel::{FromUrl, IpVersion, SinkError, SinkItem};
use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::TunnelWrapper,
StreamItem, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
build_url_from_socket_addr, common::TunnelWrapper, StreamItem, Tunnel, TunnelConnector,
TunnelError, TunnelInfo, TunnelListener,
};
pub static RING_TUNNEL_CAP: usize = 128;
static RING_TUNNEL_RESERVERD_CAP: usize = 4;
static RING_TUNNEL_RESERVED_CAP: usize = 4;
type RingLock = parking_lot::Mutex<()>;
@@ -44,7 +43,7 @@ impl RingTunnel {
pub fn new(cap: usize) -> Self {
let id = Uuid::new_v4();
let ring_impl = AsyncHeapRb::new(std::cmp::max(RING_TUNNEL_RESERVERD_CAP * 2, cap));
let ring_impl = AsyncHeapRb::new(std::cmp::max(RING_TUNNEL_RESERVED_CAP * 2, cap));
let (ring_prod_impl, ring_cons_impl) = ring_impl.split();
Self {
id,
@@ -120,7 +119,7 @@ impl RingSink {
pub fn try_send(&mut self, item: RingItem) -> Result<(), RingItem> {
let base = self.ring_prod_impl.base();
if base.occupied_len() >= base.capacity().get() - RING_TUNNEL_RESERVERD_CAP {
if base.occupied_len() >= base.capacity().get() - RING_TUNNEL_RESERVED_CAP {
return Err(item);
}
self.ring_prod_impl.try_push(item)
@@ -188,7 +187,7 @@ static CONNECTION_MAP: Lazy<Arc<std::sync::Mutex<ConnectionMap>>> =
#[derive(Debug)]
pub struct RingTunnelListener {
listerner_addr: url::Url,
listener_addr: url::Url,
conn_sender: UnboundedSender<Arc<Connection>>,
conn_receiver: UnboundedReceiver<Arc<Connection>>,
@@ -199,7 +198,7 @@ impl RingTunnelListener {
pub fn new(key: url::Url) -> Self {
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
RingTunnelListener {
listerner_addr: key,
listener_addr: key,
conn_sender,
conn_receiver,
key_in_conn_map: None,
@@ -232,20 +231,15 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
}
impl RingTunnelListener {
async fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
check_scheme_and_get_socket_addr::<Uuid>(
&self.listerner_addr,
"ring",
super::IpVersion::Both,
)
.await
async fn get_addr(&self) -> Result<Uuid, TunnelError> {
Uuid::from_url(self.listener_addr.clone(), IpVersion::Both).await
}
}
#[async_trait]
impl TunnelListener for RingTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
tracing::info!("listen new conn of key: {}", self.listerner_addr);
tracing::info!("listen new conn of key: {}", self.listener_addr);
let addr = self.get_addr().await?;
CONNECTION_MAP
.lock()
@@ -256,11 +250,11 @@ impl TunnelListener for RingTunnelListener {
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tracing::info!("waiting accept new conn of key: {}", self.listerner_addr);
tracing::info!("waiting accept new conn of key: {}", self.listener_addr);
let my_addr = self.get_addr().await?;
if let Some(conn) = self.conn_receiver.recv().await {
if conn.server.id == my_addr {
tracing::info!("accept new conn of key: {}", self.listerner_addr);
tracing::info!("accept new conn of key: {}", self.listener_addr);
return Ok(Box::new(get_tunnel_for_server(conn)));
} else {
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
@@ -276,7 +270,7 @@ impl TunnelListener for RingTunnelListener {
}
fn local_url(&self) -> url::Url {
self.listerner_addr.clone()
self.listener_addr.clone()
}
}
@@ -301,12 +295,7 @@ impl RingTunnelConnector {
#[async_trait]
impl TunnelConnector for RingTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>(
&self.remote_addr,
"ring",
super::IpVersion::Both,
)
.await?;
let remote_addr = Uuid::from_url(self.remote_addr.clone(), IpVersion::Both).await?;
let entry = CONNECTION_MAP
.lock()
.unwrap()
+5 -11
View File
@@ -1,14 +1,12 @@
use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::setup_sokcet2;
use async_trait::async_trait;
use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use super::TunnelInfo;
use crate::tunnel::common::setup_sokcet2;
use super::{
check_scheme_and_get_socket_addr,
common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper},
IpVersion, Tunnel, TunnelError, TunnelListener,
};
@@ -58,9 +56,7 @@ impl TcpTunnelListener {
impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None;
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", IpVersion::Both)
.await?;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
@@ -189,10 +185,8 @@ impl TcpTunnelConnector {
#[async_trait]
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", self.ip_version)
.await?;
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() {
self.connect_with_default_bind(addr).await
} else {
+18 -31
View File
@@ -21,7 +21,13 @@ use tokio::{
use tracing::{instrument, Instrument};
use super::{packet_def::V6HolePunchPacket, TunnelInfo};
use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
packet_def::{UDPTunnelHeader, V6HolePunchPacket, UDP_TUNNEL_HEADER_SIZE},
ring::{RingSink, RingStream},
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl,
};
use crate::{
common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap},
tunnel::{
@@ -32,13 +38,6 @@ use crate::{
},
};
use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
packet_def::{UDPTunnelHeader, UDP_TUNNEL_HEADER_SIZE},
ring::{RingSink, RingStream},
IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelListener, TunnelUrl,
};
pub const UDP_DATA_MTU: usize = 2000;
type UdpCloseEventSender = UnboundedSender<(SocketAddr, Option<TunnelError>)>;
@@ -149,11 +148,11 @@ async fn respond_stun_packet(
req_buf: Vec<u8>,
) -> Result<(), anyhow::Error> {
use crate::common::stun_codec_ext::*;
use bytecodec::DecodeExt as _;
use bytecodec::EncodeExt as _;
use stun_codec::rfc5389::attributes::XorMappedAddress;
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
use bytecodec::{DecodeExt as _, EncodeExt as _};
use stun_codec::{
rfc5389::{attributes::XorMappedAddress, methods::BINDING},
Message, MessageClass, MessageDecoder, MessageEncoder,
};
let mut decoder = MessageDecoder::<Attribute>::new();
let req_msg = decoder
@@ -532,13 +531,8 @@ impl UdpTunnelListener {
#[async_trait]
impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"udp",
IpVersion::Both,
)
.await?;
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
@@ -851,13 +845,8 @@ impl UdpTunnelConnector {
#[async_trait]
impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"udp",
self.ip_version,
)
.await?;
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
} else {
@@ -889,7 +878,6 @@ mod tests {
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
tunnel::{
check_scheme_and_get_socket_addr,
common::{
get_interface_name_by_ip,
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong, wait_for_condition},
@@ -1034,9 +1022,8 @@ mod tests {
for ip in ips {
println!("bind to ip: {}, {:?}", ip, bind_dev);
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
&format!("udp://{}:11111", ip).parse().unwrap(),
"udp",
let addr = SocketAddr::from_url(
format!("udp://{}:11111", ip).parse().unwrap(),
IpVersion::Both,
)
.await
+29 -29
View File
@@ -1,10 +1,20 @@
use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
};
use crate::{proto::common::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
use anyhow::Context;
use bytes::BytesMut;
use forwarded_header_value::ForwardedHeaderValue;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use pnet::ipnetwork::IpNetwork;
use std::sync::LazyLock;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use std::{
net::SocketAddr,
sync::{Arc, LazyLock},
time::Duration,
};
use tokio::{
net::{TcpListener, TcpSocket, TcpStream},
time::timeout,
@@ -14,16 +24,6 @@ use tokio_util::either::Either;
use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message, ServerBuilder};
use zerocopy::AsBytes;
use super::TunnelInfo;
use crate::tunnel::insecure_tls::get_insecure_tls_client_config;
use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
};
fn is_wss(addr: &url::Url) -> Result<bool, TunnelError> {
match addr.scheme() {
"ws" => Ok(false),
@@ -77,14 +77,14 @@ static TRUSTED_PROXIES: LazyLock<Vec<IpNetwork>> = LazyLock::new(|| {
});
#[derive(Debug)]
pub struct WSTunnelListener {
pub struct WsTunnelListener {
addr: url::Url,
listener: Option<TcpListener>,
}
impl WSTunnelListener {
impl WsTunnelListener {
pub fn new(addr: url::Url) -> Self {
WSTunnelListener {
WsTunnelListener {
addr,
listener: None,
}
@@ -159,7 +159,7 @@ impl WSTunnelListener {
}
#[async_trait::async_trait]
impl TunnelListener for WSTunnelListener {
impl TunnelListener for WsTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
@@ -199,16 +199,16 @@ impl TunnelListener for WSTunnelListener {
}
}
pub struct WSTunnelConnector {
pub struct WsTunnelConnector {
addr: url::Url,
ip_version: IpVersion,
bind_addrs: Vec<SocketAddr>,
}
impl WSTunnelConnector {
impl WsTunnelConnector {
pub fn new(addr: url::Url) -> Self {
WSTunnelConnector {
WsTunnelConnector {
addr,
ip_version: IpVersion::Both,
@@ -307,8 +307,8 @@ impl WSTunnelConnector {
}
#[async_trait::async_trait]
impl TunnelConnector for WSTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
impl TunnelConnector for WsTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
@@ -340,9 +340,9 @@ pub mod tests {
#[tokio::test]
#[serial_test::serial]
async fn ws_pingpong(#[values("ws", "wss")] proto: &str) {
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap());
let listener = WsTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap());
let connector =
WSTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap());
WsTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
@@ -350,9 +350,9 @@ pub mod tests {
#[tokio::test]
#[serial_test::serial]
async fn ws_pingpong_bind(#[values("ws", "wss")] proto: &str) {
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap());
let listener = WsTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap());
let mut connector =
WSTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap());
WsTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
@@ -371,16 +371,16 @@ pub mod tests {
#[tokio::test]
async fn ws_accept_wss() {
let mut listener = WSTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap());
let mut listener = WsTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap());
listener.listen().await.unwrap();
let j = tokio::spawn(async move {
let _ = listener.accept().await;
});
let mut connector = WSTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap());
let mut connector = WsTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap());
connector.connect().await.unwrap_err();
let mut connector = WSTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap());
let mut connector = WsTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap());
connector.connect().await.unwrap();
j.abort();
@@ -388,7 +388,7 @@ pub mod tests {
#[tokio::test]
async fn ws_forwarded() {
let mut listener = WSTunnelListener::new("ws://127.0.0.1:25559".parse().unwrap());
let mut listener = WsTunnelListener::new("ws://127.0.0.1:25559".parse().unwrap());
listener.listen().await.unwrap();
let server_task = tokio::spawn(async move {
+20 -23
View File
@@ -20,7 +20,14 @@ use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use super::TunnelInfo;
use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
ring::create_ring_tunnel_pair,
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
};
use crate::{
common::shrink_dashmap,
tunnel::{
@@ -30,15 +37,6 @@ use crate::{
},
};
use super::{
check_scheme_and_get_socket_addr,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
ring::create_ring_tunnel_pair,
IpVersion, Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream,
};
const MAX_PACKET: usize = 2048;
#[derive(Debug, Clone)]
@@ -202,7 +200,10 @@ impl WgPeerData {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
tracing::error!(
"Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}",
e
);
return;
}
};
@@ -214,7 +215,10 @@ impl WgPeerData {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
tracing::error!(
"Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}",
e
);
break;
}
};
@@ -550,10 +554,8 @@ impl WgTunnelListener {
#[async_trait]
impl TunnelListener for WgTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg", IpVersion::Both)
.await?;
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
@@ -705,13 +707,8 @@ impl WgTunnelConnector {
#[async_trait]
impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"wg",
self.ip_version,
)
.await?;
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if addr.is_ipv6() {
return self.connect_with_ipv6(addr).await;