use std::marker::PhantomData; use std::pin::Pin; use std::sync::{Arc, Mutex}; use bytes::Bytes; use dashmap::DashMap; use guarden::defer; use prost::Message; use tokio::sync::mpsc; use tokio::task::JoinSet; use tokio::time::timeout; use tokio_stream::StreamExt; use crate::common::{ PeerId, shrink_dashmap, stats_manager::{LabelSet, LabelType, MetricName, StatsManager}, }; use crate::proto::common::{ CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse, }; use crate::proto::rpc_impl::packet::{ BuildRpcPacketArgs, build_rpc_packet, compress_packet, decompress_packet, }; use crate::proto::rpc_types::controller::Controller; use crate::proto::rpc_types::descriptor::MethodDescriptor; use crate::proto::rpc_types::{ __rt::RpcClientFactory, descriptor::ServiceDescriptor, handler::Handler, }; use crate::proto::rpc_types::error::Result; use crate::tunnel::mpsc::{MpscTunnel, MpscTunnelSender}; use crate::tunnel::packet_def::ZCPacket; use crate::tunnel::ring::create_ring_tunnel_pair; use crate::tunnel::{Tunnel, TunnelError, ZCPacketStream}; use super::packet::PacketMerger; use super::{RpcTransactId, Transport}; static CUR_TID: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| atomic_shim::AtomicI64::new(rand::random())); type RpcPacketSender = mpsc::UnboundedSender; type RpcPacketReceiver = mpsc::UnboundedReceiver; #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct InflightRequestKey { from_peer_id: PeerId, to_peer_id: PeerId, transaction_id: RpcTransactId, } struct InflightRequest { sender: RpcPacketSender, merger: PacketMerger, start_time: std::time::Instant, } impl std::fmt::Debug for InflightRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("InflightRequest") .field("sender", &self.sender) .field("start_time", &self.start_time) .finish() } } #[derive(Debug, Clone, Default)] pub struct PeerInfo { pub peer_id: PeerId, pub compression_info: RpcCompressionInfo, pub last_active: Option, } type InflightRequestTable = Arc>; pub type PeerInfoTable = Arc>; pub struct Client { mpsc: Mutex>>, transport: Mutex, inflight_requests: InflightRequestTable, peer_info: PeerInfoTable, tasks: Mutex>, stats_manager: Option>, } impl Default for Client { fn default() -> Self { Self::new() } } impl Client { pub fn new() -> Self { let (ring_a, ring_b) = create_ring_tunnel_pair(); Self { mpsc: Mutex::new(MpscTunnel::new(ring_a, None)), transport: Mutex::new(MpscTunnel::new(ring_b, None)), inflight_requests: Arc::new(DashMap::new()), peer_info: Arc::new(DashMap::new()), tasks: Mutex::new(JoinSet::new()), stats_manager: None, } } pub fn new_with_stats_manager(stats_manager: Arc) -> Self { let mut ret = Self::new(); ret.stats_manager = Some(stats_manager); ret } pub fn get_transport_sink(&self) -> MpscTunnelSender { self.transport.lock().unwrap().get_sink() } pub fn get_transport_stream(&self) -> Pin> { self.transport.lock().unwrap().get_stream() } pub fn run(&self) { let mut tasks = self.tasks.lock().unwrap(); let peer_infos = self.peer_info.clone(); tasks.spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(30)).await; let now = std::time::Instant::now(); peer_infos.retain(|_, v| { if let Some(last_active) = v.last_active { return now.duration_since(last_active) < std::time::Duration::from_secs(120); } true }); peer_infos.shrink_to_fit(); } }); let mut rx = self.mpsc.lock().unwrap().get_stream(); let inflight_requests = self.inflight_requests.clone(); tasks.spawn(async move { while let Some(packet) = rx.next().await { if let Err(err) = packet { tracing::error!(?err, "Failed to receive packet"); continue; } let packet = match RpcPacket::decode(packet.unwrap().payload()) { Err(err) => { tracing::error!(?err, "Failed to decode packet"); continue; } Ok(packet) => packet, }; if packet.is_request { tracing::warn!(?packet, "Received non-response packet"); continue; } let key = InflightRequestKey { from_peer_id: packet.to_peer, to_peer_id: packet.from_peer, transaction_id: packet.transaction_id, }; let Some(mut inflight_request) = inflight_requests.get_mut(&key) else { tracing::warn!( ?key, ?inflight_requests, "No inflight request found for key" ); continue; }; tracing::trace!(?packet, "Received response packet"); let ret = inflight_request.merger.feed(packet); match ret { Ok(Some(rpc_packet)) => { inflight_request.sender.send(rpc_packet).unwrap(); } Ok(None) => {} Err(err) => { tracing::error!(?err, "Failed to feed packet to merger"); } } } }); } pub fn scoped_client( &self, from_peer_id: PeerId, to_peer_id: PeerId, domain_name: String, ) -> F::ClientImpl { #[derive(Clone)] struct HandlerImpl { domain_name: String, from_peer_id: PeerId, to_peer_id: PeerId, zc_packet_sender: MpscTunnelSender, inflight_requests: InflightRequestTable, peer_info: PeerInfoTable, stats_manager: Option>, _phan: PhantomData, } impl HandlerImpl { async fn do_rpc( &self, packets: Vec, rx: &mut RpcPacketReceiver, ) -> Result { for packet in packets { self.zc_packet_sender.send(packet).await?; } Ok(rx.recv().await.ok_or(TunnelError::Shutdown)?) } } #[async_trait::async_trait] impl Handler for HandlerImpl { type Descriptor = F::Descriptor; type Controller = F::Controller; async fn call( &self, mut ctrl: Self::Controller, method: ::Method, input: bytes::Bytes, ) -> Result { let start_time = std::time::Instant::now(); let transaction_id = CUR_TID.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let (tx, mut rx) = mpsc::unbounded_channel(); let key = InflightRequestKey { from_peer_id: self.from_peer_id, to_peer_id: self.to_peer_id, transaction_id, }; let desc = self.service_descriptor(); let labels = LabelSet::new() .with_label_type(LabelType::NetworkName(self.domain_name.to_string())) .with_label_type(LabelType::SrcPeerId(self.from_peer_id)) .with_label_type(LabelType::DstPeerId(self.to_peer_id)) .with_label_type(LabelType::ServiceName(desc.name().to_string())) .with_label_type(LabelType::MethodName(method.name().to_string())); defer!(self.inflight_requests.remove(&key); shrink_dashmap(&self.inflight_requests, Some(4));); self.inflight_requests.insert( key.clone(), InflightRequest { sender: tx, merger: PacketMerger::new(), start_time, }, ); // Record RPC client TX stats if let Some(ref stats_manager) = self.stats_manager { stats_manager .get_counter(MetricName::PeerRpcClientTx, labels.clone()) .inc(); } let rpc_desc = RpcDescriptor { domain_name: self.domain_name.clone(), proto_name: desc.proto_name().to_string(), service_name: desc.name().to_string(), method_index: method.index() as u32, }; let rpc_req = RpcRequest { request: if let Some(raw_input) = ctrl.get_raw_input() { raw_input.into() } else { input.into() }, timeout_ms: ctrl.timeout_ms(), ..Default::default() }; let peer_info = self .peer_info .get(&self.to_peer_id) .map(|v| v.clone()) .unwrap_or_default(); let (buf, c_algo) = compress_packet( peer_info.compression_info.accepted_algo(), &rpc_req.encode_to_vec(), ) .await .unwrap(); let packets = build_rpc_packet(BuildRpcPacketArgs { from_peer: self.from_peer_id, to_peer: self.to_peer_id, rpc_desc, transaction_id, is_req: true, content: &buf, trace_id: ctrl.trace_id(), compression_info: RpcCompressionInfo { algo: c_algo.into(), accepted_algo: CompressionAlgoPb::Zstd.into(), }, }); let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64); let mut rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??; if let Some(compression_info) = rpc_packet.compression_info { self.peer_info.insert( self.to_peer_id, PeerInfo { peer_id: self.to_peer_id, compression_info, last_active: Some(std::time::Instant::now()), }, ); rpc_packet.body = decompress_packet(compression_info.algo(), &rpc_packet.body).await?; } assert_eq!(rpc_packet.transaction_id, transaction_id); let rpc_resp = RpcResponse::decode(Bytes::from(rpc_packet.body))?; if let Some(err) = &rpc_resp.error { // Record RPC error stats if let Some(ref stats_manager) = self.stats_manager { let labels = labels .clone() .with_label_type(LabelType::ErrorType(format!("{:?}", err.error_kind))) .with_label_type(LabelType::Status("error".to_string())); stats_manager .get_counter(MetricName::PeerRpcErrors, labels.clone()) .inc(); let duration_ms = start_time.elapsed().as_millis() as u64; stats_manager .get_counter(MetricName::PeerRpcDuration, labels) .add(duration_ms); } return Err(err.into()); } let raw_output = Bytes::from(rpc_resp.response); ctrl.set_raw_output(raw_output.clone()); // Record RPC client RX and duration stats if let Some(ref stats_manager) = self.stats_manager { let labels = labels .clone() .with_label_type(LabelType::Status("success".to_string())); stats_manager .get_counter(MetricName::PeerRpcClientRx, labels.clone()) .inc(); let duration_ms = start_time.elapsed().as_millis() as u64; stats_manager .get_counter(MetricName::PeerRpcDuration, labels) .add(duration_ms); } Ok(raw_output) } } F::new(HandlerImpl:: { domain_name, from_peer_id, to_peer_id, zc_packet_sender: self.mpsc.lock().unwrap().get_sink(), inflight_requests: self.inflight_requests.clone(), peer_info: self.peer_info.clone(), stats_manager: self.stats_manager.clone(), _phan: PhantomData, }) } pub fn inflight_count(&self) -> usize { self.inflight_requests.len() } pub fn peer_info_table(&self) -> PeerInfoTable { self.peer_info.clone() } }