use workspace, prepare for config server and gui (#48)

This commit is contained in:
Sijie.Sun
2024-04-04 10:33:53 +08:00
committed by GitHub
parent bb4ae71869
commit 4eb7efe5fc
77 changed files with 162 additions and 195 deletions
+293
View File
@@ -0,0 +1,293 @@
use std::{
mem::MaybeUninit,
net::{IpAddr, Ipv4Addr, SocketAddrV4},
sync::Arc,
thread,
};
use pnet::packet::{
icmp::{self, IcmpTypes},
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
Packet,
};
use socket2::Socket;
use tokio::{
sync::{mpsc::UnboundedSender, Mutex},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct IcmpNatKey {
dst_ip: std::net::IpAddr,
icmp_id: u16,
icmp_seq: u16,
}
#[derive(Debug)]
struct IcmpNatEntry {
src_peer_id: PeerId,
my_peer_id: PeerId,
src_ip: IpAddr,
start_time: std::time::Instant,
}
impl IcmpNatEntry {
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr) -> Result<Self, Error> {
Ok(Self {
src_peer_id,
my_peer_id,
src_ip,
start_time: std::time::Instant::now(),
})
}
}
type IcmpNatTable = Arc<dashmap::DashMap<IcmpNatKey, IcmpNatEntry>>;
type NewPacketSender = tokio::sync::mpsc::UnboundedSender<IcmpNatKey>;
type NewPacketReceiver = tokio::sync::mpsc::UnboundedReceiver<IcmpNatKey>;
#[derive(Debug)]
pub struct IcmpProxy {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
cidr_set: CidrSet,
socket: socket2::Socket,
nat_table: IcmpNatTable,
tasks: Mutex<JoinSet<()>>,
}
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
let (size, addr) = socket.recv_from(buf)?;
let addr = match addr.as_socket() {
None => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
Some(add) => add.ip(),
};
Ok((size, addr))
}
fn socket_recv_loop(
socket: Socket,
nat_table: IcmpNatTable,
sender: UnboundedSender<packet::Packet>,
) {
let mut buf = [0u8; 4096];
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
loop {
let Ok((len, peer_ip)) = socket_recv(&socket, data) else {
continue;
};
if !peer_ip.is_ipv4() {
continue;
}
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[12..12 + len]) else {
continue;
};
let Some(icmp_packet) = icmp::echo_reply::EchoReplyPacket::new(ipv4_packet.payload())
else {
continue;
};
if icmp_packet.get_icmp_type() != IcmpTypes::EchoReply {
continue;
}
let key = IcmpNatKey {
dst_ip: peer_ip,
icmp_id: icmp_packet.get_identifier(),
icmp_seq: icmp_packet.get_sequence_number(),
};
let Some((_, v)) = nat_table.remove(&key) else {
continue;
};
// send packet back to the peer where this request origin.
let IpAddr::V4(dest_ip) = v.src_ip else {
continue;
};
ipv4_packet.set_destination(dest_ip);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
let peer_packet = packet::Packet::new_data_packet(
v.my_peer_id,
v.src_peer_id,
&ipv4_packet.to_immutable().packet(),
);
if let Err(e) = sender.send(peer_packet) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
break;
}
}
}
#[async_trait::async_trait]
impl PeerPacketFilter for IcmpProxy {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
_: &Bytes,
) -> Option<()> {
let _ = self.global_ctx.get_ipv4()?;
if packet.packet_type != packet::PacketType::Data {
return None;
};
let ipv4 = Ipv4Packet::new(&packet.payload.as_bytes())?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
{
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
// drop it because we do not support other icmp types
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
return Some(());
}
let icmp_id = icmp_packet.get_identifier();
let icmp_seq = icmp_packet.get_sequence_number();
let key = IcmpNatKey {
dst_ip: ipv4.get_destination().into(),
icmp_id,
icmp_seq,
};
let value = IcmpNatEntry::new(
packet.from_peer.into(),
packet.to_peer.into(),
ipv4.get_source().into(),
)
.ok()?;
if let Some(old) = self.nat_table.insert(key, value) {
tracing::info!("icmp nat table entry replaced: {:?}", old);
}
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
tracing::error!("send icmp packet failed: {:?}", e);
}
Some(())
}
}
impl IcmpProxy {
pub fn new(
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
) -> Result<Arc<Self>, Error> {
let cidr_set = CidrSet::new(global_ctx.clone());
let _g = global_ctx.net_ns.guard();
let socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::RAW,
Some(socket2::Protocol::ICMPV4),
)?;
socket.bind(&socket2::SockAddr::from(SocketAddrV4::new(
std::net::Ipv4Addr::UNSPECIFIED,
0,
)))?;
let ret = Self {
global_ctx,
peer_manager,
cidr_set,
socket,
nat_table: Arc::new(dashmap::DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
};
Ok(Arc::new(ret))
}
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
self.start_icmp_proxy().await?;
self.start_nat_table_cleaner().await?;
Ok(())
}
async fn start_nat_table_cleaner(self: &Arc<Self>) -> Result<(), Error> {
let nat_table = self.nat_table.clone();
self.tasks.lock().await.spawn(
async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
nat_table.retain(|_, v| v.start_time.elapsed().as_secs() < 20);
}
}
.instrument(tracing::info_span!("icmp proxy nat table cleaner")),
);
Ok(())
}
async fn start_icmp_proxy(self: &Arc<Self>) -> Result<(), Error> {
let socket = self.socket.try_clone()?;
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
let nat_table = self.nat_table.clone();
thread::spawn(|| {
socket_recv_loop(socket, nat_table, sender);
});
let peer_manager = self.peer_manager.clone();
self.tasks.lock().await.spawn(
async move {
while let Some(msg) = receiver.recv().await {
let to_peer_id = msg.to_peer.into();
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
if ret.is_err() {
tracing::error!("send icmp packet to peer failed: {:?}", ret);
}
}
}
.instrument(tracing::info_span!("icmp proxy send loop")),
);
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
Ok(())
}
fn send_icmp_packet(
&self,
dst_ip: Ipv4Addr,
icmp_packet: &icmp::echo_request::EchoRequestPacket,
) -> Result<(), Error> {
self.socket.send_to(
icmp_packet.packet(),
&SocketAddrV4::new(dst_ip.into(), 0).into(),
)?;
Ok(())
}
}
+56
View File
@@ -0,0 +1,56 @@
use dashmap::DashSet;
use std::sync::Arc;
use tokio::task::JoinSet;
use crate::common::global_ctx::ArcGlobalCtx;
pub mod icmp_proxy;
pub mod tcp_proxy;
pub mod udp_proxy;
#[derive(Debug)]
struct CidrSet {
global_ctx: ArcGlobalCtx,
cidr_set: Arc<DashSet<cidr::IpCidr>>,
tasks: JoinSet<()>,
}
impl CidrSet {
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
let mut ret = Self {
global_ctx,
cidr_set: Arc::new(DashSet::new()),
tasks: JoinSet::new(),
};
ret.run_cidr_updater();
ret
}
fn run_cidr_updater(&mut self) {
let global_ctx = self.global_ctx.clone();
let cidr_set = self.cidr_set.clone();
self.tasks.spawn(async move {
let mut last_cidrs = vec![];
loop {
let cidrs = global_ctx.get_proxy_cidrs();
if cidrs != last_cidrs {
last_cidrs = cidrs.clone();
cidr_set.clear();
for cidr in cidrs.iter() {
cidr_set.insert(cidr.clone());
}
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
}
pub fn contains_v4(&self, ip: std::net::Ipv4Addr) -> bool {
let ip = ip.into();
return self.cidr_set.iter().any(|cidr| cidr.contains(&ip));
}
pub fn is_empty(&self) -> bool {
return self.cidr_set.is_empty();
}
}
+407
View File
@@ -0,0 +1,407 @@
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use pnet::packet::ip::IpNextHeaderProtocols;
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::atomic::AtomicU16;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::copy_bidirectional;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::bytes::{Bytes, BytesMut};
use tracing::Instrument;
use crate::common::error::Result;
use crate::common::global_ctx::GlobalCtx;
use crate::common::join_joinset_background;
use crate::common::netns::NetNS;
use crate::peers::packet::{self, ArchivedPacket};
use crate::peers::peer_manager::PeerManager;
use crate::peers::{NicPacketFilter, PeerPacketFilter};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq)]
enum NatDstEntryState {
// receive syn packet but not start connecting to dst
SynReceived,
// connecting to dst
ConnectingDst,
// connected to dst
Connected,
// connection closed
Closed,
}
#[derive(Debug)]
pub struct NatDstEntry {
id: uuid::Uuid,
src: SocketAddr,
dst: SocketAddr,
start_time: Instant,
tasks: Mutex<JoinSet<()>>,
state: AtomicCell<NatDstEntryState>,
}
impl NatDstEntry {
pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
Self {
id: uuid::Uuid::new_v4(),
src,
dst,
start_time: Instant::now(),
tasks: Mutex::new(JoinSet::new()),
state: AtomicCell::new(NatDstEntryState::SynReceived),
}
}
}
type ArcNatDstEntry = Arc<NatDstEntry>;
type SynSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
type ConnSockMap = Arc<DashMap<uuid::Uuid, ArcNatDstEntry>>;
// peer src addr to nat entry, when respond tcp packet, should modify the tcp src addr to the nat entry's dst addr
type AddrConnSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
#[derive(Debug)]
pub struct TcpProxy {
global_ctx: Arc<GlobalCtx>,
peer_manager: Arc<PeerManager>,
local_port: AtomicU16,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
syn_map: SynSockMap,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
cidr_set: CidrSet,
}
#[async_trait::async_trait]
impl PeerPacketFilter for TcpProxy {
async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> {
let ipv4_addr = self.global_ctx.get_ipv4()?;
if packet.packet_type != packet::PacketType::Data {
return None;
};
let payload_bytes = packet.payload.as_bytes();
let ipv4 = Ipv4Packet::new(payload_bytes)?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
let mut packet_buffer = BytesMut::with_capacity(payload_bytes.len());
packet_buffer.extend_from_slice(&payload_bytes.to_vec());
let (ip_buffer, tcp_buffer) =
packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4);
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
if is_tcp_syn {
let source_ip = ip_packet.get_source();
let source_port = tcp_packet.get_source();
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
let dest_ip = ip_packet.get_destination();
let dest_port = tcp_packet.get_destination();
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
let old_val = self
.syn_map
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
}
ip_packet.set_destination(ipv4_addr);
tcp_packet.set_destination(self.get_local_port());
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded");
if let Err(e) = self
.peer_manager
.get_nic_channel()
.send(packet_buffer.freeze())
.await
{
tracing::error!("send to nic failed: {:?}", e);
}
Some(())
}
}
#[async_trait::async_trait]
impl NicPacketFilter for TcpProxy {
async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut {
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
return data;
};
let header_len = {
let Some(ipv4) = &Ipv4Packet::new(&data[..]) else {
return data;
};
if ipv4.get_version() != 4
|| ipv4.get_source() != my_ipv4
|| ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
{
return data;
}
ipv4.get_header_length() as usize * 4
};
let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len);
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
if tcp_packet.get_source() != self.get_local_port() {
return data;
}
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
ip_packet.get_destination(),
tcp_packet.get_destination(),
));
tracing::trace!(dst_addr = ?dst_addr, "tcp packet try find entry");
let entry = if let Some(entry) = self.addr_conn_map.get(&dst_addr) {
entry
} else {
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
return data;
};
syn_entry
};
let nat_entry = entry.clone();
drop(entry);
assert_eq!(nat_entry.src, dst_addr);
let IpAddr::V4(ip) = nat_entry.dst.ip() else {
panic!("v4 nat entry src ip is not v4");
};
ip_packet.set_source(ip);
tcp_packet.set_source(nat_entry.dst.port());
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
data
}
}
impl TcpProxy {
pub fn new(global_ctx: Arc<GlobalCtx>, peer_manager: Arc<PeerManager>) -> Arc<Self> {
Arc::new(Self {
global_ctx: global_ctx.clone(),
peer_manager,
local_port: AtomicU16::new(0),
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
syn_map: Arc::new(DashMap::new()),
conn_map: Arc::new(DashMap::new()),
addr_conn_map: Arc::new(DashMap::new()),
cidr_set: CidrSet::new(global_ctx),
})
}
fn update_ipv4_packet_checksum(
ipv4_packet: &mut MutableIpv4Packet,
tcp_packet: &mut MutableTcpPacket,
) {
tcp_packet.set_checksum(ipv4_checksum(
&tcp_packet.to_immutable(),
&ipv4_packet.get_source(),
&ipv4_packet.get_destination(),
));
ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable()));
}
pub async fn start(self: &Arc<Self>) -> Result<()> {
self.run_syn_map_cleaner().await?;
self.run_listener().await?;
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.clone()))
.await;
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());
Ok(())
}
async fn run_syn_map_cleaner(&self) -> Result<()> {
let syn_map = self.syn_map.clone();
let tasks = self.tasks.clone();
let syn_map_cleaner_task = async move {
loop {
syn_map.retain(|_, entry| {
if entry.start_time.elapsed() > Duration::from_secs(30) {
tracing::warn!(entry = ?entry, "syn nat entry expired");
entry.state.store(NatDstEntryState::Closed);
false
} else {
true
}
});
tokio::time::sleep(Duration::from_secs(10)).await;
}
};
tasks.lock().unwrap().spawn(syn_map_cleaner_task);
Ok(())
}
async fn run_listener(&self) -> Result<()> {
// bind on both v4 & v6
let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
let net_ns = self.global_ctx.net_ns.clone();
let tcp_listener = net_ns
.run_async(|| async { TcpListener::bind(&listen_addr).await })
.await?;
self.local_port.store(
tcp_listener.local_addr()?.port(),
std::sync::atomic::Ordering::Relaxed,
);
let tasks = self.tasks.clone();
let syn_map = self.syn_map.clone();
let conn_map = self.conn_map.clone();
let addr_conn_map = self.addr_conn_map.clone();
let accept_task = async move {
tracing::info!(listener = ?tcp_listener, "tcp connection start accepting");
let conn_map = conn_map.clone();
while let Ok((tcp_stream, socket_addr)) = tcp_listener.accept().await {
let Some(entry) = syn_map.get(&socket_addr) else {
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
continue;
};
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
let entry_clone = entry.clone();
drop(entry);
syn_map.remove_if(&socket_addr, |_, entry| entry.id == entry_clone.id);
entry_clone.state.store(NatDstEntryState::ConnectingDst);
let _ = addr_conn_map.insert(entry_clone.src, entry_clone.clone());
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
assert!(old_nat_val.is_none());
tasks.lock().unwrap().spawn(Self::connect_to_nat_dst(
net_ns.clone(),
tcp_stream,
conn_map.clone(),
addr_conn_map.clone(),
entry_clone,
));
}
tracing::error!("nat tcp listener exited");
panic!("nat tcp listener exited");
};
self.tasks
.lock()
.unwrap()
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
Ok(())
}
fn remove_entry_from_all_conn_map(
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
conn_map.remove(&nat_entry.id);
addr_conn_map.remove_if(&nat_entry.src, |_, entry| entry.id == nat_entry.id);
}
async fn connect_to_nat_dst(
net_ns: NetNS,
src_tcp_stream: TcpStream,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
if let Err(e) = src_tcp_stream.set_nodelay(true) {
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
}
let _guard = net_ns.guard();
let socket = TcpSocket::new_v4().unwrap();
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
}
let Ok(Ok(dst_tcp_stream)) = tokio::time::timeout(
Duration::from_secs(10),
TcpSocket::new_v4().unwrap().connect(nat_entry.dst),
)
.await
else {
tracing::error!("connect to dst failed: {:?}", nat_entry);
nat_entry.state.store(NatDstEntryState::Closed);
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry);
return;
};
drop(_guard);
assert_eq!(nat_entry.state.load(), NatDstEntryState::ConnectingDst);
nat_entry.state.store(NatDstEntryState::Connected);
Self::handle_nat_connection(
src_tcp_stream,
dst_tcp_stream,
conn_map,
addr_conn_map,
nat_entry,
)
.await;
}
async fn handle_nat_connection(
mut src_tcp_stream: TcpStream,
mut dst_tcp_stream: TcpStream,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
let nat_entry_clone = nat_entry.clone();
nat_entry.tasks.lock().await.spawn(async move {
let ret = copy_bidirectional(&mut src_tcp_stream, &mut dst_tcp_stream).await;
tracing::trace!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
nat_entry_clone.state.store(NatDstEntryState::Closed);
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry_clone);
});
}
pub fn get_local_port(&self) -> u16 {
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
}
}
+383
View File
@@ -0,0 +1,383 @@
use std::{
net::{SocketAddr, SocketAddrV4},
sync::{atomic::AtomicBool, Arc},
time::Duration,
};
use dashmap::DashMap;
use pnet::packet::{
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet},
udp::{self, MutableUdpPacket},
Packet,
};
use tokio::{
net::UdpSocket,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
Mutex,
},
task::{JoinHandle, JoinSet},
time::timeout,
};
use tokio_util::bytes::Bytes;
use tracing::Level;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
tunnels::common::setup_sokcet2,
};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct UdpNatKey {
src_socket: SocketAddr,
}
#[derive(Debug)]
struct UdpNatEntry {
src_peer_id: PeerId,
my_peer_id: PeerId,
src_socket: SocketAddr,
socket: UdpSocket,
forward_task: Mutex<Option<JoinHandle<()>>>,
stopped: AtomicBool,
start_time: std::time::Instant,
}
impl UdpNatEntry {
#[tracing::instrument(err(level = Level::WARN))]
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_socket: SocketAddr) -> Result<Self, Error> {
// TODO: try use src port, so we will be ip restricted nat type
let socket2_socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
setup_sokcet2(&socket2_socket, &dst_socket_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
Ok(Self {
src_peer_id,
my_peer_id,
src_socket,
socket,
forward_task: Mutex::new(None),
stopped: AtomicBool::new(false),
start_time: std::time::Instant::now(),
})
}
pub fn stop(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
async fn compose_ipv4_packet(
self: &Arc<Self>,
packet_sender: &mut UnboundedSender<packet::Packet>,
buf: &mut [u8],
src_v4: &SocketAddrV4,
payload_len: usize,
payload_mtu: usize,
ip_id: u16,
) -> Result<(), Error> {
let SocketAddr::V4(nat_src_v4) = self.src_socket else {
return Err(Error::Unknown);
};
assert_eq!(0, payload_mtu % 8);
// udp payload is in buf[20 + 8..]
let mut udp_packet = MutableUdpPacket::new(&mut buf[20..28 + payload_len]).unwrap();
udp_packet.set_source(src_v4.port());
udp_packet.set_destination(self.src_socket.port());
udp_packet.set_length(payload_len as u16 + 8);
udp_packet.set_checksum(udp::ipv4_checksum(
&udp_packet.to_immutable(),
src_v4.ip(),
nat_src_v4.ip(),
));
let payload_len = payload_len + 8; // include udp header
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20])
.unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
} else {
ipv4_packet.set_flags(0);
}
assert_eq!(0, fragment_offset % 8);
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
} else {
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
ipv4_packet.set_fragment_offset(0);
}
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.ip().clone());
ipv4_packet.set_destination(nat_src_v4.ip().clone());
ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
tracing::trace!(?ipv4_packet, "udp nat packet response send");
let peer_packet = packet::Packet::new_data_packet(
self.my_peer_id,
self.src_peer_id,
&ipv4_packet.to_immutable().packet(),
);
if let Err(e) = packet_sender.send(peer_packet) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
return Err(Error::AnyhowError(e.into()));
}
buf_offset += next_fragment_offset - fragment_offset;
fragment_offset = next_fragment_offset;
cur_piece += 1;
}
Ok(())
}
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<packet::Packet>) {
let mut buf = [0u8; 8192];
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
let mut ip_id = 1;
loop {
let (len, src_socket) = match timeout(
Duration::from_secs(120),
self.socket.recv_from(&mut udp_body),
)
.await
{
Ok(Ok(x)) => x,
Ok(Err(err)) => {
tracing::error!(?err, "udp nat recv failed");
break;
}
Err(err) => {
tracing::error!(?err, "udp nat recv timeout");
break;
}
};
tracing::trace!(?len, ?src_socket, "udp nat packet response received");
if self.stopped.load(std::sync::atomic::Ordering::Relaxed) {
break;
}
let SocketAddr::V4(src_v4) = src_socket else {
continue;
};
let Ok(_) = Self::compose_ipv4_packet(
&self,
&mut packet_sender,
&mut buf,
&src_v4,
len,
1200,
ip_id,
)
.await
else {
break;
};
ip_id = ip_id.wrapping_add(1);
}
self.stop();
}
}
#[derive(Debug)]
pub struct UdpProxy {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
cidr_set: CidrSet,
nat_table: Arc<DashMap<UdpNatKey, Arc<UdpNatEntry>>>,
sender: UnboundedSender<packet::Packet>,
receiver: Mutex<Option<UnboundedReceiver<packet::Packet>>>,
tasks: Mutex<JoinSet<()>>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for UdpProxy {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
_: &Bytes,
) -> Option<()> {
if self.cidr_set.is_empty() {
return None;
}
let _ = self.global_ctx.get_ipv4()?;
if packet.packet_type != packet::PacketType::Data {
return None;
};
let ipv4 = Ipv4Packet::new(packet.payload.as_bytes())?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp {
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
let udp_packet = udp::UdpPacket::new(ipv4.payload())?;
tracing::trace!(
?packet,
?ipv4,
?udp_packet,
"udp nat packet request received"
);
let nat_key = UdpNatKey {
src_socket: SocketAddr::new(ipv4.get_source().into(), udp_packet.get_source()),
};
let nat_entry = self
.nat_table
.entry(nat_key)
.or_try_insert_with::<Error>(|| {
tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created");
let _g = self.global_ctx.net_ns.guard();
Ok(Arc::new(UdpNatEntry::new(
packet.from_peer.into(),
packet.to_peer.into(),
nat_key.src_socket,
)?))
})
.ok()?
.clone();
if nat_entry.forward_task.lock().await.is_none() {
nat_entry
.forward_task
.lock()
.await
.replace(tokio::spawn(UdpNatEntry::forward_task(
nat_entry.clone(),
self.sender.clone(),
)));
}
// TODO: should it be async.
let dst_socket =
SocketAddr::new(ipv4.get_destination().into(), udp_packet.get_destination());
let send_ret = {
let _g = self.global_ctx.net_ns.guard();
nat_entry
.socket
.send_to(udp_packet.payload(), dst_socket)
.await
};
if let Err(send_err) = send_ret {
tracing::error!(
?send_err,
?nat_key,
?nat_entry,
?send_err,
"udp nat send failed"
);
}
Some(())
}
}
impl UdpProxy {
pub fn new(
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
) -> Result<Arc<Self>, Error> {
let cidr_set = CidrSet::new(global_ctx.clone());
let (sender, receiver) = unbounded_channel();
let ret = Self {
global_ctx,
peer_manager,
cidr_set,
nat_table: Arc::new(DashMap::new()),
sender,
receiver: Mutex::new(Some(receiver)),
tasks: Mutex::new(JoinSet::new()),
};
Ok(Arc::new(ret))
}
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
// clean up nat table
let nat_table = self.nat_table.clone();
self.tasks.lock().await.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(15)).await;
nat_table.retain(|_, v| {
if v.start_time.elapsed().as_secs() > 120 {
tracing::info!(?v, "udp nat table entry removed");
v.stop();
false
} else {
true
}
});
}
});
// forward packets to peer manager
let mut receiver = self.receiver.lock().await.take().unwrap();
let peer_manager = self.peer_manager.clone();
self.tasks.lock().await.spawn(async move {
while let Some(msg) = receiver.recv().await {
let to_peer_id: PeerId = msg.to_peer.into();
tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send");
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
if ret.is_err() {
tracing::error!("send icmp packet to peer failed: {:?}", ret);
}
}
});
Ok(())
}
}
impl Drop for UdpProxy {
fn drop(&mut self) {
for v in self.nat_table.iter() {
v.stop();
}
}
}