fix wireguard deadlock

This commit is contained in:
sijie.sun
2024-04-28 22:08:11 +08:00
committed by Sijie.Sun
parent b3717d974b
commit 577cef131b
+12 -10
View File
@@ -14,6 +14,7 @@ use boringtun::{
x25519::{PublicKey, StaticSecret}, x25519::{PublicKey, StaticSecret},
}; };
use bytes::BytesMut; use bytes::BytesMut;
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap; use dashmap::DashMap;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore; use rand::RngCore;
@@ -343,7 +344,7 @@ struct WgPeer {
data: Option<WgPeerData>, data: Option<WgPeerData>,
tasks: JoinSet<()>, tasks: JoinSet<()>,
access_time: std::time::Instant, access_time: AtomicCell<std::time::Instant>,
} }
impl WgPeer { impl WgPeer {
@@ -358,7 +359,7 @@ impl WgPeer {
data: None, data: None,
tasks: JoinSet::new(), tasks: JoinSet::new(),
access_time: std::time::Instant::now(), access_time: AtomicCell::new(std::time::Instant::now()),
} }
} }
@@ -373,8 +374,8 @@ impl WgPeer {
.store(true, std::sync::atomic::Ordering::Relaxed); .store(true, std::sync::atomic::Ordering::Relaxed);
} }
async fn handle_packet_from_peer(&mut self, packet: &[u8]) { async fn handle_packet_from_peer(&self, packet: &[u8]) {
self.access_time = std::time::Instant::now(); self.access_time.store(std::time::Instant::now());
tracing::trace!("Received {} bytes from peer", packet.len()); tracing::trace!("Received {} bytes from peer", packet.len());
let data = self.data.as_ref().unwrap(); let data = self.data.as_ref().unwrap();
// TODO: improve this // TODO: improve this
@@ -436,7 +437,7 @@ pub struct WgTunnelListener {
conn_recv: ConnReceiver, conn_recv: ConnReceiver,
conn_send: Option<ConnSender>, conn_send: Option<ConnSender>,
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>, wg_peer_map: Arc<DashMap<SocketAddr, Arc<WgPeer>>>,
tasks: JoinSet<()>, tasks: JoinSet<()>,
} }
@@ -466,15 +467,16 @@ impl WgTunnelListener {
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
config: WgConfig, config: WgConfig,
conn_sender: ConnSender, conn_sender: ConnSender,
peer_map: Arc<DashMap<SocketAddr, WgPeer>>, peer_map: Arc<DashMap<SocketAddr, Arc<WgPeer>>>,
) { ) {
let mut tasks = JoinSet::new(); let mut tasks = JoinSet::new();
let peer_map_clone = peer_map.clone(); let peer_map_clone = peer_map.clone();
tasks.spawn(async move { tasks.spawn(async move {
loop { loop {
peer_map_clone peer_map_clone.retain(|_, peer| {
.retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped()); peer.access_time.load().elapsed().as_secs() < 61 && !peer.stopped()
});
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
} }
}); });
@@ -509,10 +511,10 @@ impl WgTunnelListener {
if let Err(e) = conn_sender.send(tunnel) { if let Err(e) = conn_sender.send(tunnel) {
tracing::error!("Failed to send tunnel to conn_sender: {}", e); tracing::error!("Failed to send tunnel to conn_sender: {}", e);
} }
peer_map.insert(addr, wg); peer_map.insert(addr, Arc::new(wg));
} }
let mut peer = peer_map.get_mut(&addr).unwrap(); let peer = peer_map.get(&addr).unwrap().clone();
peer.handle_packet_from_peer(data).await; peer.handle_packet_from_peer(data).await;
} }
} }