replace AsyncRuntime with simpler CancellableTask (#2136)

This commit is contained in:
Luna Yao
2026-04-25 04:29:53 +02:00
committed by GitHub
parent 2fb41ccbba
commit 820d9095d3
2 changed files with 62 additions and 126 deletions
+1 -1
View File
@@ -62,7 +62,7 @@ futures = { version = "0.3", features = ["bilock", "unstable"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1" tokio-stream = "0.1"
tokio-util = { version = "0.7.9", features = ["codec", "net", "io"] } tokio-util = { version = "0.7.9", features = ["codec", "net", "io", "rt"] }
async-stream = "0.3.5" async-stream = "0.3.5"
async-trait = "0.1.74" async-trait = "0.1.74"
+55 -119
View File
@@ -1,140 +1,76 @@
use crate::common::scoped_task::ScopedTask;
use derivative::Derivative;
use derive_more::{Deref, DerefMut};
use parking_lot::Mutex;
use std::future::Future; use std::future::Future;
use std::mem::take; use std::io;
use std::sync::Arc; use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use tokio::sync::Notify; use tokio::task::JoinHandle;
use tokio::task::{AbortHandle, JoinError};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
#[derive(Derivative, Debug)] #[derive(Debug)]
#[derivative(Default(bound = ""))] pub struct CancellableTask<Output> {
enum AsyncRuntimeState<R: Send + 'static> { handle: AbortOnDropHandle<Output>,
#[derivative(Default)]
Idle,
Running {
id: tokio::task::Id,
task: ScopedTask<R>,
token: CancellationToken, token: CancellationToken,
},
Stopping(AbortHandle),
} }
#[derive(Derivative, Debug)] impl<Output> CancellableTask<Output> {
#[derivative(Default(bound = ""))] pub fn token(&self) -> &CancellationToken {
pub struct AsyncRuntimeInner<R: Send + 'static = ()> { &self.token
state: Mutex<AsyncRuntimeState<R>>,
idle: Notify,
} }
#[derive(Derivative, Deref, DerefMut)] pub fn with_handle(token: CancellationToken, handle: JoinHandle<Output>) -> Self {
#[derivative(Debug = "transparent", Default(bound = ""), Clone(bound = ""))] Self {
pub struct AsyncRuntime<R: Send + 'static = ()>(Arc<AsyncRuntimeInner<R>>); handle: AbortOnDropHandle::new(handle),
impl<R: Send + 'static> AsyncRuntime<R> {
pub fn token(&self) -> Option<CancellationToken> {
if let AsyncRuntimeState::Running { token, .. } = &*self.state.lock() {
Some(token.clone())
} else {
None
}
}
pub fn start<F, Fut>(&self, token: Option<CancellationToken>, factory: F) -> anyhow::Result<()>
where
F: FnOnce(CancellationToken) -> Fut,
Fut: Future<Output = R> + Send + 'static,
{
let mut state = self.state.lock();
if !matches!(*state, AsyncRuntimeState::Idle) {
return Err(anyhow::anyhow!("task is already running/stopping"));
}
let token = token.unwrap_or_default();
let task = {
let f = factory(token.clone());
let this = (*self).clone();
tokio::spawn(async move {
let result = f.await;
let mut state = this.state.lock();
if let AsyncRuntimeState::Running { id, .. } = &*state
&& *id == tokio::task::id()
{
take(&mut *state);
}
result
})
};
*state = AsyncRuntimeState::Running {
id: task.id(),
task: task.into(),
token, token,
}; }
Ok(())
} }
pub async fn stop(&self, timeout: Duration) -> Option<Result<R, JoinError>> { pub async fn stop(mut self, timeout: Option<Duration>) -> io::Result<Output> {
let state = { self.token.cancel();
let mut state = self.state.lock();
match &*state {
AsyncRuntimeState::Running { .. } => {
let AsyncRuntimeState::Running { task, token, .. } = take(&mut *state) else {
unreachable!()
};
*state = AsyncRuntimeState::Stopping(task.abort_handle());
Ok((task, token))
}
AsyncRuntimeState::Stopping(_) => Err(self.idle.notified()),
AsyncRuntimeState::Idle => return None,
}
};
let (mut task, token) = match state { match timeout {
Ok(running) => running, Some(timeout) => tokio::time::timeout(timeout, &mut self.handle)
Err(stopping) => { .await
stopping.await; .map_err(|e| {
return None;
}
};
token.cancel();
let result = if let Ok(result) = tokio::time::timeout(timeout, &mut task).await {
result
} else {
task.abort();
tracing::warn!("task stop timeout after {:?}, aborted", timeout); tracing::warn!("task stop timeout after {:?}, aborted", timeout);
task.await io::Error::new(io::ErrorKind::TimedOut, e)
}; })?,
None => self.handle.await,
}
.map_err(Into::into)
}
}
impl<Output: Send + 'static> CancellableTask<Output> {
pub fn new<F>(token: CancellationToken, future: F) -> Self
where
F: Future<Output = Output> + Send + 'static,
{ {
let mut state = self.state.lock(); Self::with_handle(token, tokio::spawn(future))
if matches!(*state, AsyncRuntimeState::Stopping(_)) { }
*state = AsyncRuntimeState::Idle;
drop(state); pub fn spawn<F>(factory: impl FnOnce(CancellationToken) -> F) -> Self
self.idle.notify_waiters(); where
F: Future<Output = Output> + Send + 'static,
{
let token = CancellationToken::new();
Self::new(token.clone(), factory(token))
}
pub fn child<F>(&self, factory: impl FnOnce(CancellationToken) -> F) -> Self
where
F: Future<Output = Output> + Send + 'static,
{
let token = self.token.clone();
Self::new(token.clone(), factory(token))
} }
} }
Some(result) impl<Output> Future for CancellableTask<Output> {
} type Output = io::Result<Output>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
pub fn abort(&self) { Pin::new(&mut self.handle)
let mut state = self.state.lock(); .poll(cx)
match &*state { .map(|result| result.map_err(Into::into))
AsyncRuntimeState::Running { task, .. } => {
task.abort();
*state = AsyncRuntimeState::Idle;
drop(state);
self.idle.notify_waiters();
}
AsyncRuntimeState::Stopping(handle) => handle.abort(),
_ => {}
}
} }
} }