refactor: remove ScopedTask (#2125)

* replace ScopedTask with AbortOnDropHandle
This commit is contained in:
Luna Yao
2026-04-25 09:20:25 +02:00
committed by GitHub
parent 820d9095d3
commit 5a1668c753
32 changed files with 161 additions and 300 deletions
Generated
+1
View File
@@ -2519,6 +2519,7 @@ dependencies = [
"thiserror 1.0.63",
"thunk-rs",
"tokio",
"tokio-util",
"tower-http",
"tower-sessions",
"tower-sessions-sqlx-store",
@@ -7,15 +7,15 @@ use std::{
use anyhow::Context as _;
use dashmap::DashMap;
use easytier::{
common::{
config::{ConfigFileControl, ConfigLoader, NetworkIdentity, PeerConfig, TomlConfigLoader},
scoped_task::ScopedTask,
common::config::{
ConfigFileControl, ConfigLoader, NetworkIdentity, PeerConfig, TomlConfigLoader,
},
defer,
instance_manager::NetworkInstanceManager,
};
use serde::{Deserialize, Serialize};
use sqlx::any;
use tokio_util::task::AbortOnDropHandle;
use tracing::{debug, error, info, instrument, warn};
use crate::db::{
@@ -240,7 +240,7 @@ pub struct HealthChecker {
db: Db,
instance_mgr: Arc<NetworkInstanceManager>,
inst_id_map: DashMap<i32, uuid::Uuid>,
node_tasks: DashMap<i32, ScopedTask<()>>,
node_tasks: DashMap<i32, AbortOnDropHandle<()>>,
node_records: Arc<DashMap<i32, HealthyMemRecord>>,
node_cfg: Arc<DashMap<i32, TomlConfigLoader>>,
}
@@ -465,7 +465,7 @@ impl HealthChecker {
}
// 启动健康检查任务
let task = ScopedTask::from(tokio::spawn(Self::node_health_check_task(
let task = AbortOnDropHandle::new(tokio::spawn(Self::node_health_check_task(
node_id,
cfg.get_id(),
Arc::clone(&self.instance_mgr),
+1
View File
@@ -10,6 +10,7 @@ tracing = { version = "0.1", features = ["log"] }
anyhow = { version = "1.0" }
thiserror = "1.0"
tokio = { version = "1", features = ["full"] }
tokio-util = { version = "0.7", features = ["rt"] }
dashmap = "6.1"
url = "2.2"
async-trait = "0.1"
+11 -10
View File
@@ -7,7 +7,7 @@ use std::{
use anyhow::Context;
use easytier::{
common::{config::ConfigSource, scoped_task::ScopedTask},
common::config::ConfigSource,
proto::{
api::manage::{
ConfigSource as RpcConfigSource, NetworkConfig, NetworkMeta, RunNetworkInstanceRequest,
@@ -21,6 +21,7 @@ use easytier::{
tunnel::Tunnel,
};
use tokio::sync::{RwLock, broadcast};
use tokio_util::task::AbortOnDropHandle;
use super::storage::{Storage, StorageToken, WeakRefStorage};
use crate::FeatureFlags;
@@ -475,7 +476,7 @@ pub struct Session {
data: SharedSessionData,
run_network_on_start_task: Option<ScopedTask<()>>,
run_network_on_start_task: Option<AbortOnDropHandle<()>>,
}
impl Debug for Session {
@@ -517,14 +518,14 @@ impl Session {
self.rpc_mgr.run_with_tunnel(tunnel);
let data = self.data.read().await;
self.run_network_on_start_task.replace(
tokio::spawn(Self::run_network_on_start(
data.heartbeat_waiter(),
data.storage.clone(),
self.scoped_rpc_client(),
))
.into(),
);
self.run_network_on_start_task
.replace(AbortOnDropHandle::new(tokio::spawn(
Self::run_network_on_start(
data.heartbeat_waiter(),
data.storage.clone(),
self.scoped_rpc_client(),
),
)));
}
fn collect_webhook_source_instance_ids(
+10 -13
View File
@@ -17,12 +17,12 @@ use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer};
use axum_login::{AuthManagerLayerBuilder, AuthUser, AuthzBackend, login_required};
use axum_messages::MessagesManagerLayer;
use easytier::common::config::{ConfigLoader, TomlConfigLoader};
use easytier::common::scoped_task::ScopedTask;
use easytier::launcher::NetworkConfig;
use easytier::proto::rpc_types;
use network::NetworkApi;
use sea_orm::DbErr;
use tokio::net::TcpListener;
use tokio_util::task::AbortOnDropHandle;
use tower_sessions::Expiry;
use tower_sessions::cookie::time::Duration;
use tower_sessions::cookie::{Key, SameSite};
@@ -199,8 +199,8 @@ impl RestfulServer {
mut self,
) -> Result<
(
ScopedTask<()>,
ScopedTask<tower_sessions::session_store::Result<()>>,
AbortOnDropHandle<()>,
AbortOnDropHandle<tower_sessions::session_store::Result<()>>,
),
anyhow::Error,
> {
@@ -213,13 +213,11 @@ impl RestfulServer {
let session_store = SqliteStore::new(self.db.inner());
session_store.migrate().await?;
let delete_task: ScopedTask<tower_sessions::session_store::Result<()>> =
tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
)
.into();
let delete_task = AbortOnDropHandle::new(tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
));
// Generate a cryptographic key to sign the session cookie.
let key = Key::generate();
@@ -298,10 +296,9 @@ impl RestfulServer {
app
};
let serve_task: ScopedTask<()> = tokio::spawn(async move {
let serve_task = AbortOnDropHandle::new(tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
})
.into();
}));
Ok((serve_task, delete_task))
}
+5 -6
View File
@@ -6,10 +6,10 @@ use axum::{
routing,
};
use axum_embed::ServeEmbed;
use easytier::common::scoped_task::ScopedTask;
use rust_embed::RustEmbed;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio_util::task::AbortOnDropHandle;
/// Embed assets for web dashboard, build frontend first
#[derive(RustEmbed, Clone)]
@@ -59,7 +59,7 @@ pub fn build_router(api_host: Option<url::Url>) -> Router {
pub struct WebServer {
bind_addr: SocketAddr,
router: Router,
serve_task: Option<ScopedTask<()>>,
serve_task: Option<AbortOnDropHandle<()>>,
}
impl WebServer {
@@ -71,14 +71,13 @@ impl WebServer {
})
}
pub async fn start(self) -> Result<ScopedTask<()>, anyhow::Error> {
pub async fn start(self) -> Result<AbortOnDropHandle<()>, anyhow::Error> {
let listener = TcpListener::bind(self.bind_addr).await?;
let app = self.router;
let task = tokio::spawn(async move {
let task = AbortOnDropHandle::new(tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
})
.into();
}));
Ok(task)
}
-1
View File
@@ -25,7 +25,6 @@ pub mod log;
pub mod netns;
pub mod network;
pub mod os_info;
pub mod scoped_task;
pub mod stats_manager;
pub mod stun;
pub mod stun_codec_ext;
-119
View File
@@ -1,119 +0,0 @@
//! This crate provides a wrapper type of Tokio's JoinHandle: `ScopedTask`, which aborts the task when it's dropped.
//! `ScopedTask` can still be awaited to join the child-task, and abort-on-drop will still trigger while it is being awaited.
//!
//! For example, if task A spawned task B but is doing something else, and task B is waiting for task C to join,
//! aborting A will also abort both B and C.
use derive_more::{Deref, DerefMut, From};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinHandle;
#[derive(Debug, From, Deref, DerefMut)]
pub struct ScopedTask<T>(JoinHandle<T>);
impl<T> Drop for ScopedTask<T> {
fn drop(&mut self) {
self.abort()
}
}
impl<T> Future for ScopedTask<T> {
type Output = <JoinHandle<T> as Future>::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}
#[cfg(test)]
mod tests {
use super::ScopedTask;
use futures_util::future::pending;
use std::sync::{Arc, RwLock};
use tokio::task::yield_now;
struct Sentry(Arc<RwLock<bool>>);
impl Drop for Sentry {
fn drop(&mut self) {
*self.0.write().unwrap() = true
}
}
#[tokio::test]
async fn drop_while_not_waiting_for_join() {
let dropped = Arc::new(RwLock::new(false));
let sentry = Sentry(dropped.clone());
let task = ScopedTask::from(tokio::spawn(async move {
let _sentry = sentry;
pending::<()>().await
}));
yield_now().await;
assert!(!*dropped.read().unwrap());
drop(task);
yield_now().await;
assert!(*dropped.read().unwrap());
}
#[tokio::test]
async fn drop_while_waiting_for_join() {
let dropped = Arc::new(RwLock::new(false));
let sentry = Sentry(dropped.clone());
let handle = tokio::spawn(async move {
ScopedTask::from(tokio::spawn(async move {
let _sentry = sentry;
pending::<()>().await
}))
.await
.unwrap()
});
yield_now().await;
assert!(!*dropped.read().unwrap());
handle.abort();
yield_now().await;
assert!(*dropped.read().unwrap());
}
#[tokio::test]
async fn no_drop_only_join() {
assert_eq!(
ScopedTask::from(tokio::spawn(async {
yield_now().await;
5
}))
.await
.unwrap(),
5
)
}
#[tokio::test]
async fn manually_abort_before_drop() {
let dropped = Arc::new(RwLock::new(false));
let sentry = Sentry(dropped.clone());
let task = ScopedTask::from(tokio::spawn(async move {
let _sentry = sentry;
pending::<()>().await
}));
yield_now().await;
assert!(!*dropped.read().unwrap());
task.abort();
yield_now().await;
assert!(*dropped.read().unwrap());
}
#[tokio::test]
async fn manually_abort_then_join() {
let dropped = Arc::new(RwLock::new(false));
let sentry = Sentry(dropped.clone());
let task = ScopedTask::from(tokio::spawn(async move {
let _sentry = sentry;
pending::<()>().await
}));
yield_now().await;
assert!(!*dropped.read().unwrap());
task.abort();
yield_now().await;
assert!(task.await.is_err());
}
}
+3 -4
View File
@@ -5,8 +5,7 @@ use std::fmt;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::interval;
use crate::common::scoped_task::ScopedTask;
use tokio_util::task::AbortOnDropHandle;
/// Predefined metric names for type safety
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -578,7 +577,7 @@ impl MetricSnapshot {
/// StatsManager manages global statistics with high performance counters
pub struct StatsManager {
counters: Arc<DashMap<MetricKey, Arc<MetricData>>>,
cleanup_task: ScopedTask<()>,
cleanup_task: AbortOnDropHandle<()>,
}
impl StatsManager {
@@ -611,7 +610,7 @@ impl StatsManager {
Self {
counters,
cleanup_task: cleanup_task.into(),
cleanup_task: AbortOnDropHandle::new(cleanup_task),
}
}
+5 -7
View File
@@ -1343,11 +1343,9 @@ impl StunInfoCollectorTrait for MockStunInfoCollector {
#[cfg(test)]
mod tests {
use crate::{
common::scoped_task::ScopedTask,
tunnel::{TunnelListener, udp::UdpTunnelListener},
};
use crate::tunnel::{TunnelListener, udp::UdpTunnelListener};
use tokio::time::{sleep, timeout};
use tokio_util::task::AbortOnDropHandle;
use super::*;
@@ -1441,7 +1439,7 @@ mod tests {
use stun_codec::rfc5389::attributes::XorMappedAddress;
use tokio::net::TcpListener;
async fn spawn_tcp_stun_server() -> (SocketAddr, ScopedTask<()>) {
async fn spawn_tcp_stun_server() -> (SocketAddr, AbortOnDropHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
@@ -1465,7 +1463,7 @@ mod tests {
stream.write_all(rsp_buf.as_slice()).await.unwrap();
});
(server_addr, task.into())
(server_addr, AbortOnDropHandle::new(task))
}
let (server1, _t1) = spawn_tcp_stun_server().await;
@@ -1504,7 +1502,7 @@ mod tests {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let _t = ScopedTask::from(tokio::spawn(async move {
let _t = AbortOnDropHandle::new(tokio::spawn(async move {
for _ in 0..8 {
let Ok((mut stream, peer_addr)) = listener.accept().await else {
break;
+5 -5
View File
@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::Notify;
use tokio::time;
use tokio_util::task::AbortOnDropHandle;
use crate::common::scoped_task::ScopedTask;
use crate::proto::common::LimiterConfig;
/// Token Bucket rate limiter using atomic operations
@@ -14,7 +14,7 @@ pub struct TokenBucket {
available_tokens: AtomicU64, // Current token count (atomic)
last_refill_time: AtomicU64, // Last refill time as micros since epoch
config: BucketConfig, // Immutable configuration
refill_task: Mutex<Option<ScopedTask<()>>>, // Background refill task
refill_task: Mutex<Option<AbortOnDropHandle<()>>>, // Background refill task
start_time: Instant, // Bucket creation time
refill_notifier: Arc<Notify>,
@@ -91,7 +91,7 @@ impl TokenBucket {
.refill_task
.lock()
.unwrap()
.replace(refill_task.into());
.replace(AbortOnDropHandle::new(refill_task));
arc_self
}
@@ -172,7 +172,7 @@ impl TokenBucket {
pub struct TokenBucketManager {
buckets: Arc<DashMap<String, Arc<TokenBucket>>>,
retain_task: ScopedTask<()>,
retain_task: AbortOnDropHandle<()>,
}
impl Default for TokenBucketManager {
@@ -205,7 +205,7 @@ impl TokenBucketManager {
Self {
buckets,
retain_task: retain_task.into(),
retain_task: AbortOnDropHandle::new(retain_task),
}
}
@@ -6,9 +6,10 @@ use std::{
use anyhow::Context;
use tokio::sync::Mutex;
use tokio_util::task::AbortOnDropHandle;
use crate::{
common::{PeerId, scoped_task::ScopedTask, stun::StunInfoCollectorTrait},
common::{PeerId, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::common::{
HOLE_PUNCH_PACKET_BODY_LEN, UdpHolePunchListener, try_connect_with_socket,
},
@@ -32,7 +33,7 @@ const REMOTE_WAIT_TIME_MS: u64 = 5000;
pub(crate) struct PunchBothEasySymHoleServer {
common: Arc<PunchHoleServerCommon>,
task: Mutex<Option<ScopedTask<()>>>,
task: Mutex<Option<AbortOnDropHandle<()>>>,
}
impl PunchBothEasySymHoleServer {
@@ -161,7 +162,7 @@ impl PunchBothEasySymHoleServer {
}
});
*locked_task = Some(task.into());
*locked_task = Some(AbortOnDropHandle::new(task));
return Ok(SendPunchPacketBothEasySymResponse {
is_busy: false,
base_mapped_addr: Some(cur_mapped_addr.into()),
@@ -5,9 +5,10 @@ use std::{
use anyhow::Context;
use tokio::net::UdpSocket;
use tokio_util::task::AbortOnDropHandle;
use crate::{
common::{PeerId, scoped_task::ScopedTask, upnp},
common::{PeerId, upnp},
connector::udp_hole_punch::common::{
HOLE_PUNCH_PACKET_BODY_LEN, UdpSocketArray, try_connect_with_socket,
},
@@ -178,7 +179,7 @@ impl PunchConeHoleClient {
send_from_local().await?;
let scoped_punch_task: ScopedTask<()> = tokio::spawn(async move {
let punch_task = AbortOnDropHandle::new(tokio::spawn(async move {
if let Err(e) = rpc_stub
.send_punch_packet_cone(
BaseController {
@@ -198,8 +199,7 @@ impl PunchConeHoleClient {
{
tracing::error!(?e, "failed to call remote send punch packet");
}
})
.into();
}));
// server: will send some punching resps, total 10 packets.
// client: use the socket to create UdpTunnel with UdpTunnelConnector
@@ -208,7 +208,7 @@ impl PunchConeHoleClient {
while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 {
tokio::time::sleep(Duration::from_millis(200)).await;
if finish_time.is_none() && (*scoped_punch_task).is_finished() {
if finish_time.is_none() && punch_task.is_finished() {
finish_time = Some(Instant::now());
}
@@ -11,12 +11,11 @@ use std::{
use anyhow::Context;
use rand::{Rng, seq::SliceRandom};
use tokio::{net::UdpSocket, sync::RwLock};
use tokio_util::task::AbortOnDropHandle;
use tracing::Level;
use crate::{
common::{
PeerId, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, stun::StunInfoCollectorTrait,
},
common::{PeerId, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::{
common::{
HOLE_PUNCH_PACKET_BODY_LEN, send_symmetric_hole_punch_packet, try_connect_with_socket,
@@ -360,7 +359,7 @@ impl PunchSymToConeHoleClient {
packet: &[u8],
tid: u32,
remote_mapped_addr: crate::proto::common::SocketAddr,
scoped_punch_task: &ScopedTask<T>,
punch_task: &AbortOnDropHandle<T>,
) -> Result<Option<Box<dyn Tunnel>>, anyhow::Error> {
// no matter what the result is, we should check if we received any hole punching packet
let mut ret_tunnel: Option<Box<dyn Tunnel>> = None;
@@ -372,7 +371,7 @@ impl PunchSymToConeHoleClient {
tokio::time::sleep(Duration::from_millis(200)).await;
if finish_time.is_none() && (*scoped_punch_task).is_finished() {
if finish_time.is_none() && punch_task.is_finished() {
finish_time = Some(Instant::now());
}
@@ -482,27 +481,27 @@ impl PunchSymToConeHoleClient {
if self.punch_predicablely.load(Ordering::Relaxed) && base_port_for_easy_sym.is_some() {
let rpc_stub = self.get_rpc_stub(dst_peer_id).await;
let scoped_punch_task: ScopedTask<()> =
tokio::spawn(Self::remote_send_hole_punch_packet_predicable(
let punch_task = AbortOnDropHandle::new(tokio::spawn(
Self::remote_send_hole_punch_packet_predicable(
rpc_stub,
base_port_for_easy_sym,
my_nat_info,
remote_mapped_addr,
public_ips.clone(),
tid,
))
.into();
),
));
let ret_tunnel = Self::check_hole_punch_result(
global_ctx.clone(),
&udp_array,
&packet,
tid,
remote_mapped_addr,
&scoped_punch_task,
&punch_task,
)
.await?;
let task_ret = scoped_punch_task.await;
let task_ret = punch_task.await;
tracing::debug!(?ret_tunnel, ?task_ret, "predictable punch task got result");
if let Some(tunnel) = ret_tunnel {
return Ok(Some(tunnel));
@@ -510,27 +509,26 @@ impl PunchSymToConeHoleClient {
}
let rpc_stub = self.get_rpc_stub(dst_peer_id).await;
let scoped_punch_task: ScopedTask<Option<u32>> =
tokio::spawn(Self::remote_send_hole_punch_packet_random(
let punch_task =
AbortOnDropHandle::new(tokio::spawn(Self::remote_send_hole_punch_packet_random(
rpc_stub,
remote_mapped_addr,
public_ips.clone(),
tid,
round,
port_index,
))
.into();
)));
let ret_tunnel = Self::check_hole_punch_result(
global_ctx,
&udp_array,
&packet,
tid,
remote_mapped_addr,
&scoped_punch_task,
&punch_task,
)
.await?;
let punch_task_result = scoped_punch_task.await;
let punch_task_result = punch_task.await;
tracing::debug!(?punch_task_result, ?ret_tunnel, "punch task got result");
if let Ok(Some(next_port_idx)) = punch_task_result {
@@ -644,7 +642,7 @@ pub mod tests {
#[tokio::test]
#[serial_test::serial(hole_punch)]
async fn hole_punching_symmetric_only_predict(#[values("true", "false")] is_inc: bool) {
use crate::common::scoped_task::ScopedTask;
use tokio_util::task::AbortOnDropHandle;
RUN_TESTING.store(true, std::sync::atomic::Ordering::Relaxed);
@@ -694,12 +692,12 @@ pub mod tests {
let counter = Arc::new(AtomicU32::new(0));
let mut tasks: Vec<ScopedTask<()>> = vec![];
let mut tasks: Vec<AbortOnDropHandle<()>> = vec![];
// all these sockets should receive hole punching packet
for udp in udps.iter().map(Arc::clone) {
let counter = counter.clone();
tasks.push(ScopedTask::from(tokio::spawn(async move {
tasks.push(AbortOnDropHandle::new(tokio::spawn(async move {
let mut buf = [0u8; 1024];
let (len, addr) = udp.recv_from(&mut buf).await.unwrap();
println!(
+4 -6
View File
@@ -12,14 +12,12 @@ use crossbeam::atomic::AtomicCell;
#[cfg(feature = "kcp")]
use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream};
use tokio_util::sync::{CancellationToken, DropGuard};
use tokio_util::task::AbortOnDropHandle;
#[cfg(feature = "kcp")]
use crate::gateway::kcp_proxy::NatDstKcpConnector;
use crate::{
common::{
config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background,
scoped_task::ScopedTask,
},
common::{config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background},
gateway::{
fast_socks5::{
server::{
@@ -473,7 +471,7 @@ pub struct Socks5Server {
entries: Socks5EntrySet,
udp_client_map: Arc<DashMap<UdpClientKey, Arc<UdpClientInfo>>>,
udp_forward_task: Arc<DashMap<UdpClientKey, ScopedTask<()>>>,
udp_forward_task: Arc<DashMap<UdpClientKey, AbortOnDropHandle<()>>>,
#[cfg(feature = "kcp")]
kcp_endpoint: Mutex<Option<Weak<KcpEndpoint>>>,
@@ -997,7 +995,7 @@ impl Socks5Server {
let client_addr = addr;
udp_forward_task.insert(
udp_client_key.clone(),
ScopedTask::from(tokio::spawn(async move {
AbortOnDropHandle::new(tokio::spawn(async move {
loop {
let mut buf = vec![0u8; 8192];
match socks_udp.recv_from(&mut buf).await {
+3 -4
View File
@@ -22,8 +22,7 @@ use smoltcp::{
pub use socket::{TcpListener, TcpStream, UdpSocket};
pub use socket_allocator::BufferSize;
use tokio::sync::Notify;
use crate::common::scoped_task::ScopedTask;
use tokio_util::task::AbortOnDropHandle;
/// The async devices.
pub mod channel_device;
@@ -79,7 +78,7 @@ pub struct Net {
ip_addr: IpCidr,
from_port: AtomicU16,
stopper: Arc<Notify>,
fut: ScopedTask<io::Result<()>>,
fut: AbortOnDropHandle<io::Result<()>>,
}
impl std::fmt::Debug for Net {
@@ -131,7 +130,7 @@ impl Net {
ip_addr: config.ip_addr,
from_port: AtomicU16::new(10001),
stopper,
fut: ScopedTask::from(tokio::spawn(fut)),
fut: AbortOnDropHandle::new(tokio::spawn(fut)),
}
}
pub fn get_address(&self) -> IpAddr {
+4 -3
View File
@@ -21,13 +21,14 @@ use tokio::{
task::{JoinHandle, JoinSet},
time::timeout,
};
use tokio_util::task::AbortOnDropHandle;
use tracing::Level;
use super::{CidrSet, ip_reassembler::IpReassembler};
use crate::tunnel::common::bind;
use crate::{
common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask},
common::{PeerId, error::Error, global_ctx::ArcGlobalCtx},
gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet},
peers::{PeerPacketFilter, peer_manager::PeerManager},
tunnel::{
@@ -149,7 +150,7 @@ impl UdpNatEntry {
let (s, mut r) = channel(128);
let self_clone = self.clone();
let recv_task = ScopedTask::from(tokio::spawn(async move {
let recv_task = AbortOnDropHandle::new(tokio::spawn(async move {
let mut cur_buf = BytesMut::new();
loop {
if self_clone
@@ -194,7 +195,7 @@ impl UdpNatEntry {
}));
let self_clone = self.clone();
let send_task = ScopedTask::from(tokio::spawn(async move {
let send_task = AbortOnDropHandle::new(tokio::spawn(async move {
let mut ip_id = 1;
while let Some((mut packet, len, src_socket)) = r.recv().await {
let SocketAddr::V4(mut src_v4) = src_socket else {
+4 -4
View File
@@ -16,13 +16,13 @@ use tokio::sync::{Mutex, Notify};
use tokio::{sync::oneshot, task::JoinSet};
#[cfg(feature = "magic-dns")]
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use crate::common::PeerId;
use crate::common::acl_processor::AclRuleBuilder;
use crate::common::config::ConfigLoader;
use crate::common::error::Error;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent};
use crate::common::scoped_task::ScopedTask;
use crate::connector::direct::DirectConnectorManager;
use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager};
use crate::connector::tcp_hole_punch::TcpHolePunchConnector;
@@ -135,7 +135,7 @@ type NicCtx = super::virtual_nic::NicCtx;
#[cfg(feature = "magic-dns")]
struct MagicDnsContainer {
dns_runner_task: ScopedTask<()>,
dns_runner_task: AbortOnDropHandle<()>,
dns_runner_cancel_token: CancellationToken,
}
@@ -167,7 +167,7 @@ impl NicCtxContainer {
Self {
nic_ctx: Some(Box::new(nic_ctx)),
magic_dns: Some(MagicDnsContainer {
dns_runner_task: task.into(),
dns_runner_task: AbortOnDropHandle::new(task),
dns_runner_cancel_token: token,
}),
}
@@ -558,7 +558,7 @@ pub struct Instance {
#[cfg(feature = "socks5")]
socks5_server: Arc<Socks5Server>,
proxy_cidrs_monitor: Option<ScopedTask<()>>,
proxy_cidrs_monitor: Option<AbortOnDropHandle<()>>,
global_ctx: ArcGlobalCtx,
}
+3 -3
View File
@@ -3,8 +3,8 @@ use std::sync::{Arc, Weak};
use std::time::Instant;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtxEvent};
use crate::common::scoped_task::ScopedTask;
use crate::peers::peer_manager::PeerManager;
use tokio_util::task::AbortOnDropHandle;
/// ProxyCidrsMonitor monitors changes in proxy CIDRs from peer routes
/// and emits GlobalCtxEvent::ProxyCidrsUpdated with added/removed diffs.
@@ -58,8 +58,8 @@ impl ProxyCidrsMonitor {
}
/// Starts monitoring proxy_cidrs changes and emits events with diffs
pub fn start(self) -> ScopedTask<()> {
ScopedTask::from(tokio::spawn(async move {
pub fn start(self) -> AbortOnDropHandle<()> {
AbortOnDropHandle::new(tokio::spawn(async move {
let mut cur_proxy_cidrs = BTreeSet::new();
let mut last_update = None::<Instant>;
+4 -4
View File
@@ -1,13 +1,13 @@
use dashmap::DashMap;
use std::fmt::{Display, Formatter};
use std::{collections::BTreeMap, path::PathBuf, sync::Arc};
use tokio_util::task::AbortOnDropHandle;
use crate::{
common::{
config::{ConfigFileControl, ConfigLoader, ConfigSource, TomlConfigLoader},
global_ctx::{EventBusSubscriber, GlobalCtxEvent},
log,
scoped_task::ScopedTask,
},
launcher::{NetworkInstance, NetworkInstanceRunningInfo},
proto::{self},
@@ -27,7 +27,7 @@ impl Drop for DaemonGuard {
pub struct NetworkInstanceManager {
instance_map: Arc<DashMap<uuid::Uuid, NetworkInstance>>,
instance_stop_tasks: Arc<DashMap<uuid::Uuid, ScopedTask<()>>>,
instance_stop_tasks: Arc<DashMap<uuid::Uuid, AbortOnDropHandle<()>>>,
stop_check_notifier: Arc<tokio::sync::Notify>,
instance_error_messages: Arc<DashMap<uuid::Uuid, String>>,
config_dir: Option<PathBuf>,
@@ -78,12 +78,12 @@ impl NetworkInstanceManager {
let stop_check_notifier = self.stop_check_notifier.clone();
self.instance_stop_tasks.insert(
instance_id,
ScopedTask::from(tokio::spawn(async move {
AbortOnDropHandle::new(tokio::spawn(async move {
let Some(instance_stop_notifier) = instance_stop_notifier else {
return;
};
let _t = instance_event_receiver
.map(|event| ScopedTask::from(handle_event(instance_id, event)));
.map(|event| AbortOnDropHandle::new(handle_event(instance_id, event)));
instance_stop_notifier.notified().await;
if let Some(instance) = instance_map.get(&instance_id)
&& let Some(error) = instance.get_latest_error_msg()
+4 -5
View File
@@ -13,7 +13,6 @@ use pnet::packet::{
Packet as _, ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket,
};
use crate::common::scoped_task::ScopedTask;
use crate::proto::acl::{AclStats, Protocol};
use crate::tunnel::packet_def::PacketType;
use crate::{
@@ -21,6 +20,7 @@ use crate::{
proto::acl::{Acl, Action, ChainType},
tunnel::packet_def::ZCPacket,
};
use tokio_util::task::AbortOnDropHandle;
#[derive(Debug, Eq, PartialEq, Hash)]
struct OutboundAllowRecord {
@@ -63,7 +63,7 @@ pub struct AclFilter {
// Track allowed outbound packets and automatically allow their corresponding inbound response
// packets, even if they would normally be dropped by ACL rules
outbound_allow_records: Arc<DashMap<OutboundAllowRecord, Instant>>,
clean_task: ScopedTask<()>,
clean_task: AbortOnDropHandle<()>,
}
impl Default for AclFilter {
@@ -80,14 +80,13 @@ impl AclFilter {
acl_processor: ArcSwap::from(Arc::new(AclProcessor::new(Acl::default()))),
acl_enabled: Arc::new(AtomicBool::new(false)),
outbound_allow_records,
clean_task: tokio::spawn(async move {
clean_task: AbortOnDropHandle::new(tokio::spawn(async move {
let max_life = std::time::Duration::from_secs(30);
loop {
record_clone.retain(|_, v| v.elapsed() < max_life);
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
}
})
.into(),
})),
}
}
+12 -14
View File
@@ -1,9 +1,10 @@
use std::sync::{Arc, Mutex};
use crate::{
common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask},
common::{PeerId, error::Error, global_ctx::ArcGlobalCtx},
tunnel::packet_def::ZCPacket,
};
use tokio_util::task::AbortOnDropHandle;
use super::{PacketRecvChan, peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager};
@@ -13,7 +14,7 @@ pub struct ForeignNetworkClient {
my_peer_id: PeerId,
peer_map: Arc<PeerMap>,
task: Mutex<Option<ScopedTask<()>>>,
task: Mutex<Option<AbortOnDropHandle<()>>>,
}
impl ForeignNetworkClient {
@@ -82,18 +83,15 @@ impl ForeignNetworkClient {
pub async fn run(&self) {
let peer_map = Arc::downgrade(&self.peer_map);
*self.task.lock().unwrap() = Some(
tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
let Some(peer_map) = peer_map.upgrade() else {
break;
};
peer_map.clean_peer_without_conn().await;
}
})
.into(),
);
*self.task.lock().unwrap() = Some(AbortOnDropHandle::new(tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
let Some(peer_map) = peer_map.upgrade() else {
break;
};
peer_map.clean_peer_without_conn().await;
}
})));
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {
+7 -10
View File
@@ -12,6 +12,7 @@ use super::{
PacketRecvChan,
peer_conn::{PeerConn, PeerConnId},
};
use crate::{common::shrink_dashmap, proto::api::instance::PeerConnInfo};
use crate::{
common::{
PeerId,
@@ -21,10 +22,7 @@ use crate::{
proto::peer_rpc::PeerIdentityType,
tunnel::packet_def::ZCPacket,
};
use crate::{
common::{scoped_task::ScopedTask, shrink_dashmap},
proto::api::instance::PeerConnInfo,
};
use tokio_util::task::AbortOnDropHandle;
type ArcPeerConn = Arc<PeerConn>;
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
@@ -37,14 +35,14 @@ pub struct Peer {
packet_recv_chan: PacketRecvChan,
close_event_sender: mpsc::Sender<PeerConnId>,
close_event_listener: ScopedTask<()>,
close_event_listener: AbortOnDropHandle<()>,
shutdown_notifier: Arc<tokio::sync::Notify>,
default_conn_id: Arc<AtomicCell<PeerConnId>>,
peer_identity_type: Arc<AtomicCell<Option<PeerIdentityType>>>,
peer_public_key: Arc<RwLock<Option<Vec<u8>>>>,
default_conn_id_clear_task: ScopedTask<()>,
default_conn_id_clear_task: AbortOnDropHandle<()>,
}
impl Peer {
@@ -64,7 +62,7 @@ impl Peer {
let conns_copy = conns.clone();
let shutdown_notifier_copy = shutdown_notifier.clone();
let global_ctx_copy = global_ctx.clone();
let close_event_listener = tokio::spawn(
let close_event_listener = AbortOnDropHandle::new(tokio::spawn(
async move {
loop {
select! {
@@ -103,14 +101,13 @@ impl Peer {
"peer_close_event_listener",
?peer_node_id,
)),
)
.into();
));
let default_conn_id = Arc::new(AtomicCell::new(PeerConnId::default()));
let conns_copy = conns.clone();
let default_conn_id_copy = default_conn_id.clone();
let default_conn_id_clear_task = ScopedTask::from(tokio::spawn(async move {
let default_conn_id_clear_task = AbortOnDropHandle::new(tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
if conns_copy.len() > 1 {
+2 -2
View File
@@ -1606,7 +1606,6 @@ pub mod tests {
use crate::common::global_ctx::GlobalCtx;
use crate::common::global_ctx::tests::get_mock_global_ctx;
use crate::common::new_peer_id;
use crate::common::scoped_task::ScopedTask;
use crate::common::stats_manager::{LabelSet, LabelType, MetricName};
use crate::peers::create_packet_recv_chan;
use crate::peers::recv_packet_from_chan;
@@ -1614,6 +1613,7 @@ pub mod tests {
use crate::tunnel::filter::PacketRecorderTunnelFilter;
use crate::tunnel::filter::tests::DropSendTunnelFilter;
use crate::tunnel::ring::create_ring_tunnel_pair;
use tokio_util::task::AbortOnDropHandle;
pub fn set_secure_mode_cfg(global_ctx: &GlobalCtx, enabled: bool) {
if !enabled {
@@ -2200,7 +2200,7 @@ pub mod tests {
c_peer.start_recv_loop(create_packet_recv_chan().0).await;
let throughput = c_peer.throughput.clone();
let _t = ScopedTask::from(tokio::spawn(async move {
let _t = AbortOnDropHandle::new(tokio::spawn(async move {
// if not drop both, we mock some rx traffic for client peer to test pinger
if drop_both {
return;
+9 -8
View File
@@ -11,8 +11,8 @@ use tokio::select;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use crate::common::scoped_task::ScopedTask;
use anyhow::Error;
use tokio_util::task::AbortOnDropHandle;
use super::peer_manager::PeerManager;
@@ -72,7 +72,7 @@ pub trait PeerTaskLauncher: Send + Sync + Clone + 'static {
pub struct PeerTaskManager<Launcher: PeerTaskLauncher> {
launcher: Launcher,
main_loop_task: Mutex<Option<ScopedTask<()>>>,
main_loop_task: Mutex<Option<AbortOnDropHandle<()>>>,
run_signal: Arc<Notify>,
external_signal: Option<Arc<ExternalTaskSignal>>,
data: Launcher::Data,
@@ -105,13 +105,12 @@ where
}
pub fn start(&self) {
let task = tokio::spawn(Self::main_loop(
let task = AbortOnDropHandle::new(tokio::spawn(Self::main_loop(
self.launcher.clone(),
self.data.clone(),
self.run_signal.clone(),
self.external_signal.clone(),
))
.into();
)));
self.main_loop_task.lock().unwrap().replace(task);
}
@@ -121,7 +120,7 @@ where
signal: Arc<Notify>,
external_signal: Option<Arc<ExternalTaskSignal>>,
) {
let peer_task_map = Arc::new(DashMap::<C, ScopedTask<Result<T, Error>>>::new());
let peer_task_map = Arc::new(DashMap::<C, AbortOnDropHandle<Result<T, Error>>>::new());
let mut external_signal_version = external_signal.as_ref().map(|signal| signal.version());
loop {
@@ -158,8 +157,10 @@ where
}
tracing::debug!(?item, "launch hole punching task");
peer_task_map
.insert(item.clone(), launcher.launch_task(&data, item).await.into());
peer_task_map.insert(
item.clone(),
AbortOnDropHandle::new(launcher.launch_task(&data, item).await),
);
}
} else if peer_task_map.is_empty() {
launcher.all_task_done(&data).await;
+3 -4
View File
@@ -424,11 +424,11 @@ async fn standalone_rpc_test() {
#[tokio::test]
async fn test_bidirect_rpc_manager() {
use crate::common::scoped_task::ScopedTask;
use crate::proto::rpc_impl::bidirect::BidirectRpcManager;
use crate::tunnel::tcp::{TcpTunnelConnector, TcpTunnelListener};
use crate::tunnel::{TunnelConnector, TunnelListener};
use tokio::sync::Notify;
use tokio_util::task::AbortOnDropHandle;
let c = BidirectRpcManager::new();
let s = BidirectRpcManager::new();
@@ -448,7 +448,7 @@ async fn test_bidirect_rpc_manager() {
let server_test_done = Arc::new(Notify::new());
let server_test_done_clone = server_test_done.clone();
let mut tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:55443".parse().unwrap());
let s_task: ScopedTask<()> = tokio::spawn(async move {
let s_task = AbortOnDropHandle::new(tokio::spawn(async move {
tcp_listener.listen().await.unwrap();
let tunnel = tcp_listener.accept().await.unwrap();
s.run_with_tunnel(tunnel);
@@ -471,8 +471,7 @@ async fn test_bidirect_rpc_manager() {
server_test_done_clone.notify_one();
s.wait().await;
})
.into();
}));
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
+3 -5
View File
@@ -779,10 +779,8 @@ pub async fn data_compress(
#[tokio::test]
#[serial_test::serial]
pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) {
use crate::{
common::scoped_task::ScopedTask,
tunnel::wireguard::{WgConfig, WgTunnelConnector},
};
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
use tokio_util::task::AbortOnDropHandle;
let insts = init_three_node(proto).await;
let mut inst4 = Instance::new(get_inst_config(
@@ -838,7 +836,7 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str
.await;
set_link_status("net_d", false);
let _t = ScopedTask::from(tokio::spawn(async move {
let _t = AbortOnDropHandle::new(tokio::spawn(async move {
// do some ping in net_a to trigger net_c pingpong
loop {
ping_test("net_a", "10.144.144.4", Some(1)).await;
+10 -14
View File
@@ -14,18 +14,16 @@ use std::{
};
use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex};
use crate::{
common::scoped_task::ScopedTask,
tunnel::{
FromUrl, IpVersion, SinkError, SinkItem, StreamItem, Tunnel, TunnelConnector, TunnelError,
TunnelInfo, TunnelListener,
common::TunnelWrapper,
fake_tcp::netfilter::create_tun,
packet_def::{PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE, ZCPacket, ZCPacketType},
},
use crate::tunnel::{
FromUrl, IpVersion, SinkError, SinkItem, StreamItem, Tunnel, TunnelConnector, TunnelError,
TunnelInfo, TunnelListener,
common::TunnelWrapper,
fake_tcp::netfilter::create_tun,
packet_def::{PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE, ZCPacket, ZCPacketType},
};
use futures::Future;
use tokio_util::task::AbortOnDropHandle;
use dashmap::DashMap;
@@ -186,8 +184,8 @@ impl FakeTcpTunnelListener {
}
}
fn build_os_socket_reader_task(mut socket: TcpStream) -> ScopedTask<()> {
let os_socket_reader_task: ScopedTask<()> = tokio::spawn(async move {
fn build_os_socket_reader_task(mut socket: TcpStream) -> AbortOnDropHandle<()> {
AbortOnDropHandle::new(tokio::spawn(async move {
// read the os socket until it's closed
let mut buf = [0u8; 1024];
while let Ok(size) = socket.read(&mut buf).await {
@@ -197,9 +195,7 @@ fn build_os_socket_reader_task(mut socket: TcpStream) -> ScopedTask<()> {
}
}
tracing::info!("FakeTcpTunnelListener os socket closed");
})
.into();
os_socket_reader_task
}))
}
#[derive(Debug)]
+3 -4
View File
@@ -38,8 +38,6 @@
//! and [`server.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/server.rs) files
//! from the `phantun` crate for how to use this library in client/server mode, respectively.
use crate::common::scoped_task::ScopedTask;
use super::packet::*;
use bytes::{Bytes, BytesMut};
use crossbeam::atomic::AtomicCell;
@@ -55,6 +53,7 @@ use std::sync::{
};
use tokio::sync::broadcast;
use tokio::time;
use tokio_util::task::AbortOnDropHandle;
use tracing::{info, trace, warn};
const TIMEOUT: time::Duration = time::Duration::from_secs(1);
@@ -96,7 +95,7 @@ pub struct Stack {
local_ip: Ipv4Addr,
local_ip6: Option<Ipv6Addr>,
local_mac: MacAddr,
reader_task: ScopedTask<()>,
reader_task: AbortOnDropHandle<()>,
}
#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
@@ -418,7 +417,7 @@ impl Stack {
local_ip,
local_ip6,
local_mac: local_mac.unwrap_or(MacAddr::zero()),
reader_task: t.into(),
reader_task: AbortOnDropHandle::new(t),
}
}
+4 -3
View File
@@ -5,11 +5,12 @@ use std::{pin::Pin, time::Duration};
use anyhow::Context;
use tokio::time::timeout;
use crate::{common::scoped_task::ScopedTask, proto::common::TunnelInfo};
use crate::proto::common::TunnelInfo;
use super::{Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, packet_def::ZCPacket};
use tokio::sync::mpsc::{Receiver, Sender, channel, error::TrySendError};
use tokio_util::task::AbortOnDropHandle;
// use tachyonix::{channel, Receiver, Sender, TrySendError};
use futures::SinkExt;
@@ -37,7 +38,7 @@ pub struct MpscTunnel<T> {
tunnel: T,
stream: Option<Pin<Box<dyn ZCPacketStream>>>,
task: ScopedTask<()>,
task: AbortOnDropHandle<()>,
}
impl<T: Tunnel> MpscTunnel<T> {
@@ -61,7 +62,7 @@ impl<T: Tunnel> MpscTunnel<T> {
tx: Some(tx),
tunnel,
stream: Some(stream),
task: task.into(),
task: AbortOnDropHandle::new(task),
}
}
+5 -6
View File
@@ -17,6 +17,7 @@ use tokio::{
sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
task::JoinSet,
};
use tokio_util::task::AbortOnDropHandle;
use tracing::{Instrument, instrument};
use super::{
@@ -28,7 +29,7 @@ use super::{
};
use crate::tunnel::common::bind;
use crate::{
common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap},
common::{join_joinset_background, shrink_dashmap},
tunnel::{
build_url_from_socket_addr,
common::{TunnelWrapper, reserve_buf},
@@ -339,7 +340,7 @@ struct UdpConnection {
dst_addr: SocketAddr,
ring_sender: RingSink,
forward_task: ScopedTask<()>,
forward_task: AbortOnDropHandle<()>,
}
impl UdpConnection {
@@ -352,15 +353,13 @@ impl UdpConnection {
close_event_sender: UdpCloseEventSender,
) -> Self {
let s = socket.clone();
let forward_task = tokio::spawn(async move {
let forward_task = AbortOnDropHandle::new(tokio::spawn(async move {
let close_event_sender = close_event_sender;
let err = forward_from_ring_to_udp(ring_recv, &s, &dst_addr, conn_id).await;
if let Err(e) = close_event_sender.send((dst_addr, err)) {
tracing::error!(?e, "udp send close event error");
}
})
.into();
}));
Self {
socket,
conn_id,
+4 -3
View File
@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{
common::{
config::TomlConfigLoader, global_ctx::GlobalCtx, log, os_info::collect_device_os_info,
scoped_task::ScopedTask, set_default_machine_id, stun::MockStunInfoCollector,
set_default_machine_id, stun::MockStunInfoCollector,
},
connector::create_connector_by_url,
instance_manager::{DaemonGuard, NetworkInstanceManager},
@@ -12,6 +12,7 @@ use crate::{
};
use anyhow::{Context as _, Result};
use async_trait::async_trait;
use tokio_util::task::AbortOnDropHandle;
use url::Url;
use uuid::Uuid;
@@ -43,7 +44,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
pub struct WebClient {
controller: Arc<controller::Controller>,
tasks: ScopedTask<()>,
tasks: AbortOnDropHandle<()>,
manager_guard: DaemonGuard,
connected: Arc<AtomicBool>,
}
@@ -70,7 +71,7 @@ impl WebClient {
let controller_clone = controller.clone();
let connected_clone = connected.clone();
let tasks = ScopedTask::from(tokio::spawn(async move {
let tasks = AbortOnDropHandle::new(tokio::spawn(async move {
Self::routine(
controller_clone,
connected_clone,