use std::{ net::{Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, Mutex}, }; use anyhow::Context; use serde::{Deserialize, Serialize}; use crate::tunnel::generate_digest_from_str; #[auto_impl::auto_impl(Box, &)] pub trait ConfigLoader: Send + Sync { fn get_id(&self) -> uuid::Uuid; fn set_id(&self, id: uuid::Uuid); fn get_hostname(&self) -> String; fn set_hostname(&self, name: Option); fn get_inst_name(&self) -> String; fn set_inst_name(&self, name: String); fn get_netns(&self) -> Option; fn set_netns(&self, ns: Option); fn get_ipv4(&self) -> Option; fn set_ipv4(&self, addr: Option); fn get_dhcp(&self) -> bool; fn set_dhcp(&self, dhcp: bool); fn add_proxy_cidr(&self, cidr: cidr::IpCidr); fn remove_proxy_cidr(&self, cidr: cidr::IpCidr); fn get_proxy_cidrs(&self) -> Vec; fn get_network_identity(&self) -> NetworkIdentity; fn set_network_identity(&self, identity: NetworkIdentity); fn get_listener_uris(&self) -> Vec; fn get_file_logger_config(&self) -> FileLoggerConfig; fn set_file_logger_config(&self, config: FileLoggerConfig); fn get_console_logger_config(&self) -> ConsoleLoggerConfig; fn set_console_logger_config(&self, config: ConsoleLoggerConfig); fn get_peers(&self) -> Vec; fn set_peers(&self, peers: Vec); fn get_listeners(&self) -> Vec; fn set_listeners(&self, listeners: Vec); fn get_rpc_portal(&self) -> Option; fn set_rpc_portal(&self, addr: SocketAddr); fn get_vpn_portal_config(&self) -> Option; fn set_vpn_portal_config(&self, config: VpnPortalConfig); fn get_flags(&self) -> Flags; fn set_flags(&self, flags: Flags); fn get_exit_nodes(&self) -> Vec; fn set_exit_nodes(&self, nodes: Vec); fn dump(&self) -> String; } pub type NetworkSecretDigest = [u8; 32]; #[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct NetworkIdentity { pub network_name: String, pub network_secret: Option, #[serde(skip)] pub network_secret_digest: Option, } impl PartialEq for NetworkIdentity { fn eq(&self, other: &Self) -> bool { if self.network_name != other.network_name { return false; } if self.network_secret.is_some() && other.network_secret.is_some() && self.network_secret != other.network_secret { return false; } if self.network_secret_digest.is_some() && other.network_secret_digest.is_some() && self.network_secret_digest != other.network_secret_digest { return false; } return true; } } impl NetworkIdentity { pub fn new(network_name: String, network_secret: String) -> Self { let mut network_secret_digest = [0u8; 32]; generate_digest_from_str(&network_name, &network_secret, &mut network_secret_digest); NetworkIdentity { network_name, network_secret: Some(network_secret), network_secret_digest: Some(network_secret_digest), } } pub fn default() -> Self { Self::new("default".to_string(), "".to_string()) } } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct PeerConfig { pub uri: url::Url, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct NetworkConfig { pub cidr: String, pub allow: Option>, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)] pub struct FileLoggerConfig { pub level: Option, pub file: Option, pub dir: Option, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)] pub struct ConsoleLoggerConfig { pub level: Option, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct VpnPortalConfig { pub client_cidr: cidr::Ipv4Cidr, pub wireguard_listen: SocketAddr, } // Flags is used to control the behavior of the program #[derive(derivative::Derivative, Deserialize, Serialize)] #[derivative(Debug, Clone, PartialEq, Default)] pub struct Flags { #[derivative(Default(value = "\"tcp\".to_string()"))] pub default_protocol: String, #[derivative(Default(value = "true"))] pub enable_encryption: bool, #[derivative(Default(value = "true"))] pub enable_ipv6: bool, #[derivative(Default(value = "1380"))] pub mtu: u16, #[derivative(Default(value = "true"))] pub latency_first: bool, #[derivative(Default(value = "false"))] pub enable_exit_node: bool, #[derivative(Default(value = "false"))] pub no_tun: bool, #[derivative(Default(value = "false"))] pub use_smoltcp: bool, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] struct Config { netns: Option, hostname: Option, instance_name: Option, instance_id: Option, ipv4: Option, dhcp: Option, network_identity: Option, listeners: Option>, exit_nodes: Option>, peer: Option>, proxy_network: Option>, file_logger: Option, console_logger: Option, rpc_portal: Option, vpn_portal_config: Option, flags: Option, } #[derive(Debug, Clone)] pub struct TomlConfigLoader { config: Arc>, } impl Default for TomlConfigLoader { fn default() -> Self { TomlConfigLoader::new_from_str("").unwrap() } } impl TomlConfigLoader { pub fn new_from_str(config_str: &str) -> Result { let config = toml::de::from_str::(config_str).with_context(|| { format!( "failed to parse config file: {}\n{}", config_str, config_str ) })?; Ok(TomlConfigLoader { config: Arc::new(Mutex::new(config)), }) } pub fn new(config_path: &PathBuf) -> Result { let config_str = std::fs::read_to_string(config_path) .with_context(|| format!("failed to read config file: {:?}", config_path))?; let ret = Self::new_from_str(&config_str)?; let old_ns = ret.get_network_identity(); ret.set_network_identity(NetworkIdentity::new( old_ns.network_name, old_ns.network_secret.unwrap_or_default(), )); Ok(ret) } } impl ConfigLoader for TomlConfigLoader { fn get_inst_name(&self) -> String { self.config .lock() .unwrap() .instance_name .clone() .unwrap_or("default".to_string()) } fn set_inst_name(&self, name: String) { self.config.lock().unwrap().instance_name = Some(name); } fn get_hostname(&self) -> String { let hostname = self.config.lock().unwrap().hostname.clone(); match hostname { Some(hostname) => { if !hostname.is_empty() { let mut name = hostname .chars() .filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_') .take(32) .collect::(); if name.len() > 32 { name = name.chars().take(32).collect::(); } if hostname != name { self.set_hostname(Some(name.clone())); } name } else { self.set_hostname(None); gethostname::gethostname().to_string_lossy().to_string() } } None => gethostname::gethostname().to_string_lossy().to_string(), } } fn set_hostname(&self, name: Option) { self.config.lock().unwrap().hostname = name; } fn get_netns(&self) -> Option { self.config.lock().unwrap().netns.clone() } fn set_netns(&self, ns: Option) { self.config.lock().unwrap().netns = ns; } fn get_ipv4(&self) -> Option { let locked_config = self.config.lock().unwrap(); locked_config .ipv4 .as_ref() .map(|s| s.parse().ok()) .flatten() } fn set_ipv4(&self, addr: Option) { self.config.lock().unwrap().ipv4 = if let Some(addr) = addr { Some(addr.to_string()) } else { None }; } fn get_dhcp(&self) -> bool { self.config.lock().unwrap().dhcp.unwrap_or_default() } fn set_dhcp(&self, dhcp: bool) { self.config.lock().unwrap().dhcp = Some(dhcp); } fn add_proxy_cidr(&self, cidr: cidr::IpCidr) { let mut locked_config = self.config.lock().unwrap(); if locked_config.proxy_network.is_none() { locked_config.proxy_network = Some(vec![]); } let cidr_str = cidr.to_string(); // insert if no duplicate if !locked_config .proxy_network .as_ref() .unwrap() .iter() .any(|c| c.cidr == cidr_str) { locked_config .proxy_network .as_mut() .unwrap() .push(NetworkConfig { cidr: cidr_str, allow: None, }); } } fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) { let mut locked_config = self.config.lock().unwrap(); if let Some(proxy_cidrs) = &mut locked_config.proxy_network { let cidr_str = cidr.to_string(); proxy_cidrs.retain(|c| c.cidr != cidr_str); } } fn get_proxy_cidrs(&self) -> Vec { self.config .lock() .unwrap() .proxy_network .as_ref() .map(|v| { v.iter() .map(|c| c.cidr.parse().unwrap()) .collect::>() }) .unwrap_or_default() } fn get_id(&self) -> uuid::Uuid { let mut locked_config = self.config.lock().unwrap(); if locked_config.instance_id.is_none() { let id = uuid::Uuid::new_v4(); locked_config.instance_id = Some(id); id } else { locked_config.instance_id.as_ref().unwrap().clone() } } fn set_id(&self, id: uuid::Uuid) { self.config.lock().unwrap().instance_id = Some(id); } fn get_network_identity(&self) -> NetworkIdentity { self.config .lock() .unwrap() .network_identity .clone() .unwrap_or_else(NetworkIdentity::default) } fn set_network_identity(&self, identity: NetworkIdentity) { self.config.lock().unwrap().network_identity = Some(identity); } fn get_listener_uris(&self) -> Vec { self.config .lock() .unwrap() .listeners .clone() .unwrap_or_default() } fn get_file_logger_config(&self) -> FileLoggerConfig { self.config .lock() .unwrap() .file_logger .clone() .unwrap_or_default() } fn set_file_logger_config(&self, config: FileLoggerConfig) { self.config.lock().unwrap().file_logger = Some(config); } fn get_console_logger_config(&self) -> ConsoleLoggerConfig { self.config .lock() .unwrap() .console_logger .clone() .unwrap_or_default() } fn set_console_logger_config(&self, config: ConsoleLoggerConfig) { self.config.lock().unwrap().console_logger = Some(config); } fn get_peers(&self) -> Vec { self.config.lock().unwrap().peer.clone().unwrap_or_default() } fn set_peers(&self, peers: Vec) { self.config.lock().unwrap().peer = Some(peers); } fn get_listeners(&self) -> Vec { self.config .lock() .unwrap() .listeners .clone() .unwrap_or_default() } fn set_listeners(&self, listeners: Vec) { self.config.lock().unwrap().listeners = Some(listeners); } fn get_rpc_portal(&self) -> Option { self.config.lock().unwrap().rpc_portal } fn set_rpc_portal(&self, addr: SocketAddr) { self.config.lock().unwrap().rpc_portal = Some(addr); } fn get_vpn_portal_config(&self) -> Option { self.config.lock().unwrap().vpn_portal_config.clone() } fn set_vpn_portal_config(&self, config: VpnPortalConfig) { self.config.lock().unwrap().vpn_portal_config = Some(config); } fn get_flags(&self) -> Flags { self.config .lock() .unwrap() .flags .clone() .unwrap_or_default() } fn set_flags(&self, flags: Flags) { self.config.lock().unwrap().flags = Some(flags); } fn get_exit_nodes(&self) -> Vec { self.config .lock() .unwrap() .exit_nodes .clone() .unwrap_or_default() } fn set_exit_nodes(&self, nodes: Vec) { self.config.lock().unwrap().exit_nodes = Some(nodes); } fn dump(&self) -> String { toml::to_string_pretty(&*self.config.lock().unwrap()).unwrap() } } #[cfg(test)] pub mod tests { use super::*; #[tokio::test] async fn full_example_test() { let config_str = r#" instance_name = "default" instance_id = "87ede5a2-9c3d-492d-9bbe-989b9d07e742" ipv4 = "10.144.144.10" listeners = [ "tcp://0.0.0.0:11010", "udp://0.0.0.0:11010" ] [network_identity] network_name = "default" network_secret = "" [[peer]] uri = "tcp://public.kkrainbow.top:11010" [[peer]] uri = "udp://192.168.94.33:11010" [[proxy_network]] cidr = "10.147.223.0/24" allow = ["tcp", "udp", "icmp"] [[proxy_network]] cidr = "10.1.1.0/24" allow = ["tcp", "icmp"] [file_logger] level = "info" file = "easytier" dir = "/tmp/easytier" [console_logger] level = "warn" "#; let ret = TomlConfigLoader::new_from_str(config_str); if let Err(e) = &ret { println!("{}", e); } else { println!("{:?}", ret.as_ref().unwrap()); } assert!(ret.is_ok()); let ret = ret.unwrap(); assert_eq!("10.144.144.10", ret.get_ipv4().unwrap().to_string()); assert_eq!( vec!["tcp://0.0.0.0:11010", "udp://0.0.0.0:11010"], ret.get_listener_uris() .iter() .map(|u| u.to_string()) .collect::>() ); println!("{}", ret.dump()); } }