use std::{ collections::{HashMap, HashSet}, fmt::Debug, str::FromStr as _, sync::Arc, }; use anyhow::Context; use easytier::{ common::scoped_task::ScopedTask, proto::{ api::manage::{ NetworkConfig, RunNetworkInstanceRequest, WebClientService, WebClientServiceClientFactory, }, rpc_impl::bidirect::BidirectRpcManager, rpc_types::{self, controller::BaseController}, web::{HeartbeatRequest, HeartbeatResponse, WebServerService, WebServerServiceServer}, }, rpc_service::remote_client::{ListNetworkProps, Storage as _}, tunnel::Tunnel, }; use tokio::sync::{RwLock, broadcast}; use super::storage::{Storage, StorageToken, WeakRefStorage}; use crate::FeatureFlags; use crate::webhook::SharedWebhookConfig; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Location { pub country: String, pub city: Option, pub region: Option, } #[derive(Debug)] pub struct SessionData { storage: WeakRefStorage, feature_flags: Arc, webhook_config: SharedWebhookConfig, client_url: url::Url, storage_token: Option, binding_version: Option, applied_config_revision: Option, notifier: broadcast::Sender, req: Option, location: Option, } impl SessionData { fn new( storage: WeakRefStorage, client_url: url::Url, location: Option, feature_flags: Arc, webhook_config: SharedWebhookConfig, ) -> Self { let (tx, _rx1) = broadcast::channel(2); SessionData { storage, feature_flags, webhook_config, client_url, storage_token: None, binding_version: None, applied_config_revision: None, notifier: tx, req: None, location, } } pub fn req(&self) -> Option { self.req.clone() } pub fn heartbeat_waiter(&self) -> broadcast::Receiver { self.notifier.subscribe() } pub fn location(&self) -> Option<&Location> { self.location.as_ref() } } impl Drop for SessionData { fn drop(&mut self) { if let Ok(storage) = Storage::try_from(self.storage.clone()) && let Some(token) = self.storage_token.as_ref() { storage.remove_client(token); // Notify the webhook receiver when a node disconnects. if self.webhook_config.is_enabled() { let webhook = self.webhook_config.clone(); let machine_id = token.machine_id.to_string(); let user_id = Some(token.user_id); let token_value = token.token.clone(); let web_instance_id = webhook.web_instance_id.clone(); let binding_version = self.binding_version; tokio::spawn(async move { webhook .notify_node_disconnected(&crate::webhook::NodeDisconnectedRequest { machine_id, token: token_value, user_id, web_instance_id, binding_version, }) .await; }); } } } } pub type SharedSessionData = Arc>; #[derive(Clone)] struct SessionRpcService { data: SharedSessionData, } impl SessionRpcService { fn normalize_network_config( mut network_config: serde_json::Value, inst_id: uuid::Uuid, ) -> anyhow::Result { let network_name = network_config .get("network_name") .and_then(|v| v.as_str()) .filter(|v| !v.is_empty()) .ok_or_else(|| anyhow::anyhow!("webhook response missing network_name"))? .to_string(); let config_obj = network_config .as_object_mut() .ok_or_else(|| anyhow::anyhow!("webhook network_config must be a JSON object"))?; config_obj.insert( "instance_id".to_string(), serde_json::Value::String(inst_id.to_string()), ); config_obj .entry("instance_name".to_string()) .or_insert_with(|| serde_json::Value::String(network_name)); Ok(serde_json::from_value::(network_config)?) } async fn reconcile_managed_network_configs( storage: &Storage, user_id: i32, machine_id: uuid::Uuid, desired_configs: Vec, ) -> anyhow::Result<()> { let existing_configs = storage .db() .list_network_configs((user_id, machine_id), ListNetworkProps::All) .await .map_err(|e| anyhow::anyhow!("failed to list existing network configs: {:?}", e))?; let existing_ids = existing_configs .iter() .filter_map(|cfg| uuid::Uuid::parse_str(&cfg.network_instance_id).ok()) .collect::>(); let mut desired_ids = HashSet::with_capacity(desired_configs.len()); let mut normalized = HashMap::with_capacity(desired_configs.len()); for desired in desired_configs { let inst_id = uuid::Uuid::parse_str(&desired.instance_id).with_context(|| { format!( "invalid desired managed instance id: {}", desired.instance_id ) })?; let config = Self::normalize_network_config(desired.network_config, inst_id)?; desired_ids.insert(inst_id); normalized.insert(inst_id, config); } for (inst_id, config) in normalized { storage .db() .insert_or_update_user_network_config((user_id, machine_id), inst_id, config) .await .map_err(|e| { anyhow::anyhow!( "failed to persist managed network config {}: {:?}", inst_id, e ) })?; } let stale_ids = existing_ids .difference(&desired_ids) .copied() .collect::>(); if !stale_ids.is_empty() { storage .db() .delete_network_configs((user_id, machine_id), &stale_ids) .await .map_err(|e| anyhow::anyhow!("failed to delete stale network configs: {:?}", e))?; } Ok(()) } async fn handle_heartbeat( &self, req: HeartbeatRequest, ) -> rpc_types::error::Result { let mut data = self.data.write().await; let Ok(storage) = Storage::try_from(data.storage.clone()) else { tracing::error!("Failed to get storage"); return Ok(HeartbeatResponse {}); }; let machine_id: uuid::Uuid = req.machine_id.map(Into::into).ok_or(anyhow::anyhow!( "Machine id is not set correctly, expect uuid but got: {:?}", req.machine_id ))?; let ( user_id, webhook_managed_network_configs, webhook_config_revision, webhook_validated, binding_version, ) = if data.webhook_config.is_enabled() { let webhook_req = crate::webhook::ValidateTokenRequest { token: req.user_token.clone(), machine_id: machine_id.to_string(), public_ip: data.client_url.host_str().map(str::to_string), hostname: req.hostname.clone(), version: req.easytier_version.clone(), os_type: req.device_os.as_ref().map(|info| info.os_type.clone()), os_version: req.device_os.as_ref().map(|info| info.version.clone()), os_distribution: req.device_os.as_ref().map(|info| info.distribution.clone()), web_instance_id: data.webhook_config.web_instance_id.clone(), web_instance_api_base_url: data.webhook_config.web_instance_api_base_url.clone(), }; let resp = data .webhook_config .validate_token(&webhook_req) .await .map_err(|e| anyhow::anyhow!("Webhook token validation failed: {:?}", e))?; if resp.valid { let user_id = match storage .db() .get_user_id_by_token(req.user_token.clone()) .await .map_err(|e| anyhow::anyhow!("DB error: {:?}", e))? { Some(id) => id, None => storage .auto_create_user(&req.user_token) .await .with_context(|| { format!("Failed to auto-create webhook user: {:?}", req.user_token) })?, }; ( user_id, resp.managed_network_configs, resp.config_revision, true, Some(resp.binding_version), ) } else { return Err(anyhow::anyhow!( "Webhook rejected token for machine {:?}: {:?}", machine_id, req.user_token ) .into()); } } else { let user_id = match storage .db() .get_user_id_by_token(req.user_token.clone()) .await .with_context(|| { format!( "Failed to get user id by token from db: {:?}", req.user_token ) })? { Some(id) => id, None if data.feature_flags.allow_auto_create_user => storage .auto_create_user(&req.user_token) .await .with_context(|| format!("Failed to auto-create user: {:?}", req.user_token))?, None => { return Err( anyhow::anyhow!("User not found by token: {:?}", req.user_token).into(), ); } }; (user_id, Vec::new(), String::new(), false, None) }; if webhook_validated && data.applied_config_revision.as_deref() != Some(webhook_config_revision.as_str()) { Self::reconcile_managed_network_configs( &storage, user_id, machine_id, webhook_managed_network_configs, ) .await .map_err(rpc_types::error::Error::from)?; data.applied_config_revision = Some(webhook_config_revision); } if data.req.replace(req.clone()).is_none() { assert!(data.storage_token.is_none()); data.storage_token = Some(StorageToken { token: req.user_token.clone(), client_url: data.client_url.clone(), machine_id, user_id, }); data.binding_version = binding_version; // Notify the webhook receiver on the first successful heartbeat. if data.webhook_config.is_enabled() { let webhook = data.webhook_config.clone(); let connect_req = crate::webhook::NodeConnectedRequest { machine_id: machine_id.to_string(), token: req.user_token.clone(), user_id: Some(user_id), hostname: req.hostname.clone(), version: req.easytier_version.clone(), os_type: req.device_os.as_ref().map(|info| info.os_type.clone()), os_version: req.device_os.as_ref().map(|info| info.version.clone()), os_distribution: req.device_os.as_ref().map(|info| info.distribution.clone()), web_instance_id: webhook.web_instance_id.clone(), binding_version, }; tokio::spawn(async move { webhook.notify_node_connected(&connect_req).await; }); } } let Ok(report_time) = chrono::DateTime::::from_str(&req.report_time) else { tracing::error!("Failed to parse report time: {:?}", req.report_time); return Ok(HeartbeatResponse {}); }; storage.update_client( data.storage_token.as_ref().unwrap().clone(), report_time.timestamp(), ); let _ = data.notifier.send(req); Ok(HeartbeatResponse {}) } } #[async_trait::async_trait] impl WebServerService for SessionRpcService { type Controller = BaseController; async fn heartbeat( &self, _: BaseController, req: HeartbeatRequest, ) -> rpc_types::error::Result { let ret = self.handle_heartbeat(req).await; if ret.is_err() { tracing::warn!("Failed to handle heartbeat: {:?}", ret); // sleep for a while to avoid client busy loop tokio::time::sleep(std::time::Duration::from_secs(2)).await; } ret } async fn get_feature( &self, _: BaseController, _: easytier::proto::web::GetFeatureRequest, ) -> rpc_types::error::Result { Ok(easytier::proto::web::GetFeatureResponse { support_encryption: true, }) } } pub struct Session { rpc_mgr: BidirectRpcManager, data: SharedSessionData, run_network_on_start_task: Option>, } impl Debug for Session { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Session").field("data", &self.data).finish() } } type SessionRpcClient = Box + Send>; impl Session { pub fn new( storage: WeakRefStorage, client_url: url::Url, location: Option, feature_flags: Arc, webhook_config: SharedWebhookConfig, ) -> Self { let session_data = SessionData::new(storage, client_url, location, feature_flags, webhook_config); let data = Arc::new(RwLock::new(session_data)); let rpc_mgr = BidirectRpcManager::new().set_rx_timeout(Some(std::time::Duration::from_secs(30))); rpc_mgr.rpc_server().registry().register( WebServerServiceServer::new(SessionRpcService { data: data.clone() }), "", ); Session { rpc_mgr, data, run_network_on_start_task: None, } } pub async fn serve(&mut self, tunnel: Box) { self.rpc_mgr.run_with_tunnel(tunnel); let data = self.data.read().await; self.run_network_on_start_task.replace( tokio::spawn(Self::run_network_on_start( data.heartbeat_waiter(), data.storage.clone(), self.scoped_rpc_client(), )) .into(), ); } async fn run_network_on_start( mut heartbeat_waiter: broadcast::Receiver, storage: WeakRefStorage, rpc_client: SessionRpcClient, ) { let mut cleaned_web_managed_instances = false; let mut last_desired_inst_ids: Option> = None; loop { heartbeat_waiter = heartbeat_waiter.resubscribe(); let req = heartbeat_waiter.recv().await; if req.is_err() { tracing::error!( "Failed to receive heartbeat request, error: {:?}", req.err() ); return; } let req = req.unwrap(); let Some(machine_id) = req.machine_id else { tracing::warn!(?req, "Machine id is not set, ignore"); continue; }; let running_inst_ids = req .running_network_instances .iter() .map(|x| x.to_string()) .collect::>(); let Some(storage) = storage.upgrade() else { tracing::error!("Failed to get storage"); return; }; let user_id = match storage .db .get_user_id_by_token(req.user_token.clone()) .await { Ok(Some(user_id)) => user_id, Ok(None) => { tracing::info!("User not found by token: {:?}", req.user_token); return; } Err(e) => { tracing::error!("Failed to get user id by token, error: {:?}", e); return; } }; let local_configs = match storage .db .list_network_configs((user_id, machine_id.into()), ListNetworkProps::EnabledOnly) .await { Ok(configs) => configs, Err(e) => { tracing::error!("Failed to list network configs, error: {:?}", e); return; } }; let mut has_failed = false; let should_be_alive_inst_ids = local_configs .iter() .map(|cfg| cfg.network_instance_id.clone()) .collect::>(); let desired_changed = last_desired_inst_ids .as_ref() .is_none_or(|last| last != &should_be_alive_inst_ids); if !cleaned_web_managed_instances || desired_changed { let all_local_configs = match storage .db .list_network_configs((user_id, machine_id.into()), ListNetworkProps::All) .await { Ok(configs) => configs, Err(e) => { tracing::error!("Failed to list all network configs, error: {:?}", e); return; } }; let all_inst_ids = all_local_configs .iter() .map(|cfg| cfg.network_instance_id.clone()) .collect::>(); let should_delete_ids = running_inst_ids .iter() .chain(all_inst_ids.iter()) .filter(|inst_id| !should_be_alive_inst_ids.contains(*inst_id)) .filter_map(|inst_id| uuid::Uuid::parse_str(inst_id).ok()) .map(Into::into) .collect::>(); if !should_delete_ids.is_empty() { let ret = rpc_client .delete_network_instance( BaseController::default(), easytier::proto::api::manage::DeleteNetworkInstanceRequest { inst_ids: should_delete_ids, }, ) .await; tracing::info!( ?user_id, "Clean non-web-managed network instances on start: {:?}, user_token: {:?}", ret, req.user_token ); has_failed |= ret.is_err(); } if !has_failed { cleaned_web_managed_instances = true; last_desired_inst_ids = Some(should_be_alive_inst_ids.clone()); } } for c in local_configs { if running_inst_ids.contains(&c.network_instance_id) { continue; } let ret = rpc_client .run_network_instance( BaseController::default(), RunNetworkInstanceRequest { inst_id: Some(c.network_instance_id.clone().into()), config: Some( serde_json::from_str::(&c.network_config).unwrap(), ), overwrite: false, }, ) .await; tracing::info!( ?user_id, "Run network instance: {:?}, user_token: {:?}", ret, req.user_token ); has_failed |= ret.is_err(); } if !has_failed { last_desired_inst_ids = Some(should_be_alive_inst_ids); } } } pub fn is_running(&self) -> bool { self.rpc_mgr.is_running() } pub async fn stop(&self) { self.rpc_mgr.stop().await; } pub fn data(&self) -> SharedSessionData { self.data.clone() } pub fn scoped_client(&self) -> F::ClientImpl { self.rpc_mgr .rpc_client() .scoped_client::(1, 1, "".to_string()) } pub fn scoped_rpc_client(&self) -> SessionRpcClient { self.scoped_client::>() } pub async fn get_token(&self) -> Option { self.data.read().await.storage_token.clone() } pub async fn get_heartbeat_req(&self) -> Option { self.data.read().await.req() } } #[cfg(test)] mod tests { use easytier::rpc_service::remote_client::{ListNetworkProps, Storage as _}; use serde_json::json; use super::{super::storage::Storage, *}; #[tokio::test] async fn reconcile_managed_network_configs_upserts_and_deletes_exact_set() { let storage = Storage::new(crate::db::Db::memory_db().await); let user_id = storage .db() .auto_create_user("webhook-user") .await .unwrap() .id; let machine_id = uuid::Uuid::new_v4(); let keep_id = uuid::Uuid::new_v4(); let stale_id = uuid::Uuid::new_v4(); let new_id = uuid::Uuid::new_v4(); storage .db() .insert_or_update_user_network_config( (user_id, machine_id), keep_id, NetworkConfig { network_name: Some("old-name".to_string()), ..Default::default() }, ) .await .unwrap(); storage .db() .insert_or_update_user_network_config( (user_id, machine_id), stale_id, NetworkConfig { network_name: Some("stale".to_string()), ..Default::default() }, ) .await .unwrap(); SessionRpcService::reconcile_managed_network_configs( &storage, user_id, machine_id, vec![ crate::webhook::ManagedNetworkConfig { instance_id: keep_id.to_string(), network_config: json!({ "instance_id": keep_id.to_string(), "network_name": "updated-name" }), }, crate::webhook::ManagedNetworkConfig { instance_id: new_id.to_string(), network_config: json!({ "instance_id": new_id.to_string(), "network_name": "new-name" }), }, ], ) .await .unwrap(); let configs = storage .db() .list_network_configs((user_id, machine_id), ListNetworkProps::All) .await .unwrap(); let config_ids = configs .iter() .map(|cfg| cfg.network_instance_id.clone()) .collect::>(); assert_eq!(configs.len(), 2); assert!(config_ids.contains(&keep_id.to_string())); assert!(config_ids.contains(&new_id.to_string())); assert!(!config_ids.contains(&stale_id.to_string())); let updated_keep = storage .db() .get_network_config((user_id, machine_id), &keep_id.to_string()) .await .unwrap() .unwrap(); let updated_keep_config: NetworkConfig = serde_json::from_str(&updated_keep.network_config).unwrap(); assert_eq!( updated_keep_config.network_name.as_deref(), Some("updated-name") ); } }