use std::{ mem::MaybeUninit, net::{IpAddr, Ipv4Addr, SocketAddrV4}, sync::Arc, thread, }; use pnet::packet::{ icmp::{self, IcmpTypes}, ip::IpNextHeaderProtocols, ipv4::{self, Ipv4Packet, MutableIpv4Packet}, Packet, }; use socket2::Socket; use tokio::{ sync::{mpsc::UnboundedSender, Mutex}, task::JoinSet, }; use tracing::Instrument; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, peers::{peer_manager::PeerManager, PeerPacketFilter}, tunnel::packet_def::{PacketType, ZCPacket}, }; use super::CidrSet; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct IcmpNatKey { dst_ip: std::net::IpAddr, icmp_id: u16, icmp_seq: u16, } #[derive(Debug)] struct IcmpNatEntry { src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr, start_time: std::time::Instant, } impl IcmpNatEntry { fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr) -> Result { Ok(Self { src_peer_id, my_peer_id, src_ip, start_time: std::time::Instant::now(), }) } } type IcmpNatTable = Arc>; type NewPacketSender = tokio::sync::mpsc::UnboundedSender; type NewPacketReceiver = tokio::sync::mpsc::UnboundedReceiver; #[derive(Debug)] pub struct IcmpProxy { global_ctx: ArcGlobalCtx, peer_manager: Arc, cidr_set: CidrSet, socket: std::sync::Mutex>, nat_table: IcmpNatTable, tasks: Mutex>, } fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit]) -> Result<(usize, IpAddr), Error> { let (size, addr) = socket.recv_from(buf)?; let addr = match addr.as_socket() { None => IpAddr::V4(Ipv4Addr::UNSPECIFIED), Some(add) => add.ip(), }; Ok((size, addr)) } fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSender) { let mut buf = [0u8; 2048]; let data: &mut [MaybeUninit] = unsafe { std::mem::transmute(&mut buf[..]) }; loop { let Ok((len, peer_ip)) = socket_recv(&socket, data) else { continue; }; if !peer_ip.is_ipv4() { continue; } let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[..len]) else { continue; }; let Some(icmp_packet) = icmp::echo_reply::EchoReplyPacket::new(ipv4_packet.payload()) else { continue; }; if icmp_packet.get_icmp_type() != IcmpTypes::EchoReply { continue; } let key = IcmpNatKey { dst_ip: peer_ip, icmp_id: icmp_packet.get_identifier(), icmp_seq: icmp_packet.get_sequence_number(), }; let Some((_, v)) = nat_table.remove(&key) else { continue; }; // send packet back to the peer where this request origin. let IpAddr::V4(dest_ip) = v.src_ip else { continue; }; ipv4_packet.set_destination(dest_ip); // MacOS do not correctly set ip length when receiving from raw socket ipv4_packet.set_total_length(len as u16); ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); let mut p = ZCPacket::new_with_payload(ipv4_packet.packet()); p.fill_peer_manager_hdr( v.my_peer_id.into(), v.src_peer_id.into(), PacketType::Data as u8, ); if let Err(e) = sender.send(p) { tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); break; } } } #[async_trait::async_trait] impl PeerPacketFilter for IcmpProxy { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { if let Some(_) = self.try_handle_peer_packet(&packet).await { return None; } else { return Some(packet); } } } impl IcmpProxy { pub fn new( global_ctx: ArcGlobalCtx, peer_manager: Arc, ) -> Result, Error> { let cidr_set = CidrSet::new(global_ctx.clone()); let ret = Self { global_ctx, peer_manager, cidr_set, socket: std::sync::Mutex::new(None), nat_table: Arc::new(dashmap::DashMap::new()), tasks: Mutex::new(JoinSet::new()), }; Ok(Arc::new(ret)) } pub async fn start(self: &Arc) -> Result<(), Error> { let _g = self.global_ctx.net_ns.guard(); let socket = socket2::Socket::new( socket2::Domain::IPV4, socket2::Type::RAW, Some(socket2::Protocol::ICMPV4), )?; socket.bind(&socket2::SockAddr::from(SocketAddrV4::new( std::net::Ipv4Addr::UNSPECIFIED, 0, )))?; self.socket.lock().unwrap().replace(socket); self.start_icmp_proxy().await?; self.start_nat_table_cleaner().await?; Ok(()) } async fn start_nat_table_cleaner(self: &Arc) -> Result<(), Error> { let nat_table = self.nat_table.clone(); self.tasks.lock().await.spawn( async move { loop { tokio::time::sleep(std::time::Duration::from_secs(1)).await; nat_table.retain(|_, v| v.start_time.elapsed().as_secs() < 20); } } .instrument(tracing::info_span!("icmp proxy nat table cleaner")), ); Ok(()) } async fn start_icmp_proxy(self: &Arc) -> Result<(), Error> { let socket = self.socket.lock().unwrap().as_ref().unwrap().try_clone()?; let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel(); let nat_table = self.nat_table.clone(); thread::spawn(|| { socket_recv_loop(socket, nat_table, sender); }); let peer_manager = self.peer_manager.clone(); self.tasks.lock().await.spawn( async move { while let Some(msg) = receiver.recv().await { let hdr = msg.peer_manager_header().unwrap(); let to_peer_id = hdr.to_peer_id.into(); let ret = peer_manager.send_msg(msg, to_peer_id).await; if ret.is_err() { tracing::error!("send icmp packet to peer failed: {:?}", ret); } } } .instrument(tracing::info_span!("icmp proxy send loop")), ); self.peer_manager .add_packet_process_pipeline(Box::new(self.clone())) .await; Ok(()) } fn send_icmp_packet( &self, dst_ip: Ipv4Addr, icmp_packet: &icmp::echo_request::EchoRequestPacket, ) -> Result<(), Error> { self.socket.lock().unwrap().as_ref().unwrap().send_to( icmp_packet.packet(), &SocketAddrV4::new(dst_ip.into(), 0).into(), )?; Ok(()) } async fn try_handle_peer_packet(&self, packet: &ZCPacket) -> Option<()> { let _ = self.global_ctx.get_ipv4()?; let hdr = packet.peer_manager_header().unwrap(); if hdr.packet_type != PacketType::Data as u8 { return None; }; let ipv4 = Ipv4Packet::new(&packet.payload())?; if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp { return None; } if !self.cidr_set.contains_v4(ipv4.get_destination()) { return None; } let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?; if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest { // drop it because we do not support other icmp types tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type()); return Some(()); } let icmp_id = icmp_packet.get_identifier(); let icmp_seq = icmp_packet.get_sequence_number(); let key = IcmpNatKey { dst_ip: ipv4.get_destination().into(), icmp_id, icmp_seq, }; let value = IcmpNatEntry::new( hdr.from_peer_id.into(), hdr.to_peer_id.into(), ipv4.get_source().into(), ) .ok()?; if let Some(old) = self.nat_table.insert(key, value) { tracing::info!("icmp nat table entry replaced: {:?}", old); } if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) { tracing::error!("send icmp packet failed: {:?}", e); } Some(()) } }