From eb3b5aae5151eaf466740ba67c68f144f39f880e Mon Sep 17 00:00:00 2001 From: Luna Yao <40349250+ZnqbuZ@users.noreply.github.com> Date: Sat, 25 Apr 2026 12:24:36 +0200 Subject: [PATCH] utils: add DetachableTask & ContextGuard (#2138) --- easytier/src/common/defer.rs | 26 -- easytier/src/common/mod.rs | 1 - easytier/src/peers/peer_conn.rs | 27 +- easytier/src/utils/guard.rs | 638 ++++++++++++++++++++++++++++++++ easytier/src/utils/mod.rs | 1 + easytier/src/utils/task.rs | 287 ++++++++++++++ 6 files changed, 939 insertions(+), 41 deletions(-) delete mode 100644 easytier/src/common/defer.rs create mode 100644 easytier/src/utils/guard.rs diff --git a/easytier/src/common/defer.rs b/easytier/src/common/defer.rs deleted file mode 100644 index c243d1ce..00000000 --- a/easytier/src/common/defer.rs +++ /dev/null @@ -1,26 +0,0 @@ -#[doc(hidden)] -pub struct Defer { - // internal struct used by defer! macro - func: Option, -} - -impl Defer { - pub fn new(func: F) -> Self { - Self { func: Some(func) } - } -} - -impl Drop for Defer { - fn drop(&mut self) { - if let Some(f) = self.func.take() { - f() - } - } -} - -#[macro_export] -macro_rules! defer { - ( $($tt:tt)* ) => { - let _deferred = $crate::common::defer::Defer::new(|| { $($tt)* }); - }; -} diff --git a/easytier/src/common/mod.rs b/easytier/src/common/mod.rs index 5e5e9a2e..f1c03b11 100644 --- a/easytier/src/common/mod.rs +++ b/easytier/src/common/mod.rs @@ -14,7 +14,6 @@ pub mod acl_processor; pub mod compressor; pub mod config; pub mod constants; -pub mod defer; pub mod dns; pub mod env_parser; pub mod error; diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 5b273eff..9849edc7 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -1,3 +1,5 @@ +use crossbeam::atomic::AtomicCell; +use futures::{StreamExt, TryFutureExt}; use std::{ any::Any, fmt::Debug, @@ -8,9 +10,6 @@ use std::{ }, }; -use crossbeam::atomic::AtomicCell; -use futures::{StreamExt, TryFutureExt}; - use base64::Engine as _; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use hmac::Mac; @@ -27,14 +26,21 @@ use zerocopy::AsBytes; use snow::{HandshakeState, params::NoiseParams}; +use super::{ + PacketRecvChan, + peer_conn_ping::PeerConnPinger, + peer_session::{PeerSession, PeerSessionAction}, + traffic_metrics::AggregateTrafficMetrics, +}; +use crate::utils::BoxExt; use crate::{ common::{ PeerId, config::{NetworkIdentity, NetworkSecretDigest}, - defer, error::Error, global_ctx::ArcGlobalCtx, }, + guard, peers::peer_session::{PeerSessionStore, SessionKey, UpsertResponderSessionReturn}, proto::{ api::instance::{PeerConnInfo, PeerConnStats}, @@ -54,13 +60,6 @@ use crate::{ use_global_var, }; -use super::{ - PacketRecvChan, - peer_conn_ping::PeerConnPinger, - peer_session::{PeerSession, PeerSessionAction}, - traffic_metrics::AggregateTrafficMetrics, -}; - pub type PeerConnId = uuid::Uuid; const MAGIC: u32 = 0xd1e1a5e1; @@ -381,9 +380,9 @@ impl PeerConn { session_filter, noise_handshake_result: None, - tunnel: Arc::new(Mutex::new(Box::new(defer::Defer::new(move || { - mpsc_tunnel.close() - })))), + tunnel: Arc::new(Mutex::new( + guard!([mut mpsc_tunnel] mpsc_tunnel.close()).boxed(), + )), sink, recv: Mutex::new(Some(recv)), tunnel_info, diff --git a/easytier/src/utils/guard.rs b/easytier/src/utils/guard.rs new file mode 100644 index 00000000..eca86824 --- /dev/null +++ b/easytier/src/utils/guard.rs @@ -0,0 +1,638 @@ +//! # Guard Module Utilities +//! +//! This module provides mechanisms for scope-based resource management and deferred execution. +//! +//! ### ⚠️ Critical Usage Note: Diverging Expressions +//! +//! Do not use "naked" diverging expressions—such as `panic!`, `todo!`, or `loop {}`—as +//! the sole content of sync guard closure. This prevents the compiler from +//! distinguishing between synchronous (`ASYNC = false`) and asynchronous +//! (`ASYNC = true`) implementations, leading to a type inference error (E0277). +//! +//! ### Technical Context +//! +//! The `!` (Never Type) is a bottom type that can be coerced into any other type. +//! Because it satisfies both the `()` requirement for sync guards and the `Future` +//! requirement for async guards, the compiler encounters an inference deadlock. +//! +//! ### Workaround +//! +//! For macros like `guard!` or `guarded!`, force the closure to resolve to `()` +//! by explicitly setting the guard to `sync`: +//! +//! ```rust +//! let _g = guard!([val] sync { +//! panic!("critical failure"); +//! }); +//! ``` + +use crate::utils::task::{DetachableTask, TaskSpawner}; +use std::fmt::Debug; +use std::mem::ManuallyDrop; +use std::ops::{Deref, DerefMut}; + +pub trait CallableGuard { + type Output; + fn call(self, context: Context) -> Self::Output; +} + +impl CallableGuard for Guard +where + Guard: FnOnce(Context), +{ + type Output = (); + + fn call(self, context: Context) -> Self::Output { + self(context) + } +} + +impl CallableGuard for Guard +where + Guard: FnOnce(Context) -> Task + Send + 'static, + Task: Future + Send + 'static, + _R: Send + 'static, +{ + type Output = DetachableTask, Task>; + + fn call(self, context: Context) -> Self::Output { + DetachableTask::new(self(context)) + } +} + +pub struct ContextGuard> { + context: ManuallyDrop, + guard: ManuallyDrop, +} + +impl> Deref + for ContextGuard +{ + type Target = Context; + + fn deref(&self) -> &Self::Target { + &self.context + } +} + +impl> DerefMut + for ContextGuard +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.context + } +} + +impl> Debug + for ContextGuard +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = if ASYNC { + "ContextGuard::Async" + } else { + "ContextGuard::Sync" + }; + f.debug_struct(name) + .field("context", &self.context) + .finish_non_exhaustive() + } +} + +impl> + ContextGuard +{ + /// Creates a new `ContextGuard`. + /// + /// **Note on generics:** The seemingly unused `_R` generic parameter and the + /// `Guard: FnOnce(Context) -> _R` trait bound are intentionally included. + /// They act as a hint to help the compiler infer closure types. + pub fn new<_R>(context: Context, guard: Guard) -> Self + where + Guard: FnOnce(Context) -> _R, + { + ContextGuard { + context: ManuallyDrop::new(context), + guard: ManuallyDrop::new(guard), + } + } +} + +impl> + ContextGuard +{ + unsafe fn call(&mut self) -> Guard::Output { + unsafe { + let context = ManuallyDrop::take(&mut self.context); + let guard = ManuallyDrop::take(&mut self.guard); + + guard.call(context) + } + } + + pub fn trigger(self) -> Guard::Output { + let mut this = ManuallyDrop::new(self); + unsafe { this.call() } + } + + pub fn defuse(self) -> Context { + let mut this = ManuallyDrop::new(self); + unsafe { + ManuallyDrop::drop(&mut this.guard); + ManuallyDrop::take(&mut this.context) + } + } +} + +impl> Drop + for ContextGuard +{ + fn drop(&mut self) { + let _: Guard::Output = unsafe { self.call() }; + } +} + +// region macro + +#[doc(hidden)] +#[macro_export] +macro_rules! __guarded { + (@parse@action $guard:ident => $($tt:tt)*) => { + $crate::__guarded! { @parse@async action: [ @stmt $guard ] ; $($tt)* } + }; + + (@parse@action $($tt:tt)*) => { + $crate::__guarded! { @parse@async action: [ @stmt __guard ] ; $($tt)* } + }; + + (@parse@async action: [ $($action:tt)* ] ; sync $($tt:tt)*) => { + $crate::__guarded! { @parse@move action: [ $($action)* ] ; async: [ false ] ; $($tt)* } + }; + + (@parse@async action: [ $($action:tt)* ] ; $($tt:tt)*) => { + $crate::__guarded! { @parse@move action: [ $($action)* ] ; async: [ _ ] ; $($tt)* } + }; + + (@parse@move action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move $($tt:tt)*) => { + $crate::__guarded! { @parse action: [ $($action)* ] ; async: [ $async ] ; move: [ move ] ; $($tt)* } + }; + + (@parse@move action: [ $($action:tt)* ] ; async: [ $async:tt ] ; $($tt:tt)*) => { + $crate::__guarded! { @parse action: [ $($action)* ] ; async: [ $async ] ; move: [] ; $($tt)* } + }; + + ( + @parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ; + [ $($args:tt)* ] $body:block + ) => { + $crate::__guarded! { + action: [ $($action)* ] + async: [ $async ] + move: [ $($move)? ] + mut: [] + rest: [ $($args)* , ] + args: [] + vars: [] + body: [ $body ] + } + }; + + ( + @parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ; + $body:block + ) => { + $crate::__guarded! { + @parse action: [ $($action)* ] ; async: [ $async ] ; move: [ $($move)? ] ; + [] $body + } + }; + + ( + @parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ; + [ $($args:tt)* ] $($body:tt)* + ) => { + $crate::__guarded! { + @parse action: [ $($action)* ] ; async: [ $async ] ; move: [ $($move)? ] ; + [ $($args)* ] { $($body)* } + } + }; + + ( + @parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ; + $($body:tt)* + ) => { + $crate::__guarded! { + @parse action: [ $($action)* ] ; async: [ $async ] ; move: [ $($move)? ] ; + [] { $($body)* } + } + }; + + ( + action: [ $($action:tt)* ] + async: [ $async:tt ] + move: [ $($move:tt)? ] + mut: [ $($mut:tt)? ] + rest: [ mut $arg:ident , $($rest:tt)* ] + args: [ $($args:ident)* ] + vars: [ $($vars:tt)* ] + body: [ $body:expr ] + ) => { + $crate::__guarded! { + action: [ $($action)* ] + async: [ $async ] + move: [ $($move)? ] + mut: [ mut ] + rest: [ $($rest)* ] + args: [ $($args)* $arg ] + vars: [ $($vars)* [mut $arg] ] + body: [ $body ] + } + }; + + ( + action: [ $($action:tt)* ] + async: [ $async:tt ] + move: [ $($move:tt)? ] + mut: [ $($mut:tt)? ] + rest: [ $arg:ident , $($rest:tt)* ] + args: [ $($args:ident)* ] + vars: [ $($vars:tt)* ] + body: [ $body:expr ] + ) => { + $crate::__guarded! { + action: [ $($action)* ] + async: [ $async ] + move: [ $($move)? ] + mut: [ $($mut)? ] + rest: [ $($rest)* ] + args: [ $($args)* $arg ] + vars: [ $($vars)* [$arg] ] + body: [ $body ] + } + }; + + ( + action: [ @stmt $guard:ident ] + async: [ $async:tt ] + move: [ $($move:tt)? ] + mut: [ $($mut:tt)? ] + rest: [ $(,)* ] + args: [ $($args:ident)* ] + vars: [ $([$($vars:tt)*])* ] + body: [ $body:expr ] + ) => { + let $($mut)? $guard = $crate::utils::guard::ContextGuard::<$async, _, _>::new( + ( $($args),* ), + $($move)? |#[allow(unused_parens, unused_mut)] ( $($($vars)*),* )| $body + ); + + #[allow(unused_parens, unused_variables, clippy::toplevel_ref_arg)] + let ( $(ref $($vars)*),* ) = *$guard; + }; + + ( + action: [ @expr ] + async: [ $async:tt ] + move: [ $($move:tt)? ] + mut: [ $($mut:tt)? ] + rest: [ $(,)* ] + args: [ $($args:ident)* ] + vars: [ $([$($vars:tt)*])* ] + body: [ $body:expr ] + ) => { + $crate::utils::guard::ContextGuard::<$async, _, _>::new( + ( $($args),* ), + $($move)? |#[allow(unused_parens)] ( $($($vars)*),* )| $body + ) + }; +} + +/// Creates a [`ContextGuard`] object, binding it to a variable with the specified name (e.g., `_guard`). +/// Context variables specified in the macro invocation are available within and after the guard body. +/// +/// **Note:** For usage with `panic!` or `loop`, see the [module-level documentation](self) +/// regarding type inference deadlocks. +#[macro_export] +macro_rules! guarded { + ( $($tt:tt)* ) => { + $crate::__guarded! { @parse@action $($tt)* } + }; +} + +/// Creates a [`ContextGuard`] object, without binding it to a variable. +/// Context variables specified in the macro invocation are available within the guard body. +/// +/// **Note:** For usage with `panic!` or `loop`, see the [module-level documentation](self) +/// regarding type inference deadlocks. +#[macro_export] +macro_rules! guard { + ( $($tt:tt)* ) => { + $crate::__guarded! { @parse@async action: [ @expr ] ; $($tt)* } + }; +} + +// endregion + +/// Alias for [`guarded!`]. +/// +/// **Note:** For usage with `panic!` or `loop`, see the [module-level documentation](self) +/// regarding type inference deadlocks. +#[macro_export] +macro_rules! defer { + ( $($tt:tt)* ) => { + $crate::guarded! { $($tt)* } + }; +} + +#[cfg(test)] +mod tests { + use std::panic::catch_unwind; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; + use tokio::sync::oneshot; + + #[test] + fn trigger_sync_executes_once() { + let called = Arc::new(AtomicUsize::new(0)); + let observed = Arc::new(AtomicUsize::new(0)); + + let value = 7usize; + let guard = { + let called = called.clone(); + let observed = observed.clone(); + crate::guard!(move [value] { + called.fetch_add(1, Ordering::SeqCst); + observed.store(value, Ordering::SeqCst); + }) + }; + + guard.trigger(); + + assert_eq!(called.load(Ordering::SeqCst), 1); + assert_eq!(observed.load(Ordering::SeqCst), 7); + } + + #[test] + fn defuse_sync_returns_context_without_running_guard() { + let called = Arc::new(AtomicUsize::new(0)); + + let value = String::from("hello"); + let guard = { + let called = called.clone(); + crate::guard!(move [mut value] { + value.push_str(" world"); + called.fetch_add(1, Ordering::SeqCst); + }) + }; + + let context = guard.defuse(); + assert_eq!(context, "hello"); + assert_eq!(called.load(Ordering::SeqCst), 0); + } + + #[test] + fn drop_sync_triggers_guard() { + let called = Arc::new(AtomicUsize::new(0)); + + { + let called = called.clone(); + crate::guarded!([called] { + called.fetch_add(1, Ordering::SeqCst); + }); + } + + assert_eq!(called.load(Ordering::SeqCst), 1); + } + + #[test] + fn drop_propagates_guard_panic() { + let dropped = catch_unwind(|| { + guarded! { + sync { + panic!("boom"); + } + } + }); + + assert!(dropped.is_err()); + } + + #[tokio::test] + async fn trigger_async_returns_runnable_task() { + let called = Arc::new(AtomicUsize::new(0)); + + let value = 5usize; + let guard = { + let called = called.clone(); + crate::guard!(move [value] async move { + called.fetch_add(value, Ordering::SeqCst); + }) + }; + let task = guard.trigger(); + task.await; + + assert_eq!(called.load(Ordering::SeqCst), 5); + } + + #[tokio::test] + async fn drop_async_detaches_task() { + let (tx, rx) = oneshot::channel(); + + { + let mut tx = Some(tx); + let value = 9usize; + let _guard = crate::guard!(move [value] { + let tx = tx.take(); + async move { + if let Some(tx) = tx { + let _ = tx.send(value); + } + } + }); + } + + let value = tokio::time::timeout(Duration::from_secs(1), rx) + .await + .expect("detached task should run") + .expect("detached task should send value"); + assert_eq!(value, 9); + } + + #[tokio::test] + async fn defuse_async_does_not_execute() { + let called = Arc::new(AtomicUsize::new(0)); + + let value = 11usize; + let guard = { + let called = called.clone(); + crate::guard!(move [value] async move { + called.fetch_add(value, Ordering::SeqCst); + }) + }; + + let context = guard.defuse(); + assert_eq!(context, 11); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(called.load(Ordering::SeqCst), 0); + } + + #[test] + fn guarded_named_mut_binding_updates_context_before_drop() { + let committed = Arc::new(AtomicUsize::new(0)); + + { + let value = 1usize; + let step = 2usize; + let committed = committed.clone(); + + crate::guarded!(scope_guard => [mut value, step] { + committed.store(value + step, Ordering::SeqCst); + }); + + *value += 10; + assert_eq!(*value, 11); + assert_eq!(*step, 2); + + drop(scope_guard); + } + + assert_eq!(committed.load(Ordering::SeqCst), 13); + } + + #[test] + fn guard_expression_parses_without_braces() { + let observed = Arc::new(AtomicUsize::new(0)); + + let value = 3usize; + let observed_clone = observed.clone(); + let guard = crate::guard!([value] observed_clone.store(value, Ordering::SeqCst)); + guard.trigger(); + + assert_eq!(observed.load(Ordering::SeqCst), 3); + } + + #[test] + fn defer_alias_behaves_like_guarded_statement() { + let called = Arc::new(AtomicUsize::new(0)); + + { + let n = 42usize; + let called = called.clone(); + crate::defer!([n] { + called.store(n, Ordering::SeqCst); + }); + } + + assert_eq!(called.load(Ordering::SeqCst), 42); + } + + #[tokio::test] + async fn guard_and_guarded_macro_usage_matrix() { + // 1) guard!: block body + trailing comma args + trigger() + let sink = Arc::new(AtomicUsize::new(0)); + let v = 1usize; + let sink_clone = sink.clone(); + let g1 = crate::guard!([v,] { + sink_clone.store(v, Ordering::SeqCst); + }); + g1.trigger(); + assert_eq!(sink.load(Ordering::SeqCst), 1); + + // 2) guard!: expression body (no braces) + let sink = Arc::new(AtomicUsize::new(0)); + let sink_clone = sink.clone(); + let v = 2usize; + let g2 = crate::guard!([v] sink_clone.store(v, Ordering::SeqCst)); + g2.trigger(); + assert_eq!(sink.load(Ordering::SeqCst), 2); + + // 3) guard!: explicit sync + no args form + let sink = Arc::new(AtomicUsize::new(0)); + let sink_clone = sink.clone(); + let g3 = crate::guard!(sync { + sink_clone.store(3, Ordering::SeqCst); + }); + g3.trigger(); + assert_eq!(sink.load(Ordering::SeqCst), 3); + + // 4) guard!: move capture + defuse() prevents execution + let sink = Arc::new(AtomicUsize::new(0)); + let owned = String::from("owned"); + let sink_clone = sink.clone(); + let g4 = crate::guard!(move [owned] { + if owned == "owned" { + sink_clone.store(4, Ordering::SeqCst); + } + }); + let context = g4.defuse(); + assert_eq!(context, "owned"); + assert_eq!(sink.load(Ordering::SeqCst), 0); + + // 5) guard!: async block inference + trigger() returns task + let sink = Arc::new(AtomicUsize::new(0)); + let sink_clone = sink.clone(); + let n = 5usize; + let g5 = crate::guard!([n] async move { + sink_clone.fetch_add(n, Ordering::SeqCst); + }); + g5.trigger().await; + assert_eq!(sink.load(Ordering::SeqCst), 5); + + // 6) guarded!: named binding + mut arg visible outside + explicit drop + let sink = Arc::new(AtomicUsize::new(0)); + { + let value = 6usize; + let delta = 1usize; + let sink_clone = sink.clone(); + + crate::guarded!(named => [mut value, delta] { + sink_clone.store(value + delta, Ordering::SeqCst); + }); + + *value += 10; + assert_eq!(*value, 16); + assert_eq!(*delta, 1); + drop(named); + } + assert_eq!(sink.load(Ordering::SeqCst), 17); + + // 7) guarded!: unnamed statement + expression body + implicit drop at scope end + let sink = Arc::new(AtomicUsize::new(0)); + { + let n = 7usize; + let sink_clone = sink.clone(); + crate::guarded!([n] sink_clone.store(n, Ordering::SeqCst)); + } + assert_eq!(sink.load(Ordering::SeqCst), 7); + + // 8) guarded!: explicit sync + panic path propagates on drop + let dropped = catch_unwind(|| { + guarded! { + sync { + panic!("matrix-boom"); + } + } + }); + assert!(dropped.is_err()); + + // 9) guarded!: async inference on drop detaches and executes + let (tx, rx) = oneshot::channel(); + { + let tx = Some(tx); + crate::guarded!([mut tx] { + let tx = tx.take(); + async move { + if let Some(tx) = tx { + let _ = tx.send(9usize); + } + } + }); + } + let detached = tokio::time::timeout(Duration::from_secs(1), rx) + .await + .expect("detached task should complete") + .expect("detached task should send value"); + assert_eq!(detached, 9); + } +} diff --git a/easytier/src/utils/mod.rs b/easytier/src/utils/mod.rs index 1339c10d..192dd9b9 100644 --- a/easytier/src/utils/mod.rs +++ b/easytier/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod guard; pub mod panic; pub mod string; pub mod task; diff --git a/easytier/src/utils/task.rs b/easytier/src/utils/task.rs index a9635ee5..dcad7933 100644 --- a/easytier/src/utils/task.rs +++ b/easytier/src/utils/task.rs @@ -1,5 +1,7 @@ +use crate::utils::guard::ContextGuard; use std::future::Future; use std::io; +use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; @@ -7,6 +9,8 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tokio_util::task::AbortOnDropHandle; +// region CancellableTask + #[derive(Debug)] pub struct CancellableTask { handle: AbortOnDropHandle, @@ -74,3 +78,286 @@ impl Future for CancellableTask { .map(|result| result.map_err(Into::into)) } } + +// endregion + +// region DetachableTask + +/// A pinned, heap-allocated task. +/// +/// **Why Box?** Heap allocation is required because if the task detaches, +/// it outlives the current stack frame. `Pin>` ensures its memory address +/// remains completely stable during and after the transfer. +type BoxTask = Pin>; + +struct DetachableTaskContext { + spawner: Spawner, + task: Option>, +} +type DetachableTaskGuardHelper = ContextGuard; +type DetachableTaskGuard = + DetachableTaskGuardHelper>; + +/// A task wrapper that executes inline but automatically detaches to a background spawner +/// if the current execution context is interrupted or dropped. +/// +/// `DetachableTask` ensures anti-cancellation. If the outer future is dropped (e.g., due to +/// a timeout or a `select!` branch failing), the underlying unfinished task is seamlessly +/// transferred to a background executor via an RAII guard. +/// +/// # Advantages over `tokio::spawn` + `.await JoinHandle` +/// +/// 1. **Zero Initial Scheduling Overhead**: Prioritizes inline execution. If the task +/// completes before being interrupted, it entirely bypasses the runtime's scheduling queue, +/// eliminating queuing latency and context-switching CPU costs. Spawning is strictly a fallback. +/// +/// 2. **Context Locality**: Before detachment, the task is polled directly by the caller's thread. +/// This implicitly preserves the current execution context, including thread-local storage (TLS), +/// Tokio `task_local!` variables, and `tracing` spans, which would otherwise be immediately +/// lost or require explicit propagation across task boundaries. +pub struct DetachableTask { + guard: DetachableTaskGuard, +} + +impl DetachableTask { + pub fn detach(self) { + self.guard.trigger() + } + + pub fn reclaim(self) -> BoxTask { + self.guard.defuse().task.unwrap() + } +} + +pub type TaskSpawner::Output>> = fn(BoxTask) -> R; + +impl DetachableTask { + pub fn with_spawner( + spawner: Spawner, + task: Task, + ) -> DetachableTask + where + Spawner: FnOnce(BoxTask) -> _R, + { + let context = DetachableTaskContext { + spawner, + task: Some(Box::pin(task)), + }; + DetachableTask { + guard: crate::guard!([context] if let Some(task) = context.task { + (context.spawner)(task); + }), + } + } + + pub fn new(task: Task) -> DetachableTask, Task> + where + Task: Future + Send + 'static, + ::Output: Send + 'static, + { + Self::with_spawner(|task| tokio::runtime::Handle::current().spawn(task), task) + } +} + +impl) -> _R, _R, Task> IntoFuture for DetachableTask +where + Task: Future, +{ + type Output = Task::Output; + type IntoFuture = DetachableTaskFuture; + + fn into_future(self) -> Self::IntoFuture { + DetachableTaskFuture { guard: self.guard } + } +} + +pub struct DetachableTaskFuture { + guard: DetachableTaskGuard, +} + +impl) -> _R, _R, Task> Future for DetachableTaskFuture +where + Task: Future, +{ + type Output = Task::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: + // 1. We only access the outer struct's unpinned fields. + // 2. The inner task remains securely pinned on the heap via `BoxTask`. + // 3. We never expose a mutable, unpinned reference to the underlying task. + let this = unsafe { self.get_unchecked_mut() }; + let context = this.guard.deref_mut(); + let mut task = context.task.take().expect("polled after completion"); + let poll = task.as_mut().poll(cx); + if poll.is_pending() { + context.task = Some(task); + } + poll + } +} + +// endregion + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::time::Duration; + use tokio::sync::{mpsc, oneshot}; + + #[tokio::test] + async fn spawn_when_dropped() { + let spawned = Arc::new(AtomicBool::new(false)); + { + let spawned = spawned.clone(); + let _task = DetachableTask::new(async move { + spawned.store(true, Ordering::SeqCst); + }); + } + + tokio::time::timeout(Duration::from_secs(1), async { + while !spawned.load(Ordering::SeqCst) { + tokio::task::yield_now().await; + } + }) + .await + .expect("task should be spawned on drop"); + } + + #[tokio::test] + async fn await_completed_task_does_not_detach() { + let spawn_count = Arc::new(AtomicUsize::new(0)); + let result = { + let spawn_count = spawn_count.clone(); + DetachableTask::with_spawner( + move |_| { + spawn_count.fetch_add(1, Ordering::SeqCst); + }, + async { 7usize }, + ) + .await + }; + + assert_eq!(result, 7); + assert_eq!(spawn_count.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn drop_without_await_and_runs_once() { + let spawn_count = Arc::new(AtomicUsize::new(0)); + let (done_tx, done_rx) = oneshot::channel(); + + { + let spawn_count = spawn_count.clone(); + let _task = DetachableTask::with_spawner( + move |f| { + spawn_count.fetch_add(1, Ordering::SeqCst); + tokio::spawn(async move { + let result = f.await; + let _ = done_tx.send(result); + }); + }, + async { 42usize }, + ); + } + + let detached_result = tokio::time::timeout(Duration::from_secs(1), done_rx) + .await + .expect("detached task should finish") + .expect("detached task should send result"); + + assert_eq!(detached_result, 42); + assert_eq!(spawn_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn drop_after_await_still_detaches() { + let spawn_count = Arc::new(AtomicUsize::new(0)); + let (value_tx, mut value_rx) = mpsc::channel(4); + let (done_tx, done_rx) = oneshot::channel(); + + let handle = { + let future = async move { + let mut sum = 0; + while let Some(value) = value_rx.recv().await { + sum += value; + } + sum + }; + + let spawn_count = spawn_count.clone(); + let task = DetachableTask::with_spawner( + move |f| { + spawn_count.fetch_add(1, Ordering::SeqCst); + tokio::spawn(async move { + let result = f.await; + let _ = done_tx.send(result); + }); + }, + future, + ); + + tokio::spawn(task.into_future()) + }; + + value_tx + .send(10) + .await + .expect("value receiver should still exist"); + handle.abort(); + value_tx + .send(11) + .await + .expect("value receiver should still exist"); + drop(value_tx); + + let detached_result = tokio::time::timeout(Duration::from_secs(1), done_rx) + .await + .expect("detached polled task should finish") + .expect("detached polled task should send result"); + + assert_eq!(detached_result, 21); + assert_eq!(spawn_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn panic_during_inline_poll_does_not_detach_on_drop() { + struct PanicOnPollFuture { + poll_count: Arc, + } + + impl Future for PanicOnPollFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + self.poll_count.fetch_add(1, Ordering::SeqCst); + panic!("panic during inline poll") + } + } + + let poll_count = Arc::new(AtomicUsize::new(0)); + let detach_count = Arc::new(AtomicUsize::new(0)); + + let task = { + let detach_count = detach_count.clone(); + DetachableTask::with_spawner( + move |_| { + detach_count.fetch_add(1, Ordering::SeqCst); + }, + PanicOnPollFuture { + poll_count: poll_count.clone(), + }, + ) + }; + + let err = tokio::spawn(task.into_future()) + .await + .expect_err("inline poll panic should propagate"); + + assert!(err.is_panic()); + assert_eq!(poll_count.load(Ordering::SeqCst), 1); + assert_eq!(detach_count.load(Ordering::SeqCst), 0); + } +}