Files
Easytier/easytier/src/gateway/quic_proxy.rs
T
2026-04-12 13:04:21 +08:00

1333 lines
44 KiB
Rust

use crate::common::PeerId;
use crate::common::acl_processor::PacketInfo;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx};
use crate::gateway::CidrSet;
use crate::gateway::tcp_proxy::{NatDstConnector, TcpProxy};
use crate::gateway::wrapped_proxy::{ProxyAclHandler, TcpProxyForWrappedSrcTrait};
use crate::peers::PeerPacketFilter;
use crate::peers::peer_manager::PeerManager;
use crate::proto::acl::{ChainType, Protocol};
use crate::proto::api::instance::{
ListTcpProxyEntryRequest, ListTcpProxyEntryResponse, TcpProxyEntry, TcpProxyEntryState,
TcpProxyEntryTransportType, TcpProxyRpc,
};
use crate::proto::peer_rpc::KcpConnData as QuicConnData;
use crate::proto::rpc_types;
use crate::proto::rpc_types::controller::BaseController;
use crate::tunnel::packet_def::{
PacketType, PeerManagerHeader, TAIL_RESERVED_SIZE, ZCPacket, ZCPacketType,
};
use crate::tunnel::quic::{client_config, endpoint_config, server_config};
use anyhow::{Context, Error, anyhow};
use atomic_refcell::AtomicRefCell;
use bytes::{BufMut, Bytes, BytesMut};
use dashmap::DashMap;
use derivative::Derivative;
use derive_more::{Constructor, Deref, DerefMut, From, Into};
use prost::Message;
use quinn::udp::{EcnCodepoint, RecvMeta, Transmit};
use quinn::{
AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, UdpPoller, default_runtime,
};
use std::cmp::min;
use std::future::Future;
use std::io::IoSliceMut;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::ptr::copy_nonoverlapping;
use std::sync::{Arc, Weak};
use std::task::Poll;
use std::time::Duration;
use tokio::io::{AsyncReadExt, Join, join};
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{Receiver, Sender, channel};
use tokio::task::JoinSet;
use tokio::time::{Instant, timeout};
use tokio::{join, pin, select};
use tokio_util::sync::PollSender;
use tracing::{debug, error, info, instrument, trace, warn};
//region packet
#[derive(Debug, Constructor)]
struct QuicPacket {
addr: SocketAddr,
payload: BytesMut,
segment: Option<usize>,
ecn: Option<EcnCodepoint>,
}
#[derive(Debug, Clone, Copy, From, Into)]
pub struct PacketMargins {
pub header: usize,
pub trailer: usize,
}
impl PacketMargins {
pub fn len(&self) -> usize {
self.header + self.trailer
}
}
//endregion
//region socket
#[derive(Debug)]
struct QuicSocketPoller {
tx: PollSender<QuicPacket>,
}
impl UdpPoller for QuicSocketPoller {
fn poll_writable(
self: Pin<&mut Self>,
cx: &mut std::task::Context,
) -> Poll<std::io::Result<()>> {
let tx = &mut self.get_mut().tx;
let poll = tx.poll_reserve(cx);
if let Poll::Ready(Ok(_)) = poll {
tx.abort_send();
}
poll.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))
}
}
#[derive(Debug)]
pub struct QuicSocket {
addr: SocketAddr,
rx: AtomicRefCell<Receiver<QuicPacket>>,
tx: Sender<QuicPacket>,
margins: PacketMargins,
}
impl AsyncUdpSocket for QuicSocket {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
Box::into_pin(Box::new(QuicSocketPoller {
tx: PollSender::new(self.tx.clone()),
}))
}
fn try_send(&self, transmit: &Transmit) -> std::io::Result<()> {
match transmit.destination {
SocketAddr::V4(addr) => {
let len = transmit.contents.len();
trace!("{:?} sending {:?} bytes to {:?}", self.addr, len, addr);
let permit = self.tx.try_reserve().map_err(|e| match e {
TrySendError::Full(_) => std::io::ErrorKind::WouldBlock,
TrySendError::Closed(_) => std::io::ErrorKind::BrokenPipe,
})?;
let segment_size = transmit.segment_size.unwrap_or(len);
let chunks = transmit.contents.chunks(segment_size);
let segment = segment_size + self.margins.len();
let mut payload = BytesMut::with_capacity(chunks.len() * segment);
// The length of the last chunk could be smaller than segment_size
for chunk in chunks {
let len = chunk.len();
unsafe {
copy_nonoverlapping(
chunk.as_ptr(),
payload.as_mut_ptr().add(self.margins.header),
len,
);
payload.advance_mut(len + self.margins.len());
}
}
permit.send(QuicPacket {
addr: transmit.destination,
payload,
segment: Some(segment),
ecn: transmit.ecn,
});
Ok(())
}
_ => Err(std::io::ErrorKind::ConnectionRefused.into()),
}
}
fn poll_recv(
&self,
cx: &mut std::task::Context,
bufs: &mut [IoSliceMut<'_>],
meta: &mut [RecvMeta],
) -> Poll<std::io::Result<usize>> {
if bufs.is_empty() || meta.is_empty() {
return Poll::Ready(Ok(0));
}
let mut rx = self.rx.borrow_mut();
let mut count = 0;
for (buf, meta) in bufs.iter_mut().zip(meta.iter_mut()) {
match rx.poll_recv(cx) {
Poll::Ready(Some(packet)) => {
let len = packet.payload.len();
if len > buf.len() {
warn!(
"buffer too small for packet: {:?} < {:?}, dropped",
buf.len(),
len,
);
continue;
}
trace!(
"{:?} received {:?} bytes from {:?}",
self.addr, len, packet.addr
);
buf[0..len].copy_from_slice(&packet.payload);
*meta = RecvMeta {
addr: packet.addr,
len,
stride: len,
ecn: packet.ecn,
dst_ip: None,
};
count += 1;
}
Poll::Ready(None) if count > 0 => break,
Poll::Ready(None) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"socket closed",
)));
}
Poll::Pending => break,
}
}
if count > 0 {
Poll::Ready(Ok(count))
} else {
Poll::Pending
}
}
fn local_addr(&self) -> std::io::Result<SocketAddr> {
Ok(self.addr)
}
}
//endregion
//region addr
#[derive(Debug, Clone, Copy, Constructor)]
struct QuicAddr {
peer_id: PeerId,
packet_type: PacketType,
}
impl From<QuicAddr> for SocketAddr {
#[inline]
fn from(value: QuicAddr) -> Self {
SocketAddr::new(IpAddr::V4(value.peer_id.into()), value.packet_type as u16)
}
}
impl TryFrom<SocketAddr> for QuicAddr {
type Error = ();
#[inline]
fn try_from(value: SocketAddr) -> Result<Self, Self::Error> {
let IpAddr::V4(ipv4) = value.ip() else {
return Err(());
};
let peer_id = ipv4.into();
let packet_type = match value.port() {
p if p == PacketType::QuicSrc as u16 => PacketType::QuicSrc,
p if p == PacketType::QuicDst as u16 => PacketType::QuicDst,
_ => return Err(()),
};
Ok(Self {
peer_id,
packet_type,
})
}
}
//endregion
//region stream
type QuicStreamInner = Join<RecvStream, SendStream>;
#[derive(Debug, Deref, DerefMut, From, Into)]
struct QuicStream {
#[deref]
#[deref_mut]
inner: QuicStreamInner,
}
impl QuicStream {
#[inline]
fn id(&self) -> (StreamId, StreamId) {
(self.reader().id(), self.writer().id())
}
}
impl From<(SendStream, RecvStream)> for QuicStream {
#[inline]
fn from(value: (SendStream, RecvStream)) -> Self {
join(value.1, value.0).into()
}
}
//endregion
#[derive(Debug, Clone)]
pub struct NatDstQuicConnector {
pub(crate) endpoint: Endpoint,
pub(crate) peer_mgr: Weak<PeerManager>,
}
#[async_trait::async_trait]
impl NatDstConnector for NatDstQuicConnector {
type DstStream = QuicStreamInner;
async fn connect(
&self,
src: SocketAddr,
nat_dst: SocketAddr,
) -> crate::common::error::Result<Self::DstStream> {
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(anyhow::anyhow!("peer manager is not available").into());
};
let Some(dst_peer_id) = (match nat_dst {
SocketAddr::V4(addr) => peer_mgr.get_peer_map().get_peer_id_by_ipv4(addr.ip()).await,
SocketAddr::V6(_) => return Err(anyhow::anyhow!("ipv6 is not supported").into()),
}) else {
return Err(anyhow::anyhow!("no peer found for nat dst: {}", nat_dst).into());
};
trace!("quic nat dst: {:?}, dst peers: {:?}", nat_dst, dst_peer_id);
let addr = QuicAddr::new(dst_peer_id, PacketType::QuicSrc).into();
let header = {
let conn_data = QuicConnData {
src: Some(src.into()),
dst: Some(nat_dst.into()),
};
let len = conn_data.encoded_len();
if len > (u16::MAX as usize) {
return Err(anyhow!("conn data too large: {:?}", len).into());
}
let mut buf = BytesMut::with_capacity(2 + len);
buf.put_u16(len as u16);
conn_data.encode(&mut buf).unwrap();
buf.freeze()
};
let mut connect_tasks = JoinSet::<Result<QuicStream, Error>>::new();
let connect = |tasks: &mut JoinSet<_>| {
let endpoint = self.endpoint.clone();
let header = header.clone();
tasks.spawn(async move {
let connection = endpoint.connect(addr, "")?.await?;
let mut stream: QuicStream = connection.open_bi().await?.into();
stream.writer_mut().write_chunk(header).await?;
Ok(stream)
});
};
connect(&mut connect_tasks);
let timer = tokio::time::sleep(Duration::from_millis(200));
pin!(timer);
let mut retry_remain = 5;
loop {
select! {
Some(result) = connect_tasks.join_next() => {
match result {
Ok(Ok(stream)) => return Ok(stream.into()),
_ => {
if connect_tasks.is_empty() {
if retry_remain == 0 {
return Err(anyhow!("failed to connect to nat dst: {:?}", nat_dst).into())
}
retry_remain -= 1;
connect(&mut connect_tasks);
timer.as_mut().reset(Instant::now() + Duration::from_millis(200))
}
}
}
}
_ = &mut timer, if retry_remain > 0 => {
retry_remain -= 1;
connect(&mut connect_tasks);
timer.as_mut().reset(Instant::now() + Duration::from_millis(200));
}
}
}
}
#[inline]
fn check_packet_from_peer_fast(&self, _cidr_set: &CidrSet, _global_ctx: &GlobalCtx) -> bool {
true
}
#[inline]
fn check_packet_from_peer(
&self,
_cidr_set: &CidrSet,
_global_ctx: &GlobalCtx,
hdr: &PeerManagerHeader,
_ipv4: &Ipv4Addr,
_real_dst_ip: &mut Ipv4Addr,
) -> bool {
hdr.from_peer_id == hdr.to_peer_id && hdr.is_quic_src_modified()
}
#[inline]
fn transport_type(&self) -> TcpProxyEntryTransportType {
TcpProxyEntryTransportType::Quic
}
}
#[derive(Clone)]
struct TcpProxyForQuicSrc(Arc<TcpProxy<NatDstQuicConnector>>);
#[async_trait::async_trait]
impl TcpProxyForWrappedSrcTrait for TcpProxyForQuicSrc {
type Connector = NatDstQuicConnector;
#[inline]
fn get_tcp_proxy(&self) -> &Arc<TcpProxy<Self::Connector>> {
&self.0
}
#[inline]
fn mark_src_modified(hdr: &mut PeerManagerHeader) -> &mut PeerManagerHeader {
hdr.mark_quic_src_modified()
}
#[inline]
async fn check_dst_allow_wrapped_input(&self, dst_ip: &Ipv4Addr) -> bool {
let Some(peer_manager) = self.0.get_peer_manager() else {
return false;
};
peer_manager
.check_allow_quic_to_dst(&IpAddr::V4(*dst_ip))
.await
}
}
#[derive(Debug)]
enum QuicProxyRole {
Src,
Dst,
}
impl QuicProxyRole {
#[inline]
const fn incoming(&self) -> PacketType {
match self {
QuicProxyRole::Src => PacketType::QuicDst,
QuicProxyRole::Dst => PacketType::QuicSrc,
}
}
#[inline]
const fn outgoing(&self) -> PacketType {
match self {
QuicProxyRole::Src => PacketType::QuicSrc,
QuicProxyRole::Dst => PacketType::QuicDst,
}
}
}
// Receive packets from peers and forward them to the QUIC endpoint
#[derive(Debug)]
struct QuicPacketReceiver {
tx: Sender<QuicPacket>,
role: QuicProxyRole,
}
#[async_trait::async_trait]
impl PeerPacketFilter for QuicPacketReceiver {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
let header = packet.peer_manager_header().unwrap();
if header.packet_type != self.role.incoming() as u8 {
return Some(packet);
}
let addr = QuicAddr::new(header.from_peer_id.get(), self.role.outgoing());
if let Err(e) = self.tx.try_send(QuicPacket::new(
addr.into(),
packet.payload_bytes(),
None,
None,
)) {
debug!("failed to send quic packet to endpoint: {:?}", e);
}
None
}
}
// Send to peers packets received from the QUIC endpoint
#[derive(Debug)]
struct QuicPacketSender {
peer_mgr: Arc<PeerManager>,
rx: Receiver<QuicPacket>,
header: Bytes,
zc_packet_type: ZCPacketType,
margins: PacketMargins,
}
impl QuicPacketSender {
#[instrument]
pub async fn run(mut self) {
while let Some(packet) = self.rx.recv().await {
let Ok(addr) = QuicAddr::try_from(packet.addr) else {
error!("invalid quic packet addr: {:?}", packet.addr);
continue;
};
let mut payload = packet.payload;
let segment = packet
.segment
.expect("segment size must be set for outgoing quic packet");
while !payload.is_empty() {
let len = min(payload.len(), segment);
let mut payload = payload.split_to(len);
payload[..self.margins.header].copy_from_slice(&self.header);
payload.truncate(len - self.margins.trailer);
let mut packet = ZCPacket::new_from_buf(payload, self.zc_packet_type);
packet.fill_peer_manager_hdr(
self.peer_mgr.my_peer_id(),
addr.peer_id,
addr.packet_type as u8,
);
if let Err(e) = self.peer_mgr.send_msg_for_proxy(packet, addr.peer_id).await {
error!("failed to send QUIC packet to peer: {:?}", e);
}
}
}
}
}
#[derive(Derivative, Clone)]
#[derivative(Debug)]
struct QuicStreamContext {
global_ctx: ArcGlobalCtx,
proxy_entries: Arc<DashMap<(StreamId, StreamId), TcpProxyEntry>>,
cidr_set: Arc<CidrSet>,
#[derivative(Debug = "ignore")]
route: Arc<dyn crate::peers::route_trait::Route + Send + Sync + 'static>,
}
impl QuicStreamContext {
fn new(peer_mgr: Arc<PeerManager>) -> Self {
let global_ctx = peer_mgr.get_global_ctx();
Self {
global_ctx: global_ctx.clone(),
proxy_entries: Arc::new(DashMap::new()),
cidr_set: Arc::new(CidrSet::new(global_ctx.clone())),
route: Arc::new(peer_mgr.get_route()),
}
}
}
struct QuicStreamReceiver {
endpoint: Endpoint,
tasks: JoinSet<()>,
ctx: Arc<QuicStreamContext>,
}
impl QuicStreamReceiver {
async fn run(mut self) {
loop {
select! {
biased;
Some(incoming) = self.endpoint.accept() => {
let addr = incoming.remote_address();
let connection = match incoming.accept() {
Ok(connection) => connection,
Err(e) => {
error!("failed to accept quic connection from {:?}: {:?}", addr, e);
continue;
}
};
let addr = connection.remote_address();
let connection = match connection.await {
Ok(connection) => connection,
Err(e) => {
error!("failed to accept quic connection from {:?}: {:?}", addr, e);
continue;
}
};
let ctx = self.ctx.clone();
self.tasks.spawn(async move {
let mut tasks = JoinSet::new();
loop {
select! {
biased;
e = connection.closed() => {
info!("connection to {:?} closed: {:?}", addr, e);
break;
}
stream = connection.accept_bi() => {
let stream = match stream {
Ok(stream) => stream.into(),
Err(e) => {
warn!("failed to accept bi stream from {:?}: {:?}", connection.remote_address(), e);
break;
}
};
match Self::establish_stream(stream, ctx.clone()).await {
Ok(stream) => drop(tasks.spawn(stream)),
Err(e) => warn!("failed to establish quic stream from {:?}: {:?}", connection.remote_address(), e),
}
}
res = tasks.join_next(), if !tasks.is_empty() => {
debug!("quic stream task completed for {:?}: {:?}", addr, res);
}
}
}
connection.close(1u32.into(), b"error");
});
}
_ = self.tasks.join_next(), if !self.tasks.is_empty() => {}
}
}
}
async fn read_stream_header(stream: &mut QuicStream) -> Result<Bytes, Error> {
const STREAM_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(5);
const STREAM_HEADER_LIMIT: u16 = 512;
let len = timeout(STREAM_HEADER_READ_TIMEOUT, stream.read_u16())
.await
.context("timeout reading header length")??;
if len > STREAM_HEADER_LIMIT {
return Err(anyhow::anyhow!("stream header too long"));
}
let mut header = Vec::with_capacity(len as usize);
timeout(
STREAM_HEADER_READ_TIMEOUT,
stream
.reader_mut()
.take(len as u64)
.read_to_end(&mut header),
)
.await
.context("timeout reading header")??;
Ok(header.into())
}
async fn establish_stream(
mut stream: QuicStream,
ctx: Arc<QuicStreamContext>,
) -> Result<impl Future<Output = crate::common::error::Result<()>>, Error> {
let conn_data = Self::read_stream_header(&mut stream).await?;
let conn_data_parsed = QuicConnData::decode(conn_data.as_ref())
.context("failed to decode quic stream header")?;
let handle = stream.id();
let proxy_entries = &ctx.proxy_entries;
proxy_entries.insert(
handle,
TcpProxyEntry {
src: conn_data_parsed.src,
dst: conn_data_parsed.dst,
start_time: chrono::Local::now().timestamp() as u64,
state: TcpProxyEntryState::ConnectingDst.into(),
transport_type: TcpProxyEntryTransportType::Quic.into(),
},
);
crate::defer! {
proxy_entries.remove(&handle);
if proxy_entries.capacity() - proxy_entries.len() > 16 {
proxy_entries.shrink_to_fit();
}
}
let src_socket: SocketAddr = conn_data_parsed
.src
.ok_or_else(|| anyhow!("missing src addr in quic stream header"))?
.into();
let mut dst_socket: SocketAddr = conn_data_parsed
.dst
.ok_or_else(|| anyhow!("missing dst addr in quic stream header"))?
.into();
if let IpAddr::V4(dst_v4_ip) = dst_socket.ip() {
let mut real_ip = dst_v4_ip;
if ctx.cidr_set.contains_v4(dst_v4_ip, &mut real_ip) {
dst_socket.set_ip(real_ip.into());
}
};
let src_ip = src_socket.ip();
let dst_ip = dst_socket.ip();
let route = ctx.route.clone();
let (src_groups, dst_groups) = join!(
route.get_peer_groups_by_ip(&src_ip),
route.get_peer_groups_by_ip(&dst_ip)
);
let global_ctx = ctx.global_ctx.clone();
if global_ctx.should_deny_proxy(&dst_socket, false) {
return Err(anyhow::anyhow!(
"dst socket {:?} is in running listeners, ignore it",
dst_socket
));
}
let send_to_self = global_ctx.is_ip_local_virtual_ip(&dst_ip);
if send_to_self && global_ctx.no_tun() {
dst_socket = format!("127.0.0.1:{}", dst_socket.port()).parse()?;
}
let acl_handler = ProxyAclHandler {
acl_filter: global_ctx.get_acl_filter().clone(),
packet_info: PacketInfo {
src_ip,
dst_ip,
src_port: Some(src_socket.port()),
dst_port: Some(dst_socket.port()),
protocol: Protocol::Tcp,
packet_size: conn_data.len(),
src_groups,
dst_groups,
},
chain_type: if send_to_self {
ChainType::Inbound
} else {
ChainType::Forward
},
};
acl_handler.handle_packet(&conn_data)?;
debug!("quic connect to dst socket: {:?}", dst_socket);
let _g = global_ctx.net_ns.guard();
let connector = crate::gateway::tcp_proxy::NatDstTcpConnector {};
let ret = connector.connect("0.0.0.0:0".parse()?, dst_socket).await?;
if let Some(mut e) = proxy_entries.get_mut(&handle) {
e.state = TcpProxyEntryState::Connected.into();
}
Ok(async move {
acl_handler
.copy_bidirection_with_acl(stream.inner, ret)
.await
})
}
}
pub struct QuicProxy {
peer_mgr: Arc<PeerManager>,
endpoint: Option<Endpoint>,
src: Option<QuicProxySrc>,
dst: Option<QuicProxyDst>,
tasks: JoinSet<()>,
}
impl QuicProxy {
#[inline]
pub fn src(&self) -> Option<&QuicProxySrc> {
self.src.as_ref()
}
#[inline]
pub fn dst(&self) -> Option<&QuicProxyDst> {
self.dst.as_ref()
}
}
impl QuicProxy {
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
Self {
peer_mgr,
endpoint: None,
src: None,
dst: None,
tasks: JoinSet::new(),
}
}
pub async fn run(&mut self, src: bool, dst: bool) {
trace!("quic proxy starting");
if self.endpoint.is_some() {
error!("quic proxy already running");
return;
}
let (header, zc_packet_type) = {
let header = ZCPacket::new_with_payload(&[]);
let zc_packet_type = header.packet_type();
let payload_offset = header.payload_offset();
(
header.inner().split_to(payload_offset).freeze(),
zc_packet_type,
)
};
let margins = (header.len(), TAIL_RESERVED_SIZE).into();
let (in_tx, in_rx) = channel(1024);
let (out_tx, out_rx) = channel(1024);
let socket = QuicSocket {
addr: SocketAddr::new(Ipv4Addr::from(self.peer_mgr.my_peer_id()).into(), 0),
rx: AtomicRefCell::new(in_rx),
tx: out_tx,
margins,
};
let mut endpoint = Endpoint::new_with_abstract_socket(
endpoint_config(),
Some(server_config()),
Arc::new(socket),
default_runtime().unwrap(),
)
.unwrap();
endpoint.set_default_client_config(client_config());
self.endpoint = Some(endpoint.clone());
let peer_mgr = self.peer_mgr.clone();
self.tasks.spawn(
QuicPacketSender {
peer_mgr,
rx: out_rx,
header,
zc_packet_type,
margins,
}
.run(),
);
let peer_mgr = self.peer_mgr.clone();
if src {
if self.src.is_some() {
error!("quic proxy src already running");
return;
}
let tcp_proxy = TcpProxyForQuicSrc(TcpProxy::new(
peer_mgr.clone(),
NatDstQuicConnector {
endpoint: endpoint.clone(),
peer_mgr: Arc::downgrade(&peer_mgr),
},
));
let src = QuicProxySrc {
peer_mgr: peer_mgr.clone(),
tcp_proxy,
tx: in_tx.clone(),
};
src.run().await;
self.src = Some(src);
}
if dst {
if self.dst.is_some() {
error!("quic proxy dst already running");
return;
}
let stream_ctx = Arc::new(QuicStreamContext::new(peer_mgr.clone()));
let dst = QuicProxyDst {
peer_mgr: peer_mgr.clone(),
tx: in_tx.clone(),
stream_ctx: stream_ctx.clone(),
};
dst.run().await;
self.tasks.spawn(
QuicStreamReceiver {
endpoint: endpoint.clone(),
tasks: JoinSet::new(),
ctx: stream_ctx,
}
.run(),
);
self.dst = Some(dst);
}
}
}
pub struct QuicProxySrc {
peer_mgr: Arc<PeerManager>,
tcp_proxy: TcpProxyForQuicSrc,
tx: Sender<QuicPacket>,
}
impl QuicProxySrc {
#[inline]
pub fn get_tcp_proxy(&self) -> Arc<TcpProxy<NatDstQuicConnector>> {
self.tcp_proxy.get_tcp_proxy().clone()
}
}
impl QuicProxySrc {
async fn run(&self) {
trace!("quic proxy src starting");
self.peer_mgr
.add_nic_packet_process_pipeline(Box::new(self.tcp_proxy.clone()))
.await;
self.peer_mgr
.add_packet_process_pipeline(Box::new(self.tcp_proxy.0.clone()))
.await;
self.peer_mgr
.add_packet_process_pipeline(Box::new(QuicPacketReceiver {
tx: self.tx.clone(),
role: QuicProxyRole::Src,
}))
.await;
self.tcp_proxy.0.start(false).await.unwrap();
}
}
pub struct QuicProxyDst {
peer_mgr: Arc<PeerManager>,
tx: Sender<QuicPacket>,
stream_ctx: Arc<QuicStreamContext>,
}
impl QuicProxyDst {
async fn run(&self) {
trace!("quic proxy dst starting");
self.peer_mgr
.add_packet_process_pipeline(Box::new(QuicPacketReceiver {
tx: self.tx.clone(),
role: QuicProxyRole::Dst,
}))
.await;
}
}
#[derive(Clone, Deref, DerefMut, From, Into)]
pub struct QuicProxyDstRpcService(Weak<DashMap<(StreamId, StreamId), TcpProxyEntry>>);
impl QuicProxyDstRpcService {
pub fn new(quic_proxy_dst: &QuicProxyDst) -> Self {
Self(Arc::downgrade(&quic_proxy_dst.stream_ctx.proxy_entries))
}
}
#[async_trait::async_trait]
impl TcpProxyRpc for QuicProxyDstRpcService {
type Controller = BaseController;
async fn list_tcp_proxy_entry(
&self,
_: BaseController,
_request: ListTcpProxyEntryRequest, // Accept request of type HelloRequest
) -> Result<ListTcpProxyEntryResponse, rpc_types::error::Error> {
let mut reply = ListTcpProxyEntryResponse::default();
if let Some(tcp_proxy) = self.0.upgrade() {
for item in tcp_proxy.iter() {
reply.entries.push(*item.value());
}
}
Ok(reply)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Buf;
/// Helper function: Create a pair of interconnected QuicSockets.
/// Data sent by socket_a will enter socket_b's rx, and vice versa.
fn make_socket_pair() -> (QuicSocket, QuicSocket) {
let addr_a: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let addr_b: SocketAddr = "127.0.0.1:5001".parse().unwrap();
// Bidirectional channels: A->B and B->A
// Sufficient capacity to prevent packet loss during high concurrency
let (tx_a_out, rx_a_out) = channel::<QuicPacket>(50_000);
let (tx_b_in, rx_b_in) = channel::<QuicPacket>(50_000);
let (tx_b_out, rx_b_out) = channel::<QuicPacket>(50_000);
let (tx_a_in, rx_a_in) = channel::<QuicPacket>(50_000);
let margins = (20, 25).into();
forward(rx_a_out, tx_b_in, addr_a, margins);
forward(rx_b_out, tx_a_in, addr_b, margins);
let socket_a = QuicSocket {
addr: addr_a,
rx: AtomicRefCell::new(rx_a_in),
tx: tx_a_out,
margins,
};
let socket_b = QuicSocket {
addr: addr_b,
rx: AtomicRefCell::new(rx_b_in),
tx: tx_b_out,
margins,
};
(socket_a, socket_b)
}
fn endpoint() -> (Endpoint, Endpoint) {
let endpoint_config = endpoint_config();
let server_config = server_config();
let client_config = client_config();
// 1. Create an in-memory Socket pair
let (socket_client, socket_server) = make_socket_pair();
let socket_client = Arc::new(socket_client);
let socket_server = Arc::new(socket_server);
// 3. Configure Client Endpoint
let mut client_endpoint = Endpoint::new_with_abstract_socket(
endpoint_config.clone(),
Some(server_config.clone()),
socket_client.clone(),
default_runtime().unwrap(),
)
.unwrap();
client_endpoint.set_default_client_config(client_config.clone());
// 2. Configure Server Endpoint
let mut server_endpoint = Endpoint::new_with_abstract_socket(
endpoint_config.clone(),
Some(server_config.clone()),
socket_server.clone(),
default_runtime().unwrap(),
)
.unwrap();
server_endpoint.set_default_client_config(client_config.clone());
(client_endpoint, server_endpoint)
}
fn forward(
mut rx: Receiver<QuicPacket>,
tx: Sender<QuicPacket>,
addr: SocketAddr,
margins: PacketMargins,
) {
const BATCH_SIZE: usize = 128;
tokio::spawn(async move {
// Key optimization: use buffer for batch processing
let mut buffer = Vec::with_capacity(BATCH_SIZE);
// recv_many wakes up when data is available, taking up to 100 packets at a time
// This reduces context switch overhead by 99 times compared to taking 1 packet at a time
while rx.recv_many(&mut buffer, BATCH_SIZE).await > 0 {
for packet in buffer.iter_mut() {
// [Filter Logic]: Modify address here
packet.addr = addr;
packet.payload.advance(margins.header);
packet
.payload
.truncate(packet.payload.len() - margins.trailer);
}
// Batch forward
for packet in buffer.drain(..) {
if let Err(e) = tx.send(packet).await {
info!("{:?}", e);
return; // Channel closed
}
}
}
});
}
#[tokio::test]
async fn test_ping() -> anyhow::Result<()> {
let (client_endpoint, server_endpoint) = endpoint();
let server_addr = server_endpoint.local_addr()?;
// 4. Server receive task
let server_handle = tokio::spawn(async move {
println!("Server: Waiting for connection...");
if let Some(conn) = server_endpoint.accept().await {
let connection = conn.await.unwrap();
println!(
"Server: Connection accepted from {}",
connection.remote_address()
);
// Accept bidirectional stream
let (mut send, mut recv) = connection.accept_bi().await.unwrap();
// Read data
let mut buf = vec![0u8; 10];
recv.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping______");
println!("Server: Received 'ping______'");
// Send reply
send.write_all(b"pong______").await.unwrap();
send.finish().unwrap();
let _ = connection.closed().await;
}
});
// 5. Client initiates connection
// Note: The connect address here must be V4, because try_send is limited to SocketAddr::V4
println!("Client: Connecting...");
let connection = client_endpoint.connect(server_addr, "localhost")?.await?;
println!("Client: Connected!");
// Open a stream and send data
let (mut send, mut recv) = connection.open_bi().await?;
send.write_all(b"ping______").await?;
send.finish()?;
// Read reply
let mut buf = vec![0u8; 10];
recv.read_exact(&mut buf).await?;
assert_eq!(&buf, b"pong______");
println!("Client: Received 'pong______'");
// 6. Cleanup
connection.close(0u32.into(), b"done");
// Wait for Server to finish
let _ = tokio::time::timeout(Duration::from_secs(2), server_handle).await;
Ok(())
}
#[tokio::test]
#[ignore = "consumes massive memory (~16GB)"]
async fn test_bandwidth() -> anyhow::Result<()> {
// --- 3. Define test data volume ---
// Total test size: 512 MB
const TOTAL_SIZE: usize = 32768 * 1024 * 1024;
// Write chunk size: 1 MB (simulate large chunk write)
const CHUNK_SIZE: usize = 1024 * 1024;
let (client_endpoint, server_endpoint) = endpoint();
let server_addr = server_endpoint.local_addr()?;
// --- 4. Server side (receive and timing) ---
let server_handle = tokio::spawn(async move {
if let Some(conn) = server_endpoint.accept().await {
let connection = conn.await.unwrap();
// Accept unidirectional stream
let mut recv = connection.accept_uni().await.unwrap();
let start = std::time::Instant::now();
let mut received = 0;
// Loop read until the stream ends
// read_chunk performs slightly better than read_exact because it reduces internal buffer copying
while let Some(chunk) = recv.read_chunk(usize::MAX, true).await.unwrap() {
received += chunk.bytes.len();
}
let duration = start.elapsed();
assert_eq!(received, TOTAL_SIZE, "Data length mismatch");
let seconds = duration.as_secs_f64();
let mbps = (received as f64 * 8.0) / (1_000_000.0 * seconds);
let gbps = mbps / 1000.0;
println!("--------------------------------------------------");
println!("Server Recv Statistics:");
println!(" Total Data: {} MB", received / 1024 / 1024);
println!(" Duration : {:.2?}", duration);
println!(" Throughput: {:.2} Gbps ({:.2} Mbps)", gbps, mbps);
println!("--------------------------------------------------");
// Keep connection until the Client disconnects
let _ = connection.closed().await;
}
});
// --- 5. Client side (send) ---
let connection = client_endpoint.connect(server_addr, "localhost")?.await?;
let mut send = connection.open_uni().await?;
// Construct a 1MB data chunk
let data_chunk = vec![0u8; CHUNK_SIZE];
let bytes_data = Bytes::from(data_chunk); // Use Bytes to avoid repeated allocation
println!("Client: Start sending {} MB...", TOTAL_SIZE / 1024 / 1024);
let start_send = std::time::Instant::now();
let chunks = TOTAL_SIZE / CHUNK_SIZE;
for _ in 0..chunks {
// write_chunk is most efficient when used with Bytes
send.write_chunk(bytes_data.clone()).await?;
}
// Tell peer sending is finished
send.finish()?;
// Wait for the stream to close completely (ensure peer received FIN)
send.stopped().await?;
let send_duration = start_send.elapsed();
println!("Client: Send finished in {:.2?}", send_duration);
// Close connection
connection.close(0u32.into(), b"done");
// Wait for Server to print results
let _ = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
Ok(())
}
#[tokio::test]
#[ignore = "consumes massive memory (~16GB)"]
async fn test_bandwidth_parallel() -> anyhow::Result<()> {
// --- 1. Configuration parameters ---
const STREAM_COUNT: usize = 16; // Number of concurrent streams
const STREAM_SIZE: usize = 1024 * 1024 * 1024; // Each stream sends 1GB
let (client_endpoint, server_endpoint) = endpoint();
let server_addr = server_endpoint.local_addr()?;
// --- 3. Server side (concurrent receiver) ---
let server_handle = tokio::spawn(async move {
if let Some(conn) = server_endpoint.accept().await {
let connection = conn.await.unwrap();
println!("Server: Accepted connection");
let mut stream_handles = Vec::new();
let start = std::time::Instant::now();
// Accept an expected number of streams
for i in 0..STREAM_COUNT {
match connection.accept_uni().await {
Ok(mut recv) => {
// Start an independent processing task for each stream
let handle = tokio::spawn(async move {
// Read all data
match recv.read_to_end(usize::MAX).await {
Ok(data) => {
// Verify length
assert_eq!(
data.len(),
STREAM_SIZE,
"Stream {} length mismatch",
i
);
// Verify data content (verify data isolation)
// We agree that the first byte of data is (stream_index % 255)
// This ensures stream data is not mixed
let expected_byte = data[0] as usize; // Get the actual received marker
// Simple check of head and tail here, CRC can be used in production
if data[data.len() - 1] != data[0] {
panic!("Stream data corruption");
}
expected_byte // Return marker for statistics
}
Err(e) => panic!("Stream read error: {}", e),
}
});
stream_handles.push(handle);
}
Err(e) => panic!("Failed to accept stream {}: {}", i, e),
}
}
// Wait for all streams to finish processing
let results = futures::future::join_all(stream_handles).await;
let duration = start.elapsed();
let speed = ((STREAM_COUNT * STREAM_SIZE) as f64 * 8.0)
/ (duration.as_secs_f64() * 1_000_000.0);
println!("--------------------------------------------------");
println!("Server: All {} streams received processing.", results.len());
println!("Total Time: {:.2?}", duration);
println!(
"Total Data: {} MB",
(STREAM_COUNT * STREAM_SIZE) / 1024 / 1024
);
println!(
"Average Speed: {:.2} Gbps ({:.2} Mbps)",
speed / 1024.0,
speed
);
println!("--------------------------------------------------");
// Keep connection until the Client disconnects
let _ = connection.closed().await;
}
});
// --- 4. Client side (concurrent sender) ---
let connection = client_endpoint.connect(server_addr, "localhost")?.await?;
println!(
"Client: Connected, starting {} parallel streams...",
STREAM_COUNT
);
let start_send = std::time::Instant::now();
let mut client_tasks = Vec::new();
// Start sending tasks concurrently
for i in 0..STREAM_COUNT {
let conn = connection.clone();
client_tasks.push(tokio::spawn(async move {
// Open unidirectional stream
let mut send = conn.open_uni().await.expect("Failed to open stream");
// Construct data: use i as the padding marker to verify isolation
// All bytes are filled with (i % 255)
let fill_byte = (i % 255) as u8;
let data = vec![fill_byte; STREAM_SIZE];
let bytes_data = Bytes::from(data);
send.write_chunk(bytes_data).await.expect("Write failed");
send.finish().expect("Finish failed");
// Wait for Server to acknowledge receipt of FIN
send.stopped().await.expect("Stopped failed");
}));
}
// Wait for all sending tasks to complete
futures::future::join_all(client_tasks).await;
let send_duration = start_send.elapsed();
println!("Client: All streams sent in {:.2?}", send_duration);
// Close connection
connection.close(0u32.into(), b"done");
// Wait for Server to finish
let _ = tokio::time::timeout(Duration::from_secs(10), server_handle).await;
Ok(())
}
}