From 5b35c51da99a6744a86df81008486c367f047da2 Mon Sep 17 00:00:00 2001 From: KKRainbow <443152178@qq.com> Date: Mon, 13 Apr 2026 11:03:09 +0800 Subject: [PATCH] fix packet split on udp tunnel and avoid tcp proxy access rpc portal (#2107) * distinct control / data when forward packets * fix rpc split for udp tunnel * feat(easytier-web): pass public ip in validate token webhook * protect rpc port from subnet proxy --- easytier-web/src/client_manager/session.rs | 1 + easytier-web/src/webhook.rs | 1 + easytier/src/common/global_ctx.rs | 19 ++ easytier/src/common/stats_manager.rs | 10 + easytier/src/peers/foreign_network_manager.rs | 104 ++++++- easytier/src/peers/peer_manager.rs | 162 ++++++++++- easytier/src/peers/traffic_metrics.rs | 38 +-- easytier/src/proto/rpc_impl/packet.rs | 266 +++++++++++++++--- easytier/src/rpc_service/api.rs | 18 +- easytier/src/rpc_service/mod.rs | 1 + easytier/src/rpc_service/protected_port.rs | 61 ++++ 11 files changed, 602 insertions(+), 79 deletions(-) create mode 100644 easytier/src/rpc_service/protected_port.rs diff --git a/easytier-web/src/client_manager/session.rs b/easytier-web/src/client_manager/session.rs index 3a2e527f..ab0f4fb6 100644 --- a/easytier-web/src/client_manager/session.rs +++ b/easytier-web/src/client_manager/session.rs @@ -233,6 +233,7 @@ impl SessionRpcService { let webhook_req = crate::webhook::ValidateTokenRequest { token: req.user_token.clone(), machine_id: machine_id.to_string(), + public_ip: data.client_url.host_str().map(str::to_string), hostname: req.hostname.clone(), version: req.easytier_version.clone(), os_type: req.device_os.as_ref().map(|info| info.os_type.clone()), diff --git a/easytier-web/src/webhook.rs b/easytier-web/src/webhook.rs index f93a14ea..be252b11 100644 --- a/easytier-web/src/webhook.rs +++ b/easytier-web/src/webhook.rs @@ -49,6 +49,7 @@ impl WebhookConfig { pub struct ValidateTokenRequest { pub token: String, pub machine_id: String, + pub public_ip: Option, pub hostname: String, pub version: String, pub os_type: Option, diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index ed3a2286..b6af900b 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -28,6 +28,7 @@ use crate::{ common::{PeerFeatureFlag, PortForwardConfigPb}, peer_rpc::PeerGroupInfo, }, + rpc_service::protected_port, tunnel::matches_protocol, }; use crossbeam::atomic::AtomicCell; @@ -658,6 +659,7 @@ impl GlobalCtx { if dst_is_local_virtual_ip || dst_is_local_phy_ip { // if is local ip, make sure the port is not one of the listening ports self.is_port_in_running_listeners(dst_addr.port(), is_udp) + || (!is_udp && protected_port::is_protected_tcp_port(dst_addr.port())) } else { false } @@ -765,6 +767,23 @@ pub mod tests { assert!(feature_flags.is_public_server); } + #[tokio::test] + async fn should_deny_proxy_for_process_wide_rpc_port() { + protected_port::clear_protected_tcp_ports_for_test(); + protected_port::register_protected_tcp_port(15888); + + let config = TomlConfigLoader::default(); + let global_ctx = GlobalCtx::new(config); + let rpc_addr = SocketAddr::from(([127, 0, 0, 1], 15888)); + let other_tcp_addr = SocketAddr::from(([127, 0, 0, 1], 15889)); + + assert!(global_ctx.should_deny_proxy(&rpc_addr, false)); + assert!(!global_ctx.should_deny_proxy(&rpc_addr, true)); + assert!(!global_ctx.should_deny_proxy(&other_tcp_addr, false)); + + protected_port::clear_protected_tcp_ports_for_test(); + } + pub fn get_mock_global_ctx_with_network( network_identy: Option, ) -> ArcGlobalCtx { diff --git a/easytier/src/common/stats_manager.rs b/easytier/src/common/stats_manager.rs index d5355759..01f018f1 100644 --- a/easytier/src/common/stats_manager.rs +++ b/easytier/src/common/stats_manager.rs @@ -42,6 +42,8 @@ pub enum MetricName { TrafficControlBytesRxByInstance, /// Traffic bytes forwarded TrafficBytesForwarded, + /// Control-plane traffic bytes forwarded + TrafficControlBytesForwarded, /// Traffic bytes sent to self TrafficBytesSelfTx, /// Traffic bytes received from self @@ -71,6 +73,8 @@ pub enum MetricName { TrafficControlPacketsRxByInstance, /// Traffic packets forwarded TrafficPacketsForwarded, + /// Control-plane traffic packets forwarded + TrafficControlPacketsForwarded, /// Traffic packets sent to self TrafficPacketsSelfTx, /// Traffic packets received from self @@ -117,6 +121,9 @@ impl fmt::Display for MetricName { write!(f, "traffic_control_bytes_rx_by_instance") } MetricName::TrafficBytesForwarded => write!(f, "traffic_bytes_forwarded"), + MetricName::TrafficControlBytesForwarded => { + write!(f, "traffic_control_bytes_forwarded") + } MetricName::TrafficBytesSelfTx => write!(f, "traffic_bytes_self_tx"), MetricName::TrafficBytesSelfRx => write!(f, "traffic_bytes_self_rx"), MetricName::TrafficBytesForeignForwardRx => { @@ -146,6 +153,9 @@ impl fmt::Display for MetricName { write!(f, "traffic_control_packets_rx_by_instance") } MetricName::TrafficPacketsForwarded => write!(f, "traffic_packets_forwarded"), + MetricName::TrafficControlPacketsForwarded => { + write!(f, "traffic_control_packets_forwarded") + } MetricName::TrafficPacketsSelfTx => write!(f, "traffic_packets_self_tx"), MetricName::TrafficPacketsSelfRx => write!(f, "traffic_packets_self_rx"), MetricName::TrafficPacketsForeignForwardRx => { diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index 3edd9603..c9693e16 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -55,8 +55,8 @@ use super::{ relay_peer_map::RelayPeerMap, route_trait::NextHopPolicy, traffic_metrics::{ - InstanceLabelKind, LogicalTrafficMetrics, TrafficMetricRecorder, - route_peer_info_instance_id, + InstanceLabelKind, LogicalTrafficMetrics, TrafficKind, TrafficMetricRecorder, + route_peer_info_instance_id, traffic_kind, }, }; @@ -419,12 +419,19 @@ impl ForeignNetworkEntry { let label_set = LabelSet::new().with_label_type(LabelType::NetworkName(network_name.clone())); - let forward_bytes = self + let forward_data_bytes = self .stats_mgr .get_counter(MetricName::TrafficBytesForwarded, label_set.clone()); - let forward_packets = self + let forward_data_packets = self .stats_mgr .get_counter(MetricName::TrafficPacketsForwarded, label_set.clone()); + let forward_control_bytes = self + .stats_mgr + .get_counter(MetricName::TrafficControlBytesForwarded, label_set.clone()); + let forward_control_packets = self.stats_mgr.get_counter( + MetricName::TrafficControlPacketsForwarded, + label_set.clone(), + ); let rx_bytes = self .stats_mgr .get_counter(MetricName::TrafficBytesSelfRx, label_set.clone()); @@ -502,8 +509,16 @@ impl ForeignNetworkEntry { } } - forward_bytes.add(buf_len as u64); - forward_packets.inc(); + match traffic_kind(packet_type) { + TrafficKind::Data => { + forward_data_bytes.add(buf_len as u64); + forward_data_packets.inc(); + } + TrafficKind::Control => { + forward_control_bytes.add(buf_len as u64); + forward_control_packets.inc(); + } + } let gateway_peer_id = peer_map .get_gateway_peer_id(to_peer_id, NextHopPolicy::LeastHop) @@ -1293,6 +1308,11 @@ pub mod tests { MetricName::TrafficBytesForwarded, network_labels.clone(), ); + let forwarded_packets_before = metric_value( + &pm_center, + MetricName::TrafficPacketsForwarded, + network_labels.clone(), + ); let rx_bytes_before = metric_value( &pm_center, MetricName::TrafficBytesRx, @@ -1320,6 +1340,7 @@ pub mod tests { pmb_net1.my_peer_id(), PacketType::Data as u8, ); + let transit_pkt_len = transit_pkt.buf_len() as u64; pma_net1 .get_foreign_network_client() .send_msg(transit_pkt, center_peer_id) @@ -1334,7 +1355,12 @@ pub mod tests { &pm_center, MetricName::TrafficBytesForwarded, network_labels.clone(), - ) > forwarded_bytes_before + ) >= forwarded_bytes_before + transit_pkt_len + && metric_value( + &pm_center, + MetricName::TrafficPacketsForwarded, + network_labels.clone(), + ) > forwarded_packets_before } }, Duration::from_secs(5), @@ -1371,6 +1397,70 @@ pub mod tests { ); } + #[tokio::test] + async fn foreign_network_transit_control_forwarding_records_control_forwarded_metrics() { + let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await; + let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await; + + connect_peer_manager(pma_net1.clone(), pm_center.clone()).await; + connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await; + wait_route_appear(pma_net1.clone(), pmb_net1.clone()) + .await + .unwrap(); + + let center_peer_id = pm_center + .get_foreign_network_manager() + .get_network_peer_id("net1") + .unwrap(); + let network_labels = + LabelSet::new().with_label_type(LabelType::NetworkName("net1".to_string())); + let forwarded_bytes_before = metric_value( + &pm_center, + MetricName::TrafficControlBytesForwarded, + network_labels.clone(), + ); + let forwarded_packets_before = metric_value( + &pm_center, + MetricName::TrafficControlPacketsForwarded, + network_labels.clone(), + ); + + let mut transit_pkt = ZCPacket::new_with_payload(b"foreign-control-transit"); + transit_pkt.fill_peer_manager_hdr( + pma_net1.my_peer_id(), + pmb_net1.my_peer_id(), + PacketType::RpcReq as u8, + ); + let transit_pkt_len = transit_pkt.buf_len() as u64; + pma_net1 + .get_foreign_network_client() + .send_msg(transit_pkt, center_peer_id) + .await + .unwrap(); + + wait_for_condition( + || { + let pm_center = pm_center.clone(); + let network_labels = network_labels.clone(); + async move { + metric_value( + &pm_center, + MetricName::TrafficControlBytesForwarded, + network_labels.clone(), + ) >= forwarded_bytes_before + transit_pkt_len + && metric_value( + &pm_center, + MetricName::TrafficControlPacketsForwarded, + network_labels.clone(), + ) > forwarded_packets_before + } + }, + Duration::from_secs(5), + ) + .await; + } + #[tokio::test] async fn foreign_network_encapsulated_forwarding_records_tx_metrics() { set_global_var!(OSPF_UPDATE_MY_GLOBAL_FOREIGN_NETWORK_INTERVAL_SEC, 1); diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index e428aeb8..c2dcbc0f 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -37,8 +37,8 @@ use crate::{ recv_packet_from_chan, route_trait::{ForeignNetworkRouteInfoMap, MockRoute, NextHopPolicy, RouteInterface}, traffic_metrics::{ - InstanceLabelKind, LogicalTrafficMetrics, TrafficMetricRecorder, - route_peer_info_instance_id, + InstanceLabelKind, LogicalTrafficMetrics, TrafficKind, TrafficMetricRecorder, + route_peer_info_instance_id, traffic_kind, }, }, proto::{ @@ -888,10 +888,16 @@ impl PeerManager { stats_mgr.get_counter(MetricName::TrafficBytesSelfRx, label_set.clone()); let self_rx_packets = stats_mgr.get_counter(MetricName::TrafficPacketsSelfRx, label_set.clone()); - let forward_tx_bytes = + let forward_data_tx_bytes = stats_mgr.get_counter(MetricName::TrafficBytesForwarded, label_set.clone()); - let forward_tx_packets = + let forward_data_tx_packets = stats_mgr.get_counter(MetricName::TrafficPacketsForwarded, label_set.clone()); + let forward_control_tx_bytes = + stats_mgr.get_counter(MetricName::TrafficControlBytesForwarded, label_set.clone()); + let forward_control_tx_packets = stats_mgr.get_counter( + MetricName::TrafficControlPacketsForwarded, + label_set.clone(), + ); let compress_tx_bytes_before = self.self_tx_counters.compress_tx_bytes_before.clone(); let compress_tx_bytes_after = self.self_tx_counters.compress_tx_bytes_after.clone(); @@ -966,8 +972,16 @@ impl PeerManager { self_tx_bytes.add(ret.buf_len() as u64); self_tx_packets.inc(); } else { - forward_tx_bytes.add(buf_len as u64); - forward_tx_packets.inc(); + match traffic_kind(packet_type) { + TrafficKind::Data => { + forward_data_tx_bytes.add(buf_len as u64); + forward_data_tx_packets.inc(); + } + TrafficKind::Control => { + forward_control_tx_bytes.add(buf_len as u64); + forward_control_tx_packets.inc(); + } + } } tracing::trace!(?to_peer_id, ?my_peer_id, "need forward"); @@ -2053,6 +2067,12 @@ mod tests { .unwrap_or(0) } + fn network_labels(peer_mgr: &PeerManager) -> LabelSet { + LabelSet::new().with_label_type(LabelType::NetworkName( + peer_mgr.get_global_ctx().get_network_name(), + )) + } + #[test] fn recent_traffic_fanout_policy_only_marks_single_peer() { assert!(PeerManager::should_mark_recent_traffic_for_fanout(0)); @@ -2439,6 +2459,136 @@ mod tests { ); } + #[tokio::test] + async fn send_msg_internal_records_data_forwarded_metrics_for_transit_peer() { + let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + + connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await; + connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await; + wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone()) + .await + .unwrap(); + + let b_network_labels = network_labels(&peer_mgr_b); + let forwarded_bytes_before = metric_value( + &peer_mgr_b, + MetricName::TrafficBytesForwarded, + &b_network_labels, + ); + let forwarded_packets_before = metric_value( + &peer_mgr_b, + MetricName::TrafficPacketsForwarded, + &b_network_labels, + ); + + let mut pkt = ZCPacket::new_with_payload(b"forward-data"); + pkt.fill_peer_manager_hdr( + peer_mgr_a.my_peer_id(), + peer_mgr_c.my_peer_id(), + PacketType::Data as u8, + ); + let pkt_len = pkt.buf_len() as u64; + + PeerManager::send_msg_internal( + &peer_mgr_a.peers, + &peer_mgr_a.foreign_network_client, + &peer_mgr_a.relay_peer_map, + Some(&peer_mgr_a.traffic_metrics), + pkt, + peer_mgr_c.my_peer_id(), + ) + .await + .unwrap(); + + wait_for_condition( + || { + let peer_mgr_b = peer_mgr_b.clone(); + let b_network_labels = b_network_labels.clone(); + async move { + metric_value( + &peer_mgr_b, + MetricName::TrafficBytesForwarded, + &b_network_labels, + ) >= forwarded_bytes_before + pkt_len + && metric_value( + &peer_mgr_b, + MetricName::TrafficPacketsForwarded, + &b_network_labels, + ) > forwarded_packets_before + } + }, + Duration::from_secs(5), + ) + .await; + } + + #[tokio::test] + async fn send_msg_internal_records_control_forwarded_metrics_for_transit_peer() { + let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + + connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await; + connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await; + wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone()) + .await + .unwrap(); + + let b_network_labels = network_labels(&peer_mgr_b); + let forwarded_bytes_before = metric_value( + &peer_mgr_b, + MetricName::TrafficControlBytesForwarded, + &b_network_labels, + ); + let forwarded_packets_before = metric_value( + &peer_mgr_b, + MetricName::TrafficControlPacketsForwarded, + &b_network_labels, + ); + + let mut pkt = ZCPacket::new_with_payload(b"forward-control"); + pkt.fill_peer_manager_hdr( + peer_mgr_a.my_peer_id(), + peer_mgr_c.my_peer_id(), + PacketType::RpcReq as u8, + ); + let pkt_len = pkt.buf_len() as u64; + + PeerManager::send_msg_internal( + &peer_mgr_a.peers, + &peer_mgr_a.foreign_network_client, + &peer_mgr_a.relay_peer_map, + Some(&peer_mgr_a.traffic_metrics), + pkt, + peer_mgr_c.my_peer_id(), + ) + .await + .unwrap(); + + wait_for_condition( + || { + let peer_mgr_b = peer_mgr_b.clone(); + let b_network_labels = b_network_labels.clone(); + async move { + metric_value( + &peer_mgr_b, + MetricName::TrafficControlBytesForwarded, + &b_network_labels, + ) >= forwarded_bytes_before + pkt_len + && metric_value( + &peer_mgr_b, + MetricName::TrafficControlPacketsForwarded, + &b_network_labels, + ) > forwarded_packets_before + } + }, + Duration::from_secs(5), + ) + .await; + } + #[tokio::test] async fn recent_traffic_tolerates_future_timestamps() { let peer_mgr_a = create_lazy_peer_manager().await; diff --git a/easytier/src/peers/traffic_metrics.rs b/easytier/src/peers/traffic_metrics.rs index a75006b8..ecd5c5d4 100644 --- a/easytier/src/peers/traffic_metrics.rs +++ b/easytier/src/peers/traffic_metrics.rs @@ -220,12 +220,27 @@ impl LogicalTrafficMetrics { } } -#[derive(Clone, Copy)] -enum TrafficKind { +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum TrafficKind { Data, Control, } +pub(crate) fn traffic_kind(packet_type: u8) -> TrafficKind { + if packet_type == PacketType::Data as u8 + || packet_type == PacketType::KcpSrc as u8 + || packet_type == PacketType::KcpDst as u8 + || packet_type == PacketType::QuicSrc as u8 + || packet_type == PacketType::QuicDst as u8 + || packet_type == PacketType::DataWithKcpSrcModified as u8 + || packet_type == PacketType::DataWithQuicSrcModified as u8 + { + TrafficKind::Data + } else { + TrafficKind::Control + } +} + #[derive(Clone)] struct TrafficMetricGroup { data: Arc, @@ -282,7 +297,7 @@ impl TrafficMetricRecorder { return; } self.tx_metrics - .select(Self::traffic_kind(packet_type)) + .select(traffic_kind(packet_type)) .record_with_resolver(peer_id, bytes, || self.resolve_instance_id(peer_id)) .await; } @@ -292,7 +307,7 @@ impl TrafficMetricRecorder { return; } self.rx_metrics - .select(Self::traffic_kind(packet_type)) + .select(traffic_kind(packet_type)) .record_with_resolver(peer_id, bytes, || self.resolve_instance_id(peer_id)) .await; } @@ -314,21 +329,6 @@ impl TrafficMetricRecorder { fn resolve_instance_id(&self, peer_id: PeerId) -> BoxFuture<'static, Option> { (self.resolve_instance_id)(peer_id) } - - fn traffic_kind(packet_type: u8) -> TrafficKind { - if packet_type == PacketType::Data as u8 - || packet_type == PacketType::KcpSrc as u8 - || packet_type == PacketType::KcpDst as u8 - || packet_type == PacketType::QuicSrc as u8 - || packet_type == PacketType::QuicDst as u8 - || packet_type == PacketType::DataWithKcpSrcModified as u8 - || packet_type == PacketType::DataWithQuicSrcModified as u8 - { - TrafficKind::Data - } else { - TrafficKind::Control - } - } } pub(crate) fn route_peer_info_instance_id(route_peer_info: &RoutePeerInfo) -> Option { diff --git a/easytier/src/proto/rpc_impl/packet.rs b/easytier/src/proto/rpc_impl/packet.rs index 4f8b62b7..a03085df 100644 --- a/easytier/src/proto/rpc_impl/packet.rs +++ b/easytier/src/proto/rpc_impl/packet.rs @@ -1,4 +1,4 @@ -use prost::Message as _; +use prost::{Message as _, length_delimiter_len}; use crate::{ common::{PeerId, compressor::DefaultCompressor}, @@ -6,12 +6,15 @@ use crate::{ common::{CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket}, rpc_types::error::Error, }, - tunnel::packet_def::{CompressorAlgo, PacketType, ZCPacket}, + tunnel::packet_def::{CompressorAlgo, PacketType, TAIL_RESERVED_SIZE, ZCPacket, ZCPacketType}, }; use super::RpcTransactId; -const RPC_PACKET_CONTENT_MTU: usize = 1300; +// Budget the final UDP payload size on the wire for peer RPC over `udp://`. +// This includes EasyTier's UDP tunnel header, peer header, and reserved tail +// space for encryption/compression metadata, but excludes the outer IP header. +const RPC_PACKET_UDP_PAYLOAD_BUDGET: usize = 1300; pub async fn compress_packet( accepted_compression_algo: CompressionAlgoPb, @@ -150,44 +153,166 @@ pub struct BuildRpcPacketArgs<'a> { pub compression_info: RpcCompressionInfo, } +// Fixed transport overhead for peer RPC carried by EasyTier's UDP tunnel: +// +// UDP payload budget +// +-------------------------------------------------------------------------+ +// | EasyTier UDP tunnel hdr | PeerManager hdr | RpcPacket bytes | tail room | +// +-------------------------------------------------------------------------+ +// |<------ ZCPacketType::UDP payload_offset ------>|<-- TAIL_RESERVED_SIZE -->| +// +// `udp_rpc_tunnel_overhead()` is everything except `RpcPacket bytes`. +fn udp_rpc_tunnel_overhead() -> usize { + ZCPacketType::UDP.get_packet_offsets().payload_offset + TAIL_RESERVED_SIZE +} + +// Maximum encoded RpcPacket size we can admit before adding it to a UDP tunnel. +// This budget excludes the outer UDP/IP headers because the caller only controls +// the EasyTier payload carried inside the UDP datagram. +fn max_rpc_packet_encoded_len_for_udp() -> usize { + RPC_PACKET_UDP_PAYLOAD_BUDGET.saturating_sub(udp_rpc_tunnel_overhead()) +} + +// Build one logical RpcPacket piece. This is reused both for the actual output +// packets and for sizing templates that estimate worst-case protobuf overhead. +fn build_rpc_piece( + args: &BuildRpcPacketArgs<'_>, + total_pieces: u32, + piece_idx: u32, + body: &[u8], +) -> RpcPacket { + RpcPacket { + from_peer: args.from_peer, + to_peer: args.to_peer, + descriptor: if piece_idx == 0 + || args.compression_info.algo == CompressionAlgoPb::None as i32 + { + // old version must have descriptor on every piece + Some(args.rpc_desc.clone()) + } else { + None + }, + is_request: args.is_req, + total_pieces, + piece_idx, + transaction_id: args.transaction_id, + body: body.to_vec(), + trace_id: args.trace_id, + compression_info: if piece_idx == 0 { + Some(args.compression_info) + } else { + None + }, + } +} + +fn pick_piece_len_for_budget( + base_encoded_len_without_body: usize, + remaining: usize, + max_encoded_len: usize, +) -> usize { + if remaining == 0 { + return 0; + } + + // Minimum non-empty body field encoding cost: + // body tag (1 byte) + body length (1 byte) + body data (1 byte) + if base_encoded_len_without_body + 3 > max_encoded_len { + tracing::warn!( + base_encoded_len_without_body, + max_encoded_len, + "rpc metadata exceeds udp payload budget; falling back to a minimal piece" + ); + return 1; + } + + // `budget` is what remains for the protobuf `body` field after all fixed + // RpcPacket metadata has been accounted for. + let budget = max_encoded_len - base_encoded_len_without_body; + // Reserve the bytes field wrapper conservatively, then use the rest for + // the body itself. + // + // Encoded RpcPacket layout relevant to `body`: + // + // +------------------------------- max_encoded_len -------------------------------+ + // | fixed RpcPacket fields | body tag (1B) | body len varint (worst-case) | body | + // +--------------------------------------------------------------------------- --+ + // ^ ^ + // | `- reserve by using the varint width of `budget` + // `- base_encoded_len_without_body + // + // This is intentionally conservative. A few bytes may be left unused, but + // every piece stays within the UDP payload budget without iterative sizing. + let reserved_for_body_header = 1 + length_delimiter_len(budget); + remaining + .min(budget.saturating_sub(reserved_for_body_header)) + .max(1) +} + +// Pre-split the raw RPC content using conservative worst-case protobuf sizing. +// We compute separate base sizes for the first piece and later pieces because +// only the first piece carries `compression_info`, and old compatibility rules +// may also force `descriptor` to appear on every piece. +// +// Split flow: +// +// raw RPC content +// +--------------------------------------------------------------+ +// | args.content | +// +--------------------------------------------------------------+ +// | first piece uses first_piece_base_len +// | later pieces use other_piece_base_len +// v +// +-----------+-----------+-----------+----- ... +// | offset,len| offset,len| offset,len| +// +-----------+-----------+-----------+----- ... +// +// The result is only a slicing plan. Actual RpcPacket objects are built later +// with the real `total_pieces`. +fn split_rpc_content_for_udp_budget(args: &BuildRpcPacketArgs<'_>) -> Vec<(usize, usize)> { + if args.content.is_empty() { + return vec![(0, 0)]; + } + + let max_encoded_len = max_rpc_packet_encoded_len_for_udp().max(1); + // Use the worst-case varint width for piece counters so the budget remains + // valid without iterating on `total_pieces`/`piece_idx`. + let first_piece_base_len = build_rpc_piece(args, u32::MAX, 0, &[]).encoded_len(); + let other_piece_base_len = build_rpc_piece(args, u32::MAX, u32::MAX, &[]).encoded_len(); + + let mut pieces = Vec::new(); + let mut offset = 0usize; + while offset < args.content.len() { + // First and subsequent pieces have different metadata shapes, so they + // use different fixed-size templates. + let base_len = if pieces.is_empty() { + first_piece_base_len + } else { + other_piece_base_len + }; + let piece_len = + pick_piece_len_for_budget(base_len, args.content.len() - offset, max_encoded_len); + pieces.push((offset, piece_len)); + offset += piece_len; + } + + pieces +} + +// Build the final transport packets after the payload has been split. We do the +// actual `total_pieces` assignment only here so the wire packet stays accurate, +// while the earlier sizing step remains simple and conservatively safe. pub fn build_rpc_packet(args: BuildRpcPacketArgs<'_>) -> Vec { let mut ret = Vec::new(); - let content_mtu = RPC_PACKET_CONTENT_MTU; - let total_pieces = args.content.len().div_ceil(content_mtu); - let mut cur_offset = 0; - while cur_offset < args.content.len() || args.content.is_empty() { - let mut cur_len = content_mtu; - if cur_offset + cur_len > args.content.len() { - cur_len = args.content.len() - cur_offset; - } - - let mut cur_content = Vec::new(); - cur_content.extend_from_slice(&args.content[cur_offset..cur_offset + cur_len]); - - let cur_packet = RpcPacket { - from_peer: args.from_peer, - to_peer: args.to_peer, - descriptor: if cur_offset == 0 - || args.compression_info.algo == CompressionAlgoPb::None as i32 - { - // old version must have descriptor on every piece - Some(args.rpc_desc.clone()) - } else { - None - }, - is_request: args.is_req, - total_pieces: total_pieces as u32, - piece_idx: (cur_offset / RPC_PACKET_CONTENT_MTU) as u32, - transaction_id: args.transaction_id, - body: cur_content, - trace_id: args.trace_id, - compression_info: if cur_offset == 0 { - Some(args.compression_info) - } else { - None - }, - }; - cur_offset += cur_len; + let pieces = split_rpc_content_for_udp_budget(&args); + let total_pieces = pieces.len() as u32; + for (piece_idx, (offset, len)) in pieces.into_iter().enumerate() { + let cur_packet = build_rpc_piece( + &args, + total_pieces, + piece_idx as u32, + &args.content[offset..offset + len], + ); let packet_type = if args.is_req { PacketType::RpcReq @@ -200,11 +325,66 @@ pub fn build_rpc_packet(args: BuildRpcPacketArgs<'_>) -> Vec { let mut zc_packet = ZCPacket::new_with_payload(&buf); zc_packet.fill_peer_manager_hdr(args.from_peer, args.to_peer, packet_type as u8); ret.push(zc_packet); - - if args.content.is_empty() { - break; - } } ret } + +#[cfg(test)] +mod tests { + use super::*; + + fn build_test_args<'a>( + content: &'a [u8], + compression_algo: CompressionAlgoPb, + ) -> BuildRpcPacketArgs<'a> { + BuildRpcPacketArgs { + from_peer: 11, + to_peer: 22, + rpc_desc: RpcDescriptor { + domain_name: "very-long-domain-name-for-rpc-packet-budget-check".repeat(2), + proto_name: "extremely.verbose.proto.name.for.rpc.packet.tests".repeat(2), + service_name: "LargeMetadataServiceForRpcPacketBudget".repeat(2), + method_index: 7, + }, + transaction_id: 33, + is_req: true, + content, + trace_id: 44, + compression_info: RpcCompressionInfo { + algo: compression_algo.into(), + accepted_algo: CompressionAlgoPb::Zstd.into(), + }, + } + } + + fn udp_packet_size_after_tail(packet: &ZCPacket) -> usize { + ZCPacketType::UDP.get_packet_offsets().payload_offset + + packet.payload_len() + + TAIL_RESERVED_SIZE + } + + #[test] + fn build_rpc_packet_respects_udp_budget_with_large_metadata() { + let content = vec![0x5a; 4096]; + let packets = build_rpc_packet(build_test_args(&content, CompressionAlgoPb::None)); + + assert!(packets.len() > 1); + for packet in packets { + assert!( + udp_packet_size_after_tail(&packet) <= RPC_PACKET_UDP_PAYLOAD_BUDGET, + "packet size {} exceeded budget {}", + udp_packet_size_after_tail(&packet), + RPC_PACKET_UDP_PAYLOAD_BUDGET + ); + } + } + + #[test] + fn build_rpc_packet_respects_udp_budget_for_empty_payload() { + let packets = build_rpc_packet(build_test_args(&[], CompressionAlgoPb::Zstd)); + + assert_eq!(1, packets.len()); + assert!(udp_packet_size_after_tail(&packets[0]) <= RPC_PACKET_UDP_PAYLOAD_BUDGET); + } +} diff --git a/easytier/src/rpc_service/api.rs b/easytier/src/rpc_service/api.rs index 77aafd0d..4aedca31 100644 --- a/easytier/src/rpc_service/api.rs +++ b/easytier/src/rpc_service/api.rs @@ -27,8 +27,8 @@ use crate::{ instance_manage::InstanceManageRpcService, logger::LoggerRpcService, mapped_listener_manage::MappedListenerManageRpcService, peer_center::PeerCenterManageRpcService, peer_manage::PeerManageRpcService, - port_forward_manage::PortForwardManageRpcService, proxy::TcpProxyRpcService, - stats::StatsRpcService, vpn_portal::VpnPortalRpcService, + port_forward_manage::PortForwardManageRpcService, protected_port, + proxy::TcpProxyRpcService, stats::StatsRpcService, vpn_portal::VpnPortalRpcService, }, tunnel::{TunnelListener, tcp::TcpTunnelListener}, web_client::{DefaultHooks, WebClientHooks}, @@ -36,6 +36,7 @@ use crate::{ pub struct ApiRpcServer { rpc_server: StandAloneServer, + protected_tcp_port: Option, } impl ApiRpcServer { @@ -44,14 +45,17 @@ impl ApiRpcServer { rpc_portal_whitelist: Option>, instance_manager: Arc, ) -> anyhow::Result { + let rpc_addr = parse_rpc_portal(rpc_portal)?; let mut server = Self::from_tunnel( TcpTunnelListener::new( - format!("tcp://{}", parse_rpc_portal(rpc_portal)?) + format!("tcp://{}", rpc_addr) .parse() .context("failed to parse rpc portal address")?, ), instance_manager, ); + protected_port::register_protected_tcp_port(rpc_addr.port()); + server.protected_tcp_port = Some(rpc_addr.port()); server .rpc_server @@ -65,7 +69,10 @@ impl ApiRpcServer { pub fn from_tunnel(tunnel: T, instance_manager: Arc) -> Self { let rpc_server = StandAloneServer::new(tunnel); register_api_rpc_service(&instance_manager, rpc_server.registry(), None); - Self { rpc_server } + Self { + rpc_server, + protected_tcp_port: None, + } } } @@ -83,6 +90,9 @@ impl ApiRpcServer { impl Drop for ApiRpcServer { fn drop(&mut self) { + if let Some(port) = self.protected_tcp_port.take() { + protected_port::unregister_protected_tcp_port(port); + } self.rpc_server.registry().unregister_all(); } } diff --git a/easytier/src/rpc_service/mod.rs b/easytier/src/rpc_service/mod.rs index a51cd987..b9af8285 100644 --- a/easytier/src/rpc_service/mod.rs +++ b/easytier/src/rpc_service/mod.rs @@ -6,6 +6,7 @@ mod mapped_listener_manage; mod peer_center; mod peer_manage; mod port_forward_manage; +pub(crate) mod protected_port; mod proxy; mod stats; mod vpn_portal; diff --git a/easytier/src/rpc_service/protected_port.rs b/easytier/src/rpc_service/protected_port.rs new file mode 100644 index 00000000..b44216b5 --- /dev/null +++ b/easytier/src/rpc_service/protected_port.rs @@ -0,0 +1,61 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +use once_cell::sync::Lazy; + +static PROTECTED_TCP_PORTS: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +pub fn register_protected_tcp_port(port: u16) { + let mut ports = PROTECTED_TCP_PORTS.lock().unwrap(); + *ports.entry(port).or_default() += 1; +} + +pub fn unregister_protected_tcp_port(port: u16) { + let mut ports = PROTECTED_TCP_PORTS.lock().unwrap(); + if let Some(ref_count) = ports.get_mut(&port) { + *ref_count -= 1; + if *ref_count == 0 { + ports.remove(&port); + } + } +} + +pub fn is_protected_tcp_port(port: u16) -> bool { + PROTECTED_TCP_PORTS.lock().unwrap().contains_key(&port) +} + +#[cfg(test)] +pub fn clear_protected_tcp_ports_for_test() { + PROTECTED_TCP_PORTS.lock().unwrap().clear(); +} + +#[cfg(test)] +mod tests { + use super::{ + clear_protected_tcp_ports_for_test, is_protected_tcp_port, register_protected_tcp_port, + unregister_protected_tcp_port, + }; + + #[test] + fn protected_tcp_port_registry_is_ref_counted() { + clear_protected_tcp_ports_for_test(); + + register_protected_tcp_port(15888); + register_protected_tcp_port(15888); + assert!(is_protected_tcp_port(15888)); + + unregister_protected_tcp_port(15888); + assert!(is_protected_tcp_port(15888)); + + unregister_protected_tcp_port(15888); + assert!(!is_protected_tcp_port(15888)); + } + + #[test] + fn unregistering_unknown_port_is_a_noop() { + clear_protected_tcp_ports_for_test(); + unregister_protected_tcp_port(15888); + assert!(!is_protected_tcp_port(15888)); + } +}