From cd2cf563586e5017cf5ef86530c0f1e83fca4ecd Mon Sep 17 00:00:00 2001 From: Luna Yao <40349250+ZnqbuZ@users.noreply.github.com> Date: Mon, 2 Feb 2026 04:53:40 +0100 Subject: [PATCH] refactor: handle quic proxy internally instead of use external udp port (#1743) * deprecate quic_listen_port, add disable_relay_quic and enable_relay_foreign_network_quic * add set_src_modified to TcpProxyForWrappedSrcTrait * prioritize quic over kcp --- Cargo.lock | 62 +- easytier/Cargo.toml | 12 +- easytier/locales/app.yml | 9 +- easytier/src/common/config.rs | 13 +- easytier/src/common/global_ctx.rs | 14 +- easytier/src/core.rs | 33 +- easytier/src/gateway/kcp_proxy.rs | 4 + easytier/src/gateway/quic_proxy.rs | 1541 +++++++++++++---- easytier/src/gateway/wrapped_proxy.rs | 11 +- easytier/src/instance/instance.rs | 70 +- easytier/src/launcher.rs | 5 - easytier/src/peers/acl_filter.rs | 50 +- easytier/src/peers/foreign_network_manager.rs | 1 + easytier/src/peers/peer_manager.rs | 48 + easytier/src/peers/peer_ospf_route.rs | 7 +- easytier/src/proto/api_manage.proto | 2 +- easytier/src/proto/common.proto | 10 +- easytier/src/proto/peer_rpc.proto | 2 +- easytier/src/tests/three_node.rs | 31 +- easytier/src/tunnel/packet_def.rs | 20 + easytier/src/tunnel/quic.rs | 4 +- 21 files changed, 1419 insertions(+), 530 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8053697f..6a638162 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -491,6 +491,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_refcell" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" + [[package]] name = "auto_impl" version = "1.2.1" @@ -1499,6 +1505,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "cookie" version = "0.18.1" @@ -1918,6 +1933,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive-new" version = "0.6.0" @@ -1977,13 +2003,36 @@ version = "0.99.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ - "convert_case", + "convert_case 0.4.0", "proc-macro2", "quote", "rustc_version", "syn 2.0.87", ] +[[package]] +name = "derive_more" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" +dependencies = [ + "convert_case 0.10.0", + "proc-macro2", + "quote", + "rustc_version", + "syn 2.0.87", + "unicode-xid", +] + [[package]] name = "digest" version = "0.10.7" @@ -2153,6 +2202,7 @@ dependencies = [ "async-stream", "async-trait", "atomic-shim", + "atomic_refcell", "auto_impl", "base64 0.22.1", "bitflags 2.8.0", @@ -2171,7 +2221,9 @@ dependencies = [ "dashmap", "dbus", "defguard_wireguard_rs", + "derivative", "derive_builder", + "derive_more 2.1.1", "easytier-rpc-build", "encoding", "flume 0.12.0", @@ -7567,7 +7619,7 @@ checksum = "0c37578180969d00692904465fb7f6b3d50b9a2b952b87c23d0e2e5cb5013416" dependencies = [ "bitflags 1.3.2", "cssparser", - "derive_more", + "derive_more 0.99.18", "fxhash", "log", "phf 0.8.0", @@ -9855,6 +9907,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unicode_categories" version = "0.1.1" diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 82cd82ad..4b33ee5f 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -36,6 +36,8 @@ tracing-subscriber = { version = "0.3", features = [ "local-time", "time", ] } +derivative = "2.2.0" +derive_more = {version = "2.1.1", features = ["full"]} console-subscriber = { version = "0.4.1", optional = true } thiserror = "1.0" auto_impl = "1.1.0" @@ -64,8 +66,10 @@ zerocopy = { version = "0.7.32", features = ["derive", "simd"] } bytes = "1.5.0" pin-project-lite = "0.2.13" +atomic_refcell = "0.1.13" + quinn = { version = "0.11.8", optional = true, features = ["ring"] } -quinn-plaintext = { version = "0.3.0", optional = true} +quinn-plaintext = { version = "0.3.0", optional = true } rustls = { version = "0.23.0", features = [ "ring", @@ -86,7 +90,7 @@ http = { version = "1", default-features = false, features = [ tokio-rustls = { version = "0.26", default-features = false, optional = true } # for tap device -tun = { package = "tun-easytier", git="https://github.com/EasyTier/rust-tun", features = [ +tun = { package = "tun-easytier", git = "https://github.com/EasyTier/rust-tun", features = [ "async", ], optional = true } # for net ns @@ -263,10 +267,10 @@ winreg = "0.52" windows-service = "0.7.0" windows-sys = { version = "0.52", features = [ "Win32_NetworkManagement_IpHelper", - "Win32_NetworkManagement_Ndis", + "Win32_NetworkManagement_Ndis", "Win32_Networking_WinSock", "Win32_Foundation" -]} +] } winapi = { version = "0.3.9", features = ["impl-default"] } [target.'cfg(not(windows))'.dependencies] diff --git a/easytier/locales/app.yml b/easytier/locales/app.yml index eeb0c21a..dd3397ef 100644 --- a/easytier/locales/app.yml +++ b/easytier/locales/app.yml @@ -196,9 +196,6 @@ core_clap: disable_quic_input: en: "do not allow other nodes to use QUIC to proxy tcp streams to this node. when a node with QUIC proxy enabled accesses this node, the original tcp connection is preserved." zh-CN: "不允许其他节点使用 QUIC 代理 TCP 流到此节点。开启 QUIC 代理的节点访问此节点时,依然使用原始 TCP 连接。" - quic_listen_port: - en: "the port to listen for quic connections, default is 0 (random port)" - zh-CN: "监听 QUIC 连接的端口,默认值为0(随机端口)。" port_forward: en: "forward local port to remote port in virtual network. e.g.: udp://0.0.0.0:12345/10.126.126.1:23456, means forward local udp port 12345 to 10.126.126.1:23456 in the virtual network. can specify multiple." zh-CN: "将本地端口转发到虚拟网络中的远程端口。例如:udp://0.0.0.0:12345/10.126.126.1:23456,表示将本地UDP端口12345转发到虚拟网络中的10.126.126.1:23456。可以指定多个。" @@ -223,9 +220,15 @@ core_clap: disable_relay_kcp: en: "if true, disable relay kcp packets. avoid consuming too many bandwidth. default is false" zh-CN: "如果为true,则禁止节点转发 KCP 数据包,防止过度消耗流量。默认值为false" + disable_relay_quic: + en: "if true, disable relay quic packets. avoid consuming too many bandwidth. default is false" + zh-CN: "如果为true,则禁止节点转发 QUIC 数据包,防止过度消耗流量。默认值为false" enable_relay_foreign_network_kcp: en: "if true, allow relay kcp packets from foreign network. default is false (not forward foreign network kcp packets)" zh-CN: "如果为true,则作为共享节点时也可以转发其他网络的 KCP 数据包。默认值为false(不转发)" + enable_relay_foreign_network_quic: + en: "if true, allow relay quic packets from foreign network. default is false (not forward foreign network quic packets)" + zh-CN: "如果为true,则作为共享节点时也可以转发其他网络的 QUIC 数据包。默认值为false(不转发)" stun_servers: en: "Override default STUN servers; If configured but empty, STUN servers are not used" zh-CN: "覆盖内置的默认 STUN server 列表;如果设置了但是为空,则不使用 STUN servers;如果没设置,则使用默认 STUN server 列表" diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index c0f86105..e6f07fd2 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -24,6 +24,7 @@ use super::env_parser; pub type Flags = crate::proto::common::FlagsInConfig; pub fn gen_default_flags() -> Flags { + #[allow(deprecated)] Flags { default_protocol: "tcp".to_string(), dev_name: "".to_string(), @@ -52,12 +53,15 @@ pub fn gen_default_flags() -> Flags { private_mode: false, enable_quic_proxy: false, disable_quic_input: false, - quic_listen_port: 0, + disable_relay_quic: false, + enable_relay_foreign_network_quic: false, foreign_relay_bps_limit: u64::MAX, multi_thread_count: 2, encryption_algorithm: "aes-gcm".to_string(), disable_sym_hole_punching: false, tld_dns_zone: DEFAULT_ET_DNS_ZONE.to_string(), + + quic_listen_port: u32::MAX, } } @@ -1584,7 +1588,6 @@ enable_ipv6 = ${ENABLE_IPV6} async fn test_numeric_type_env_vars() { // 设置数字类型的环境变量 std::env::set_var("MTU_VALUE", "1400"); - std::env::set_var("QUIC_PORT", "8080"); std::env::set_var("THREAD_COUNT", "4"); let mut temp_file = NamedTempFile::new().unwrap(); @@ -1597,7 +1600,6 @@ network_secret = "secret" [flags] mtu = ${MTU_VALUE} -quic_listen_port = ${QUIC_PORT} multi_thread_count = ${THREAD_COUNT} "#; temp_file.write_all(config_content.as_bytes()).unwrap(); @@ -1611,10 +1613,6 @@ multi_thread_count = ${THREAD_COUNT} // 验证数字值被正确解析 let flags = config.get_flags(); assert_eq!(flags.mtu, 1400, "mtu should be 1400"); - assert_eq!( - flags.quic_listen_port, 8080, - "quic_listen_port should be 8080" - ); assert_eq!( flags.multi_thread_count, 4, "multi_thread_count should be 4" @@ -1626,7 +1624,6 @@ multi_thread_count = ${THREAD_COUNT} // 清理 std::env::remove_var("MTU_VALUE"); - std::env::remove_var("QUIC_PORT"); std::env::remove_var("THREAD_COUNT"); } diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index 360e113d..d8df3d83 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -92,8 +92,6 @@ pub struct GlobalCtx { feature_flags: AtomicCell, - quic_proxy_port: AtomicCell>, - token_bucket_manager: TokenBucketManager, stats_manager: Arc, @@ -149,6 +147,8 @@ impl GlobalCtx { kcp_input: !config_fs.get_flags().disable_kcp_input, no_relay_kcp: config_fs.get_flags().disable_relay_kcp, support_conn_list_sync: true, // Enable selective peer list sync by default + quic_input: !config_fs.get_flags().disable_quic_input, + no_relay_quic: config_fs.get_flags().disable_relay_quic, ..Default::default() }; @@ -181,7 +181,6 @@ impl GlobalCtx { p2p_only, feature_flags: AtomicCell::new(feature_flags), - quic_proxy_port: AtomicCell::new(None), token_bucket_manager: TokenBucketManager::new(), @@ -393,15 +392,6 @@ impl GlobalCtx { self.feature_flags.store(flags); } - pub fn get_quic_proxy_port(&self) -> Option { - self.quic_proxy_port.load() - } - - pub fn set_quic_proxy_port(&self, port: Option) { - self.acl_filter.set_quic_udp_port(port.unwrap_or(0)); - self.quic_proxy_port.store(port); - } - pub fn token_bucket_manager(&self) -> &TokenBucketManager { &self.token_bucket_manager } diff --git a/easytier/src/core.rs b/easytier/src/core.rs index 55dce023..fc11b7e0 100644 --- a/easytier/src/core.rs +++ b/easytier/src/core.rs @@ -507,14 +507,6 @@ struct NetworkOptions { )] disable_quic_input: Option, - #[arg( - long, - env = "ET_QUIC_LISTEN_PORT", - help = t!("core_clap.quic_listen_port").to_string(), - num_args = 0..=1, - )] - quic_listen_port: Option, - #[arg( long, env = "ET_PORT_FORWARD", @@ -576,6 +568,15 @@ struct NetworkOptions { )] disable_relay_kcp: Option, + #[arg( + long, + env = "ET_DISABLE_RELAY_QUIC", + help = t!("core_clap.disable_relay_quic").to_string(), + num_args = 0..=1, + default_missing_value = "true" + )] + disable_relay_quic: Option, + #[arg( long, env = "ET_ENABLE_RELAY_FOREIGN_NETWORK_KCP", @@ -585,6 +586,15 @@ struct NetworkOptions { )] enable_relay_foreign_network_kcp: Option, + #[arg( + long, + env = "ET_ENABLE_RELAY_FOREIGN_NETWORK_QUIC", + help = t!("core_clap.enable_relay_foreign_network_quic").to_string(), + num_args = 0..=1, + default_missing_value = "true" + )] + enable_relay_foreign_network_quic: Option, + #[arg( long, env = "ET_STUN_SERVERS", @@ -1030,9 +1040,6 @@ impl NetworkOptions { f.disable_kcp_input = self.disable_kcp_input.unwrap_or(f.disable_kcp_input); f.enable_quic_proxy = self.enable_quic_proxy.unwrap_or(f.enable_quic_proxy); f.disable_quic_input = self.disable_quic_input.unwrap_or(f.disable_quic_input); - if let Some(quic_listen_port) = self.quic_listen_port { - f.quic_listen_port = quic_listen_port as u32; - } f.accept_dns = self.accept_dns.unwrap_or(f.accept_dns); f.private_mode = self.private_mode.unwrap_or(f.private_mode); f.foreign_relay_bps_limit = self @@ -1040,9 +1047,13 @@ impl NetworkOptions { .unwrap_or(f.foreign_relay_bps_limit); f.multi_thread_count = self.multi_thread_count.unwrap_or(f.multi_thread_count); f.disable_relay_kcp = self.disable_relay_kcp.unwrap_or(f.disable_relay_kcp); + f.disable_relay_quic = self.disable_relay_quic.unwrap_or(f.disable_relay_quic); f.enable_relay_foreign_network_kcp = self .enable_relay_foreign_network_kcp .unwrap_or(f.enable_relay_foreign_network_kcp); + f.enable_relay_foreign_network_quic = self + .enable_relay_foreign_network_quic + .unwrap_or(f.enable_relay_foreign_network_quic); f.disable_sym_hole_punching = self.disable_sym_hole_punching.unwrap_or(false); // Configure tld_dns_zone: use provided value if set if let Some(tld_dns_zone) = &self.tld_dns_zone { diff --git a/easytier/src/gateway/kcp_proxy.rs b/easytier/src/gateway/kcp_proxy.rs index 18b3e443..a9ee61bb 100644 --- a/easytier/src/gateway/kcp_proxy.rs +++ b/easytier/src/gateway/kcp_proxy.rs @@ -213,6 +213,10 @@ impl TcpProxyForWrappedSrcTrait for TcpProxyForKcpSrc { &self.0 } + fn set_src_modified(hdr: &mut PeerManagerHeader, modified: bool) -> &mut PeerManagerHeader { + hdr.set_kcp_src_modified(modified) + } + async fn check_dst_allow_wrapped_input(&self, dst_ip: &Ipv4Addr) -> bool { let Some(peer_manager) = self.0.get_peer_manager() else { return false; diff --git a/easytier/src/gateway/quic_proxy.rs b/easytier/src/gateway/quic_proxy.rs index 7303801a..45dddd18 100644 --- a/easytier/src/gateway/quic_proxy.rs +++ b/easytier/src/gateway/quic_proxy.rs @@ -1,169 +1,376 @@ -use anyhow::Context; -use dashmap::DashMap; -use pnet::packet::ipv4::Ipv4Packet; -use prost::Message as _; -use quinn::{Endpoint, Incoming}; -use std::net::{IpAddr, Ipv4Addr}; -use std::sync::{Arc, Mutex, Weak}; -use std::{net::SocketAddr, pin::Pin}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; -use tokio::net::TcpStream; -use tokio::task::JoinSet; -use tokio::time::timeout; - use crate::common::acl_processor::PacketInfo; -use crate::common::config::ConfigLoader; -use crate::common::error::Result; use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx}; -use crate::common::join_joinset_background; -use crate::defer; -use crate::gateway::tcp_proxy::{NatDstConnector, NatDstTcpConnector, TcpProxy}; +use crate::common::PeerId; +use crate::gateway::tcp_proxy::{NatDstConnector, TcpProxy}; use crate::gateway::wrapped_proxy::{ProxyAclHandler, TcpProxyForWrappedSrcTrait}; use crate::gateway::CidrSet; use crate::peers::peer_manager::PeerManager; +use crate::peers::PeerPacketFilter; use crate::proto::acl::{ChainType, Protocol}; use crate::proto::api::instance::{ ListTcpProxyEntryRequest, ListTcpProxyEntryResponse, TcpProxyEntry, TcpProxyEntryState, TcpProxyEntryTransportType, TcpProxyRpc, }; -use crate::proto::common::ProxyDstInfo; +use crate::proto::peer_rpc::KcpConnData as QuicConnData; use crate::proto::rpc_types; use crate::proto::rpc_types::controller::BaseController; -use crate::tunnel::packet_def::PeerManagerHeader; -use crate::tunnel::quic::{client_config, make_server_endpoint}; +use crate::tunnel::packet_def::{ + PacketType, PeerManagerHeader, ZCPacket, ZCPacketType, TAIL_RESERVED_SIZE, +}; +use crate::tunnel::quic::{client_config, endpoint_config, server_config}; +use anyhow::{anyhow, Context, Error}; +use atomic_refcell::AtomicRefCell; +use bytes::{BufMut, Bytes, BytesMut}; +use dashmap::DashMap; +use derivative::Derivative; +use derive_more::{Constructor, Deref, DerefMut, From, Into}; +use pnet::packet::ipv4::Ipv4Packet; +use prost::Message; +use quinn::udp::{EcnCodepoint, RecvMeta, Transmit}; +use quinn::{AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, TokioRuntime, UdpPoller}; +use std::cmp::min; +use std::future::Future; +use std::io::IoSliceMut; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::pin::Pin; +use std::ptr::copy_nonoverlapping; +use std::sync::{Arc, Weak}; +use std::task::Poll; +use std::time::Duration; +use tokio::io::{join, AsyncReadExt, Join}; +use tokio::sync::mpsc::error::TrySendError; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::task::JoinSet; +use tokio::time::{timeout, Instant}; +use tokio::{join, pin, select}; +use tokio_util::sync::PollSender; +use tracing::{debug, error, info, instrument, trace, warn}; -pub struct QUICStream { - endpoint: Option, - connection: Option, - sender: quinn::SendStream, - receiver: quinn::RecvStream, +//region packet +#[derive(Debug, Constructor)] +struct QuicPacket { + addr: SocketAddr, + payload: BytesMut, + segment: Option, + ecn: Option, } -impl AsyncRead for QUICStream { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - Pin::new(&mut this.receiver).poll_read(cx, buf) +#[derive(Debug, Clone, Copy, From, Into)] +pub struct PacketMargins { + pub header: usize, + pub trailer: usize, +} + +impl PacketMargins { + pub fn len(&self) -> usize { + self.header + self.trailer + } +} +//endregion + +//region socket +#[derive(Debug)] +struct QuicSocketPoller { + tx: PollSender, +} + +impl UdpPoller for QuicSocketPoller { + fn poll_writable( + self: Pin<&mut Self>, + cx: &mut std::task::Context, + ) -> Poll> { + self.get_mut() + .tx + .poll_reserve(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) } } -impl AsyncWrite for QUICStream { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - let this = self.get_mut(); - AsyncWrite::poll_write(Pin::new(&mut this.sender), cx, buf) +#[derive(Debug)] +pub struct QuicSocket { + addr: SocketAddr, + rx: AtomicRefCell>, + tx: Sender, + margins: PacketMargins, +} + +impl AsyncUdpSocket for QuicSocket { + fn create_io_poller(self: Arc) -> Pin> { + Box::into_pin(Box::new(QuicSocketPoller { + tx: PollSender::new(self.tx.clone()), + })) } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - Pin::new(&mut this.sender).poll_flush(cx) + fn try_send(&self, transmit: &Transmit) -> std::io::Result<()> { + match transmit.destination { + SocketAddr::V4(addr) => { + let len = transmit.contents.len(); + trace!("{:?} sending {:?} bytes to {:?}", self.addr, len, addr); + + let permit = self.tx.try_reserve().map_err(|e| match e { + TrySendError::Full(_) => std::io::ErrorKind::WouldBlock, + TrySendError::Closed(_) => std::io::ErrorKind::BrokenPipe, + })?; + + let segment_size = transmit.segment_size.unwrap_or(len); + let chunks = transmit.contents.chunks(segment_size); + let segment = segment_size + self.margins.len(); + + let mut payload = BytesMut::with_capacity(chunks.len() * segment); + + // The length of the last chunk could be smaller than segment_size + for chunk in chunks { + let len = chunk.len(); + unsafe { + copy_nonoverlapping( + chunk.as_ptr(), + payload.as_mut_ptr().add(self.margins.header), + len, + ); + payload.advance_mut(len + self.margins.len()); + } + } + + permit.send(QuicPacket { + addr: transmit.destination, + payload, + segment: Some(segment), + ecn: transmit.ecn, + }); + + Ok(()) + } + _ => Err(std::io::ErrorKind::ConnectionRefused.into()), + } } - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - Pin::new(&mut this.sender).poll_shutdown(cx) + fn poll_recv( + &self, + cx: &mut std::task::Context, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [RecvMeta], + ) -> Poll> { + if bufs.is_empty() || meta.is_empty() { + return Poll::Ready(Ok(0)); + } + + let mut rx = self.rx.borrow_mut(); + let mut count = 0; + + for (buf, meta) in bufs.iter_mut().zip(meta.iter_mut()) { + match rx.poll_recv(cx) { + Poll::Ready(Some(packet)) => { + let len = packet.payload.len(); + if len > buf.len() { + warn!( + "buffer too small for packet: {:?} < {:?}, dropped", + buf.len(), + len, + ); + continue; + } + trace!( + "{:?} received {:?} bytes from {:?}", + self.addr, + len, + packet.addr + ); + buf[0..len].copy_from_slice(&packet.payload); + *meta = RecvMeta { + addr: packet.addr, + len, + stride: len, + ecn: packet.ecn, + dst_ip: None, + }; + count += 1; + } + Poll::Ready(None) if count > 0 => break, + Poll::Ready(None) => { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "socket closed", + ))) + } + Poll::Pending => break, + } + } + + if count > 0 { + Poll::Ready(Ok(count)) + } else { + Poll::Pending + } + } + + fn local_addr(&self) -> std::io::Result { + Ok(self.addr) } } +//endregion + +//region addr +#[derive(Debug, Clone, Copy, Constructor)] +struct QuicAddr { + peer_id: PeerId, + packet_type: PacketType, +} + +impl From for SocketAddr { + #[inline] + fn from(value: QuicAddr) -> Self { + SocketAddr::new(IpAddr::V4(value.peer_id.into()), value.packet_type as u16) + } +} + +impl TryFrom for QuicAddr { + type Error = (); + + #[inline] + fn try_from(value: SocketAddr) -> Result { + let IpAddr::V4(ipv4) = value.ip() else { + return Err(()); + }; + let peer_id = ipv4.into(); + + let packet_type = match value.port() { + p if p == PacketType::QuicSrc as u16 => PacketType::QuicSrc, + p if p == PacketType::QuicDst as u16 => PacketType::QuicDst, + _ => return Err(()), + }; + + Ok(Self { + peer_id, + packet_type, + }) + } +} +//endregion + +//region stream +type QuicStreamInner = Join; +#[derive(Debug, Deref, DerefMut, From, Into)] +struct QuicStream { + #[deref] + #[deref_mut] + inner: QuicStreamInner, +} + +impl QuicStream { + #[inline] + fn id(&self) -> (StreamId, StreamId) { + (self.reader().id(), self.writer().id()) + } +} + +impl From<(SendStream, RecvStream)> for QuicStream { + #[inline] + fn from(value: (SendStream, RecvStream)) -> Self { + join(value.1, value.0).into() + } +} +//endregion #[derive(Debug, Clone)] -pub struct NatDstQUICConnector { +pub struct NatDstQuicConnector { + pub(crate) endpoint: Endpoint, pub(crate) peer_mgr: Weak, } #[async_trait::async_trait] -impl NatDstConnector for NatDstQUICConnector { - type DstStream = QUICStream; +impl NatDstConnector for NatDstQuicConnector { + type DstStream = QuicStreamInner; - #[tracing::instrument(skip(self), level = "debug", name = "NatDstQUICConnector::connect")] - async fn connect(&self, src: SocketAddr, nat_dst: SocketAddr) -> Result { + async fn connect( + &self, + src: SocketAddr, + nat_dst: SocketAddr, + ) -> crate::common::error::Result { let Some(peer_mgr) = self.peer_mgr.upgrade() else { return Err(anyhow::anyhow!("peer manager is not available").into()); }; - let IpAddr::V4(dst_ipv4) = nat_dst.ip() else { - return Err(anyhow::anyhow!("src must be an IPv4 address").into()); + let Some(dst_peer_id) = (match nat_dst { + SocketAddr::V4(addr) => peer_mgr.get_peer_map().get_peer_id_by_ipv4(addr.ip()).await, + SocketAddr::V6(_) => return Err(anyhow::anyhow!("ipv6 is not supported").into()), + }) else { + return Err(anyhow::anyhow!("no peer found for nat dst: {}", nat_dst).into()); }; - let Some(dst_peer) = peer_mgr.get_peer_map().get_peer_id_by_ipv4(&dst_ipv4).await else { - return Err(anyhow::anyhow!("no peer found for dst: {}", nat_dst).into()); + trace!("quic nat dst: {:?}, dst peers: {:?}", nat_dst, dst_peer_id); + + let addr = QuicAddr::new(dst_peer_id, PacketType::QuicSrc).into(); + let header = { + let conn_data = QuicConnData { + src: Some(src.into()), + dst: Some(nat_dst.into()), + }; + + let len = conn_data.encoded_len(); + if len > (u16::MAX as usize) { + return Err(anyhow!("conn data too large: {:?}", len).into()); + } + + let mut buf = BytesMut::with_capacity(2 + len); + + buf.put_u16(len as u16); + conn_data.encode(&mut buf).unwrap(); + + buf.freeze() }; - let Some(dst_peer_info) = peer_mgr.get_peer_map().get_route_peer_info(dst_peer).await - else { - return Err(anyhow::anyhow!("no peer info found for dst peer: {}", dst_peer).into()); + let mut connect_tasks = JoinSet::>::new(); + let connect = |tasks: &mut JoinSet<_>| { + let endpoint = self.endpoint.clone(); + let header = header.clone(); + + tasks.spawn(async move { + let connection = endpoint.connect(addr, "")?.await?; + let mut stream: QuicStream = connection.open_bi().await?.into(); + stream.writer_mut().write_chunk(header).await?; + Ok(stream) + }); }; - let Some(dst_ipv4): Option = dst_peer_info.ipv4_addr.map(Into::into) else { - return Err(anyhow::anyhow!("no ipv4 found for dst peer: {}", dst_peer).into()); - }; + connect(&mut connect_tasks); - let Some(quic_port) = dst_peer_info.quic_port else { - return Err(anyhow::anyhow!("no quic port found for dst peer: {}", dst_peer).into()); - }; + let timer = tokio::time::sleep(Duration::from_millis(200)); + pin!(timer); - let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()) - .with_context(|| format!("failed to create QUIC endpoint for src: {}", src))?; - endpoint.set_default_client_config(client_config()); + let mut retry_remain = 5; + loop { + select! { + Some(result) = connect_tasks.join_next() => { + match result { + Ok(Ok(stream)) => return Ok(stream.into()), + _ => { + if connect_tasks.is_empty() { + if retry_remain == 0 { + return Err(anyhow!("failed to connect to nat dst: {:?}", nat_dst).into()) + } - // connect to server - let connection = { - let _g = peer_mgr.get_global_ctx().net_ns.guard(); - endpoint - .connect( - SocketAddr::new(dst_ipv4.into(), quic_port as u16), - "localhost", - ) - .unwrap() - .await - .with_context(|| { - format!( - "failed to connect to NAT destination {} from {}, real dst: {}", - nat_dst, src, dst_ipv4 - ) - })? - }; - - let (mut w, r) = connection - .open_bi() - .await - .with_context(|| "open_bi failed")?; - - let proxy_dst_info = ProxyDstInfo { - dst_addr: Some(nat_dst.into()), - }; - let proxy_dst_info_buf = proxy_dst_info.encode_to_vec(); - let buf_len = proxy_dst_info_buf.len() as u8; - w.write(&buf_len.to_le_bytes()) - .await - .with_context(|| "failed to write proxy dst info buf len to QUIC stream")?; - w.write(&proxy_dst_info_buf) - .await - .with_context(|| "failed to write proxy dst info to QUIC stream")?; - - Ok(QUICStream { - endpoint: Some(endpoint), - connection: Some(connection), - sender: w, - receiver: r, - }) + retry_remain -= 1; + connect(&mut connect_tasks); + timer.as_mut().reset(Instant::now() + Duration::from_millis(200)) + } + } + } + } + _ = &mut timer, if retry_remain > 0 => { + retry_remain -= 1; + connect(&mut connect_tasks); + timer.as_mut().reset(Instant::now() + Duration::from_millis(200)); + } + } + } } + #[inline] fn check_packet_from_peer_fast(&self, _cidr_set: &CidrSet, _global_ctx: &GlobalCtx) -> bool { true } + #[inline] fn check_packet_from_peer( &self, _cidr_set: &CidrSet, @@ -172,283 +379,335 @@ impl NatDstConnector for NatDstQUICConnector { _ipv4: &Ipv4Packet, _real_dst_ip: &mut Ipv4Addr, ) -> bool { - hdr.from_peer_id == hdr.to_peer_id && !hdr.is_kcp_src_modified() + hdr.from_peer_id == hdr.to_peer_id && hdr.is_quic_src_modified() } + #[inline] fn transport_type(&self) -> TcpProxyEntryTransportType { TcpProxyEntryTransportType::Quic } } #[derive(Clone)] -struct TcpProxyForQUICSrc(Arc>); +struct TcpProxyForQuicSrc(Arc>); #[async_trait::async_trait] -impl TcpProxyForWrappedSrcTrait for TcpProxyForQUICSrc { - type Connector = NatDstQUICConnector; +impl TcpProxyForWrappedSrcTrait for TcpProxyForQuicSrc { + type Connector = NatDstQuicConnector; + #[inline] fn get_tcp_proxy(&self) -> &Arc> { &self.0 } + #[inline] + fn set_src_modified(hdr: &mut PeerManagerHeader, modified: bool) -> &mut PeerManagerHeader { + hdr.set_quic_src_modified(modified) + } + + #[inline] async fn check_dst_allow_wrapped_input(&self, dst_ip: &Ipv4Addr) -> bool { let Some(peer_manager) = self.0.get_peer_manager() else { return false; }; - let peer_map: Arc = peer_manager.get_peer_map(); - let Some(dst_peer_id) = peer_map.get_peer_id_by_ipv4(dst_ip).await else { - return false; - }; - let Some(peer_info) = peer_map.get_route_peer_info(dst_peer_id).await else { - return false; - }; - tracing::debug!( - "check dst {} allow quic input, peer info: {:?}", - dst_ip, - peer_info - ); - let Some(quic_port) = peer_info.quic_port else { - return false; - }; - quic_port > 0 + peer_manager + .check_allow_quic_to_dst(&IpAddr::V4(*dst_ip)) + .await } } -pub struct QUICProxySrc { - peer_manager: Arc, - tcp_proxy: TcpProxyForQUICSrc, +#[derive(Debug)] +enum QuicProxyRole { + Src, + Dst, } -impl QUICProxySrc { - pub async fn new(peer_manager: Arc) -> Self { - let tcp_proxy = TcpProxy::new( - peer_manager.clone(), - NatDstQUICConnector { - peer_mgr: Arc::downgrade(&peer_manager), - }, - ); - - Self { - peer_manager, - tcp_proxy: TcpProxyForQUICSrc(tcp_proxy), +impl QuicProxyRole { + #[inline] + const fn incoming(&self) -> PacketType { + match self { + QuicProxyRole::Src => PacketType::QuicDst, + QuicProxyRole::Dst => PacketType::QuicSrc, } } - pub async fn start(&self) { - self.peer_manager - .add_nic_packet_process_pipeline(Box::new(self.tcp_proxy.clone())) - .await; - self.peer_manager - .add_packet_process_pipeline(Box::new(self.tcp_proxy.0.clone())) - .await; - self.tcp_proxy.0.start(false).await.unwrap(); - } - - pub fn get_tcp_proxy(&self) -> Arc> { - self.tcp_proxy.0.clone() + #[inline] + const fn outgoing(&self) -> PacketType { + match self { + QuicProxyRole::Src => PacketType::QuicSrc, + QuicProxyRole::Dst => PacketType::QuicDst, + } } } -pub struct QUICProxyDst { - global_ctx: Arc, - endpoint: Arc, - proxy_entries: Arc>, - tasks: Arc>>, +// Receive packets from peers and forward them to the QUIC endpoint +#[derive(Debug)] +struct QuicPacketReceiver { + tx: Sender, + role: QuicProxyRole, +} + +#[async_trait::async_trait] +impl PeerPacketFilter for QuicPacketReceiver { + async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { + let header = packet.peer_manager_header().unwrap(); + + if header.packet_type != self.role.incoming() as u8 { + return Some(packet); + } + + let addr = QuicAddr::new(header.from_peer_id.get(), self.role.outgoing()); + + if let Err(e) = self.tx.try_send(QuicPacket::new( + addr.into(), + packet.payload_bytes(), + None, + None, + )) { + debug!("failed to send quic packet to endpoint: {:?}", e); + } + + None + } +} + +// Send to peers packets received from the QUIC endpoint +#[derive(Debug)] +struct QuicPacketSender { + peer_mgr: Arc, + rx: Receiver, + + header: Bytes, + zc_packet_type: ZCPacketType, + margins: PacketMargins, +} + +impl QuicPacketSender { + #[instrument] + pub async fn run(mut self) { + while let Some(packet) = self.rx.recv().await { + let Ok(addr) = QuicAddr::try_from(packet.addr) else { + error!("invalid quic packet addr: {:?}", packet.addr); + continue; + }; + + let mut payload = packet.payload; + let segment = packet + .segment + .expect("segment size must be set for outgoing quic packet"); + + while !payload.is_empty() { + let len = min(payload.len(), segment); + let mut payload = payload.split_to(len); + payload[..self.margins.header].copy_from_slice(&self.header); + payload.truncate(len - self.margins.trailer); + let mut packet = ZCPacket::new_from_buf(payload, self.zc_packet_type); + + packet.fill_peer_manager_hdr( + self.peer_mgr.my_peer_id(), + addr.peer_id, + addr.packet_type as u8, + ); + + if let Err(e) = self.peer_mgr.send_msg_for_proxy(packet, addr.peer_id).await { + error!("failed to send QUIC packet to peer: {:?}", e); + } + } + } + } +} + +#[derive(Derivative, Clone)] +#[derivative(Debug)] +struct QuicStreamContext { + global_ctx: ArcGlobalCtx, + proxy_entries: Arc>, + cidr_set: Arc, + #[derivative(Debug = "ignore")] route: Arc, } -impl QUICProxyDst { - pub fn new( - global_ctx: ArcGlobalCtx, - route: Arc, - ) -> Result { - let _g = global_ctx.net_ns.guard(); - let endpoint = make_server_endpoint( - format!("0.0.0.0:{}", global_ctx.config.get_flags().quic_listen_port) - .parse() - .unwrap(), - ) - .map_err(|e| anyhow::anyhow!("failed to create QUIC endpoint: {}", e))?; - let tasks = Arc::new(Mutex::new(JoinSet::new())); - join_joinset_background(tasks.clone(), "QUICProxyDst tasks".to_string()); - Ok(Self { - global_ctx, - endpoint: Arc::new(endpoint), +impl QuicStreamContext { + fn new(peer_mgr: Arc) -> Self { + let global_ctx = peer_mgr.get_global_ctx(); + Self { + global_ctx: global_ctx.clone(), proxy_entries: Arc::new(DashMap::new()), - tasks, - route, - }) + cidr_set: Arc::new(CidrSet::new(global_ctx.clone())), + route: Arc::new(peer_mgr.get_route()), + } } +} - pub async fn start(&self) -> Result<()> { - let endpoint = self.endpoint.clone(); - let tasks = Arc::downgrade(&self.tasks.clone()); - let ctx = self.global_ctx.clone(); - let cidr_set = Arc::new(CidrSet::new(ctx.clone())); - let proxy_entries = self.proxy_entries.clone(); - let route = self.route.clone(); +struct QuicStreamReceiver { + endpoint: Endpoint, + tasks: JoinSet<()>, + ctx: Arc, +} - let task = async move { - loop { - match endpoint.accept().await { - Some(conn) => { - let Some(tasks) = tasks.upgrade() else { - tracing::warn!( - "QUICProxyDst tasks is not available, stopping accept loop" - ); - return; - }; - tasks - .lock() - .unwrap() - .spawn(Self::handle_connection_with_timeout( - conn, - ctx.clone(), - cidr_set.clone(), - proxy_entries.clone(), - route.clone(), - )); - } - None => { - return; - } +impl QuicStreamReceiver { + async fn run(mut self) { + loop { + select! { + biased; + + Some(incoming) = self.endpoint.accept() => { + let addr = incoming.remote_address(); + let connection = match incoming.accept() { + Ok(connection) => connection, + Err(e) => { + error!("failed to accept quic connection from {:?}: {:?}", addr, e); + continue; + } + }; + + let addr = connection.remote_address(); + let connection = match connection.await { + Ok(connection) => connection, + Err(e) => { + error!("failed to accept quic connection from {:?}: {:?}", addr, e); + continue; + } + }; + + let ctx = self.ctx.clone(); + self.tasks.spawn(async move { + let mut tasks = JoinSet::new(); + loop { + select! { + biased; + + e = connection.closed() => { + info!("connection to {:?} closed: {:?}", addr, e); + break; + } + + stream = connection.accept_bi() => { + let stream = match stream { + Ok(stream) => stream.into(), + Err(e) => { + warn!("failed to accept bi stream from {:?}: {:?}", connection.remote_address(), e); + break; + } + }; + + match Self::establish_stream(stream, ctx.clone()).await { + Ok(stream) => drop(tasks.spawn(stream)), + Err(e) => warn!("failed to establish quic stream from {:?}: {:?}", connection.remote_address(), e), + } + } + + res = tasks.join_next(), if !tasks.is_empty() => { + debug!("quic stream task completed for {:?}: {:?}", addr, res); + } + } + } + + connection.close(1u32.into(), b"error"); + }); } + + _ = self.tasks.join_next(), if !self.tasks.is_empty() => {} } - }; - - self.tasks.lock().unwrap().spawn(task); - - Ok(()) + } } - pub fn local_addr(&self) -> Result { - self.endpoint.local_addr().map_err(Into::into) - } - - async fn handle_connection_with_timeout( - conn: Incoming, - ctx: Arc, - cidr_set: Arc, - proxy_entries: Arc>, - route: Arc, - ) { - let remote_addr = conn.remote_address(); - defer!( - proxy_entries.remove(&remote_addr); - if proxy_entries.capacity() - proxy_entries.len() > 16 { - proxy_entries.shrink_to_fit(); - } - ); - let ret = timeout( - std::time::Duration::from_secs(10), - Self::handle_connection( - conn, - ctx, - cidr_set, - remote_addr, - proxy_entries.clone(), - route, - ), + async fn read_stream_header(stream: &mut QuicStream) -> Result { + const STREAM_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(5); + const STREAM_HEADER_LIMIT: u16 = 512; + let len = timeout(STREAM_HEADER_READ_TIMEOUT, stream.read_u16()) + .await + .context("timeout reading header length")??; + if len > STREAM_HEADER_LIMIT { + return Err(anyhow::anyhow!("stream header too long")); + } + let mut header = Vec::with_capacity(len as usize); + timeout( + STREAM_HEADER_READ_TIMEOUT, + stream + .reader_mut() + .take(len as u64) + .read_to_end(&mut header), ) - .await; - - match ret { - Ok(Ok((quic_stream, tcp_stream, acl))) => { - let remote_addr = quic_stream.connection.as_ref().map(|c| c.remote_address()); - let ret = acl.copy_bidirection_with_acl(quic_stream, tcp_stream).await; - tracing::info!( - "QUIC connection handled, result: {:?}, remote addr: {:?}", - ret, - remote_addr, - ); - } - Ok(Err(e)) => { - tracing::error!("Failed to handle QUIC connection: {}", e); - } - Err(_) => { - tracing::warn!("Timeout while handling QUIC connection"); - } - } + .await + .context("timeout reading header")??; + Ok(header.into()) } - async fn handle_connection( - incoming: Incoming, - ctx: ArcGlobalCtx, - cidr_set: Arc, - proxy_entry_key: SocketAddr, - proxy_entries: Arc>, - route: Arc, - ) -> Result<(QUICStream, TcpStream, ProxyAclHandler)> { - let conn = incoming.await.with_context(|| "accept failed")?; - let addr = conn.remote_address(); - tracing::info!("Accepted QUIC connection from {}", addr); - let (w, mut r) = conn.accept_bi().await.with_context(|| "accept_bi failed")?; - let len = r - .read_u8() - .await - .with_context(|| "failed to read proxy dst info buf len")?; - let mut buf = vec![0u8; len as usize]; - r.read_exact(&mut buf) - .await - .with_context(|| "failed to read proxy dst info")?; - - let proxy_dst_info = - ProxyDstInfo::decode(&buf[..]).with_context(|| "failed to decode proxy dst info")?; - - let dst_socket: SocketAddr = proxy_dst_info - .dst_addr - .map(Into::into) - .ok_or_else(|| anyhow::anyhow!("no dst addr in proxy dst info"))?; - - let SocketAddr::V4(mut dst_socket) = dst_socket else { - return Err(anyhow::anyhow!("NAT destination must be an IPv4 address").into()); - }; - - let mut real_ip = *dst_socket.ip(); - if cidr_set.contains_v4(*dst_socket.ip(), &mut real_ip) { - dst_socket.set_ip(real_ip); - } - - let src_ip = addr.ip(); - let dst_ip = *dst_socket.ip(); - let (src_groups, dst_groups) = tokio::join!( - route.get_peer_groups_by_ip(&src_ip), - route.get_peer_groups_by_ipv4(&dst_ip) - ); - - if ctx.should_deny_proxy(&dst_socket.into(), false) { - return Err(anyhow::anyhow!( - "dst socket {:?} is in running listeners, ignore it", - dst_socket - ) - .into()); - } - - let send_to_self = ctx.is_ip_local_virtual_ip(&dst_ip.into()); - if send_to_self && ctx.no_tun() { - dst_socket = format!("127.0.0.1:{}", dst_socket.port()).parse().unwrap(); - } + async fn establish_stream( + mut stream: QuicStream, + ctx: Arc, + ) -> Result>, Error> { + let conn_data = Self::read_stream_header(&mut stream).await?; + let conn_data_parsed = QuicConnData::decode(conn_data.as_ref()) + .context("failed to decode quic stream header")?; + let handle = stream.id(); + let proxy_entries = &ctx.proxy_entries; proxy_entries.insert( - proxy_entry_key, + handle, TcpProxyEntry { - src: Some(addr.into()), - dst: Some(SocketAddr::V4(dst_socket).into()), + src: conn_data_parsed.src, + dst: conn_data_parsed.dst, start_time: chrono::Local::now().timestamp() as u64, state: TcpProxyEntryState::ConnectingDst.into(), transport_type: TcpProxyEntryTransportType::Quic.into(), }, ); + crate::defer! { + proxy_entries.remove(&handle); + if proxy_entries.capacity() - proxy_entries.len() > 16 { + proxy_entries.shrink_to_fit(); + } + } + + let src_socket: SocketAddr = conn_data_parsed + .src + .ok_or_else(|| anyhow!("missing src addr in quic stream header"))? + .into(); + let mut dst_socket: SocketAddr = conn_data_parsed + .dst + .ok_or_else(|| anyhow!("missing dst addr in quic stream header"))? + .into(); + + if let IpAddr::V4(dst_v4_ip) = dst_socket.ip() { + let mut real_ip = dst_v4_ip; + if ctx.cidr_set.contains_v4(dst_v4_ip, &mut real_ip) { + dst_socket.set_ip(real_ip.into()); + } + }; + + let src_ip = src_socket.ip(); + let dst_ip = dst_socket.ip(); + + let route = ctx.route.clone(); + let (src_groups, dst_groups) = join!( + route.get_peer_groups_by_ip(&src_ip), + route.get_peer_groups_by_ip(&dst_ip) + ); + + let global_ctx = ctx.global_ctx.clone(); + if global_ctx.should_deny_proxy(&dst_socket, false) { + return Err(anyhow::anyhow!( + "dst socket {:?} is in running listeners, ignore it", + dst_socket + )); + } + + let send_to_self = global_ctx.is_ip_local_virtual_ip(&dst_ip); + if send_to_self && global_ctx.no_tun() { + dst_socket = format!("127.0.0.1:{}", dst_socket.port()).parse()?; + } let acl_handler = ProxyAclHandler { - acl_filter: ctx.get_acl_filter().clone(), + acl_filter: global_ctx.get_acl_filter().clone(), packet_info: PacketInfo { src_ip, - dst_ip: dst_ip.into(), - src_port: Some(addr.port()), + dst_ip, + src_port: Some(src_socket.port()), dst_port: Some(dst_socket.port()), protocol: Protocol::Tcp, - packet_size: len as usize, + packet_size: conn_data.len(), src_groups, dst_groups, }, @@ -458,49 +717,236 @@ impl QUICProxyDst { ChainType::Forward }, }; - acl_handler.handle_packet(&buf)?; + acl_handler.handle_packet(&conn_data)?; - let connector = NatDstTcpConnector {}; + debug!("quic connect to dst socket: {:?}", dst_socket); - let dst_stream = { - let _g = ctx.net_ns.guard(); - connector - .connect("0.0.0.0:0".parse().unwrap(), dst_socket.into()) - .await? - }; + let _g = global_ctx.net_ns.guard(); + let connector = crate::gateway::tcp_proxy::NatDstTcpConnector {}; + let ret = connector.connect("0.0.0.0:0".parse()?, dst_socket).await?; - if let Some(mut e) = proxy_entries.get_mut(&proxy_entry_key) { + if let Some(mut e) = proxy_entries.get_mut(&handle) { e.state = TcpProxyEntryState::Connected.into(); } - let quic_stream = QUICStream { - endpoint: None, - connection: Some(conn), - sender: w, - receiver: r, - }; - - Ok((quic_stream, dst_stream, acl_handler)) + Ok(async move { + acl_handler + .copy_bidirection_with_acl(stream.inner, ret) + .await + }) } } -#[derive(Clone)] -pub struct QUICProxyDstRpcService(Weak>); +pub struct QuicProxy { + peer_mgr: Arc, -impl QUICProxyDstRpcService { - pub fn new(quic_proxy_dst: &QUICProxyDst) -> Self { - Self(Arc::downgrade(&quic_proxy_dst.proxy_entries)) + endpoint: Option, + + src: Option, + dst: Option, + + tasks: JoinSet<()>, +} + +impl QuicProxy { + #[inline] + pub fn src(&self) -> Option<&QuicProxySrc> { + self.src.as_ref() + } + + #[inline] + pub fn dst(&self) -> Option<&QuicProxyDst> { + self.dst.as_ref() + } +} + +impl QuicProxy { + pub fn new(peer_mgr: Arc) -> Self { + Self { + peer_mgr, + endpoint: None, + src: None, + dst: None, + tasks: JoinSet::new(), + } + } + + pub async fn run(&mut self, src: bool, dst: bool) { + trace!("quic proxy starting"); + + if self.endpoint.is_some() { + error!("quic proxy already running"); + return; + } + + let (header, zc_packet_type) = { + let header = ZCPacket::new_with_payload(&[]); + let zc_packet_type = header.packet_type(); + let payload_offset = header.payload_offset(); + ( + header.inner().split_to(payload_offset).freeze(), + zc_packet_type, + ) + }; + + let margins = (header.len(), TAIL_RESERVED_SIZE).into(); + + let (in_tx, in_rx) = channel(1024); + let (out_tx, out_rx) = channel(1024); + + let socket = QuicSocket { + addr: SocketAddr::new(Ipv4Addr::from(self.peer_mgr.my_peer_id()).into(), 0), + rx: AtomicRefCell::new(in_rx), + tx: out_tx, + margins, + }; + + let mut endpoint = Endpoint::new_with_abstract_socket( + endpoint_config(), + Some(server_config()), + Arc::new(socket), + Arc::new(TokioRuntime), + ) + .unwrap(); + endpoint.set_default_client_config(client_config()); + self.endpoint = Some(endpoint.clone()); + + let peer_mgr = self.peer_mgr.clone(); + self.tasks.spawn( + QuicPacketSender { + peer_mgr, + rx: out_rx, + header, + zc_packet_type, + margins, + } + .run(), + ); + + let peer_mgr = self.peer_mgr.clone(); + + if src { + if self.src.is_some() { + error!("quic proxy src already running"); + return; + } + + let tcp_proxy = TcpProxyForQuicSrc(TcpProxy::new( + peer_mgr.clone(), + NatDstQuicConnector { + endpoint: endpoint.clone(), + peer_mgr: Arc::downgrade(&peer_mgr), + }, + )); + + let src = QuicProxySrc { + peer_mgr: peer_mgr.clone(), + tcp_proxy, + tx: in_tx.clone(), + }; + src.run().await; + + self.src = Some(src); + } + + if dst { + if self.dst.is_some() { + error!("quic proxy dst already running"); + return; + } + + let stream_ctx = Arc::new(QuicStreamContext::new(peer_mgr.clone())); + + let dst = QuicProxyDst { + peer_mgr: peer_mgr.clone(), + tx: in_tx.clone(), + stream_ctx: stream_ctx.clone(), + }; + dst.run().await; + + self.tasks.spawn( + QuicStreamReceiver { + endpoint: endpoint.clone(), + tasks: JoinSet::new(), + ctx: stream_ctx, + } + .run(), + ); + + self.dst = Some(dst); + } + } +} + +pub struct QuicProxySrc { + peer_mgr: Arc, + tcp_proxy: TcpProxyForQuicSrc, + + tx: Sender, +} + +impl QuicProxySrc { + #[inline] + pub fn get_tcp_proxy(&self) -> Arc> { + self.tcp_proxy.get_tcp_proxy().clone() + } +} + +impl QuicProxySrc { + async fn run(&self) { + trace!("quic proxy src starting"); + self.peer_mgr + .add_nic_packet_process_pipeline(Box::new(self.tcp_proxy.clone())) + .await; + self.peer_mgr + .add_packet_process_pipeline(Box::new(self.tcp_proxy.0.clone())) + .await; + self.peer_mgr + .add_packet_process_pipeline(Box::new(QuicPacketReceiver { + tx: self.tx.clone(), + role: QuicProxyRole::Src, + })) + .await; + self.tcp_proxy.0.start(false).await.unwrap(); + } +} + +pub struct QuicProxyDst { + peer_mgr: Arc, + + tx: Sender, + stream_ctx: Arc, +} + +impl QuicProxyDst { + async fn run(&self) { + trace!("quic proxy dst starting"); + self.peer_mgr + .add_packet_process_pipeline(Box::new(QuicPacketReceiver { + tx: self.tx.clone(), + role: QuicProxyRole::Dst, + })) + .await; + } +} + +#[derive(Clone, Deref, DerefMut, From, Into)] +pub struct QuicProxyDstRpcService(Weak>); + +impl QuicProxyDstRpcService { + pub fn new(quic_proxy_dst: &QuicProxyDst) -> Self { + Self(Arc::downgrade(&quic_proxy_dst.stream_ctx.proxy_entries)) } } #[async_trait::async_trait] -impl TcpProxyRpc for QUICProxyDstRpcService { +impl TcpProxyRpc for QuicProxyDstRpcService { type Controller = BaseController; async fn list_tcp_proxy_entry( &self, _: BaseController, _request: ListTcpProxyEntryRequest, // Accept request of type HelloRequest - ) -> std::result::Result { + ) -> Result { let mut reply = ListTcpProxyEntryResponse::default(); if let Some(tcp_proxy) = self.0.upgrade() { for item in tcp_proxy.iter() { @@ -510,3 +956,380 @@ impl TcpProxyRpc for QUICProxyDstRpcService { Ok(reply) } } + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Buf; + + fn init() { + let _ = tracing_subscriber::fmt() + .with_env_filter("debug") + .try_init(); + } + + /// Helper function: Create a pair of interconnected QuicSockets. + /// Data sent by socket_a will enter socket_b's rx, and vice versa. + fn make_socket_pair() -> (QuicSocket, QuicSocket) { + let addr_a: SocketAddr = "127.0.0.1:5000".parse().unwrap(); + let addr_b: SocketAddr = "127.0.0.1:5001".parse().unwrap(); + + // Bidirectional channels: A->B and B->A + // Sufficient capacity to prevent packet loss during high concurrency + let (tx_a_out, rx_a_out) = channel::(50_000); + let (tx_b_in, rx_b_in) = channel::(50_000); + + let (tx_b_out, rx_b_out) = channel::(50_000); + let (tx_a_in, rx_a_in) = channel::(50_000); + + let margins = (20, 25).into(); + + forward(rx_a_out, tx_b_in, addr_a, margins); + forward(rx_b_out, tx_a_in, addr_b, margins); + + let socket_a = QuicSocket { + addr: addr_a, + rx: AtomicRefCell::new(rx_a_in), + tx: tx_a_out, + margins, + }; + + let socket_b = QuicSocket { + addr: addr_b, + rx: AtomicRefCell::new(rx_b_in), + tx: tx_b_out, + margins, + }; + + (socket_a, socket_b) + } + + fn endpoint() -> (Endpoint, Endpoint) { + let endpoint_config = endpoint_config(); + let server_config = server_config(); + let client_config = client_config(); + + // 1. Create an in-memory Socket pair + let (socket_client, socket_server) = make_socket_pair(); + let socket_client = Arc::new(socket_client); + let socket_server = Arc::new(socket_server); + + // 3. Configure Client Endpoint + let mut client_endpoint = Endpoint::new_with_abstract_socket( + endpoint_config.clone(), + Some(server_config.clone()), + socket_client.clone(), + Arc::new(TokioRuntime), + ) + .unwrap(); + client_endpoint.set_default_client_config(client_config.clone()); + + // 2. Configure Server Endpoint + let mut server_endpoint = Endpoint::new_with_abstract_socket( + endpoint_config.clone(), + Some(server_config.clone()), + socket_server.clone(), + Arc::new(TokioRuntime), + ) + .unwrap(); + server_endpoint.set_default_client_config(client_config.clone()); + + (client_endpoint, server_endpoint) + } + + fn forward( + mut rx: Receiver, + tx: Sender, + addr: SocketAddr, + margins: PacketMargins, + ) { + const BATCH_SIZE: usize = 128; + tokio::spawn(async move { + // Key optimization: use buffer for batch processing + let mut buffer = Vec::with_capacity(BATCH_SIZE); + + // recv_many wakes up when data is available, taking up to 100 packets at a time + // This reduces context switch overhead by 99 times compared to taking 1 packet at a time + while rx.recv_many(&mut buffer, BATCH_SIZE).await > 0 { + for packet in buffer.iter_mut() { + // [Filter Logic]: Modify address here + packet.addr = addr; + packet.payload.advance(margins.header); + packet + .payload + .truncate(packet.payload.len() - margins.trailer); + } + // Batch forward + for packet in buffer.drain(..) { + if let Err(e) = tx.send(packet).await { + info!("{:?}", e); + return; // Channel closed + } + } + } + }); + } + + #[tokio::test] + async fn test_ping() -> anyhow::Result<()> { + let (client_endpoint, server_endpoint) = endpoint(); + let server_addr = server_endpoint.local_addr()?; + + // 4. Server receive task + let server_handle = tokio::spawn(async move { + println!("Server: Waiting for connection..."); + if let Some(conn) = server_endpoint.accept().await { + let connection = conn.await.unwrap(); + println!( + "Server: Connection accepted from {}", + connection.remote_address() + ); + + // Accept bidirectional stream + let (mut send, mut recv) = connection.accept_bi().await.unwrap(); + + // Read data + let mut buf = vec![0u8; 10]; + recv.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ping______"); + println!("Server: Received 'ping______'"); + + // Send reply + send.write_all(b"pong______").await.unwrap(); + send.finish().unwrap(); + + let _ = connection.closed().await; + } + }); + + // 5. Client initiates connection + // Note: The connect address here must be V4, because try_send is limited to SocketAddr::V4 + println!("Client: Connecting..."); + let connection = client_endpoint.connect(server_addr, "localhost")?.await?; + println!("Client: Connected!"); + + // Open a stream and send data + let (mut send, mut recv) = connection.open_bi().await?; + send.write_all(b"ping______").await?; + send.finish()?; + + // Read reply + let mut buf = vec![0u8; 10]; + recv.read_exact(&mut buf).await?; + assert_eq!(&buf, b"pong______"); + println!("Client: Received 'pong______'"); + + // 6. Cleanup + connection.close(0u32.into(), b"done"); + // Wait for Server to finish + let _ = tokio::time::timeout(Duration::from_secs(2), server_handle).await; + + Ok(()) + } + + #[tokio::test] + #[ignore = "consumes massive memory (~16GB)"] + async fn test_bandwidth() -> anyhow::Result<()> { + // --- 3. Define test data volume --- + // Total test size: 512 MB + const TOTAL_SIZE: usize = 32768 * 1024 * 1024; + // Write chunk size: 1 MB (simulate large chunk write) + const CHUNK_SIZE: usize = 1024 * 1024; + + let (client_endpoint, server_endpoint) = endpoint(); + let server_addr = server_endpoint.local_addr()?; + + // --- 4. Server side (receive and timing) --- + let server_handle = tokio::spawn(async move { + if let Some(conn) = server_endpoint.accept().await { + let connection = conn.await.unwrap(); + // Accept unidirectional stream + let mut recv = connection.accept_uni().await.unwrap(); + + let start = std::time::Instant::now(); + let mut received = 0; + + // Loop read until the stream ends + // read_chunk performs slightly better than read_exact because it reduces internal buffer copying + while let Some(chunk) = recv.read_chunk(usize::MAX, true).await.unwrap() { + received += chunk.bytes.len(); + } + + let duration = start.elapsed(); + assert_eq!(received, TOTAL_SIZE, "Data length mismatch"); + + let seconds = duration.as_secs_f64(); + let mbps = (received as f64 * 8.0) / (1_000_000.0 * seconds); + let gbps = mbps / 1000.0; + + println!("--------------------------------------------------"); + println!("Server Recv Statistics:"); + println!(" Total Data: {} MB", received / 1024 / 1024); + println!(" Duration : {:.2?}", duration); + println!(" Throughput: {:.2} Gbps ({:.2} Mbps)", gbps, mbps); + println!("--------------------------------------------------"); + + // Keep connection until the Client disconnects + let _ = connection.closed().await; + } + }); + + // --- 5. Client side (send) --- + let connection = client_endpoint.connect(server_addr, "localhost")?.await?; + let mut send = connection.open_uni().await?; + + // Construct a 1MB data chunk + let data_chunk = vec![0u8; CHUNK_SIZE]; + let bytes_data = Bytes::from(data_chunk); // Use Bytes to avoid repeated allocation + + println!("Client: Start sending {} MB...", TOTAL_SIZE / 1024 / 1024); + let start_send = std::time::Instant::now(); + + let chunks = TOTAL_SIZE / CHUNK_SIZE; + for _ in 0..chunks { + // write_chunk is most efficient when used with Bytes + send.write_chunk(bytes_data.clone()).await?; + } + + // Tell peer sending is finished + send.finish()?; + // Wait for the stream to close completely (ensure peer received FIN) + send.stopped().await?; + + let send_duration = start_send.elapsed(); + println!("Client: Send finished in {:.2?}", send_duration); + + // Close connection + connection.close(0u32.into(), b"done"); + + // Wait for Server to print results + let _ = tokio::time::timeout(Duration::from_secs(5), server_handle).await; + + Ok(()) + } + + #[tokio::test] + #[ignore = "consumes massive memory (~16GB)"] + async fn test_bandwidth_parallel() -> anyhow::Result<()> { + // --- 1. Configuration parameters --- + const STREAM_COUNT: usize = 16; // Number of concurrent streams + const STREAM_SIZE: usize = 1024 * 1024 * 1024; // Each stream sends 1GB + + let (client_endpoint, server_endpoint) = endpoint(); + let server_addr = server_endpoint.local_addr()?; + + // --- 3. Server side (concurrent receiver) --- + let server_handle = tokio::spawn(async move { + if let Some(conn) = server_endpoint.accept().await { + let connection = conn.await.unwrap(); + println!("Server: Accepted connection"); + + let mut stream_handles = Vec::new(); + let start = std::time::Instant::now(); + + // Accept an expected number of streams + for i in 0..STREAM_COUNT { + match connection.accept_uni().await { + Ok(mut recv) => { + // Start an independent processing task for each stream + let handle = tokio::spawn(async move { + // Read all data + match recv.read_to_end(usize::MAX).await { + Ok(data) => { + // Verify length + assert_eq!( + data.len(), + STREAM_SIZE, + "Stream {} length mismatch", + i + ); + // Verify data content (verify data isolation) + // We agree that the first byte of data is (stream_index % 255) + // This ensures stream data is not mixed + let expected_byte = data[0] as usize; // Get the actual received marker + // Simple check of head and tail here, CRC can be used in production + if data[data.len() - 1] != data[0] { + panic!("Stream data corruption"); + } + expected_byte // Return marker for statistics + } + Err(e) => panic!("Stream read error: {}", e), + } + }); + stream_handles.push(handle); + } + Err(e) => panic!("Failed to accept stream {}: {}", i, e), + } + } + + // Wait for all streams to finish processing + let results = futures::future::join_all(stream_handles).await; + let duration = start.elapsed(); + + let speed = ((STREAM_COUNT * STREAM_SIZE) as f64 * 8.0) + / (duration.as_secs_f64() * 1_000_000.0); + + println!("--------------------------------------------------"); + println!("Server: All {} streams received processing.", results.len()); + println!("Total Time: {:.2?}", duration); + println!( + "Total Data: {} MB", + (STREAM_COUNT * STREAM_SIZE) / 1024 / 1024 + ); + println!( + "Average Speed: {:.2} Gbps ({:.2} Mbps)", + speed / 1024.0, + speed + ); + println!("--------------------------------------------------"); + + // Keep connection until the Client disconnects + let _ = connection.closed().await; + } + }); + + // --- 4. Client side (concurrent sender) --- + let connection = client_endpoint.connect(server_addr, "localhost")?.await?; + println!( + "Client: Connected, starting {} parallel streams...", + STREAM_COUNT + ); + + let start_send = std::time::Instant::now(); + let mut client_tasks = Vec::new(); + + // Start sending tasks concurrently + for i in 0..STREAM_COUNT { + let conn = connection.clone(); + client_tasks.push(tokio::spawn(async move { + // Open unidirectional stream + let mut send = conn.open_uni().await.expect("Failed to open stream"); + + // Construct data: use i as the padding marker to verify isolation + // All bytes are filled with (i % 255) + let fill_byte = (i % 255) as u8; + let data = vec![fill_byte; STREAM_SIZE]; + let bytes_data = Bytes::from(data); + + send.write_chunk(bytes_data).await.expect("Write failed"); + send.finish().expect("Finish failed"); + // Wait for Server to acknowledge receipt of FIN + send.stopped().await.expect("Stopped failed"); + })); + } + + // Wait for all sending tasks to complete + futures::future::join_all(client_tasks).await; + + let send_duration = start_send.elapsed(); + println!("Client: All streams sent in {:.2?}", send_duration); + + // Close connection + connection.close(0u32.into(), b"done"); + + // Wait for Server to finish + let _ = tokio::time::timeout(Duration::from_secs(10), server_handle).await; + + Ok(()) + } +} diff --git a/easytier/src/gateway/wrapped_proxy.rs b/easytier/src/gateway/wrapped_proxy.rs index 35d94343..87d28b43 100644 --- a/easytier/src/gateway/wrapped_proxy.rs +++ b/easytier/src/gateway/wrapped_proxy.rs @@ -12,14 +12,12 @@ use pnet::packet::{ use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite}; use tokio_util::io::InspectReader; +use crate::tunnel::packet_def::PeerManagerHeader; use crate::{ common::{acl_processor::PacketInfo, error::Result}, gateway::tcp_proxy::{NatDstConnector, TcpProxy}, peers::{acl_filter::AclFilter, NicPacketFilter}, - proto::{ - acl::{Action, ChainType}, - api::instance::TcpProxyEntryTransportType, - }, + proto::acl::{Action, ChainType}, tunnel::packet_def::ZCPacket, }; @@ -71,6 +69,7 @@ impl ProxyAclHandler { pub(crate) trait TcpProxyForWrappedSrcTrait: Send + Sync + 'static { type Connector: NatDstConnector; fn get_tcp_proxy(&self) -> &Arc>; + fn set_src_modified(hdr: &mut PeerManagerHeader, modified: bool) -> &mut PeerManagerHeader; async fn check_dst_allow_wrapped_input(&self, dst_ip: &Ipv4Addr) -> bool; } @@ -142,9 +141,7 @@ impl> NicPacket let hdr = zc_packet.mut_peer_manager_header().unwrap(); hdr.to_peer_id = self.get_tcp_proxy().get_my_peer_id().into(); - if self.get_tcp_proxy().get_transport_type() == TcpProxyEntryTransportType::Kcp { - hdr.set_kcp_src_modified(true); - } + Self::set_src_modified(hdr, true); true } } diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 623a5b0b..9bcfbbc5 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -31,7 +31,7 @@ use crate::gateway::icmp_proxy::IcmpProxy; #[cfg(feature = "kcp")] use crate::gateway::kcp_proxy::{KcpProxyDst, KcpProxyDstRpcService, KcpProxySrc}; #[cfg(feature = "quic")] -use crate::gateway::quic_proxy::{QUICProxyDst, QUICProxyDstRpcService, QUICProxySrc}; +use crate::gateway::quic_proxy::{QuicProxy, QuicProxyDstRpcService}; use crate::gateway::tcp_proxy::{NatDstTcpConnector, TcpProxy, TcpProxyRpcService}; use crate::gateway::udp_proxy::UdpProxy; use crate::peer_center::instance::PeerCenterInstance; @@ -541,9 +541,7 @@ pub struct Instance { kcp_proxy_dst: Option, #[cfg(feature = "quic")] - quic_proxy_src: Option, - #[cfg(feature = "quic")] - quic_proxy_dst: Option, + quic_proxy: Option, peer_center: Arc, @@ -627,9 +625,7 @@ impl Instance { kcp_proxy_dst: None, #[cfg(feature = "quic")] - quic_proxy_src: None, - #[cfg(feature = "quic")] - quic_proxy_dst: None, + quic_proxy: None, peer_center, @@ -927,21 +923,6 @@ impl Instance { }); } - #[cfg(feature = "quic")] - async fn run_quic_dst(&mut self) -> Result<(), Error> { - if self.global_ctx.get_flags().disable_quic_input { - return Ok(()); - } - - let route = Arc::new(self.peer_manager.get_route()); - let quic_dst = QUICProxyDst::new(self.global_ctx.clone(), route)?; - quic_dst.start().await?; - self.global_ctx - .set_quic_proxy_port(Some(quic_dst.local_addr()?.port())); - self.quic_proxy_dst = Some(quic_dst); - Ok(()) - } - pub async fn run(&mut self) -> Result<(), Error> { self.listener_manager .lock() @@ -982,19 +963,13 @@ impl Instance { } #[cfg(feature = "quic")] - if self.global_ctx.get_flags().enable_quic_proxy { - let quic_src = QUICProxySrc::new(self.get_peer_manager()).await; - quic_src.start().await; - self.quic_proxy_src = Some(quic_src); - } - - #[cfg(feature = "quic")] - if !self.global_ctx.get_flags().disable_quic_input { - if let Err(e) = self.run_quic_dst().await { - eprintln!( - "quic input start failed: {:?} (some platforms may not support)", - e - ); + { + let quic_src = self.global_ctx.get_flags().enable_quic_proxy; + let quic_dst = !self.global_ctx.get_flags().disable_quic_input; + if quic_src || quic_dst { + let mut quic_proxy = QuicProxy::new(self.get_peer_manager()); + quic_proxy.run(quic_src, quic_dst).await; + self.quic_proxy = Some(quic_proxy); } } @@ -1423,19 +1398,20 @@ impl Instance { } #[cfg(feature = "quic")] - if let Some(quic_proxy) = self.quic_proxy_src.as_ref() { - tcp_proxy_rpc_services.insert( - "quic_src".to_string(), - Arc::new(TcpProxyRpcService::new(quic_proxy.get_tcp_proxy())), - ); - } + if let Some(quic_proxy) = self.quic_proxy.as_ref() { + if let Some(quic_src) = quic_proxy.src() { + tcp_proxy_rpc_services.insert( + "quic_src".to_string(), + Arc::new(TcpProxyRpcService::new(quic_src.get_tcp_proxy())), + ); + } - #[cfg(feature = "quic")] - if let Some(quic_proxy) = self.quic_proxy_dst.as_ref() { - tcp_proxy_rpc_services.insert( - "quic_dst".to_string(), - Arc::new(QUICProxyDstRpcService::new(quic_proxy)), - ); + if let Some(quic_dst) = quic_proxy.dst() { + tcp_proxy_rpc_services.insert( + "quic_dst".to_string(), + Arc::new(QuicProxyDstRpcService::new(quic_dst)), + ); + } } tcp_proxy_rpc_services diff --git a/easytier/src/launcher.rs b/easytier/src/launcher.rs index 4e6df74f..e4b1f34f 100644 --- a/easytier/src/launcher.rs +++ b/easytier/src/launcher.rs @@ -710,10 +710,6 @@ impl NetworkConfig { flags.disable_quic_input = disable_quic_input; } - if let Some(quic_listen_port) = self.quic_listen_port { - flags.quic_listen_port = quic_listen_port as u32; - } - if let Some(disable_p2p) = self.disable_p2p { flags.disable_p2p = disable_p2p; } @@ -912,7 +908,6 @@ impl NetworkConfig { result.disable_kcp_input = Some(flags.disable_kcp_input); result.enable_quic_proxy = Some(flags.enable_quic_proxy); result.disable_quic_input = Some(flags.disable_quic_input); - result.quic_listen_port = Some(flags.quic_listen_port as i32); result.disable_p2p = Some(flags.disable_p2p); result.p2p_only = Some(flags.p2p_only); result.bind_device = Some(flags.bind_device); diff --git a/easytier/src/peers/acl_filter.rs b/easytier/src/peers/acl_filter.rs index 2ae3c8f5..6b3d6263 100644 --- a/easytier/src/peers/acl_filter.rs +++ b/easytier/src/peers/acl_filter.rs @@ -1,5 +1,5 @@ use std::net::{Ipv4Addr, Ipv6Addr}; -use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::atomic::Ordering; use std::time::Instant; use std::{ net::IpAddr, @@ -59,7 +59,6 @@ pub struct AclFilter { // Use ArcSwap for lock-free atomic replacement during hot reload acl_processor: ArcSwap, acl_enabled: Arc, - quic_udp_port: AtomicU16, // Track allowed outbound packets and automatically allow their corresponding inbound response // packets, even if they would normally be dropped by ACL rules @@ -80,7 +79,6 @@ impl AclFilter { Self { acl_processor: ArcSwap::from(Arc::new(AclProcessor::new(Acl::default()))), acl_enabled: Arc::new(AtomicBool::new(false)), - quic_udp_port: AtomicU16::new(0), outbound_allow_records, clean_task: tokio::spawn(async move { let max_life = std::time::Duration::from_secs(30); @@ -295,40 +293,6 @@ impl AclFilter { processor.increment_stat(AclStatKey::PacketsTotal); } - fn check_is_quic_packet( - &self, - packet_info: &PacketInfo, - my_ipv4: &Option, - my_ipv6: &Option, - ) -> bool { - if packet_info.protocol != Protocol::Udp { - return false; - } - - let quic_port = self.get_quic_udp_port(); - if quic_port == 0 { - return false; - } - - // quic input - if packet_info.dst_port == Some(quic_port) - && (packet_info.dst_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED) - || packet_info.dst_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED)) - { - return true; - } - - // quic output - if packet_info.src_port == Some(quic_port) - && (packet_info.src_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED) - || packet_info.src_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED)) - { - return true; - } - - false - } - /// Common ACL processing logic pub fn process_packet_with_acl( &self, @@ -360,10 +324,6 @@ impl AclFilter { } }; - if self.check_is_quic_packet(&packet_info, &my_ipv4, &my_ipv6) { - return true; - } - let chain_type = if is_in { if packet_info.dst_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED) || packet_info.dst_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED) @@ -424,12 +384,4 @@ impl AclFilter { } } } - - pub fn get_quic_udp_port(&self) -> u16 { - self.quic_udp_port.load(Ordering::Relaxed) - } - - pub fn set_quic_udp_port(&self, port: u16) { - self.quic_udp_port.store(port, Ordering::Relaxed); - } } diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index 8adc78f8..47ede94c 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -169,6 +169,7 @@ impl ForeignNetworkEntry { let mut flags = config.get_flags(); flags.disable_relay_kcp = !global_ctx.get_flags().enable_relay_foreign_network_kcp; + flags.disable_relay_quic = !global_ctx.get_flags().enable_relay_foreign_network_quic; config.set_flags(flags); config.set_mapped_listeners(Some(global_ctx.config.get_mapped_listeners())); diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 2717eccb..442312d9 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -1489,6 +1489,54 @@ impl PeerManager { true } + pub async fn check_allow_quic_to_dst(&self, dst_ip: &IpAddr) -> bool { + let route = self.get_route(); + let Some(dst_peer_id) = route.get_peer_id_by_ip(dst_ip).await else { + return false; + }; + let Some(peer_info) = route.get_peer_info(dst_peer_id).await else { + return false; + }; + + // check dst allow quic input + if !peer_info + .feature_flag + .map(|x| x.quic_input) + .unwrap_or(false) + { + return false; + } + + let next_hop_policy = Self::get_next_hop_policy(self.global_ctx.get_flags().latency_first); + // check relay node allow relay quic. + let Some(next_hop_id) = route + .get_next_hop_with_policy(dst_peer_id, next_hop_policy) + .await + else { + return false; + }; + + if next_hop_id == dst_peer_id { + // dst p2p, no need to relay + return true; + } + + let Some(next_hop_info) = route.get_peer_info(next_hop_id).await else { + return false; + }; + + // check next hop allow quic relay + if next_hop_info + .feature_flag + .map(|x| x.no_relay_quic) + .unwrap_or(false) + { + return false; + } + + true + } + pub async fn update_exit_nodes(&self) { let exit_nodes = self.global_ctx.config.get_exit_nodes(); *self.exit_nodes.write().await = exit_nodes; diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index e8d3b062..2b8f3316 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -123,6 +123,7 @@ fn is_foreign_network_info_newer( } impl RoutePeerInfo { + #[allow(deprecated)] pub fn new() -> Self { Self { peer_id: 0, @@ -141,9 +142,10 @@ impl RoutePeerInfo { feature_flag: None, peer_route_id: 0, network_length: 24, - quic_port: None, ipv6_addr: None, groups: Vec::new(), + + quic_port: None, } } @@ -191,10 +193,11 @@ impl RoutePeerInfo { .map(|x| x.network_length() as u32) .unwrap_or(24), - quic_port: global_ctx.get_quic_proxy_port().map(|x| x as u32), ipv6_addr: global_ctx.get_ipv6().map(|x| x.into()), groups: global_ctx.get_acl_groups(my_peer_id), + + ..Default::default() } } diff --git a/easytier/src/proto/api_manage.proto b/easytier/src/proto/api_manage.proto index de52ebc1..36e75c26 100644 --- a/easytier/src/proto/api_manage.proto +++ b/easytier/src/proto/api_manage.proto @@ -72,7 +72,7 @@ message NetworkConfig { optional bool enable_quic_proxy = 45; optional bool disable_quic_input = 46; - optional int32 quic_listen_port = 50; + optional int32 quic_listen_port = 50 [deprecated = true]; repeated PortForwardConfig port_forwards = 48; optional bool disable_sym_hole_punching = 49; diff --git a/easytier/src/proto/common.proto b/easytier/src/proto/common.proto index 9bbcb7d6..bb9f82b9 100644 --- a/easytier/src/proto/common.proto +++ b/easytier/src/proto/common.proto @@ -41,8 +41,11 @@ message FlagsInConfig { bool enable_quic_proxy = 24; // does this peer allow quic input bool disable_quic_input = 25; + // disable relay local network quic packets + bool disable_relay_quic = 35; + // quic listen port - uint32 quic_listen_port = 33; + uint32 quic_listen_port = 33 [deprecated = true]; // a global relay limit, only work for foreign network uint64 foreign_relay_bps_limit = 26; @@ -52,6 +55,9 @@ message FlagsInConfig { // enable relay foreign network kcp packets bool enable_relay_foreign_network_kcp = 28; + // enable relay foreign network quic packets + bool enable_relay_foreign_network_quic = 36; + // encryption algorithm to use, empty string means default (aes-gcm) string encryption_algorithm = 29; @@ -208,6 +214,8 @@ message PeerFeatureFlag { bool kcp_input = 3; bool no_relay_kcp = 4; bool support_conn_list_sync = 5; + bool quic_input = 6; + bool no_relay_quic = 7; } enum SocketType { diff --git a/easytier/src/proto/peer_rpc.proto b/easytier/src/proto/peer_rpc.proto index 8d162560..77861f74 100644 --- a/easytier/src/proto/peer_rpc.proto +++ b/easytier/src/proto/peer_rpc.proto @@ -23,7 +23,7 @@ message RoutePeerInfo { uint32 network_length = 13; - optional uint32 quic_port = 14; + optional uint32 quic_port = 14 [deprecated = true]; optional common.Ipv6Inet ipv6_addr = 15; repeated PeerGroupInfo groups = 16; diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index b5e349c0..1b014a00 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -634,22 +634,7 @@ pub async fn subnet_proxy_three_node_test( subnet_proxy_test_tcp(listen_ip, target_ip).await; subnet_proxy_test_udp(listen_ip, target_ip).await; } - - if enable_kcp_proxy && !disable_kcp_input { - let metrics = insts[0] - .get_global_ctx() - .stats_manager() - .get_metrics_by_prefix(&MetricName::TcpProxyConnect.to_string()); - assert_eq!(metrics.len(), 3); - for metric in metrics { - assert_eq!(1, metric.value); - assert!(metric.labels.labels().iter().any(|l| { - let t = - LabelType::Protocol(TcpProxyEntryTransportType::Kcp.as_str_name().to_string()); - t.key() == l.key && t.value() == l.value - })); - } - } else if enable_quic_proxy && !disable_quic_input { + if enable_quic_proxy && !disable_quic_input { let metrics = insts[0] .get_global_ctx() .stats_manager() @@ -663,6 +648,20 @@ pub async fn subnet_proxy_three_node_test( t.key() == l.key && t.value() == l.value })); } + } else if enable_kcp_proxy && !disable_kcp_input { + let metrics = insts[0] + .get_global_ctx() + .stats_manager() + .get_metrics_by_prefix(&MetricName::TcpProxyConnect.to_string()); + assert_eq!(metrics.len(), 3); + for metric in metrics { + assert_eq!(1, metric.value); + assert!(metric.labels.labels().iter().any(|l| { + let t = + LabelType::Protocol(TcpProxyEntryTransportType::Kcp.as_str_name().to_string()); + t.key() == l.key && t.value() == l.value + })); + } } else { // tcp subnet proxy let metrics = insts[2] diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index 095eb321..ed4f84fe 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -72,6 +72,8 @@ pub enum PacketType { ForeignNetworkPacket = 10, KcpSrc = 11, KcpDst = 12, + QuicSrc = 16, + QuicDst = 17, NoiseHandshakeMsg1 = 13, NoiseHandshakeMsg2 = 14, NoiseHandshakeMsg3 = 15, @@ -85,6 +87,7 @@ bitflags::bitflags! { const NO_PROXY = 0b0000_1000; const COMPRESSED = 0b0001_0000; const KCP_SRC_MODIFIED = 0b0010_0000; + const QUIC_SRC_MODIFIED = 0b1000_0000; const NOT_SEND_TO_TUN = 0b0100_0000; const _ = !0; @@ -206,6 +209,23 @@ impl PeerManagerHeader { .contains(PeerManagerHeaderFlags::KCP_SRC_MODIFIED) } + pub fn set_quic_src_modified(&mut self, modified: bool) -> &mut Self { + let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap(); + if modified { + flags.insert(PeerManagerHeaderFlags::QUIC_SRC_MODIFIED); + } else { + flags.remove(PeerManagerHeaderFlags::QUIC_SRC_MODIFIED); + } + self.flags = flags.bits(); + self + } + + pub fn is_quic_src_modified(&self) -> bool { + PeerManagerHeaderFlags::from_bits(self.flags) + .unwrap() + .contains(PeerManagerHeaderFlags::QUIC_SRC_MODIFIED) + } + pub fn set_not_send_to_tun(&mut self, not_send_to_tun: bool) -> &mut Self { let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap(); if not_send_to_tun { diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs index e8772f42..8512d048 100644 --- a/easytier/src/tunnel/quic.rs +++ b/easytier/src/tunnel/quic.rs @@ -26,12 +26,12 @@ pub fn transport_config() -> Arc { let mut config = TransportConfig::default(); config - // .max_concurrent_bidi_streams(VarInt::MAX) + .max_concurrent_bidi_streams(u8::MAX.into()) .max_concurrent_uni_streams(0u8.into()) .keep_alive_interval(Some(Duration::from_secs(5))) .initial_mtu(1200) .min_mtu(1200) - .enable_segmentation_offload(false) + .enable_segmentation_offload(true) .congestion_controller_factory(Arc::new(BbrConfig::default())); Arc::new(config)