refactor: listener/connector protocol abstraction (#2026)

* fix listener protocol detection
* replace IpProtocol with IpNextHeaderProtocol
* use an enum to gather all listener schemes
* rename ListenerScheme to TunnelScheme; replace IpNextHeaderProtocols with socket2::Protocol
* move TunnelScheme to tunnel
* add IpScheme, simplify connector creation
* format; fix some typos; remove check_scheme_...;
* remove PROTO_PORT_OFFSET
* rename WSTunnel.. -> WsTunnel.., DNSTunnel.. -> DnsTunnel..
This commit is contained in:
Luna Yao
2026-04-04 04:55:58 +02:00
committed by GitHub
parent 9cc617ae4c
commit e91a0da70a
18 changed files with 481 additions and 526 deletions
+23 -23
View File
@@ -31,19 +31,20 @@ use crate::{
},
rpc_types::controller::BaseController,
},
tunnel::{udp::UdpTunnelConnector, IpVersion},
tunnel::{matches_protocol, udp::UdpTunnelConnector, IpVersion},
use_global_var,
};
use anyhow::Context;
use rand::Rng;
use tokio::{net::UdpSocket, task::JoinSet, time::timeout};
use url::Host;
use super::{
create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer,
udp_hole_punch,
};
use crate::tunnel::{matches_scheme, FromUrl, IpScheme, TunnelScheme};
use anyhow::Context;
use rand::Rng;
use socket2::Protocol;
use tokio::{net::UdpSocket, task::JoinSet, time::timeout};
use url::Host;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
@@ -189,9 +190,7 @@ impl DirectConnectorManagerData {
.await;
let udp_connector = UdpTunnelConnector::new(remote_url.clone());
let remote_addr =
super::check_scheme_and_get_socket_addr::<SocketAddr>(remote_url, "udp", IpVersion::V6)
.await?;
let remote_addr = SocketAddr::from_url(remote_url.clone(), IpVersion::V6).await?;
let ret = udp_connector
.try_connect_with_socket(local_socket, remote_addr)
.await?;
@@ -205,18 +204,19 @@ impl DirectConnectorManagerData {
async fn do_try_connect_to_ip(&self, dst_peer_id: PeerId, addr: String) -> Result<(), Error> {
let connector = create_connector_by_url(&addr, &self.global_ctx, IpVersion::Both).await?;
let remote_url = connector.remote_url();
let (peer_id, conn_id) =
if remote_url.scheme() == "udp" && matches!(remote_url.host(), Some(Host::Ipv6(_))) {
self.connect_to_public_ipv6(dst_peer_id, &remote_url)
.await?
} else {
timeout(
std::time::Duration::from_secs(3),
self.peer_manager
.try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)),
)
.await??
};
let (peer_id, conn_id) = if matches_scheme!(remote_url, TunnelScheme::Ip(IpScheme::Udp))
&& matches!(remote_url.host(), Some(Host::Ipv6(_)))
{
self.connect_to_public_ipv6(dst_peer_id, &remote_url)
.await?
} else {
timeout(
std::time::Duration::from_secs(3),
self.peer_manager
.try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)),
)
.await??
};
if peer_id != dst_peer_id && !TESTING.load(Ordering::Relaxed) {
tracing::info!(
@@ -306,7 +306,7 @@ impl DirectConnectorManagerData {
let listener_host = addrs.pop();
tracing::info!(?listener_host, ?listener, "try direct connect to peer");
let is_udp = matches!(listener.scheme(), "udp" | "wg");
let is_udp = matches_protocol!(listener, Protocol::UDP);
// Snapshot running listeners once; used for cheap port pre-checks before the
// expensive should_deny_proxy call (which binds a socket per IP) in the
// unspecified-address expansion loops below.
@@ -314,7 +314,7 @@ impl DirectConnectorManagerData {
let port_has_local_listener = |port: u16| -> bool {
local_listeners
.iter()
.any(|l| l.port() == Some(port) && (matches!(l.scheme(), "udp" | "wg") == is_udp))
.any(|l| l.port() == Some(port) && matches_protocol!(l, Protocol::UDP) == is_udp)
};
match listener_host {
+38 -40
View File
@@ -1,5 +1,6 @@
use std::{net::SocketAddr, sync::Arc};
use super::{create_connector_by_url, http_connector::TunnelWithInfo};
use crate::{
common::{
dns::{resolve_txt_record, RESOLVER},
@@ -7,16 +8,15 @@ use crate::{
global_ctx::ArcGlobalCtx,
log,
},
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, PROTO_PORT_OFFSET},
proto::common::TunnelInfo,
tunnel::{IpScheme, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelScheme},
};
use anyhow::Context;
use dashmap::DashSet;
use hickory_resolver::proto::rr::rdata::SRV;
use itertools::Itertools;
use rand::{seq::SliceRandom, Rng as _};
use crate::proto::common::TunnelInfo;
use super::{create_connector_by_url, http_connector::TunnelWithInfo};
use strum::VariantArray;
fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> {
let total_weight = options.iter().map(|(_, weight)| *weight).sum();
@@ -35,16 +35,18 @@ fn weighted_choice<T>(options: &[(T, u64)]) -> Option<&T> {
}
#[derive(Debug)]
pub struct DNSTunnelConnector {
pub struct DnsTunnelConnector {
scheme: TunnelScheme,
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
global_ctx: ArcGlobalCtx,
ip_version: IpVersion,
}
impl DNSTunnelConnector {
impl DnsTunnelConnector {
pub fn new(addr: url::Url, global_ctx: ArcGlobalCtx) -> Self {
Self {
scheme: (&addr).try_into().unwrap(),
addr,
bind_addrs: Vec::new(),
global_ctx,
@@ -82,7 +84,7 @@ impl DNSTunnelConnector {
Ok(connector)
}
fn handle_one_srv_record(record: &SRV, protocol: &str) -> Result<(url::Url, u64), Error> {
fn handle_one_srv_record(record: &SRV, protocol: IpScheme) -> Result<(url::Url, u64), Error> {
// port must be non-zero
if record.port() == 0 {
return Err(anyhow::anyhow!("port must be non-zero").into());
@@ -112,15 +114,15 @@ impl DNSTunnelConnector {
) -> Result<Box<dyn TunnelConnector>, Error> {
tracing::info!("handle_srv_record: {}", domain_name);
let srv_domains = PROTO_PORT_OFFSET
let srv_domains = IpScheme::VARIANTS
.iter()
.map(|(p, _)| (format!("_easytier._{}.{}", p, domain_name), *p)) // _easytier._udp.{domain_name}
.collect::<Vec<_>>();
.map(|s| (s, format!("_easytier._{}.{}", s, domain_name)))
.collect_vec();
tracing::info!("build srv_domains: {:?}", srv_domains);
let responses = Arc::new(DashSet::new());
let srv_lookup_tasks = srv_domains
.iter()
.map(|(srv_domain, protocol)| {
.map(|(protocol, srv_domain)| {
let resolver = RESOLVER.clone();
let responses = responses.clone();
async move {
@@ -129,7 +131,7 @@ impl DNSTunnelConnector {
})?;
tracing::info!(?response, ?srv_domain, "srv_lookup response");
for record in response.iter() {
let parsed_record = Self::handle_one_srv_record(record, protocol);
let parsed_record = Self::handle_one_srv_record(record, **protocol);
tracing::info!(?parsed_record, ?srv_domain, "parsed_record");
if let Err(e) = &parsed_record {
log::warn!("got invalid srv record {:?}", e);
@@ -162,32 +164,28 @@ impl DNSTunnelConnector {
}
#[async_trait::async_trait]
impl super::TunnelConnector for DNSTunnelConnector {
impl super::TunnelConnector for DnsTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let mut conn = if self.addr.scheme() == "txt" {
self.handle_txt_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in txt url"))?,
)
.await
.with_context(|| "get txt record url failed")?
} else if self.addr.scheme() == "srv" {
self.handle_srv_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in srv url"))?,
)
.await
.with_context(|| "get srv record url failed")?
} else {
return Err(anyhow::anyhow!(
"unsupported dns scheme: {}, expecting txt or srv",
self.addr.scheme()
)
.into());
let mut conn = match self.scheme {
TunnelScheme::Txt => self
.handle_txt_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in txt url"))?,
)
.await
.with_context(|| "get txt record url failed")?,
TunnelScheme::Srv => self
.handle_srv_record(
self.addr
.host_str()
.as_ref()
.ok_or(anyhow::anyhow!("host should not be empty in srv url"))?,
)
.await
.with_context(|| "get srv record url failed")?,
_ => return Err(anyhow::anyhow!("unsupported dns scheme: {:?}", self.scheme).into()),
};
let t = conn.connect().await?;
let info = t.info().unwrap_or_default();
@@ -227,7 +225,7 @@ mod tests {
async fn test_txt() {
let url = "txt://txt.easytier.cn";
let global_ctx = get_mock_global_ctx();
let mut connector = DNSTunnelConnector::new(url.parse().unwrap(), global_ctx);
let mut connector = DnsTunnelConnector::new(url.parse().unwrap(), global_ctx);
connector.set_ip_version(IpVersion::V4);
for _ in 0..5 {
match connector.connect().await {
@@ -246,7 +244,7 @@ mod tests {
async fn test_srv() {
let url = "srv://easytier.cn";
let global_ctx = get_mock_global_ctx();
let mut connector = DNSTunnelConnector::new(url.parse().unwrap(), global_ctx);
let mut connector = DnsTunnelConnector::new(url.parse().unwrap(), global_ctx);
connector.set_ip_version(IpVersion::V4);
for _ in 0..5 {
match connector.connect().await {
+41 -116
View File
@@ -3,24 +3,17 @@ use std::{
sync::Arc,
};
use http_connector::HttpTunnelConnector;
#[cfg(feature = "faketcp")]
use crate::tunnel::fake_tcp::FakeTcpTunnelConnector;
#[cfg(feature = "quic")]
use crate::tunnel::quic::QUICTunnelConnector;
#[cfg(unix)]
use crate::tunnel::unix::UnixSocketTunnelConnector;
#[cfg(feature = "wireguard")]
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector},
connector::dns_connector::DnsTunnelConnector,
proto::common::PeerFeatureFlag,
tunnel::{
check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector,
udp::UdpTunnelConnector, IpVersion, TunnelConnector,
self, ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, FromUrl,
IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme,
},
utils::BoxExt,
};
use http_connector::HttpTunnelConnector;
pub mod direct;
pub mod manual;
@@ -90,84 +83,34 @@ pub async fn create_connector_by_url(
) -> Result<Box<dyn TunnelConnector + 'static>, Error> {
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
let url = idn::convert_idn_to_ascii(url)?;
let mut connector: Box<dyn TunnelConnector + 'static> = match url.scheme() {
"tcp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp", ip_version).await?;
let mut connector = TcpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
"udp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp", ip_version).await?;
let mut connector = UdpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
"http" | "https" => {
let connector = HttpTunnelConnector::new(url, global_ctx.clone());
Box::new(connector)
}
"ring" => {
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring", IpVersion::Both).await?;
let connector = RingTunnelConnector::new(url);
Box::new(connector)
}
#[cfg(feature = "quic")]
"quic" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic", ip_version).await?;
let mut connector = QUICTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(feature = "wireguard")]
"wg" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg", ip_version).await?;
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
let mut connector = WgTunnelConnector::new(url, wg_config);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(feature = "websocket")]
"ws" | "wss" => {
use crate::tunnel::FromUrl;
let scheme = (&url)
.try_into()
.map_err(|_| TunnelError::InvalidProtocol(url.scheme().to_owned()))?;
let mut connector: Box<dyn TunnelConnector + 'static> = match scheme {
TunnelScheme::Ip(scheme) => {
let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?;
let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url);
let mut connector: Box<dyn TunnelConnector> = match scheme {
IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(),
IpScheme::Udp => UdpTunnelConnector::new(url).boxed(),
#[cfg(feature = "quic")]
IpScheme::Quic => tunnel::quic::QuicTunnelConnector::new(url).boxed(),
#[cfg(feature = "wireguard")]
IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
let nid = global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
WgTunnelConnector::new(url, wg_config).boxed()
}
#[cfg(feature = "websocket")]
IpScheme::Ws | IpScheme::Wss => {
tunnel::websocket::WsTunnelConnector::new(url).boxed()
}
#[cfg(feature = "faketcp")]
IpScheme::FakeTcp => tunnel::fake_tcp::FakeTcpTunnelConnector::new(url).boxed(),
};
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
@@ -176,40 +119,22 @@ pub async fn create_connector_by_url(
)
.await;
}
Box::new(connector)
connector
}
"txt" | "srv" => {
#[cfg(unix)]
TunnelScheme::Unix => tunnel::unix::UnixSocketTunnelConnector::new(url).boxed(),
TunnelScheme::Http | TunnelScheme::Https => {
HttpTunnelConnector::new(url, global_ctx.clone()).boxed()
}
TunnelScheme::Ring => RingTunnelConnector::new(url).boxed(),
TunnelScheme::Txt | TunnelScheme::Srv => {
if url.host_str().is_none() {
return Err(Error::InvalidUrl(format!(
"host should not be empty in txt or srv url: {}",
url
)));
}
let connector = dns_connector::DNSTunnelConnector::new(url, global_ctx.clone());
Box::new(connector)
}
#[cfg(feature = "faketcp")]
"faketcp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "faketcp", ip_version).await?;
let mut connector = FakeTcpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
#[cfg(unix)]
"unix" => {
let connector = UnixSocketTunnelConnector::new(url);
Box::new(connector)
}
_ => {
return Err(Error::InvalidUrl(url.into()));
DnsTunnelConnector::new(url, global_ctx.clone()).boxed()
}
};
connector.set_ip_version(ip_version);