diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index 5e9ed3c4..69c4f17b 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, hash_map::DefaultHasher}, + collections::{BTreeSet, HashMap, hash_map::DefaultHasher}, hash::Hasher, net::{IpAddr, SocketAddr}, sync::{Arc, Mutex}, @@ -203,6 +203,7 @@ pub struct GlobalCtx { cached_ipv4: AtomicCell>, cached_ipv6: AtomicCell>, public_ipv6_lease: AtomicCell>, + public_ipv6_routes: Mutex>, cached_proxy_cidrs: AtomicCell>>, ip_collector: Mutex>>, @@ -300,6 +301,7 @@ impl GlobalCtx { cached_ipv4: AtomicCell::new(None), cached_ipv6: AtomicCell::new(None), public_ipv6_lease: AtomicCell::new(None), + public_ipv6_routes: Mutex::new(BTreeSet::new()), cached_proxy_cidrs: AtomicCell::new(None), ip_collector: Mutex::new(Some(Arc::new(IPCollector::new( @@ -395,6 +397,11 @@ impl GlobalCtx { self.public_ipv6_lease.store(addr); } + pub fn set_public_ipv6_routes(&self, routes: BTreeSet) { + *self.public_ipv6_routes.lock().unwrap() = + routes.into_iter().map(|route| route.address()).collect(); + } + pub fn is_ip_local_ipv6(&self, ip: &std::net::Ipv6Addr) -> bool { self.get_ipv6().map(|x| x.address() == *ip).unwrap_or(false) || self @@ -403,6 +410,10 @@ impl GlobalCtx { .unwrap_or(false) } + pub fn is_ip_easytier_managed_ipv6(&self, ip: &std::net::Ipv6Addr) -> bool { + self.is_ip_local_ipv6(ip) || self.public_ipv6_routes.lock().unwrap().contains(ip) + } + pub fn get_advertised_ipv6_public_addr_prefix(&self) -> Option { *self.advertised_ipv6_public_addr_prefix.lock().unwrap() } diff --git a/easytier/src/connector/direct.rs b/easytier/src/connector/direct.rs index d4abb637..e6296f11 100644 --- a/easytier/src/connector/direct.rs +++ b/easytier/src/connector/direct.rs @@ -64,6 +64,24 @@ async fn resolve_mapped_listener_addrs(listener: &url::Url) -> Result bool { + is_usable_public_ipv6_candidate_with_mode(ip, global_ctx, TESTING.load(Ordering::Relaxed)) +} + +fn is_usable_public_ipv6_candidate_with_mode( + ip: &Ipv6Addr, + global_ctx: &ArcGlobalCtx, + testing: bool, +) -> bool { + !global_ctx.is_ip_easytier_managed_ipv6(ip) + && (testing + || (!ip.is_loopback() + && !ip.is_unspecified() + && !ip.is_unique_local() + && !ip.is_unicast_link_local() + && !ip.is_multicast())) +} + #[async_trait::async_trait] pub trait PeerManagerForDirectConnector { async fn list_peers(&self) -> Vec; @@ -190,34 +208,28 @@ impl DirectConnectorManagerData { .with_context(|| format!("failed to bind local socket for {}", remote_url))?, ); let connector_ip = self - .peer_manager - .get_global_ctx() + .global_ctx .get_stun_info_collector() .get_stun_info() .public_ip .iter() - .find(|x| x.contains(':')) - .ok_or(anyhow::anyhow!( - "failed to get public ipv6 address from stun info" - ))? - .parse::() - .with_context(|| { - format!( - "failed to parse public ipv6 address from stun info: {:?}", - self.peer_manager - .get_global_ctx() - .get_stun_info_collector() - .get_stun_info() - ) - })?; - let connector_addr = - SocketAddr::new(IpAddr::V6(connector_ip), local_socket.local_addr()?.port()); + .filter_map(|ip| ip.parse::().ok()) + .find(|ip| !self.global_ctx.is_ip_easytier_managed_ipv6(ip)); // ask remote to send v6 hole punch packet // and no matter what the result is, continue to connect - let _ = self - .remote_send_udp_hole_punch_packet(dst_peer_id, connector_addr, remote_url) - .await; + if let Some(connector_ip) = connector_ip { + let connector_addr = + SocketAddr::new(IpAddr::V6(connector_ip), local_socket.local_addr()?.port()); + let _ = self + .remote_send_udp_hole_punch_packet(dst_peer_id, connector_addr, remote_url) + .await; + } else { + tracing::debug!( + ?remote_url, + "skip remote IPv6 hole-punch packet; no non-EasyTier public IPv6 in STUN info" + ); + } let udp_connector = UdpTunnelConnector::new(remote_url.clone()); let remote_addr = SocketAddr::from_url(remote_url.clone(), IpVersion::V6).await?; @@ -479,14 +491,7 @@ impl DirectConnectorManagerData { .iter() .chain(ip_list.public_ipv6.iter()) .filter_map(|x| Ipv6Addr::from_str(&x.to_string()).ok()) - .filter(|x| { - TESTING.load(Ordering::Relaxed) - || (!x.is_loopback() - && !x.is_unspecified() - && !x.is_unique_local() - && !x.is_unicast_link_local() - && !x.is_multicast()) - }) + .filter(|x| is_usable_public_ipv6_candidate(x, &self.global_ctx)) .collect::>() .iter() .for_each(|ip| { @@ -515,6 +520,11 @@ impl DirectConnectorManagerData { ); } }); + } else if self.global_ctx.is_ip_easytier_managed_ipv6(s_addr.ip()) { + tracing::debug!( + ?listener, + "skip EasyTier-managed IPv6 as direct-connect target" + ); } else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) { if self .global_ctx @@ -790,9 +800,10 @@ impl DirectConnectorManager { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{collections::BTreeSet, sync::Arc}; use crate::{ + common::global_ctx::tests::get_mock_global_ctx, connector::direct::{ DirectConnectorManager, DirectConnectorManagerData, DstListenerUrlBlackListItem, }, @@ -809,6 +820,24 @@ mod tests { use super::{TESTING, mapped_listener_port, resolve_mapped_listener_addrs}; + #[tokio::test] + async fn public_ipv6_candidate_rejects_easytier_managed_addr_even_in_tests() { + let global_ctx = get_mock_global_ctx(); + let managed_ipv6: cidr::Ipv6Inet = "2001:db8::2/128".parse().unwrap(); + global_ctx.set_public_ipv6_routes(BTreeSet::from([managed_ipv6])); + + assert!(!super::is_usable_public_ipv6_candidate_with_mode( + &"2001:db8::2".parse().unwrap(), + &global_ctx, + true, + )); + assert!(super::is_usable_public_ipv6_candidate_with_mode( + &"::1".parse().unwrap(), + &global_ctx, + true, + )); + } + #[test] fn udp_ipv6_url_matches_hole_punch_branch_condition() { let remote_url: url::Url = "udp://[2001:db8::1]:11010".parse().unwrap(); diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 832f0173..bfc6c7f2 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -1,19 +1,17 @@ -use std::{ - net::{SocketAddr, SocketAddrV4, SocketAddrV6}, - sync::Arc, -}; +use std::net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use crate::{ - common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector}, + common::{dns::socket_addrs, error::Error, global_ctx::ArcGlobalCtx, idn}, connector::dns_connector::DnsTunnelConnector, proto::common::PeerFeatureFlag, tunnel::{ - self, FromUrl, IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme, + self, IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme, ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, }, utils::BoxExt, }; use http_connector::HttpTunnelConnector; +use rand::seq::SliceRandom; pub mod direct; pub mod manual; @@ -56,7 +54,7 @@ pub(crate) fn should_background_p2p_with_peer( async fn set_bind_addr_for_peer_connector( connector: &mut (impl TunnelConnector + ?Sized), is_ipv4: bool, - ip_collector: &Arc, + global_ctx: &ArcGlobalCtx, ) { if cfg!(any( target_os = "android", @@ -69,7 +67,7 @@ async fn set_bind_addr_for_peer_connector( return; } - let ips = ip_collector.collect_ip_addrs().await; + let ips = global_ctx.get_ip_collector().collect_ip_addrs().await; if is_ipv4 { let mut bind_addrs = vec![]; for ipv4 in ips.interface_ipv4s { @@ -80,7 +78,11 @@ async fn set_bind_addr_for_peer_connector( } else { let mut bind_addrs = vec![]; for ipv6 in ips.interface_ipv6s.iter().chain(ips.public_ipv6.iter()) { - let socket_addr = SocketAddrV6::new(std::net::Ipv6Addr::from(*ipv6), 0, 0, 0).into(); + let ipv6 = std::net::Ipv6Addr::from(*ipv6); + if global_ctx.is_ip_easytier_managed_ipv6(&ipv6) { + continue; + } + let socket_addr = SocketAddrV6::new(ipv6, 0, 0, 0).into(); bind_addrs.push(socket_addr); } connector.set_bind_addrs(bind_addrs); @@ -88,6 +90,144 @@ async fn set_bind_addr_for_peer_connector( let _ = connector; } +struct ResolvedConnectorAddr { + addr: SocketAddr, + ip_version: IpVersion, +} + +fn connector_default_port(url: &url::Url) -> Option { + url.try_into() + .ok() + .and_then(|s: TunnelScheme| s.try_into().ok()) + .map(IpScheme::default_port) +} + +fn addr_matches_ip_version(addr: &SocketAddr, ip_version: IpVersion) -> bool { + match ip_version { + IpVersion::V4 => addr.is_ipv4(), + IpVersion::V6 => addr.is_ipv6(), + IpVersion::Both => true, + } +} + +fn infer_effective_ip_version(addrs: &[SocketAddr], requested_ip_version: IpVersion) -> IpVersion { + match requested_ip_version { + IpVersion::Both if addrs.iter().all(SocketAddr::is_ipv4) => IpVersion::V4, + IpVersion::Both if addrs.iter().all(SocketAddr::is_ipv6) => IpVersion::V6, + _ => requested_ip_version, + } +} + +async fn easytier_managed_ipv6_source_for_dst( + global_ctx: &ArcGlobalCtx, + dst_addr: SocketAddrV6, +) -> Result, Error> { + let socket = { + let _g = global_ctx.net_ns.guard(); + tokio::net::UdpSocket::bind("[::]:0").await? + }; + socket.connect(SocketAddr::V6(dst_addr)).await?; + + let IpAddr::V6(local_ip) = socket.local_addr()?.ip() else { + return Ok(None); + }; + + Ok(global_ctx + .is_ip_easytier_managed_ipv6(&local_ip) + .then_some(local_ip)) +} + +async fn ipv6_connector_reject_reason( + url: &url::Url, + global_ctx: &ArcGlobalCtx, + v6_addr: SocketAddrV6, + skip_source_validation_errors: bool, +) -> Result, Error> { + if global_ctx.is_ip_easytier_managed_ipv6(v6_addr.ip()) { + return Ok(Some(format!( + "{} resolves to EasyTier-managed IPv6 {}", + url, + v6_addr.ip() + ))); + } + + match easytier_managed_ipv6_source_for_dst(global_ctx, v6_addr).await { + Ok(Some(local_ip)) => Ok(Some(format!( + "{} would use EasyTier-managed IPv6 {} as local source for {}", + url, local_ip, v6_addr + ))), + Ok(None) => Ok(None), + Err(err) if skip_source_validation_errors => Ok(Some(format!( + "{} IPv6 candidate {} could not be validated: {}", + url, v6_addr, err + ))), + Err(err) => Err(err), + } +} + +async fn resolve_connector_socket_addr( + url: &url::Url, + global_ctx: &ArcGlobalCtx, + ip_version: IpVersion, +) -> Result { + let addrs = socket_addrs(url, || connector_default_port(url)) + .await + .map_err(|e| { + TunnelError::InvalidAddr(format!( + "failed to resolve socket addr, url: {}, error: {}", + url, e + )) + })?; + + let mut usable_addrs = Vec::new(); + let mut rejected_ipv6_reason = None; + let skip_source_validation_errors = ip_version == IpVersion::Both; + for addr in addrs + .into_iter() + .filter(|addr| addr_matches_ip_version(addr, ip_version)) + { + if let SocketAddr::V6(v6_addr) = addr + && let Some(reason) = ipv6_connector_reject_reason( + url, + global_ctx, + v6_addr, + skip_source_validation_errors, + ) + .await? + { + rejected_ipv6_reason = Some(reason); + continue; + } + + usable_addrs.push(addr); + } + + if usable_addrs.is_empty() { + if let Some(reason) = rejected_ipv6_reason { + return Err(Error::InvalidUrl(format!( + "{}, refusing overlay-backed underlay connection", + reason + ))); + } + + return Err(Error::TunnelError(TunnelError::NoDnsRecordFound( + ip_version, + ))); + } + + let effective_ip_version = infer_effective_ip_version(&usable_addrs, ip_version); + + let addr = usable_addrs + .choose(&mut rand::thread_rng()) + .copied() + .ok_or_else(|| Error::TunnelError(TunnelError::NoDnsRecordFound(ip_version)))?; + + Ok(ResolvedConnectorAddr { + addr, + ip_version: effective_ip_version, + }) +} + pub async fn create_connector_by_url( url: &str, global_ctx: &ArcGlobalCtx, @@ -98,9 +238,11 @@ pub async fn create_connector_by_url( let scheme = (&url) .try_into() .map_err(|_| TunnelError::InvalidProtocol(url.scheme().to_owned()))?; + let mut effective_connector_ip_version = ip_version; let mut connector: Box = match scheme { TunnelScheme::Ip(scheme) => { - let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?; + let resolved_addr = resolve_connector_socket_addr(&url, global_ctx, ip_version).await?; + effective_connector_ip_version = resolved_addr.ip_version; let mut connector: Box = match scheme { IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(), IpScheme::Udp => UdpTunnelConnector::new(url).boxed(), @@ -125,11 +267,12 @@ pub async fn create_connector_by_url( #[cfg(feature = "faketcp")] IpScheme::FakeTcp => tunnel::fake_tcp::FakeTcpTunnelConnector::new(url).boxed(), }; + connector.set_resolved_addr(resolved_addr.addr); if global_ctx.config.get_flags().bind_device { set_bind_addr_for_peer_connector( &mut connector, - dst_addr.is_ipv4(), - &global_ctx.get_ip_collector(), + resolved_addr.addr.is_ipv4(), + global_ctx, ) .await; } @@ -151,16 +294,38 @@ pub async fn create_connector_by_url( DnsTunnelConnector::new(url, global_ctx.clone()).boxed() } }; - connector.set_ip_version(ip_version); + connector.set_ip_version(effective_connector_ip_version); Ok(connector) } #[cfg(test)] mod tests { - use crate::proto::common::PeerFeatureFlag; + use std::collections::BTreeSet; - use super::{should_background_p2p_with_peer, should_try_p2p_with_peer}; + use crate::{ + common::global_ctx::tests::get_mock_global_ctx, proto::common::PeerFeatureFlag, + tunnel::IpVersion, + }; + + use super::{ + create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer, + }; + + #[tokio::test] + async fn connector_rejects_easytier_managed_ipv6_destination() { + let global_ctx = get_mock_global_ctx(); + let public_route: cidr::Ipv6Inet = "2001:db8::2/128".parse().unwrap(); + global_ctx.set_public_ipv6_routes(BTreeSet::from([public_route])); + + let ret = + create_connector_by_url("tcp://[2001:db8::2]:11010", &global_ctx, IpVersion::V6).await; + + assert!(matches!( + ret, + Err(crate::common::error::Error::InvalidUrl(_)) + )); + } #[test] fn lazy_background_p2p_requires_need_p2p() { diff --git a/easytier/src/connector/udp_hole_punch/common.rs b/easytier/src/connector/udp_hole_punch/common.rs index 2b0d40be..e82b0663 100644 --- a/easytier/src/connector/udp_hole_punch/common.rs +++ b/easytier/src/connector/udp_hole_punch/common.rs @@ -719,25 +719,31 @@ async fn check_udp_socket_local_addr( ) -> Result<(), Error> { let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.connect(remote_mapped_addr).await?; - if let Ok(local_addr) = socket.local_addr() { - // local_addr should not be equal to an EasyTier-managed virtual/public address. - match local_addr.ip() { - IpAddr::V4(ip) => { - if global_ctx.get_ipv4().map(|ip| ip.address()) == Some(ip) { - return Err(anyhow::anyhow!("local address is virtual ipv4").into()); - } - } - IpAddr::V6(ip) => { - if global_ctx.is_ip_local_ipv6(&ip) { - return Err(anyhow::anyhow!("local address is easytier-managed ipv6").into()); - } - } - } + if let Ok(local_addr) = socket.local_addr() + && let Some(err) = easytier_managed_local_addr_error(&global_ctx, local_addr) + { + return Err(anyhow::anyhow!(err).into()); } Ok(()) } +fn easytier_managed_local_addr_error( + global_ctx: &ArcGlobalCtx, + local_addr: SocketAddr, +) -> Option<&'static str> { + // local_addr should not be equal to an EasyTier-managed virtual/public address. + match local_addr.ip() { + IpAddr::V4(ip) if global_ctx.get_ipv4().map(|ip| ip.address()) == Some(ip) => { + Some("local address is virtual ipv4") + } + IpAddr::V6(ip) if global_ctx.is_ip_easytier_managed_ipv6(&ip) => { + Some("local address is easytier-managed ipv6") + } + _ => None, + } +} + pub(crate) async fn try_connect_with_socket( global_ctx: ArcGlobalCtx, socket: Arc, @@ -763,11 +769,29 @@ pub(crate) async fn try_connect_with_socket( #[cfg(test)] mod tests { + use std::{collections::BTreeSet, net::SocketAddr}; + + use crate::common::global_ctx::tests::get_mock_global_ctx; + use super::{ - MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS, should_create_public_listener, - should_retry_public_listener_selection, + MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS, easytier_managed_local_addr_error, + should_create_public_listener, should_retry_public_listener_selection, }; + #[tokio::test] + async fn local_addr_check_rejects_easytier_public_ipv6_route() { + let global_ctx = get_mock_global_ctx(); + let public_route: cidr::Ipv6Inet = "2001:db8::4/128".parse().unwrap(); + global_ctx.set_public_ipv6_routes(BTreeSet::from([public_route])); + + let local_addr: SocketAddr = "[2001:db8::4]:1234".parse().unwrap(); + + assert_eq!( + easytier_managed_local_addr_error(&global_ctx, local_addr), + Some("local address is easytier-managed ipv6") + ); + } + #[test] fn listener_selection_prefers_reuse_before_cap() { assert!(!should_create_public_listener(1, true, true, false, false)); diff --git a/easytier/src/peers/acl_filter.rs b/easytier/src/peers/acl_filter.rs index adb95f68..58446206 100644 --- a/easytier/src/peers/acl_filter.rs +++ b/easytier/src/peers/acl_filter.rs @@ -94,6 +94,8 @@ impl AclFilter { /// Preserves connection tracking and rate limiting state across reloads /// Now lock-free and doesn't require &mut self! pub fn reload_rules(&self, acl_config: Option<&Acl>) { + self.outbound_allow_records.clear(); + let Some(acl_config) = acl_config else { self.acl_enabled.store(false, Ordering::Relaxed); return; @@ -400,14 +402,15 @@ mod tests { use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, sync::Arc, + time::Instant, }; use crate::{ common::acl_processor::PacketInfo, - proto::acl::{ChainType, Protocol}, + proto::acl::{Acl, ChainType, Protocol}, }; - use super::AclFilter; + use super::{AclFilter, OutboundAllowRecord}; fn packet_info(dst_ip: IpAddr) -> PacketInfo { PacketInfo { @@ -445,4 +448,40 @@ mod tests { assert_eq!(chain, ChainType::Forward); } + + #[tokio::test] + async fn reload_rules_clears_outbound_allow_records() { + let filter = AclFilter::new(); + filter.outbound_allow_records.insert( + OutboundAllowRecord { + src_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), + src_port: Some(1234), + dst_port: Some(80), + protocol: Protocol::Tcp, + }, + Instant::now(), + ); + assert_eq!(filter.outbound_allow_records.len(), 1); + + filter.reload_rules(Some(&Acl::default())); + + assert_eq!(filter.outbound_allow_records.len(), 0); + + filter.outbound_allow_records.insert( + OutboundAllowRecord { + src_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), + dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + src_port: Some(4321), + dst_port: Some(443), + protocol: Protocol::Tcp, + }, + Instant::now(), + ); + assert_eq!(filter.outbound_allow_records.len(), 1); + + filter.reload_rules(None); + + assert_eq!(filter.outbound_allow_records.len(), 0); + } } diff --git a/easytier/src/peers/peer_rpc_service.rs b/easytier/src/peers/peer_rpc_service.rs index bc639210..08e82d2d 100644 --- a/easytier/src/peers/peer_rpc_service.rs +++ b/easytier/src/peers/peer_rpc_service.rs @@ -12,6 +12,22 @@ use crate::{ tunnel::udp, }; +fn remove_easytier_managed_ipv6s(ret: &mut GetIpListResponse, global_ctx: &ArcGlobalCtx) { + ret.interface_ipv6s.retain(|ip| { + let ip = std::net::Ipv6Addr::from(*ip); + !global_ctx.is_ip_easytier_managed_ipv6(&ip) + }); + + if ret + .public_ipv6 + .as_ref() + .map(|ip| std::net::Ipv6Addr::from(*ip)) + .is_some_and(|ip| global_ctx.is_ip_easytier_managed_ipv6(&ip)) + { + ret.public_ipv6 = None; + } +} + #[derive(Clone)] pub struct DirectConnectorManagerRpcServer { // TODO: this only cache for one src peer, should make it global @@ -36,15 +52,7 @@ impl DirectConnectorRpc for DirectConnectorManagerRpcServer { .chain(self.global_ctx.get_running_listeners()) .map(Into::into) .collect(); - // remove et ipv6 from the interface ipv6 list - if let Some(et_ipv6) = self.global_ctx.get_ipv6() { - let et_ipv6: crate::proto::common::Ipv6Addr = et_ipv6.address().into(); - ret.interface_ipv6s.retain(|x| *x != et_ipv6); - } - if let Some(public_ipv6) = self.global_ctx.get_public_ipv6_lease() { - let public_ipv6: crate::proto::common::Ipv6Addr = public_ipv6.address().into(); - ret.interface_ipv6s.retain(|x| *x != public_ipv6); - } + remove_easytier_managed_ipv6s(&mut ret, &self.global_ctx); tracing::trace!( "get_ip_list: public_ipv4: {:?}, public_ipv6: {:?}, listeners: {:?}", ret.public_ipv4, @@ -88,3 +96,41 @@ impl DirectConnectorManagerRpcServer { Self { global_ctx } } } + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::{ + common::global_ctx::tests::get_mock_global_ctx, + peers::peer_rpc_service::remove_easytier_managed_ipv6s, proto::peer_rpc::GetIpListResponse, + }; + + #[tokio::test] + async fn get_ip_list_sanitizer_removes_managed_ipv6_from_all_sources() { + let global_ctx = get_mock_global_ctx(); + let virtual_ipv6 = "fd00::1/64".parse().unwrap(); + let public_ipv6 = "2001:db8::2/128".parse().unwrap(); + let physical_ipv6: std::net::Ipv6Addr = "2001:db8::3".parse().unwrap(); + let routed_ipv6: cidr::Ipv6Inet = "2001:db8::4/128".parse().unwrap(); + global_ctx.set_ipv6(Some(virtual_ipv6)); + global_ctx.set_public_ipv6_lease(Some(public_ipv6)); + global_ctx.set_public_ipv6_routes(BTreeSet::from([routed_ipv6])); + + let mut ip_list = GetIpListResponse { + public_ipv6: Some(public_ipv6.address().into()), + interface_ipv6s: vec![ + virtual_ipv6.address().into(), + public_ipv6.address().into(), + routed_ipv6.address().into(), + physical_ipv6.into(), + ], + ..Default::default() + }; + + remove_easytier_managed_ipv6s(&mut ip_list, &global_ctx); + + assert_eq!(ip_list.public_ipv6, None); + assert_eq!(ip_list.interface_ipv6s, vec![physical_ipv6.into()]); + } +} diff --git a/easytier/src/peers/public_ipv6.rs b/easytier/src/peers/public_ipv6.rs index 0773b5e0..90432b4e 100644 --- a/easytier/src/peers/public_ipv6.rs +++ b/easytier/src/peers/public_ipv6.rs @@ -243,6 +243,8 @@ impl PublicIpv6Service { .copied() .collect::>(); *cached_routes = routes; + self.global_ctx + .set_public_ipv6_routes(cached_routes.clone()); self.global_ctx .issue_event(GlobalCtxEvent::PublicIpv6RoutesUpdated(added, removed)); } diff --git a/easytier/src/tunnel/fake_tcp/mod.rs b/easytier/src/tunnel/fake_tcp/mod.rs index 895f6633..8acfd5a0 100644 --- a/easytier/src/tunnel/fake_tcp/mod.rs +++ b/easytier/src/tunnel/fake_tcp/mod.rs @@ -281,6 +281,7 @@ impl TunnelListener for FakeTcpTunnelListener { pub struct FakeTcpTunnelConnector { addr: url::Url, ip_to_if_name: IpToIfNameCache, + resolved_addr: Option, } impl FakeTcpTunnelConnector { @@ -288,6 +289,7 @@ impl FakeTcpTunnelConnector { FakeTcpTunnelConnector { addr, ip_to_if_name: IpToIfNameCache::new(), + resolved_addr: None, } } } @@ -314,7 +316,10 @@ fn get_local_ip_for_destination(destination: IpAddr) -> Option { #[async_trait::async_trait] impl TunnelConnector for FakeTcpTunnelConnector { async fn connect(&mut self) -> Result, TunnelError> { - let remote_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?; + let remote_addr = match self.resolved_addr { + Some(addr) => addr, + None => SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?, + }; let local_ip = get_local_ip_for_destination(remote_addr.ip()) .ok_or(TunnelError::InternalError("Failed to get local ip".into()))?; @@ -390,6 +395,10 @@ impl TunnelConnector for FakeTcpTunnelConnector { fn remote_url(&self) -> url::Url { self.addr.clone() } + + fn set_resolved_addr(&mut self, addr: SocketAddr) { + self.resolved_addr = Some(addr); + } } type RecvFut = Pin> + Send + Sync>>; diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index a090b074..06629b6c 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -141,6 +141,7 @@ pub trait TunnelConnector: Send { fn remote_url(&self) -> url::Url; fn set_bind_addrs(&mut self, _addrs: Vec) {} fn set_ip_version(&mut self, _ip_version: IpVersion) {} + fn set_resolved_addr(&mut self, _addr: SocketAddr) {} } pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url { diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs index 106d05a2..2a30396b 100644 --- a/easytier/src/tunnel/quic.rs +++ b/easytier/src/tunnel/quic.rs @@ -432,6 +432,7 @@ pub struct QuicTunnelConnector { addr: url::Url, global_ctx: ArcGlobalCtx, ip_version: IpVersion, + resolved_addr: Option, } impl QuicTunnelConnector { @@ -440,6 +441,7 @@ impl QuicTunnelConnector { addr, global_ctx, ip_version: IpVersion::Both, + resolved_addr: None, } } } @@ -447,7 +449,10 @@ impl QuicTunnelConnector { #[async_trait::async_trait] impl TunnelConnector for QuicTunnelConnector { async fn connect(&mut self) -> Result, TunnelError> { - let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; + let addr = match self.resolved_addr { + Some(addr) => addr, + None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?, + }; let (endpoint, connection) = QuicEndpointManager::connect(&self.global_ctx, addr).await?; let local_addr = endpoint.local_addr()?; @@ -484,6 +489,10 @@ impl TunnelConnector for QuicTunnelConnector { fn set_ip_version(&mut self, ip_version: IpVersion) { self.ip_version = ip_version; } + + fn set_resolved_addr(&mut self, addr: SocketAddr) { + self.resolved_addr = Some(addr); + } } #[cfg(test)] diff --git a/easytier/src/tunnel/tcp.rs b/easytier/src/tunnel/tcp.rs index fd501558..ec7de9be 100644 --- a/easytier/src/tunnel/tcp.rs +++ b/easytier/src/tunnel/tcp.rs @@ -129,6 +129,7 @@ pub struct TcpTunnelConnector { bind_addrs: Vec, ip_version: IpVersion, + resolved_addr: Option, } impl TcpTunnelConnector { @@ -137,6 +138,7 @@ impl TcpTunnelConnector { addr, bind_addrs: vec![], ip_version: IpVersion::Both, + resolved_addr: None, } } @@ -175,7 +177,10 @@ impl TcpTunnelConnector { #[async_trait] impl super::TunnelConnector for TcpTunnelConnector { async fn connect(&mut self) -> Result, TunnelError> { - let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; + let addr = match self.resolved_addr { + Some(addr) => addr, + None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?, + }; if self.bind_addrs.is_empty() { self.connect_with_default_bind(addr).await } else { @@ -194,6 +199,10 @@ impl super::TunnelConnector for TcpTunnelConnector { fn set_ip_version(&mut self, ip_version: IpVersion) { self.ip_version = ip_version; } + + fn set_resolved_addr(&mut self, addr: SocketAddr) { + self.resolved_addr = Some(addr); + } } #[cfg(test)] @@ -294,6 +303,31 @@ mod tests { ); } + #[tokio::test] + async fn connector_uses_pre_resolved_addr_without_resolving_url() { + let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:0".parse().unwrap()); + listener.listen().await.unwrap(); + + let port = listener.local_url().port().unwrap(); + let source_url: url::Url = format!("tcp://unresolvable.invalid:{port}") + .parse() + .unwrap(); + let resolved_addr: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap(); + let mut connector = TcpTunnelConnector::new(source_url.clone()); + connector.set_resolved_addr(resolved_addr); + + let accept_task = tokio::spawn(async move { listener.accept().await.unwrap() }); + let tunnel = connector.connect().await.unwrap(); + let _accepted_tunnel = accept_task.await.unwrap(); + + let info = tunnel.info().unwrap(); + assert_eq!(info.remote_addr.unwrap().url, source_url.to_string()); + + let resolved_remote_addr: url::Url = info.resolved_remote_addr.unwrap().into(); + assert_eq!(resolved_remote_addr.host_str(), Some("127.0.0.1")); + assert_eq!(resolved_remote_addr.port(), Some(port)); + } + #[tokio::test] async fn test_alloc_port() { // v4 diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index c8a4c0ed..714dd10f 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -682,6 +682,7 @@ pub struct UdpTunnelConnector { addr: url::Url, bind_addrs: Vec, ip_version: IpVersion, + resolved_addr: Option, } impl UdpTunnelConnector { @@ -690,6 +691,7 @@ impl UdpTunnelConnector { addr, bind_addrs: vec![], ip_version: IpVersion::Both, + resolved_addr: None, } } @@ -906,7 +908,10 @@ impl UdpTunnelConnector { #[async_trait] impl super::TunnelConnector for UdpTunnelConnector { async fn connect(&mut self) -> Result, TunnelError> { - let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; + let addr = match self.resolved_addr { + Some(addr) => addr, + None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?, + }; if self.bind_addrs.is_empty() || addr.is_ipv6() { self.connect_with_default_bind(addr).await } else { @@ -925,6 +930,10 @@ impl super::TunnelConnector for UdpTunnelConnector { fn set_ip_version(&mut self, ip_version: IpVersion) { self.ip_version = ip_version; } + + fn set_resolved_addr(&mut self, addr: SocketAddr) { + self.resolved_addr = Some(addr); + } } #[cfg(test)] diff --git a/easytier/src/tunnel/websocket.rs b/easytier/src/tunnel/websocket.rs index a9f58bd4..e58736e0 100644 --- a/easytier/src/tunnel/websocket.rs +++ b/easytier/src/tunnel/websocket.rs @@ -198,6 +198,7 @@ impl TunnelListener for WsTunnelListener { pub struct WsTunnelConnector { addr: url::Url, ip_version: IpVersion, + resolved_addr: Option, bind_addrs: Vec, } @@ -207,6 +208,7 @@ impl WsTunnelConnector { WsTunnelConnector { addr, ip_version: IpVersion::Both, + resolved_addr: None, bind_addrs: vec![], } @@ -214,11 +216,10 @@ impl WsTunnelConnector { async fn connect_with( addr: url::Url, - ip_version: IpVersion, + socket_addr: SocketAddr, tcp_socket: TcpSocket, ) -> Result, TunnelError> { let is_wss = is_wss(&addr)?; - let socket_addr = SocketAddr::from_url(addr.clone(), ip_version).await?; let stream = tcp_socket.connect(socket_addr).await?; if let Err(error) = stream.set_nodelay(true) { tracing::warn!(?error, "set_nodelay fail in ws connect"); @@ -273,7 +274,7 @@ impl WsTunnelConnector { } else { TcpSocket::new_v6()? }; - Self::connect_with(self.addr.clone(), self.ip_version, socket).await + Self::connect_with(self.addr.clone(), addr, socket).await } async fn connect_with_custom_bind( @@ -285,11 +286,7 @@ impl WsTunnelConnector { for bind_addr in self.bind_addrs.iter() { tracing::info!(?bind_addr, ?addr, "bind addr"); match bind().addr(*bind_addr).only_v6(true).call() { - Ok(socket) => futures.push(Self::connect_with( - self.addr.clone(), - self.ip_version, - socket, - )), + Ok(socket) => futures.push(Self::connect_with(self.addr.clone(), addr, socket)), Err(error) => { tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail"); continue; @@ -304,7 +301,10 @@ impl WsTunnelConnector { #[async_trait::async_trait] impl TunnelConnector for WsTunnelConnector { async fn connect(&mut self) -> Result, TunnelError> { - let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; + let addr = match self.resolved_addr { + Some(addr) => addr, + None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?, + }; if self.bind_addrs.is_empty() || addr.is_ipv6() { self.connect_with_default_bind(addr).await } else { @@ -323,6 +323,10 @@ impl TunnelConnector for WsTunnelConnector { fn set_bind_addrs(&mut self, addrs: Vec) { self.bind_addrs = addrs; } + + fn set_resolved_addr(&mut self, addr: SocketAddr) { + self.resolved_addr = Some(addr); + } } #[cfg(test)] diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs index 66fba55a..c37d8d1f 100644 --- a/easytier/src/tunnel/wireguard.rs +++ b/easytier/src/tunnel/wireguard.rs @@ -598,6 +598,7 @@ pub struct WgTunnelConnector { bind_addrs: Vec, ip_version: IpVersion, + resolved_addr: Option, } impl Debug for WgTunnelConnector { @@ -617,6 +618,7 @@ impl WgTunnelConnector { udp: None, bind_addrs: vec![], ip_version: IpVersion::Both, + resolved_addr: None, } } @@ -702,7 +704,10 @@ impl WgTunnelConnector { impl super::TunnelConnector for WgTunnelConnector { #[tracing::instrument] async fn connect(&mut self) -> Result, TunnelError> { - let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?; + let addr = match self.resolved_addr { + Some(addr) => addr, + None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?, + }; if addr.is_ipv6() { return self.connect_with_ipv6(addr).await; @@ -744,6 +749,10 @@ impl super::TunnelConnector for WgTunnelConnector { fn set_ip_version(&mut self, ip_version: IpVersion) { self.ip_version = ip_version; } + + fn set_resolved_addr(&mut self, addr: SocketAddr) { + self.resolved_addr = Some(addr); + } } #[cfg(test)]