mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-14 01:45:46 +00:00
209 lines
6.5 KiB
Rust
209 lines
6.5 KiB
Rust
use std::sync::{Arc, Mutex, atomic::AtomicBool};
|
|
|
|
use futures::{SinkExt as _, StreamExt};
|
|
use guarden::defer;
|
|
use tokio::{task::JoinSet, time::timeout};
|
|
|
|
use crate::{
|
|
proto::rpc_types::error::Error,
|
|
tunnel::{Tunnel, packet_def::PacketType, ring::create_ring_tunnel_pair},
|
|
};
|
|
|
|
use super::{client::Client, server::Server, service_registry::ServiceRegistry};
|
|
use crate::common::stats_manager::StatsManager;
|
|
|
|
pub struct BidirectRpcManager {
|
|
rpc_client: Client,
|
|
rpc_server: Server,
|
|
|
|
rx_timeout: Option<std::time::Duration>,
|
|
error: Arc<Mutex<Option<Error>>>,
|
|
tunnel: Mutex<Option<Box<dyn Tunnel>>>,
|
|
running: Arc<AtomicBool>,
|
|
|
|
tasks: Mutex<Option<JoinSet<()>>>,
|
|
}
|
|
|
|
impl Default for BidirectRpcManager {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl BidirectRpcManager {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
rpc_client: Client::new(),
|
|
rpc_server: Server::new(),
|
|
|
|
rx_timeout: None,
|
|
error: Arc::new(Mutex::new(None)),
|
|
tunnel: Mutex::new(None),
|
|
running: Arc::new(AtomicBool::new(false)),
|
|
|
|
tasks: Mutex::new(None),
|
|
}
|
|
}
|
|
|
|
pub fn new_with_stats_manager(stats_manager: Arc<StatsManager>) -> Self {
|
|
Self {
|
|
rpc_client: Client::new_with_stats_manager(stats_manager.clone()),
|
|
rpc_server: Server::new_with_registry_and_stats_manager(
|
|
Arc::new(ServiceRegistry::new()),
|
|
stats_manager,
|
|
),
|
|
|
|
rx_timeout: None,
|
|
error: Arc::new(Mutex::new(None)),
|
|
tunnel: Mutex::new(None),
|
|
running: Arc::new(AtomicBool::new(false)),
|
|
|
|
tasks: Mutex::new(None),
|
|
}
|
|
}
|
|
|
|
pub fn set_rx_timeout(mut self, timeout: Option<std::time::Duration>) -> Self {
|
|
self.rx_timeout = timeout;
|
|
self
|
|
}
|
|
|
|
pub fn run_and_create_tunnel(&self) -> Box<dyn Tunnel> {
|
|
let (ret, inner) = create_ring_tunnel_pair();
|
|
self.run_with_tunnel(inner);
|
|
ret
|
|
}
|
|
|
|
pub fn run_with_tunnel(&self, inner: Box<dyn Tunnel>) {
|
|
let mut tasks = JoinSet::new();
|
|
self.rpc_client.run();
|
|
self.rpc_server.run();
|
|
self.running
|
|
.store(true, std::sync::atomic::Ordering::Relaxed);
|
|
|
|
let (server_tx, mut server_rx) = (
|
|
self.rpc_server.get_transport_sink(),
|
|
self.rpc_server.get_transport_stream(),
|
|
);
|
|
let (client_tx, mut client_rx) = (
|
|
self.rpc_client.get_transport_sink(),
|
|
self.rpc_client.get_transport_stream(),
|
|
);
|
|
|
|
let (mut inner_rx, mut inner_tx) = inner.split();
|
|
self.tunnel.lock().unwrap().replace(inner);
|
|
|
|
let e_clone = self.error.clone();
|
|
let r_clone = self.running.clone();
|
|
tasks.spawn(async move {
|
|
defer! {
|
|
r_clone.store(false, std::sync::atomic::Ordering::Relaxed);
|
|
}
|
|
loop {
|
|
let packet = tokio::select! {
|
|
Some(Ok(packet)) = server_rx.next() => {
|
|
tracing::trace!(?packet, "recv rpc packet from server");
|
|
packet
|
|
}
|
|
Some(Ok(packet)) = client_rx.next() => {
|
|
tracing::trace!(?packet, "recv rpc packet from client");
|
|
packet
|
|
}
|
|
else => {
|
|
tracing::warn!("rpc transport read aborted, exiting");
|
|
break;
|
|
}
|
|
};
|
|
|
|
if let Err(e) = inner_tx.send(packet).await {
|
|
tracing::error!(error = ?e, "send to peer failed");
|
|
e_clone.lock().unwrap().replace(Error::from(e));
|
|
}
|
|
}
|
|
});
|
|
|
|
let recv_timeout = self.rx_timeout;
|
|
let e_clone = self.error.clone();
|
|
let r_clone = self.running.clone();
|
|
tasks.spawn(async move {
|
|
defer! {
|
|
r_clone.store(false, std::sync::atomic::Ordering::Relaxed);
|
|
}
|
|
loop {
|
|
let ret = if let Some(recv_timeout) = recv_timeout {
|
|
match timeout(recv_timeout, inner_rx.next()).await {
|
|
Ok(ret) => ret,
|
|
Err(e) => {
|
|
e_clone.lock().unwrap().replace(e.into());
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
inner_rx.next().await
|
|
};
|
|
|
|
let o = match ret {
|
|
Some(Ok(o)) => o,
|
|
Some(Err(e)) => {
|
|
tracing::error!(error = ?e, "recv from peer failed");
|
|
e_clone.lock().unwrap().replace(Error::from(e));
|
|
break;
|
|
}
|
|
None => {
|
|
tracing::warn!("peer rpc transport read aborted, exiting");
|
|
e_clone.lock().unwrap().replace(Error::Shutdown);
|
|
break;
|
|
}
|
|
};
|
|
|
|
let Some(peer_manager_header) = o.peer_manager_header() else {
|
|
tracing::error!("peer manager header not found");
|
|
continue;
|
|
};
|
|
if peer_manager_header.packet_type == PacketType::RpcReq as u8 {
|
|
server_tx.send(o).await.unwrap();
|
|
continue;
|
|
} else if peer_manager_header.packet_type == PacketType::RpcResp as u8 {
|
|
client_tx.send(o).await.unwrap();
|
|
continue;
|
|
}
|
|
}
|
|
});
|
|
|
|
self.tasks.lock().unwrap().replace(tasks);
|
|
}
|
|
|
|
pub fn rpc_client(&self) -> &Client {
|
|
&self.rpc_client
|
|
}
|
|
|
|
pub fn rpc_server(&self) -> &Server {
|
|
&self.rpc_server
|
|
}
|
|
|
|
pub async fn stop(&self) {
|
|
let Some(mut tasks) = self.tasks.lock().unwrap().take() else {
|
|
return;
|
|
};
|
|
tasks.abort_all();
|
|
while tasks.join_next().await.is_some() {}
|
|
}
|
|
|
|
pub fn take_error(&self) -> Option<Error> {
|
|
self.error.lock().unwrap().take()
|
|
}
|
|
|
|
pub async fn wait(&self) {
|
|
let Some(mut tasks) = self.tasks.lock().unwrap().take() else {
|
|
return;
|
|
};
|
|
while tasks.join_next().await.is_some() {
|
|
// when any task is done, abort all tasks
|
|
tasks.abort_all();
|
|
}
|
|
}
|
|
|
|
pub fn is_running(&self) -> bool {
|
|
self.running.load(std::sync::atomic::Ordering::Relaxed)
|
|
}
|
|
}
|