mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 10:14:35 +00:00
fix handshake dead lock, clean old code (#61)
* fix handshake dead lock * remove old code
This commit is contained in:
+228
-418
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
fmt::Debug,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
@@ -7,8 +8,9 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use pnet::datalink::NetworkInterface;
|
||||
use futures::{SinkExt, StreamExt, TryFutureExt};
|
||||
|
||||
use prost::Message;
|
||||
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, Mutex},
|
||||
@@ -16,293 +18,34 @@ use tokio::{
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
|
||||
use tokio_util::{bytes::Bytes, sync::PollSender};
|
||||
use tokio_util::sync::PollSender;
|
||||
use tracing::Instrument;
|
||||
use zerocopy::AsBytes;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
config::{NetworkIdentity, NetworkSecretDigest},
|
||||
error::Error,
|
||||
global_ctx::ArcGlobalCtx,
|
||||
PeerId,
|
||||
},
|
||||
define_tunnel_filter_chain,
|
||||
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
|
||||
rpc::{PeerConnInfo, PeerConnStats},
|
||||
tunnels::{
|
||||
peers::packet::PacketType,
|
||||
rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo},
|
||||
tunnel::{
|
||||
filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter},
|
||||
mpsc::{MpscTunnel, MpscTunnelSender},
|
||||
packet_def::ZCPacket,
|
||||
stats::{Throughput, WindowLatency},
|
||||
tunnel_filter::StatsRecorderTunnelFilter,
|
||||
DatagramSink, Tunnel, TunnelError,
|
||||
Tunnel, TunnelError, ZCPacketStream,
|
||||
},
|
||||
};
|
||||
|
||||
use super::packet::{self, HandShake, Packet};
|
||||
|
||||
pub type PacketRecvChan = mpsc::Sender<Bytes>;
|
||||
use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan};
|
||||
|
||||
pub type PeerConnId = uuid::Uuid;
|
||||
|
||||
macro_rules! wait_response {
|
||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||
let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response timeout".to_owned(),
|
||||
));
|
||||
};
|
||||
let Some(rsp_vec) = rsp_vec else {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response get none".to_owned(),
|
||||
));
|
||||
};
|
||||
let Ok(rsp_vec) = rsp_vec else {
|
||||
return Err(TunnelError::WaitRespError(format!(
|
||||
"wait handshake response get error {}",
|
||||
rsp_vec.err().unwrap()
|
||||
)));
|
||||
};
|
||||
|
||||
let $out_var;
|
||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||
if rsp_bytes.packet_type != PacketType::HandShake {
|
||||
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"unexpected packet type".to_owned(),
|
||||
));
|
||||
}
|
||||
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
|
||||
match &resp_payload {
|
||||
$pattern => $out_var = $value,
|
||||
_ => {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, pattern: {:?}",
|
||||
rsp_bytes,
|
||||
stringify!($pattern)
|
||||
);
|
||||
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PeerInfo {
|
||||
magic: u32,
|
||||
pub my_peer_id: PeerId,
|
||||
version: u32,
|
||||
pub features: Vec<String>,
|
||||
pub interfaces: Vec<NetworkInterface>,
|
||||
pub network_identity: NetworkIdentity,
|
||||
}
|
||||
|
||||
impl<'a> From<&HandShake> for PeerInfo {
|
||||
fn from(hs: &HandShake) -> Self {
|
||||
PeerInfo {
|
||||
magic: hs.magic.into(),
|
||||
my_peer_id: hs.my_peer_id.into(),
|
||||
version: hs.version.into(),
|
||||
features: hs.features.iter().map(|x| x.to_string()).collect(),
|
||||
interfaces: Vec::new(),
|
||||
network_identity: hs.network_identity.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PeerConnPinger {
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
}
|
||||
|
||||
impl Debug for PeerConnPinger {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConnPinger")
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("peer_id", &self.peer_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerConnPinger {
|
||||
pub fn new(
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
) -> Self {
|
||||
Self {
|
||||
my_peer_id,
|
||||
peer_id,
|
||||
sink: Arc::new(Mutex::new(sink)),
|
||||
tasks: JoinSet::new(),
|
||||
latency_stats,
|
||||
ctrl_sender,
|
||||
loss_rate_stats,
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_pingpong_once(
|
||||
my_node_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
receiver: &mut broadcast::Receiver<Bytes>,
|
||||
seq: u32,
|
||||
) -> Result<u128, TunnelError> {
|
||||
// should add seq here. so latency can be calculated more accurately
|
||||
let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into();
|
||||
tracing::trace!("send ping packet: {:?}", req);
|
||||
sink.lock().await.send(req).await.map_err(|e| {
|
||||
tracing::warn!("send ping packet error: {:?}", e);
|
||||
TunnelError::CommonError("send ping packet error".to_owned())
|
||||
})?;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// wait until we get a pong packet in ctrl_resp_receiver
|
||||
let resp = timeout(Duration::from_secs(1), async {
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(p) => {
|
||||
let ctrl_payload =
|
||||
packet::CtrlPacketPayload::from_packet(Packet::decode(&p));
|
||||
if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload {
|
||||
if resp_seq == seq {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("recv pong resp error: {:?}", e);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"recv pong resp error".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
tracing::trace!(?resp, "wait ping response done");
|
||||
|
||||
if resp.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait ping response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
if resp.as_ref().unwrap().is_err() {
|
||||
return Err(resp.unwrap().err().unwrap());
|
||||
}
|
||||
|
||||
Ok(now.elapsed().as_micros())
|
||||
}
|
||||
|
||||
async fn pingpong(&mut self) {
|
||||
let sink = self.sink.clone();
|
||||
let my_node_id = self.my_peer_id;
|
||||
let peer_id = self.peer_id;
|
||||
let latency_stats = self.latency_stats.clone();
|
||||
|
||||
let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100);
|
||||
|
||||
let stopped = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// generate a pingpong task every 200ms
|
||||
let mut pingpong_tasks = JoinSet::new();
|
||||
let ctrl_resp_sender = self.ctrl_sender.clone();
|
||||
let stopped_clone = stopped.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut req_seq = 0;
|
||||
loop {
|
||||
let receiver = ctrl_resp_sender.subscribe();
|
||||
let ping_res_sender = ping_res_sender.clone();
|
||||
let sink = sink.clone();
|
||||
|
||||
if stopped_clone.load(Ordering::Relaxed) != 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while pingpong_tasks.len() > 5 {
|
||||
pingpong_tasks.join_next().await;
|
||||
}
|
||||
|
||||
pingpong_tasks.spawn(async move {
|
||||
let mut receiver = receiver.resubscribe();
|
||||
let pingpong_once_ret = Self::do_pingpong_once(
|
||||
my_node_id,
|
||||
peer_id,
|
||||
sink.clone(),
|
||||
&mut receiver,
|
||||
req_seq,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = ping_res_sender.send(pingpong_once_ret).await {
|
||||
tracing::info!(?e, "pingpong task send result error, exit..");
|
||||
};
|
||||
});
|
||||
|
||||
req_seq = req_seq.wrapping_add(1);
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
}
|
||||
});
|
||||
|
||||
// one with 1% precision
|
||||
let loss_rate_stats_1 = WindowLatency::new(100);
|
||||
// one with 20% precision, so we can fast fail this conn.
|
||||
let loss_rate_stats_20 = WindowLatency::new(5);
|
||||
|
||||
let mut counter: u64 = 0;
|
||||
|
||||
while let Some(ret) = ping_res_receiver.recv().await {
|
||||
counter += 1;
|
||||
|
||||
if let Ok(lat) = ret {
|
||||
latency_stats.record_latency(lat as u32);
|
||||
|
||||
loss_rate_stats_1.record_latency(0);
|
||||
loss_rate_stats_20.record_latency(0);
|
||||
} else {
|
||||
loss_rate_stats_1.record_latency(1);
|
||||
loss_rate_stats_20.record_latency(1);
|
||||
}
|
||||
|
||||
let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us();
|
||||
let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us();
|
||||
|
||||
tracing::trace!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong task recv pingpong_once result"
|
||||
);
|
||||
|
||||
if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) {
|
||||
tracing::warn!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong loss rate too high, closing"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
self.loss_rate_stats
|
||||
.store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
stopped.store(1, Ordering::Relaxed);
|
||||
ping_res_receiver.close();
|
||||
}
|
||||
}
|
||||
|
||||
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
|
||||
const MAGIC: u32 = 0xd1e1a5e1;
|
||||
const VERSION: u32 = 1;
|
||||
|
||||
pub struct PeerConn {
|
||||
conn_id: PeerConnId,
|
||||
@@ -310,33 +53,45 @@ pub struct PeerConn {
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
tunnel: Arc<Mutex<Box<dyn Any + Send + 'static>>>,
|
||||
sink: MpscTunnelSender,
|
||||
recv: Arc<Mutex<Option<Pin<Box<dyn ZCPacketStream>>>>>,
|
||||
tunnel_info: Option<TunnelInfo>,
|
||||
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
|
||||
info: Option<PeerInfo>,
|
||||
info: Option<HandshakeRequest>,
|
||||
|
||||
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||
|
||||
ctrl_resp_sender: broadcast::Sender<Bytes>,
|
||||
ctrl_resp_sender: broadcast::Sender<ZCPacket>,
|
||||
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
throughput: Arc<Throughput>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
enum PeerConnPacketType {
|
||||
Data(Bytes),
|
||||
CtrlReq(Bytes),
|
||||
CtrlResp(Bytes),
|
||||
impl Debug for PeerConn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConn")
|
||||
.field("conn_id", &self.conn_id)
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerConn {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let tunnel_info = tunnel.info();
|
||||
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||
let peer_conn_tunnel = PeerConnTunnel::new();
|
||||
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
|
||||
|
||||
let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new();
|
||||
let throughput = peer_conn_tunnel_filter.filter_output();
|
||||
let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter);
|
||||
let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel);
|
||||
|
||||
let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink());
|
||||
|
||||
PeerConn {
|
||||
conn_id: PeerConnId::new_v4(),
|
||||
@@ -344,8 +99,10 @@ impl PeerConn {
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
|
||||
sink: tunnel.pin_sink(),
|
||||
tunnel: Box::new(tunnel),
|
||||
tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))),
|
||||
sink,
|
||||
recv: Arc::new(Mutex::new(Some(recv))),
|
||||
tunnel_info,
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
@@ -355,7 +112,7 @@ impl PeerConn {
|
||||
ctrl_resp_sender: ctrl_sender,
|
||||
|
||||
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
|
||||
throughput,
|
||||
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||
}
|
||||
}
|
||||
@@ -364,41 +121,97 @@ impl PeerConn {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
async fn wait_handshake(&mut self, need_retry: &mut bool) -> Result<HandshakeRequest, Error> {
|
||||
*need_retry = false;
|
||||
|
||||
tracing::info!("waiting for handshake request from client");
|
||||
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_req));
|
||||
tracing::info!("handshake request: {:?}", hs_req);
|
||||
let mut locked = self.recv.lock().await;
|
||||
let recv = locked.as_mut().unwrap();
|
||||
let Some(rsp) = recv.next().await else {
|
||||
return Err(Error::WaitRespError(
|
||||
"conn closed during wait handshake response".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
*need_retry = true;
|
||||
|
||||
let rsp = rsp?;
|
||||
let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| {
|
||||
Error::WaitRespError(format!("decode handshake response error: {:?}", e))
|
||||
})?;
|
||||
|
||||
if rsp.network_secret_digrest.len() != std::mem::size_of::<NetworkSecretDigest>() {
|
||||
return Err(Error::WaitRespError(
|
||||
"invalid network secret digest".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
return Ok(rsp);
|
||||
}
|
||||
|
||||
async fn wait_handshake_loop(&mut self) -> Result<HandshakeRequest, Error> {
|
||||
timeout(Duration::from_secs(5), async move {
|
||||
loop {
|
||||
let mut need_retry = true;
|
||||
match self.wait_handshake(&mut need_retry).await {
|
||||
Ok(rsp) => return Ok(rsp),
|
||||
Err(e) => {
|
||||
log::warn!("wait handshake error: {:?}", e);
|
||||
if !need_retry {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e)))
|
||||
.await?
|
||||
}
|
||||
|
||||
async fn send_handshake(&mut self) -> Result<(), Error> {
|
||||
let network = self.global_ctx.get_network_identity();
|
||||
let mut req = HandshakeRequest {
|
||||
magic: MAGIC,
|
||||
my_peer_id: self.my_peer_id,
|
||||
version: VERSION,
|
||||
features: Vec::new(),
|
||||
network_name: network.network_name.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
req.network_secret_digrest
|
||||
.extend_from_slice(&network.network_secret_digest.unwrap_or_default());
|
||||
|
||||
let hs_req = req.encode_to_vec();
|
||||
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
|
||||
zc_packet.fill_peer_manager_hdr(
|
||||
self.my_peer_id,
|
||||
PeerId::default(),
|
||||
PacketType::HandShake as u8,
|
||||
);
|
||||
|
||||
self.sink.send(zc_packet).await.map_err(|e| {
|
||||
tracing::warn!("send handshake request error: {:?}", e);
|
||||
Error::WaitRespError("send handshake request error".to_owned())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> {
|
||||
let rsp = self.wait_handshake_loop().await?;
|
||||
tracing::info!("handshake request: {:?}", rsp);
|
||||
self.info = Some(rsp);
|
||||
self.send_handshake().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> {
|
||||
self.send_handshake().await?;
|
||||
tracing::info!("waiting for handshake request from server");
|
||||
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_rsp));
|
||||
tracing::info!("handshake response: {:?}", hs_rsp);
|
||||
|
||||
let rsp = self.wait_handshake_loop().await?;
|
||||
tracing::info!("handshake response: {:?}", rsp);
|
||||
self.info = Some(rsp);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -406,11 +219,72 @@ impl PeerConn {
|
||||
self.info.is_some()
|
||||
}
|
||||
|
||||
pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.recv.lock().await.take().unwrap();
|
||||
let sink = self.sink.clone();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let _conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let mut zc_packet = ret.unwrap();
|
||||
let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, cannot decode peer manager hdr",
|
||||
zc_packet
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
if peer_mgr_hdr.packet_type == PacketType::Ping as u8 {
|
||||
peer_mgr_hdr.packet_type = PacketType::Pong as u8;
|
||||
if let Err(e) = sink.send(zc_packet).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
} else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 {
|
||||
if let Err(e) = ctrl_sender.send(zc_packet) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
} else {
|
||||
if sender.send(zc_packet).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
drop(sink);
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn start_pingpong(&mut self) {
|
||||
let mut pingpong = PeerConnPinger::new(
|
||||
self.my_peer_id,
|
||||
self.get_peer_id(),
|
||||
self.tunnel.pin_sink(),
|
||||
self.sink.clone(),
|
||||
self.ctrl_resp_sender.clone(),
|
||||
self.latency_stats.clone(),
|
||||
self.loss_rate_stats.clone(),
|
||||
@@ -432,79 +306,8 @@ impl PeerConn {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let buf = ret.unwrap();
|
||||
let p = Packet::decode(&buf);
|
||||
match p.packet_type {
|
||||
ArchivedPacketType::Ping => {
|
||||
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
|
||||
else {
|
||||
log::error!("unexpected packet: {:?}", p);
|
||||
continue;
|
||||
};
|
||||
|
||||
let pong = packet::Packet::new_pong_packet(
|
||||
conn_info.my_peer_id,
|
||||
conn_info.peer_id,
|
||||
seq.into(),
|
||||
);
|
||||
|
||||
if let Err(e) = sink.send(pong.into()).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
}
|
||||
ArchivedPacketType::Pong => {
|
||||
if let Err(e) = ctrl_sender.send(buf.into()) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if sender.send(buf.into()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
if let Err(close_ret) = sink.close().await {
|
||||
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
|
||||
}
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
|
||||
self.sink.send(msg).await
|
||||
pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> {
|
||||
Ok(self.sink.send(msg).await?)
|
||||
}
|
||||
|
||||
pub fn get_peer_id(&self) -> PeerId {
|
||||
@@ -512,7 +315,17 @@ impl PeerConn {
|
||||
}
|
||||
|
||||
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||
self.info.as_ref().unwrap().network_identity.clone()
|
||||
let info = self.info.as_ref().unwrap();
|
||||
let mut ret = NetworkIdentity {
|
||||
network_name: info.network_name.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
ret.network_secret_digest = Some([0u8; 32]);
|
||||
ret.network_secret_digest
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.copy_from_slice(&info.network_secret_digrest);
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||
@@ -537,34 +350,13 @@ impl PeerConn {
|
||||
my_peer_id: self.my_peer_id,
|
||||
peer_id: self.get_peer_id(),
|
||||
features: self.info.as_ref().unwrap().features.clone(),
|
||||
tunnel: self.tunnel.info(),
|
||||
tunnel: self.tunnel_info.clone(),
|
||||
stats: Some(self.get_stats()),
|
||||
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PeerConn {
|
||||
fn drop(&mut self) {
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
tokio::spawn(async move {
|
||||
let ret = sink.close().await;
|
||||
tracing::info!(error = ?ret, "peer conn tunnel closed.");
|
||||
});
|
||||
log::info!("peer conn {:?} drop", self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for PeerConn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConn")
|
||||
.field("conn_id", &self.conn_id)
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
@@ -572,12 +364,12 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::common::global_ctx::tests::get_mock_global_ctx;
|
||||
use crate::common::new_peer_id;
|
||||
use crate::tunnels::tunnel_filter::tests::DropSendTunnelFilter;
|
||||
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
|
||||
use crate::tunnel::filter::tests::DropSendTunnelFilter;
|
||||
use crate::tunnel::filter::PacketRecorderTunnelFilter;
|
||||
use crate::tunnel::ring::create_ring_tunnel_pair;
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_handshake() {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
@@ -614,7 +406,6 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
// drop 1-3 packets should not affect pingpong
|
||||
@@ -633,7 +424,9 @@ mod tests {
|
||||
);
|
||||
|
||||
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
|
||||
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
s_peer
|
||||
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
||||
.await;
|
||||
|
||||
assert!(c_ret.is_ok());
|
||||
assert!(s_ret.is_ok());
|
||||
@@ -641,7 +434,9 @@ mod tests {
|
||||
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
|
||||
c_peer.set_close_event_sender(close_send);
|
||||
c_peer.start_pingpong();
|
||||
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
c_peer
|
||||
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
||||
.await;
|
||||
|
||||
// wait 5s, conn should not be disconnected
|
||||
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||
@@ -658,4 +453,19 @@ mod tests {
|
||||
peer_conn_pingpong_test_common(3, 5, false).await;
|
||||
peer_conn_pingpong_test_common(5, 12, true).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_tunnel_during_handshake() {
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
let mut c_peer = PeerConn::new(new_peer_id(), get_mock_global_ctx(), Box::new(c));
|
||||
let j = tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
drop(s);
|
||||
});
|
||||
timeout(Duration::from_millis(1500), c_peer.do_handshake_as_client())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
let _ = tokio::join!(j);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user