improve faketcp, handle tcp GSO correctly (#1708)

Current implementation falsely drop GSO-merged tcp packet, and cause unexpected packet loss.
This commit is contained in:
KKRainbow
2025-12-26 23:46:17 +08:00
committed by GitHub
parent 0be4ac1fa5
commit 4341bcba5d
4 changed files with 194 additions and 61 deletions
+1 -1
View File
@@ -203,7 +203,7 @@ jobs:
# Copied and slightly modified from @lmq8267 (https://github.com/lmq8267)
- name: Build Core & Cli (X86_64 FreeBSD)
uses: vmactions/freebsd-vm@v1
uses: vmactions/freebsd-vm@670398e4236735b8b65805c3da44b7a511fb8b27
if: ${{ endsWith(matrix.TARGET, 'freebsd') }}
env:
TARGET: ${{ matrix.TARGET }}
+58 -23
View File
@@ -327,7 +327,11 @@ impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
tracing::info!(?remote_addr, "FakeTcpTunnelConnector connecting");
socket.recv_bytes().await.ok_or(TunnelError::InternalError(
let mut buf = BytesMut::new();
socket
.recv(&mut buf)
.await
.ok_or(TunnelError::InternalError(
"Failed to recv bytes to establish connection".into(),
))?;
@@ -367,17 +371,24 @@ use crate::tunnel::{SinkError, SinkItem, StreamItem};
use futures::{Sink, Stream};
use std::task::{Context as TaskContext, Poll};
type RecvFut = Pin<Box<dyn Future<Output = Option<(BytesMut, usize)>> + Send + Sync>>;
enum FakeTcpStreamState {
ConsumingBuf(BytesMut),
PollFuture(RecvFut),
Closed,
}
struct FakeTcpStream {
socket: Arc<stack::Socket>,
#[allow(clippy::type_complexity)]
recv_fut: Option<Pin<Box<dyn Future<Output = Option<Vec<u8>>> + Send + Sync>>>,
state: FakeTcpStreamState,
}
impl FakeTcpStream {
fn new(socket: Arc<stack::Socket>) -> Self {
Self {
socket,
recv_fut: None,
state: FakeTcpStreamState::ConsumingBuf(BytesMut::new()),
}
}
}
@@ -387,27 +398,51 @@ impl Stream for FakeTcpStream {
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let s = self.get_mut();
if s.recv_fut.is_none() {
let socket = s.socket.clone();
s.recv_fut = Some(Box::pin(async move { socket.recv_bytes().await }));
loop {
let state = std::mem::replace(&mut s.state, FakeTcpStreamState::Closed);
match state {
FakeTcpStreamState::ConsumingBuf(buf) => {
let buf_len = buf.len();
// check peer manager header and split buf out
let packet = ZCPacket::new_from_buf(buf, ZCPacketType::TCP);
if let Some(tcp_hdr) = packet.tcp_tunnel_header() {
let expected_payload_len = tcp_hdr.len.get() as usize;
if expected_payload_len <= buf_len && expected_payload_len != 0 {
let mut buf = packet.inner();
let new_inner = buf.split_to(expected_payload_len);
s.state = FakeTcpStreamState::ConsumingBuf(buf);
return Poll::Ready(Some(Ok(ZCPacket::new_from_buf(
new_inner,
ZCPacketType::TCP,
))));
}
}
match s.recv_fut.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready(Some(data)) => {
let mut buf = BytesMut::new();
buf.extend_from_slice(&data);
let packet = ZCPacket::new_from_buf(buf, ZCPacketType::DummyTunnel);
let mut buf = packet.inner();
buf.truncate(0);
s.recv_fut = None;
Poll::Ready(Some(Ok(packet)))
let socket = s.socket.clone();
s.state = FakeTcpStreamState::PollFuture(Box::pin(async move {
let ret = socket.recv(&mut buf).await;
ret.map(|s| (buf, s))
}));
}
FakeTcpStreamState::PollFuture(mut fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Some((buf, _sz))) => {
s.state = FakeTcpStreamState::ConsumingBuf(buf);
}
Poll::Ready(None) => {
// 连接关闭
s.recv_fut = None;
Poll::Ready(None)
s.state = FakeTcpStreamState::Closed;
}
Poll::Pending => {
s.state = FakeTcpStreamState::PollFuture(fut);
return Poll::Pending;
}
},
FakeTcpStreamState::Closed => {
return Poll::Ready(None);
}
}
Poll::Pending => Poll::Pending,
}
}
}
@@ -435,10 +470,10 @@ impl Sink<SinkItem> for FakeTcpSink {
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
// We need to send the packet as bytes
// The item is ZCPacket, which has into_bytes() method
let bytes = item.convert_type(ZCPacketType::DummyTunnel).into_bytes();
// Let's just spawn for now as a simple implementation, noting the limitation.
self.socket.try_send(&bytes);
let mut packet = item.convert_type(ZCPacketType::TCP);
let len = packet.buf_len();
packet.mut_tcp_tunnel_header().unwrap().len.set(len as u32);
self.socket.try_send(&packet.into_bytes());
Ok(())
}
@@ -9,6 +9,7 @@ use std::net::SocketAddr;
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use crate::tunnel::fake_tcp::stack;
@@ -37,6 +38,11 @@ const BPF_JEQ: u16 = 0x10;
const BPF_K: u16 = 0x00;
const SOL_PACKET: i32 = 263;
const PACKET_STATISTICS: i32 = 6;
const DEFAULT_RCVBUF_BYTES: i32 = 32 * 1024 * 1024;
fn stmt(code: u16, k: u32) -> libc::sock_filter {
libc::sock_filter {
code,
@@ -303,6 +309,63 @@ fn build_tcp_filter(
b.finish()
}
#[repr(C)]
#[derive(Clone, Copy, Default)]
struct PacketSocketStats {
tp_packets: u32,
tp_drops: u32,
}
fn set_socket_rcvbuf(fd: i32, desired_bytes: i32) -> io::Result<i32> {
let ret = unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_RCVBUF,
&desired_bytes as *const _ as *const libc::c_void,
mem::size_of_val(&desired_bytes) as u32,
)
};
if ret != 0 {
return Err(io::Error::last_os_error());
}
let mut actual: i32 = 0;
let mut len = mem::size_of_val(&actual) as libc::socklen_t;
let ret = unsafe {
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_RCVBUF,
&mut actual as *mut _ as *mut libc::c_void,
&mut len as *mut _,
)
};
if ret != 0 {
return Err(io::Error::last_os_error());
}
Ok(actual)
}
fn read_packet_socket_stats(fd: i32) -> io::Result<PacketSocketStats> {
let mut stats = PacketSocketStats::default();
let mut len = mem::size_of_val(&stats) as libc::socklen_t;
let ret = unsafe {
libc::getsockopt(
fd,
SOL_PACKET,
PACKET_STATISTICS,
&mut stats as *mut _ as *mut libc::c_void,
&mut len as *mut _,
)
};
if ret != 0 {
return Err(io::Error::last_os_error());
}
Ok(stats)
}
pub struct LinuxBpfTun {
fd: OwnedFd,
ifindex: i32,
@@ -350,6 +413,8 @@ impl LinuxBpfTun {
return Err(io::Error::last_os_error());
}
let actual_rcvbuf = set_socket_rcvbuf(fd.as_raw_fd(), DEFAULT_RCVBUF_BYTES)?;
let filter = build_tcp_filter(src_addr, dst_addr)?;
let mut prog = libc::sock_fprog {
len: filter
@@ -389,9 +454,16 @@ impl LinuxBpfTun {
let (tx, rx) = tokio::sync::mpsc::channel(1024);
let stop_clone = stop.clone();
let read_fd = fd.as_raw_fd();
let interface_name_for_worker = interface_name.to_string();
let worker = std::thread::spawn(move || {
let mut buf = vec![0u8; 65536];
let mut stats_enabled = true;
let mut total_packets: u64 = 0;
let mut total_drops: u64 = 0;
let mut total_bytes: u64 = 0;
let mut dropped_by_queue_full: u64 = 0;
let mut last_stats_log = Instant::now();
while !stop_clone.load(AtomicOrdering::Relaxed) {
let n = unsafe {
libc::recv(read_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0)
@@ -410,8 +482,60 @@ impl LinuxBpfTun {
continue;
}
let data = buf[..(n as usize)].to_vec();
if tx.blocking_send(data).is_err() {
break;
total_bytes = total_bytes.wrapping_add(n as u64);
match tx.try_send(data) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
dropped_by_queue_full = dropped_by_queue_full.wrapping_add(1);
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break,
}
if last_stats_log.elapsed() >= Duration::from_secs(1) {
if stats_enabled {
match read_packet_socket_stats(read_fd) {
Ok(delta) => {
total_packets = total_packets.wrapping_add(delta.tp_packets as u64);
total_drops = total_drops.wrapping_add(delta.tp_drops as u64);
let denom =
(delta.tp_packets as u64).saturating_add(delta.tp_drops as u64);
let drop_rate = if denom == 0 {
0.0
} else {
(delta.tp_drops as f64) / (denom as f64)
};
tracing::debug!(
"{}: delta_packets = {}, delta_drops = {}, delta_drop_rate = {}, total_packets = {}, total_drops = {}, total_bytes = {}, dropped_by_queue_full = {}",
interface_name_for_worker,
delta.tp_packets,
delta.tp_drops,
drop_rate,
total_packets,
total_drops,
total_bytes,
dropped_by_queue_full,
);
}
Err(e) => {
stats_enabled = false;
tracing::warn!(
?e,
interface_name_for_worker,
"LinuxBpfTun failed to read PACKET_STATISTICS, stats disabled"
);
}
}
} else {
tracing::debug!(
"{}: total_bytes = {}, dropped_by_queue_full = {}",
interface_name_for_worker,
total_bytes,
dropped_by_queue_full,
);
}
last_stats_log = Instant::now();
}
}
});
@@ -419,6 +543,8 @@ impl LinuxBpfTun {
tracing::info!(
interface_name,
ifindex,
desired_rcvbuf = DEFAULT_RCVBUF_BYTES,
actual_rcvbuf,
"LinuxBpfTun created with filter {:?}",
filter
);
+2 -30
View File
@@ -54,14 +54,12 @@ use std::sync::{
Arc, RwLock,
};
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::time;
use tracing::{info, trace, warn};
const TIMEOUT: time::Duration = time::Duration::from_secs(1);
const RETRIES: usize = 6;
const MPMC_BUFFER_LEN: usize = 512;
const MPSC_BUFFER_LEN: usize = 128;
const MAX_UNACKED_LEN: u32 = 128 * 1024 * 1024; // 128MB
#[async_trait::async_trait]
@@ -90,7 +88,6 @@ struct Shared {
tuples: RwLock<HashMap<AddrTuple, flume::Sender<Bytes>>>,
listening: RwLock<HashSet<u16>>,
tun: Arc<dyn Tun>,
ready: mpsc::Sender<Socket>,
tuples_purge: broadcast::Sender<AddrTuple>,
}
@@ -99,7 +96,6 @@ pub struct Stack {
local_ip: Ipv4Addr,
local_ip6: Option<Ipv6Addr>,
local_mac: MacAddr,
ready: mpsc::Receiver<Socket>,
reader_task: ScopedTask<()>,
}
@@ -206,11 +202,6 @@ impl Socket {
}
}
pub async fn recv_bytes(&self) -> Option<Vec<u8>> {
let mut buf = [0u8; 2048];
self.recv(&mut buf).await.map(|size| buf[..size].to_vec())
}
/// Attempt to receive a datagram from the other end.
///
/// This method takes `&self`, and it can be called safely by multiple threads
@@ -218,7 +209,7 @@ impl Socket {
///
/// A return of `None` means the TCP connection is broken
/// and this socket must be closed.
pub async fn recv(&self, buf: &mut [u8]) -> Option<usize> {
pub async fn recv(&self, buf: &mut BytesMut) -> Option<usize> {
tracing::trace!(
"Socket recv called, local_addr: {:?}, remote_addr: {:?}",
self.local_addr,
@@ -306,18 +297,7 @@ impl Socket {
continue;
}
if payload.len() >= buf.len() {
tracing::warn!(
"Payload len {} > buf len {}, tcp: {:?}, payload: {:?}",
payload.len(),
buf.len(),
tcp_packet,
payload
);
continue;
}
buf[..payload.len()].copy_from_slice(payload);
buf.extend_from_slice(payload);
return Some(payload.len());
}
@@ -412,13 +392,11 @@ impl Stack {
local_ip6: Option<Ipv6Addr>,
local_mac: Option<MacAddr>,
) -> Stack {
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
let shared = Arc::new(Shared {
tuples: RwLock::new(HashMap::new()),
tun: tun.clone(),
listening: RwLock::new(HashSet::new()),
ready: ready_tx,
tuples_purge: tuples_purge_tx.clone(),
});
@@ -433,7 +411,6 @@ impl Stack {
local_ip,
local_ip6,
local_mac: local_mac.unwrap_or(MacAddr::zero()),
ready: ready_rx,
reader_task: t.into(),
}
}
@@ -448,11 +425,6 @@ impl Stack {
assert!(self.shared.listening.write().unwrap().insert(port));
}
/// Accepts an incoming connection.
pub async fn accept(&mut self) -> Socket {
self.ready.recv().await.unwrap()
}
pub async fn alloc_established_socket(
&mut self,
local_addr: SocketAddr,