refactor: listener/connector protocol abstraction (#2026)

* fix listener protocol detection
* replace IpProtocol with IpNextHeaderProtocol
* use an enum to gather all listener schemes
* rename ListenerScheme to TunnelScheme; replace IpNextHeaderProtocols with socket2::Protocol
* move TunnelScheme to tunnel
* add IpScheme, simplify connector creation
* format; fix some typos; remove check_scheme_...;
* remove PROTO_PORT_OFFSET
* rename WSTunnel.. -> WsTunnel.., DNSTunnel.. -> DnsTunnel..
This commit is contained in:
Luna Yao
2026-04-04 04:55:58 +02:00
committed by GitHub
parent 9cc617ae4c
commit e91a0da70a
18 changed files with 481 additions and 526 deletions
+21 -23
View File
@@ -1,8 +1,7 @@
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::Hasher,
net::{IpAddr, SocketAddr},
sync::{Arc, Mutex},
time::{SystemTime, UNIX_EPOCH},
};
@@ -10,21 +9,6 @@ use std::{
use arc_swap::ArcSwap;
use dashmap::DashMap;
use crate::common::config::ProxyNetworkConfig;
use crate::common::shrink_dashmap;
use crate::common::stats_manager::StatsManager;
use crate::common::token_bucket::TokenBucketManager;
use crate::peers::acl_filter::AclFilter;
use crate::peers::credential_manager::CredentialManager;
use crate::proto::acl::GroupIdentity;
use crate::proto::api::config::InstanceConfigPatch;
use crate::proto::api::instance::PeerConnInfo;
use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb};
use crate::proto::peer_rpc::PeerGroupInfo;
use crossbeam::atomic::AtomicCell;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use super::{
config::{ConfigLoader, Flags},
netns::NetNS,
@@ -32,6 +16,24 @@ use super::{
stun::{StunInfoCollector, StunInfoCollectorTrait},
PeerId,
};
use crate::{
common::{
config::ProxyNetworkConfig, shrink_dashmap, stats_manager::StatsManager,
token_bucket::TokenBucketManager,
},
peers::{acl_filter::AclFilter, credential_manager::CredentialManager},
proto::{
acl::GroupIdentity,
api::{config::InstanceConfigPatch, instance::PeerConnInfo},
common::{PeerFeatureFlag, PortForwardConfigPb},
peer_rpc::PeerGroupInfo,
},
tunnel::matches_protocol,
};
use crossbeam::atomic::AtomicCell;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use socket2::Protocol;
pub type NetworkIdentity = crate::common::config::NetworkIdentity;
@@ -625,15 +627,11 @@ impl GlobalCtx {
}
fn is_port_in_running_listeners(&self, port: u16, is_udp: bool) -> bool {
let check_proto = |listener_proto: &str| {
let listener_is_udp = matches!(listener_proto, "udp" | "wg");
listener_is_udp == is_udp
};
self.running_listeners
.lock()
.unwrap()
.iter()
.any(|x| x.port() == Some(port) && check_proto(x.scheme()))
.any(|x| x.port() == Some(port) && matches_protocol!(x, Protocol::UDP) == is_udp)
}
#[tracing::instrument(ret, skip(self))]
+23 -23
View File
@@ -31,19 +31,20 @@ use crate::{
},
rpc_types::controller::BaseController,
},
tunnel::{udp::UdpTunnelConnector, IpVersion},
tunnel::{matches_protocol, udp::UdpTunnelConnector, IpVersion},
use_global_var,
};
use anyhow::Context;
use rand::Rng;
use tokio::{net::UdpSocket, task::JoinSet, time::timeout};
use url::Host;
use super::{
create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer,
udp_hole_punch,
};
use crate::tunnel::{matches_scheme, FromUrl, IpScheme, TunnelScheme};
use anyhow::Context;
use rand::Rng;
use socket2::Protocol;
use tokio::{net::UdpSocket, task::JoinSet, time::timeout};
use url::Host;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
@@ -189,9 +190,7 @@ impl DirectConnectorManagerData {
.await;
let udp_connector = UdpTunnelConnector::new(remote_url.clone());
let remote_addr =
super::check_scheme_and_get_socket_addr::<SocketAddr>(remote_url, "udp", IpVersion::V6)
.await?;
let remote_addr = SocketAddr::from_url(remote_url.clone(), IpVersion::V6).await?;
let ret = udp_connector
.try_connect_with_socket(local_socket, remote_addr)
.await?;
@@ -205,18 +204,19 @@ impl DirectConnectorManagerData {
async fn do_try_connect_to_ip(&self, dst_peer_id: PeerId, addr: String) -> Result<(), Error> {
let connector = create_connector_by_url(&addr, &self.global_ctx, IpVersion::Both).await?;
let remote_url = connector.remote_url();
let (peer_id, conn_id) =
if remote_url.scheme() == "udp" && matches!(remote_url.host(), Some(Host::Ipv6(_))) {
self.connect_to_public_ipv6(dst_peer_id, &remote_url)
.await?
} else {
timeout(
std::time::Duration::from_secs(3),
self.peer_manager
.try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)),
)
.await??
};
let (peer_id, conn_id) = if matches_scheme!(remote_url, TunnelScheme::Ip(IpScheme::Udp))
&& matches!(remote_url.host(), Some(Host::Ipv6(_)))
{
self.connect_to_public_ipv6(dst_peer_id, &remote_url)
.await?
} else {
timeout(
std::time::Duration::from_secs(3),
self.peer_manager
.try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)),
)
.await??
};
if peer_id != dst_peer_id && !TESTING.load(Ordering::Relaxed) {
tracing::info!(
@@ -306,7 +306,7 @@ impl DirectConnectorManagerData {
let listener_host = addrs.pop();
tracing::info!(?listener_host, ?listener, "try direct connect to peer");
let is_udp = matches!(listener.scheme(), "udp" | "wg");
let is_udp = matches_protocol!(listener, Protocol::UDP);
// Snapshot running listeners once; used for cheap port pre-checks before the
// expensive should_deny_proxy call (which binds a socket per IP) in the
// unspecified-address expansion loops below.
@@ -314,7 +314,7 @@ impl DirectConnectorManagerData {
let port_has_local_listener = |port: u16| -> bool {
local_listeners
.iter()
.any(|l| l.port() == Some(port) && (matches!(l.scheme(), "udp" | "wg") == is_udp))
.any(|l| l.port() == Some(port) && matches_protocol!(l, Protocol::UDP) == is_udp)
};
match listener_host {
+38 -40
View File
@@ -1,5 +1,6 @@
use std::{net::SocketAddr, sync::Arc};
use super::{create_connector_by_url, http_connector::TunnelWithInfo};
use crate::{
common::{
dns::{resolve_txt_record, RESOLVER},
@@ -7,16 +8,15 @@ use crate::{
global_ctx::ArcGlobalCtx,
log,
},
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, PROTO_PORT_OFFSET},
proto::common::TunnelInfo,
tunnel::{IpScheme, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelScheme},
};
use anyhow::Context;
use dashmap::DashSet;
use hickory_resolver::proto::rr::rdata::SRV;
use itertools::Itertools;
use rand::{seq::SliceRandom, Rng as _};
use crate::proto::common::TunnelInfo;
use super::{create_connector_by_url, http_connector::TunnelWithInfo};
use strum::VariantArray;
fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> {
let total_weight = options.iter().map(|(_, weight)| *weight).sum();
@@ -35,16 +35,18 @@ fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> {
}
#[derive(Debug)]
pub struct DNSTunnelConnector {
pub struct DnsTunnelConnector {
scheme: TunnelScheme,
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
global_ctx: ArcGlobalCtx,
ip_version: IpVersion,
}
impl DNSTunnelConnector {
impl DnsTunnelConnector {
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
Self {
scheme: (&addr).try_into().unwrap(),
addr,
bind_addrs: Vec::new(),
global_ctx,
@@ -82,7 +84,7 @@ impl DNSTunnelConnector {
Ok(connector)
}
fn handle_one_srv_record(record: &SRV, protocol: &str) -> Result<(url::Url, u64), Error> {
fn handle_one_srv_record(record: &SRV, protocol: IpScheme) -> Result<(url::Url, u64), Error> {
// port must be non-zero
if record.port() == 0 {
return Err(anyhow::anyhow!("port must be non-zero").into());
@@ -112,15 +114,15 @@ impl DNSTunnelConnector {
) -> Result<Box<dyn TunnelConnector>, Error> {
tracing::info!("handle_srv_record: {}", domain_name);
let srv_domains = PROTO_PORT_OFFSET
let srv_domains = IpScheme::VARIANTS
.iter()
.map(|(p, _)| (format!("_easytier._{}.{}", p, domain_name), *p)) // _easytier._udp.{domain_name}
.collect::<Vec<_>>();
.map(|s| (s, format!("_easytier._{}.{}", s, domain_name)))
.collect_vec();
tracing::info!("build srv_domains: {:?}", srv_domains);
let responses = Arc::new(DashSet::new());
let srv_lookup_tasks = srv_domains
.iter()
.map(|(srv_domain, protocol)| {
.map(|(protocol, srv_domain)| {
let resolver = RESOLVER.clone();
let responses = responses.clone();
async move {
@@ -129,7 +131,7 @@ impl DNSTunnelConnector {
})?;
tracing::info!(?response, ?srv_domain, "srv_lookup response");
for record in response.iter() {
let parsed_record = Self::handle_one_srv_record(record, protocol);
let parsed_record = Self::handle_one_srv_record(record, **protocol);
tracing::info!(?parsed_record, ?srv_domain, "parsed_record");
if let Err(e) = &parsed_record {
log::warn!("got invalid srv record {:?}", e);
@@ -162,32 +164,28 @@ impl DNSTunnelConnector {
}
#[async_trait::async_trait]
impl super::TunnelConnector for DNSTunnelConnector {
impl super::TunnelConnector for DnsTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let mut conn = if self.addr.scheme() == "txt" {
self.handle_txt_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in txt url"))?,
)
.await
.with_context(|| "get txt record url failed")?
} else if self.addr.scheme() == "srv" {
self.handle_srv_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in srv url"))?,
)
.await
.with_context(|| "get srv record url failed")?
} else {
return Err(anyhow::anyhow!(
"unsupported dns scheme: {}, expecting txt or srv",
self.addr.scheme()
)
.into());
let mut conn = match self.scheme {
TunnelScheme::Txt => self
.handle_txt_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in txt url"))?,
)
.await
.with_context(|| "get txt record url failed")?,
TunnelScheme::Srv => self
.handle_srv_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in srv url"))?,
)
.await
.with_context(|| "get srv record url failed")?,
_ => return Err(anyhow::anyhow!("unsupported dns scheme: {:?}", self.scheme).into()),
};
let t = conn.connect().await?;
let info = t.info().unwrap_or_default();
@@ -227,7 +225,7 @@ mod tests {
async fn test_txt() {
let url = "txt://txt.easytier.cn";
let global_ctx = get_mock_global_ctx();
let mut connector = DNSTunnelConnector::new(url.parse().unwrap(), global_ctx);
let mut connector = DnsTunnelConnector::new(url.parse().unwrap(), global_ctx);
connector.set_ip_version(IpVersion::V4);
for _ in 0..5 {
match connector.connect().await {
@@ -246,7 +244,7 @@ mod tests {
async fn test_srv() {
let url = "srv://easytier.cn";
let global_ctx = get_mock_global_ctx();
let mut connector = DNSTunnelConnector::new(url.parse().unwrap(), global_ctx);
let mut connector = DnsTunnelConnector::new(url.parse().unwrap(), global_ctx);
connector.set_ip_version(IpVersion::V4);
for _ in 0..5 {
match connector.connect().await {
+41 -116
View File
@@ -3,24 +3,17 @@ use std::{
sync::Arc,
};
use http_connector::HttpTunnelConnector;
#[cfg(feature = "faketcp")]
use crate::tunnel::fake_tcp::FakeTcpTunnelConnector;
#[cfg(feature = "quic")]
use crate::tunnel::quic::QUICTunnelConnector;
#[cfg(unix)]
use crate::tunnel::unix::UnixSocketTunnelConnector;
#[cfg(feature = "wireguard")]
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector},
connector::dns_connector::DnsTunnelConnector,
proto::common::PeerFeatureFlag,
tunnel::{
check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector,
udp::UdpTunnelConnector, IpVersion, TunnelConnector,
self, ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, FromUrl,
IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme,
},
utils::BoxExt,
};
use http_connector::HttpTunnelConnector;
pub mod direct;
pub mod manual;
@@ -90,84 +83,34 @@ pub async fn create_connector_by_url(
) -> Result<Box<dyn TunnelConnector + 'static>, Error> {
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
let url = idn::convert_idn_to_ascii(url)?;
let mut connector: Box<dyn TunnelConnector + 'static> = match url.scheme() {
"tcp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp", ip_version).await?;
let mut connector = TcpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
"udp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp", ip_version).await?;
let mut connector = UdpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
"http" | "https" => {
let connector = HttpTunnelConnector::new(url, global_ctx.clone());
Box::new(connector)
}
"ring" => {
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring", IpVersion::Both).await?;
let connector = RingTunnelConnector::new(url);
Box::new(connector)
}
#[cfg(feature = "quic")]
"quic" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic", ip_version).await?;
let mut connector = QUICTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(feature = "wireguard")]
"wg" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg", ip_version).await?;
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
let mut connector = WgTunnelConnector::new(url, wg_config);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(feature = "websocket")]
"ws" | "wss" => {
use crate::tunnel::FromUrl;
let scheme = (&url)
.try_into()
.map_err(|_| TunnelError::InvalidProtocol(url.scheme().to_owned()))?;
let mut connector: Box<dyn TunnelConnector + 'static> = match scheme {
TunnelScheme::Ip(scheme) => {
let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?;
let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url);
let mut connector: Box<dyn TunnelConnector> = match scheme {
IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(),
IpScheme::Udp => UdpTunnelConnector::new(url).boxed(),
#[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelConnector::new(url).boxed(),
#[cfg(feature = "wireguard")]
IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
WgTunnelConnector::new(url, wg_config).boxed()
}
#[cfg(feature = "websocket")]
IpScheme::Ws | IpScheme::Wss => {
tunnel::websocket::WsTunnelConnector::new(url).boxed()
}
#[cfg(feature = "faketcp")]
IpScheme::FakeTcp => tunnel::fake_tcp::FakeTcpTunnelConnector::new(url).boxed(),
};
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
@@ -176,40 +119,22 @@ pub async fn create_connector_by_url(
)
.await;
}
Box::new(connector)
connector
}
"txt" | "srv" => {
#[cfg(unix)]
TunnelScheme::Unix => tunnel::unix::UnixSocketTunnelConnector::new(url).boxed(),
TunnelScheme::Http | TunnelScheme::Https => {
HttpTunnelConnector::new(url, global_ctx.clone()).boxed()
}
TunnelScheme::Ring => RingTunnelConnector::new(url).boxed(),
TunnelScheme::Txt | TunnelScheme::Srv => {
if url.host_str().is_none() {
return Err(Error::InvalidUrl(format!(
"host should not be empty in txt or srv url: {}",
url
)));
}
let connector = dns_connector::DNSTunnelConnector::new(url, global_ctx.clone());
Box::new(connector)
}
#[cfg(feature = "faketcp")]
"faketcp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "faketcp", ip_version).await?;
let mut connector = FakeTcpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(unix)]
"unix" => {
let connector = UnixSocketTunnelConnector::new(url);
Box::new(connector)
}
_ => {
return Err(Error::InvalidUrl(url.into()));
DnsTunnelConnector::new(url, global_ctx.clone()).boxed()
}
};
connector.set_ip_version(ip_version);
+14 -17
View File
@@ -22,7 +22,6 @@ use crate::{
launcher::add_proxy_network_to_config,
proto::common::{CompressionAlgoPb, SecureModeConfig},
rpc_service::ApiRpcServer,
tunnel::PROTO_PORT_OFFSET,
utils::setup_panic_handler,
web_client, ShellType,
};
@@ -30,8 +29,10 @@ use anyhow::Context;
use cidr::IpCidr;
use clap::{CommandFactory, Parser};
use rust_i18n::t;
use strum::VariantArray;
use tokio::io::AsyncReadExt;
use crate::tunnel::IpScheme;
#[cfg(feature = "jemalloc-prof")]
use jemalloc_ctl::{epoch, stats, Access as _, AsName as _};
@@ -742,8 +743,12 @@ impl Cli {
let mut listeners: Vec<String> = Vec::new();
if origin_listeners.len() == 1 {
if let Ok(port) = origin_listeners[0].parse::<u16>() {
for (proto, offset) in PROTO_PORT_OFFSET {
listeners.push(format!("{}://0.0.0.0:{}", proto, port + *offset));
for proto in IpScheme::VARIANTS {
listeners.push(format!(
"{}://0.0.0.0:{}",
proto,
port + proto.port_offset()
));
}
return Ok(listeners);
}
@@ -758,20 +763,15 @@ impl Cli {
panic!("failed to parse listener: {}", l);
}
} else {
let Some((proto, offset)) = PROTO_PORT_OFFSET
.iter()
.find(|(proto, _)| *proto == proto_port[0])
else {
return Err(anyhow::anyhow!("unknown protocol: {}", proto_port[0]));
};
let scheme: IpScheme = proto_port[0].parse()?;
let port = if proto_port.len() == 2 {
proto_port[1].parse::<u16>().unwrap()
} else {
11010 + offset
11010 + scheme.port_offset()
};
listeners.push(format!("{}://0.0.0.0:{}", proto, port));
listeners.push(format!("{}://0.0.0.0:{}", scheme, port));
}
}
@@ -1134,8 +1134,7 @@ impl LoggingConfigLoader for &LoggingOptions {
#[cfg(target_os = "windows")]
fn win_service_set_work_dir(service_name: &std::ffi::OsString) -> anyhow::Result<()> {
use crate::common::constants::WIN_SERVICE_WORK_DIR_REG_KEY;
use winreg::enums::*;
use winreg::RegKey;
use winreg::{enums::*, RegKey};
let hklm = RegKey::predef(HKEY_LOCAL_MACHINE);
let key = hklm.open_subkey_with_flags(WIN_SERVICE_WORK_DIR_REG_KEY, KEY_READ)?;
@@ -1215,11 +1214,9 @@ fn parse_cli() -> Cli {
#[cfg(target_os = "windows")]
fn win_service_main(arg: Vec<std::ffi::OsString>) {
use std::sync::Arc;
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use tokio::sync::Notify;
use windows_service::service::*;
use windows_service::service_control_handler::*;
use windows_service::{service::*, service_control_handler::*};
_ = win_service_set_work_dir(&arg[0]);
+37 -45
View File
@@ -9,12 +9,6 @@ use anyhow::Context;
use async_trait::async_trait;
use tokio::task::JoinSet;
#[cfg(feature = "faketcp")]
use crate::tunnel::fake_tcp::FakeTcpTunnelListener;
#[cfg(feature = "quic")]
use crate::tunnel::quic::QUICTunnelListener;
#[cfg(feature = "wireguard")]
use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
use crate::{
common::{
error::Error,
@@ -23,44 +17,42 @@ use crate::{
},
peers::peer_manager::PeerManager,
tunnel::{
ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, Tunnel,
TunnelListener,
self, ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, IpScheme,
Tunnel, TunnelListener, TunnelScheme,
},
utils::BoxExt,
};
pub fn get_listener_by_url(
pub fn create_listener_by_url(
l: &url::Url,
_ctx: ArcGlobalCtx,
#[allow(unused_variables)] ctx: ArcGlobalCtx,
) -> Result<Box<dyn TunnelListener>, Error> {
Ok(match l.scheme() {
"tcp" => Box::new(TcpTunnelListener::new(l.clone())),
"udp" => Box::new(UdpTunnelListener::new(l.clone())),
#[cfg(feature = "wireguard")]
"wg" => {
let nid = _ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
Box::new(WgTunnelListener::new(l.clone(), wg_config))
}
#[cfg(feature = "quic")]
"quic" => Box::new(QUICTunnelListener::new(l.clone())),
#[cfg(feature = "websocket")]
"ws" | "wss" => {
use crate::tunnel::websocket::WSTunnelListener;
Box::new(WSTunnelListener::new(l.clone()))
}
#[cfg(feature = "faketcp")]
"faketcp" => Box::new(FakeTcpTunnelListener::new(l.clone())),
Ok(match l.try_into()? {
TunnelScheme::Ip(scheme) => match scheme {
IpScheme::Tcp => TcpTunnelListener::new(l.clone()).boxed(),
IpScheme::Udp => UdpTunnelListener::new(l.clone()).boxed(),
#[cfg(feature = "wireguard")]
IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
let nid = ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
WgTunnelListener::new(l.clone(), wg_config).boxed()
}
#[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelListener::new(l.clone()).boxed(),
#[cfg(feature = "websocket")]
IpScheme::Ws | IpScheme::Wss => {
tunnel::websocket::WsTunnelListener::new(l.clone()).boxed()
}
#[cfg(feature = "faketcp")]
IpScheme::FakeTcp => tunnel::fake_tcp::FakeTcpTunnelListener::new(l.clone()).boxed(),
},
#[cfg(unix)]
"unix" => {
use crate::tunnel::unix::UnixSocketTunnelListener;
Box::new(UnixSocketTunnelListener::new(l.clone()))
}
_ => {
return Err(Error::InvalidUrl(l.to_string()));
}
TunnelScheme::Unix => tunnel::unix::UnixSocketTunnelListener::new(l.clone()).boxed(),
_ => return Err(Error::InvalidUrl(l.to_string())),
})
}
@@ -133,7 +125,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
for l in self.global_ctx.config.get_listener_uris().iter() {
let l = l.clone();
let Ok(_) = get_listener_by_url(&l, self.global_ctx.clone()) else {
let Ok(_) = create_listener_by_url(&l, self.global_ctx.clone()) else {
let msg = format!("failed to get listener by url: {}, maybe not supported", l);
self.global_ctx
.issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg));
@@ -143,7 +135,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
let listener = l.clone();
self.add_listener(
move || get_listener_by_url(&listener, ctx.clone()).unwrap(),
move || create_listener_by_url(&listener, ctx.clone()).unwrap(),
true,
)
.await?;
@@ -160,7 +152,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
.with_context(|| format!("failed to set ipv6 host for listener: {}", l))?;
let ctx = self.global_ctx.clone();
self.add_listener(
move || get_listener_by_url(&ipv6_listener, ctx.clone()).unwrap(),
move || create_listener_by_url(&ipv6_listener, ctx.clone()).unwrap(),
false,
)
.await?;
@@ -361,10 +353,6 @@ mod tests {
#[async_trait::async_trait]
impl TunnelListener for MockListener {
fn local_url(&self) -> url::Url {
"mock://".parse().unwrap()
}
async fn listen(&mut self) -> Result<(), TunnelError> {
self.counter.fetch_add(1, Ordering::Relaxed);
Ok(())
@@ -374,6 +362,10 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Err(TunnelError::BufferFull)
}
fn local_url(&self) -> url::Url {
"mock://".parse().unwrap()
}
}
impl Drop for MockListener {
+3 -3
View File
@@ -2003,7 +2003,7 @@ mod tests {
create_connector_by_url, direct::PeerManagerForDirectConnector,
udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
},
instance::listeners::get_listener_by_url,
instance::listeners::create_listener_by_url,
peers::{
create_packet_recv_chan,
peer_conn::tests::set_secure_mode_cfg,
@@ -2783,7 +2783,7 @@ mod tests {
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
register_service(&peer_mgr_c.peer_rpc_mgr, "", 0, "hello c");
let mut listener1 = get_listener_by_url(
let mut listener1 = create_listener_by_url(
&format!("{}://0.0.0.0:31013", proto1).parse().unwrap(),
peer_mgr_b.get_global_ctx(),
)
@@ -2802,7 +2802,7 @@ mod tests {
.await
.unwrap();
let mut listener2 = get_listener_by_url(
let mut listener2 = create_listener_by_url(
&format!("{}://0.0.0.0:31014", proto2).parse().unwrap(),
peer_mgr_c.get_global_ctx(),
)
+2 -2
View File
@@ -156,14 +156,14 @@ async fn init_three_node_ex_with_inst3<F: Fn(TomlConfigLoader) -> TomlConfigLoad
#[cfg(feature = "websocket")]
inst1
.get_conn_manager()
.add_connector(crate::tunnel::websocket::WSTunnelConnector::new(
.add_connector(crate::tunnel::websocket::WsTunnelConnector::new(
"ws://10.1.1.2:11011".parse().unwrap(),
));
} else if proto == "wss" {
#[cfg(feature = "websocket")]
inst1
.get_conn_manager()
.add_connector(crate::tunnel::websocket::WSTunnelConnector::new(
.add_connector(crate::tunnel::websocket::WsTunnelConnector::new(
"wss://10.1.1.2:11012".parse().unwrap(),
));
}
+21 -30
View File
@@ -2,20 +2,28 @@ mod netfilter;
mod packet;
mod stack;
use std::net::{IpAddr, Ipv4Addr, UdpSocket};
use std::sync::Arc;
use std::{net::SocketAddr, pin::Pin};
use bytes::BytesMut;
use futures::{Sink, Stream};
use network_interface::NetworkInterfaceConfig;
use pnet::util::MacAddr;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
use crate::common::scoped_task::ScopedTask;
use crate::tunnel::fake_tcp::netfilter::create_tun;
use crate::tunnel::{common::TunnelWrapper, Tunnel, TunnelError, TunnelInfo, TunnelListener};
use crate::{
common::scoped_task::ScopedTask,
tunnel::{
common::TunnelWrapper,
fake_tcp::netfilter::create_tun,
packet_def::{ZCPacket, ZCPacketType, PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE},
FromUrl, IpVersion, SinkError, SinkItem, StreamItem, Tunnel, TunnelConnector, TunnelError,
TunnelInfo, TunnelListener,
},
};
use futures::Future;
@@ -207,12 +215,7 @@ struct AcceptResult {
impl TunnelListener for FakeTcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let port = self.addr.port().unwrap_or(0);
let bind_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let bind_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let os_listener = tokio::net::TcpListener::bind(bind_addr).await?;
tracing::info!(port, "FakeTcpTunnelListener listening");
self.os_listener = Some(os_listener);
@@ -306,14 +309,9 @@ fn get_local_ip_for_destination(destination: IpAddr) -> Option<IpAddr> {
}
#[async_trait::async_trait]
impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
impl TunnelConnector for FakeTcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let remote_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let remote_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let local_ip = get_local_ip_for_destination(remote_addr.ip())
.ok_or(TunnelError::InternalError("Failed to get local ip".into()))?;
@@ -387,13 +385,6 @@ impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
}
}
use crate::tunnel::packet_def::{
ZCPacket, ZCPacketType, PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE,
};
use crate::tunnel::{SinkError, SinkItem, StreamItem};
use futures::{Sink, Stream};
use std::task::{Context as TaskContext, Poll};
type RecvFut = Pin<Box<dyn Future<Output = Option<(BytesMut, usize)>> + Send + Sync>>;
enum FakeTcpStreamState {
+140 -52
View File
@@ -1,16 +1,19 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
use std::{net::SocketAddr, pin::Pin, sync::Arc};
use std::{
collections::hash_map::DefaultHasher, hash::Hasher, net::SocketAddr, pin::Pin, sync::Arc,
};
use crate::{
common::{dns::socket_addrs, error::Error},
proto::common::TunnelInfo,
};
use async_trait::async_trait;
use derive_more::{From, TryInto};
use futures::{Sink, Stream};
use socket2::Protocol;
use std::fmt::Debug;
use strum::{Display, EnumString, VariantArray};
use tokio::time::error::Elapsed;
use crate::common::dns::socket_addrs;
use crate::proto::common::TunnelInfo;
use self::packet_def::ZCPacket;
pub mod buf;
@@ -23,15 +26,6 @@ pub mod stats;
pub mod tcp;
pub mod udp;
pub const PROTO_PORT_OFFSET: &[(&str, u16)] = &[
("tcp", 0),
("udp", 0),
("wg", 1),
("ws", 1),
("wss", 2),
("faketcp", 3),
];
#[cfg(feature = "faketcp")]
pub mod fake_tcp;
@@ -193,45 +187,23 @@ pub(crate) trait FromUrl {
Self: Sized;
}
pub(crate) async fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
ip_version: IpVersion,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
T::from_url(url.clone(), ip_version).await
}
fn default_port(scheme: &str) -> Option<u16> {
match scheme {
"tcp" => Some(11010),
"udp" => Some(11010),
"ws" => Some(80),
"wss" => Some(443),
"faketcp" => Some(11013),
"quic" => Some(11012),
"wg" => Some(11011),
_ => None,
}
}
#[async_trait::async_trait]
impl FromUrl for SocketAddr {
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
let addrs = socket_addrs(&url, || default_port(url.scheme()))
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
let addrs = socket_addrs(&url, || {
(&url)
.try_into()
.ok()
.and_then(|s: TunnelScheme| s.try_into().ok())
.map(IpScheme::default_port)
})
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs");
let addrs = addrs
.into_iter()
@@ -305,3 +277,119 @@ pub fn generate_digest_from_str(str1: &str, str2: &str, digest: &mut [u8]) {
hasher.write(&digest[..(i + 1) * 8]);
}
}
#[derive(Debug, Clone, Copy)]
struct IpSchemeAttributes {
protocol: Protocol,
port_offset: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, VariantArray)]
#[strum(serialize_all = "lowercase")]
pub enum IpScheme {
Tcp,
Udp,
#[cfg(feature = "wireguard")]
Wg,
#[cfg(feature = "quic")]
Quic,
#[cfg(feature = "websocket")]
Ws,
#[cfg(feature = "websocket")]
Wss,
#[cfg(feature = "faketcp")]
FakeTcp,
}
impl IpScheme {
const fn attributes(self) -> IpSchemeAttributes {
let (protocol, port_offset) = match self {
Self::Tcp => (Protocol::TCP, 0),
Self::Udp => (Protocol::UDP, 0),
#[cfg(feature = "wireguard")]
Self::Wg => (Protocol::UDP, 1),
#[cfg(feature = "quic")]
Self::Quic => (Protocol::UDP, 2),
#[cfg(feature = "websocket")]
Self::Ws => (Protocol::TCP, 1),
#[cfg(feature = "websocket")]
Self::Wss => (Protocol::TCP, 2),
#[cfg(feature = "faketcp")]
Self::FakeTcp => (Protocol::TCP, 3),
};
IpSchemeAttributes {
protocol,
port_offset,
}
}
pub const fn protocol(self) -> Protocol {
self.attributes().protocol
}
pub const fn port_offset(self) -> u16 {
self.attributes().port_offset
}
pub const fn default_port(self) -> u16 {
match self {
#[cfg(feature = "websocket")]
Self::Ws => 80,
#[cfg(feature = "websocket")]
Self::Wss => 443,
_ => 11010 + self.port_offset(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, EnumString, From, TryInto)]
#[strum(serialize_all = "lowercase")]
pub enum TunnelScheme {
#[strum(disabled)]
Ip(IpScheme),
#[cfg(unix)]
Unix,
// Only for connector
Http,
Https,
Ring,
Txt,
Srv,
}
impl TryFrom<&url::Url> for TunnelScheme {
type Error = Error;
fn try_from(value: &url::Url) -> Result<Self, Self::Error> {
let scheme = value.scheme();
scheme.parse().or_else(|_| {
Ok(TunnelScheme::Ip(
scheme
.parse()
.map_err(|_| Error::InvalidUrl(value.to_string()))?,
))
})
}
}
macro_rules! __matches_scheme__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::TunnelScheme::try_from(($url).as_ref()), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_scheme__ as matches_scheme;
pub fn get_protocol_by_url(l: &url::Url) -> Result<Protocol, Error> {
let TunnelScheme::Ip(scheme) = l.try_into()? else {
return Err(Error::InvalidUrl(l.to_string()));
};
Ok(scheme.protocol())
}
macro_rules! __matches_protocol__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::get_protocol_by_url($url), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_protocol__ as matches_protocol;
+26 -34
View File
@@ -8,20 +8,16 @@ use std::{
use crate::tunnel::{
common::{setup_sokcet2, FramedReader, FramedWriter, TunnelWrapper},
TunnelInfo,
FromUrl, TunnelInfo,
};
use anyhow::Context;
use super::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
use quinn::{
congestion::BbrConfig, udp::RecvMeta, AsyncUdpSocket, ClientConfig, Connection, Endpoint,
EndpointConfig, ServerConfig, TransportConfig, UdpPoller,
};
use super::{
check_scheme_and_get_socket_addr, IpVersion, Tunnel, TunnelConnector, TunnelError,
TunnelListener,
};
pub fn transport_config() -> Arc<TransportConfig> {
let mut config = TransportConfig::default();
@@ -145,14 +141,14 @@ impl Drop for ConnWrapper {
}
}
pub struct QUICTunnelListener {
pub struct QuicTunnelListener {
addr: url::Url,
endpoint: Option<Endpoint>,
}
impl QUICTunnelListener {
impl QuicTunnelListener {
pub fn new(addr: url::Url) -> Self {
QUICTunnelListener {
QuicTunnelListener {
addr,
endpoint: None,
}
@@ -190,11 +186,9 @@ impl QUICTunnelListener {
}
#[async_trait::async_trait]
impl TunnelListener for QUICTunnelListener {
impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", IpVersion::Both)
.await?;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let endpoint = make_server_endpoint(addr)
.map_err(|e| anyhow::anyhow!("make server endpoint error: {:?}", e))?;
self.endpoint = Some(endpoint);
@@ -223,15 +217,15 @@ impl TunnelListener for QUICTunnelListener {
}
}
pub struct QUICTunnelConnector {
pub struct QuicTunnelConnector {
addr: url::Url,
endpoint: Option<Endpoint>,
ip_version: IpVersion,
}
impl QUICTunnelConnector {
impl QuicTunnelConnector {
pub fn new(addr: url::Url) -> Self {
QUICTunnelConnector {
QuicTunnelConnector {
addr,
endpoint: None,
ip_version: IpVersion::Both,
@@ -240,11 +234,9 @@ impl QUICTunnelConnector {
}
#[async_trait::async_trait]
impl TunnelConnector for QUICTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", self.ip_version)
.await?;
impl TunnelConnector for QuicTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if addr.port() == 0 {
return Err(TunnelError::InvalidAddr(format!(
"invalid remote QUIC port 0 in url: {} (port 0 is not a valid QUIC port)",
@@ -318,36 +310,36 @@ mod tests {
#[tokio::test]
async fn quic_pingpong() {
let listener = QUICTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
let listener = QuicTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn quic_bench() {
let listener = QUICTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
let listener = QuicTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn ipv6_pingpong() {
let listener = QUICTunnelListener::new("quic://[::1]:31015".parse().unwrap());
let connector = QUICTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
let listener = QuicTunnelListener::new("quic://[::1]:31015".parse().unwrap());
let connector = QuicTunnelConnector::new("quic://[::1]:31015".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ipv6_domain_pingpong() {
let listener = QUICTunnelListener::new("quic://[::1]:31016".parse().unwrap());
let listener = QuicTunnelListener::new("quic://[::1]:31016".parse().unwrap());
let mut connector =
QUICTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
connector.set_ip_version(IpVersion::V6);
_tunnel_pingpong(listener, connector).await;
let listener = QUICTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
let listener = QuicTunnelListener::new("quic://127.0.0.1:31016".parse().unwrap());
let mut connector =
QUICTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
QuicTunnelConnector::new("quic://test.easytier.top:31016".parse().unwrap());
connector.set_ip_version(IpVersion::V4);
_tunnel_pingpong(listener, connector).await;
}
@@ -355,13 +347,13 @@ mod tests {
#[tokio::test]
async fn test_alloc_port() {
// v4
let mut listener = QUICTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
let mut listener = QuicTunnelListener::new("quic://0.0.0.0:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
// v6
let mut listener = QUICTunnelListener::new("quic://[::]:0".parse().unwrap());
let mut listener = QuicTunnelListener::new("quic://[::]:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
assert!(port > 0);
@@ -369,7 +361,7 @@ mod tests {
#[tokio::test]
async fn quic_connector_reject_port_zero() {
let mut connector = QUICTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
let mut connector = QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap());
let err = connector.connect().await.unwrap_err().to_string();
assert!(err.contains("port 0"), "unexpected error: {}", err);
}
+17 -28
View File
@@ -1,3 +1,5 @@
use async_ringbuf::{traits::*, AsyncHeapCons, AsyncHeapProd, AsyncHeapRb};
use crossbeam::atomic::AtomicCell;
use std::{
collections::HashMap,
fmt::Debug,
@@ -5,9 +7,6 @@ use std::{
task::{ready, Poll},
};
use async_ringbuf::{traits::*, AsyncHeapCons, AsyncHeapProd, AsyncHeapRb};
use crossbeam::atomic::AtomicCell;
use async_trait::async_trait;
use futures::{Sink, SinkExt, Stream, StreamExt};
use once_cell::sync::Lazy;
@@ -16,15 +15,15 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use uuid::Uuid;
use crate::tunnel::{SinkError, SinkItem};
use crate::tunnel::{FromUrl, IpVersion, SinkError, SinkItem};
use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::TunnelWrapper,
StreamItem, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
build_url_from_socket_addr, common::TunnelWrapper, StreamItem, Tunnel, TunnelConnector,
TunnelError, TunnelInfo, TunnelListener,
};
pub static RING_TUNNEL_CAP: usize = 128;
static RING_TUNNEL_RESERVERD_CAP: usize = 4;
static RING_TUNNEL_RESERVED_CAP: usize = 4;
type RingLock = parking_lot::Mutex<()>;
@@ -44,7 +43,7 @@ impl RingTunnel {
pub fn new(cap: usize) -> Self {
let id = Uuid::new_v4();
let ring_impl = AsyncHeapRb::new(std::cmp::max(RING_TUNNEL_RESERVERD_CAP * 2, cap));
let ring_impl = AsyncHeapRb::new(std::cmp::max(RING_TUNNEL_RESERVED_CAP * 2, cap));
let (ring_prod_impl, ring_cons_impl) = ring_impl.split();
Self {
id,
@@ -120,7 +119,7 @@ impl RingSink {
pub fn try_send(&mut self, item: RingItem) -> Result<(), RingItem> {
let base = self.ring_prod_impl.base();
if base.occupied_len() >= base.capacity().get() - RING_TUNNEL_RESERVERD_CAP {
if base.occupied_len() >= base.capacity().get() - RING_TUNNEL_RESERVED_CAP {
return Err(item);
}
self.ring_prod_impl.try_push(item)
@@ -188,7 +187,7 @@ static CONNECTION_MAP: Lazy<Arc<std::sync::Mutex<ConnectionMap>>> =
#[derive(Debug)]
pub struct RingTunnelListener {
listerner_addr: url::Url,
listener_addr: url::Url,
conn_sender: UnboundedSender<Arc<Connection>>,
conn_receiver: UnboundedReceiver<Arc<Connection>>,
@@ -199,7 +198,7 @@ impl RingTunnelListener {
pub fn new(key: url::Url) -> Self {
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
RingTunnelListener {
listerner_addr: key,
listener_addr: key,
conn_sender,
conn_receiver,
key_in_conn_map: None,
@@ -232,20 +231,15 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
}
impl RingTunnelListener {
async fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
check_scheme_and_get_socket_addr::<Uuid>(
&self.listerner_addr,
"ring",
super::IpVersion::Both,
)
.await
async fn get_addr(&self) -> Result<Uuid, TunnelError> {
Uuid::from_url(self.listener_addr.clone(), IpVersion::Both).await
}
}
#[async_trait]
impl TunnelListener for RingTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
tracing::info!("listen new conn of key: {}", self.listerner_addr);
tracing::info!("listen new conn of key: {}", self.listener_addr);
let addr = self.get_addr().await?;
CONNECTION_MAP
.lock()
@@ -256,11 +250,11 @@ impl TunnelListener for RingTunnelListener {
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tracing::info!("waiting accept new conn of key: {}", self.listerner_addr);
tracing::info!("waiting accept new conn of key: {}", self.listener_addr);
let my_addr = self.get_addr().await?;
if let Some(conn) = self.conn_receiver.recv().await {
if conn.server.id == my_addr {
tracing::info!("accept new conn of key: {}", self.listerner_addr);
tracing::info!("accept new conn of key: {}", self.listener_addr);
return Ok(Box::new(get_tunnel_for_server(conn)));
} else {
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
@@ -276,7 +270,7 @@ impl TunnelListener for RingTunnelListener {
}
fn local_url(&self) -> url::Url {
self.listerner_addr.clone()
self.listener_addr.clone()
}
}
@@ -301,12 +295,7 @@ impl RingTunnelConnector {
#[async_trait]
impl TunnelConnector for RingTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>(
&self.remote_addr,
"ring",
super::IpVersion::Both,
)
.await?;
let remote_addr = Uuid::from_url(self.remote_addr.clone(), IpVersion::Both).await?;
let entry = CONNECTION_MAP
.lock()
.unwrap()
+5 -11
View File
@@ -1,14 +1,12 @@
use std::net::SocketAddr;
use super::{FromUrl, TunnelInfo};
use crate::tunnel::common::setup_sokcet2;
use async_trait::async_trait;
use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use super::TunnelInfo;
use crate::tunnel::common::setup_sokcet2;
use super::{
check_scheme_and_get_socket_addr,
common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper},
IpVersion, Tunnel, TunnelError, TunnelListener,
};
@@ -58,9 +56,7 @@ impl TcpTunnelListener {
impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
self.listener = None;
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", IpVersion::Both)
.await?;
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
@@ -189,10 +185,8 @@ impl TcpTunnelConnector {
#[async_trait]
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", self.ip_version)
.await?;
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() {
self.connect_with_default_bind(addr).await
} else {
+18 -31
View File
@@ -21,7 +21,13 @@ use tokio::{
use tracing::{instrument, Instrument};
use super::{packet_def::V6HolePunchPacket, TunnelInfo};
use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
packet_def::{UDPTunnelHeader, V6HolePunchPacket, UDP_TUNNEL_HEADER_SIZE},
ring::{RingSink, RingStream},
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl,
};
use crate::{
common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap},
tunnel::{
@@ -32,13 +38,6 @@ use crate::{
},
};
use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
packet_def::{UDPTunnelHeader, UDP_TUNNEL_HEADER_SIZE},
ring::{RingSink, RingStream},
IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelListener, TunnelUrl,
};
pub const UDP_DATA_MTU: usize = 2000;
type UdpCloseEventSender = UnboundedSender<(SocketAddr, Option<TunnelError>)>;
@@ -149,11 +148,11 @@ async fn respond_stun_packet(
req_buf: Vec<u8>,
) -> Result<(), anyhow::Error> {
use crate::common::stun_codec_ext::*;
use bytecodec::DecodeExt as _;
use bytecodec::EncodeExt as _;
use stun_codec::rfc5389::attributes::XorMappedAddress;
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
use bytecodec::{DecodeExt as _, EncodeExt as _};
use stun_codec::{
rfc5389::{attributes::XorMappedAddress, methods::BINDING},
Message, MessageClass, MessageDecoder, MessageEncoder,
};
let mut decoder = MessageDecoder::<Attribute>::new();
let req_msg = decoder
@@ -532,13 +531,8 @@ impl UdpTunnelListener {
#[async_trait]
impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"udp",
IpVersion::Both,
)
.await?;
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
@@ -851,13 +845,8 @@ impl UdpTunnelConnector {
#[async_trait]
impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"udp",
self.ip_version,
)
.await?;
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
} else {
@@ -889,7 +878,6 @@ mod tests {
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
tunnel::{
check_scheme_and_get_socket_addr,
common::{
get_interface_name_by_ip,
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong, wait_for_condition},
@@ -1034,9 +1022,8 @@ mod tests {
for ip in ips {
println!("bind to ip: {}, {:?}", ip, bind_dev);
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
&format!("udp://{}:11111", ip).parse().unwrap(),
"udp",
let addr = SocketAddr::from_url(
format!("udp://{}:11111", ip).parse().unwrap(),
IpVersion::Both,
)
.await
+29 -29
View File
@@ -1,10 +1,20 @@
use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
};
use crate::{proto::common::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
use anyhow::Context;
use bytes::BytesMut;
use forwarded_header_value::ForwardedHeaderValue;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use pnet::ipnetwork::IpNetwork;
use std::sync::LazyLock;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use std::{
net::SocketAddr,
sync::{Arc, LazyLock},
time::Duration,
};
use tokio::{
net::{TcpListener, TcpSocket, TcpStream},
time::timeout,
@@ -14,16 +24,6 @@ use tokio_util::either::Either;
use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message, ServerBuilder};
use zerocopy::AsBytes;
use super::TunnelInfo;
use crate::tunnel::insecure_tls::get_insecure_tls_client_config;
use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType},
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
};
fn is_wss(addr: &url::Url) -> Result<bool, TunnelError> {
match addr.scheme() {
"ws" => Ok(false),
@@ -77,14 +77,14 @@ static TRUSTED_PROXIES: LazyLock<Vec<IpNetwork>> = LazyLock::new(|| {
});
#[derive(Debug)]
pub struct WSTunnelListener {
pub struct WsTunnelListener {
addr: url::Url,
listener: Option<TcpListener>,
}
impl WSTunnelListener {
impl WsTunnelListener {
pub fn new(addr: url::Url) -> Self {
WSTunnelListener {
WsTunnelListener {
addr,
listener: None,
}
@@ -159,7 +159,7 @@ impl WSTunnelListener {
}
#[async_trait::async_trait]
impl TunnelListener for WSTunnelListener {
impl TunnelListener for WsTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
@@ -199,16 +199,16 @@ impl TunnelListener for WSTunnelListener {
}
}
pub struct WSTunnelConnector {
pub struct WsTunnelConnector {
addr: url::Url,
ip_version: IpVersion,
bind_addrs: Vec<SocketAddr>,
}
impl WSTunnelConnector {
impl WsTunnelConnector {
pub fn new(addr: url::Url) -> Self {
WSTunnelConnector {
WsTunnelConnector {
addr,
ip_version: IpVersion::Both,
@@ -307,8 +307,8 @@ impl WSTunnelConnector {
}
#[async_trait::async_trait]
impl TunnelConnector for WSTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
impl TunnelConnector for WsTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
@@ -340,9 +340,9 @@ pub mod tests {
#[tokio::test]
#[serial_test::serial]
async fn ws_pingpong(#[values("ws", "wss")] proto: &str) {
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap());
let listener = WsTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap());
let connector =
WSTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap());
WsTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
@@ -350,9 +350,9 @@ pub mod tests {
#[tokio::test]
#[serial_test::serial]
async fn ws_pingpong_bind(#[values("ws", "wss")] proto: &str) {
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap());
let listener = WsTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap());
let mut connector =
WSTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap());
WsTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
@@ -371,16 +371,16 @@ pub mod tests {
#[tokio::test]
async fn ws_accept_wss() {
let mut listener = WSTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap());
let mut listener = WsTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap());
listener.listen().await.unwrap();
let j = tokio::spawn(async move {
let _ = listener.accept().await;
});
let mut connector = WSTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap());
let mut connector = WsTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap());
connector.connect().await.unwrap_err();
let mut connector = WSTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap());
let mut connector = WsTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap());
connector.connect().await.unwrap();
j.abort();
@@ -388,7 +388,7 @@ pub mod tests {
#[tokio::test]
async fn ws_forwarded() {
let mut listener = WSTunnelListener::new("ws://127.0.0.1:25559".parse().unwrap());
let mut listener = WsTunnelListener::new("ws://127.0.0.1:25559".parse().unwrap());
listener.listen().await.unwrap();
let server_task = tokio::spawn(async move {
+20 -23
View File
@@ -20,7 +20,14 @@ use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use super::TunnelInfo;
use super::{
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
ring::create_ring_tunnel_pair,
FromUrl, IpVersion, Tunnel, TunnelError, TunnelInfo, TunnelListener, TunnelUrl, ZCPacketSink,
ZCPacketStream,
};
use crate::{
common::shrink_dashmap,
tunnel::{
@@ -30,15 +37,6 @@ use crate::{
},
};
use super::{
check_scheme_and_get_socket_addr,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
ring::create_ring_tunnel_pair,
IpVersion, Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream,
};
const MAX_PACKET: usize = 2048;
#[derive(Debug, Clone)]
@@ -202,7 +200,10 @@ impl WgPeerData {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
tracing::error!(
"Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}",
e
);
return;
}
};
@@ -214,7 +215,10 @@ impl WgPeerData {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
tracing::error!(
"Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}",
e
);
break;
}
};
@@ -550,10 +554,8 @@ impl WgTunnelListener {
#[async_trait]
impl TunnelListener for WgTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg", IpVersion::Both)
.await?;
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
@@ -705,13 +707,8 @@ impl WgTunnelConnector {
#[async_trait]
impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"wg",
self.ip_version,
)
.await?;
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
if addr.is_ipv6() {
return self.connect_with_ipv6(addr).await;
+1 -2
View File
@@ -36,8 +36,7 @@ thread_local! {
}
pub fn setup_panic_handler() {
use std::backtrace;
use std::io::Write;
use std::{backtrace, io::Write};
std::panic::set_hook(Box::new(|info| {
let mut stderr = std::io::stderr();
let sep = format!("{}\n", "=======".repeat(10));