feat: add upnp support (#1449)

This commit is contained in:
Debugger Chen
2026-04-21 17:19:04 +08:00
committed by GitHub
parent f4319c4d4f
commit 5cd0a3e846
26 changed files with 3707 additions and 235 deletions
+10 -1
View File
@@ -28,10 +28,19 @@ pub enum UdpPacketType {
Data = 3,
Fin = 4,
HolePunch = 5,
V6HolePunch = 6, // when receiving v6 hole punch packet, the packet contains a socket addr of other peer, we
V4HolePunch = 6, // when receiving v4 hole punch packet, the packet contains a socket addr of other peer, we
// will send a hole punch packet to that peer. we only accept this packet from loopback interface.
V6HolePunch = 7, // when receiving v6 hole punch packet, the packet contains a socket addr of other peer, we
// will send a hole punch packet to that peer. we only accept this packet from lookback interface.
}
#[repr(C, packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
pub struct V4HolePunchPacket {
pub dst_ipv4: [u8; 4],
pub dst_port: U16<DefaultEndian>,
}
#[repr(C, packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
pub struct V6HolePunchPacket {
+108 -12
View File
@@ -1,6 +1,6 @@
use std::{
fmt::Debug,
net::{Ipv6Addr, SocketAddrV6},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
sync::{Arc, Weak},
};
@@ -12,7 +12,6 @@ use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
use rand::{Rng, SeedableRng};
use zerocopy::{AsBytes, FromBytes};
use std::net::SocketAddr;
use tokio::{
net::UdpSocket,
sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
@@ -24,7 +23,7 @@ use super::{
FromUrl, IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelInfo, TunnelListener,
TunnelUrl,
common::wait_for_connect_futures,
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V6HolePunchPacket},
packet_def::{UDP_TUNNEL_HEADER_SIZE, UDPTunnelHeader, V4HolePunchPacket, V6HolePunchPacket},
ring::{RingSink, RingStream},
};
use crate::tunnel::common::bind;
@@ -114,6 +113,28 @@ pub fn new_v6_hole_punch_packet(dst: &SocketAddrV6) -> ZCPacket {
)
}
pub fn new_v4_hole_punch_packet(dst: &SocketAddrV4) -> ZCPacket {
let mut body = V4HolePunchPacket::default();
body.dst_ipv4.copy_from_slice(&dst.ip().octets());
body.dst_port.set(dst.port());
new_udp_packet(
|header| {
header.msg_type = UdpPacketType::V4HolePunch as u8;
header.conn_id.set(dst.port() as u32);
header
.len
.set(std::mem::size_of::<V4HolePunchPacket>() as u16);
},
Some(body.as_bytes()),
)
}
fn extract_dst_addr_from_v4_hole_punch_packet(buf: &[u8]) -> Option<SocketAddrV4> {
let body = V4HolePunchPacket::ref_from_prefix(buf)?;
let ip = Ipv4Addr::from(body.dst_ipv4);
Some(SocketAddrV4::new(ip, body.dst_port.get()))
}
fn extrace_dst_addr_from_hole_punch_packet(buf: &[u8]) -> Option<SocketAddrV6> {
let body = V6HolePunchPacket::ref_from_prefix(buf)?;
let ip = Ipv6Addr::from(body.dst_ipv6);
@@ -142,6 +163,21 @@ pub async fn send_v6_hole_punch_packet(
Ok(())
}
pub async fn send_v4_hole_punch_packet(
listener_port: u16,
dst_addr: SocketAddrV4,
) -> Result<(), TunnelError> {
let local_socket = UdpSocket::bind("127.0.0.1:0").await?;
let udp_packet = new_v4_hole_punch_packet(&dst_addr);
let remote_addr = format!("127.0.0.1:{}", listener_port)
.parse::<SocketAddr>()
.unwrap();
local_socket
.send_to(&udp_packet.into_bytes(), remote_addr)
.await?;
Ok(())
}
async fn respond_stun_packet(
socket: Arc<UdpSocket>,
addr: SocketAddr,
@@ -455,6 +491,27 @@ impl UdpTunnelListenerData {
tracing::error!(?e, "udp respond stun packet error");
}
});
} else if header.msg_type == UdpPacketType::V4HolePunch as u8 {
if !addr.ip().is_loopback() {
tracing::warn!(?addr, "v4 hole punch packet should be from loopback");
return;
}
if !addr.ip().is_ipv4() {
tracing::warn!(?addr, "v4 hole punch packet should be sent from ipv4");
return;
}
let Some(dst_addr) =
extract_dst_addr_from_v4_hole_punch_packet(zc_packet.udp_payload())
else {
tracing::warn!("invalid v4 hole punch packet");
return;
};
let socket = self.socket.as_ref().unwrap().clone();
let udp_packet = new_hole_punch_packet(1, 32);
if let Err(e) = socket.try_send_to(&udp_packet.into_bytes(), SocketAddr::V4(dst_addr)) {
tracing::error!(?e, "udp send hole punch packet error");
}
tracing::debug!(?dst_addr, "udp forward packet send hole punch packet");
} else if header.msg_type == UdpPacketType::V6HolePunch as u8 {
if !addr.ip().is_loopback() {
tracing::warn!(?addr, "v6 hole punch packet should be from loopback");
@@ -527,6 +584,12 @@ impl UdpTunnelListener {
}
}
pub fn new_with_socket(addr: url::Url, socket: Arc<UdpSocket>) -> Self {
let mut listener = Self::new(addr);
listener.socket = Some(socket);
listener
}
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
self.socket.clone()
}
@@ -535,15 +598,17 @@ impl UdpTunnelListener {
#[async_trait]
impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
self.socket = Some(Arc::new(
bind()
.addr(addr)
.only_v6(true)
.maybe_dev(tunnel_url.bind_dev())
.call()?,
));
if self.socket.is_none() {
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
self.socket = Some(Arc::new(
bind()
.addr(addr)
.only_v6(true)
.maybe_dev(tunnel_url.bind_dev())
.call()?,
));
}
self.data.socket = self.socket.clone();
self.addr
@@ -1147,4 +1212,35 @@ mod tests {
.expect("Timeout waiting for v6 hole punch packet")
.unwrap();
}
#[tokio::test]
async fn test_v4_hole_punch_packet() {
let mut lis = UdpTunnelListener::new("udp://0.0.0.0:0".parse().unwrap());
lis.listen().await.unwrap();
let socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
let socket_clone = socket.clone();
let t = tokio::spawn(async move {
let mut buf = BytesMut::new();
buf.resize(128, 0);
socket_clone.recv_from(&mut buf).await.unwrap();
});
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
send_v4_hole_punch_packet(
lis.local_url().port().unwrap(),
match socket.local_addr().unwrap() {
std::net::SocketAddr::V4(addr_v4) => addr_v4,
_ => panic!("Expected an IPv4 address"),
},
)
.await
.unwrap();
tokio::time::timeout(tokio::time::Duration::from_secs(2), t)
.await
.expect("Timeout waiting for v4 hole punch packet")
.unwrap();
}
}