use workspace, prepare for config server and gui (#48)

This commit is contained in:
Sijie.Sun
2024-04-04 10:33:53 +08:00
committed by GitHub
parent bb4ae71869
commit 4eb7efe5fc
77 changed files with 162 additions and 195 deletions
+411
View File
@@ -0,0 +1,411 @@
// try connect peers directly, with either its public ip or lan ip
use std::sync::Arc;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager},
};
use crate::rpc::{peer::GetIpListResponse, PeerConnInfo};
use tokio::{task::JoinSet, time::timeout};
use tracing::Instrument;
use super::create_connector_by_url;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
#[tarpc::service]
pub trait DirectConnectorRpc {
async fn get_ip_list() -> GetIpListResponse;
}
#[async_trait::async_trait]
pub trait PeerManagerForDirectConnector {
async fn list_peers(&self) -> Vec<PeerId>;
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>>;
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager>;
}
#[async_trait::async_trait]
impl PeerManagerForDirectConnector for PeerManager {
async fn list_peers(&self) -> Vec<PeerId> {
let mut ret = vec![];
let routes = self.list_routes().await;
for r in routes.iter() {
ret.push(r.peer_id);
}
ret
}
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
self.get_peer_map().list_peer_conns(peer_id).await
}
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
self.get_peer_rpc_mgr()
}
}
#[derive(Clone)]
struct DirectConnectorManagerRpcServer {
// TODO: this only cache for one src peer, should make it global
global_ctx: ArcGlobalCtx,
}
#[tarpc::server]
impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse {
let mut ret = self.global_ctx.get_ip_collector().collect_ip_addrs().await;
ret.listeners = self.global_ctx.get_running_listeners();
ret
}
}
impl DirectConnectorManagerRpcServer {
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
Self { global_ctx }
}
}
#[derive(Hash, Eq, PartialEq, Clone)]
struct DstBlackListItem(PeerId, String);
#[derive(Hash, Eq, PartialEq, Clone)]
struct DstSchemeBlackListItem(PeerId, String);
struct DirectConnectorManagerData {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
dst_blacklist: timedmap::TimedMap<DstBlackListItem, ()>,
dst_sceme_blacklist: timedmap::TimedMap<DstSchemeBlackListItem, ()>,
}
impl DirectConnectorManagerData {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
Self {
global_ctx,
peer_manager,
dst_blacklist: timedmap::TimedMap::new(),
dst_sceme_blacklist: timedmap::TimedMap::new(),
}
}
}
impl std::fmt::Debug for DirectConnectorManagerData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DirectConnectorManagerData")
.field("peer_manager", &self.peer_manager)
.finish()
}
}
pub struct DirectConnectorManager {
global_ctx: ArcGlobalCtx,
data: Arc<DirectConnectorManagerData>,
tasks: JoinSet<()>,
}
impl DirectConnectorManager {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
Self {
global_ctx: global_ctx.clone(),
data: Arc::new(DirectConnectorManagerData::new(global_ctx, peer_manager)),
tasks: JoinSet::new(),
}
}
pub fn run(&mut self) {
self.run_as_server();
self.run_as_client();
}
pub fn run_as_server(&mut self) {
self.data.peer_manager.get_peer_rpc_mgr().run_service(
DIRECT_CONNECTOR_SERVICE_ID,
DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(),
);
}
pub fn run_as_client(&mut self) {
let data = self.data.clone();
let my_peer_id = self.data.peer_manager.my_peer_id();
self.tasks.spawn(
async move {
loop {
let peers = data.peer_manager.list_peers().await;
let mut tasks = JoinSet::new();
for peer_id in peers {
if peer_id == my_peer_id {
continue;
}
tasks.spawn(Self::do_try_direct_connect(data.clone(), peer_id));
}
while let Some(task_ret) = tasks.join_next().await {
tracing::trace!(?task_ret, "direct connect task ret");
}
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
.instrument(
tracing::info_span!("direct_connector_client", my_id = ?self.global_ctx.id),
),
);
}
async fn do_try_connect_to_ip(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
addr: String,
) -> Result<(), Error> {
data.dst_blacklist.cleanup();
if data
.dst_blacklist
.contains(&DstBlackListItem(dst_peer_id.clone(), addr.clone()))
{
tracing::trace!("try_connect_to_ip failed, addr in blacklist: {}", addr);
return Err(Error::UrlInBlacklist);
}
let connector = create_connector_by_url(&addr, &data.global_ctx).await?;
let (peer_id, conn_id) = timeout(
std::time::Duration::from_secs(5),
data.peer_manager.try_connect(connector),
)
.await??;
// let (peer_id, conn_id) = data.peer_manager.try_connect(connector).await?;
if peer_id != dst_peer_id {
tracing::info!(
"connect to ip succ: {}, but peer id mismatch, expect: {}, actual: {}",
addr,
dst_peer_id,
peer_id
);
data.peer_manager
.get_peer_map()
.close_peer_conn(peer_id, &conn_id)
.await?;
return Err(Error::InvalidUrl(addr));
}
Ok(())
}
#[tracing::instrument]
async fn try_connect_to_ip(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
addr: String,
) -> Result<(), Error> {
let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await;
if let Err(e) = ret {
if !matches!(e, Error::UrlInBlacklist) {
tracing::info!(
"try_connect_to_ip failed: {:?}, peer_id: {}",
e,
dst_peer_id
);
data.dst_blacklist.insert(
DstBlackListItem(dst_peer_id.clone(), addr.clone()),
(),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
}
return Err(e);
} else {
log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id);
return Ok(());
}
}
#[tracing::instrument]
async fn do_try_direct_connect_internal(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
ip_list: GetIpListResponse,
) -> Result<(), Error> {
let available_listeners = ip_list
.listeners
.iter()
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
.filter(|l| l.port().is_some())
.filter(|l| {
!data.dst_sceme_blacklist.contains(&DstSchemeBlackListItem(
dst_peer_id.clone(),
l.scheme().to_string(),
))
})
.collect::<Vec<_>>();
let mut listener = available_listeners.get(0).ok_or(anyhow::anyhow!(
"peer {} have no valid listener",
dst_peer_id
))?;
// if have default listener, use it first
listener = available_listeners
.iter()
.find(|l| l.scheme() == data.global_ctx.get_flags().default_protocol)
.unwrap_or(listener);
let mut tasks = JoinSet::new();
ip_list.interface_ipv4s.iter().for_each(|ip| {
let addr = format!(
"{}://{}:{}",
listener.scheme(),
ip,
listener.port().unwrap_or(11010)
);
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr,
));
});
let addr = format!(
"{}://{}:{}",
listener.scheme(),
ip_list.public_ipv4.clone(),
listener.port().unwrap_or(11010)
);
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr,
));
let mut has_succ = false;
while let Some(ret) = tasks.join_next().await {
if let Err(e) = ret {
log::error!("join direct connect task failed: {:?}", e);
} else if let Ok(Ok(_)) = ret {
has_succ = true;
}
}
if !has_succ {
data.dst_sceme_blacklist.insert(
DstSchemeBlackListItem(dst_peer_id.clone(), listener.scheme().to_string()),
(),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
}
Ok(())
}
#[tracing::instrument]
async fn do_try_direct_connect(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let peer_manager = data.peer_manager.clone();
// check if we have direct connection with dst_peer_id
if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await {
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
if !c.is_empty() {
return Ok(());
}
}
log::trace!("try direct connect to peer: {}", dst_peer_id);
let ip_list = peer_manager
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
let client =
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
let ip_list = client.get_ip_list(tarpc::context::current()).await;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
ip_list
})
.await?;
Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
connector::direct::{
DirectConnectorManager, DirectConnectorManagerData, DstBlackListItem,
DstSchemeBlackListItem,
},
instance::listeners::ListenerManager,
peers::tests::{
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
rpc::peer::GetIpListResponse,
};
#[rstest::rstest]
#[tokio::test]
async fn direct_connector_basic_test(#[values("tcp", "udp", "wg")] proto: &str) {
let p_a = create_mock_peer_manager().await;
let p_b = create_mock_peer_manager().await;
let p_c = create_mock_peer_manager().await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
let mut dm_a = DirectConnectorManager::new(p_a.get_global_ctx(), p_a.clone());
let mut dm_c = DirectConnectorManager::new(p_c.get_global_ctx(), p_c.clone());
dm_a.run_as_client();
dm_c.run_as_server();
let port = if proto == "wg" { 11040 } else { 11041 };
p_c.get_global_ctx()
.config
.set_listeners(vec![format!("{}://0.0.0.0:{}", proto, port)
.parse()
.unwrap()]);
let mut lis_c = ListenerManager::new(p_c.get_global_ctx(), p_c.clone());
lis_c.prepare_listeners().await.unwrap();
lis_c.run().await.unwrap();
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
.await
.unwrap();
}
#[tokio::test]
async fn direct_connector_scheme_blacklist() {
let p_a = create_mock_peer_manager().await;
let data = Arc::new(DirectConnectorManagerData::new(
p_a.get_global_ctx(),
p_a.clone(),
));
let mut ip_list = GetIpListResponse::new();
ip_list
.listeners
.push("tcp://127.0.0.1:10222".parse().unwrap());
ip_list.interface_ipv4s.push("127.0.0.1".to_string());
DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone())
.await
.unwrap();
assert!(data
.dst_sceme_blacklist
.contains(&DstSchemeBlackListItem(1, "tcp".into())));
assert!(data
.dst_blacklist
.contains(&DstBlackListItem(1, ip_list.listeners[0].to_string())));
}
}
+390
View File
@@ -0,0 +1,390 @@
use std::{collections::BTreeSet, sync::Arc};
use dashmap::{DashMap, DashSet};
use tokio::{
sync::{broadcast::Receiver, mpsc, Mutex},
task::JoinSet,
time::timeout,
};
use crate::{common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc};
use crate::{
common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
netns::NetNS,
},
connector::set_bind_addr_for_peer_connector,
peers::peer_manager::PeerManager,
rpc::{
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
ListConnectorRequest, ManageConnectorRequest,
},
tunnels::{Tunnel, TunnelConnector},
use_global_var,
};
use super::create_connector_by_url;
type ConnectorMap = Arc<DashMap<String, Box<dyn TunnelConnector + Send + Sync>>>;
#[derive(Debug, Clone)]
struct ReconnResult {
dead_url: String,
peer_id: PeerId,
conn_id: PeerConnId,
}
struct ConnectorManagerData {
connectors: ConnectorMap,
reconnecting: DashSet<String>,
peer_manager: Arc<PeerManager>,
alive_conn_urls: Arc<Mutex<BTreeSet<String>>>,
// user removed connector urls
removed_conn_urls: Arc<DashSet<String>>,
net_ns: NetNS,
global_ctx: ArcGlobalCtx,
}
pub struct ManualConnectorManager {
global_ctx: ArcGlobalCtx,
data: Arc<ConnectorManagerData>,
tasks: JoinSet<()>,
}
impl ManualConnectorManager {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
let connectors = Arc::new(DashMap::new());
let tasks = JoinSet::new();
let event_subscriber = global_ctx.subscribe();
let mut ret = Self {
global_ctx: global_ctx.clone(),
data: Arc::new(ConnectorManagerData {
connectors,
reconnecting: DashSet::new(),
peer_manager,
alive_conn_urls: Arc::new(Mutex::new(BTreeSet::new())),
removed_conn_urls: Arc::new(DashSet::new()),
net_ns: global_ctx.net_ns.clone(),
global_ctx,
}),
tasks,
};
ret.tasks
.spawn(Self::conn_mgr_routine(ret.data.clone(), event_subscriber));
ret
}
pub fn add_connector<T>(&self, connector: T)
where
T: TunnelConnector + Send + Sync + 'static,
{
log::info!("add_connector: {}", connector.remote_url());
self.data
.connectors
.insert(connector.remote_url().into(), Box::new(connector));
}
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
self.add_connector(create_connector_by_url(url, &self.global_ctx).await?);
Ok(())
}
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> {
log::info!("remove_connector: {}", url);
if !self.list_connectors().await.iter().any(|x| x.url == url) {
return Err(Error::NotFound);
}
self.data.removed_conn_urls.insert(url.into());
Ok(())
}
pub async fn list_connectors(&self) -> Vec<Connector> {
let conn_urls: BTreeSet<String> = self
.data
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
.await
.into_iter()
.collect();
let mut ret = Vec::new();
for conn_url in conn_urls {
let mut status = ConnectorStatus::Connected;
if dead_urls.contains(&conn_url) {
status = ConnectorStatus::Disconnected;
}
ret.insert(
0,
Connector {
url: conn_url,
status: status.into(),
},
);
}
let reconnecting_urls: BTreeSet<String> = self
.data
.reconnecting
.iter()
.map(|x| x.clone().into())
.collect();
for conn_url in reconnecting_urls {
ret.insert(
0,
Connector {
url: conn_url,
status: ConnectorStatus::Connecting.into(),
},
);
}
ret
}
async fn conn_mgr_routine(
data: Arc<ConnectorManagerData>,
mut event_recv: Receiver<GlobalCtxEvent>,
) {
log::warn!("conn_mgr_routine started");
let mut reconn_interval = tokio::time::interval(std::time::Duration::from_millis(
use_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS),
));
let mut reconn_tasks = JoinSet::new();
let (reconn_result_send, mut reconn_result_recv) = mpsc::channel(100);
loop {
tokio::select! {
event = event_recv.recv() => {
if let Ok(event) = event {
Self::handle_event(&event, data.clone()).await;
} else {
log::warn!("event_recv closed");
panic!("event_recv closed");
}
}
_ = reconn_interval.tick() => {
let dead_urls = Self::collect_dead_conns(data.clone()).await;
if dead_urls.is_empty() {
continue;
}
for dead_url in dead_urls {
let data_clone = data.clone();
let sender = reconn_result_send.clone();
let (_, connector) = data.connectors.remove(&dead_url).unwrap();
let insert_succ = data.reconnecting.insert(dead_url.clone());
assert!(insert_succ);
reconn_tasks.spawn(async move {
sender.send(Self::conn_reconnect(data_clone.clone(), dead_url, connector).await).await.unwrap();
});
}
log::info!("reconn_interval tick, done");
}
ret = reconn_result_recv.recv() => {
log::warn!("reconn_tasks done, out: {:?}", ret);
let _ = reconn_tasks.join_next().await.unwrap();
}
}
}
}
async fn handle_event(event: &GlobalCtxEvent, data: Arc<ConnectorManagerData>) {
match event {
GlobalCtxEvent::PeerConnAdded(conn_info) => {
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.lock().await.insert(addr);
log::warn!("peer conn added: {:?}", conn_info);
}
GlobalCtxEvent::PeerConnRemoved(conn_info) => {
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.lock().await.remove(&addr);
log::warn!("peer conn removed: {:?}", conn_info);
}
_ => {}
}
}
fn handle_remove_connector(data: Arc<ConnectorManagerData>) {
let remove_later = DashSet::new();
for it in data.removed_conn_urls.iter() {
let url = it.key();
if let Some(_) = data.connectors.remove(url) {
log::warn!("connector: {}, removed", url);
continue;
} else if data.reconnecting.contains(url) {
log::warn!("connector: {}, reconnecting, remove later.", url);
remove_later.insert(url.clone());
continue;
} else {
log::warn!("connector: {}, not found", url);
}
}
data.removed_conn_urls.clear();
for it in remove_later.iter() {
data.removed_conn_urls.insert(it.key().clone());
}
}
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
Self::handle_remove_connector(data.clone());
let curr_alive = data.alive_conn_urls.lock().await.clone();
let all_urls: BTreeSet<String> = data
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
&all_urls - &curr_alive
}
async fn conn_reconnect(
data: Arc<ConnectorManagerData>,
dead_url: String,
connector: Box<dyn TunnelConnector + Send + Sync>,
) -> Result<ReconnResult, Error> {
let connector = Arc::new(Mutex::new(Some(connector)));
let net_ns = data.net_ns.clone();
log::info!("reconnect: {}", dead_url);
let connector_clone = connector.clone();
let data_clone = data.clone();
let url_clone = dead_url.clone();
let ip_collector = data.global_ctx.get_ip_collector();
let reconn_task = async move {
let mut locked = connector_clone.lock().await;
let conn = locked.as_mut().unwrap();
// TODO: should support set v6 here, use url in connector array
set_bind_addr_for_peer_connector(conn, true, &ip_collector).await;
data_clone
.global_ctx
.issue_event(GlobalCtxEvent::Connecting(conn.remote_url().clone()));
let _g = net_ns.guard();
log::info!("reconnect try connect... conn: {:?}", conn);
let tunnel = conn.connect().await?;
log::info!("reconnect get tunnel succ: {:?}", tunnel);
assert_eq!(
url_clone,
tunnel.info().unwrap().remote_addr,
"info: {:?}",
tunnel.info()
);
let (peer_id, conn_id) = data_clone.peer_manager.add_client_tunnel(tunnel).await?;
log::info!("reconnect succ: {} {} {}", peer_id, conn_id, url_clone);
Ok(ReconnResult {
dead_url: url_clone,
peer_id,
conn_id,
})
};
let ret = timeout(std::time::Duration::from_secs(1), reconn_task).await;
log::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
if ret.is_err() || ret.as_ref().unwrap().is_err() {
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
dead_url.clone(),
format!("{:?}", ret),
));
}
let conn = connector.lock().await.take().unwrap();
data.reconnecting.remove(&dead_url).unwrap();
data.connectors.insert(dead_url.clone(), conn);
ret?
}
}
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
#[tonic::async_trait]
impl ConnectorManageRpc for ConnectorManagerRpcService {
async fn list_connector(
&self,
_request: tonic::Request<ListConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> {
let mut ret = easytier_rpc::ListConnectorResponse::default();
let connectors = self.0.list_connectors().await;
ret.connectors = connectors;
Ok(tonic::Response::new(ret))
}
async fn manage_connector(
&self,
request: tonic::Request<ManageConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> {
let req = request.into_inner();
let url = url::Url::parse(&req.url)
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?;
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 {
self.0.remove_connector(url.path()).await.map_err(|e| {
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e))
})?;
return Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
));
} else {
self.0
.add_connector_by_url(url.as_str())
.await
.map_err(|e| {
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
})?;
}
Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
))
}
}
#[cfg(test)]
mod tests {
use crate::{
peers::tests::create_mock_peer_manager,
set_global_var,
tunnels::{Tunnel, TunnelError},
};
use super::*;
#[tokio::test]
async fn test_reconnect_with_connecting_addr() {
set_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, 1);
let peer_mgr = create_mock_peer_manager().await;
let mgr = ManualConnectorManager::new(peer_mgr.get_global_ctx(), peer_mgr);
struct MockConnector {}
#[async_trait::async_trait]
impl TunnelConnector for MockConnector {
fn remote_url(&self) -> url::Url {
url::Url::parse("tcp://aa.com").unwrap()
}
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Err(TunnelError::CommonError("fake error".into()))
}
}
mgr.add_connector(MockConnector {});
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
+99
View File
@@ -0,0 +1,99 @@
use std::{
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
sync::Arc,
};
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
tunnels::{
ring_tunnel::RingTunnelConnector,
tcp_tunnel::TcpTunnelConnector,
udp_tunnel::UdpTunnelConnector,
wireguard::{WgConfig, WgTunnelConnector},
TunnelConnector,
},
};
pub mod direct;
pub mod manual;
pub mod udp_hole_punch;
async fn set_bind_addr_for_peer_connector(
connector: &mut impl TunnelConnector,
is_ipv4: bool,
ip_collector: &Arc<IPCollector>,
) {
let ips = ip_collector.collect_ip_addrs().await;
if is_ipv4 {
let mut bind_addrs = vec![];
for ipv4 in ips.interface_ipv4s {
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
} else {
let mut bind_addrs = vec![];
for ipv6 in ips.interface_ipv6s {
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
}
let _ = connector;
}
pub async fn create_connector_by_url(
url: &str,
global_ctx: &ArcGlobalCtx,
) -> Result<Box<dyn TunnelConnector + Send + Sync + 'static>, Error> {
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
match url.scheme() {
"tcp" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
let mut connector = TcpTunnelConnector::new(url);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector));
}
"udp" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
let mut connector = UdpTunnelConnector::new(url);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector));
}
"ring" => {
crate::tunnels::check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
let connector = RingTunnelConnector::new(url);
return Ok(Box::new(connector));
}
"wg" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
let nid = global_ctx.get_network_identity();
let wg_config =
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
let mut connector = WgTunnelConnector::new(url, wg_config);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector));
}
_ => {
return Err(Error::InvalidUrl(url.into()));
}
}
}
+545
View File
@@ -0,0 +1,545 @@
use std::{net::SocketAddr, sync::Arc};
use anyhow::Context;
use crossbeam::atomic::AtomicCell;
use rand::{seq::SliceRandom, Rng, SeedableRng};
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tracing::Instrument;
use crate::{
common::{
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
},
peers::peer_manager::PeerManager,
rpc::NatType,
tunnels::{
common::setup_sokcet2,
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
Tunnel, TunnelConnCounter, TunnelListener,
},
};
use super::direct::PeerManagerForDirectConnector;
#[tarpc::service]
pub trait UdpHolePunchService {
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
}
#[derive(Debug)]
struct UdpHolePunchListener {
socket: Arc<UdpSocket>,
tasks: JoinSet<()>,
running: Arc<AtomicCell<bool>>,
mapped_addr: SocketAddr,
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
listen_time: std::time::Instant,
last_select_time: AtomicCell<std::time::Instant>,
last_connected_time: Arc<AtomicCell<std::time::Instant>>,
}
impl UdpHolePunchListener {
async fn get_avail_port() -> Result<u16, Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
Ok(socket.local_addr()?.port())
}
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
let port = Self::get_avail_port().await?;
let listen_url = format!("udp://0.0.0.0:{}", port);
let gctx = peer_mgr.get_global_ctx();
let stun_info_collect = gctx.get_stun_info_collector();
let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?;
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
{
let _g = peer_mgr.get_global_ctx().net_ns.guard();
listener.listen().await?;
}
let socket = listener.get_socket().unwrap();
let running = Arc::new(AtomicCell::new(true));
let running_clone = running.clone();
let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
let last_connected_time_clone = last_connected_time.clone();
let conn_counter = listener.get_conn_counter();
let mut tasks = JoinSet::new();
tasks.spawn(async move {
while let Ok(conn) = listener.accept().await {
last_connected_time_clone.store(std::time::Instant::now());
tracing::warn!(?conn, "udp hole punching listener got peer connection");
let peer_mgr = peer_mgr.clone();
tokio::spawn(async move {
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
tracing::error!(
?e,
"failed to add tunnel as server in hole punch listener"
);
}
});
}
running_clone.store(false);
});
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
Ok(Self {
tasks,
socket,
running,
mapped_addr,
conn_counter,
listen_time: std::time::Instant::now(),
last_select_time: AtomicCell::new(std::time::Instant::now()),
last_connected_time,
})
}
pub async fn get_socket(&self) -> Arc<UdpSocket> {
self.last_select_time.store(std::time::Instant::now());
self.socket.clone()
}
}
#[derive(Debug)]
struct UdpHolePunchConnectorData {
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
}
#[derive(Clone)]
struct UdpHolePunchRpcServer {
data: Arc<UdpHolePunchConnectorData>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
#[tarpc::server]
impl UdpHolePunchService for UdpHolePunchRpcServer {
async fn try_punch_hole(
self,
_: tarpc::context::Context,
local_mapped_addr: SocketAddr,
) -> Option<SocketAddr> {
let (socket, mapped_addr) = self.select_listener().await?;
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
let my_udp_nat_type = self
.data
.global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type;
// if we are restricted, we need to send hole punching resp to client
if my_udp_nat_type == NatType::PortRestricted as i32
|| my_udp_nat_type == NatType::Restricted as i32
{
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
self.tasks.lock().unwrap().spawn(async move {
for _ in 0..10 {
tracing::info!(?local_mapped_addr, "sending hole punching packet");
// generate a 128 bytes vec with random data
let mut rng = rand::rngs::StdRng::from_entropy();
let mut buf = vec![0u8; 128];
rng.fill(&mut buf[..]);
let udp_packet = UdpPacket::new_hole_punch_packet(buf);
let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet);
let _ = socket
.send_to(udp_packet_bytes.as_ref(), local_mapped_addr)
.await;
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
}
});
}
Some(mapped_addr)
}
}
impl UdpHolePunchRpcServer {
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
Self { data, tasks }
}
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
let all_listener_sockets = &self.data.listeners;
// remove listener that not have connection in for 20 seconds
all_listener_sockets.lock().await.retain(|listener| {
listener.last_connected_time.load().elapsed().as_secs() < 20
&& listener.conn_counter.get() > 0
});
let mut use_last = false;
if all_listener_sockets.lock().await.len() < 4 {
tracing::warn!("creating new udp hole punching listener");
all_listener_sockets.lock().await.push(
UdpHolePunchListener::new(self.data.peer_mgr.clone())
.await
.ok()?,
);
use_last = true;
}
let locked = all_listener_sockets.lock().await;
let listener = if use_last {
locked.last()?
} else {
locked.choose(&mut rand::rngs::StdRng::from_entropy())?
};
Some((listener.get_socket().await, listener.mapped_addr))
}
}
pub struct UdpHolePunchConnector {
data: Arc<UdpHolePunchConnectorData>,
tasks: JoinSet<()>,
}
// Currently support:
// Symmetric -> Full Cone
// Any Type of Full Cone -> Any Type of Full Cone
// if same level of full cone, node with smaller peer_id will be the initiator
// if different level of full cone, node with more strict level will be the initiator
impl UdpHolePunchConnector {
pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
Self {
data: Arc::new(UdpHolePunchConnectorData {
global_ctx,
peer_mgr,
listeners: Arc::new(Mutex::new(Vec::new())),
}),
tasks: JoinSet::new(),
}
}
pub async fn run_as_client(&mut self) -> Result<(), Error> {
let data = self.data.clone();
self.tasks.spawn(async move {
Self::main_loop(data).await;
});
Ok(())
}
pub async fn run_as_server(&mut self) -> Result<(), Error> {
self.data.peer_mgr.get_peer_rpc_mgr().run_service(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
UdpHolePunchRpcServer::new(self.data.clone()).serve(),
);
Ok(())
}
pub async fn run(&mut self) -> Result<(), Error> {
self.run_as_client().await?;
self.run_as_server().await?;
Ok(())
}
async fn collect_peer_to_connect(data: Arc<UdpHolePunchConnectorData>) -> Vec<PeerId> {
let mut peers_to_connect = Vec::new();
// do not do anything if:
// 1. our stun test has not finished
// 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
let my_nat_type = data
.global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type;
let my_nat_type = NatType::try_from(my_nat_type).unwrap();
if my_nat_type == NatType::Unknown
|| my_nat_type == NatType::OpenInternet
|| my_nat_type == NatType::NoPat
{
return peers_to_connect;
}
// collect peer list from peer manager and do some filter:
// 1. peers without direct conns;
// 2. peers is full cone (any restricted type);
for route in data.peer_mgr.list_routes().await.iter() {
let Some(peer_stun_info) = route.stun_info.as_ref() else {
continue;
};
let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else {
continue;
};
let peer_id: PeerId = route.peer_id;
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
if conns.is_some() && conns.unwrap().len() > 0 {
continue;
}
// if peer is symmetric ignore it because we cannot connect to it
// if peer is open internet or no pat, direct connector will connecto to it
if peer_nat_type == NatType::Unknown
|| peer_nat_type == NatType::OpenInternet
|| peer_nat_type == NatType::NoPat
|| peer_nat_type == NatType::Symmetric
|| peer_nat_type == NatType::SymUdpFirewall
{
continue;
}
// if we are symmetric, we can only connect to full cone
// TODO: can also connect to restricted full cone, with some extra work
if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall)
&& peer_nat_type != NatType::FullCone
{
continue;
}
// if we have smae level of full cone, node with smaller peer_id will be the initiator
if my_nat_type == peer_nat_type {
if data.peer_mgr.my_peer_id() > peer_id {
continue;
}
} else {
// if we have different level of full cone
// we will be the initiator if we have more strict level
if my_nat_type < peer_nat_type {
continue;
}
}
tracing::info!(
?peer_id,
?peer_nat_type,
?my_nat_type,
?data.global_ctx.id,
"found peer to do hole punching"
);
peers_to_connect.push(peer_id);
}
peers_to_connect
}
#[tracing::instrument]
async fn do_hole_punching(
data: Arc<UdpHolePunchConnectorData>,
dst_peer_id: PeerId,
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
tracing::info!(?dst_peer_id, "start hole punching");
// client: choose a local udp port, and get the pubic mapped port from stun server
let socket = {
let _g = data.global_ctx.net_ns.guard();
UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")?
};
let local_socket_addr = socket.local_addr()?;
let local_port = socket.local_addr()?.port();
drop(socket); // drop the socket to release the port
let local_mapped_addr = data
.global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(local_port)
.await
.with_context(|| "failed to get udp port mapping")?;
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
let Some(remote_mapped_addr) = data
.peer_mgr
.get_peer_rpc_mgr()
.do_client_rpc_scoped(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
dst_peer_id,
|c| async {
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
let remote_mapped_addr = client
.try_punch_hole(tarpc::context::current(), local_mapped_addr)
.await;
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr");
remote_mapped_addr
},
)
.await?
else {
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
};
// server: will send some punching resps, total 10 packets.
// client: use the socket to create UdpTunnel with UdpTunnelConnector
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
let connector = UdpTunnelConnector::new(
format!(
"udp://{}:{}",
remote_mapped_addr.ip(),
remote_mapped_addr.port()
)
.to_string()
.parse()
.unwrap(),
);
let _g = data.global_ctx.net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(local_socket_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &local_socket_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
Ok(connector
.try_connect_with_socket(socket)
.await
.with_context(|| "UdpTunnelConnector failed to connect remote")?)
}
async fn main_loop(data: Arc<UdpHolePunchConnectorData>) {
loop {
let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await;
tracing::trace!(?peers_to_connect, "peers to connect");
if peers_to_connect.len() == 0 {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
let mut tasks: JoinSet<Result<(), anyhow::Error>> = JoinSet::new();
for peer_id in peers_to_connect {
let data = data.clone();
tasks.spawn(
async move {
let tunnel = Self::do_hole_punching(data.clone(), peer_id)
.await
.with_context(|| "failed to do hole punching")?;
let _ =
data.peer_mgr
.add_client_tunnel(tunnel)
.await
.with_context(|| {
"failed to add tunnel as client in hole punch connector"
})?;
Ok(())
}
.instrument(tracing::info_span!("doing hole punching client", ?peer_id)),
);
}
while let Some(res) = tasks.join_next().await {
if let Err(e) = res {
tracing::error!(?e, "failed to join hole punching job");
continue;
}
match res.unwrap() {
Err(e) => {
tracing::error!(?e, "failed to do hole punching job");
}
Ok(_) => {
tracing::info!("hole punching job succeed");
}
}
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
}
}
#[cfg(test)]
pub mod tests {
use std::sync::Arc;
use crate::rpc::{NatType, StunInfo};
use crate::{
common::{error::Error, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::UdpHolePunchConnector,
peers::{
peer_manager::PeerManager,
tests::{
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
},
};
struct MockStunInfoCollector {
udp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: NatType::Unknown as i32,
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
}
}
async fn get_udp_port_mapping(&self, port: u16) -> Result<std::net::SocketAddr, Error> {
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
peer_mgr
.get_global_ctx()
.replace_stun_info_collector(collector);
}
pub async fn create_mock_peer_manager_with_mock_stun(
udp_nat_type: NatType,
) -> Arc<PeerManager> {
let p_a = create_mock_peer_manager().await;
replace_stun_info_collector(p_a.clone(), udp_nat_type);
p_a
}
#[tokio::test]
async fn hole_punching() {
let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
println!("{:?}", p_a.list_routes().await);
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone());
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone());
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
.await
.unwrap();
println!("{:?}", p_a.list_routes().await);
}
}