fix packet split on udp tunnel and avoid tcp proxy access rpc portal (#2107)

* distinct control / data when forward packets
* fix rpc split for udp tunnel
* feat(easytier-web): pass public ip in validate token webhook
* protect rpc port from subnet proxy
This commit is contained in:
KKRainbow
2026-04-13 11:03:09 +08:00
committed by GitHub
parent ec7ddd3bad
commit 5b35c51da9
11 changed files with 602 additions and 79 deletions
+19
View File
@@ -28,6 +28,7 @@ use crate::{
common::{PeerFeatureFlag, PortForwardConfigPb},
peer_rpc::PeerGroupInfo,
},
rpc_service::protected_port,
tunnel::matches_protocol,
};
use crossbeam::atomic::AtomicCell;
@@ -658,6 +659,7 @@ impl GlobalCtx {
if dst_is_local_virtual_ip || dst_is_local_phy_ip {
// if is local ip, make sure the port is not one of the listening ports
self.is_port_in_running_listeners(dst_addr.port(), is_udp)
|| (!is_udp && protected_port::is_protected_tcp_port(dst_addr.port()))
} else {
false
}
@@ -765,6 +767,23 @@ pub mod tests {
assert!(feature_flags.is_public_server);
}
#[tokio::test]
async fn should_deny_proxy_for_process_wide_rpc_port() {
protected_port::clear_protected_tcp_ports_for_test();
protected_port::register_protected_tcp_port(15888);
let config = TomlConfigLoader::default();
let global_ctx = GlobalCtx::new(config);
let rpc_addr = SocketAddr::from(([127, 0, 0, 1], 15888));
let other_tcp_addr = SocketAddr::from(([127, 0, 0, 1], 15889));
assert!(global_ctx.should_deny_proxy(&rpc_addr, false));
assert!(!global_ctx.should_deny_proxy(&rpc_addr, true));
assert!(!global_ctx.should_deny_proxy(&other_tcp_addr, false));
protected_port::clear_protected_tcp_ports_for_test();
}
pub fn get_mock_global_ctx_with_network(
network_identy: Option<NetworkIdentity>,
) -> ArcGlobalCtx {
+10
View File
@@ -42,6 +42,8 @@ pub enum MetricName {
TrafficControlBytesRxByInstance,
/// Traffic bytes forwarded
TrafficBytesForwarded,
/// Control-plane traffic bytes forwarded
TrafficControlBytesForwarded,
/// Traffic bytes sent to self
TrafficBytesSelfTx,
/// Traffic bytes received from self
@@ -71,6 +73,8 @@ pub enum MetricName {
TrafficControlPacketsRxByInstance,
/// Traffic packets forwarded
TrafficPacketsForwarded,
/// Control-plane traffic packets forwarded
TrafficControlPacketsForwarded,
/// Traffic packets sent to self
TrafficPacketsSelfTx,
/// Traffic packets received from self
@@ -117,6 +121,9 @@ impl fmt::Display for MetricName {
write!(f, "traffic_control_bytes_rx_by_instance")
}
MetricName::TrafficBytesForwarded => write!(f, "traffic_bytes_forwarded"),
MetricName::TrafficControlBytesForwarded => {
write!(f, "traffic_control_bytes_forwarded")
}
MetricName::TrafficBytesSelfTx => write!(f, "traffic_bytes_self_tx"),
MetricName::TrafficBytesSelfRx => write!(f, "traffic_bytes_self_rx"),
MetricName::TrafficBytesForeignForwardRx => {
@@ -146,6 +153,9 @@ impl fmt::Display for MetricName {
write!(f, "traffic_control_packets_rx_by_instance")
}
MetricName::TrafficPacketsForwarded => write!(f, "traffic_packets_forwarded"),
MetricName::TrafficControlPacketsForwarded => {
write!(f, "traffic_control_packets_forwarded")
}
MetricName::TrafficPacketsSelfTx => write!(f, "traffic_packets_self_tx"),
MetricName::TrafficPacketsSelfRx => write!(f, "traffic_packets_self_rx"),
MetricName::TrafficPacketsForeignForwardRx => {
+97 -7
View File
@@ -55,8 +55,8 @@ use super::{
relay_peer_map::RelayPeerMap,
route_trait::NextHopPolicy,
traffic_metrics::{
InstanceLabelKind, LogicalTrafficMetrics, TrafficMetricRecorder,
route_peer_info_instance_id,
InstanceLabelKind, LogicalTrafficMetrics, TrafficKind, TrafficMetricRecorder,
route_peer_info_instance_id, traffic_kind,
},
};
@@ -419,12 +419,19 @@ impl ForeignNetworkEntry {
let label_set =
LabelSet::new().with_label_type(LabelType::NetworkName(network_name.clone()));
let forward_bytes = self
let forward_data_bytes = self
.stats_mgr
.get_counter(MetricName::TrafficBytesForwarded, label_set.clone());
let forward_packets = self
let forward_data_packets = self
.stats_mgr
.get_counter(MetricName::TrafficPacketsForwarded, label_set.clone());
let forward_control_bytes = self
.stats_mgr
.get_counter(MetricName::TrafficControlBytesForwarded, label_set.clone());
let forward_control_packets = self.stats_mgr.get_counter(
MetricName::TrafficControlPacketsForwarded,
label_set.clone(),
);
let rx_bytes = self
.stats_mgr
.get_counter(MetricName::TrafficBytesSelfRx, label_set.clone());
@@ -502,8 +509,16 @@ impl ForeignNetworkEntry {
}
}
forward_bytes.add(buf_len as u64);
forward_packets.inc();
match traffic_kind(packet_type) {
TrafficKind::Data => {
forward_data_bytes.add(buf_len as u64);
forward_data_packets.inc();
}
TrafficKind::Control => {
forward_control_bytes.add(buf_len as u64);
forward_control_packets.inc();
}
}
let gateway_peer_id = peer_map
.get_gateway_peer_id(to_peer_id, NextHopPolicy::LeastHop)
@@ -1293,6 +1308,11 @@ pub mod tests {
MetricName::TrafficBytesForwarded,
network_labels.clone(),
);
let forwarded_packets_before = metric_value(
&pm_center,
MetricName::TrafficPacketsForwarded,
network_labels.clone(),
);
let rx_bytes_before = metric_value(
&pm_center,
MetricName::TrafficBytesRx,
@@ -1320,6 +1340,7 @@ pub mod tests {
pmb_net1.my_peer_id(),
PacketType::Data as u8,
);
let transit_pkt_len = transit_pkt.buf_len() as u64;
pma_net1
.get_foreign_network_client()
.send_msg(transit_pkt, center_peer_id)
@@ -1334,7 +1355,12 @@ pub mod tests {
&pm_center,
MetricName::TrafficBytesForwarded,
network_labels.clone(),
) > forwarded_bytes_before
) >= forwarded_bytes_before + transit_pkt_len
&& metric_value(
&pm_center,
MetricName::TrafficPacketsForwarded,
network_labels.clone(),
) > forwarded_packets_before
}
},
Duration::from_secs(5),
@@ -1371,6 +1397,70 @@ pub mod tests {
);
}
#[tokio::test]
async fn foreign_network_transit_control_forwarding_records_control_forwarded_metrics() {
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
let center_peer_id = pm_center
.get_foreign_network_manager()
.get_network_peer_id("net1")
.unwrap();
let network_labels =
LabelSet::new().with_label_type(LabelType::NetworkName("net1".to_string()));
let forwarded_bytes_before = metric_value(
&pm_center,
MetricName::TrafficControlBytesForwarded,
network_labels.clone(),
);
let forwarded_packets_before = metric_value(
&pm_center,
MetricName::TrafficControlPacketsForwarded,
network_labels.clone(),
);
let mut transit_pkt = ZCPacket::new_with_payload(b"foreign-control-transit");
transit_pkt.fill_peer_manager_hdr(
pma_net1.my_peer_id(),
pmb_net1.my_peer_id(),
PacketType::RpcReq as u8,
);
let transit_pkt_len = transit_pkt.buf_len() as u64;
pma_net1
.get_foreign_network_client()
.send_msg(transit_pkt, center_peer_id)
.await
.unwrap();
wait_for_condition(
|| {
let pm_center = pm_center.clone();
let network_labels = network_labels.clone();
async move {
metric_value(
&pm_center,
MetricName::TrafficControlBytesForwarded,
network_labels.clone(),
) >= forwarded_bytes_before + transit_pkt_len
&& metric_value(
&pm_center,
MetricName::TrafficControlPacketsForwarded,
network_labels.clone(),
) > forwarded_packets_before
}
},
Duration::from_secs(5),
)
.await;
}
#[tokio::test]
async fn foreign_network_encapsulated_forwarding_records_tx_metrics() {
set_global_var!(OSPF_UPDATE_MY_GLOBAL_FOREIGN_NETWORK_INTERVAL_SEC, 1);
+156 -6
View File
@@ -37,8 +37,8 @@ use crate::{
recv_packet_from_chan,
route_trait::{ForeignNetworkRouteInfoMap, MockRoute, NextHopPolicy, RouteInterface},
traffic_metrics::{
InstanceLabelKind, LogicalTrafficMetrics, TrafficMetricRecorder,
route_peer_info_instance_id,
InstanceLabelKind, LogicalTrafficMetrics, TrafficKind, TrafficMetricRecorder,
route_peer_info_instance_id, traffic_kind,
},
},
proto::{
@@ -888,10 +888,16 @@ impl PeerManager {
stats_mgr.get_counter(MetricName::TrafficBytesSelfRx, label_set.clone());
let self_rx_packets =
stats_mgr.get_counter(MetricName::TrafficPacketsSelfRx, label_set.clone());
let forward_tx_bytes =
let forward_data_tx_bytes =
stats_mgr.get_counter(MetricName::TrafficBytesForwarded, label_set.clone());
let forward_tx_packets =
let forward_data_tx_packets =
stats_mgr.get_counter(MetricName::TrafficPacketsForwarded, label_set.clone());
let forward_control_tx_bytes =
stats_mgr.get_counter(MetricName::TrafficControlBytesForwarded, label_set.clone());
let forward_control_tx_packets = stats_mgr.get_counter(
MetricName::TrafficControlPacketsForwarded,
label_set.clone(),
);
let compress_tx_bytes_before = self.self_tx_counters.compress_tx_bytes_before.clone();
let compress_tx_bytes_after = self.self_tx_counters.compress_tx_bytes_after.clone();
@@ -966,8 +972,16 @@ impl PeerManager {
self_tx_bytes.add(ret.buf_len() as u64);
self_tx_packets.inc();
} else {
forward_tx_bytes.add(buf_len as u64);
forward_tx_packets.inc();
match traffic_kind(packet_type) {
TrafficKind::Data => {
forward_data_tx_bytes.add(buf_len as u64);
forward_data_tx_packets.inc();
}
TrafficKind::Control => {
forward_control_tx_bytes.add(buf_len as u64);
forward_control_tx_packets.inc();
}
}
}
tracing::trace!(?to_peer_id, ?my_peer_id, "need forward");
@@ -2053,6 +2067,12 @@ mod tests {
.unwrap_or(0)
}
fn network_labels(peer_mgr: &PeerManager) -> LabelSet {
LabelSet::new().with_label_type(LabelType::NetworkName(
peer_mgr.get_global_ctx().get_network_name(),
))
}
#[test]
fn recent_traffic_fanout_policy_only_marks_single_peer() {
assert!(PeerManager::should_mark_recent_traffic_for_fanout(0));
@@ -2439,6 +2459,136 @@ mod tests {
);
}
#[tokio::test]
async fn send_msg_internal_records_data_forwarded_metrics_for_transit_peer() {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let b_network_labels = network_labels(&peer_mgr_b);
let forwarded_bytes_before = metric_value(
&peer_mgr_b,
MetricName::TrafficBytesForwarded,
&b_network_labels,
);
let forwarded_packets_before = metric_value(
&peer_mgr_b,
MetricName::TrafficPacketsForwarded,
&b_network_labels,
);
let mut pkt = ZCPacket::new_with_payload(b"forward-data");
pkt.fill_peer_manager_hdr(
peer_mgr_a.my_peer_id(),
peer_mgr_c.my_peer_id(),
PacketType::Data as u8,
);
let pkt_len = pkt.buf_len() as u64;
PeerManager::send_msg_internal(
&peer_mgr_a.peers,
&peer_mgr_a.foreign_network_client,
&peer_mgr_a.relay_peer_map,
Some(&peer_mgr_a.traffic_metrics),
pkt,
peer_mgr_c.my_peer_id(),
)
.await
.unwrap();
wait_for_condition(
|| {
let peer_mgr_b = peer_mgr_b.clone();
let b_network_labels = b_network_labels.clone();
async move {
metric_value(
&peer_mgr_b,
MetricName::TrafficBytesForwarded,
&b_network_labels,
) >= forwarded_bytes_before + pkt_len
&& metric_value(
&peer_mgr_b,
MetricName::TrafficPacketsForwarded,
&b_network_labels,
) > forwarded_packets_before
}
},
Duration::from_secs(5),
)
.await;
}
#[tokio::test]
async fn send_msg_internal_records_control_forwarded_metrics_for_transit_peer() {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let b_network_labels = network_labels(&peer_mgr_b);
let forwarded_bytes_before = metric_value(
&peer_mgr_b,
MetricName::TrafficControlBytesForwarded,
&b_network_labels,
);
let forwarded_packets_before = metric_value(
&peer_mgr_b,
MetricName::TrafficControlPacketsForwarded,
&b_network_labels,
);
let mut pkt = ZCPacket::new_with_payload(b"forward-control");
pkt.fill_peer_manager_hdr(
peer_mgr_a.my_peer_id(),
peer_mgr_c.my_peer_id(),
PacketType::RpcReq as u8,
);
let pkt_len = pkt.buf_len() as u64;
PeerManager::send_msg_internal(
&peer_mgr_a.peers,
&peer_mgr_a.foreign_network_client,
&peer_mgr_a.relay_peer_map,
Some(&peer_mgr_a.traffic_metrics),
pkt,
peer_mgr_c.my_peer_id(),
)
.await
.unwrap();
wait_for_condition(
|| {
let peer_mgr_b = peer_mgr_b.clone();
let b_network_labels = b_network_labels.clone();
async move {
metric_value(
&peer_mgr_b,
MetricName::TrafficControlBytesForwarded,
&b_network_labels,
) >= forwarded_bytes_before + pkt_len
&& metric_value(
&peer_mgr_b,
MetricName::TrafficControlPacketsForwarded,
&b_network_labels,
) > forwarded_packets_before
}
},
Duration::from_secs(5),
)
.await;
}
#[tokio::test]
async fn recent_traffic_tolerates_future_timestamps() {
let peer_mgr_a = create_lazy_peer_manager().await;
+19 -19
View File
@@ -220,12 +220,27 @@ impl LogicalTrafficMetrics {
}
}
#[derive(Clone, Copy)]
enum TrafficKind {
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum TrafficKind {
Data,
Control,
}
pub(crate) fn traffic_kind(packet_type: u8) -> TrafficKind {
if packet_type == PacketType::Data as u8
|| packet_type == PacketType::KcpSrc as u8
|| packet_type == PacketType::KcpDst as u8
|| packet_type == PacketType::QuicSrc as u8
|| packet_type == PacketType::QuicDst as u8
|| packet_type == PacketType::DataWithKcpSrcModified as u8
|| packet_type == PacketType::DataWithQuicSrcModified as u8
{
TrafficKind::Data
} else {
TrafficKind::Control
}
}
#[derive(Clone)]
struct TrafficMetricGroup {
data: Arc<LogicalTrafficMetrics>,
@@ -282,7 +297,7 @@ impl TrafficMetricRecorder {
return;
}
self.tx_metrics
.select(Self::traffic_kind(packet_type))
.select(traffic_kind(packet_type))
.record_with_resolver(peer_id, bytes, || self.resolve_instance_id(peer_id))
.await;
}
@@ -292,7 +307,7 @@ impl TrafficMetricRecorder {
return;
}
self.rx_metrics
.select(Self::traffic_kind(packet_type))
.select(traffic_kind(packet_type))
.record_with_resolver(peer_id, bytes, || self.resolve_instance_id(peer_id))
.await;
}
@@ -314,21 +329,6 @@ impl TrafficMetricRecorder {
fn resolve_instance_id(&self, peer_id: PeerId) -> BoxFuture<'static, Option<String>> {
(self.resolve_instance_id)(peer_id)
}
fn traffic_kind(packet_type: u8) -> TrafficKind {
if packet_type == PacketType::Data as u8
|| packet_type == PacketType::KcpSrc as u8
|| packet_type == PacketType::KcpDst as u8
|| packet_type == PacketType::QuicSrc as u8
|| packet_type == PacketType::QuicDst as u8
|| packet_type == PacketType::DataWithKcpSrcModified as u8
|| packet_type == PacketType::DataWithQuicSrcModified as u8
{
TrafficKind::Data
} else {
TrafficKind::Control
}
}
}
pub(crate) fn route_peer_info_instance_id(route_peer_info: &RoutePeerInfo) -> Option<String> {
+223 -43
View File
@@ -1,4 +1,4 @@
use prost::Message as _;
use prost::{Message as _, length_delimiter_len};
use crate::{
common::{PeerId, compressor::DefaultCompressor},
@@ -6,12 +6,15 @@ use crate::{
common::{CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket},
rpc_types::error::Error,
},
tunnel::packet_def::{CompressorAlgo, PacketType, ZCPacket},
tunnel::packet_def::{CompressorAlgo, PacketType, TAIL_RESERVED_SIZE, ZCPacket, ZCPacketType},
};
use super::RpcTransactId;
const RPC_PACKET_CONTENT_MTU: usize = 1300;
// Budget the final UDP payload size on the wire for peer RPC over `udp://`.
// This includes EasyTier's UDP tunnel header, peer header, and reserved tail
// space for encryption/compression metadata, but excludes the outer IP header.
const RPC_PACKET_UDP_PAYLOAD_BUDGET: usize = 1300;
pub async fn compress_packet(
accepted_compression_algo: CompressionAlgoPb,
@@ -150,44 +153,166 @@ pub struct BuildRpcPacketArgs<'a> {
pub compression_info: RpcCompressionInfo,
}
// Fixed transport overhead for peer RPC carried by EasyTier's UDP tunnel:
//
// UDP payload budget
// +-------------------------------------------------------------------------+
// | EasyTier UDP tunnel hdr | PeerManager hdr | RpcPacket bytes | tail room |
// +-------------------------------------------------------------------------+
// |<------ ZCPacketType::UDP payload_offset ------>|<-- TAIL_RESERVED_SIZE -->|
//
// `udp_rpc_tunnel_overhead()` is everything except `RpcPacket bytes`.
fn udp_rpc_tunnel_overhead() -> usize {
ZCPacketType::UDP.get_packet_offsets().payload_offset + TAIL_RESERVED_SIZE
}
// Maximum encoded RpcPacket size we can admit before adding it to a UDP tunnel.
// This budget excludes the outer UDP/IP headers because the caller only controls
// the EasyTier payload carried inside the UDP datagram.
fn max_rpc_packet_encoded_len_for_udp() -> usize {
RPC_PACKET_UDP_PAYLOAD_BUDGET.saturating_sub(udp_rpc_tunnel_overhead())
}
// Build one logical RpcPacket piece. This is reused both for the actual output
// packets and for sizing templates that estimate worst-case protobuf overhead.
fn build_rpc_piece(
args: &BuildRpcPacketArgs<'_>,
total_pieces: u32,
piece_idx: u32,
body: &[u8],
) -> RpcPacket {
RpcPacket {
from_peer: args.from_peer,
to_peer: args.to_peer,
descriptor: if piece_idx == 0
|| args.compression_info.algo == CompressionAlgoPb::None as i32
{
// old version must have descriptor on every piece
Some(args.rpc_desc.clone())
} else {
None
},
is_request: args.is_req,
total_pieces,
piece_idx,
transaction_id: args.transaction_id,
body: body.to_vec(),
trace_id: args.trace_id,
compression_info: if piece_idx == 0 {
Some(args.compression_info)
} else {
None
},
}
}
fn pick_piece_len_for_budget(
base_encoded_len_without_body: usize,
remaining: usize,
max_encoded_len: usize,
) -> usize {
if remaining == 0 {
return 0;
}
// Minimum non-empty body field encoding cost:
// body tag (1 byte) + body length (1 byte) + body data (1 byte)
if base_encoded_len_without_body + 3 > max_encoded_len {
tracing::warn!(
base_encoded_len_without_body,
max_encoded_len,
"rpc metadata exceeds udp payload budget; falling back to a minimal piece"
);
return 1;
}
// `budget` is what remains for the protobuf `body` field after all fixed
// RpcPacket metadata has been accounted for.
let budget = max_encoded_len - base_encoded_len_without_body;
// Reserve the bytes field wrapper conservatively, then use the rest for
// the body itself.
//
// Encoded RpcPacket layout relevant to `body`:
//
// +------------------------------- max_encoded_len -------------------------------+
// | fixed RpcPacket fields | body tag (1B) | body len varint (worst-case) | body |
// +--------------------------------------------------------------------------- --+
// ^ ^
// | `- reserve by using the varint width of `budget`
// `- base_encoded_len_without_body
//
// This is intentionally conservative. A few bytes may be left unused, but
// every piece stays within the UDP payload budget without iterative sizing.
let reserved_for_body_header = 1 + length_delimiter_len(budget);
remaining
.min(budget.saturating_sub(reserved_for_body_header))
.max(1)
}
// Pre-split the raw RPC content using conservative worst-case protobuf sizing.
// We compute separate base sizes for the first piece and later pieces because
// only the first piece carries `compression_info`, and old compatibility rules
// may also force `descriptor` to appear on every piece.
//
// Split flow:
//
// raw RPC content
// +--------------------------------------------------------------+
// | args.content |
// +--------------------------------------------------------------+
// | first piece uses first_piece_base_len
// | later pieces use other_piece_base_len
// v
// +-----------+-----------+-----------+----- ...
// | offset,len| offset,len| offset,len|
// +-----------+-----------+-----------+----- ...
//
// The result is only a slicing plan. Actual RpcPacket objects are built later
// with the real `total_pieces`.
fn split_rpc_content_for_udp_budget(args: &BuildRpcPacketArgs<'_>) -> Vec<(usize, usize)> {
if args.content.is_empty() {
return vec![(0, 0)];
}
let max_encoded_len = max_rpc_packet_encoded_len_for_udp().max(1);
// Use the worst-case varint width for piece counters so the budget remains
// valid without iterating on `total_pieces`/`piece_idx`.
let first_piece_base_len = build_rpc_piece(args, u32::MAX, 0, &[]).encoded_len();
let other_piece_base_len = build_rpc_piece(args, u32::MAX, u32::MAX, &[]).encoded_len();
let mut pieces = Vec::new();
let mut offset = 0usize;
while offset < args.content.len() {
// First and subsequent pieces have different metadata shapes, so they
// use different fixed-size templates.
let base_len = if pieces.is_empty() {
first_piece_base_len
} else {
other_piece_base_len
};
let piece_len =
pick_piece_len_for_budget(base_len, args.content.len() - offset, max_encoded_len);
pieces.push((offset, piece_len));
offset += piece_len;
}
pieces
}
// Build the final transport packets after the payload has been split. We do the
// actual `total_pieces` assignment only here so the wire packet stays accurate,
// while the earlier sizing step remains simple and conservatively safe.
pub fn build_rpc_packet(args: BuildRpcPacketArgs<'_>) -> Vec<ZCPacket> {
let mut ret = Vec::new();
let content_mtu = RPC_PACKET_CONTENT_MTU;
let total_pieces = args.content.len().div_ceil(content_mtu);
let mut cur_offset = 0;
while cur_offset < args.content.len() || args.content.is_empty() {
let mut cur_len = content_mtu;
if cur_offset + cur_len > args.content.len() {
cur_len = args.content.len() - cur_offset;
}
let mut cur_content = Vec::new();
cur_content.extend_from_slice(&args.content[cur_offset..cur_offset + cur_len]);
let cur_packet = RpcPacket {
from_peer: args.from_peer,
to_peer: args.to_peer,
descriptor: if cur_offset == 0
|| args.compression_info.algo == CompressionAlgoPb::None as i32
{
// old version must have descriptor on every piece
Some(args.rpc_desc.clone())
} else {
None
},
is_request: args.is_req,
total_pieces: total_pieces as u32,
piece_idx: (cur_offset / RPC_PACKET_CONTENT_MTU) as u32,
transaction_id: args.transaction_id,
body: cur_content,
trace_id: args.trace_id,
compression_info: if cur_offset == 0 {
Some(args.compression_info)
} else {
None
},
};
cur_offset += cur_len;
let pieces = split_rpc_content_for_udp_budget(&args);
let total_pieces = pieces.len() as u32;
for (piece_idx, (offset, len)) in pieces.into_iter().enumerate() {
let cur_packet = build_rpc_piece(
&args,
total_pieces,
piece_idx as u32,
&args.content[offset..offset + len],
);
let packet_type = if args.is_req {
PacketType::RpcReq
@@ -200,11 +325,66 @@ pub fn build_rpc_packet(args: BuildRpcPacketArgs<'_>) -> Vec<ZCPacket> {
let mut zc_packet = ZCPacket::new_with_payload(&buf);
zc_packet.fill_peer_manager_hdr(args.from_peer, args.to_peer, packet_type as u8);
ret.push(zc_packet);
if args.content.is_empty() {
break;
}
}
ret
}
#[cfg(test)]
mod tests {
use super::*;
fn build_test_args<'a>(
content: &'a [u8],
compression_algo: CompressionAlgoPb,
) -> BuildRpcPacketArgs<'a> {
BuildRpcPacketArgs {
from_peer: 11,
to_peer: 22,
rpc_desc: RpcDescriptor {
domain_name: "very-long-domain-name-for-rpc-packet-budget-check".repeat(2),
proto_name: "extremely.verbose.proto.name.for.rpc.packet.tests".repeat(2),
service_name: "LargeMetadataServiceForRpcPacketBudget".repeat(2),
method_index: 7,
},
transaction_id: 33,
is_req: true,
content,
trace_id: 44,
compression_info: RpcCompressionInfo {
algo: compression_algo.into(),
accepted_algo: CompressionAlgoPb::Zstd.into(),
},
}
}
fn udp_packet_size_after_tail(packet: &ZCPacket) -> usize {
ZCPacketType::UDP.get_packet_offsets().payload_offset
+ packet.payload_len()
+ TAIL_RESERVED_SIZE
}
#[test]
fn build_rpc_packet_respects_udp_budget_with_large_metadata() {
let content = vec![0x5a; 4096];
let packets = build_rpc_packet(build_test_args(&content, CompressionAlgoPb::None));
assert!(packets.len() > 1);
for packet in packets {
assert!(
udp_packet_size_after_tail(&packet) <= RPC_PACKET_UDP_PAYLOAD_BUDGET,
"packet size {} exceeded budget {}",
udp_packet_size_after_tail(&packet),
RPC_PACKET_UDP_PAYLOAD_BUDGET
);
}
}
#[test]
fn build_rpc_packet_respects_udp_budget_for_empty_payload() {
let packets = build_rpc_packet(build_test_args(&[], CompressionAlgoPb::Zstd));
assert_eq!(1, packets.len());
assert!(udp_packet_size_after_tail(&packets[0]) <= RPC_PACKET_UDP_PAYLOAD_BUDGET);
}
}
+14 -4
View File
@@ -27,8 +27,8 @@ use crate::{
instance_manage::InstanceManageRpcService, logger::LoggerRpcService,
mapped_listener_manage::MappedListenerManageRpcService,
peer_center::PeerCenterManageRpcService, peer_manage::PeerManageRpcService,
port_forward_manage::PortForwardManageRpcService, proxy::TcpProxyRpcService,
stats::StatsRpcService, vpn_portal::VpnPortalRpcService,
port_forward_manage::PortForwardManageRpcService, protected_port,
proxy::TcpProxyRpcService, stats::StatsRpcService, vpn_portal::VpnPortalRpcService,
},
tunnel::{TunnelListener, tcp::TcpTunnelListener},
web_client::{DefaultHooks, WebClientHooks},
@@ -36,6 +36,7 @@ use crate::{
pub struct ApiRpcServer<T: TunnelListener + 'static> {
rpc_server: StandAloneServer<T>,
protected_tcp_port: Option<u16>,
}
impl ApiRpcServer<TcpTunnelListener> {
@@ -44,14 +45,17 @@ impl ApiRpcServer<TcpTunnelListener> {
rpc_portal_whitelist: Option<Vec<IpCidr>>,
instance_manager: Arc<NetworkInstanceManager>,
) -> anyhow::Result<Self> {
let rpc_addr = parse_rpc_portal(rpc_portal)?;
let mut server = Self::from_tunnel(
TcpTunnelListener::new(
format!("tcp://{}", parse_rpc_portal(rpc_portal)?)
format!("tcp://{}", rpc_addr)
.parse()
.context("failed to parse rpc portal address")?,
),
instance_manager,
);
protected_port::register_protected_tcp_port(rpc_addr.port());
server.protected_tcp_port = Some(rpc_addr.port());
server
.rpc_server
@@ -65,7 +69,10 @@ impl<T: TunnelListener + 'static> ApiRpcServer<T> {
pub fn from_tunnel(tunnel: T, instance_manager: Arc<NetworkInstanceManager>) -> Self {
let rpc_server = StandAloneServer::new(tunnel);
register_api_rpc_service(&instance_manager, rpc_server.registry(), None);
Self { rpc_server }
Self {
rpc_server,
protected_tcp_port: None,
}
}
}
@@ -83,6 +90,9 @@ impl<T: TunnelListener + 'static> ApiRpcServer<T> {
impl<T: TunnelListener + 'static> Drop for ApiRpcServer<T> {
fn drop(&mut self) {
if let Some(port) = self.protected_tcp_port.take() {
protected_port::unregister_protected_tcp_port(port);
}
self.rpc_server.registry().unregister_all();
}
}
+1
View File
@@ -6,6 +6,7 @@ mod mapped_listener_manage;
mod peer_center;
mod peer_manage;
mod port_forward_manage;
pub(crate) mod protected_port;
mod proxy;
mod stats;
mod vpn_portal;
@@ -0,0 +1,61 @@
use std::collections::HashMap;
use std::sync::Mutex;
use once_cell::sync::Lazy;
static PROTECTED_TCP_PORTS: Lazy<Mutex<HashMap<u16, usize>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
pub fn register_protected_tcp_port(port: u16) {
let mut ports = PROTECTED_TCP_PORTS.lock().unwrap();
*ports.entry(port).or_default() += 1;
}
pub fn unregister_protected_tcp_port(port: u16) {
let mut ports = PROTECTED_TCP_PORTS.lock().unwrap();
if let Some(ref_count) = ports.get_mut(&port) {
*ref_count -= 1;
if *ref_count == 0 {
ports.remove(&port);
}
}
}
pub fn is_protected_tcp_port(port: u16) -> bool {
PROTECTED_TCP_PORTS.lock().unwrap().contains_key(&port)
}
#[cfg(test)]
pub fn clear_protected_tcp_ports_for_test() {
PROTECTED_TCP_PORTS.lock().unwrap().clear();
}
#[cfg(test)]
mod tests {
use super::{
clear_protected_tcp_ports_for_test, is_protected_tcp_port, register_protected_tcp_port,
unregister_protected_tcp_port,
};
#[test]
fn protected_tcp_port_registry_is_ref_counted() {
clear_protected_tcp_ports_for_test();
register_protected_tcp_port(15888);
register_protected_tcp_port(15888);
assert!(is_protected_tcp_port(15888));
unregister_protected_tcp_port(15888);
assert!(is_protected_tcp_port(15888));
unregister_protected_tcp_port(15888);
assert!(!is_protected_tcp_port(15888));
}
#[test]
fn unregistering_unknown_port_is_a_noop() {
clear_protected_tcp_ports_for_test();
unregister_protected_tcp_port(15888);
assert!(!is_protected_tcp_port(15888));
}
}