some minor bug fixs (#41)

* fix joinset leak; 

* fix udp packet format

* fix trace log panic

* avoid waiting after listener accept
This commit is contained in:
Sijie.Sun
2024-03-24 22:21:47 +08:00
committed by GitHub
parent 0f6f553010
commit ce889e990e
6 changed files with 186 additions and 49 deletions
+88
View File
@@ -1,3 +1,11 @@
use std::{
fmt::Debug,
future,
sync::{Arc, Mutex},
};
use tokio::task::JoinSet;
use tracing::Instrument;
pub mod config; pub mod config;
pub mod constants; pub mod constants;
pub mod error; pub mod error;
@@ -30,3 +38,83 @@ pub type PeerId = u32;
pub fn new_peer_id() -> PeerId { pub fn new_peer_id() -> PeerId {
rand::random() rand::random()
} }
pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
js: Arc<Mutex<JoinSet<T>>>,
origin: String,
) {
let js = Arc::downgrade(&js);
tokio::spawn(
async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if js.weak_count() == 0 {
tracing::info!("joinset task exit");
break;
}
future::poll_fn(|cx| {
tracing::info!("try join joinset tasks");
let Some(js) = js.upgrade() else {
return std::task::Poll::Ready(());
};
let mut js = js.lock().unwrap();
while !js.is_empty() {
let ret = js.poll_join_next(cx);
if ret.is_pending() {
return std::task::Poll::Pending;
}
}
std::task::Poll::Ready(())
})
.await;
}
}
.instrument(tracing::info_span!(
"join_joinset_background",
origin = origin
)),
);
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_join_joinset_backgroud() {
let js = Arc::new(Mutex::new(JoinSet::<()>::new()));
join_joinset_background(js.clone(), "TEST".to_owned());
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
});
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(js.try_lock().unwrap().is_empty());
for _ in 0..5 {
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
});
tokio::task::yield_now().await;
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
for _ in 0..5 {
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
});
tokio::task::yield_now().await;
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(js.try_lock().unwrap().is_empty());
let weak_js = Arc::downgrade(&js);
drop(js);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert_eq!(weak_js.weak_count(), 0);
}
}
+14 -9
View File
@@ -8,8 +8,8 @@ use tracing::Instrument;
use crate::{ use crate::{
common::{ common::{
constants, error::Error, global_ctx::ArcGlobalCtx, rkyv_util::encode_to_bytes, constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
stun::StunInfoCollectorTrait, PeerId, rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
}, },
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
rpc::NatType, rpc::NatType,
@@ -75,9 +75,15 @@ impl UdpHolePunchListener {
while let Ok(conn) = listener.accept().await { while let Ok(conn) = listener.accept().await {
last_connected_time_clone.store(std::time::Instant::now()); last_connected_time_clone.store(std::time::Instant::now());
tracing::warn!(?conn, "udp hole punching listener got peer connection"); tracing::warn!(?conn, "udp hole punching listener got peer connection");
let peer_mgr = peer_mgr.clone();
tokio::spawn(async move {
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await { if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
tracing::error!(?e, "failed to add tunnel as server in hole punch listener"); tracing::error!(
?e,
"failed to add tunnel as server in hole punch listener"
);
} }
});
} }
running_clone.store(false); running_clone.store(false);
@@ -115,7 +121,7 @@ struct UdpHolePunchConnectorData {
struct UdpHolePunchRpcServer { struct UdpHolePunchRpcServer {
data: Arc<UdpHolePunchConnectorData>, data: Arc<UdpHolePunchConnectorData>,
tasks: Arc<Mutex<JoinSet<()>>>, tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
} }
#[tarpc::server] #[tarpc::server]
@@ -140,7 +146,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
|| my_udp_nat_type == NatType::Restricted as i32 || my_udp_nat_type == NatType::Restricted as i32
{ {
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second // send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
self.tasks.lock().await.spawn(async move { self.tasks.lock().unwrap().spawn(async move {
for _ in 0..10 { for _ in 0..10 {
tracing::info!(?local_mapped_addr, "sending hole punching packet"); tracing::info!(?local_mapped_addr, "sending hole punching packet");
// generate a 128 bytes vec with random data // generate a 128 bytes vec with random data
@@ -164,10 +170,9 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
impl UdpHolePunchRpcServer { impl UdpHolePunchRpcServer {
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self { pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
Self { let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
data, join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
tasks: Arc::new(Mutex::new(JoinSet::new())), Self { data, tasks }
}
} }
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> { async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
+7 -5
View File
@@ -16,6 +16,7 @@ use tracing::Instrument;
use crate::common::error::Result; use crate::common::error::Result;
use crate::common::global_ctx::GlobalCtx; use crate::common::global_ctx::GlobalCtx;
use crate::common::join_joinset_background;
use crate::common::netns::NetNS; use crate::common::netns::NetNS;
use crate::peers::packet::{self, ArchivedPacket}; use crate::peers::packet::{self, ArchivedPacket};
use crate::peers::peer_manager::PeerManager; use crate::peers::peer_manager::PeerManager;
@@ -71,7 +72,7 @@ pub struct TcpProxy {
peer_manager: Arc<PeerManager>, peer_manager: Arc<PeerManager>,
local_port: AtomicU16, local_port: AtomicU16,
tasks: Arc<Mutex<JoinSet<()>>>, tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
syn_map: SynSockMap, syn_map: SynSockMap,
conn_map: ConnSockMap, conn_map: ConnSockMap,
@@ -215,7 +216,7 @@ impl TcpProxy {
peer_manager, peer_manager,
local_port: AtomicU16::new(0), local_port: AtomicU16::new(0),
tasks: Arc::new(Mutex::new(JoinSet::new())), tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
syn_map: Arc::new(DashMap::new()), syn_map: Arc::new(DashMap::new()),
conn_map: Arc::new(DashMap::new()), conn_map: Arc::new(DashMap::new()),
@@ -247,6 +248,7 @@ impl TcpProxy {
self.peer_manager self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.clone())) .add_nic_packet_process_pipeline(Box::new(self.clone()))
.await; .await;
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());
Ok(()) Ok(())
} }
@@ -268,7 +270,7 @@ impl TcpProxy {
tokio::time::sleep(Duration::from_secs(10)).await; tokio::time::sleep(Duration::from_secs(10)).await;
} }
}; };
tasks.lock().await.spawn(syn_map_cleaner_task); tasks.lock().unwrap().spawn(syn_map_cleaner_task);
Ok(()) Ok(())
} }
@@ -312,7 +314,7 @@ impl TcpProxy {
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone()); let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
assert!(old_nat_val.is_none()); assert!(old_nat_val.is_none());
tasks.lock().await.spawn(Self::connect_to_nat_dst( tasks.lock().unwrap().spawn(Self::connect_to_nat_dst(
net_ns.clone(), net_ns.clone(),
tcp_stream, tcp_stream,
conn_map.clone(), conn_map.clone(),
@@ -325,7 +327,7 @@ impl TcpProxy {
}; };
self.tasks self.tasks
.lock() .lock()
.await .unwrap()
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener"))); .spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
Ok(()) Ok(())
+4
View File
@@ -100,6 +100,9 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
tunnel_info.remote_addr.clone(), tunnel_info.remote_addr.clone(),
)); ));
tracing::info!(ret = ?ret, "conn accepted"); tracing::info!(ret = ?ret, "conn accepted");
let peer_manager = peer_manager.clone();
let global_ctx = global_ctx.clone();
tokio::spawn(async move {
let server_ret = peer_manager.handle_tunnel(ret).await; let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret { if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError( global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
@@ -109,6 +112,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
)); ));
tracing::error!(error = ?e, "handle conn error"); tracing::error!(error = ?e, "handle conn error");
} }
});
} }
} }
+14 -1
View File
@@ -99,7 +99,7 @@ pub enum PacketType {
TaRpc = 6, TaRpc = 6,
} }
#[derive(Archive, Deserialize, Serialize, Debug)] #[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)] #[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type: // Derives can be passed through to the generated type:
pub struct Packet { pub struct Packet {
@@ -109,6 +109,19 @@ pub struct Packet {
pub payload: String, pub payload: String,
} }
impl std::fmt::Debug for Packet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
self.from_peer,
self.to_peer,
self.packet_type,
&self.payload.as_bytes()
)
}
}
impl std::fmt::Debug for ArchivedPacket { impl std::fmt::Debug for ArchivedPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!( write!(
+48 -23
View File
@@ -13,7 +13,10 @@ use tokio_util::{
use tracing::Instrument; use tracing::Instrument;
use crate::{ use crate::{
common::rkyv_util::{self, encode_to_bytes, vec_to_string}, common::{
join_joinset_background,
rkyv_util::{self, encode_to_bytes, vec_to_string},
},
rpc::TunnelInfo, rpc::TunnelInfo,
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector}, tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
}; };
@@ -27,7 +30,7 @@ use super::{
pub const UDP_DATA_MTU: usize = 2500; pub const UDP_DATA_MTU: usize = 2500;
#[derive(Archive, Deserialize, Serialize, Debug)] #[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)] #[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type: // Derives can be passed through to the generated type:
pub enum UdpPacketPayload { pub enum UdpPacketPayload {
@@ -37,14 +40,29 @@ pub enum UdpPacketPayload {
Data(String), Data(String),
} }
impl std::fmt::Debug for UdpPacketPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
match self {
UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(),
UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
}
}
}
#[derive(Archive, Deserialize, Serialize, Debug)] #[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)] #[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))] #[archive_attr(derive(Debug))]
pub struct UdpPacket { pub struct UdpPacket {
pub conn_id: u32, pub conn_id: u32,
pub magic: u32,
pub payload: UdpPacketPayload, pub payload: UdpPacketPayload,
} }
const UDP_PACKET_MAGIC: u32 = 0x19941126;
impl std::fmt::Debug for ArchivedUdpPacketPayload { impl std::fmt::Debug for ArchivedUdpPacketPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload"); let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
@@ -63,6 +81,7 @@ impl UdpPacket {
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self { pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
Self { Self {
conn_id, conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Data(vec_to_string(data)), payload: UdpPacketPayload::Data(vec_to_string(data)),
} }
} }
@@ -70,6 +89,7 @@ impl UdpPacket {
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self { pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
Self { Self {
conn_id: 0, conn_id: 0,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::HolePunch(vec_to_string(data)), payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
} }
} }
@@ -77,6 +97,7 @@ impl UdpPacket {
pub fn new_syn_packet(conn_id: u32) -> Self { pub fn new_syn_packet(conn_id: u32) -> Self {
Self { Self {
conn_id, conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Syn, payload: UdpPacketPayload::Syn,
} }
} }
@@ -84,6 +105,7 @@ impl UdpPacket {
pub fn new_sack_packet(conn_id: u32) -> Self { pub fn new_sack_packet(conn_id: u32) -> Self {
Self { Self {
conn_id, conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Sack, payload: UdpPacketPayload::Sack,
} }
} }
@@ -100,6 +122,11 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
return None; return None;
} }
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::warn!(?udp_packet, "udp magic not match");
return None;
}
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else { let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
tracing::warn!(?udp_packet, "udp payload not data"); tracing::warn!(?udp_packet, "udp payload not data");
return None; return None;
@@ -189,7 +216,7 @@ pub struct UdpTunnelListener {
socket: Option<Arc<UdpSocket>>, socket: Option<Arc<UdpSocket>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>, sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
forward_tasks: Arc<Mutex<JoinSet<()>>>, forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>, conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>, conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
@@ -202,7 +229,7 @@ impl UdpTunnelListener {
addr, addr,
socket: None, socket: None,
sock_map: Arc::new(DashMap::new()), sock_map: Arc::new(DashMap::new()),
forward_tasks: Arc::new(Mutex::new(JoinSet::new())), forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
conn_recv, conn_recv,
conn_send: Some(conn_send), conn_send: Some(conn_send),
} }
@@ -234,7 +261,7 @@ impl UdpTunnelListener {
async fn handle_connect( async fn handle_connect(
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
addr: SocketAddr, addr: SocketAddr,
forward_tasks: Arc<Mutex<JoinSet<()>>>, forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>, sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
local_url: url::Url, local_url: url::Url,
conn_id: u32, conn_id: u32,
@@ -251,7 +278,7 @@ impl UdpTunnelListener {
let addr_copy = addr.clone(); let addr_copy = addr.clone();
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair))); sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
let ctunnel_stream = ctunnel.pin_stream(); let ctunnel_stream = ctunnel.pin_stream();
forward_tasks.lock().await.spawn(async move { forward_tasks.lock().unwrap().spawn(async move {
let ret = ctunnel_stream let ret = ctunnel_stream
.map(|v| { .map(|v| {
tracing::trace!(?v, "udp stream recv something in forward task"); tracing::trace!(?v, "udp stream recv something in forward task");
@@ -304,7 +331,7 @@ impl TunnelListener for UdpTunnelListener {
let sock_map = self.sock_map.clone(); let sock_map = self.sock_map.clone();
let conn_send = self.conn_send.take().unwrap(); let conn_send = self.conn_send.take().unwrap();
let local_url = self.local_url().clone(); let local_url = self.local_url().clone();
self.forward_tasks.lock().await.spawn( self.forward_tasks.lock().unwrap().spawn(
async move { async move {
loop { loop {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
@@ -323,6 +350,11 @@ impl TunnelListener for UdpTunnelListener {
continue; continue;
}; };
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::info!(?udp_packet, "udp magic not match");
continue;
}
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) { if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
let Ok(conn) = Self::handle_connect( let Ok(conn) = Self::handle_connect(
socket.clone(), socket.clone(),
@@ -350,22 +382,7 @@ impl TunnelListener for UdpTunnelListener {
.instrument(tracing::info_span!("udp forward task", ?self.socket)), .instrument(tracing::info_span!("udp forward task", ?self.socket)),
); );
// let forward_tasks_clone = self.forward_tasks.clone(); join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
// tokio::spawn(async move {
// loop {
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
// tokio::select! {
// ret = locked_forward_tasks.join_next() => {
// tracing::warn!(?ret, "udp forward task exit");
// }
// else => {
// drop(locked_forward_tasks);
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// continue;
// }
// }
// }
// });
Ok(()) Ok(())
} }
@@ -453,6 +470,14 @@ impl UdpTunnelConnector {
))); )));
}; };
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::info!(?udp_packet, "udp magic not match");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, magic not match. magic: {:?}",
udp_packet.magic
)));
}
if conn_id != udp_packet.conn_id { if conn_id != udp_packet.conn_id {
return Err(super::TunnelError::ConnectError(format!( return Err(super::TunnelError::ConnectError(format!(
"udp connect error, conn id not match. conn_id: {:?}, {:?}", "udp connect error, conn id not match. conn_id: {:?}, {:?}",