diff --git a/easytier/src/peers/acl_filter.rs b/easytier/src/peers/acl_filter.rs index 180865a7..2ae3c8f5 100644 --- a/easytier/src/peers/acl_filter.rs +++ b/easytier/src/peers/acl_filter.rs @@ -1,16 +1,19 @@ use std::net::{Ipv4Addr, Ipv6Addr}; use std::sync::atomic::{AtomicU16, Ordering}; +use std::time::Instant; use std::{ net::IpAddr, sync::{atomic::AtomicBool, Arc}, }; use arc_swap::ArcSwap; +use dashmap::DashMap; use pnet::packet::ipv6::Ipv6Packet; use pnet::packet::{ ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket, Packet as _, }; +use crate::common::scoped_task::ScopedTask; use crate::proto::acl::{AclStats, Protocol}; use crate::tunnel::packet_def::PacketType; use crate::{ @@ -19,6 +22,37 @@ use crate::{ tunnel::packet_def::ZCPacket, }; +#[derive(Debug, Eq, PartialEq, Hash)] +struct OutboundAllowRecord { + src_ip: IpAddr, + dst_ip: IpAddr, + src_port: Option, + dst_port: Option, + protocol: Protocol, +} + +impl OutboundAllowRecord { + fn new_from_inbound_packet(p: &PacketInfo) -> Self { + Self { + src_ip: p.src_ip, + dst_ip: p.dst_ip, + src_port: p.src_port, + dst_port: p.dst_port, + protocol: p.protocol, + } + } + + fn new_from_outbound_packet(p: &PacketInfo) -> Self { + Self { + src_ip: p.dst_ip, + dst_ip: p.src_ip, + src_port: p.dst_port, + dst_port: p.src_port, + protocol: p.protocol, + } + } +} + /// ACL filter that can be inserted into the packet processing pipeline /// Optimized with lock-free hot reloading via atomic processor replacement pub struct AclFilter { @@ -26,6 +60,11 @@ pub struct AclFilter { acl_processor: ArcSwap, acl_enabled: Arc, quic_udp_port: AtomicU16, + + // Track allowed outbound packets and automatically allow their corresponding inbound response + // packets, even if they would normally be dropped by ACL rules + outbound_allow_records: Arc>, + clean_task: ScopedTask<()>, } impl Default for AclFilter { @@ -36,10 +75,21 @@ impl Default for AclFilter { impl AclFilter { pub fn new() -> Self { + let outbound_allow_records = Arc::new(DashMap::new()); + let record_clone = outbound_allow_records.clone(); Self { acl_processor: ArcSwap::from(Arc::new(AclProcessor::new(Acl::default()))), acl_enabled: Arc::new(AtomicBool::new(false)), quic_udp_port: AtomicU16::new(0), + outbound_allow_records, + clean_task: tokio::spawn(async move { + let max_life = std::time::Duration::from_secs(30); + loop { + record_clone.retain(|_, v| v.elapsed() < max_life); + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + } + }) + .into(), } } @@ -336,8 +386,32 @@ impl AclFilter { // Check if packet should be allowed match acl_result.action { - Action::Allow | Action::Noop => true, + Action::Allow | Action::Noop => { + if matches!(chain_type, ChainType::Outbound) { + self.outbound_allow_records.insert( + OutboundAllowRecord::new_from_outbound_packet(&packet_info), + Instant::now(), + ); + } + true + } Action::Drop => { + if is_in { + let record = OutboundAllowRecord::new_from_inbound_packet(&packet_info); + let entry = self.outbound_allow_records.entry(record); + if let dashmap::Entry::Occupied(mut entry) = entry { + entry.insert(Instant::now()); + tracing::trace!( + "ACL: Allowing {:?} packet from {} to {} because of existing allow record, chain_type: {:?}", + packet_info.protocol, + packet_info.src_ip, + packet_info.dst_ip, + chain_type, + ); + return true; + } + } + tracing::trace!( "ACL: Dropping {:?} packet from {} to {}, chain_type: {:?}", packet_info.protocol, diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 9730b1f1..bb419dcd 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -2468,12 +2468,21 @@ pub async fn acl_group_self_test( #[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn whitelist_test(#[values("tcp", "udp")] protocol: &str) { +pub async fn whitelist_test( + #[values("tcp", "udp")] protocol: &str, + #[values(true, false)] test_outbound_allow_list: bool, +) { let port = 44553; + let acl_configured_inst = if test_outbound_allow_list { + "inst1" + } else { + "inst3" + }; let insts = init_three_node_ex( protocol, move |cfg| { - if cfg.get_inst_name() == "inst3" { + let port = if test_outbound_allow_list { 0 } else { port }; + if cfg.get_inst_name() == acl_configured_inst { if protocol == "tcp" { cfg.set_tcp_whitelist(vec![format!("{}", port)]); } else if protocol == "udp" { @@ -2536,6 +2545,10 @@ pub async fn whitelist_test(#[values("tcp", "udp")] protocol: &str) { .unwrap_or_else(|_| panic!("{} should be allowed", p)); } + if test_outbound_allow_list { + return; + } + // test other port let other_port = port + 1; for p in ["tcp", "udp"] {