diff --git a/easytier/src/gateway/socks5.rs b/easytier/src/gateway/socks5.rs index 7fe62d24..ff31c4e9 100644 --- a/easytier/src/gateway/socks5.rs +++ b/easytier/src/gateway/socks5.rs @@ -2,7 +2,7 @@ use std::{ any::Any, net::{IpAddr, Ipv4Addr, SocketAddr}, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, Weak, }, time::{Duration, Instant}, @@ -163,6 +163,7 @@ type Socks5EntrySet = Arc>; struct SmolTcpConnector { net: Arc, entries: Socks5EntrySet, + entry_count: Arc, current_entry: std::sync::Mutex>, } @@ -187,6 +188,7 @@ impl AsyncTcpConnector for SmolTcpConnector { *self.current_entry.lock().unwrap() = Some(entry.clone()); self.entries .insert(entry, Socks5EntryData::Tcp(tmp_listener)); + self.entry_count.fetch_add(1, Ordering::Relaxed); if addr.ip() == local_addr { let modified_addr = @@ -215,6 +217,7 @@ impl Drop for SmolTcpConnector { if let Some(entry) = self.current_entry.lock().unwrap().take() { tracing::debug!("drop smoltcp connector entry {:?}", entry); self.entries.remove(&entry); + self.entry_count.fetch_sub(1, Ordering::Relaxed); } } } @@ -256,6 +259,7 @@ struct Socks5AutoConnector { kcp_endpoint: Option>, peer_mgr: Weak, entries: Socks5EntrySet, + entry_count: Arc, smoltcp_net: Option>, src_addr: SocketAddr, @@ -310,6 +314,7 @@ impl AsyncTcpConnector for Socks5AutoConnector { (_, _) => Box::new(SmolTcpConnector { net: self.smoltcp_net.clone().unwrap(), entries: self.entries.clone(), + entry_count: self.entry_count.clone(), current_entry: std::sync::Mutex::new(None), }), }; @@ -317,6 +322,7 @@ impl AsyncTcpConnector for Socks5AutoConnector { let connector = Box::new(SmolTcpConnector { net: self.smoltcp_net.clone().unwrap(), entries: self.entries.clone(), + entry_count: self.entry_count.clone(), current_entry: std::sync::Mutex::new(None), }); @@ -514,13 +520,13 @@ pub struct Socks5Server { socks5_enabled: Arc, cancel_tokens: Arc>, port_forward_list_change_notifier: Arc, + entry_count: Arc, } #[async_trait::async_trait] impl PeerPacketFilter for Socks5Server { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { - if self.cancel_tokens.is_empty() - && self.entries.is_empty() + if self.entry_count.load(Ordering::Relaxed) == 0 && !self.socks5_enabled.load(Ordering::Relaxed) { return Some(packet); @@ -628,6 +634,7 @@ impl Socks5Server { socks5_enabled: Arc::new(AtomicBool::new(false)), cancel_tokens: Arc::new(DashMap::new()), port_forward_list_change_notifier: Arc::new(Notify::new()), + entry_count: Arc::new(AtomicUsize::new(0)), }) } @@ -637,6 +644,7 @@ impl Socks5Server { let peer_manager = self.peer_manager.clone(); let packet_recv = self.packet_recv.clone(); let entries = self.entries.clone(); + let entry_count = self.entry_count.clone(); let udp_client_map = self.udp_client_map.clone(); let cancel_tokens = self.cancel_tokens.clone(); let port_forward_list_change_notifier = self.port_forward_list_change_notifier.clone(); @@ -656,7 +664,10 @@ impl Socks5Server { if prev_ipv4 != cur_ipv4 { prev_ipv4 = cur_ipv4; - entries.clear(); + entries.retain(|_, _| { + entry_count.fetch_sub(1, Ordering::Relaxed); + false + }); udp_client_map.clear(); if let Some(cur_ipv4) = cur_ipv4 { @@ -701,6 +712,7 @@ impl Socks5Server { )?; let entries = self.entries.clone(); + let entry_count = self.entry_count.clone(); let peer_manager = self.peer_manager.clone(); let net = self.net.clone(); self.tasks.lock().unwrap().spawn(async move { @@ -720,6 +732,7 @@ impl Socks5Server { peer_mgr: peer_manager.clone(), src_addr: addr, inner_connector: parking_lot::Mutex::new(None), + entry_count: entry_count.clone(), }; if let Some(net) = net.lock().await.as_ref() { net.handle_tcp_stream(socket, connector); @@ -832,6 +845,7 @@ impl Socks5Server { let net = self.net.clone(); let entries = self.entries.clone(); + let entry_count = self.entry_count.clone(); let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); join_joinset_background(tasks.clone(), "tcp port forward".to_string()); let forward_tasks = tasks; @@ -874,6 +888,7 @@ impl Socks5Server { entries: entries.clone(), smoltcp_net: net.lock().await.as_ref().map(|net| net.smoltcp_net.clone()), src_addr: addr, + entry_count: entry_count.clone(), inner_connector: parking_lot::Mutex::new(None), }; @@ -897,6 +912,7 @@ impl Socks5Server { let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?); let entries = self.entries.clone(); + let entry_count = self.entry_count.clone(); let net_ns = self.global_ctx.net_ns.clone(); let net = self.net.clone(); let udp_client_map = self.udp_client_map.clone(); @@ -1005,6 +1021,7 @@ impl Socks5Server { client_info.entry_key.clone(), Socks5EntryData::Udp((socks_udp.clone(), udp_client_key.clone())), ); + entry_count.fetch_add(1, Ordering::Relaxed); let socks = socket.clone(); let client_addr = addr; @@ -1057,6 +1074,7 @@ impl Socks5Server { let udp_client_map = self.udp_client_map.clone(); let udp_forward_task = self.udp_forward_task.clone(); let entries = self.entries.clone(); + let entry_count = self.entry_count.clone(); let cancel_tokens = self.cancel_tokens.clone(); self.tasks.lock().unwrap().spawn(async move { loop { @@ -1068,7 +1086,11 @@ impl Socks5Server { udp_forward_task.retain(|k, _| udp_client_map.contains_key(k)); entries.retain(|_, data| match data { Socks5EntryData::Udp((_, udp_client_key)) => { - udp_client_map.contains_key(udp_client_key) + let keep = udp_client_map.contains_key(udp_client_key); + if !keep { + entry_count.fetch_sub(1, Ordering::Relaxed); + } + keep } _ => true, });