mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 10:14:35 +00:00
use workspace, prepare for config server and gui (#48)
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::{
|
||||
sync::{mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
use super::{
|
||||
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
|
||||
peer_conn::PeerConn,
|
||||
peer_map::PeerMap,
|
||||
peer_rpc::PeerRpcManager,
|
||||
};
|
||||
|
||||
pub struct ForeignNetworkClient {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
my_peer_id: PeerId,
|
||||
|
||||
peer_map: Arc<PeerMap>,
|
||||
|
||||
next_hop: Arc<DashMap<PeerId, PeerId>>,
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkClient {
|
||||
pub fn new(
|
||||
global_ctx: ArcGlobalCtx,
|
||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
my_peer_id: PeerId,
|
||||
) -> Self {
|
||||
let peer_map = Arc::new(PeerMap::new(
|
||||
packet_sender_to_mgr,
|
||||
global_ctx.clone(),
|
||||
my_peer_id,
|
||||
));
|
||||
let next_hop = Arc::new(DashMap::new());
|
||||
|
||||
Self {
|
||||
global_ctx,
|
||||
peer_rpc,
|
||||
my_peer_id,
|
||||
|
||||
peer_map,
|
||||
|
||||
next_hop,
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
|
||||
tracing::warn!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network client");
|
||||
self.peer_map.add_new_peer_conn(peer_conn).await
|
||||
}
|
||||
|
||||
async fn collect_next_hop_in_foreign_network_task(
|
||||
network_identity: NetworkIdentity,
|
||||
peer_map: Arc<PeerMap>,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
next_hop: Arc<DashMap<PeerId, PeerId>>,
|
||||
) {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
peer_map.clean_peer_without_conn().await;
|
||||
|
||||
let new_next_hop = Self::collect_next_hop_in_foreign_network(
|
||||
network_identity.clone(),
|
||||
peer_map.clone(),
|
||||
peer_rpc.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
next_hop.clear();
|
||||
for (k, v) in new_next_hop.into_iter() {
|
||||
next_hop.insert(k, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_next_hop_in_foreign_network(
|
||||
network_identity: NetworkIdentity,
|
||||
peer_map: Arc<PeerMap>,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
) -> DashMap<PeerId, PeerId> {
|
||||
let peers = peer_map.list_peers().await;
|
||||
let mut tasks = JoinSet::new();
|
||||
if !peers.is_empty() {
|
||||
tracing::warn!(?peers, my_peer_id = ?peer_rpc.my_peer_id(), "collect next hop in foreign network");
|
||||
}
|
||||
for peer in peers {
|
||||
let peer_rpc = peer_rpc.clone();
|
||||
let network_identity = network_identity.clone();
|
||||
tasks.spawn(async move {
|
||||
let Ok(Some(peers_in_foreign)) = peer_rpc
|
||||
.do_client_rpc_scoped(FOREIGN_NETWORK_SERVICE_ID, peer, |c| async {
|
||||
let c =
|
||||
ForeignNetworkServiceClient::new(tarpc::client::Config::default(), c)
|
||||
.spawn();
|
||||
let mut rpc_ctx = tarpc::context::current();
|
||||
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(2);
|
||||
let ret = c.list_network_peers(rpc_ctx, network_identity).await;
|
||||
ret
|
||||
})
|
||||
.await
|
||||
else {
|
||||
return (peer, vec![]);
|
||||
};
|
||||
|
||||
(peer, peers_in_foreign)
|
||||
});
|
||||
}
|
||||
|
||||
let new_next_hop = DashMap::new();
|
||||
while let Some(join_ret) = tasks.join_next().await {
|
||||
let Ok((gateway, peer_ids)) = join_ret else {
|
||||
tracing::error!(?join_ret, "collect next hop in foreign network failed");
|
||||
continue;
|
||||
};
|
||||
for ret in peer_ids {
|
||||
new_next_hop.insert(ret, gateway);
|
||||
}
|
||||
}
|
||||
|
||||
new_next_hop
|
||||
}
|
||||
|
||||
pub fn has_next_hop(&self, peer_id: PeerId) -> bool {
|
||||
self.get_next_hop(peer_id).is_some()
|
||||
}
|
||||
|
||||
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
|
||||
if self.peer_map.has_peer(peer_id) {
|
||||
return Some(peer_id.clone());
|
||||
}
|
||||
self.next_hop.get(&peer_id).map(|v| v.clone())
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes, peer_id: PeerId) -> Result<(), Error> {
|
||||
if let Some(next_hop) = self.get_next_hop(peer_id) {
|
||||
let ret = self.peer_map.send_msg_directly(msg, next_hop).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!(
|
||||
?ret,
|
||||
?peer_id,
|
||||
?next_hop,
|
||||
"foreign network client send msg failed"
|
||||
);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
Err(Error::RouteError(Some("no next hop".to_string())))
|
||||
}
|
||||
|
||||
pub fn list_foreign_peers(&self) -> Vec<PeerId> {
|
||||
let mut peers = vec![];
|
||||
for item in self.next_hop.iter() {
|
||||
if item.key() != &self.my_peer_id {
|
||||
peers.push(item.key().clone());
|
||||
}
|
||||
}
|
||||
peers
|
||||
}
|
||||
|
||||
pub async fn run(&self) {
|
||||
self.tasks
|
||||
.lock()
|
||||
.await
|
||||
.spawn(Self::collect_next_hop_in_foreign_network_task(
|
||||
self.global_ctx.get_network_identity(),
|
||||
self.peer_map.clone(),
|
||||
self.peer_rpc.clone(),
|
||||
self.next_hop.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
pub fn get_next_hop_table(&self) -> DashMap<PeerId, PeerId> {
|
||||
let next_hop = DashMap::new();
|
||||
for item in self.next_hop.iter() {
|
||||
next_hop.insert(item.key().clone(), item.value().clone());
|
||||
}
|
||||
next_hop
|
||||
}
|
||||
|
||||
pub fn get_peer_map(&self) -> Arc<PeerMap> {
|
||||
self.peer_map.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,459 @@
|
||||
/*
|
||||
foreign_network_manager is used to forward packets of other networks. currently
|
||||
only forward packets of peers that directly connected to this node.
|
||||
|
||||
in future, with the help wo peer center we can forward packets of peers that
|
||||
connected to any node in the local network.
|
||||
*/
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
Mutex,
|
||||
},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
use super::{
|
||||
packet::{self},
|
||||
peer_conn::PeerConn,
|
||||
peer_map::PeerMap,
|
||||
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
|
||||
};
|
||||
|
||||
struct ForeignNetworkEntry {
|
||||
network: NetworkIdentity,
|
||||
peer_map: Arc<PeerMap>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkEntry {
|
||||
fn new(
|
||||
network: NetworkIdentity,
|
||||
packet_sender: mpsc::Sender<Bytes>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
my_peer_id: PeerId,
|
||||
) -> Self {
|
||||
let peer_map = Arc::new(PeerMap::new(packet_sender, global_ctx, my_peer_id));
|
||||
Self { network, peer_map }
|
||||
}
|
||||
}
|
||||
|
||||
struct ForeignNetworkManagerData {
|
||||
network_peer_maps: DashMap<String, Arc<ForeignNetworkEntry>>,
|
||||
peer_network_map: DashMap<PeerId, String>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkManagerData {
|
||||
async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let network_name = self
|
||||
.peer_network_map
|
||||
.get(&dst_peer_id)
|
||||
.ok_or_else(|| Error::RouteError(Some("network not found".to_string())))?
|
||||
.clone();
|
||||
let entry = self
|
||||
.network_peer_maps
|
||||
.get(&network_name)
|
||||
.ok_or_else(|| Error::RouteError(Some("no peer in network".to_string())))?
|
||||
.clone();
|
||||
entry.peer_map.send_msg(msg, dst_peer_id).await
|
||||
}
|
||||
|
||||
fn get_peer_network(&self, peer_id: PeerId) -> Option<String> {
|
||||
self.peer_network_map.get(&peer_id).map(|v| v.clone())
|
||||
}
|
||||
|
||||
fn get_network_entry(&self, network_name: &str) -> Option<Arc<ForeignNetworkEntry>> {
|
||||
self.network_peer_maps.get(network_name).map(|v| v.clone())
|
||||
}
|
||||
|
||||
fn remove_peer(&self, peer_id: PeerId) {
|
||||
self.peer_network_map.remove(&peer_id);
|
||||
self.network_peer_maps.retain(|_, v| !v.peer_map.is_empty());
|
||||
}
|
||||
|
||||
fn clear_no_conn_peer(&self) {
|
||||
for item in self.network_peer_maps.iter() {
|
||||
let peer_map = item.value().peer_map.clone();
|
||||
tokio::spawn(async move {
|
||||
peer_map.clean_peer_without_conn().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RpcTransport {
|
||||
my_peer_id: PeerId,
|
||||
data: Arc<ForeignNetworkManagerData>,
|
||||
|
||||
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for RpcTransport {
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
self.data.send_msg(msg, dst_peer_id).await
|
||||
}
|
||||
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
||||
Ok(o)
|
||||
} else {
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait ForeignNetworkService {
|
||||
async fn list_network_peers(network_identy: NetworkIdentity) -> Option<Vec<PeerId>>;
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl ForeignNetworkService for Arc<ForeignNetworkManagerData> {
|
||||
async fn list_network_peers(
|
||||
self,
|
||||
_: tarpc::context::Context,
|
||||
network_identy: NetworkIdentity,
|
||||
) -> Option<Vec<PeerId>> {
|
||||
let entry = self.network_peer_maps.get(&network_identy.network_name)?;
|
||||
Some(entry.peer_map.list_peers().await)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ForeignNetworkManager {
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
||||
|
||||
packet_sender: mpsc::Sender<Bytes>,
|
||||
packet_recv: Mutex<Option<mpsc::Receiver<Bytes>>>,
|
||||
|
||||
data: Arc<ForeignNetworkManagerData>,
|
||||
rpc_mgr: Arc<PeerRpcManager>,
|
||||
rpc_transport_sender: UnboundedSender<Bytes>,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkManager {
|
||||
pub fn new(
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
||||
) -> Self {
|
||||
// recv packet from all foreign networks
|
||||
let (packet_sender, packet_recv) = mpsc::channel(1000);
|
||||
|
||||
let data = Arc::new(ForeignNetworkManagerData {
|
||||
network_peer_maps: DashMap::new(),
|
||||
peer_network_map: DashMap::new(),
|
||||
});
|
||||
|
||||
// handle rpc from foreign networks
|
||||
let (rpc_transport_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
|
||||
let rpc_mgr = Arc::new(PeerRpcManager::new(RpcTransport {
|
||||
my_peer_id,
|
||||
data: data.clone(),
|
||||
packet_recv: Mutex::new(peer_rpc_tspt_recv),
|
||||
}));
|
||||
|
||||
Self {
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
packet_sender_to_mgr,
|
||||
|
||||
packet_sender,
|
||||
packet_recv: Mutex::new(Some(packet_recv)),
|
||||
|
||||
data,
|
||||
rpc_mgr,
|
||||
rpc_transport_sender,
|
||||
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> {
|
||||
tracing::info!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager");
|
||||
|
||||
let entry = self
|
||||
.data
|
||||
.network_peer_maps
|
||||
.entry(peer_conn.get_network_identity().network_name.clone())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(ForeignNetworkEntry::new(
|
||||
peer_conn.get_network_identity(),
|
||||
self.packet_sender.clone(),
|
||||
self.global_ctx.clone(),
|
||||
self.my_peer_id,
|
||||
))
|
||||
})
|
||||
.clone();
|
||||
|
||||
self.data.peer_network_map.insert(
|
||||
peer_conn.get_peer_id(),
|
||||
peer_conn.get_network_identity().network_name.clone(),
|
||||
);
|
||||
|
||||
if entry.network.network_secret != peer_conn.get_network_identity().network_secret {
|
||||
return Err(anyhow::anyhow!("network secret not match").into());
|
||||
}
|
||||
|
||||
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
|
||||
}
|
||||
|
||||
async fn start_global_event_handler(&self) {
|
||||
let data = self.data.clone();
|
||||
let mut s = self.global_ctx.subscribe();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
while let Ok(e) = s.recv().await {
|
||||
if let GlobalCtxEvent::PeerRemoved(peer_id) = &e {
|
||||
tracing::info!(?e, "remove peer from foreign network manager");
|
||||
data.remove_peer(*peer_id);
|
||||
} else if let GlobalCtxEvent::PeerConnRemoved(..) = &e {
|
||||
tracing::info!(?e, "clear no conn peer from foreign network manager");
|
||||
data.clear_no_conn_peer();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn start_packet_recv(&self) {
|
||||
let mut recv = self.packet_recv.lock().await.take().unwrap();
|
||||
let sender_to_mgr = self.packet_sender_to_mgr.clone();
|
||||
let my_node_id = self.my_peer_id;
|
||||
let rpc_sender = self.rpc_transport_sender.clone();
|
||||
let data = self.data.clone();
|
||||
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
while let Some(packet_bytes) = recv.recv().await {
|
||||
let packet = packet::Packet::decode(&packet_bytes);
|
||||
let from_peer_id = packet.from_peer.into();
|
||||
let to_peer_id = packet.to_peer.into();
|
||||
if to_peer_id == my_node_id {
|
||||
if packet.packet_type == packet::PacketType::TaRpc {
|
||||
rpc_sender.send(packet_bytes.clone()).unwrap();
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = sender_to_mgr.send(packet_bytes).await {
|
||||
tracing::error!("send packet to mgr failed: {:?}", e);
|
||||
}
|
||||
} else {
|
||||
let Some(from_network) = data.get_peer_network(from_peer_id) else {
|
||||
continue;
|
||||
};
|
||||
let Some(to_network) = data.get_peer_network(to_peer_id) else {
|
||||
continue;
|
||||
};
|
||||
if from_network != to_network {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(entry) = data.get_network_entry(&from_network) {
|
||||
let ret = entry.peer_map.send_msg(packet_bytes, to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("forward packet to peer failed: {:?}", ret.err());
|
||||
}
|
||||
} else {
|
||||
tracing::error!("foreign network not found: {}", from_network);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn register_peer_rpc_service(&self) {
|
||||
self.rpc_mgr.run();
|
||||
self.rpc_mgr
|
||||
.run_service(FOREIGN_NETWORK_SERVICE_ID, self.data.clone().serve())
|
||||
}
|
||||
|
||||
pub async fn run(&self) {
|
||||
self.start_global_event_handler().await;
|
||||
self.start_packet_recv().await;
|
||||
self.register_peer_rpc_service().await;
|
||||
}
|
||||
|
||||
pub async fn list_foreign_networks(&self) -> DashMap<String, Vec<PeerId>> {
|
||||
let ret = DashMap::new();
|
||||
for item in self.data.network_peer_maps.iter() {
|
||||
let network_name = item.key().clone();
|
||||
ret.insert(network_name, vec![]);
|
||||
}
|
||||
|
||||
for mut n in ret.iter_mut() {
|
||||
let network_name = n.key().clone();
|
||||
let Some(item) = self
|
||||
.data
|
||||
.network_peer_maps
|
||||
.get(&network_name)
|
||||
.map(|v| v.clone())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
n.value_mut().extend(item.peer_map.list_peers().await);
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
common::global_ctx::tests::get_mock_global_ctx_with_network,
|
||||
connector::udp_hole_punch::tests::{
|
||||
create_mock_peer_manager_with_mock_stun, replace_stun_info_collector,
|
||||
},
|
||||
peers::{
|
||||
peer_manager::{PeerManager, RouteAlgoType},
|
||||
tests::{connect_peer_manager, wait_route_appear},
|
||||
},
|
||||
rpc::NatType,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
async fn create_mock_peer_manager_for_foreign_network(network: &str) -> Arc<PeerManager> {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Ospf,
|
||||
get_mock_global_ctx_with_network(Some(NetworkIdentity {
|
||||
network_name: network.to_string(),
|
||||
network_secret: network.to_string(),
|
||||
})),
|
||||
s,
|
||||
));
|
||||
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
|
||||
peer_mgr.run().await.unwrap();
|
||||
peer_mgr
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_foreign_network_manager() {
|
||||
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
||||
let pm_center2 =
|
||||
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
||||
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
|
||||
|
||||
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
|
||||
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
let mut succ = false;
|
||||
while now.elapsed().as_secs() < 10 {
|
||||
let table = pma_net1.get_foreign_network_client().get_next_hop_table();
|
||||
if table.len() >= 1 {
|
||||
succ = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
assert!(succ);
|
||||
|
||||
assert_eq!(
|
||||
vec![pm_center.my_peer_id()],
|
||||
pma_net1
|
||||
.get_foreign_network_client()
|
||||
.get_peer_map()
|
||||
.list_peers()
|
||||
.await
|
||||
);
|
||||
assert_eq!(
|
||||
vec![pm_center.my_peer_id()],
|
||||
pmb_net1
|
||||
.get_foreign_network_client()
|
||||
.get_peer_map()
|
||||
.list_peers()
|
||||
.await
|
||||
);
|
||||
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, pma_net1.list_routes().await.len());
|
||||
assert_eq!(1, pmb_net1.list_routes().await.len());
|
||||
|
||||
let pmc_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||
connect_peer_manager(pmc_net1.clone(), pm_center.clone()).await;
|
||||
wait_route_appear(pma_net1.clone(), pmc_net1.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(pmb_net1.clone(), pmc_net1.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(2, pmc_net1.list_routes().await.len());
|
||||
|
||||
let pma_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
||||
let pmb_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
||||
connect_peer_manager(pma_net2.clone(), pm_center.clone()).await;
|
||||
connect_peer_manager(pmb_net2.clone(), pm_center.clone()).await;
|
||||
wait_route_appear(pma_net2.clone(), pmb_net2.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, pma_net2.list_routes().await.len());
|
||||
assert_eq!(1, pmb_net2.list_routes().await.len());
|
||||
|
||||
assert_eq!(
|
||||
5,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.peer_network_map
|
||||
.len()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
2,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.network_peer_maps
|
||||
.len()
|
||||
);
|
||||
|
||||
drop(pmb_net2);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
assert_eq!(
|
||||
4,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.peer_network_map
|
||||
.len()
|
||||
);
|
||||
drop(pma_net2);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
assert_eq!(
|
||||
3,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.peer_network_map
|
||||
.len()
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.network_peer_maps
|
||||
.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
pub mod packet;
|
||||
pub mod peer;
|
||||
pub mod peer_conn;
|
||||
pub mod peer_manager;
|
||||
pub mod peer_map;
|
||||
pub mod peer_ospf_route;
|
||||
pub mod peer_rip_route;
|
||||
pub mod peer_rpc;
|
||||
pub mod route_trait;
|
||||
pub mod rpc_service;
|
||||
|
||||
pub mod foreign_network_client;
|
||||
pub mod foreign_network_manager;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait PeerPacketFilter {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
_packet: &packet::ArchivedPacket,
|
||||
_data: &Bytes,
|
||||
) -> Option<()> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait NicPacketFilter {
|
||||
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
|
||||
}
|
||||
|
||||
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
|
||||
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
|
||||
@@ -0,0 +1,254 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{
|
||||
global_ctx::NetworkIdentity,
|
||||
rkyv_util::{decode_from_bytes, encode_to_bytes, vec_to_string},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
const MAGIC: u32 = 0xd1e1a5e1;
|
||||
const VERSION: u32 = 1;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, PartialEq, Clone)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UUID(uuid::Bytes);
|
||||
|
||||
// impl Debug for UUID
|
||||
impl std::fmt::Debug for UUID {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let uuid = uuid::Uuid::from_bytes(self.0);
|
||||
write!(f, "{}", uuid)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<uuid::Uuid> for UUID {
|
||||
fn from(uuid: uuid::Uuid) -> Self {
|
||||
UUID(*uuid.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UUID> for uuid::Uuid {
|
||||
fn from(uuid: UUID) -> Self {
|
||||
uuid::Uuid::from_bytes(uuid.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl ArchivedUUID {
|
||||
pub fn to_uuid(&self) -> uuid::Uuid {
|
||||
uuid::Uuid::from_bytes(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ArchivedUUID> for UUID {
|
||||
fn from(uuid: &ArchivedUUID) -> Self {
|
||||
UUID(uuid.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
pub struct HandShake {
|
||||
pub magic: u32,
|
||||
pub my_peer_id: PeerId,
|
||||
pub version: u32,
|
||||
pub features: Vec<String>,
|
||||
pub network_identity: NetworkIdentity,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
pub struct RoutePacket {
|
||||
pub route_id: u8,
|
||||
pub body: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub enum CtrlPacketPayload {
|
||||
HandShake(HandShake),
|
||||
RoutePacket(RoutePacket),
|
||||
Ping(u32),
|
||||
Pong(u32),
|
||||
TaRpc(u32, u32, bool, Vec<u8>), // u32: service_id, u32: transact_id, bool: is_req, Vec<u8>: rpc body
|
||||
}
|
||||
|
||||
impl CtrlPacketPayload {
|
||||
pub fn from_packet(p: &ArchivedPacket) -> CtrlPacketPayload {
|
||||
assert_ne!(p.packet_type, PacketType::Data);
|
||||
postcard::from_bytes(p.payload.as_bytes()).unwrap()
|
||||
}
|
||||
|
||||
pub fn from_packet2(p: &Packet) -> CtrlPacketPayload {
|
||||
postcard::from_bytes(p.payload.as_bytes()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub enum PacketType {
|
||||
Data = 1,
|
||||
HandShake = 2,
|
||||
RoutePacket = 3,
|
||||
Ping = 4,
|
||||
Pong = 5,
|
||||
TaRpc = 6,
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub struct Packet {
|
||||
pub from_peer: PeerId,
|
||||
pub to_peer: PeerId,
|
||||
pub packet_type: PacketType,
|
||||
pub payload: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Packet {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
|
||||
self.from_peer,
|
||||
self.to_peer,
|
||||
self.packet_type,
|
||||
&self.payload.as_bytes()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ArchivedPacket {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
|
||||
self.from_peer,
|
||||
self.to_peer,
|
||||
self.packet_type,
|
||||
&self.payload.as_bytes()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn decode(v: &[u8]) -> &ArchivedPacket {
|
||||
decode_from_bytes::<Packet>(v).unwrap()
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
from_peer: PeerId,
|
||||
to_peer: PeerId,
|
||||
packet_type: PacketType,
|
||||
payload: Vec<u8>,
|
||||
) -> Self {
|
||||
Packet {
|
||||
from_peer,
|
||||
to_peer,
|
||||
packet_type,
|
||||
payload: vec_to_string(payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for Bytes {
|
||||
fn from(val: Packet) -> Self {
|
||||
encode_to_bytes::<_, 4096>(&val)
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn new_handshake(from_peer: PeerId, network: &NetworkIdentity) -> Self {
|
||||
let handshake = CtrlPacketPayload::HandShake(HandShake {
|
||||
magic: MAGIC,
|
||||
my_peer_id: from_peer,
|
||||
version: VERSION,
|
||||
features: Vec::new(),
|
||||
network_identity: network.clone().into(),
|
||||
});
|
||||
Packet::new(
|
||||
from_peer.into(),
|
||||
0,
|
||||
PacketType::HandShake,
|
||||
postcard::to_allocvec(&handshake).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_data_packet(from_peer: PeerId, to_peer: PeerId, data: &[u8]) -> Self {
|
||||
Packet::new(from_peer, to_peer, PacketType::Data, data.to_vec())
|
||||
}
|
||||
|
||||
pub fn new_route_packet(from_peer: PeerId, to_peer: PeerId, route_id: u8, data: &[u8]) -> Self {
|
||||
let route = CtrlPacketPayload::RoutePacket(RoutePacket {
|
||||
route_id,
|
||||
body: data.to_vec(),
|
||||
});
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::RoutePacket,
|
||||
postcard::to_allocvec(&route).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_ping_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
|
||||
let ping = CtrlPacketPayload::Ping(seq);
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::Ping,
|
||||
postcard::to_allocvec(&ping).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_pong_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
|
||||
let pong = CtrlPacketPayload::Pong(seq);
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::Pong,
|
||||
postcard::to_allocvec(&pong).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_tarpc_packet(
|
||||
from_peer: PeerId,
|
||||
to_peer: PeerId,
|
||||
service_id: u32,
|
||||
transact_id: u32,
|
||||
is_req: bool,
|
||||
body: Vec<u8>,
|
||||
) -> Self {
|
||||
let ta_rpc = CtrlPacketPayload::TaRpc(service_id, transact_id, is_req, body);
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::TaRpc,
|
||||
postcard::to_allocvec(&ta_rpc).unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::common::new_peer_id;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn serialize() {
|
||||
let a = "abcde";
|
||||
let out = Packet::new_data_packet(new_peer_id(), new_peer_id(), a.as_bytes());
|
||||
// let out = T::new(a.as_bytes());
|
||||
let out_bytes: Bytes = out.into();
|
||||
println!("out str: {:?}", a.as_bytes());
|
||||
println!("out bytes: {:?}", out_bytes);
|
||||
|
||||
let archived = Packet::decode(&out_bytes[..]);
|
||||
println!("in packet: {:?}", archived);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
use tokio::{
|
||||
select,
|
||||
sync::{mpsc, Mutex},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use super::peer_conn::{PeerConn, PeerConnId};
|
||||
use crate::common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
PeerId,
|
||||
};
|
||||
use crate::rpc::PeerConnInfo;
|
||||
|
||||
type ArcPeerConn = Arc<Mutex<PeerConn>>;
|
||||
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
|
||||
|
||||
pub struct Peer {
|
||||
pub peer_node_id: PeerId,
|
||||
conns: ConnMap,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
||||
|
||||
close_event_sender: mpsc::Sender<PeerConnId>,
|
||||
close_event_listener: JoinHandle<()>,
|
||||
|
||||
shutdown_notifier: Arc<tokio::sync::Notify>,
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
peer_node_id: PeerId,
|
||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
) -> Self {
|
||||
let conns: ConnMap = Arc::new(DashMap::new());
|
||||
let (close_event_sender, mut close_event_receiver) = mpsc::channel(10);
|
||||
let shutdown_notifier = Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
let conns_copy = conns.clone();
|
||||
let shutdown_notifier_copy = shutdown_notifier.clone();
|
||||
let global_ctx_copy = global_ctx.clone();
|
||||
let close_event_listener = tokio::spawn(
|
||||
async move {
|
||||
loop {
|
||||
select! {
|
||||
ret = close_event_receiver.recv() => {
|
||||
if ret.is_none() {
|
||||
break;
|
||||
}
|
||||
let ret = ret.unwrap();
|
||||
tracing::warn!(
|
||||
?peer_node_id,
|
||||
?ret,
|
||||
"notified that peer conn is closed",
|
||||
);
|
||||
|
||||
if let Some((_, conn)) = conns_copy.remove(&ret) {
|
||||
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
|
||||
conn.lock().await.get_conn_info(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
_ = shutdown_notifier_copy.notified() => {
|
||||
close_event_receiver.close();
|
||||
tracing::warn!(?peer_node_id, "peer close event listener notified");
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!("peer {} close event listener exit", peer_node_id);
|
||||
}
|
||||
.instrument(tracing::info_span!(
|
||||
"peer_close_event_listener",
|
||||
?peer_node_id,
|
||||
)),
|
||||
);
|
||||
|
||||
Peer {
|
||||
peer_node_id,
|
||||
conns: conns.clone(),
|
||||
packet_recv_chan,
|
||||
global_ctx,
|
||||
|
||||
close_event_sender,
|
||||
close_event_listener,
|
||||
|
||||
shutdown_notifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
|
||||
conn.set_close_event_sender(self.close_event_sender.clone());
|
||||
conn.start_recv_loop(self.packet_recv_chan.clone());
|
||||
conn.start_pingpong();
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
|
||||
self.conns
|
||||
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
|
||||
let Some(conn) = self.conns.iter().next() else {
|
||||
return Err(Error::PeerNoConnectionError(self.peer_node_id));
|
||||
};
|
||||
|
||||
let conn_clone = conn.clone();
|
||||
drop(conn);
|
||||
conn_clone.lock().await.send_msg(msg).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(&self, conn_id: &PeerConnId) -> Result<(), Error> {
|
||||
let has_key = self.conns.contains_key(conn_id);
|
||||
if !has_key {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
self.close_event_sender.send(conn_id.clone()).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_peer_conns(&self) -> Vec<PeerConnInfo> {
|
||||
let mut conns = vec![];
|
||||
for conn in self.conns.iter() {
|
||||
// do not lock here, otherwise it will cause dashmap deadlock
|
||||
conns.push(conn.clone());
|
||||
}
|
||||
|
||||
let mut ret = Vec::new();
|
||||
for conn in conns {
|
||||
ret.push(conn.lock().await.get_conn_info());
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
// pritn on drop
|
||||
impl Drop for Peer {
|
||||
fn drop(&mut self) {
|
||||
self.shutdown_notifier.notify_one();
|
||||
tracing::info!("peer {} drop", self.peer_node_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use tokio::{sync::mpsc, time::timeout};
|
||||
|
||||
use crate::{
|
||||
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
|
||||
peers::peer_conn::PeerConn,
|
||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
||||
};
|
||||
|
||||
use super::Peer;
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_peer() {
|
||||
let (local_packet_send, _local_packet_recv) = mpsc::channel(10);
|
||||
let (remote_packet_send, _remote_packet_recv) = mpsc::channel(10);
|
||||
let global_ctx = get_mock_global_ctx();
|
||||
let local_peer = Peer::new(new_peer_id(), local_packet_send, global_ctx.clone());
|
||||
let remote_peer = Peer::new(new_peer_id(), remote_packet_send, global_ctx.clone());
|
||||
|
||||
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
|
||||
let mut local_peer_conn =
|
||||
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
|
||||
let mut remote_peer_conn =
|
||||
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
|
||||
|
||||
assert!(!local_peer_conn.handshake_done());
|
||||
assert!(!remote_peer_conn.handshake_done());
|
||||
|
||||
let (a, b) = tokio::join!(
|
||||
local_peer_conn.do_handshake_as_client(),
|
||||
remote_peer_conn.do_handshake_as_server()
|
||||
);
|
||||
a.unwrap();
|
||||
b.unwrap();
|
||||
|
||||
let local_conn_id = local_peer_conn.get_conn_id();
|
||||
|
||||
local_peer.add_peer_conn(local_peer_conn).await;
|
||||
remote_peer.add_peer_conn(remote_peer_conn).await;
|
||||
|
||||
assert_eq!(local_peer.list_peer_conns().await.len(), 1);
|
||||
assert_eq!(remote_peer.list_peer_conns().await.len(), 1);
|
||||
|
||||
let close_handler =
|
||||
tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await });
|
||||
|
||||
// wait for remote peer conn close
|
||||
timeout(std::time::Duration::from_secs(5), async {
|
||||
while (&remote_peer).list_peer_conns().await.len() != 0 {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("wait for close handler");
|
||||
close_handler.await.unwrap().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,652 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use pnet::datalink::NetworkInterface;
|
||||
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
|
||||
use tokio_util::{bytes::Bytes, sync::PollSender};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
PeerId,
|
||||
},
|
||||
define_tunnel_filter_chain,
|
||||
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
|
||||
rpc::{PeerConnInfo, PeerConnStats},
|
||||
tunnels::{
|
||||
stats::{Throughput, WindowLatency},
|
||||
tunnel_filter::StatsRecorderTunnelFilter,
|
||||
DatagramSink, Tunnel, TunnelError,
|
||||
},
|
||||
};
|
||||
|
||||
use super::packet::{self, HandShake, Packet};
|
||||
|
||||
pub type PacketRecvChan = mpsc::Sender<Bytes>;
|
||||
|
||||
pub type PeerConnId = uuid::Uuid;
|
||||
|
||||
macro_rules! wait_response {
|
||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
|
||||
if rsp_vec.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
let rsp_vec = rsp_vec.unwrap().unwrap()?;
|
||||
|
||||
let $out_var;
|
||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||
if rsp_bytes.packet_type != PacketType::HandShake {
|
||||
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"unexpected packet type".to_owned(),
|
||||
));
|
||||
}
|
||||
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
|
||||
match &resp_payload {
|
||||
$pattern => $out_var = $value,
|
||||
_ => {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, pattern: {:?}",
|
||||
rsp_bytes,
|
||||
stringify!($pattern)
|
||||
);
|
||||
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PeerInfo {
|
||||
magic: u32,
|
||||
pub my_peer_id: PeerId,
|
||||
version: u32,
|
||||
pub features: Vec<String>,
|
||||
pub interfaces: Vec<NetworkInterface>,
|
||||
pub network_identity: NetworkIdentity,
|
||||
}
|
||||
|
||||
impl<'a> From<&HandShake> for PeerInfo {
|
||||
fn from(hs: &HandShake) -> Self {
|
||||
PeerInfo {
|
||||
magic: hs.magic.into(),
|
||||
my_peer_id: hs.my_peer_id.into(),
|
||||
version: hs.version.into(),
|
||||
features: hs.features.iter().map(|x| x.to_string()).collect(),
|
||||
interfaces: Vec::new(),
|
||||
network_identity: hs.network_identity.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PeerConnPinger {
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
}
|
||||
|
||||
impl Debug for PeerConnPinger {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConnPinger")
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("peer_id", &self.peer_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerConnPinger {
|
||||
pub fn new(
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
) -> Self {
|
||||
Self {
|
||||
my_peer_id,
|
||||
peer_id,
|
||||
sink: Arc::new(Mutex::new(sink)),
|
||||
tasks: JoinSet::new(),
|
||||
latency_stats,
|
||||
ctrl_sender,
|
||||
loss_rate_stats,
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_pingpong_once(
|
||||
my_node_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
receiver: &mut broadcast::Receiver<Bytes>,
|
||||
seq: u32,
|
||||
) -> Result<u128, TunnelError> {
|
||||
// should add seq here. so latency can be calculated more accurately
|
||||
let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into();
|
||||
tracing::trace!("send ping packet: {:?}", req);
|
||||
sink.lock().await.send(req).await.map_err(|e| {
|
||||
tracing::warn!("send ping packet error: {:?}", e);
|
||||
TunnelError::CommonError("send ping packet error".to_owned())
|
||||
})?;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// wait until we get a pong packet in ctrl_resp_receiver
|
||||
let resp = timeout(Duration::from_secs(1), async {
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(p) => {
|
||||
let ctrl_payload =
|
||||
packet::CtrlPacketPayload::from_packet(Packet::decode(&p));
|
||||
if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload {
|
||||
if resp_seq == seq {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("recv pong resp error: {:?}", e);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"recv pong resp error".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
tracing::trace!(?resp, "wait ping response done");
|
||||
|
||||
if resp.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait ping response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
if resp.as_ref().unwrap().is_err() {
|
||||
return Err(resp.unwrap().err().unwrap());
|
||||
}
|
||||
|
||||
Ok(now.elapsed().as_micros())
|
||||
}
|
||||
|
||||
async fn pingpong(&mut self) {
|
||||
let sink = self.sink.clone();
|
||||
let my_node_id = self.my_peer_id;
|
||||
let peer_id = self.peer_id;
|
||||
let latency_stats = self.latency_stats.clone();
|
||||
|
||||
let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100);
|
||||
|
||||
let stopped = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// generate a pingpong task every 200ms
|
||||
let mut pingpong_tasks = JoinSet::new();
|
||||
let ctrl_resp_sender = self.ctrl_sender.clone();
|
||||
let stopped_clone = stopped.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut req_seq = 0;
|
||||
loop {
|
||||
let receiver = ctrl_resp_sender.subscribe();
|
||||
let ping_res_sender = ping_res_sender.clone();
|
||||
let sink = sink.clone();
|
||||
|
||||
if stopped_clone.load(Ordering::Relaxed) != 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while pingpong_tasks.len() > 5 {
|
||||
pingpong_tasks.join_next().await;
|
||||
}
|
||||
|
||||
pingpong_tasks.spawn(async move {
|
||||
let mut receiver = receiver.resubscribe();
|
||||
let pingpong_once_ret = Self::do_pingpong_once(
|
||||
my_node_id,
|
||||
peer_id,
|
||||
sink.clone(),
|
||||
&mut receiver,
|
||||
req_seq,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = ping_res_sender.send(pingpong_once_ret).await {
|
||||
tracing::info!(?e, "pingpong task send result error, exit..");
|
||||
};
|
||||
});
|
||||
|
||||
req_seq = req_seq.wrapping_add(1);
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
}
|
||||
});
|
||||
|
||||
// one with 1% precision
|
||||
let loss_rate_stats_1 = WindowLatency::new(100);
|
||||
// one with 20% precision, so we can fast fail this conn.
|
||||
let loss_rate_stats_20 = WindowLatency::new(5);
|
||||
|
||||
let mut counter: u64 = 0;
|
||||
|
||||
while let Some(ret) = ping_res_receiver.recv().await {
|
||||
counter += 1;
|
||||
|
||||
if let Ok(lat) = ret {
|
||||
latency_stats.record_latency(lat as u32);
|
||||
|
||||
loss_rate_stats_1.record_latency(0);
|
||||
loss_rate_stats_20.record_latency(0);
|
||||
} else {
|
||||
loss_rate_stats_1.record_latency(1);
|
||||
loss_rate_stats_20.record_latency(1);
|
||||
}
|
||||
|
||||
let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us();
|
||||
let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us();
|
||||
|
||||
tracing::trace!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong task recv pingpong_once result"
|
||||
);
|
||||
|
||||
if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) {
|
||||
tracing::warn!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong loss rate too high, closing"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
self.loss_rate_stats
|
||||
.store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
stopped.store(1, Ordering::Relaxed);
|
||||
ping_res_receiver.close();
|
||||
}
|
||||
}
|
||||
|
||||
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
|
||||
|
||||
pub struct PeerConn {
|
||||
conn_id: PeerConnId,
|
||||
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
|
||||
info: Option<PeerInfo>,
|
||||
|
||||
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||
|
||||
ctrl_resp_sender: broadcast::Sender<Bytes>,
|
||||
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
throughput: Arc<Throughput>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
enum PeerConnPacketType {
|
||||
Data(Bytes),
|
||||
CtrlReq(Bytes),
|
||||
CtrlResp(Bytes),
|
||||
}
|
||||
|
||||
impl PeerConn {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||
let peer_conn_tunnel = PeerConnTunnel::new();
|
||||
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
|
||||
|
||||
PeerConn {
|
||||
conn_id: PeerConnId::new_v4(),
|
||||
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
|
||||
sink: tunnel.pin_sink(),
|
||||
tunnel: Box::new(tunnel),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
info: None,
|
||||
close_event_sender: None,
|
||||
|
||||
ctrl_resp_sender: ctrl_sender,
|
||||
|
||||
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
|
||||
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_id(&self) -> PeerConnId {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
tracing::info!("waiting for handshake request from client");
|
||||
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_req));
|
||||
tracing::info!("handshake request: {:?}", hs_req);
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
tracing::info!("waiting for handshake request from server");
|
||||
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_rsp));
|
||||
tracing::info!("handshake response: {:?}", hs_rsp);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn handshake_done(&self) -> bool {
|
||||
self.info.is_some()
|
||||
}
|
||||
|
||||
pub fn start_pingpong(&mut self) {
|
||||
let mut pingpong = PeerConnPinger::new(
|
||||
self.my_peer_id,
|
||||
self.get_peer_id(),
|
||||
self.tunnel.pin_sink(),
|
||||
self.ctrl_resp_sender.clone(),
|
||||
self.latency_stats.clone(),
|
||||
self.loss_rate_stats.clone(),
|
||||
);
|
||||
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
pingpong.pingpong().await;
|
||||
|
||||
tracing::warn!(?pingpong, "pingpong task exit");
|
||||
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
log::warn!("close event sender error: {:?}", e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let buf = ret.unwrap();
|
||||
let p = Packet::decode(&buf);
|
||||
match p.packet_type {
|
||||
ArchivedPacketType::Ping => {
|
||||
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
|
||||
else {
|
||||
log::error!("unexpected packet: {:?}", p);
|
||||
continue;
|
||||
};
|
||||
|
||||
let pong = packet::Packet::new_pong_packet(
|
||||
conn_info.my_peer_id,
|
||||
conn_info.peer_id,
|
||||
seq.into(),
|
||||
);
|
||||
|
||||
if let Err(e) = sink.send(pong.into()).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
}
|
||||
ArchivedPacketType::Pong => {
|
||||
if let Err(e) = ctrl_sender.send(buf.into()) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if sender.send(buf.into()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
if let Err(close_ret) = sink.close().await {
|
||||
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
|
||||
}
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
|
||||
self.sink.send(msg).await
|
||||
}
|
||||
|
||||
pub fn get_peer_id(&self) -> PeerId {
|
||||
self.info.as_ref().unwrap().my_peer_id
|
||||
}
|
||||
|
||||
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||
self.info.as_ref().unwrap().network_identity.clone()
|
||||
}
|
||||
|
||||
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||
self.close_event_sender = Some(sender);
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> PeerConnStats {
|
||||
PeerConnStats {
|
||||
latency_us: self.latency_stats.get_latency_us(),
|
||||
|
||||
tx_bytes: self.throughput.tx_bytes(),
|
||||
rx_bytes: self.throughput.rx_bytes(),
|
||||
|
||||
tx_packets: self.throughput.tx_packets(),
|
||||
rx_packets: self.throughput.rx_packets(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_info(&self) -> PeerConnInfo {
|
||||
PeerConnInfo {
|
||||
conn_id: self.conn_id.to_string(),
|
||||
my_peer_id: self.my_peer_id,
|
||||
peer_id: self.get_peer_id(),
|
||||
features: self.info.as_ref().unwrap().features.clone(),
|
||||
tunnel: self.tunnel.info(),
|
||||
stats: Some(self.get_stats()),
|
||||
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PeerConn {
|
||||
fn drop(&mut self) {
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
tokio::spawn(async move {
|
||||
let ret = sink.close().await;
|
||||
tracing::info!(error = ?ret, "peer conn tunnel closed.");
|
||||
});
|
||||
log::info!("peer conn {:?} drop", self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for PeerConn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConn")
|
||||
.field("conn_id", &self.conn_id)
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::common::global_ctx::tests::get_mock_global_ctx;
|
||||
use crate::common::new_peer_id;
|
||||
use crate::tunnels::tunnel_filter::tests::DropSendTunnelFilter;
|
||||
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_handshake() {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
|
||||
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||
let s = TunnelWithFilter::new(s, s_recorder.clone());
|
||||
|
||||
let c_peer_id = new_peer_id();
|
||||
let s_peer_id = new_peer_id();
|
||||
|
||||
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||
|
||||
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||
|
||||
let (c_ret, s_ret) = tokio::join!(
|
||||
c_peer.do_handshake_as_client(),
|
||||
s_peer.do_handshake_as_server()
|
||||
);
|
||||
|
||||
c_ret.unwrap();
|
||||
s_ret.unwrap();
|
||||
|
||||
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(c_peer.get_peer_id(), s_peer_id);
|
||||
assert_eq!(s_peer.get_peer_id(), c_peer_id);
|
||||
assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity());
|
||||
assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default());
|
||||
}
|
||||
|
||||
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
// drop 1-3 packets should not affect pingpong
|
||||
let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end));
|
||||
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||
|
||||
let c_peer_id = new_peer_id();
|
||||
let s_peer_id = new_peer_id();
|
||||
|
||||
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||
|
||||
let (c_ret, s_ret) = tokio::join!(
|
||||
c_peer.do_handshake_as_client(),
|
||||
s_peer.do_handshake_as_server()
|
||||
);
|
||||
|
||||
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
|
||||
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
|
||||
assert!(c_ret.is_ok());
|
||||
assert!(s_ret.is_ok());
|
||||
|
||||
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
|
||||
c_peer.set_close_event_sender(close_send);
|
||||
c_peer.start_pingpong();
|
||||
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
|
||||
// wait 5s, conn should not be disconnected
|
||||
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||
|
||||
if conn_closed {
|
||||
assert!(close_recv.try_recv().is_ok());
|
||||
} else {
|
||||
assert!(close_recv.try_recv().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_pingpong_timeout() {
|
||||
peer_conn_pingpong_test_common(3, 5, false).await;
|
||||
peer_conn_pingpong_test_common(5, 12, true).await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,641 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
net::Ipv4Addr,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
Mutex, RwLock,
|
||||
},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string,
|
||||
PeerId,
|
||||
},
|
||||
peers::{
|
||||
packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport,
|
||||
route_trait::RouteInterface, PeerPacketFilter,
|
||||
},
|
||||
tunnels::{SinkItem, Tunnel, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::{
|
||||
foreign_network_client::ForeignNetworkClient,
|
||||
foreign_network_manager::ForeignNetworkManager,
|
||||
peer_conn::PeerConnId,
|
||||
peer_map::PeerMap,
|
||||
peer_ospf_route::PeerRoute,
|
||||
peer_rip_route::BasicRoute,
|
||||
peer_rpc::PeerRpcManager,
|
||||
route_trait::{ArcRoute, Route},
|
||||
BoxNicPacketFilter, BoxPeerPacketFilter,
|
||||
};
|
||||
|
||||
struct RpcTransport {
|
||||
my_peer_id: PeerId,
|
||||
peers: Weak<PeerMap>,
|
||||
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
|
||||
|
||||
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
|
||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for RpcTransport {
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let foreign_peers = self
|
||||
.foreign_peers
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.ok_or(Error::Unknown)?
|
||||
.upgrade()
|
||||
.ok_or(Error::Unknown)?;
|
||||
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||
|
||||
let ret = peers.send_msg(msg.clone(), dst_peer_id).await;
|
||||
|
||||
if matches!(ret, Err(Error::RouteError(..))) && foreign_peers.has_next_hop(dst_peer_id) {
|
||||
tracing::info!(
|
||||
?dst_peer_id,
|
||||
?self.my_peer_id,
|
||||
"failed to send msg to peer, try foreign network",
|
||||
);
|
||||
return foreign_peers.send_msg(msg, dst_peer_id).await;
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
||||
Ok(o)
|
||||
} else {
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum RouteAlgoType {
|
||||
Rip,
|
||||
Ospf,
|
||||
None,
|
||||
}
|
||||
|
||||
enum RouteAlgoInst {
|
||||
Rip(Arc<BasicRoute>),
|
||||
Ospf(Arc<PeerRoute>),
|
||||
None,
|
||||
}
|
||||
|
||||
pub struct PeerManager {
|
||||
my_peer_id: PeerId,
|
||||
|
||||
global_ctx: ArcGlobalCtx,
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
|
||||
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
|
||||
|
||||
peers: Arc<PeerMap>,
|
||||
|
||||
peer_rpc_mgr: Arc<PeerRpcManager>,
|
||||
peer_rpc_tspt: Arc<RpcTransport>,
|
||||
|
||||
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>,
|
||||
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>,
|
||||
|
||||
route_algo_inst: RouteAlgoInst,
|
||||
|
||||
foreign_network_manager: Arc<ForeignNetworkManager>,
|
||||
foreign_network_client: Arc<ForeignNetworkClient>,
|
||||
}
|
||||
|
||||
impl Debug for PeerManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerManager")
|
||||
.field("my_peer_id", &self.my_peer_id())
|
||||
.field("instance_name", &self.global_ctx.inst_name)
|
||||
.field("net_ns", &self.global_ctx.net_ns.name())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerManager {
|
||||
pub fn new(
|
||||
route_algo: RouteAlgoType,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
) -> Self {
|
||||
let my_peer_id = rand::random();
|
||||
|
||||
let (packet_send, packet_recv) = mpsc::channel(100);
|
||||
let peers = Arc::new(PeerMap::new(
|
||||
packet_send.clone(),
|
||||
global_ctx.clone(),
|
||||
my_peer_id,
|
||||
));
|
||||
|
||||
// TODO: remove these because we have impl pipeline processor.
|
||||
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
|
||||
let rpc_tspt = Arc::new(RpcTransport {
|
||||
my_peer_id,
|
||||
peers: Arc::downgrade(&peers),
|
||||
foreign_peers: Mutex::new(None),
|
||||
packet_recv: Mutex::new(peer_rpc_tspt_recv),
|
||||
peer_rpc_tspt_sender,
|
||||
});
|
||||
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
|
||||
|
||||
let route_algo_inst = match route_algo {
|
||||
RouteAlgoType::Rip => {
|
||||
RouteAlgoInst::Rip(Arc::new(BasicRoute::new(my_peer_id, global_ctx.clone())))
|
||||
}
|
||||
RouteAlgoType::Ospf => RouteAlgoInst::Ospf(PeerRoute::new(
|
||||
my_peer_id,
|
||||
global_ctx.clone(),
|
||||
peer_rpc_mgr.clone(),
|
||||
)),
|
||||
RouteAlgoType::None => RouteAlgoInst::None,
|
||||
};
|
||||
|
||||
let foreign_network_manager = Arc::new(ForeignNetworkManager::new(
|
||||
my_peer_id,
|
||||
global_ctx.clone(),
|
||||
packet_send.clone(),
|
||||
));
|
||||
let foreign_network_client = Arc::new(ForeignNetworkClient::new(
|
||||
global_ctx.clone(),
|
||||
packet_send.clone(),
|
||||
peer_rpc_mgr.clone(),
|
||||
my_peer_id,
|
||||
));
|
||||
|
||||
PeerManager {
|
||||
my_peer_id,
|
||||
|
||||
global_ctx,
|
||||
nic_channel,
|
||||
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
|
||||
packet_recv: Arc::new(Mutex::new(Some(packet_recv))),
|
||||
|
||||
peers: peers.clone(),
|
||||
|
||||
peer_rpc_mgr,
|
||||
peer_rpc_tspt: rpc_tspt,
|
||||
|
||||
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
|
||||
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
|
||||
|
||||
route_algo_inst,
|
||||
|
||||
foreign_network_manager,
|
||||
foreign_network_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_client_tunnel(
|
||||
&self,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
) -> Result<(PeerId, PeerConnId), Error> {
|
||||
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
|
||||
peer.do_handshake_as_client().await?;
|
||||
let conn_id = peer.get_conn_id();
|
||||
let peer_id = peer.get_peer_id();
|
||||
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
|
||||
self.peers.add_new_peer_conn(peer).await;
|
||||
} else {
|
||||
self.foreign_network_client.add_new_peer_conn(peer).await;
|
||||
}
|
||||
Ok((peer_id, conn_id))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn try_connect<C>(&self, mut connector: C) -> Result<(PeerId, PeerConnId), Error>
|
||||
where
|
||||
C: TunnelConnector + Debug,
|
||||
{
|
||||
let ns = self.global_ctx.net_ns.clone();
|
||||
let t = ns
|
||||
.run_async(|| async move { connector.connect().await })
|
||||
.await?;
|
||||
self.add_client_tunnel(t).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn add_tunnel_as_server(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||
tracing::info!("add tunnel as server start");
|
||||
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
|
||||
peer.do_handshake_as_server().await?;
|
||||
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
|
||||
self.peers.add_new_peer_conn(peer).await;
|
||||
} else {
|
||||
self.foreign_network_manager.add_peer_conn(peer).await?;
|
||||
}
|
||||
tracing::info!("add tunnel as server done");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_peer_recv(&self) {
|
||||
let mut recv = ReceiverStream::new(self.packet_recv.lock().await.take().unwrap());
|
||||
let my_peer_id = self.my_peer_id;
|
||||
let peers = self.peers.clone();
|
||||
let pipe_line = self.peer_packet_process_pipeline.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
log::trace!("start_peer_recv");
|
||||
while let Some(ret) = recv.next().await {
|
||||
log::trace!("peer recv a packet...: {:?}", ret);
|
||||
let packet = packet::Packet::decode(&ret);
|
||||
let from_peer_id: PeerId = packet.from_peer.into();
|
||||
let to_peer_id: PeerId = packet.to_peer.into();
|
||||
if to_peer_id != my_peer_id {
|
||||
log::trace!(
|
||||
"need forward: to_peer_id: {:?}, my_peer_id: {:?}",
|
||||
to_peer_id,
|
||||
my_peer_id
|
||||
);
|
||||
let ret = peers.send_msg(ret.clone(), to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
log::error!(
|
||||
"forward packet error: {:?}, dst: {:?}, from: {:?}",
|
||||
ret,
|
||||
to_peer_id,
|
||||
from_peer_id
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let mut processed = false;
|
||||
for pipeline in pipe_line.read().await.iter().rev() {
|
||||
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
|
||||
{
|
||||
processed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !processed {
|
||||
tracing::error!("unexpected packet: {:?}", ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!("done_peer_recv");
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
|
||||
// newest pipeline will be executed first
|
||||
self.peer_packet_process_pipeline
|
||||
.write()
|
||||
.await
|
||||
.push(pipeline);
|
||||
}
|
||||
|
||||
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
|
||||
// newest pipeline will be executed first
|
||||
self.nic_packet_process_pipeline
|
||||
.write()
|
||||
.await
|
||||
.push(pipeline);
|
||||
}
|
||||
|
||||
async fn init_packet_process_pipeline(&self) {
|
||||
// for tun/tap ip/eth packet.
|
||||
struct NicPacketProcessor {
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
}
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for NicPacketProcessor {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if packet.packet_type == packet::PacketType::Data {
|
||||
// TODO: use a function to get the body ref directly for zero copy
|
||||
self.nic_channel
|
||||
.send(extract_bytes_from_archived_string(data, &packet.payload))
|
||||
.await
|
||||
.unwrap();
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
self.add_packet_process_pipeline(Box::new(NicPacketProcessor {
|
||||
nic_channel: self.nic_channel.clone(),
|
||||
}))
|
||||
.await;
|
||||
|
||||
// for peer rpc packet
|
||||
struct PeerRpcPacketProcessor {
|
||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for PeerRpcPacketProcessor {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if packet.packet_type == packet::PacketType::TaRpc {
|
||||
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor {
|
||||
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn add_route<T>(&self, route: T)
|
||||
where
|
||||
T: Route + PeerPacketFilter + Send + Sync + Clone + 'static,
|
||||
{
|
||||
// for route
|
||||
self.add_packet_process_pipeline(Box::new(route.clone()))
|
||||
.await;
|
||||
|
||||
struct Interface {
|
||||
my_peer_id: PeerId,
|
||||
peers: Weak<PeerMap>,
|
||||
foreign_network_client: Weak<ForeignNetworkClient>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouteInterface for Interface {
|
||||
async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let Some(foreign_client) = self.foreign_network_client.upgrade() else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let Some(peer_map) = self.peers.upgrade() else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut peers = foreign_client.list_foreign_peers();
|
||||
peers.extend(peer_map.list_peers_with_conn().await);
|
||||
peers
|
||||
}
|
||||
async fn send_route_packet(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
route_id: u8,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error> {
|
||||
let foreign_client = self
|
||||
.foreign_network_client
|
||||
.upgrade()
|
||||
.ok_or(Error::Unknown)?;
|
||||
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||
|
||||
let packet_bytes: Bytes =
|
||||
packet::Packet::new_route_packet(self.my_peer_id, dst_peer_id, route_id, &msg)
|
||||
.into();
|
||||
if foreign_client.has_next_hop(dst_peer_id) {
|
||||
return foreign_client.send_msg(packet_bytes, dst_peer_id).await;
|
||||
}
|
||||
|
||||
peer_map.send_msg_directly(packet_bytes, dst_peer_id).await
|
||||
}
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
}
|
||||
|
||||
let my_peer_id = self.my_peer_id;
|
||||
let _route_id = route
|
||||
.open(Box::new(Interface {
|
||||
my_peer_id,
|
||||
peers: Arc::downgrade(&self.peers),
|
||||
foreign_network_client: Arc::downgrade(&self.foreign_network_client),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let arc_route: ArcRoute = Arc::new(Box::new(route));
|
||||
self.peers.add_route(arc_route).await;
|
||||
}
|
||||
|
||||
pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> {
|
||||
match &self.route_algo_inst {
|
||||
RouteAlgoInst::Rip(route) => Box::new(route.clone()),
|
||||
RouteAlgoInst::Ospf(route) => Box::new(route.clone()),
|
||||
RouteAlgoInst::None => panic!("no route"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_routes(&self) -> Vec<crate::rpc::Route> {
|
||||
self.get_route().list_routes().await
|
||||
}
|
||||
|
||||
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
|
||||
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
|
||||
data = pipeline.try_process_packet_from_nic(data).await;
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
self.peers.send_msg(msg, dst_peer_id).await
|
||||
}
|
||||
|
||||
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
|
||||
log::trace!(
|
||||
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
|
||||
msg,
|
||||
ipv4_addr
|
||||
);
|
||||
|
||||
let mut dst_peers = vec![];
|
||||
// NOTE: currently we only support ipv4 and cidr is 24
|
||||
if ipv4_addr.is_broadcast() || ipv4_addr.is_multicast() || ipv4_addr.octets()[3] == 255 {
|
||||
dst_peers.extend(
|
||||
self.peers
|
||||
.list_routes()
|
||||
.await
|
||||
.iter()
|
||||
.map(|x| x.key().clone()),
|
||||
);
|
||||
} else if let Some(peer_id) = self.peers.get_peer_id_by_ipv4(&ipv4_addr).await {
|
||||
dst_peers.push(peer_id);
|
||||
}
|
||||
|
||||
if dst_peers.is_empty() {
|
||||
tracing::info!("no peer id for ipv4: {}", ipv4_addr);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let msg = self.run_nic_packet_process_pipeline(msg).await;
|
||||
let mut errs: Vec<Error> = vec![];
|
||||
|
||||
for peer_id in dst_peers.iter() {
|
||||
let msg: Bytes =
|
||||
packet::Packet::new_data_packet(self.my_peer_id, peer_id.clone(), &msg).into();
|
||||
let send_ret = self.peers.send_msg(msg.clone(), *peer_id).await;
|
||||
|
||||
if matches!(send_ret, Err(Error::RouteError(..)))
|
||||
&& self.foreign_network_client.has_next_hop(*peer_id)
|
||||
{
|
||||
let foreign_send_ret = self.foreign_network_client.send_msg(msg, *peer_id).await;
|
||||
if foreign_send_ret.is_ok() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(send_ret) = send_ret {
|
||||
errs.push(send_ret);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::trace!(?dst_peers, "do send_msg in peer manager done");
|
||||
|
||||
if errs.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
tracing::error!(?errs, "send_msg has error");
|
||||
Err(anyhow::anyhow!("send_msg has error: {:?}", errs).into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_clean_peer_without_conn_routine(&self) {
|
||||
let peer_map = self.peers.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
peer_map.clean_peer_without_conn().await;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_foriegn_network(&self) {
|
||||
self.peer_rpc_tspt
|
||||
.foreign_peers
|
||||
.lock()
|
||||
.await
|
||||
.replace(Arc::downgrade(&self.foreign_network_client));
|
||||
|
||||
self.foreign_network_manager.run().await;
|
||||
self.foreign_network_client.run().await;
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<(), Error> {
|
||||
match &self.route_algo_inst {
|
||||
RouteAlgoInst::Ospf(route) => self.add_route(route.clone()).await,
|
||||
RouteAlgoInst::Rip(route) => self.add_route(route.clone()).await,
|
||||
RouteAlgoInst::None => {}
|
||||
};
|
||||
|
||||
self.init_packet_process_pipeline().await;
|
||||
self.peer_rpc_mgr.run();
|
||||
|
||||
self.start_peer_recv().await;
|
||||
self.run_clean_peer_without_conn_routine().await;
|
||||
|
||||
self.run_foriegn_network().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_peer_map(&self) -> Arc<PeerMap> {
|
||||
self.peers.clone()
|
||||
}
|
||||
|
||||
pub fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
|
||||
self.peer_rpc_mgr.clone()
|
||||
}
|
||||
|
||||
pub fn my_node_id(&self) -> uuid::Uuid {
|
||||
self.global_ctx.get_id()
|
||||
}
|
||||
|
||||
pub fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
|
||||
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
|
||||
self.global_ctx.clone()
|
||||
}
|
||||
|
||||
pub fn get_nic_channel(&self) -> mpsc::Sender<SinkItem> {
|
||||
self.nic_channel.clone()
|
||||
}
|
||||
|
||||
pub fn get_basic_route(&self) -> Arc<BasicRoute> {
|
||||
match &self.route_algo_inst {
|
||||
RouteAlgoInst::Rip(route) => route.clone(),
|
||||
_ => panic!("not rip route"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_foreign_network_manager(&self) -> Arc<ForeignNetworkManager> {
|
||||
self.foreign_network_manager.clone()
|
||||
}
|
||||
|
||||
pub fn get_foreign_network_client(&self) -> Arc<ForeignNetworkClient> {
|
||||
self.foreign_network_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::{
|
||||
connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
|
||||
peers::tests::{connect_peer_manager, wait_for_condition, wait_route_appear},
|
||||
rpc::NatType,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn drop_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
|
||||
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
|
||||
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_c.clone()).await;
|
||||
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// wait mgr_a have 2 peers
|
||||
wait_for_condition(
|
||||
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 2 },
|
||||
std::time::Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
drop(peer_mgr_b);
|
||||
|
||||
wait_for_condition(
|
||||
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 1 },
|
||||
std::time::Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
PeerId,
|
||||
},
|
||||
rpc::PeerConnInfo,
|
||||
tunnels::TunnelError,
|
||||
};
|
||||
|
||||
use super::{
|
||||
peer::Peer,
|
||||
peer_conn::{PeerConn, PeerConnId},
|
||||
route_trait::ArcRoute,
|
||||
};
|
||||
|
||||
pub struct PeerMap {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
my_peer_id: PeerId,
|
||||
peer_map: DashMap<PeerId, Arc<Peer>>,
|
||||
packet_send: mpsc::Sender<Bytes>,
|
||||
routes: RwLock<Vec<ArcRoute>>,
|
||||
}
|
||||
|
||||
impl PeerMap {
|
||||
pub fn new(
|
||||
packet_send: mpsc::Sender<Bytes>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
my_peer_id: PeerId,
|
||||
) -> Self {
|
||||
PeerMap {
|
||||
global_ctx,
|
||||
my_peer_id,
|
||||
peer_map: DashMap::new(),
|
||||
packet_send,
|
||||
routes: RwLock::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_new_peer(&self, peer: Peer) {
|
||||
let peer_id = peer.peer_node_id.clone();
|
||||
self.peer_map.insert(peer_id.clone(), Arc::new(peer));
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::PeerAdded(peer_id));
|
||||
}
|
||||
|
||||
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
|
||||
let peer_id = peer_conn.get_peer_id();
|
||||
let no_entry = self.peer_map.get(&peer_id).is_none();
|
||||
if no_entry {
|
||||
let new_peer = Peer::new(peer_id, self.packet_send.clone(), self.global_ctx.clone());
|
||||
new_peer.add_peer_conn(peer_conn).await;
|
||||
self.add_new_peer(new_peer).await;
|
||||
} else {
|
||||
let peer = self.peer_map.get(&peer_id).unwrap().clone();
|
||||
peer.add_peer_conn(peer_conn).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn get_peer_by_id(&self, peer_id: PeerId) -> Option<Arc<Peer>> {
|
||||
self.peer_map.get(&peer_id).map(|v| v.clone())
|
||||
}
|
||||
|
||||
pub fn has_peer(&self, peer_id: PeerId) -> bool {
|
||||
self.peer_map.contains_key(&peer_id)
|
||||
}
|
||||
|
||||
pub async fn send_msg_directly(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
if dst_peer_id == self.my_peer_id {
|
||||
return Ok(self
|
||||
.packet_send
|
||||
.send(msg)
|
||||
.await
|
||||
.with_context(|| "send msg to self failed")?);
|
||||
}
|
||||
|
||||
match self.get_peer_by_id(dst_peer_id) {
|
||||
Some(peer) => {
|
||||
peer.send_msg(msg).await?;
|
||||
}
|
||||
None => {
|
||||
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
|
||||
return Err(Error::RouteError(None));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
if dst_peer_id == self.my_peer_id {
|
||||
return Ok(self
|
||||
.packet_send
|
||||
.send(msg)
|
||||
.await
|
||||
.with_context(|| "send msg to self failed")?);
|
||||
}
|
||||
|
||||
// get route info
|
||||
let mut gateway_peer_id = None;
|
||||
for route in self.routes.read().await.iter() {
|
||||
gateway_peer_id = route.get_next_hop(dst_peer_id).await;
|
||||
if gateway_peer_id.is_none() {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if gateway_peer_id.is_none() && self.has_peer(dst_peer_id) {
|
||||
gateway_peer_id = Some(dst_peer_id);
|
||||
}
|
||||
|
||||
let Some(gateway_peer_id) = gateway_peer_id else {
|
||||
tracing::trace!(
|
||||
"no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}",
|
||||
dst_peer_id,
|
||||
self.peer_map.iter().map(|v| *v.key()).collect::<Vec<_>>(),
|
||||
self.my_peer_id
|
||||
);
|
||||
return Err(Error::RouteError(None));
|
||||
};
|
||||
|
||||
self.send_msg_directly(msg.clone(), gateway_peer_id).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
pub async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
for route in self.routes.read().await.iter() {
|
||||
let peer_id = route.get_peer_id_by_ipv4(ipv4).await;
|
||||
if peer_id.is_some() {
|
||||
return peer_id;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.peer_map.is_empty()
|
||||
}
|
||||
|
||||
pub async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let mut ret = Vec::new();
|
||||
for item in self.peer_map.iter() {
|
||||
let peer_id = item.key();
|
||||
ret.push(*peer_id);
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub async fn list_peers_with_conn(&self) -> Vec<PeerId> {
|
||||
let mut ret = Vec::new();
|
||||
let peers = self.list_peers().await;
|
||||
for peer_id in peers.iter() {
|
||||
let Some(peer) = self.get_peer_by_id(*peer_id) else {
|
||||
continue;
|
||||
};
|
||||
if peer.list_peer_conns().await.len() > 0 {
|
||||
ret.push(*peer_id);
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
|
||||
if let Some(p) = self.get_peer_by_id(peer_id) {
|
||||
Some(p.list_peer_conns().await)
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(
|
||||
&self,
|
||||
peer_id: PeerId,
|
||||
conn_id: &PeerConnId,
|
||||
) -> Result<(), Error> {
|
||||
if let Some(p) = self.get_peer_by_id(peer_id) {
|
||||
p.close_peer_conn(conn_id).await
|
||||
} else {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close_peer(&self, peer_id: PeerId) -> Result<(), TunnelError> {
|
||||
let remove_ret = self.peer_map.remove(&peer_id);
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::PeerRemoved(peer_id));
|
||||
tracing::info!(
|
||||
?peer_id,
|
||||
has_old_value = ?remove_ret.is_some(),
|
||||
peer_ref_counter = ?remove_ret.map(|v| Arc::strong_count(&v.1)),
|
||||
"peer is closed"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_route(&self, route: ArcRoute) {
|
||||
let mut routes = self.routes.write().await;
|
||||
routes.insert(0, route);
|
||||
}
|
||||
|
||||
pub async fn clean_peer_without_conn(&self) {
|
||||
let mut to_remove = vec![];
|
||||
|
||||
for peer_id in self.list_peers().await {
|
||||
let conns = self.list_peer_conns(peer_id).await;
|
||||
if conns.is_none() || conns.as_ref().unwrap().is_empty() {
|
||||
to_remove.push(peer_id);
|
||||
}
|
||||
}
|
||||
|
||||
for peer_id in to_remove {
|
||||
self.close_peer(peer_id).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_routes(&self) -> DashMap<PeerId, PeerId> {
|
||||
let route_map = DashMap::new();
|
||||
for route in self.routes.read().await.iter() {
|
||||
for item in route.list_routes().await.iter() {
|
||||
route_map.insert(item.peer_id, item.next_hop_peer_id);
|
||||
}
|
||||
}
|
||||
route_map
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,770 @@
|
||||
use std::{
|
||||
net::Ipv4Addr,
|
||||
sync::{atomic::AtomicU32, Arc},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use tokio::{
|
||||
sync::{Mutex, RwLock},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
|
||||
peers::{
|
||||
packet,
|
||||
route_trait::{Route, RouteInterfaceBox},
|
||||
},
|
||||
rpc::{NatType, StunInfo},
|
||||
};
|
||||
|
||||
use super::{packet::CtrlPacketPayload, PeerPacketFilter};
|
||||
|
||||
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
|
||||
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
|
||||
const ROUTE_EXPIRED_SEC: u64 = 70;
|
||||
|
||||
type Version = u32;
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub struct SyncPeerInfo {
|
||||
// means next hop in route table.
|
||||
pub peer_id: PeerId,
|
||||
pub cost: u32,
|
||||
pub ipv4_addr: Option<Ipv4Addr>,
|
||||
pub proxy_cidrs: Vec<String>,
|
||||
pub hostname: Option<String>,
|
||||
pub udp_stun_info: i8,
|
||||
}
|
||||
|
||||
impl SyncPeerInfo {
|
||||
pub fn new_self(from_peer: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
|
||||
SyncPeerInfo {
|
||||
peer_id: from_peer,
|
||||
cost: 0,
|
||||
ipv4_addr: global_ctx.get_ipv4(),
|
||||
proxy_cidrs: global_ctx
|
||||
.get_proxy_cidrs()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
|
||||
.collect(),
|
||||
hostname: global_ctx.get_hostname(),
|
||||
udp_stun_info: global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type as i8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clone_for_route_table(&self, next_hop: PeerId, cost: u32, from: &Self) -> Self {
|
||||
SyncPeerInfo {
|
||||
peer_id: next_hop,
|
||||
cost,
|
||||
ipv4_addr: from.ipv4_addr.clone(),
|
||||
proxy_cidrs: from.proxy_cidrs.clone(),
|
||||
hostname: from.hostname.clone(),
|
||||
udp_stun_info: from.udp_stun_info,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
|
||||
pub struct SyncPeer {
|
||||
pub myself: SyncPeerInfo,
|
||||
pub neighbors: Vec<SyncPeerInfo>,
|
||||
// the route table version of myself
|
||||
pub version: Version,
|
||||
// the route table version of peer that we have received last time
|
||||
pub peer_version: Option<Version>,
|
||||
// if we do not have latest peer version, need_reply is true
|
||||
pub need_reply: bool,
|
||||
}
|
||||
|
||||
impl SyncPeer {
|
||||
pub fn new(
|
||||
from_peer: PeerId,
|
||||
_to_peer: PeerId,
|
||||
neighbors: Vec<SyncPeerInfo>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
version: Version,
|
||||
peer_version: Option<Version>,
|
||||
need_reply: bool,
|
||||
) -> Self {
|
||||
SyncPeer {
|
||||
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
|
||||
neighbors,
|
||||
version,
|
||||
peer_version,
|
||||
need_reply,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SyncPeerFromRemote {
|
||||
packet: SyncPeer,
|
||||
last_update: std::time::Instant,
|
||||
}
|
||||
|
||||
type SyncPeerFromRemoteMap = Arc<DashMap<PeerId, SyncPeerFromRemote>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RouteTable {
|
||||
route_info: DashMap<PeerId, SyncPeerInfo>,
|
||||
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
|
||||
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
|
||||
}
|
||||
|
||||
impl RouteTable {
|
||||
fn new() -> Self {
|
||||
RouteTable {
|
||||
route_info: DashMap::new(),
|
||||
ipv4_peer_id_map: DashMap::new(),
|
||||
cidr_peer_id_map: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_from(&self, other: &Self) {
|
||||
self.route_info.clear();
|
||||
for item in other.route_info.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.route_info.insert(*k, v.clone());
|
||||
}
|
||||
|
||||
self.ipv4_peer_id_map.clear();
|
||||
for item in other.ipv4_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.ipv4_peer_id_map.insert(*k, *v);
|
||||
}
|
||||
|
||||
self.cidr_peer_id_map.clear();
|
||||
for item in other.cidr_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.cidr_peer_id_map.insert(*k, *v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RouteVersion(Arc<AtomicU32>);
|
||||
|
||||
impl RouteVersion {
|
||||
fn new() -> Self {
|
||||
// RouteVersion(Arc::new(AtomicU32::new(rand::random())))
|
||||
RouteVersion(Arc::new(AtomicU32::new(0)))
|
||||
}
|
||||
|
||||
fn get(&self) -> Version {
|
||||
self.0.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn inc(&self) {
|
||||
self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BasicRoute {
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
|
||||
|
||||
route_table: Arc<RouteTable>,
|
||||
|
||||
sync_peer_from_remote: SyncPeerFromRemoteMap,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
|
||||
need_sync_notifier: Arc<tokio::sync::Notify>,
|
||||
|
||||
version: RouteVersion,
|
||||
myself: Arc<RwLock<SyncPeerInfo>>,
|
||||
last_send_time_map: Arc<DashMap<PeerId, (Version, Option<Version>, Instant)>>,
|
||||
}
|
||||
|
||||
impl BasicRoute {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx) -> Self {
|
||||
BasicRoute {
|
||||
my_peer_id,
|
||||
global_ctx: global_ctx.clone(),
|
||||
interface: Arc::new(Mutex::new(None)),
|
||||
|
||||
route_table: Arc::new(RouteTable::new()),
|
||||
|
||||
sync_peer_from_remote: Arc::new(DashMap::new()),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
|
||||
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
|
||||
|
||||
version: RouteVersion::new(),
|
||||
myself: Arc::new(RwLock::new(SyncPeerInfo::new_self(
|
||||
my_peer_id.into(),
|
||||
&global_ctx,
|
||||
))),
|
||||
last_send_time_map: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_route_table(
|
||||
my_id: PeerId,
|
||||
sync_peer_reqs: SyncPeerFromRemoteMap,
|
||||
route_table: Arc<RouteTable>,
|
||||
) {
|
||||
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
|
||||
|
||||
let new_route_table = Arc::new(RouteTable::new());
|
||||
for item in sync_peer_reqs.iter() {
|
||||
Self::update_route_table_with_req(my_id, &item.value().packet, new_route_table.clone());
|
||||
}
|
||||
|
||||
route_table.copy_from(&new_route_table);
|
||||
}
|
||||
|
||||
async fn update_myself(
|
||||
my_peer_id: PeerId,
|
||||
myself: &Arc<RwLock<SyncPeerInfo>>,
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
) -> bool {
|
||||
let new_myself = SyncPeerInfo::new_self(my_peer_id, &global_ctx);
|
||||
if *myself.read().await != new_myself {
|
||||
*myself.write().await = new_myself;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn update_route_table_with_req(my_id: PeerId, packet: &SyncPeer, route_table: Arc<RouteTable>) {
|
||||
let peer_id = packet.myself.peer_id.clone();
|
||||
let update = |cost: u32, peer_info: &SyncPeerInfo| {
|
||||
let node_id: PeerId = peer_info.peer_id.into();
|
||||
let ret = route_table
|
||||
.route_info
|
||||
.entry(node_id.clone().into())
|
||||
.and_modify(|info| {
|
||||
if info.cost > cost {
|
||||
*info = info.clone_for_route_table(peer_id, cost, &peer_info);
|
||||
}
|
||||
})
|
||||
.or_insert(
|
||||
peer_info
|
||||
.clone()
|
||||
.clone_for_route_table(peer_id, cost, &peer_info),
|
||||
)
|
||||
.value()
|
||||
.clone();
|
||||
|
||||
if ret.cost > 6 {
|
||||
log::error!(
|
||||
"cost too large: {}, may lost connection, remove it",
|
||||
ret.cost
|
||||
);
|
||||
route_table.route_info.remove(&node_id);
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
|
||||
node_id,
|
||||
peer_id,
|
||||
cost,
|
||||
&peer_info
|
||||
);
|
||||
|
||||
if let Some(ipv4) = peer_info.ipv4_addr {
|
||||
route_table
|
||||
.ipv4_peer_id_map
|
||||
.insert(ipv4.clone(), node_id.clone().into());
|
||||
}
|
||||
|
||||
for cidr in peer_info.proxy_cidrs.iter() {
|
||||
let cidr: cidr::IpCidr = cidr.parse().unwrap();
|
||||
route_table
|
||||
.cidr_peer_id_map
|
||||
.insert(cidr, node_id.clone().into());
|
||||
}
|
||||
};
|
||||
|
||||
for neighbor in packet.neighbors.iter() {
|
||||
if neighbor.peer_id == my_id {
|
||||
continue;
|
||||
}
|
||||
update(neighbor.cost + 1, &neighbor);
|
||||
log::trace!("route info: {:?}", neighbor);
|
||||
}
|
||||
|
||||
// add the sender peer to route info
|
||||
update(1, &packet.myself);
|
||||
|
||||
log::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
|
||||
}
|
||||
|
||||
async fn send_sync_peer_request(
|
||||
interface: &RouteInterfaceBox,
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_id: PeerId,
|
||||
route_table: Arc<RouteTable>,
|
||||
my_version: Version,
|
||||
peer_version: Option<Version>,
|
||||
need_reply: bool,
|
||||
) -> Result<(), Error> {
|
||||
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
|
||||
// copy the route info
|
||||
for item in route_table.route_info.iter() {
|
||||
let (k, v) = item.pair();
|
||||
route_info_copy.push(v.clone().clone_for_route_table(*k, v.cost, &v));
|
||||
}
|
||||
let msg = SyncPeer::new(
|
||||
my_peer_id,
|
||||
peer_id,
|
||||
route_info_copy,
|
||||
global_ctx,
|
||||
my_version,
|
||||
peer_version,
|
||||
need_reply,
|
||||
);
|
||||
// TODO: this may exceed the MTU of the tunnel
|
||||
interface
|
||||
.send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn sync_peer_periodically(&self) {
|
||||
let route_table = self.route_table.clone();
|
||||
let global_ctx = self.global_ctx.clone();
|
||||
let my_peer_id = self.my_peer_id.clone();
|
||||
let interface = self.interface.clone();
|
||||
let notifier = self.need_sync_notifier.clone();
|
||||
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
|
||||
let myself = self.myself.clone();
|
||||
let version = self.version.clone();
|
||||
let last_send_time_map = self.last_send_time_map.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
if Self::update_myself(my_peer_id,&myself, &global_ctx).await {
|
||||
version.inc();
|
||||
tracing::info!(
|
||||
my_id = ?my_peer_id,
|
||||
version = version.get(),
|
||||
"update route table version when myself changed"
|
||||
);
|
||||
}
|
||||
|
||||
let lockd_interface = interface.lock().await;
|
||||
let interface = lockd_interface.as_ref().unwrap();
|
||||
let last_send_time_map_new = DashMap::new();
|
||||
let peers = interface.list_peers().await;
|
||||
for peer in peers.iter() {
|
||||
let last_send_time = last_send_time_map.get(peer).map(|v| *v).unwrap_or((0, None, Instant::now() - Duration::from_secs(3600)));
|
||||
let my_version_peer_saved = sync_peer_from_remote.get(peer).and_then(|v| v.packet.peer_version);
|
||||
let peer_have_latest_version = my_version_peer_saved == Some(version.get());
|
||||
if peer_have_latest_version && last_send_time.2.elapsed().as_secs() < SEND_ROUTE_PERIOD_SEC {
|
||||
last_send_time_map_new.insert(*peer, last_send_time);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::trace!(
|
||||
my_id = ?my_peer_id,
|
||||
dst_peer_id = ?peer,
|
||||
version = version.get(),
|
||||
?my_version_peer_saved,
|
||||
last_send_version = ?last_send_time.0,
|
||||
last_send_peer_version = ?last_send_time.1,
|
||||
last_send_elapse = ?last_send_time.2.elapsed().as_secs(),
|
||||
"need send route info"
|
||||
);
|
||||
let peer_version_we_saved = sync_peer_from_remote.get(&peer).and_then(|v| Some(v.packet.version));
|
||||
last_send_time_map_new.insert(*peer, (version.get(), peer_version_we_saved, Instant::now()));
|
||||
let ret = Self::send_sync_peer_request(
|
||||
interface,
|
||||
my_peer_id.clone(),
|
||||
global_ctx.clone(),
|
||||
*peer,
|
||||
route_table.clone(),
|
||||
version.get(),
|
||||
peer_version_we_saved,
|
||||
!peer_have_latest_version,
|
||||
)
|
||||
.await;
|
||||
|
||||
match &ret {
|
||||
Ok(_) => {
|
||||
log::trace!("send sync peer request to peer: {}", peer);
|
||||
}
|
||||
Err(Error::PeerNoConnectionError(_)) => {
|
||||
log::trace!("peer {} no connection", peer);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"send sync peer request to peer: {} error: {:?}",
|
||||
peer,
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
last_send_time_map.clear();
|
||||
for item in last_send_time_map_new.iter() {
|
||||
let (k, v) = item.pair();
|
||||
last_send_time_map.insert(*k, *v);
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = notifier.notified() => {
|
||||
log::trace!("sync peer request triggered by notifier");
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(1)) => {
|
||||
log::trace!("sync peer request triggered by timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
async fn check_expired_sync_peer_from_remote(&self) {
|
||||
let route_table = self.route_table.clone();
|
||||
let my_peer_id = self.my_peer_id.clone();
|
||||
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
|
||||
let notifier = self.need_sync_notifier.clone();
|
||||
let interface = self.interface.clone();
|
||||
let version = self.version.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
let mut need_update_route = false;
|
||||
let now = std::time::Instant::now();
|
||||
let mut need_remove = Vec::new();
|
||||
let connected_peers = interface.lock().await.as_ref().unwrap().list_peers().await;
|
||||
for item in sync_peer_from_remote.iter() {
|
||||
let (k, v) = item.pair();
|
||||
if now.duration_since(v.last_update).as_secs() > ROUTE_EXPIRED_SEC
|
||||
|| !connected_peers.contains(k)
|
||||
{
|
||||
need_update_route = true;
|
||||
need_remove.insert(0, k.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for k in need_remove.iter() {
|
||||
log::warn!("remove expired sync peer: {:?}", k);
|
||||
sync_peer_from_remote.remove(k);
|
||||
}
|
||||
|
||||
if need_update_route {
|
||||
Self::update_route_table(
|
||||
my_peer_id,
|
||||
sync_peer_from_remote.clone(),
|
||||
route_table.clone(),
|
||||
);
|
||||
version.inc();
|
||||
tracing::info!(
|
||||
my_id = ?my_peer_id,
|
||||
version = version.get(),
|
||||
"update route table when check expired peer"
|
||||
);
|
||||
notifier.notify_one();
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
let ipv4 = std::net::IpAddr::V4(*ipv4);
|
||||
for item in self.route_table.cidr_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
if k.contains(&ipv4) {
|
||||
return Some(*v);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
|
||||
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) {
|
||||
let packet = postcard::from_bytes::<SyncPeer>(&packet).unwrap();
|
||||
let p = &packet;
|
||||
let mut updated = true;
|
||||
assert_eq!(packet.myself.peer_id, src_peer_id);
|
||||
self.sync_peer_from_remote
|
||||
.entry(packet.myself.peer_id.into())
|
||||
.and_modify(|v| {
|
||||
if v.packet.myself == p.myself && v.packet.neighbors == p.neighbors {
|
||||
updated = false;
|
||||
} else {
|
||||
v.packet = p.clone();
|
||||
}
|
||||
v.packet.version = p.version;
|
||||
v.packet.peer_version = p.peer_version;
|
||||
v.last_update = std::time::Instant::now();
|
||||
})
|
||||
.or_insert(SyncPeerFromRemote {
|
||||
packet: p.clone(),
|
||||
last_update: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
if updated {
|
||||
Self::update_route_table(
|
||||
self.my_peer_id.clone(),
|
||||
self.sync_peer_from_remote.clone(),
|
||||
self.route_table.clone(),
|
||||
);
|
||||
self.version.inc();
|
||||
tracing::info!(
|
||||
my_id = ?self.my_peer_id,
|
||||
?p,
|
||||
version = self.version.get(),
|
||||
"update route table when receive route packet"
|
||||
);
|
||||
}
|
||||
|
||||
if packet.need_reply {
|
||||
self.last_send_time_map
|
||||
.entry(packet.myself.peer_id.into())
|
||||
.and_modify(|v| {
|
||||
const FAST_REPLY_DURATION: u64 =
|
||||
SEND_ROUTE_PERIOD_SEC - SEND_ROUTE_FAST_REPLY_SEC;
|
||||
if v.0 != self.version.get() || v.1 != Some(p.version) {
|
||||
v.2 = Instant::now() - Duration::from_secs(3600);
|
||||
} else if v.2.elapsed().as_secs() < FAST_REPLY_DURATION {
|
||||
// do not send same version route info too frequently
|
||||
v.2 = Instant::now() - Duration::from_secs(FAST_REPLY_DURATION);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if updated || packet.need_reply {
|
||||
self.need_sync_notifier.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Route for BasicRoute {
|
||||
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
|
||||
*self.interface.lock().await = Some(interface);
|
||||
self.sync_peer_periodically().await;
|
||||
self.check_expired_sync_peer_from_remote().await;
|
||||
Ok(1)
|
||||
}
|
||||
|
||||
async fn close(&self) {}
|
||||
|
||||
async fn get_next_hop(&self, dst_peer_id: PeerId) -> Option<PeerId> {
|
||||
match self.route_table.route_info.get(&dst_peer_id) {
|
||||
Some(info) => {
|
||||
return Some(info.peer_id.clone().into());
|
||||
}
|
||||
None => {
|
||||
log::error!("no route info for dst_peer_id: {}", dst_peer_id);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
|
||||
let mut routes = Vec::new();
|
||||
|
||||
let parse_route_info = |real_peer_id: PeerId, route_info: &SyncPeerInfo| {
|
||||
let mut route = crate::rpc::Route::default();
|
||||
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
|
||||
ipv4_addr.to_string()
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
route.peer_id = real_peer_id;
|
||||
route.next_hop_peer_id = route_info.peer_id;
|
||||
route.cost = route_info.cost as i32;
|
||||
route.proxy_cidrs = route_info.proxy_cidrs.clone();
|
||||
route.hostname = if let Some(hostname) = &route_info.hostname {
|
||||
hostname.clone()
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let mut stun_info = StunInfo::default();
|
||||
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
|
||||
stun_info.set_udp_nat_type(udp_nat_type);
|
||||
}
|
||||
route.stun_info = Some(stun_info);
|
||||
|
||||
route
|
||||
};
|
||||
|
||||
self.route_table.route_info.iter().for_each(|item| {
|
||||
routes.push(parse_route_info(*item.key(), item.value()));
|
||||
});
|
||||
|
||||
routes
|
||||
}
|
||||
|
||||
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
|
||||
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
|
||||
return Some(*peer_id);
|
||||
}
|
||||
|
||||
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
|
||||
return Some(peer_id);
|
||||
}
|
||||
|
||||
log::info!("no peer id for ipv4: {}", ipv4_addr);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for BasicRoute {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
_data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if packet.packet_type == packet::PacketType::RoutePacket {
|
||||
let CtrlPacketPayload::RoutePacket(route_packet) =
|
||||
CtrlPacketPayload::from_packet(packet)
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
|
||||
self.handle_route_packet(
|
||||
packet.from_peer.into(),
|
||||
route_packet.body.into_boxed_slice().into(),
|
||||
)
|
||||
.await;
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
common::{global_ctx::tests::get_mock_global_ctx, PeerId},
|
||||
connector::udp_hole_punch::tests::replace_stun_info_collector,
|
||||
peers::{
|
||||
peer_manager::{PeerManager, RouteAlgoType},
|
||||
peer_rip_route::Version,
|
||||
tests::{connect_peer_manager, wait_route_appear},
|
||||
},
|
||||
rpc::NatType,
|
||||
};
|
||||
|
||||
async fn create_mock_pmgr() -> Arc<PeerManager> {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Rip,
|
||||
get_mock_global_ctx(),
|
||||
s,
|
||||
));
|
||||
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
|
||||
peer_mgr.run().await.unwrap();
|
||||
peer_mgr
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rip_route() {
|
||||
let peer_mgr_a = create_mock_pmgr().await;
|
||||
let peer_mgr_b = create_mock_pmgr().await;
|
||||
let peer_mgr_c = create_mock_pmgr().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mgrs = vec![peer_mgr_a.clone(), peer_mgr_b.clone(), peer_mgr_c.clone()];
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(4)).await;
|
||||
|
||||
let check_version = |version: Version, peer_id: PeerId, mgrs: &Vec<Arc<PeerManager>>| {
|
||||
for mgr in mgrs.iter() {
|
||||
tracing::warn!(
|
||||
"check version: {:?}, {:?}, {:?}, {:?}",
|
||||
version,
|
||||
peer_id,
|
||||
mgr,
|
||||
mgr.get_basic_route().sync_peer_from_remote
|
||||
);
|
||||
assert_eq!(
|
||||
version,
|
||||
mgr.get_basic_route()
|
||||
.sync_peer_from_remote
|
||||
.get(&peer_id)
|
||||
.unwrap()
|
||||
.packet
|
||||
.version,
|
||||
);
|
||||
assert_eq!(
|
||||
mgr.get_basic_route()
|
||||
.sync_peer_from_remote
|
||||
.get(&peer_id)
|
||||
.unwrap()
|
||||
.packet
|
||||
.peer_version
|
||||
.unwrap(),
|
||||
mgr.get_basic_route().version.get()
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let check_sanity = || {
|
||||
// check peer version in other peer mgr are correct.
|
||||
check_version(
|
||||
peer_mgr_b.get_basic_route().version.get(),
|
||||
peer_mgr_b.my_peer_id(),
|
||||
&vec![peer_mgr_a.clone(), peer_mgr_c.clone()],
|
||||
);
|
||||
|
||||
check_version(
|
||||
peer_mgr_a.get_basic_route().version.get(),
|
||||
peer_mgr_a.my_peer_id(),
|
||||
&vec![peer_mgr_b.clone()],
|
||||
);
|
||||
|
||||
check_version(
|
||||
peer_mgr_c.get_basic_route().version.get(),
|
||||
peer_mgr_c.my_peer_id(),
|
||||
&vec![peer_mgr_b.clone()],
|
||||
);
|
||||
};
|
||||
|
||||
check_sanity();
|
||||
|
||||
let versions = mgrs
|
||||
.iter()
|
||||
.map(|x| x.get_basic_route().version.get())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
|
||||
let versions2 = mgrs
|
||||
.iter()
|
||||
.map(|x| x.get_basic_route().version.get())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(versions, versions2);
|
||||
check_sanity();
|
||||
|
||||
assert!(peer_mgr_a.get_basic_route().version.get() <= 3);
|
||||
assert!(peer_mgr_b.get_basic_route().version.get() <= 6);
|
||||
assert!(peer_mgr_c.get_basic_route().version.get() <= 3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,581 @@
|
||||
use std::sync::{atomic::AtomicU32, Arc};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use rkyv::Deserialize;
|
||||
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
|
||||
use tokio::{
|
||||
sync::mpsc::{self, UnboundedSender},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, PeerId},
|
||||
peers::packet::Packet,
|
||||
};
|
||||
|
||||
use super::packet::CtrlPacketPayload;
|
||||
|
||||
type PeerRpcServiceId = u32;
|
||||
type PeerRpcTransactId = u32;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
|
||||
fn my_peer_id(&self) -> PeerId;
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error>;
|
||||
async fn recv(&self) -> Result<Bytes, Error>;
|
||||
}
|
||||
|
||||
type PacketSender = UnboundedSender<Packet>;
|
||||
|
||||
struct PeerRpcEndPoint {
|
||||
peer_id: PeerId,
|
||||
packet_sender: PacketSender,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
type PeerRpcEndPointCreator = Box<dyn Fn(PeerId) -> PeerRpcEndPoint + Send + Sync + 'static>;
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct PeerRpcClientCtxKey(PeerId, PeerRpcServiceId, PeerRpcTransactId);
|
||||
|
||||
// handle rpc request from one peer
|
||||
pub struct PeerRpcManager {
|
||||
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
|
||||
tasks: JoinSet<()>,
|
||||
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
|
||||
|
||||
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
|
||||
peer_rpc_endpoints: Arc<DashMap<(PeerId, PeerRpcServiceId), PeerRpcEndPoint>>,
|
||||
|
||||
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
|
||||
|
||||
transact_id: AtomicU32,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PeerRpcManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerRpcManager")
|
||||
.field("node_id", &self.tspt.my_peer_id())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TaRpcPacketInfo {
|
||||
from_peer: PeerId,
|
||||
to_peer: PeerId,
|
||||
service_id: PeerRpcServiceId,
|
||||
transact_id: PeerRpcTransactId,
|
||||
is_req: bool,
|
||||
content: Vec<u8>,
|
||||
}
|
||||
|
||||
impl PeerRpcManager {
|
||||
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
|
||||
Self {
|
||||
service_map: Arc::new(DashMap::new()),
|
||||
tasks: JoinSet::new(),
|
||||
tspt: Arc::new(Box::new(tspt)),
|
||||
|
||||
service_registry: Arc::new(DashMap::new()),
|
||||
peer_rpc_endpoints: Arc::new(DashMap::new()),
|
||||
|
||||
client_resp_receivers: Arc::new(DashMap::new()),
|
||||
|
||||
transact_id: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
|
||||
where
|
||||
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
|
||||
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
|
||||
S::Resp:
|
||||
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
|
||||
S::Fut: Send + 'static,
|
||||
{
|
||||
let tspt = self.tspt.clone();
|
||||
let creator = Box::new(move |peer_id: PeerId| {
|
||||
let mut tasks = JoinSet::new();
|
||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
||||
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
let my_peer_id_clone = tspt.my_peer_id();
|
||||
let peer_id_clone = peer_id.clone();
|
||||
|
||||
let o = server.execute(s.clone());
|
||||
tasks.spawn(o);
|
||||
|
||||
let tspt = tspt.clone();
|
||||
tasks.spawn(async move {
|
||||
let mut cur_req_peer_id = None;
|
||||
let mut cur_transact_id = 0;
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(resp) = client_transport.next() => {
|
||||
let Some(cur_req_peer_id) = cur_req_peer_id.take() else {
|
||||
tracing::error!("[PEER RPC MGR] cur_req_peer_id is none, ignore this resp");
|
||||
continue;
|
||||
};
|
||||
|
||||
tracing::trace!(resp = ?resp, "recv packet from client");
|
||||
if resp.is_err() {
|
||||
tracing::warn!(err = ?resp.err(),
|
||||
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
|
||||
continue;
|
||||
}
|
||||
let resp = resp.unwrap();
|
||||
|
||||
let serialized_resp = postcard::to_allocvec(&resp);
|
||||
if serialized_resp.is_err() {
|
||||
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg = Packet::new_tarpc_packet(
|
||||
tspt.my_peer_id(),
|
||||
cur_req_peer_id,
|
||||
service_id,
|
||||
cur_transact_id,
|
||||
false,
|
||||
serialized_resp.unwrap(),
|
||||
);
|
||||
|
||||
if let Err(e) = tspt.send(msg.into(), peer_id).await {
|
||||
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
|
||||
}
|
||||
}
|
||||
Some(packet) = packet_receiver.recv() => {
|
||||
let info = Self::parse_rpc_packet(&packet);
|
||||
if let Err(e) = info {
|
||||
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
let info = info.unwrap();
|
||||
|
||||
if info.from_peer != peer_id {
|
||||
tracing::warn!("recv packet from peer, but peer_id not match, ignore it");
|
||||
continue;
|
||||
}
|
||||
|
||||
if cur_req_peer_id.is_some() {
|
||||
tracing::warn!("cur_req_peer_id is not none, ignore this packet");
|
||||
continue;
|
||||
}
|
||||
|
||||
assert_eq!(info.service_id, service_id);
|
||||
cur_req_peer_id = Some(packet.from_peer.clone().into());
|
||||
cur_transact_id = info.transact_id;
|
||||
|
||||
tracing::trace!("recv packet from peer, packet: {:?}", packet);
|
||||
|
||||
let decoded_ret = postcard::from_bytes(&info.content.as_slice());
|
||||
if let Err(e) = decoded_ret {
|
||||
tracing::error!(error = ?e, "decode rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
|
||||
|
||||
if let Err(e) = client_transport.send(decoded).await {
|
||||
tracing::error!(error = ?e, "send to req to client transport failed");
|
||||
}
|
||||
}
|
||||
else => {
|
||||
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
|
||||
|
||||
tracing::info!(
|
||||
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
|
||||
peer_id,
|
||||
service_id
|
||||
);
|
||||
|
||||
return PeerRpcEndPoint {
|
||||
peer_id,
|
||||
packet_sender,
|
||||
tasks,
|
||||
};
|
||||
// let resp = client_transport.next().await;
|
||||
});
|
||||
|
||||
if let Some(_) = self.service_registry.insert(service_id, creator) {
|
||||
panic!(
|
||||
"[PEER RPC MGR] service {} is already registered",
|
||||
service_id
|
||||
);
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
|
||||
service_id,
|
||||
self.tspt.my_peer_id()
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
|
||||
let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet);
|
||||
match &ctrl_packet_payload {
|
||||
CtrlPacketPayload::TaRpc(id, tid, is_req, body) => Ok(TaRpcPacketInfo {
|
||||
from_peer: packet.from_peer.into(),
|
||||
to_peer: packet.to_peer.into(),
|
||||
service_id: *id,
|
||||
transact_id: *tid,
|
||||
is_req: *is_req,
|
||||
content: body.clone(),
|
||||
}),
|
||||
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&self) {
|
||||
let tspt = self.tspt.clone();
|
||||
let service_registry = self.service_registry.clone();
|
||||
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
|
||||
let client_resp_receivers = self.client_resp_receivers.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let Ok(o) = tspt.recv().await else {
|
||||
tracing::warn!("peer rpc transport read aborted, exiting");
|
||||
break;
|
||||
};
|
||||
let packet = Packet::decode(&o);
|
||||
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
|
||||
let info = Self::parse_rpc_packet(&packet).unwrap();
|
||||
|
||||
if info.is_req {
|
||||
if !service_registry.contains_key(&info.service_id) {
|
||||
log::warn!(
|
||||
"service {} not found, my_node_id: {}",
|
||||
info.service_id,
|
||||
tspt.my_peer_id()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let endpoint = peer_rpc_endpoints
|
||||
.entry((info.from_peer, info.service_id))
|
||||
.or_insert_with(|| {
|
||||
service_registry.get(&info.service_id).unwrap()(info.from_peer)
|
||||
});
|
||||
|
||||
endpoint.packet_sender.send(packet).unwrap();
|
||||
} else {
|
||||
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
|
||||
info.from_peer,
|
||||
info.service_id,
|
||||
info.transact_id,
|
||||
)) {
|
||||
log::trace!("recv resp: {:?}", packet);
|
||||
if let Err(e) = a.send(packet) {
|
||||
tracing::error!(error = ?e, "send resp to client failed");
|
||||
}
|
||||
} else {
|
||||
log::warn!("client resp receiver not found, info: {:?}", info);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(f))]
|
||||
pub async fn do_client_rpc_scoped<CM, Req, RpcRet, Fut>(
|
||||
&self,
|
||||
service_id: PeerRpcServiceId,
|
||||
dst_peer_id: PeerId,
|
||||
f: impl FnOnce(UnboundedChannel<CM, Req>) -> Fut,
|
||||
) -> RpcRet
|
||||
where
|
||||
CM: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
|
||||
Req: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
|
||||
Fut: std::future::Future<Output = RpcRet>,
|
||||
{
|
||||
let mut tasks = JoinSet::new();
|
||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
||||
let (client_transport, server_transport) =
|
||||
tarpc::transport::channel::unbounded::<CM, Req>();
|
||||
|
||||
let (mut server_s, mut server_r) = server_transport.split();
|
||||
|
||||
let transact_id = self
|
||||
.transact_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let tspt = self.tspt.clone();
|
||||
tasks.spawn(async move {
|
||||
while let Some(a) = server_r.next().await {
|
||||
if a.is_err() {
|
||||
tracing::error!(error = ?a.err(), "channel error");
|
||||
continue;
|
||||
}
|
||||
|
||||
let a = postcard::to_allocvec(&a.unwrap());
|
||||
if a.is_err() {
|
||||
tracing::error!(error = ?a.err(), "bincode serialize failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let a = Packet::new_tarpc_packet(
|
||||
tspt.my_peer_id(),
|
||||
dst_peer_id,
|
||||
service_id,
|
||||
transact_id,
|
||||
true,
|
||||
a.unwrap(),
|
||||
);
|
||||
|
||||
if let Err(e) = tspt.send(a.into(), dst_peer_id).await {
|
||||
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
|
||||
});
|
||||
|
||||
tasks.spawn(async move {
|
||||
while let Some(packet) = packet_receiver.recv().await {
|
||||
tracing::trace!("tunnel recv: {:?}", packet);
|
||||
|
||||
let info = PeerRpcManager::parse_rpc_packet(&packet);
|
||||
if let Err(e) = info {
|
||||
tracing::error!(error = ?e, "parse rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let decoded = postcard::from_bytes(&info.unwrap().content.as_slice());
|
||||
if let Err(e) = decoded {
|
||||
tracing::error!(error = ?e, "decode rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = server_s.send(decoded.unwrap()).await {
|
||||
tracing::error!(error = ?e, "send to rpc server channel failed");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("[PEER RPC MGR] server packet read aborted");
|
||||
});
|
||||
|
||||
let key = PeerRpcClientCtxKey(dst_peer_id, service_id, transact_id);
|
||||
let _insert_ret = self
|
||||
.client_resp_receivers
|
||||
.insert(key.clone(), packet_sender);
|
||||
|
||||
let ret = f(client_transport).await;
|
||||
|
||||
self.client_resp_receivers.remove(&key);
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn my_peer_id(&self) -> PeerId {
|
||||
self.tspt.my_peer_id()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, new_peer_id, PeerId},
|
||||
peers::{
|
||||
peer_rpc::PeerRpcManager,
|
||||
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
||||
},
|
||||
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
|
||||
};
|
||||
|
||||
use super::PeerRpcManagerTransport;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait TestRpcService {
|
||||
async fn hello(s: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockService {
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl TestRpcService for MockService {
|
||||
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
|
||||
format!("{} {}", self.prefix, s)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_rpc_basic_test() {
|
||||
struct MockTransport {
|
||||
tunnel: Box<dyn tunnels::Tunnel>,
|
||||
my_peer_id: PeerId,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for MockTransport {
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
async fn send(&self, msg: Bytes, _dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
println!("rpc mgr send: {:?}", msg);
|
||||
self.tunnel.pin_sink().send(msg).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
let ret = self.tunnel.pin_stream().next().await.unwrap();
|
||||
println!("rpc mgr recv: {:?}", ret);
|
||||
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
|
||||
}
|
||||
}
|
||||
|
||||
let (ct, st) = create_ring_tunnel_pair();
|
||||
|
||||
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||
tunnel: st,
|
||||
my_peer_id: new_peer_id(),
|
||||
});
|
||||
server_rpc_mgr.run();
|
||||
let s = MockService {
|
||||
prefix: "hello".to_owned(),
|
||||
};
|
||||
server_rpc_mgr.run_service(1, s.serve());
|
||||
|
||||
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||
tunnel: ct,
|
||||
my_peer_id: new_peer_id(),
|
||||
});
|
||||
client_rpc_mgr.run();
|
||||
|
||||
let ret = client_rpc_mgr
|
||||
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
println!("ret: {:?}", ret);
|
||||
assert_eq!(ret.unwrap(), "hello abc");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rpc_with_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager().await;
|
||||
let peer_mgr_b = create_mock_peer_manager().await;
|
||||
let peer_mgr_c = create_mock_peer_manager().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_a.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_peer_id()
|
||||
);
|
||||
|
||||
assert_eq!(peer_mgr_c.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_c.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_peer_id()
|
||||
);
|
||||
|
||||
let s = MockService {
|
||||
prefix: "hello".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
println!("ip_list: {:?}", ip_list);
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello abc");
|
||||
|
||||
// call again
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abcd".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
println!("ip_list: {:?}", ip_list);
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello abcd");
|
||||
|
||||
let ip_list = peer_mgr_c
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "bcd".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
println!("ip_list: {:?}", ip_list);
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello bcd");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_service_with_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager().await;
|
||||
let peer_mgr_b = create_mock_peer_manager().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_a.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_peer_id()
|
||||
);
|
||||
|
||||
let s = MockService {
|
||||
prefix: "hello_a".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
|
||||
let b = MockService {
|
||||
prefix: "hello_b".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc");
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(2, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{error::Error, PeerId};
|
||||
|
||||
#[async_trait]
|
||||
pub trait RouteInterface {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
async fn send_route_packet(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
route_id: u8,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error>;
|
||||
fn my_peer_id(&self) -> PeerId;
|
||||
}
|
||||
|
||||
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Route {
|
||||
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
|
||||
async fn close(&self);
|
||||
|
||||
async fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId>;
|
||||
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
|
||||
|
||||
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
|
||||
@@ -0,0 +1,63 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::rpc::{
|
||||
cli::PeerInfo,
|
||||
peer_manage_rpc_server::PeerManageRpc,
|
||||
{ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse},
|
||||
};
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use super::peer_manager::PeerManager;
|
||||
|
||||
pub struct PeerManagerRpcService {
|
||||
peer_manager: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
impl PeerManagerRpcService {
|
||||
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
|
||||
PeerManagerRpcService { peer_manager }
|
||||
}
|
||||
|
||||
pub async fn list_peers(&self) -> Vec<PeerInfo> {
|
||||
let peers = self.peer_manager.get_peer_map().list_peers().await;
|
||||
let mut peer_infos = Vec::new();
|
||||
for peer in peers {
|
||||
let mut peer_info = PeerInfo::default();
|
||||
peer_info.peer_id = peer;
|
||||
|
||||
if let Some(conns) = self.peer_manager.get_peer_map().list_peer_conns(peer).await {
|
||||
peer_info.conns = conns;
|
||||
}
|
||||
|
||||
peer_infos.push(peer_info);
|
||||
}
|
||||
|
||||
peer_infos
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl PeerManageRpc for PeerManagerRpcService {
|
||||
async fn list_peer(
|
||||
&self,
|
||||
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
|
||||
) -> Result<Response<ListPeerResponse>, Status> {
|
||||
let mut reply = ListPeerResponse::default();
|
||||
|
||||
let peers = self.list_peers().await;
|
||||
for peer in peers {
|
||||
reply.peer_infos.push(peer);
|
||||
}
|
||||
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
|
||||
async fn list_route(
|
||||
&self,
|
||||
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
|
||||
) -> Result<Response<ListRouteResponse>, Status> {
|
||||
let mut reply = ListRouteResponse::default();
|
||||
reply.routes = self.peer_manager.list_routes().await;
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::Future;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId},
|
||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
||||
};
|
||||
|
||||
use super::peer_manager::{PeerManager, RouteAlgoType};
|
||||
|
||||
pub async fn create_mock_peer_manager() -> Arc<PeerManager> {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Ospf,
|
||||
get_mock_global_ctx(),
|
||||
s,
|
||||
));
|
||||
peer_mgr.run().await.unwrap();
|
||||
peer_mgr
|
||||
}
|
||||
|
||||
pub async fn connect_peer_manager(client: Arc<PeerManager>, server: Arc<PeerManager>) {
|
||||
let (a_ring, b_ring) = create_ring_tunnel_pair();
|
||||
let a_mgr_copy = client.clone();
|
||||
tokio::spawn(async move {
|
||||
a_mgr_copy.add_client_tunnel(a_ring).await.unwrap();
|
||||
});
|
||||
let b_mgr_copy = server.clone();
|
||||
tokio::spawn(async move {
|
||||
b_mgr_copy.add_tunnel_as_server(b_ring).await.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn wait_route_appear_with_cost(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
node_id: PeerId,
|
||||
cost: Option<i32>,
|
||||
) -> Result<(), Error> {
|
||||
let now = std::time::Instant::now();
|
||||
while now.elapsed().as_secs() < 5 {
|
||||
let route = peer_mgr.list_routes().await;
|
||||
if route
|
||||
.iter()
|
||||
.any(|r| r.peer_id == node_id && (cost.is_none() || r.cost == cost.unwrap()))
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
|
||||
pub async fn wait_route_appear(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
target_peer: Arc<PeerManager>,
|
||||
) -> Result<(), Error> {
|
||||
wait_route_appear_with_cost(peer_mgr.clone(), target_peer.my_peer_id(), None).await?;
|
||||
wait_route_appear_with_cost(target_peer, peer_mgr.my_peer_id(), None).await
|
||||
}
|
||||
|
||||
pub async fn wait_for_condition<F, FRet>(mut condition: F, timeout: std::time::Duration) -> ()
|
||||
where
|
||||
F: FnMut() -> FRet + Send,
|
||||
FRet: Future<Output = bool>,
|
||||
{
|
||||
let now = std::time::Instant::now();
|
||||
while now.elapsed() < timeout {
|
||||
if condition().await {
|
||||
return;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
assert!(condition().await, "Timeout")
|
||||
}
|
||||
Reference in New Issue
Block a user