diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index b6af900b..04d1eefd 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -42,6 +42,7 @@ pub type NetworkIdentity = crate::common::config::NetworkIdentity; pub enum GlobalCtxEvent { TunDeviceReady(String), TunDeviceError(String), + TunDeviceFallback(String), PeerAdded(PeerId), PeerRemoved(PeerId), diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 8d30f5fe..3f9f48f9 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "tun")] -use std::any::Any; use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -65,6 +63,10 @@ use crate::vpn_portal::{self, VpnPortal}; #[cfg(feature = "magic-dns")] use super::dns_server::{MAGIC_DNS_FAKE_IP, runner::DnsRunner}; use super::listeners::ListenerManager; +#[cfg(feature = "tun")] +use super::shared_tun::{ + SharedTunAccess, SharedTunAttachError, SharedTunMemberHandle, try_attach_shared_tun, +}; #[cfg(feature = "socks5")] use crate::gateway::socks5::Socks5Server; @@ -133,6 +135,13 @@ impl IpProxy { #[cfg(feature = "tun")] type NicCtx = super::virtual_nic::NicCtx; +#[cfg(feature = "tun")] +enum NicRuntime { + Dedicated(NicCtx), + Shared(SharedTunMemberHandle), + Dummy(JoinSet<()>), +} + #[cfg(feature = "magic-dns")] struct MagicDnsContainer { dns_runner_task: ScopedTask<()>, @@ -142,7 +151,7 @@ struct MagicDnsContainer { // nic container will be cleared when dhcp ip changed #[cfg(feature = "tun")] pub struct NicCtxContainer { - nic_ctx: Option>, + nic_ctx: Option, #[cfg(feature = "magic-dns")] magic_dns: Option, } @@ -150,14 +159,14 @@ pub struct NicCtxContainer { #[cfg(feature = "tun")] impl NicCtxContainer { #[cfg(not(feature = "magic-dns"))] - fn new(nic_ctx: NicCtx) -> Self { + fn new_dedicated(nic_ctx: NicCtx) -> Self { Self { - nic_ctx: Some(Box::new(nic_ctx)), + nic_ctx: Some(NicRuntime::Dedicated(nic_ctx)), } } #[cfg(feature = "magic-dns")] - fn new(nic_ctx: NicCtx, dns_runner: Option) -> Self { + fn new_dedicated(nic_ctx: NicCtx, dns_runner: Option) -> Self { if let Some(mut dns_runner) = dns_runner { let token = CancellationToken::new(); let token_clone = token.clone(); @@ -165,7 +174,7 @@ impl NicCtxContainer { let _ = dns_runner.run(token_clone).await; }); Self { - nic_ctx: Some(Box::new(nic_ctx)), + nic_ctx: Some(NicRuntime::Dedicated(nic_ctx)), magic_dns: Some(MagicDnsContainer { dns_runner_task: task.into(), dns_runner_cancel_token: token, @@ -173,15 +182,45 @@ impl NicCtxContainer { } } else { Self { - nic_ctx: Some(Box::new(nic_ctx)), + nic_ctx: Some(NicRuntime::Dedicated(nic_ctx)), magic_dns: None, } } } - fn new_with_any(ctx: T) -> Self { + #[cfg(not(feature = "magic-dns"))] + fn new_shared(handle: SharedTunMemberHandle) -> Self { Self { - nic_ctx: Some(Box::new(ctx)), + nic_ctx: Some(NicRuntime::Shared(handle)), + } + } + + #[cfg(feature = "magic-dns")] + fn new_shared(handle: SharedTunMemberHandle, dns_runner: Option) -> Self { + if let Some(mut dns_runner) = dns_runner { + let token = CancellationToken::new(); + let token_clone = token.clone(); + let task = tokio::spawn(async move { + let _ = dns_runner.run(token_clone).await; + }); + Self { + nic_ctx: Some(NicRuntime::Shared(handle)), + magic_dns: Some(MagicDnsContainer { + dns_runner_task: task.into(), + dns_runner_cancel_token: token, + }), + } + } else { + Self { + nic_ctx: Some(NicRuntime::Shared(handle)), + magic_dns: None, + } + } + } + + fn new_dummy(tasks: JoinSet<()>) -> Self { + Self { + nic_ctx: Some(NicRuntime::Dummy(tasks)), #[cfg(feature = "magic-dns")] magic_dns: None, } @@ -666,19 +705,31 @@ impl Instance { // use a mock nic ctx to consume packets. #[cfg(feature = "tun")] - async fn clear_nic_ctx( - arc_nic_ctx: ArcNicCtx, - packet_recv: Arc>, - ) { + async fn cleanup_nic_ctx(mut old_ctx: NicCtxContainer) { + if let Some(runtime) = old_ctx.nic_ctx.take() { + match runtime { + NicRuntime::Shared(handle) => handle.shutdown().await, + NicRuntime::Dedicated(_) | NicRuntime::Dummy(_) => {} + } + } + #[cfg(feature = "magic-dns")] - if let Some(old_ctx) = arc_nic_ctx.lock().await.take() - && let Some(dns_runner) = old_ctx.magic_dns - { + if let Some(dns_runner) = old_ctx.magic_dns.take() { dns_runner.dns_runner_cancel_token.cancel(); tracing::debug!("cancelling dns runner task"); let ret = dns_runner.dns_runner_task.await; tracing::debug!("dns runner task cancelled, ret: {:?}", ret); - }; + } + } + + #[cfg(feature = "tun")] + async fn clear_nic_ctx( + arc_nic_ctx: ArcNicCtx, + packet_recv: Arc>, + ) { + if let Some(old_ctx) = arc_nic_ctx.lock().await.take() { + Self::cleanup_nic_ctx(old_ctx).await; + } let mut tasks = JoinSet::new(); tasks.spawn(async move { @@ -690,7 +741,7 @@ impl Instance { arc_nic_ctx .lock() .await - .replace(NicCtxContainer::new_with_any(tasks)); + .replace(NicCtxContainer::new_dummy(tasks)); tracing::debug!("nic ctx cleared."); } @@ -716,13 +767,13 @@ impl Instance { } #[cfg(feature = "tun")] - async fn use_new_nic_ctx( + async fn use_new_dedicated_nic_ctx( arc_nic_ctx: ArcNicCtx, nic_ctx: NicCtx, #[cfg(feature = "magic-dns")] magic_dns: Option, ) { let mut g = arc_nic_ctx.lock().await; - *g = Some(NicCtxContainer::new( + *g = Some(NicCtxContainer::new_dedicated( nic_ctx, #[cfg(feature = "magic-dns")] magic_dns, @@ -730,6 +781,91 @@ impl Instance { tracing::debug!("nic ctx updated."); } + #[cfg(feature = "tun")] + async fn use_new_shared_nic_ctx( + arc_nic_ctx: ArcNicCtx, + handle: SharedTunMemberHandle, + #[cfg(feature = "magic-dns")] magic_dns: Option, + ) { + let mut g = arc_nic_ctx.lock().await; + *g = Some(NicCtxContainer::new_shared( + handle, + #[cfg(feature = "magic-dns")] + magic_dns, + )); + tracing::debug!("shared nic ctx updated."); + } + + #[cfg(all(not(mobile), feature = "tun"))] + async fn setup_dedicated_nic_ctx( + arc_nic_ctx: ArcNicCtx, + global_ctx: ArcGlobalCtx, + peer_mgr: Arc, + peer_packet_receiver: Arc>, + close_notifier: Arc, + ipv4_addr: Option, + ipv6_addr: Option, + ) -> Result<(), Error> { + let mut new_nic_ctx = + NicCtx::new(global_ctx, &peer_mgr, peer_packet_receiver, close_notifier); + new_nic_ctx.run(ipv4_addr, ipv6_addr).await?; + + #[cfg(feature = "magic-dns")] + { + let ifname = new_nic_ctx.ifname().await; + let dns_runner = + ipv4_addr.and_then(|ipv4| Self::create_magic_dns_runner(peer_mgr, ifname, ipv4)); + Self::use_new_dedicated_nic_ctx(arc_nic_ctx, new_nic_ctx, dns_runner).await; + } + #[cfg(not(feature = "magic-dns"))] + Self::use_new_dedicated_nic_ctx(arc_nic_ctx, new_nic_ctx).await; + + Ok(()) + } + + #[cfg(all(not(mobile), feature = "tun"))] + async fn try_setup_shared_tun( + arc_nic_ctx: ArcNicCtx, + global_ctx: ArcGlobalCtx, + peer_mgr: Arc, + peer_packet_receiver: Arc>, + close_notifier: Arc, + ipv4_addr: Option, + ) -> Result { + match try_attach_shared_tun( + global_ctx.clone(), + peer_mgr.clone(), + peer_packet_receiver, + close_notifier, + SharedTunAccess::Native, + ) + .await + { + Ok(attached) => { + global_ctx.issue_event(GlobalCtxEvent::TunDeviceReady(attached.ifname.clone())); + #[cfg(feature = "magic-dns")] + let dns_runner = ipv4_addr + .and_then(|ip| { + Self::create_magic_dns_runner(peer_mgr, Some(attached.ifname.clone()), ip) + }); + Self::use_new_shared_nic_ctx( + arc_nic_ctx, + attached.handle, + #[cfg(feature = "magic-dns")] + dns_runner, + ) + .await; + Ok(true) + } + Err(SharedTunAttachError::Fallback(reason)) => { + tracing::info!(instance_id = %global_ctx.get_id(), %reason, "shared tun unavailable, falling back to dedicated tun"); + global_ctx.issue_event(GlobalCtxEvent::TunDeviceFallback(reason)); + Ok(false) + } + Err(SharedTunAttachError::Fatal(err)) => Err(err), + } + } + // Warning, if there is an IP conflict in the network when using DHCP, the IP will be automatically changed. fn check_dhcp_ip_conflict(&self) { use rand::Rng; @@ -813,37 +949,57 @@ impl Instance { continue; } + global_ctx_c.set_ipv4(Some(ip)); + #[cfg(all(not(mobile), feature = "tun"))] { - let mut new_nic_ctx = NicCtx::new( + match Self::try_setup_shared_tun( + nic_ctx.clone(), global_ctx_c.clone(), - &peer_manager_c, + peer_manager_c.clone(), _peer_packet_receiver.clone(), nic_closed_notifier.clone(), - ); - if let Err(e) = new_nic_ctx.run(Some(ip), global_ctx_c.get_ipv6()).await { - tracing::error!( - ?current_dhcp_ip, - ?candidate_ipv4_addr, - ?e, - "add ip failed" - ); - global_ctx_c.set_ipv4(None); - continue; - } - #[cfg(feature = "magic-dns")] - let ifname = new_nic_ctx.ifname().await; - Self::use_new_nic_ctx( - nic_ctx.clone(), - new_nic_ctx, - #[cfg(feature = "magic-dns")] - Self::create_magic_dns_runner(peer_manager_c.clone(), ifname, ip), + Some(ip), ) - .await; + .await + { + Ok(true) => {} + Ok(false) => { + if let Err(e) = Self::setup_dedicated_nic_ctx( + nic_ctx.clone(), + global_ctx_c.clone(), + peer_manager_c.clone(), + _peer_packet_receiver.clone(), + nic_closed_notifier.clone(), + Some(ip), + global_ctx_c.get_ipv6(), + ) + .await + { + tracing::error!( + ?current_dhcp_ip, + ?candidate_ipv4_addr, + ?e, + "add ip failed" + ); + global_ctx_c.set_ipv4(None); + continue; + } + } + Err(e) => { + tracing::error!( + ?current_dhcp_ip, + ?candidate_ipv4_addr, + ?e, + "shared tun attach failed" + ); + global_ctx_c.set_ipv4(None); + continue; + } + } } current_dhcp_ip = Some(ip); - global_ctx_c.set_ipv4(Some(ip)); global_ctx_c.issue_event(GlobalCtxEvent::DhcpIpv4Changed(last_ip, Some(ip))); } else { current_dhcp_ip = None; @@ -883,36 +1039,51 @@ impl Instance { return; }; - let mut new_nic_ctx = NicCtx::new( + let shared_ready = match Self::try_setup_shared_tun( + nic_ctx.clone(), peer_mgr.get_global_ctx(), - &peer_mgr, + peer_mgr.clone(), peer_packet_receiver.clone(), close_notifier.clone(), - ); - - if let Err(e) = new_nic_ctx.run(ipv4_addr, ipv6_addr).await { - if let Some(output_tx) = output_tx.take() { - let _ = output_tx.send(Err(e)); - return; - } - tracing::error!("failed to create new nic ctx, err: {:?}", e); - tokio::time::sleep(Duration::from_secs(1)).await; - continue; - } - - // Create Magic DNS runner only if we have IPv4 - #[cfg(feature = "magic-dns")] + ipv4_addr, + ) + .await { - let ifname = new_nic_ctx.ifname().await; - let dns_runner = if let Some(ipv4) = ipv4_addr { - Self::create_magic_dns_runner(peer_mgr, ifname, ipv4) - } else { - None - }; - Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx, dns_runner).await; + Ok(ready) => ready, + Err(e) => { + if let Some(output_tx) = output_tx.take() { + let _ = output_tx.send(Err(e)); + return; + } + tracing::error!("failed to attach shared tun, err: {:?}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + }; + + if shared_ready { + // shared nic context is installed + } else { + if let Err(e) = Self::setup_dedicated_nic_ctx( + nic_ctx.clone(), + peer_mgr.get_global_ctx(), + peer_mgr.clone(), + peer_packet_receiver.clone(), + close_notifier.clone(), + ipv4_addr, + ipv6_addr, + ) + .await + { + if let Some(output_tx) = output_tx.take() { + let _ = output_tx.send(Err(e)); + return; + } + tracing::error!("failed to create new nic ctx, err: {:?}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } } - #[cfg(not(feature = "magic-dns"))] - Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx).await; } if let Some(output_tx) = output_tx.take() { @@ -1480,30 +1651,60 @@ impl Instance { return Ok(()); } let close_notifier = Arc::new(Notify::new()); - let mut new_nic_ctx = NicCtx::new( + match try_attach_shared_tun( global_ctx.clone(), - &peer_manager, + peer_manager.clone(), peer_packet_receiver.clone(), close_notifier.clone(), - ); - new_nic_ctx - .run_for_mobile(fd) - .await - .with_context(|| "add ip failed")?; + SharedTunAccess::MobileFd(fd), + ) + .await + { + Ok(attached) => { + global_ctx.issue_event(GlobalCtxEvent::TunDeviceReady(attached.ifname.clone())); + let magic_dns_runner = if let Some(ipv4) = global_ctx.get_ipv4() { + Self::create_magic_dns_runner( + peer_manager.clone(), + Some(attached.ifname.clone()), + ipv4, + ) + } else { + None + }; + Self::use_new_shared_nic_ctx(nic_ctx, attached.handle, magic_dns_runner).await; + } + Err(SharedTunAttachError::Fallback(reason)) => { + tracing::info!(instance_id = %global_ctx.get_id(), %reason, "shared mobile tun unavailable, falling back to dedicated tun"); + global_ctx.issue_event(GlobalCtxEvent::TunDeviceFallback(reason)); + let mut dedicated_nic_ctx = NicCtx::new( + global_ctx.clone(), + &peer_manager, + peer_packet_receiver.clone(), + close_notifier.clone(), + ); + dedicated_nic_ctx + .run_for_mobile(fd) + .await + .with_context(|| "add ip failed")?; - let magic_dns_runner = if let Some(ipv4) = global_ctx.get_ipv4() { - Self::create_magic_dns_runner(peer_manager.clone(), None, ipv4) - } else { - None - }; - Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx, magic_dns_runner).await; + let magic_dns_runner = if let Some(ipv4) = global_ctx.get_ipv4() { + Self::create_magic_dns_runner(peer_manager.clone(), None, ipv4) + } else { + None + }; + Self::use_new_dedicated_nic_ctx(nic_ctx, dedicated_nic_ctx, magic_dns_runner).await; + } + Err(SharedTunAttachError::Fatal(err)) => return Err(err.into()), + } Ok(()) } pub async fn clear_resources(&mut self) { self.peer_manager.clear_resources().await; #[cfg(feature = "tun")] - let _ = self.nic_ctx.lock().await.take(); + if let Some(old_ctx) = self.nic_ctx.lock().await.take() { + Self::cleanup_nic_ctx(old_ctx).await; + } } } @@ -1515,7 +1716,9 @@ impl Drop for Instance { let nic_ctx = self.nic_ctx.clone(); tokio::spawn(async move { #[cfg(feature = "tun")] - nic_ctx.lock().await.take(); + if let Some(old_ctx) = nic_ctx.lock().await.take() { + Instance::cleanup_nic_ctx(old_ctx).await; + } if let Some(pm) = pm.upgrade() { pm.clear_resources().await; }; diff --git a/easytier/src/instance/mod.rs b/easytier/src/instance/mod.rs index d50ad23e..3e3e9b2a 100644 --- a/easytier/src/instance/mod.rs +++ b/easytier/src/instance/mod.rs @@ -6,5 +6,8 @@ pub mod listeners; pub mod proxy_cidrs_monitor; +#[cfg(feature = "tun")] +pub mod shared_tun; + #[cfg(feature = "tun")] pub mod virtual_nic; diff --git a/easytier/src/instance/shared_tun.rs b/easytier/src/instance/shared_tun.rs new file mode 100644 index 00000000..4a121d8b --- /dev/null +++ b/easytier/src/instance/shared_tun.rs @@ -0,0 +1,1002 @@ +use std::{ + collections::{BTreeSet, HashMap}, + net::{Ipv4Addr, Ipv6Addr}, + sync::{Arc, Weak}, +}; + +use futures::{SinkExt, StreamExt}; +use once_cell::sync::Lazy; +use tokio::{ + sync::{Mutex, Notify, RwLock, mpsc}, + task::JoinSet, +}; +use tokio_util::sync::CancellationToken; + +use crate::{ + common::{ + config::ConfigLoader, + error::Error, + global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, + netns::NetNS, + }, + instance::{ + proxy_cidrs_monitor::ProxyCidrsMonitor, + virtual_nic::{NicCtx, VirtualNic}, + }, + peers::{PacketRecvChanReceiver, peer_manager::PeerManager, recv_packet_from_chan}, + tunnel::{Tunnel, packet_def::ZCPacket}, +}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum SharedTunAccess { + Native, + #[cfg(mobile)] + MobileFd(i32), +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct SharedTunKey { + netns: Option, + access: SharedTunAccess, +} + +impl SharedTunKey { + fn new(netns: &NetNS, access: SharedTunAccess) -> Self { + Self { + netns: netns.name(), + access, + } + } +} + +pub struct SharedTunAttach { + pub handle: SharedTunMemberHandle, + pub ifname: String, +} + +pub enum SharedTunAttachError { + Fallback(String), + Fatal(Error), +} + +impl From for SharedTunAttachError { + fn from(value: Error) -> Self { + Self::Fatal(value) + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +struct MemberClaims { + ipv4: Option, + ipv6: Option, + owned_proxy_v4_routes: BTreeSet, + proxy_v4_routes: BTreeSet, + effective_mtu: u16, + default_route_v4: bool, + default_route_v6: bool, +} + +impl MemberClaims { + fn local_v4_prefix(&self) -> Option { + self.ipv4.map(|ipv4| ipv4.network()) + } + + fn shared_route_v4_prefixes(&self) -> BTreeSet { + let mut ret = self.owned_proxy_v4_routes.clone(); + if self.default_route_v4 { + ret.insert(cidr::Ipv4Cidr::new(Ipv4Addr::UNSPECIFIED, 0).unwrap()); + } + ret + } + + fn reachable_v4_prefixes(&self) -> BTreeSet { + let mut ret = self.proxy_v4_routes.clone(); + if self.default_route_v4 { + ret.insert(cidr::Ipv4Cidr::new(Ipv4Addr::UNSPECIFIED, 0).unwrap()); + } + ret + } + + fn dispatch_v4_prefixes(&self) -> BTreeSet { + let mut ret = self.shared_route_v4_prefixes(); + if let Some(prefix) = self.local_v4_prefix() { + ret.insert(prefix); + } + ret + } + + fn local_v6_prefix(&self) -> Option { + self.ipv6.map(|ipv6| ipv6.network()) + } + + fn shared_route_v6_prefixes(&self) -> BTreeSet { + let mut ret = BTreeSet::new(); + if self.default_route_v6 { + ret.insert(cidr::Ipv6Cidr::new(Ipv6Addr::UNSPECIFIED, 0).unwrap()); + } + ret + } + + fn dispatch_v6_prefixes(&self) -> BTreeSet { + let mut ret = self.shared_route_v6_prefixes(); + if let Some(prefix) = self.local_v6_prefix() { + ret.insert(prefix); + } + ret + } +} + +#[derive(Clone)] +struct MemberRuntimeContext { + device: Arc, + slot: Arc, + peer_packet_receiver: Arc>, +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +struct AppliedConfig { + ipv4_addrs: BTreeSet<(Ipv4Addr, u8)>, + ipv6_addrs: BTreeSet<(Ipv6Addr, u8)>, + ipv4_routes: BTreeSet, + ipv6_routes: BTreeSet, + mtu: Option, +} + +struct SharedTunMemberSlot { + instance_id: uuid::Uuid, + global_ctx: ArcGlobalCtx, + peer_manager: Weak, + claims: RwLock, + close_notifier: Arc, +} + +struct SharedTunDevice { + key: SharedTunKey, + nic: Arc>, + ifname: String, + writer_tx: mpsc::Sender, + members: RwLock>>, + current_config: Mutex, + cancel: CancellationToken, + tasks: Mutex>, +} + +pub struct SharedTunMemberHandle { + device: Arc, + member_id: uuid::Uuid, + cancel: CancellationToken, + tasks: JoinSet<()>, + shutdown: bool, +} + +static SHARED_TUN_REGISTRY: Lazy = Lazy::new(SharedTunRegistry::default); + +#[derive(Default)] +struct SharedTunRegistry { + devices: Mutex>>, +} + +struct SharedTunAttachRequest { + key: SharedTunKey, + global_ctx: ArcGlobalCtx, + peer_manager: Arc, + peer_packet_receiver: Arc>, + close_notifier: Arc, + claims: MemberClaims, + access: SharedTunAccess, +} + +pub async fn try_attach_shared_tun( + global_ctx: ArcGlobalCtx, + peer_manager: Arc, + peer_packet_receiver: Arc>, + close_notifier: Arc, + access: SharedTunAccess, +) -> Result { + if global_ctx.config.get_flags().no_tun { + return Err(SharedTunAttachError::Fallback( + "shared tun is disabled when no_tun is enabled".to_owned(), + )); + } + + let proxy_v4_routes = + ProxyCidrsMonitor::diff_proxy_cidrs(peer_manager.as_ref(), &global_ctx, &BTreeSet::new()) + .await + .0; + let claims = build_member_claims(&global_ctx, proxy_v4_routes); + let key = SharedTunKey::new(&global_ctx.net_ns, access.clone()); + + SHARED_TUN_REGISTRY + .attach(SharedTunAttachRequest { + global_ctx, + peer_manager, + peer_packet_receiver, + close_notifier, + key, + claims, + access, + }) + .await +} + +impl SharedTunRegistry { + async fn attach( + &self, + request: SharedTunAttachRequest, + ) -> Result { + let SharedTunAttachRequest { + key, + global_ctx, + peer_manager, + peer_packet_receiver, + close_notifier, + claims, + access, + } = request; + + let device = { + let mut devices = self.devices.lock().await; + if let Some(device) = devices.get(&key) { + device.clone() + } else { + let device = SharedTunDevice::new(key.clone(), global_ctx.clone(), access) + .await + .map_err(SharedTunAttachError::Fatal)?; + devices.insert(key.clone(), device.clone()); + device + } + }; + + let handle = match device + .attach_member( + global_ctx.clone(), + peer_manager, + peer_packet_receiver, + close_notifier, + claims, + ) + .await + { + Ok(handle) => handle, + Err(err) => return Err(SharedTunAttachError::Fallback(err)), + }; + + Ok(SharedTunAttach { + handle, + ifname: device.ifname.clone(), + }) + } + + async fn remove_if_unused(&self, key: &SharedTunKey, device: &Arc) { + let should_shutdown = { + let mut devices = self.devices.lock().await; + if let Some(existing) = devices.get(key) + && Arc::ptr_eq(existing, device) + && device.member_count().await == 0 + { + devices.remove(key); + true + } else { + false + } + }; + + if should_shutdown { + device.shutdown().await; + } + } +} + +impl SharedTunDevice { + async fn new( + key: SharedTunKey, + global_ctx: ArcGlobalCtx, + access: SharedTunAccess, + ) -> Result, Error> { + let mut nic = VirtualNic::new(global_ctx); + let tunnel = match access { + SharedTunAccess::Native => nic.create_dev().await?, + #[cfg(mobile)] + SharedTunAccess::MobileFd(fd) => nic.create_dev_for_mobile(fd).await?, + }; + + let ifname = nic.ifname().to_owned(); + let (stream, sink) = tunnel.split(); + let (writer_tx, writer_rx) = mpsc::channel(256); + + let device = Arc::new(Self { + key, + nic: Arc::new(Mutex::new(nic)), + ifname, + writer_tx, + members: RwLock::new(HashMap::new()), + current_config: Mutex::new(AppliedConfig::default()), + cancel: CancellationToken::new(), + tasks: Mutex::new(JoinSet::new()), + }); + device.start_runtime_tasks(stream, sink, writer_rx).await; + Ok(device) + } + + async fn start_runtime_tasks( + self: &Arc, + mut stream: std::pin::Pin>, + mut sink: std::pin::Pin>, + mut writer_rx: mpsc::Receiver, + ) { + let cancel = self.cancel.clone(); + let reader_device = self.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::select! { + _ = cancel.cancelled() => break, + item = stream.next() => { + let Some(item) = item else { + reader_device.fail_all_members("shared tun reader closed").await; + break; + }; + + match item { + Ok(packet) => reader_device.dispatch_packet(packet).await, + Err(err) => { + tracing::error!(?err, "shared tun reader error"); + reader_device.fail_all_members("shared tun reader error").await; + break; + } + } + } + } + } + }); + + let cancel = self.cancel.clone(); + let writer_device = self.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::select! { + _ = cancel.cancelled() => break, + packet = writer_rx.recv() => { + let Some(packet) = packet else { + break; + }; + if let Err(err) = sink.send(packet).await { + tracing::error!(?err, "shared tun writer error"); + writer_device.fail_all_members("shared tun writer error").await; + break; + } + } + } + } + }); + } + + async fn attach_member( + self: &Arc, + global_ctx: ArcGlobalCtx, + peer_manager: Arc, + peer_packet_receiver: Arc>, + close_notifier: Arc, + claims: MemberClaims, + ) -> Result { + let dev_name = global_ctx.get_flags().dev_name; + if !dev_name.is_empty() && dev_name != self.ifname { + return Err(format!( + "shared tun device {} does not match requested dev_name {}", + self.ifname, dev_name + )); + } + + self.validate_claims(global_ctx.get_id(), &claims).await?; + + let slot = Arc::new(SharedTunMemberSlot { + instance_id: global_ctx.get_id(), + global_ctx: global_ctx.clone(), + peer_manager: Arc::downgrade(&peer_manager), + claims: RwLock::new(claims), + close_notifier, + }); + self.members + .write() + .await + .insert(slot.instance_id, slot.clone()); + + if let Err(err) = self.apply_config().await { + self.members.write().await.remove(&slot.instance_id); + return Err(format!("failed to apply shared tun config: {err}")); + } + + let cancel = CancellationToken::new(); + let mut tasks = JoinSet::new(); + let runtime_ctx = MemberRuntimeContext { + device: self.clone(), + slot: slot.clone(), + peer_packet_receiver, + }; + self.spawn_peer_to_tun_task(&mut tasks, runtime_ctx.clone(), cancel.clone()); + self.spawn_member_refresh_task(&mut tasks, runtime_ctx, global_ctx, cancel.clone()); + + Ok(SharedTunMemberHandle { + device: self.clone(), + member_id: slot.instance_id, + cancel, + tasks, + shutdown: false, + }) + } + + async fn update_member_claims( + &self, + member_id: uuid::Uuid, + claims: MemberClaims, + ) -> Result<(), String> { + self.validate_claims(member_id, &claims).await?; + let slot = { + let members = self.members.read().await; + members.get(&member_id).cloned() + } + .ok_or_else(|| format!("shared tun member {} not found", member_id))?; + *slot.claims.write().await = claims; + self.apply_config() + .await + .map_err(|err| format!("failed to apply shared tun config: {err}"))?; + Ok(()) + } + + fn spawn_peer_to_tun_task( + &self, + tasks: &mut JoinSet<()>, + runtime_ctx: MemberRuntimeContext, + member_cancel: CancellationToken, + ) { + let writer_tx = self.writer_tx.clone(); + let device_cancel = self.cancel.clone(); + tasks.spawn(async move { + let mut packet_recv = runtime_ctx.peer_packet_receiver.lock().await; + loop { + tokio::select! { + _ = device_cancel.cancelled() => break, + _ = member_cancel.cancelled() => break, + packet = recv_packet_from_chan(&mut packet_recv) => { + let Ok(packet) = packet else { + break; + }; + if writer_tx.send(packet).await.is_err() { + break; + } + } + } + } + }); + } + + fn spawn_member_refresh_task( + &self, + tasks: &mut JoinSet<()>, + runtime_ctx: MemberRuntimeContext, + global_ctx: ArcGlobalCtx, + member_cancel: CancellationToken, + ) { + let device_cancel = self.cancel.clone(); + tasks.spawn(async move { + let mut event_receiver = global_ctx.subscribe(); + let mut cur_proxy_cidrs = runtime_ctx.slot.claims.read().await.proxy_v4_routes.clone(); + + loop { + tokio::select! { + _ = device_cancel.cancelled() => break, + _ = member_cancel.cancelled() => break, + event = event_receiver.recv() => { + let Some(event) = handle_member_event(&mut event_receiver, event) else { + break; + }; + + if !should_refresh_member_claims(&event) { + continue; + } + + let Some(peer_manager) = runtime_ctx.slot.peer_manager.upgrade() else { + break; + }; + let (new_proxy_cidrs, _, _) = ProxyCidrsMonitor::diff_proxy_cidrs( + peer_manager.as_ref(), + &runtime_ctx.slot.global_ctx, + &cur_proxy_cidrs, + ) + .await; + cur_proxy_cidrs = new_proxy_cidrs.clone(); + + let claims = build_member_claims(&runtime_ctx.slot.global_ctx, new_proxy_cidrs); + if let Err(err) = runtime_ctx + .device + .update_member_claims(runtime_ctx.slot.instance_id, claims) + .await + { + tracing::warn!(instance_id = %runtime_ctx.slot.instance_id, %err, "shared tun member update failed"); + runtime_ctx + .slot + .global_ctx + .issue_event(GlobalCtxEvent::TunDeviceFallback(err.clone())); + runtime_ctx.slot.close_notifier.notify_one(); + break; + } + } + } + } + }); + } + + async fn validate_claims( + &self, + member_id: uuid::Uuid, + claims: &MemberClaims, + ) -> Result<(), String> { + let others = { + let members = self.members.read().await; + members + .iter() + .filter(|(id, _)| **id != member_id) + .map(|(_, slot)| slot.clone()) + .collect::>() + }; + + for other in others { + let other_claims = other.claims.read().await.clone(); + if let (Some(left), Some(right)) = (claims.ipv4, other_claims.ipv4) + && left.address() == right.address() + { + return Err(format!( + "shared tun conflict: duplicated IPv4 address {}", + left.address() + )); + } + if let (Some(left), Some(right)) = (claims.ipv6, other_claims.ipv6) + && left.address() == right.address() + { + return Err(format!( + "shared tun conflict: duplicated IPv6 address {}", + left.address() + )); + } + + for prefix in claims.shared_route_v4_prefixes() { + if other_claims.dispatch_v4_prefixes().contains(&prefix) { + return Err(format!( + "shared tun conflict: duplicated IPv4 route prefix {}", + prefix + )); + } + } + if let Some(prefix) = claims.local_v4_prefix() + && other_claims.shared_route_v4_prefixes().contains(&prefix) + { + return Err(format!( + "shared tun conflict: duplicated IPv4 route prefix {}", + prefix + )); + } + + for prefix in claims.shared_route_v6_prefixes() { + if other_claims.dispatch_v6_prefixes().contains(&prefix) { + return Err(format!( + "shared tun conflict: duplicated IPv6 route prefix {}", + prefix + )); + } + } + if let Some(prefix) = claims.local_v6_prefix() + && other_claims.shared_route_v6_prefixes().contains(&prefix) + { + return Err(format!( + "shared tun conflict: duplicated IPv6 route prefix {}", + prefix + )); + } + } + + Ok(()) + } + + async fn dispatch_packet(&self, packet: ZCPacket) { + let owner = self.select_owner(&packet).await; + let Some(slot) = owner else { + tracing::trace!("shared tun dropped packet without owner"); + return; + }; + let Some(peer_manager) = slot.peer_manager.upgrade() else { + tracing::trace!(instance_id = %slot.instance_id, "shared tun owner peer manager dropped"); + return; + }; + + NicCtx::forward_nic_packet_to_peers(packet, peer_manager.as_ref()).await; + } + + async fn select_owner(&self, packet: &ZCPacket) -> Option> { + let members = { + let members = self.members.read().await; + members.values().cloned().collect::>() + }; + + let payload = packet.payload(); + if payload.is_empty() { + return None; + } + + match payload[0] >> 4 { + 4 => self.select_owner_ipv4(payload, members).await, + 6 => self.select_owner_ipv6(payload, members).await, + _ => None, + } + } + + async fn select_owner_ipv4( + &self, + payload: &[u8], + members: Vec>, + ) -> Option> { + let ipv4 = pnet::packet::ipv4::Ipv4Packet::new(payload)?; + for slot in &members { + let claims = slot.claims.read().await; + if claims + .ipv4 + .map(|inet| inet.address() == ipv4.get_source()) + .unwrap_or(false) + { + return Some(slot.clone()); + } + } + + let dst = ipv4.get_destination(); + let mut best_owned: Option<(u8, Arc)> = None; + for slot in &members { + let claims = slot.claims.read().await; + for prefix in claims.dispatch_v4_prefixes() { + if prefix.contains(&dst) { + let prefix_len = prefix.network_length(); + if best_owned + .as_ref() + .map(|(best_len, _)| prefix_len > *best_len) + .unwrap_or(true) + { + best_owned = Some((prefix_len, slot.clone())); + } + } + } + } + + if let Some((_, slot)) = best_owned { + return Some(slot); + } + + let mut best_reachable: Option<(u8, Arc)> = None; + for slot in members { + let claims = slot.claims.read().await; + for prefix in claims.reachable_v4_prefixes() { + if prefix.contains(&dst) { + let prefix_len = prefix.network_length(); + if best_reachable + .as_ref() + .map(|(best_len, _)| prefix_len > *best_len) + .unwrap_or(true) + { + best_reachable = Some((prefix_len, slot.clone())); + } + } + } + } + + best_reachable.map(|(_, slot)| slot) + } + + async fn select_owner_ipv6( + &self, + payload: &[u8], + members: Vec>, + ) -> Option> { + let ipv6 = pnet::packet::ipv6::Ipv6Packet::new(payload)?; + for slot in &members { + let claims = slot.claims.read().await; + if claims + .ipv6 + .map(|inet| inet.address() == ipv6.get_source()) + .unwrap_or(false) + { + return Some(slot.clone()); + } + } + + let dst = ipv6.get_destination(); + let mut best: Option<(u8, Arc)> = None; + for slot in members { + let claims = slot.claims.read().await; + for prefix in claims.dispatch_v6_prefixes() { + if prefix.contains(&dst) { + let prefix_len = prefix.network_length(); + if best + .as_ref() + .map(|(best_len, _)| prefix_len > *best_len) + .unwrap_or(true) + { + best = Some((prefix_len, slot.clone())); + } + } + } + } + + best.map(|(_, slot)| slot) + } + + async fn apply_config(&self) -> Result<(), Error> { + let slots = { + let members = self.members.read().await; + members.values().cloned().collect::>() + }; + + let mut desired = AppliedConfig::default(); + for slot in slots { + let claims = slot.claims.read().await.clone(); + if let Some(ipv4) = claims.ipv4 { + desired + .ipv4_addrs + .insert((ipv4.address(), ipv4.network_length())); + #[cfg(any( + all(target_os = "macos", not(feature = "macos-ne")), + target_os = "freebsd" + ))] + desired.ipv4_routes.insert(ipv4.network()); + } + if let Some(ipv6) = claims.ipv6 { + desired + .ipv6_addrs + .insert((ipv6.address(), ipv6.network_length())); + #[cfg(any( + all(target_os = "macos", not(feature = "macos-ne")), + target_os = "freebsd" + ))] + desired.ipv6_routes.insert(ipv6.network()); + } + desired + .ipv4_routes + .extend(claims.proxy_v4_routes.iter().copied()); + desired.mtu = Some(match desired.mtu { + Some(cur) => cur.min(claims.effective_mtu), + None => claims.effective_mtu, + }); + } + + let mut current = self.current_config.lock().await; + if *current == desired { + return Ok(()); + } + + let nic = self.nic.lock().await; + nic.link_up().await?; + if current.mtu != desired.mtu + && let Some(mtu) = desired.mtu + { + nic.set_mtu(mtu).await?; + } + + if current.ipv4_addrs != desired.ipv4_addrs { + nic.remove_ip(None).await?; + for (addr, prefix) in desired.ipv4_addrs.iter().copied() { + nic.add_ip(addr, prefix as i32).await?; + } + } + if current.ipv6_addrs != desired.ipv6_addrs { + nic.remove_ipv6(None).await?; + for (addr, prefix) in desired.ipv6_addrs.iter().copied() { + nic.add_ipv6(addr, prefix as i32).await?; + } + } + + for prefix in current.ipv4_routes.difference(&desired.ipv4_routes) { + nic.remove_route(prefix.first_address(), prefix.network_length()) + .await?; + } + for prefix in desired.ipv4_routes.difference(¤t.ipv4_routes) { + nic.add_route(prefix.first_address(), prefix.network_length()) + .await?; + } + + for prefix in current.ipv6_routes.difference(&desired.ipv6_routes) { + nic.remove_ipv6_route(prefix.first_address(), prefix.network_length()) + .await?; + } + for prefix in desired.ipv6_routes.difference(¤t.ipv6_routes) { + nic.add_ipv6_route(prefix.first_address(), prefix.network_length()) + .await?; + } + + *current = desired; + Ok(()) + } + + async fn detach_member(&self, member_id: uuid::Uuid) { + self.members.write().await.remove(&member_id); + if let Err(err) = self.apply_config().await { + tracing::warn!(instance_id = %member_id, ?err, "failed to reconcile shared tun after detach"); + } + } + + async fn member_count(&self) -> usize { + self.members.read().await.len() + } + + async fn fail_all_members(&self, reason: &str) { + self.cancel.cancel(); + let members = { + let members = self.members.read().await; + members.values().cloned().collect::>() + }; + for slot in members { + slot.global_ctx + .issue_event(GlobalCtxEvent::TunDeviceError(reason.to_owned())); + slot.close_notifier.notify_one(); + } + } + + async fn shutdown(&self) { + self.cancel.cancel(); + let mut tasks = self.tasks.lock().await; + tasks.abort_all(); + while tasks.join_next().await.is_some() {} + } +} + +impl SharedTunMemberHandle { + pub async fn shutdown(mut self) { + if self.shutdown { + return; + } + self.shutdown = true; + self.cancel.cancel(); + self.tasks.abort_all(); + while self.tasks.join_next().await.is_some() {} + self.device.detach_member(self.member_id).await; + SHARED_TUN_REGISTRY + .remove_if_unused(&self.device.key, &self.device) + .await; + } +} + +impl Drop for SharedTunMemberHandle { + fn drop(&mut self) { + if self.shutdown { + return; + } + self.cancel.cancel(); + self.tasks.abort_all(); + let device = self.device.clone(); + let member_id = self.member_id; + let key = self.device.key.clone(); + tokio::spawn(async move { + device.detach_member(member_id).await; + SHARED_TUN_REGISTRY.remove_if_unused(&key, &device).await; + }); + } +} + +fn handle_member_event( + event_receiver: &mut tokio::sync::broadcast::Receiver, + event: Result, +) -> Option { + match event { + Ok(event) => Some(event), + Err(tokio::sync::broadcast::error::RecvError::Closed) => None, + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { + *event_receiver = event_receiver.resubscribe(); + Some(GlobalCtxEvent::ProxyCidrsUpdated(Vec::new(), Vec::new())) + } + } +} + +fn should_refresh_member_claims(event: &GlobalCtxEvent) -> bool { + matches!( + event, + GlobalCtxEvent::ProxyCidrsUpdated(_, _) + | GlobalCtxEvent::ConfigPatched(_) + | GlobalCtxEvent::DhcpIpv4Changed(_, _) + ) +} + +fn build_member_claims( + global_ctx: &ArcGlobalCtx, + proxy_v4_routes: BTreeSet, +) -> MemberClaims { + let flags = global_ctx.get_flags(); + let effective_mtu = flags + .mtu + .saturating_sub(if flags.enable_encryption { 20 } else { 0 }); + let enable_exit_node = global_ctx.enable_exit_node(); + MemberClaims { + ipv4: global_ctx.get_ipv4(), + ipv6: global_ctx.get_ipv6(), + owned_proxy_v4_routes: collect_owned_proxy_v4_routes(global_ctx), + proxy_v4_routes, + effective_mtu: effective_mtu as u16, + default_route_v4: enable_exit_node, + default_route_v6: enable_exit_node, + } +} + +fn collect_owned_proxy_v4_routes(global_ctx: &ArcGlobalCtx) -> BTreeSet { + let mut routes = global_ctx + .config + .get_proxy_cidrs() + .into_iter() + .map(|cfg| cfg.mapped_cidr.unwrap_or(cfg.cidr)) + .collect::>(); + + if let Some(cidr) = global_ctx.get_vpn_portal_cidr() { + routes.insert(cidr); + } + + routes +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::MemberClaims; + + #[test] + fn nested_prefixes_are_distinct() { + let mut left = MemberClaims::default(); + left.owned_proxy_v4_routes + .insert(cidr::Ipv4Cidr::from_str("10.0.0.0/16").unwrap()); + + let mut right = MemberClaims::default(); + right + .owned_proxy_v4_routes + .insert(cidr::Ipv4Cidr::from_str("10.0.1.0/24").unwrap()); + + assert_ne!(left.dispatch_v4_prefixes(), right.dispatch_v4_prefixes()); + } + + #[test] + fn exit_node_adds_default_prefix() { + let claims = MemberClaims { + default_route_v4: true, + ..Default::default() + }; + + assert!( + claims + .dispatch_v4_prefixes() + .contains(&cidr::Ipv4Cidr::from_str("0.0.0.0/0").unwrap()) + ); + } + + #[test] + fn identical_local_v4_subnets_do_not_conflict_in_dispatch() { + let left = MemberClaims { + ipv4: Some(cidr::Ipv4Inet::from_str("10.144.145.1/24").unwrap()), + ..Default::default() + }; + + let right = MemberClaims { + ipv4: Some(cidr::Ipv4Inet::from_str("10.144.145.2/24").unwrap()), + ..Default::default() + }; + + assert_eq!(left.local_v4_prefix(), right.local_v4_prefix()); + assert!(left.shared_route_v4_prefixes().is_empty()); + assert!(right.shared_route_v4_prefixes().is_empty()); + } + + #[test] + fn learned_proxy_routes_do_not_become_owned_claims() { + let mut claims = MemberClaims::default(); + claims + .proxy_v4_routes + .insert(cidr::Ipv4Cidr::from_str("10.1.2.0/24").unwrap()); + + assert!(claims.shared_route_v4_prefixes().is_empty()); + assert!( + claims + .reachable_v4_prefixes() + .contains(&cidr::Ipv4Cidr::from_str("10.1.2.0/24").unwrap()) + ); + } +} diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index 280147b4..19305fb3 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -742,6 +742,22 @@ impl VirtualNic { Ok(()) } + pub async fn remove_route(&self, address: Ipv4Addr, cidr: u8) -> Result<(), Error> { + let _g = self.global_ctx.net_ns.guard(); + self.ifcfg + .remove_ipv4_route(self.ifname(), address, cidr) + .await?; + Ok(()) + } + + pub async fn remove_ipv6_route(&self, address: Ipv6Addr, cidr: u8) -> Result<(), Error> { + let _g = self.global_ctx.net_ns.guard(); + self.ifcfg + .remove_ipv6_route(self.ifname(), address, cidr) + .await?; + Ok(()) + } + pub async fn remove_ip(&self, ip: Option) -> Result<(), Error> { let _g = self.global_ctx.net_ns.guard(); self.ifcfg.remove_ip(self.ifname(), ip).await?; @@ -770,6 +786,12 @@ impl VirtualNic { Ok(()) } + pub async fn set_mtu(&self, mtu: u16) -> Result<(), Error> { + let _g = self.global_ctx.net_ns.guard(); + self.ifcfg.set_mtu(self.ifname(), mtu as u32).await?; + Ok(()) + } + pub fn get_ifcfg(&self) -> impl IfConfiguerTrait + use<> { IfConfiger {} } @@ -943,6 +965,10 @@ impl NicCtx { } } + pub(crate) async fn forward_nic_packet_to_peers(ret: ZCPacket, mgr: &PeerManager) { + Self::do_forward_nic_to_peers(ret, mgr).await; + } + fn do_forward_nic_to_peers_task( &mut self, mut stream: Pin>, diff --git a/easytier/src/instance_manager.rs b/easytier/src/instance_manager.rs index 4ad710e8..b1efd880 100644 --- a/easytier/src/instance_manager.rs +++ b/easytier/src/instance_manager.rs @@ -363,6 +363,10 @@ fn handle_event( event!(error, %err, "[{}] tun device error", instance_id); } + GlobalCtxEvent::TunDeviceFallback(reason) => { + event!(warn, %reason, "[{}] tun device fallback", instance_id); + } + GlobalCtxEvent::Connecting(dst) => { event!(info, category: "CONNECTION", %dst, "[{}] connecting to peer", instance_id); } diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 5832771f..058a38fd 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -258,6 +258,211 @@ pub async fn drop_insts(insts: Vec) { while set.join_next().await.is_some() {} } +async fn wait_for_tun_ready_event( + receiver: &mut tokio::sync::broadcast::Receiver, +) -> String { + tokio::time::timeout(Duration::from_secs(5), async { + loop { + if let crate::common::global_ctx::GlobalCtxEvent::TunDeviceReady(ifname) = receiver.recv().await.unwrap() { + return ifname; + } + } + }) + .await + .unwrap() +} + +async fn assert_no_tun_ready_event( + receiver: &mut tokio::sync::broadcast::Receiver, + timeout: Duration, +) { + tokio::time::timeout(timeout, async { + loop { + if let crate::common::global_ctx::GlobalCtxEvent::TunDeviceReady(ifname) = receiver.recv().await.unwrap() { + panic!("unexpected TunDeviceReady event: {ifname}"); + } + } + }) + .await + .ok(); +} + +async fn assert_no_tun_fallback_event( + receiver: &mut tokio::sync::broadcast::Receiver, + timeout: Duration, +) { + tokio::time::timeout(timeout, async { + loop { + if let crate::common::global_ctx::GlobalCtxEvent::TunDeviceFallback(reason) = receiver.recv().await.unwrap() { + panic!("unexpected TunDeviceFallback event: {reason}"); + } + } + }) + .await + .ok(); +} + +fn assert_tcp_proxy_metric_has_protocol( + inst: &Instance, + protocol: TcpProxyEntryTransportType, + min_value: u64, +) { + let metrics = inst + .get_global_ctx() + .stats_manager() + .get_metrics_by_prefix(&MetricName::TcpProxyConnect.to_string()); + + assert!( + metrics.iter().any(|metric| { + metric.value >= min_value + && metric.labels.labels().iter().any(|l| { + let t = LabelType::Protocol(protocol.as_str_name().to_string()); + t.key() == l.key && t.value() == l.value + }) + }), + "metrics: {:?}", + metrics + ); +} + +async fn shared_tun_subnet_proxy_transport_test( + transport: TcpProxyEntryTransportType, + source_shared: bool, +) { + prepare_linux_namespaces(); + + let center_cfg = get_inst_config("center", Some("net_b"), "10.144.144.100", "fd00::64/64"); + center_cfg.set_listeners(vec![]); + let mut center = Instance::new(center_cfg); + + let mut shared_events = Vec::new(); + let mut insts = Vec::new(); + let dst_idx; + + if source_shared { + let source_cfg = get_inst_config("src_shared", Some("net_a"), "10.144.144.1", "fd00::1/64"); + source_cfg.set_listeners(vec![]); + source_cfg.set_socks5_portal(None); + let mut source_flags = source_cfg.get_flags(); + source_flags.dev_name = "et_ssrc0".to_string(); + match transport { + TcpProxyEntryTransportType::Kcp => source_flags.enable_kcp_proxy = true, + TcpProxyEntryTransportType::Quic => source_flags.enable_quic_proxy = true, + _ => unreachable!(), + } + source_cfg.set_flags(source_flags.clone()); + let source = Instance::new(source_cfg); + shared_events.push(source.get_global_ctx().subscribe()); + + let source_peer = get_inst_config("src_peer", Some("net_a"), "10.144.144.2", "fd00::2/64"); + source_peer.set_listeners(vec![]); + source_peer.set_socks5_portal(None); + source_peer.set_flags(source_flags); + let source_peer = Instance::new(source_peer); + shared_events.push(source_peer.get_global_ctx().subscribe()); + + let dst_cfg = get_inst_config("dst", Some("net_c"), "10.144.144.3", "fd00::3/64"); + dst_cfg.set_listeners(vec![]); + dst_cfg.set_socks5_portal(None); + dst_cfg + .add_proxy_cidr("10.1.2.0/24".parse().unwrap(), None) + .unwrap(); + let dst = Instance::new(dst_cfg); + + insts.push(source); + insts.push(source_peer); + insts.push(dst); + dst_idx = 2; + } else { + let src_cfg = get_inst_config("src", Some("net_a"), "10.144.144.1", "fd00::1/64"); + src_cfg.set_listeners(vec![]); + src_cfg.set_socks5_portal(None); + let mut src_flags = src_cfg.get_flags(); + match transport { + TcpProxyEntryTransportType::Kcp => src_flags.enable_kcp_proxy = true, + TcpProxyEntryTransportType::Quic => src_flags.enable_quic_proxy = true, + _ => unreachable!(), + } + src_cfg.set_flags(src_flags); + let src = Instance::new(src_cfg); + + let dst_cfg = get_inst_config("dst_shared", Some("net_c"), "10.144.144.3", "fd00::3/64"); + dst_cfg.set_listeners(vec![]); + dst_cfg.set_socks5_portal(None); + dst_cfg + .add_proxy_cidr("10.1.2.0/24".parse().unwrap(), None) + .unwrap(); + let mut dst_flags = dst_cfg.get_flags(); + dst_flags.dev_name = "et_sdst0".to_string(); + dst_cfg.set_flags(dst_flags.clone()); + let dst = Instance::new(dst_cfg); + shared_events.push(dst.get_global_ctx().subscribe()); + + let dst_peer = get_inst_config("dst_peer", Some("net_c"), "10.144.144.4", "fd00::4/64"); + dst_peer.set_listeners(vec![]); + dst_peer.set_socks5_portal(None); + dst_peer.set_flags(dst_flags); + let dst_peer = Instance::new(dst_peer); + shared_events.push(dst_peer.get_global_ctx().subscribe()); + + insts.push(src); + insts.push(dst); + insts.push(dst_peer); + dst_idx = 1; + } + + center.run().await.unwrap(); + for inst in &mut insts { + inst.run().await.unwrap(); + } + + let ifname = wait_for_tun_ready_event(&mut shared_events[0]).await; + for receiver in shared_events.iter_mut().skip(1) { + assert_eq!(ifname, wait_for_tun_ready_event(receiver).await); + } + + insts[0] + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + insts[dst_idx] + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + + wait_for_condition( + || async { + insts[0].get_peer_manager().list_routes().await.len() >= 2 + && insts[dst_idx].get_peer_manager().list_routes().await.len() >= 2 + }, + Duration::from_secs(8), + ) + .await; + + wait_proxy_route_appear( + &insts[0].get_peer_manager(), + "10.144.144.3/24", + insts[dst_idx].peer_id(), + "10.1.2.0/24", + ) + .await; + + subnet_proxy_test_icmp("10.1.2.4", Duration::from_secs(8)).await; + subnet_proxy_test_tcp("10.1.2.4", "10.1.2.4", Duration::from_secs(8)).await; + subnet_proxy_test_udp("10.1.2.4", "10.1.2.4", Duration::from_secs(8)).await; + + assert_tcp_proxy_metric_has_protocol(&insts[0], transport, 1); + for receiver in &mut shared_events { + assert_no_tun_fallback_event(receiver, Duration::from_secs(2)).await; + } + + let mut all_insts = vec![center]; + all_insts.extend(insts); + drop_insts(all_insts).await; +} + async fn ping_test(from_netns: &str, target_ip: &str, payload_size: Option) -> bool { let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard(); let code = tokio::process::Command::new("ip") @@ -994,6 +1199,293 @@ pub async fn foreign_network_forward_nic_data() { drop_insts(vec![center_inst, inst1, inst2]).await; } +#[tokio::test] +#[serial_test::serial] +pub async fn shared_tun_same_namespace_real_tun() { + prepare_linux_namespaces(); + + let center_cfg = get_inst_config("center", Some("net_a"), "10.144.144.1", "fd00::1/64"); + center_cfg.set_listeners(vec![]); + let mut center = Instance::new(center_cfg); + + let shared_cfg_1 = get_inst_config("shared_1", Some("net_b"), "10.144.144.2", "fd00::2/64"); + shared_cfg_1.set_listeners(vec![]); + shared_cfg_1.set_socks5_portal(None); + let mut shared_flags = shared_cfg_1.get_flags(); + shared_flags.dev_name = "et_shared0".to_string(); + shared_cfg_1.set_flags(shared_flags.clone()); + let mut shared_1 = Instance::new(shared_cfg_1); + + let shared_cfg_2 = get_inst_config("shared_2", Some("net_b"), "10.144.144.3", "fd00::3/64"); + shared_cfg_2.set_listeners(vec![]); + shared_cfg_2.set_socks5_portal(None); + shared_cfg_2.set_flags(shared_flags); + let mut shared_2 = Instance::new(shared_cfg_2); + + let remote_cfg = get_inst_config("remote", Some("net_c"), "10.144.144.4", "fd00::4/64"); + remote_cfg.set_listeners(vec![]); + let mut remote = Instance::new(remote_cfg); + + let mut shared_1_events = shared_1.get_global_ctx().subscribe(); + let mut shared_2_events = shared_2.get_global_ctx().subscribe(); + + center.run().await.unwrap(); + shared_1.run().await.unwrap(); + shared_2.run().await.unwrap(); + remote.run().await.unwrap(); + + let shared_1_tun = wait_for_tun_ready_event(&mut shared_1_events).await; + let shared_2_tun = wait_for_tun_ready_event(&mut shared_2_events).await; + assert_eq!(shared_1_tun, shared_2_tun); + + shared_1 + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + shared_2 + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + remote + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + + wait_for_condition( + || async { + shared_1.get_peer_manager().list_routes().await.len() == 3 + && shared_2.get_peer_manager().list_routes().await.len() == 3 + && remote.get_peer_manager().list_routes().await.len() == 3 + }, + Duration::from_secs(8), + ) + .await; + + wait_for_condition( + || async { ping_test("net_c", "10.144.144.2", None).await }, + Duration::from_secs(8), + ) + .await; + wait_for_condition( + || async { ping_test("net_c", "10.144.144.3", None).await }, + Duration::from_secs(8), + ) + .await; + wait_for_condition( + || async { ping_test("net_b", "10.144.144.4", None).await }, + Duration::from_secs(8), + ) + .await; + + drop_insts(vec![center, shared_1, shared_2, remote]).await; +} + +#[tokio::test] +#[serial_test::serial] +pub async fn shared_tun_proxy_cidr_same_namespace_real_tun() { + prepare_linux_namespaces(); + + let center_cfg = get_inst_config("center", Some("net_a"), "10.144.144.1", "fd00::1/64"); + center_cfg.set_listeners(vec![]); + let mut center = Instance::new(center_cfg); + + let shared_cfg_1 = get_inst_config("shared_1", Some("net_c"), "10.144.144.2", "fd00::2/64"); + shared_cfg_1.set_listeners(vec![]); + shared_cfg_1.set_socks5_portal(None); + shared_cfg_1 + .add_proxy_cidr("10.1.2.0/24".parse().unwrap(), None) + .unwrap(); + let mut shared_flags = shared_cfg_1.get_flags(); + shared_flags.dev_name = "et_shp0".to_string(); + shared_cfg_1.set_flags(shared_flags.clone()); + let mut shared_1 = Instance::new(shared_cfg_1); + + let shared_cfg_2 = get_inst_config("shared_2", Some("net_c"), "10.144.144.3", "fd00::3/64"); + shared_cfg_2.set_listeners(vec![]); + shared_cfg_2.set_socks5_portal(None); + shared_cfg_2.set_flags(shared_flags); + let mut shared_2 = Instance::new(shared_cfg_2); + + let remote_cfg = get_inst_config("remote", Some("net_b"), "10.144.144.4", "fd00::4/64"); + remote_cfg.set_listeners(vec![]); + let mut remote = Instance::new(remote_cfg); + + let mut shared_1_events = shared_1.get_global_ctx().subscribe(); + let mut shared_2_events = shared_2.get_global_ctx().subscribe(); + + center.run().await.unwrap(); + shared_1.run().await.unwrap(); + shared_2.run().await.unwrap(); + remote.run().await.unwrap(); + + let shared_1_tun = wait_for_tun_ready_event(&mut shared_1_events).await; + let shared_2_tun = wait_for_tun_ready_event(&mut shared_2_events).await; + assert_eq!(shared_1_tun, shared_2_tun); + + shared_1 + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + shared_2 + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + remote + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + + wait_for_condition( + || async { + shared_1.get_peer_manager().list_routes().await.len() == 3 + && shared_2.get_peer_manager().list_routes().await.len() == 3 + && remote.get_peer_manager().list_routes().await.len() == 3 + }, + Duration::from_secs(8), + ) + .await; + + wait_proxy_route_appear( + ¢er.get_peer_manager(), + "10.144.144.2/24", + shared_1.peer_id(), + "10.1.2.0/24", + ) + .await; + wait_proxy_route_appear( + &remote.get_peer_manager(), + "10.144.144.2/24", + shared_1.peer_id(), + "10.1.2.0/24", + ) + .await; + + wait_for_condition( + || async { ping_test("net_a", "10.1.2.4", None).await }, + Duration::from_secs(8), + ) + .await; + wait_for_condition( + || async { ping_test("net_b", "10.1.2.4", None).await }, + Duration::from_secs(8), + ) + .await; + + assert_no_tun_fallback_event(&mut shared_1_events, Duration::from_secs(2)).await; + assert_no_tun_fallback_event(&mut shared_2_events, Duration::from_secs(2)).await; + + drop_insts(vec![center, shared_1, shared_2, remote]).await; +} + +#[tokio::test] +#[serial_test::serial] +pub async fn shared_tun_kcp_proxy_with_source_shared_tun() { + shared_tun_subnet_proxy_transport_test(TcpProxyEntryTransportType::Kcp, true).await; +} + +#[tokio::test] +#[serial_test::serial] +pub async fn shared_tun_quic_proxy_with_source_shared_tun() { + shared_tun_subnet_proxy_transport_test(TcpProxyEntryTransportType::Quic, true).await; +} + +#[tokio::test] +#[serial_test::serial] +pub async fn shared_tun_kcp_proxy_with_destination_shared_tun() { + shared_tun_subnet_proxy_transport_test(TcpProxyEntryTransportType::Kcp, false).await; +} + +#[tokio::test] +#[serial_test::serial] +pub async fn shared_tun_quic_proxy_with_destination_shared_tun() { + shared_tun_subnet_proxy_transport_test(TcpProxyEntryTransportType::Quic, false).await; +} + +#[tokio::test] +#[serial_test::serial] +pub async fn same_namespace_no_tun_skips_shared_tun_and_keeps_connectivity() { + prepare_linux_namespaces(); + + let center_cfg = get_inst_config("center", Some("net_a"), "10.144.144.1", "fd00::1/64"); + center_cfg.set_listeners(vec![]); + let mut center = Instance::new(center_cfg); + + let no_tun_cfg_1 = get_inst_config("no_tun_1", Some("net_b"), "10.144.144.2", "fd00::2/64"); + no_tun_cfg_1.set_listeners(vec![]); + no_tun_cfg_1.set_socks5_portal(None); + let mut no_tun_flags = no_tun_cfg_1.get_flags(); + no_tun_flags.dev_name = "et_shared_disabled0".to_string(); + no_tun_flags.no_tun = true; + no_tun_cfg_1.set_flags(no_tun_flags.clone()); + let mut no_tun_1 = Instance::new(no_tun_cfg_1); + + let no_tun_cfg_2 = get_inst_config("no_tun_2", Some("net_b"), "10.144.144.3", "fd00::3/64"); + no_tun_cfg_2.set_listeners(vec![]); + no_tun_cfg_2.set_socks5_portal(None); + no_tun_cfg_2.set_flags(no_tun_flags); + let mut no_tun_2 = Instance::new(no_tun_cfg_2); + + let remote_cfg = get_inst_config("remote", Some("net_c"), "10.144.144.4", "fd00::4/64"); + remote_cfg.set_listeners(vec![]); + let mut remote = Instance::new(remote_cfg); + + let mut no_tun_1_events = no_tun_1.get_global_ctx().subscribe(); + let mut no_tun_2_events = no_tun_2.get_global_ctx().subscribe(); + + center.run().await.unwrap(); + no_tun_1.run().await.unwrap(); + no_tun_2.run().await.unwrap(); + remote.run().await.unwrap(); + + assert_no_tun_ready_event(&mut no_tun_1_events, Duration::from_secs(2)).await; + assert_no_tun_ready_event(&mut no_tun_2_events, Duration::from_secs(2)).await; + + no_tun_1 + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + no_tun_2 + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + remote + .get_conn_manager() + .add_connector(RingTunnelConnector::new( + format!("ring://{}", center.id()).parse().unwrap(), + )); + + wait_for_condition( + || async { + no_tun_1.get_peer_manager().list_routes().await.len() == 3 + && no_tun_2.get_peer_manager().list_routes().await.len() == 3 + && remote.get_peer_manager().list_routes().await.len() == 3 + }, + Duration::from_secs(8), + ) + .await; + + wait_for_condition( + || async { ping_test("net_c", "10.144.144.2", None).await }, + Duration::from_secs(8), + ) + .await; + wait_for_condition( + || async { ping_test("net_c", "10.144.144.3", None).await }, + Duration::from_secs(8), + ) + .await; + + drop_insts(vec![center, no_tun_1, no_tun_2, remote]).await; +} + use std::{net::SocketAddr, str::FromStr}; use defguard_wireguard_rs::{