feat: add upnp support (#1449)

This commit is contained in:
Debugger Chen
2026-04-21 17:19:04 +08:00
committed by GitHub
parent f4319c4d4f
commit 5cd0a3e846
26 changed files with 3707 additions and 235 deletions
+223 -47
View File
@@ -13,8 +13,7 @@ use zerocopy::FromBytes as _;
use crate::{
common::{
PeerId, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
stun::StunInfoCollectorTrait as _,
PeerId, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, upnp,
},
defer,
peers::peer_manager::PeerManager,
@@ -27,6 +26,7 @@ use crate::{
};
pub(crate) const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 16;
const MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS: usize = 4;
fn generate_shuffled_port_vec() -> Vec<u16> {
let mut rng = rand::thread_rng();
@@ -352,6 +352,8 @@ pub(crate) struct UdpHolePunchListener {
tasks: JoinSet<()>,
running: Arc<AtomicCell<bool>>,
mapped_addr: SocketAddr,
has_port_mapping_lease: bool,
_port_mapping_lease: Option<upnp::UdpPortMappingLease>,
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
listen_time: std::time::Instant,
@@ -360,11 +362,6 @@ pub(crate) struct UdpHolePunchListener {
}
impl UdpHolePunchListener {
async fn get_avail_port() -> Result<u16, Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
Ok(socket.local_addr()?.port())
}
#[instrument(err)]
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
Self::new_ext(peer_mgr, true, None).await
@@ -376,18 +373,24 @@ impl UdpHolePunchListener {
with_mapped_addr: bool,
port: Option<u16>,
) -> Result<Self, Error> {
let port = port.unwrap_or(Self::get_avail_port().await?);
let listen_url = format!("udp://0.0.0.0:{}", port);
let socket = {
let _g = peer_mgr.get_global_ctx().net_ns.guard();
Arc::new(UdpSocket::bind((Ipv4Addr::UNSPECIFIED, port.unwrap_or(0))).await?)
};
let local_port = socket.local_addr()?.port();
let listen_url: url::Url = format!("udp://0.0.0.0:{local_port}").parse().unwrap();
let mapped_addr = if with_mapped_addr {
let gctx = peer_mgr.get_global_ctx();
let stun_info_collect = gctx.get_stun_info_collector();
stun_info_collect.get_udp_port_mapping(port).await?
let (mapped_addr, port_mapping_lease) = if with_mapped_addr {
upnp::resolve_udp_public_addr(peer_mgr.get_global_ctx(), &listen_url, socket.clone())
.await?
} else {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), port))
(
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local_port)),
None,
)
};
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
let mut listener = UdpTunnelListener::new_with_socket(listen_url, socket.clone());
{
let _g = peer_mgr.get_global_ctx().net_ns.guard();
@@ -437,6 +440,8 @@ impl UdpHolePunchListener {
socket,
running,
mapped_addr,
has_port_mapping_lease: port_mapping_lease.is_some(),
_port_mapping_lease: port_mapping_lease,
conn_counter,
listen_time: std::time::Instant::now(),
@@ -517,45 +522,87 @@ impl PunchHoleServerCommon {
pub(crate) async fn select_listener(
&self,
use_new_listener: bool,
prefer_port_mapping: bool,
) -> Option<(Arc<UdpSocket>, SocketAddr)> {
let all_listener_sockets = &self.listeners;
let mut use_last = false;
if all_listener_sockets.lock().await.len() < 16 || use_new_listener {
tracing::warn!("creating new udp hole punching listener");
all_listener_sockets.lock().await.push(
UdpHolePunchListener::new(self.peer_mgr.clone())
.await
.ok()?,
);
use_last = true;
}
let mut locked = all_listener_sockets.lock().await;
let listener = if use_last {
Some(locked.last_mut()?)
} else {
// use the listener that is active most recently
locked
.iter_mut()
.filter(|l| !l.mapped_addr.ip().is_unspecified())
.max_by_key(|listener| listener.last_active_time.load())
let (listener_count, has_reusable_listener, has_port_mapping_listener) = {
let locked = self.listeners.lock().await;
(
locked.len(),
locked.iter().any(can_reuse_public_listener),
locked.iter().any(can_reuse_port_mapping_listener),
)
};
let should_create = should_create_public_listener(
listener_count,
has_reusable_listener,
has_port_mapping_listener,
use_new_listener,
prefer_port_mapping,
);
if listener.is_none() || listener.as_ref().unwrap().mapped_addr.ip().is_unspecified() {
if should_create {
tracing::warn!(
?use_new_listener,
"no available udp hole punching listener with mapped address"
max_listeners = MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
"creating udp hole punching listener"
);
if !use_new_listener {
return self.select_listener(true).await;
} else {
return None;
match UdpHolePunchListener::new(self.peer_mgr.clone()).await {
Ok(listener) => self.listeners.lock().await.push(listener),
Err(err) => {
tracing::warn!(?err, "failed to create udp hole punching listener");
}
}
}
let listener = listener.unwrap();
let mut locked = self.listeners.lock().await;
let listener_count = locked.len();
let listener_idx = if prefer_port_mapping {
select_reusable_port_mapping_listener_idx(locked.as_slice())
.or_else(|| {
if should_create && locked.last().is_some_and(can_reuse_public_listener) {
Some(locked.len() - 1)
} else {
None
}
})
.or_else(|| select_reusable_public_listener_idx(locked.as_slice()))
} else if should_create {
locked.len().checked_sub(1)
} else {
select_reusable_public_listener_idx(locked.as_slice())
};
let Some(listener_idx) = listener_idx else {
tracing::warn!(
?use_new_listener,
?prefer_port_mapping,
listener_count,
max_listeners = MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
"no available udp hole punching listener with mapped address"
);
if should_retry_public_listener_selection(
use_new_listener,
listener_count,
prefer_port_mapping,
has_port_mapping_listener,
) {
drop(locked);
return self.select_listener(true, prefer_port_mapping).await;
}
return None;
};
let listener = &mut locked[listener_idx];
if !can_reuse_public_listener(listener) {
tracing::warn!(
?use_new_listener,
?prefer_port_mapping,
listener_count,
max_listeners = MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
"selected udp hole punching listener is not reusable"
);
return None;
}
Some((listener.get_socket().await, listener.mapped_addr))
}
@@ -572,7 +619,73 @@ impl PunchHoleServerCommon {
}
}
#[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))]
fn can_reuse_public_listener(listener: &UdpHolePunchListener) -> bool {
listener.running.load() && !listener.mapped_addr.ip().is_unspecified()
}
fn can_reuse_port_mapping_listener(listener: &UdpHolePunchListener) -> bool {
can_reuse_public_listener(listener) && listener.has_port_mapping_lease
}
fn select_reusable_public_listener_idx(listeners: &[UdpHolePunchListener]) -> Option<usize> {
// Reuse the listener that was active most recently.
listeners
.iter()
.enumerate()
.filter(|(_, listener)| can_reuse_public_listener(listener))
.max_by_key(|(_, listener)| listener.last_active_time.load())
.map(|(idx, _)| idx)
}
fn select_reusable_port_mapping_listener_idx(listeners: &[UdpHolePunchListener]) -> Option<usize> {
listeners
.iter()
.enumerate()
.filter(|(_, listener)| can_reuse_port_mapping_listener(listener))
.max_by_key(|(_, listener)| listener.last_active_time.load())
.map(|(idx, _)| idx)
}
fn should_create_public_listener(
current_listener_count: usize,
has_reusable_listener: bool,
has_port_mapping_listener: bool,
force_new_listener: bool,
prefer_port_mapping: bool,
) -> bool {
if current_listener_count >= MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS {
return false;
}
if current_listener_count == 0 {
return true;
}
if force_new_listener {
return true;
}
if prefer_port_mapping && !has_port_mapping_listener {
return true;
}
!has_reusable_listener
}
fn should_retry_public_listener_selection(
force_new_listener: bool,
current_listener_count: usize,
prefer_port_mapping: bool,
has_port_mapping_listener: bool,
) -> bool {
if prefer_port_mapping && has_port_mapping_listener {
return false;
}
!force_new_listener && current_listener_count < MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS
}
#[tracing::instrument(err, ret(level=Level::DEBUG))]
pub(crate) async fn send_symmetric_hole_punch_packet(
ports: &[u16],
udp: Arc<UdpSocket>,
@@ -647,3 +760,66 @@ pub(crate) async fn try_connect_with_socket(
.await
.map_err(Error::from)
}
#[cfg(test)]
mod tests {
use super::{
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS, should_create_public_listener,
should_retry_public_listener_selection,
};
#[test]
fn listener_selection_prefers_reuse_before_cap() {
assert!(!should_create_public_listener(1, true, true, false, false));
assert!(!should_create_public_listener(
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
true,
true,
false,
false
));
}
#[test]
fn listener_selection_creates_when_empty_or_no_reusable_listener() {
assert!(should_create_public_listener(0, false, false, false, false));
assert!(should_create_public_listener(1, false, false, false, false));
}
#[test]
fn listener_selection_force_new_respects_cap() {
assert!(should_create_public_listener(1, true, true, true, false));
assert!(!should_create_public_listener(
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
true,
true,
true,
false
));
}
#[test]
fn listener_selection_prefers_port_mapping_until_available() {
assert!(should_create_public_listener(1, true, false, false, true));
assert!(!should_create_public_listener(1, true, true, false, true));
}
#[test]
fn listener_selection_retry_respects_cap() {
assert!(should_retry_public_listener_selection(
false, 1, false, false
));
assert!(!should_retry_public_listener_selection(
false,
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
false,
false
));
assert!(!should_retry_public_listener_selection(
true, 1, false, false
));
assert!(!should_retry_public_listener_selection(
false, 1, true, true
));
}
}