From 5a1668c7533b4d0064f105e53ad11263651f9101 Mon Sep 17 00:00:00 2001 From: Luna Yao <40349250+ZnqbuZ@users.noreply.github.com> Date: Sat, 25 Apr 2026 09:20:25 +0200 Subject: [PATCH] refactor: remove ScopedTask (#2125) * replace ScopedTask with AbortOnDropHandle --- Cargo.lock | 1 + .../easytier-uptime/src/health_checker.rs | 10 +- easytier-web/Cargo.toml | 1 + easytier-web/src/client_manager/session.rs | 21 ++-- easytier-web/src/restful/mod.rs | 23 ++-- easytier-web/src/web/mod.rs | 11 +- easytier/src/common/mod.rs | 1 - easytier/src/common/scoped_task.rs | 119 ------------------ easytier/src/common/stats_manager.rs | 7 +- easytier/src/common/stun.rs | 12 +- easytier/src/common/token_bucket.rs | 10 +- .../connector/udp_hole_punch/both_easy_sym.rs | 7 +- easytier/src/connector/udp_hole_punch/cone.rs | 10 +- .../connector/udp_hole_punch/sym_to_cone.rs | 38 +++--- easytier/src/gateway/socks5.rs | 10 +- easytier/src/gateway/tokio_smoltcp/mod.rs | 7 +- easytier/src/gateway/udp_proxy.rs | 7 +- easytier/src/instance/instance.rs | 8 +- easytier/src/instance/proxy_cidrs_monitor.rs | 6 +- easytier/src/instance_manager.rs | 8 +- easytier/src/peers/acl_filter.rs | 9 +- easytier/src/peers/foreign_network_client.rs | 26 ++-- easytier/src/peers/peer.rs | 17 ++- easytier/src/peers/peer_conn.rs | 4 +- easytier/src/peers/peer_task.rs | 17 +-- easytier/src/proto/tests.rs | 7 +- easytier/src/tests/three_node.rs | 8 +- easytier/src/tunnel/fake_tcp/mod.rs | 24 ++-- easytier/src/tunnel/fake_tcp/stack.rs | 7 +- easytier/src/tunnel/mpsc.rs | 7 +- easytier/src/tunnel/udp.rs | 11 +- easytier/src/web_client/mod.rs | 7 +- 32 files changed, 161 insertions(+), 300 deletions(-) delete mode 100644 easytier/src/common/scoped_task.rs diff --git a/Cargo.lock b/Cargo.lock index 156ccb62..0f5dd647 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2519,6 +2519,7 @@ dependencies = [ "thiserror 1.0.63", "thunk-rs", "tokio", + "tokio-util", "tower-http", "tower-sessions", "tower-sessions-sqlx-store", diff --git a/easytier-contrib/easytier-uptime/src/health_checker.rs b/easytier-contrib/easytier-uptime/src/health_checker.rs index 49198a30..9346e7ef 100644 --- a/easytier-contrib/easytier-uptime/src/health_checker.rs +++ b/easytier-contrib/easytier-uptime/src/health_checker.rs @@ -7,15 +7,15 @@ use std::{ use anyhow::Context as _; use dashmap::DashMap; use easytier::{ - common::{ - config::{ConfigFileControl, ConfigLoader, NetworkIdentity, PeerConfig, TomlConfigLoader}, - scoped_task::ScopedTask, + common::config::{ + ConfigFileControl, ConfigLoader, NetworkIdentity, PeerConfig, TomlConfigLoader, }, defer, instance_manager::NetworkInstanceManager, }; use serde::{Deserialize, Serialize}; use sqlx::any; +use tokio_util::task::AbortOnDropHandle; use tracing::{debug, error, info, instrument, warn}; use crate::db::{ @@ -240,7 +240,7 @@ pub struct HealthChecker { db: Db, instance_mgr: Arc, inst_id_map: DashMap, - node_tasks: DashMap>, + node_tasks: DashMap>, node_records: Arc>, node_cfg: Arc>, } @@ -465,7 +465,7 @@ impl HealthChecker { } // 启动健康检查任务 - let task = ScopedTask::from(tokio::spawn(Self::node_health_check_task( + let task = AbortOnDropHandle::new(tokio::spawn(Self::node_health_check_task( node_id, cfg.get_id(), Arc::clone(&self.instance_mgr), diff --git a/easytier-web/Cargo.toml b/easytier-web/Cargo.toml index 1c1d9b7d..d86e3181 100644 --- a/easytier-web/Cargo.toml +++ b/easytier-web/Cargo.toml @@ -10,6 +10,7 @@ tracing = { version = "0.1", features = ["log"] } anyhow = { version = "1.0" } thiserror = "1.0" tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0.7", features = ["rt"] } dashmap = "6.1" url = "2.2" async-trait = "0.1" diff --git a/easytier-web/src/client_manager/session.rs b/easytier-web/src/client_manager/session.rs index 99b1daea..5fe39451 100644 --- a/easytier-web/src/client_manager/session.rs +++ b/easytier-web/src/client_manager/session.rs @@ -7,7 +7,7 @@ use std::{ use anyhow::Context; use easytier::{ - common::{config::ConfigSource, scoped_task::ScopedTask}, + common::config::ConfigSource, proto::{ api::manage::{ ConfigSource as RpcConfigSource, NetworkConfig, NetworkMeta, RunNetworkInstanceRequest, @@ -21,6 +21,7 @@ use easytier::{ tunnel::Tunnel, }; use tokio::sync::{RwLock, broadcast}; +use tokio_util::task::AbortOnDropHandle; use super::storage::{Storage, StorageToken, WeakRefStorage}; use crate::FeatureFlags; @@ -475,7 +476,7 @@ pub struct Session { data: SharedSessionData, - run_network_on_start_task: Option>, + run_network_on_start_task: Option>, } impl Debug for Session { @@ -517,14 +518,14 @@ impl Session { self.rpc_mgr.run_with_tunnel(tunnel); let data = self.data.read().await; - self.run_network_on_start_task.replace( - tokio::spawn(Self::run_network_on_start( - data.heartbeat_waiter(), - data.storage.clone(), - self.scoped_rpc_client(), - )) - .into(), - ); + self.run_network_on_start_task + .replace(AbortOnDropHandle::new(tokio::spawn( + Self::run_network_on_start( + data.heartbeat_waiter(), + data.storage.clone(), + self.scoped_rpc_client(), + ), + ))); } fn collect_webhook_source_instance_ids( diff --git a/easytier-web/src/restful/mod.rs b/easytier-web/src/restful/mod.rs index b78e9e00..91fc1996 100644 --- a/easytier-web/src/restful/mod.rs +++ b/easytier-web/src/restful/mod.rs @@ -17,12 +17,12 @@ use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer}; use axum_login::{AuthManagerLayerBuilder, AuthUser, AuthzBackend, login_required}; use axum_messages::MessagesManagerLayer; use easytier::common::config::{ConfigLoader, TomlConfigLoader}; -use easytier::common::scoped_task::ScopedTask; use easytier::launcher::NetworkConfig; use easytier::proto::rpc_types; use network::NetworkApi; use sea_orm::DbErr; use tokio::net::TcpListener; +use tokio_util::task::AbortOnDropHandle; use tower_sessions::Expiry; use tower_sessions::cookie::time::Duration; use tower_sessions::cookie::{Key, SameSite}; @@ -199,8 +199,8 @@ impl RestfulServer { mut self, ) -> Result< ( - ScopedTask<()>, - ScopedTask>, + AbortOnDropHandle<()>, + AbortOnDropHandle>, ), anyhow::Error, > { @@ -213,13 +213,11 @@ impl RestfulServer { let session_store = SqliteStore::new(self.db.inner()); session_store.migrate().await?; - let delete_task: ScopedTask> = - tokio::task::spawn( - session_store - .clone() - .continuously_delete_expired(tokio::time::Duration::from_secs(60)), - ) - .into(); + let delete_task = AbortOnDropHandle::new(tokio::task::spawn( + session_store + .clone() + .continuously_delete_expired(tokio::time::Duration::from_secs(60)), + )); // Generate a cryptographic key to sign the session cookie. let key = Key::generate(); @@ -298,10 +296,9 @@ impl RestfulServer { app }; - let serve_task: ScopedTask<()> = tokio::spawn(async move { + let serve_task = AbortOnDropHandle::new(tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); - }) - .into(); + })); Ok((serve_task, delete_task)) } diff --git a/easytier-web/src/web/mod.rs b/easytier-web/src/web/mod.rs index 934108ed..bd164b43 100644 --- a/easytier-web/src/web/mod.rs +++ b/easytier-web/src/web/mod.rs @@ -6,10 +6,10 @@ use axum::{ routing, }; use axum_embed::ServeEmbed; -use easytier::common::scoped_task::ScopedTask; use rust_embed::RustEmbed; use std::net::SocketAddr; use tokio::net::TcpListener; +use tokio_util::task::AbortOnDropHandle; /// Embed assets for web dashboard, build frontend first #[derive(RustEmbed, Clone)] @@ -59,7 +59,7 @@ pub fn build_router(api_host: Option) -> Router { pub struct WebServer { bind_addr: SocketAddr, router: Router, - serve_task: Option>, + serve_task: Option>, } impl WebServer { @@ -71,14 +71,13 @@ impl WebServer { }) } - pub async fn start(self) -> Result, anyhow::Error> { + pub async fn start(self) -> Result, anyhow::Error> { let listener = TcpListener::bind(self.bind_addr).await?; let app = self.router; - let task = tokio::spawn(async move { + let task = AbortOnDropHandle::new(tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); - }) - .into(); + })); Ok(task) } diff --git a/easytier/src/common/mod.rs b/easytier/src/common/mod.rs index 8d066dde..5e5e9a2e 100644 --- a/easytier/src/common/mod.rs +++ b/easytier/src/common/mod.rs @@ -25,7 +25,6 @@ pub mod log; pub mod netns; pub mod network; pub mod os_info; -pub mod scoped_task; pub mod stats_manager; pub mod stun; pub mod stun_codec_ext; diff --git a/easytier/src/common/scoped_task.rs b/easytier/src/common/scoped_task.rs deleted file mode 100644 index b87a1f19..00000000 --- a/easytier/src/common/scoped_task.rs +++ /dev/null @@ -1,119 +0,0 @@ -//! This crate provides a wrapper type of Tokio's JoinHandle: `ScopedTask`, which aborts the task when it's dropped. -//! `ScopedTask` can still be awaited to join the child-task, and abort-on-drop will still trigger while it is being awaited. -//! -//! For example, if task A spawned task B but is doing something else, and task B is waiting for task C to join, -//! aborting A will also abort both B and C. - -use derive_more::{Deref, DerefMut, From}; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::task::JoinHandle; - -#[derive(Debug, From, Deref, DerefMut)] -pub struct ScopedTask(JoinHandle); - -impl Drop for ScopedTask { - fn drop(&mut self) { - self.abort() - } -} - -impl Future for ScopedTask { - type Output = as Future>::Output; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx) - } -} - -#[cfg(test)] -mod tests { - use super::ScopedTask; - use futures_util::future::pending; - use std::sync::{Arc, RwLock}; - use tokio::task::yield_now; - - struct Sentry(Arc>); - impl Drop for Sentry { - fn drop(&mut self) { - *self.0.write().unwrap() = true - } - } - - #[tokio::test] - async fn drop_while_not_waiting_for_join() { - let dropped = Arc::new(RwLock::new(false)); - let sentry = Sentry(dropped.clone()); - let task = ScopedTask::from(tokio::spawn(async move { - let _sentry = sentry; - pending::<()>().await - })); - yield_now().await; - assert!(!*dropped.read().unwrap()); - drop(task); - yield_now().await; - assert!(*dropped.read().unwrap()); - } - - #[tokio::test] - async fn drop_while_waiting_for_join() { - let dropped = Arc::new(RwLock::new(false)); - let sentry = Sentry(dropped.clone()); - let handle = tokio::spawn(async move { - ScopedTask::from(tokio::spawn(async move { - let _sentry = sentry; - pending::<()>().await - })) - .await - .unwrap() - }); - yield_now().await; - assert!(!*dropped.read().unwrap()); - handle.abort(); - yield_now().await; - assert!(*dropped.read().unwrap()); - } - - #[tokio::test] - async fn no_drop_only_join() { - assert_eq!( - ScopedTask::from(tokio::spawn(async { - yield_now().await; - 5 - })) - .await - .unwrap(), - 5 - ) - } - - #[tokio::test] - async fn manually_abort_before_drop() { - let dropped = Arc::new(RwLock::new(false)); - let sentry = Sentry(dropped.clone()); - let task = ScopedTask::from(tokio::spawn(async move { - let _sentry = sentry; - pending::<()>().await - })); - yield_now().await; - assert!(!*dropped.read().unwrap()); - task.abort(); - yield_now().await; - assert!(*dropped.read().unwrap()); - } - - #[tokio::test] - async fn manually_abort_then_join() { - let dropped = Arc::new(RwLock::new(false)); - let sentry = Sentry(dropped.clone()); - let task = ScopedTask::from(tokio::spawn(async move { - let _sentry = sentry; - pending::<()>().await - })); - yield_now().await; - assert!(!*dropped.read().unwrap()); - task.abort(); - yield_now().await; - assert!(task.await.is_err()); - } -} diff --git a/easytier/src/common/stats_manager.rs b/easytier/src/common/stats_manager.rs index 01f018f1..b7314dd6 100644 --- a/easytier/src/common/stats_manager.rs +++ b/easytier/src/common/stats_manager.rs @@ -5,8 +5,7 @@ use std::fmt; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::time::interval; - -use crate::common::scoped_task::ScopedTask; +use tokio_util::task::AbortOnDropHandle; /// Predefined metric names for type safety #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -578,7 +577,7 @@ impl MetricSnapshot { /// StatsManager manages global statistics with high performance counters pub struct StatsManager { counters: Arc>>, - cleanup_task: ScopedTask<()>, + cleanup_task: AbortOnDropHandle<()>, } impl StatsManager { @@ -611,7 +610,7 @@ impl StatsManager { Self { counters, - cleanup_task: cleanup_task.into(), + cleanup_task: AbortOnDropHandle::new(cleanup_task), } } diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index a4a8a397..008a63f8 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -1343,11 +1343,9 @@ impl StunInfoCollectorTrait for MockStunInfoCollector { #[cfg(test)] mod tests { - use crate::{ - common::scoped_task::ScopedTask, - tunnel::{TunnelListener, udp::UdpTunnelListener}, - }; + use crate::tunnel::{TunnelListener, udp::UdpTunnelListener}; use tokio::time::{sleep, timeout}; + use tokio_util::task::AbortOnDropHandle; use super::*; @@ -1441,7 +1439,7 @@ mod tests { use stun_codec::rfc5389::attributes::XorMappedAddress; use tokio::net::TcpListener; - async fn spawn_tcp_stun_server() -> (SocketAddr, ScopedTask<()>) { + async fn spawn_tcp_stun_server() -> (SocketAddr, AbortOnDropHandle<()>) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let server_addr = listener.local_addr().unwrap(); @@ -1465,7 +1463,7 @@ mod tests { stream.write_all(rsp_buf.as_slice()).await.unwrap(); }); - (server_addr, task.into()) + (server_addr, AbortOnDropHandle::new(task)) } let (server1, _t1) = spawn_tcp_stun_server().await; @@ -1504,7 +1502,7 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let server_addr = listener.local_addr().unwrap(); - let _t = ScopedTask::from(tokio::spawn(async move { + let _t = AbortOnDropHandle::new(tokio::spawn(async move { for _ in 0..8 { let Ok((mut stream, peer_addr)) = listener.accept().await else { break; diff --git a/easytier/src/common/token_bucket.rs b/easytier/src/common/token_bucket.rs index efffe030..ab1c9d4c 100644 --- a/easytier/src/common/token_bucket.rs +++ b/easytier/src/common/token_bucket.rs @@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tokio::sync::Notify; use tokio::time; +use tokio_util::task::AbortOnDropHandle; -use crate::common::scoped_task::ScopedTask; use crate::proto::common::LimiterConfig; /// Token Bucket rate limiter using atomic operations @@ -14,7 +14,7 @@ pub struct TokenBucket { available_tokens: AtomicU64, // Current token count (atomic) last_refill_time: AtomicU64, // Last refill time as micros since epoch config: BucketConfig, // Immutable configuration - refill_task: Mutex>>, // Background refill task + refill_task: Mutex>>, // Background refill task start_time: Instant, // Bucket creation time refill_notifier: Arc, @@ -91,7 +91,7 @@ impl TokenBucket { .refill_task .lock() .unwrap() - .replace(refill_task.into()); + .replace(AbortOnDropHandle::new(refill_task)); arc_self } @@ -172,7 +172,7 @@ impl TokenBucket { pub struct TokenBucketManager { buckets: Arc>>, - retain_task: ScopedTask<()>, + retain_task: AbortOnDropHandle<()>, } impl Default for TokenBucketManager { @@ -205,7 +205,7 @@ impl TokenBucketManager { Self { buckets, - retain_task: retain_task.into(), + retain_task: AbortOnDropHandle::new(retain_task), } } diff --git a/easytier/src/connector/udp_hole_punch/both_easy_sym.rs b/easytier/src/connector/udp_hole_punch/both_easy_sym.rs index 87f49538..03f30723 100644 --- a/easytier/src/connector/udp_hole_punch/both_easy_sym.rs +++ b/easytier/src/connector/udp_hole_punch/both_easy_sym.rs @@ -6,9 +6,10 @@ use std::{ use anyhow::Context; use tokio::sync::Mutex; +use tokio_util::task::AbortOnDropHandle; use crate::{ - common::{PeerId, scoped_task::ScopedTask, stun::StunInfoCollectorTrait}, + common::{PeerId, stun::StunInfoCollectorTrait}, connector::udp_hole_punch::common::{ HOLE_PUNCH_PACKET_BODY_LEN, UdpHolePunchListener, try_connect_with_socket, }, @@ -32,7 +33,7 @@ const REMOTE_WAIT_TIME_MS: u64 = 5000; pub(crate) struct PunchBothEasySymHoleServer { common: Arc, - task: Mutex>>, + task: Mutex>>, } impl PunchBothEasySymHoleServer { @@ -161,7 +162,7 @@ impl PunchBothEasySymHoleServer { } }); - *locked_task = Some(task.into()); + *locked_task = Some(AbortOnDropHandle::new(task)); return Ok(SendPunchPacketBothEasySymResponse { is_busy: false, base_mapped_addr: Some(cur_mapped_addr.into()), diff --git a/easytier/src/connector/udp_hole_punch/cone.rs b/easytier/src/connector/udp_hole_punch/cone.rs index fc36a322..bb738e09 100644 --- a/easytier/src/connector/udp_hole_punch/cone.rs +++ b/easytier/src/connector/udp_hole_punch/cone.rs @@ -5,9 +5,10 @@ use std::{ use anyhow::Context; use tokio::net::UdpSocket; +use tokio_util::task::AbortOnDropHandle; use crate::{ - common::{PeerId, scoped_task::ScopedTask, upnp}, + common::{PeerId, upnp}, connector::udp_hole_punch::common::{ HOLE_PUNCH_PACKET_BODY_LEN, UdpSocketArray, try_connect_with_socket, }, @@ -178,7 +179,7 @@ impl PunchConeHoleClient { send_from_local().await?; - let scoped_punch_task: ScopedTask<()> = tokio::spawn(async move { + let punch_task = AbortOnDropHandle::new(tokio::spawn(async move { if let Err(e) = rpc_stub .send_punch_packet_cone( BaseController { @@ -198,8 +199,7 @@ impl PunchConeHoleClient { { tracing::error!(?e, "failed to call remote send punch packet"); } - }) - .into(); + })); // server: will send some punching resps, total 10 packets. // client: use the socket to create UdpTunnel with UdpTunnelConnector @@ -208,7 +208,7 @@ impl PunchConeHoleClient { while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 { tokio::time::sleep(Duration::from_millis(200)).await; - if finish_time.is_none() && (*scoped_punch_task).is_finished() { + if finish_time.is_none() && punch_task.is_finished() { finish_time = Some(Instant::now()); } diff --git a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs index f2475fa0..213664ec 100644 --- a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs +++ b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs @@ -11,12 +11,11 @@ use std::{ use anyhow::Context; use rand::{Rng, seq::SliceRandom}; use tokio::{net::UdpSocket, sync::RwLock}; +use tokio_util::task::AbortOnDropHandle; use tracing::Level; use crate::{ - common::{ - PeerId, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, stun::StunInfoCollectorTrait, - }, + common::{PeerId, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait}, connector::udp_hole_punch::{ common::{ HOLE_PUNCH_PACKET_BODY_LEN, send_symmetric_hole_punch_packet, try_connect_with_socket, @@ -360,7 +359,7 @@ impl PunchSymToConeHoleClient { packet: &[u8], tid: u32, remote_mapped_addr: crate::proto::common::SocketAddr, - scoped_punch_task: &ScopedTask, + punch_task: &AbortOnDropHandle, ) -> Result>, anyhow::Error> { // no matter what the result is, we should check if we received any hole punching packet let mut ret_tunnel: Option> = None; @@ -372,7 +371,7 @@ impl PunchSymToConeHoleClient { tokio::time::sleep(Duration::from_millis(200)).await; - if finish_time.is_none() && (*scoped_punch_task).is_finished() { + if finish_time.is_none() && punch_task.is_finished() { finish_time = Some(Instant::now()); } @@ -482,27 +481,27 @@ impl PunchSymToConeHoleClient { if self.punch_predicablely.load(Ordering::Relaxed) && base_port_for_easy_sym.is_some() { let rpc_stub = self.get_rpc_stub(dst_peer_id).await; - let scoped_punch_task: ScopedTask<()> = - tokio::spawn(Self::remote_send_hole_punch_packet_predicable( + let punch_task = AbortOnDropHandle::new(tokio::spawn( + Self::remote_send_hole_punch_packet_predicable( rpc_stub, base_port_for_easy_sym, my_nat_info, remote_mapped_addr, public_ips.clone(), tid, - )) - .into(); + ), + )); let ret_tunnel = Self::check_hole_punch_result( global_ctx.clone(), &udp_array, &packet, tid, remote_mapped_addr, - &scoped_punch_task, + &punch_task, ) .await?; - let task_ret = scoped_punch_task.await; + let task_ret = punch_task.await; tracing::debug!(?ret_tunnel, ?task_ret, "predictable punch task got result"); if let Some(tunnel) = ret_tunnel { return Ok(Some(tunnel)); @@ -510,27 +509,26 @@ impl PunchSymToConeHoleClient { } let rpc_stub = self.get_rpc_stub(dst_peer_id).await; - let scoped_punch_task: ScopedTask> = - tokio::spawn(Self::remote_send_hole_punch_packet_random( + let punch_task = + AbortOnDropHandle::new(tokio::spawn(Self::remote_send_hole_punch_packet_random( rpc_stub, remote_mapped_addr, public_ips.clone(), tid, round, port_index, - )) - .into(); + ))); let ret_tunnel = Self::check_hole_punch_result( global_ctx, &udp_array, &packet, tid, remote_mapped_addr, - &scoped_punch_task, + &punch_task, ) .await?; - let punch_task_result = scoped_punch_task.await; + let punch_task_result = punch_task.await; tracing::debug!(?punch_task_result, ?ret_tunnel, "punch task got result"); if let Ok(Some(next_port_idx)) = punch_task_result { @@ -644,7 +642,7 @@ pub mod tests { #[tokio::test] #[serial_test::serial(hole_punch)] async fn hole_punching_symmetric_only_predict(#[values("true", "false")] is_inc: bool) { - use crate::common::scoped_task::ScopedTask; + use tokio_util::task::AbortOnDropHandle; RUN_TESTING.store(true, std::sync::atomic::Ordering::Relaxed); @@ -694,12 +692,12 @@ pub mod tests { let counter = Arc::new(AtomicU32::new(0)); - let mut tasks: Vec> = vec![]; + let mut tasks: Vec> = vec![]; // all these sockets should receive hole punching packet for udp in udps.iter().map(Arc::clone) { let counter = counter.clone(); - tasks.push(ScopedTask::from(tokio::spawn(async move { + tasks.push(AbortOnDropHandle::new(tokio::spawn(async move { let mut buf = [0u8; 1024]; let (len, addr) = udp.recv_from(&mut buf).await.unwrap(); println!( diff --git a/easytier/src/gateway/socks5.rs b/easytier/src/gateway/socks5.rs index 3bdceaf1..fcad40bf 100644 --- a/easytier/src/gateway/socks5.rs +++ b/easytier/src/gateway/socks5.rs @@ -12,14 +12,12 @@ use crossbeam::atomic::AtomicCell; #[cfg(feature = "kcp")] use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream}; use tokio_util::sync::{CancellationToken, DropGuard}; +use tokio_util::task::AbortOnDropHandle; #[cfg(feature = "kcp")] use crate::gateway::kcp_proxy::NatDstKcpConnector; use crate::{ - common::{ - config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background, - scoped_task::ScopedTask, - }, + common::{config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background}, gateway::{ fast_socks5::{ server::{ @@ -473,7 +471,7 @@ pub struct Socks5Server { entries: Socks5EntrySet, udp_client_map: Arc>>, - udp_forward_task: Arc>>, + udp_forward_task: Arc>>, #[cfg(feature = "kcp")] kcp_endpoint: Mutex>>, @@ -997,7 +995,7 @@ impl Socks5Server { let client_addr = addr; udp_forward_task.insert( udp_client_key.clone(), - ScopedTask::from(tokio::spawn(async move { + AbortOnDropHandle::new(tokio::spawn(async move { loop { let mut buf = vec![0u8; 8192]; match socks_udp.recv_from(&mut buf).await { diff --git a/easytier/src/gateway/tokio_smoltcp/mod.rs b/easytier/src/gateway/tokio_smoltcp/mod.rs index 302496fd..adb2d2f3 100644 --- a/easytier/src/gateway/tokio_smoltcp/mod.rs +++ b/easytier/src/gateway/tokio_smoltcp/mod.rs @@ -22,8 +22,7 @@ use smoltcp::{ pub use socket::{TcpListener, TcpStream, UdpSocket}; pub use socket_allocator::BufferSize; use tokio::sync::Notify; - -use crate::common::scoped_task::ScopedTask; +use tokio_util::task::AbortOnDropHandle; /// The async devices. pub mod channel_device; @@ -79,7 +78,7 @@ pub struct Net { ip_addr: IpCidr, from_port: AtomicU16, stopper: Arc, - fut: ScopedTask>, + fut: AbortOnDropHandle>, } impl std::fmt::Debug for Net { @@ -131,7 +130,7 @@ impl Net { ip_addr: config.ip_addr, from_port: AtomicU16::new(10001), stopper, - fut: ScopedTask::from(tokio::spawn(fut)), + fut: AbortOnDropHandle::new(tokio::spawn(fut)), } } pub fn get_address(&self) -> IpAddr { diff --git a/easytier/src/gateway/udp_proxy.rs b/easytier/src/gateway/udp_proxy.rs index 5da5cb09..45e71f30 100644 --- a/easytier/src/gateway/udp_proxy.rs +++ b/easytier/src/gateway/udp_proxy.rs @@ -21,13 +21,14 @@ use tokio::{ task::{JoinHandle, JoinSet}, time::timeout, }; +use tokio_util::task::AbortOnDropHandle; use tracing::Level; use super::{CidrSet, ip_reassembler::IpReassembler}; use crate::tunnel::common::bind; use crate::{ - common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask}, + common::{PeerId, error::Error, global_ctx::ArcGlobalCtx}, gateway::ip_reassembler::{ComposeIpv4PacketArgs, compose_ipv4_packet}, peers::{PeerPacketFilter, peer_manager::PeerManager}, tunnel::{ @@ -149,7 +150,7 @@ impl UdpNatEntry { let (s, mut r) = channel(128); let self_clone = self.clone(); - let recv_task = ScopedTask::from(tokio::spawn(async move { + let recv_task = AbortOnDropHandle::new(tokio::spawn(async move { let mut cur_buf = BytesMut::new(); loop { if self_clone @@ -194,7 +195,7 @@ impl UdpNatEntry { })); let self_clone = self.clone(); - let send_task = ScopedTask::from(tokio::spawn(async move { + let send_task = AbortOnDropHandle::new(tokio::spawn(async move { let mut ip_id = 1; while let Some((mut packet, len, src_socket)) = r.recv().await { let SocketAddr::V4(mut src_v4) = src_socket else { diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 8d30f5fe..30ba9b50 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -16,13 +16,13 @@ use tokio::sync::{Mutex, Notify}; use tokio::{sync::oneshot, task::JoinSet}; #[cfg(feature = "magic-dns")] use tokio_util::sync::CancellationToken; +use tokio_util::task::AbortOnDropHandle; use crate::common::PeerId; use crate::common::acl_processor::AclRuleBuilder; use crate::common::config::ConfigLoader; use crate::common::error::Error; use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent}; -use crate::common::scoped_task::ScopedTask; use crate::connector::direct::DirectConnectorManager; use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager}; use crate::connector::tcp_hole_punch::TcpHolePunchConnector; @@ -135,7 +135,7 @@ type NicCtx = super::virtual_nic::NicCtx; #[cfg(feature = "magic-dns")] struct MagicDnsContainer { - dns_runner_task: ScopedTask<()>, + dns_runner_task: AbortOnDropHandle<()>, dns_runner_cancel_token: CancellationToken, } @@ -167,7 +167,7 @@ impl NicCtxContainer { Self { nic_ctx: Some(Box::new(nic_ctx)), magic_dns: Some(MagicDnsContainer { - dns_runner_task: task.into(), + dns_runner_task: AbortOnDropHandle::new(task), dns_runner_cancel_token: token, }), } @@ -558,7 +558,7 @@ pub struct Instance { #[cfg(feature = "socks5")] socks5_server: Arc, - proxy_cidrs_monitor: Option>, + proxy_cidrs_monitor: Option>, global_ctx: ArcGlobalCtx, } diff --git a/easytier/src/instance/proxy_cidrs_monitor.rs b/easytier/src/instance/proxy_cidrs_monitor.rs index 3ae5c3f5..2efbf915 100644 --- a/easytier/src/instance/proxy_cidrs_monitor.rs +++ b/easytier/src/instance/proxy_cidrs_monitor.rs @@ -3,8 +3,8 @@ use std::sync::{Arc, Weak}; use std::time::Instant; use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtxEvent}; -use crate::common::scoped_task::ScopedTask; use crate::peers::peer_manager::PeerManager; +use tokio_util::task::AbortOnDropHandle; /// ProxyCidrsMonitor monitors changes in proxy CIDRs from peer routes /// and emits GlobalCtxEvent::ProxyCidrsUpdated with added/removed diffs. @@ -58,8 +58,8 @@ impl ProxyCidrsMonitor { } /// Starts monitoring proxy_cidrs changes and emits events with diffs - pub fn start(self) -> ScopedTask<()> { - ScopedTask::from(tokio::spawn(async move { + pub fn start(self) -> AbortOnDropHandle<()> { + AbortOnDropHandle::new(tokio::spawn(async move { let mut cur_proxy_cidrs = BTreeSet::new(); let mut last_update = None::; diff --git a/easytier/src/instance_manager.rs b/easytier/src/instance_manager.rs index 26b257e3..2867c516 100644 --- a/easytier/src/instance_manager.rs +++ b/easytier/src/instance_manager.rs @@ -1,13 +1,13 @@ use dashmap::DashMap; use std::fmt::{Display, Formatter}; use std::{collections::BTreeMap, path::PathBuf, sync::Arc}; +use tokio_util::task::AbortOnDropHandle; use crate::{ common::{ config::{ConfigFileControl, ConfigLoader, ConfigSource, TomlConfigLoader}, global_ctx::{EventBusSubscriber, GlobalCtxEvent}, log, - scoped_task::ScopedTask, }, launcher::{NetworkInstance, NetworkInstanceRunningInfo}, proto::{self}, @@ -27,7 +27,7 @@ impl Drop for DaemonGuard { pub struct NetworkInstanceManager { instance_map: Arc>, - instance_stop_tasks: Arc>>, + instance_stop_tasks: Arc>>, stop_check_notifier: Arc, instance_error_messages: Arc>, config_dir: Option, @@ -78,12 +78,12 @@ impl NetworkInstanceManager { let stop_check_notifier = self.stop_check_notifier.clone(); self.instance_stop_tasks.insert( instance_id, - ScopedTask::from(tokio::spawn(async move { + AbortOnDropHandle::new(tokio::spawn(async move { let Some(instance_stop_notifier) = instance_stop_notifier else { return; }; let _t = instance_event_receiver - .map(|event| ScopedTask::from(handle_event(instance_id, event))); + .map(|event| AbortOnDropHandle::new(handle_event(instance_id, event))); instance_stop_notifier.notified().await; if let Some(instance) = instance_map.get(&instance_id) && let Some(error) = instance.get_latest_error_msg() diff --git a/easytier/src/peers/acl_filter.rs b/easytier/src/peers/acl_filter.rs index b4271d63..b611a864 100644 --- a/easytier/src/peers/acl_filter.rs +++ b/easytier/src/peers/acl_filter.rs @@ -13,7 +13,6 @@ use pnet::packet::{ Packet as _, ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket, }; -use crate::common::scoped_task::ScopedTask; use crate::proto::acl::{AclStats, Protocol}; use crate::tunnel::packet_def::PacketType; use crate::{ @@ -21,6 +20,7 @@ use crate::{ proto::acl::{Acl, Action, ChainType}, tunnel::packet_def::ZCPacket, }; +use tokio_util::task::AbortOnDropHandle; #[derive(Debug, Eq, PartialEq, Hash)] struct OutboundAllowRecord { @@ -63,7 +63,7 @@ pub struct AclFilter { // Track allowed outbound packets and automatically allow their corresponding inbound response // packets, even if they would normally be dropped by ACL rules outbound_allow_records: Arc>, - clean_task: ScopedTask<()>, + clean_task: AbortOnDropHandle<()>, } impl Default for AclFilter { @@ -80,14 +80,13 @@ impl AclFilter { acl_processor: ArcSwap::from(Arc::new(AclProcessor::new(Acl::default()))), acl_enabled: Arc::new(AtomicBool::new(false)), outbound_allow_records, - clean_task: tokio::spawn(async move { + clean_task: AbortOnDropHandle::new(tokio::spawn(async move { let max_life = std::time::Duration::from_secs(30); loop { record_clone.retain(|_, v| v.elapsed() < max_life); tokio::time::sleep(std::time::Duration::from_secs(30)).await; } - }) - .into(), + })), } } diff --git a/easytier/src/peers/foreign_network_client.rs b/easytier/src/peers/foreign_network_client.rs index 9670c15e..4c8d6765 100644 --- a/easytier/src/peers/foreign_network_client.rs +++ b/easytier/src/peers/foreign_network_client.rs @@ -1,9 +1,10 @@ use std::sync::{Arc, Mutex}; use crate::{ - common::{PeerId, error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask}, + common::{PeerId, error::Error, global_ctx::ArcGlobalCtx}, tunnel::packet_def::ZCPacket, }; +use tokio_util::task::AbortOnDropHandle; use super::{PacketRecvChan, peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager}; @@ -13,7 +14,7 @@ pub struct ForeignNetworkClient { my_peer_id: PeerId, peer_map: Arc, - task: Mutex>>, + task: Mutex>>, } impl ForeignNetworkClient { @@ -82,18 +83,15 @@ impl ForeignNetworkClient { pub async fn run(&self) { let peer_map = Arc::downgrade(&self.peer_map); - *self.task.lock().unwrap() = Some( - tokio::spawn(async move { - loop { - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - let Some(peer_map) = peer_map.upgrade() else { - break; - }; - peer_map.clean_peer_without_conn().await; - } - }) - .into(), - ); + *self.task.lock().unwrap() = Some(AbortOnDropHandle::new(tokio::spawn(async move { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + let Some(peer_map) = peer_map.upgrade() else { + break; + }; + peer_map.clean_peer_without_conn().await; + } + }))); } pub fn get_peer_map(&self) -> Arc { diff --git a/easytier/src/peers/peer.rs b/easytier/src/peers/peer.rs index 4c3b55ff..63af93f8 100644 --- a/easytier/src/peers/peer.rs +++ b/easytier/src/peers/peer.rs @@ -12,6 +12,7 @@ use super::{ PacketRecvChan, peer_conn::{PeerConn, PeerConnId}, }; +use crate::{common::shrink_dashmap, proto::api::instance::PeerConnInfo}; use crate::{ common::{ PeerId, @@ -21,10 +22,7 @@ use crate::{ proto::peer_rpc::PeerIdentityType, tunnel::packet_def::ZCPacket, }; -use crate::{ - common::{scoped_task::ScopedTask, shrink_dashmap}, - proto::api::instance::PeerConnInfo, -}; +use tokio_util::task::AbortOnDropHandle; type ArcPeerConn = Arc; type ConnMap = Arc>; @@ -37,14 +35,14 @@ pub struct Peer { packet_recv_chan: PacketRecvChan, close_event_sender: mpsc::Sender, - close_event_listener: ScopedTask<()>, + close_event_listener: AbortOnDropHandle<()>, shutdown_notifier: Arc, default_conn_id: Arc>, peer_identity_type: Arc>>, peer_public_key: Arc>>>, - default_conn_id_clear_task: ScopedTask<()>, + default_conn_id_clear_task: AbortOnDropHandle<()>, } impl Peer { @@ -64,7 +62,7 @@ impl Peer { let conns_copy = conns.clone(); let shutdown_notifier_copy = shutdown_notifier.clone(); let global_ctx_copy = global_ctx.clone(); - let close_event_listener = tokio::spawn( + let close_event_listener = AbortOnDropHandle::new(tokio::spawn( async move { loop { select! { @@ -103,14 +101,13 @@ impl Peer { "peer_close_event_listener", ?peer_node_id, )), - ) - .into(); + )); let default_conn_id = Arc::new(AtomicCell::new(PeerConnId::default())); let conns_copy = conns.clone(); let default_conn_id_copy = default_conn_id.clone(); - let default_conn_id_clear_task = ScopedTask::from(tokio::spawn(async move { + let default_conn_id_clear_task = AbortOnDropHandle::new(tokio::spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(5)).await; if conns_copy.len() > 1 { diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 4aafb1b9..5b273eff 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -1606,7 +1606,6 @@ pub mod tests { use crate::common::global_ctx::GlobalCtx; use crate::common::global_ctx::tests::get_mock_global_ctx; use crate::common::new_peer_id; - use crate::common::scoped_task::ScopedTask; use crate::common::stats_manager::{LabelSet, LabelType, MetricName}; use crate::peers::create_packet_recv_chan; use crate::peers::recv_packet_from_chan; @@ -1614,6 +1613,7 @@ pub mod tests { use crate::tunnel::filter::PacketRecorderTunnelFilter; use crate::tunnel::filter::tests::DropSendTunnelFilter; use crate::tunnel::ring::create_ring_tunnel_pair; + use tokio_util::task::AbortOnDropHandle; pub fn set_secure_mode_cfg(global_ctx: &GlobalCtx, enabled: bool) { if !enabled { @@ -2200,7 +2200,7 @@ pub mod tests { c_peer.start_recv_loop(create_packet_recv_chan().0).await; let throughput = c_peer.throughput.clone(); - let _t = ScopedTask::from(tokio::spawn(async move { + let _t = AbortOnDropHandle::new(tokio::spawn(async move { // if not drop both, we mock some rx traffic for client peer to test pinger if drop_both { return; diff --git a/easytier/src/peers/peer_task.rs b/easytier/src/peers/peer_task.rs index 72371819..47fd7fcb 100644 --- a/easytier/src/peers/peer_task.rs +++ b/easytier/src/peers/peer_task.rs @@ -11,8 +11,8 @@ use tokio::select; use tokio::sync::Notify; use tokio::task::JoinHandle; -use crate::common::scoped_task::ScopedTask; use anyhow::Error; +use tokio_util::task::AbortOnDropHandle; use super::peer_manager::PeerManager; @@ -72,7 +72,7 @@ pub trait PeerTaskLauncher: Send + Sync + Clone + 'static { pub struct PeerTaskManager { launcher: Launcher, - main_loop_task: Mutex>>, + main_loop_task: Mutex>>, run_signal: Arc, external_signal: Option>, data: Launcher::Data, @@ -105,13 +105,12 @@ where } pub fn start(&self) { - let task = tokio::spawn(Self::main_loop( + let task = AbortOnDropHandle::new(tokio::spawn(Self::main_loop( self.launcher.clone(), self.data.clone(), self.run_signal.clone(), self.external_signal.clone(), - )) - .into(); + ))); self.main_loop_task.lock().unwrap().replace(task); } @@ -121,7 +120,7 @@ where signal: Arc, external_signal: Option>, ) { - let peer_task_map = Arc::new(DashMap::>>::new()); + let peer_task_map = Arc::new(DashMap::>>::new()); let mut external_signal_version = external_signal.as_ref().map(|signal| signal.version()); loop { @@ -158,8 +157,10 @@ where } tracing::debug!(?item, "launch hole punching task"); - peer_task_map - .insert(item.clone(), launcher.launch_task(&data, item).await.into()); + peer_task_map.insert( + item.clone(), + AbortOnDropHandle::new(launcher.launch_task(&data, item).await), + ); } } else if peer_task_map.is_empty() { launcher.all_task_done(&data).await; diff --git a/easytier/src/proto/tests.rs b/easytier/src/proto/tests.rs index ee32bea3..89258a77 100644 --- a/easytier/src/proto/tests.rs +++ b/easytier/src/proto/tests.rs @@ -424,11 +424,11 @@ async fn standalone_rpc_test() { #[tokio::test] async fn test_bidirect_rpc_manager() { - use crate::common::scoped_task::ScopedTask; use crate::proto::rpc_impl::bidirect::BidirectRpcManager; use crate::tunnel::tcp::{TcpTunnelConnector, TcpTunnelListener}; use crate::tunnel::{TunnelConnector, TunnelListener}; use tokio::sync::Notify; + use tokio_util::task::AbortOnDropHandle; let c = BidirectRpcManager::new(); let s = BidirectRpcManager::new(); @@ -448,7 +448,7 @@ async fn test_bidirect_rpc_manager() { let server_test_done = Arc::new(Notify::new()); let server_test_done_clone = server_test_done.clone(); let mut tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:55443".parse().unwrap()); - let s_task: ScopedTask<()> = tokio::spawn(async move { + let s_task = AbortOnDropHandle::new(tokio::spawn(async move { tcp_listener.listen().await.unwrap(); let tunnel = tcp_listener.accept().await.unwrap(); s.run_with_tunnel(tunnel); @@ -471,8 +471,7 @@ async fn test_bidirect_rpc_manager() { server_test_done_clone.notify_one(); s.wait().await; - }) - .into(); + })); tokio::time::sleep(std::time::Duration::from_secs(1)).await; diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 5832771f..0b12c46b 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -779,10 +779,8 @@ pub async fn data_compress( #[tokio::test] #[serial_test::serial] pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) { - use crate::{ - common::scoped_task::ScopedTask, - tunnel::wireguard::{WgConfig, WgTunnelConnector}, - }; + use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector}; + use tokio_util::task::AbortOnDropHandle; let insts = init_three_node(proto).await; let mut inst4 = Instance::new(get_inst_config( @@ -838,7 +836,7 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str .await; set_link_status("net_d", false); - let _t = ScopedTask::from(tokio::spawn(async move { + let _t = AbortOnDropHandle::new(tokio::spawn(async move { // do some ping in net_a to trigger net_c pingpong loop { ping_test("net_a", "10.144.144.4", Some(1)).await; diff --git a/easytier/src/tunnel/fake_tcp/mod.rs b/easytier/src/tunnel/fake_tcp/mod.rs index c255ab54..895f6633 100644 --- a/easytier/src/tunnel/fake_tcp/mod.rs +++ b/easytier/src/tunnel/fake_tcp/mod.rs @@ -14,18 +14,16 @@ use std::{ }; use tokio::{io::AsyncReadExt, net::TcpStream, sync::Mutex}; -use crate::{ - common::scoped_task::ScopedTask, - tunnel::{ - FromUrl, IpVersion, SinkError, SinkItem, StreamItem, Tunnel, TunnelConnector, TunnelError, - TunnelInfo, TunnelListener, - common::TunnelWrapper, - fake_tcp::netfilter::create_tun, - packet_def::{PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE, ZCPacket, ZCPacketType}, - }, +use crate::tunnel::{ + FromUrl, IpVersion, SinkError, SinkItem, StreamItem, Tunnel, TunnelConnector, TunnelError, + TunnelInfo, TunnelListener, + common::TunnelWrapper, + fake_tcp::netfilter::create_tun, + packet_def::{PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE, ZCPacket, ZCPacketType}, }; use futures::Future; +use tokio_util::task::AbortOnDropHandle; use dashmap::DashMap; @@ -186,8 +184,8 @@ impl FakeTcpTunnelListener { } } -fn build_os_socket_reader_task(mut socket: TcpStream) -> ScopedTask<()> { - let os_socket_reader_task: ScopedTask<()> = tokio::spawn(async move { +fn build_os_socket_reader_task(mut socket: TcpStream) -> AbortOnDropHandle<()> { + AbortOnDropHandle::new(tokio::spawn(async move { // read the os socket until it's closed let mut buf = [0u8; 1024]; while let Ok(size) = socket.read(&mut buf).await { @@ -197,9 +195,7 @@ fn build_os_socket_reader_task(mut socket: TcpStream) -> ScopedTask<()> { } } tracing::info!("FakeTcpTunnelListener os socket closed"); - }) - .into(); - os_socket_reader_task + })) } #[derive(Debug)] diff --git a/easytier/src/tunnel/fake_tcp/stack.rs b/easytier/src/tunnel/fake_tcp/stack.rs index b3c311f8..e9696524 100644 --- a/easytier/src/tunnel/fake_tcp/stack.rs +++ b/easytier/src/tunnel/fake_tcp/stack.rs @@ -38,8 +38,6 @@ //! and [`server.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/server.rs) files //! from the `phantun` crate for how to use this library in client/server mode, respectively. -use crate::common::scoped_task::ScopedTask; - use super::packet::*; use bytes::{Bytes, BytesMut}; use crossbeam::atomic::AtomicCell; @@ -55,6 +53,7 @@ use std::sync::{ }; use tokio::sync::broadcast; use tokio::time; +use tokio_util::task::AbortOnDropHandle; use tracing::{info, trace, warn}; const TIMEOUT: time::Duration = time::Duration::from_secs(1); @@ -96,7 +95,7 @@ pub struct Stack { local_ip: Ipv4Addr, local_ip6: Option, local_mac: MacAddr, - reader_task: ScopedTask<()>, + reader_task: AbortOnDropHandle<()>, } #[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)] @@ -418,7 +417,7 @@ impl Stack { local_ip, local_ip6, local_mac: local_mac.unwrap_or(MacAddr::zero()), - reader_task: t.into(), + reader_task: AbortOnDropHandle::new(t), } } diff --git a/easytier/src/tunnel/mpsc.rs b/easytier/src/tunnel/mpsc.rs index 77e77257..e15231ae 100644 --- a/easytier/src/tunnel/mpsc.rs +++ b/easytier/src/tunnel/mpsc.rs @@ -5,11 +5,12 @@ use std::{pin::Pin, time::Duration}; use anyhow::Context; use tokio::time::timeout; -use crate::{common::scoped_task::ScopedTask, proto::common::TunnelInfo}; +use crate::proto::common::TunnelInfo; use super::{Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, packet_def::ZCPacket}; use tokio::sync::mpsc::{Receiver, Sender, channel, error::TrySendError}; +use tokio_util::task::AbortOnDropHandle; // use tachyonix::{channel, Receiver, Sender, TrySendError}; use futures::SinkExt; @@ -37,7 +38,7 @@ pub struct MpscTunnel { tunnel: T, stream: Option>>, - task: ScopedTask<()>, + task: AbortOnDropHandle<()>, } impl MpscTunnel { @@ -61,7 +62,7 @@ impl MpscTunnel { tx: Some(tx), tunnel, stream: Some(stream), - task: task.into(), + task: AbortOnDropHandle::new(task), } } diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 7bd2ecf9..c8a4c0ed 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -17,6 +17,7 @@ use tokio::{ sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender}, task::JoinSet, }; +use tokio_util::task::AbortOnDropHandle; use tracing::{Instrument, instrument}; use super::{ @@ -28,7 +29,7 @@ use super::{ }; use crate::tunnel::common::bind; use crate::{ - common::{join_joinset_background, scoped_task::ScopedTask, shrink_dashmap}, + common::{join_joinset_background, shrink_dashmap}, tunnel::{ build_url_from_socket_addr, common::{TunnelWrapper, reserve_buf}, @@ -339,7 +340,7 @@ struct UdpConnection { dst_addr: SocketAddr, ring_sender: RingSink, - forward_task: ScopedTask<()>, + forward_task: AbortOnDropHandle<()>, } impl UdpConnection { @@ -352,15 +353,13 @@ impl UdpConnection { close_event_sender: UdpCloseEventSender, ) -> Self { let s = socket.clone(); - let forward_task = tokio::spawn(async move { + let forward_task = AbortOnDropHandle::new(tokio::spawn(async move { let close_event_sender = close_event_sender; let err = forward_from_ring_to_udp(ring_recv, &s, &dst_addr, conn_id).await; if let Err(e) = close_event_sender.send((dst_addr, err)) { tracing::error!(?e, "udp send close event error"); } - }) - .into(); - + })); Self { socket, conn_id, diff --git a/easytier/src/web_client/mod.rs b/easytier/src/web_client/mod.rs index ee478e1c..8a2ff6be 100644 --- a/easytier/src/web_client/mod.rs +++ b/easytier/src/web_client/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use crate::{ common::{ config::TomlConfigLoader, global_ctx::GlobalCtx, log, os_info::collect_device_os_info, - scoped_task::ScopedTask, set_default_machine_id, stun::MockStunInfoCollector, + set_default_machine_id, stun::MockStunInfoCollector, }, connector::create_connector_by_url, instance_manager::{DaemonGuard, NetworkInstanceManager}, @@ -12,6 +12,7 @@ use crate::{ }; use anyhow::{Context as _, Result}; use async_trait::async_trait; +use tokio_util::task::AbortOnDropHandle; use url::Url; use uuid::Uuid; @@ -43,7 +44,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; pub struct WebClient { controller: Arc, - tasks: ScopedTask<()>, + tasks: AbortOnDropHandle<()>, manager_guard: DaemonGuard, connected: Arc, } @@ -70,7 +71,7 @@ impl WebClient { let controller_clone = controller.clone(); let connected_clone = connected.clone(); - let tasks = ScopedTask::from(tokio::spawn(async move { + let tasks = AbortOnDropHandle::new(tokio::spawn(async move { Self::routine( controller_clone, connected_clone,