refactor: listener/connector protocol abstraction (#2026)

* fix listener protocol detection
* replace IpProtocol with IpNextHeaderProtocol
* use an enum to gather all listener schemes
* rename ListenerScheme to TunnelScheme; replace IpNextHeaderProtocols with socket2::Protocol
* move TunnelScheme to tunnel
* add IpScheme, simplify connector creation
* format; fix some typos; remove check_scheme_...;
* remove PROTO_PORT_OFFSET
* rename WSTunnel.. -> WsTunnel.., DNSTunnel.. -> DnsTunnel..
This commit is contained in:
Luna Yao
2026-04-04 04:55:58 +02:00
committed by GitHub
parent 9cc617ae4c
commit e91a0da70a
18 changed files with 481 additions and 526 deletions
+25 -17
View File
@@ -7,7 +7,7 @@ use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use clap::Parser; use clap::Parser;
use easytier::tunnel::websocket::WSTunnelListener; use easytier::tunnel::websocket::WsTunnelListener;
use easytier::{ use easytier::{
common::{ common::{
config::{ConsoleLoggerConfig, FileLoggerConfig, LoggingConfigLoader}, config::{ConsoleLoggerConfig, FileLoggerConfig, LoggingConfigLoader},
@@ -20,6 +20,8 @@ use easytier::{
utils::setup_panic_handler, utils::setup_panic_handler,
}; };
use easytier::tunnel::IpScheme;
use easytier::utils::BoxExt;
use mimalloc::MiMalloc; use mimalloc::MiMalloc;
mod client_manager; mod client_manager;
@@ -192,14 +194,12 @@ impl LoggingConfigLoader for &Cli {
} }
} }
pub fn get_listener_by_url(l: &url::Url) -> Result<Box<dyn TunnelListener>, Error> { pub fn get_listener_by_url(scheme: IpScheme, l: &url::Url) -> Option<Box<dyn TunnelListener>> {
Ok(match l.scheme() { Some(match scheme {
"tcp" => Box::new(TcpTunnelListener::new(l.clone())), IpScheme::Tcp => TcpTunnelListener::new(l.clone()).boxed(),
"udp" => Box::new(UdpTunnelListener::new(l.clone())), IpScheme::Udp => UdpTunnelListener::new(l.clone()).boxed(),
"ws" => Box::new(WSTunnelListener::new(l.clone())), IpScheme::Ws => WsTunnelListener::new(l.clone()).boxed(),
_ => { _ => return None,
return Err(Error::InvalidUrl(l.to_string()));
}
}) })
} }
@@ -213,15 +213,23 @@ async fn get_dual_stack_listener(
), ),
Error, Error,
> { > {
let is_protocol_support_dual_stack = let scheme = protocol
protocol.trim().to_lowercase() == "tcp" || protocol.trim().to_lowercase() == "udp"; .parse()
let v6_listener = if is_protocol_support_dual_stack && local_ipv6().await.is_ok() { .map_err(|_| Error::InvalidUrl(protocol.to_string()))?;
get_listener_by_url(&format!("{}://[::0]:{}", protocol, port).parse().unwrap()).ok() let v6_listener =
} else { if local_ipv6().await.is_ok() && matches!(scheme, IpScheme::Tcp | IpScheme::Udp) {
None get_listener_by_url(
}; scheme,
&format!("{protocol}://[::]:{port}").parse().unwrap(),
)
} else {
None
};
let v4_listener = if local_ipv4().await.is_ok() { let v4_listener = if local_ipv4().await.is_ok() {
get_listener_by_url(&format!("{}://0.0.0.0:{}", protocol, port).parse().unwrap()).ok() get_listener_by_url(
scheme,
&format!("{protocol}://0.0.0.0:{port}").parse().unwrap(),
)
} else { } else {
None None
}; };
+21 -23
View File
@@ -1,8 +1,7 @@
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::Hasher, hash::Hasher,
net::{IpAddr, SocketAddr},
sync::{Arc, Mutex}, sync::{Arc, Mutex},
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
@@ -10,21 +9,6 @@ use std::{
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use dashmap::DashMap; use dashmap::DashMap;
use crate::common::config::ProxyNetworkConfig;
use crate::common::shrink_dashmap;
use crate::common::stats_manager::StatsManager;
use crate::common::token_bucket::TokenBucketManager;
use crate::peers::acl_filter::AclFilter;
use crate::peers::credential_manager::CredentialManager;
use crate::proto::acl::GroupIdentity;
use crate::proto::api::config::InstanceConfigPatch;
use crate::proto::api::instance::PeerConnInfo;
use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb};
use crate::proto::peer_rpc::PeerGroupInfo;
use crossbeam::atomic::AtomicCell;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use super::{ use super::{
config::{ConfigLoader, Flags}, config::{ConfigLoader, Flags},
netns::NetNS, netns::NetNS,
@@ -32,6 +16,24 @@ use super::{
stun::{StunInfoCollector, StunInfoCollectorTrait}, stun::{StunInfoCollector, StunInfoCollectorTrait},
PeerId, PeerId,
}; };
use crate::{
common::{
config::ProxyNetworkConfig, shrink_dashmap, stats_manager::StatsManager,
token_bucket::TokenBucketManager,
},
peers::{acl_filter::AclFilter, credential_manager::CredentialManager},
proto::{
acl::GroupIdentity,
api::{config::InstanceConfigPatch, instance::PeerConnInfo},
common::{PeerFeatureFlag, PortForwardConfigPb},
peer_rpc::PeerGroupInfo,
},
tunnel::matches_protocol,
};
use crossbeam::atomic::AtomicCell;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use socket2::Protocol;
pub type NetworkIdentity = crate::common::config::NetworkIdentity; pub type NetworkIdentity = crate::common::config::NetworkIdentity;
@@ -625,15 +627,11 @@ impl GlobalCtx {
} }
fn is_port_in_running_listeners(&self, port: u16, is_udp: bool) -> bool { fn is_port_in_running_listeners(&self, port: u16, is_udp: bool) -> bool {
let check_proto = |listener_proto: &str| {
let listener_is_udp = matches!(listener_proto, "udp" | "wg");
listener_is_udp == is_udp
};
self.running_listeners self.running_listeners
.lock() .lock()
.unwrap() .unwrap()
.iter() .iter()
.any(|x| x.port() == Some(port) && check_proto(x.scheme())) .any(|x| x.port() == Some(port) && matches_protocol!(x, Protocol::UDP) == is_udp)
} }
#[tracing::instrument(ret, skip(self))] #[tracing::instrument(ret, skip(self))]
+23 -23
View File
@@ -31,19 +31,20 @@ use crate::{
}, },
rpc_types::controller::BaseController, rpc_types::controller::BaseController,
}, },
tunnel::{udp::UdpTunnelConnector, IpVersion}, tunnel::{matches_protocol, udp::UdpTunnelConnector, IpVersion},
use_global_var, use_global_var,
}; };
use anyhow::Context;
use rand::Rng;
use tokio::{net::UdpSocket, task::JoinSet, time::timeout};
use url::Host;
use super::{ use super::{
create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer, create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer,
udp_hole_punch, udp_hole_punch,
}; };
use crate::tunnel::{matches_scheme, FromUrl, IpScheme, TunnelScheme};
use anyhow::Context;
use rand::Rng;
use socket2::Protocol;
use tokio::{net::UdpSocket, task::JoinSet, time::timeout};
use url::Host;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1; pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300; pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
@@ -189,9 +190,7 @@ impl DirectConnectorManagerData {
.await; .await;
let udp_connector = UdpTunnelConnector::new(remote_url.clone()); let udp_connector = UdpTunnelConnector::new(remote_url.clone());
let remote_addr = let remote_addr = SocketAddr::from_url(remote_url.clone(), IpVersion::V6).await?;
super::check_scheme_and_get_socket_addr::<SocketAddr>(remote_url, "udp", IpVersion::V6)
.await?;
let ret = udp_connector let ret = udp_connector
.try_connect_with_socket(local_socket, remote_addr) .try_connect_with_socket(local_socket, remote_addr)
.await?; .await?;
@@ -205,18 +204,19 @@ impl DirectConnectorManagerData {
async fn do_try_connect_to_ip(&self, dst_peer_id: PeerId, addr: String) -> Result<(), Error> { async fn do_try_connect_to_ip(&self, dst_peer_id: PeerId, addr: String) -> Result<(), Error> {
let connector = create_connector_by_url(&addr, &self.global_ctx, IpVersion::Both).await?; let connector = create_connector_by_url(&addr, &self.global_ctx, IpVersion::Both).await?;
let remote_url = connector.remote_url(); let remote_url = connector.remote_url();
let (peer_id, conn_id) = let (peer_id, conn_id) = if matches_scheme!(remote_url, TunnelScheme::Ip(IpScheme::Udp))
if remote_url.scheme() == "udp" && matches!(remote_url.host(), Some(Host::Ipv6(_))) { && matches!(remote_url.host(), Some(Host::Ipv6(_)))
self.connect_to_public_ipv6(dst_peer_id, &remote_url) {
.await? self.connect_to_public_ipv6(dst_peer_id, &remote_url)
} else { .await?
timeout( } else {
std::time::Duration::from_secs(3), timeout(
self.peer_manager std::time::Duration::from_secs(3),
.try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)), self.peer_manager
) .try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)),
.await?? )
}; .await??
};
if peer_id != dst_peer_id && !TESTING.load(Ordering::Relaxed) { if peer_id != dst_peer_id && !TESTING.load(Ordering::Relaxed) {
tracing::info!( tracing::info!(
@@ -306,7 +306,7 @@ impl DirectConnectorManagerData {
let listener_host = addrs.pop(); let listener_host = addrs.pop();
tracing::info!(?listener_host, ?listener, "try direct connect to peer"); tracing::info!(?listener_host, ?listener, "try direct connect to peer");
let is_udp = matches!(listener.scheme(), "udp" | "wg"); let is_udp = matches_protocol!(listener, Protocol::UDP);
// Snapshot running listeners once; used for cheap port pre-checks before the // Snapshot running listeners once; used for cheap port pre-checks before the
// expensive should_deny_proxy call (which binds a socket per IP) in the // expensive should_deny_proxy call (which binds a socket per IP) in the
// unspecified-address expansion loops below. // unspecified-address expansion loops below.
@@ -314,7 +314,7 @@ impl DirectConnectorManagerData {
let port_has_local_listener = |port: u16| -> bool { let port_has_local_listener = |port: u16| -> bool {
local_listeners local_listeners
.iter() .iter()
.any(|l| l.port() == Some(port) && (matches!(l.scheme(), "udp" | "wg") == is_udp)) .any(|l| l.port() == Some(port) && matches_protocol!(l, Protocol::UDP) == is_udp)
}; };
match listener_host { match listener_host {
+38 -40
View File
@@ -1,5 +1,6 @@
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use super::{create_connector_by_url, http_connector::TunnelWithInfo};
use crate::{ use crate::{
common::{ common::{
dns::{resolve_txt_record, RESOLVER}, dns::{resolve_txt_record, RESOLVER},
@@ -7,16 +8,15 @@ use crate::{
global_ctx::ArcGlobalCtx, global_ctx::ArcGlobalCtx,
log, log,
}, },
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, PROTO_PORT_OFFSET}, proto::common::TunnelInfo,
tunnel::{IpScheme, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelScheme},
}; };
use anyhow::Context; use anyhow::Context;
use dashmap::DashSet; use dashmap::DashSet;
use hickory_resolver::proto::rr::rdata::SRV; use hickory_resolver::proto::rr::rdata::SRV;
use itertools::Itertools;
use rand::{seq::SliceRandom, Rng as _}; use rand::{seq::SliceRandom, Rng as _};
use strum::VariantArray;
use crate::proto::common::TunnelInfo;
use super::{create_connector_by_url, http_connector::TunnelWithInfo};
fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> { fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> {
let total_weight = options.iter().map(|(_, weight)| *weight).sum(); let total_weight = options.iter().map(|(_, weight)| *weight).sum();
@@ -35,16 +35,18 @@ fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct DNSTunnelConnector { pub struct DnsTunnelConnector {
scheme: TunnelScheme,
addr: url::Url, addr: url::Url,
bind_addrs: Vec<SocketAddr>, bind_addrs: Vec<SocketAddr>,
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
ip_version: IpVersion, ip_version: IpVersion,
} }
impl DNSTunnelConnector { impl DnsTunnelConnector {
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self { pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
Self { Self {
scheme: (&addr).try_into().unwrap(),
addr, addr,
bind_addrs: Vec::new(), bind_addrs: Vec::new(),
global_ctx, global_ctx,
@@ -82,7 +84,7 @@ impl DNSTunnelConnector {
Ok(connector) Ok(connector)
} }
fn handle_one_srv_record(record: &SRV, protocol: &str) -> Result<(url::Url, u64), Error> { fn handle_one_srv_record(record: &SRV, protocol: IpScheme) -> Result<(url::Url, u64), Error> {
// port must be non-zero // port must be non-zero
if record.port() == 0 { if record.port() == 0 {
return Err(anyhow::anyhow!("port must be non-zero").into()); return Err(anyhow::anyhow!("port must be non-zero").into());
@@ -112,15 +114,15 @@ impl DNSTunnelConnector {
) -> Result<Box<dyn TunnelConnector>, Error> { ) -> Result<Box<dyn TunnelConnector>, Error> {
tracing::info!("handle_srv_record: {}", domain_name); tracing::info!("handle_srv_record: {}", domain_name);
let srv_domains = PROTO_PORT_OFFSET let srv_domains = IpScheme::VARIANTS
.iter() .iter()
.map(|(p, _)| (format!("_easytier._{}.{}", p, domain_name), *p)) // _easytier._udp.{domain_name} .map(|s| (s, format!("_easytier._{}.{}", s, domain_name)))
.collect::<Vec<_>>(); .collect_vec();
tracing::info!("build srv_domains: {:?}", srv_domains); tracing::info!("build srv_domains: {:?}", srv_domains);
let responses = Arc::new(DashSet::new()); let responses = Arc::new(DashSet::new());
let srv_lookup_tasks = srv_domains let srv_lookup_tasks = srv_domains
.iter() .iter()
.map(|(srv_domain, protocol)| { .map(|(protocol, srv_domain)| {
let resolver = RESOLVER.clone(); let resolver = RESOLVER.clone();
let responses = responses.clone(); let responses = responses.clone();
async move { async move {
@@ -129,7 +131,7 @@ impl DNSTunnelConnector {
})?; })?;
tracing::info!(?response, ?srv_domain, "srv_lookup response"); tracing::info!(?response, ?srv_domain, "srv_lookup response");
for record in response.iter() { for record in response.iter() {
let parsed_record = Self::handle_one_srv_record(record, protocol); let parsed_record = Self::handle_one_srv_record(record, **protocol);
tracing::info!(?parsed_record, ?srv_domain, "parsed_record"); tracing::info!(?parsed_record, ?srv_domain, "parsed_record");
if let Err(e) = &parsed_record { if let Err(e) = &parsed_record {
log::warn!("got invalid srv record {:?}", e); log::warn!("got invalid srv record {:?}", e);
@@ -162,32 +164,28 @@ impl DNSTunnelConnector {
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl super::TunnelConnector for DNSTunnelConnector { impl super::TunnelConnector for DnsTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let mut conn = if self.addr.scheme() == "txt" { let mut conn = match self.scheme {
self.handle_txt_record( TunnelScheme::Txt => self
self.addr .handle_txt_record(
.host_str() self.addr
.as_ref() .host_str()
.ok_or(anyhow::anyhow!("host should not be empty in txt url"))?, .as_ref()
) .ok_or(anyhow::anyhow!("host should not be empty in txt url"))?,
.await )
.with_context(|| "get txt record url failed")? .await
} else if self.addr.scheme() == "srv" { .with_context(|| "get txt record url failed")?,
self.handle_srv_record( TunnelScheme::Srv => self
self.addr .handle_srv_record(
.host_str() self.addr
.as_ref() .host_str()
.ok_or(anyhow::anyhow!("host should not be empty in srv url"))?, .as_ref()
) .ok_or(anyhow::anyhow!("host should not be empty in srv url"))?,
.await )
.with_context(|| "get srv record url failed")? .await
} else { .with_context(|| "get srv record url failed")?,
return Err(anyhow::anyhow!( _ => return Err(anyhow::anyhow!("unsupported dns scheme: {:?}", self.scheme).into()),
"unsupported dns scheme: {}, expecting txt or srv",
self.addr.scheme()
)
.into());
}; };
let t = conn.connect().await?; let t = conn.connect().await?;
let info = t.info().unwrap_or_default(); let info = t.info().unwrap_or_default();
@@ -227,7 +225,7 @@ mod tests {
async fn test_txt() { async fn test_txt() {
let url = "txt://txt.easytier.cn"; let url = "txt://txt.easytier.cn";
let global_ctx = get_mock_global_ctx(); let global_ctx = get_mock_global_ctx();
let mut connector = DNSTunnelConnector::new(url.parse().unwrap(), global_ctx); let mut connector = DnsTunnelConnector::new(url.parse().unwrap(), global_ctx);
connector.set_ip_version(IpVersion::V4); connector.set_ip_version(IpVersion::V4);
for _ in 0..5 { for _ in 0..5 {
match connector.connect().await { match connector.connect().await {
@@ -246,7 +244,7 @@ mod tests {
async fn test_srv() { async fn test_srv() {
let url = "srv://easytier.cn"; let url = "srv://easytier.cn";
let global_ctx = get_mock_global_ctx(); let global_ctx = get_mock_global_ctx();
let mut connector = DNSTunnelConnector::new(url.parse().unwrap(), global_ctx); let mut connector = DnsTunnelConnector::new(url.parse().unwrap(), global_ctx);
connector.set_ip_version(IpVersion::V4); connector.set_ip_version(IpVersion::V4);
for _ in 0..5 { for _ in 0..5 {
match connector.connect().await { match connector.connect().await {
+41 -116
View File
@@ -3,24 +3,17 @@ use std::{
sync::Arc, sync::Arc,
}; };
use http_connector::HttpTunnelConnector;
#[cfg(feature = "faketcp")]
use crate::tunnel::fake_tcp::FakeTcpTunnelConnector;
#[cfg(feature = "quic")]
use crate::tunnel::quic::QUICTunnelConnector;
#[cfg(unix)]
use crate::tunnel::unix::UnixSocketTunnelConnector;
#[cfg(feature = "wireguard")]
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector}, common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector},
connector::dns_connector::DnsTunnelConnector,
proto::common::PeerFeatureFlag, proto::common::PeerFeatureFlag,
tunnel::{ tunnel::{
check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector, self, ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, FromUrl,
udp::UdpTunnelConnector, IpVersion, TunnelConnector, IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme,
}, },
utils::BoxExt,
}; };
use http_connector::HttpTunnelConnector;
pub mod direct; pub mod direct;
pub mod manual; pub mod manual;
@@ -90,84 +83,34 @@ pub async fn create_connector_by_url(
) -> Result<Box<dyn TunnelConnector + 'static>, Error> { ) -> Result<Box<dyn TunnelConnector + 'static>, Error> {
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?; let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
let url = idn::convert_idn_to_ascii(url)?; let url = idn::convert_idn_to_ascii(url)?;
let mut connector: Box<dyn TunnelConnector + 'static> = match url.scheme() { let scheme = (&url)
"tcp" => { .try_into()
let dst_addr = .map_err(|_| TunnelError::InvalidProtocol(url.scheme().to_owned()))?;
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp", ip_version).await?; let mut connector: Box<dyn TunnelConnector + 'static> = match scheme {
let mut connector = TcpTunnelConnector::new(url); TunnelScheme::Ip(scheme) => {
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
"udp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp", ip_version).await?;
let mut connector = UdpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
"http" | "https" => {
let connector = HttpTunnelConnector::new(url, global_ctx.clone());
Box::new(connector)
}
"ring" => {
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring", IpVersion::Both).await?;
let connector = RingTunnelConnector::new(url);
Box::new(connector)
}
#[cfg(feature = "quic")]
"quic" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic", ip_version).await?;
let mut connector = QUICTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(feature = "wireguard")]
"wg" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg", ip_version).await?;
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
let mut connector = WgTunnelConnector::new(url, wg_config);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(feature = "websocket")]
"ws" | "wss" => {
use crate::tunnel::FromUrl;
let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?; let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?;
let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url); let mut connector: Box<dyn TunnelConnector> = match scheme {
IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(),
IpScheme::Udp => UdpTunnelConnector::new(url).boxed(),
#[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelConnector::new(url).boxed(),
#[cfg(feature = "wireguard")]
IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
WgTunnelConnector::new(url, wg_config).boxed()
}
#[cfg(feature = "websocket")]
IpScheme::Ws | IpScheme::Wss => {
tunnel::websocket::WsTunnelConnector::new(url).boxed()
}
#[cfg(feature = "faketcp")]
IpScheme::FakeTcp => tunnel::fake_tcp::FakeTcpTunnelConnector::new(url).boxed(),
};
if global_ctx.config.get_flags().bind_device { if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector( set_bind_addr_for_peer_connector(
&mut connector, &mut connector,
@@ -176,40 +119,22 @@ pub async fn create_connector_by_url(
) )
.await; .await;
} }
Box::new(connector) connector
} }
"txt" | "srv" => { #[cfg(unix)]
TunnelScheme::Unix => tunnel::unix::UnixSocketTunnelConnector::new(url).boxed(),
TunnelScheme::Http | TunnelScheme::Https => {
HttpTunnelConnector::new(url, global_ctx.clone()).boxed()
}
TunnelScheme::Ring => RingTunnelConnector::new(url).boxed(),
TunnelScheme::Txt | TunnelScheme::Srv => {
if url.host_str().is_none() { if url.host_str().is_none() {
return Err(Error::InvalidUrl(format!( return Err(Error::InvalidUrl(format!(
"host should not be empty in txt or srv url: {}", "host should not be empty in txt or srv url: {}",
url url
))); )));
} }
let connector = dns_connector::DNSTunnelConnector::new(url, global_ctx.clone()); DnsTunnelConnector::new(url, global_ctx.clone()).boxed()
Box::new(connector)
}
#[cfg(feature = "faketcp")]
"faketcp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "faketcp", ip_version).await?;
let mut connector = FakeTcpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(unix)]
"unix" => {
let connector = UnixSocketTunnelConnector::new(url);
Box::new(connector)
}
_ => {
return Err(Error::InvalidUrl(url.into()));
} }
}; };
connector.set_ip_version(ip_version); connector.set_ip_version(ip_version);
+14 -17
View File
@@ -22,7 +22,6 @@ use crate::{
launcher::add_proxy_network_to_config, launcher::add_proxy_network_to_config,
proto::common::{CompressionAlgoPb, SecureModeConfig}, proto::common::{CompressionAlgoPb, SecureModeConfig},
rpc_service::ApiRpcServer, rpc_service::ApiRpcServer,
tunnel::PROTO_PORT_OFFSET,
utils::setup_panic_handler, utils::setup_panic_handler,
web_client, ShellType, web_client, ShellType,
}; };
@@ -30,8 +29,10 @@ use anyhow::Context;
use cidr::IpCidr; use cidr::IpCidr;
use clap::{CommandFactory, Parser}; use clap::{CommandFactory, Parser};
use rust_i18n::t; use rust_i18n::t;
use strum::VariantArray;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use crate::tunnel::IpScheme;
#[cfg(feature = "jemalloc-prof")] #[cfg(feature = "jemalloc-prof")]
use jemalloc_ctl::{epoch, stats, Access as _, AsName as _}; use jemalloc_ctl::{epoch, stats, Access as _, AsName as _};
@@ -742,8 +743,12 @@ impl Cli {
let mut listeners: Vec<String> = Vec::new(); let mut listeners: Vec<String> = Vec::new();
if origin_listeners.len() == 1 { if origin_listeners.len() == 1 {
if let Ok(port) = origin_listeners[0].parse::<u16>() { if let Ok(port) = origin_listeners[0].parse::<u16>() {
for (proto, offset) in PROTO_PORT_OFFSET { for proto in IpScheme::VARIANTS {
listeners.push(format!("{}://0.0.0.0:{}", proto, port + *offset)); listeners.push(format!(
"{}://0.0.0.0:{}",
proto,
port + proto.port_offset()
));
} }
return Ok(listeners); return Ok(listeners);
} }
@@ -758,20 +763,15 @@ impl Cli {
panic!("failed to parse listener: {}", l); panic!("failed to parse listener: {}", l);
} }
} else { } else {
let Some((proto, offset)) = PROTO_PORT_OFFSET let scheme: IpScheme = proto_port[0].parse()?;
.iter()
.find(|(proto, _)| *proto == proto_port[0])
else {
return Err(anyhow::anyhow!("unknown protocol: {}", proto_port[0]));
};
let port = if proto_port.len() == 2 { let port = if proto_port.len() == 2 {
proto_port[1].parse::<u16>().unwrap() proto_port[1].parse::<u16>().unwrap()
} else { } else {
11010 + offset 11010 + scheme.port_offset()
}; };
listeners.push(format!("{}://0.0.0.0:{}", proto, port)); listeners.push(format!("{}://0.0.0.0:{}", scheme, port));
} }
} }
@@ -1134,8 +1134,7 @@ impl LoggingConfigLoader for &LoggingOptions {
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
fn win_service_set_work_dir(service_name: &std::ffi::OsString) -> anyhow::Result<()> { fn win_service_set_work_dir(service_name: &std::ffi::OsString) -> anyhow::Result<()> {
use crate::common::constants::WIN_SERVICE_WORK_DIR_REG_KEY; use crate::common::constants::WIN_SERVICE_WORK_DIR_REG_KEY;
use winreg::enums::*; use winreg::{enums::*, RegKey};
use winreg::RegKey;
let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
let key = hklm.open_subkey_with_flags(WIN_SERVICE_WORK_DIR_REG_KEY, KEY_READ)?; let key = hklm.open_subkey_with_flags(WIN_SERVICE_WORK_DIR_REG_KEY, KEY_READ)?;
@@ -1215,11 +1214,9 @@ fn parse_cli() -> Cli {
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
fn win_service_main(arg: Vec<std::ffi::OsString>) { fn win_service_main(arg: Vec<std::ffi::OsString>) {
use std::sync::Arc; use std::{sync::Arc, time::Duration};
use std::time::Duration;
use tokio::sync::Notify; use tokio::sync::Notify;
use windows_service::service::*; use windows_service::{service::*, service_control_handler::*};
use windows_service::service_control_handler::*;
_ = win_service_set_work_dir(&arg[0]); _ = win_service_set_work_dir(&arg[0]);
+37 -45
View File
@@ -9,12 +9,6 @@ use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use tokio::task::JoinSet; use tokio::task::JoinSet;
#[cfg(feature = "faketcp")]
use crate::tunnel::fake_tcp::FakeTcpTunnelListener;
#[cfg(feature = "quic")]
use crate::tunnel::quic::QUICTunnelListener;
#[cfg(feature = "wireguard")]
use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
use crate::{ use crate::{
common::{ common::{
error::Error, error::Error,
@@ -23,44 +17,42 @@ use crate::{
}, },
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
tunnel::{ tunnel::{
ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, Tunnel, self, ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, IpScheme,
TunnelListener, Tunnel, TunnelListener, TunnelScheme,
}, },
utils::BoxExt,
}; };
pub fn get_listener_by_url( pub fn create_listener_by_url(
l: &url::Url, l: &url::Url,
_ctx: ArcGlobalCtx, #[allow(unused_variables)] ctx: ArcGlobalCtx,
) -> Result<Box<dyn TunnelListener>, Error> { ) -> Result<Box<dyn TunnelListener>, Error> {
Ok(match l.scheme() { Ok(match l.try_into()? {
"tcp" => Box::new(TcpTunnelListener::new(l.clone())), TunnelScheme::Ip(scheme) => match scheme {
"udp" => Box::new(UdpTunnelListener::new(l.clone())), IpScheme::Tcp => TcpTunnelListener::new(l.clone()).boxed(),
#[cfg(feature = "wireguard")] IpScheme::Udp => UdpTunnelListener::new(l.clone()).boxed(),
"wg" => { #[cfg(feature = "wireguard")]
let nid = _ctx.get_network_identity(); IpScheme::Wg => {
let wg_config = WgConfig::new_from_network_identity( use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
&nid.network_name, let nid = ctx.get_network_identity();
&nid.network_secret.unwrap_or_default(), let wg_config = WgConfig::new_from_network_identity(
); &nid.network_name,
Box::new(WgTunnelListener::new(l.clone(), wg_config)) &nid.network_secret.unwrap_or_default(),
} );
#[cfg(feature = "quic")] WgTunnelListener::new(l.clone(), wg_config).boxed()
"quic" => Box::new(QUICTunnelListener::new(l.clone())), }
#[cfg(feature = "websocket")] #[cfg(feature = "quic")]
"ws" | "wss" => { IpScheme::Quic => tunnel::quic::QuicTunnelListener::new(l.clone()).boxed(),
use crate::tunnel::websocket::WSTunnelListener; #[cfg(feature = "websocket")]
Box::new(WSTunnelListener::new(l.clone())) IpScheme::Ws | IpScheme::Wss => {
} tunnel::websocket::WsTunnelListener::new(l.clone()).boxed()
#[cfg(feature = "faketcp")] }
"faketcp" => Box::new(FakeTcpTunnelListener::new(l.clone())), #[cfg(feature = "faketcp")]
IpScheme::FakeTcp => tunnel::fake_tcp::FakeTcpTunnelListener::new(l.clone()).boxed(),
},
#[cfg(unix)] #[cfg(unix)]
"unix" => { TunnelScheme::Unix => tunnel::unix::UnixSocketTunnelListener::new(l.clone()).boxed(),
use crate::tunnel::unix::UnixSocketTunnelListener; _ => return Err(Error::InvalidUrl(l.to_string())),
Box::new(UnixSocketTunnelListener::new(l.clone()))
}
_ => {
return Err(Error::InvalidUrl(l.to_string()));
}
}) })
} }
@@ -133,7 +125,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
for l in self.global_ctx.config.get_listener_uris().iter() { for l in self.global_ctx.config.get_listener_uris().iter() {
let l = l.clone(); let l = l.clone();
let Ok(_) = get_listener_by_url(&l, self.global_ctx.clone()) else { let Ok(_) = create_listener_by_url(&l, self.global_ctx.clone()) else {
let msg = format!("failed to get listener by url: {}, maybe not supported", l); let msg = format!("failed to get listener by url: {}, maybe not supported", l);
self.global_ctx self.global_ctx
.issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg)); .issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg));
@@ -143,7 +135,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
let listener = l.clone(); let listener = l.clone();
self.add_listener( self.add_listener(
move || get_listener_by_url(&listener, ctx.clone()).unwrap(), move || create_listener_by_url(&listener, ctx.clone()).unwrap(),
true, true,
) )
.await?; .await?;
@@ -160,7 +152,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
.with_context(|| format!("failed to set ipv6 host for listener: {}", l))?; .with_context(|| format!("failed to set ipv6 host for listener: {}", l))?;
let ctx = self.global_ctx.clone(); let ctx = self.global_ctx.clone();
self.add_listener( self.add_listener(
move || get_listener_by_url(&ipv6_listener, ctx.clone()).unwrap(), move || create_listener_by_url(&ipv6_listener, ctx.clone()).unwrap(),
false, false,
) )
.await?; .await?;
@@ -361,10 +353,6 @@ mod tests {
#[async_trait::async_trait] #[async_trait::async_trait]
impl TunnelListener for MockListener { impl TunnelListener for MockListener {
fn local_url(&self) -> url::Url {
"mock://".parse().unwrap()
}
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
self.counter.fetch_add(1, Ordering::Relaxed); self.counter.fetch_add(1, Ordering::Relaxed);
Ok(()) Ok(())
@@ -374,6 +362,10 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_secs(1)).await; tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Err(TunnelError::BufferFull) Err(TunnelError::BufferFull)
} }
fn local_url(&self) -> url::Url {
"mock://".parse().unwrap()
}
} }
impl Drop for MockListener { impl Drop for MockListener {
+3 -3
View File
@@ -2003,7 +2003,7 @@ mod tests {
create_connector_by_url, direct::PeerManagerForDirectConnector, create_connector_by_url, direct::PeerManagerForDirectConnector,
udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun, udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
}, },
instance::listeners::get_listener_by_url, instance::listeners::create_listener_by_url,
peers::{ peers::{
create_packet_recv_chan, create_packet_recv_chan,
peer_conn::tests::set_secure_mode_cfg, peer_conn::tests::set_secure_mode_cfg,
@@ -2783,7 +2783,7 @@ mod tests {
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
register_service(&peer_mgr_c.peer_rpc_mgr, "", 0, "hello c"); register_service(&peer_mgr_c.peer_rpc_mgr, "", 0, "hello c");
let mut listener1 = get_listener_by_url( let mut listener1 = create_listener_by_url(
&format!("{}://0.0.0.0:31013", proto1).parse().unwrap(), &format!("{}://0.0.0.0:31013", proto1).parse().unwrap(),
peer_mgr_b.get_global_ctx(), peer_mgr_b.get_global_ctx(),
) )
@@ -2802,7 +2802,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let mut listener2 = get_listener_by_url( let mut listener2 = create_listener_by_url(
&format!("{}://0.0.0.0:31014", proto2).parse().unwrap(), &format!("{}://0.0.0.0:31014", proto2).parse().unwrap(),
peer_mgr_c.get_global_ctx(), peer_mgr_c.get_global_ctx(),
) )
+2 -2
View File
@@ -156,14 +156,14 @@ async fn init_three_node_ex_with_inst3<F: Fn(TomlConfigLoader) -> TomlConfigLoad
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
inst1 inst1
.get_conn_manager() .get_conn_manager()
.add_connector(crate::tunnel::websocket::WSTunnelConnector::new( .add_connector(crate::tunnel::websocket::WsTunnelConnector::new(
"ws://10.1.1.2:11011".parse().unwrap(), "ws://10.1.1.2:11011".parse().unwrap(),
)); ));
} else if proto == "wss" { } else if proto == "wss" {
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
inst1 inst1
.get_conn_manager() .get_conn_manager()
.add_connector(crate::tunnel::websocket::WSTunnelConnector::new( .add_connector(crate::tunnel::websocket::WsTunnelConnector::new(
"wss://10.1.1.2:11012".parse().unwrap(), "wss://10.1.1.2:11012".parse().unwrap(),
)); ));
} }
+21 -30
View File
@@ -2,20 +2,28 @@ mod netfilter;
mod packet; mod packet;
mod stack; mod stack;
use std::net::{IpAddr, Ipv4Addr, UdpSocket};
use std::sync::Arc;
use std::{net::SocketAddr, pin::Pin};
use bytes::BytesMut; use bytes::BytesMut;
use futures::{Sink, Stream};
use network_interface::NetworkInterfaceConfig; use network_interface::NetworkInterfaceConfig;
use pnet::util::MacAddr; use pnet::util::MacAddr;
use tokio::io::AsyncReadExt; use std::{
use tokio::net::TcpStream; net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
use tokio::sync::Mutex; pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
use crate::common::scoped_task::ScopedTask; use crate::{
use crate::tunnel::fake_tcp::netfilter::create_tun; common::scoped_task::ScopedTask,
use crate::tunnel::{common::TunnelWrapper, Tunnel, TunnelError, TunnelInfo, TunnelListener}; tunnel::{
common::TunnelWrapper,
fake_tcp::netfilter::create_tun,
packet_def::{ZCPacket, ZCPacketType, PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE},
FromUrl, IpVersion, SinkError, SinkItem, StreamItem, Tunnel, TunnelConnector, TunnelError,
TunnelInfo, TunnelListener,
},
};
use futures::Future; use futures::Future;
@@ -207,12 +215,7 @@ struct AcceptResult {
impl TunnelListener for FakeTcpTunnelListener { impl TunnelListener for FakeTcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let port = self.addr.port().unwrap_or(0); let port = self.addr.port().unwrap_or(0);
let bind_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>( let bind_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let os_listener = tokio::net::TcpListener::bind(bind_addr).await?; let os_listener = tokio::net::TcpListener::bind(bind_addr).await?;
tracing::info!(port, "FakeTcpTunnelListener listening"); tracing::info!(port, "FakeTcpTunnelListener listening");
self.os_listener = Some(os_listener); self.os_listener = Some(os_listener);
@@ -306,14 +309,9 @@ fn get_local_ip_for_destination(destination: IpAddr) -> Option<IpAddr> {
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector { impl TunnelConnector for FakeTcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let remote_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>( let remote_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let local_ip = get_local_ip_for_destination(remote_addr.ip()) let local_ip = get_local_ip_for_destination(remote_addr.ip())
.ok_or(TunnelError::InternalError("Failed to get local ip".into()))?; .ok_or(TunnelError::InternalError("Failed to get local ip".into()))?;
@@ -387,13 +385,6 @@ impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
} }
} }
use crate::tunnel::packet_def::{
ZCPacket, ZCPacketType, PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE,
};
use crate::tunnel::{SinkError, SinkItem, StreamItem};
use futures::{Sink, Stream};
use std::task::{Context as TaskContext, Poll};
type RecvFut = Pin<Box<dyn Future<Output = Option<(BytesMut, usize)>> + Send + Sync>>; type RecvFut = Pin<Box<dyn Future<Output = Option<(BytesMut, usize)>> + Send + Sync>>;
enum FakeTcpStreamState { enum FakeTcpStreamState {
+140 -52
View File
@@ -1,16 +1,19 @@
use std::collections::hash_map::DefaultHasher; use std::{
use std::hash::Hasher; collections::hash_map::DefaultHasher, hash::Hasher, net::SocketAddr, pin::Pin, sync::Arc,
use std::{net::SocketAddr, pin::Pin, sync::Arc}; };
use crate::{
common::{dns::socket_addrs, error::Error},
proto::common::TunnelInfo,
};
use async_trait::async_trait; use async_trait::async_trait;
use derive_more::{From, TryInto};
use futures::{Sink, Stream}; use futures::{Sink, Stream};
use socket2::Protocol;
use std::fmt::Debug; use std::fmt::Debug;
use strum::{Display, EnumString, VariantArray};
use tokio::time::error::Elapsed; use tokio::time::error::Elapsed;
use crate::common::dns::socket_addrs;
use crate::proto::common::TunnelInfo;
use self::packet_def::ZCPacket; use self::packet_def::ZCPacket;
pub mod buf; pub mod buf;
@@ -23,15 +26,6 @@ pub mod stats;
pub mod tcp; pub mod tcp;
pub mod udp; pub mod udp;
pub const PROTO_PORT_OFFSET: &[(&str, u16)] = &[
("tcp", 0),
("udp", 0),
("wg", 1),
("ws", 1),
("wss", 2),
("faketcp", 3),
];
#[cfg(feature = "faketcp")] #[cfg(feature = "faketcp")]
pub mod fake_tcp; pub mod fake_tcp;
@@ -193,45 +187,23 @@ pub(crate) trait FromUrl {
Self: Sized; Self: Sized;
} }
pub(crate) async fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
ip_version: IpVersion,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
T::from_url(url.clone(), ip_version).await
}
fn default_port(scheme: &str) -> Option<u16> {
match scheme {
"tcp" => Some(11010),
"udp" => Some(11010),
"ws" => Some(80),
"wss" => Some(443),
"faketcp" => Some(11013),
"quic" => Some(11012),
"wg" => Some(11011),
_ => None,
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl FromUrl for SocketAddr { impl FromUrl for SocketAddr {
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> { async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
let addrs = socket_addrs(&url, || default_port(url.scheme())) let addrs = socket_addrs(&url, || {
.await (&url)
.map_err(|e| { .try_into()
TunnelError::InvalidAddr(format!( .ok()
"failed to resolve socket addr, url: {}, error: {}", .and_then(|s: TunnelScheme| s.try_into().ok())
url, e .map(IpScheme::default_port)
)) })
})?; .await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs"); tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs");
let addrs = addrs let addrs = addrs
.into_iter() .into_iter()
@@ -305,3 +277,119 @@ pub fn generate_digest_from_str(str1: &str, str2: &str, digest: &mut [u8]) {
hasher.write(&digest[..(i + 1) * 8]); hasher.write(&digest[..(i + 1) * 8]);
} }
} }
#[derive(Debug, Clone, Copy)]
struct IpSchemeAttributes {
protocol: Protocol,
port_offset: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, VariantArray)]
#[strum(serialize_all = "lowercase")]
pub enum IpScheme {
Tcp,
Udp,
#[cfg(feature = "wireguard")]
Wg,
#[cfg(feature = "quic")]
Quic,
#[cfg(feature = "websocket")]
Ws,
#[cfg(feature = "websocket")]
Wss,
#[cfg(feature = "faketcp")]
FakeTcp,
}
impl IpScheme {
const fn attributes(self) -> IpSchemeAttributes {
let (protocol, port_offset) = match self {
Self::Tcp => (Protocol::TCP, 0),
Self::Udp => (Protocol::UDP, 0),
#[cfg(feature = "wireguard")]
Self::Wg => (Protocol::UDP, 1),
#[cfg(feature = "quic")]
Self::Quic => (Protocol::UDP, 2),
#[cfg(feature = "websocket")]
Self::Ws => (Protocol::TCP, 1),
#[cfg(feature = "websocket")]
Self::Wss => (Protocol::TCP, 2),
#[cfg(feature = "faketcp")]
Self::FakeTcp => (Protocol::TCP, 3),
};
IpSchemeAttributes {
protocol,
port_offset,
}
}
pub const fn protocol(self) -> Protocol {
self.attributes().protocol
}
pub const fn port_offset(self) -> u16 {
self.attributes().port_offset
}
pub const fn default_port(self) -> u16 {
match self {
#[cfg(feature = "websocket")]
Self::Ws => 80,
#[cfg(feature = "websocket")]
Self::Wss => 443,
_ => 11010 + self.port_offset(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, EnumString, From, TryInto)]
#[strum(serialize_all = "lowercase")]
pub enum TunnelScheme {
#[strum(disabled)]
Ip(IpScheme),
#[cfg(unix)]
Unix,
// Only for connector
Http,
Https,
Ring,
Txt,
Srv,
}
impl TryFrom<&url::Url> for TunnelScheme {
type Error = Error;
fn try_from(value: &url::Url) -> Result<Self, Self::Error> {
let scheme = value.scheme();
scheme.parse().or_else(|_| {
Ok(TunnelScheme::Ip(
scheme
.parse()
.map_err(|_| Error::InvalidUrl(value.to_string()))?,
))
})
}
}
macro_rules! __matches_scheme__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::TunnelScheme::try_from(($url).as_ref()), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_scheme__ as matches_scheme;
pub fn get_protocol_by_url(l: &url::Url) -> Result<Protocol, Error> {
let TunnelScheme::Ip(scheme) = l.try_into()? else {
return Err(Error::InvalidUrl(l.to_string()));
};
Ok(scheme.protocol())
}
macro_rules! __matches_protocol__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::get_protocol_by_url($url), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_protocol__ as matches_protocol;
+26 -34
View File
@@ -8,20 +8,16 @@ use std::{
use crate::tunnel::{ use crate::tunnel::{
common::{setup_sokcet2, FramedReader, FramedWriter, TunnelWrapper}, common::{setup_sokcet2, FramedReader, FramedWriter, TunnelWrapper},
TunnelInfo, FromUrl, TunnelInfo,
}; };
use anyhow::Context; use anyhow::Context;
use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use quinn::{ use quinn::{
congestion::BbrConfig, udp::RecvMeta, AsyncUdpSocket, ClientConfig, Connection, Endpoint, congestion::BbrConfig, udp::RecvMeta, AsyncUdpSocket, ClientConfig, Connection, Endpoint,
EndpointConfig, ServerConfig, TransportConfig, UdpPoller, EndpointConfig, ServerConfig, TransportConfig, UdpPoller,
}; };
use super::{
check_scheme_and_get_socket_addr, IpVersion, Tunnel, TunnelConnector, TunnelError,
TunnelListener,
};
pub fn transport_config() -> Arc<TransportConfig> { pub fn transport_config() -> Arc<TransportConfig> {
let mut config = TransportConfig::default(); let mut config = TransportConfig::default();
@@ -145,14 +141,14 @@ impl Drop for ConnWrapper {
} }
} }
pub struct QUICTunnelListener { pub struct QuicTunnelListener {
addr: url::Url, addr: url::Url,
endpoint: Option<Endpoint>, endpoint: Option<Endpoint>,
} }
impl QUICTunnelListener { impl QuicTunnelListener {
pub fn new(addr: url::Url) -> Self { pub fn new(addr: url::Url) -> Self {
QUICTunnelListener { QuicTunnelListener {
addr, addr,
endpoint: None, endpoint: None,
} }
@@ -190,11 +186,9 @@ impl QUICTunnelListener {
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl TunnelListener for QUICTunnelListener { impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", IpVersion::Both)
.await?;
let endpoint = make_server_endpoint(addr) let endpoint = make_server_endpoint(addr)
.map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?; .map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?;
self.endpoint = Some(endpoint); self.endpoint = Some(endpoint);
@@ -223,15 +217,15 @@ impl TunnelListener for QUICTunnelListener {
} }
} }
pub struct QUICTunnelConnector { pub struct QuicTunnelConnector {
addr: url::Url, addr: url::Url,
endpoint: Option<Endpoint>, endpoint: Option<Endpoint>,
ip_version: IpVersion, ip_version: IpVersion,
} }
impl QUICTunnelConnector { impl QuicTunnelConnector {
pub fn new(addr: url::Url) -> Self { pub fn new(addr: url::Url) -> Self {
QUICTunnelConnector { QuicTunnelConnector {
addr, addr,
endpoint: None, endpoint: None,
ip_version: IpVersion::Both, ip_version: IpVersion::Both,
@@ -240,11 +234,9 @@ impl QUICTunnelConnector {
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl TunnelConnector for QUICTunnelConnector { impl TunnelConnector for QuicTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", self.ip_version)
.await?;
if addr.port() == 0 { if addr.port() == 0 {
return Err(TunnelError::InvalidAddr(format!( return Err(TunnelError::InvalidAddr(format!(
"invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)", "invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)",
@@ -318,36 +310,36 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn quic_pingpong() { async fn quic_pingpong() {
let listener = QUICTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap()); let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap()); let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
#[tokio::test] #[tokio::test]
async fn quic_bench() { async fn quic_bench() {
let listener = QUICTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap()); let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap()); let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
_tunnel_bench(listener, connector).await _tunnel_bench(listener, connector).await
} }
#[tokio::test] #[tokio::test]
async fn ipv6_pingpong() { async fn ipv6_pingpong() {
let listener = QUICTunnelListener::new("quic://[::1]:31015".parse().unwrap()); let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://[::1]:31015".parse().unwrap()); let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
#[tokio::test] #[tokio::test]
async fn ipv6_domain_pingpong() { async fn ipv6_domain_pingpong() {
let listener = QUICTunnelListener::new("quic://[::1]:31016".parse().unwrap()); let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap());
let mut connector = let mut connector =
QUICTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap()); QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
connector.set_ip_version(IpVersion::V6); connector.set_ip_version(IpVersion::V6);
_tunnel_pingpong(listener, connector).await; _tunnel_pingpong(listener, connector).await;
let listener = QUICTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap()); let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
let mut connector = let mut connector =
QUICTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap()); QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
connector.set_ip_version(IpVersion::V4); connector.set_ip_version(IpVersion::V4);
_tunnel_pingpong(listener, connector).await; _tunnel_pingpong(listener, connector).await;
} }
@@ -355,13 +347,13 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_alloc_port() { async fn test_alloc_port() {
// v4 // v4
let mut listener = QUICTunnelListener::new("quic://0.0.0.0:0".parse().unwrap()); let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
listener.listen().await.unwrap(); listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap(); let port = listener.local_url().port().unwrap();
assert!(port > 0); assert!(port > 0);
// v6 // v6
let mut listener = QUICTunnelListener::new("quic://[::]:0".parse().unwrap()); let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap());
listener.listen().await.unwrap(); listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap(); let port = listener.local_url().port().unwrap();
assert!(port > 0); assert!(port > 0);
@@ -369,7 +361,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn quic_connector_reject_port_zero() { async fn quic_connector_reject_port_zero() {
let mut connector = QUICTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap()); let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
let err = connector.connect().await.unwrap_err().to_string(); let err = connector.connect().await.unwrap_err().to_string();
assert!(err.contains("port 0"), "unexpected error: {}", err); assert!(err.contains("port 0"), "unexpected error: {}", err);
} }
+17 -28
View File
@@ -1,3 +1,5 @@
use async_ringbuf::{traits::*, AsyncHeapCons, AsyncHeapProd, AsyncHeapRb};
use crossbeam::atomic::AtomicCell;
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt::Debug, fmt::Debug,
@@ -5,9 +7,6 @@ use std::{
task::{ready, Poll}, task::{ready, Poll},
}; };
use async_ringbuf::{traits::*, AsyncHeapCons, AsyncHeapProd, AsyncHeapRb};
use crossbeam::atomic::AtomicCell;
use async_trait::async_trait; use async_trait::async_trait;
use futures::{Sink, SinkExt, Stream, StreamExt}; use futures::{Sink, SinkExt, Stream, StreamExt};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
@@ -16,15 +15,15 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use uuid::Uuid; use uuid::Uuid;
use crate::tunnel::{SinkError, SinkItem}; use crate::tunnel::{FromUrl, IpVersion, SinkError, SinkItem};
use super::{ use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::TunnelWrapper, build_url_from_socket_addr, common::TunnelWrapper, StreamItem, Tunnel, TunnelConnector,
StreamItem, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener, TunnelError, TunnelInfo, TunnelListener,
}; };
pub static RING_TUNNEL_CAP: usize = 128; pub static RING_TUNNEL_CAP: usize = 128;
static RING_TUNNEL_RESERVERD_CAP: usize = 4; static RING_TUNNEL_RESERVED_CAP: usize = 4;
type RingLock = parking_lot::Mutex<()>; type RingLock = parking_lot::Mutex<()>;
@@ -44,7 +43,7 @@ impl RingTunnel {
pub fn new(cap: usize) -> Self { pub fn new(cap: usize) -> Self {
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let ring_impl = AsyncHeapRb::new(std::cmp::max(RING_TUNNEL_RESERVERD_CAP * 2, cap)); let ring_impl = AsyncHeapRb::new(std::cmp::max(RING_TUNNEL_RESERVED_CAP * 2, cap));
let (ring_prod_impl, ring_cons_impl) = ring_impl.split(); let (ring_prod_impl, ring_cons_impl) = ring_impl.split();
Self { Self {
id, id,
@@ -120,7 +119,7 @@ impl RingSink {
pub fn try_send(&mut self, item: RingItem) -> Result<(), RingItem> { pub fn try_send(&mut self, item: RingItem) -> Result<(), RingItem> {
let base = self.ring_prod_impl.base(); let base = self.ring_prod_impl.base();
if base.occupied_len() >= base.capacity().get() - RING_TUNNEL_RESERVERD_CAP { if base.occupied_len() >= base.capacity().get() - RING_TUNNEL_RESERVED_CAP {
return Err(item); return Err(item);
} }
self.ring_prod_impl.try_push(item) self.ring_prod_impl.try_push(item)
@@ -188,7 +187,7 @@ static CONNECTION_MAP: Lazy<Arc<std::sync::Mutex<ConnectionMap>>> =
#[derive(Debug)] #[derive(Debug)]
pub struct RingTunnelListener { pub struct RingTunnelListener {
listerner_addr: url::Url, listener_addr: url::Url,
conn_sender: UnboundedSender<Arc<Connection>>, conn_sender: UnboundedSender<Arc<Connection>>,
conn_receiver: UnboundedReceiver<Arc<Connection>>, conn_receiver: UnboundedReceiver<Arc<Connection>>,
@@ -199,7 +198,7 @@ impl RingTunnelListener {
pub fn new(key: url::Url) -> Self { pub fn new(key: url::Url) -> Self {
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel(); let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
RingTunnelListener { RingTunnelListener {
listerner_addr: key, listener_addr: key,
conn_sender, conn_sender,
conn_receiver, conn_receiver,
key_in_conn_map: None, key_in_conn_map: None,
@@ -232,20 +231,15 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
} }
impl RingTunnelListener { impl RingTunnelListener {
async fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> { async fn get_addr(&self) -> Result<Uuid, TunnelError> {
check_scheme_and_get_socket_addr::<Uuid>( Uuid::from_url(self.listener_addr.clone(), IpVersion::Both).await
&self.listerner_addr,
"ring",
super::IpVersion::Both,
)
.await
} }
} }
#[async_trait] #[async_trait]
impl TunnelListener for RingTunnelListener { impl TunnelListener for RingTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
tracing::info!("listen new conn of key: {}", self.listerner_addr); tracing::info!("listen new conn of key: {}", self.listener_addr);
let addr = self.get_addr().await?; let addr = self.get_addr().await?;
CONNECTION_MAP CONNECTION_MAP
.lock() .lock()
@@ -256,11 +250,11 @@ impl TunnelListener for RingTunnelListener {
} }
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> { async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tracing::info!("waiting accept new conn of key: {}", self.listerner_addr); tracing::info!("waiting accept new conn of key: {}", self.listener_addr);
let my_addr = self.get_addr().await?; let my_addr = self.get_addr().await?;
if let Some(conn) = self.conn_receiver.recv().await { if let Some(conn) = self.conn_receiver.recv().await {
if conn.server.id == my_addr { if conn.server.id == my_addr {
tracing::info!("accept new conn of key: {}", self.listerner_addr); tracing::info!("accept new conn of key: {}", self.listener_addr);
return Ok(Box::new(get_tunnel_for_server(conn))); return Ok(Box::new(get_tunnel_for_server(conn)));
} else { } else {
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id"); tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
@@ -276,7 +270,7 @@ impl TunnelListener for RingTunnelListener {
} }
fn local_url(&self) -> url::Url { fn local_url(&self) -> url::Url {
self.listerner_addr.clone() self.listener_addr.clone()
} }
} }
@@ -301,12 +295,7 @@ impl RingTunnelConnector {
#[async_trait] #[async_trait]
impl TunnelConnector for RingTunnelConnector { impl TunnelConnector for RingTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>( let remote_addr = Uuid::from_url(self.remote_addr.clone(), IpVersion::Both).await?;
&self.remote_addr,
"ring",
super::IpVersion::Both,
)
.await?;
let entry = CONNECTION_MAP let entry = CONNECTION_MAP
.lock() .lock()
.unwrap() .unwrap()
+5 -11
View File
@@ -1,14 +1,12 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::setup_sokcet2;
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::FuturesUnordered; use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
use super::TunnelInfo;
use crate::tunnel::common::setup_sokcet2;
use super::{ use super::{
check_scheme_and_get_socket_addr,
common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper}, common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper},
IpVersion, Tunnel, TunnelError, TunnelListener, IpVersion, Tunnel, TunnelError, TunnelListener,
}; };
@@ -58,9 +56,7 @@ impl TcpTunnelListener {
impl TunnelListener for TcpTunnelListener { impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None; self.listener = None;
let addr = let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", IpVersion::Both)
.await?;
let socket2_socket = socket2::Socket::new( let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr), socket2::Domain::for_address(addr),
@@ -189,10 +185,8 @@ impl TcpTunnelConnector {
#[async_trait] #[async_trait]
impl super::TunnelConnector for TcpTunnelConnector { impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", self.ip_version)
.await?;
if self.bind_addrs.is_empty() { if self.bind_addrs.is_empty() {
self.connect_with_default_bind(addr).await self.connect_with_default_bind(addr).await
} else { } else {
+18 -31
View File
@@ -21,7 +21,13 @@ use tokio::{
use tracing::{instrument, Instrument}; use tracing::{instrument, Instrument};
use super::{packet_def::V6HolePunchPacket, TunnelInfo}; use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
packet_def::{UDPTunnelHeader, V6HolePunchPacket, UDP_TUNNEL_HEADER_SIZE},
ring::{RingSink, RingStream},
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl,
};
use crate::{ use crate::{
common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap}, common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap},
tunnel::{ tunnel::{
@@ -32,13 +38,6 @@ use crate::{
}, },
}; };
use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
packet_def::{UDPTunnelHeader, UDP_TUNNEL_HEADER_SIZE},
ring::{RingSink, RingStream},
IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelListener, TunnelUrl,
};
pub const UDP_DATA_MTU: usize = 2000; pub const UDP_DATA_MTU: usize = 2000;
type UdpCloseEventSender = UnboundedSender<(SocketAddr, Option<TunnelError>)>; type UdpCloseEventSender = UnboundedSender<(SocketAddr, Option<TunnelError>)>;
@@ -149,11 +148,11 @@ async fn respond_stun_packet(
req_buf: Vec<u8>, req_buf: Vec<u8>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
use crate::common::stun_codec_ext::*; use crate::common::stun_codec_ext::*;
use bytecodec::DecodeExt as _; use bytecodec::{DecodeExt as _, EncodeExt as _};
use bytecodec::EncodeExt as _; use stun_codec::{
use stun_codec::rfc5389::attributes::XorMappedAddress; rfc5389::{attributes::XorMappedAddress, methods::BINDING},
use stun_codec::rfc5389::methods::BINDING; Message, MessageClass, MessageDecoder, MessageEncoder,
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder}; };
let mut decoder = MessageDecoder::<Attribute>::new(); let mut decoder = MessageDecoder::<Attribute>::new();
let req_msg = decoder let req_msg = decoder
@@ -532,13 +531,8 @@ impl UdpTunnelListener {
#[async_trait] #[async_trait]
impl TunnelListener for UdpTunnelListener { impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>( let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
&self.addr,
"udp",
IpVersion::Both,
)
.await?;
let socket2_socket = socket2::Socket::new( let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr), socket2::Domain::for_address(addr),
@@ -851,13 +845,8 @@ impl UdpTunnelConnector {
#[async_trait] #[async_trait]
impl super::TunnelConnector for UdpTunnelConnector { impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>( let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
&self.addr,
"udp",
self.ip_version,
)
.await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() { if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await self.connect_with_default_bind(addr).await
} else { } else {
@@ -889,7 +878,6 @@ mod tests {
use crate::{ use crate::{
common::global_ctx::tests::get_mock_global_ctx, common::global_ctx::tests::get_mock_global_ctx,
tunnel::{ tunnel::{
check_scheme_and_get_socket_addr,
common::{ common::{
get_interface_name_by_ip, get_interface_name_by_ip,
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong, wait_for_condition}, tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong, wait_for_condition},
@@ -1034,9 +1022,8 @@ mod tests {
for ip in ips { for ip in ips {
println!("bind to ip: {}, {:?}", ip, bind_dev); println!("bind to ip: {}, {:?}", ip, bind_dev);
let addr = check_scheme_and_get_socket_addr::<SocketAddr>( let addr = SocketAddr::from_url(
&format!("udp://{}:11111", ip).parse().unwrap(), format!("udp://{}:11111", ip).parse().unwrap(),
"udp",
IpVersion::Both, IpVersion::Both,
) )
.await .await
+29 -29
View File
@@ -1,10 +1,20 @@
use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
};
use crate::{proto::common::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
use anyhow::Context; use anyhow::Context;
use bytes::BytesMut; use bytes::BytesMut;
use forwarded_header_value::ForwardedHeaderValue; use forwarded_header_value::ForwardedHeaderValue;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use pnet::ipnetwork::IpNetwork; use pnet::ipnetwork::IpNetwork;
use std::sync::LazyLock; use std::{
use std::{net::SocketAddr, sync::Arc, time::Duration}; net::SocketAddr,
sync::{Arc, LazyLock},
time::Duration,
};
use tokio::{ use tokio::{
net::{TcpListener, TcpSocket, TcpStream}, net::{TcpListener, TcpSocket, TcpStream},
time::timeout, time::timeout,
@@ -14,16 +24,6 @@ use tokio_util::either::Either;
use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message, ServerBuilder}; use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message, ServerBuilder};
use zerocopy::AsBytes; use zerocopy::AsBytes;
use super::TunnelInfo;
use crate::tunnel::insecure_tls::get_insecure_tls_client_config;
use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
};
fn is_wss(addr: &url::Url) -> Result<bool, TunnelError> { fn is_wss(addr: &url::Url) -> Result<bool, TunnelError> {
match addr.scheme() { match addr.scheme() {
"ws" => Ok(false), "ws" => Ok(false),
@@ -77,14 +77,14 @@ static TRUSTED_PROXIES: LazyLock<Vec<IpNetwork>> = LazyLock::new(|| {
}); });
#[derive(Debug)] #[derive(Debug)]
pub struct WSTunnelListener { pub struct WsTunnelListener {
addr: url::Url, addr: url::Url,
listener: Option<TcpListener>, listener: Option<TcpListener>,
} }
impl WSTunnelListener { impl WsTunnelListener {
pub fn new(addr: url::Url) -> Self { pub fn new(addr: url::Url) -> Self {
WSTunnelListener { WsTunnelListener {
addr, addr,
listener: None, listener: None,
} }
@@ -159,7 +159,7 @@ impl WSTunnelListener {
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl TunnelListener for WSTunnelListener { impl TunnelListener for WsTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new( let socket2_socket = socket2::Socket::new(
@@ -199,16 +199,16 @@ impl TunnelListener for WSTunnelListener {
} }
} }
pub struct WSTunnelConnector { pub struct WsTunnelConnector {
addr: url::Url, addr: url::Url,
ip_version: IpVersion, ip_version: IpVersion,
bind_addrs: Vec<SocketAddr>, bind_addrs: Vec<SocketAddr>,
} }
impl WSTunnelConnector { impl WsTunnelConnector {
pub fn new(addr: url::Url) -> Self { pub fn new(addr: url::Url) -> Self {
WSTunnelConnector { WsTunnelConnector {
addr, addr,
ip_version: IpVersion::Both, ip_version: IpVersion::Both,
@@ -307,8 +307,8 @@ impl WSTunnelConnector {
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl TunnelConnector for WSTunnelConnector { impl TunnelConnector for WsTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() { if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await self.connect_with_default_bind(addr).await
@@ -340,9 +340,9 @@ pub mod tests {
#[tokio::test] #[tokio::test]
#[serial_test::serial] #[serial_test::serial]
async fn ws_pingpong(#[values("ws", "wss")] proto: &str) { async fn ws_pingpong(#[values("ws", "wss")] proto: &str) {
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap()); let listener = WsTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap());
let connector = let connector =
WSTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap()); WsTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap());
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
@@ -350,9 +350,9 @@ pub mod tests {
#[tokio::test] #[tokio::test]
#[serial_test::serial] #[serial_test::serial]
async fn ws_pingpong_bind(#[values("ws", "wss")] proto: &str) { async fn ws_pingpong_bind(#[values("ws", "wss")] proto: &str) {
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap()); let listener = WsTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap());
let mut connector = let mut connector =
WSTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap()); WsTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
@@ -371,16 +371,16 @@ pub mod tests {
#[tokio::test] #[tokio::test]
async fn ws_accept_wss() { async fn ws_accept_wss() {
let mut listener = WSTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap()); let mut listener = WsTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap());
listener.listen().await.unwrap(); listener.listen().await.unwrap();
let j = tokio::spawn(async move { let j = tokio::spawn(async move {
let _ = listener.accept().await; let _ = listener.accept().await;
}); });
let mut connector = WSTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap()); let mut connector = WsTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap());
connector.connect().await.unwrap_err(); connector.connect().await.unwrap_err();
let mut connector = WSTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap()); let mut connector = WsTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap());
connector.connect().await.unwrap(); connector.connect().await.unwrap();
j.abort(); j.abort();
@@ -388,7 +388,7 @@ pub mod tests {
#[tokio::test] #[tokio::test]
async fn ws_forwarded() { async fn ws_forwarded() {
let mut listener = WSTunnelListener::new("ws://127.0.0.1:25559".parse().unwrap()); let mut listener = WsTunnelListener::new("ws://127.0.0.1:25559".parse().unwrap());
listener.listen().await.unwrap(); listener.listen().await.unwrap();
let server_task = tokio::spawn(async move { let server_task = tokio::spawn(async move {
+20 -23
View File
@@ -20,7 +20,14 @@ use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore; use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use super::TunnelInfo; use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
ring::create_ring_tunnel_pair,
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
};
use crate::{ use crate::{
common::shrink_dashmap, common::shrink_dashmap,
tunnel::{ tunnel::{
@@ -30,15 +37,6 @@ use crate::{
}, },
}; };
use super::{
check_scheme_and_get_socket_addr,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
ring::create_ring_tunnel_pair,
IpVersion, Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream,
};
const MAX_PACKET: usize = 2048; const MAX_PACKET: usize = 2048;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -202,7 +200,10 @@ impl WgPeerData {
match self.udp.send_to(packet, self.endpoint).await { match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); tracing::error!(
"Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}",
e
);
return; return;
} }
}; };
@@ -214,7 +215,10 @@ impl WgPeerData {
match self.udp.send_to(packet, self.endpoint).await { match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); tracing::error!(
"Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}",
e
);
break; break;
} }
}; };
@@ -550,10 +554,8 @@ impl WgTunnelListener {
#[async_trait] #[async_trait]
impl TunnelListener for WgTunnelListener { impl TunnelListener for WgTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> { async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg", IpVersion::Both)
.await?;
let socket2_socket = socket2::Socket::new( let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr), socket2::Domain::for_address(addr),
socket2::Type::DGRAM, socket2::Type::DGRAM,
@@ -705,13 +707,8 @@ impl WgTunnelConnector {
#[async_trait] #[async_trait]
impl super::TunnelConnector for WgTunnelConnector { impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument] #[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>( let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
&self.addr,
"wg",
self.ip_version,
)
.await?;
if addr.is_ipv6() { if addr.is_ipv6() {
return self.connect_with_ipv6(addr).await; return self.connect_with_ipv6(addr).await;
+1 -2
View File
@@ -36,8 +36,7 @@ thread_local! {
} }
pub fn setup_panic_handler() { pub fn setup_panic_handler() {
use std::backtrace; use std::{backtrace, io::Write};
use std::io::Write;
std::panic::set_hook(Box::new(|info| { std::panic::set_hook(Box::new(|info| {
let mut stderr = std::io::stderr(); let mut stderr = std::io::stderr();
let sep = format!("{}\n", "=======".repeat(10)); let sep = format!("{}\n", "=======".repeat(10));