cli for port forward and tcp whitelist (#1165)

This commit is contained in:
Sijie.Sun
2025-07-29 09:30:47 +08:00
committed by GitHub
parent 5514de1187
commit 2ec88da823
8 changed files with 828 additions and 171 deletions
+142 -1
View File
@@ -6,8 +6,9 @@ use std::{
time::{Duration, SystemTime, UNIX_EPOCH}, time::{Duration, SystemTime, UNIX_EPOCH},
}; };
use crate::common::token_bucket::TokenBucket; use crate::common::{config::ConfigLoader, global_ctx::ArcGlobalCtx, token_bucket::TokenBucket};
use crate::proto::acl::*; use crate::proto::acl::*;
use anyhow::Context as _;
use dashmap::DashMap; use dashmap::DashMap;
use tokio::task::JoinSet; use tokio::task::JoinSet;
@@ -993,6 +994,146 @@ impl AclStatKey {
} }
} }
pub struct AclRuleBuilder {
pub acl: Option<Acl>,
pub tcp_whitelist: Vec<String>,
pub udp_whitelist: Vec<String>,
pub whitelist_priority: Option<u32>,
}
impl AclRuleBuilder {
fn parse_port_list(port_list: &[String]) -> anyhow::Result<Vec<String>> {
let mut ports = Vec::new();
for port_spec in port_list {
if port_spec.contains('-') {
// Handle port range like "8000-9000"
let parts: Vec<&str> = port_spec.split('-').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec));
}
let start: u16 = parts[0]
.parse()
.with_context(|| format!("Invalid start port in range: {}", port_spec))?;
let end: u16 = parts[1]
.parse()
.with_context(|| format!("Invalid end port in range: {}", port_spec))?;
if start > end {
return Err(anyhow::anyhow!(
"Start port must be <= end port in range: {}",
port_spec
));
}
// acl can handle port range
ports.push(port_spec.clone());
} else {
// Handle single port
let port: u16 = port_spec
.parse()
.with_context(|| format!("Invalid port number: {}", port_spec))?;
ports.push(port.to_string());
}
}
Ok(ports)
}
fn generate_acl_from_whitelists(&mut self) -> anyhow::Result<()> {
if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() {
return Ok(());
}
// Create inbound chain for whitelist rules
let mut inbound_chain = Chain {
name: "inbound_whitelist".to_string(),
chain_type: ChainType::Inbound as i32,
description: "Auto-generated inbound whitelist from CLI".to_string(),
enabled: true,
rules: vec![],
default_action: Action::Drop as i32, // Default deny
};
let mut rule_priority = self.whitelist_priority.unwrap_or(1000u32);
// Add TCP whitelist rules
if !self.tcp_whitelist.is_empty() {
let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?;
let tcp_rule = Rule {
name: "tcp_whitelist".to_string(),
description: "Auto-generated TCP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Tcp as i32,
ports: tcp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: true,
};
inbound_chain.rules.push(tcp_rule);
rule_priority -= 1;
}
// Add UDP whitelist rules
if !self.udp_whitelist.is_empty() {
let udp_ports = Self::parse_port_list(&self.udp_whitelist)?;
let udp_rule = Rule {
name: "udp_whitelist".to_string(),
description: "Auto-generated UDP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Udp as i32,
ports: udp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: false,
};
inbound_chain.rules.push(udp_rule);
}
if self.acl.is_none() {
self.acl = Some(Acl::default());
}
let acl = self.acl.as_mut().unwrap();
if let Some(ref mut acl_v1) = acl.acl_v1 {
acl_v1.chains.push(inbound_chain);
} else {
acl.acl_v1 = Some(AclV1 {
chains: vec![inbound_chain],
});
}
Ok(())
}
fn do_build(mut self) -> anyhow::Result<Option<Acl>> {
self.generate_acl_from_whitelists()?;
Ok(self.acl.clone())
}
pub fn build(global_ctx: &ArcGlobalCtx) -> anyhow::Result<Option<Acl>> {
let builder = AclRuleBuilder {
acl: global_ctx.config.get_acl(),
tcp_whitelist: global_ctx.config.get_tcp_whitelist(),
udp_whitelist: global_ctx.config.get_udp_whitelist(),
whitelist_priority: None,
};
builder.do_build()
}
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum AclStatType { pub enum AclStatType {
Total, Total,
+36 -1
View File
@@ -122,6 +122,12 @@ pub trait ConfigLoader: Send + Sync {
fn get_acl(&self) -> Option<Acl>; fn get_acl(&self) -> Option<Acl>;
fn set_acl(&self, acl: Option<Acl>); fn set_acl(&self, acl: Option<Acl>);
fn get_tcp_whitelist(&self) -> Vec<String>;
fn set_tcp_whitelist(&self, whitelist: Vec<String>);
fn get_udp_whitelist(&self) -> Vec<String>;
fn set_udp_whitelist(&self, whitelist: Vec<String>);
fn dump(&self) -> String; fn dump(&self) -> String;
} }
@@ -230,7 +236,7 @@ pub struct VpnPortalConfig {
pub wireguard_listen: SocketAddr, pub wireguard_listen: SocketAddr,
} }
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct PortForwardConfig { pub struct PortForwardConfig {
pub bind_addr: SocketAddr, pub bind_addr: SocketAddr,
pub dst_addr: SocketAddr, pub dst_addr: SocketAddr,
@@ -299,6 +305,9 @@ struct Config {
flags_struct: Option<Flags>, flags_struct: Option<Flags>,
acl: Option<Acl>, acl: Option<Acl>,
tcp_whitelist: Option<Vec<String>>,
udp_whitelist: Option<Vec<String>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -665,6 +674,32 @@ impl ConfigLoader for TomlConfigLoader {
self.config.lock().unwrap().acl = acl; self.config.lock().unwrap().acl = acl;
} }
fn get_tcp_whitelist(&self) -> Vec<String> {
self.config
.lock()
.unwrap()
.tcp_whitelist
.clone()
.unwrap_or_default()
}
fn set_tcp_whitelist(&self, whitelist: Vec<String>) {
self.config.lock().unwrap().tcp_whitelist = Some(whitelist);
}
fn get_udp_whitelist(&self) -> Vec<String> {
self.config
.lock()
.unwrap()
.udp_whitelist
.clone()
.unwrap_or_default()
}
fn set_udp_whitelist(&self, whitelist: Vec<String>) {
self.config.lock().unwrap().udp_whitelist = Some(whitelist);
}
fn dump(&self) -> String { fn dump(&self) -> String {
let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap(); let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap();
let default_flags_hashmap = let default_flags_hashmap =
+382 -9
View File
@@ -22,23 +22,25 @@ use tokio::time::timeout;
use easytier::{ use easytier::{
common::{ common::{
config::PortForwardConfig,
constants::EASYTIER_VERSION, constants::EASYTIER_VERSION,
stun::{StunInfoCollector, StunInfoCollectorTrait}, stun::{StunInfoCollector, StunInfoCollectorTrait},
}, },
proto::{ proto::{
cli::{ cli::{
list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, ConnectorManageRpc, list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, AddPortForwardRequest,
ConnectorManageRpcClientFactory, DumpRouteRequest, GetAclStatsRequest, ConnectorManageRpc, ConnectorManageRpcClientFactory, DumpRouteRequest,
GetVpnPortalInfoRequest, ListConnectorRequest, ListForeignNetworkRequest, GetAclStatsRequest, GetVpnPortalInfoRequest, GetWhitelistRequest, ListConnectorRequest,
ListGlobalForeignNetworkRequest, ListMappedListenerRequest, ListPeerRequest, ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, ListMappedListenerRequest,
ListPeerResponse, ListRouteRequest, ListRouteResponse, ManageMappedListenerRequest, ListPeerRequest, ListPeerResponse, ListPortForwardRequest, ListRouteRequest,
MappedListenerManageAction, MappedListenerManageRpc, ListRouteResponse, ManageMappedListenerRequest, MappedListenerManageAction,
MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc, MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc,
PeerManageRpcClientFactory, ShowNodeInfoRequest, TcpProxyEntryState, PeerManageRpcClientFactory, PortForwardManageRpc, PortForwardManageRpcClientFactory,
RemovePortForwardRequest, SetWhitelistRequest, ShowNodeInfoRequest, TcpProxyEntryState,
TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc, TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc,
VpnPortalRpcClientFactory, VpnPortalRpcClientFactory,
}, },
common::NatType, common::{NatType, SocketType},
peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory}, peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory},
rpc_impl::standalone::StandAloneClient, rpc_impl::standalone::StandAloneClient,
rpc_types::controller::BaseController, rpc_types::controller::BaseController,
@@ -96,6 +98,10 @@ enum SubCommand {
Proxy, Proxy,
#[command(about = "show ACL rules statistics")] #[command(about = "show ACL rules statistics")]
Acl(AclArgs), Acl(AclArgs),
#[command(about = "manage port forwarding")]
PortForward(PortForwardArgs),
#[command(about = "manage TCP/UDP whitelist")]
Whitelist(WhitelistArgs),
#[command(about = t!("core_clap.generate_completions").to_string())] #[command(about = t!("core_clap.generate_completions").to_string())]
GenAutocomplete { shell: Shell }, GenAutocomplete { shell: Shell },
} }
@@ -193,6 +199,62 @@ enum AclSubCommand {
Stats, Stats,
} }
#[derive(Args, Debug)]
struct PortForwardArgs {
#[command(subcommand)]
sub_command: Option<PortForwardSubCommand>,
}
#[derive(Subcommand, Debug)]
enum PortForwardSubCommand {
/// Add port forward rule
Add {
#[arg(help = "Protocol (tcp/udp)")]
protocol: String,
#[arg(help = "Local bind address (e.g., 0.0.0.0:8080)")]
bind_addr: String,
#[arg(help = "Destination address (e.g., 10.1.1.1:80)")]
dst_addr: String,
},
/// Remove port forward rule
Remove {
#[arg(help = "Protocol (tcp/udp)")]
protocol: String,
#[arg(help = "Local bind address (e.g., 0.0.0.0:8080)")]
bind_addr: String,
#[arg(help = "Optional Destination address (e.g., 10.1.1.1:80)")]
dst_addr: Option<String>,
},
/// List port forward rules
List,
}
#[derive(Args, Debug)]
struct WhitelistArgs {
#[command(subcommand)]
sub_command: Option<WhitelistSubCommand>,
}
#[derive(Subcommand, Debug)]
enum WhitelistSubCommand {
/// Set TCP port whitelist
SetTcp {
#[arg(help = "TCP ports (e.g., 80,443,8000-9000)")]
ports: String,
},
/// Set UDP port whitelist
SetUdp {
#[arg(help = "UDP ports (e.g., 53,5000-6000)")]
ports: String,
},
/// Clear TCP whitelist
ClearTcp,
/// Clear UDP whitelist
ClearUdp,
/// Show current whitelist configuration
Show,
}
#[derive(Args, Debug)] #[derive(Args, Debug)]
struct ServiceArgs { struct ServiceArgs {
#[arg(short, long, default_value = env!("CARGO_PKG_NAME"), help = "service name")] #[arg(short, long, default_value = env!("CARGO_PKG_NAME"), help = "service name")]
@@ -340,6 +402,18 @@ impl CommandHandler<'_> {
.with_context(|| "failed to get vpn portal client")?) .with_context(|| "failed to get vpn portal client")?)
} }
async fn get_port_forward_manager_client(
&self,
) -> Result<Box<dyn PortForwardManageRpc<Controller = BaseController>>, Error> {
Ok(self
.client
.lock()
.unwrap()
.scoped_client::<PortForwardManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get port forward manager client")?)
}
async fn list_peers(&self) -> Result<ListPeerResponse, Error> { async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let client = self.get_peer_manager_client().await?; let client = self.get_peer_manager_client().await?;
let request = ListPeerRequest::default(); let request = ListPeerRequest::default();
@@ -788,6 +862,265 @@ impl CommandHandler<'_> {
} }
Ok(url) Ok(url)
} }
async fn handle_port_forward_add(
&self,
protocol: &str,
bind_addr: &str,
dst_addr: &str,
) -> Result<(), Error> {
let bind_addr: std::net::SocketAddr = bind_addr
.parse()
.with_context(|| format!("Invalid bind address: {}", bind_addr))?;
let dst_addr: std::net::SocketAddr = dst_addr
.parse()
.with_context(|| format!("Invalid destination address: {}", dst_addr))?;
if protocol != "tcp" && protocol != "udp" {
return Err(anyhow::anyhow!("Protocol must be 'tcp' or 'udp'"));
}
let client = self.get_port_forward_manager_client().await?;
let request = AddPortForwardRequest {
cfg: Some(
PortForwardConfig {
proto: protocol.to_string(),
bind_addr: bind_addr.into(),
dst_addr: dst_addr.into(),
}
.into(),
),
};
client
.add_port_forward(BaseController::default(), request)
.await?;
println!(
"Port forward rule added: {} {} -> {}",
protocol, bind_addr, dst_addr
);
Ok(())
}
async fn handle_port_forward_remove(
&self,
protocol: &str,
bind_addr: &str,
dst_addr: Option<&str>,
) -> Result<(), Error> {
let bind_addr: std::net::SocketAddr = bind_addr
.parse()
.with_context(|| format!("Invalid bind address: {}", bind_addr))?;
if protocol != "tcp" && protocol != "udp" {
return Err(anyhow::anyhow!("Protocol must be 'tcp' or 'udp'"));
}
let client = self.get_port_forward_manager_client().await?;
let request = RemovePortForwardRequest {
cfg: Some(
PortForwardConfig {
proto: protocol.to_string(),
bind_addr: bind_addr.into(),
dst_addr: dst_addr
.map(|s| s.parse::<SocketAddr>().unwrap())
.map(Into::into)
.unwrap_or("0.0.0.0:0".parse::<SocketAddr>().unwrap().into()),
}
.into(),
),
};
client
.remove_port_forward(BaseController::default(), request)
.await?;
println!("Port forward rule removed: {} {}", protocol, bind_addr);
Ok(())
}
async fn handle_port_forward_list(&self) -> Result<(), Error> {
let client = self.get_port_forward_manager_client().await?;
let request = ListPortForwardRequest::default();
let response = client
.list_port_forward(BaseController::default(), request)
.await?;
if self.verbose || *self.output_format == OutputFormat::Json {
println!("{}", serde_json::to_string_pretty(&response)?);
return Ok(());
}
#[derive(tabled::Tabled, serde::Serialize)]
struct PortForwardTableItem {
protocol: String,
bind_addr: String,
dst_addr: String,
}
let items: Vec<PortForwardTableItem> = response
.cfgs
.into_iter()
.map(|rule| PortForwardTableItem {
protocol: format!(
"{:?}",
SocketType::try_from(rule.socket_type).unwrap_or(SocketType::Tcp)
),
bind_addr: rule
.bind_addr
.map(|addr| addr.to_string())
.unwrap_or_default(),
dst_addr: rule
.dst_addr
.map(|addr| addr.to_string())
.unwrap_or_default(),
})
.collect();
print_output(&items, self.output_format)?;
Ok(())
}
async fn handle_whitelist_set_tcp(&self, ports: &str) -> Result<(), Error> {
let tcp_ports = Self::parse_port_list(ports)?;
let client = self.get_acl_manager_client().await?;
// Get current UDP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports,
udp_ports: current.udp_ports,
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("TCP whitelist updated: {}", ports);
Ok(())
}
async fn handle_whitelist_set_udp(&self, ports: &str) -> Result<(), Error> {
let udp_ports = Self::parse_port_list(ports)?;
let client = self.get_acl_manager_client().await?;
// Get current TCP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports: current.tcp_ports,
udp_ports,
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("UDP whitelist updated: {}", ports);
Ok(())
}
async fn handle_whitelist_clear_tcp(&self) -> Result<(), Error> {
let client = self.get_acl_manager_client().await?;
// Get current UDP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports: vec![],
udp_ports: current.udp_ports,
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("TCP whitelist cleared");
Ok(())
}
async fn handle_whitelist_clear_udp(&self) -> Result<(), Error> {
let client = self.get_acl_manager_client().await?;
// Get current TCP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports: current.tcp_ports,
udp_ports: vec![],
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("UDP whitelist cleared");
Ok(())
}
async fn handle_whitelist_show(&self) -> Result<(), Error> {
let client = self.get_acl_manager_client().await?;
let request = GetWhitelistRequest::default();
let response = client
.get_whitelist(BaseController::default(), request)
.await?;
if self.verbose || *self.output_format == OutputFormat::Json {
println!("{}", serde_json::to_string_pretty(&response)?);
return Ok(());
}
println!(
"TCP Whitelist: {}",
if response.tcp_ports.is_empty() {
"None".to_string()
} else {
response.tcp_ports.join(", ")
}
);
println!(
"UDP Whitelist: {}",
if response.udp_ports.is_empty() {
"None".to_string()
} else {
response.udp_ports.join(", ")
}
);
Ok(())
}
fn parse_port_list(ports_str: &str) -> Result<Vec<String>, Error> {
let mut ports = Vec::new();
for port_spec in ports_str.split(',') {
let port_spec = port_spec.trim();
if port_spec.contains('-') {
// Handle port range
let parts: Vec<&str> = port_spec.split('-').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid port range: {}", port_spec));
}
let start: u16 = parts[0]
.parse()
.with_context(|| format!("Invalid start port: {}", parts[0]))?;
let end: u16 = parts[1]
.parse()
.with_context(|| format!("Invalid end port: {}", parts[1]))?;
if start > end {
return Err(anyhow::anyhow!("Invalid port range: start > end"));
}
ports.push(format!("{}-{}", start, end));
} else {
// Handle single port
let port: u16 = port_spec
.parse()
.with_context(|| format!("Invalid port number: {}", port_spec))?;
ports.push(port.to_string());
}
}
Ok(ports)
}
} }
#[derive(Debug)] #[derive(Debug)]
@@ -1494,6 +1827,46 @@ async fn main() -> Result<(), Error> {
handler.handle_acl_stats().await?; handler.handle_acl_stats().await?;
} }
}, },
SubCommand::PortForward(port_forward_args) => match &port_forward_args.sub_command {
Some(PortForwardSubCommand::Add {
protocol,
bind_addr,
dst_addr,
}) => {
handler
.handle_port_forward_add(protocol, bind_addr, dst_addr)
.await?;
}
Some(PortForwardSubCommand::Remove {
protocol,
bind_addr,
dst_addr,
}) => {
handler
.handle_port_forward_remove(protocol, bind_addr, dst_addr.as_deref())
.await?;
}
Some(PortForwardSubCommand::List) | None => {
handler.handle_port_forward_list().await?;
}
},
SubCommand::Whitelist(whitelist_args) => match &whitelist_args.sub_command {
Some(WhitelistSubCommand::SetTcp { ports }) => {
handler.handle_whitelist_set_tcp(ports).await?;
}
Some(WhitelistSubCommand::SetUdp { ports }) => {
handler.handle_whitelist_set_udp(ports).await?;
}
Some(WhitelistSubCommand::ClearTcp) => {
handler.handle_whitelist_clear_tcp().await?;
}
Some(WhitelistSubCommand::ClearUdp) => {
handler.handle_whitelist_clear_udp().await?;
}
Some(WhitelistSubCommand::Show) | None => {
handler.handle_whitelist_show().await?;
}
},
SubCommand::GenAutocomplete { shell } => { SubCommand::GenAutocomplete { shell } => {
let mut cmd = Cli::command(); let mut cmd = Cli::command();
easytier::print_completions(shell, &mut cmd, "easytier-cli"); easytier::print_completions(shell, &mut cmd, "easytier-cli");
+8 -117
View File
@@ -29,10 +29,7 @@ use easytier::{
connector::create_connector_by_url, connector::create_connector_by_url,
instance_manager::NetworkInstanceManager, instance_manager::NetworkInstanceManager,
launcher::{add_proxy_network_to_config, ConfigSource}, launcher::{add_proxy_network_to_config, ConfigSource},
proto::{ proto::common::{CompressionAlgoPb, NatType},
acl::{Acl, AclV1, Action, Chain, ChainType, Protocol, Rule},
common::{CompressionAlgoPb, NatType},
},
tunnel::{IpVersion, PROTO_PORT_OFFSET}, tunnel::{IpVersion, PROTO_PORT_OFFSET},
utils::{init_logger, setup_panic_handler}, utils::{init_logger, setup_panic_handler},
web_client, web_client,
@@ -622,115 +619,6 @@ impl NetworkOptions {
false false
} }
fn parse_port_list(port_list: &[String]) -> anyhow::Result<Vec<String>> {
let mut ports = Vec::new();
for port_spec in port_list {
if port_spec.contains('-') {
// Handle port range like "8000-9000"
let parts: Vec<&str> = port_spec.split('-').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec));
}
let start: u16 = parts[0]
.parse()
.with_context(|| format!("Invalid start port in range: {}", port_spec))?;
let end: u16 = parts[1]
.parse()
.with_context(|| format!("Invalid end port in range: {}", port_spec))?;
if start > end {
return Err(anyhow::anyhow!(
"Start port must be <= end port in range: {}",
port_spec
));
}
// acl can handle port range
ports.push(port_spec.clone());
} else {
// Handle single port
let port: u16 = port_spec
.parse()
.with_context(|| format!("Invalid port number: {}", port_spec))?;
ports.push(port.to_string());
}
}
Ok(ports)
}
fn generate_acl_from_whitelists(&self) -> anyhow::Result<Option<Acl>> {
if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() {
return Ok(None);
}
let mut acl = Acl {
acl_v1: Some(AclV1 { chains: vec![] }),
};
let acl_v1 = acl.acl_v1.as_mut().unwrap();
// Create inbound chain for whitelist rules
let mut inbound_chain = Chain {
name: "inbound_whitelist".to_string(),
chain_type: ChainType::Inbound as i32,
description: "Auto-generated inbound whitelist from CLI".to_string(),
enabled: true,
rules: vec![],
default_action: Action::Drop as i32, // Default deny
};
let mut rule_priority = 1000u32;
// Add TCP whitelist rules
if !self.tcp_whitelist.is_empty() {
let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?;
let tcp_rule = Rule {
name: "tcp_whitelist".to_string(),
description: "Auto-generated TCP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Tcp as i32,
ports: tcp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: true,
};
inbound_chain.rules.push(tcp_rule);
rule_priority -= 1;
}
// Add UDP whitelist rules
if !self.udp_whitelist.is_empty() {
let udp_ports = Self::parse_port_list(&self.udp_whitelist)?;
let udp_rule = Rule {
name: "udp_whitelist".to_string(),
description: "Auto-generated UDP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Udp as i32,
ports: udp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: false,
};
inbound_chain.rules.push(udp_rule);
}
acl_v1.chains.push(inbound_chain);
Ok(Some(acl))
}
fn merge_into(&self, cfg: &mut TomlConfigLoader) -> anyhow::Result<()> { fn merge_into(&self, cfg: &mut TomlConfigLoader) -> anyhow::Result<()> {
if self.hostname.is_some() { if self.hostname.is_some() {
cfg.set_hostname(self.hostname.clone()); cfg.set_hostname(self.hostname.clone());
@@ -988,10 +876,13 @@ impl NetworkOptions {
cfg.set_exit_nodes(self.exit_nodes.clone()); cfg.set_exit_nodes(self.exit_nodes.clone());
} }
// Handle port whitelists by generating ACL configuration let mut old_tcp_whitelist = cfg.get_tcp_whitelist();
if let Some(acl) = self.generate_acl_from_whitelists()? { old_tcp_whitelist.extend(self.tcp_whitelist.clone());
cfg.set_acl(Some(acl)); cfg.set_tcp_whitelist(old_tcp_whitelist);
}
let mut old_udp_whitelist = cfg.get_udp_whitelist();
old_udp_whitelist.extend(self.udp_whitelist.clone());
cfg.set_udp_whitelist(old_udp_whitelist);
Ok(()) Ok(())
} }
+72 -28
View File
@@ -6,6 +6,7 @@ use std::{
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream}; use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream};
use tokio_util::sync::{CancellationToken, DropGuard};
use crate::{ use crate::{
common::{ common::{
@@ -432,6 +433,8 @@ pub struct Socks5Server {
udp_forward_task: Arc<DashMap<UdpClientKey, ScopedTask<()>>>, udp_forward_task: Arc<DashMap<UdpClientKey, ScopedTask<()>>>,
kcp_endpoint: Mutex<Option<Weak<KcpEndpoint>>>, kcp_endpoint: Mutex<Option<Weak<KcpEndpoint>>>,
cancel_tokens: DashMap<PortForwardConfig, DropGuard>,
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@@ -531,6 +534,8 @@ impl Socks5Server {
udp_forward_task: Arc::new(DashMap::new()), udp_forward_task: Arc::new(DashMap::new()),
kcp_endpoint: Mutex::new(None), kcp_endpoint: Mutex::new(None),
cancel_tokens: DashMap::new(),
}) })
} }
@@ -614,10 +619,9 @@ impl Socks5Server {
need_start = true; need_start = true;
}; };
for port_forward in self.global_ctx.config.get_port_forwards() { let cfgs = self.global_ctx.config.get_port_forwards();
self.add_port_forward(port_forward).await?; self.reload_port_forwards(&cfgs).await?;
need_start = true; need_start = need_start || cfgs.len() > 0;
}
if need_start { if need_start {
self.peer_manager self.peer_manager
@@ -630,6 +634,26 @@ impl Socks5Server {
Ok(()) Ok(())
} }
pub async fn reload_port_forwards(&self, cfgs: &Vec<PortForwardConfig>) -> Result<(), Error> {
// remove entries not in new cfg
self.cancel_tokens.retain(|k, _| {
cfgs.iter().any(|cfg| {
if cfg.dst_addr.ip().is_unspecified() {
k.bind_addr == cfg.bind_addr && k.proto == cfg.proto
} else {
k == cfg
}
})
});
// add new ones
for cfg in cfgs {
if !self.cancel_tokens.contains_key(cfg) {
self.add_port_forward(cfg.clone()).await?;
}
}
Ok(())
}
async fn handle_port_forward_connection( async fn handle_port_forward_connection(
mut incoming_socket: tokio::net::TcpStream, mut incoming_socket: tokio::net::TcpStream,
connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send>, connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send>,
@@ -660,12 +684,10 @@ impl Socks5Server {
pub async fn add_port_forward(&self, cfg: PortForwardConfig) -> Result<(), Error> { pub async fn add_port_forward(&self, cfg: PortForwardConfig) -> Result<(), Error> {
match cfg.proto.to_lowercase().as_str() { match cfg.proto.to_lowercase().as_str() {
"tcp" => { "tcp" => {
self.add_tcp_port_forward(cfg.bind_addr, cfg.dst_addr) self.add_tcp_port_forward(&cfg).await?;
.await?;
} }
"udp" => { "udp" => {
self.add_udp_port_forward(cfg.bind_addr, cfg.dst_addr) self.add_udp_port_forward(&cfg).await?;
.await?;
} }
_ => { _ => {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
@@ -680,11 +702,12 @@ impl Socks5Server {
Ok(()) Ok(())
} }
pub async fn add_tcp_port_forward( pub fn remove_port_forward(&self, cfg: PortForwardConfig) {
&self, let _ = self.cancel_tokens.remove(&cfg);
bind_addr: SocketAddr, }
dst_addr: SocketAddr,
) -> Result<(), Error> { pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?; let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?;
let net = self.net.clone(); let net = self.net.clone();
@@ -693,14 +716,26 @@ impl Socks5Server {
let forward_tasks = tasks.clone(); let forward_tasks = tasks.clone();
let kcp_endpoint = self.kcp_endpoint.lock().await.clone(); let kcp_endpoint = self.kcp_endpoint.lock().await.clone();
let peer_mgr = Arc::downgrade(&self.peer_manager.clone()); let peer_mgr = Arc::downgrade(&self.peer_manager.clone());
let cancel_token = CancellationToken::new();
self.cancel_tokens
.insert(cfg.clone(), cancel_token.clone().drop_guard());
self.tasks.lock().unwrap().spawn(async move { self.tasks.lock().unwrap().spawn(async move {
loop { loop {
let (incoming_socket, addr) = match listener.accept().await { let (incoming_socket, addr) = select! {
Ok(result) => result, biased;
Err(err) => { _ = cancel_token.cancelled() => {
tracing::error!("port forward accept error = {:?}", err); tracing::info!("port forward for {:?} cancelled", bind_addr);
continue; break;
}
res = listener.accept() => {
match res {
Ok(result) => result,
Err(err) => {
tracing::error!("port forward accept error = {:?}", err);
continue;
}
}
} }
}; };
@@ -747,11 +782,8 @@ impl Socks5Server {
} }
#[tracing::instrument(name = "add_udp_port_forward", skip(self))] #[tracing::instrument(name = "add_udp_port_forward", skip(self))]
pub async fn add_udp_port_forward( pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
&self, let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
bind_addr: SocketAddr,
dst_addr: SocketAddr,
) -> Result<(), Error> {
let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?); let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?);
let entries = self.entries.clone(); let entries = self.entries.clone();
@@ -759,16 +791,28 @@ impl Socks5Server {
let net = self.net.clone(); let net = self.net.clone();
let udp_client_map = self.udp_client_map.clone(); let udp_client_map = self.udp_client_map.clone();
let udp_forward_task = self.udp_forward_task.clone(); let udp_forward_task = self.udp_forward_task.clone();
let cancel_token = CancellationToken::new();
self.cancel_tokens
.insert(cfg.clone(), cancel_token.clone().drop_guard());
self.tasks.lock().unwrap().spawn(async move { self.tasks.lock().unwrap().spawn(async move {
loop { loop {
// we set the max buffer size of smoltcp to 8192, so we need to use a buffer size that is less than 8192 here. // we set the max buffer size of smoltcp to 8192, so we need to use a buffer size that is less than 8192 here.
let mut buf = vec![0u8; 8192]; let mut buf = vec![0u8; 8192];
let (len, addr) = match socket.recv_from(&mut buf).await { let (len, addr) = select! {
Ok(result) => result, biased;
Err(err) => { _ = cancel_token.cancelled() => {
tracing::error!("udp port forward recv error = {:?}", err); tracing::info!("udp port forward for {:?} cancelled", bind_addr);
continue; break;
}
res = socket.recv_from(&mut buf) => {
match res {
Ok(result) => result,
Err(err) => {
tracing::error!("udp port forward recv error = {:?}", err);
continue;
}
}
} }
}; };
+95 -8
View File
@@ -10,6 +10,7 @@ use cidr::{IpCidr, Ipv4Inet};
use tokio::{sync::Mutex, task::JoinSet}; use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::common::acl_processor::AclRuleBuilder;
use crate::common::config::ConfigLoader; use crate::common::config::ConfigLoader;
use crate::common::error::Error; use crate::common::error::Error;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent}; use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent};
@@ -29,13 +30,15 @@ use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
use crate::peers::rpc_service::PeerManagerRpcService; use crate::peers::rpc_service::PeerManagerRpcService;
use crate::peers::{create_packet_recv_chan, recv_packet_from_chan, PacketRecvChanReceiver}; use crate::peers::{create_packet_recv_chan, recv_packet_from_chan, PacketRecvChanReceiver};
use crate::proto::cli::VpnPortalRpc; use crate::proto::cli::VpnPortalRpc;
use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::proto::cli::{ use crate::proto::cli::{
ListMappedListenerRequest, ListMappedListenerResponse, ManageMappedListenerRequest, AddPortForwardRequest, AddPortForwardResponse, ListMappedListenerRequest,
ManageMappedListenerResponse, MappedListener, MappedListenerManageAction, ListMappedListenerResponse, ListPortForwardRequest, ListPortForwardResponse,
MappedListenerManageRpc, ManageMappedListenerRequest, ManageMappedListenerResponse, MappedListener,
MappedListenerManageAction, MappedListenerManageRpc, PortForwardManageRpc,
RemovePortForwardRequest, RemovePortForwardResponse,
}; };
use crate::proto::common::TunnelInfo; use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::proto::common::{PortForwardConfigPb, TunnelInfo};
use crate::proto::peer_rpc::PeerCenterRpcServer; use crate::proto::peer_rpc::PeerCenterRpcServer;
use crate::proto::rpc_impl::standalone::{RpcServerHook, StandAloneServer}; use crate::proto::rpc_impl::standalone::{RpcServerHook, StandAloneServer};
use crate::proto::rpc_types; use crate::proto::rpc_types;
@@ -609,9 +612,9 @@ impl Instance {
} }
} }
if let Some(acl) = self.global_ctx.config.get_acl() { self.global_ctx
self.global_ctx.get_acl_filter().reload_rules(Some(&acl)); .get_acl_filter()
} .reload_rules(AclRuleBuilder::build(&self.global_ctx)?.as_ref());
// run after tun device created, so listener can bind to tun device, which may be required by win 10 // run after tun device created, so listener can bind to tun device, which may be required by win 10
self.ip_proxy = Some(IpProxy::new( self.ip_proxy = Some(IpProxy::new(
@@ -790,6 +793,85 @@ impl Instance {
MappedListenerManagerRpcService(self.global_ctx.clone()) MappedListenerManagerRpcService(self.global_ctx.clone())
} }
fn get_port_forward_manager_rpc_service(
&self,
) -> impl PortForwardManageRpc<Controller = BaseController> + Clone {
#[derive(Clone)]
pub struct PortForwardManagerRpcService {
global_ctx: ArcGlobalCtx,
socks5_server: Weak<Socks5Server>,
}
#[async_trait::async_trait]
impl PortForwardManageRpc for PortForwardManagerRpcService {
type Controller = BaseController;
async fn add_port_forward(
&self,
_: BaseController,
request: AddPortForwardRequest,
) -> Result<AddPortForwardResponse, rpc_types::error::Error> {
let Some(socks5_server) = self.socks5_server.upgrade() else {
return Err(anyhow::anyhow!("socks5 server not available").into());
};
if let Some(cfg) = request.cfg {
tracing::info!("Port forward rule added: {:?}", cfg);
let mut current_forwards = self.global_ctx.config.get_port_forwards();
current_forwards.push(cfg.into());
self.global_ctx
.config
.set_port_forwards(current_forwards.clone());
socks5_server
.reload_port_forwards(&current_forwards)
.await
.with_context(|| "Failed to reload port forwards")?;
}
Ok(AddPortForwardResponse {})
}
async fn remove_port_forward(
&self,
_: BaseController,
request: RemovePortForwardRequest,
) -> Result<RemovePortForwardResponse, rpc_types::error::Error> {
let Some(socks5_server) = self.socks5_server.upgrade() else {
return Err(anyhow::anyhow!("socks5 server not available").into());
};
let Some(cfg) = request.cfg else {
return Err(anyhow::anyhow!("port forward config is empty").into());
};
let cfg = cfg.into();
let mut current_forwards = self.global_ctx.config.get_port_forwards();
current_forwards.retain(|e| *e != cfg);
self.global_ctx
.config
.set_port_forwards(current_forwards.clone());
socks5_server
.reload_port_forwards(&current_forwards)
.await
.with_context(|| "Failed to reload port forwards")?;
tracing::info!("Port forward rule removed: {:?}", cfg);
Ok(RemovePortForwardResponse {})
}
async fn list_port_forward(
&self,
_: BaseController,
_request: ListPortForwardRequest,
) -> Result<ListPortForwardResponse, rpc_types::error::Error> {
let forwards = self.global_ctx.config.get_port_forwards();
let cfgs: Vec<PortForwardConfigPb> = forwards.into_iter().map(Into::into).collect();
Ok(ListPortForwardResponse { cfgs })
}
}
PortForwardManagerRpcService {
global_ctx: self.global_ctx.clone(),
socks5_server: Arc::downgrade(&self.socks5_server),
}
}
async fn run_rpc_server(&mut self) -> Result<(), Error> { async fn run_rpc_server(&mut self) -> Result<(), Error> {
let Some(_) = self.global_ctx.config.get_rpc_portal() else { let Some(_) = self.global_ctx.config.get_rpc_portal() else {
tracing::info!("rpc server not enabled, because rpc_portal is not set."); tracing::info!("rpc server not enabled, because rpc_portal is not set.");
@@ -803,6 +885,7 @@ impl Instance {
let peer_center = self.peer_center.clone(); let peer_center = self.peer_center.clone();
let vpn_portal_rpc = self.get_vpn_portal_rpc_service(); let vpn_portal_rpc = self.get_vpn_portal_rpc_service();
let mapped_listener_manager_rpc = self.get_mapped_listener_manager_rpc_service(); let mapped_listener_manager_rpc = self.get_mapped_listener_manager_rpc_service();
let port_forward_manager_rpc = self.get_port_forward_manager_rpc_service();
let s = self.rpc_server.as_mut().unwrap(); let s = self.rpc_server.as_mut().unwrap();
let peer_mgr_rpc_service = PeerManagerRpcService::new(peer_mgr.clone()); let peer_mgr_rpc_service = PeerManagerRpcService::new(peer_mgr.clone());
@@ -823,6 +906,10 @@ impl Instance {
MappedListenerManageRpcServer::new(mapped_listener_manager_rpc), MappedListenerManageRpcServer::new(mapped_listener_manager_rpc),
"", "",
); );
s.registry().register(
PortForwardManageRpcServer::new(port_forward_manager_rpc),
"",
);
if let Some(ip_proxy) = self.ip_proxy.as_ref() { if let Some(ip_proxy) = self.ip_proxy.as_ref() {
s.registry().register( s.registry().register(
+53 -7
View File
@@ -1,13 +1,18 @@
use std::sync::Arc; use std::sync::Arc;
use crate::proto::{ use crate::{
cli::{ common::acl_processor::AclRuleBuilder,
AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest, GetAclStatsResponse, proto::{
ListForeignNetworkRequest, ListForeignNetworkResponse, ListGlobalForeignNetworkRequest, cli::{
ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest, AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest,
ListRouteResponse, PeerInfo, PeerManageRpc, ShowNodeInfoRequest, ShowNodeInfoResponse, GetAclStatsResponse, GetWhitelistRequest, GetWhitelistResponse,
ListForeignNetworkRequest, ListForeignNetworkResponse, ListGlobalForeignNetworkRequest,
ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest,
ListRouteResponse, PeerInfo, PeerManageRpc, SetWhitelistRequest, SetWhitelistResponse,
ShowNodeInfoRequest, ShowNodeInfoResponse,
},
rpc_types::{self, controller::BaseController},
}, },
rpc_types::{self, controller::BaseController},
}; };
use super::peer_manager::PeerManager; use super::peer_manager::PeerManager;
@@ -153,4 +158,45 @@ impl AclManageRpc for PeerManagerRpcService {
acl_stats: Some(acl_stats), acl_stats: Some(acl_stats),
}) })
} }
async fn set_whitelist(
&self,
_: BaseController,
request: SetWhitelistRequest,
) -> Result<SetWhitelistResponse, rpc_types::error::Error> {
tracing::info!(
"Setting whitelist - TCP: {:?}, UDP: {:?}",
request.tcp_ports,
request.udp_ports
);
let global_ctx = self.peer_manager.get_global_ctx();
global_ctx.config.set_tcp_whitelist(request.tcp_ports);
global_ctx.config.set_udp_whitelist(request.udp_ports);
global_ctx
.get_acl_filter()
.reload_rules(AclRuleBuilder::build(&global_ctx)?.as_ref());
Ok(SetWhitelistResponse {})
}
async fn get_whitelist(
&self,
_: BaseController,
_request: GetWhitelistRequest,
) -> Result<GetWhitelistResponse, rpc_types::error::Error> {
let global_ctx = self.peer_manager.get_global_ctx();
let tcp_ports = global_ctx.config.get_tcp_whitelist();
let udp_ports = global_ctx.config.get_udp_whitelist();
tracing::info!(
"Getting whitelist - TCP: {:?}, UDP: {:?}",
tcp_ports,
udp_ports
);
Ok(GetWhitelistResponse {
tcp_ports,
udp_ports,
})
}
} }
+40
View File
@@ -261,4 +261,44 @@ message GetAclStatsResponse {
service AclManageRpc { service AclManageRpc {
rpc GetAclStats(GetAclStatsRequest) returns (GetAclStatsResponse); rpc GetAclStats(GetAclStatsRequest) returns (GetAclStatsResponse);
rpc SetWhitelist(SetWhitelistRequest) returns (SetWhitelistResponse);
rpc GetWhitelist(GetWhitelistRequest) returns (GetWhitelistResponse);
}
message SetWhitelistRequest {
repeated string tcp_ports = 1;
repeated string udp_ports = 2;
}
message SetWhitelistResponse {}
message GetWhitelistRequest {}
message GetWhitelistResponse {
repeated string tcp_ports = 1;
repeated string udp_ports = 2;
}
message AddPortForwardRequest {
common.PortForwardConfigPb cfg = 1;
}
message AddPortForwardResponse {}
message RemovePortForwardRequest {
common.PortForwardConfigPb cfg = 1;
}
message RemovePortForwardResponse {}
message ListPortForwardRequest {}
message ListPortForwardResponse {
repeated common.PortForwardConfigPb cfgs = 1;
}
service PortForwardManageRpc {
rpc AddPortForward(AddPortForwardRequest) returns (AddPortForwardResponse);
rpc RemovePortForward(RemovePortForwardRequest) returns (RemovePortForwardResponse);
rpc ListPortForward(ListPortForwardRequest) returns (ListPortForwardResponse);
} }