From bcb2e512d480db94265adff9784701babaa28756 Mon Sep 17 00:00:00 2001 From: Luna Yao <40349250+ZnqbuZ@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:32:07 +0200 Subject: [PATCH] utils: move code to a dedicated mod; add AsyncRuntime (#2072) --- easytier-gui/src-tauri/src/lib.rs | 4 +- easytier-web/src/main.rs | 2 +- easytier/src/common/ifcfg/mod.rs | 4 +- easytier/src/common/scoped_task.rs | 25 +--- easytier/src/core.rs | 2 +- easytier/src/easytier-cli.rs | 2 +- easytier/src/utils/mod.rs | 30 +++++ easytier/src/{utils.rs => utils/panic.rs} | 58 +-------- easytier/src/utils/string.rs | 26 ++++ easytier/src/utils/task.rs | 140 ++++++++++++++++++++++ 10 files changed, 211 insertions(+), 82 deletions(-) create mode 100644 easytier/src/utils/mod.rs rename easytier/src/{utils.rs => utils/panic.rs} (68%) create mode 100644 easytier/src/utils/string.rs create mode 100644 easytier/src/utils/task.rs diff --git a/easytier-gui/src-tauri/src/lib.rs b/easytier-gui/src-tauri/src/lib.rs index 2b6c613b..d4dd6a4d 100644 --- a/easytier-gui/src-tauri/src/lib.rs +++ b/easytier-gui/src-tauri/src/lib.rs @@ -24,7 +24,7 @@ use easytier::{ tunnel::TunnelListener, tunnel::ring::RingTunnelListener, tunnel::tcp::TcpTunnelListener, - utils::{self}, + utils::panic::setup_panic_handler, }; use std::ops::Deref; use std::sync::Arc; @@ -1120,7 +1120,7 @@ pub fn run_gui() -> std::process::ExitCode { process::exit(0); } - utils::setup_panic_handler(); + setup_panic_handler(); let mut builder = tauri::Builder::default(); diff --git a/easytier-web/src/main.rs b/easytier-web/src/main.rs index 6c741727..10404d3c 100644 --- a/easytier-web/src/main.rs +++ b/easytier-web/src/main.rs @@ -17,7 +17,7 @@ use easytier::{ network::{local_ipv4, local_ipv6}, }, tunnel::{TunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener}, - utils::setup_panic_handler, + utils::panic::setup_panic_handler, }; use easytier::tunnel::IpScheme; diff --git a/easytier/src/common/ifcfg/mod.rs b/easytier/src/common/ifcfg/mod.rs index 568e71bc..a8e4085f 100644 --- a/easytier/src/common/ifcfg/mod.rs +++ b/easytier/src/common/ifcfg/mod.rs @@ -119,8 +119,8 @@ async fn run_shell_cmd(cmd: &str) -> Result<(), Error> { .creation_flags(CREATE_NO_WINDOW) .output() .await?; - stdout = crate::utils::utf8_or_gbk_to_string(cmd_out.stdout.as_slice()); - stderr = crate::utils::utf8_or_gbk_to_string(cmd_out.stderr.as_slice()); + stdout = crate::utils::string::utf8_or_gbk_to_string(cmd_out.stdout.as_slice()); + stderr = crate::utils::string::utf8_or_gbk_to_string(cmd_out.stderr.as_slice()); }; #[cfg(not(target_os = "windows"))] diff --git a/easytier/src/common/scoped_task.rs b/easytier/src/common/scoped_task.rs index 56690087..b87a1f19 100644 --- a/easytier/src/common/scoped_task.rs +++ b/easytier/src/common/scoped_task.rs @@ -4,40 +4,25 @@ //! For example, if task A spawned task B but is doing something else, and task B is waiting for task C to join, //! aborting A will also abort both B and C. +use derive_more::{Deref, DerefMut, From}; use std::future::Future; -use std::ops::Deref; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::task::JoinHandle; -#[derive(Debug)] -pub struct ScopedTask { - inner: JoinHandle, -} +#[derive(Debug, From, Deref, DerefMut)] +pub struct ScopedTask(JoinHandle); impl Drop for ScopedTask { fn drop(&mut self) { - self.inner.abort() + self.abort() } } impl Future for ScopedTask { type Output = as Future>::Output; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx) - } -} - -impl From> for ScopedTask { - fn from(inner: JoinHandle) -> Self { - Self { inner } - } -} - -impl Deref for ScopedTask { - type Target = JoinHandle; - fn deref(&self) -> &Self::Target { - &self.inner + Pin::new(&mut self.0).poll(cx) } } diff --git a/easytier/src/core.rs b/easytier/src/core.rs index 218d87af..634f36bb 100644 --- a/easytier/src/core.rs +++ b/easytier/src/core.rs @@ -23,7 +23,7 @@ use crate::{ launcher::add_proxy_network_to_config, proto::common::{CompressionAlgoPb, SecureModeConfig}, rpc_service::ApiRpcServer, - utils::setup_panic_handler, + utils::panic::setup_panic_handler, web_client, }; use anyhow::Context; diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index af55343d..f18a131e 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -76,7 +76,7 @@ use easytier::{ rpc_types::controller::BaseController, }, tunnel::{TunnelScheme, tcp::TcpTunnelConnector}, - utils::{PeerRoutePair, cost_to_str}, + utils::{PeerRoutePair, string::cost_to_str}, }; rust_i18n::i18n!("locales", fallback = "en"); diff --git a/easytier/src/utils/mod.rs b/easytier/src/utils/mod.rs new file mode 100644 index 00000000..1339c10d --- /dev/null +++ b/easytier/src/utils/mod.rs @@ -0,0 +1,30 @@ +pub mod panic; +pub mod string; +pub mod task; + +use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener}; +use std::sync::{Arc, Weak}; + +pub type PeerRoutePair = crate::proto::api::instance::PeerRoutePair; + +pub fn check_tcp_available(port: u16) -> bool { + let s = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port); + TcpListener::bind(s).is_ok() +} + +pub fn find_free_tcp_port(mut range: std::ops::Range) -> Option { + range.find(|&port| check_tcp_available(port)) +} + +pub fn weak_upgrade(weak: &Weak) -> anyhow::Result> { + weak.upgrade() + .ok_or_else(|| anyhow::anyhow!("{} not available", std::any::type_name::())) +} + +pub trait BoxExt: Sized { + fn boxed(self) -> Box { + Box::new(self) + } +} + +impl BoxExt for T {} diff --git a/easytier/src/utils.rs b/easytier/src/utils/panic.rs similarity index 68% rename from easytier/src/utils.rs rename to easytier/src/utils/panic.rs index e2b35676..6421824b 100644 --- a/easytier/src/utils.rs +++ b/easytier/src/utils/panic.rs @@ -1,43 +1,14 @@ use crate::common::log; use indoc::formatdoc; -use std::sync::Arc; -use std::{fs::OpenOptions, str::FromStr}; - -pub type PeerRoutePair = crate::proto::api::instance::PeerRoutePair; - -pub fn cost_to_str(cost: i32) -> String { - if cost == 1 { - "p2p".to_string() - } else { - format!("relay({})", cost) - } -} - -pub fn float_to_str(f: f64, precision: usize) -> String { - format!("{:.1$}", f, precision) -} - -#[cfg(target_os = "windows")] -pub fn utf8_or_gbk_to_string(s: &[u8]) -> String { - use encoding::{DecoderTrap, Encoding, all::GBK}; - if let Ok(utf8_str) = String::from_utf8(s.to_vec()) { - utf8_str - } else { - // 如果解码失败,则尝试使用GBK解码 - if let Ok(gbk_str) = GBK.decode(s, DecoderTrap::Strict) { - gbk_str - } else { - String::from_utf8_lossy(s).to_string() - } - } -} +use std::fs::OpenOptions; +use std::str::FromStr; +use std::{backtrace, io::Write}; thread_local! { static PANIC_COUNT : std::cell::RefCell = const { std::cell::RefCell::new(0) }; } pub fn setup_panic_handler() { - use std::{backtrace, io::Write}; std::panic::set_hook(Box::new(|info| { let mut stderr = std::io::stderr(); let sep = format!("{}\n", "=======".repeat(10)); @@ -126,26 +97,3 @@ pub fn setup_panic_handler() { std::process::exit(1); })); } - -pub fn check_tcp_available(port: u16) -> bool { - use std::net::TcpListener; - let s = std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), port); - TcpListener::bind(s).is_ok() -} - -pub fn find_free_tcp_port(mut range: std::ops::Range) -> Option { - range.find(|&port| check_tcp_available(port)) -} - -pub fn weak_upgrade(weak: &std::sync::Weak) -> anyhow::Result> { - weak.upgrade() - .ok_or_else(|| anyhow::anyhow!("{} not available", std::any::type_name::())) -} - -pub trait BoxExt: Sized { - fn boxed(self) -> Box { - Box::new(self) - } -} - -impl BoxExt for T {} diff --git a/easytier/src/utils/string.rs b/easytier/src/utils/string.rs new file mode 100644 index 00000000..8dab764f --- /dev/null +++ b/easytier/src/utils/string.rs @@ -0,0 +1,26 @@ +pub fn cost_to_str(cost: i32) -> String { + if cost == 1 { + "p2p".to_string() + } else { + format!("relay({})", cost) + } +} + +pub fn float_to_str(f: f64, precision: usize) -> String { + format!("{:.1$}", f, precision) +} + +#[cfg(target_os = "windows")] +pub fn utf8_or_gbk_to_string(s: &[u8]) -> String { + use encoding::{DecoderTrap, Encoding, all::GBK}; + if let Ok(utf8_str) = String::from_utf8(s.to_vec()) { + utf8_str + } else { + // 如果解码失败,则尝试使用GBK解码 + if let Ok(gbk_str) = GBK.decode(s, DecoderTrap::Strict) { + gbk_str + } else { + String::from_utf8_lossy(s).to_string() + } + } +} diff --git a/easytier/src/utils/task.rs b/easytier/src/utils/task.rs new file mode 100644 index 00000000..db25df64 --- /dev/null +++ b/easytier/src/utils/task.rs @@ -0,0 +1,140 @@ +use crate::common::scoped_task::ScopedTask; +use derivative::Derivative; +use derive_more::{Deref, DerefMut}; +use parking_lot::Mutex; +use std::future::Future; +use std::mem::take; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Notify; +use tokio::task::{AbortHandle, JoinError}; +use tokio_util::sync::CancellationToken; + +#[derive(Derivative, Debug)] +#[derivative(Default(bound = ""))] +enum AsyncRuntimeState { + #[derivative(Default)] + Idle, + Running { + id: tokio::task::Id, + task: ScopedTask, + token: CancellationToken, + }, + Stopping(AbortHandle), +} + +#[derive(Derivative, Debug)] +#[derivative(Default(bound = ""))] +pub struct AsyncRuntimeInner { + state: Mutex>, + idle: Notify, +} + +#[derive(Derivative, Deref, DerefMut)] +#[derivative(Debug = "transparent", Default(bound = ""), Clone(bound = ""))] +pub struct AsyncRuntime(Arc>); + +impl AsyncRuntime { + pub fn token(&self) -> Option { + if let AsyncRuntimeState::Running { token, .. } = &*self.state.lock() { + Some(token.clone()) + } else { + None + } + } + + pub fn start(&self, token: Option, factory: F) -> anyhow::Result<()> + where + F: FnOnce(CancellationToken) -> Fut, + Fut: Future + Send + 'static, + { + let mut state = self.state.lock(); + if !matches!(*state, AsyncRuntimeState::Idle) { + return Err(anyhow::anyhow!("task is already running/stopping")); + } + + let token = token.unwrap_or_default(); + + let task = { + let f = factory(token.clone()); + let this = (*self).clone(); + tokio::spawn(async move { + let result = f.await; + let mut state = this.state.lock(); + if let AsyncRuntimeState::Running { id, .. } = &*state + && *id == tokio::task::id() + { + take(&mut *state); + } + result + }) + }; + + *state = AsyncRuntimeState::Running { + id: task.id(), + task: task.into(), + token, + }; + + Ok(()) + } + + pub async fn stop(&self, timeout: Duration) -> Option> { + let state = { + let mut state = self.state.lock(); + match &*state { + AsyncRuntimeState::Running { .. } => { + let AsyncRuntimeState::Running { task, token, .. } = take(&mut *state) else { + unreachable!() + }; + *state = AsyncRuntimeState::Stopping(task.abort_handle()); + Ok((task, token)) + } + AsyncRuntimeState::Stopping(_) => Err(self.idle.notified()), + AsyncRuntimeState::Idle => return None, + } + }; + + let (mut task, token) = match state { + Ok(running) => running, + Err(stopping) => { + stopping.await; + return None; + } + }; + + token.cancel(); + let result = if let Ok(result) = tokio::time::timeout(timeout, &mut task).await { + result + } else { + task.abort(); + tracing::warn!("task stop timeout after {:?}, aborted", timeout); + task.await + }; + + { + let mut state = self.state.lock(); + if matches!(*state, AsyncRuntimeState::Stopping(_)) { + *state = AsyncRuntimeState::Idle; + drop(state); + self.idle.notify_waiters(); + } + } + + Some(result) + } + + pub fn abort(&self) { + let mut state = self.state.lock(); + match &*state { + AsyncRuntimeState::Running { task, .. } => { + task.abort(); + *state = AsyncRuntimeState::Idle; + drop(state); + self.idle.notify_waiters(); + } + AsyncRuntimeState::Stopping(handle) => handle.abort(), + _ => {} + } + } +}