mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 02:09:06 +00:00
use workspace, prepare for config server and gui (#48)
This commit is contained in:
@@ -0,0 +1,411 @@
|
||||
// try connect peers directly, with either its public ip or lan ip
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager},
|
||||
};
|
||||
|
||||
use crate::rpc::{peer::GetIpListResponse, PeerConnInfo};
|
||||
use tokio::{task::JoinSet, time::timeout};
|
||||
use tracing::Instrument;
|
||||
|
||||
use super::create_connector_by_url;
|
||||
|
||||
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
|
||||
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait DirectConnectorRpc {
|
||||
async fn get_ip_list() -> GetIpListResponse;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait PeerManagerForDirectConnector {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>>;
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerManagerForDirectConnector for PeerManager {
|
||||
async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let mut ret = vec![];
|
||||
|
||||
let routes = self.list_routes().await;
|
||||
for r in routes.iter() {
|
||||
ret.push(r.peer_id);
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
|
||||
self.get_peer_map().list_peer_conns(peer_id).await
|
||||
}
|
||||
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
|
||||
self.get_peer_rpc_mgr()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DirectConnectorManagerRpcServer {
|
||||
// TODO: this only cache for one src peer, should make it global
|
||||
global_ctx: ArcGlobalCtx,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
|
||||
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse {
|
||||
let mut ret = self.global_ctx.get_ip_collector().collect_ip_addrs().await;
|
||||
ret.listeners = self.global_ctx.get_running_listeners();
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
impl DirectConnectorManagerRpcServer {
|
||||
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
|
||||
Self { global_ctx }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct DstBlackListItem(PeerId, String);
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct DstSchemeBlackListItem(PeerId, String);
|
||||
|
||||
struct DirectConnectorManagerData {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
dst_blacklist: timedmap::TimedMap<DstBlackListItem, ()>,
|
||||
dst_sceme_blacklist: timedmap::TimedMap<DstSchemeBlackListItem, ()>,
|
||||
}
|
||||
|
||||
impl DirectConnectorManagerData {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
global_ctx,
|
||||
peer_manager,
|
||||
dst_blacklist: timedmap::TimedMap::new(),
|
||||
dst_sceme_blacklist: timedmap::TimedMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DirectConnectorManagerData {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("DirectConnectorManagerData")
|
||||
.field("peer_manager", &self.peer_manager)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DirectConnectorManager {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl DirectConnectorManager {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
data: Arc::new(DirectConnectorManagerData::new(global_ctx, peer_manager)),
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self) {
|
||||
self.run_as_server();
|
||||
self.run_as_client();
|
||||
}
|
||||
|
||||
pub fn run_as_server(&mut self) {
|
||||
self.data.peer_manager.get_peer_rpc_mgr().run_service(
|
||||
DIRECT_CONNECTOR_SERVICE_ID,
|
||||
DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn run_as_client(&mut self) {
|
||||
let data = self.data.clone();
|
||||
let my_peer_id = self.data.peer_manager.my_peer_id();
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
loop {
|
||||
let peers = data.peer_manager.list_peers().await;
|
||||
let mut tasks = JoinSet::new();
|
||||
for peer_id in peers {
|
||||
if peer_id == my_peer_id {
|
||||
continue;
|
||||
}
|
||||
tasks.spawn(Self::do_try_direct_connect(data.clone(), peer_id));
|
||||
}
|
||||
|
||||
while let Some(task_ret) = tasks.join_next().await {
|
||||
tracing::trace!(?task_ret, "direct connect task ret");
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("direct_connector_client", my_id = ?self.global_ctx.id),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
async fn do_try_connect_to_ip(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
addr: String,
|
||||
) -> Result<(), Error> {
|
||||
data.dst_blacklist.cleanup();
|
||||
if data
|
||||
.dst_blacklist
|
||||
.contains(&DstBlackListItem(dst_peer_id.clone(), addr.clone()))
|
||||
{
|
||||
tracing::trace!("try_connect_to_ip failed, addr in blacklist: {}", addr);
|
||||
return Err(Error::UrlInBlacklist);
|
||||
}
|
||||
|
||||
let connector = create_connector_by_url(&addr, &data.global_ctx).await?;
|
||||
let (peer_id, conn_id) = timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
data.peer_manager.try_connect(connector),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// let (peer_id, conn_id) = data.peer_manager.try_connect(connector).await?;
|
||||
|
||||
if peer_id != dst_peer_id {
|
||||
tracing::info!(
|
||||
"connect to ip succ: {}, but peer id mismatch, expect: {}, actual: {}",
|
||||
addr,
|
||||
dst_peer_id,
|
||||
peer_id
|
||||
);
|
||||
data.peer_manager
|
||||
.get_peer_map()
|
||||
.close_peer_conn(peer_id, &conn_id)
|
||||
.await?;
|
||||
return Err(Error::InvalidUrl(addr));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn try_connect_to_ip(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
addr: String,
|
||||
) -> Result<(), Error> {
|
||||
let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await;
|
||||
if let Err(e) = ret {
|
||||
if !matches!(e, Error::UrlInBlacklist) {
|
||||
tracing::info!(
|
||||
"try_connect_to_ip failed: {:?}, peer_id: {}",
|
||||
e,
|
||||
dst_peer_id
|
||||
);
|
||||
data.dst_blacklist.insert(
|
||||
DstBlackListItem(dst_peer_id.clone(), addr.clone()),
|
||||
(),
|
||||
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
|
||||
);
|
||||
}
|
||||
return Err(e);
|
||||
} else {
|
||||
log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_try_direct_connect_internal(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
ip_list: GetIpListResponse,
|
||||
) -> Result<(), Error> {
|
||||
let available_listeners = ip_list
|
||||
.listeners
|
||||
.iter()
|
||||
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
|
||||
.filter(|l| l.port().is_some())
|
||||
.filter(|l| {
|
||||
!data.dst_sceme_blacklist.contains(&DstSchemeBlackListItem(
|
||||
dst_peer_id.clone(),
|
||||
l.scheme().to_string(),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut listener = available_listeners.get(0).ok_or(anyhow::anyhow!(
|
||||
"peer {} have no valid listener",
|
||||
dst_peer_id
|
||||
))?;
|
||||
|
||||
// if have default listener, use it first
|
||||
listener = available_listeners
|
||||
.iter()
|
||||
.find(|l| l.scheme() == data.global_ctx.get_flags().default_protocol)
|
||||
.unwrap_or(listener);
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
ip_list.interface_ipv4s.iter().for_each(|ip| {
|
||||
let addr = format!(
|
||||
"{}://{}:{}",
|
||||
listener.scheme(),
|
||||
ip,
|
||||
listener.port().unwrap_or(11010)
|
||||
);
|
||||
tasks.spawn(Self::try_connect_to_ip(
|
||||
data.clone(),
|
||||
dst_peer_id.clone(),
|
||||
addr,
|
||||
));
|
||||
});
|
||||
|
||||
let addr = format!(
|
||||
"{}://{}:{}",
|
||||
listener.scheme(),
|
||||
ip_list.public_ipv4.clone(),
|
||||
listener.port().unwrap_or(11010)
|
||||
);
|
||||
tasks.spawn(Self::try_connect_to_ip(
|
||||
data.clone(),
|
||||
dst_peer_id.clone(),
|
||||
addr,
|
||||
));
|
||||
|
||||
let mut has_succ = false;
|
||||
while let Some(ret) = tasks.join_next().await {
|
||||
if let Err(e) = ret {
|
||||
log::error!("join direct connect task failed: {:?}", e);
|
||||
} else if let Ok(Ok(_)) = ret {
|
||||
has_succ = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !has_succ {
|
||||
data.dst_sceme_blacklist.insert(
|
||||
DstSchemeBlackListItem(dst_peer_id.clone(), listener.scheme().to_string()),
|
||||
(),
|
||||
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_try_direct_connect(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error> {
|
||||
let peer_manager = data.peer_manager.clone();
|
||||
// check if we have direct connection with dst_peer_id
|
||||
if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await {
|
||||
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
|
||||
if !c.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!("try direct connect to peer: {}", dst_peer_id);
|
||||
|
||||
let ip_list = peer_manager
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
|
||||
let client =
|
||||
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ip_list = client.get_ip_list(tarpc::context::current()).await;
|
||||
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
|
||||
ip_list
|
||||
})
|
||||
.await?;
|
||||
|
||||
Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
connector::direct::{
|
||||
DirectConnectorManager, DirectConnectorManagerData, DstBlackListItem,
|
||||
DstSchemeBlackListItem,
|
||||
},
|
||||
instance::listeners::ListenerManager,
|
||||
peers::tests::{
|
||||
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
|
||||
wait_route_appear_with_cost,
|
||||
},
|
||||
rpc::peer::GetIpListResponse,
|
||||
};
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
async fn direct_connector_basic_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let p_b = create_mock_peer_manager().await;
|
||||
let p_c = create_mock_peer_manager().await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut dm_a = DirectConnectorManager::new(p_a.get_global_ctx(), p_a.clone());
|
||||
let mut dm_c = DirectConnectorManager::new(p_c.get_global_ctx(), p_c.clone());
|
||||
|
||||
dm_a.run_as_client();
|
||||
dm_c.run_as_server();
|
||||
|
||||
let port = if proto == "wg" { 11040 } else { 11041 };
|
||||
p_c.get_global_ctx()
|
||||
.config
|
||||
.set_listeners(vec![format!("{}://0.0.0.0:{}", proto, port)
|
||||
.parse()
|
||||
.unwrap()]);
|
||||
let mut lis_c = ListenerManager::new(p_c.get_global_ctx(), p_c.clone());
|
||||
lis_c.prepare_listeners().await.unwrap();
|
||||
|
||||
lis_c.run().await.unwrap();
|
||||
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connector_scheme_blacklist() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let data = Arc::new(DirectConnectorManagerData::new(
|
||||
p_a.get_global_ctx(),
|
||||
p_a.clone(),
|
||||
));
|
||||
let mut ip_list = GetIpListResponse::new();
|
||||
ip_list
|
||||
.listeners
|
||||
.push("tcp://127.0.0.1:10222".parse().unwrap());
|
||||
|
||||
ip_list.interface_ipv4s.push("127.0.0.1".to_string());
|
||||
|
||||
DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(data
|
||||
.dst_sceme_blacklist
|
||||
.contains(&DstSchemeBlackListItem(1, "tcp".into())));
|
||||
|
||||
assert!(data
|
||||
.dst_blacklist
|
||||
.contains(&DstBlackListItem(1, ip_list.listeners[0].to_string())));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,390 @@
|
||||
use std::{collections::BTreeSet, sync::Arc};
|
||||
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use tokio::{
|
||||
sync::{broadcast::Receiver, mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
use crate::{common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
netns::NetNS,
|
||||
},
|
||||
connector::set_bind_addr_for_peer_connector,
|
||||
peers::peer_manager::PeerManager,
|
||||
rpc::{
|
||||
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
|
||||
ListConnectorRequest, ManageConnectorRequest,
|
||||
},
|
||||
tunnels::{Tunnel, TunnelConnector},
|
||||
use_global_var,
|
||||
};
|
||||
|
||||
use super::create_connector_by_url;
|
||||
|
||||
type ConnectorMap = Arc<DashMap<String, Box<dyn TunnelConnector + Send + Sync>>>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ReconnResult {
|
||||
dead_url: String,
|
||||
peer_id: PeerId,
|
||||
conn_id: PeerConnId,
|
||||
}
|
||||
|
||||
struct ConnectorManagerData {
|
||||
connectors: ConnectorMap,
|
||||
reconnecting: DashSet<String>,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
alive_conn_urls: Arc<Mutex<BTreeSet<String>>>,
|
||||
// user removed connector urls
|
||||
removed_conn_urls: Arc<DashSet<String>>,
|
||||
net_ns: NetNS,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
}
|
||||
|
||||
pub struct ManualConnectorManager {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
data: Arc<ConnectorManagerData>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl ManualConnectorManager {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
|
||||
let connectors = Arc::new(DashMap::new());
|
||||
let tasks = JoinSet::new();
|
||||
let event_subscriber = global_ctx.subscribe();
|
||||
|
||||
let mut ret = Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
data: Arc::new(ConnectorManagerData {
|
||||
connectors,
|
||||
reconnecting: DashSet::new(),
|
||||
peer_manager,
|
||||
alive_conn_urls: Arc::new(Mutex::new(BTreeSet::new())),
|
||||
removed_conn_urls: Arc::new(DashSet::new()),
|
||||
net_ns: global_ctx.net_ns.clone(),
|
||||
global_ctx,
|
||||
}),
|
||||
tasks,
|
||||
};
|
||||
|
||||
ret.tasks
|
||||
.spawn(Self::conn_mgr_routine(ret.data.clone(), event_subscriber));
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn add_connector<T>(&self, connector: T)
|
||||
where
|
||||
T: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
log::info!("add_connector: {}", connector.remote_url());
|
||||
self.data
|
||||
.connectors
|
||||
.insert(connector.remote_url().into(), Box::new(connector));
|
||||
}
|
||||
|
||||
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
|
||||
self.add_connector(create_connector_by_url(url, &self.global_ctx).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> {
|
||||
log::info!("remove_connector: {}", url);
|
||||
if !self.list_connectors().await.iter().any(|x| x.url == url) {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
self.data.removed_conn_urls.insert(url.into());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_connectors(&self) -> Vec<Connector> {
|
||||
let conn_urls: BTreeSet<String> = self
|
||||
.data
|
||||
.connectors
|
||||
.iter()
|
||||
.map(|x| x.key().clone().into())
|
||||
.collect();
|
||||
|
||||
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
|
||||
.await
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let mut ret = Vec::new();
|
||||
|
||||
for conn_url in conn_urls {
|
||||
let mut status = ConnectorStatus::Connected;
|
||||
if dead_urls.contains(&conn_url) {
|
||||
status = ConnectorStatus::Disconnected;
|
||||
}
|
||||
ret.insert(
|
||||
0,
|
||||
Connector {
|
||||
url: conn_url,
|
||||
status: status.into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let reconnecting_urls: BTreeSet<String> = self
|
||||
.data
|
||||
.reconnecting
|
||||
.iter()
|
||||
.map(|x| x.clone().into())
|
||||
.collect();
|
||||
|
||||
for conn_url in reconnecting_urls {
|
||||
ret.insert(
|
||||
0,
|
||||
Connector {
|
||||
url: conn_url,
|
||||
status: ConnectorStatus::Connecting.into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn conn_mgr_routine(
|
||||
data: Arc<ConnectorManagerData>,
|
||||
mut event_recv: Receiver<GlobalCtxEvent>,
|
||||
) {
|
||||
log::warn!("conn_mgr_routine started");
|
||||
let mut reconn_interval = tokio::time::interval(std::time::Duration::from_millis(
|
||||
use_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS),
|
||||
));
|
||||
let mut reconn_tasks = JoinSet::new();
|
||||
let (reconn_result_send, mut reconn_result_recv) = mpsc::channel(100);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = event_recv.recv() => {
|
||||
if let Ok(event) = event {
|
||||
Self::handle_event(&event, data.clone()).await;
|
||||
} else {
|
||||
log::warn!("event_recv closed");
|
||||
panic!("event_recv closed");
|
||||
}
|
||||
}
|
||||
|
||||
_ = reconn_interval.tick() => {
|
||||
let dead_urls = Self::collect_dead_conns(data.clone()).await;
|
||||
if dead_urls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for dead_url in dead_urls {
|
||||
let data_clone = data.clone();
|
||||
let sender = reconn_result_send.clone();
|
||||
let (_, connector) = data.connectors.remove(&dead_url).unwrap();
|
||||
let insert_succ = data.reconnecting.insert(dead_url.clone());
|
||||
assert!(insert_succ);
|
||||
reconn_tasks.spawn(async move {
|
||||
sender.send(Self::conn_reconnect(data_clone.clone(), dead_url, connector).await).await.unwrap();
|
||||
});
|
||||
}
|
||||
log::info!("reconn_interval tick, done");
|
||||
}
|
||||
|
||||
ret = reconn_result_recv.recv() => {
|
||||
log::warn!("reconn_tasks done, out: {:?}", ret);
|
||||
let _ = reconn_tasks.join_next().await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_event(event: &GlobalCtxEvent, data: Arc<ConnectorManagerData>) {
|
||||
match event {
|
||||
GlobalCtxEvent::PeerConnAdded(conn_info) => {
|
||||
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
|
||||
data.alive_conn_urls.lock().await.insert(addr);
|
||||
log::warn!("peer conn added: {:?}", conn_info);
|
||||
}
|
||||
|
||||
GlobalCtxEvent::PeerConnRemoved(conn_info) => {
|
||||
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
|
||||
data.alive_conn_urls.lock().await.remove(&addr);
|
||||
log::warn!("peer conn removed: {:?}", conn_info);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_remove_connector(data: Arc<ConnectorManagerData>) {
|
||||
let remove_later = DashSet::new();
|
||||
for it in data.removed_conn_urls.iter() {
|
||||
let url = it.key();
|
||||
if let Some(_) = data.connectors.remove(url) {
|
||||
log::warn!("connector: {}, removed", url);
|
||||
continue;
|
||||
} else if data.reconnecting.contains(url) {
|
||||
log::warn!("connector: {}, reconnecting, remove later.", url);
|
||||
remove_later.insert(url.clone());
|
||||
continue;
|
||||
} else {
|
||||
log::warn!("connector: {}, not found", url);
|
||||
}
|
||||
}
|
||||
data.removed_conn_urls.clear();
|
||||
for it in remove_later.iter() {
|
||||
data.removed_conn_urls.insert(it.key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
|
||||
Self::handle_remove_connector(data.clone());
|
||||
|
||||
let curr_alive = data.alive_conn_urls.lock().await.clone();
|
||||
let all_urls: BTreeSet<String> = data
|
||||
.connectors
|
||||
.iter()
|
||||
.map(|x| x.key().clone().into())
|
||||
.collect();
|
||||
&all_urls - &curr_alive
|
||||
}
|
||||
|
||||
async fn conn_reconnect(
|
||||
data: Arc<ConnectorManagerData>,
|
||||
dead_url: String,
|
||||
connector: Box<dyn TunnelConnector + Send + Sync>,
|
||||
) -> Result<ReconnResult, Error> {
|
||||
let connector = Arc::new(Mutex::new(Some(connector)));
|
||||
let net_ns = data.net_ns.clone();
|
||||
|
||||
log::info!("reconnect: {}", dead_url);
|
||||
|
||||
let connector_clone = connector.clone();
|
||||
let data_clone = data.clone();
|
||||
let url_clone = dead_url.clone();
|
||||
let ip_collector = data.global_ctx.get_ip_collector();
|
||||
let reconn_task = async move {
|
||||
let mut locked = connector_clone.lock().await;
|
||||
let conn = locked.as_mut().unwrap();
|
||||
// TODO: should support set v6 here, use url in connector array
|
||||
set_bind_addr_for_peer_connector(conn, true, &ip_collector).await;
|
||||
|
||||
data_clone
|
||||
.global_ctx
|
||||
.issue_event(GlobalCtxEvent::Connecting(conn.remote_url().clone()));
|
||||
|
||||
let _g = net_ns.guard();
|
||||
log::info!("reconnect try connect... conn: {:?}", conn);
|
||||
let tunnel = conn.connect().await?;
|
||||
log::info!("reconnect get tunnel succ: {:?}", tunnel);
|
||||
assert_eq!(
|
||||
url_clone,
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
"info: {:?}",
|
||||
tunnel.info()
|
||||
);
|
||||
let (peer_id, conn_id) = data_clone.peer_manager.add_client_tunnel(tunnel).await?;
|
||||
log::info!("reconnect succ: {} {} {}", peer_id, conn_id, url_clone);
|
||||
Ok(ReconnResult {
|
||||
dead_url: url_clone,
|
||||
peer_id,
|
||||
conn_id,
|
||||
})
|
||||
};
|
||||
|
||||
let ret = timeout(std::time::Duration::from_secs(1), reconn_task).await;
|
||||
log::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
|
||||
|
||||
if ret.is_err() || ret.as_ref().unwrap().is_err() {
|
||||
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
|
||||
dead_url.clone(),
|
||||
format!("{:?}", ret),
|
||||
));
|
||||
}
|
||||
|
||||
let conn = connector.lock().await.take().unwrap();
|
||||
data.reconnecting.remove(&dead_url).unwrap();
|
||||
data.connectors.insert(dead_url.clone(), conn);
|
||||
|
||||
ret?
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl ConnectorManageRpc for ConnectorManagerRpcService {
|
||||
async fn list_connector(
|
||||
&self,
|
||||
_request: tonic::Request<ListConnectorRequest>,
|
||||
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> {
|
||||
let mut ret = easytier_rpc::ListConnectorResponse::default();
|
||||
let connectors = self.0.list_connectors().await;
|
||||
ret.connectors = connectors;
|
||||
Ok(tonic::Response::new(ret))
|
||||
}
|
||||
|
||||
async fn manage_connector(
|
||||
&self,
|
||||
request: tonic::Request<ManageConnectorRequest>,
|
||||
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> {
|
||||
let req = request.into_inner();
|
||||
let url = url::Url::parse(&req.url)
|
||||
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?;
|
||||
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 {
|
||||
self.0.remove_connector(url.path()).await.map_err(|e| {
|
||||
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e))
|
||||
})?;
|
||||
return Ok(tonic::Response::new(
|
||||
easytier_rpc::ManageConnectorResponse::default(),
|
||||
));
|
||||
} else {
|
||||
self.0
|
||||
.add_connector_by_url(url.as_str())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
|
||||
})?;
|
||||
}
|
||||
Ok(tonic::Response::new(
|
||||
easytier_rpc::ManageConnectorResponse::default(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
peers::tests::create_mock_peer_manager,
|
||||
set_global_var,
|
||||
tunnels::{Tunnel, TunnelError},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reconnect_with_connecting_addr() {
|
||||
set_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, 1);
|
||||
|
||||
let peer_mgr = create_mock_peer_manager().await;
|
||||
let mgr = ManualConnectorManager::new(peer_mgr.get_global_ctx(), peer_mgr);
|
||||
|
||||
struct MockConnector {}
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelConnector for MockConnector {
|
||||
fn remote_url(&self) -> url::Url {
|
||||
url::Url::parse("tcp://aa.com").unwrap()
|
||||
}
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
Err(TunnelError::CommonError("fake error".into()))
|
||||
}
|
||||
}
|
||||
|
||||
mgr.add_connector(MockConnector {});
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
use std::{
|
||||
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
|
||||
tunnels::{
|
||||
ring_tunnel::RingTunnelConnector,
|
||||
tcp_tunnel::TcpTunnelConnector,
|
||||
udp_tunnel::UdpTunnelConnector,
|
||||
wireguard::{WgConfig, WgTunnelConnector},
|
||||
TunnelConnector,
|
||||
},
|
||||
};
|
||||
|
||||
pub mod direct;
|
||||
pub mod manual;
|
||||
pub mod udp_hole_punch;
|
||||
|
||||
async fn set_bind_addr_for_peer_connector(
|
||||
connector: &mut impl TunnelConnector,
|
||||
is_ipv4: bool,
|
||||
ip_collector: &Arc<IPCollector>,
|
||||
) {
|
||||
let ips = ip_collector.collect_ip_addrs().await;
|
||||
if is_ipv4 {
|
||||
let mut bind_addrs = vec![];
|
||||
for ipv4 in ips.interface_ipv4s {
|
||||
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into();
|
||||
bind_addrs.push(socket_addr);
|
||||
}
|
||||
connector.set_bind_addrs(bind_addrs);
|
||||
} else {
|
||||
let mut bind_addrs = vec![];
|
||||
for ipv6 in ips.interface_ipv6s {
|
||||
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into();
|
||||
bind_addrs.push(socket_addr);
|
||||
}
|
||||
connector.set_bind_addrs(bind_addrs);
|
||||
}
|
||||
let _ = connector;
|
||||
}
|
||||
|
||||
pub async fn create_connector_by_url(
|
||||
url: &str,
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
) -> Result<Box<dyn TunnelConnector + Send + Sync + 'static>, Error> {
|
||||
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
||||
match url.scheme() {
|
||||
"tcp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
|
||||
let mut connector = TcpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
dst_addr.is_ipv4(),
|
||||
&global_ctx.get_ip_collector(),
|
||||
)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"udp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
|
||||
let mut connector = UdpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
dst_addr.is_ipv4(),
|
||||
&global_ctx.get_ip_collector(),
|
||||
)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"ring" => {
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
|
||||
let connector = RingTunnelConnector::new(url);
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"wg" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
|
||||
let nid = global_ctx.get_network_identity();
|
||||
let wg_config =
|
||||
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
|
||||
let mut connector = WgTunnelConnector::new(url, wg_config);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
dst_addr.is_ipv4(),
|
||||
&global_ctx.get_ip_collector(),
|
||||
)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidUrl(url.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,545 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use rand::{seq::SliceRandom, Rng, SeedableRng};
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
|
||||
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
|
||||
},
|
||||
peers::peer_manager::PeerManager,
|
||||
rpc::NatType,
|
||||
tunnels::{
|
||||
common::setup_sokcet2,
|
||||
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
|
||||
Tunnel, TunnelConnCounter, TunnelListener,
|
||||
},
|
||||
};
|
||||
|
||||
use super::direct::PeerManagerForDirectConnector;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait UdpHolePunchService {
|
||||
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UdpHolePunchListener {
|
||||
socket: Arc<UdpSocket>,
|
||||
tasks: JoinSet<()>,
|
||||
running: Arc<AtomicCell<bool>>,
|
||||
mapped_addr: SocketAddr,
|
||||
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
|
||||
|
||||
listen_time: std::time::Instant,
|
||||
last_select_time: AtomicCell<std::time::Instant>,
|
||||
last_connected_time: Arc<AtomicCell<std::time::Instant>>,
|
||||
}
|
||||
|
||||
impl UdpHolePunchListener {
|
||||
async fn get_avail_port() -> Result<u16, Error> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
Ok(socket.local_addr()?.port())
|
||||
}
|
||||
|
||||
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
|
||||
let port = Self::get_avail_port().await?;
|
||||
let listen_url = format!("udp://0.0.0.0:{}", port);
|
||||
|
||||
let gctx = peer_mgr.get_global_ctx();
|
||||
let stun_info_collect = gctx.get_stun_info_collector();
|
||||
let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?;
|
||||
|
||||
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
|
||||
|
||||
{
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
listener.listen().await?;
|
||||
}
|
||||
let socket = listener.get_socket().unwrap();
|
||||
|
||||
let running = Arc::new(AtomicCell::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
|
||||
let last_connected_time_clone = last_connected_time.clone();
|
||||
|
||||
let conn_counter = listener.get_conn_counter();
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
tasks.spawn(async move {
|
||||
while let Ok(conn) = listener.accept().await {
|
||||
last_connected_time_clone.store(std::time::Instant::now());
|
||||
tracing::warn!(?conn, "udp hole punching listener got peer connection");
|
||||
let peer_mgr = peer_mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
||||
tracing::error!(
|
||||
?e,
|
||||
"failed to add tunnel as server in hole punch listener"
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
running_clone.store(false);
|
||||
});
|
||||
|
||||
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
|
||||
|
||||
Ok(Self {
|
||||
tasks,
|
||||
socket,
|
||||
running,
|
||||
mapped_addr,
|
||||
conn_counter,
|
||||
|
||||
listen_time: std::time::Instant::now(),
|
||||
last_select_time: AtomicCell::new(std::time::Instant::now()),
|
||||
last_connected_time,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_socket(&self) -> Arc<UdpSocket> {
|
||||
self.last_select_time.store(std::time::Instant::now());
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UdpHolePunchConnectorData {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct UdpHolePunchRpcServer {
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl UdpHolePunchService for UdpHolePunchRpcServer {
|
||||
async fn try_punch_hole(
|
||||
self,
|
||||
_: tarpc::context::Context,
|
||||
local_mapped_addr: SocketAddr,
|
||||
) -> Option<SocketAddr> {
|
||||
let (socket, mapped_addr) = self.select_listener().await?;
|
||||
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
|
||||
|
||||
let my_udp_nat_type = self
|
||||
.data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type;
|
||||
|
||||
// if we are restricted, we need to send hole punching resp to client
|
||||
if my_udp_nat_type == NatType::PortRestricted as i32
|
||||
|| my_udp_nat_type == NatType::Restricted as i32
|
||||
{
|
||||
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
for _ in 0..10 {
|
||||
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
||||
// generate a 128 bytes vec with random data
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
let mut buf = vec![0u8; 128];
|
||||
rng.fill(&mut buf[..]);
|
||||
|
||||
let udp_packet = UdpPacket::new_hole_punch_packet(buf);
|
||||
let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet);
|
||||
let _ = socket
|
||||
.send_to(udp_packet_bytes.as_ref(), local_mapped_addr)
|
||||
.await;
|
||||
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Some(mapped_addr)
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpHolePunchRpcServer {
|
||||
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
|
||||
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
|
||||
Self { data, tasks }
|
||||
}
|
||||
|
||||
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
|
||||
let all_listener_sockets = &self.data.listeners;
|
||||
|
||||
// remove listener that not have connection in for 20 seconds
|
||||
all_listener_sockets.lock().await.retain(|listener| {
|
||||
listener.last_connected_time.load().elapsed().as_secs() < 20
|
||||
&& listener.conn_counter.get() > 0
|
||||
});
|
||||
|
||||
let mut use_last = false;
|
||||
if all_listener_sockets.lock().await.len() < 4 {
|
||||
tracing::warn!("creating new udp hole punching listener");
|
||||
all_listener_sockets.lock().await.push(
|
||||
UdpHolePunchListener::new(self.data.peer_mgr.clone())
|
||||
.await
|
||||
.ok()?,
|
||||
);
|
||||
use_last = true;
|
||||
}
|
||||
|
||||
let locked = all_listener_sockets.lock().await;
|
||||
|
||||
let listener = if use_last {
|
||||
locked.last()?
|
||||
} else {
|
||||
locked.choose(&mut rand::rngs::StdRng::from_entropy())?
|
||||
};
|
||||
|
||||
Some((listener.get_socket().await, listener.mapped_addr))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpHolePunchConnector {
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
// Currently support:
|
||||
// Symmetric -> Full Cone
|
||||
// Any Type of Full Cone -> Any Type of Full Cone
|
||||
|
||||
// if same level of full cone, node with smaller peer_id will be the initiator
|
||||
// if different level of full cone, node with more strict level will be the initiator
|
||||
|
||||
impl UdpHolePunchConnector {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
data: Arc::new(UdpHolePunchConnectorData {
|
||||
global_ctx,
|
||||
peer_mgr,
|
||||
listeners: Arc::new(Mutex::new(Vec::new())),
|
||||
}),
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_as_client(&mut self) -> Result<(), Error> {
|
||||
let data = self.data.clone();
|
||||
self.tasks.spawn(async move {
|
||||
Self::main_loop(data).await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_as_server(&mut self) -> Result<(), Error> {
|
||||
self.data.peer_mgr.get_peer_rpc_mgr().run_service(
|
||||
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
|
||||
UdpHolePunchRpcServer::new(self.data.clone()).serve(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
self.run_as_client().await?;
|
||||
self.run_as_server().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_peer_to_connect(data: Arc<UdpHolePunchConnectorData>) -> Vec<PeerId> {
|
||||
let mut peers_to_connect = Vec::new();
|
||||
|
||||
// do not do anything if:
|
||||
// 1. our stun test has not finished
|
||||
// 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
|
||||
let my_nat_type = data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type;
|
||||
|
||||
let my_nat_type = NatType::try_from(my_nat_type).unwrap();
|
||||
|
||||
if my_nat_type == NatType::Unknown
|
||||
|| my_nat_type == NatType::OpenInternet
|
||||
|| my_nat_type == NatType::NoPat
|
||||
{
|
||||
return peers_to_connect;
|
||||
}
|
||||
|
||||
// collect peer list from peer manager and do some filter:
|
||||
// 1. peers without direct conns;
|
||||
// 2. peers is full cone (any restricted type);
|
||||
for route in data.peer_mgr.list_routes().await.iter() {
|
||||
let Some(peer_stun_info) = route.stun_info.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let peer_id: PeerId = route.peer_id;
|
||||
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
|
||||
if conns.is_some() && conns.unwrap().len() > 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if peer is symmetric ignore it because we cannot connect to it
|
||||
// if peer is open internet or no pat, direct connector will connecto to it
|
||||
if peer_nat_type == NatType::Unknown
|
||||
|| peer_nat_type == NatType::OpenInternet
|
||||
|| peer_nat_type == NatType::NoPat
|
||||
|| peer_nat_type == NatType::Symmetric
|
||||
|| peer_nat_type == NatType::SymUdpFirewall
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// if we are symmetric, we can only connect to full cone
|
||||
// TODO: can also connect to restricted full cone, with some extra work
|
||||
if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall)
|
||||
&& peer_nat_type != NatType::FullCone
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// if we have smae level of full cone, node with smaller peer_id will be the initiator
|
||||
if my_nat_type == peer_nat_type {
|
||||
if data.peer_mgr.my_peer_id() > peer_id {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// if we have different level of full cone
|
||||
// we will be the initiator if we have more strict level
|
||||
if my_nat_type < peer_nat_type {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
?peer_id,
|
||||
?peer_nat_type,
|
||||
?my_nat_type,
|
||||
?data.global_ctx.id,
|
||||
"found peer to do hole punching"
|
||||
);
|
||||
|
||||
peers_to_connect.push(peer_id);
|
||||
}
|
||||
|
||||
peers_to_connect
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_hole_punching(
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
|
||||
tracing::info!(?dst_peer_id, "start hole punching");
|
||||
// client: choose a local udp port, and get the pubic mapped port from stun server
|
||||
let socket = {
|
||||
let _g = data.global_ctx.net_ns.guard();
|
||||
UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")?
|
||||
};
|
||||
let local_socket_addr = socket.local_addr()?;
|
||||
let local_port = socket.local_addr()?.port();
|
||||
drop(socket); // drop the socket to release the port
|
||||
|
||||
let local_mapped_addr = data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_udp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get udp port mapping")?;
|
||||
|
||||
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
|
||||
let Some(remote_mapped_addr) = data
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(
|
||||
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
|
||||
dst_peer_id,
|
||||
|c| async {
|
||||
let client =
|
||||
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let remote_mapped_addr = client
|
||||
.try_punch_hole(tarpc::context::current(), local_mapped_addr)
|
||||
.await;
|
||||
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr");
|
||||
remote_mapped_addr
|
||||
},
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
|
||||
};
|
||||
|
||||
// server: will send some punching resps, total 10 packets.
|
||||
// client: use the socket to create UdpTunnel with UdpTunnelConnector
|
||||
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
|
||||
|
||||
let connector = UdpTunnelConnector::new(
|
||||
format!(
|
||||
"udp://{}:{}",
|
||||
remote_mapped_addr.ip(),
|
||||
remote_mapped_addr.port()
|
||||
)
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let _g = data.global_ctx.net_ns.guard();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(local_socket_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &local_socket_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
|
||||
Ok(connector
|
||||
.try_connect_with_socket(socket)
|
||||
.await
|
||||
.with_context(|| "UdpTunnelConnector failed to connect remote")?)
|
||||
}
|
||||
|
||||
async fn main_loop(data: Arc<UdpHolePunchConnectorData>) {
|
||||
loop {
|
||||
let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await;
|
||||
tracing::trace!(?peers_to_connect, "peers to connect");
|
||||
if peers_to_connect.len() == 0 {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut tasks: JoinSet<Result<(), anyhow::Error>> = JoinSet::new();
|
||||
for peer_id in peers_to_connect {
|
||||
let data = data.clone();
|
||||
tasks.spawn(
|
||||
async move {
|
||||
let tunnel = Self::do_hole_punching(data.clone(), peer_id)
|
||||
.await
|
||||
.with_context(|| "failed to do hole punching")?;
|
||||
|
||||
let _ =
|
||||
data.peer_mgr
|
||||
.add_client_tunnel(tunnel)
|
||||
.await
|
||||
.with_context(|| {
|
||||
"failed to add tunnel as client in hole punch connector"
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.instrument(tracing::info_span!("doing hole punching client", ?peer_id)),
|
||||
);
|
||||
}
|
||||
|
||||
while let Some(res) = tasks.join_next().await {
|
||||
if let Err(e) = res {
|
||||
tracing::error!(?e, "failed to join hole punching job");
|
||||
continue;
|
||||
}
|
||||
|
||||
match res.unwrap() {
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "failed to do hole punching job");
|
||||
}
|
||||
Ok(_) => {
|
||||
tracing::info!("hole punching job succeed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::rpc::{NatType, StunInfo};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, stun::StunInfoCollectorTrait},
|
||||
connector::udp_hole_punch::UdpHolePunchConnector,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
tests::{
|
||||
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
|
||||
wait_route_appear_with_cost,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
struct MockStunInfoCollector {
|
||||
udp_nat_type: NatType,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StunInfoCollectorTrait for MockStunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
StunInfo {
|
||||
udp_nat_type: self.udp_nat_type as i32,
|
||||
tcp_nat_type: NatType::Unknown as i32,
|
||||
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_udp_port_mapping(&self, port: u16) -> Result<std::net::SocketAddr, Error> {
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
|
||||
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.replace_stun_info_collector(collector);
|
||||
}
|
||||
|
||||
pub async fn create_mock_peer_manager_with_mock_stun(
|
||||
udp_nat_type: NatType,
|
||||
) -> Arc<PeerManager> {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
replace_stun_info_collector(p_a.clone(), udp_nat_type);
|
||||
p_a
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hole_punching() {
|
||||
let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
|
||||
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
|
||||
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone());
|
||||
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone());
|
||||
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
|
||||
.await
|
||||
.unwrap();
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user