mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-15 10:25:40 +00:00
392 lines
14 KiB
Rust
392 lines
14 KiB
Rust
use std::marker::PhantomData;
|
|
use std::pin::Pin;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use bytes::Bytes;
|
|
use dashmap::DashMap;
|
|
use guarden::defer;
|
|
use prost::Message;
|
|
use tokio::sync::mpsc;
|
|
use tokio::task::JoinSet;
|
|
use tokio::time::timeout;
|
|
use tokio_stream::StreamExt;
|
|
|
|
use crate::common::{
|
|
PeerId, shrink_dashmap,
|
|
stats_manager::{LabelSet, LabelType, MetricName, StatsManager},
|
|
};
|
|
use crate::proto::common::{
|
|
CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse,
|
|
};
|
|
use crate::proto::rpc_impl::packet::{
|
|
BuildRpcPacketArgs, build_rpc_packet, compress_packet, decompress_packet,
|
|
};
|
|
use crate::proto::rpc_types::controller::Controller;
|
|
use crate::proto::rpc_types::descriptor::MethodDescriptor;
|
|
use crate::proto::rpc_types::{
|
|
__rt::RpcClientFactory, descriptor::ServiceDescriptor, handler::Handler,
|
|
};
|
|
|
|
use crate::proto::rpc_types::error::Result;
|
|
use crate::tunnel::mpsc::{MpscTunnel, MpscTunnelSender};
|
|
use crate::tunnel::packet_def::ZCPacket;
|
|
use crate::tunnel::ring::create_ring_tunnel_pair;
|
|
use crate::tunnel::{Tunnel, TunnelError, ZCPacketStream};
|
|
|
|
use super::packet::PacketMerger;
|
|
use super::{RpcTransactId, Transport};
|
|
|
|
static CUR_TID: once_cell::sync::Lazy<atomic_shim::AtomicI64> =
|
|
once_cell::sync::Lazy::new(|| atomic_shim::AtomicI64::new(rand::random()));
|
|
|
|
type RpcPacketSender = mpsc::UnboundedSender<RpcPacket>;
|
|
type RpcPacketReceiver = mpsc::UnboundedReceiver<RpcPacket>;
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
|
struct InflightRequestKey {
|
|
from_peer_id: PeerId,
|
|
to_peer_id: PeerId,
|
|
transaction_id: RpcTransactId,
|
|
}
|
|
|
|
struct InflightRequest {
|
|
sender: RpcPacketSender,
|
|
merger: PacketMerger,
|
|
start_time: std::time::Instant,
|
|
}
|
|
|
|
impl std::fmt::Debug for InflightRequest {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("InflightRequest")
|
|
.field("sender", &self.sender)
|
|
.field("start_time", &self.start_time)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct PeerInfo {
|
|
pub peer_id: PeerId,
|
|
pub compression_info: RpcCompressionInfo,
|
|
pub last_active: Option<std::time::Instant>,
|
|
}
|
|
|
|
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
|
|
pub type PeerInfoTable = Arc<DashMap<PeerId, PeerInfo>>;
|
|
|
|
pub struct Client {
|
|
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
|
|
transport: Mutex<Transport>,
|
|
inflight_requests: InflightRequestTable,
|
|
peer_info: PeerInfoTable,
|
|
tasks: Mutex<JoinSet<()>>,
|
|
stats_manager: Option<Arc<StatsManager>>,
|
|
}
|
|
|
|
impl Default for Client {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl Client {
|
|
pub fn new() -> Self {
|
|
let (ring_a, ring_b) = create_ring_tunnel_pair();
|
|
Self {
|
|
mpsc: Mutex::new(MpscTunnel::new(ring_a, None)),
|
|
transport: Mutex::new(MpscTunnel::new(ring_b, None)),
|
|
inflight_requests: Arc::new(DashMap::new()),
|
|
peer_info: Arc::new(DashMap::new()),
|
|
tasks: Mutex::new(JoinSet::new()),
|
|
stats_manager: None,
|
|
}
|
|
}
|
|
|
|
pub fn new_with_stats_manager(stats_manager: Arc<StatsManager>) -> Self {
|
|
let mut ret = Self::new();
|
|
ret.stats_manager = Some(stats_manager);
|
|
ret
|
|
}
|
|
|
|
pub fn get_transport_sink(&self) -> MpscTunnelSender {
|
|
self.transport.lock().unwrap().get_sink()
|
|
}
|
|
|
|
pub fn get_transport_stream(&self) -> Pin<Box<dyn ZCPacketStream>> {
|
|
self.transport.lock().unwrap().get_stream()
|
|
}
|
|
|
|
pub fn run(&self) {
|
|
let mut tasks = self.tasks.lock().unwrap();
|
|
|
|
let peer_infos = self.peer_info.clone();
|
|
tasks.spawn(async move {
|
|
loop {
|
|
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
|
let now = std::time::Instant::now();
|
|
peer_infos.retain(|_, v| {
|
|
if let Some(last_active) = v.last_active {
|
|
return now.duration_since(last_active)
|
|
< std::time::Duration::from_secs(120);
|
|
}
|
|
true
|
|
});
|
|
peer_infos.shrink_to_fit();
|
|
}
|
|
});
|
|
|
|
let mut rx = self.mpsc.lock().unwrap().get_stream();
|
|
let inflight_requests = self.inflight_requests.clone();
|
|
tasks.spawn(async move {
|
|
while let Some(packet) = rx.next().await {
|
|
if let Err(err) = packet {
|
|
tracing::error!(?err, "Failed to receive packet");
|
|
continue;
|
|
}
|
|
let packet = match RpcPacket::decode(packet.unwrap().payload()) {
|
|
Err(err) => {
|
|
tracing::error!(?err, "Failed to decode packet");
|
|
continue;
|
|
}
|
|
Ok(packet) => packet,
|
|
};
|
|
|
|
if packet.is_request {
|
|
tracing::warn!(?packet, "Received non-response packet");
|
|
continue;
|
|
}
|
|
|
|
let key = InflightRequestKey {
|
|
from_peer_id: packet.to_peer,
|
|
to_peer_id: packet.from_peer,
|
|
transaction_id: packet.transaction_id,
|
|
};
|
|
|
|
let Some(mut inflight_request) = inflight_requests.get_mut(&key) else {
|
|
tracing::warn!(
|
|
?key,
|
|
?inflight_requests,
|
|
"No inflight request found for key"
|
|
);
|
|
continue;
|
|
};
|
|
|
|
tracing::trace!(?packet, "Received response packet");
|
|
|
|
let ret = inflight_request.merger.feed(packet);
|
|
match ret {
|
|
Ok(Some(rpc_packet)) => {
|
|
inflight_request.sender.send(rpc_packet).unwrap();
|
|
}
|
|
Ok(None) => {}
|
|
Err(err) => {
|
|
tracing::error!(?err, "Failed to feed packet to merger");
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
pub fn scoped_client<F: RpcClientFactory>(
|
|
&self,
|
|
from_peer_id: PeerId,
|
|
to_peer_id: PeerId,
|
|
domain_name: String,
|
|
) -> F::ClientImpl {
|
|
#[derive(Clone)]
|
|
struct HandlerImpl<F> {
|
|
domain_name: String,
|
|
from_peer_id: PeerId,
|
|
to_peer_id: PeerId,
|
|
zc_packet_sender: MpscTunnelSender,
|
|
inflight_requests: InflightRequestTable,
|
|
peer_info: PeerInfoTable,
|
|
stats_manager: Option<Arc<StatsManager>>,
|
|
_phan: PhantomData<F>,
|
|
}
|
|
|
|
impl<F: RpcClientFactory> HandlerImpl<F> {
|
|
async fn do_rpc(
|
|
&self,
|
|
packets: Vec<ZCPacket>,
|
|
rx: &mut RpcPacketReceiver,
|
|
) -> Result<RpcPacket> {
|
|
for packet in packets {
|
|
self.zc_packet_sender.send(packet).await?;
|
|
}
|
|
|
|
Ok(rx.recv().await.ok_or(TunnelError::Shutdown)?)
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl<F: RpcClientFactory> Handler for HandlerImpl<F> {
|
|
type Descriptor = F::Descriptor;
|
|
type Controller = F::Controller;
|
|
|
|
async fn call(
|
|
&self,
|
|
mut ctrl: Self::Controller,
|
|
method: <Self::Descriptor as ServiceDescriptor>::Method,
|
|
input: bytes::Bytes,
|
|
) -> Result<bytes::Bytes> {
|
|
let start_time = std::time::Instant::now();
|
|
let transaction_id = CUR_TID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
|
let (tx, mut rx) = mpsc::unbounded_channel();
|
|
let key = InflightRequestKey {
|
|
from_peer_id: self.from_peer_id,
|
|
to_peer_id: self.to_peer_id,
|
|
transaction_id,
|
|
};
|
|
let desc = self.service_descriptor();
|
|
let labels = LabelSet::new()
|
|
.with_label_type(LabelType::NetworkName(self.domain_name.to_string()))
|
|
.with_label_type(LabelType::SrcPeerId(self.from_peer_id))
|
|
.with_label_type(LabelType::DstPeerId(self.to_peer_id))
|
|
.with_label_type(LabelType::ServiceName(desc.name().to_string()))
|
|
.with_label_type(LabelType::MethodName(method.name().to_string()));
|
|
|
|
defer!(self.inflight_requests.remove(&key); shrink_dashmap(&self.inflight_requests, Some(4)););
|
|
self.inflight_requests.insert(
|
|
key.clone(),
|
|
InflightRequest {
|
|
sender: tx,
|
|
merger: PacketMerger::new(),
|
|
start_time,
|
|
},
|
|
);
|
|
|
|
// Record RPC client TX stats
|
|
if let Some(ref stats_manager) = self.stats_manager {
|
|
stats_manager
|
|
.get_counter(MetricName::PeerRpcClientTx, labels.clone())
|
|
.inc();
|
|
}
|
|
|
|
let rpc_desc = RpcDescriptor {
|
|
domain_name: self.domain_name.clone(),
|
|
proto_name: desc.proto_name().to_string(),
|
|
service_name: desc.name().to_string(),
|
|
method_index: method.index() as u32,
|
|
};
|
|
|
|
let rpc_req = RpcRequest {
|
|
request: if let Some(raw_input) = ctrl.get_raw_input() {
|
|
raw_input.into()
|
|
} else {
|
|
input.into()
|
|
},
|
|
timeout_ms: ctrl.timeout_ms(),
|
|
..Default::default()
|
|
};
|
|
|
|
let peer_info = self
|
|
.peer_info
|
|
.get(&self.to_peer_id)
|
|
.map(|v| v.clone())
|
|
.unwrap_or_default();
|
|
let (buf, c_algo) = compress_packet(
|
|
peer_info.compression_info.accepted_algo(),
|
|
&rpc_req.encode_to_vec(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
let packets = build_rpc_packet(BuildRpcPacketArgs {
|
|
from_peer: self.from_peer_id,
|
|
to_peer: self.to_peer_id,
|
|
rpc_desc,
|
|
transaction_id,
|
|
is_req: true,
|
|
content: &buf,
|
|
trace_id: ctrl.trace_id(),
|
|
compression_info: RpcCompressionInfo {
|
|
algo: c_algo.into(),
|
|
accepted_algo: CompressionAlgoPb::Zstd.into(),
|
|
},
|
|
});
|
|
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
|
|
let mut rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
|
|
|
|
if let Some(compression_info) = rpc_packet.compression_info {
|
|
self.peer_info.insert(
|
|
self.to_peer_id,
|
|
PeerInfo {
|
|
peer_id: self.to_peer_id,
|
|
compression_info,
|
|
last_active: Some(std::time::Instant::now()),
|
|
},
|
|
);
|
|
|
|
rpc_packet.body =
|
|
decompress_packet(compression_info.algo(), &rpc_packet.body).await?;
|
|
}
|
|
|
|
assert_eq!(rpc_packet.transaction_id, transaction_id);
|
|
|
|
let rpc_resp = RpcResponse::decode(Bytes::from(rpc_packet.body))?;
|
|
|
|
if let Some(err) = &rpc_resp.error {
|
|
// Record RPC error stats
|
|
if let Some(ref stats_manager) = self.stats_manager {
|
|
let labels = labels
|
|
.clone()
|
|
.with_label_type(LabelType::ErrorType(format!("{:?}", err.error_kind)))
|
|
.with_label_type(LabelType::Status("error".to_string()));
|
|
|
|
stats_manager
|
|
.get_counter(MetricName::PeerRpcErrors, labels.clone())
|
|
.inc();
|
|
|
|
let duration_ms = start_time.elapsed().as_millis() as u64;
|
|
stats_manager
|
|
.get_counter(MetricName::PeerRpcDuration, labels)
|
|
.add(duration_ms);
|
|
}
|
|
return Err(err.into());
|
|
}
|
|
|
|
let raw_output = Bytes::from(rpc_resp.response);
|
|
ctrl.set_raw_output(raw_output.clone());
|
|
|
|
// Record RPC client RX and duration stats
|
|
if let Some(ref stats_manager) = self.stats_manager {
|
|
let labels = labels
|
|
.clone()
|
|
.with_label_type(LabelType::Status("success".to_string()));
|
|
|
|
stats_manager
|
|
.get_counter(MetricName::PeerRpcClientRx, labels.clone())
|
|
.inc();
|
|
|
|
let duration_ms = start_time.elapsed().as_millis() as u64;
|
|
stats_manager
|
|
.get_counter(MetricName::PeerRpcDuration, labels)
|
|
.add(duration_ms);
|
|
}
|
|
|
|
Ok(raw_output)
|
|
}
|
|
}
|
|
|
|
F::new(HandlerImpl::<F> {
|
|
domain_name,
|
|
from_peer_id,
|
|
to_peer_id,
|
|
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
|
|
inflight_requests: self.inflight_requests.clone(),
|
|
peer_info: self.peer_info.clone(),
|
|
stats_manager: self.stats_manager.clone(),
|
|
_phan: PhantomData,
|
|
})
|
|
}
|
|
|
|
pub fn inflight_count(&self) -> usize {
|
|
self.inflight_requests.len()
|
|
}
|
|
|
|
pub fn peer_info_table(&self) -> PeerInfoTable {
|
|
self.peer_info.clone()
|
|
}
|
|
}
|