use std::{net::SocketAddr, sync::Arc}; use anyhow::Context; use crossbeam::atomic::AtomicCell; use rand::{seq::SliceRandom, SeedableRng}; use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use tracing::Instrument; use crate::{ common::{ constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, stun::StunInfoCollectorTrait, PeerId, }, peers::peer_manager::PeerManager, rpc::NatType, tunnel::{ common::setup_sokcet2, udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener}, Tunnel, TunnelConnCounter, TunnelListener, }, }; use super::direct::PeerManagerForDirectConnector; #[tarpc::service] pub trait UdpHolePunchService { async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option; } #[derive(Debug)] struct UdpHolePunchListener { socket: Arc, tasks: JoinSet<()>, running: Arc>, mapped_addr: SocketAddr, conn_counter: Arc>, listen_time: std::time::Instant, last_select_time: AtomicCell, last_connected_time: Arc>, } impl UdpHolePunchListener { async fn get_avail_port() -> Result { let socket = UdpSocket::bind("0.0.0.0:0").await?; Ok(socket.local_addr()?.port()) } pub async fn new(peer_mgr: Arc) -> Result { let port = Self::get_avail_port().await?; let listen_url = format!("udp://0.0.0.0:{}", port); let gctx = peer_mgr.get_global_ctx(); let stun_info_collect = gctx.get_stun_info_collector(); let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?; let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap()); { let _g = peer_mgr.get_global_ctx().net_ns.guard(); listener.listen().await?; } let socket = listener.get_socket().unwrap(); let running = Arc::new(AtomicCell::new(true)); let running_clone = running.clone(); let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now())); let last_connected_time_clone = last_connected_time.clone(); let conn_counter = listener.get_conn_counter(); let mut tasks = JoinSet::new(); tasks.spawn(async move { while let Ok(conn) = listener.accept().await { last_connected_time_clone.store(std::time::Instant::now()); tracing::warn!(?conn, "udp hole punching listener got peer connection"); let peer_mgr = peer_mgr.clone(); tokio::spawn(async move { if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await { tracing::error!( ?e, "failed to add tunnel as server in hole punch listener" ); } }); } running_clone.store(false); }); tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started"); Ok(Self { tasks, socket, running, mapped_addr, conn_counter, listen_time: std::time::Instant::now(), last_select_time: AtomicCell::new(std::time::Instant::now()), last_connected_time, }) } pub async fn get_socket(&self) -> Arc { self.last_select_time.store(std::time::Instant::now()); self.socket.clone() } } #[derive(Debug)] struct UdpHolePunchConnectorData { global_ctx: ArcGlobalCtx, peer_mgr: Arc, listeners: Arc>>, } #[derive(Clone)] struct UdpHolePunchRpcServer { data: Arc, tasks: Arc>>, } #[tarpc::server] impl UdpHolePunchService for UdpHolePunchRpcServer { async fn try_punch_hole( self, _: tarpc::context::Context, local_mapped_addr: SocketAddr, ) -> Option { let (socket, mapped_addr) = self.select_listener().await?; tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching"); let my_udp_nat_type = self .data .global_ctx .get_stun_info_collector() .get_stun_info() .udp_nat_type; // if we are restricted, we need to send hole punching resp to client if my_udp_nat_type == NatType::PortRestricted as i32 || my_udp_nat_type == NatType::Restricted as i32 { // send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second self.tasks.lock().unwrap().spawn(async move { for _ in 0..10 { tracing::info!(?local_mapped_addr, "sending hole punching packet"); let udp_packet = new_hole_punch_packet(); let _ = socket .send_to(&udp_packet.into_bytes(), local_mapped_addr) .await; tokio::time::sleep(std::time::Duration::from_millis(300)).await; } }); } Some(mapped_addr) } } impl UdpHolePunchRpcServer { pub fn new(data: Arc) -> Self { let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned()); Self { data, tasks } } async fn select_listener(&self) -> Option<(Arc, SocketAddr)> { let all_listener_sockets = &self.data.listeners; // remove listener that not have connection in for 20 seconds all_listener_sockets.lock().await.retain(|listener| { listener.last_connected_time.load().elapsed().as_secs() < 20 && listener.conn_counter.get() > 0 }); let mut use_last = false; if all_listener_sockets.lock().await.len() < 4 { tracing::warn!("creating new udp hole punching listener"); all_listener_sockets.lock().await.push( UdpHolePunchListener::new(self.data.peer_mgr.clone()) .await .ok()?, ); use_last = true; } let locked = all_listener_sockets.lock().await; let listener = if use_last { locked.last()? } else { locked.choose(&mut rand::rngs::StdRng::from_entropy())? }; Some((listener.get_socket().await, listener.mapped_addr)) } } pub struct UdpHolePunchConnector { data: Arc, tasks: JoinSet<()>, } // Currently support: // Symmetric -> Full Cone // Any Type of Full Cone -> Any Type of Full Cone // if same level of full cone, node with smaller peer_id will be the initiator // if different level of full cone, node with more strict level will be the initiator impl UdpHolePunchConnector { pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc) -> Self { Self { data: Arc::new(UdpHolePunchConnectorData { global_ctx, peer_mgr, listeners: Arc::new(Mutex::new(Vec::new())), }), tasks: JoinSet::new(), } } pub async fn run_as_client(&mut self) -> Result<(), Error> { let data = self.data.clone(); self.tasks.spawn(async move { Self::main_loop(data).await; }); Ok(()) } pub async fn run_as_server(&mut self) -> Result<(), Error> { self.data.peer_mgr.get_peer_rpc_mgr().run_service( constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID, UdpHolePunchRpcServer::new(self.data.clone()).serve(), ); Ok(()) } pub async fn run(&mut self) -> Result<(), Error> { self.run_as_client().await?; self.run_as_server().await?; Ok(()) } async fn collect_peer_to_connect(data: Arc) -> Vec { let mut peers_to_connect = Vec::new(); // do not do anything if: // 1. our stun test has not finished // 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us let my_nat_type = data .global_ctx .get_stun_info_collector() .get_stun_info() .udp_nat_type; let my_nat_type = NatType::try_from(my_nat_type).unwrap(); if my_nat_type == NatType::Unknown || my_nat_type == NatType::OpenInternet || my_nat_type == NatType::NoPat { return peers_to_connect; } // collect peer list from peer manager and do some filter: // 1. peers without direct conns; // 2. peers is full cone (any restricted type); for route in data.peer_mgr.list_routes().await.iter() { let Some(peer_stun_info) = route.stun_info.as_ref() else { continue; }; let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else { continue; }; let peer_id: PeerId = route.peer_id; let conns = data.peer_mgr.list_peer_conns(peer_id).await; if conns.is_some() && conns.unwrap().len() > 0 { continue; } // if peer is symmetric ignore it because we cannot connect to it // if peer is open internet or no pat, direct connector will connecto to it if peer_nat_type == NatType::Unknown || peer_nat_type == NatType::OpenInternet || peer_nat_type == NatType::NoPat || peer_nat_type == NatType::Symmetric || peer_nat_type == NatType::SymUdpFirewall { continue; } // if we are symmetric, we can only connect to full cone // TODO: can also connect to restricted full cone, with some extra work if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall) && peer_nat_type != NatType::FullCone { continue; } // if we have smae level of full cone, node with smaller peer_id will be the initiator if my_nat_type == peer_nat_type { if data.peer_mgr.my_peer_id() > peer_id { continue; } } else { // if we have different level of full cone // we will be the initiator if we have more strict level if my_nat_type < peer_nat_type { continue; } } tracing::info!( ?peer_id, ?peer_nat_type, ?my_nat_type, ?data.global_ctx.id, "found peer to do hole punching" ); peers_to_connect.push(peer_id); } peers_to_connect } #[tracing::instrument] async fn do_hole_punching( data: Arc, dst_peer_id: PeerId, ) -> Result, anyhow::Error> { tracing::info!(?dst_peer_id, "start hole punching"); // client: choose a local udp port, and get the pubic mapped port from stun server let socket = { let _g = data.global_ctx.net_ns.guard(); UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")? }; let local_socket_addr = socket.local_addr()?; let local_port = socket.local_addr()?.port(); drop(socket); // drop the socket to release the port let local_mapped_addr = data .global_ctx .get_stun_info_collector() .get_udp_port_mapping(local_port) .await .with_context(|| "failed to get udp port mapping")?; // client -> server: tell server the mapped port, server will return the mapped address of listening port. let Some(remote_mapped_addr) = data .peer_mgr .get_peer_rpc_mgr() .do_client_rpc_scoped( constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID, dst_peer_id, |c| async { let client = UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn(); let remote_mapped_addr = client .try_punch_hole(tarpc::context::current(), local_mapped_addr) .await; tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr"); remote_mapped_addr }, ) .await? else { return Err(anyhow::anyhow!("failed to get remote mapped addr")); }; // server: will send some punching resps, total 10 packets. // client: use the socket to create UdpTunnel with UdpTunnelConnector // NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote. let connector = UdpTunnelConnector::new( format!( "udp://{}:{}", remote_mapped_addr.ip(), remote_mapped_addr.port() ) .to_string() .parse() .unwrap(), ); let _g = data.global_ctx.net_ns.guard(); let socket2_socket = socket2::Socket::new( socket2::Domain::for_address(local_socket_addr), socket2::Type::DGRAM, Some(socket2::Protocol::UDP), )?; setup_sokcet2(&socket2_socket, &local_socket_addr)?; let socket = UdpSocket::from_std(socket2_socket.into())?; Ok(connector .try_connect_with_socket(socket, remote_mapped_addr) .await .with_context(|| "UdpTunnelConnector failed to connect remote")?) } async fn main_loop(data: Arc) { loop { let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await; tracing::trace!(?peers_to_connect, "peers to connect"); if peers_to_connect.len() == 0 { tokio::time::sleep(std::time::Duration::from_secs(5)).await; continue; } let mut tasks: JoinSet> = JoinSet::new(); for peer_id in peers_to_connect { let data = data.clone(); tasks.spawn( async move { let tunnel = Self::do_hole_punching(data.clone(), peer_id) .await .with_context(|| "failed to do hole punching")?; let _ = data.peer_mgr .add_client_tunnel(tunnel) .await .with_context(|| { "failed to add tunnel as client in hole punch connector" })?; Ok(()) } .instrument(tracing::info_span!("doing hole punching client", ?peer_id)), ); } while let Some(res) = tasks.join_next().await { if let Err(e) = res { tracing::error!(?e, "failed to join hole punching job"); continue; } match res.unwrap() { Err(e) => { tracing::error!(?e, "failed to do hole punching job"); } Ok(_) => { tracing::info!("hole punching job succeed"); } } } tokio::time::sleep(std::time::Duration::from_secs(10)).await; } } } #[cfg(test)] pub mod tests { use std::sync::Arc; use crate::rpc::{NatType, StunInfo}; use crate::{ common::{error::Error, stun::StunInfoCollectorTrait}, connector::udp_hole_punch::UdpHolePunchConnector, peers::{ peer_manager::PeerManager, tests::{ connect_peer_manager, create_mock_peer_manager, wait_route_appear, wait_route_appear_with_cost, }, }, }; struct MockStunInfoCollector { udp_nat_type: NatType, } #[async_trait::async_trait] impl StunInfoCollectorTrait for MockStunInfoCollector { fn get_stun_info(&self) -> StunInfo { StunInfo { udp_nat_type: self.udp_nat_type as i32, tcp_nat_type: NatType::Unknown as i32, last_update_time: std::time::Instant::now().elapsed().as_secs() as i64, } } async fn get_udp_port_mapping(&self, port: u16) -> Result { Ok(format!("127.0.0.1:{}", port).parse().unwrap()) } } pub fn replace_stun_info_collector(peer_mgr: Arc, udp_nat_type: NatType) { let collector = Box::new(MockStunInfoCollector { udp_nat_type }); peer_mgr .get_global_ctx() .replace_stun_info_collector(collector); } pub async fn create_mock_peer_manager_with_mock_stun( udp_nat_type: NatType, ) -> Arc { let p_a = create_mock_peer_manager().await; replace_stun_info_collector(p_a.clone(), udp_nat_type); p_a } #[tokio::test] async fn hole_punching() { let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await; let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; connect_peer_manager(p_a.clone(), p_b.clone()).await; connect_peer_manager(p_b.clone(), p_c.clone()).await; wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); println!("{:?}", p_a.list_routes().await); let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone()); let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone()); hole_punching_a.run().await.unwrap(); hole_punching_c.run().await.unwrap(); wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1)) .await .unwrap(); println!("{:?}", p_a.list_routes().await); } }