feat: add upnp support (#1449)

This commit is contained in:
Debugger Chen
2026-04-21 17:19:04 +08:00
committed by GitHub
parent f4319c4d4f
commit 5cd0a3e846
26 changed files with 3707 additions and 235 deletions
+118 -33
View File
@@ -2,7 +2,7 @@
use std::{
collections::HashSet,
net::{IpAddr, Ipv6Addr, SocketAddr},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
str::FromStr,
sync::{
Arc,
@@ -27,7 +27,7 @@ use crate::{
proto::{
peer_rpc::{
DirectConnectorRpc, DirectConnectorRpcClientFactory, DirectConnectorRpcServer,
GetIpListRequest, GetIpListResponse, SendV6HolePunchPacketRequest,
GetIpListRequest, GetIpListResponse, SendUdpHolePunchPacketRequest,
},
rpc_types::controller::BaseController,
},
@@ -117,37 +117,25 @@ impl DirectConnectorManagerData {
}
}
async fn remote_send_v6_hole_punch_packet(
async fn remote_send_udp_hole_punch_packet(
&self,
dst_peer_id: PeerId,
local_socket: &UdpSocket,
connector_addr: SocketAddr,
remote_url: &url::Url,
) -> Result<(), Error> {
if !matches_scheme!(remote_url, TunnelScheme::Ip(IpScheme::Udp)) {
return Err(anyhow::anyhow!(
"udp hole punch packet only applies to udp listener: {}",
remote_url
)
.into());
}
let global_ctx = self.peer_manager.get_global_ctx();
let listener_port = remote_url.port().ok_or(anyhow::anyhow!(
"failed to parse port from remote url: {}",
remote_url
))?;
let connector_ip = global_ctx
.get_stun_info_collector()
.get_stun_info()
.public_ip
.iter()
.find(|x| x.contains(":"))
.ok_or(anyhow::anyhow!(
"failed to get public ipv6 address from stun info"
))?
.parse::<std::net::Ipv6Addr>()
.with_context(|| {
format!(
"failed to parse public ipv6 address from stun info: {:?}",
global_ctx.get_stun_info_collector().get_stun_info()
)
})?;
let connector_addr = SocketAddr::new(
std::net::IpAddr::V6(connector_ip),
local_socket.local_addr()?.port(),
);
let rpc_stub = self
.peer_manager
@@ -160,9 +148,9 @@ impl DirectConnectorManagerData {
);
rpc_stub
.send_v6_hole_punch_packet(
.send_udp_hole_punch_packet(
BaseController::default(),
SendV6HolePunchPacketRequest {
SendUdpHolePunchPacketRequest {
listener_port: listener_port as u32,
connector_addr: Some(connector_addr.into()),
},
@@ -170,7 +158,7 @@ impl DirectConnectorManagerData {
.await
.with_context(|| {
format!(
"do rpc, send v6 hole punch packet to peer {} at {}",
"do rpc, send udp hole punch packet to peer {} at {}",
dst_peer_id, remote_url
)
})?;
@@ -188,11 +176,34 @@ impl DirectConnectorManagerData {
.await
.with_context(|| format!("failed to bind local socket for {}", remote_url))?,
);
let connector_ip = self
.peer_manager
.get_global_ctx()
.get_stun_info_collector()
.get_stun_info()
.public_ip
.iter()
.find(|x| x.contains(':'))
.ok_or(anyhow::anyhow!(
"failed to get public ipv6 address from stun info"
))?
.parse::<Ipv6Addr>()
.with_context(|| {
format!(
"failed to parse public ipv6 address from stun info: {:?}",
self.peer_manager
.get_global_ctx()
.get_stun_info_collector()
.get_stun_info()
)
})?;
let connector_addr =
SocketAddr::new(IpAddr::V6(connector_ip), local_socket.local_addr()?.port());
// ask remote to send v6 hole punch packet
// and no matter what the result is, continue to connect
let _ = self
.remote_send_v6_hole_punch_packet(dst_peer_id, &local_socket, remote_url)
.remote_send_udp_hole_punch_packet(dst_peer_id, connector_addr, remote_url)
.await;
let udp_connector = UdpTunnelConnector::new(remote_url.clone());
@@ -207,14 +218,80 @@ impl DirectConnectorManagerData {
.await
}
async fn connect_to_public_ipv4(
&self,
dst_peer_id: PeerId,
remote_url: &url::Url,
) -> Result<(PeerId, PeerConnId), Error> {
let local_socket = {
let _g = self.global_ctx.net_ns.guard();
Arc::new(
UdpSocket::bind("0.0.0.0:0")
.await
.with_context(|| format!("failed to bind local socket for {}", remote_url))?,
)
};
let connector_addr = self
.peer_manager
.get_global_ctx()
.get_stun_info_collector()
.get_udp_port_mapping_with_socket(local_socket.clone())
.await
.with_context(|| format!("failed to get udp port mapping for {}", remote_url))?;
let _ = self
.remote_send_udp_hole_punch_packet(dst_peer_id, connector_addr, remote_url)
.await;
let udp_connector = UdpTunnelConnector::new(remote_url.clone());
let remote_addr = SocketAddr::from_url(remote_url.clone(), IpVersion::V4).await?;
let ret = udp_connector
.try_connect_with_socket(local_socket, remote_addr)
.await?;
self.peer_manager
.add_client_tunnel_with_peer_id_hint(ret, true, Some(dst_peer_id))
.await
}
async fn do_try_connect_to_ip(&self, dst_peer_id: PeerId, addr: String) -> Result<(), Error> {
let connector = create_connector_by_url(&addr, &self.global_ctx, IpVersion::Both).await?;
let remote_url = connector.remote_url();
let (peer_id, conn_id) = if matches_scheme!(remote_url, TunnelScheme::Ip(IpScheme::Udp))
&& matches!(remote_url.host(), Some(Host::Ipv6(_)))
{
self.connect_to_public_ipv6(dst_peer_id, &remote_url)
.await?
let (peer_id, conn_id) = if matches_scheme!(remote_url, TunnelScheme::Ip(IpScheme::Udp)) {
match remote_url.host() {
Some(Host::Ipv6(_)) => {
self.connect_to_public_ipv6(dst_peer_id, &remote_url)
.await?
}
Some(Host::Ipv4(ip)) if is_public_ipv4(ip) => {
match self.connect_to_public_ipv4(dst_peer_id, &remote_url).await {
Ok(ret) => ret,
Err(err) => {
tracing::debug!(
?err,
%remote_url,
"udp public ipv4 listener punch failed, falling back to direct connect"
);
timeout(
std::time::Duration::from_secs(3),
self.peer_manager.try_direct_connect_with_peer_id_hint(
connector,
Some(dst_peer_id),
),
)
.await??
}
}
}
_ => {
timeout(
std::time::Duration::from_secs(3),
self.peer_manager
.try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)),
)
.await??
}
}
} else {
timeout(
std::time::Duration::from_secs(3),
@@ -577,6 +654,14 @@ impl DirectConnectorManagerData {
}
}
fn is_public_ipv4(ip: Ipv4Addr) -> bool {
!ip.is_private()
&& !ip.is_loopback()
&& !ip.is_link_local()
&& !ip.is_broadcast()
&& !ip.is_unspecified()
}
impl std::fmt::Debug for DirectConnectorManagerData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DirectConnectorManagerData")
+7
View File
@@ -621,6 +621,13 @@ mod tests {
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
async fn get_udp_port_mapping_with_socket(
&self,
udp: std::sync::Arc<tokio::net::UdpSocket>,
) -> Result<SocketAddr, Error> {
self.get_udp_port_mapping(udp.local_addr()?.port()).await
}
async fn get_tcp_port_mapping(&self, mut port: u16) -> Result<SocketAddr, Error> {
if port == 0 {
port = 40144;
+223 -47
View File
@@ -13,8 +13,7 @@ use zerocopy::FromBytes as _;
use crate::{
common::{
PeerId, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
stun::StunInfoCollectorTrait as _,
PeerId, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, upnp,
},
defer,
peers::peer_manager::PeerManager,
@@ -27,6 +26,7 @@ use crate::{
};
pub(crate) const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 16;
const MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS: usize = 4;
fn generate_shuffled_port_vec() -> Vec<u16> {
let mut rng = rand::thread_rng();
@@ -352,6 +352,8 @@ pub(crate) struct UdpHolePunchListener {
tasks: JoinSet<()>,
running: Arc<AtomicCell<bool>>,
mapped_addr: SocketAddr,
has_port_mapping_lease: bool,
_port_mapping_lease: Option<upnp::UdpPortMappingLease>,
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
listen_time: std::time::Instant,
@@ -360,11 +362,6 @@ pub(crate) struct UdpHolePunchListener {
}
impl UdpHolePunchListener {
async fn get_avail_port() -> Result<u16, Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
Ok(socket.local_addr()?.port())
}
#[instrument(err)]
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
Self::new_ext(peer_mgr, true, None).await
@@ -376,18 +373,24 @@ impl UdpHolePunchListener {
with_mapped_addr: bool,
port: Option<u16>,
) -> Result<Self, Error> {
let port = port.unwrap_or(Self::get_avail_port().await?);
let listen_url = format!("udp://0.0.0.0:{}", port);
let socket = {
let _g = peer_mgr.get_global_ctx().net_ns.guard();
Arc::new(UdpSocket::bind((Ipv4Addr::UNSPECIFIED, port.unwrap_or(0))).await?)
};
let local_port = socket.local_addr()?.port();
let listen_url: url::Url = format!("udp://0.0.0.0:{local_port}").parse().unwrap();
let mapped_addr = if with_mapped_addr {
let gctx = peer_mgr.get_global_ctx();
let stun_info_collect = gctx.get_stun_info_collector();
stun_info_collect.get_udp_port_mapping(port).await?
let (mapped_addr, port_mapping_lease) = if with_mapped_addr {
upnp::resolve_udp_public_addr(peer_mgr.get_global_ctx(), &listen_url, socket.clone())
.await?
} else {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), port))
(
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, local_port)),
None,
)
};
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
let mut listener = UdpTunnelListener::new_with_socket(listen_url, socket.clone());
{
let _g = peer_mgr.get_global_ctx().net_ns.guard();
@@ -437,6 +440,8 @@ impl UdpHolePunchListener {
socket,
running,
mapped_addr,
has_port_mapping_lease: port_mapping_lease.is_some(),
_port_mapping_lease: port_mapping_lease,
conn_counter,
listen_time: std::time::Instant::now(),
@@ -517,45 +522,87 @@ impl PunchHoleServerCommon {
pub(crate) async fn select_listener(
&self,
use_new_listener: bool,
prefer_port_mapping: bool,
) -> Option<(Arc<UdpSocket>, SocketAddr)> {
let all_listener_sockets = &self.listeners;
let mut use_last = false;
if all_listener_sockets.lock().await.len() < 16 || use_new_listener {
tracing::warn!("creating new udp hole punching listener");
all_listener_sockets.lock().await.push(
UdpHolePunchListener::new(self.peer_mgr.clone())
.await
.ok()?,
);
use_last = true;
}
let mut locked = all_listener_sockets.lock().await;
let listener = if use_last {
Some(locked.last_mut()?)
} else {
// use the listener that is active most recently
locked
.iter_mut()
.filter(|l| !l.mapped_addr.ip().is_unspecified())
.max_by_key(|listener| listener.last_active_time.load())
let (listener_count, has_reusable_listener, has_port_mapping_listener) = {
let locked = self.listeners.lock().await;
(
locked.len(),
locked.iter().any(can_reuse_public_listener),
locked.iter().any(can_reuse_port_mapping_listener),
)
};
let should_create = should_create_public_listener(
listener_count,
has_reusable_listener,
has_port_mapping_listener,
use_new_listener,
prefer_port_mapping,
);
if listener.is_none() || listener.as_ref().unwrap().mapped_addr.ip().is_unspecified() {
if should_create {
tracing::warn!(
?use_new_listener,
"no available udp hole punching listener with mapped address"
max_listeners = MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
"creating udp hole punching listener"
);
if !use_new_listener {
return self.select_listener(true).await;
} else {
return None;
match UdpHolePunchListener::new(self.peer_mgr.clone()).await {
Ok(listener) => self.listeners.lock().await.push(listener),
Err(err) => {
tracing::warn!(?err, "failed to create udp hole punching listener");
}
}
}
let listener = listener.unwrap();
let mut locked = self.listeners.lock().await;
let listener_count = locked.len();
let listener_idx = if prefer_port_mapping {
select_reusable_port_mapping_listener_idx(locked.as_slice())
.or_else(|| {
if should_create && locked.last().is_some_and(can_reuse_public_listener) {
Some(locked.len() - 1)
} else {
None
}
})
.or_else(|| select_reusable_public_listener_idx(locked.as_slice()))
} else if should_create {
locked.len().checked_sub(1)
} else {
select_reusable_public_listener_idx(locked.as_slice())
};
let Some(listener_idx) = listener_idx else {
tracing::warn!(
?use_new_listener,
?prefer_port_mapping,
listener_count,
max_listeners = MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
"no available udp hole punching listener with mapped address"
);
if should_retry_public_listener_selection(
use_new_listener,
listener_count,
prefer_port_mapping,
has_port_mapping_listener,
) {
drop(locked);
return self.select_listener(true, prefer_port_mapping).await;
}
return None;
};
let listener = &mut locked[listener_idx];
if !can_reuse_public_listener(listener) {
tracing::warn!(
?use_new_listener,
?prefer_port_mapping,
listener_count,
max_listeners = MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
"selected udp hole punching listener is not reusable"
);
return None;
}
Some((listener.get_socket().await, listener.mapped_addr))
}
@@ -572,7 +619,73 @@ impl PunchHoleServerCommon {
}
}
#[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))]
fn can_reuse_public_listener(listener: &UdpHolePunchListener) -> bool {
listener.running.load() && !listener.mapped_addr.ip().is_unspecified()
}
fn can_reuse_port_mapping_listener(listener: &UdpHolePunchListener) -> bool {
can_reuse_public_listener(listener) && listener.has_port_mapping_lease
}
fn select_reusable_public_listener_idx(listeners: &[UdpHolePunchListener]) -> Option<usize> {
// Reuse the listener that was active most recently.
listeners
.iter()
.enumerate()
.filter(|(_, listener)| can_reuse_public_listener(listener))
.max_by_key(|(_, listener)| listener.last_active_time.load())
.map(|(idx, _)| idx)
}
fn select_reusable_port_mapping_listener_idx(listeners: &[UdpHolePunchListener]) -> Option<usize> {
listeners
.iter()
.enumerate()
.filter(|(_, listener)| can_reuse_port_mapping_listener(listener))
.max_by_key(|(_, listener)| listener.last_active_time.load())
.map(|(idx, _)| idx)
}
fn should_create_public_listener(
current_listener_count: usize,
has_reusable_listener: bool,
has_port_mapping_listener: bool,
force_new_listener: bool,
prefer_port_mapping: bool,
) -> bool {
if current_listener_count >= MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS {
return false;
}
if current_listener_count == 0 {
return true;
}
if force_new_listener {
return true;
}
if prefer_port_mapping && !has_port_mapping_listener {
return true;
}
!has_reusable_listener
}
fn should_retry_public_listener_selection(
force_new_listener: bool,
current_listener_count: usize,
prefer_port_mapping: bool,
has_port_mapping_listener: bool,
) -> bool {
if prefer_port_mapping && has_port_mapping_listener {
return false;
}
!force_new_listener && current_listener_count < MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS
}
#[tracing::instrument(err, ret(level=Level::DEBUG))]
pub(crate) async fn send_symmetric_hole_punch_packet(
ports: &[u16],
udp: Arc<UdpSocket>,
@@ -647,3 +760,66 @@ pub(crate) async fn try_connect_with_socket(
.await
.map_err(Error::from)
}
#[cfg(test)]
mod tests {
use super::{
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS, should_create_public_listener,
should_retry_public_listener_selection,
};
#[test]
fn listener_selection_prefers_reuse_before_cap() {
assert!(!should_create_public_listener(1, true, true, false, false));
assert!(!should_create_public_listener(
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
true,
true,
false,
false
));
}
#[test]
fn listener_selection_creates_when_empty_or_no_reusable_listener() {
assert!(should_create_public_listener(0, false, false, false, false));
assert!(should_create_public_listener(1, false, false, false, false));
}
#[test]
fn listener_selection_force_new_respects_cap() {
assert!(should_create_public_listener(1, true, true, true, false));
assert!(!should_create_public_listener(
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
true,
true,
true,
false
));
}
#[test]
fn listener_selection_prefers_port_mapping_until_available() {
assert!(should_create_public_listener(1, true, false, false, true));
assert!(!should_create_public_listener(1, true, true, false, true));
}
#[test]
fn listener_selection_retry_respects_cap() {
assert!(should_retry_public_listener_selection(
false, 1, false, false
));
assert!(!should_retry_public_listener_selection(
false,
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS,
false,
false
));
assert!(!should_retry_public_listener_selection(
true, 1, false, false
));
assert!(!should_retry_public_listener_selection(
false, 1, true, true
));
}
}
+16 -17
View File
@@ -7,7 +7,7 @@ use anyhow::Context;
use tokio::net::UdpSocket;
use crate::{
common::{PeerId, scoped_task::ScopedTask, stun::StunInfoCollectorTrait},
common::{PeerId, scoped_task::ScopedTask, upnp},
connector::udp_hole_punch::common::{
HOLE_PUNCH_PACKET_BODY_LEN, UdpSocketArray, try_connect_with_socket,
},
@@ -117,23 +117,19 @@ impl PunchConeHoleClient {
let _g = self.peer_mgr.get_global_ctx().net_ns.guard();
Arc::new(UdpSocket::bind("0.0.0.0:0").await?)
};
let local_addr = local_socket
.local_addr()
.with_context(|| "failed to get local port from udp array")?;
let local_port = local_addr.port();
drop(local_socket);
let local_mapped_addr = global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(local_port)
.await
.with_context(|| "failed to get udp port mapping")?;
let local_socket = {
let _g = self.peer_mgr.get_global_ctx().net_ns.guard();
Arc::new(UdpSocket::bind(local_addr).await?)
};
.with_context(|| "failed to get local addr from udp punch socket")?;
let local_listener: url::Url = format!("udp://0.0.0.0:{}", local_addr.port())
.parse()
.unwrap();
let (local_mapped_addr, _local_port_mapping_lease) = upnp::resolve_udp_public_addr(
global_ctx.clone(),
&local_listener,
local_socket.clone(),
)
.await
.with_context(|| "failed to resolve udp public addr for cone hole punch")?;
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
let rpc_stub = self
@@ -149,7 +145,10 @@ impl PunchConeHoleClient {
let resp = rpc_stub
.select_punch_listener(
BaseController::default(),
SelectPunchListenerRequest { force_new: false },
SelectPunchListenerRequest {
force_new: false,
prefer_port_mapping: true,
},
)
.await;
+9 -1
View File
@@ -88,7 +88,7 @@ impl UdpHolePunchRpc for UdpHolePunchServer {
) -> rpc_types::error::Result<SelectPunchListenerResponse> {
let (_, addr) = self
.common
.select_listener(input.force_new)
.select_listener(input.force_new, input.prefer_port_mapping)
.await
.ok_or(anyhow::anyhow!("no listener available"))?;
@@ -584,6 +584,11 @@ impl UdpHolePunchConnector {
Ok(())
}
#[cfg(test)]
pub async fn run_immediately_for_test(&self) {
self.client.run_immediately().await;
}
}
#[cfg(test)]
@@ -614,6 +619,9 @@ pub mod tests {
udp_nat_type: NatType,
) -> Arc<PeerManager> {
let p_a = create_mock_peer_manager().await;
let mut flags = p_a.get_global_ctx().get_flags();
flags.disable_upnp = true;
p_a.get_global_ctx().set_flags(flags);
replace_stun_info_collector(p_a.clone(), udp_nat_type);
p_a
}
@@ -434,7 +434,10 @@ impl PunchSymToConeHoleClient {
let resp = rpc_stub
.select_punch_listener(
BaseController::default(),
SelectPunchListenerRequest { force_new: false },
SelectPunchListenerRequest {
force_new: false,
prefer_port_mapping: true,
},
)
.await;