use workspace, prepare for config server and gui (#48)

This commit is contained in:
Sijie.Sun
2024-04-04 10:33:53 +08:00
committed by GitHub
parent bb4ae71869
commit 4eb7efe5fc
77 changed files with 162 additions and 195 deletions
@@ -0,0 +1,200 @@
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use dashmap::DashMap;
use tokio::{
sync::{mpsc, Mutex},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, NetworkIdentity},
PeerId,
};
use super::{
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::PeerRpcManager,
};
pub struct ForeignNetworkClient {
global_ctx: ArcGlobalCtx,
peer_rpc: Arc<PeerRpcManager>,
my_peer_id: PeerId,
peer_map: Arc<PeerMap>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
tasks: Mutex<JoinSet<()>>,
}
impl ForeignNetworkClient {
pub fn new(
global_ctx: ArcGlobalCtx,
packet_sender_to_mgr: mpsc::Sender<Bytes>,
peer_rpc: Arc<PeerRpcManager>,
my_peer_id: PeerId,
) -> Self {
let peer_map = Arc::new(PeerMap::new(
packet_sender_to_mgr,
global_ctx.clone(),
my_peer_id,
));
let next_hop = Arc::new(DashMap::new());
Self {
global_ctx,
peer_rpc,
my_peer_id,
peer_map,
next_hop,
tasks: Mutex::new(JoinSet::new()),
}
}
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
tracing::warn!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network client");
self.peer_map.add_new_peer_conn(peer_conn).await
}
async fn collect_next_hop_in_foreign_network_task(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
) {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
peer_map.clean_peer_without_conn().await;
let new_next_hop = Self::collect_next_hop_in_foreign_network(
network_identity.clone(),
peer_map.clone(),
peer_rpc.clone(),
)
.await;
next_hop.clear();
for (k, v) in new_next_hop.into_iter() {
next_hop.insert(k, v);
}
}
}
async fn collect_next_hop_in_foreign_network(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
) -> DashMap<PeerId, PeerId> {
let peers = peer_map.list_peers().await;
let mut tasks = JoinSet::new();
if !peers.is_empty() {
tracing::warn!(?peers, my_peer_id = ?peer_rpc.my_peer_id(), "collect next hop in foreign network");
}
for peer in peers {
let peer_rpc = peer_rpc.clone();
let network_identity = network_identity.clone();
tasks.spawn(async move {
let Ok(Some(peers_in_foreign)) = peer_rpc
.do_client_rpc_scoped(FOREIGN_NETWORK_SERVICE_ID, peer, |c| async {
let c =
ForeignNetworkServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(2);
let ret = c.list_network_peers(rpc_ctx, network_identity).await;
ret
})
.await
else {
return (peer, vec![]);
};
(peer, peers_in_foreign)
});
}
let new_next_hop = DashMap::new();
while let Some(join_ret) = tasks.join_next().await {
let Ok((gateway, peer_ids)) = join_ret else {
tracing::error!(?join_ret, "collect next hop in foreign network failed");
continue;
};
for ret in peer_ids {
new_next_hop.insert(ret, gateway);
}
}
new_next_hop
}
pub fn has_next_hop(&self, peer_id: PeerId) -> bool {
self.get_next_hop(peer_id).is_some()
}
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
if self.peer_map.has_peer(peer_id) {
return Some(peer_id.clone());
}
self.next_hop.get(&peer_id).map(|v| v.clone())
}
pub async fn send_msg(&self, msg: Bytes, peer_id: PeerId) -> Result<(), Error> {
if let Some(next_hop) = self.get_next_hop(peer_id) {
let ret = self.peer_map.send_msg_directly(msg, next_hop).await;
if ret.is_err() {
tracing::error!(
?ret,
?peer_id,
?next_hop,
"foreign network client send msg failed"
);
}
return ret;
}
Err(Error::RouteError(Some("no next hop".to_string())))
}
pub fn list_foreign_peers(&self) -> Vec<PeerId> {
let mut peers = vec![];
for item in self.next_hop.iter() {
if item.key() != &self.my_peer_id {
peers.push(item.key().clone());
}
}
peers
}
pub async fn run(&self) {
self.tasks
.lock()
.await
.spawn(Self::collect_next_hop_in_foreign_network_task(
self.global_ctx.get_network_identity(),
self.peer_map.clone(),
self.peer_rpc.clone(),
self.next_hop.clone(),
));
}
pub fn get_next_hop_table(&self) -> DashMap<PeerId, PeerId> {
let next_hop = DashMap::new();
for item in self.next_hop.iter() {
next_hop.insert(item.key().clone(), item.value().clone());
}
next_hop
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {
self.peer_map.clone()
}
}
@@ -0,0 +1,459 @@
/*
foreign_network_manager is used to forward packets of other networks. currently
only forward packets of peers that directly connected to this node.
in future, with the help wo peer center we can forward packets of peers that
connected to any node in the local network.
*/
use std::sync::Arc;
use dashmap::DashMap;
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex,
},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity},
PeerId,
};
use super::{
packet::{self},
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
};
struct ForeignNetworkEntry {
network: NetworkIdentity,
peer_map: Arc<PeerMap>,
}
impl ForeignNetworkEntry {
fn new(
network: NetworkIdentity,
packet_sender: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
my_peer_id: PeerId,
) -> Self {
let peer_map = Arc::new(PeerMap::new(packet_sender, global_ctx, my_peer_id));
Self { network, peer_map }
}
}
struct ForeignNetworkManagerData {
network_peer_maps: DashMap<String, Arc<ForeignNetworkEntry>>,
peer_network_map: DashMap<PeerId, String>,
}
impl ForeignNetworkManagerData {
async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
let network_name = self
.peer_network_map
.get(&dst_peer_id)
.ok_or_else(|| Error::RouteError(Some("network not found".to_string())))?
.clone();
let entry = self
.network_peer_maps
.get(&network_name)
.ok_or_else(|| Error::RouteError(Some("no peer in network".to_string())))?
.clone();
entry.peer_map.send_msg(msg, dst_peer_id).await
}
fn get_peer_network(&self, peer_id: PeerId) -> Option<String> {
self.peer_network_map.get(&peer_id).map(|v| v.clone())
}
fn get_network_entry(&self, network_name: &str) -> Option<Arc<ForeignNetworkEntry>> {
self.network_peer_maps.get(network_name).map(|v| v.clone())
}
fn remove_peer(&self, peer_id: PeerId) {
self.peer_network_map.remove(&peer_id);
self.network_peer_maps.retain(|_, v| !v.peer_map.is_empty());
}
fn clear_no_conn_peer(&self) {
for item in self.network_peer_maps.iter() {
let peer_map = item.value().peer_map.clone();
tokio::spawn(async move {
peer_map.clean_peer_without_conn().await;
});
}
}
}
struct RpcTransport {
my_peer_id: PeerId,
data: Arc<ForeignNetworkManagerData>,
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for RpcTransport {
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
self.data.send_msg(msg, dst_peer_id).await
}
async fn recv(&self) -> Result<Bytes, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await {
Ok(o)
} else {
Err(Error::Unknown)
}
}
}
pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1;
#[tarpc::service]
pub trait ForeignNetworkService {
async fn list_network_peers(network_identy: NetworkIdentity) -> Option<Vec<PeerId>>;
}
#[tarpc::server]
impl ForeignNetworkService for Arc<ForeignNetworkManagerData> {
async fn list_network_peers(
self,
_: tarpc::context::Context,
network_identy: NetworkIdentity,
) -> Option<Vec<PeerId>> {
let entry = self.network_peer_maps.get(&network_identy.network_name)?;
Some(entry.peer_map.list_peers().await)
}
}
pub struct ForeignNetworkManager {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
packet_sender_to_mgr: mpsc::Sender<Bytes>,
packet_sender: mpsc::Sender<Bytes>,
packet_recv: Mutex<Option<mpsc::Receiver<Bytes>>>,
data: Arc<ForeignNetworkManagerData>,
rpc_mgr: Arc<PeerRpcManager>,
rpc_transport_sender: UnboundedSender<Bytes>,
tasks: Mutex<JoinSet<()>>,
}
impl ForeignNetworkManager {
pub fn new(
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
packet_sender_to_mgr: mpsc::Sender<Bytes>,
) -> Self {
// recv packet from all foreign networks
let (packet_sender, packet_recv) = mpsc::channel(1000);
let data = Arc::new(ForeignNetworkManagerData {
network_peer_maps: DashMap::new(),
peer_network_map: DashMap::new(),
});
// handle rpc from foreign networks
let (rpc_transport_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_mgr = Arc::new(PeerRpcManager::new(RpcTransport {
my_peer_id,
data: data.clone(),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
}));
Self {
my_peer_id,
global_ctx,
packet_sender_to_mgr,
packet_sender,
packet_recv: Mutex::new(Some(packet_recv)),
data,
rpc_mgr,
rpc_transport_sender,
tasks: Mutex::new(JoinSet::new()),
}
}
pub async fn add_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> {
tracing::info!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager");
let entry = self
.data
.network_peer_maps
.entry(peer_conn.get_network_identity().network_name.clone())
.or_insert_with(|| {
Arc::new(ForeignNetworkEntry::new(
peer_conn.get_network_identity(),
self.packet_sender.clone(),
self.global_ctx.clone(),
self.my_peer_id,
))
})
.clone();
self.data.peer_network_map.insert(
peer_conn.get_peer_id(),
peer_conn.get_network_identity().network_name.clone(),
);
if entry.network.network_secret != peer_conn.get_network_identity().network_secret {
return Err(anyhow::anyhow!("network secret not match").into());
}
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
}
async fn start_global_event_handler(&self) {
let data = self.data.clone();
let mut s = self.global_ctx.subscribe();
self.tasks.lock().await.spawn(async move {
while let Ok(e) = s.recv().await {
if let GlobalCtxEvent::PeerRemoved(peer_id) = &e {
tracing::info!(?e, "remove peer from foreign network manager");
data.remove_peer(*peer_id);
} else if let GlobalCtxEvent::PeerConnRemoved(..) = &e {
tracing::info!(?e, "clear no conn peer from foreign network manager");
data.clear_no_conn_peer();
}
}
});
}
async fn start_packet_recv(&self) {
let mut recv = self.packet_recv.lock().await.take().unwrap();
let sender_to_mgr = self.packet_sender_to_mgr.clone();
let my_node_id = self.my_peer_id;
let rpc_sender = self.rpc_transport_sender.clone();
let data = self.data.clone();
self.tasks.lock().await.spawn(async move {
while let Some(packet_bytes) = recv.recv().await {
let packet = packet::Packet::decode(&packet_bytes);
let from_peer_id = packet.from_peer.into();
let to_peer_id = packet.to_peer.into();
if to_peer_id == my_node_id {
if packet.packet_type == packet::PacketType::TaRpc {
rpc_sender.send(packet_bytes.clone()).unwrap();
continue;
}
if let Err(e) = sender_to_mgr.send(packet_bytes).await {
tracing::error!("send packet to mgr failed: {:?}", e);
}
} else {
let Some(from_network) = data.get_peer_network(from_peer_id) else {
continue;
};
let Some(to_network) = data.get_peer_network(to_peer_id) else {
continue;
};
if from_network != to_network {
continue;
}
if let Some(entry) = data.get_network_entry(&from_network) {
let ret = entry.peer_map.send_msg(packet_bytes, to_peer_id).await;
if ret.is_err() {
tracing::error!("forward packet to peer failed: {:?}", ret.err());
}
} else {
tracing::error!("foreign network not found: {}", from_network);
}
}
}
});
}
async fn register_peer_rpc_service(&self) {
self.rpc_mgr.run();
self.rpc_mgr
.run_service(FOREIGN_NETWORK_SERVICE_ID, self.data.clone().serve())
}
pub async fn run(&self) {
self.start_global_event_handler().await;
self.start_packet_recv().await;
self.register_peer_rpc_service().await;
}
pub async fn list_foreign_networks(&self) -> DashMap<String, Vec<PeerId>> {
let ret = DashMap::new();
for item in self.data.network_peer_maps.iter() {
let network_name = item.key().clone();
ret.insert(network_name, vec![]);
}
for mut n in ret.iter_mut() {
let network_name = n.key().clone();
let Some(item) = self
.data
.network_peer_maps
.get(&network_name)
.map(|v| v.clone())
else {
continue;
};
n.value_mut().extend(item.peer_map.list_peers().await);
}
ret
}
}
#[cfg(test)]
mod tests {
use crate::{
common::global_ctx::tests::get_mock_global_ctx_with_network,
connector::udp_hole_punch::tests::{
create_mock_peer_manager_with_mock_stun, replace_stun_info_collector,
},
peers::{
peer_manager::{PeerManager, RouteAlgoType},
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
};
use super::*;
async fn create_mock_peer_manager_for_foreign_network(network: &str) -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx_with_network(Some(NetworkIdentity {
network_name: network.to_string(),
network_secret: network.to_string(),
})),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr.run().await.unwrap();
peer_mgr
}
#[tokio::test]
async fn test_foreign_network_manager() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
let pm_center2 =
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
let now = std::time::Instant::now();
let mut succ = false;
while now.elapsed().as_secs() < 10 {
let table = pma_net1.get_foreign_network_client().get_next_hop_table();
if table.len() >= 1 {
succ = true;
break;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
assert!(succ);
assert_eq!(
vec![pm_center.my_peer_id()],
pma_net1
.get_foreign_network_client()
.get_peer_map()
.list_peers()
.await
);
assert_eq!(
vec![pm_center.my_peer_id()],
pmb_net1
.get_foreign_network_client()
.get_peer_map()
.list_peers()
.await
);
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
assert_eq!(1, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len());
let pmc_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pmc_net1.clone(), pm_center.clone()).await;
wait_route_appear(pma_net1.clone(), pmc_net1.clone())
.await
.unwrap();
wait_route_appear(pmb_net1.clone(), pmc_net1.clone())
.await
.unwrap();
assert_eq!(2, pmc_net1.list_routes().await.len());
let pma_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
let pmb_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
connect_peer_manager(pma_net2.clone(), pm_center.clone()).await;
connect_peer_manager(pmb_net2.clone(), pm_center.clone()).await;
wait_route_appear(pma_net2.clone(), pmb_net2.clone())
.await
.unwrap();
assert_eq!(1, pma_net2.list_routes().await.len());
assert_eq!(1, pmb_net2.list_routes().await.len());
assert_eq!(
5,
pm_center
.get_foreign_network_manager()
.data
.peer_network_map
.len()
);
assert_eq!(
2,
pm_center
.get_foreign_network_manager()
.data
.network_peer_maps
.len()
);
drop(pmb_net2);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(
4,
pm_center
.get_foreign_network_manager()
.data
.peer_network_map
.len()
);
drop(pma_net2);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(
3,
pm_center
.get_foreign_network_manager()
.data
.peer_network_map
.len()
);
assert_eq!(
1,
pm_center
.get_foreign_network_manager()
.data
.network_peer_maps
.len()
);
}
}
+39
View File
@@ -0,0 +1,39 @@
pub mod packet;
pub mod peer;
pub mod peer_conn;
pub mod peer_manager;
pub mod peer_map;
pub mod peer_ospf_route;
pub mod peer_rip_route;
pub mod peer_rpc;
pub mod route_trait;
pub mod rpc_service;
pub mod foreign_network_client;
pub mod foreign_network_manager;
#[cfg(test)]
pub mod tests;
use tokio_util::bytes::{Bytes, BytesMut};
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerPacketFilter {
async fn try_process_packet_from_peer(
&self,
_packet: &packet::ArchivedPacket,
_data: &Bytes,
) -> Option<()> {
None
}
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait NicPacketFilter {
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
}
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
+254
View File
@@ -0,0 +1,254 @@
use std::fmt::Debug;
use rkyv::{Archive, Deserialize, Serialize};
use tokio_util::bytes::Bytes;
use crate::common::{
global_ctx::NetworkIdentity,
rkyv_util::{decode_from_bytes, encode_to_bytes, vec_to_string},
PeerId,
};
const MAGIC: u32 = 0xd1e1a5e1;
const VERSION: u32 = 1;
#[derive(Archive, Deserialize, Serialize, PartialEq, Clone)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct UUID(uuid::Bytes);
// impl Debug for UUID
impl std::fmt::Debug for UUID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let uuid = uuid::Uuid::from_bytes(self.0);
write!(f, "{}", uuid)
}
}
impl From<uuid::Uuid> for UUID {
fn from(uuid: uuid::Uuid) -> Self {
UUID(*uuid.as_bytes())
}
}
impl From<UUID> for uuid::Uuid {
fn from(uuid: UUID) -> Self {
uuid::Uuid::from_bytes(uuid.0)
}
}
impl ArchivedUUID {
pub fn to_uuid(&self) -> uuid::Uuid {
uuid::Uuid::from_bytes(self.0)
}
}
impl From<&ArchivedUUID> for UUID {
fn from(uuid: &ArchivedUUID) -> Self {
UUID(uuid.0)
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct HandShake {
pub magic: u32,
pub my_peer_id: PeerId,
pub version: u32,
pub features: Vec<String>,
pub network_identity: NetworkIdentity,
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct RoutePacket {
pub route_id: u8,
pub body: Vec<u8>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub enum CtrlPacketPayload {
HandShake(HandShake),
RoutePacket(RoutePacket),
Ping(u32),
Pong(u32),
TaRpc(u32, u32, bool, Vec<u8>), // u32: service_id, u32: transact_id, bool: is_req, Vec<u8>: rpc body
}
impl CtrlPacketPayload {
pub fn from_packet(p: &ArchivedPacket) -> CtrlPacketPayload {
assert_ne!(p.packet_type, PacketType::Data);
postcard::from_bytes(p.payload.as_bytes()).unwrap()
}
pub fn from_packet2(p: &Packet) -> CtrlPacketPayload {
postcard::from_bytes(p.payload.as_bytes()).unwrap()
}
}
#[repr(u8)]
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum PacketType {
Data = 1,
HandShake = 2,
RoutePacket = 3,
Ping = 4,
Pong = 5,
TaRpc = 6,
}
#[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
pub struct Packet {
pub from_peer: PeerId,
pub to_peer: PeerId,
pub packet_type: PacketType,
pub payload: String,
}
impl std::fmt::Debug for Packet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
self.from_peer,
self.to_peer,
self.packet_type,
&self.payload.as_bytes()
)
}
}
impl std::fmt::Debug for ArchivedPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
self.from_peer,
self.to_peer,
self.packet_type,
&self.payload.as_bytes()
)
}
}
impl Packet {
pub fn decode(v: &[u8]) -> &ArchivedPacket {
decode_from_bytes::<Packet>(v).unwrap()
}
pub fn new(
from_peer: PeerId,
to_peer: PeerId,
packet_type: PacketType,
payload: Vec<u8>,
) -> Self {
Packet {
from_peer,
to_peer,
packet_type,
payload: vec_to_string(payload),
}
}
}
impl From<Packet> for Bytes {
fn from(val: Packet) -> Self {
encode_to_bytes::<_, 4096>(&val)
}
}
impl Packet {
pub fn new_handshake(from_peer: PeerId, network: &NetworkIdentity) -> Self {
let handshake = CtrlPacketPayload::HandShake(HandShake {
magic: MAGIC,
my_peer_id: from_peer,
version: VERSION,
features: Vec::new(),
network_identity: network.clone().into(),
});
Packet::new(
from_peer.into(),
0,
PacketType::HandShake,
postcard::to_allocvec(&handshake).unwrap(),
)
}
pub fn new_data_packet(from_peer: PeerId, to_peer: PeerId, data: &[u8]) -> Self {
Packet::new(from_peer, to_peer, PacketType::Data, data.to_vec())
}
pub fn new_route_packet(from_peer: PeerId, to_peer: PeerId, route_id: u8, data: &[u8]) -> Self {
let route = CtrlPacketPayload::RoutePacket(RoutePacket {
route_id,
body: data.to_vec(),
});
Packet::new(
from_peer,
to_peer,
PacketType::RoutePacket,
postcard::to_allocvec(&route).unwrap(),
)
}
pub fn new_ping_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
let ping = CtrlPacketPayload::Ping(seq);
Packet::new(
from_peer,
to_peer,
PacketType::Ping,
postcard::to_allocvec(&ping).unwrap(),
)
}
pub fn new_pong_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
let pong = CtrlPacketPayload::Pong(seq);
Packet::new(
from_peer,
to_peer,
PacketType::Pong,
postcard::to_allocvec(&pong).unwrap(),
)
}
pub fn new_tarpc_packet(
from_peer: PeerId,
to_peer: PeerId,
service_id: u32,
transact_id: u32,
is_req: bool,
body: Vec<u8>,
) -> Self {
let ta_rpc = CtrlPacketPayload::TaRpc(service_id, transact_id, is_req, body);
Packet::new(
from_peer,
to_peer,
PacketType::TaRpc,
postcard::to_allocvec(&ta_rpc).unwrap(),
)
}
}
#[cfg(test)]
mod tests {
use crate::common::new_peer_id;
use super::*;
#[tokio::test]
async fn serialize() {
let a = "abcde";
let out = Packet::new_data_packet(new_peer_id(), new_peer_id(), a.as_bytes());
// let out = T::new(a.as_bytes());
let out_bytes: Bytes = out.into();
println!("out str: {:?}", a.as_bytes());
println!("out bytes: {:?}", out_bytes);
let archived = Packet::decode(&out_bytes[..]);
println!("in packet: {:?}", archived);
}
}
+213
View File
@@ -0,0 +1,213 @@
use std::sync::Arc;
use dashmap::DashMap;
use tokio::{
select,
sync::{mpsc, Mutex},
task::JoinHandle,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use super::peer_conn::{PeerConn, PeerConnId};
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
PeerId,
};
use crate::rpc::PeerConnInfo;
type ArcPeerConn = Arc<Mutex<PeerConn>>;
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
pub struct Peer {
pub peer_node_id: PeerId,
conns: ConnMap,
global_ctx: ArcGlobalCtx,
packet_recv_chan: mpsc::Sender<Bytes>,
close_event_sender: mpsc::Sender<PeerConnId>,
close_event_listener: JoinHandle<()>,
shutdown_notifier: Arc<tokio::sync::Notify>,
}
impl Peer {
pub fn new(
peer_node_id: PeerId,
packet_recv_chan: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
) -> Self {
let conns: ConnMap = Arc::new(DashMap::new());
let (close_event_sender, mut close_event_receiver) = mpsc::channel(10);
let shutdown_notifier = Arc::new(tokio::sync::Notify::new());
let conns_copy = conns.clone();
let shutdown_notifier_copy = shutdown_notifier.clone();
let global_ctx_copy = global_ctx.clone();
let close_event_listener = tokio::spawn(
async move {
loop {
select! {
ret = close_event_receiver.recv() => {
if ret.is_none() {
break;
}
let ret = ret.unwrap();
tracing::warn!(
?peer_node_id,
?ret,
"notified that peer conn is closed",
);
if let Some((_, conn)) = conns_copy.remove(&ret) {
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
conn.lock().await.get_conn_info(),
));
}
}
_ = shutdown_notifier_copy.notified() => {
close_event_receiver.close();
tracing::warn!(?peer_node_id, "peer close event listener notified");
}
}
}
tracing::info!("peer {} close event listener exit", peer_node_id);
}
.instrument(tracing::info_span!(
"peer_close_event_listener",
?peer_node_id,
)),
);
Peer {
peer_node_id,
conns: conns.clone(),
packet_recv_chan,
global_ctx,
close_event_sender,
close_event_listener,
shutdown_notifier,
}
}
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
conn.set_close_event_sender(self.close_event_sender.clone());
conn.start_recv_loop(self.packet_recv_chan.clone());
conn.start_pingpong();
self.global_ctx
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
self.conns
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
}
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
let Some(conn) = self.conns.iter().next() else {
return Err(Error::PeerNoConnectionError(self.peer_node_id));
};
let conn_clone = conn.clone();
drop(conn);
conn_clone.lock().await.send_msg(msg).await?;
Ok(())
}
pub async fn close_peer_conn(&self, conn_id: &PeerConnId) -> Result<(), Error> {
let has_key = self.conns.contains_key(conn_id);
if !has_key {
return Err(Error::NotFound);
}
self.close_event_sender.send(conn_id.clone()).await.unwrap();
Ok(())
}
pub async fn list_peer_conns(&self) -> Vec<PeerConnInfo> {
let mut conns = vec![];
for conn in self.conns.iter() {
// do not lock here, otherwise it will cause dashmap deadlock
conns.push(conn.clone());
}
let mut ret = Vec::new();
for conn in conns {
ret.push(conn.lock().await.get_conn_info());
}
ret
}
}
// pritn on drop
impl Drop for Peer {
fn drop(&mut self) {
self.shutdown_notifier.notify_one();
tracing::info!("peer {} drop", self.peer_node_id);
}
}
#[cfg(test)]
mod tests {
use tokio::{sync::mpsc, time::timeout};
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
peers::peer_conn::PeerConn,
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::Peer;
#[tokio::test]
async fn close_peer() {
let (local_packet_send, _local_packet_recv) = mpsc::channel(10);
let (remote_packet_send, _remote_packet_recv) = mpsc::channel(10);
let global_ctx = get_mock_global_ctx();
let local_peer = Peer::new(new_peer_id(), local_packet_send, global_ctx.clone());
let remote_peer = Peer::new(new_peer_id(), remote_packet_send, global_ctx.clone());
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
let mut local_peer_conn =
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
let mut remote_peer_conn =
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
assert!(!local_peer_conn.handshake_done());
assert!(!remote_peer_conn.handshake_done());
let (a, b) = tokio::join!(
local_peer_conn.do_handshake_as_client(),
remote_peer_conn.do_handshake_as_server()
);
a.unwrap();
b.unwrap();
let local_conn_id = local_peer_conn.get_conn_id();
local_peer.add_peer_conn(local_peer_conn).await;
remote_peer.add_peer_conn(remote_peer_conn).await;
assert_eq!(local_peer.list_peer_conns().await.len(), 1);
assert_eq!(remote_peer.list_peer_conns().await.len(), 1);
let close_handler =
tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await });
// wait for remote peer conn close
timeout(std::time::Duration::from_secs(5), async {
while (&remote_peer).list_peer_conns().await.len() != 0 {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
})
.await
.unwrap();
println!("wait for close handler");
close_handler.await.unwrap().unwrap();
}
}
+652
View File
@@ -0,0 +1,652 @@
use std::{
fmt::Debug,
pin::Pin,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};
use futures::{SinkExt, StreamExt};
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{broadcast, mpsc, Mutex},
task::JoinSet,
time::{timeout, Duration},
};
use tokio_util::{bytes::Bytes, sync::PollSender};
use tracing::Instrument;
use crate::{
common::{
global_ctx::{ArcGlobalCtx, NetworkIdentity},
PeerId,
},
define_tunnel_filter_chain,
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
rpc::{PeerConnInfo, PeerConnStats},
tunnels::{
stats::{Throughput, WindowLatency},
tunnel_filter::StatsRecorderTunnelFilter,
DatagramSink, Tunnel, TunnelError,
},
};
use super::packet::{self, HandShake, Packet};
pub type PacketRecvChan = mpsc::Sender<Bytes>;
pub type PeerConnId = uuid::Uuid;
macro_rules! wait_response {
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
if rsp_vec.is_err() {
return Err(TunnelError::WaitRespError(
"wait handshake response timeout".to_owned(),
));
}
let rsp_vec = rsp_vec.unwrap().unwrap()?;
let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec);
if rsp_bytes.packet_type != PacketType::HandShake {
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
return Err(TunnelError::WaitRespError(
"unexpected packet type".to_owned(),
));
}
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
match &resp_payload {
$pattern => $out_var = $value,
_ => {
tracing::error!(
"unexpected packet: {:?}, pattern: {:?}",
rsp_bytes,
stringify!($pattern)
);
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
}
}
};
}
#[derive(Debug)]
pub struct PeerInfo {
magic: u32,
pub my_peer_id: PeerId,
version: u32,
pub features: Vec<String>,
pub interfaces: Vec<NetworkInterface>,
pub network_identity: NetworkIdentity,
}
impl<'a> From<&HandShake> for PeerInfo {
fn from(hs: &HandShake) -> Self {
PeerInfo {
magic: hs.magic.into(),
my_peer_id: hs.my_peer_id.into(),
version: hs.version.into(),
features: hs.features.iter().map(|x| x.to_string()).collect(),
interfaces: Vec::new(),
network_identity: hs.network_identity.clone(),
}
}
}
struct PeerConnPinger {
my_peer_id: PeerId,
peer_id: PeerId,
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
ctrl_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
loss_rate_stats: Arc<AtomicU32>,
tasks: JoinSet<Result<(), TunnelError>>,
}
impl Debug for PeerConnPinger {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerConnPinger")
.field("my_peer_id", &self.my_peer_id)
.field("peer_id", &self.peer_id)
.finish()
}
}
impl PeerConnPinger {
pub fn new(
my_peer_id: PeerId,
peer_id: PeerId,
sink: Pin<Box<dyn DatagramSink>>,
ctrl_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
loss_rate_stats: Arc<AtomicU32>,
) -> Self {
Self {
my_peer_id,
peer_id,
sink: Arc::new(Mutex::new(sink)),
tasks: JoinSet::new(),
latency_stats,
ctrl_sender,
loss_rate_stats,
}
}
async fn do_pingpong_once(
my_node_id: PeerId,
peer_id: PeerId,
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
receiver: &mut broadcast::Receiver<Bytes>,
seq: u32,
) -> Result<u128, TunnelError> {
// should add seq here. so latency can be calculated more accurately
let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into();
tracing::trace!("send ping packet: {:?}", req);
sink.lock().await.send(req).await.map_err(|e| {
tracing::warn!("send ping packet error: {:?}", e);
TunnelError::CommonError("send ping packet error".to_owned())
})?;
let now = std::time::Instant::now();
// wait until we get a pong packet in ctrl_resp_receiver
let resp = timeout(Duration::from_secs(1), async {
loop {
match receiver.recv().await {
Ok(p) => {
let ctrl_payload =
packet::CtrlPacketPayload::from_packet(Packet::decode(&p));
if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload {
if resp_seq == seq {
break;
}
}
}
Err(e) => {
log::warn!("recv pong resp error: {:?}", e);
return Err(TunnelError::WaitRespError(
"recv pong resp error".to_owned(),
));
}
}
}
Ok(())
})
.await;
tracing::trace!(?resp, "wait ping response done");
if resp.is_err() {
return Err(TunnelError::WaitRespError(
"wait ping response timeout".to_owned(),
));
}
if resp.as_ref().unwrap().is_err() {
return Err(resp.unwrap().err().unwrap());
}
Ok(now.elapsed().as_micros())
}
async fn pingpong(&mut self) {
let sink = self.sink.clone();
let my_node_id = self.my_peer_id;
let peer_id = self.peer_id;
let latency_stats = self.latency_stats.clone();
let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100);
let stopped = Arc::new(AtomicU32::new(0));
// generate a pingpong task every 200ms
let mut pingpong_tasks = JoinSet::new();
let ctrl_resp_sender = self.ctrl_sender.clone();
let stopped_clone = stopped.clone();
self.tasks.spawn(async move {
let mut req_seq = 0;
loop {
let receiver = ctrl_resp_sender.subscribe();
let ping_res_sender = ping_res_sender.clone();
let sink = sink.clone();
if stopped_clone.load(Ordering::Relaxed) != 0 {
return Ok(());
}
while pingpong_tasks.len() > 5 {
pingpong_tasks.join_next().await;
}
pingpong_tasks.spawn(async move {
let mut receiver = receiver.resubscribe();
let pingpong_once_ret = Self::do_pingpong_once(
my_node_id,
peer_id,
sink.clone(),
&mut receiver,
req_seq,
)
.await;
if let Err(e) = ping_res_sender.send(pingpong_once_ret).await {
tracing::info!(?e, "pingpong task send result error, exit..");
};
});
req_seq = req_seq.wrapping_add(1);
tokio::time::sleep(Duration::from_millis(1000)).await;
}
});
// one with 1% precision
let loss_rate_stats_1 = WindowLatency::new(100);
// one with 20% precision, so we can fast fail this conn.
let loss_rate_stats_20 = WindowLatency::new(5);
let mut counter: u64 = 0;
while let Some(ret) = ping_res_receiver.recv().await {
counter += 1;
if let Ok(lat) = ret {
latency_stats.record_latency(lat as u32);
loss_rate_stats_1.record_latency(0);
loss_rate_stats_20.record_latency(0);
} else {
loss_rate_stats_1.record_latency(1);
loss_rate_stats_20.record_latency(1);
}
let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us();
let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us();
tracing::trace!(
?ret,
?self,
?loss_rate_1,
?loss_rate_20,
"pingpong task recv pingpong_once result"
);
if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) {
tracing::warn!(
?ret,
?self,
?loss_rate_1,
?loss_rate_20,
"pingpong loss rate too high, closing"
);
break;
}
self.loss_rate_stats
.store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed);
}
stopped.store(1, Ordering::Relaxed);
ping_res_receiver.close();
}
}
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
pub struct PeerConn {
conn_id: PeerConnId,
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
sink: Pin<Box<dyn DatagramSink>>,
tunnel: Box<dyn Tunnel>,
tasks: JoinSet<Result<(), TunnelError>>,
info: Option<PeerInfo>,
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
ctrl_resp_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
throughput: Arc<Throughput>,
loss_rate_stats: Arc<AtomicU32>,
}
enum PeerConnPacketType {
Data(Bytes),
CtrlReq(Bytes),
CtrlResp(Bytes),
}
impl PeerConn {
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
let peer_conn_tunnel = PeerConnTunnel::new();
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
PeerConn {
conn_id: PeerConnId::new_v4(),
my_peer_id,
global_ctx,
sink: tunnel.pin_sink(),
tunnel: Box::new(tunnel),
tasks: JoinSet::new(),
info: None,
close_event_sender: None,
ctrl_resp_sender: ctrl_sender,
latency_stats: Arc::new(WindowLatency::new(15)),
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
loss_rate_stats: Arc::new(AtomicU32::new(0)),
}
}
pub fn get_conn_id(&self) -> PeerConnId {
self.conn_id
}
#[tracing::instrument]
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
tracing::info!("waiting for handshake request from client");
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
self.info = Some(PeerInfo::from(hs_req));
tracing::info!("handshake request: {:?}", hs_req);
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
sink.send(hs_req.into()).await?;
Ok(())
}
#[tracing::instrument]
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
sink.send(hs_req.into()).await?;
tracing::info!("waiting for handshake request from server");
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
self.info = Some(PeerInfo::from(hs_rsp));
tracing::info!("handshake response: {:?}", hs_rsp);
Ok(())
}
pub fn handshake_done(&self) -> bool {
self.info.is_some()
}
pub fn start_pingpong(&mut self) {
let mut pingpong = PeerConnPinger::new(
self.my_peer_id,
self.get_peer_id(),
self.tunnel.pin_sink(),
self.ctrl_resp_sender.clone(),
self.latency_stats.clone(),
self.loss_rate_stats.clone(),
);
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
self.tasks.spawn(async move {
pingpong.pingpong().await;
tracing::warn!(?pingpong, "pingpong task exit");
if let Err(e) = close_event_sender.send(conn_id).await {
log::warn!("close event sender error: {:?}", e);
}
Ok(())
});
}
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let mut sender = PollSender::new(packet_recv_chan.clone());
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
let ctrl_sender = self.ctrl_resp_sender.clone();
let conn_info = self.get_conn_info();
let conn_info_for_instrument = self.get_conn_info();
self.tasks.spawn(
async move {
tracing::info!("start recving peer conn packet");
let mut task_ret = Ok(());
while let Some(ret) = stream.next().await {
if ret.is_err() {
tracing::error!(error = ?ret, "peer conn recv error");
task_ret = Err(ret.err().unwrap());
break;
}
let buf = ret.unwrap();
let p = Packet::decode(&buf);
match p.packet_type {
ArchivedPacketType::Ping => {
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
else {
log::error!("unexpected packet: {:?}", p);
continue;
};
let pong = packet::Packet::new_pong_packet(
conn_info.my_peer_id,
conn_info.peer_id,
seq.into(),
);
if let Err(e) = sink.send(pong.into()).await {
tracing::error!(?e, "peer conn send req error");
}
}
ArchivedPacketType::Pong => {
if let Err(e) = ctrl_sender.send(buf.into()) {
tracing::error!(?e, "peer conn send ctrl resp error");
}
}
_ => {
if sender.send(buf.into()).await.is_err() {
break;
}
}
}
}
tracing::info!("end recving peer conn packet");
if let Err(close_ret) = sink.close().await {
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
}
if let Err(e) = close_event_sender.send(conn_id).await {
tracing::error!(error = ?e, "peer conn close event send error");
}
task_ret
}
.instrument(
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
),
);
}
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
self.sink.send(msg).await
}
pub fn get_peer_id(&self) -> PeerId {
self.info.as_ref().unwrap().my_peer_id
}
pub fn get_network_identity(&self) -> NetworkIdentity {
self.info.as_ref().unwrap().network_identity.clone()
}
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
self.close_event_sender = Some(sender);
}
pub fn get_stats(&self) -> PeerConnStats {
PeerConnStats {
latency_us: self.latency_stats.get_latency_us(),
tx_bytes: self.throughput.tx_bytes(),
rx_bytes: self.throughput.rx_bytes(),
tx_packets: self.throughput.tx_packets(),
rx_packets: self.throughput.rx_packets(),
}
}
pub fn get_conn_info(&self) -> PeerConnInfo {
PeerConnInfo {
conn_id: self.conn_id.to_string(),
my_peer_id: self.my_peer_id,
peer_id: self.get_peer_id(),
features: self.info.as_ref().unwrap().features.clone(),
tunnel: self.tunnel.info(),
stats: Some(self.get_stats()),
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
}
}
}
impl Drop for PeerConn {
fn drop(&mut self) {
let mut sink = self.tunnel.pin_sink();
tokio::spawn(async move {
let ret = sink.close().await;
tracing::info!(error = ?ret, "peer conn tunnel closed.");
});
log::info!("peer conn {:?} drop", self.conn_id);
}
}
impl Debug for PeerConn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerConn")
.field("conn_id", &self.conn_id)
.field("my_peer_id", &self.my_peer_id)
.field("info", &self.info)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::common::global_ctx::tests::get_mock_global_ctx;
use crate::common::new_peer_id;
use crate::tunnels::tunnel_filter::tests::DropSendTunnelFilter;
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
#[tokio::test]
async fn peer_conn_handshake() {
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
let (c, s) = create_ring_tunnel_pair();
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let c = TunnelWithFilter::new(c, c_recorder.clone());
let s = TunnelWithFilter::new(s, s_recorder.clone());
let c_peer_id = new_peer_id();
let s_peer_id = new_peer_id();
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
let (c_ret, s_ret) = tokio::join!(
c_peer.do_handshake_as_client(),
s_peer.do_handshake_as_server()
);
c_ret.unwrap();
s_ret.unwrap();
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
assert_eq!(c_peer.get_peer_id(), s_peer_id);
assert_eq!(s_peer.get_peer_id(), c_peer_id);
assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity());
assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default());
}
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
let (c, s) = create_ring_tunnel_pair();
// drop 1-3 packets should not affect pingpong
let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end));
let c = TunnelWithFilter::new(c, c_recorder.clone());
let c_peer_id = new_peer_id();
let s_peer_id = new_peer_id();
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
let (c_ret, s_ret) = tokio::join!(
c_peer.do_handshake_as_client(),
s_peer.do_handshake_as_server()
);
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
assert!(c_ret.is_ok());
assert!(s_ret.is_ok());
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
c_peer.set_close_event_sender(close_send);
c_peer.start_pingpong();
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
// wait 5s, conn should not be disconnected
tokio::time::sleep(Duration::from_secs(15)).await;
if conn_closed {
assert!(close_recv.try_recv().is_ok());
} else {
assert!(close_recv.try_recv().is_err());
}
}
#[tokio::test]
async fn peer_conn_pingpong_timeout() {
peer_conn_pingpong_test_common(3, 5, false).await;
peer_conn_pingpong_test_common(5, 12, true).await;
}
}
+641
View File
@@ -0,0 +1,641 @@
use std::{
fmt::Debug,
net::Ipv4Addr,
sync::{Arc, Weak},
};
use async_trait::async_trait;
use futures::StreamExt;
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex, RwLock,
},
task::JoinSet,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::bytes::{Bytes, BytesMut};
use crate::{
common::{
error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string,
PeerId,
},
peers::{
packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport,
route_trait::RouteInterface, PeerPacketFilter,
},
tunnels::{SinkItem, Tunnel, TunnelConnector},
};
use super::{
foreign_network_client::ForeignNetworkClient,
foreign_network_manager::ForeignNetworkManager,
peer_conn::PeerConnId,
peer_map::PeerMap,
peer_ospf_route::PeerRoute,
peer_rip_route::BasicRoute,
peer_rpc::PeerRpcManager,
route_trait::{ArcRoute, Route},
BoxNicPacketFilter, BoxPeerPacketFilter,
};
struct RpcTransport {
my_peer_id: PeerId,
peers: Weak<PeerMap>,
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for RpcTransport {
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
let foreign_peers = self
.foreign_peers
.lock()
.await
.as_ref()
.ok_or(Error::Unknown)?
.upgrade()
.ok_or(Error::Unknown)?;
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
let ret = peers.send_msg(msg.clone(), dst_peer_id).await;
if matches!(ret, Err(Error::RouteError(..))) && foreign_peers.has_next_hop(dst_peer_id) {
tracing::info!(
?dst_peer_id,
?self.my_peer_id,
"failed to send msg to peer, try foreign network",
);
return foreign_peers.send_msg(msg, dst_peer_id).await;
}
ret
}
async fn recv(&self) -> Result<Bytes, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await {
Ok(o)
} else {
Err(Error::Unknown)
}
}
}
pub enum RouteAlgoType {
Rip,
Ospf,
None,
}
enum RouteAlgoInst {
Rip(Arc<BasicRoute>),
Ospf(Arc<PeerRoute>),
None,
}
pub struct PeerManager {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
nic_channel: mpsc::Sender<SinkItem>,
tasks: Arc<Mutex<JoinSet<()>>>,
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
peers: Arc<PeerMap>,
peer_rpc_mgr: Arc<PeerRpcManager>,
peer_rpc_tspt: Arc<RpcTransport>,
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>,
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>,
route_algo_inst: RouteAlgoInst,
foreign_network_manager: Arc<ForeignNetworkManager>,
foreign_network_client: Arc<ForeignNetworkClient>,
}
impl Debug for PeerManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerManager")
.field("my_peer_id", &self.my_peer_id())
.field("instance_name", &self.global_ctx.inst_name)
.field("net_ns", &self.global_ctx.net_ns.name())
.finish()
}
}
impl PeerManager {
pub fn new(
route_algo: RouteAlgoType,
global_ctx: ArcGlobalCtx,
nic_channel: mpsc::Sender<SinkItem>,
) -> Self {
let my_peer_id = rand::random();
let (packet_send, packet_recv) = mpsc::channel(100);
let peers = Arc::new(PeerMap::new(
packet_send.clone(),
global_ctx.clone(),
my_peer_id,
));
// TODO: remove these because we have impl pipeline processor.
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_tspt = Arc::new(RpcTransport {
my_peer_id,
peers: Arc::downgrade(&peers),
foreign_peers: Mutex::new(None),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
peer_rpc_tspt_sender,
});
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
let route_algo_inst = match route_algo {
RouteAlgoType::Rip => {
RouteAlgoInst::Rip(Arc::new(BasicRoute::new(my_peer_id, global_ctx.clone())))
}
RouteAlgoType::Ospf => RouteAlgoInst::Ospf(PeerRoute::new(
my_peer_id,
global_ctx.clone(),
peer_rpc_mgr.clone(),
)),
RouteAlgoType::None => RouteAlgoInst::None,
};
let foreign_network_manager = Arc::new(ForeignNetworkManager::new(
my_peer_id,
global_ctx.clone(),
packet_send.clone(),
));
let foreign_network_client = Arc::new(ForeignNetworkClient::new(
global_ctx.clone(),
packet_send.clone(),
peer_rpc_mgr.clone(),
my_peer_id,
));
PeerManager {
my_peer_id,
global_ctx,
nic_channel,
tasks: Arc::new(Mutex::new(JoinSet::new())),
packet_recv: Arc::new(Mutex::new(Some(packet_recv))),
peers: peers.clone(),
peer_rpc_mgr,
peer_rpc_tspt: rpc_tspt,
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
route_algo_inst,
foreign_network_manager,
foreign_network_client,
}
}
pub async fn add_client_tunnel(
&self,
tunnel: Box<dyn Tunnel>,
) -> Result<(PeerId, PeerConnId), Error> {
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
let peer_id = peer.get_peer_id();
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
} else {
self.foreign_network_client.add_new_peer_conn(peer).await;
}
Ok((peer_id, conn_id))
}
#[tracing::instrument]
pub async fn try_connect<C>(&self, mut connector: C) -> Result<(PeerId, PeerConnId), Error>
where
C: TunnelConnector + Debug,
{
let ns = self.global_ctx.net_ns.clone();
let t = ns
.run_async(|| async move { connector.connect().await })
.await?;
self.add_client_tunnel(t).await
}
#[tracing::instrument]
pub async fn add_tunnel_as_server(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
tracing::info!("add tunnel as server start");
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_server().await?;
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
} else {
self.foreign_network_manager.add_peer_conn(peer).await?;
}
tracing::info!("add tunnel as server done");
Ok(())
}
async fn start_peer_recv(&self) {
let mut recv = ReceiverStream::new(self.packet_recv.lock().await.take().unwrap());
let my_peer_id = self.my_peer_id;
let peers = self.peers.clone();
let pipe_line = self.peer_packet_process_pipeline.clone();
self.tasks.lock().await.spawn(async move {
log::trace!("start_peer_recv");
while let Some(ret) = recv.next().await {
log::trace!("peer recv a packet...: {:?}", ret);
let packet = packet::Packet::decode(&ret);
let from_peer_id: PeerId = packet.from_peer.into();
let to_peer_id: PeerId = packet.to_peer.into();
if to_peer_id != my_peer_id {
log::trace!(
"need forward: to_peer_id: {:?}, my_peer_id: {:?}",
to_peer_id,
my_peer_id
);
let ret = peers.send_msg(ret.clone(), to_peer_id).await;
if ret.is_err() {
log::error!(
"forward packet error: {:?}, dst: {:?}, from: {:?}",
ret,
to_peer_id,
from_peer_id
);
}
} else {
let mut processed = false;
for pipeline in pipe_line.read().await.iter().rev() {
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
{
processed = true;
break;
}
}
if !processed {
tracing::error!("unexpected packet: {:?}", ret);
}
}
}
panic!("done_peer_recv");
});
}
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
// newest pipeline will be executed first
self.peer_packet_process_pipeline
.write()
.await
.push(pipeline);
}
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
// newest pipeline will be executed first
self.nic_packet_process_pipeline
.write()
.await
.push(pipeline);
}
async fn init_packet_process_pipeline(&self) {
// for tun/tap ip/eth packet.
struct NicPacketProcessor {
nic_channel: mpsc::Sender<SinkItem>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for NicPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if packet.packet_type == packet::PacketType::Data {
// TODO: use a function to get the body ref directly for zero copy
self.nic_channel
.send(extract_bytes_from_archived_string(data, &packet.payload))
.await
.unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(NicPacketProcessor {
nic_channel: self.nic_channel.clone(),
}))
.await;
// for peer rpc packet
struct PeerRpcPacketProcessor {
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for PeerRpcPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if packet.packet_type == packet::PacketType::TaRpc {
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor {
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
}))
.await;
}
pub async fn add_route<T>(&self, route: T)
where
T: Route + PeerPacketFilter + Send + Sync + Clone + 'static,
{
// for route
self.add_packet_process_pipeline(Box::new(route.clone()))
.await;
struct Interface {
my_peer_id: PeerId,
peers: Weak<PeerMap>,
foreign_network_client: Weak<ForeignNetworkClient>,
}
#[async_trait]
impl RouteInterface for Interface {
async fn list_peers(&self) -> Vec<PeerId> {
let Some(foreign_client) = self.foreign_network_client.upgrade() else {
return vec![];
};
let Some(peer_map) = self.peers.upgrade() else {
return vec![];
};
let mut peers = foreign_client.list_foreign_peers();
peers.extend(peer_map.list_peers_with_conn().await);
peers
}
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let foreign_client = self
.foreign_network_client
.upgrade()
.ok_or(Error::Unknown)?;
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
let packet_bytes: Bytes =
packet::Packet::new_route_packet(self.my_peer_id, dst_peer_id, route_id, &msg)
.into();
if foreign_client.has_next_hop(dst_peer_id) {
return foreign_client.send_msg(packet_bytes, dst_peer_id).await;
}
peer_map.send_msg_directly(packet_bytes, dst_peer_id).await
}
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
}
let my_peer_id = self.my_peer_id;
let _route_id = route
.open(Box::new(Interface {
my_peer_id,
peers: Arc::downgrade(&self.peers),
foreign_network_client: Arc::downgrade(&self.foreign_network_client),
}))
.await
.unwrap();
let arc_route: ArcRoute = Arc::new(Box::new(route));
self.peers.add_route(arc_route).await;
}
pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> {
match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => Box::new(route.clone()),
RouteAlgoInst::Ospf(route) => Box::new(route.clone()),
RouteAlgoInst::None => panic!("no route"),
}
}
pub async fn list_routes(&self) -> Vec<crate::rpc::Route> {
self.get_route().list_routes().await
}
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
data = pipeline.try_process_packet_from_nic(data).await;
}
data
}
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
self.peers.send_msg(msg, dst_peer_id).await
}
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
log::trace!(
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
msg,
ipv4_addr
);
let mut dst_peers = vec![];
// NOTE: currently we only support ipv4 and cidr is 24
if ipv4_addr.is_broadcast() || ipv4_addr.is_multicast() || ipv4_addr.octets()[3] == 255 {
dst_peers.extend(
self.peers
.list_routes()
.await
.iter()
.map(|x| x.key().clone()),
);
} else if let Some(peer_id) = self.peers.get_peer_id_by_ipv4(&ipv4_addr).await {
dst_peers.push(peer_id);
}
if dst_peers.is_empty() {
tracing::info!("no peer id for ipv4: {}", ipv4_addr);
return Ok(());
}
let msg = self.run_nic_packet_process_pipeline(msg).await;
let mut errs: Vec<Error> = vec![];
for peer_id in dst_peers.iter() {
let msg: Bytes =
packet::Packet::new_data_packet(self.my_peer_id, peer_id.clone(), &msg).into();
let send_ret = self.peers.send_msg(msg.clone(), *peer_id).await;
if matches!(send_ret, Err(Error::RouteError(..)))
&& self.foreign_network_client.has_next_hop(*peer_id)
{
let foreign_send_ret = self.foreign_network_client.send_msg(msg, *peer_id).await;
if foreign_send_ret.is_ok() {
continue;
}
}
if let Err(send_ret) = send_ret {
errs.push(send_ret);
}
}
tracing::trace!(?dst_peers, "do send_msg in peer manager done");
if errs.is_empty() {
Ok(())
} else {
tracing::error!(?errs, "send_msg has error");
Err(anyhow::anyhow!("send_msg has error: {:?}", errs).into())
}
}
async fn run_clean_peer_without_conn_routine(&self) {
let peer_map = self.peers.clone();
self.tasks.lock().await.spawn(async move {
loop {
peer_map.clean_peer_without_conn().await;
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
}
});
}
async fn run_foriegn_network(&self) {
self.peer_rpc_tspt
.foreign_peers
.lock()
.await
.replace(Arc::downgrade(&self.foreign_network_client));
self.foreign_network_manager.run().await;
self.foreign_network_client.run().await;
}
pub async fn run(&self) -> Result<(), Error> {
match &self.route_algo_inst {
RouteAlgoInst::Ospf(route) => self.add_route(route.clone()).await,
RouteAlgoInst::Rip(route) => self.add_route(route.clone()).await,
RouteAlgoInst::None => {}
};
self.init_packet_process_pipeline().await;
self.peer_rpc_mgr.run();
self.start_peer_recv().await;
self.run_clean_peer_without_conn_routine().await;
self.run_foriegn_network().await;
Ok(())
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {
self.peers.clone()
}
pub fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
self.peer_rpc_mgr.clone()
}
pub fn my_node_id(&self) -> uuid::Uuid {
self.global_ctx.get_id()
}
pub fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
self.global_ctx.clone()
}
pub fn get_nic_channel(&self) -> mpsc::Sender<SinkItem> {
self.nic_channel.clone()
}
pub fn get_basic_route(&self) -> Arc<BasicRoute> {
match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => route.clone(),
_ => panic!("not rip route"),
}
}
pub fn get_foreign_network_manager(&self) -> Arc<ForeignNetworkManager> {
self.foreign_network_manager.clone()
}
pub fn get_foreign_network_client(&self) -> Arc<ForeignNetworkClient> {
self.foreign_network_client.clone()
}
}
#[cfg(test)]
mod tests {
use crate::{
connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
peers::tests::{connect_peer_manager, wait_for_condition, wait_route_appear},
rpc::NatType,
};
#[tokio::test]
async fn drop_peer_manager() {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
// wait mgr_a have 2 peers
wait_for_condition(
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 2 },
std::time::Duration::from_secs(5),
)
.await;
drop(peer_mgr_b);
wait_for_condition(
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 1 },
std::time::Duration::from_secs(5),
)
.await;
}
}
+234
View File
@@ -0,0 +1,234 @@
use std::{net::Ipv4Addr, sync::Arc};
use anyhow::Context;
use dashmap::DashMap;
use tokio::sync::{mpsc, RwLock};
use tokio_util::bytes::Bytes;
use crate::{
common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
PeerId,
},
rpc::PeerConnInfo,
tunnels::TunnelError,
};
use super::{
peer::Peer,
peer_conn::{PeerConn, PeerConnId},
route_trait::ArcRoute,
};
pub struct PeerMap {
global_ctx: ArcGlobalCtx,
my_peer_id: PeerId,
peer_map: DashMap<PeerId, Arc<Peer>>,
packet_send: mpsc::Sender<Bytes>,
routes: RwLock<Vec<ArcRoute>>,
}
impl PeerMap {
pub fn new(
packet_send: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
my_peer_id: PeerId,
) -> Self {
PeerMap {
global_ctx,
my_peer_id,
peer_map: DashMap::new(),
packet_send,
routes: RwLock::new(Vec::new()),
}
}
async fn add_new_peer(&self, peer: Peer) {
let peer_id = peer.peer_node_id.clone();
self.peer_map.insert(peer_id.clone(), Arc::new(peer));
self.global_ctx
.issue_event(GlobalCtxEvent::PeerAdded(peer_id));
}
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
let peer_id = peer_conn.get_peer_id();
let no_entry = self.peer_map.get(&peer_id).is_none();
if no_entry {
let new_peer = Peer::new(peer_id, self.packet_send.clone(), self.global_ctx.clone());
new_peer.add_peer_conn(peer_conn).await;
self.add_new_peer(new_peer).await;
} else {
let peer = self.peer_map.get(&peer_id).unwrap().clone();
peer.add_peer_conn(peer_conn).await;
}
}
fn get_peer_by_id(&self, peer_id: PeerId) -> Option<Arc<Peer>> {
self.peer_map.get(&peer_id).map(|v| v.clone())
}
pub fn has_peer(&self, peer_id: PeerId) -> bool {
self.peer_map.contains_key(&peer_id)
}
pub async fn send_msg_directly(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
if dst_peer_id == self.my_peer_id {
return Ok(self
.packet_send
.send(msg)
.await
.with_context(|| "send msg to self failed")?);
}
match self.get_peer_by_id(dst_peer_id) {
Some(peer) => {
peer.send_msg(msg).await?;
}
None => {
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
return Err(Error::RouteError(None));
}
}
Ok(())
}
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
if dst_peer_id == self.my_peer_id {
return Ok(self
.packet_send
.send(msg)
.await
.with_context(|| "send msg to self failed")?);
}
// get route info
let mut gateway_peer_id = None;
for route in self.routes.read().await.iter() {
gateway_peer_id = route.get_next_hop(dst_peer_id).await;
if gateway_peer_id.is_none() {
continue;
} else {
break;
}
}
if gateway_peer_id.is_none() && self.has_peer(dst_peer_id) {
gateway_peer_id = Some(dst_peer_id);
}
let Some(gateway_peer_id) = gateway_peer_id else {
tracing::trace!(
"no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}",
dst_peer_id,
self.peer_map.iter().map(|v| *v.key()).collect::<Vec<_>>(),
self.my_peer_id
);
return Err(Error::RouteError(None));
};
self.send_msg_directly(msg.clone(), gateway_peer_id).await?;
return Ok(());
}
pub async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
for route in self.routes.read().await.iter() {
let peer_id = route.get_peer_id_by_ipv4(ipv4).await;
if peer_id.is_some() {
return peer_id;
}
}
None
}
pub fn is_empty(&self) -> bool {
self.peer_map.is_empty()
}
pub async fn list_peers(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
for item in self.peer_map.iter() {
let peer_id = item.key();
ret.push(*peer_id);
}
ret
}
pub async fn list_peers_with_conn(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
let peers = self.list_peers().await;
for peer_id in peers.iter() {
let Some(peer) = self.get_peer_by_id(*peer_id) else {
continue;
};
if peer.list_peer_conns().await.len() > 0 {
ret.push(*peer_id);
}
}
ret
}
pub async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
if let Some(p) = self.get_peer_by_id(peer_id) {
Some(p.list_peer_conns().await)
} else {
return None;
}
}
pub async fn close_peer_conn(
&self,
peer_id: PeerId,
conn_id: &PeerConnId,
) -> Result<(), Error> {
if let Some(p) = self.get_peer_by_id(peer_id) {
p.close_peer_conn(conn_id).await
} else {
return Err(Error::NotFound);
}
}
pub async fn close_peer(&self, peer_id: PeerId) -> Result<(), TunnelError> {
let remove_ret = self.peer_map.remove(&peer_id);
self.global_ctx
.issue_event(GlobalCtxEvent::PeerRemoved(peer_id));
tracing::info!(
?peer_id,
has_old_value = ?remove_ret.is_some(),
peer_ref_counter = ?remove_ret.map(|v| Arc::strong_count(&v.1)),
"peer is closed"
);
Ok(())
}
pub async fn add_route(&self, route: ArcRoute) {
let mut routes = self.routes.write().await;
routes.insert(0, route);
}
pub async fn clean_peer_without_conn(&self) {
let mut to_remove = vec![];
for peer_id in self.list_peers().await {
let conns = self.list_peer_conns(peer_id).await;
if conns.is_none() || conns.as_ref().unwrap().is_empty() {
to_remove.push(peer_id);
}
}
for peer_id in to_remove {
self.close_peer(peer_id).await.unwrap();
}
}
pub async fn list_routes(&self) -> DashMap<PeerId, PeerId> {
let route_map = DashMap::new();
for route in self.routes.read().await.iter() {
for item in route.list_routes().await.iter() {
route_map.insert(item.peer_id, item.next_hop_peer_id);
}
}
route_map
}
}
File diff suppressed because it is too large Load Diff
+770
View File
@@ -0,0 +1,770 @@
use std::{
net::Ipv4Addr,
sync::{atomic::AtomicU32, Arc},
time::{Duration, Instant},
};
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
peers::{
packet,
route_trait::{Route, RouteInterfaceBox},
},
rpc::{NatType, StunInfo},
};
use super::{packet::CtrlPacketPayload, PeerPacketFilter};
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
const ROUTE_EXPIRED_SEC: u64 = 70;
type Version = u32;
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)]
// Derives can be passed through to the generated type:
pub struct SyncPeerInfo {
// means next hop in route table.
pub peer_id: PeerId,
pub cost: u32,
pub ipv4_addr: Option<Ipv4Addr>,
pub proxy_cidrs: Vec<String>,
pub hostname: Option<String>,
pub udp_stun_info: i8,
}
impl SyncPeerInfo {
pub fn new_self(from_peer: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
SyncPeerInfo {
peer_id: from_peer,
cost: 0,
ipv4_addr: global_ctx.get_ipv4(),
proxy_cidrs: global_ctx
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
.collect(),
hostname: global_ctx.get_hostname(),
udp_stun_info: global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type as i8,
}
}
pub fn clone_for_route_table(&self, next_hop: PeerId, cost: u32, from: &Self) -> Self {
SyncPeerInfo {
peer_id: next_hop,
cost,
ipv4_addr: from.ipv4_addr.clone(),
proxy_cidrs: from.proxy_cidrs.clone(),
hostname: from.hostname.clone(),
udp_stun_info: from.udp_stun_info,
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
pub struct SyncPeer {
pub myself: SyncPeerInfo,
pub neighbors: Vec<SyncPeerInfo>,
// the route table version of myself
pub version: Version,
// the route table version of peer that we have received last time
pub peer_version: Option<Version>,
// if we do not have latest peer version, need_reply is true
pub need_reply: bool,
}
impl SyncPeer {
pub fn new(
from_peer: PeerId,
_to_peer: PeerId,
neighbors: Vec<SyncPeerInfo>,
global_ctx: ArcGlobalCtx,
version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Self {
SyncPeer {
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
neighbors,
version,
peer_version,
need_reply,
}
}
}
#[derive(Debug)]
struct SyncPeerFromRemote {
packet: SyncPeer,
last_update: std::time::Instant,
}
type SyncPeerFromRemoteMap = Arc<DashMap<PeerId, SyncPeerFromRemote>>;
#[derive(Debug)]
struct RouteTable {
route_info: DashMap<PeerId, SyncPeerInfo>,
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
}
impl RouteTable {
fn new() -> Self {
RouteTable {
route_info: DashMap::new(),
ipv4_peer_id_map: DashMap::new(),
cidr_peer_id_map: DashMap::new(),
}
}
fn copy_from(&self, other: &Self) {
self.route_info.clear();
for item in other.route_info.iter() {
let (k, v) = item.pair();
self.route_info.insert(*k, v.clone());
}
self.ipv4_peer_id_map.clear();
for item in other.ipv4_peer_id_map.iter() {
let (k, v) = item.pair();
self.ipv4_peer_id_map.insert(*k, *v);
}
self.cidr_peer_id_map.clear();
for item in other.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
self.cidr_peer_id_map.insert(*k, *v);
}
}
}
#[derive(Debug, Clone)]
struct RouteVersion(Arc<AtomicU32>);
impl RouteVersion {
fn new() -> Self {
// RouteVersion(Arc::new(AtomicU32::new(rand::random())))
RouteVersion(Arc::new(AtomicU32::new(0)))
}
fn get(&self) -> Version {
self.0.load(std::sync::atomic::Ordering::Relaxed)
}
fn inc(&self) {
self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
pub struct BasicRoute {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
route_table: Arc<RouteTable>,
sync_peer_from_remote: SyncPeerFromRemoteMap,
tasks: Mutex<JoinSet<()>>,
need_sync_notifier: Arc<tokio::sync::Notify>,
version: RouteVersion,
myself: Arc<RwLock<SyncPeerInfo>>,
last_send_time_map: Arc<DashMap<PeerId, (Version, Option<Version>, Instant)>>,
}
impl BasicRoute {
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx) -> Self {
BasicRoute {
my_peer_id,
global_ctx: global_ctx.clone(),
interface: Arc::new(Mutex::new(None)),
route_table: Arc::new(RouteTable::new()),
sync_peer_from_remote: Arc::new(DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
version: RouteVersion::new(),
myself: Arc::new(RwLock::new(SyncPeerInfo::new_self(
my_peer_id.into(),
&global_ctx,
))),
last_send_time_map: Arc::new(DashMap::new()),
}
}
fn update_route_table(
my_id: PeerId,
sync_peer_reqs: SyncPeerFromRemoteMap,
route_table: Arc<RouteTable>,
) {
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
let new_route_table = Arc::new(RouteTable::new());
for item in sync_peer_reqs.iter() {
Self::update_route_table_with_req(my_id, &item.value().packet, new_route_table.clone());
}
route_table.copy_from(&new_route_table);
}
async fn update_myself(
my_peer_id: PeerId,
myself: &Arc<RwLock<SyncPeerInfo>>,
global_ctx: &ArcGlobalCtx,
) -> bool {
let new_myself = SyncPeerInfo::new_self(my_peer_id, &global_ctx);
if *myself.read().await != new_myself {
*myself.write().await = new_myself;
true
} else {
false
}
}
fn update_route_table_with_req(my_id: PeerId, packet: &SyncPeer, route_table: Arc<RouteTable>) {
let peer_id = packet.myself.peer_id.clone();
let update = |cost: u32, peer_info: &SyncPeerInfo| {
let node_id: PeerId = peer_info.peer_id.into();
let ret = route_table
.route_info
.entry(node_id.clone().into())
.and_modify(|info| {
if info.cost > cost {
*info = info.clone_for_route_table(peer_id, cost, &peer_info);
}
})
.or_insert(
peer_info
.clone()
.clone_for_route_table(peer_id, cost, &peer_info),
)
.value()
.clone();
if ret.cost > 6 {
log::error!(
"cost too large: {}, may lost connection, remove it",
ret.cost
);
route_table.route_info.remove(&node_id);
}
log::trace!(
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
node_id,
peer_id,
cost,
&peer_info
);
if let Some(ipv4) = peer_info.ipv4_addr {
route_table
.ipv4_peer_id_map
.insert(ipv4.clone(), node_id.clone().into());
}
for cidr in peer_info.proxy_cidrs.iter() {
let cidr: cidr::IpCidr = cidr.parse().unwrap();
route_table
.cidr_peer_id_map
.insert(cidr, node_id.clone().into());
}
};
for neighbor in packet.neighbors.iter() {
if neighbor.peer_id == my_id {
continue;
}
update(neighbor.cost + 1, &neighbor);
log::trace!("route info: {:?}", neighbor);
}
// add the sender peer to route info
update(1, &packet.myself);
log::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
}
async fn send_sync_peer_request(
interface: &RouteInterfaceBox,
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
peer_id: PeerId,
route_table: Arc<RouteTable>,
my_version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Result<(), Error> {
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
// copy the route info
for item in route_table.route_info.iter() {
let (k, v) = item.pair();
route_info_copy.push(v.clone().clone_for_route_table(*k, v.cost, &v));
}
let msg = SyncPeer::new(
my_peer_id,
peer_id,
route_info_copy,
global_ctx,
my_version,
peer_version,
need_reply,
);
// TODO: this may exceed the MTU of the tunnel
interface
.send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id)
.await
}
async fn sync_peer_periodically(&self) {
let route_table = self.route_table.clone();
let global_ctx = self.global_ctx.clone();
let my_peer_id = self.my_peer_id.clone();
let interface = self.interface.clone();
let notifier = self.need_sync_notifier.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let myself = self.myself.clone();
let version = self.version.clone();
let last_send_time_map = self.last_send_time_map.clone();
self.tasks.lock().await.spawn(
async move {
loop {
if Self::update_myself(my_peer_id,&myself, &global_ctx).await {
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table version when myself changed"
);
}
let lockd_interface = interface.lock().await;
let interface = lockd_interface.as_ref().unwrap();
let last_send_time_map_new = DashMap::new();
let peers = interface.list_peers().await;
for peer in peers.iter() {
let last_send_time = last_send_time_map.get(peer).map(|v| *v).unwrap_or((0, None, Instant::now() - Duration::from_secs(3600)));
let my_version_peer_saved = sync_peer_from_remote.get(peer).and_then(|v| v.packet.peer_version);
let peer_have_latest_version = my_version_peer_saved == Some(version.get());
if peer_have_latest_version && last_send_time.2.elapsed().as_secs() < SEND_ROUTE_PERIOD_SEC {
last_send_time_map_new.insert(*peer, last_send_time);
continue;
}
tracing::trace!(
my_id = ?my_peer_id,
dst_peer_id = ?peer,
version = version.get(),
?my_version_peer_saved,
last_send_version = ?last_send_time.0,
last_send_peer_version = ?last_send_time.1,
last_send_elapse = ?last_send_time.2.elapsed().as_secs(),
"need send route info"
);
let peer_version_we_saved = sync_peer_from_remote.get(&peer).and_then(|v| Some(v.packet.version));
last_send_time_map_new.insert(*peer, (version.get(), peer_version_we_saved, Instant::now()));
let ret = Self::send_sync_peer_request(
interface,
my_peer_id.clone(),
global_ctx.clone(),
*peer,
route_table.clone(),
version.get(),
peer_version_we_saved,
!peer_have_latest_version,
)
.await;
match &ret {
Ok(_) => {
log::trace!("send sync peer request to peer: {}", peer);
}
Err(Error::PeerNoConnectionError(_)) => {
log::trace!("peer {} no connection", peer);
}
Err(e) => {
log::error!(
"send sync peer request to peer: {} error: {:?}",
peer,
e
);
}
};
}
last_send_time_map.clear();
for item in last_send_time_map_new.iter() {
let (k, v) = item.pair();
last_send_time_map.insert(*k, *v);
}
tokio::select! {
_ = notifier.notified() => {
log::trace!("sync peer request triggered by notifier");
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
log::trace!("sync peer request triggered by timeout");
}
}
}
}
.instrument(
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
),
);
}
async fn check_expired_sync_peer_from_remote(&self) {
let route_table = self.route_table.clone();
let my_peer_id = self.my_peer_id.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let notifier = self.need_sync_notifier.clone();
let interface = self.interface.clone();
let version = self.version.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut need_update_route = false;
let now = std::time::Instant::now();
let mut need_remove = Vec::new();
let connected_peers = interface.lock().await.as_ref().unwrap().list_peers().await;
for item in sync_peer_from_remote.iter() {
let (k, v) = item.pair();
if now.duration_since(v.last_update).as_secs() > ROUTE_EXPIRED_SEC
|| !connected_peers.contains(k)
{
need_update_route = true;
need_remove.insert(0, k.clone());
}
}
for k in need_remove.iter() {
log::warn!("remove expired sync peer: {:?}", k);
sync_peer_from_remote.remove(k);
}
if need_update_route {
Self::update_route_table(
my_peer_id,
sync_peer_from_remote.clone(),
route_table.clone(),
);
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table when check expired peer"
);
notifier.notify_one();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
let ipv4 = std::net::IpAddr::V4(*ipv4);
for item in self.route_table.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
if k.contains(&ipv4) {
return Some(*v);
}
}
None
}
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) {
let packet = postcard::from_bytes::<SyncPeer>(&packet).unwrap();
let p = &packet;
let mut updated = true;
assert_eq!(packet.myself.peer_id, src_peer_id);
self.sync_peer_from_remote
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
if v.packet.myself == p.myself && v.packet.neighbors == p.neighbors {
updated = false;
} else {
v.packet = p.clone();
}
v.packet.version = p.version;
v.packet.peer_version = p.peer_version;
v.last_update = std::time::Instant::now();
})
.or_insert(SyncPeerFromRemote {
packet: p.clone(),
last_update: std::time::Instant::now(),
});
if updated {
Self::update_route_table(
self.my_peer_id.clone(),
self.sync_peer_from_remote.clone(),
self.route_table.clone(),
);
self.version.inc();
tracing::info!(
my_id = ?self.my_peer_id,
?p,
version = self.version.get(),
"update route table when receive route packet"
);
}
if packet.need_reply {
self.last_send_time_map
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
const FAST_REPLY_DURATION: u64 =
SEND_ROUTE_PERIOD_SEC - SEND_ROUTE_FAST_REPLY_SEC;
if v.0 != self.version.get() || v.1 != Some(p.version) {
v.2 = Instant::now() - Duration::from_secs(3600);
} else if v.2.elapsed().as_secs() < FAST_REPLY_DURATION {
// do not send same version route info too frequently
v.2 = Instant::now() - Duration::from_secs(FAST_REPLY_DURATION);
}
});
}
if updated || packet.need_reply {
self.need_sync_notifier.notify_one();
}
}
}
#[async_trait]
impl Route for BasicRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
*self.interface.lock().await = Some(interface);
self.sync_peer_periodically().await;
self.check_expired_sync_peer_from_remote().await;
Ok(1)
}
async fn close(&self) {}
async fn get_next_hop(&self, dst_peer_id: PeerId) -> Option<PeerId> {
match self.route_table.route_info.get(&dst_peer_id) {
Some(info) => {
return Some(info.peer_id.clone().into());
}
None => {
log::error!("no route info for dst_peer_id: {}", dst_peer_id);
return None;
}
}
}
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
let mut routes = Vec::new();
let parse_route_info = |real_peer_id: PeerId, route_info: &SyncPeerInfo| {
let mut route = crate::rpc::Route::default();
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
ipv4_addr.to_string()
} else {
"".to_string()
};
route.peer_id = real_peer_id;
route.next_hop_peer_id = route_info.peer_id;
route.cost = route_info.cost as i32;
route.proxy_cidrs = route_info.proxy_cidrs.clone();
route.hostname = if let Some(hostname) = &route_info.hostname {
hostname.clone()
} else {
"".to_string()
};
let mut stun_info = StunInfo::default();
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
stun_info.set_udp_nat_type(udp_nat_type);
}
route.stun_info = Some(stun_info);
route
};
self.route_table.route_info.iter().for_each(|item| {
routes.push(parse_route_info(*item.key(), item.value()));
});
routes
}
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
return Some(*peer_id);
}
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
return Some(peer_id);
}
log::info!("no peer id for ipv4: {}", ipv4_addr);
return None;
}
}
#[async_trait::async_trait]
impl PeerPacketFilter for BasicRoute {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
_data: &Bytes,
) -> Option<()> {
if packet.packet_type == packet::PacketType::RoutePacket {
let CtrlPacketPayload::RoutePacket(route_packet) =
CtrlPacketPayload::from_packet(packet)
else {
return None;
};
self.handle_route_packet(
packet.from_peer.into(),
route_packet.body.into_boxed_slice().into(),
)
.await;
Some(())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, PeerId},
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
peer_rip_route::Version,
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
};
async fn create_mock_pmgr() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Rip,
get_mock_global_ctx(),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr.run().await.unwrap();
peer_mgr
}
#[tokio::test]
async fn test_rip_route() {
let peer_mgr_a = create_mock_pmgr().await;
let peer_mgr_b = create_mock_pmgr().await;
let peer_mgr_c = create_mock_pmgr().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let mgrs = vec![peer_mgr_a.clone(), peer_mgr_b.clone(), peer_mgr_c.clone()];
tokio::time::sleep(tokio::time::Duration::from_secs(4)).await;
let check_version = |version: Version, peer_id: PeerId, mgrs: &Vec<Arc<PeerManager>>| {
for mgr in mgrs.iter() {
tracing::warn!(
"check version: {:?}, {:?}, {:?}, {:?}",
version,
peer_id,
mgr,
mgr.get_basic_route().sync_peer_from_remote
);
assert_eq!(
version,
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.version,
);
assert_eq!(
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.peer_version
.unwrap(),
mgr.get_basic_route().version.get()
);
}
};
let check_sanity = || {
// check peer version in other peer mgr are correct.
check_version(
peer_mgr_b.get_basic_route().version.get(),
peer_mgr_b.my_peer_id(),
&vec![peer_mgr_a.clone(), peer_mgr_c.clone()],
);
check_version(
peer_mgr_a.get_basic_route().version.get(),
peer_mgr_a.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
check_version(
peer_mgr_c.get_basic_route().version.get(),
peer_mgr_c.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
};
check_sanity();
let versions = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let versions2 = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
assert_eq!(versions, versions2);
check_sanity();
assert!(peer_mgr_a.get_basic_route().version.get() <= 3);
assert!(peer_mgr_b.get_basic_route().version.get() <= 6);
assert!(peer_mgr_c.get_basic_route().version.get() <= 3);
}
}
+581
View File
@@ -0,0 +1,581 @@
use std::sync::{atomic::AtomicU32, Arc};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use rkyv::Deserialize;
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
use tokio::{
sync::mpsc::{self, UnboundedSender},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, PeerId},
peers::packet::Packet,
};
use super::packet::CtrlPacketPayload;
type PeerRpcServiceId = u32;
type PeerRpcTransactId = u32;
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
fn my_peer_id(&self) -> PeerId;
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error>;
async fn recv(&self) -> Result<Bytes, Error>;
}
type PacketSender = UnboundedSender<Packet>;
struct PeerRpcEndPoint {
peer_id: PeerId,
packet_sender: PacketSender,
tasks: JoinSet<()>,
}
type PeerRpcEndPointCreator = Box<dyn Fn(PeerId) -> PeerRpcEndPoint + Send + Sync + 'static>;
#[derive(Hash, Eq, PartialEq, Clone)]
struct PeerRpcClientCtxKey(PeerId, PeerRpcServiceId, PeerRpcTransactId);
// handle rpc request from one peer
pub struct PeerRpcManager {
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
tasks: JoinSet<()>,
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
peer_rpc_endpoints: Arc<DashMap<(PeerId, PeerRpcServiceId), PeerRpcEndPoint>>,
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
transact_id: AtomicU32,
}
impl std::fmt::Debug for PeerRpcManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerRpcManager")
.field("node_id", &self.tspt.my_peer_id())
.finish()
}
}
#[derive(Debug)]
struct TaRpcPacketInfo {
from_peer: PeerId,
to_peer: PeerId,
service_id: PeerRpcServiceId,
transact_id: PeerRpcTransactId,
is_req: bool,
content: Vec<u8>,
}
impl PeerRpcManager {
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
Self {
service_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(),
tspt: Arc::new(Box::new(tspt)),
service_registry: Arc::new(DashMap::new()),
peer_rpc_endpoints: Arc::new(DashMap::new()),
client_resp_receivers: Arc::new(DashMap::new()),
transact_id: AtomicU32::new(0),
}
}
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
where
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Resp:
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Fut: Send + 'static,
{
let tspt = self.tspt.clone();
let creator = Box::new(move |peer_id: PeerId| {
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
let my_peer_id_clone = tspt.my_peer_id();
let peer_id_clone = peer_id.clone();
let o = server.execute(s.clone());
tasks.spawn(o);
let tspt = tspt.clone();
tasks.spawn(async move {
let mut cur_req_peer_id = None;
let mut cur_transact_id = 0;
loop {
tokio::select! {
Some(resp) = client_transport.next() => {
let Some(cur_req_peer_id) = cur_req_peer_id.take() else {
tracing::error!("[PEER RPC MGR] cur_req_peer_id is none, ignore this resp");
continue;
};
tracing::trace!(resp = ?resp, "recv packet from client");
if resp.is_err() {
tracing::warn!(err = ?resp.err(),
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
continue;
}
let resp = resp.unwrap();
let serialized_resp = postcard::to_allocvec(&resp);
if serialized_resp.is_err() {
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
continue;
}
let msg = Packet::new_tarpc_packet(
tspt.my_peer_id(),
cur_req_peer_id,
service_id,
cur_transact_id,
false,
serialized_resp.unwrap(),
);
if let Err(e) = tspt.send(msg.into(), peer_id).await {
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
}
}
Some(packet) = packet_receiver.recv() => {
let info = Self::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
continue;
}
let info = info.unwrap();
if info.from_peer != peer_id {
tracing::warn!("recv packet from peer, but peer_id not match, ignore it");
continue;
}
if cur_req_peer_id.is_some() {
tracing::warn!("cur_req_peer_id is not none, ignore this packet");
continue;
}
assert_eq!(info.service_id, service_id);
cur_req_peer_id = Some(packet.from_peer.clone().into());
cur_transact_id = info.transact_id;
tracing::trace!("recv packet from peer, packet: {:?}", packet);
let decoded_ret = postcard::from_bytes(&info.content.as_slice());
if let Err(e) = decoded_ret {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
if let Err(e) = client_transport.send(decoded).await {
tracing::error!(error = ?e, "send to req to client transport failed");
}
}
else => {
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
}
}
}
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
tracing::info!(
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
peer_id,
service_id
);
return PeerRpcEndPoint {
peer_id,
packet_sender,
tasks,
};
// let resp = client_transport.next().await;
});
if let Some(_) = self.service_registry.insert(service_id, creator) {
panic!(
"[PEER RPC MGR] service {} is already registered",
service_id
);
}
log::info!(
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
service_id,
self.tspt.my_peer_id()
)
}
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet);
match &ctrl_packet_payload {
CtrlPacketPayload::TaRpc(id, tid, is_req, body) => Ok(TaRpcPacketInfo {
from_peer: packet.from_peer.into(),
to_peer: packet.to_peer.into(),
service_id: *id,
transact_id: *tid,
is_req: *is_req,
content: body.clone(),
}),
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
}
}
pub fn run(&self) {
let tspt = self.tspt.clone();
let service_registry = self.service_registry.clone();
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
let client_resp_receivers = self.client_resp_receivers.clone();
tokio::spawn(async move {
loop {
let Ok(o) = tspt.recv().await else {
tracing::warn!("peer rpc transport read aborted, exiting");
break;
};
let packet = Packet::decode(&o);
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
let info = Self::parse_rpc_packet(&packet).unwrap();
if info.is_req {
if !service_registry.contains_key(&info.service_id) {
log::warn!(
"service {} not found, my_node_id: {}",
info.service_id,
tspt.my_peer_id()
);
continue;
}
let endpoint = peer_rpc_endpoints
.entry((info.from_peer, info.service_id))
.or_insert_with(|| {
service_registry.get(&info.service_id).unwrap()(info.from_peer)
});
endpoint.packet_sender.send(packet).unwrap();
} else {
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
info.from_peer,
info.service_id,
info.transact_id,
)) {
log::trace!("recv resp: {:?}", packet);
if let Err(e) = a.send(packet) {
tracing::error!(error = ?e, "send resp to client failed");
}
} else {
log::warn!("client resp receiver not found, info: {:?}", info);
}
}
}
});
}
#[tracing::instrument(skip(f))]
pub async fn do_client_rpc_scoped<CM, Req, RpcRet, Fut>(
&self,
service_id: PeerRpcServiceId,
dst_peer_id: PeerId,
f: impl FnOnce(UnboundedChannel<CM, Req>) -> Fut,
) -> RpcRet
where
CM: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Req: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Fut: std::future::Future<Output = RpcRet>,
{
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (client_transport, server_transport) =
tarpc::transport::channel::unbounded::<CM, Req>();
let (mut server_s, mut server_r) = server_transport.split();
let transact_id = self
.transact_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tspt = self.tspt.clone();
tasks.spawn(async move {
while let Some(a) = server_r.next().await {
if a.is_err() {
tracing::error!(error = ?a.err(), "channel error");
continue;
}
let a = postcard::to_allocvec(&a.unwrap());
if a.is_err() {
tracing::error!(error = ?a.err(), "bincode serialize failed");
continue;
}
let a = Packet::new_tarpc_packet(
tspt.my_peer_id(),
dst_peer_id,
service_id,
transact_id,
true,
a.unwrap(),
);
if let Err(e) = tspt.send(a.into(), dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
}
}
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
});
tasks.spawn(async move {
while let Some(packet) = packet_receiver.recv().await {
tracing::trace!("tunnel recv: {:?}", packet);
let info = PeerRpcManager::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, "parse rpc packet failed");
continue;
}
let decoded = postcard::from_bytes(&info.unwrap().content.as_slice());
if let Err(e) = decoded {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
if let Err(e) = server_s.send(decoded.unwrap()).await {
tracing::error!(error = ?e, "send to rpc server channel failed");
}
}
tracing::warn!("[PEER RPC MGR] server packet read aborted");
});
let key = PeerRpcClientCtxKey(dst_peer_id, service_id, transact_id);
let _insert_ret = self
.client_resp_receivers
.insert(key.clone(), packet_sender);
let ret = f(client_transport).await;
self.client_resp_receivers.remove(&key);
ret
}
pub fn my_peer_id(&self) -> PeerId {
self.tspt.my_peer_id()
}
}
#[cfg(test)]
mod tests {
use futures::{SinkExt, StreamExt};
use tokio_util::bytes::Bytes;
use crate::{
common::{error::Error, new_peer_id, PeerId},
peers::{
peer_rpc::PeerRpcManager,
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
},
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
};
use super::PeerRpcManagerTransport;
#[tarpc::service]
pub trait TestRpcService {
async fn hello(s: String) -> String;
}
#[derive(Clone)]
struct MockService {
prefix: String,
}
#[tarpc::server]
impl TestRpcService for MockService {
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
format!("{} {}", self.prefix, s)
}
}
#[tokio::test]
async fn peer_rpc_basic_test() {
struct MockTransport {
tunnel: Box<dyn tunnels::Tunnel>,
my_peer_id: PeerId,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for MockTransport {
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
async fn send(&self, msg: Bytes, _dst_peer_id: PeerId) -> Result<(), Error> {
println!("rpc mgr send: {:?}", msg);
self.tunnel.pin_sink().send(msg).await.unwrap();
Ok(())
}
async fn recv(&self) -> Result<Bytes, Error> {
let ret = self.tunnel.pin_stream().next().await.unwrap();
println!("rpc mgr recv: {:?}", ret);
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
}
}
let (ct, st) = create_ring_tunnel_pair();
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: st,
my_peer_id: new_peer_id(),
});
server_rpc_mgr.run();
let s = MockService {
prefix: "hello".to_owned(),
};
server_rpc_mgr.run_service(1, s.serve());
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: ct,
my_peer_id: new_peer_id(),
});
client_rpc_mgr.run();
let ret = client_rpc_mgr
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), "hello abc");
}
#[tokio::test]
async fn test_rpc_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
let peer_mgr_c = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_peer_id()
);
assert_eq!(peer_mgr_c.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_c.get_peer_map().list_peers().await[0],
peer_mgr_b.my_peer_id()
);
let s = MockService {
prefix: "hello".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello abc");
// call again
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abcd".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello abcd");
let ip_list = peer_mgr_c
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "bcd".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello bcd");
}
#[tokio::test]
async fn test_multi_service_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_peer_id()
);
let s = MockService {
prefix: "hello_a".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let b = MockService {
prefix: "hello_b".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc");
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(2, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc");
}
}
+36
View File
@@ -0,0 +1,36 @@
use std::{net::Ipv4Addr, sync::Arc};
use async_trait::async_trait;
use tokio_util::bytes::Bytes;
use crate::common::{error::Error, PeerId};
#[async_trait]
pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>;
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error>;
fn my_peer_id(&self) -> PeerId;
}
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
#[async_trait]
#[auto_impl::auto_impl(Box, Arc)]
pub trait Route {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
async fn close(&self);
async fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId>;
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
None
}
}
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
+63
View File
@@ -0,0 +1,63 @@
use std::sync::Arc;
use crate::rpc::{
cli::PeerInfo,
peer_manage_rpc_server::PeerManageRpc,
{ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse},
};
use tonic::{Request, Response, Status};
use super::peer_manager::PeerManager;
pub struct PeerManagerRpcService {
peer_manager: Arc<PeerManager>,
}
impl PeerManagerRpcService {
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
PeerManagerRpcService { peer_manager }
}
pub async fn list_peers(&self) -> Vec<PeerInfo> {
let peers = self.peer_manager.get_peer_map().list_peers().await;
let mut peer_infos = Vec::new();
for peer in peers {
let mut peer_info = PeerInfo::default();
peer_info.peer_id = peer;
if let Some(conns) = self.peer_manager.get_peer_map().list_peer_conns(peer).await {
peer_info.conns = conns;
}
peer_infos.push(peer_info);
}
peer_infos
}
}
#[tonic::async_trait]
impl PeerManageRpc for PeerManagerRpcService {
async fn list_peer(
&self,
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListPeerResponse>, Status> {
let mut reply = ListPeerResponse::default();
let peers = self.list_peers().await;
for peer in peers {
reply.peer_infos.push(peer);
}
Ok(Response::new(reply))
}
async fn list_route(
&self,
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListRouteResponse>, Status> {
let mut reply = ListRouteResponse::default();
reply.routes = self.peer_manager.list_routes().await;
Ok(Response::new(reply))
}
}
+75
View File
@@ -0,0 +1,75 @@
use std::sync::Arc;
use futures::Future;
use crate::{
common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId},
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::peer_manager::{PeerManager, RouteAlgoType};
pub async fn create_mock_peer_manager() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx(),
s,
));
peer_mgr.run().await.unwrap();
peer_mgr
}
pub async fn connect_peer_manager(client: Arc<PeerManager>, server: Arc<PeerManager>) {
let (a_ring, b_ring) = create_ring_tunnel_pair();
let a_mgr_copy = client.clone();
tokio::spawn(async move {
a_mgr_copy.add_client_tunnel(a_ring).await.unwrap();
});
let b_mgr_copy = server.clone();
tokio::spawn(async move {
b_mgr_copy.add_tunnel_as_server(b_ring).await.unwrap();
});
}
pub async fn wait_route_appear_with_cost(
peer_mgr: Arc<PeerManager>,
node_id: PeerId,
cost: Option<i32>,
) -> Result<(), Error> {
let now = std::time::Instant::now();
while now.elapsed().as_secs() < 5 {
let route = peer_mgr.list_routes().await;
if route
.iter()
.any(|r| r.peer_id == node_id && (cost.is_none() || r.cost == cost.unwrap()))
{
return Ok(());
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
return Err(Error::NotFound);
}
pub async fn wait_route_appear(
peer_mgr: Arc<PeerManager>,
target_peer: Arc<PeerManager>,
) -> Result<(), Error> {
wait_route_appear_with_cost(peer_mgr.clone(), target_peer.my_peer_id(), None).await?;
wait_route_appear_with_cost(target_peer, peer_mgr.my_peer_id(), None).await
}
pub async fn wait_for_condition<F, FRet>(mut condition: F, timeout: std::time::Duration) -> ()
where
F: FnMut() -> FRet + Send,
FRet: Future<Output = bool>,
{
let now = std::time::Instant::now();
while now.elapsed() < timeout {
if condition().await {
return;
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
assert!(condition().await, "Timeout")
}