Files
Easytier/easytier/src/connector/tcp_hole_punch.rs
T
2026-04-21 17:19:04 +08:00

778 lines
25 KiB
Rust

use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{Context, Error};
use rand::Rng as _;
use tokio::task::JoinSet;
use crate::{
common::{PeerId, join_joinset_background, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::BackOff,
peers::{
peer_manager::PeerManager,
peer_task::{PeerTaskLauncher, PeerTaskManager},
},
proto::{
common::NatType,
peer_rpc::{
TcpHolePunchRequest, TcpHolePunchResponse, TcpHolePunchRpc,
TcpHolePunchRpcClientFactory, TcpHolePunchRpcServer,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{
TunnelConnector as _, TunnelListener as _,
tcp::{TcpTunnelConnector, TcpTunnelListener},
},
};
use crate::connector::{should_background_p2p_with_peer, should_try_p2p_with_peer};
pub const BLACKLIST_TIMEOUT_SEC: u64 = 3600;
fn handle_rpc_result<T>(
ret: Result<T, rpc_types::error::Error>,
dst_peer_id: PeerId,
blacklist: &timedmap::TimedMap<PeerId, ()>,
) -> Result<T, rpc_types::error::Error> {
match ret {
Ok(ret) => Ok(ret),
Err(e) => {
if matches!(e, rpc_types::error::Error::InvalidServiceKey(_, _)) {
blacklist.insert(dst_peer_id, (), Duration::from_secs(BLACKLIST_TIMEOUT_SEC));
}
Err(e)
}
}
}
fn is_symmetric_tcp_nat(nat_type: NatType) -> bool {
matches!(
nat_type,
NatType::Symmetric | NatType::SymmetricEasyInc | NatType::SymmetricEasyDec
)
}
fn bind_addr_for_port(port: u16, is_v6: bool) -> SocketAddr {
if is_v6 {
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)
} else {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port)
}
}
async fn select_local_port(peer_mgr: &Arc<PeerManager>, is_v6: bool) -> Result<u16, Error> {
let bind_addr = bind_addr_for_port(0, is_v6);
tracing::trace!(?bind_addr, is_v6, "tcp hole punch select local port");
let _g = peer_mgr.get_global_ctx().net_ns.guard();
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
let port = listener.local_addr()?.port();
tracing::debug!(?bind_addr, port, "tcp hole punch selected local port");
Ok(port)
}
// tcp support simultaneous connect, so initiator and server can both use connect.
async fn try_connect_to_remote(
peer_mgr: Arc<PeerManager>,
a_mapped_addr: SocketAddr,
local_port: u16,
is_client: bool,
max_attempts: u32,
) -> Result<(), Error> {
tracing::info!(
?a_mapped_addr,
local_port,
"tcp hole punch server start connect loop"
);
let mut connector =
TcpTunnelConnector::new(format!("tcp://{}", a_mapped_addr).parse().unwrap());
connector.set_bind_addrs(vec![bind_addr_for_port(
local_port,
a_mapped_addr.is_ipv6(),
)]);
let start = tokio::time::Instant::now();
let mut attempts: u32 = 0;
while start.elapsed() < Duration::from_secs(10) && attempts < max_attempts {
attempts = attempts.wrapping_add(1);
let _g = peer_mgr.get_global_ctx().net_ns.guard();
if let Ok(Ok(tunnel)) =
tokio::time::timeout(Duration::from_secs(3), connector.connect()).await
{
let add_tunnel_ret = if is_client {
peer_mgr.add_client_tunnel(tunnel, false).await.map(|_| ())
} else {
peer_mgr.add_tunnel_as_server(tunnel, false).await
};
if let Err(e) = add_tunnel_ret {
tracing::error!(
?a_mapped_addr,
local_port,
attempts,
?e,
"tcp hole punch server connected and added client tunnel failed"
);
continue;
} else {
tracing::info!(
?a_mapped_addr,
local_port,
attempts,
is_client,
"tcp hole punch server connected and added tunnel"
);
return Ok(());
}
}
tracing::trace!(
?a_mapped_addr,
local_port,
attempts,
"tcp hole punch server connect attempt failed"
);
let sleep_ms = rand::thread_rng().gen_range(10..100);
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
}
tracing::warn!(
?a_mapped_addr,
local_port,
attempts,
"tcp hole punch server connect loop timeout"
);
Err(anyhow::anyhow!(
"tcp hole punch server connect loop timeout"
))
}
struct TcpHolePunchServer {
peer_mgr: Arc<PeerManager>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
impl TcpHolePunchServer {
fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "tcp hole punch server".to_string());
Arc::new(Self { peer_mgr, tasks })
}
}
#[async_trait::async_trait]
impl TcpHolePunchRpc for TcpHolePunchServer {
type Controller = BaseController;
#[tracing::instrument(skip(self), fields(a_mapped_addr = ?input.connector_mapped_addr), err)]
async fn exchange_mapped_addr(
&self,
_ctrl: Self::Controller,
input: TcpHolePunchRequest,
) -> rpc_types::error::Result<TcpHolePunchResponse> {
let my_tcp_nat_type = NatType::try_from(
self.peer_mgr
.get_global_ctx()
.get_stun_info_collector()
.get_stun_info()
.tcp_nat_type,
)
.unwrap_or(NatType::Unknown);
tracing::debug!(?my_tcp_nat_type, "tcp hole punch rpc received");
if matches!(my_tcp_nat_type, NatType::Unknown) {
tracing::warn!(?my_tcp_nat_type, "tcp hole punch rpc rejected (unknown)");
return Err(anyhow::anyhow!("tcp nat type unknown not supported").into());
}
let a_mapped_addr = input
.connector_mapped_addr
.ok_or(anyhow::anyhow!("connector_mapped_addr is required"))?;
let a_mapped_addr: SocketAddr = a_mapped_addr.into();
let a_ip = a_mapped_addr.ip();
if a_ip.is_unspecified() || a_ip.is_multicast() {
tracing::warn!(?a_mapped_addr, "tcp hole punch rpc invalid connector addr");
return Err(anyhow::anyhow!("connector_mapped_addr is malformed").into());
}
let is_v6 = a_mapped_addr.is_ipv6();
let local_port = select_local_port(&self.peer_mgr, is_v6).await?;
let mapped_addr = self
.peer_mgr
.get_global_ctx()
.get_stun_info_collector()
.get_tcp_port_mapping(local_port)
.await
.with_context(|| "failed to get tcp port mapping")?;
tracing::info!(
?a_mapped_addr,
local_port,
?mapped_addr,
"tcp hole punch rpc responding with listener mapped addr and start connecting"
);
let peer_mgr = self.peer_mgr.clone();
self.tasks.lock().unwrap().spawn(async move {
let _ = try_connect_to_remote(peer_mgr, a_mapped_addr, local_port, true, 5).await;
});
Ok(TcpHolePunchResponse {
listener_mapped_addr: Some(mapped_addr.into()),
})
}
}
struct TcpHolePunchConnectorData {
peer_mgr: Arc<PeerManager>,
blacklist: Arc<timedmap::TimedMap<PeerId, ()>>,
}
impl TcpHolePunchConnectorData {
fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
Arc::new(Self {
peer_mgr,
blacklist: Arc::new(timedmap::TimedMap::new()),
})
}
async fn punch_as_initiator(self: Arc<Self>, dst_peer_id: PeerId) -> Result<(), Error> {
let mut backoff = BackOff::new(vec![1000, 1000, 4000, 8000]);
loop {
backoff.sleep_for_next_backoff().await;
if self.do_punch_as_initiator(dst_peer_id).await.is_ok() {
break;
}
if self.blacklist.contains(&dst_peer_id) {
tracing::warn!(
dst_peer_id,
"tcp hole punch initiator skipped (blacklisted)"
);
break;
}
}
Ok(())
}
#[tracing::instrument(skip(self), fields(dst_peer_id), err)]
async fn do_punch_as_initiator(&self, dst_peer_id: PeerId) -> Result<(), Error> {
let global_ctx = self.peer_mgr.get_global_ctx();
let my_tcp_nat_type = NatType::try_from(
global_ctx
.get_stun_info_collector()
.get_stun_info()
.tcp_nat_type,
)
.unwrap_or(NatType::Unknown);
tracing::debug!(?my_tcp_nat_type, "tcp hole punch initiator start");
if is_symmetric_tcp_nat(my_tcp_nat_type) || my_tcp_nat_type == NatType::Unknown {
tracing::debug!("tcp hole punch initiator skipped (symmetric)");
return Ok(());
}
let local_port = select_local_port(&self.peer_mgr, false).await?;
let mapped_addr = global_ctx
.get_stun_info_collector()
.get_tcp_port_mapping(local_port)
.await
.with_context(|| "failed to get tcp port mapping")?;
tracing::info!(
dst_peer_id,
local_port,
?mapped_addr,
"tcp hole punch initiator got mapped addr, start rpc exchange"
);
let rpc_stub = self
.peer_mgr
.get_peer_rpc_mgr()
.rpc_client()
.scoped_client::<TcpHolePunchRpcClientFactory<BaseController>>(
self.peer_mgr.my_peer_id(),
dst_peer_id,
global_ctx.get_network_name(),
);
let resp = rpc_stub
.exchange_mapped_addr(
BaseController {
timeout_ms: 6000,
..Default::default()
},
TcpHolePunchRequest {
connector_mapped_addr: Some(mapped_addr.into()),
},
)
.await;
let resp = handle_rpc_result(resp, dst_peer_id, &self.blacklist)?;
let remote_mapped_addr = resp
.listener_mapped_addr
.ok_or(anyhow::anyhow!("listener_mapped_addr is required"))?;
let remote_mapped_addr: SocketAddr = remote_mapped_addr.into();
tracing::info!(
dst_peer_id,
?remote_mapped_addr,
"tcp hole punch initiator rpc returned"
);
if let Ok(()) = try_connect_to_remote(
self.peer_mgr.clone(),
remote_mapped_addr,
local_port,
false,
1,
)
.await
{
tracing::info!(
dst_peer_id,
local_port,
?remote_mapped_addr,
"tcp hole punch initiator connected to remote mapped addr with simultaneous connection"
);
return Ok(());
}
tracing::debug!(
dst_peer_id,
local_port,
?remote_mapped_addr,
"tcp hole punch initiator sent syn to remote mapped addr"
);
let mut listener =
TcpTunnelListener::new(format!("tcp://0.0.0.0:{}", local_port).parse().unwrap());
{
let _g = self.peer_mgr.get_global_ctx().net_ns.guard();
listener.listen().await?;
}
tracing::info!(
dst_peer_id,
local_port,
url = %listener.local_url(),
"tcp hole punch initiator listening"
);
tokio::time::timeout(
Duration::from_secs(10),
self.accept_loop(&mut listener, dst_peer_id),
)
.await??;
tracing::info!(
dst_peer_id,
"tcp hole punch initiator accepted and added server tunnel"
);
Ok(())
}
async fn accept_loop(
&self,
listener: &mut TcpTunnelListener,
dst_peer_id: PeerId,
) -> Result<(), Error> {
loop {
match listener.accept().await {
Ok(tunnel) => {
if let Err(e) = self.peer_mgr.add_tunnel_as_server(tunnel, false).await {
tracing::error!("tcp hole punch add tunnel error: {}", e);
continue;
}
tracing::info!(
dst_peer_id,
"tcp hole punch initiator accepted and added server tunnel"
);
}
Err(e) => {
tracing::error!("tcp hole punch accept error: {}", e);
}
}
}
}
}
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
struct TcpPunchTaskInfo {
dst_peer_id: PeerId,
}
#[derive(Clone)]
struct TcpHolePunchPeerTaskLauncher {}
#[async_trait::async_trait]
impl PeerTaskLauncher for TcpHolePunchPeerTaskLauncher {
type Data = Arc<TcpHolePunchConnectorData>;
type CollectPeerItem = TcpPunchTaskInfo;
type TaskRet = ();
fn new_data(&self, peer_mgr: Arc<PeerManager>) -> Self::Data {
TcpHolePunchConnectorData::new(peer_mgr)
}
#[tracing::instrument(skip(self, data))]
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem> {
let global_ctx = data.peer_mgr.get_global_ctx();
let flags = global_ctx.get_flags();
let lazy_p2p = flags.lazy_p2p;
let my_tcp_nat_type = NatType::try_from(
global_ctx
.get_stun_info_collector()
.get_stun_info()
.tcp_nat_type,
)
.unwrap_or(NatType::Unknown);
if is_symmetric_tcp_nat(my_tcp_nat_type) || my_tcp_nat_type == NatType::Unknown {
tracing::trace!(
?my_tcp_nat_type,
"tcp hole punch task collect skipped (symmetric)"
);
return vec![];
}
let my_peer_id = data.peer_mgr.my_peer_id();
let now = Instant::now();
data.blacklist.cleanup();
let mut peers_to_connect = Vec::new();
for route in data.peer_mgr.list_routes().await.iter() {
let static_allowed = should_background_p2p_with_peer(
route.feature_flag.as_ref(),
false,
lazy_p2p,
flags.disable_p2p,
flags.need_p2p,
);
let dynamic_allowed = should_try_p2p_with_peer(
route.feature_flag.as_ref(),
false,
flags.disable_p2p,
flags.need_p2p,
) && data.peer_mgr.has_recent_traffic(route.peer_id, now);
if !static_allowed && !dynamic_allowed {
continue;
}
let peer_id: PeerId = route.peer_id;
if peer_id == my_peer_id {
tracing::trace!(peer_id, "tcp hole punch task collect skip self");
continue;
}
if data.blacklist.contains(&peer_id) {
tracing::debug!(peer_id, "tcp hole punch task collect skip blacklisted");
continue;
}
if data.peer_mgr.get_peer_map().has_peer(peer_id) {
tracing::trace!(peer_id, "tcp hole punch task collect skip already has peer");
continue;
}
let peer_tcp_nat_type = route
.stun_info
.as_ref()
.map(|x| x.tcp_nat_type)
.unwrap_or(0);
let peer_tcp_nat_type =
NatType::try_from(peer_tcp_nat_type).unwrap_or(NatType::Unknown);
if matches!(peer_tcp_nat_type, NatType::Unknown) {
tracing::debug!(
peer_id,
?peer_tcp_nat_type,
"tcp hole punch task collect skip peer unknown"
);
continue;
}
tracing::info!(
peer_id,
my_peer_id,
?my_tcp_nat_type,
?peer_tcp_nat_type,
"tcp hole punch task collect add peer"
);
peers_to_connect.push(TcpPunchTaskInfo {
dst_peer_id: peer_id,
});
}
peers_to_connect
}
async fn launch_task(
&self,
data: &Self::Data,
item: Self::CollectPeerItem,
) -> tokio::task::JoinHandle<Result<Self::TaskRet, anyhow::Error>> {
let data = data.clone();
tokio::spawn(async move { data.punch_as_initiator(item.dst_peer_id).await.map(|_| ()) })
}
async fn all_task_done(&self, _data: &Self::Data) {}
fn loop_interval_ms(&self) -> u64 {
5000
}
}
pub struct TcpHolePunchConnector {
server: Arc<TcpHolePunchServer>,
client: PeerTaskManager<TcpHolePunchPeerTaskLauncher>,
peer_mgr: Arc<PeerManager>,
}
impl TcpHolePunchConnector {
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
Self {
server: TcpHolePunchServer::new(peer_mgr.clone()),
client: PeerTaskManager::new_with_external_signal(
TcpHolePunchPeerTaskLauncher {},
peer_mgr.clone(),
Some(peer_mgr.p2p_demand_notify()),
),
peer_mgr,
}
}
pub async fn run_as_client(&mut self) -> Result<(), Error> {
tracing::info!("tcp hole punch client start");
self.client.start();
Ok(())
}
pub async fn run_as_server(&mut self) -> Result<(), Error> {
tracing::info!("tcp hole punch server register rpc");
self.peer_mgr
.get_peer_rpc_mgr()
.rpc_server()
.registry()
.register(
TcpHolePunchRpcServer::new_arc(self.server.clone()),
&self.peer_mgr.get_global_ctx().get_network_name(),
);
Ok(())
}
pub async fn run(&mut self) -> Result<(), Error> {
let flags = self.peer_mgr.get_global_ctx().get_flags();
if flags.disable_tcp_hole_punching {
tracing::debug!(
"tcp hole punch disabled by disable_tcp_hole_punching(={});",
flags.disable_tcp_hole_punching
);
return Ok(());
}
self.run_as_client().await?;
self.run_as_server().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, sync::Arc, time::Duration};
use crate::{
common::{error::Error, stun::StunInfoCollectorTrait},
connector::tcp_hole_punch::TcpHolePunchConnector,
peers::{
peer_manager::PeerManager,
peer_task::PeerTaskLauncher,
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
},
proto::common::{NatType, StunInfo},
tunnel::common::tests::wait_for_condition,
};
use super::TcpHolePunchPeerTaskLauncher;
struct MockStunInfoCollector {
udp_nat_type: NatType,
tcp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: self.tcp_nat_type as i32,
last_update_time: 0,
public_ip: vec!["127.0.0.1".to_string(), "::1".to_string()],
min_port: 100,
max_port: 200,
}
}
async fn get_udp_port_mapping(&self, mut port: u16) -> Result<SocketAddr, Error> {
if port == 0 {
port = 40144;
}
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
async fn get_udp_port_mapping_with_socket(
&self,
udp: std::sync::Arc<tokio::net::UdpSocket>,
) -> Result<SocketAddr, Error> {
self.get_udp_port_mapping(udp.local_addr()?.port()).await
}
async fn get_tcp_port_mapping(&self, mut port: u16) -> Result<SocketAddr, Error> {
if port == 0 {
port = 40144;
}
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, tcp_nat_type: NatType) {
let collector = Box::new(MockStunInfoCollector {
udp_nat_type: NatType::Unknown,
tcp_nat_type,
});
peer_mgr
.get_global_ctx()
.replace_stun_info_collector(collector);
}
async fn collect_lazy_punch_peers(peer_mgr: Arc<PeerManager>) -> Vec<u32> {
let launcher = TcpHolePunchPeerTaskLauncher {};
let data = launcher.new_data(peer_mgr);
launcher
.collect_peers_need_task(&data)
.await
.into_iter()
.map(|task| task.dst_peer_id)
.collect()
}
#[tokio::test]
async fn tcp_hole_punch_connects() {
let p_a = create_mock_peer_manager().await;
let p_b = create_mock_peer_manager().await;
let p_c = create_mock_peer_manager().await;
replace_stun_info_collector(p_a.clone(), NatType::PortRestricted);
replace_stun_info_collector(p_b.clone(), NatType::PortRestricted);
replace_stun_info_collector(p_c.clone(), NatType::PortRestricted);
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
let mut hole_punching_a = TcpHolePunchConnector::new(p_a.clone());
let mut hole_punching_c = TcpHolePunchConnector::new(p_c.clone());
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
hole_punching_a.client.run_immediately().await;
hole_punching_c.client.run_immediately().await;
wait_for_condition(
|| {
let p_a = p_a.clone();
let p_c = p_c.clone();
async move {
let a_has = p_a
.get_peer_map()
.list_peer_conns(p_c.my_peer_id())
.await
.is_some_and(|c| !c.is_empty());
let c_has = p_c
.get_peer_map()
.list_peer_conns(p_a.my_peer_id())
.await
.is_some_and(|c| !c.is_empty());
a_has || c_has
}
},
Duration::from_secs(15),
)
.await;
}
#[tokio::test]
async fn tcp_hole_punch_skip_symmetric_peer() {
let p_a = create_mock_peer_manager().await;
let p_b = create_mock_peer_manager().await;
let p_c = create_mock_peer_manager().await;
replace_stun_info_collector(p_a.clone(), NatType::Symmetric);
replace_stun_info_collector(p_b.clone(), NatType::PortRestricted);
replace_stun_info_collector(p_c.clone(), NatType::Symmetric);
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
let mut hole_punching_a = TcpHolePunchConnector::new(p_a.clone());
let mut hole_punching_c = TcpHolePunchConnector::new(p_c.clone());
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
hole_punching_a.client.run_immediately().await;
hole_punching_c.client.run_immediately().await;
tokio::time::sleep(Duration::from_secs(2)).await;
assert!(
p_a.get_peer_map()
.list_peer_conns(p_c.my_peer_id())
.await
.map(|c| c.is_empty())
.unwrap_or(true)
);
assert!(
p_c.get_peer_map()
.list_peer_conns(p_a.my_peer_id())
.await
.map(|c| c.is_empty())
.unwrap_or(true)
);
}
#[tokio::test]
async fn lazy_p2p_collects_tcp_hole_punch_tasks_only_after_recent_traffic() {
let p_a = create_mock_peer_manager().await;
let p_b = create_mock_peer_manager().await;
let p_c = create_mock_peer_manager().await;
replace_stun_info_collector(p_a.clone(), NatType::PortRestricted);
replace_stun_info_collector(p_b.clone(), NatType::PortRestricted);
replace_stun_info_collector(p_c.clone(), NatType::PortRestricted);
let mut flags = p_a.get_global_ctx().get_flags();
flags.lazy_p2p = true;
p_a.get_global_ctx().set_flags(flags);
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
assert!(
!collect_lazy_punch_peers(p_a.clone())
.await
.contains(&p_c.my_peer_id())
);
p_a.mark_recent_traffic(p_c.my_peer_id());
assert!(
collect_lazy_punch_peers(p_a.clone())
.await
.contains(&p_c.my_peer_id())
);
}
}