pub mod session; pub mod storage; use std::sync::{ Arc, atomic::{AtomicU32, Ordering}, }; use dashmap::DashMap; use easytier::{ proto::{ api::manage::WebClientService, rpc_types::controller::BaseController, web::HeartbeatRequest, }, rpc_service::remote_client::{self, RemoteClientManager}, tunnel::TunnelListener, web_client::security, }; use maxminddb::geoip2; use session::{Location, Session}; use storage::{Storage, StorageToken}; use crate::FeatureFlags; use crate::webhook::SharedWebhookConfig; use tokio::task::JoinSet; use crate::db::{Db, UserIdInDb, entity::user_running_network_configs}; #[derive(rust_embed::Embed)] #[folder = "resources/"] #[include = "geoip2-cn.mmdb"] struct GeoipDb; fn load_geoip_db(geoip_db: Option) -> Option>> { if let Some(path) = geoip_db { match maxminddb::Reader::open_readfile(&path) { Ok(reader) => { tracing::info!("Successfully loaded GeoIP2 database from {}", path); Some(reader) } Err(err) => { tracing::debug!("Failed to load GeoIP2 database from {}: {}", path, err); None } } } else { let db = GeoipDb::get("geoip2-cn.mmdb").unwrap(); let reader = maxminddb::Reader::from_source(db.data.to_vec()).ok()?; tracing::info!("Successfully loaded GeoIP2 database from embedded file"); Some(reader) } } #[derive(Debug)] pub struct ClientManager { tasks: JoinSet<()>, listeners_cnt: Arc, client_sessions: Arc>>, storage: Storage, feature_flags: Arc, webhook_config: SharedWebhookConfig, geoip_db: Arc>>>, } impl ClientManager { pub fn new( db: Db, geoip_db: Option, feature_flags: Arc, webhook_config: SharedWebhookConfig, ) -> Self { let client_sessions = Arc::new(DashMap::new()); let sessions: Arc>> = client_sessions.clone(); let mut tasks = JoinSet::new(); tasks.spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(15)).await; sessions.retain(|_, session| session.is_running()); } }); ClientManager { tasks, listeners_cnt: Arc::new(AtomicU32::new(0)), client_sessions, storage: Storage::new(db), feature_flags, webhook_config, geoip_db: Arc::new(load_geoip_db(geoip_db)), } } pub async fn add_listener( &mut self, mut listener: L, ) -> Result<(), anyhow::Error> { listener.listen().await?; self.listeners_cnt.fetch_add(1, Ordering::Relaxed); let sessions = self.client_sessions.clone(); let storage = self.storage.weak_ref(); let listeners_cnt = self.listeners_cnt.clone(); let geoip_db = self.geoip_db.clone(); let feature_flags = self.feature_flags.clone(); let webhook_config = self.webhook_config.clone(); self.tasks.spawn(async move { while let Ok(tunnel) = listener.accept().await { let (tunnel, secure) = match security::accept_or_upgrade_server_tunnel(tunnel).await { Ok(v) => v, Err(error) => { tracing::warn!(%error, "failed to accept secure tunnel, dropping connection"); continue; } }; let info = tunnel.info().unwrap(); let client_url: url::Url = info.remote_addr.unwrap().into(); let location = Self::lookup_location(&client_url, geoip_db.clone()); tracing::info!( "New session from {:?}, secure: {}, location: {:?}", client_url, secure, location ); let mut session = Session::new( storage.clone(), client_url.clone(), location, feature_flags.clone(), webhook_config.clone(), ); session.serve(tunnel).await; sessions.insert(client_url, Arc::new(session)); } listeners_cnt.fetch_sub(1, Ordering::Relaxed); }); Ok(()) } pub fn is_running(&self) -> bool { self.listeners_cnt.load(Ordering::Relaxed) > 0 } pub async fn list_sessions(&self) -> Vec { let sessions = self .client_sessions .iter() .map(|item| item.value().clone()) .collect::>(); let mut ret: Vec = vec![]; for s in sessions { if let Some(t) = s.get_token().await { ret.push(t); } } ret } pub fn get_session_by_machine_id( &self, user_id: UserIdInDb, machine_id: &uuid::Uuid, ) -> Option> { let c_url = self .storage .get_client_url_by_machine_id(user_id, machine_id)?; self.client_sessions .get(&c_url) .map(|item| item.value().clone()) } pub async fn disconnect_session_by_machine_id( &self, user_id: UserIdInDb, machine_id: &uuid::Uuid, ) -> bool { let Some(client_url) = self .storage .get_client_url_by_machine_id(user_id, machine_id) else { return false; }; let Some((_, session)) = self.client_sessions.remove(&client_url) else { return false; }; session.stop().await; true } pub async fn list_machine_by_user_id(&self, user_id: UserIdInDb) -> Vec { self.storage.list_user_clients(user_id) } pub async fn get_heartbeat_requests(&self, client_url: &url::Url) -> Option { let s = self.client_sessions.get(client_url)?.clone(); s.data().read().await.req() } pub async fn get_machine_location(&self, client_url: &url::Url) -> Option { let s = self.client_sessions.get(client_url)?.clone(); s.data().read().await.location().cloned() } fn db(&self) -> &Db { self.storage.db() } fn lookup_location( client_url: &url::Url, geoip_db: Arc>>>, ) -> Option { let host = client_url.host_str()?; let ip: std::net::IpAddr = if let Ok(ip) = host.parse() { ip } else { tracing::debug!("Failed to parse host as IP address: {}", host); return None; }; // Skip lookup for private/special IPs let is_private = match ip { std::net::IpAddr::V4(ipv4) => { ipv4.is_private() || ipv4.is_loopback() || ipv4.is_unspecified() } std::net::IpAddr::V6(ipv6) => ipv6.is_loopback() || ipv6.is_unspecified(), }; if is_private { tracing::debug!("Skipping GeoIP lookup for special IP: {}", ip); let location = Location { country: "本地网络".to_string(), city: None, region: None, }; return Some(location); } let location = if let Some(db) = &*geoip_db { match db.lookup::(ip) { Ok(city) => { let country = city .country .and_then(|c| c.names) .and_then(|n| { n.get("zh-CN") .or_else(|| n.get("en")) .map(|s| s.to_string()) }) .unwrap_or_else(|| "海外".to_string()); let city_name = city.city.and_then(|c| c.names).and_then(|n| { n.get("zh-CN") .or_else(|| n.get("en")) .map(|s| s.to_string()) }); let region = city.subdivisions.map(|r| { r.iter() .filter_map(|x| x.names.as_ref()) .filter_map(|x| x.get("zh-CN").or_else(|| x.get("en"))) .map(|x| x.to_string()) .collect::>() .join(",") }); Location { country, city: city_name, region, } } Err(err) => { tracing::debug!("GeoIP lookup failed for {}: {}", ip, err); Location { country: "海外".to_string(), city: None, region: None, } } } } else { tracing::debug!( "GeoIP database not available, using default location for {}", ip ); Location { country: "海外".to_string(), city: None, region: None, } }; Some(location) } } impl RemoteClientManager< (UserIdInDb, uuid::Uuid), user_running_network_configs::Model, sea_orm::DbErr, > for ClientManager { fn get_rpc_client( &self, (user_id, machine_id): (UserIdInDb, uuid::Uuid), ) -> Option + Send>> { let s = self.get_session_by_machine_id(user_id, &machine_id)?; Some(s.scoped_rpc_client()) } fn get_storage( &self, ) -> &impl remote_client::Storage< (UserIdInDb, uuid::Uuid), user_running_network_configs::Model, sea_orm::DbErr, > { self.storage.db() } } #[cfg(test)] mod tests { use std::{sync::Arc, time::Duration}; use easytier::{ instance_manager::NetworkInstanceManager, tunnel::{ common::tests::wait_for_condition, udp::{UdpTunnelConnector, UdpTunnelListener}, }, web_client::WebClient, }; use sqlx::Executor; use crate::{FeatureFlags, client_manager::ClientManager, db::Db}; #[tokio::test] async fn test_client() { let listener = UdpTunnelListener::new("udp://0.0.0.0:54333".parse().unwrap()); let mut mgr = ClientManager::new( Db::memory_db().await, None, Arc::new(FeatureFlags::default()), Arc::new(crate::webhook::WebhookConfig::new( None, None, None, None, None, )), ); mgr.add_listener(Box::new(listener)).await.unwrap(); mgr.db() .inner() .execute("INSERT INTO users (username, password) VALUES ('test', 'test')") .await .unwrap(); let connector = UdpTunnelConnector::new("udp://127.0.0.1:54333".parse().unwrap()); let _c = WebClient::new( connector, "test", "test", false, Arc::new(NetworkInstanceManager::new()), None, ); wait_for_condition( || async { !mgr.client_sessions.is_empty() }, Duration::from_secs(12), ) .await; let req = tokio::time::timeout(Duration::from_secs(12), async { loop { let sessions = mgr .client_sessions .iter() .map(|item| item.value().clone()) .collect::>(); if sessions.is_empty() { tokio::time::sleep(Duration::from_millis(100)).await; continue; } let mut found_req = None; for session in sessions { if let Some(req) = session.data().read().await.req() { found_req = Some(req); break; } } if let Some(req) = found_req { break req; } tokio::time::sleep(Duration::from_millis(100)).await; } }) .await .unwrap(); println!("{:?}", req); println!("{:?}", mgr); } }