mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 10:14:35 +00:00
34f832bbad
* remove log crates * remove udp/raw of smoltcp * make tun as optional dependancy, compile for freebsd works
486 lines
16 KiB
Rust
486 lines
16 KiB
Rust
use std::{
|
|
any::Any,
|
|
fmt::Debug,
|
|
pin::Pin,
|
|
sync::{
|
|
atomic::{AtomicU32, Ordering},
|
|
Arc,
|
|
},
|
|
};
|
|
|
|
use futures::{SinkExt, StreamExt, TryFutureExt};
|
|
|
|
use prost::Message;
|
|
|
|
use tokio::{
|
|
sync::{broadcast, mpsc, Mutex},
|
|
task::JoinSet,
|
|
time::{timeout, Duration},
|
|
};
|
|
|
|
use tokio_util::sync::PollSender;
|
|
use tracing::Instrument;
|
|
use zerocopy::AsBytes;
|
|
|
|
use crate::{
|
|
common::{
|
|
config::{NetworkIdentity, NetworkSecretDigest},
|
|
error::Error,
|
|
global_ctx::ArcGlobalCtx,
|
|
PeerId,
|
|
},
|
|
rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo},
|
|
tunnel::packet_def::PacketType,
|
|
tunnel::{
|
|
filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter},
|
|
mpsc::{MpscTunnel, MpscTunnelSender},
|
|
packet_def::ZCPacket,
|
|
stats::{Throughput, WindowLatency},
|
|
Tunnel, TunnelError, ZCPacketStream,
|
|
},
|
|
};
|
|
|
|
use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan};
|
|
|
|
pub type PeerConnId = uuid::Uuid;
|
|
|
|
const MAGIC: u32 = 0xd1e1a5e1;
|
|
const VERSION: u32 = 1;
|
|
|
|
pub struct PeerConn {
|
|
conn_id: PeerConnId,
|
|
|
|
my_peer_id: PeerId,
|
|
global_ctx: ArcGlobalCtx,
|
|
|
|
tunnel: Arc<Mutex<Box<dyn Any + Send + 'static>>>,
|
|
sink: MpscTunnelSender,
|
|
recv: Arc<Mutex<Option<Pin<Box<dyn ZCPacketStream>>>>>,
|
|
tunnel_info: Option<TunnelInfo>,
|
|
|
|
tasks: JoinSet<Result<(), TunnelError>>,
|
|
|
|
info: Option<HandshakeRequest>,
|
|
|
|
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
|
|
|
ctrl_resp_sender: broadcast::Sender<ZCPacket>,
|
|
|
|
latency_stats: Arc<WindowLatency>,
|
|
throughput: Arc<Throughput>,
|
|
loss_rate_stats: Arc<AtomicU32>,
|
|
}
|
|
|
|
impl Debug for PeerConn {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("PeerConn")
|
|
.field("conn_id", &self.conn_id)
|
|
.field("my_peer_id", &self.my_peer_id)
|
|
.field("info", &self.info)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl PeerConn {
|
|
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
|
let tunnel_info = tunnel.info();
|
|
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
|
|
|
let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new();
|
|
let throughput = peer_conn_tunnel_filter.filter_output();
|
|
let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter);
|
|
let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel);
|
|
|
|
let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink());
|
|
|
|
PeerConn {
|
|
conn_id: PeerConnId::new_v4(),
|
|
|
|
my_peer_id,
|
|
global_ctx,
|
|
|
|
tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))),
|
|
sink,
|
|
recv: Arc::new(Mutex::new(Some(recv))),
|
|
tunnel_info,
|
|
|
|
tasks: JoinSet::new(),
|
|
|
|
info: None,
|
|
close_event_sender: None,
|
|
|
|
ctrl_resp_sender: ctrl_sender,
|
|
|
|
latency_stats: Arc::new(WindowLatency::new(15)),
|
|
throughput,
|
|
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
|
}
|
|
}
|
|
|
|
pub fn get_conn_id(&self) -> PeerConnId {
|
|
self.conn_id
|
|
}
|
|
|
|
async fn wait_handshake(&mut self, need_retry: &mut bool) -> Result<HandshakeRequest, Error> {
|
|
*need_retry = false;
|
|
|
|
let mut locked = self.recv.lock().await;
|
|
let recv = locked.as_mut().unwrap();
|
|
let Some(rsp) = recv.next().await else {
|
|
return Err(Error::WaitRespError(
|
|
"conn closed during wait handshake response".to_owned(),
|
|
));
|
|
};
|
|
|
|
*need_retry = true;
|
|
|
|
let rsp = rsp?;
|
|
let Some(peer_mgr_hdr) = rsp.peer_manager_header() else {
|
|
return Err(Error::WaitRespError(format!(
|
|
"unexpected packet: {:?}, cannot decode peer manager hdr",
|
|
rsp
|
|
)));
|
|
};
|
|
|
|
if peer_mgr_hdr.packet_type != PacketType::HandShake as u8 {
|
|
return Err(Error::WaitRespError(format!(
|
|
"unexpected packet type: {:?}",
|
|
peer_mgr_hdr.packet_type
|
|
)));
|
|
}
|
|
|
|
let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| {
|
|
Error::WaitRespError(format!("decode handshake response error: {:?}", e))
|
|
})?;
|
|
|
|
if rsp.network_secret_digrest.len() != std::mem::size_of::<NetworkSecretDigest>() {
|
|
return Err(Error::WaitRespError(
|
|
"invalid network secret digest".to_owned(),
|
|
));
|
|
}
|
|
|
|
return Ok(rsp);
|
|
}
|
|
|
|
async fn wait_handshake_loop(&mut self) -> Result<HandshakeRequest, Error> {
|
|
timeout(Duration::from_secs(5), async move {
|
|
loop {
|
|
let mut need_retry = true;
|
|
match self.wait_handshake(&mut need_retry).await {
|
|
Ok(rsp) => return Ok(rsp),
|
|
Err(e) => {
|
|
tracing::warn!("wait handshake error: {:?}", e);
|
|
if !need_retry {
|
|
return Err(e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
.map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e)))
|
|
.await?
|
|
}
|
|
|
|
async fn send_handshake(&mut self) -> Result<(), Error> {
|
|
let network = self.global_ctx.get_network_identity();
|
|
let mut req = HandshakeRequest {
|
|
magic: MAGIC,
|
|
my_peer_id: self.my_peer_id,
|
|
version: VERSION,
|
|
features: Vec::new(),
|
|
network_name: network.network_name.clone(),
|
|
..Default::default()
|
|
};
|
|
req.network_secret_digrest
|
|
.extend_from_slice(&network.network_secret_digest.unwrap_or_default());
|
|
|
|
let hs_req = req.encode_to_vec();
|
|
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
|
|
zc_packet.fill_peer_manager_hdr(
|
|
self.my_peer_id,
|
|
PeerId::default(),
|
|
PacketType::HandShake as u8,
|
|
);
|
|
|
|
self.sink.send(zc_packet).await.map_err(|e| {
|
|
tracing::warn!("send handshake request error: {:?}", e);
|
|
Error::WaitRespError("send handshake request error".to_owned())
|
|
})?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument]
|
|
pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> {
|
|
let rsp = self.wait_handshake_loop().await?;
|
|
tracing::info!("handshake request: {:?}", rsp);
|
|
self.info = Some(rsp);
|
|
self.send_handshake().await?;
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument]
|
|
pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> {
|
|
self.send_handshake().await?;
|
|
tracing::info!("waiting for handshake request from server");
|
|
let rsp = self.wait_handshake_loop().await?;
|
|
tracing::info!("handshake response: {:?}", rsp);
|
|
self.info = Some(rsp);
|
|
Ok(())
|
|
}
|
|
|
|
pub fn handshake_done(&self) -> bool {
|
|
self.info.is_some()
|
|
}
|
|
|
|
pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
|
let mut stream = self.recv.lock().await.take().unwrap();
|
|
let sink = self.sink.clone();
|
|
let mut sender = PollSender::new(packet_recv_chan.clone());
|
|
let close_event_sender = self.close_event_sender.clone().unwrap();
|
|
let conn_id = self.conn_id;
|
|
let ctrl_sender = self.ctrl_resp_sender.clone();
|
|
let _conn_info = self.get_conn_info();
|
|
let conn_info_for_instrument = self.get_conn_info();
|
|
|
|
self.tasks.spawn(
|
|
async move {
|
|
tracing::info!("start recving peer conn packet");
|
|
let mut task_ret = Ok(());
|
|
while let Some(ret) = stream.next().await {
|
|
if ret.is_err() {
|
|
tracing::error!(error = ?ret, "peer conn recv error");
|
|
task_ret = Err(ret.err().unwrap());
|
|
break;
|
|
}
|
|
|
|
let mut zc_packet = ret.unwrap();
|
|
let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else {
|
|
tracing::error!(
|
|
"unexpected packet: {:?}, cannot decode peer manager hdr",
|
|
zc_packet
|
|
);
|
|
continue;
|
|
};
|
|
|
|
if peer_mgr_hdr.packet_type == PacketType::Ping as u8 {
|
|
peer_mgr_hdr.packet_type = PacketType::Pong as u8;
|
|
if let Err(e) = sink.send(zc_packet).await {
|
|
tracing::error!(?e, "peer conn send req error");
|
|
}
|
|
} else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 {
|
|
if let Err(e) = ctrl_sender.send(zc_packet) {
|
|
tracing::error!(?e, "peer conn send ctrl resp error");
|
|
}
|
|
} else {
|
|
if sender.send(zc_packet).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::info!("end recving peer conn packet");
|
|
|
|
drop(sink);
|
|
if let Err(e) = close_event_sender.send(conn_id).await {
|
|
tracing::error!(error = ?e, "peer conn close event send error");
|
|
}
|
|
|
|
task_ret
|
|
}
|
|
.instrument(
|
|
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
|
),
|
|
);
|
|
}
|
|
|
|
pub fn start_pingpong(&mut self) {
|
|
let mut pingpong = PeerConnPinger::new(
|
|
self.my_peer_id,
|
|
self.get_peer_id(),
|
|
self.sink.clone(),
|
|
self.ctrl_resp_sender.clone(),
|
|
self.latency_stats.clone(),
|
|
self.loss_rate_stats.clone(),
|
|
);
|
|
|
|
let close_event_sender = self.close_event_sender.clone().unwrap();
|
|
let conn_id = self.conn_id;
|
|
|
|
self.tasks.spawn(async move {
|
|
pingpong.pingpong().await;
|
|
|
|
tracing::warn!(?pingpong, "pingpong task exit");
|
|
|
|
if let Err(e) = close_event_sender.send(conn_id).await {
|
|
tracing::warn!("close event sender error: {:?}", e);
|
|
}
|
|
|
|
Ok(())
|
|
});
|
|
}
|
|
|
|
pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> {
|
|
Ok(self.sink.send(msg).await?)
|
|
}
|
|
|
|
pub fn get_peer_id(&self) -> PeerId {
|
|
self.info.as_ref().unwrap().my_peer_id
|
|
}
|
|
|
|
pub fn get_network_identity(&self) -> NetworkIdentity {
|
|
let info = self.info.as_ref().unwrap();
|
|
let mut ret = NetworkIdentity {
|
|
network_name: info.network_name.clone(),
|
|
..Default::default()
|
|
};
|
|
ret.network_secret_digest = Some([0u8; 32]);
|
|
ret.network_secret_digest
|
|
.as_mut()
|
|
.unwrap()
|
|
.copy_from_slice(&info.network_secret_digrest);
|
|
ret
|
|
}
|
|
|
|
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
|
self.close_event_sender = Some(sender);
|
|
}
|
|
|
|
pub fn get_stats(&self) -> PeerConnStats {
|
|
PeerConnStats {
|
|
latency_us: self.latency_stats.get_latency_us(),
|
|
|
|
tx_bytes: self.throughput.tx_bytes(),
|
|
rx_bytes: self.throughput.rx_bytes(),
|
|
|
|
tx_packets: self.throughput.tx_packets(),
|
|
rx_packets: self.throughput.rx_packets(),
|
|
}
|
|
}
|
|
|
|
pub fn get_conn_info(&self) -> PeerConnInfo {
|
|
PeerConnInfo {
|
|
conn_id: self.conn_id.to_string(),
|
|
my_peer_id: self.my_peer_id,
|
|
peer_id: self.get_peer_id(),
|
|
features: self.info.as_ref().unwrap().features.clone(),
|
|
tunnel: self.tunnel_info.clone(),
|
|
stats: Some(self.get_stats()),
|
|
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::Arc;
|
|
|
|
use super::*;
|
|
use crate::common::global_ctx::tests::get_mock_global_ctx;
|
|
use crate::common::new_peer_id;
|
|
use crate::tunnel::filter::tests::DropSendTunnelFilter;
|
|
use crate::tunnel::filter::PacketRecorderTunnelFilter;
|
|
use crate::tunnel::ring::create_ring_tunnel_pair;
|
|
|
|
#[tokio::test]
|
|
async fn peer_conn_handshake() {
|
|
let (c, s) = create_ring_tunnel_pair();
|
|
|
|
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
|
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
|
|
|
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
|
let s = TunnelWithFilter::new(s, s_recorder.clone());
|
|
|
|
let c_peer_id = new_peer_id();
|
|
let s_peer_id = new_peer_id();
|
|
|
|
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
|
|
|
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
|
|
|
let (c_ret, s_ret) = tokio::join!(
|
|
c_peer.do_handshake_as_client(),
|
|
s_peer.do_handshake_as_server()
|
|
);
|
|
|
|
c_ret.unwrap();
|
|
s_ret.unwrap();
|
|
|
|
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
|
|
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
|
|
|
|
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
|
|
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
|
|
|
|
assert_eq!(c_peer.get_peer_id(), s_peer_id);
|
|
assert_eq!(s_peer.get_peer_id(), c_peer_id);
|
|
assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity());
|
|
assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default());
|
|
}
|
|
|
|
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
|
|
let (c, s) = create_ring_tunnel_pair();
|
|
|
|
// drop 1-3 packets should not affect pingpong
|
|
let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end));
|
|
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
|
|
|
let c_peer_id = new_peer_id();
|
|
let s_peer_id = new_peer_id();
|
|
|
|
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
|
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
|
|
|
let (c_ret, s_ret) = tokio::join!(
|
|
c_peer.do_handshake_as_client(),
|
|
s_peer.do_handshake_as_server()
|
|
);
|
|
|
|
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
|
|
s_peer
|
|
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
|
.await;
|
|
|
|
assert!(c_ret.is_ok());
|
|
assert!(s_ret.is_ok());
|
|
|
|
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
|
|
c_peer.set_close_event_sender(close_send);
|
|
c_peer.start_pingpong();
|
|
c_peer
|
|
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
|
.await;
|
|
|
|
// wait 5s, conn should not be disconnected
|
|
tokio::time::sleep(Duration::from_secs(15)).await;
|
|
|
|
if conn_closed {
|
|
assert!(close_recv.try_recv().is_ok());
|
|
} else {
|
|
assert!(close_recv.try_recv().is_err());
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn peer_conn_pingpong_timeout() {
|
|
peer_conn_pingpong_test_common(3, 5, false).await;
|
|
peer_conn_pingpong_test_common(5, 12, true).await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn close_tunnel_during_handshake() {
|
|
let (c, s) = create_ring_tunnel_pair();
|
|
let mut c_peer = PeerConn::new(new_peer_id(), get_mock_global_ctx(), Box::new(c));
|
|
let j = tokio::spawn(async move {
|
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
|
drop(s);
|
|
});
|
|
timeout(Duration::from_millis(1500), c_peer.do_handshake_as_client())
|
|
.await
|
|
.unwrap()
|
|
.unwrap_err();
|
|
let _ = tokio::join!(j);
|
|
}
|
|
}
|