v6 hole punch (#873)

Some devices have ipv6 but don't allow input connection, this patch add hole punching for these devices.

- **add v6 hole punch msg to udp tunnel**
- **send hole punch packet when do ipv6 direct connect**
This commit is contained in:
Sijie.Sun
2025-05-24 22:57:33 +08:00
committed by GitHub
parent fc397c35c5
commit 29994b663a
15 changed files with 499 additions and 198 deletions
-15
View File
@@ -177,21 +177,6 @@ pub(crate) trait FromUrl {
Self: Sized;
}
pub(crate) async fn check_scheme_and_get_socket_addr_ext<T>(
url: &url::Url,
scheme: &str,
ip_version: IpVersion,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
Ok(T::from_url(url.clone(), ip_version).await?)
}
pub(crate) async fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
+9
View File
@@ -28,6 +28,15 @@ pub enum UdpPacketType {
Data = 3,
Fin = 4,
HolePunch = 5,
V6HolePunch = 6, // when receiving v6 hole punch packet, the packet contains a socket addr of other peer, we
// will send a hole punch packet to that peer. we only accept this packet from lookback interface.
}
#[repr(C, packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
pub struct V6HolePunchPacket {
pub dst_ipv6: [u8; 16],
pub dst_port: U16<DefaultEndian>,
}
#[repr(C, packed)]
+1 -2
View File
@@ -5,7 +5,6 @@
use std::{error::Error, net::SocketAddr, sync::Arc};
use crate::tunnel::{
check_scheme_and_get_socket_addr_ext,
common::{FramedReader, FramedWriter, TunnelWrapper},
TunnelInfo,
};
@@ -151,7 +150,7 @@ impl QUICTunnelConnector {
impl TunnelConnector for QUICTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "quic", self.ip_version)
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", self.ip_version)
.await?;
let local_addr = if addr.is_ipv4() {
"0.0.0.0:0"
+2 -2
View File
@@ -8,7 +8,7 @@ use super::TunnelInfo;
use crate::tunnel::common::setup_sokcet2;
use super::{
check_scheme_and_get_socket_addr, check_scheme_and_get_socket_addr_ext,
check_scheme_and_get_socket_addr,
common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper},
IpVersion, Tunnel, TunnelError, TunnelListener,
};
@@ -191,7 +191,7 @@ impl TcpTunnelConnector {
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let addr =
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "tcp", self.ip_version)
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", self.ip_version)
.await?;
if self.bind_addrs.is_empty() {
self.connect_with_default_bind(addr).await
+102 -4
View File
@@ -1,5 +1,6 @@
use std::{
fmt::Debug,
net::{Ipv6Addr, SocketAddrV6},
sync::{Arc, Weak},
};
@@ -9,7 +10,7 @@ use bytes::BytesMut;
use dashmap::DashMap;
use futures::{stream::FuturesUnordered, StreamExt};
use rand::{Rng, SeedableRng};
use zerocopy::AsBytes;
use zerocopy::{AsBytes, FromBytes};
use std::net::SocketAddr;
use tokio::{
@@ -20,7 +21,7 @@ use tokio::{
use tracing::{instrument, Instrument};
use super::TunnelInfo;
use super::{packet_def::V6HolePunchPacket, TunnelInfo};
use crate::{
common::{join_joinset_background, scoped_task::ScopedTask},
tunnel::{
@@ -43,7 +44,7 @@ pub const UDP_DATA_MTU: usize = 2000;
type UdpCloseEventSender = UnboundedSender<(SocketAddr, Option<TunnelError>)>;
type UdpCloseEventReceiver = UnboundedReceiver<(SocketAddr, Option<TunnelError>)>;
fn new_udp_packet<F>(f: F, udp_body: Option<&mut [u8]>) -> ZCPacket
fn new_udp_packet<F>(f: F, udp_body: Option<&[u8]>) -> ZCPacket
where
F: FnOnce(&mut UDPTunnelHeader),
{
@@ -97,6 +98,29 @@ pub fn new_hole_punch_packet(tid: u32, buf_len: u16) -> ZCPacket {
)
}
pub fn new_v6_hole_punch_packet(dst: &SocketAddrV6) -> ZCPacket {
// generate a 128 bytes vec with random data
let mut body = V6HolePunchPacket::default();
body.dst_ipv6.copy_from_slice(&dst.ip().octets());
body.dst_port.set(dst.port());
new_udp_packet(
|header| {
header.msg_type = UdpPacketType::V6HolePunch as u8;
header.conn_id.set(dst.port() as u32);
header
.len
.set(std::mem::size_of::<V6HolePunchPacket>() as u16);
},
Some(body.as_bytes()),
)
}
fn extrace_dst_addr_from_hole_punch_packet(buf: &[u8]) -> Option<SocketAddrV6> {
let body = V6HolePunchPacket::ref_from_prefix(&buf[..])?;
let ip = Ipv6Addr::from(body.dst_ipv6);
Some(SocketAddrV6::new(ip, body.dst_port.get(), 0, 0))
}
fn is_stun_packet(b: &[u8]) -> bool {
// stun has following pattern:
// 1. first two bits are 0b00
@@ -104,6 +128,21 @@ fn is_stun_packet(b: &[u8]) -> bool {
b[4..8] == [0x21, 0x12, 0xA4, 0x42] && b[0] & 0xC0 == 0
}
pub async fn send_v6_hole_punch_packet(
listener_port: u16,
dst_addr: SocketAddrV6,
) -> Result<(), TunnelError> {
let local_socket = UdpSocket::bind("[::1]:0").await?;
let udp_packet = new_v6_hole_punch_packet(&dst_addr);
let remote_addr = format!("[::1]:{}", listener_port)
.parse::<SocketAddr>()
.unwrap();
local_socket
.send_to(&udp_packet.into_bytes(), remote_addr)
.await?;
Ok(())
}
async fn respond_stun_packet(
socket: Arc<UdpSocket>,
addr: SocketAddr,
@@ -421,6 +460,27 @@ impl UdpTunnelListenerData {
tracing::error!(?e, "udp respond stun packet error");
}
});
} else if header.msg_type == UdpPacketType::V6HolePunch as u8 {
if !addr.ip().is_loopback() {
tracing::warn!(?addr, "v6 hole punch packet should be from loopback");
return;
}
if !addr.ip().is_ipv6() {
tracing::warn!(?addr, "v6 hole punch packet should be sent from ipv6");
return;
}
let Some(dst_addr) = extrace_dst_addr_from_hole_punch_packet(zc_packet.udp_payload())
else {
tracing::warn!("invalid v6 hole punch packet");
return;
};
let socket = self.socket.as_ref().unwrap().clone();
let udp_packet = new_hole_punch_packet(1, 32);
if let Err(e) = socket.try_send_to(&udp_packet.into_bytes(), SocketAddr::V6(dst_addr)) {
tracing::error!(?e, "udp send hole punch packet error");
}
tracing::debug!(?dst_addr, "udp forward packet send hole punch packet");
return;
} else if header.msg_type != UdpPacketType::HolePunch as u8 {
let Some(mut conn) = self.sock_map.get_mut(&addr) else {
tracing::trace!(?header, "udp forward packet error, connection not found");
@@ -429,6 +489,8 @@ impl UdpTunnelListenerData {
if let Err(e) = conn.handle_packet_from_remote(zc_packet) {
tracing::trace!(?e, "udp forward packet error");
}
} else {
tracing::trace!(?header, "udp forward packet ignore hole punch packet");
}
}
@@ -778,7 +840,7 @@ impl UdpTunnelConnector {
#[async_trait]
impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr_ext::<SocketAddr>(
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"udp",
self.ip_version,
@@ -1055,4 +1117,40 @@ mod tests {
)
.await;
}
#[tokio::test]
async fn test_v6_hole_punch_packet() {
let mut lis = UdpTunnelListener::new("udp://[::]:0".parse().unwrap());
lis.listen().await.unwrap();
// a socket to receive forwarded hole punch packets
let socket = Arc::new(UdpSocket::bind("[::]:0").await.unwrap());
let socket_clone = socket.clone();
let t = tokio::spawn(async move {
let mut buf = BytesMut::new();
buf.resize(128, 0);
socket_clone.recv_from(&mut buf).await.unwrap();
});
tracing::info!("lis local addr: {:?}", lis.local_url());
tracing::info!("socket local addr: {:?}", socket.local_addr().unwrap());
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// a socket to send v6 hole punch packets
send_v6_hole_punch_packet(
lis.local_url().port().unwrap(),
match socket.local_addr().unwrap() {
std::net::SocketAddr::V6(addr_v6) => addr_v6,
_ => panic!("Expected an IPv6 address"),
},
)
.await
.unwrap();
tokio::time::timeout(tokio::time::Duration::from_secs(2), t)
.await
.expect("Timeout waiting for v6 hole punch packet")
.unwrap();
}
}
+1 -1
View File
@@ -702,7 +702,7 @@ impl WgTunnelConnector {
impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr_ext::<SocketAddr>(
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"wg",
self.ip_version,