mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-15 10:25:40 +00:00
use customized rpc implementation, remove Tarpc & Tonic (#348)
This patch removes Tarpc & Tonic GRPC and implements a customized rpc framework, which can be used by peer rpc and cli interface. web config server can also use this rpc framework. moreover, rewrite the public server logic, use ospf route to implement public server based networking. this make public server mesh possible.
This commit is contained in:
@@ -0,0 +1,240 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use prost::Message;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::timeout;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use crate::common::PeerId;
|
||||
use crate::defer;
|
||||
use crate::proto::common::{RpcDescriptor, RpcPacket, RpcRequest, RpcResponse};
|
||||
use crate::proto::rpc_impl::packet::build_rpc_packet;
|
||||
use crate::proto::rpc_types::controller::Controller;
|
||||
use crate::proto::rpc_types::descriptor::MethodDescriptor;
|
||||
use crate::proto::rpc_types::{
|
||||
__rt::RpcClientFactory, descriptor::ServiceDescriptor, handler::Handler,
|
||||
};
|
||||
|
||||
use crate::proto::rpc_types::error::Result;
|
||||
use crate::tunnel::mpsc::{MpscTunnel, MpscTunnelSender};
|
||||
use crate::tunnel::packet_def::ZCPacket;
|
||||
use crate::tunnel::ring::create_ring_tunnel_pair;
|
||||
use crate::tunnel::{Tunnel, TunnelError, ZCPacketStream};
|
||||
|
||||
use super::packet::PacketMerger;
|
||||
use super::{RpcTransactId, Transport};
|
||||
|
||||
static CUR_TID: once_cell::sync::Lazy<atomic_shim::AtomicI64> =
|
||||
once_cell::sync::Lazy::new(|| atomic_shim::AtomicI64::new(rand::random()));
|
||||
|
||||
type RpcPacketSender = mpsc::UnboundedSender<RpcPacket>;
|
||||
type RpcPacketReceiver = mpsc::UnboundedReceiver<RpcPacket>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct InflightRequestKey {
|
||||
from_peer_id: PeerId,
|
||||
to_peer_id: PeerId,
|
||||
transaction_id: RpcTransactId,
|
||||
}
|
||||
|
||||
struct InflightRequest {
|
||||
sender: RpcPacketSender,
|
||||
merger: PacketMerger,
|
||||
start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
|
||||
|
||||
pub struct Client {
|
||||
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
|
||||
transport: Mutex<Transport>,
|
||||
inflight_requests: InflightRequestTable,
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new() -> Self {
|
||||
let (ring_a, ring_b) = create_ring_tunnel_pair();
|
||||
Self {
|
||||
mpsc: Mutex::new(MpscTunnel::new(ring_a)),
|
||||
transport: Mutex::new(MpscTunnel::new(ring_b)),
|
||||
inflight_requests: Arc::new(DashMap::new()),
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_transport_sink(&self) -> MpscTunnelSender {
|
||||
self.transport.lock().unwrap().get_sink()
|
||||
}
|
||||
|
||||
pub fn get_transport_stream(&self) -> Pin<Box<dyn ZCPacketStream>> {
|
||||
self.transport.lock().unwrap().get_stream()
|
||||
}
|
||||
|
||||
pub fn run(&self) {
|
||||
let mut tasks = self.tasks.lock().unwrap();
|
||||
|
||||
let mut rx = self.mpsc.lock().unwrap().get_stream();
|
||||
let inflight_requests = self.inflight_requests.clone();
|
||||
tasks.spawn(async move {
|
||||
while let Some(packet) = rx.next().await {
|
||||
if let Err(err) = packet {
|
||||
tracing::error!(?err, "Failed to receive packet");
|
||||
continue;
|
||||
}
|
||||
let packet = match RpcPacket::decode(packet.unwrap().payload()) {
|
||||
Err(err) => {
|
||||
tracing::error!(?err, "Failed to decode packet");
|
||||
continue;
|
||||
}
|
||||
Ok(packet) => packet,
|
||||
};
|
||||
|
||||
if packet.is_request {
|
||||
tracing::warn!(?packet, "Received non-response packet");
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = InflightRequestKey {
|
||||
from_peer_id: packet.to_peer,
|
||||
to_peer_id: packet.from_peer,
|
||||
transaction_id: packet.transaction_id,
|
||||
};
|
||||
|
||||
let Some(mut inflight_request) = inflight_requests.get_mut(&key) else {
|
||||
tracing::warn!(?key, "No inflight request found for key");
|
||||
continue;
|
||||
};
|
||||
|
||||
let ret = inflight_request.merger.feed(packet);
|
||||
match ret {
|
||||
Ok(Some(rpc_packet)) => {
|
||||
inflight_request.sender.send(rpc_packet).unwrap();
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(err) => {
|
||||
tracing::error!(?err, "Failed to feed packet to merger");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn scoped_client<F: RpcClientFactory>(
|
||||
&self,
|
||||
from_peer_id: PeerId,
|
||||
to_peer_id: PeerId,
|
||||
domain_name: String,
|
||||
) -> F::ClientImpl {
|
||||
#[derive(Clone)]
|
||||
struct HandlerImpl<F> {
|
||||
domain_name: String,
|
||||
from_peer_id: PeerId,
|
||||
to_peer_id: PeerId,
|
||||
zc_packet_sender: MpscTunnelSender,
|
||||
inflight_requests: InflightRequestTable,
|
||||
_phan: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RpcClientFactory> HandlerImpl<F> {
|
||||
async fn do_rpc(
|
||||
&self,
|
||||
packets: Vec<ZCPacket>,
|
||||
rx: &mut RpcPacketReceiver,
|
||||
) -> Result<RpcPacket> {
|
||||
for packet in packets {
|
||||
self.zc_packet_sender.send(packet).await?;
|
||||
}
|
||||
|
||||
Ok(rx.recv().await.ok_or(TunnelError::Shutdown)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<F: RpcClientFactory> Handler for HandlerImpl<F> {
|
||||
type Descriptor = F::Descriptor;
|
||||
type Controller = F::Controller;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
ctrl: Self::Controller,
|
||||
method: <Self::Descriptor as ServiceDescriptor>::Method,
|
||||
input: bytes::Bytes,
|
||||
) -> Result<bytes::Bytes> {
|
||||
let transaction_id = CUR_TID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
let key = InflightRequestKey {
|
||||
from_peer_id: self.from_peer_id,
|
||||
to_peer_id: self.to_peer_id,
|
||||
transaction_id,
|
||||
};
|
||||
|
||||
defer!(self.inflight_requests.remove(&key););
|
||||
self.inflight_requests.insert(
|
||||
key.clone(),
|
||||
InflightRequest {
|
||||
sender: tx,
|
||||
merger: PacketMerger::new(),
|
||||
start_time: std::time::Instant::now(),
|
||||
},
|
||||
);
|
||||
|
||||
let desc = self.service_descriptor();
|
||||
|
||||
let rpc_desc = RpcDescriptor {
|
||||
domain_name: self.domain_name.clone(),
|
||||
proto_name: desc.proto_name().to_string(),
|
||||
service_name: desc.name().to_string(),
|
||||
method_index: method.index() as u32,
|
||||
};
|
||||
|
||||
let rpc_req = RpcRequest {
|
||||
descriptor: Some(rpc_desc.clone()),
|
||||
request: input.into(),
|
||||
timeout_ms: ctrl.timeout_ms(),
|
||||
};
|
||||
|
||||
let packets = build_rpc_packet(
|
||||
self.from_peer_id,
|
||||
self.to_peer_id,
|
||||
rpc_desc,
|
||||
transaction_id,
|
||||
true,
|
||||
&rpc_req.encode_to_vec(),
|
||||
ctrl.trace_id(),
|
||||
);
|
||||
|
||||
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
|
||||
let rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
|
||||
|
||||
assert_eq!(rpc_packet.transaction_id, transaction_id);
|
||||
|
||||
let rpc_resp = RpcResponse::decode(Bytes::from(rpc_packet.body))?;
|
||||
|
||||
if let Some(err) = &rpc_resp.error {
|
||||
return Err(err.into());
|
||||
}
|
||||
|
||||
Ok(bytes::Bytes::from(rpc_resp.response))
|
||||
}
|
||||
}
|
||||
|
||||
F::new(HandlerImpl::<F> {
|
||||
domain_name: domain_name.to_string(),
|
||||
from_peer_id,
|
||||
to_peer_id,
|
||||
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
|
||||
inflight_requests: self.inflight_requests.clone(),
|
||||
_phan: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn inflight_count(&self) -> usize {
|
||||
self.inflight_requests.len()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
use crate::tunnel::{mpsc::MpscTunnel, Tunnel};
|
||||
|
||||
pub type RpcController = super::rpc_types::controller::BaseController;
|
||||
|
||||
pub mod client;
|
||||
pub mod packet;
|
||||
pub mod server;
|
||||
pub mod service_registry;
|
||||
pub mod standalone;
|
||||
|
||||
pub type Transport = MpscTunnel<Box<dyn Tunnel>>;
|
||||
pub type RpcTransactId = i64;
|
||||
@@ -0,0 +1,161 @@
|
||||
use prost::Message as _;
|
||||
|
||||
use crate::{
|
||||
common::PeerId,
|
||||
proto::{
|
||||
common::{RpcDescriptor, RpcPacket},
|
||||
rpc_types::error::Error,
|
||||
},
|
||||
tunnel::packet_def::{PacketType, ZCPacket},
|
||||
};
|
||||
|
||||
use super::RpcTransactId;
|
||||
|
||||
const RPC_PACKET_CONTENT_MTU: usize = 1300;
|
||||
|
||||
pub struct PacketMerger {
|
||||
first_piece: Option<RpcPacket>,
|
||||
pieces: Vec<RpcPacket>,
|
||||
last_updated: std::time::Instant,
|
||||
}
|
||||
|
||||
impl PacketMerger {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
first_piece: None,
|
||||
pieces: Vec::new(),
|
||||
last_updated: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_merge_pieces(&self) -> Option<RpcPacket> {
|
||||
if self.first_piece.is_none() || self.pieces.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
for p in &self.pieces {
|
||||
// some piece is missing
|
||||
if p.total_pieces == 0 {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// all pieces are received
|
||||
let mut body = Vec::new();
|
||||
for p in &self.pieces {
|
||||
body.extend_from_slice(&p.body);
|
||||
}
|
||||
|
||||
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone();
|
||||
tmpl_packet.total_pieces = 1;
|
||||
tmpl_packet.piece_idx = 0;
|
||||
tmpl_packet.body = body;
|
||||
|
||||
Some(tmpl_packet)
|
||||
}
|
||||
|
||||
pub fn feed(&mut self, rpc_packet: RpcPacket) -> Result<Option<RpcPacket>, Error> {
|
||||
let total_pieces = rpc_packet.total_pieces;
|
||||
let piece_idx = rpc_packet.piece_idx;
|
||||
|
||||
if rpc_packet.descriptor.is_none() {
|
||||
return Err(Error::MalformatRpcPacket(
|
||||
"descriptor is missing".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
// for compatibility with old version
|
||||
if total_pieces == 0 && piece_idx == 0 {
|
||||
return Ok(Some(rpc_packet));
|
||||
}
|
||||
|
||||
// about 32MB max size
|
||||
if total_pieces > 32 * 1024 || total_pieces == 0 {
|
||||
return Err(Error::MalformatRpcPacket(format!(
|
||||
"total_pieces is invalid: {}",
|
||||
total_pieces
|
||||
)));
|
||||
}
|
||||
|
||||
if piece_idx >= total_pieces {
|
||||
return Err(Error::MalformatRpcPacket(
|
||||
"piece_idx >= total_pieces".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
if self.first_piece.is_none()
|
||||
|| self.first_piece.as_ref().unwrap().transaction_id != rpc_packet.transaction_id
|
||||
|| self.first_piece.as_ref().unwrap().from_peer != rpc_packet.from_peer
|
||||
{
|
||||
self.first_piece = Some(rpc_packet.clone());
|
||||
self.pieces.clear();
|
||||
}
|
||||
|
||||
self.pieces
|
||||
.resize(total_pieces as usize, Default::default());
|
||||
self.pieces[piece_idx as usize] = rpc_packet;
|
||||
|
||||
self.last_updated = std::time::Instant::now();
|
||||
|
||||
Ok(self.try_merge_pieces())
|
||||
}
|
||||
|
||||
pub fn last_updated(&self) -> std::time::Instant {
|
||||
self.last_updated
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_rpc_packet(
|
||||
from_peer: PeerId,
|
||||
to_peer: PeerId,
|
||||
rpc_desc: RpcDescriptor,
|
||||
transaction_id: RpcTransactId,
|
||||
is_req: bool,
|
||||
content: &Vec<u8>,
|
||||
trace_id: i32,
|
||||
) -> Vec<ZCPacket> {
|
||||
let mut ret = Vec::new();
|
||||
let content_mtu = RPC_PACKET_CONTENT_MTU;
|
||||
let total_pieces = (content.len() + content_mtu - 1) / content_mtu;
|
||||
let mut cur_offset = 0;
|
||||
while cur_offset < content.len() || content.len() == 0 {
|
||||
let mut cur_len = content_mtu;
|
||||
if cur_offset + cur_len > content.len() {
|
||||
cur_len = content.len() - cur_offset;
|
||||
}
|
||||
|
||||
let mut cur_content = Vec::new();
|
||||
cur_content.extend_from_slice(&content[cur_offset..cur_offset + cur_len]);
|
||||
|
||||
let cur_packet = RpcPacket {
|
||||
from_peer,
|
||||
to_peer,
|
||||
descriptor: Some(rpc_desc.clone()),
|
||||
is_request: is_req,
|
||||
total_pieces: total_pieces as u32,
|
||||
piece_idx: (cur_offset / content_mtu) as u32,
|
||||
transaction_id,
|
||||
body: cur_content,
|
||||
trace_id,
|
||||
};
|
||||
cur_offset += cur_len;
|
||||
|
||||
let packet_type = if is_req {
|
||||
PacketType::RpcReq
|
||||
} else {
|
||||
PacketType::RpcResp
|
||||
};
|
||||
|
||||
let mut buf = Vec::new();
|
||||
cur_packet.encode(&mut buf).unwrap();
|
||||
let mut zc_packet = ZCPacket::new_with_payload(&buf);
|
||||
zc_packet.fill_peer_manager_hdr(from_peer, to_peer, packet_type as u8);
|
||||
ret.push(zc_packet);
|
||||
|
||||
if content.len() == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use prost::Message;
|
||||
use tokio::{task::JoinSet, time::timeout};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use crate::{
|
||||
common::{join_joinset_background, PeerId},
|
||||
proto::{
|
||||
common::{self, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse},
|
||||
rpc_types::error::Result,
|
||||
},
|
||||
tunnel::{
|
||||
mpsc::{MpscTunnel, MpscTunnelSender},
|
||||
ring::create_ring_tunnel_pair,
|
||||
Tunnel, ZCPacketStream,
|
||||
},
|
||||
};
|
||||
|
||||
use super::{
|
||||
packet::{build_rpc_packet, PacketMerger},
|
||||
service_registry::ServiceRegistry,
|
||||
RpcController, Transport,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct PacketMergerKey {
|
||||
from_peer_id: PeerId,
|
||||
rpc_desc: RpcDescriptor,
|
||||
transaction_id: i64,
|
||||
}
|
||||
|
||||
pub struct Server {
|
||||
registry: Arc<ServiceRegistry>,
|
||||
|
||||
mpsc: Mutex<Option<MpscTunnel<Box<dyn Tunnel>>>>,
|
||||
|
||||
transport: Mutex<Transport>,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
packet_mergers: Arc<DashMap<PacketMergerKey, PacketMerger>>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub fn new() -> Self {
|
||||
Server::new_with_registry(Arc::new(ServiceRegistry::new()))
|
||||
}
|
||||
|
||||
pub fn new_with_registry(registry: Arc<ServiceRegistry>) -> Self {
|
||||
let (ring_a, ring_b) = create_ring_tunnel_pair();
|
||||
|
||||
Self {
|
||||
registry,
|
||||
mpsc: Mutex::new(Some(MpscTunnel::new(ring_a))),
|
||||
transport: Mutex::new(MpscTunnel::new(ring_b)),
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
packet_mergers: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn registry(&self) -> &ServiceRegistry {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
pub fn get_transport_sink(&self) -> MpscTunnelSender {
|
||||
self.transport.lock().unwrap().get_sink()
|
||||
}
|
||||
|
||||
pub fn get_transport_stream(&self) -> Pin<Box<dyn ZCPacketStream>> {
|
||||
self.transport.lock().unwrap().get_stream()
|
||||
}
|
||||
|
||||
pub fn run(&self) {
|
||||
let tasks = self.tasks.clone();
|
||||
join_joinset_background(tasks.clone(), "rpc server".to_string());
|
||||
|
||||
let mpsc = self.mpsc.lock().unwrap().take().unwrap();
|
||||
|
||||
let packet_merges = self.packet_mergers.clone();
|
||||
let reg = self.registry.clone();
|
||||
let t = tasks.clone();
|
||||
tasks.lock().unwrap().spawn(async move {
|
||||
let mut mpsc = mpsc;
|
||||
let mut rx = mpsc.get_stream();
|
||||
|
||||
while let Some(packet) = rx.next().await {
|
||||
if let Err(err) = packet {
|
||||
tracing::error!(?err, "Failed to receive packet");
|
||||
continue;
|
||||
}
|
||||
let packet = match common::RpcPacket::decode(packet.unwrap().payload()) {
|
||||
Err(err) => {
|
||||
tracing::error!(?err, "Failed to decode packet");
|
||||
continue;
|
||||
}
|
||||
Ok(packet) => packet,
|
||||
};
|
||||
|
||||
if !packet.is_request {
|
||||
tracing::warn!(?packet, "Received non-request packet");
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = PacketMergerKey {
|
||||
from_peer_id: packet.from_peer,
|
||||
rpc_desc: packet.descriptor.clone().unwrap_or_default(),
|
||||
transaction_id: packet.transaction_id,
|
||||
};
|
||||
|
||||
let ret = packet_merges
|
||||
.entry(key.clone())
|
||||
.or_insert_with(PacketMerger::new)
|
||||
.feed(packet);
|
||||
|
||||
match ret {
|
||||
Ok(Some(packet)) => {
|
||||
packet_merges.remove(&key);
|
||||
t.lock().unwrap().spawn(Self::handle_rpc(
|
||||
mpsc.get_sink(),
|
||||
packet,
|
||||
reg.clone(),
|
||||
));
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to feed packet to merger, {}", err.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let packet_mergers = self.packet_mergers.clone();
|
||||
tasks.lock().unwrap().spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
packet_mergers.retain(|_, v| v.last_updated().elapsed().as_secs() < 10);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
|
||||
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?;
|
||||
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
|
||||
let ctrl = RpcController {};
|
||||
Ok(timeout(
|
||||
timeout_duration,
|
||||
reg.call_method(
|
||||
packet.descriptor.unwrap(),
|
||||
ctrl,
|
||||
Bytes::from(rpc_request.request),
|
||||
),
|
||||
)
|
||||
.await??)
|
||||
}
|
||||
|
||||
async fn handle_rpc(sender: MpscTunnelSender, packet: RpcPacket, reg: Arc<ServiceRegistry>) {
|
||||
let from_peer = packet.from_peer;
|
||||
let to_peer = packet.to_peer;
|
||||
let transaction_id = packet.transaction_id;
|
||||
let trace_id = packet.trace_id;
|
||||
let desc = packet.descriptor.clone().unwrap();
|
||||
|
||||
let mut resp_msg = RpcResponse::default();
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
let resp_bytes = Self::handle_rpc_request(packet, reg).await;
|
||||
|
||||
match &resp_bytes {
|
||||
Ok(r) => {
|
||||
resp_msg.response = r.clone().into();
|
||||
}
|
||||
Err(err) => {
|
||||
resp_msg.error = Some(err.into());
|
||||
}
|
||||
};
|
||||
resp_msg.runtime_us = now.elapsed().as_micros() as u64;
|
||||
|
||||
let packets = build_rpc_packet(
|
||||
to_peer,
|
||||
from_peer,
|
||||
desc,
|
||||
transaction_id,
|
||||
false,
|
||||
&resp_msg.encode_to_vec(),
|
||||
trace_id,
|
||||
);
|
||||
|
||||
for packet in packets {
|
||||
if let Err(err) = sender.send(packet).await {
|
||||
tracing::error!(?err, "Failed to send response packet");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inflight_count(&self) -> usize {
|
||||
self.packet_mergers.len()
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.transport.lock().unwrap().close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
use crate::proto::common::RpcDescriptor;
|
||||
use crate::proto::rpc_types;
|
||||
use crate::proto::rpc_types::descriptor::ServiceDescriptor;
|
||||
use crate::proto::rpc_types::handler::{Handler, HandlerExt};
|
||||
|
||||
use super::RpcController;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
|
||||
pub struct ServiceKey {
|
||||
pub domain_name: String,
|
||||
pub service_name: String,
|
||||
pub proto_name: String,
|
||||
}
|
||||
|
||||
impl From<&RpcDescriptor> for ServiceKey {
|
||||
fn from(desc: &RpcDescriptor) -> Self {
|
||||
Self {
|
||||
domain_name: desc.domain_name.to_string(),
|
||||
service_name: desc.service_name.to_string(),
|
||||
proto_name: desc.proto_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ServiceEntry {
|
||||
service: Arc<Box<dyn HandlerExt<Controller = RpcController>>>,
|
||||
}
|
||||
|
||||
impl ServiceEntry {
|
||||
fn new<H: Handler<Controller = RpcController>>(h: H) -> Self {
|
||||
Self {
|
||||
service: Arc::new(Box::new(h)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_method(
|
||||
&self,
|
||||
ctrl: RpcController,
|
||||
method_index: u8,
|
||||
input: bytes::Bytes,
|
||||
) -> rpc_types::error::Result<bytes::Bytes> {
|
||||
self.service.call_method(ctrl, method_index, input).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServiceRegistry {
|
||||
table: DashMap<ServiceKey, ServiceEntry>,
|
||||
}
|
||||
|
||||
impl ServiceRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
table: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<H: Handler<Controller = RpcController>>(&self, h: H, domain_name: &str) {
|
||||
let desc = h.service_descriptor();
|
||||
let key = ServiceKey {
|
||||
domain_name: domain_name.to_string(),
|
||||
service_name: desc.name().to_string(),
|
||||
proto_name: desc.proto_name().to_string(),
|
||||
};
|
||||
let entry = ServiceEntry::new(h);
|
||||
self.table.insert(key, entry);
|
||||
}
|
||||
|
||||
pub fn unregister<H: Handler<Controller = RpcController>>(
|
||||
&self,
|
||||
h: H,
|
||||
domain_name: &str,
|
||||
) -> Option<()> {
|
||||
let desc = h.service_descriptor();
|
||||
let key = ServiceKey {
|
||||
domain_name: domain_name.to_string(),
|
||||
service_name: desc.name().to_string(),
|
||||
proto_name: desc.proto_name().to_string(),
|
||||
};
|
||||
self.table.remove(&key).map(|_| ())
|
||||
}
|
||||
|
||||
pub async fn call_method(
|
||||
&self,
|
||||
rpc_desc: RpcDescriptor,
|
||||
ctrl: RpcController,
|
||||
input: bytes::Bytes,
|
||||
) -> rpc_types::error::Result<bytes::Bytes> {
|
||||
let service_key = ServiceKey::from(&rpc_desc);
|
||||
let method_index = rpc_desc.method_index as u8;
|
||||
let entry = self
|
||||
.table
|
||||
.get(&service_key)
|
||||
.ok_or(rpc_types::error::Error::InvalidServiceKey(
|
||||
service_key.service_name.clone(),
|
||||
service_key.proto_name.clone(),
|
||||
))?
|
||||
.clone();
|
||||
entry.call_method(ctrl, method_index, input).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
use std::{
|
||||
sync::{atomic::AtomicU32, Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::{SinkExt as _, StreamExt};
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::{
|
||||
common::join_joinset_background,
|
||||
proto::rpc_types::{__rt::RpcClientFactory, error::Error},
|
||||
tunnel::{Tunnel, TunnelConnector, TunnelListener},
|
||||
};
|
||||
|
||||
use super::{client::Client, server::Server, service_registry::ServiceRegistry};
|
||||
|
||||
struct StandAloneServerOneTunnel {
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
rpc_server: Server,
|
||||
}
|
||||
|
||||
impl StandAloneServerOneTunnel {
|
||||
pub fn new(tunnel: Box<dyn Tunnel>, registry: Arc<ServiceRegistry>) -> Self {
|
||||
let rpc_server = Server::new_with_registry(registry);
|
||||
StandAloneServerOneTunnel { tunnel, rpc_server }
|
||||
}
|
||||
|
||||
pub async fn run(self) {
|
||||
use tokio_stream::StreamExt as _;
|
||||
|
||||
let (tunnel_rx, tunnel_tx) = self.tunnel.split();
|
||||
let (rpc_rx, rpc_tx) = (
|
||||
self.rpc_server.get_transport_stream(),
|
||||
self.rpc_server.get_transport_sink(),
|
||||
);
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
tasks.spawn(async move {
|
||||
let ret = tunnel_rx.timeout(Duration::from_secs(60));
|
||||
tokio::pin!(ret);
|
||||
while let Ok(Some(Ok(p))) = ret.try_next().await {
|
||||
if let Err(e) = rpc_tx.send(p).await {
|
||||
tracing::error!("tunnel_rx send to rpc_tx error: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
tracing::info!("forward tunnel_rx to rpc_tx done");
|
||||
});
|
||||
|
||||
tasks.spawn(async move {
|
||||
let ret = rpc_rx.forward(tunnel_tx).await;
|
||||
tracing::info!("rpc_rx forward tunnel_tx done: {:?}", ret);
|
||||
});
|
||||
|
||||
self.rpc_server.run();
|
||||
|
||||
while let Some(ret) = tasks.join_next().await {
|
||||
self.rpc_server.close();
|
||||
tracing::info!("task done: {:?}", ret);
|
||||
}
|
||||
|
||||
tracing::info!("all tasks done");
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StandAloneServer<L> {
|
||||
registry: Arc<ServiceRegistry>,
|
||||
listener: Option<L>,
|
||||
inflight_server: Arc<AtomicU32>,
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl<L: TunnelListener + 'static> StandAloneServer<L> {
|
||||
pub fn new(listener: L) -> Self {
|
||||
StandAloneServer {
|
||||
registry: Arc::new(ServiceRegistry::new()),
|
||||
listener: Some(listener),
|
||||
inflight_server: Arc::new(AtomicU32::new(0)),
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn registry(&self) -> &ServiceRegistry {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
pub async fn serve(&mut self) -> Result<(), Error> {
|
||||
let tasks = self.tasks.clone();
|
||||
let mut listener = self.listener.take().unwrap();
|
||||
let registry = self.registry.clone();
|
||||
|
||||
join_joinset_background(tasks.clone(), "standalone server tasks".to_string());
|
||||
|
||||
listener
|
||||
.listen()
|
||||
.await
|
||||
.with_context(|| "failed to listen")?;
|
||||
|
||||
let inflight_server = self.inflight_server.clone();
|
||||
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
while let Ok(tunnel) = listener.accept().await {
|
||||
let server = StandAloneServerOneTunnel::new(tunnel, registry.clone());
|
||||
let inflight_server = inflight_server.clone();
|
||||
inflight_server.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
tasks.lock().unwrap().spawn(async move {
|
||||
server.run().await;
|
||||
inflight_server.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
});
|
||||
}
|
||||
panic!("standalone server listener exit");
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn inflight_server(&self) -> u32 {
|
||||
self.inflight_server
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
struct StandAloneClientOneTunnel {
|
||||
rpc_client: Client,
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
error: Arc<Mutex<Option<Error>>>,
|
||||
}
|
||||
|
||||
impl StandAloneClientOneTunnel {
|
||||
pub fn new(tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let rpc_client = Client::new();
|
||||
let (mut rpc_rx, rpc_tx) = (
|
||||
rpc_client.get_transport_stream(),
|
||||
rpc_client.get_transport_sink(),
|
||||
);
|
||||
let tasks = Arc::new(Mutex::new(JoinSet::new()));
|
||||
|
||||
let (mut tunnel_rx, mut tunnel_tx) = tunnel.split();
|
||||
|
||||
let error_store = Arc::new(Mutex::new(None));
|
||||
|
||||
let error = error_store.clone();
|
||||
tasks.lock().unwrap().spawn(async move {
|
||||
while let Some(p) = rpc_rx.next().await {
|
||||
match p {
|
||||
Ok(p) => {
|
||||
if let Err(e) = tunnel_tx
|
||||
.send(p)
|
||||
.await
|
||||
.with_context(|| "failed to send packet")
|
||||
{
|
||||
*error.lock().unwrap() = Some(e.into());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*error.lock().unwrap() = Some(anyhow::anyhow!("rpc_rx next exit").into());
|
||||
});
|
||||
|
||||
let error = error_store.clone();
|
||||
tasks.lock().unwrap().spawn(async move {
|
||||
while let Some(p) = tunnel_rx.next().await {
|
||||
match p {
|
||||
Ok(p) => {
|
||||
if let Err(e) = rpc_tx
|
||||
.send(p)
|
||||
.await
|
||||
.with_context(|| "failed to send packet")
|
||||
{
|
||||
*error.lock().unwrap() = Some(e.into());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*error.lock().unwrap() = Some(anyhow::anyhow!("tunnel_rx next exit").into());
|
||||
});
|
||||
|
||||
rpc_client.run();
|
||||
|
||||
StandAloneClientOneTunnel {
|
||||
rpc_client,
|
||||
tasks,
|
||||
error: error_store,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_error(&self) -> Option<Error> {
|
||||
self.error.lock().unwrap().take()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StandAloneClient<C: TunnelConnector> {
|
||||
connector: C,
|
||||
client: Option<StandAloneClientOneTunnel>,
|
||||
}
|
||||
|
||||
impl<C: TunnelConnector> StandAloneClient<C> {
|
||||
pub fn new(connector: C) -> Self {
|
||||
StandAloneClient {
|
||||
connector,
|
||||
client: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, Error> {
|
||||
Ok(self.connector.connect().await.with_context(|| {
|
||||
format!(
|
||||
"failed to connect to server: {:?}",
|
||||
self.connector.remote_url()
|
||||
)
|
||||
})?)
|
||||
}
|
||||
|
||||
pub async fn scoped_client<F: RpcClientFactory>(
|
||||
&mut self,
|
||||
domain_name: String,
|
||||
) -> Result<F::ClientImpl, Error> {
|
||||
let mut c = self.client.take();
|
||||
let error = c.as_ref().and_then(|c| c.take_error());
|
||||
if c.is_none() || error.is_some() {
|
||||
tracing::info!("reconnect due to error: {:?}", error);
|
||||
let tunnel = self.connect().await?;
|
||||
c = Some(StandAloneClientOneTunnel::new(tunnel));
|
||||
}
|
||||
|
||||
self.client = c;
|
||||
|
||||
Ok(self
|
||||
.client
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.rpc_client
|
||||
.scoped_client::<F>(1, 1, domain_name))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user