use std::sync::Arc; use crossbeam::atomic::AtomicCell; use dashmap::{DashMap, DashSet}; use parking_lot::RwLock; use tokio::{select, sync::mpsc}; use tracing::Instrument; use super::{ peer_conn::{PeerConn, PeerConnId}, PacketRecvChan, }; use crate::{ common::{ error::Error, global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, PeerId, }, proto::peer_rpc::PeerIdentityType, tunnel::packet_def::ZCPacket, }; use crate::{ common::{scoped_task::ScopedTask, shrink_dashmap}, proto::api::instance::PeerConnInfo, }; type ArcPeerConn = Arc; type ConnMap = Arc>; pub struct Peer { pub peer_node_id: PeerId, conns: ConnMap, global_ctx: ArcGlobalCtx, packet_recv_chan: PacketRecvChan, close_event_sender: mpsc::Sender, close_event_listener: ScopedTask<()>, shutdown_notifier: Arc, default_conn_id: Arc>, peer_identity_type: Arc>>, peer_public_key: Arc>>>, default_conn_id_clear_task: ScopedTask<()>, } impl Peer { pub fn new( peer_node_id: PeerId, packet_recv_chan: PacketRecvChan, global_ctx: ArcGlobalCtx, ) -> Self { let conns: ConnMap = Arc::new(DashMap::new()); let (close_event_sender, mut close_event_receiver) = mpsc::channel(10); let shutdown_notifier = Arc::new(tokio::sync::Notify::new()); let peer_identity_type = Arc::new(AtomicCell::new(None)); let peer_identity_type_copy = peer_identity_type.clone(); let peer_public_key = Arc::new(RwLock::new(None)); let peer_public_key_copy = peer_public_key.clone(); let conns_copy = conns.clone(); let shutdown_notifier_copy = shutdown_notifier.clone(); let global_ctx_copy = global_ctx.clone(); let close_event_listener = tokio::spawn( async move { loop { select! { ret = close_event_receiver.recv() => { if ret.is_none() { break; } let ret = ret.unwrap(); tracing::warn!( ?peer_node_id, ?ret, "notified that peer conn is closed", ); if let Some((_, conn)) = conns_copy.remove(&ret) { global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved( conn.get_conn_info(), )); shrink_dashmap(&conns_copy, Some(4)); if conns_copy.is_empty() { peer_identity_type_copy.store(None); *peer_public_key_copy.write() = None; } } } _ = shutdown_notifier_copy.notified() => { close_event_receiver.close(); tracing::warn!(?peer_node_id, "peer close event listener notified"); } } } tracing::info!("peer {} close event listener exit", peer_node_id); } .instrument(tracing::info_span!( "peer_close_event_listener", ?peer_node_id, )), ) .into(); let default_conn_id = Arc::new(AtomicCell::new(PeerConnId::default())); let conns_copy = conns.clone(); let default_conn_id_copy = default_conn_id.clone(); let default_conn_id_clear_task = ScopedTask::from(tokio::spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(5)).await; if conns_copy.len() > 1 { default_conn_id_copy.store(PeerConnId::default()); } } })); Peer { peer_node_id, conns, packet_recv_chan, global_ctx, close_event_sender, close_event_listener, shutdown_notifier, default_conn_id, peer_identity_type, peer_public_key, default_conn_id_clear_task, } } pub async fn add_peer_conn(&self, mut conn: PeerConn) -> Result<(), Error> { let conn_identity_type = conn.get_peer_identity_type(); let peer_identity_type = self.peer_identity_type.load(); if let Some(peer_identity_type) = peer_identity_type { if peer_identity_type != conn_identity_type { return Err(Error::SecretKeyError(format!( "peer identity type mismatch. peer: {:?}, conn: {:?}", peer_identity_type, conn_identity_type ))); } } else { self.peer_identity_type.store(Some(conn_identity_type)); } let close_notifier = conn.get_close_notifier(); let conn_info = conn.get_conn_info(); let conn_pubkey = conn_info.noise_remote_static_pubkey.clone(); { let mut peer_pubkey = self.peer_public_key.write(); if let Some(existing_pubkey) = peer_pubkey.as_ref() { if existing_pubkey != &conn_pubkey { return Err(Error::SecretKeyError(format!( "peer public key mismatch. peer_id: {}, existing_len: {}, new_len: {}", self.peer_node_id, existing_pubkey.len(), conn_pubkey.len() ))); } } else { *peer_pubkey = Some(conn_pubkey); } } conn.start_recv_loop(self.packet_recv_chan.clone()).await; conn.start_pingpong(); self.conns.insert(conn.get_conn_id(), Arc::new(conn)); let close_event_sender = self.close_event_sender.clone(); tokio::spawn(async move { let conn_id = close_notifier.get_conn_id(); if let Some(mut waiter) = close_notifier.get_waiter().await { let _ = waiter.recv().await; } if let Err(e) = close_event_sender.send(conn_id).await { tracing::warn!(?conn_id, "failed to send close event: {}", e); } }); self.global_ctx .issue_event(GlobalCtxEvent::PeerConnAdded(conn_info)); Ok(()) } async fn select_conn(&self) -> Option { let default_conn_id = self.default_conn_id.load(); if let Some(conn) = self.conns.get(&default_conn_id) { return Some(conn.clone()); } // find a conn with the smallest latency let mut min_latency = u64::MAX; for conn in self.conns.iter() { let latency = conn.value().get_stats().latency_us; if latency < min_latency { min_latency = latency; self.default_conn_id.store(conn.get_conn_id()); } } self.conns .get(&self.default_conn_id.load()) .map(|conn| conn.clone()) } pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> { let Some(conn) = self.select_conn().await else { return Err(Error::PeerNoConnectionError(self.peer_node_id)); }; conn.send_msg(msg).await?; Ok(()) } pub async fn close_peer_conn(&self, conn_id: &PeerConnId) -> Result<(), Error> { let has_key = self.conns.contains_key(conn_id); if !has_key { return Err(Error::NotFound); } self.close_event_sender.send(*conn_id).await.unwrap(); Ok(()) } pub async fn list_peer_conns(&self) -> Vec { let mut conns = vec![]; for conn in self.conns.iter() { // do not lock here, otherwise it will cause dashmap deadlock conns.push(conn.clone()); } let mut ret = Vec::new(); for conn in conns { let info = conn.get_conn_info(); if !info.is_closed { ret.push(info); } else { let conn_id = info.conn_id.parse().unwrap(); let _ = self.close_peer_conn(&conn_id).await; } } ret } pub fn has_live_conns(&self) -> bool { self.conns.iter().any(|entry| !entry.value().is_closed()) } pub fn has_directly_connected_conn(&self) -> bool { self.conns .iter() .any(|entry| !entry.value().is_closed() && !entry.value().is_hole_punched()) } pub fn get_directly_connections(&self) -> DashSet { self.conns .iter() .filter(|entry| !(entry.value()).is_hole_punched()) .map(|entry| (entry.value()).get_conn_id()) .collect() } pub fn get_default_conn_id(&self) -> PeerConnId { self.default_conn_id.load() } pub fn get_peer_identity_type(&self) -> Option { self.peer_identity_type.load() } pub fn get_peer_public_key(&self) -> Option> { self.peer_public_key.read().clone() } } // pritn on drop impl Drop for Peer { fn drop(&mut self) { self.conns.retain(|_, conn| { self.global_ctx .issue_event(GlobalCtxEvent::PeerConnRemoved(conn.get_conn_info())); false }); self.shutdown_notifier.notify_one(); tracing::info!("peer {} drop", self.peer_node_id); } } #[cfg(test)] mod tests { use base64::prelude::{Engine as _, BASE64_STANDARD}; use rand::rngs::OsRng; use std::sync::Arc; use tokio::time::timeout; use crate::{ common::{ config::{NetworkIdentity, PeerConfig}, global_ctx::{tests::get_mock_global_ctx, GlobalCtx}, new_peer_id, }, peers::{create_packet_recv_chan, peer_conn::PeerConn, peer_session::PeerSessionStore}, proto::common::SecureModeConfig, tunnel::ring::create_ring_tunnel_pair, }; use super::Peer; fn set_secure_mode_cfg(global_ctx: &GlobalCtx, enabled: bool) { if !enabled { global_ctx.config.set_secure_mode(None); } else { let private = x25519_dalek::StaticSecret::random_from_rng(OsRng); let public = x25519_dalek::PublicKey::from(&private); global_ctx.config.set_secure_mode(Some(SecureModeConfig { enabled: true, local_private_key: Some(BASE64_STANDARD.encode(private.as_bytes())), local_public_key: Some(BASE64_STANDARD.encode(public.as_bytes())), })); } } #[tokio::test] async fn close_peer() { let (local_packet_send, _local_packet_recv) = create_packet_recv_chan(); let (remote_packet_send, _remote_packet_recv) = create_packet_recv_chan(); let global_ctx = get_mock_global_ctx(); let local_peer = Peer::new(new_peer_id(), local_packet_send, global_ctx.clone()); let remote_peer = Peer::new(new_peer_id(), remote_packet_send, global_ctx.clone()); let ps = Arc::new(PeerSessionStore::new()); let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair(); let mut local_peer_conn = PeerConn::new( local_peer.peer_node_id, global_ctx.clone(), local_tunnel, ps.clone(), ); let mut remote_peer_conn = PeerConn::new( remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel, ps.clone(), ); assert!(!local_peer_conn.handshake_done()); assert!(!remote_peer_conn.handshake_done()); let (a, b) = tokio::join!( local_peer_conn.do_handshake_as_client(), remote_peer_conn.do_handshake_as_server() ); a.unwrap(); b.unwrap(); let local_conn_id = local_peer_conn.get_conn_id(); local_peer.add_peer_conn(local_peer_conn).await.unwrap(); remote_peer.add_peer_conn(remote_peer_conn).await.unwrap(); assert_eq!(local_peer.list_peer_conns().await.len(), 1); assert_eq!(remote_peer.list_peer_conns().await.len(), 1); let close_handler = tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await }); // wait for remote peer conn close timeout(std::time::Duration::from_secs(5), async { while !remote_peer.list_peer_conns().await.is_empty() { tokio::time::sleep(std::time::Duration::from_millis(100)).await; } }) .await .unwrap(); println!("wait for close handler"); close_handler.await.unwrap().unwrap(); } #[tokio::test] async fn reject_peer_conn_with_mismatched_identity_type() { let (packet_send, _packet_recv) = create_packet_recv_chan(); let global_ctx = get_mock_global_ctx(); let local_peer_id = new_peer_id(); let remote_peer_id = new_peer_id(); let peer = Peer::new(remote_peer_id, packet_send, global_ctx); let ps = Arc::new(PeerSessionStore::new()); let (shared_client_tunnel, shared_server_tunnel) = create_ring_tunnel_pair(); let shared_client_ctx = get_mock_global_ctx(); let shared_server_ctx = get_mock_global_ctx(); shared_client_ctx .config .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec2".to_string())); shared_server_ctx .config .set_network_identity(NetworkIdentity { network_name: "net2".to_string(), network_secret: None, network_secret_digest: None, }); set_secure_mode_cfg(&shared_client_ctx, true); set_secure_mode_cfg(&shared_server_ctx, true); let remote_url: url::Url = shared_client_tunnel .info() .unwrap() .remote_addr .unwrap() .url .parse() .unwrap(); shared_client_ctx.config.set_peers(vec![PeerConfig { uri: remote_url, peer_public_key: Some( shared_server_ctx .config .get_secure_mode() .unwrap() .local_public_key .unwrap(), ), }]); let mut shared_client_conn = PeerConn::new( local_peer_id, shared_client_ctx, Box::new(shared_client_tunnel), ps.clone(), ); let mut shared_server_conn = PeerConn::new( remote_peer_id, shared_server_ctx, Box::new(shared_server_tunnel), ps.clone(), ); let (c1, s1) = tokio::join!( shared_client_conn.do_handshake_as_client(), shared_server_conn.do_handshake_as_server() ); c1.unwrap(); s1.unwrap(); assert_eq!( shared_client_conn.get_peer_identity_type(), crate::proto::peer_rpc::PeerIdentityType::SharedNode ); let (admin_client_tunnel, admin_server_tunnel) = create_ring_tunnel_pair(); let admin_client_ctx = get_mock_global_ctx(); let admin_server_ctx = get_mock_global_ctx(); admin_client_ctx .config .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec2".to_string())); admin_server_ctx .config .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec2".to_string())); set_secure_mode_cfg(&admin_client_ctx, true); set_secure_mode_cfg(&admin_server_ctx, true); let mut admin_client_conn = PeerConn::new( local_peer_id, admin_client_ctx, Box::new(admin_client_tunnel), Arc::new(PeerSessionStore::new()), ); let mut admin_server_conn = PeerConn::new( remote_peer_id, admin_server_ctx, Box::new(admin_server_tunnel), Arc::new(PeerSessionStore::new()), ); let (c2, s2) = tokio::join!( admin_client_conn.do_handshake_as_client(), admin_server_conn.do_handshake_as_server() ); c2.unwrap(); s2.unwrap(); assert_eq!( admin_client_conn.get_peer_identity_type(), crate::proto::peer_rpc::PeerIdentityType::Admin ); peer.add_peer_conn(shared_client_conn).await.unwrap(); let ret = peer.add_peer_conn(admin_client_conn).await; assert!(ret.is_err()); } #[tokio::test] async fn reject_peer_conn_with_mismatched_public_key() { let (packet_send, _packet_recv) = create_packet_recv_chan(); let local_peer_id = new_peer_id(); let remote_peer_id = new_peer_id(); let peer = Peer::new(remote_peer_id, packet_send, get_mock_global_ctx()); let ps = Arc::new(PeerSessionStore::new()); let (client_tunnel_1, server_tunnel_1) = create_ring_tunnel_pair(); let client_ctx_1 = get_mock_global_ctx(); let server_ctx_1 = get_mock_global_ctx(); client_ctx_1 .config .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string())); server_ctx_1 .config .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string())); set_secure_mode_cfg(&client_ctx_1, true); set_secure_mode_cfg(&server_ctx_1, true); let mut client_conn_1 = PeerConn::new( local_peer_id, client_ctx_1, Box::new(client_tunnel_1), ps.clone(), ); let mut server_conn_1 = PeerConn::new( remote_peer_id, server_ctx_1, Box::new(server_tunnel_1), ps.clone(), ); let (c1, s1) = tokio::join!( client_conn_1.do_handshake_as_client(), server_conn_1.do_handshake_as_server() ); c1.unwrap(); s1.unwrap(); let (client_tunnel_2, server_tunnel_2) = create_ring_tunnel_pair(); let client_ctx_2 = get_mock_global_ctx(); let server_ctx_2 = get_mock_global_ctx(); client_ctx_2 .config .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string())); server_ctx_2 .config .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string())); set_secure_mode_cfg(&client_ctx_2, true); set_secure_mode_cfg(&server_ctx_2, true); let mut client_conn_2 = PeerConn::new( local_peer_id, client_ctx_2, Box::new(client_tunnel_2), Arc::new(PeerSessionStore::new()), ); let mut server_conn_2 = PeerConn::new( remote_peer_id, server_ctx_2, Box::new(server_tunnel_2), Arc::new(PeerSessionStore::new()), ); let (c2, s2) = tokio::join!( client_conn_2.do_handshake_as_client(), server_conn_2.do_handshake_as_server() ); c2.unwrap(); s2.unwrap(); let pubkey_1 = client_conn_1.get_conn_info().noise_remote_static_pubkey; let pubkey_2 = client_conn_2.get_conn_info().noise_remote_static_pubkey; assert_ne!(pubkey_1, pubkey_2); peer.add_peer_conn(client_conn_1).await.unwrap(); assert_eq!(peer.get_peer_public_key(), Some(pubkey_1)); let ret = peer.add_peer_conn(client_conn_2).await; assert!(ret.is_err()); } }