use std::result::Result; use std::sync::{Arc, Mutex}; use async_trait::async_trait; use dashmap::DashMap; use tokio::select; use tokio::sync::Notify; use tokio::task::JoinHandle; use crate::common::scoped_task::ScopedTask; use anyhow::Error; use super::peer_manager::PeerManager; #[async_trait] pub trait PeerTaskLauncher: Send + Sync + Clone + 'static { type Data; type CollectPeerItem; type TaskRet; fn new_data(&self, peer_mgr: Arc) -> Self::Data; async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec; async fn launch_task( &self, data: &Self::Data, item: Self::CollectPeerItem, ) -> JoinHandle>; async fn all_task_done(&self, _data: &Self::Data) {} fn loop_interval_ms(&self) -> u64 { 5000 } } pub struct PeerTaskManager { launcher: Launcher, peer_mgr: Arc, main_loop_task: Mutex>>, run_signal: Arc, data: Launcher::Data, } impl PeerTaskManager where D: Send + Sync + Clone + 'static, C: std::fmt::Debug + Send + Sync + Clone + core::hash::Hash + Eq + 'static, T: Send + 'static, L: PeerTaskLauncher + 'static, { pub fn new(launcher: L, peer_mgr: Arc) -> Self { let data = launcher.new_data(peer_mgr.clone()); Self { launcher, peer_mgr, main_loop_task: Mutex::new(None), run_signal: Arc::new(Notify::new()), data, } } pub fn start(&self) { let task = tokio::spawn(Self::main_loop( self.launcher.clone(), self.data.clone(), self.run_signal.clone(), )) .into(); self.main_loop_task.lock().unwrap().replace(task); } async fn main_loop(launcher: L, data: D, signal: Arc) { let peer_task_map = Arc::new(DashMap::>>::new()); loop { let peers_to_connect = launcher.collect_peers_need_task(&data).await; // remove task not in peers_to_connect let mut to_remove = vec![]; for item in peer_task_map.iter() { if !peers_to_connect.contains(item.key()) || item.value().is_finished() { to_remove.push(item.key().clone()); } } for key in to_remove { if let Some((_, task)) = peer_task_map.remove(&key) { task.abort(); match task.await { Ok(Ok(_)) => {} Ok(Err(task_ret)) => { tracing::error!(?task_ret, "hole punching task failed"); } Err(e) => { tracing::error!(?e, "hole punching task aborted"); } } } peer_task_map.shrink_to_fit(); } if !peers_to_connect.is_empty() { for item in peers_to_connect { if peer_task_map.contains_key(&item) { continue; } tracing::debug!(?item, "launch hole punching task"); peer_task_map .insert(item.clone(), launcher.launch_task(&data, item).await.into()); } } else if peer_task_map.is_empty() { launcher.all_task_done(&data).await; } select! { _ = tokio::time::sleep(std::time::Duration::from_millis( launcher.loop_interval_ms(), )) => {}, _ = signal.notified() => {} } } } pub async fn run_immediately(&self) { self.run_signal.notify_one(); } pub fn data(&self) -> D { self.data.clone() } }