use std::{ collections::{HashMap, HashSet}, net::{IpAddr, SocketAddr}, str::FromStr as _, sync::Arc, time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use crate::common::{config::ConfigLoader, global_ctx::ArcGlobalCtx, token_bucket::TokenBucket}; use crate::proto::acl::*; use anyhow::Context as _; use dashmap::DashMap; use tokio::task::JoinSet; // Performance-optimized key for rate limiting to avoid string allocations #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct RateLimitKey { pub chain_type: ChainType, pub rule_priority: u32, } impl RateLimitKey { pub fn new(chain_type: ChainType, rule_priority: u32) -> Self { Self { chain_type, rule_priority, } } } /// Value wrapper for rate limiters with last update timestamp pub struct RateLimitValue { pub token_bucket: Arc, pub last_update: Instant, } // Performance-optimized rule identifier to avoid string allocations #[derive(Debug, Clone, PartialEq, Eq)] pub enum RuleId { Priority(u32), Stateful(u32), Default, } impl RuleId { /// Convert to string only when actually needed (lazy evaluation) pub fn to_string_cached(&self) -> String { match self { RuleId::Priority(p) => p.to_string(), RuleId::Stateful(p) => format!("stateful-{}", p), RuleId::Default => "default".to_string(), } } /// Get string representation for logging (optimized for hot path) pub fn as_str(&self) -> String { self.to_string_cached() } } // Fast lookup structures for performance optimization #[derive(Debug, Clone)] pub struct FastLookupRule { pub priority: u32, pub protocol: Protocol, pub src_ip_ranges: Vec, pub dst_ip_ranges: Vec, pub src_port_ranges: Vec<(u16, u16)>, pub dst_port_ranges: Vec<(u16, u16)>, pub source_groups: HashSet, pub destination_groups: HashSet, pub action: Action, pub enabled: bool, pub stateful: bool, pub rate_limit: u32, pub burst_limit: u32, pub rule_stats: Arc, } // Cache key combining packet info and chain type #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct AclCacheKey { pub chain_type: ChainType, pub protocol: Protocol, pub src_ip: IpAddr, pub dst_ip: IpAddr, pub src_port: u16, pub dst_port: u16, pub src_groups: Arc>, pub dst_groups: Arc>, } impl AclCacheKey { pub fn from_packet_info(packet_info: &PacketInfo, chain_type: ChainType) -> Self { Self { chain_type, protocol: packet_info.protocol, src_ip: packet_info.src_ip, dst_ip: packet_info.dst_ip, src_port: packet_info.src_port.unwrap_or(0), dst_port: packet_info.dst_port.unwrap_or(0), src_groups: packet_info.src_groups.clone(), dst_groups: packet_info.dst_groups.clone(), } } } // Cache entry with timestamp for LRU cleanup #[derive(Debug, Clone)] pub struct AclCacheEntry { pub action: Action, pub matched_rule: RuleId, pub last_access: std::time::Instant, // New fields to track rule characteristics for proper cache behavior pub conn_track_key: Option, pub rate_limit_keys: Vec, pub chain_type: ChainType, pub acl_result: Option, pub rule_stats_vec: Vec>, } // Packet info extracted for ACL processing #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct PacketInfo { pub src_ip: IpAddr, pub dst_ip: IpAddr, pub src_port: Option, pub dst_port: Option, pub protocol: Protocol, pub packet_size: usize, pub src_groups: Arc>, pub dst_groups: Arc>, } // ACL processing result #[derive(Debug, Clone)] pub struct AclResult { pub action: Action, pub matched_rule: Option, pub should_log: bool, pub log_context: Option, } impl AclResult { /// Get matched rule as string (lazy evaluation) pub fn matched_rule_string(&self) -> Option { self.matched_rule.as_ref().map(|r| r.to_string_cached()) } /// Get matched rule as string reference for logging (compatibility method) pub fn matched_rule_str(&self) -> Option { self.matched_rule.as_ref().map(|r| r.as_str()) } } // Context for lazy log message construction #[derive(Debug, Clone)] pub enum AclLogContext { StatefulMatch { src_ip: IpAddr, dst_ip: IpAddr, }, RuleMatch { src_ip: IpAddr, dst_ip: IpAddr, action: Action, }, DefaultDrop, DefaultAllow, UnsupportedChainType, RateLimitDrop, } impl AclLogContext { pub fn to_message(&self) -> String { match self { AclLogContext::StatefulMatch { src_ip, dst_ip } => { format!("Stateful match: {} -> {}", src_ip, dst_ip) } AclLogContext::RuleMatch { src_ip, dst_ip, action, } => { format!("Rule match: {} -> {} action: {:?}", src_ip, dst_ip, action) } AclLogContext::DefaultDrop => "No matching rule, default drop".to_string(), AclLogContext::DefaultAllow => "No matching rule, default allow".to_string(), AclLogContext::UnsupportedChainType => "Unsupported chain type".to_string(), AclLogContext::RateLimitDrop => "Rate limit drop".to_string(), } } } pub type SharedState = ( Arc>, Arc>, Arc>, ); // High-performance ACL processor - No more internal locks! pub struct AclProcessor { // Immutable rule vectors - no locks needed since they're never modified after creation inbound_rules: Vec, outbound_rules: Vec, forward_rules: Vec, default_inbound_action: Action, default_outbound_action: Action, default_forward_action: Action, default_rule_stats: Arc, // Connection tracking table - shared across different processor instances if needed conn_track: Arc>, // Rate limiting buckets per rule using TokenBucket with optimized keys rate_limiters: Arc>, // Rule lookup cache with LRU cleanup rule_cache: Arc>, cache_max_size: usize, cache_cleanup_interval: Duration, // Statistics stats: Arc>, tasks: JoinSet<()>, } impl AclProcessor { /// Create a new ACL processor with pre-built immutable rules /// This is the main constructor that should be used pub fn new(acl_config: Acl) -> Self { Self::new_with_shared_state(acl_config, None, None, None) } /// Create a new ACL processor while preserving connection tracking and rate limiting state /// This is useful for hot reloading where you want to preserve established connections pub fn new_with_shared_state( acl_config: Acl, conn_track: Option>>, rate_limiters: Option>>, stats: Option>>, ) -> Self { let (inbound_rules, outbound_rules, forward_rules) = Self::build_rules(&acl_config); let (default_inbound_action, default_outbound_action, default_forward_action) = Self::build_default_actions(&acl_config); let tasks = JoinSet::new(); let mut processor = Self { inbound_rules, outbound_rules, forward_rules, default_inbound_action, default_outbound_action, default_forward_action, default_rule_stats: Arc::new(RuleStats { rule: None, stat: Some(StatItem { packet_count: 0, byte_count: 0, }), }), conn_track: conn_track.unwrap_or_else(|| Arc::new(DashMap::new())), rate_limiters: rate_limiters.unwrap_or_else(|| Arc::new(DashMap::new())), rule_cache: Arc::new(DashMap::new()), // Always start with fresh cache cache_max_size: 1024, // Limit cache to 1k entries cache_cleanup_interval: Duration::from_secs(20), // Cleanup every 5 minutes stats: stats.unwrap_or_else(|| Arc::new(DashMap::new())), tasks, }; processor.start_cache_cleanup_task(); processor } fn build_default_actions(acl_config: &Acl) -> (Action, Action, Action) { let default_inbound_action = acl_config .acl_v1 .as_ref() .and_then(|v1| { v1.chains .iter() .find(|c| c.chain_type == ChainType::Inbound as i32) }) .map(|c| c.default_action()) .unwrap_or(Action::Allow); let default_outbound_action = acl_config .acl_v1 .as_ref() .and_then(|v1| { v1.chains .iter() .find(|c| c.chain_type == ChainType::Outbound as i32) }) .map(|c| c.default_action()) .unwrap_or(Action::Allow); let default_forward_action = acl_config .acl_v1 .as_ref() .and_then(|v1| { v1.chains .iter() .find(|c| c.chain_type == ChainType::Forward as i32) }) .map(|c| c.default_action()) .unwrap_or(Action::Allow); ( default_inbound_action, default_outbound_action, default_forward_action, ) } /// Build all rule vectors from configuration fn build_rules( acl_config: &Acl, ) -> ( Vec, Vec, Vec, ) { let mut inbound_rules = Vec::new(); let mut outbound_rules = Vec::new(); let mut forward_rules = Vec::new(); // Build new rule vectors if let Some(ref acl_v1) = acl_config.acl_v1 { for chain in &acl_v1.chains { if !chain.enabled { continue; } let mut rules = chain .rules .iter() .filter(|rule| rule.enabled) .map(Self::convert_to_fast_lookup_rule) .collect::>(); // Sort by priority (higher priority first) rules.sort_by(|a, b| b.priority.cmp(&a.priority)); match chain.chain_type() { ChainType::Inbound => inbound_rules.extend(rules), ChainType::Outbound => outbound_rules.extend(rules), ChainType::Forward => forward_rules.extend(rules), _ => {} } } } tracing::info!( "ACL rules built: {} inbound, {} outbound, {} forward", inbound_rules.len(), outbound_rules.len(), forward_rules.len(), ); (inbound_rules, outbound_rules, forward_rules) } /// Start periodic cache cleanup task fn start_cache_cleanup_task(&mut self) { let rate_limiters = self.rate_limiters.clone(); let rule_cache = self.rule_cache.clone(); let cache_max_size = self.cache_max_size; let cleanup_interval = self.cache_cleanup_interval; self.tasks.spawn(async move { let mut interval = tokio::time::interval(cleanup_interval); loop { interval.tick().await; Self::cleanup_cache(&rule_cache, cache_max_size); rule_cache.shrink_to_fit(); rate_limiters.retain(|_, v| v.last_update.elapsed() < cleanup_interval); rate_limiters.shrink_to_fit(); } }); let conn_track = self.conn_track.clone(); self.tasks.spawn(async move { let mut interval = tokio::time::interval(cleanup_interval); loop { interval.tick().await; Self::cleanup_expired_connections(conn_track.clone(), 60); conn_track.shrink_to_fit(); } }); } /// Clean up cache using LRU strategy fn cleanup_cache(cache: &DashMap, max_size: usize) { // remove cache not be used in last 15 second let expired_timepoint = Instant::now() .checked_sub(Duration::from_secs(15)) .unwrap_or(Instant::now()); cache.retain(|_, entry| entry.last_access > expired_timepoint); let current_size = cache.len(); if current_size <= max_size { return; } // Remove oldest entries (LRU cleanup) let mut entries: Vec<(AclCacheKey, std::time::Instant)> = cache .iter() .map(|entry| (entry.key().clone(), entry.value().last_access)) .collect(); // Sort by last_access (oldest first) entries.sort_by_key(|(_, last_access)| *last_access); // Remove oldest 20% of entries let to_remove = current_size - max_size + (max_size / 5); for (key, _) in entries.into_iter().take(to_remove) { cache.remove(&key); } tracing::debug!( "Cache cleanup completed: removed {} entries, current size: {}", to_remove, cache.len() ); } pub fn process_packet_with_cache_entry( &self, packet_info: &PacketInfo, cache_entry: &AclCacheEntry, ) -> AclResult { for rate_limit_key in cache_entry.rate_limit_keys.iter() { // bucket should already be created, so rate and burst are not important if !self.check_rate_limit(rate_limit_key, 1, 1, false) { return AclResult { action: Action::Drop, matched_rule: Some(cache_entry.matched_rule.clone()), should_log: false, log_context: Some(AclLogContext::RateLimitDrop), }; } } if let Some(conn_track_key) = cache_entry.conn_track_key.as_ref() { self.check_connection_state(conn_track_key, packet_info); } self.inc_cache_entry_stats(cache_entry, packet_info); cache_entry.acl_result.clone().unwrap() } fn inc_cache_entry_stats(&self, cache_entry: &AclCacheEntry, packet_info: &PacketInfo) { for rule_stats in cache_entry.rule_stats_vec.iter() { // Use unsafe code to mutate the contents behind the Arc let stat_ptr = rule_stats.stat.as_ref().unwrap() as *const StatItem as *mut StatItem; unsafe { (*stat_ptr).packet_count += 1; (*stat_ptr).byte_count += packet_info.packet_size as u64; } } } pub fn get_rules_stats(&self) -> Vec { let mut stats: Vec = Vec::new(); for rule in self.inbound_rules.iter() { stats.push((*rule.rule_stats).clone()); } for rule in self.outbound_rules.iter() { stats.push((*rule.rule_stats).clone()); } for rule in self.forward_rules.iter() { stats.push((*rule.rule_stats).clone()); } stats } /// Process a packet through ACL rules - Now lock-free! pub fn process_packet(&self, packet_info: &PacketInfo, chain_type: ChainType) -> AclResult { // Check cache first for performance let cache_key = AclCacheKey::from_packet_info(packet_info, chain_type); // If cache hit and can skip checks, return cached result if let Some(mut cached) = self.rule_cache.get_mut(&cache_key) { // Update last access time for LRU cached.last_access = Instant::now(); self.increment_stat(AclStatKey::CacheHits); return self.process_packet_with_cache_entry(packet_info, &cached); } // Direct access to rules - no locks needed! let rules = match chain_type { ChainType::Inbound => &self.inbound_rules, ChainType::Outbound => &self.outbound_rules, ChainType::Forward => &self.forward_rules, _ => { return AclResult { action: Action::Drop, matched_rule: Some(RuleId::Default), should_log: false, log_context: Some(AclLogContext::UnsupportedChainType), } } }; let mut cache_entry = AclCacheEntry { action: Action::Allow, matched_rule: RuleId::Default, last_access: Instant::now(), conn_track_key: None, rate_limit_keys: vec![], chain_type, acl_result: None, rule_stats_vec: vec![], }; // Process rules in priority order for rule in rules.iter() { if !rule.enabled || !self.rule_matches(rule, packet_info) { continue; } // Check rate limiting if configured if rule.rate_limit > 0 { let rule_key = RateLimitKey::new(chain_type, rule.priority); cache_entry.rate_limit_keys.push(rule_key.clone()); cache_entry.rule_stats_vec.push(rule.rule_stats.clone()); if !self.check_rate_limit(&rule_key, rule.rate_limit, rule.burst_limit, true) { // rate limited, drop packet return AclResult { action: Action::Drop, matched_rule: Some(RuleId::Priority(rule.priority)), should_log: false, log_context: Some(AclLogContext::RateLimitDrop), }; } } // Handle stateful connections if configured if rule.stateful && rule.action == Action::Allow { let conn_track_key = self.conn_track_key(packet_info); self.check_connection_state(&conn_track_key, packet_info); cache_entry.rule_stats_vec.push(rule.rule_stats.clone()); cache_entry.matched_rule = RuleId::Stateful(rule.priority); cache_entry.conn_track_key = Some(conn_track_key); cache_entry.acl_result = Some(AclResult { action: Action::Allow, matched_rule: Some(RuleId::Stateful(rule.priority)), should_log: false, log_context: Some(AclLogContext::StatefulMatch { src_ip: packet_info.src_ip, dst_ip: packet_info.dst_ip, }), }); } else { // Rule matched, return action cache_entry.rule_stats_vec.push(rule.rule_stats.clone()); cache_entry.matched_rule = RuleId::Priority(rule.priority); cache_entry.acl_result = Some(AclResult { action: rule.action, matched_rule: Some(RuleId::Priority(rule.priority)), should_log: false, log_context: Some(AclLogContext::RuleMatch { src_ip: packet_info.src_ip, dst_ip: packet_info.dst_ip, action: rule.action, }), }); } // Cache the result with rule info self.increment_stat(AclStatKey::RuleMatches); self.inc_cache_entry_stats(&cache_entry, packet_info); self.cache_result(&cache_key, cache_entry.clone()); return cache_entry.acl_result.clone().unwrap(); } let default_action = match chain_type { ChainType::Inbound => self.default_inbound_action, ChainType::Outbound => self.default_outbound_action, ChainType::Forward => self.default_forward_action, _ => Action::Allow, }; // No rule matched, return default drop if default_action == Action::Drop { self.increment_stat(AclStatKey::DefaultDrops); } else { self.increment_stat(AclStatKey::DefaultAllows); } let log_context = if default_action == Action::Drop { AclLogContext::DefaultDrop } else { AclLogContext::DefaultAllow }; cache_entry .rule_stats_vec .push(self.default_rule_stats.clone()); cache_entry.matched_rule = RuleId::Default; cache_entry.acl_result = Some(AclResult { action: default_action, matched_rule: Some(RuleId::Default), should_log: false, log_context: Some(log_context), }); // Cache the default result (no rule info) self.inc_cache_entry_stats(&cache_entry, packet_info); self.cache_result(&cache_key, cache_entry.clone()); cache_entry.acl_result.clone().unwrap() } /// Get shared state for preserving across hot reloads pub fn get_shared_state(&self) -> SharedState { ( self.conn_track.clone(), self.rate_limiters.clone(), self.stats.clone(), ) } /// Cache an ACL result fn cache_result(&self, cache_key: &AclCacheKey, cache_entry: AclCacheEntry) { self.rule_cache.insert(cache_key.clone(), cache_entry); // Trigger cleanup if cache is getting too large if self.rule_cache.len() > self.cache_max_size * 2 { let cache = self.rule_cache.clone(); let max_size = self.cache_max_size; Self::cleanup_cache(&cache, max_size); } } /// Check if a rule matches the packet fn rule_matches(&self, rule: &FastLookupRule, packet_info: &PacketInfo) -> bool { // Protocol check if rule.protocol != Protocol::Any && rule.protocol as i32 != packet_info.protocol as i32 { return false; } // Source IP check if !rule.src_ip_ranges.is_empty() { let matches = rule .src_ip_ranges .iter() .any(|cidr| match (cidr, packet_info.src_ip) { (cidr::IpCidr::V4(v4_cidr), IpAddr::V4(v4_addr)) => v4_cidr.contains(&v4_addr), (cidr::IpCidr::V6(v6_cidr), IpAddr::V6(v6_addr)) => v6_cidr.contains(&v6_addr), _ => false, }); if !matches { return false; } } // Destination IP check if !rule.dst_ip_ranges.is_empty() { let matches = rule .dst_ip_ranges .iter() .any(|cidr| match (cidr, packet_info.dst_ip) { (cidr::IpCidr::V4(v4_cidr), IpAddr::V4(v4_addr)) => v4_cidr.contains(&v4_addr), (cidr::IpCidr::V6(v6_cidr), IpAddr::V6(v6_addr)) => v6_cidr.contains(&v6_addr), _ => false, }); if !matches { return false; } } // Source port check if let Some(src_port) = packet_info.src_port { if !rule.src_port_ranges.is_empty() { let matches = rule .src_port_ranges .iter() .any(|(start, end)| src_port >= *start && src_port <= *end); if !matches { return false; } } } // Destination port check if let Some(dst_port) = packet_info.dst_port { if !rule.dst_port_ranges.is_empty() { let matches = rule .dst_port_ranges .iter() .any(|(start, end)| dst_port >= *start && dst_port <= *end); if !matches { return false; } } } // Source group check if !rule.source_groups.is_empty() { let matches = packet_info .src_groups .iter() .any(|group| rule.source_groups.contains(group)); if !matches { return false; } } // Destination group check if !rule.destination_groups.is_empty() { let matches = packet_info .dst_groups .iter() .any(|group| rule.destination_groups.contains(group)); if !matches { return false; } } true } fn conn_track_key(&self, packet_info: &PacketInfo) -> String { format!( "{}:{}->{}:{}", packet_info.src_ip, packet_info.src_port.unwrap_or(0), packet_info.dst_ip, packet_info.dst_port.unwrap_or(0) ) } /// Check connection state for stateful rules fn check_connection_state(&self, conn_track_key: &str, packet_info: &PacketInfo) { self.conn_track .entry(conn_track_key.to_string()) .and_modify(|x| { x.last_seen = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); x.packet_count += 1; x.byte_count += packet_info.packet_size as u64; x.state = ConnState::Established as i32; }) .or_insert_with(|| ConnTrackEntry { src_addr: Some( SocketAddr::new(packet_info.src_ip, packet_info.src_port.unwrap_or(0)).into(), ), dst_addr: Some( SocketAddr::new(packet_info.dst_ip, packet_info.dst_port.unwrap_or(0)).into(), ), protocol: packet_info.protocol as i32, state: ConnState::New as i32, created_at: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(), last_seen: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(), packet_count: 1, byte_count: packet_info.packet_size as u64, }); } /// Check rate limiting for a rule fn check_rate_limit( &self, rule_key: &RateLimitKey, rate: u32, burst: u32, allow_create: bool, ) -> bool { if rate == 0 { return true; // No rate limiting } let mut rate_limiter = self .rate_limiters .entry(rule_key.clone()) .or_insert_with(|| { if !allow_create { panic!("Rate limit bucket not found"); } RateLimitValue { token_bucket: TokenBucket::new( burst as u64, rate as u64, Duration::from_millis(10), ), last_update: Instant::now(), } }); // Try to consume 1 token (1 packet) rate_limiter.last_update = Instant::now(); rate_limiter.token_bucket.try_consume(1) } /// Convert proto Rule to FastLookupRule fn convert_to_fast_lookup_rule(rule: &Rule) -> FastLookupRule { let src_ip_ranges = rule .source_ips .iter() .filter_map(|x| Self::convert_ip_inet_to_cidr(x.as_str())) .collect(); let dst_ip_ranges = rule .destination_ips .iter() .filter_map(|x| Self::convert_ip_inet_to_cidr(x.as_str())) .collect(); let src_port_ranges = rule .source_ports .iter() .filter_map(|port_range| { if let Some((start, end)) = parse_port_range(port_range) { Some((start, end)) } else { None } }) .collect(); let dst_port_ranges = rule .ports .iter() .filter_map(|port_range| { if let Some((start, end)) = parse_port_range(port_range) { Some((start, end)) } else { None } }) .collect(); FastLookupRule { priority: rule.priority, protocol: rule.protocol(), src_ip_ranges, dst_ip_ranges, src_port_ranges, dst_port_ranges, source_groups: rule.source_groups.iter().cloned().collect(), destination_groups: rule.destination_groups.iter().cloned().collect(), action: rule.action(), enabled: rule.enabled, stateful: rule.stateful, rate_limit: rule.rate_limit, burst_limit: rule.burst_limit, rule_stats: Arc::new(RuleStats { rule: Some(rule.clone()), stat: Some(StatItem { packet_count: 0, byte_count: 0, }), }), } } /// Convert IpInet to CIDR for fast lookup fn convert_ip_inet_to_cidr(input: &str) -> Option { cidr::IpCidr::from_str(input).ok() } /// Increment statistics counter pub fn increment_stat(&self, key: AclStatKey) { self.stats .entry(key) .and_modify(|counter| *counter += 1) .or_insert(1); } /// Get statistics pub fn get_stats(&self) -> HashMap { let mut stats = self .stats .iter() .map(|entry| (entry.key().as_str(), *entry.value())) .collect::>(); // Add cache statistics using enum keys stats.insert(AclStatKey::CacheSize.as_str(), self.rule_cache.len() as u64); stats.insert( AclStatKey::CacheMaxSize.as_str(), self.cache_max_size as u64, ); stats } /// Clean up expired connection tracking entries pub fn cleanup_expired_connections( conn_track: Arc>, timeout_secs: u64, ) { let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); let keys_to_remove: Vec = conn_track .iter() .filter_map(|entry| { if current_time - entry.last_seen > timeout_secs { Some(entry.key().clone()) } else { None } }) .collect(); for key in keys_to_remove { conn_track.remove(&key); } } /// Get cache hit rate pub fn get_cache_hit_rate(&self) -> f64 { let cache_hits = self .stats .get(&AclStatKey::CacheHits) .map(|v| *v.value()) .unwrap_or(0); let total_requests = cache_hits + self .stats .get(&AclStatKey::RuleMatches) .map(|v| *v.value()) .unwrap_or(0); if total_requests == 0 { 0.0 } else { cache_hits as f64 / total_requests as f64 } } } // 新增辅助函数 fn parse_port_start(port_strs: &[String]) -> Option { port_strs .iter() .filter_map(|s| parse_port_range(s).map(|(start, _)| start)) .min() } fn parse_port_end(port_strs: &[String]) -> Option { port_strs .iter() .filter_map(|s| parse_port_range(s).map(|(_, end)| end)) .max() } fn parse_port_range(s: &str) -> Option<(u16, u16)> { if let Some((start, end)) = s.split_once('-') { let start = start.trim().parse().ok()?; let end = end.trim().parse().ok()?; Some((start, end)) } else { let port = s.trim().parse().ok()?; Some((port, port)) } } // Statistics key enum for better performance #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum AclStatKey { // Cache statistics CacheHits, CacheSize, CacheMaxSize, RuleMatches, DefaultAllows, DefaultDrops, // Global packet statistics PacketsTotal, PacketsAllowed, PacketsDropped, PacketsNoop, // Per-chain statistics InboundPacketsTotal, InboundPacketsAllowed, InboundPacketsDropped, InboundPacketsNoop, OutboundPacketsTotal, OutboundPacketsAllowed, OutboundPacketsDropped, OutboundPacketsNoop, ForwardPacketsTotal, ForwardPacketsAllowed, ForwardPacketsDropped, ForwardPacketsNoop, UnknownPacketsTotal, UnknownPacketsAllowed, UnknownPacketsDropped, UnknownPacketsNoop, } impl AclStatKey { pub fn as_str(&self) -> String { format!("{:?}", self) } pub fn from_chain_and_action(chain_type: ChainType, stat_type: AclStatType) -> Self { match (chain_type, stat_type) { (ChainType::Inbound, AclStatType::Total) => AclStatKey::InboundPacketsTotal, (ChainType::Inbound, AclStatType::Allowed) => AclStatKey::InboundPacketsAllowed, (ChainType::Inbound, AclStatType::Dropped) => AclStatKey::InboundPacketsDropped, (ChainType::Inbound, AclStatType::Noop) => AclStatKey::InboundPacketsNoop, (ChainType::Outbound, AclStatType::Total) => AclStatKey::OutboundPacketsTotal, (ChainType::Outbound, AclStatType::Allowed) => AclStatKey::OutboundPacketsAllowed, (ChainType::Outbound, AclStatType::Dropped) => AclStatKey::OutboundPacketsDropped, (ChainType::Outbound, AclStatType::Noop) => AclStatKey::OutboundPacketsNoop, (ChainType::Forward, AclStatType::Total) => AclStatKey::ForwardPacketsTotal, (ChainType::Forward, AclStatType::Allowed) => AclStatKey::ForwardPacketsAllowed, (ChainType::Forward, AclStatType::Dropped) => AclStatKey::ForwardPacketsDropped, (ChainType::Forward, AclStatType::Noop) => AclStatKey::ForwardPacketsNoop, (_, AclStatType::Total) => AclStatKey::UnknownPacketsTotal, (_, AclStatType::Allowed) => AclStatKey::UnknownPacketsAllowed, (_, AclStatType::Dropped) => AclStatKey::UnknownPacketsDropped, (_, AclStatType::Noop) => AclStatKey::UnknownPacketsNoop, } } } pub struct AclRuleBuilder { pub acl: Option, pub tcp_whitelist: Vec, pub udp_whitelist: Vec, pub whitelist_priority: Option, } impl AclRuleBuilder { fn parse_port_list(port_list: &[String]) -> anyhow::Result> { let mut ports = Vec::new(); for port_spec in port_list { if port_spec.contains('-') { // Handle port range like "8000-9000" let parts: Vec<&str> = port_spec.split('-').collect(); if parts.len() != 2 { return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec)); } let start: u16 = parts[0] .parse() .with_context(|| format!("Invalid start port in range: {}", port_spec))?; let end: u16 = parts[1] .parse() .with_context(|| format!("Invalid end port in range: {}", port_spec))?; if start > end { return Err(anyhow::anyhow!( "Start port must be <= end port in range: {}", port_spec )); } // acl can handle port range ports.push(port_spec.clone()); } else { // Handle single port let port: u16 = port_spec .parse() .with_context(|| format!("Invalid port number: {}", port_spec))?; ports.push(port.to_string()); } } Ok(ports) } fn generate_acl_from_whitelists(&mut self) -> anyhow::Result<()> { if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() { return Ok(()); } // Create inbound chain for whitelist rules let mut inbound_chain = Chain { name: "inbound_whitelist".to_string(), chain_type: ChainType::Inbound as i32, description: "Auto-generated inbound whitelist from CLI".to_string(), enabled: true, rules: vec![], default_action: Action::Allow as i32, }; let mut rule_priority = self.whitelist_priority.unwrap_or(1000u32); // Add TCP whitelist rules if !self.tcp_whitelist.is_empty() { let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?; let tcp_rule = Rule { name: "tcp_whitelist".to_string(), description: "Auto-generated TCP whitelist rule".to_string(), priority: rule_priority, enabled: true, protocol: Protocol::Tcp as i32, ports: tcp_ports, source_ips: vec![], destination_ips: vec![], source_ports: vec![], action: Action::Allow as i32, rate_limit: 0, burst_limit: 0, stateful: true, source_groups: vec![], destination_groups: vec![], }; let tcp_rule_deny_other = Rule { name: "tcp_whitelist_deny_other".to_string(), description: "Auto-generated TCP whitelist rule to deny other ports".to_string(), priority: 0, enabled: true, protocol: Protocol::Tcp as i32, ports: vec!["0-65535".to_string()], source_ips: vec![], destination_ips: vec![], source_ports: vec![], action: Action::Drop as i32, rate_limit: 0, burst_limit: 0, stateful: false, source_groups: vec![], destination_groups: vec![], }; inbound_chain.rules.push(tcp_rule); inbound_chain.rules.push(tcp_rule_deny_other); rule_priority -= 1; } // Add UDP whitelist rules if !self.udp_whitelist.is_empty() { let udp_ports = Self::parse_port_list(&self.udp_whitelist)?; let udp_rule = Rule { name: "udp_whitelist".to_string(), description: "Auto-generated UDP whitelist rule".to_string(), priority: rule_priority, enabled: true, protocol: Protocol::Udp as i32, ports: udp_ports, source_ips: vec![], destination_ips: vec![], source_ports: vec![], action: Action::Allow as i32, rate_limit: 0, burst_limit: 0, stateful: false, source_groups: vec![], destination_groups: vec![], }; let udp_rule_deny_other = Rule { name: "udp_whitelist_deny_other".to_string(), description: "Auto-generated UDP whitelist rule to deny other ports".to_string(), priority: 0, enabled: true, protocol: Protocol::Udp as i32, ports: vec!["0-65535".to_string()], source_ips: vec![], destination_ips: vec![], source_ports: vec![], action: Action::Drop as i32, rate_limit: 0, burst_limit: 0, stateful: false, source_groups: vec![], destination_groups: vec![], }; inbound_chain.rules.push(udp_rule); inbound_chain.rules.push(udp_rule_deny_other); } if self.acl.is_none() { self.acl = Some(Acl::default()); } let acl = self.acl.as_mut().unwrap(); if let Some(ref mut acl_v1) = acl.acl_v1 { acl_v1.chains.push(inbound_chain); } else { acl.acl_v1 = Some(AclV1 { chains: vec![inbound_chain], group: Some(GroupInfo { declares: vec![], members: vec![], }), }); } Ok(()) } fn do_build(mut self) -> anyhow::Result> { self.generate_acl_from_whitelists()?; Ok(self.acl.clone()) } pub fn build(global_ctx: &ArcGlobalCtx) -> anyhow::Result> { let builder = AclRuleBuilder { acl: global_ctx.config.get_acl(), tcp_whitelist: global_ctx.config.get_tcp_whitelist(), udp_whitelist: global_ctx.config.get_udp_whitelist(), whitelist_priority: None, }; builder.do_build() } } #[derive(Debug, Clone, Copy)] pub enum AclStatType { Total, Allowed, Dropped, Noop, } #[cfg(test)] mod tests { use super::*; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, Ipv4Addr}; #[tokio::test] async fn test_group_based_acl_rules() { let mut acl_config = Acl::default(); let mut acl_v1 = AclV1::default(); let mut chain = Chain { name: "group_test_chain".to_string(), chain_type: ChainType::Inbound as i32, enabled: true, default_action: Action::Drop as i32, ..Default::default() }; // Rules chain.rules.push(Rule { name: "allow_admins_to_db".to_string(), priority: 100, enabled: true, action: Action::Allow as i32, protocol: Protocol::Any as i32, source_groups: vec!["admin".to_string()], destination_groups: vec!["db-server".to_string()], ..Default::default() }); chain.rules.push(Rule { name: "allow_devs_from_anywhere".to_string(), priority: 90, enabled: true, action: Action::Allow as i32, protocol: Protocol::Any as i32, source_groups: vec!["dev".to_string()], ..Default::default() }); chain.rules.push(Rule { name: "deny_guests_to_db".to_string(), priority: 80, enabled: true, action: Action::Drop as i32, protocol: Protocol::Any as i32, source_groups: vec!["guest".to_string()], destination_groups: vec!["db-server".to_string()], ..Default::default() }); chain.rules.push(Rule { name: "allow_specific_ip".to_string(), priority: 70, enabled: true, action: Action::Allow as i32, protocol: Protocol::Any as i32, source_ips: vec!["1.2.3.4/32".to_string()], ..Default::default() }); acl_v1.chains.push(chain); acl_config.acl_v1 = Some(acl_v1); let processor = AclProcessor::new(acl_config); // Case 3.1: Source group match (devs from anywhere) let mut packet_info = create_test_packet_info(); packet_info.src_groups = Arc::new(vec!["dev".to_string()]); let result = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result.action, Action::Allow); assert_eq!(result.matched_rule, Some(RuleId::Priority(90))); // Case 3.2: Source group no match packet_info.src_groups = Arc::new(vec!["guest".to_string()]); let result = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result.action, Action::Drop); // Default drop assert_eq!(result.matched_rule, Some(RuleId::Default)); // Case 3.3: Destination group match (deny guests to db) packet_info.src_groups = Arc::new(vec!["guest".to_string()]); packet_info.dst_groups = Arc::new(vec!["db-server".to_string()]); let result = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result.action, Action::Drop); assert_eq!(result.matched_rule, Some(RuleId::Priority(80))); // Case 3.4: Source and Destination groups match packet_info.src_groups = Arc::new(vec!["admin".to_string()]); packet_info.dst_groups = Arc::new(vec!["db-server".to_string()]); let result = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result.action, Action::Allow); assert_eq!(result.matched_rule, Some(RuleId::Priority(100))); // Case 3.5: Partial match (admin to web-server) packet_info.src_groups = Arc::new(vec!["admin".to_string()]); packet_info.dst_groups = Arc::new(vec!["web-server".to_string()]); let result = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result.action, Action::Drop); // Default drop assert_eq!(result.matched_rule, Some(RuleId::Default)); // Case 3.6: Rule with no group definition packet_info.src_ip = "1.2.3.4".parse().unwrap(); packet_info.src_groups = Arc::new(vec!["admin".to_string()]); packet_info.dst_groups = Arc::new(vec![]); let result = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result.action, Action::Allow); assert_eq!(result.matched_rule, Some(RuleId::Priority(70))); } fn create_test_acl_config() -> Acl { let mut acl_config = Acl::default(); let mut acl_v1 = AclV1::default(); // Create inbound chain let mut chain = Chain { name: "test_inbound".to_string(), chain_type: ChainType::Inbound as i32, enabled: true, ..Default::default() }; // Allow all rule let rule = Rule { name: "allow_all".to_string(), priority: 100, enabled: true, action: Action::Allow as i32, protocol: Protocol::Any as i32, ..Default::default() }; chain.rules.push(rule); acl_v1.chains.push(chain); acl_config.acl_v1 = Some(acl_v1); acl_config } fn create_test_packet_info() -> PacketInfo { PacketInfo { src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), src_port: Some(12345), dst_port: Some(80), protocol: Protocol::Tcp, packet_size: 1024, src_groups: Arc::new(vec![]), dst_groups: Arc::new(vec![]), } } #[test] fn test_acl_cache_key_creation() { let packet_info = create_test_packet_info(); let cache_key = AclCacheKey::from_packet_info(&packet_info, ChainType::Inbound); assert_eq!(cache_key.chain_type, ChainType::Inbound); assert_eq!(cache_key.protocol, Protocol::Tcp); assert_eq!( cache_key.src_ip, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)) ); assert_eq!(cache_key.dst_ip, IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); assert_eq!(cache_key.src_port, 12345); assert_eq!(cache_key.dst_port, 80); } #[test] fn test_acl_cache_key_equality() { let packet_info1 = create_test_packet_info(); let packet_info2 = create_test_packet_info(); let key1 = AclCacheKey::from_packet_info(&packet_info1, ChainType::Inbound); let key2 = AclCacheKey::from_packet_info(&packet_info2, ChainType::Inbound); assert_eq!(key1, key2); // Test hash consistency use std::collections::hash_map::DefaultHasher; let mut hasher1 = DefaultHasher::new(); let mut hasher2 = DefaultHasher::new(); key1.hash(&mut hasher1); key2.hash(&mut hasher2); assert_eq!(hasher1.finish(), hasher2.finish()); } #[tokio::test] async fn test_acl_processor_basic_functionality() { let acl_config = create_test_acl_config(); let processor = AclProcessor::new(acl_config); let packet_info = create_test_packet_info(); let result = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result.action, Action::Allow); assert!(result.matched_rule.is_some()); } #[tokio::test] async fn test_acl_cache_hit() { let acl_config = create_test_acl_config(); let processor = AclProcessor::new(acl_config); let packet_info = create_test_packet_info(); // First request - should be a cache miss let result1 = processor.process_packet(&packet_info, ChainType::Inbound); // Second request - should be a cache hit let result2 = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result1.action, result2.action); assert_eq!(result1.matched_rule, result2.matched_rule); // Check cache statistics let stats = processor.get_stats(); assert_eq!(stats.get(&AclStatKey::CacheHits.as_str()).unwrap_or(&0), &1); assert!(processor.get_cache_hit_rate() > 0.0); } #[tokio::test] async fn test_lock_free_hot_reload_demo() { println!("\n=== ACL 优化演示:无锁热加载 ==="); // 创建初始配置 let initial_config = create_test_acl_config(); let processor = AclProcessor::new(initial_config); let packet_info = create_test_packet_info(); // 处理一些数据包 println!("1. 处理初始数据包..."); let result1 = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result1.action, Action::Allow); println!(" ✓ 数据包被允许通过"); // 获取共享状态 let (conn_track, rate_limiters, stats) = processor.get_shared_state(); println!("2. 保存连接跟踪和统计状态..."); println!(" ✓ 连接数: {}", conn_track.len()); println!(" ✓ 限流器数量: {}", rate_limiters.len()); println!(" ✓ 统计计数器数量: {}", stats.len()); // 创建新配置(模拟热加载) let mut new_config = create_test_acl_config(); if let Some(ref mut acl_v1) = new_config.acl_v1 { let drop_rule = Rule { name: "drop_all".to_string(), priority: 200, enabled: true, action: Action::Drop as i32, protocol: Protocol::Any as i32, ..Default::default() }; acl_v1.chains[0].rules.push(drop_rule); } // 创建新的处理器实例(热加载) println!("3. 执行热加载(创建新的处理器实例)..."); let new_processor = AclProcessor::new_with_shared_state( new_config, Some(conn_track.clone()), Some(rate_limiters.clone()), Some(stats.clone()), ); // 验证新处理器的行为 let result2 = new_processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result2.action, Action::Drop); // 新规则应该拒绝 println!(" ✓ 新规则生效:数据包被拒绝"); // 验证状态被保留 let (new_conn_track, new_rate_limiters, new_stats) = new_processor.get_shared_state(); assert!(Arc::ptr_eq(&conn_track, &new_conn_track)); assert!(Arc::ptr_eq(&rate_limiters, &new_rate_limiters)); assert!(Arc::ptr_eq(&stats, &new_stats)); println!(" ✓ 连接状态和统计信息被完整保留"); println!("\n=== 性能优化效果 ==="); println!("✓ 无锁访问:处理器内部不再有任何锁"); println!("✓ 零拷贝:规则访问直接引用,无需克隆Arc"); println!("✓ 热加载:创建新实例替换,保留所有状态"); println!("✓ 内存效率:消除了多层Arc包装的开销"); } #[tokio::test] async fn test_performance_and_security_balance() { // Create ACL config with different rule types let mut acl_config = Acl::default(); let mut acl_v1 = AclV1::default(); let mut chain = Chain { name: "performance_test".to_string(), chain_type: ChainType::Inbound as i32, enabled: true, ..Default::default() }; // 1. High-priority simple rule for UDP (can be cached efficiently) let simple_rule = Rule { name: "simple_udp".to_string(), priority: 300, enabled: true, action: Action::Allow as i32, protocol: Protocol::Udp as i32, ..Default::default() }; // No stateful or rate limit - can benefit from full cache optimization chain.rules.push(simple_rule); // 2. Medium-priority stateful + rate-limited rule for TCP (security critical) let security_rule = Rule { name: "security_tcp".to_string(), priority: 200, enabled: true, action: Action::Allow as i32, protocol: Protocol::Tcp as i32, stateful: true, rate_limit: 100, burst_limit: 200, ..Default::default() }; chain.rules.push(security_rule); // 3. Low-priority default allow rule for Any let default_rule = Rule { name: "default_allow".to_string(), priority: 100, enabled: true, action: Action::Allow as i32, protocol: Protocol::Any as i32, ..Default::default() }; chain.rules.push(default_rule); acl_v1.chains.push(chain); acl_config.acl_v1 = Some(acl_v1); let processor = AclProcessor::new(acl_config); // Test simple UDP packet (should hit high-priority simple rule and be cached) let udp_packet = PacketInfo { src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), src_port: Some(12345), dst_port: Some(53), // DNS protocol: Protocol::Udp, // UDP packet_size: 512, src_groups: Arc::new(vec![]), dst_groups: Arc::new(vec![]), }; // Test TCP packet (should hit stateful+rate-limited rule) let tcp_packet = PacketInfo { src_ip: IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), src_port: Some(12345), dst_port: Some(80), // HTTP protocol: Protocol::Tcp, // TCP packet_size: 1024, src_groups: Arc::new(vec![]), dst_groups: Arc::new(vec![]), }; // Process UDP packets multiple times println!("\n=== Performance Test Results ==="); for i in 1..=5 { let result = processor.process_packet(&udp_packet, ChainType::Inbound); assert_eq!(result.action, Action::Allow); // UDP packets should match the highest priority rule that applies // Since all rules allow "Any" protocol, UDP will match the highest priority one println!( "UDP packet {}: Allowed by rule (priority {:?})", i, result.matched_rule ); } // Process TCP packets multiple times (stateful + rate limited) for i in 1..=3 { let result = processor.process_packet(&tcp_packet, ChainType::Inbound); println!( "TCP packet {}: {:?} by rule (priority {:?})", i, result.action, result.matched_rule ); } let stats = processor.get_stats(); println!("\nStatistics:"); println!( " Cache hits: {}", stats.get(&AclStatKey::CacheHits.as_str()).unwrap_or(&0) ); println!( " Rule matches: {}", stats.get(&AclStatKey::RuleMatches.as_str()).unwrap_or(&0) ); println!( " Cache hit rate: {:.1}%", processor.get_cache_hit_rate() * 100.0 ); println!("\n✓ Stateful + rate-limited rules: Always processed for security"); println!("✓ Simple rules: Cached for performance"); println!( "✓ Cache hit rate: {:.1}%", processor.get_cache_hit_rate() * 100.0 ); } #[test] fn test_rate_limit_drop_log_context() { // Test that RateLimitDrop log context is properly created let context = AclLogContext::RateLimitDrop; let message = context.to_message(); assert_eq!(message, "Rate limit drop"); } #[tokio::test] async fn test_rate_limit_drop_behavior() { let mut acl_config = create_test_acl_config(); // Create a very restrictive rate-limited rule if let Some(ref mut acl_v1) = acl_config.acl_v1 { let rule = Rule { name: "strict_rate_limit".to_string(), priority: 200, enabled: true, action: Action::Allow as i32, protocol: Protocol::Any as i32, rate_limit: 1, // Allow only 1 packet per second burst_limit: 1, // Burst of 1 packet ..Default::default() }; acl_v1.chains[0].rules.push(rule); } let processor = AclProcessor::new(acl_config); let packet_info = create_test_packet_info(); // First request should be allowed let result1 = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result1.action, Action::Allow); assert_eq!(result1.matched_rule, Some(RuleId::Priority(200))); // Second request should be rate limited and dropped immediately let result2 = processor.process_packet(&packet_info, ChainType::Inbound); assert_eq!(result2.action, Action::Drop); assert_eq!(result2.matched_rule, Some(RuleId::Priority(200))); assert!(!result2.should_log); // Verify the specific log context assert!(matches!( result2.log_context, Some(AclLogContext::RateLimitDrop) )); } }