diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 99c07751..e136e43d 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -62,7 +62,7 @@ futures = { version = "0.3", features = ["bilock", "unstable"] } tokio = { version = "1", features = ["full"] } tokio-stream = "0.1" -tokio-util = { version = "0.7.9", features = ["codec", "net", "io"] } +tokio-util = { version = "0.7.9", features = ["codec", "net", "io", "rt"] } async-stream = "0.3.5" async-trait = "0.1.74" diff --git a/easytier/src/utils/task.rs b/easytier/src/utils/task.rs index db25df64..a9635ee5 100644 --- a/easytier/src/utils/task.rs +++ b/easytier/src/utils/task.rs @@ -1,140 +1,76 @@ -use crate::common::scoped_task::ScopedTask; -use derivative::Derivative; -use derive_more::{Deref, DerefMut}; -use parking_lot::Mutex; use std::future::Future; -use std::mem::take; -use std::sync::Arc; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; -use tokio::sync::Notify; -use tokio::task::{AbortHandle, JoinError}; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +use tokio_util::task::AbortOnDropHandle; -#[derive(Derivative, Debug)] -#[derivative(Default(bound = ""))] -enum AsyncRuntimeState { - #[derivative(Default)] - Idle, - Running { - id: tokio::task::Id, - task: ScopedTask, - token: CancellationToken, - }, - Stopping(AbortHandle), +#[derive(Debug)] +pub struct CancellableTask { + handle: AbortOnDropHandle, + token: CancellationToken, } -#[derive(Derivative, Debug)] -#[derivative(Default(bound = ""))] -pub struct AsyncRuntimeInner { - state: Mutex>, - idle: Notify, -} - -#[derive(Derivative, Deref, DerefMut)] -#[derivative(Debug = "transparent", Default(bound = ""), Clone(bound = ""))] -pub struct AsyncRuntime(Arc>); - -impl AsyncRuntime { - pub fn token(&self) -> Option { - if let AsyncRuntimeState::Running { token, .. } = &*self.state.lock() { - Some(token.clone()) - } else { - None - } +impl CancellableTask { + pub fn token(&self) -> &CancellationToken { + &self.token } - pub fn start(&self, token: Option, factory: F) -> anyhow::Result<()> - where - F: FnOnce(CancellationToken) -> Fut, - Fut: Future + Send + 'static, - { - let mut state = self.state.lock(); - if !matches!(*state, AsyncRuntimeState::Idle) { - return Err(anyhow::anyhow!("task is already running/stopping")); - } - - let token = token.unwrap_or_default(); - - let task = { - let f = factory(token.clone()); - let this = (*self).clone(); - tokio::spawn(async move { - let result = f.await; - let mut state = this.state.lock(); - if let AsyncRuntimeState::Running { id, .. } = &*state - && *id == tokio::task::id() - { - take(&mut *state); - } - result - }) - }; - - *state = AsyncRuntimeState::Running { - id: task.id(), - task: task.into(), + pub fn with_handle(token: CancellationToken, handle: JoinHandle) -> Self { + Self { + handle: AbortOnDropHandle::new(handle), token, - }; - - Ok(()) + } } - pub async fn stop(&self, timeout: Duration) -> Option> { - let state = { - let mut state = self.state.lock(); - match &*state { - AsyncRuntimeState::Running { .. } => { - let AsyncRuntimeState::Running { task, token, .. } = take(&mut *state) else { - unreachable!() - }; - *state = AsyncRuntimeState::Stopping(task.abort_handle()); - Ok((task, token)) - } - AsyncRuntimeState::Stopping(_) => Err(self.idle.notified()), - AsyncRuntimeState::Idle => return None, - } - }; + pub async fn stop(mut self, timeout: Option) -> io::Result { + self.token.cancel(); - let (mut task, token) = match state { - Ok(running) => running, - Err(stopping) => { - stopping.await; - return None; - } - }; - - token.cancel(); - let result = if let Ok(result) = tokio::time::timeout(timeout, &mut task).await { - result - } else { - task.abort(); - tracing::warn!("task stop timeout after {:?}, aborted", timeout); - task.await - }; - - { - let mut state = self.state.lock(); - if matches!(*state, AsyncRuntimeState::Stopping(_)) { - *state = AsyncRuntimeState::Idle; - drop(state); - self.idle.notify_waiters(); - } - } - - Some(result) - } - - pub fn abort(&self) { - let mut state = self.state.lock(); - match &*state { - AsyncRuntimeState::Running { task, .. } => { - task.abort(); - *state = AsyncRuntimeState::Idle; - drop(state); - self.idle.notify_waiters(); - } - AsyncRuntimeState::Stopping(handle) => handle.abort(), - _ => {} + match timeout { + Some(timeout) => tokio::time::timeout(timeout, &mut self.handle) + .await + .map_err(|e| { + tracing::warn!("task stop timeout after {:?}, aborted", timeout); + io::Error::new(io::ErrorKind::TimedOut, e) + })?, + None => self.handle.await, } + .map_err(Into::into) + } +} + +impl CancellableTask { + pub fn new(token: CancellationToken, future: F) -> Self + where + F: Future + Send + 'static, + { + Self::with_handle(token, tokio::spawn(future)) + } + + pub fn spawn(factory: impl FnOnce(CancellationToken) -> F) -> Self + where + F: Future + Send + 'static, + { + let token = CancellationToken::new(); + Self::new(token.clone(), factory(token)) + } + + pub fn child(&self, factory: impl FnOnce(CancellationToken) -> F) -> Self + where + F: Future + Send + 'static, + { + let token = self.token.clone(); + Self::new(token.clone(), factory(token)) + } +} + +impl Future for CancellableTask { + type Output = io::Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.handle) + .poll(cx) + .map(|result| result.map_err(Into::into)) } }