use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::Arc, time::Duration, }; use crossbeam::atomic::AtomicCell; use dashmap::{DashMap, DashSet}; use rand::seq::SliceRandom as _; use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use tracing::{instrument, Instrument, Level}; use zerocopy::FromBytes as _; use crate::{ common::{ error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, stun::StunInfoCollectorTrait as _, PeerId, }, defer, peers::peer_manager::PeerManager, proto::common::NatType, tunnel::{ packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE}, udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener}, Tunnel, TunnelConnCounter, TunnelListener as _, }, }; pub(crate) const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 32; fn generate_shuffled_port_vec() -> Vec { let mut rng = rand::thread_rng(); let mut port_vec: Vec = (1..=65535).collect(); port_vec.shuffle(&mut rng); port_vec } pub(crate) enum UdpPunchClientMethod { None, ConeToCone, SymToCone, EasySymToEasySym, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) enum UdpNatType { Unknown, Open(NatType), Cone(NatType), // bool means if it is incremental EasySymmetric(NatType, bool), HardSymmetric(NatType), } impl From for UdpNatType { fn from(nat_type: NatType) -> Self { match nat_type { NatType::Unknown => UdpNatType::Unknown, NatType::NoPat | NatType::OpenInternet => UdpNatType::Open(nat_type), NatType::FullCone | NatType::Restricted | NatType::PortRestricted => { UdpNatType::Cone(nat_type) } NatType::Symmetric | NatType::SymUdpFirewall => UdpNatType::HardSymmetric(nat_type), NatType::SymmetricEasyInc => UdpNatType::EasySymmetric(nat_type, true), NatType::SymmetricEasyDec => UdpNatType::EasySymmetric(nat_type, false), } } } impl Into for UdpNatType { fn into(self) -> NatType { match self { UdpNatType::Unknown => NatType::Unknown, UdpNatType::Open(nat_type) => nat_type, UdpNatType::Cone(nat_type) => nat_type, UdpNatType::EasySymmetric(nat_type, _) => nat_type, UdpNatType::HardSymmetric(nat_type) => nat_type, } } } impl UdpNatType { pub(crate) fn is_open(&self) -> bool { matches!(self, UdpNatType::Open(_)) } pub(crate) fn is_unknown(&self) -> bool { matches!(self, UdpNatType::Unknown) } pub(crate) fn is_sym(&self) -> bool { self.is_hard_sym() || self.is_easy_sym() } pub(crate) fn is_hard_sym(&self) -> bool { matches!(self, UdpNatType::HardSymmetric(_)) } pub(crate) fn is_easy_sym(&self) -> bool { matches!(self, UdpNatType::EasySymmetric(_, _)) } pub(crate) fn is_cone(&self) -> bool { matches!(self, UdpNatType::Cone(_)) } pub(crate) fn get_inc_of_easy_sym(&self) -> Option { match self { UdpNatType::EasySymmetric(_, inc) => Some(*inc), _ => None, } } pub(crate) fn get_punch_hole_method(&self, other: Self) -> UdpPunchClientMethod { if other.is_unknown() { if self.is_sym() { return UdpPunchClientMethod::SymToCone; } else { return UdpPunchClientMethod::ConeToCone; } } if self.is_unknown() { if other.is_sym() { return UdpPunchClientMethod::None; } else { return UdpPunchClientMethod::ConeToCone; } } if self.is_open() || other.is_open() { // open nat does not need to punch hole return UdpPunchClientMethod::None; } if self.is_cone() { if other.is_sym() { return UdpPunchClientMethod::None; } else { return UdpPunchClientMethod::ConeToCone; } } else if self.is_easy_sym() { if other.is_hard_sym() { return UdpPunchClientMethod::None; } else if other.is_easy_sym() { return UdpPunchClientMethod::EasySymToEasySym; } else { return UdpPunchClientMethod::SymToCone; } } else if self.is_hard_sym() { if other.is_sym() { return UdpPunchClientMethod::None; } else { return UdpPunchClientMethod::SymToCone; } } unreachable!("invalid nat type"); } pub(crate) fn can_punch_hole_as_client( &self, other: Self, my_peer_id: PeerId, dst_peer_id: PeerId, ) -> bool { match self.get_punch_hole_method(other) { UdpPunchClientMethod::None => false, UdpPunchClientMethod::ConeToCone | UdpPunchClientMethod::SymToCone => true, UdpPunchClientMethod::EasySymToEasySym => my_peer_id < dst_peer_id, } } } #[derive(Debug)] pub(crate) struct PunchedUdpSocket { pub(crate) socket: Arc, pub(crate) tid: u32, pub(crate) remote_addr: SocketAddr, } // used for symmetric hole punching, binding to multiple ports to increase the chance of success pub(crate) struct UdpSocketArray { sockets: Arc>>, max_socket_count: usize, net_ns: NetNS, tasks: Arc>>, intreast_tids: Arc>, tid_to_socket: Arc>>, } impl UdpSocketArray { pub fn new(max_socket_count: usize, net_ns: NetNS) -> Self { let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); join_joinset_background(tasks.clone(), "UdpSocketArray".to_owned()); Self { sockets: Arc::new(DashMap::new()), max_socket_count, net_ns, tasks, intreast_tids: Arc::new(DashSet::new()), tid_to_socket: Arc::new(DashMap::new()), } } pub fn started(&self) -> bool { !self.sockets.is_empty() } pub async fn add_new_socket(&self, socket: Arc) -> Result<(), anyhow::Error> { let socket_map = self.sockets.clone(); let local_addr = socket.local_addr()?; let intreast_tids = self.intreast_tids.clone(); let tid_to_socket = self.tid_to_socket.clone(); socket_map.insert(local_addr, socket.clone()); self.tasks.lock().unwrap().spawn( async move { defer!(socket_map.remove(&local_addr);); let mut buf = [0u8; UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize]; tracing::trace!(?local_addr, "udp socket added"); loop { let Ok((len, addr)) = socket.recv_from(&mut buf).await else { break; }; tracing::debug!(?len, ?addr, "got raw packet"); if len != UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize { continue; } let Some(p) = UDPTunnelHeader::ref_from_prefix(&buf) else { continue; }; let tid = p.conn_id.get(); let valid = p.msg_type == UdpPacketType::HolePunch as u8 && p.len.get() == HOLE_PUNCH_PACKET_BODY_LEN; tracing::debug!(?p, ?addr, ?tid, ?valid, ?p, "got udp hole punch packet"); if !valid { continue; } if intreast_tids.contains(&tid) { tracing::info!(?addr, ?tid, "got hole punching packet with intreast tid"); tid_to_socket .entry(tid) .or_insert_with(Vec::new) .push(PunchedUdpSocket { socket: socket.clone(), tid, remote_addr: addr, }); break; } } tracing::debug!(?local_addr, "udp socket recv loop end"); } .instrument(tracing::info_span!("udp array socket recv loop")), ); Ok(()) } #[instrument(err)] pub async fn start(&self) -> Result<(), anyhow::Error> { tracing::info!("starting udp socket array"); while self.sockets.len() < self.max_socket_count { let socket = { let _g = self.net_ns.guard(); Arc::new(UdpSocket::bind("0.0.0.0:0").await?) }; self.add_new_socket(socket).await?; } Ok(()) } #[instrument(err)] pub async fn send_with_all(&self, data: &[u8], addr: SocketAddr) -> Result<(), anyhow::Error> { tracing::info!(?addr, "sending hole punching packet"); let sockets = self .sockets .iter() .map(|s| s.value().clone()) .collect::>(); for socket in sockets.iter() { for _ in 0..3 { socket.send_to(data, addr).await?; } } Ok(()) } #[instrument(ret(level = Level::DEBUG))] pub fn try_fetch_punched_socket(&self, tid: u32) -> Option { tracing::debug!(?tid, "try fetch punched socket"); self.tid_to_socket.get_mut(&tid)?.value_mut().pop() } pub fn add_intreast_tid(&self, tid: u32) { self.intreast_tids.insert(tid); } pub fn remove_intreast_tid(&self, tid: u32) { self.intreast_tids.remove(&tid); self.tid_to_socket.remove(&tid); } } impl std::fmt::Debug for UdpSocketArray { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("UdpSocketArray") .field("sockets", &self.sockets.len()) .field("max_socket_count", &self.max_socket_count) .field("started", &self.started()) .field("intreast_tids", &self.intreast_tids.len()) .field("tid_to_socket", &self.tid_to_socket.len()) .finish() } } #[derive(Debug)] pub(crate) struct UdpHolePunchListener { socket: Arc, tasks: JoinSet<()>, running: Arc>, mapped_addr: SocketAddr, conn_counter: Arc>, listen_time: std::time::Instant, last_select_time: AtomicCell, last_active_time: Arc>, } impl UdpHolePunchListener { async fn get_avail_port() -> Result { let socket = UdpSocket::bind("0.0.0.0:0").await?; Ok(socket.local_addr()?.port()) } #[instrument(err)] pub async fn new(peer_mgr: Arc) -> Result { Self::new_ext(peer_mgr, true, None).await } #[instrument(err)] pub async fn new_ext( peer_mgr: Arc, with_mapped_addr: bool, port: Option, ) -> Result { let port = port.unwrap_or(Self::get_avail_port().await?); let listen_url = format!("udp://0.0.0.0:{}", port); let mapped_addr = if with_mapped_addr { let gctx = peer_mgr.get_global_ctx(); let stun_info_collect = gctx.get_stun_info_collector(); stun_info_collect.get_udp_port_mapping(port).await? } else { SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), port)) }; let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap()); { let _g = peer_mgr.get_global_ctx().net_ns.guard(); listener.listen().await?; } let socket = listener.get_socket().unwrap(); let running = Arc::new(AtomicCell::new(true)); let running_clone = running.clone(); let conn_counter = listener.get_conn_counter(); let mut tasks = JoinSet::new(); tasks.spawn(async move { while let Ok(conn) = listener.accept().await { tracing::warn!(?conn, "udp hole punching listener got peer connection"); let peer_mgr = peer_mgr.clone(); tokio::spawn(async move { if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await { tracing::error!( ?e, "failed to add tunnel as server in hole punch listener" ); } }); } running_clone.store(false); }); let last_active_time = Arc::new(AtomicCell::new(std::time::Instant::now())); let conn_counter_clone = conn_counter.clone(); let last_active_time_clone = last_active_time.clone(); tasks.spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(5)).await; if conn_counter_clone.get().unwrap_or(0) != 0 { last_active_time_clone.store(std::time::Instant::now()); } } }); tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started"); Ok(Self { tasks, socket, running, mapped_addr, conn_counter, listen_time: std::time::Instant::now(), last_select_time: AtomicCell::new(std::time::Instant::now()), last_active_time, }) } pub async fn get_socket(&self) -> Arc { self.last_select_time.store(std::time::Instant::now()); self.socket.clone() } pub async fn get_conn_count(&self) -> usize { self.conn_counter.get().unwrap_or(0) as usize } } pub(crate) struct PunchHoleServerCommon { peer_mgr: Arc, listeners: Arc>>, tasks: Arc>>, } impl PunchHoleServerCommon { pub(crate) fn new(peer_mgr: Arc) -> Self { let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); join_joinset_background(tasks.clone(), "PunchHoleServerCommon".to_owned()); let listeners = Arc::new(Mutex::new(Vec::::new())); let l = listeners.clone(); tasks.lock().unwrap().spawn(async move { loop { tokio::time::sleep(Duration::from_secs(5)).await; { // remove listener that is not active for 40 seconds but keep listeners that are selected less than 30 seconds l.lock().await.retain(|listener| { listener.last_active_time.load().elapsed().as_secs() < 40 || listener.last_select_time.load().elapsed().as_secs() < 30 }); } } }); Self { peer_mgr, listeners, tasks, } } pub(crate) async fn add_listener(&self, listener: UdpHolePunchListener) { self.listeners.lock().await.push(listener); } pub(crate) async fn find_listener(&self, addr: &SocketAddr) -> Option> { let all_listener_sockets = self.listeners.lock().await; let listener = all_listener_sockets .iter() .find(|listener| listener.mapped_addr == *addr && listener.running.load())?; Some(listener.get_socket().await) } pub(crate) async fn my_udp_nat_type(&self) -> i32 { self.peer_mgr .get_global_ctx() .get_stun_info_collector() .get_stun_info() .udp_nat_type } pub(crate) async fn select_listener( &self, use_new_listener: bool, ) -> Option<(Arc, SocketAddr)> { let all_listener_sockets = &self.listeners; let mut use_last = false; if all_listener_sockets.lock().await.len() < 16 || use_new_listener { tracing::warn!("creating new udp hole punching listener"); all_listener_sockets.lock().await.push( UdpHolePunchListener::new(self.peer_mgr.clone()) .await .ok()?, ); use_last = true; } let mut locked = all_listener_sockets.lock().await; let listener = if use_last { locked.last_mut()? } else { // use the listener that is active most recently locked .iter_mut() .max_by_key(|listener| listener.last_active_time.load())? }; if listener.mapped_addr.ip().is_unspecified() { tracing::info!("listener mapped addr is unspecified, trying to get mapped addr"); listener.mapped_addr = self .get_global_ctx() .get_stun_info_collector() .get_udp_port_mapping(listener.mapped_addr.port()) .await .ok()?; } Some((listener.get_socket().await, listener.mapped_addr)) } pub(crate) fn get_joinset(&self) -> Arc>> { self.tasks.clone() } pub(crate) fn get_global_ctx(&self) -> ArcGlobalCtx { self.peer_mgr.get_global_ctx() } pub(crate) fn get_peer_mgr(&self) -> Arc { self.peer_mgr.clone() } } #[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))] pub(crate) async fn send_symmetric_hole_punch_packet( ports: &Vec, udp: Arc, transaction_id: u32, public_ips: &Vec, port_start_idx: usize, max_packets: usize, ) -> Result { tracing::debug!("sending hard symmetric hole punching packet"); let mut sent_packets = 0; let mut cur_port_idx = port_start_idx; while sent_packets < max_packets { let port = ports[cur_port_idx % ports.len()]; for pub_ip in public_ips { let addr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, port)); for _ in 0..3 { let packet = new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN); udp.send_to(&packet.into_bytes(), addr).await?; } sent_packets += 1; } cur_port_idx = cur_port_idx.wrapping_add(1); tokio::time::sleep(Duration::from_millis(1)).await; } Ok(cur_port_idx % ports.len()) } pub(crate) async fn try_connect_with_socket( socket: Arc, remote_mapped_addr: SocketAddr, ) -> Result, Error> { let connector = UdpTunnelConnector::new( format!( "udp://{}:{}", remote_mapped_addr.ip(), remote_mapped_addr.port() ) .to_string() .parse() .unwrap(), ); connector .try_connect_with_socket(socket, remote_mapped_addr) .await .map_err(|e| Error::from(e)) }