use std::{fmt::Debug, sync::Arc}; use easytier::{ common::scoped_task::ScopedTask, proto::{ rpc_impl::bidirect::BidirectRpcManager, rpc_types::{self, controller::BaseController}, web::{ HeartbeatRequest, HeartbeatResponse, NetworkConfig, RunNetworkInstanceRequest, WebClientService, WebClientServiceClientFactory, WebServerService, WebServerServiceServer, }, }, tunnel::Tunnel, }; use tokio::sync::{broadcast, RwLock}; use super::storage::{Storage, StorageToken, WeakRefStorage}; #[derive(Debug)] pub struct SessionData { storage: WeakRefStorage, client_url: url::Url, storage_token: Option, notifier: broadcast::Sender, req: Option, } impl SessionData { fn new(storage: WeakRefStorage, client_url: url::Url) -> Self { let (tx, _rx1) = broadcast::channel(2); SessionData { storage, client_url, storage_token: None, notifier: tx, req: None, } } pub fn req(&self) -> Option { self.req.clone() } pub fn heartbeat_waiter(&self) -> broadcast::Receiver { self.notifier.subscribe() } } impl Drop for SessionData { fn drop(&mut self) { if let Ok(storage) = Storage::try_from(self.storage.clone()) { if let Some(token) = self.storage_token.as_ref() { storage.remove_client(token); } } } } pub type SharedSessionData = Arc>; #[derive(Clone)] struct SessionRpcService { data: SharedSessionData, } #[async_trait::async_trait] impl WebServerService for SessionRpcService { type Controller = BaseController; async fn heartbeat( &self, _: BaseController, req: HeartbeatRequest, ) -> rpc_types::error::Result { let mut data = self.data.write().await; if data.req.replace(req.clone()).is_none() { assert!(data.storage_token.is_none()); data.storage_token = Some(StorageToken { token: req.user_token.clone().into(), client_url: data.client_url.clone(), machine_id: req .machine_id .clone() .map(Into::into) .unwrap_or(uuid::Uuid::new_v4()), }); if let Ok(storage) = Storage::try_from(data.storage.clone()) { storage.add_client(data.storage_token.as_ref().unwrap().clone()); } } let _ = data.notifier.send(req); Ok(HeartbeatResponse {}) } } pub struct Session { rpc_mgr: BidirectRpcManager, data: SharedSessionData, run_network_on_start_task: Option>, } impl Debug for Session { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Session").field("data", &self.data).finish() } } type SessionRpcClient = Box + Send>; impl Session { pub fn new(storage: WeakRefStorage, client_url: url::Url) -> Self { let session_data = SessionData::new(storage, client_url); let data = Arc::new(RwLock::new(session_data)); let rpc_mgr = BidirectRpcManager::new().set_rx_timeout(Some(std::time::Duration::from_secs(30))); rpc_mgr.rpc_server().registry().register( WebServerServiceServer::new(SessionRpcService { data: data.clone() }), "", ); Session { rpc_mgr, data, run_network_on_start_task: None, } } pub async fn serve(&mut self, tunnel: Box) { self.rpc_mgr.run_with_tunnel(tunnel); let data = self.data.read().await; self.run_network_on_start_task.replace( tokio::spawn(Self::run_network_on_start( data.heartbeat_waiter(), data.storage.clone(), self.scoped_rpc_client(), )) .into(), ); } async fn run_network_on_start( mut heartbeat_waiter: broadcast::Receiver, storage: WeakRefStorage, rpc_client: SessionRpcClient, ) { loop { heartbeat_waiter = heartbeat_waiter.resubscribe(); let req = heartbeat_waiter.recv().await; if req.is_err() { tracing::error!( "Failed to receive heartbeat request, error: {:?}", req.err() ); return; } let req = req.unwrap(); if req.machine_id.is_none() { tracing::warn!(?req, "Machine id is not set, ignore"); continue; } let running_inst_ids = req .running_network_instances .iter() .map(|x| x.to_string()) .collect::>(); let Some(storage) = storage.upgrade() else { tracing::error!("Failed to get storage"); return; }; let user_id = match storage .db .get_user_id_by_token(req.user_token.clone()) .await { Ok(Some(user_id)) => user_id, Ok(None) => { tracing::info!("User not found by token: {:?}", req.user_token); return; } Err(e) => { tracing::error!("Failed to get user id by token, error: {:?}", e); return; } }; let local_configs = match storage .db .list_network_configs(user_id, Some(req.machine_id.unwrap().into()), true) .await { Ok(configs) => configs, Err(e) => { tracing::error!("Failed to list network configs, error: {:?}", e); return; } }; let mut has_failed = false; for c in local_configs { if running_inst_ids.contains(&c.network_instance_id) { continue; } let ret = rpc_client .run_network_instance( BaseController::default(), RunNetworkInstanceRequest { inst_id: Some(c.network_instance_id.clone().into()), config: Some( serde_json::from_str::(&c.network_config).unwrap(), ), }, ) .await; tracing::info!( ?user_id, "Run network instance: {:?}, user_token: {:?}", ret, req.user_token ); has_failed |= ret.is_err(); } if !has_failed { tracing::info!(?req, "All network instances are running"); break; } } } pub fn is_running(&self) -> bool { self.rpc_mgr.is_running() } pub fn data(&self) -> SharedSessionData { self.data.clone() } pub fn scoped_rpc_client(&self) -> SessionRpcClient { self.rpc_mgr .rpc_client() .scoped_client::>(1, 1, "".to_string()) } pub async fn get_token(&self) -> Option { self.data.read().await.storage_token.clone() } pub async fn get_heartbeat_req(&self) -> Option { self.data.read().await.req() } }