mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 18:24:36 +00:00
fix handshake dead lock, clean old code (#61)
* fix handshake dead lock * remove old code
This commit is contained in:
@@ -2,7 +2,7 @@ use std::{io, result};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{tunnel, tunnels};
|
||||
use crate::tunnel;
|
||||
|
||||
use super::PeerId;
|
||||
|
||||
@@ -13,7 +13,7 @@ pub enum Error {
|
||||
#[error("rust tun error {0}")]
|
||||
TunError(#[from] tun::Error),
|
||||
#[error("tunnel error {0}")]
|
||||
TunnelError(#[from] tunnels::TunnelError),
|
||||
TunnelError(#[from] tunnel::TunnelError),
|
||||
#[error("Peer has no conn, PeerId: {0}")]
|
||||
PeerNoConnectionError(PeerId),
|
||||
#[error("RouteError: {0:?}")]
|
||||
@@ -42,9 +42,6 @@ pub enum Error {
|
||||
#[error("wait resp error: {0}")]
|
||||
WaitRespError(String),
|
||||
|
||||
#[error("tunnel error")]
|
||||
TunnelErr(#[from] tunnel::TunnelError),
|
||||
|
||||
#[error("message decode error: {0}")]
|
||||
MessageDecodeError(String),
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ use tokio::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
common::PeerId, peers::zc_peer_conn::PeerConnId, rpc as easytier_rpc, tunnel::TunnelConnector,
|
||||
common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc, tunnel::TunnelConnector,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::{
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
|
||||
tunnel::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
quic::QUICTunnelConnector,
|
||||
ring::RingTunnelConnector,
|
||||
tcp::TcpTunnelConnector,
|
||||
@@ -50,8 +51,7 @@ pub async fn create_connector_by_url(
|
||||
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
||||
match url.scheme() {
|
||||
"tcp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
|
||||
let mut connector = TcpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
@@ -62,8 +62,7 @@ pub async fn create_connector_by_url(
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"udp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
|
||||
let mut connector = UdpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
@@ -74,13 +73,12 @@ pub async fn create_connector_by_url(
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"ring" => {
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
|
||||
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
|
||||
let connector = RingTunnelConnector::new(url);
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"quic" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic")?;
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic")?;
|
||||
let mut connector = QUICTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
@@ -91,8 +89,7 @@ pub async fn create_connector_by_url(
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"wg" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
|
||||
let nid = global_ctx.get_network_identity();
|
||||
let wg_config = WgConfig::new_from_network_identity(
|
||||
&nid.network_name,
|
||||
|
||||
@@ -10,7 +10,6 @@ mod arch;
|
||||
mod common;
|
||||
mod rpc;
|
||||
mod tunnel;
|
||||
mod tunnels;
|
||||
mod utils;
|
||||
|
||||
use crate::{
|
||||
|
||||
@@ -17,7 +17,6 @@ mod peer_center;
|
||||
mod peers;
|
||||
mod rpc;
|
||||
mod tunnel;
|
||||
mod tunnels;
|
||||
mod vpn_portal;
|
||||
|
||||
use common::{
|
||||
|
||||
@@ -26,8 +26,10 @@ use tracing::Level;
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||
peers::{peer_manager::PeerManager, PeerPacketFilter},
|
||||
tunnel::packet_def::{PacketType, ZCPacket},
|
||||
tunnels::common::setup_sokcet2,
|
||||
tunnel::{
|
||||
common::setup_sokcet2,
|
||||
packet_def::{PacketType, ZCPacket},
|
||||
},
|
||||
};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
@@ -22,9 +22,9 @@ use crate::gateway::icmp_proxy::IcmpProxy;
|
||||
use crate::gateway::tcp_proxy::TcpProxy;
|
||||
use crate::gateway::udp_proxy::UdpProxy;
|
||||
use crate::peer_center::instance::PeerCenterInstance;
|
||||
use crate::peers::peer_conn::PeerConnId;
|
||||
use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
|
||||
use crate::peers::rpc_service::PeerManagerRpcService;
|
||||
use crate::peers::zc_peer_conn::PeerConnId;
|
||||
use crate::peers::PacketRecvChanReceiver;
|
||||
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
|
||||
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
|
||||
|
||||
@@ -9,6 +9,5 @@ pub mod peer_center;
|
||||
pub mod peers;
|
||||
pub mod rpc;
|
||||
pub mod tunnel;
|
||||
pub mod tunnels;
|
||||
pub mod utils;
|
||||
pub mod vpn_portal;
|
||||
|
||||
@@ -137,7 +137,7 @@ impl Encryptor for AesGcmCipher {
|
||||
mod tests {
|
||||
use crate::{
|
||||
peers::encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor},
|
||||
tunnel::packet_def::{ZCPacket, ZCPacketType, AES_GCM_ENCRYPTION_RESERVED},
|
||||
tunnel::packet_def::{ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
|
||||
};
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -17,9 +17,9 @@ use crate::{
|
||||
|
||||
use super::{
|
||||
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
|
||||
peer_conn::PeerConn,
|
||||
peer_map::PeerMap,
|
||||
peer_rpc::PeerRpcManager,
|
||||
zc_peer_conn::PeerConn,
|
||||
PacketRecvChan,
|
||||
};
|
||||
|
||||
|
||||
@@ -26,9 +26,9 @@ use crate::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
peer_conn::PeerConn,
|
||||
peer_map::PeerMap,
|
||||
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
|
||||
zc_peer_conn::PeerConn,
|
||||
PacketRecvChan, PacketRecvChanReceiver,
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod packet;
|
||||
pub mod peer;
|
||||
// pub mod peer_conn;
|
||||
pub mod peer_conn;
|
||||
pub mod peer_conn_ping;
|
||||
pub mod peer_manager;
|
||||
pub mod peer_map;
|
||||
@@ -9,7 +10,6 @@ pub mod peer_rip_route;
|
||||
pub mod peer_rpc;
|
||||
pub mod route_trait;
|
||||
pub mod rpc_service;
|
||||
pub mod zc_peer_conn;
|
||||
|
||||
pub mod foreign_network_client;
|
||||
pub mod foreign_network_manager;
|
||||
|
||||
@@ -8,7 +8,7 @@ use tokio::{select, sync::mpsc, task::JoinHandle};
|
||||
use tracing::Instrument;
|
||||
|
||||
use super::{
|
||||
zc_peer_conn::{PeerConn, PeerConnId},
|
||||
peer_conn::{PeerConn, PeerConnId},
|
||||
PacketRecvChan,
|
||||
};
|
||||
use crate::rpc::PeerConnInfo;
|
||||
@@ -175,7 +175,7 @@ mod tests {
|
||||
|
||||
use crate::{
|
||||
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
|
||||
peers::zc_peer_conn::PeerConn,
|
||||
peers::peer_conn::PeerConn,
|
||||
tunnel::ring::create_ring_tunnel_pair,
|
||||
};
|
||||
|
||||
|
||||
+228
-418
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
fmt::Debug,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
@@ -7,8 +8,9 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use pnet::datalink::NetworkInterface;
|
||||
use futures::{SinkExt, StreamExt, TryFutureExt};
|
||||
|
||||
use prost::Message;
|
||||
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, Mutex},
|
||||
@@ -16,293 +18,34 @@ use tokio::{
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
|
||||
use tokio_util::{bytes::Bytes, sync::PollSender};
|
||||
use tokio_util::sync::PollSender;
|
||||
use tracing::Instrument;
|
||||
use zerocopy::AsBytes;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
config::{NetworkIdentity, NetworkSecretDigest},
|
||||
error::Error,
|
||||
global_ctx::ArcGlobalCtx,
|
||||
PeerId,
|
||||
},
|
||||
define_tunnel_filter_chain,
|
||||
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
|
||||
rpc::{PeerConnInfo, PeerConnStats},
|
||||
tunnels::{
|
||||
peers::packet::PacketType,
|
||||
rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo},
|
||||
tunnel::{
|
||||
filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter},
|
||||
mpsc::{MpscTunnel, MpscTunnelSender},
|
||||
packet_def::ZCPacket,
|
||||
stats::{Throughput, WindowLatency},
|
||||
tunnel_filter::StatsRecorderTunnelFilter,
|
||||
DatagramSink, Tunnel, TunnelError,
|
||||
Tunnel, TunnelError, ZCPacketStream,
|
||||
},
|
||||
};
|
||||
|
||||
use super::packet::{self, HandShake, Packet};
|
||||
|
||||
pub type PacketRecvChan = mpsc::Sender<Bytes>;
|
||||
use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan};
|
||||
|
||||
pub type PeerConnId = uuid::Uuid;
|
||||
|
||||
macro_rules! wait_response {
|
||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||
let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response timeout".to_owned(),
|
||||
));
|
||||
};
|
||||
let Some(rsp_vec) = rsp_vec else {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response get none".to_owned(),
|
||||
));
|
||||
};
|
||||
let Ok(rsp_vec) = rsp_vec else {
|
||||
return Err(TunnelError::WaitRespError(format!(
|
||||
"wait handshake response get error {}",
|
||||
rsp_vec.err().unwrap()
|
||||
)));
|
||||
};
|
||||
|
||||
let $out_var;
|
||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||
if rsp_bytes.packet_type != PacketType::HandShake {
|
||||
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"unexpected packet type".to_owned(),
|
||||
));
|
||||
}
|
||||
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
|
||||
match &resp_payload {
|
||||
$pattern => $out_var = $value,
|
||||
_ => {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, pattern: {:?}",
|
||||
rsp_bytes,
|
||||
stringify!($pattern)
|
||||
);
|
||||
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PeerInfo {
|
||||
magic: u32,
|
||||
pub my_peer_id: PeerId,
|
||||
version: u32,
|
||||
pub features: Vec<String>,
|
||||
pub interfaces: Vec<NetworkInterface>,
|
||||
pub network_identity: NetworkIdentity,
|
||||
}
|
||||
|
||||
impl<'a> From<&HandShake> for PeerInfo {
|
||||
fn from(hs: &HandShake) -> Self {
|
||||
PeerInfo {
|
||||
magic: hs.magic.into(),
|
||||
my_peer_id: hs.my_peer_id.into(),
|
||||
version: hs.version.into(),
|
||||
features: hs.features.iter().map(|x| x.to_string()).collect(),
|
||||
interfaces: Vec::new(),
|
||||
network_identity: hs.network_identity.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PeerConnPinger {
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
}
|
||||
|
||||
impl Debug for PeerConnPinger {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConnPinger")
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("peer_id", &self.peer_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerConnPinger {
|
||||
pub fn new(
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
) -> Self {
|
||||
Self {
|
||||
my_peer_id,
|
||||
peer_id,
|
||||
sink: Arc::new(Mutex::new(sink)),
|
||||
tasks: JoinSet::new(),
|
||||
latency_stats,
|
||||
ctrl_sender,
|
||||
loss_rate_stats,
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_pingpong_once(
|
||||
my_node_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
receiver: &mut broadcast::Receiver<Bytes>,
|
||||
seq: u32,
|
||||
) -> Result<u128, TunnelError> {
|
||||
// should add seq here. so latency can be calculated more accurately
|
||||
let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into();
|
||||
tracing::trace!("send ping packet: {:?}", req);
|
||||
sink.lock().await.send(req).await.map_err(|e| {
|
||||
tracing::warn!("send ping packet error: {:?}", e);
|
||||
TunnelError::CommonError("send ping packet error".to_owned())
|
||||
})?;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// wait until we get a pong packet in ctrl_resp_receiver
|
||||
let resp = timeout(Duration::from_secs(1), async {
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(p) => {
|
||||
let ctrl_payload =
|
||||
packet::CtrlPacketPayload::from_packet(Packet::decode(&p));
|
||||
if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload {
|
||||
if resp_seq == seq {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("recv pong resp error: {:?}", e);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"recv pong resp error".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
tracing::trace!(?resp, "wait ping response done");
|
||||
|
||||
if resp.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait ping response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
if resp.as_ref().unwrap().is_err() {
|
||||
return Err(resp.unwrap().err().unwrap());
|
||||
}
|
||||
|
||||
Ok(now.elapsed().as_micros())
|
||||
}
|
||||
|
||||
async fn pingpong(&mut self) {
|
||||
let sink = self.sink.clone();
|
||||
let my_node_id = self.my_peer_id;
|
||||
let peer_id = self.peer_id;
|
||||
let latency_stats = self.latency_stats.clone();
|
||||
|
||||
let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100);
|
||||
|
||||
let stopped = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// generate a pingpong task every 200ms
|
||||
let mut pingpong_tasks = JoinSet::new();
|
||||
let ctrl_resp_sender = self.ctrl_sender.clone();
|
||||
let stopped_clone = stopped.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut req_seq = 0;
|
||||
loop {
|
||||
let receiver = ctrl_resp_sender.subscribe();
|
||||
let ping_res_sender = ping_res_sender.clone();
|
||||
let sink = sink.clone();
|
||||
|
||||
if stopped_clone.load(Ordering::Relaxed) != 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while pingpong_tasks.len() > 5 {
|
||||
pingpong_tasks.join_next().await;
|
||||
}
|
||||
|
||||
pingpong_tasks.spawn(async move {
|
||||
let mut receiver = receiver.resubscribe();
|
||||
let pingpong_once_ret = Self::do_pingpong_once(
|
||||
my_node_id,
|
||||
peer_id,
|
||||
sink.clone(),
|
||||
&mut receiver,
|
||||
req_seq,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = ping_res_sender.send(pingpong_once_ret).await {
|
||||
tracing::info!(?e, "pingpong task send result error, exit..");
|
||||
};
|
||||
});
|
||||
|
||||
req_seq = req_seq.wrapping_add(1);
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
}
|
||||
});
|
||||
|
||||
// one with 1% precision
|
||||
let loss_rate_stats_1 = WindowLatency::new(100);
|
||||
// one with 20% precision, so we can fast fail this conn.
|
||||
let loss_rate_stats_20 = WindowLatency::new(5);
|
||||
|
||||
let mut counter: u64 = 0;
|
||||
|
||||
while let Some(ret) = ping_res_receiver.recv().await {
|
||||
counter += 1;
|
||||
|
||||
if let Ok(lat) = ret {
|
||||
latency_stats.record_latency(lat as u32);
|
||||
|
||||
loss_rate_stats_1.record_latency(0);
|
||||
loss_rate_stats_20.record_latency(0);
|
||||
} else {
|
||||
loss_rate_stats_1.record_latency(1);
|
||||
loss_rate_stats_20.record_latency(1);
|
||||
}
|
||||
|
||||
let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us();
|
||||
let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us();
|
||||
|
||||
tracing::trace!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong task recv pingpong_once result"
|
||||
);
|
||||
|
||||
if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) {
|
||||
tracing::warn!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong loss rate too high, closing"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
self.loss_rate_stats
|
||||
.store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
stopped.store(1, Ordering::Relaxed);
|
||||
ping_res_receiver.close();
|
||||
}
|
||||
}
|
||||
|
||||
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
|
||||
const MAGIC: u32 = 0xd1e1a5e1;
|
||||
const VERSION: u32 = 1;
|
||||
|
||||
pub struct PeerConn {
|
||||
conn_id: PeerConnId,
|
||||
@@ -310,33 +53,45 @@ pub struct PeerConn {
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
tunnel: Arc<Mutex<Box<dyn Any + Send + 'static>>>,
|
||||
sink: MpscTunnelSender,
|
||||
recv: Arc<Mutex<Option<Pin<Box<dyn ZCPacketStream>>>>>,
|
||||
tunnel_info: Option<TunnelInfo>,
|
||||
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
|
||||
info: Option<PeerInfo>,
|
||||
info: Option<HandshakeRequest>,
|
||||
|
||||
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||
|
||||
ctrl_resp_sender: broadcast::Sender<Bytes>,
|
||||
ctrl_resp_sender: broadcast::Sender<ZCPacket>,
|
||||
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
throughput: Arc<Throughput>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
enum PeerConnPacketType {
|
||||
Data(Bytes),
|
||||
CtrlReq(Bytes),
|
||||
CtrlResp(Bytes),
|
||||
impl Debug for PeerConn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConn")
|
||||
.field("conn_id", &self.conn_id)
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerConn {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let tunnel_info = tunnel.info();
|
||||
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||
let peer_conn_tunnel = PeerConnTunnel::new();
|
||||
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
|
||||
|
||||
let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new();
|
||||
let throughput = peer_conn_tunnel_filter.filter_output();
|
||||
let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter);
|
||||
let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel);
|
||||
|
||||
let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink());
|
||||
|
||||
PeerConn {
|
||||
conn_id: PeerConnId::new_v4(),
|
||||
@@ -344,8 +99,10 @@ impl PeerConn {
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
|
||||
sink: tunnel.pin_sink(),
|
||||
tunnel: Box::new(tunnel),
|
||||
tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))),
|
||||
sink,
|
||||
recv: Arc::new(Mutex::new(Some(recv))),
|
||||
tunnel_info,
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
@@ -355,7 +112,7 @@ impl PeerConn {
|
||||
ctrl_resp_sender: ctrl_sender,
|
||||
|
||||
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
|
||||
throughput,
|
||||
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||
}
|
||||
}
|
||||
@@ -364,41 +121,97 @@ impl PeerConn {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
async fn wait_handshake(&mut self, need_retry: &mut bool) -> Result<HandshakeRequest, Error> {
|
||||
*need_retry = false;
|
||||
|
||||
tracing::info!("waiting for handshake request from client");
|
||||
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_req));
|
||||
tracing::info!("handshake request: {:?}", hs_req);
|
||||
let mut locked = self.recv.lock().await;
|
||||
let recv = locked.as_mut().unwrap();
|
||||
let Some(rsp) = recv.next().await else {
|
||||
return Err(Error::WaitRespError(
|
||||
"conn closed during wait handshake response".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
*need_retry = true;
|
||||
|
||||
let rsp = rsp?;
|
||||
let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| {
|
||||
Error::WaitRespError(format!("decode handshake response error: {:?}", e))
|
||||
})?;
|
||||
|
||||
if rsp.network_secret_digrest.len() != std::mem::size_of::<NetworkSecretDigest>() {
|
||||
return Err(Error::WaitRespError(
|
||||
"invalid network secret digest".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
return Ok(rsp);
|
||||
}
|
||||
|
||||
async fn wait_handshake_loop(&mut self) -> Result<HandshakeRequest, Error> {
|
||||
timeout(Duration::from_secs(5), async move {
|
||||
loop {
|
||||
let mut need_retry = true;
|
||||
match self.wait_handshake(&mut need_retry).await {
|
||||
Ok(rsp) => return Ok(rsp),
|
||||
Err(e) => {
|
||||
log::warn!("wait handshake error: {:?}", e);
|
||||
if !need_retry {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e)))
|
||||
.await?
|
||||
}
|
||||
|
||||
async fn send_handshake(&mut self) -> Result<(), Error> {
|
||||
let network = self.global_ctx.get_network_identity();
|
||||
let mut req = HandshakeRequest {
|
||||
magic: MAGIC,
|
||||
my_peer_id: self.my_peer_id,
|
||||
version: VERSION,
|
||||
features: Vec::new(),
|
||||
network_name: network.network_name.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
req.network_secret_digrest
|
||||
.extend_from_slice(&network.network_secret_digest.unwrap_or_default());
|
||||
|
||||
let hs_req = req.encode_to_vec();
|
||||
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
|
||||
zc_packet.fill_peer_manager_hdr(
|
||||
self.my_peer_id,
|
||||
PeerId::default(),
|
||||
PacketType::HandShake as u8,
|
||||
);
|
||||
|
||||
self.sink.send(zc_packet).await.map_err(|e| {
|
||||
tracing::warn!("send handshake request error: {:?}", e);
|
||||
Error::WaitRespError("send handshake request error".to_owned())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> {
|
||||
let rsp = self.wait_handshake_loop().await?;
|
||||
tracing::info!("handshake request: {:?}", rsp);
|
||||
self.info = Some(rsp);
|
||||
self.send_handshake().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> {
|
||||
self.send_handshake().await?;
|
||||
tracing::info!("waiting for handshake request from server");
|
||||
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_rsp));
|
||||
tracing::info!("handshake response: {:?}", hs_rsp);
|
||||
|
||||
let rsp = self.wait_handshake_loop().await?;
|
||||
tracing::info!("handshake response: {:?}", rsp);
|
||||
self.info = Some(rsp);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -406,11 +219,72 @@ impl PeerConn {
|
||||
self.info.is_some()
|
||||
}
|
||||
|
||||
pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.recv.lock().await.take().unwrap();
|
||||
let sink = self.sink.clone();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let _conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let mut zc_packet = ret.unwrap();
|
||||
let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, cannot decode peer manager hdr",
|
||||
zc_packet
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
if peer_mgr_hdr.packet_type == PacketType::Ping as u8 {
|
||||
peer_mgr_hdr.packet_type = PacketType::Pong as u8;
|
||||
if let Err(e) = sink.send(zc_packet).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
} else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 {
|
||||
if let Err(e) = ctrl_sender.send(zc_packet) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
} else {
|
||||
if sender.send(zc_packet).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
drop(sink);
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn start_pingpong(&mut self) {
|
||||
let mut pingpong = PeerConnPinger::new(
|
||||
self.my_peer_id,
|
||||
self.get_peer_id(),
|
||||
self.tunnel.pin_sink(),
|
||||
self.sink.clone(),
|
||||
self.ctrl_resp_sender.clone(),
|
||||
self.latency_stats.clone(),
|
||||
self.loss_rate_stats.clone(),
|
||||
@@ -432,79 +306,8 @@ impl PeerConn {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let buf = ret.unwrap();
|
||||
let p = Packet::decode(&buf);
|
||||
match p.packet_type {
|
||||
ArchivedPacketType::Ping => {
|
||||
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
|
||||
else {
|
||||
log::error!("unexpected packet: {:?}", p);
|
||||
continue;
|
||||
};
|
||||
|
||||
let pong = packet::Packet::new_pong_packet(
|
||||
conn_info.my_peer_id,
|
||||
conn_info.peer_id,
|
||||
seq.into(),
|
||||
);
|
||||
|
||||
if let Err(e) = sink.send(pong.into()).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
}
|
||||
ArchivedPacketType::Pong => {
|
||||
if let Err(e) = ctrl_sender.send(buf.into()) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if sender.send(buf.into()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
if let Err(close_ret) = sink.close().await {
|
||||
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
|
||||
}
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
|
||||
self.sink.send(msg).await
|
||||
pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> {
|
||||
Ok(self.sink.send(msg).await?)
|
||||
}
|
||||
|
||||
pub fn get_peer_id(&self) -> PeerId {
|
||||
@@ -512,7 +315,17 @@ impl PeerConn {
|
||||
}
|
||||
|
||||
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||
self.info.as_ref().unwrap().network_identity.clone()
|
||||
let info = self.info.as_ref().unwrap();
|
||||
let mut ret = NetworkIdentity {
|
||||
network_name: info.network_name.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
ret.network_secret_digest = Some([0u8; 32]);
|
||||
ret.network_secret_digest
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.copy_from_slice(&info.network_secret_digrest);
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||
@@ -537,34 +350,13 @@ impl PeerConn {
|
||||
my_peer_id: self.my_peer_id,
|
||||
peer_id: self.get_peer_id(),
|
||||
features: self.info.as_ref().unwrap().features.clone(),
|
||||
tunnel: self.tunnel.info(),
|
||||
tunnel: self.tunnel_info.clone(),
|
||||
stats: Some(self.get_stats()),
|
||||
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PeerConn {
|
||||
fn drop(&mut self) {
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
tokio::spawn(async move {
|
||||
let ret = sink.close().await;
|
||||
tracing::info!(error = ?ret, "peer conn tunnel closed.");
|
||||
});
|
||||
log::info!("peer conn {:?} drop", self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for PeerConn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConn")
|
||||
.field("conn_id", &self.conn_id)
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
@@ -572,12 +364,12 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::common::global_ctx::tests::get_mock_global_ctx;
|
||||
use crate::common::new_peer_id;
|
||||
use crate::tunnels::tunnel_filter::tests::DropSendTunnelFilter;
|
||||
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
|
||||
use crate::tunnel::filter::tests::DropSendTunnelFilter;
|
||||
use crate::tunnel::filter::PacketRecorderTunnelFilter;
|
||||
use crate::tunnel::ring::create_ring_tunnel_pair;
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_handshake() {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
@@ -614,7 +406,6 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
// drop 1-3 packets should not affect pingpong
|
||||
@@ -633,7 +424,9 @@ mod tests {
|
||||
);
|
||||
|
||||
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
|
||||
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
s_peer
|
||||
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
||||
.await;
|
||||
|
||||
assert!(c_ret.is_ok());
|
||||
assert!(s_ret.is_ok());
|
||||
@@ -641,7 +434,9 @@ mod tests {
|
||||
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
|
||||
c_peer.set_close_event_sender(close_send);
|
||||
c_peer.start_pingpong();
|
||||
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
c_peer
|
||||
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
||||
.await;
|
||||
|
||||
// wait 5s, conn should not be disconnected
|
||||
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||
@@ -658,4 +453,19 @@ mod tests {
|
||||
peer_conn_pingpong_test_common(3, 5, false).await;
|
||||
peer_conn_pingpong_test_common(5, 12, true).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_tunnel_during_handshake() {
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
let mut c_peer = PeerConn::new(new_peer_id(), get_mock_global_ctx(), Box::new(c));
|
||||
let j = tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
drop(s);
|
||||
});
|
||||
timeout(Duration::from_millis(1500), c_peer.do_handshake_as_client())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
let _ = tokio::join!(j);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ use tokio_util::bytes::Bytes;
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||
peers::{
|
||||
packet, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface,
|
||||
zc_peer_conn::PeerConn, PeerPacketFilter,
|
||||
packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport,
|
||||
route_trait::RouteInterface, PeerPacketFilter,
|
||||
},
|
||||
tunnel::{
|
||||
packet_def::{PacketType, ZCPacket},
|
||||
@@ -35,12 +35,12 @@ use super::{
|
||||
encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor, NullCipher},
|
||||
foreign_network_client::ForeignNetworkClient,
|
||||
foreign_network_manager::ForeignNetworkManager,
|
||||
peer_conn::PeerConnId,
|
||||
peer_map::PeerMap,
|
||||
peer_ospf_route::PeerRoute,
|
||||
peer_rip_route::BasicRoute,
|
||||
peer_rpc::PeerRpcManager,
|
||||
route_trait::{ArcRoute, Route},
|
||||
zc_peer_conn::PeerConnId,
|
||||
BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChanReceiver,
|
||||
};
|
||||
|
||||
|
||||
@@ -12,13 +12,13 @@ use crate::{
|
||||
},
|
||||
rpc::PeerConnInfo,
|
||||
tunnel::packet_def::ZCPacket,
|
||||
tunnels::TunnelError,
|
||||
tunnel::TunnelError,
|
||||
};
|
||||
|
||||
use super::{
|
||||
peer::Peer,
|
||||
peer_conn::{PeerConn, PeerConnId},
|
||||
route_trait::ArcRoute,
|
||||
zc_peer_conn::{PeerConn, PeerConnId},
|
||||
PacketRecvChan,
|
||||
};
|
||||
|
||||
|
||||
@@ -1,769 +0,0 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
fmt::Debug,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt, TryFutureExt};
|
||||
|
||||
use prost::Message;
|
||||
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
|
||||
use tokio_util::sync::PollSender;
|
||||
use tracing::Instrument;
|
||||
use zerocopy::AsBytes;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
config::{NetworkIdentity, NetworkSecretDigest},
|
||||
error::Error,
|
||||
global_ctx::ArcGlobalCtx,
|
||||
PeerId,
|
||||
},
|
||||
peers::packet::PacketType,
|
||||
rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo},
|
||||
tunnel::{
|
||||
filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter},
|
||||
mpsc::{MpscTunnel, MpscTunnelSender},
|
||||
packet_def::ZCPacket,
|
||||
stats::{Throughput, WindowLatency},
|
||||
Tunnel, TunnelError, ZCPacketStream,
|
||||
},
|
||||
};
|
||||
|
||||
use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan};
|
||||
|
||||
pub type PeerConnId = uuid::Uuid;
|
||||
|
||||
const MAGIC: u32 = 0xd1e1a5e1;
|
||||
const VERSION: u32 = 1;
|
||||
|
||||
pub struct PeerConn {
|
||||
conn_id: PeerConnId,
|
||||
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
tunnel: Arc<Mutex<Box<dyn Any + Send + 'static>>>,
|
||||
sink: MpscTunnelSender,
|
||||
recv: Arc<Mutex<Option<Pin<Box<dyn ZCPacketStream>>>>>,
|
||||
tunnel_info: Option<TunnelInfo>,
|
||||
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
|
||||
info: Option<HandshakeRequest>,
|
||||
|
||||
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||
|
||||
ctrl_resp_sender: broadcast::Sender<ZCPacket>,
|
||||
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
throughput: Arc<Throughput>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl Debug for PeerConn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConn")
|
||||
.field("conn_id", &self.conn_id)
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerConn {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let tunnel_info = tunnel.info();
|
||||
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||
|
||||
let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new();
|
||||
let throughput = peer_conn_tunnel_filter.filter_output();
|
||||
let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter);
|
||||
let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel);
|
||||
|
||||
let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink());
|
||||
|
||||
PeerConn {
|
||||
conn_id: PeerConnId::new_v4(),
|
||||
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
|
||||
tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))),
|
||||
sink,
|
||||
recv: Arc::new(Mutex::new(Some(recv))),
|
||||
tunnel_info,
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
info: None,
|
||||
close_event_sender: None,
|
||||
|
||||
ctrl_resp_sender: ctrl_sender,
|
||||
|
||||
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||
throughput,
|
||||
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_id(&self) -> PeerConnId {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
async fn wait_handshake(&mut self) -> Result<HandshakeRequest, Error> {
|
||||
let mut locked = self.recv.lock().await;
|
||||
let recv = locked.as_mut().unwrap();
|
||||
let Some(rsp) = recv.next().await else {
|
||||
return Err(Error::WaitRespError(
|
||||
"conn closed during wait handshake response".to_owned(),
|
||||
));
|
||||
};
|
||||
let rsp = rsp?;
|
||||
let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| {
|
||||
Error::WaitRespError(format!("decode handshake response error: {:?}", e))
|
||||
})?;
|
||||
|
||||
if rsp.network_secret_digrest.len() != std::mem::size_of::<NetworkSecretDigest>() {
|
||||
return Err(Error::WaitRespError(
|
||||
"invalid network secret digest".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
return Ok(rsp);
|
||||
}
|
||||
|
||||
async fn wait_handshake_loop(&mut self) -> Result<HandshakeRequest, Error> {
|
||||
Ok(timeout(Duration::from_secs(5), async move {
|
||||
loop {
|
||||
match self.wait_handshake().await {
|
||||
Ok(rsp) => return rsp,
|
||||
Err(e) => {
|
||||
log::warn!("wait handshake error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e)))
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn send_handshake(&mut self) -> Result<(), Error> {
|
||||
let network = self.global_ctx.get_network_identity();
|
||||
let mut req = HandshakeRequest {
|
||||
magic: MAGIC,
|
||||
my_peer_id: self.my_peer_id,
|
||||
version: VERSION,
|
||||
features: Vec::new(),
|
||||
network_name: network.network_name.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
req.network_secret_digrest
|
||||
.extend_from_slice(&network.network_secret_digest.unwrap_or_default());
|
||||
|
||||
let hs_req = req.encode_to_vec();
|
||||
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
|
||||
zc_packet.fill_peer_manager_hdr(
|
||||
self.my_peer_id,
|
||||
PeerId::default(),
|
||||
PacketType::HandShake as u8,
|
||||
);
|
||||
|
||||
self.sink.send(zc_packet).await.map_err(|e| {
|
||||
tracing::warn!("send handshake request error: {:?}", e);
|
||||
Error::WaitRespError("send handshake request error".to_owned())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> {
|
||||
let rsp = self.wait_handshake_loop().await?;
|
||||
tracing::info!("handshake request: {:?}", rsp);
|
||||
self.info = Some(rsp);
|
||||
self.send_handshake().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> {
|
||||
self.send_handshake().await?;
|
||||
tracing::info!("waiting for handshake request from server");
|
||||
let rsp = self.wait_handshake_loop().await?;
|
||||
tracing::info!("handshake response: {:?}", rsp);
|
||||
self.info = Some(rsp);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn handshake_done(&self) -> bool {
|
||||
self.info.is_some()
|
||||
}
|
||||
|
||||
pub async fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.recv.lock().await.take().unwrap();
|
||||
let sink = self.sink.clone();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let _conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let mut zc_packet = ret.unwrap();
|
||||
let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, cannot decode peer manager hdr",
|
||||
zc_packet
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
if peer_mgr_hdr.packet_type == PacketType::Ping as u8 {
|
||||
peer_mgr_hdr.packet_type = PacketType::Pong as u8;
|
||||
if let Err(e) = sink.send(zc_packet).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
} else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 {
|
||||
if let Err(e) = ctrl_sender.send(zc_packet) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
} else {
|
||||
if sender.send(zc_packet).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
drop(sink);
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn start_pingpong(&mut self) {
|
||||
let mut pingpong = PeerConnPinger::new(
|
||||
self.my_peer_id,
|
||||
self.get_peer_id(),
|
||||
self.sink.clone(),
|
||||
self.ctrl_resp_sender.clone(),
|
||||
self.latency_stats.clone(),
|
||||
self.loss_rate_stats.clone(),
|
||||
);
|
||||
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
pingpong.pingpong().await;
|
||||
|
||||
tracing::warn!(?pingpong, "pingpong task exit");
|
||||
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
log::warn!("close event sender error: {:?}", e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> {
|
||||
Ok(self.sink.send(msg).await?)
|
||||
}
|
||||
|
||||
pub fn get_peer_id(&self) -> PeerId {
|
||||
self.info.as_ref().unwrap().my_peer_id
|
||||
}
|
||||
|
||||
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||
let info = self.info.as_ref().unwrap();
|
||||
let mut ret = NetworkIdentity {
|
||||
network_name: info.network_name.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
ret.network_secret_digest = Some([0u8; 32]);
|
||||
ret.network_secret_digest
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.copy_from_slice(&info.network_secret_digrest);
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||
self.close_event_sender = Some(sender);
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> PeerConnStats {
|
||||
PeerConnStats {
|
||||
latency_us: self.latency_stats.get_latency_us(),
|
||||
|
||||
tx_bytes: self.throughput.tx_bytes(),
|
||||
rx_bytes: self.throughput.rx_bytes(),
|
||||
|
||||
tx_packets: self.throughput.tx_packets(),
|
||||
rx_packets: self.throughput.rx_packets(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_info(&self) -> PeerConnInfo {
|
||||
PeerConnInfo {
|
||||
conn_id: self.conn_id.to_string(),
|
||||
my_peer_id: self.my_peer_id,
|
||||
peer_id: self.get_peer_id(),
|
||||
features: self.info.as_ref().unwrap().features.clone(),
|
||||
tunnel: self.tunnel_info.clone(),
|
||||
stats: Some(self.get_stats()),
|
||||
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::common::global_ctx::tests::get_mock_global_ctx;
|
||||
use crate::common::new_peer_id;
|
||||
use crate::tunnel::filter::tests::DropSendTunnelFilter;
|
||||
use crate::tunnel::filter::PacketRecorderTunnelFilter;
|
||||
use crate::tunnel::ring::create_ring_tunnel_pair;
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_handshake() {
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
|
||||
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||
let s = TunnelWithFilter::new(s, s_recorder.clone());
|
||||
|
||||
let c_peer_id = new_peer_id();
|
||||
let s_peer_id = new_peer_id();
|
||||
|
||||
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||
|
||||
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||
|
||||
let (c_ret, s_ret) = tokio::join!(
|
||||
c_peer.do_handshake_as_client(),
|
||||
s_peer.do_handshake_as_server()
|
||||
);
|
||||
|
||||
c_ret.unwrap();
|
||||
s_ret.unwrap();
|
||||
|
||||
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(c_peer.get_peer_id(), s_peer_id);
|
||||
assert_eq!(s_peer.get_peer_id(), c_peer_id);
|
||||
assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity());
|
||||
assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default());
|
||||
}
|
||||
|
||||
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
// drop 1-3 packets should not affect pingpong
|
||||
let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end));
|
||||
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||
|
||||
let c_peer_id = new_peer_id();
|
||||
let s_peer_id = new_peer_id();
|
||||
|
||||
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||
|
||||
let (c_ret, s_ret) = tokio::join!(
|
||||
c_peer.do_handshake_as_client(),
|
||||
s_peer.do_handshake_as_server()
|
||||
);
|
||||
|
||||
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
|
||||
s_peer
|
||||
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
||||
.await;
|
||||
|
||||
assert!(c_ret.is_ok());
|
||||
assert!(s_ret.is_ok());
|
||||
|
||||
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
|
||||
c_peer.set_close_event_sender(close_send);
|
||||
c_peer.start_pingpong();
|
||||
c_peer
|
||||
.start_recv_loop(tokio::sync::mpsc::channel(200).0)
|
||||
.await;
|
||||
|
||||
// wait 5s, conn should not be disconnected
|
||||
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||
|
||||
if conn_closed {
|
||||
assert!(close_recv.try_recv().is_ok());
|
||||
} else {
|
||||
assert!(close_recv.try_recv().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_pingpong_timeout() {
|
||||
peer_conn_pingpong_test_common(3, 5, false).await;
|
||||
peer_conn_pingpong_test_common(5, 12, true).await;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use pnet::datalink::NetworkInterface;
|
||||
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
|
||||
use tokio_util::{bytes::Bytes, sync::PollSender};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
PeerId,
|
||||
},
|
||||
define_tunnel_filter_chain,
|
||||
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
|
||||
rpc::{PeerConnInfo, PeerConnStats},
|
||||
tunnel::{mpsc::MpscTunnelSender, stats::WindowLatency, TunnelError},
|
||||
};
|
||||
|
||||
use super::packet::{self, HandShake, Packet};
|
||||
|
||||
pub type PacketRecvChan = mpsc::Sender<Bytes>;
|
||||
|
||||
macro_rules! wait_response {
|
||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||
let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
|
||||
return Err(Error::WaitRespError(
|
||||
"wait handshake response timeout".to_owned(),
|
||||
));
|
||||
};
|
||||
let Some(rsp_vec) = rsp_vec else {
|
||||
return Err(Error::WaitRespError(
|
||||
"wait handshake response get none".to_owned(),
|
||||
));
|
||||
};
|
||||
let Ok(rsp_vec) = rsp_vec else {
|
||||
return Err(Error::WaitRespError(format!(
|
||||
"wait handshake response get error {}",
|
||||
rsp_vec.err().unwrap()
|
||||
)));
|
||||
};
|
||||
|
||||
let $out_var;
|
||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||
if rsp_bytes.packet_type != PacketType::HandShake {
|
||||
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
|
||||
return Err(Error::WaitRespError("unexpected packet type".to_owned()));
|
||||
}
|
||||
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
|
||||
match &resp_payload {
|
||||
$pattern => $out_var = $value,
|
||||
_ => {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, pattern: {:?}",
|
||||
rsp_bytes,
|
||||
stringify!($pattern)
|
||||
);
|
||||
return Err(Error::WaitRespError("unexpected packet".to_owned()));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl<'a> From<&HandShake> for PeerInfo {
|
||||
fn from(hs: &HandShake) -> Self {
|
||||
PeerInfo {
|
||||
magic: hs.magic.into(),
|
||||
my_peer_id: hs.my_peer_id.into(),
|
||||
version: hs.version.into(),
|
||||
features: hs.features.iter().map(|x| x.to_string()).collect(),
|
||||
interfaces: Vec::new(),
|
||||
network_identity: hs.network_identity.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
|
||||
|
||||
pub struct PeerConn {
|
||||
conn_id: PeerConnId,
|
||||
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
|
||||
info: Option<PeerInfo>,
|
||||
|
||||
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||
|
||||
ctrl_resp_sender: broadcast::Sender<Bytes>,
|
||||
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
throughput: Arc<Throughput>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
enum PeerConnPacketType {
|
||||
Data(Bytes),
|
||||
CtrlReq(Bytes),
|
||||
CtrlResp(Bytes),
|
||||
}
|
||||
|
||||
impl PeerConn {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||
let peer_conn_tunnel = PeerConnTunnel::new();
|
||||
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
|
||||
|
||||
PeerConn {
|
||||
conn_id: PeerConnId::new_v4(),
|
||||
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
|
||||
sink: tunnel.pin_sink(),
|
||||
tunnel: Box::new(tunnel),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
info: None,
|
||||
close_event_sender: None,
|
||||
|
||||
ctrl_resp_sender: ctrl_sender,
|
||||
|
||||
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
|
||||
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_id(&self) -> PeerConnId {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
tracing::info!("waiting for handshake request from client");
|
||||
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_req));
|
||||
tracing::info!("handshake request: {:?}", hs_req);
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
tracing::info!("waiting for handshake request from server");
|
||||
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_rsp));
|
||||
tracing::info!("handshake response: {:?}", hs_rsp);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn handshake_done(&self) -> bool {
|
||||
self.info.is_some()
|
||||
}
|
||||
|
||||
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let buf = ret.unwrap();
|
||||
let p = Packet::decode(&buf);
|
||||
match p.packet_type {
|
||||
ArchivedPacketType::Ping => {
|
||||
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
|
||||
else {
|
||||
log::error!("unexpected packet: {:?}", p);
|
||||
continue;
|
||||
};
|
||||
|
||||
let pong = packet::Packet::new_pong_packet(
|
||||
conn_info.my_peer_id,
|
||||
conn_info.peer_id,
|
||||
seq.into(),
|
||||
);
|
||||
|
||||
if let Err(e) = sink.send(pong.into()).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
}
|
||||
ArchivedPacketType::Pong => {
|
||||
if let Err(e) = ctrl_sender.send(buf.into()) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if sender.send(buf.into()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
if let Err(close_ret) = sink.close().await {
|
||||
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
|
||||
}
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), Error> {
|
||||
self.sink.send(msg).await
|
||||
}
|
||||
|
||||
pub fn get_peer_id(&self) -> PeerId {
|
||||
self.info.as_ref().unwrap().my_peer_id
|
||||
}
|
||||
|
||||
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||
self.info.as_ref().unwrap().network_identity.clone()
|
||||
}
|
||||
|
||||
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||
self.close_event_sender = Some(sender);
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> PeerConnStats {
|
||||
PeerConnStats {
|
||||
latency_us: self.latency_stats.get_latency_us(),
|
||||
|
||||
tx_bytes: self.throughput.tx_bytes(),
|
||||
rx_bytes: self.throughput.rx_bytes(),
|
||||
|
||||
tx_packets: self.throughput.tx_packets(),
|
||||
rx_packets: self.throughput.rx_packets(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_info(&self) -> PeerConnInfo {
|
||||
PeerConnInfo {
|
||||
conn_id: self.conn_id.to_string(),
|
||||
my_peer_id: self.my_peer_id,
|
||||
peer_id: self.get_peer_id(),
|
||||
features: self.info.as_ref().unwrap().features.clone(),
|
||||
tunnel: self.tunnel.info(),
|
||||
stats: Some(self.get_stats()),
|
||||
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PeerConn {
|
||||
fn drop(&mut self) {
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
tokio::spawn(async move {
|
||||
let ret = sink.close().await;
|
||||
tracing::info!(error = ?ret, "peer conn tunnel closed.");
|
||||
});
|
||||
log::info!("peer conn {:?} drop", self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
*/
|
||||
@@ -626,7 +626,10 @@ impl UdpTunnelConnector {
|
||||
)
|
||||
.await??;
|
||||
|
||||
socket.connect(recv_addr).await?;
|
||||
if recv_addr != addr {
|
||||
tracing::debug!(?recv_addr, ?addr, "udp connect addr not match");
|
||||
}
|
||||
|
||||
self.build_tunnel(socket, addr, conn_id).await
|
||||
}
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
use std::result::Result;
|
||||
use tokio::io;
|
||||
use tokio_util::{
|
||||
bytes::{BufMut, Bytes, BytesMut},
|
||||
codec::{Decoder, Encoder},
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
|
||||
pub struct BytesCodec {
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl BytesCodec {
|
||||
/// Creates a new `BytesCodec` for shipping around raw bytes.
|
||||
pub fn new(capacity: usize) -> BytesCodec {
|
||||
BytesCodec { capacity }
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for BytesCodec {
|
||||
type Item = BytesMut;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
|
||||
if !buf.is_empty() {
|
||||
let len = buf.len();
|
||||
let ret = Some(buf.split_to(len));
|
||||
buf.reserve(self.capacity);
|
||||
Ok(ret)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Bytes> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<BytesMut> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,483 +0,0 @@
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::{stream::FuturesUnordered, Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
|
||||
use network_interface::NetworkInterfaceConfig;
|
||||
use tokio::{sync::Mutex, time::error::Elapsed};
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::tunnels::{SinkError, TunnelError};
|
||||
|
||||
use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo};
|
||||
|
||||
pub struct FramedTunnel<R, W> {
|
||||
read: Arc<Mutex<R>>,
|
||||
write: Arc<Mutex<W>>,
|
||||
|
||||
info: Option<TunnelInfo>,
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
pub fn new(read: R, write: W, info: Option<TunnelInfo>) -> Self {
|
||||
FramedTunnel {
|
||||
read: Arc::new(Mutex::new(read)),
|
||||
write: Arc::new(Mutex::new(write)),
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box<dyn Tunnel> {
|
||||
Box::new(FramedTunnel::new(read, write, Some(info)))
|
||||
}
|
||||
|
||||
pub fn recv_stream(&self) -> impl DatagramStream {
|
||||
let read = self.read.clone();
|
||||
let info = self.info.clone();
|
||||
stream! {
|
||||
loop {
|
||||
let read_ret = read.lock().await.next().await;
|
||||
if read_ret.is_none() {
|
||||
tracing::info!(?info, "read_ret is none");
|
||||
yield Err(TunnelError::CommonError("recv stream closed".to_string()));
|
||||
} else {
|
||||
let read_ret = read_ret.unwrap();
|
||||
if read_ret.is_err() {
|
||||
let err = read_ret.err().unwrap();
|
||||
tracing::info!(?info, "recv stream read error");
|
||||
yield Err(TunnelError::CommonError(err.to_string()));
|
||||
} else {
|
||||
yield Ok(read_ret.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_sink(&self) -> impl DatagramSink {
|
||||
struct SendSink<W, WE> {
|
||||
write: Arc<Mutex<W>>,
|
||||
max_buffer_size: usize,
|
||||
sending_buffers: Option<VecDeque<SinkItem>>,
|
||||
send_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
close_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
}
|
||||
|
||||
impl<W, WE> SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
fn try_send_buffser(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), WE>> {
|
||||
if self.send_task.is_none() {
|
||||
let mut buffers = self.sending_buffers.take().unwrap();
|
||||
let tun = self.write.clone();
|
||||
let send_task = async move {
|
||||
if buffers.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
while let Some(buf) = buffers.front() {
|
||||
log::trace!(
|
||||
"try_send buffer, len: {:?}, buf: {:?}",
|
||||
buffers.len(),
|
||||
&buf
|
||||
);
|
||||
let timeout_task = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(1),
|
||||
locked_tun.send(buf.clone()),
|
||||
);
|
||||
let send_res = timeout_task.await;
|
||||
let Ok(send_res) = send_res else {
|
||||
// panic!("send timeout");
|
||||
let err = send_res.err().unwrap();
|
||||
return Err(err.into());
|
||||
};
|
||||
let Ok(_) = send_res else {
|
||||
let err = send_res.err().unwrap();
|
||||
println!("send error: {:?}", err);
|
||||
return Err(err);
|
||||
};
|
||||
buffers.pop_front();
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
self.send_task = Some(Box::pin(send_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self.send_task = None;
|
||||
self.sending_buffers = Some(VecDeque::new());
|
||||
return Poll::Ready(ret);
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, WE> Sink<SinkItem> for SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let sending_buf = self_mut.sending_buffers.as_ref();
|
||||
// if sending_buffers is None, must already be doing flush
|
||||
if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size {
|
||||
return self_mut.poll_flush_unpin(cx);
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
assert!(self.send_task.is_none());
|
||||
let self_mut = self.get_mut();
|
||||
self_mut.sending_buffers.as_mut().unwrap().push_back(item);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let ret = self_mut.try_send_buffser(cx);
|
||||
match ret {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))),
|
||||
Poll::Pending => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.close_task.is_none() {
|
||||
let tun = self_mut.write.clone();
|
||||
let close_task = async move {
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
return locked_tun.close().await;
|
||||
};
|
||||
self_mut.close_task = Some(Box::pin(close_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self_mut.close_task = None;
|
||||
|
||||
if ret.is_err() {
|
||||
return Poll::Ready(Err(SinkError::CommonError(
|
||||
ret.err().unwrap().to_string(),
|
||||
)));
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SendSink {
|
||||
write: self.write.clone(),
|
||||
max_buffer_size: 1000,
|
||||
sending_buffers: Some(VecDeque::new()),
|
||||
send_task: None,
|
||||
close_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> Tunnel for FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
if self.info.is_none() {
|
||||
None
|
||||
} else {
|
||||
Some(self.info.clone().unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelWithCustomInfo {
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
info: TunnelInfo,
|
||||
}
|
||||
|
||||
impl TunnelWithCustomInfo {
|
||||
pub fn new(tunnel: Box<dyn Tunnel>, info: TunnelInfo) -> Self {
|
||||
TunnelWithCustomInfo { tunnel, info }
|
||||
}
|
||||
}
|
||||
|
||||
impl Tunnel for TunnelWithCustomInfo {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
self.tunnel.stream()
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
self.tunnel.sink()
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(self.info.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
if local_ip.is_unspecified() || local_ip.is_multicast() {
|
||||
return None;
|
||||
}
|
||||
let ifaces = network_interface::NetworkInterface::show().ok()?;
|
||||
for iface in ifaces {
|
||||
for addr in iface.addr {
|
||||
if addr.ip() == *local_ip {
|
||||
return Some(iface.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::error!(?local_ip, "can not find interface name by ip");
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2_ext(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
bind_dev: Option<String>,
|
||||
) -> Result<(), TunnelError> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
|
||||
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
|
||||
}
|
||||
|
||||
socket2_socket.set_nonblocking(true)?;
|
||||
socket2_socket.set_reuse_address(true)?;
|
||||
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;
|
||||
|
||||
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
// socket2_socket.set_reuse_port(true)?;
|
||||
|
||||
if bind_addr.ip().is_unspecified() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// win can handle this with bind correctly
|
||||
#[cfg(any(target_os = "ios", target_os = "macos"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
// use IP_BOUND_IF to bind device
|
||||
unsafe {
|
||||
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device");
|
||||
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||
mut futures: FuturesUnordered<Fut>,
|
||||
) -> Result<Ret, super::TunnelError>
|
||||
where
|
||||
Fut: Future<Output = Result<Ret, E>> + Send + Sync,
|
||||
E: std::error::Error + Into<super::TunnelError> + Send + Sync + 'static,
|
||||
{
|
||||
// return last error
|
||||
let mut last_err = None;
|
||||
|
||||
while let Some(ret) = futures.next().await {
|
||||
if let Err(e) = ret {
|
||||
last_err = Some(e.into());
|
||||
} else {
|
||||
return ret.map_err(|e| e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(super::TunnelError::CommonError(
|
||||
"no connect futures".to_string(),
|
||||
)))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
) -> Result<(), TunnelError> {
|
||||
setup_sokcet2_ext(
|
||||
socket2_socket,
|
||||
bind_addr,
|
||||
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||
)
|
||||
}
|
||||
|
||||
pub mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use futures::SinkExt;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::netns::NetNS,
|
||||
tunnels::{close_tunnel, TunnelConnector, TunnelListener},
|
||||
};
|
||||
|
||||
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
||||
let mut recv = Box::into_pin(tunnel.stream());
|
||||
let mut send = Box::into_pin(tunnel.sink());
|
||||
|
||||
while let Some(ret) = recv.next().await {
|
||||
if ret.is_err() {
|
||||
log::trace!("recv error: {:?}", ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
let res = ret.unwrap();
|
||||
log::trace!("recv a msg, try echo back: {:?}", res);
|
||||
send.send(Bytes::from(res)).await.unwrap();
|
||||
if once {
|
||||
break;
|
||||
}
|
||||
}
|
||||
log::warn!("echo server exit...");
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||
mut listener: L,
|
||||
mut connector: C,
|
||||
l_netns: NetNS,
|
||||
c_netns: NetNS,
|
||||
) where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
l_netns
|
||||
.run_async(|| async {
|
||||
listener.listen().await.unwrap();
|
||||
})
|
||||
.await;
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
connector.remote_url().to_string()
|
||||
);
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
let send_data = Bytes::from("12345678abcdefg");
|
||||
send.send(send_data).await.unwrap();
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret, Bytes::from("12345678abcdefg"));
|
||||
|
||||
close_tunnel(&tunnel).await.unwrap();
|
||||
|
||||
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
|
||||
lis.abort();
|
||||
} else {
|
||||
// lis should finish in 1 second
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
|
||||
// prepare a 4k buffer with random data
|
||||
let mut send_buf = BytesMut::new();
|
||||
for _ in 0..64 {
|
||||
send_buf.put_i128(rand::random::<i128>());
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mut count = 0;
|
||||
while now.elapsed().as_secs() < 3 {
|
||||
send.send(send_buf.clone().freeze()).await.unwrap();
|
||||
let _ = recv.next().await.unwrap().unwrap();
|
||||
count += 1;
|
||||
}
|
||||
println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs());
|
||||
|
||||
lis.abort();
|
||||
}
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
pub mod codec;
|
||||
pub mod common;
|
||||
// pub mod ring_tunnel;
|
||||
// pub mod stats;
|
||||
// pub mod tcp_tunnel;
|
||||
// pub mod tunnel_filter;
|
||||
// pub mod udp_tunnel;
|
||||
// pub mod wireguard;
|
||||
|
||||
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
|
||||
|
||||
use crate::rpc::TunnelInfo;
|
||||
use async_trait::async_trait;
|
||||
use futures::{Sink, SinkExt, Stream};
|
||||
|
||||
use thiserror::Error;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TunnelError {
|
||||
#[error("Error: {0}")]
|
||||
CommonError(String),
|
||||
#[error("io error")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("wait resp error {0}")]
|
||||
WaitRespError(String),
|
||||
#[error("Connect Error: {0}")]
|
||||
ConnectError(String),
|
||||
#[error("Invalid Protocol: {0}")]
|
||||
InvalidProtocol(String),
|
||||
#[error("Invalid Addr: {0}")]
|
||||
InvalidAddr(String),
|
||||
#[error("Tun Error: {0}")]
|
||||
TunError(String),
|
||||
#[error("timeout")]
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
pub type StreamT = BytesMut;
|
||||
pub type StreamItem = Result<StreamT, TunnelError>;
|
||||
pub type SinkItem = Bytes;
|
||||
pub type SinkError = TunnelError;
|
||||
|
||||
pub trait DatagramStream: Stream<Item = StreamItem> + Send + Sync {}
|
||||
impl<T> DatagramStream for T where T: Stream<Item = StreamItem> + Send + Sync {}
|
||||
pub trait DatagramSink: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
impl<T> DatagramSink for T where T: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Tunnel: Send + Sync {
|
||||
fn stream(&self) -> Box<dyn DatagramStream>;
|
||||
fn sink(&self) -> Box<dyn DatagramSink>;
|
||||
|
||||
fn pin_stream(&self) -> Pin<Box<dyn DatagramStream>> {
|
||||
Box::into_pin(self.stream())
|
||||
}
|
||||
|
||||
fn pin_sink(&self) -> Pin<Box<dyn DatagramSink>> {
|
||||
Box::into_pin(self.sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo>;
|
||||
}
|
||||
|
||||
pub async fn close_tunnel(t: &Box<dyn Tunnel>) -> Result<(), TunnelError> {
|
||||
t.pin_sink().close().await
|
||||
}
|
||||
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
|
||||
fn get(&self) -> u32;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelListener: Send + Sync {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError>;
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn local_url(&self) -> url::Url;
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
#[derive(Debug)]
|
||||
struct FakeTunnelConnCounter {}
|
||||
impl TunnelConnCounter for FakeTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
Arc::new(Box::new(FakeTunnelConnCounter {}))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn remote_url(&self) -> url::Url;
|
||||
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
|
||||
}
|
||||
|
||||
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
|
||||
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn Tunnel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Tunnel")
|
||||
.field("info", &self.info())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelConnector + Sync + Send {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelConnector")
|
||||
.field("remote_url", &self.remote_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelListener {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelListener")
|
||||
.field("local_url", &self.local_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait FromUrl {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub(crate) fn check_scheme_and_get_socket_addr<T>(
|
||||
url: &url::Url,
|
||||
scheme: &str,
|
||||
) -> Result<T, TunnelError>
|
||||
where
|
||||
T: FromUrl,
|
||||
{
|
||||
if url.scheme() != scheme {
|
||||
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||
}
|
||||
|
||||
Ok(T::from_url(url.clone())?)
|
||||
}
|
||||
|
||||
impl FromUrl for SocketAddr {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
Ok(url.socket_addrs(|| None)?.pop().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromUrl for uuid::Uuid {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
let o = url.host_str().unwrap();
|
||||
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
||||
Ok(o)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelUrl {
|
||||
inner: url::Url,
|
||||
}
|
||||
|
||||
impl From<url::Url> for TunnelUrl {
|
||||
fn from(url: url::Url) -> Self {
|
||||
TunnelUrl { inner: url }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TunnelUrl> for url::Url {
|
||||
fn from(url: TunnelUrl) -> Self {
|
||||
url.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelUrl {
|
||||
pub fn into_inner(self) -> url::Url {
|
||||
self.inner
|
||||
}
|
||||
|
||||
pub fn bind_dev(&self) -> Option<String> {
|
||||
self.inner.path().strip_prefix("/").and_then(|s| {
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,436 +0,0 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU32},
|
||||
Arc,
|
||||
},
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Sink;
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::sync::{
|
||||
mpsc::{UnboundedReceiver, UnboundedSender},
|
||||
Mutex, Notify,
|
||||
};
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio_util::bytes::BytesMut;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tunnels::{SinkError, SinkItem};
|
||||
|
||||
use super::{
|
||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::FramedTunnel,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
static RING_TUNNEL_CAP: usize = 1000;
|
||||
|
||||
struct Ring {
|
||||
id: Uuid,
|
||||
ring: ArrayQueue<SinkItem>,
|
||||
consume_notify: Notify,
|
||||
produce_notify: Notify,
|
||||
closed: AtomicBool,
|
||||
}
|
||||
|
||||
impl Ring {
|
||||
fn new(cap: usize, id: uuid::Uuid) -> Self {
|
||||
Self {
|
||||
id,
|
||||
ring: ArrayQueue::new(cap),
|
||||
consume_notify: Notify::new(),
|
||||
produce_notify: Notify::new(),
|
||||
closed: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
self.closed
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
self.produce_notify.notify_one();
|
||||
}
|
||||
|
||||
fn closed(&self) -> bool {
|
||||
self.closed.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingTunnel {
|
||||
id: Uuid,
|
||||
ring: Arc<Ring>,
|
||||
sender_counter: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl RingTunnel {
|
||||
pub fn new(cap: usize) -> Self {
|
||||
let id = Uuid::new_v4();
|
||||
RingTunnel {
|
||||
id: id.clone(),
|
||||
ring: Arc::new(Ring::new(cap, id)),
|
||||
sender_counter: Arc::new(AtomicU32::new(1)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
|
||||
let mut ret = Self::new(cap);
|
||||
ret.id = id;
|
||||
ret
|
||||
}
|
||||
|
||||
fn recv_stream(&self) -> impl DatagramStream {
|
||||
let ring = self.ring.clone();
|
||||
let id = self.id;
|
||||
stream! {
|
||||
loop {
|
||||
match ring.ring.pop() {
|
||||
Some(v) => {
|
||||
let mut out = BytesMut::new();
|
||||
out.extend_from_slice(&v);
|
||||
ring.consume_notify.notify_one();
|
||||
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
|
||||
yield Ok(out);
|
||||
},
|
||||
None => {
|
||||
if ring.closed() {
|
||||
log::warn!("ring recv tunnel {:?} closed", id);
|
||||
yield Err(TunnelError::CommonError("ring closed".to_owned()));
|
||||
}
|
||||
log::trace!("waiting recv buffer, id: {}", id);
|
||||
ring.produce_notify.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_sink(&self) -> impl DatagramSink {
|
||||
let ring = self.ring.clone();
|
||||
let sender_counter = self.sender_counter.clone();
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
struct T {
|
||||
ring: Arc<Ring>,
|
||||
wait_consume_task: Option<JoinHandle<()>>,
|
||||
sender_counter: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl T {
|
||||
fn wait_ring_consume(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
expected_size: usize,
|
||||
) -> std::task::Poll<()> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.ring.ring.len() <= expected_size {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
if self_mut.wait_consume_task.is_none() {
|
||||
let id = self_mut.ring.id;
|
||||
let ring = self_mut.ring.clone();
|
||||
let task = async move {
|
||||
log::trace!(
|
||||
"waiting ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
while ring.ring.len() > expected_size {
|
||||
ring.consume_notify.notified().await;
|
||||
}
|
||||
log::trace!(
|
||||
"ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
};
|
||||
self_mut.wait_consume_task = Some(tokio::spawn(task));
|
||||
}
|
||||
let task = self_mut.wait_consume_task.as_mut().unwrap();
|
||||
match task.poll_unpin(cx) {
|
||||
Poll::Ready(_) => {
|
||||
self_mut.wait_consume_task = None;
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<SinkItem> for T {
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
if self.ring.closed() {
|
||||
return Poll::Ready(Err(TunnelError::CommonError(
|
||||
"ring closed during ready".to_owned(),
|
||||
)
|
||||
.into()));
|
||||
}
|
||||
let expected_size = self.ring.ring.capacity() - 1;
|
||||
match self.wait_ring_consume(cx, expected_size) {
|
||||
Poll::Ready(_) => Poll::Ready(Ok(())),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
item: SinkItem,
|
||||
) -> Result<(), Self::Error> {
|
||||
if self.ring.closed() {
|
||||
return Err(
|
||||
TunnelError::CommonError("ring closed during send".to_owned()).into(),
|
||||
);
|
||||
}
|
||||
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
|
||||
self.ring.ring.push(item).unwrap();
|
||||
self.ring.produce_notify.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
if self.ring.closed() {
|
||||
return Poll::Ready(Err(TunnelError::CommonError(
|
||||
"ring closed during flush".to_owned(),
|
||||
)
|
||||
.into()));
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.ring.close();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for T {
|
||||
fn drop(&mut self) {
|
||||
let rem = self
|
||||
.sender_counter
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if rem == 1 {
|
||||
self.ring.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sender_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
T {
|
||||
ring,
|
||||
wait_consume_task: None,
|
||||
sender_counter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RingTunnel {
|
||||
fn drop(&mut self) {
|
||||
let rem = self
|
||||
.sender_counter
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if rem == 1 {
|
||||
self.ring.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Connection {
|
||||
client: RingTunnel,
|
||||
server: RingTunnel,
|
||||
}
|
||||
|
||||
impl Tunnel for RingTunnel {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, UnboundedSender<Arc<Connection>>>>>> =
|
||||
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingTunnelListener {
|
||||
listerner_addr: url::Url,
|
||||
conn_sender: UnboundedSender<Arc<Connection>>,
|
||||
conn_receiver: UnboundedReceiver<Arc<Connection>>,
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
pub fn new(key: url::Url) -> Self {
|
||||
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||
RingTunnelListener {
|
||||
listerner_addr: key,
|
||||
conn_sender,
|
||||
conn_receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tunnel_for_client(conn: Arc<Connection>) -> Box<dyn Tunnel> {
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
Box::pin(conn.client.recv_stream()),
|
||||
conn.server.send_sink(),
|
||||
TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn get_tunnel_for_server(conn: Arc<Connection>) -> Box<dyn Tunnel> {
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
Box::pin(conn.server.recv_stream()),
|
||||
conn.client.send_sink(),
|
||||
TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
||||
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for RingTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
log::info!("listen new conn of key: {}", self.listerner_addr);
|
||||
CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.insert(self.get_addr()?, self.conn_sender.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
||||
let my_addr = self.get_addr()?;
|
||||
if let Some(conn) = self.conn_receiver.recv().await {
|
||||
if conn.server.id == my_addr {
|
||||
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||
return Ok(get_tunnel_for_server(conn));
|
||||
} else {
|
||||
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
|
||||
return Err(TunnelError::CommonError(
|
||||
"accept got wrong ring server id".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
return Err(TunnelError::CommonError("conn receiver stopped".to_owned()));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.listerner_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingTunnelConnector {
|
||||
remote_addr: url::Url,
|
||||
}
|
||||
|
||||
impl RingTunnelConnector {
|
||||
pub fn new(remote_addr: url::Url) -> Self {
|
||||
RingTunnelConnector { remote_addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelConnector for RingTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>(&self.remote_addr, "ring")?;
|
||||
let entry = CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.get(&remote_addr)
|
||||
.unwrap()
|
||||
.clone();
|
||||
log::info!("connecting");
|
||||
let conn = Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new_with_id(remote_addr.clone(), RING_TUNNEL_CAP),
|
||||
});
|
||||
entry
|
||||
.send(conn.clone())
|
||||
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
|
||||
Ok(get_tunnel_for_client(conn))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.remote_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
||||
let conn = Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
});
|
||||
(
|
||||
Box::new(get_tunnel_for_server(conn.clone())),
|
||||
Box::new(get_tunnel_for_client(conn)),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::StreamExt;
|
||||
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_pingpong() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id.clone());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_bench() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_close() {
|
||||
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
drop(stunnel);
|
||||
|
||||
let mut stream = ctunnel.pin_stream();
|
||||
let ret = stream.next().await;
|
||||
assert!(ret.as_ref().unwrap().is_err(), "expect Err, got {:?}", ret);
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed};
|
||||
|
||||
pub struct WindowLatency {
|
||||
latency_us_window: Vec<AtomicU32>,
|
||||
latency_us_window_index: AtomicU32,
|
||||
latency_us_window_size: u32,
|
||||
|
||||
sum: AtomicU32,
|
||||
count: AtomicU32,
|
||||
}
|
||||
|
||||
impl WindowLatency {
|
||||
pub fn new(window_size: u32) -> Self {
|
||||
Self {
|
||||
latency_us_window: (0..window_size).map(|_| AtomicU32::new(0)).collect(),
|
||||
latency_us_window_index: AtomicU32::new(0),
|
||||
latency_us_window_size: window_size,
|
||||
|
||||
sum: AtomicU32::new(0),
|
||||
count: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_latency(&self, latency_us: u32) {
|
||||
let index = self.latency_us_window_index.fetch_add(1, Relaxed);
|
||||
if self.count.load(Relaxed) < self.latency_us_window_size {
|
||||
self.count.fetch_add(1, Relaxed);
|
||||
}
|
||||
|
||||
let index = index % self.latency_us_window_size;
|
||||
let old_lat = self.latency_us_window[index as usize].swap(latency_us, Relaxed);
|
||||
|
||||
if old_lat < latency_us {
|
||||
self.sum.fetch_add(latency_us - old_lat, Relaxed);
|
||||
} else {
|
||||
self.sum.fetch_sub(old_lat - latency_us, Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_latency_us<T: From<u32> + std::ops::Div<Output = T>>(&self) -> T {
|
||||
let count = self.count.load(Relaxed);
|
||||
let sum = self.sum.load(Relaxed);
|
||||
if count == 0 {
|
||||
0.into()
|
||||
} else {
|
||||
(T::from(sum)) / T::from(count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Throughput {
|
||||
tx_bytes: AtomicU64,
|
||||
rx_bytes: AtomicU64,
|
||||
|
||||
tx_packets: AtomicU64,
|
||||
rx_packets: AtomicU64,
|
||||
}
|
||||
|
||||
impl Throughput {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tx_bytes: AtomicU64::new(0),
|
||||
rx_bytes: AtomicU64::new(0),
|
||||
|
||||
tx_packets: AtomicU64::new(0),
|
||||
rx_packets: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tx_bytes(&self) -> u64 {
|
||||
self.tx_bytes.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_bytes(&self) -> u64 {
|
||||
self.rx_bytes.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn tx_packets(&self) -> u64 {
|
||||
self.tx_packets.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_packets(&self) -> u64 {
|
||||
self.rx_packets.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn record_tx_bytes(&self, bytes: u64) {
|
||||
self.tx_bytes.fetch_add(bytes, Relaxed);
|
||||
self.tx_packets.fetch_add(1, Relaxed);
|
||||
}
|
||||
|
||||
pub fn record_rx_bytes(&self, bytes: u64) {
|
||||
self.rx_bytes.fetch_add(bytes, Relaxed);
|
||||
self.rx_packets.fetch_add(1, Relaxed);
|
||||
}
|
||||
}
|
||||
@@ -1,270 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
|
||||
|
||||
use crate::tunnels::common::setup_sokcet2;
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{wait_for_connect_futures, FramedTunnel},
|
||||
Tunnel, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelListener {
|
||||
addr: url::Url,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl TcpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelListener {
|
||||
addr,
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for TcpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
let socket = if addr.is_ipv4() {
|
||||
TcpSocket::new_v4()?
|
||||
} else {
|
||||
TcpSocket::new_v6()?
|
||||
};
|
||||
|
||||
socket.set_reuseaddr(true)?;
|
||||
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
// socket.set_reuseport(true)?;
|
||||
socket.bind(addr)?;
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let listener = self.listener.as_ref().unwrap();
|
||||
let (stream, _) = listener.accept().await?;
|
||||
stream.set_nodelay(true).unwrap();
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: self.local_url().into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tunnel_with_tcp_stream(
|
||||
stream: TcpStream,
|
||||
remote_url: url::Url,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
stream.set_nodelay(true).unwrap();
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
remote_addr: remote_url.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(Box::new(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelConnector {
|
||||
addr: url::Url,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl TcpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelConnector {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(addr = ?self.addr, "connect tcp start");
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
let stream = TcpStream::connect(addr).await?;
|
||||
tracing::info!(addr = ?self.addr, "connect tcp succ");
|
||||
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let futures = FuturesUnordered::new();
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr");
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(dst_addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, bind_addr)?;
|
||||
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
futures.push(socket.connect(dst_addr.clone()));
|
||||
}
|
||||
|
||||
let ret = wait_for_connect_futures(futures).await;
|
||||
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for TcpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
use crate::tunnels::{
|
||||
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||
TunnelConnector,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_pingpong() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench_with_bind() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn tcp_bench_with_bind_fail() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
// test slow send lock in framed tunnel
|
||||
#[tokio::test]
|
||||
async fn tcp_multiple_sender_and_slow_receiver() {
|
||||
// console_subscriber::init();
|
||||
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
|
||||
listener.listen().await.unwrap();
|
||||
let t1 = tokio::spawn(async move {
|
||||
let t = listener.accept().await.unwrap();
|
||||
let mut stream = t.pin_stream();
|
||||
|
||||
let now = tokio::time::Instant::now();
|
||||
|
||||
while let Some(Ok(_)) = stream.next().await {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t1 exit");
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
let mut sink1 = tunnel.pin_sink();
|
||||
let t2 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink1.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t2 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t2 exit");
|
||||
});
|
||||
|
||||
let mut sink2 = tunnel.pin_sink();
|
||||
let t3 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink2.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t3 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t3 exit");
|
||||
});
|
||||
|
||||
let t4 = tokio::spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
tracing::info!("closing");
|
||||
let close_ret = tunnel.pin_sink().close().await;
|
||||
tracing::info!("closed {:?}", close_ret);
|
||||
});
|
||||
|
||||
let _ = tokio::join!(t1, t2, t3, t4);
|
||||
}
|
||||
}
|
||||
@@ -1,279 +0,0 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use crate::rpc::TunnelInfo;
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
|
||||
use self::stats::Throughput;
|
||||
|
||||
use super::*;
|
||||
use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel};
|
||||
|
||||
pub trait TunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
Some(Ok(data))
|
||||
}
|
||||
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
|
||||
match data {
|
||||
Ok(v) => Some(Ok(v)),
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelWithFilter<T, F> {
|
||||
inner: T,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
|
||||
impl<T, F> Tunnel for TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
struct SinkWrapper<F> {
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Sink<SinkItem> for SinkWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
let Some(item) = self.filter.before_send(item) else {
|
||||
return Ok(());
|
||||
};
|
||||
self.get_mut().sink.start_send_unpin(item?)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_close_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(SinkWrapper {
|
||||
sink: self.inner.pin_sink(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
struct StreamWrapper<F> {
|
||||
stream: Pin<Box<dyn DatagramStream>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Stream for StreamWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let self_mut = self.get_mut();
|
||||
loop {
|
||||
match self_mut.stream.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ret)) => {
|
||||
let Some(ret) = self_mut.filter.after_received(ret) else {
|
||||
continue;
|
||||
};
|
||||
return Poll::Ready(Some(ret));
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Pending => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(StreamWrapper {
|
||||
stream: self.inner.pin_stream(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
self.inner.info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F> TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(inner: T, filter: Arc<F>) -> Self {
|
||||
Self { inner, filter }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PacketRecorderTunnelFilter {
|
||||
pub received: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
pub sent: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for PacketRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
self.received.lock().unwrap().push(data.clone());
|
||||
Some(Ok(data))
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.sent.lock().unwrap().push(v.clone().into());
|
||||
Some(Ok(v))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
received: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StatsRecorderTunnelFilter {
|
||||
throughput: Arc<Throughput>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for StatsRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
self.throughput.record_tx_bytes(data.len() as u64);
|
||||
Some(Ok(data))
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.throughput.record_rx_bytes(v.len() as u64);
|
||||
Some(Ok(v))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StatsRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
throughput: Arc::new(Throughput::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_throughput(&self) -> Arc<Throughput> {
|
||||
self.throughput.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! define_tunnel_filter_chain {
|
||||
($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => (
|
||||
pub struct $type_name {
|
||||
$($field_name: std::sync::Arc<$filter_type>,)+
|
||||
}
|
||||
|
||||
impl $type_name {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
$($field_name: std::sync::Arc::new(<$filter_type>::new()),)+
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel {
|
||||
$(
|
||||
let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone());
|
||||
)+
|
||||
tunnel
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
use super::*;
|
||||
use crate::tunnels::ring_tunnel::RingTunnel;
|
||||
|
||||
pub struct DropSendTunnelFilter {
|
||||
start: AtomicU32,
|
||||
end: AtomicU32,
|
||||
cur: AtomicU32,
|
||||
}
|
||||
|
||||
impl TunnelFilter for DropSendTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
self.cur.fetch_add(1, Ordering::SeqCst);
|
||||
if self.cur.load(Ordering::SeqCst) >= self.start.load(Ordering::SeqCst)
|
||||
&& self.cur.load(std::sync::atomic::Ordering::SeqCst)
|
||||
< self.end.load(Ordering::SeqCst)
|
||||
{
|
||||
tracing::trace!("drop packet: {:?}", data);
|
||||
return None;
|
||||
}
|
||||
Some(Ok(data))
|
||||
}
|
||||
}
|
||||
|
||||
impl DropSendTunnelFilter {
|
||||
pub fn new(start: u32, end: u32) -> Self {
|
||||
Self {
|
||||
start: AtomicU32::new(start),
|
||||
end: AtomicU32::new(end),
|
||||
cur: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_filter() {
|
||||
define_tunnel_filter_chain!(
|
||||
Filter,
|
||||
a = PacketRecorderTunnelFilter,
|
||||
b = PacketRecorderTunnelFilter,
|
||||
c = PacketRecorderTunnelFilter
|
||||
);
|
||||
|
||||
let filter = Filter::new();
|
||||
let tunnel = filter.wrap_tunnel(RingTunnel::new(1));
|
||||
|
||||
let mut s = tunnel.pin_sink();
|
||||
s.send(Bytes::from("hello")).await.unwrap();
|
||||
|
||||
assert_eq!(1, filter.a.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.b.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.c.received.lock().unwrap().len());
|
||||
}
|
||||
}
|
||||
@@ -1,768 +0,0 @@
|
||||
use std::{fmt::Debug, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tokio_util::{
|
||||
bytes::{Buf, Bytes, BytesMut},
|
||||
udp::UdpFramed,
|
||||
};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
join_joinset_background,
|
||||
rkyv_util::{self, encode_to_bytes, vec_to_string},
|
||||
},
|
||||
rpc::TunnelInfo,
|
||||
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::{
|
||||
codec::BytesCodec,
|
||||
common::{
|
||||
setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures, FramedTunnel,
|
||||
TunnelWithCustomInfo,
|
||||
},
|
||||
ring_tunnel::create_ring_tunnel_pair,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelListener, TunnelUrl,
|
||||
};
|
||||
|
||||
pub const UDP_DATA_MTU: usize = 65000;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub enum UdpPacketPayload {
|
||||
Syn,
|
||||
Sack,
|
||||
HolePunch(String),
|
||||
Data(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for UdpPacketPayload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||
match self {
|
||||
UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
|
||||
UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
|
||||
UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(),
|
||||
UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UdpPacket {
|
||||
pub conn_id: u32,
|
||||
pub magic: u32,
|
||||
pub payload: UdpPacketPayload,
|
||||
}
|
||||
|
||||
const UDP_PACKET_MAGIC: u32 = 0x19941126;
|
||||
|
||||
impl std::fmt::Debug for ArchivedUdpPacketPayload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||
match self {
|
||||
ArchivedUdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
|
||||
ArchivedUdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
|
||||
ArchivedUdpPacketPayload::HolePunch(s) => {
|
||||
tmp.field("HolePunch", &s.as_bytes()).finish()
|
||||
}
|
||||
ArchivedUdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpPacket {
|
||||
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Data(vec_to_string(data)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id: 0,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_syn_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Syn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_sack_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Sack,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error");
|
||||
return None;
|
||||
};
|
||||
|
||||
if udp_packet.conn_id != conn_id.clone() {
|
||||
tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::trace!(?udp_packet, "udp magic not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
|
||||
tracing::warn!(?udp_packet, "udp payload not data");
|
||||
return None;
|
||||
};
|
||||
|
||||
let offset = payload.as_ptr() as usize - buf.as_ptr() as usize;
|
||||
let len = payload.len();
|
||||
if offset + len > buf.len() {
|
||||
tracing::warn!(?offset, ?len, ?buf, "udp payload data out of range");
|
||||
return None;
|
||||
}
|
||||
|
||||
buf.advance(offset);
|
||||
buf.truncate(len);
|
||||
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
|
||||
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
fn get_tunnel_from_socket(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Box<dyn super::Tunnel> {
|
||||
let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU));
|
||||
let (sink, stream) = udp.split();
|
||||
|
||||
let recv_addr = addr;
|
||||
let stream = stream.filter_map(move |v| async move {
|
||||
tracing::trace!(?v, "udp stream recv something");
|
||||
if v.is_err() {
|
||||
tracing::warn!(?v, "udp stream error");
|
||||
return Some(Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
)));
|
||||
}
|
||||
|
||||
let (buf, addr) = v.unwrap();
|
||||
if recv_addr != addr {
|
||||
tracing::warn!(?addr, ?recv_addr, "udp recv addr not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
|
||||
});
|
||||
let stream = Box::pin(stream);
|
||||
|
||||
let sender_addr = addr;
|
||||
let sink = Box::pin(sink.with(move |v: Bytes| async move {
|
||||
if false {
|
||||
return Err(super::TunnelError::CommonError("udp sink error".to_owned()));
|
||||
}
|
||||
|
||||
// TODO: two copy here, how to avoid?
|
||||
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
|
||||
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
tracing::trace!(?udp_packet, ?v, "udp send packet");
|
||||
|
||||
Ok((v, sender_addr))
|
||||
}));
|
||||
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
stream,
|
||||
sink,
|
||||
// TODO: this remote addr is not a url
|
||||
super::TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(
|
||||
&socket.local_addr().unwrap().to_string(),
|
||||
"udp",
|
||||
)
|
||||
.into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) struct StreamSinkPair(
|
||||
pub Pin<Box<dyn DatagramStream>>,
|
||||
pub Pin<Box<dyn DatagramSink>>,
|
||||
pub u32,
|
||||
);
|
||||
pub(crate) type ArcStreamSinkPair = Arc<Mutex<StreamSinkPair>>;
|
||||
|
||||
pub struct UdpTunnelListener {
|
||||
addr: url::Url,
|
||||
socket: Option<Arc<UdpSocket>>,
|
||||
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
|
||||
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
|
||||
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
|
||||
}
|
||||
|
||||
impl UdpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
|
||||
Self {
|
||||
addr,
|
||||
socket: None,
|
||||
sock_map: Arc::new(DashMap::new()),
|
||||
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_forward_packet(
|
||||
sock_map: &DashMap<SocketAddr, ArcStreamSinkPair>,
|
||||
buf: BytesMut,
|
||||
addr: SocketAddr,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let entry = sock_map.get_mut(&addr);
|
||||
if entry.is_none() {
|
||||
log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::trace!("udp forward packet: {:?}, {:?}", addr, buf);
|
||||
let entry = entry.unwrap();
|
||||
let pair = entry.value().clone();
|
||||
drop(entry);
|
||||
|
||||
let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else {
|
||||
return Ok(());
|
||||
};
|
||||
pair.lock().await.1.send(buf.freeze()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
local_url: url::Url,
|
||||
conn_id: u32,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(?conn_id, ?addr, "udp connection accept handling",);
|
||||
|
||||
let udp_packet = UdpPacket::new_sack_packet(conn_id);
|
||||
let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
socket.send_to(&sack_buf, addr).await?;
|
||||
|
||||
let (ctunnel, stunnel) = create_ring_tunnel_pair();
|
||||
let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id);
|
||||
let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id);
|
||||
let addr_copy = addr.clone();
|
||||
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
|
||||
let ctunnel_stream = ctunnel.pin_stream();
|
||||
forward_tasks.lock().unwrap().spawn(async move {
|
||||
let ret = ctunnel_stream
|
||||
.map(|v| {
|
||||
tracing::trace!(?v, "udp stream recv something in forward task");
|
||||
if v.is_err() {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
));
|
||||
}
|
||||
Ok(v.unwrap().freeze())
|
||||
})
|
||||
.forward(udp_tunnel.pin_sink())
|
||||
.await;
|
||||
if let None = sock_map.remove(&addr_copy) {
|
||||
log::warn!("udp forward packet: {:?}, no entry", addr_copy);
|
||||
}
|
||||
close_tunnel(&ctunnel).await.unwrap();
|
||||
log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret);
|
||||
});
|
||||
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
stunnel,
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: local_url.into(),
|
||||
remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for UdpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
}
|
||||
|
||||
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
|
||||
let socket = self.socket.as_ref().unwrap().clone();
|
||||
let forward_tasks = self.forward_tasks.clone();
|
||||
let sock_map = self.sock_map.clone();
|
||||
let conn_send = self.conn_send.take().unwrap();
|
||||
let local_url = self.local_url().clone();
|
||||
self.forward_tasks.lock().unwrap().spawn(
|
||||
async move {
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(UDP_DATA_MTU, 0);
|
||||
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
|
||||
let _ = buf.split_off(_size);
|
||||
log::trace!(
|
||||
"udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||
addr,
|
||||
buf,
|
||||
_size
|
||||
);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error in forward task");
|
||||
continue;
|
||||
};
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::trace!(?udp_packet, "udp magic not match");
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
|
||||
let Ok(conn) = Self::handle_connect(
|
||||
socket.clone(),
|
||||
addr,
|
||||
forward_tasks.clone(),
|
||||
sock_map.clone(),
|
||||
local_url.clone(),
|
||||
udp_packet.conn_id.into(),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
tracing::error!(?addr, "udp handle connect error");
|
||||
continue;
|
||||
};
|
||||
if let Err(e) = conn_send.send(conn).await {
|
||||
tracing::warn!(?e, "udp send conn to accept channel error");
|
||||
}
|
||||
} else {
|
||||
Self::try_forward_packet(sock_map.as_ref(), buf, addr)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
|
||||
);
|
||||
|
||||
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
log::info!("start udp accept: {:?}", self.addr);
|
||||
while let Some(conn) = self.conn_recv.recv().await {
|
||||
return Ok(conn);
|
||||
}
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp accept error".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
struct UdpTunnelConnCounter {
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
}
|
||||
|
||||
impl Debug for UdpTunnelConnCounter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("UdpTunnelConnCounter")
|
||||
.field("sock_map_len", &self.sock_map.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
self.sock_map.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(Box::new(UdpTunnelConnCounter {
|
||||
sock_map: self.sock_map.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpTunnelConnector {
|
||||
addr: url::Url,
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl UdpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_sack(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(128, 0);
|
||||
|
||||
let (usize, recv_addr) = tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
socket.recv_from(&mut buf),
|
||||
)
|
||||
.await??;
|
||||
|
||||
if recv_addr != addr {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected sack addr: {:?}, {:?}",
|
||||
recv_addr, addr
|
||||
)));
|
||||
}
|
||||
|
||||
let _ = buf.split_off(usize);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error in wait sack");
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, decode error. buf: {:?}",
|
||||
buf
|
||||
)));
|
||||
};
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::trace!(?udp_packet, "udp magic not match");
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, magic not match. magic: {:?}",
|
||||
udp_packet.magic
|
||||
)));
|
||||
}
|
||||
|
||||
if conn_id != udp_packet.conn_id {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
|
||||
conn_id, udp_packet.conn_id
|
||||
)));
|
||||
}
|
||||
|
||||
if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected payload. payload: {:?}",
|
||||
udp_packet.payload
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_sack_loop(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
while let Err(err) = Self::wait_sack(socket, addr, conn_id).await {
|
||||
tracing::warn!(?err, "udp wait sack error");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn try_connect_with_socket(
|
||||
&self,
|
||||
socket: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
log::warn!("udp connect: {:?}", self.addr);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
crate::arch::windows::disable_connection_reset(&socket)?;
|
||||
|
||||
// send syn
|
||||
let conn_id = rand::random();
|
||||
let udp_packet = UdpPacket::new_syn_packet(conn_id);
|
||||
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
let ret = socket.send_to(&b, &addr).await?;
|
||||
tracing::warn!(?udp_packet, ?ret, "udp send syn");
|
||||
|
||||
// wait sack
|
||||
tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
Self::wait_sack_loop(&socket, addr, conn_id),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// sack done
|
||||
let local_addr = socket.local_addr().unwrap().to_string();
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
get_tunnel_from_socket(Arc::new(socket), addr, conn_id),
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(),
|
||||
remote_addr: self.remote_url().into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
return self.try_connect_with_socket(socket).await;
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(*bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
futures.push(self.try_connect_with_socket(socket));
|
||||
}
|
||||
wait_for_connect_futures(futures).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for UdpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use rand::Rng;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::{
|
||||
common::global_ctx::tests::get_mock_global_ctx,
|
||||
tunnels::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{
|
||||
get_interface_name_by_ip, setup_sokcet2_ext,
|
||||
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_pingpong() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench_with_bind() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn udp_bench_with_bind_fail() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
async fn send_random_data_to_socket(remote_url: url::Url) {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
|
||||
socket
|
||||
.connect(format!(
|
||||
"{}:{}",
|
||||
remote_url.host().unwrap(),
|
||||
remote_url.port().unwrap()
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// get a random 100-len buf
|
||||
loop {
|
||||
let mut buf = vec![0u8; 100];
|
||||
rand::thread_rng().fill(&mut buf[..]);
|
||||
socket.send(&buf).await.unwrap();
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_multiple_conns() {
|
||||
let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap());
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let _lis = tokio::spawn(async move {
|
||||
loop {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
|
||||
}
|
||||
});
|
||||
|
||||
let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||
let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||
|
||||
let t1 = connector1.connect().await.unwrap();
|
||||
let t2 = connector2.connect().await.unwrap();
|
||||
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()),
|
||||
));
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()),
|
||||
));
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()),
|
||||
));
|
||||
|
||||
let sender1 = tokio::spawn(async move {
|
||||
let mut sink = t1.pin_sink();
|
||||
let mut stream = t1.pin_stream();
|
||||
|
||||
for i in 0..10 {
|
||||
sink.send(Bytes::from("hello1")).await.unwrap();
|
||||
let recv = stream.next().await.unwrap().unwrap();
|
||||
println!("t1 recv: {:?}, {:?}", recv, i);
|
||||
assert_eq!(recv, Bytes::from("hello1"));
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let sender2 = tokio::spawn(async move {
|
||||
let mut sink = t2.pin_sink();
|
||||
let mut stream = t2.pin_stream();
|
||||
|
||||
for i in 0..10 {
|
||||
sink.send(Bytes::from("hello2")).await.unwrap();
|
||||
let recv = stream.next().await.unwrap().unwrap();
|
||||
println!("t2 recv: {:?}, {:?}", recv, i);
|
||||
assert_eq!(recv, Bytes::from("hello2"));
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let _ = tokio::join!(sender1, sender2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_packet_print() {
|
||||
let udp_packet = UdpPacket::new_data_packet(1, vec![1, 2, 3, 4, 5]);
|
||||
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
let a_udp_packet = rkyv_util::decode_from_bytes::<UdpPacket>(&b).unwrap();
|
||||
println!("{:?}, {:?}", udp_packet, a_udp_packet);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bind_multi_ip_to_same_dev() {
|
||||
let global_ctx = get_mock_global_ctx();
|
||||
let ips = global_ctx
|
||||
.get_ip_collector()
|
||||
.collect_ip_addrs()
|
||||
.await
|
||||
.interface_ipv4s;
|
||||
if ips.is_empty() {
|
||||
return;
|
||||
}
|
||||
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap());
|
||||
|
||||
for ip in ips {
|
||||
println!("bind to ip: {:?}, {:?}", ip, bind_dev);
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
|
||||
&format!("udp://{}:11111", ip).parse().unwrap(),
|
||||
"udp",
|
||||
)
|
||||
.unwrap();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)
|
||||
.unwrap();
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,841 +0,0 @@
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
fmt::{Debug, Formatter},
|
||||
hash::Hasher,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_recursion::async_recursion;
|
||||
use async_trait::async_trait;
|
||||
use boringtun::{
|
||||
noise::{errors::WireGuardError, Tunn, TunnResult},
|
||||
x25519::{PublicKey, StaticSecret},
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||
use rand::RngCore;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
|
||||
use crate::{
|
||||
rpc::TunnelInfo,
|
||||
tunnels::{build_url_from_socket_addr, common::TunnelWithCustomInfo},
|
||||
};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
|
||||
ring_tunnel::create_ring_tunnel_pair,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, TunnelUrl,
|
||||
};
|
||||
|
||||
const MAX_PACKET: usize = 65500;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum WgType {
|
||||
// used by easytier peer, need remove/add ip header for in/out wg msg
|
||||
InternalUse,
|
||||
// used by wireguard peer, keep original ip header
|
||||
ExternalUse,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WgConfig {
|
||||
my_secret_key: StaticSecret,
|
||||
my_public_key: PublicKey,
|
||||
|
||||
peer_secret_key: StaticSecret,
|
||||
peer_public_key: PublicKey,
|
||||
|
||||
wg_type: WgType,
|
||||
}
|
||||
|
||||
impl WgConfig {
|
||||
pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self {
|
||||
let mut my_sec = [0u8; 32];
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write(network_name.as_bytes());
|
||||
hasher.write(network_secret.as_bytes());
|
||||
my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..8]);
|
||||
my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..16]);
|
||||
my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..24]);
|
||||
my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
|
||||
let my_secret_key = StaticSecret::from(my_sec);
|
||||
let my_public_key = PublicKey::from(&my_secret_key);
|
||||
let peer_secret_key = StaticSecret::from(my_sec);
|
||||
let peer_public_key = my_public_key.clone();
|
||||
|
||||
WgConfig {
|
||||
my_secret_key,
|
||||
my_public_key,
|
||||
peer_secret_key,
|
||||
peer_public_key,
|
||||
|
||||
wg_type: WgType::InternalUse,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self {
|
||||
let server_cfg = Self::new_from_network_identity("server", server_key_seed);
|
||||
let client_cfg = Self::new_from_network_identity("client", client_key_seed);
|
||||
Self {
|
||||
my_secret_key: server_cfg.my_secret_key,
|
||||
my_public_key: server_cfg.my_public_key,
|
||||
peer_secret_key: client_cfg.my_secret_key,
|
||||
peer_public_key: client_cfg.my_public_key,
|
||||
|
||||
wg_type: WgType::ExternalUse,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn my_secret_key(&self) -> &[u8] {
|
||||
self.my_secret_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn peer_secret_key(&self) -> &[u8] {
|
||||
self.peer_secret_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn my_public_key(&self) -> &[u8] {
|
||||
self.my_public_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn peer_public_key(&self) -> &[u8] {
|
||||
self.peer_public_key.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct WgPeerData {
|
||||
udp: Arc<UdpSocket>, // only for send
|
||||
endpoint: SocketAddr,
|
||||
tunn: Arc<Mutex<Tunn>>,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
|
||||
wg_type: WgType,
|
||||
stopped: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Debug for WgPeerData {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WgPeerData")
|
||||
.field("endpoint", &self.endpoint)
|
||||
.field("local", &self.udp.local_addr())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl WgPeerData {
|
||||
#[tracing::instrument]
|
||||
async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
|
||||
let encapsulate_result = {
|
||||
let mut peer = self.tunn.lock().await;
|
||||
if matches!(self.wg_type, WgType::InternalUse) {
|
||||
peer.encapsulate(&self.add_ip_header(&packet), &mut send_buf)
|
||||
} else {
|
||||
peer.encapsulate(&packet, &mut send_buf)
|
||||
}
|
||||
};
|
||||
|
||||
tracing::trace!(
|
||||
?encapsulate_result,
|
||||
"Received {} bytes from me",
|
||||
packet.len()
|
||||
);
|
||||
|
||||
match encapsulate_result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
self.udp
|
||||
.send_to(packet, self.endpoint)
|
||||
.await
|
||||
.context("Failed to send encrypted IP packet to WireGuard endpoint.")?;
|
||||
tracing::debug!(
|
||||
"Sent {} bytes to WireGuard endpoint (encrypted IP packet)",
|
||||
packet.len()
|
||||
);
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!("Failed to encapsulate IP packet: {:?}", e);
|
||||
}
|
||||
TunnResult::Done => {
|
||||
// Ignored
|
||||
}
|
||||
other => {
|
||||
tracing::error!(
|
||||
"Unexpected WireGuard state during encapsulation: {:?}",
|
||||
other
|
||||
);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
|
||||
/// decapsulates them, and dispatches newly received IP packets.
|
||||
#[tracing::instrument]
|
||||
pub async fn handle_one_packet_from_peer(&self, recv_buf: &[u8]) {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
let data = &recv_buf[..];
|
||||
let decapsulate_result = {
|
||||
let mut peer = self.tunn.lock().await;
|
||||
peer.decapsulate(None, data, &mut send_buf)
|
||||
};
|
||||
|
||||
tracing::debug!("Decapsulation result: {:?}", decapsulate_result);
|
||||
|
||||
match decapsulate_result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut peer = self.tunn.lock().await;
|
||||
loop {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
match peer.decapsulate(None, &[], &mut send_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
_ => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => {
|
||||
tracing::debug!(
|
||||
"WireGuard endpoint sent IP packet of {} bytes",
|
||||
packet.len()
|
||||
);
|
||||
let ret = self
|
||||
.sink
|
||||
.lock()
|
||||
.await
|
||||
.send(
|
||||
if matches!(self.wg_type, WgType::InternalUse) {
|
||||
self.remove_ip_header(packet, packet[0] >> 4 == 4)
|
||||
} else {
|
||||
packet
|
||||
}
|
||||
.to_vec()
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("Failed to send packet to tunnel: {:?}", ret);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
"Unexpected WireGuard state during decapsulation: {:?}",
|
||||
decapsulate_result
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
#[async_recursion]
|
||||
async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () {
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
tracing::debug!(
|
||||
"Sending routine packet of {} bytes to WireGuard endpoint",
|
||||
packet.len()
|
||||
);
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to send routine packet to WireGuard endpoint: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired) => {
|
||||
tracing::warn!("Wireguard handshake has expired!");
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let result = self
|
||||
.tunn
|
||||
.lock()
|
||||
.await
|
||||
.format_handshake_initiation(&mut buf[..], false);
|
||||
|
||||
self.handle_routine_tun_result(result).await
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to prepare routine packet for WireGuard endpoint: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
TunnResult::Done => {
|
||||
// Sleep for a bit
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("Unexpected WireGuard routine task state: {:?}", other);
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||
pub async fn routine_task(self) {
|
||||
loop {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) };
|
||||
self.handle_routine_tun_result(tun_result).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn add_ip_header(&self, packet: &[u8]) -> Vec<u8> {
|
||||
let mut ret = vec![0u8; packet.len() + 20];
|
||||
let ip_header = ret.as_mut_slice();
|
||||
ip_header[0] = 0x45;
|
||||
ip_header[1] = 0;
|
||||
ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes());
|
||||
ip_header[4..6].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[6..8].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[8] = 64;
|
||||
ip_header[9] = 0;
|
||||
ip_header[10..12].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[12..16].copy_from_slice(&0u32.to_be_bytes());
|
||||
ip_header[16..20].copy_from_slice(&0u32.to_be_bytes());
|
||||
ip_header[20..].copy_from_slice(packet);
|
||||
ret
|
||||
}
|
||||
|
||||
fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] {
|
||||
if is_v4 {
|
||||
return &packet[20..];
|
||||
} else {
|
||||
return &packet[40..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct WgPeer {
|
||||
udp: Arc<UdpSocket>, // only for send
|
||||
config: WgConfig,
|
||||
endpoint: SocketAddr,
|
||||
|
||||
data: Option<WgPeerData>,
|
||||
tasks: JoinSet<()>,
|
||||
|
||||
access_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl WgPeer {
|
||||
fn new(udp: Arc<UdpSocket>, config: WgConfig, endpoint: SocketAddr) -> Self {
|
||||
WgPeer {
|
||||
udp,
|
||||
config,
|
||||
endpoint,
|
||||
|
||||
data: None,
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
access_time: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_packet_from_me(data: WgPeerData) {
|
||||
while let Some(Ok(packet)) = data.stream.lock().await.next().await {
|
||||
let ret = data.handle_one_packet_from_me(&packet).await;
|
||||
if let Err(e) = ret {
|
||||
tracing::error!("Failed to handle packet from me: {}", e);
|
||||
}
|
||||
}
|
||||
data.stopped
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
|
||||
self.access_time = std::time::Instant::now();
|
||||
tracing::trace!("Received {} bytes from peer", packet.len());
|
||||
let data = self.data.as_ref().unwrap();
|
||||
data.handle_one_packet_from_peer(packet).await;
|
||||
}
|
||||
|
||||
fn start_and_get_tunnel(&mut self) -> Box<dyn Tunnel> {
|
||||
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
|
||||
let data = WgPeerData {
|
||||
udp: self.udp.clone(),
|
||||
endpoint: self.endpoint,
|
||||
tunn: Arc::new(Mutex::new(
|
||||
Tunn::new(
|
||||
self.config.my_secret_key.clone(),
|
||||
self.config.peer_public_key.clone(),
|
||||
None,
|
||||
None,
|
||||
rand::thread_rng().next_u32(),
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
)),
|
||||
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
|
||||
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
|
||||
wg_type: self.config.wg_type.clone(),
|
||||
stopped: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
self.data = Some(data.clone());
|
||||
self.tasks.spawn(Self::handle_packet_from_me(data.clone()));
|
||||
self.tasks.spawn(data.routine_task());
|
||||
|
||||
ctunnel
|
||||
}
|
||||
|
||||
fn stopped(&self) -> bool {
|
||||
self.data
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.stopped
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WgPeer {
|
||||
fn drop(&mut self) {
|
||||
self.tasks.abort_all();
|
||||
if let Some(data) = self.data.clone() {
|
||||
tokio::spawn(async move {
|
||||
let _ = data.sink.lock().await.close().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>;
|
||||
type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver<Box<dyn Tunnel>>;
|
||||
|
||||
pub struct WgTunnelListener {
|
||||
addr: url::Url,
|
||||
config: WgConfig,
|
||||
|
||||
udp: Option<Arc<UdpSocket>>,
|
||||
conn_recv: ConnReceiver,
|
||||
conn_send: Option<ConnSender>,
|
||||
|
||||
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl WgTunnelListener {
|
||||
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
WgTunnelListener {
|
||||
addr,
|
||||
config,
|
||||
|
||||
udp: None,
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
|
||||
wg_peer_map: Arc::new(DashMap::new()),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_udp_socket(&self) -> Arc<UdpSocket> {
|
||||
self.udp.as_ref().unwrap().clone()
|
||||
}
|
||||
|
||||
async fn handle_udp_incoming(
|
||||
socket: Arc<UdpSocket>,
|
||||
config: WgConfig,
|
||||
conn_sender: ConnSender,
|
||||
peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||
) {
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
let peer_map_clone = peer_map.clone();
|
||||
tasks.spawn(async move {
|
||||
loop {
|
||||
peer_map_clone
|
||||
.retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped());
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
loop {
|
||||
let Ok((n, addr)) = socket.recv_from(&mut buf).await else {
|
||||
tracing::error!("Failed to receive from UDP socket");
|
||||
break;
|
||||
};
|
||||
|
||||
let data = &buf[..n];
|
||||
tracing::trace!("Received {} bytes from {}", n, addr);
|
||||
|
||||
if !peer_map.contains_key(&addr) {
|
||||
tracing::info!("New peer: {}", addr);
|
||||
let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone());
|
||||
let tunnel = Box::new(TunnelWithCustomInfo::new(
|
||||
wg.start_and_get_tunnel(),
|
||||
TunnelInfo {
|
||||
tunnel_type: "wg".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(
|
||||
&socket.local_addr().unwrap().to_string(),
|
||||
"wg",
|
||||
)
|
||||
.into(),
|
||||
remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(),
|
||||
},
|
||||
));
|
||||
if let Err(e) = conn_sender.send(tunnel) {
|
||||
tracing::error!("Failed to send tunnel to conn_sender: {}", e);
|
||||
}
|
||||
peer_map.insert(addr, wg);
|
||||
}
|
||||
|
||||
let mut peer = peer_map.get_mut(&addr).unwrap();
|
||||
peer.handle_packet_from_peer(data).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for WgTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg")?;
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
}
|
||||
|
||||
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
self.tasks.spawn(Self::handle_udp_incoming(
|
||||
self.get_udp_socket(),
|
||||
self.config.clone(),
|
||||
self.conn_send.take().unwrap(),
|
||||
self.wg_peer_map.clone(),
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
while let Some(tunnel) = self.conn_recv.recv().await {
|
||||
tracing::info!(?tunnel, "Accepted tunnel");
|
||||
return Ok(tunnel);
|
||||
}
|
||||
Err(TunnelError::CommonError(
|
||||
"Failed to accept tunnel".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WgClientTunnel {
|
||||
wg_peer: WgPeer,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
info: TunnelInfo,
|
||||
}
|
||||
|
||||
impl Tunnel for WgClientTunnel {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
self.tunnel.stream()
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
self.tunnel.sink()
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(self.info.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WgTunnelConnector {
|
||||
addr: url::Url,
|
||||
config: WgConfig,
|
||||
udp: Option<Arc<UdpSocket>>,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl Debug for WgTunnelConnector {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WgTunnelConnector")
|
||||
.field("addr", &self.addr)
|
||||
.field("udp", &self.udp)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl WgTunnelConnector {
|
||||
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||
WgTunnelConnector {
|
||||
addr,
|
||||
config,
|
||||
udp: None,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
|
||||
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
|
||||
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
handshake_init.into()
|
||||
}
|
||||
|
||||
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
|
||||
assert!(
|
||||
matches!(keepalive, TunnResult::WriteToNetwork(_)),
|
||||
"Failed to parse handshake response, {:?}",
|
||||
keepalive
|
||||
);
|
||||
|
||||
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
keepalive.into()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(config))]
|
||||
async fn connect_with_socket(
|
||||
addr_url: url::Url,
|
||||
config: WgConfig,
|
||||
udp: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&addr_url, "wg")?;
|
||||
tracing::warn!("wg connect: {:?}", addr);
|
||||
let local_addr = udp.local_addr().unwrap().to_string();
|
||||
|
||||
let mut my_tun = Tunn::new(
|
||||
config.my_secret_key.clone(),
|
||||
config.peer_public_key.clone(),
|
||||
None,
|
||||
None,
|
||||
rand::thread_rng().next_u32(),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let init = Self::create_handshake_init(&mut my_tun);
|
||||
udp.send_to(&init, addr).await?;
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let (n, _) = udp.recv_from(&mut buf).await.unwrap();
|
||||
let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]);
|
||||
udp.send_to(&keepalive, addr).await?;
|
||||
|
||||
let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr);
|
||||
let tunnel = wg_peer.start_and_get_tunnel();
|
||||
|
||||
let data = wg_peer.data.as_ref().unwrap().clone();
|
||||
wg_peer.tasks.spawn(async move {
|
||||
loop {
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap();
|
||||
if recv_addr != addr {
|
||||
continue;
|
||||
}
|
||||
data.handle_one_packet_from_peer(&buf[..n]).await;
|
||||
}
|
||||
});
|
||||
|
||||
let ret = Box::new(WgClientTunnel {
|
||||
wg_peer,
|
||||
tunnel,
|
||||
info: TunnelInfo {
|
||||
tunnel_type: "wg".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(),
|
||||
remote_addr: addr_url.to_string(),
|
||||
},
|
||||
});
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for WgTunnelConnector {
|
||||
#[tracing::instrument]
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let bind_addrs = if self.bind_addrs.is_empty() {
|
||||
vec!["0.0.0.0:0".parse().unwrap()]
|
||||
} else {
|
||||
self.bind_addrs.clone()
|
||||
};
|
||||
let futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in bind_addrs.into_iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
|
||||
futures.push(Self::connect_with_socket(
|
||||
self.addr.clone(),
|
||||
self.config.clone(),
|
||||
socket,
|
||||
));
|
||||
}
|
||||
|
||||
wait_for_connect_futures(futures).await
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use boringtun::*;
|
||||
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
use crate::tunnels::{wireguard::*, TunnelConnector};
|
||||
|
||||
pub fn create_wg_config() -> (WgConfig, WgConfig) {
|
||||
let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||
let my_public_key = x25519::PublicKey::from(&my_secret_key);
|
||||
|
||||
let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||
let their_public_key = x25519::PublicKey::from(&their_secret_key);
|
||||
|
||||
let server_cfg = WgConfig {
|
||||
my_secret_key: my_secret_key.clone(),
|
||||
my_public_key,
|
||||
peer_secret_key: their_secret_key.clone(),
|
||||
peer_public_key: their_public_key.clone(),
|
||||
wg_type: WgType::InternalUse,
|
||||
};
|
||||
|
||||
let client_cfg = WgConfig {
|
||||
my_secret_key: their_secret_key,
|
||||
my_public_key: their_public_key,
|
||||
peer_secret_key: my_secret_key,
|
||||
peer_public_key: my_public_key,
|
||||
wg_type: WgType::InternalUse,
|
||||
};
|
||||
|
||||
(server_cfg, client_cfg)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_pingpong() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg);
|
||||
let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_bench() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg);
|
||||
let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_bench_with_bind() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg);
|
||||
let mut connector =
|
||||
WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg);
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn wg_bench_with_bind_fail() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg);
|
||||
let mut connector =
|
||||
WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg);
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_server_erase_from_map_after_close() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let mut listener =
|
||||
WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg);
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
const CONN_COUNT: usize = 10;
|
||||
|
||||
tokio::spawn(async move {
|
||||
for _ in 0..CONN_COUNT {
|
||||
let mut connector = WgTunnelConnector::new(
|
||||
"wg://127.0.0.1:5595".parse().unwrap(),
|
||||
client_cfg.clone(),
|
||||
);
|
||||
let ret = connector.connect().await;
|
||||
assert!(ret.is_ok());
|
||||
drop(ret);
|
||||
}
|
||||
});
|
||||
|
||||
for _ in 0..CONN_COUNT {
|
||||
let conn = listener.accept().await;
|
||||
assert!(conn.is_ok());
|
||||
drop(conn);
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
|
||||
assert_eq!(0, listener.wg_peer_map.len());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user