refactor: listener/connector protocol abstraction (#2026)

* fix listener protocol detection
* replace IpProtocol with IpNextHeaderProtocol
* use an enum to gather all listener schemes
* rename ListenerScheme to TunnelScheme; replace IpNextHeaderProtocols with socket2::Protocol
* move TunnelScheme to tunnel
* add IpScheme, simplify connector creation
* format; fix some typos; remove check_scheme_...;
* remove PROTO_PORT_OFFSET
* rename WSTunnel.. -> WsTunnel.., DNSTunnel.. -> DnsTunnel..
This commit is contained in:
Luna Yao
2026-04-04 04:55:58 +02:00
committed by GitHub
parent 9cc617ae4c
commit e91a0da70a
18 changed files with 481 additions and 526 deletions
+140 -52
View File
@@ -1,16 +1,19 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
use std::{net::SocketAddr, pin::Pin, sync::Arc};
use std::{
collections::hash_map::DefaultHasher, hash::Hasher, net::SocketAddr, pin::Pin, sync::Arc,
};
use crate::{
common::{dns::socket_addrs, error::Error},
proto::common::TunnelInfo,
};
use async_trait::async_trait;
use derive_more::{From, TryInto};
use futures::{Sink, Stream};
use socket2::Protocol;
use std::fmt::Debug;
use strum::{Display, EnumString, VariantArray};
use tokio::time::error::Elapsed;
use crate::common::dns::socket_addrs;
use crate::proto::common::TunnelInfo;
use self::packet_def::ZCPacket;
pub mod buf;
@@ -23,15 +26,6 @@ pub mod stats;
pub mod tcp;
pub mod udp;
pub const PROTO_PORT_OFFSET: &[(&str, u16)] = &[
("tcp", 0),
("udp", 0),
("wg", 1),
("ws", 1),
("wss", 2),
("faketcp", 3),
];
#[cfg(feature = "faketcp")]
pub mod fake_tcp;
@@ -193,45 +187,23 @@ pub(crate) trait FromUrl {
Self: Sized;
}
pub(crate) async fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
ip_version: IpVersion,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
T::from_url(url.clone(), ip_version).await
}
fn default_port(scheme: &str) -> Option<u16> {
match scheme {
"tcp" => Some(11010),
"udp" => Some(11010),
"ws" => Some(80),
"wss" => Some(443),
"faketcp" => Some(11013),
"quic" => Some(11012),
"wg" => Some(11011),
_ => None,
}
}
#[async_trait::async_trait]
impl FromUrl for SocketAddr {
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
let addrs = socket_addrs(&url, || default_port(url.scheme()))
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
let addrs = socket_addrs(&url, || {
(&url)
.try_into()
.ok()
.and_then(|s: TunnelScheme| s.try_into().ok())
.map(IpScheme::default_port)
})
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs");
let addrs = addrs
.into_iter()
@@ -305,3 +277,119 @@ pub fn generate_digest_from_str(str1: &str, str2: &str, digest: &mut [u8]) {
hasher.write(&digest[..(i + 1) * 8]);
}
}
#[derive(Debug, Clone, Copy)]
struct IpSchemeAttributes {
protocol: Protocol,
port_offset: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, VariantArray)]
#[strum(serialize_all = "lowercase")]
pub enum IpScheme {
Tcp,
Udp,
#[cfg(feature = "wireguard")]
Wg,
#[cfg(feature = "quic")]
Quic,
#[cfg(feature = "websocket")]
Ws,
#[cfg(feature = "websocket")]
Wss,
#[cfg(feature = "faketcp")]
FakeTcp,
}
impl IpScheme {
const fn attributes(self) -> IpSchemeAttributes {
let (protocol, port_offset) = match self {
Self::Tcp => (Protocol::TCP, 0),
Self::Udp => (Protocol::UDP, 0),
#[cfg(feature = "wireguard")]
Self::Wg => (Protocol::UDP, 1),
#[cfg(feature = "quic")]
Self::Quic => (Protocol::UDP, 2),
#[cfg(feature = "websocket")]
Self::Ws => (Protocol::TCP, 1),
#[cfg(feature = "websocket")]
Self::Wss => (Protocol::TCP, 2),
#[cfg(feature = "faketcp")]
Self::FakeTcp => (Protocol::TCP, 3),
};
IpSchemeAttributes {
protocol,
port_offset,
}
}
pub const fn protocol(self) -> Protocol {
self.attributes().protocol
}
pub const fn port_offset(self) -> u16 {
self.attributes().port_offset
}
pub const fn default_port(self) -> u16 {
match self {
#[cfg(feature = "websocket")]
Self::Ws => 80,
#[cfg(feature = "websocket")]
Self::Wss => 443,
_ => 11010 + self.port_offset(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, EnumString, From, TryInto)]
#[strum(serialize_all = "lowercase")]
pub enum TunnelScheme {
#[strum(disabled)]
Ip(IpScheme),
#[cfg(unix)]
Unix,
// Only for connector
Http,
Https,
Ring,
Txt,
Srv,
}
impl TryFrom<&url::Url> for TunnelScheme {
type Error = Error;
fn try_from(value: &url::Url) -> Result<Self, Self::Error> {
let scheme = value.scheme();
scheme.parse().or_else(|_| {
Ok(TunnelScheme::Ip(
scheme
.parse()
.map_err(|_| Error::InvalidUrl(value.to_string()))?,
))
})
}
}
macro_rules! __matches_scheme__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::TunnelScheme::try_from(($url).as_ref()), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_scheme__ as matches_scheme;
pub fn get_protocol_by_url(l: &url::Url) -> Result<Protocol, Error> {
let TunnelScheme::Ip(scheme) = l.try_into()? else {
return Err(Error::InvalidUrl(l.to_string()));
};
Ok(scheme.protocol())
}
macro_rules! __matches_protocol__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::get_protocol_by_url($url), Ok($( $pattern )|+))
};
}
pub(crate) use __matches_protocol__ as matches_protocol;