mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 10:14:35 +00:00
@@ -0,0 +1,730 @@
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{Context, Error};
|
||||
use rand::Rng as _;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::{
|
||||
common::{join_joinset_background, stun::StunInfoCollectorTrait, PeerId},
|
||||
connector::udp_hole_punch::BackOff,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
peer_task::{PeerTaskLauncher, PeerTaskManager},
|
||||
},
|
||||
proto::{
|
||||
common::NatType,
|
||||
peer_rpc::{
|
||||
TcpHolePunchRequest, TcpHolePunchResponse, TcpHolePunchRpc,
|
||||
TcpHolePunchRpcClientFactory, TcpHolePunchRpcServer,
|
||||
},
|
||||
rpc_types::{self, controller::BaseController},
|
||||
},
|
||||
tunnel::{
|
||||
common::setup_sokcet2,
|
||||
tcp::{TcpTunnelConnector, TcpTunnelListener},
|
||||
TunnelConnector as _, TunnelListener as _,
|
||||
},
|
||||
};
|
||||
|
||||
pub const BLACKLIST_TIMEOUT_SEC: u64 = 3600;
|
||||
|
||||
fn handle_rpc_result<T>(
|
||||
ret: Result<T, rpc_types::error::Error>,
|
||||
dst_peer_id: PeerId,
|
||||
blacklist: &timedmap::TimedMap<PeerId, ()>,
|
||||
) -> Result<T, rpc_types::error::Error> {
|
||||
match ret {
|
||||
Ok(ret) => Ok(ret),
|
||||
Err(e) => {
|
||||
if matches!(e, rpc_types::error::Error::InvalidServiceKey(_, _)) {
|
||||
blacklist.insert(dst_peer_id, (), Duration::from_secs(BLACKLIST_TIMEOUT_SEC));
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_symmetric_tcp_nat(nat_type: NatType) -> bool {
|
||||
matches!(
|
||||
nat_type,
|
||||
NatType::Symmetric | NatType::SymmetricEasyInc | NatType::SymmetricEasyDec
|
||||
)
|
||||
}
|
||||
|
||||
fn bind_addr_for_port(port: u16, is_v6: bool) -> SocketAddr {
|
||||
if is_v6 {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)
|
||||
} else {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port)
|
||||
}
|
||||
}
|
||||
|
||||
async fn select_local_port(peer_mgr: &Arc<PeerManager>, is_v6: bool) -> Result<u16, Error> {
|
||||
let bind_addr = bind_addr_for_port(0, is_v6);
|
||||
tracing::trace!(?bind_addr, is_v6, "tcp hole punch select local port");
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
|
||||
let port = listener.local_addr()?.port();
|
||||
tracing::debug!(?bind_addr, port, "tcp hole punch selected local port");
|
||||
Ok(port)
|
||||
}
|
||||
|
||||
async fn send_syn_from_port(
|
||||
peer_mgr: &Arc<PeerManager>,
|
||||
local_port: u16,
|
||||
dst: SocketAddr,
|
||||
) -> Result<(), Error> {
|
||||
let bind_addr = bind_addr_for_port(local_port, dst.is_ipv6());
|
||||
tracing::debug!(?bind_addr, ?dst, "tcp hole punch send syn");
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(dst),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = tokio::net::TcpSocket::from_std_stream(socket2_socket.into());
|
||||
match tokio::time::timeout(Duration::from_millis(600), socket.connect(dst)).await {
|
||||
Ok(Ok(_stream)) => {
|
||||
tracing::trace!(?bind_addr, ?dst, "tcp hole punch syn connect succeeded");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::trace!(?bind_addr, ?dst, ?e, "tcp hole punch syn connect failed");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::trace!(?bind_addr, ?dst, ?e, "tcp hole punch syn connect timeout");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// tcp support simultaneous connect, so initiator and server can both use connect.
|
||||
async fn try_connect_to_remote(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
a_mapped_addr: SocketAddr,
|
||||
local_port: u16,
|
||||
is_client: bool,
|
||||
max_attempts: u32,
|
||||
) -> Result<(), Error> {
|
||||
tracing::info!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
"tcp hole punch server start connect loop"
|
||||
);
|
||||
|
||||
let mut connector =
|
||||
TcpTunnelConnector::new(format!("tcp://{}", a_mapped_addr).parse().unwrap());
|
||||
connector.set_bind_addrs(vec![bind_addr_for_port(
|
||||
local_port,
|
||||
a_mapped_addr.is_ipv6(),
|
||||
)]);
|
||||
|
||||
let start = tokio::time::Instant::now();
|
||||
let mut attempts: u32 = 0;
|
||||
while start.elapsed() < Duration::from_secs(10) && attempts < max_attempts {
|
||||
attempts = attempts.wrapping_add(1);
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
if let Ok(Ok(tunnel)) =
|
||||
tokio::time::timeout(Duration::from_secs(3), connector.connect()).await
|
||||
{
|
||||
let add_tunnel_ret = if is_client {
|
||||
peer_mgr.add_client_tunnel(tunnel, false).await.map(|_| ())
|
||||
} else {
|
||||
peer_mgr.add_tunnel_as_server(tunnel, false).await
|
||||
};
|
||||
if let Err(e) = add_tunnel_ret {
|
||||
tracing::error!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
?e,
|
||||
"tcp hole punch server connected and added client tunnel failed"
|
||||
);
|
||||
continue;
|
||||
} else {
|
||||
tracing::info!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
is_client,
|
||||
"tcp hole punch server connected and added tunnel"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
tracing::trace!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
"tcp hole punch server connect attempt failed"
|
||||
);
|
||||
let sleep_ms = rand::thread_rng().gen_range(10..100);
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
"tcp hole punch server connect loop timeout"
|
||||
);
|
||||
|
||||
Err(anyhow::anyhow!(
|
||||
"tcp hole punch server connect loop timeout"
|
||||
))
|
||||
}
|
||||
|
||||
struct TcpHolePunchServer {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl TcpHolePunchServer {
|
||||
fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
|
||||
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||
join_joinset_background(tasks.clone(), "tcp hole punch server".to_string());
|
||||
Arc::new(Self { peer_mgr, tasks })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TcpHolePunchRpc for TcpHolePunchServer {
|
||||
type Controller = BaseController;
|
||||
|
||||
#[tracing::instrument(skip(self), fields(a_mapped_addr = ?input.connector_mapped_addr), err)]
|
||||
async fn exchange_mapped_addr(
|
||||
&self,
|
||||
_ctrl: Self::Controller,
|
||||
input: TcpHolePunchRequest,
|
||||
) -> rpc_types::error::Result<TcpHolePunchResponse> {
|
||||
let my_tcp_nat_type = NatType::try_from(
|
||||
self.peer_mgr
|
||||
.get_global_ctx()
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.tcp_nat_type,
|
||||
)
|
||||
.unwrap_or(NatType::Unknown);
|
||||
tracing::debug!(?my_tcp_nat_type, "tcp hole punch rpc received");
|
||||
if matches!(my_tcp_nat_type, NatType::Unknown) {
|
||||
tracing::warn!(?my_tcp_nat_type, "tcp hole punch rpc rejected (unknown)");
|
||||
return Err(anyhow::anyhow!("tcp nat type unknown not supported").into());
|
||||
}
|
||||
|
||||
let a_mapped_addr = input
|
||||
.connector_mapped_addr
|
||||
.ok_or(anyhow::anyhow!("connector_mapped_addr is required"))?;
|
||||
let a_mapped_addr: SocketAddr = a_mapped_addr.into();
|
||||
let a_ip = a_mapped_addr.ip();
|
||||
if a_ip.is_unspecified() || a_ip.is_multicast() {
|
||||
tracing::warn!(?a_mapped_addr, "tcp hole punch rpc invalid connector addr");
|
||||
return Err(anyhow::anyhow!("connector_mapped_addr is malformed").into());
|
||||
}
|
||||
|
||||
let is_v6 = a_mapped_addr.is_ipv6();
|
||||
let local_port = select_local_port(&self.peer_mgr, is_v6).await?;
|
||||
let mapped_addr = self
|
||||
.peer_mgr
|
||||
.get_global_ctx()
|
||||
.get_stun_info_collector()
|
||||
.get_tcp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get tcp port mapping")?;
|
||||
|
||||
tracing::info!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
?mapped_addr,
|
||||
"tcp hole punch rpc responding with listener mapped addr and start connecting"
|
||||
);
|
||||
|
||||
let peer_mgr = self.peer_mgr.clone();
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
let _ = try_connect_to_remote(peer_mgr, a_mapped_addr, local_port, true, 5).await;
|
||||
});
|
||||
|
||||
Ok(TcpHolePunchResponse {
|
||||
listener_mapped_addr: Some(mapped_addr.into()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct TcpHolePunchConnectorData {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
blacklist: Arc<timedmap::TimedMap<PeerId, ()>>,
|
||||
}
|
||||
|
||||
impl TcpHolePunchConnectorData {
|
||||
fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
peer_mgr,
|
||||
blacklist: Arc::new(timedmap::TimedMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
async fn punch_as_initiator(self: Arc<Self>, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let mut backoff = BackOff::new(vec![1000, 1000, 4000, 8000]);
|
||||
|
||||
loop {
|
||||
backoff.sleep_for_next_backoff().await;
|
||||
if self.do_punch_as_initiator(dst_peer_id).await.is_ok() {
|
||||
break;
|
||||
}
|
||||
|
||||
if self.blacklist.contains(&dst_peer_id) {
|
||||
tracing::warn!(
|
||||
dst_peer_id,
|
||||
"tcp hole punch initiator skipped (blacklisted)"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(dst_peer_id), err)]
|
||||
async fn do_punch_as_initiator(&self, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let global_ctx = self.peer_mgr.get_global_ctx();
|
||||
let my_tcp_nat_type = NatType::try_from(
|
||||
global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.tcp_nat_type,
|
||||
)
|
||||
.unwrap_or(NatType::Unknown);
|
||||
tracing::debug!(?my_tcp_nat_type, "tcp hole punch initiator start");
|
||||
if is_symmetric_tcp_nat(my_tcp_nat_type) || my_tcp_nat_type == NatType::Unknown {
|
||||
tracing::debug!("tcp hole punch initiator skipped (symmetric)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let local_port = select_local_port(&self.peer_mgr, false).await?;
|
||||
let mapped_addr = global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_tcp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get tcp port mapping")?;
|
||||
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
?mapped_addr,
|
||||
"tcp hole punch initiator got mapped addr, start rpc exchange"
|
||||
);
|
||||
|
||||
let rpc_stub = self
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_client()
|
||||
.scoped_client::<TcpHolePunchRpcClientFactory<BaseController>>(
|
||||
self.peer_mgr.my_peer_id(),
|
||||
dst_peer_id,
|
||||
global_ctx.get_network_name(),
|
||||
);
|
||||
|
||||
let resp = rpc_stub
|
||||
.exchange_mapped_addr(
|
||||
BaseController {
|
||||
timeout_ms: 6000,
|
||||
..Default::default()
|
||||
},
|
||||
TcpHolePunchRequest {
|
||||
connector_mapped_addr: Some(mapped_addr.into()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let resp = handle_rpc_result(resp, dst_peer_id, &self.blacklist)?;
|
||||
let remote_mapped_addr = resp
|
||||
.listener_mapped_addr
|
||||
.ok_or(anyhow::anyhow!("listener_mapped_addr is required"))?;
|
||||
let remote_mapped_addr: SocketAddr = remote_mapped_addr.into();
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
?remote_mapped_addr,
|
||||
"tcp hole punch initiator rpc returned"
|
||||
);
|
||||
|
||||
if let Ok(()) = try_connect_to_remote(
|
||||
self.peer_mgr.clone(),
|
||||
remote_mapped_addr,
|
||||
local_port,
|
||||
false,
|
||||
1,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
?remote_mapped_addr,
|
||||
"tcp hole punch initiator connected to remote mapped addr with simultaneous connection"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
?remote_mapped_addr,
|
||||
"tcp hole punch initiator sent syn to remote mapped addr"
|
||||
);
|
||||
|
||||
let mut listener =
|
||||
TcpTunnelListener::new(format!("tcp://0.0.0.0:{}", local_port).parse().unwrap());
|
||||
{
|
||||
let _g = self.peer_mgr.get_global_ctx().net_ns.guard();
|
||||
listener.listen().await?;
|
||||
}
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
url = %listener.local_url(),
|
||||
"tcp hole punch initiator listening"
|
||||
);
|
||||
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
self.accept_loop(&mut listener, dst_peer_id),
|
||||
)
|
||||
.await??;
|
||||
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
"tcp hole punch initiator accepted and added server tunnel"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept_loop(
|
||||
&self,
|
||||
listener: &mut TcpTunnelListener,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error> {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok(tunnel) => {
|
||||
if let Err(e) = self.peer_mgr.add_tunnel_as_server(tunnel, false).await {
|
||||
tracing::error!("tcp hole punch add tunnel error: {}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
"tcp hole punch initiator accepted and added server tunnel"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("tcp hole punch accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
|
||||
struct TcpPunchTaskInfo {
|
||||
dst_peer_id: PeerId,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TcpHolePunchPeerTaskLauncher {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerTaskLauncher for TcpHolePunchPeerTaskLauncher {
|
||||
type Data = Arc<TcpHolePunchConnectorData>;
|
||||
type CollectPeerItem = TcpPunchTaskInfo;
|
||||
type TaskRet = ();
|
||||
|
||||
fn new_data(&self, peer_mgr: Arc<PeerManager>) -> Self::Data {
|
||||
TcpHolePunchConnectorData::new(peer_mgr)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, data))]
|
||||
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem> {
|
||||
let global_ctx = data.peer_mgr.get_global_ctx();
|
||||
let my_tcp_nat_type = NatType::try_from(
|
||||
global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.tcp_nat_type,
|
||||
)
|
||||
.unwrap_or(NatType::Unknown);
|
||||
if is_symmetric_tcp_nat(my_tcp_nat_type) || my_tcp_nat_type == NatType::Unknown {
|
||||
tracing::trace!(
|
||||
?my_tcp_nat_type,
|
||||
"tcp hole punch task collect skipped (symmetric)"
|
||||
);
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let my_peer_id = data.peer_mgr.my_peer_id();
|
||||
|
||||
data.blacklist.cleanup();
|
||||
|
||||
let mut peers_to_connect = Vec::new();
|
||||
for route in data.peer_mgr.list_routes().await.iter() {
|
||||
if route
|
||||
.feature_flag
|
||||
.map(|x| x.is_public_server)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let peer_id: PeerId = route.peer_id;
|
||||
if peer_id == my_peer_id {
|
||||
tracing::trace!(peer_id, "tcp hole punch task collect skip self");
|
||||
continue;
|
||||
}
|
||||
|
||||
if data.blacklist.contains(&peer_id) {
|
||||
tracing::debug!(peer_id, "tcp hole punch task collect skip blacklisted");
|
||||
continue;
|
||||
}
|
||||
|
||||
if data.peer_mgr.get_peer_map().has_peer(peer_id) {
|
||||
tracing::trace!(peer_id, "tcp hole punch task collect skip already has peer");
|
||||
continue;
|
||||
}
|
||||
|
||||
let peer_tcp_nat_type = route
|
||||
.stun_info
|
||||
.as_ref()
|
||||
.map(|x| x.tcp_nat_type)
|
||||
.unwrap_or(0);
|
||||
let peer_tcp_nat_type =
|
||||
NatType::try_from(peer_tcp_nat_type).unwrap_or(NatType::Unknown);
|
||||
if matches!(peer_tcp_nat_type, NatType::Unknown) {
|
||||
tracing::debug!(
|
||||
peer_id,
|
||||
?peer_tcp_nat_type,
|
||||
"tcp hole punch task collect skip peer unknown"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
peer_id,
|
||||
my_peer_id,
|
||||
?my_tcp_nat_type,
|
||||
?peer_tcp_nat_type,
|
||||
"tcp hole punch task collect add peer"
|
||||
);
|
||||
peers_to_connect.push(TcpPunchTaskInfo {
|
||||
dst_peer_id: peer_id,
|
||||
});
|
||||
}
|
||||
|
||||
peers_to_connect
|
||||
}
|
||||
|
||||
async fn launch_task(
|
||||
&self,
|
||||
data: &Self::Data,
|
||||
item: Self::CollectPeerItem,
|
||||
) -> tokio::task::JoinHandle<Result<Self::TaskRet, anyhow::Error>> {
|
||||
let data = data.clone();
|
||||
tokio::spawn(async move { data.punch_as_initiator(item.dst_peer_id).await.map(|_| ()) })
|
||||
}
|
||||
|
||||
async fn all_task_done(&self, _data: &Self::Data) {}
|
||||
|
||||
fn loop_interval_ms(&self) -> u64 {
|
||||
5000
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpHolePunchConnector {
|
||||
server: Arc<TcpHolePunchServer>,
|
||||
client: PeerTaskManager<TcpHolePunchPeerTaskLauncher>,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
impl TcpHolePunchConnector {
|
||||
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
server: TcpHolePunchServer::new(peer_mgr.clone()),
|
||||
client: PeerTaskManager::new(TcpHolePunchPeerTaskLauncher {}, peer_mgr.clone()),
|
||||
peer_mgr,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_as_client(&mut self) -> Result<(), Error> {
|
||||
tracing::info!("tcp hole punch client start");
|
||||
self.client.start();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_as_server(&mut self) -> Result<(), Error> {
|
||||
tracing::info!("tcp hole punch server register rpc");
|
||||
self.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_server()
|
||||
.registry()
|
||||
.register(
|
||||
TcpHolePunchRpcServer::new(self.server.clone()),
|
||||
&self.peer_mgr.get_global_ctx().get_network_name(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
if self.peer_mgr.get_global_ctx().get_flags().disable_p2p {
|
||||
tracing::debug!("tcp hole punch disabled by disable_p2p");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.run_as_client().await?;
|
||||
self.run_as_server().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, stun::StunInfoCollectorTrait},
|
||||
connector::tcp_hole_punch::TcpHolePunchConnector,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
||||
},
|
||||
proto::common::{NatType, StunInfo},
|
||||
tunnel::common::tests::wait_for_condition,
|
||||
};
|
||||
|
||||
struct MockStunInfoCollector {
|
||||
udp_nat_type: NatType,
|
||||
tcp_nat_type: NatType,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StunInfoCollectorTrait for MockStunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
StunInfo {
|
||||
udp_nat_type: self.udp_nat_type as i32,
|
||||
tcp_nat_type: self.tcp_nat_type as i32,
|
||||
last_update_time: 0,
|
||||
public_ip: vec!["127.0.0.1".to_string(), "::1".to_string()],
|
||||
min_port: 100,
|
||||
max_port: 200,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_udp_port_mapping(&self, mut port: u16) -> Result<SocketAddr, Error> {
|
||||
if port == 0 {
|
||||
port = 40144;
|
||||
}
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
|
||||
async fn get_tcp_port_mapping(&self, mut port: u16) -> Result<SocketAddr, Error> {
|
||||
if port == 0 {
|
||||
port = 40144;
|
||||
}
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, tcp_nat_type: NatType) {
|
||||
let collector = Box::new(MockStunInfoCollector {
|
||||
udp_nat_type: NatType::Unknown,
|
||||
tcp_nat_type,
|
||||
});
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.replace_stun_info_collector(collector);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_hole_punch_connects() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let p_b = create_mock_peer_manager().await;
|
||||
let p_c = create_mock_peer_manager().await;
|
||||
|
||||
replace_stun_info_collector(p_a.clone(), NatType::PortRestricted);
|
||||
replace_stun_info_collector(p_b.clone(), NatType::PortRestricted);
|
||||
replace_stun_info_collector(p_c.clone(), NatType::PortRestricted);
|
||||
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut hole_punching_a = TcpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = TcpHolePunchConnector::new(p_c.clone());
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
hole_punching_c.client.run_immediately().await;
|
||||
|
||||
wait_for_condition(
|
||||
|| {
|
||||
let p_a = p_a.clone();
|
||||
let p_c = p_c.clone();
|
||||
async move {
|
||||
let a_has = p_a
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_c.my_peer_id())
|
||||
.await
|
||||
.is_some_and(|c| !c.is_empty());
|
||||
let c_has = p_c
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_a.my_peer_id())
|
||||
.await
|
||||
.is_some_and(|c| !c.is_empty());
|
||||
a_has || c_has
|
||||
}
|
||||
},
|
||||
Duration::from_secs(15),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_hole_punch_skip_symmetric_peer() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let p_b = create_mock_peer_manager().await;
|
||||
let p_c = create_mock_peer_manager().await;
|
||||
|
||||
replace_stun_info_collector(p_a.clone(), NatType::Symmetric);
|
||||
replace_stun_info_collector(p_b.clone(), NatType::PortRestricted);
|
||||
replace_stun_info_collector(p_c.clone(), NatType::Symmetric);
|
||||
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut hole_punching_a = TcpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = TcpHolePunchConnector::new(p_c.clone());
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
hole_punching_c.client.run_immediately().await;
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
|
||||
assert!(p_a
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_c.my_peer_id())
|
||||
.await
|
||||
.map(|c| c.is_empty())
|
||||
.unwrap_or(true));
|
||||
assert!(p_c
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_a.my_peer_id())
|
||||
.await
|
||||
.map(|c| c.is_empty())
|
||||
.unwrap_or(true));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user