diff --git a/easytier/src/common/acl_processor.rs b/easytier/src/common/acl_processor.rs index ce4904d2..a789e5e9 100644 --- a/easytier/src/common/acl_processor.rs +++ b/easytier/src/common/acl_processor.rs @@ -39,6 +39,7 @@ pub struct RateLimitValue { pub enum RuleId { Priority(u32), Stateful(u32), + StatefulReverse, Default, } @@ -48,6 +49,7 @@ impl RuleId { match self { RuleId::Priority(p) => p.to_string(), RuleId::Stateful(p) => format!("stateful-{}", p), + RuleId::StatefulReverse => "stateful-reverse".to_string(), RuleId::Default => "default".to_string(), } } @@ -482,24 +484,30 @@ impl AclProcessor { stats } - /// Process a packet through ACL rules - Now lock-free! + /// Process a packet through ACL rules. pub fn process_packet(&self, packet_info: &PacketInfo, chain_type: ChainType) -> AclResult { - if let Some(result) = self.check_reverse_connection(packet_info) { - return result; - } - // Check cache first for performance let cache_key = AclCacheKey::from_packet_info(packet_info, chain_type); - // If cache hit and can skip checks, return cached result + // If cache hit and can skip checks, return cached result. Cached drops may be + // overridden by a stateful reverse connection that was created after caching. if let Some(mut cached) = self.rule_cache.get_mut(&cache_key) { // Update last access time for LRU cached.last_access = Instant::now(); self.increment_stat(AclStatKey::CacheHits); + if cached.acl_result.as_ref().map(|r| r.action) == Some(Action::Drop) + && let Some(result) = self.check_reverse_connection(packet_info) + { + return result; + } return self.process_packet_with_cache_entry(packet_info, &cached); } + if let Some(result) = self.check_reverse_connection(packet_info) { + return result; + } + // Direct access to rules - no locks needed! let rules = match chain_type { ChainType::Inbound => &self.inbound_rules, @@ -734,22 +742,35 @@ impl AclProcessor { } fn conn_track_key(&self, packet_info: &PacketInfo) -> String { - format!( - "{}:{}->{}:{}", + Self::make_conn_track_key( packet_info.src_ip, - packet_info.src_port.unwrap_or(0), + packet_info.src_port, packet_info.dst_ip, - packet_info.dst_port.unwrap_or(0) + packet_info.dst_port, ) } fn reverse_conn_track_key(&self, packet_info: &PacketInfo) -> String { + Self::make_conn_track_key( + packet_info.dst_ip, + packet_info.dst_port, + packet_info.src_ip, + packet_info.src_port, + ) + } + + fn make_conn_track_key( + src_ip: IpAddr, + src_port: Option, + dst_ip: IpAddr, + dst_port: Option, + ) -> String { format!( "{}:{}->{}:{}", - packet_info.dst_ip, - packet_info.dst_port.unwrap_or(0), - packet_info.src_ip, - packet_info.src_port.unwrap_or(0) + src_ip, + src_port.unwrap_or(0), + dst_ip, + dst_port.unwrap_or(0) ) } @@ -759,7 +780,7 @@ impl AclProcessor { Self::update_conn_track_entry(entry.value_mut(), packet_info); Some(AclResult { action: Action::Allow, - matched_rule: Some(RuleId::Default), + matched_rule: Some(RuleId::StatefulReverse), should_log: false, log_context: Some(AclLogContext::StatefulMatch { src_ip: packet_info.src_ip, @@ -1413,8 +1434,11 @@ mod tests { } } - #[tokio::test] - async fn test_stateful_allows_reverse_traffic_before_default_drop() { + #[test] + fn test_stateful_allows_reverse_traffic_before_default_drop() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let _runtime_guard = runtime.enter(); + let mut acl_config = Acl::default(); let mut acl_v1 = AclV1::default(); @@ -1474,6 +1498,7 @@ mod tests { let inbound_result = processor.process_packet(&inbound_reply, ChainType::Inbound); assert_eq!(inbound_result.action, Action::Allow); + assert_eq!(inbound_result.matched_rule, Some(RuleId::StatefulReverse)); } #[test]