Add fake tcp tunnel (experimental) (#1673)

support faketcp to avoid tcp-over-tcp problem.
linux/macos/windows are supported.

better to be used in internet env, the maximum 
performance is majorly limited by windivert/raw socket.
This commit is contained in:
KKRainbow
2025-12-25 00:10:32 +08:00
committed by GitHub
parent 0712ef762d
commit 28cd6da502
27 changed files with 3744 additions and 19 deletions
+3 -3
View File
@@ -245,13 +245,13 @@ jobs:
# windows is the only OS using a different convention for executable file name
if [[ $OS =~ ^windows.*$ && $TARGET =~ ^x86_64.*$ ]]; then
SUFFIX=.exe
cp easytier/third_party/*.dll ./artifacts/objects/
cp easytier/third_party/x86_64/* ./artifacts/objects/
elif [[ $OS =~ ^windows.*$ && $TARGET =~ ^i686.*$ ]]; then
SUFFIX=.exe
cp easytier/third_party/i686/*.dll ./artifacts/objects/
cp easytier/third_party/i686/* ./artifacts/objects/
elif [[ $OS =~ ^windows.*$ && $TARGET =~ ^aarch64.*$ ]]; then
SUFFIX=.exe
cp easytier/third_party/arm64/*.dll ./artifacts/objects/
cp easytier/third_party/arm64/* ./artifacts/objects/
fi
if [[ $GITHUB_REF_TYPE =~ ^tag$ ]]; then
TAG=$GITHUB_REF_NAME
+3 -3
View File
@@ -170,11 +170,11 @@ jobs:
if: ${{ matrix.OS == 'windows-latest' }}
run: |
if [[ $GUI_TARGET =~ ^aarch64.*$ ]]; then
cp ./easytier/third_party/arm64/*.dll ./easytier-gui/src-tauri/
cp ./easytier/third_party/arm64/* ./easytier-gui/src-tauri/
elif [[ $GUI_TARGET =~ ^i686.*$ ]]; then
cp ./easytier/third_party/i686/*.dll ./easytier-gui/src-tauri/
cp ./easytier/third_party/i686/* ./easytier-gui/src-tauri/
else
cp ./easytier/third_party/*.dll ./easytier-gui/src-tauri/
cp ./easytier/third_party/x86_64/* ./easytier-gui/src-tauri/
fi
- name: Build GUI
+1
View File
@@ -38,6 +38,7 @@ node_modules
.vite
easytier-gui/src-tauri/*.dll
easytier-gui/src-tauri/*.sys
/easytier-contrib/easytier-ohrs/dist/
.direnv
Generated
+60 -1
View File
@@ -2095,6 +2095,7 @@ dependencies = [
"bytecodec",
"byteorder",
"bytes",
"cfg-if",
"chrono",
"cidr",
"clap",
@@ -2107,6 +2108,7 @@ dependencies = [
"derive_builder",
"easytier-rpc-build 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"encoding",
"flume 0.12.0",
"futures",
"futures-util",
"gethostname 0.5.0",
@@ -2195,6 +2197,7 @@ dependencies = [
"which 7.0.3",
"wildmatch",
"winapi",
"windivert",
"windows 0.52.0",
"windows-service",
"windows-sys 0.52.0",
@@ -2577,6 +2580,15 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "etherparse"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "827292ea592108849932ad8e30218f8b1f21c0dfd0696698a18b5d0aed62d990"
dependencies = [
"arrayvec",
]
[[package]]
name = "event-listener"
version = "5.3.1"
@@ -2615,6 +2627,9 @@ name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
dependencies = [
"getrandom 0.2.15",
]
[[package]]
name = "fdeflate"
@@ -2675,6 +2690,18 @@ dependencies = [
"spin",
]
[[package]]
name = "flume"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be"
dependencies = [
"fastrand",
"futures-core",
"futures-sink",
"spin",
]
[[package]]
name = "fnv"
version = "1.0.7"
@@ -8175,7 +8202,7 @@ checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680"
dependencies = [
"atoi",
"chrono",
"flume",
"flume 0.11.0",
"futures-channel",
"futures-core",
"futures-executor",
@@ -10327,6 +10354,29 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windivert"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc6b6833a760d1c36b489314a5541a12a39d162dc8341d8f6f400212b96d3df1"
dependencies = [
"etherparse",
"thiserror 1.0.63",
"windivert-sys",
"windows 0.48.0",
]
[[package]]
name = "windivert-sys"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "832bc4af9272458a8a64395b3aabe10dc4089546486fcbd0e19b9b6d28ba6e54"
dependencies = [
"cc",
"thiserror 1.0.63",
"windows 0.48.0",
]
[[package]]
name = "window-vibrancy"
version = "0.6.0"
@@ -10342,6 +10392,15 @@ dependencies = [
"windows-version",
]
[[package]]
name = "windows"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f"
dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "windows"
version = "0.52.0"
@@ -3,7 +3,8 @@
"externalBin": [],
"resources": [
"./wintun.dll",
"./Packet.dll"
"./Packet.dll",
"./*.sys"
],
"windows": {
"webviewInstallMode": {
+7
View File
@@ -220,6 +220,10 @@ hmac = "0.12.1"
sha2 = "0.10.8"
shellexpand = "3.1.1"
# for fake tcp
flume = "0.12"
cfg-if = "1.0"
[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies]
machine-uid = "0.5.3"
@@ -233,6 +237,9 @@ resolv-conf = "0.7.3"
dbus = { version = "0.9.7", features = ["vendored"] }
which = "7.0.3"
[target.'cfg(all(windows, any(target_arch = "x86_64", target_arch = "x86")))'.dependencies]
windivert = { version = "0.6.0", features = ["static"] }
[target.'cfg(windows)'.dependencies]
windows = { version = "0.52.0", features = [
"Win32_Foundation",
+1 -1
View File
@@ -70,7 +70,7 @@ impl WindowsBuild {
let target = std::env::var("TARGET").unwrap_or_default();
if target.contains("x86_64") {
println!("cargo:rustc-link-search=native=easytier/third_party/");
println!("cargo:rustc-link-search=native=easytier/third_party/x86_64/");
} else if target.contains("i686") {
println!("cargo:rustc-link-search=native=easytier/third_party/i686/");
} else if target.contains("aarch64") {
+17 -2
View File
@@ -12,8 +12,9 @@ use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector},
tunnel::{
check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector,
udp::UdpTunnelConnector, IpVersion, TunnelConnector,
check_scheme_and_get_socket_addr, fake_tcp::FakeTcpTunnelConnector,
ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, IpVersion,
TunnelConnector,
},
};
@@ -157,6 +158,20 @@ pub async fn create_connector_by_url(
let connector = dns_connector::DNSTunnelConnector::new(url, global_ctx.clone());
Box::new(connector)
}
"faketcp" => {
let dst_addr =
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "faketcp", ip_version).await?;
let mut connector = FakeTcpTunnelConnector::new(url);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
}
Box::new(connector)
}
_ => {
return Err(Error::InvalidUrl(url.into()));
}
+4 -3
View File
@@ -21,8 +21,8 @@ use crate::{
},
peers::peer_manager::PeerManager,
tunnel::{
ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, Tunnel,
TunnelListener,
fake_tcp::FakeTcpTunnelListener, ring::RingTunnelListener, tcp::TcpTunnelListener,
udp::UdpTunnelListener, Tunnel, TunnelListener,
},
};
@@ -49,6 +49,7 @@ pub fn get_listener_by_url(
use crate::tunnel::websocket::WSTunnelListener;
Box::new(WSTunnelListener::new(l.clone()))
}
"faketcp" => Box::new(FakeTcpTunnelListener::new(l.clone())),
_ => {
return Err(Error::InvalidUrl(l.to_string()));
}
@@ -143,7 +144,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
&& !is_url_host_ipv6(&l)
&& is_url_host_unspecified(&l)
// quic enables dual-stack by default, may conflict with v4 listener
&& l.scheme() != "quic"
&& l.scheme() != "quic" && l.scheme() != "faketcp"
{
let mut ipv6_listener = l.clone();
ipv6_listener
+2 -2
View File
@@ -223,8 +223,8 @@ impl PeerConn {
if peer_mgr_hdr.packet_type != PacketType::HandShake as u8 {
return Err(Error::WaitRespError(format!(
"unexpected packet type: {:?}",
peer_mgr_hdr.packet_type
"unexpected packet type: {:?}, packet: {:?}",
peer_mgr_hdr.packet_type, rsp
)));
}
+5
View File
@@ -534,6 +534,11 @@ pub mod tests {
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
println!("connect: {:?}", tunnel.info());
if connector.remote_url().scheme() == "faketcp" {
// listener need some time to start capturing packet
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
assert_eq!(
url::Url::from(tunnel.info().unwrap().remote_addr.unwrap()),
connector.remote_url(),
+6
View File
@@ -0,0 +1,6 @@
Copyright 2021-2025 Datong Sun dndx@idndx.com
Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
https://www.apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
https://opensource.org/licenses/MIT>, at your option. Files in the project may
not be copied, modified, or distributed except according to those terms.
+482
View File
@@ -0,0 +1,482 @@
mod netfilter;
mod packet;
mod stack;
use std::net::{IpAddr, Ipv4Addr, UdpSocket};
use std::sync::Arc;
use std::{net::SocketAddr, pin::Pin};
use bytes::BytesMut;
use pnet::datalink;
use pnet::util::MacAddr;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use crate::common::scoped_task::ScopedTask;
use crate::tunnel::fake_tcp::netfilter::create_tun;
use crate::tunnel::{common::TunnelWrapper, Tunnel, TunnelError, TunnelInfo, TunnelListener};
use futures::Future;
use dashmap::DashMap;
struct IpToIfNameCache {
ip_to_ifname: DashMap<IpAddr, (String, Option<MacAddr>)>,
}
impl IpToIfNameCache {
fn new() -> Self {
Self {
ip_to_ifname: DashMap::new(),
}
}
fn reload_ip_to_ifname(&self) {
self.ip_to_ifname.clear();
let interfaces = datalink::interfaces();
for iface in interfaces {
for ip in iface.ips.iter() {
self.ip_to_ifname
.insert(ip.ip(), (iface.name.clone(), iface.mac));
}
}
}
fn get_ifname(&self, ip: &IpAddr) -> Option<(String, Option<MacAddr>)> {
if let Some(ifname) = self.ip_to_ifname.get(ip) {
Some(ifname.clone())
} else {
self.reload_ip_to_ifname();
self.ip_to_ifname.get(ip).map(|s| s.clone())
}
}
}
fn get_faketcp_tunnel_type_str(driver_type: &str) -> String {
format!("faketcp_{}", driver_type)
}
pub struct FakeTcpTunnelListener {
addr: url::Url,
os_listener: Option<tokio::net::TcpListener>,
// interface_name -> fake tcp stack
stack_map: DashMap<String, Arc<Mutex<stack::Stack>>>,
// a cache from ip addr to interface name
ip_to_ifname: IpToIfNameCache,
}
impl FakeTcpTunnelListener {
pub fn new(addr: url::Url) -> Self {
// Define filter: Capture all packets (or refine this if needed)
// For FakeTCP, we probably want to capture packets destined to us?
// But `stack::Stack` handles IP/TCP logic.
// Maybe we just capture everything for now as a raw tunnel?
// Or better, filter based on some criteria?
// The user said "satisfy filter function".
// Let's create a filter that accepts everything for now, or maybe only IP packets?
FakeTcpTunnelListener {
addr,
os_listener: None,
stack_map: DashMap::new(),
ip_to_ifname: IpToIfNameCache::new(),
}
}
async fn do_accept(&mut self) -> Result<AcceptResult, TunnelError> {
loop {
match self.os_listener.as_mut().unwrap().accept().await {
Ok((s, remote_addr)) => {
let Ok(local_addr) = s.local_addr() else {
tracing::warn!("accept fail with local_addr error");
continue;
};
let Some((interface_name, mac)) =
self.ip_to_ifname.get_ifname(&local_addr.ip())
else {
tracing::warn!("accept fail with interface_name error");
continue;
};
return Ok(AcceptResult {
socket: s,
local_addr,
remote_addr,
interface_name,
mac,
});
}
Err(e) => {
use std::io::ErrorKind::*;
if matches!(
e.kind(),
NotConnected | ConnectionAborted | ConnectionRefused | ConnectionReset
) {
tracing::warn!(?e, "accept fail with retryable error: {:?}", e);
continue;
}
tracing::warn!(?e, "accept fail");
return Err(e.into());
}
}
}
}
async fn get_stack(
&self,
accept_result: &AcceptResult,
) -> Result<Arc<Mutex<stack::Stack>>, TunnelError> {
let local_socket_addr = accept_result.local_addr;
let interface_name = &accept_result.interface_name;
let (local_ip, local_ip6) = match local_socket_addr.ip() {
IpAddr::V4(ip) => (Some(ip), None),
IpAddr::V6(ip) => (None, Some(ip)),
};
let ret = self
.stack_map
.entry(interface_name.to_string())
.or_insert_with(|| {
let tun = create_tun(interface_name, None, local_socket_addr);
tracing::info!(
?local_socket_addr,
"create new stack with interface_name: {:?}",
interface_name
);
// TODO: Get local MAC address of the interface
Arc::new(Mutex::new(stack::Stack::new(
tun,
local_ip.unwrap_or(Ipv4Addr::UNSPECIFIED),
local_ip6,
accept_result.mac,
)))
})
.clone();
Ok(ret)
}
}
fn build_os_socket_reader_task(mut socket: TcpStream) -> ScopedTask<()> {
let os_socket_reader_task: ScopedTask<()> = tokio::spawn(async move {
// read the os socket until it's closed
let mut buf = [0u8; 1024];
while let Ok(size) = socket.read(&mut buf).await {
tracing::trace!("read {} bytes from os socket", size);
if size == 0 {
break;
}
}
tracing::info!("FakeTcpTunnelListener os socket closed");
})
.into();
os_socket_reader_task
}
#[derive(Debug)]
struct AcceptResult {
socket: TcpStream,
local_addr: SocketAddr,
remote_addr: SocketAddr,
interface_name: String,
mac: Option<MacAddr>,
}
#[async_trait::async_trait]
impl TunnelListener for FakeTcpTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
let port = self.addr.port().unwrap_or(0);
let bind_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let os_listener = tokio::net::TcpListener::bind(bind_addr).await?;
tracing::info!(port, "FakeTcpTunnelListener listening");
self.os_listener = Some(os_listener);
// self.stack.lock().await.listen(port);
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tracing::debug!("FakeTcpTunnelListener waiting for accept");
let res = self.do_accept().await?;
let stack = self.get_stack(&res).await?;
let socket = stack
.lock()
.await
.alloc_established_socket(res.local_addr, res.remote_addr, stack::State::Established)
.await;
tracing::info!(
?res,
remote = socket.remote_addr().to_string(),
"FakeTcpTunnelListener accepted connection"
);
let info = TunnelInfo {
tunnel_type: get_faketcp_tunnel_type_str(stack.lock().await.driver_type()),
local_addr: Some(self.local_url().into()),
remote_addr: Some(
crate::tunnel::build_url_from_socket_addr(
&socket.remote_addr().to_string(),
"faketcp",
)
.into(),
),
};
// We treat the fake tcp socket as a datagram tunnel directly
// The reader/writer will interface with the socket using recv_bytes/send
// We need to adapt the socket to ZCPacketStream and ZCPacketSink
// Since FakeTcpTunnel is a datagram tunnel, we don't need FramedReader/Writer (which are for stream based tunnels like TCP)
// We should wrap the socket into something that produces/consumes ZCPacket directly.
let socket = Arc::new(socket);
let reader = FakeTcpStream::new(socket.clone());
let writer = FakeTcpSink::new(socket);
Ok(Box::new(TunnelWrapper::new_with_associate_data(
reader,
writer,
Some(info),
Some(Box::new(build_os_socket_reader_task(res.socket))),
)))
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
}
pub struct FakeTcpTunnelConnector {
addr: url::Url,
ip_to_if_name: IpToIfNameCache,
}
impl FakeTcpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
FakeTcpTunnelConnector {
addr,
ip_to_if_name: IpToIfNameCache::new(),
}
}
}
fn get_local_ip_for_destination(destination: IpAddr) -> Option<IpAddr> {
// 使用一个不可路由的、私有的、或回环地址创建一个临时的 socket,让内核自动选择源接口。
// 对于 IPv4,使用 0.0.0.0; 对于 IPv6,使用 ::
let bind_addr = if destination.is_ipv4() {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
} else {
IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0))
};
// 绑定到一个临时端口 (0)
let socket = UdpSocket::bind((bind_addr, 0)).ok()?;
// 尝试连接到目标地址。这不会真正发送数据包,只是让内核确定路由。
socket.connect((destination, 80)).ok()?; // 使用一个常见的端口,例如 80
// 获取 socket 的本地地址信息
socket.local_addr().map(|addr| addr.ip()).ok()
}
#[async_trait::async_trait]
impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let remote_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
&self.addr,
"faketcp",
crate::tunnel::IpVersion::Both,
)
.await?;
let local_ip = get_local_ip_for_destination(remote_addr.ip())
.ok_or(TunnelError::InternalError("Failed to get local ip".into()))?;
let os_socket = tokio::net::TcpSocket::new_v4()?;
os_socket.bind("0.0.0.0:0".parse().unwrap())?;
let local_port = os_socket.local_addr()?.port();
let local_addr = SocketAddr::new(local_ip, local_port);
let (interface_name, mac) =
self.ip_to_if_name
.get_ifname(&local_ip)
.ok_or(TunnelError::InternalError(
"Failed to get interface name".into(),
))?;
let (local_ip, local_ip6) = match local_ip {
IpAddr::V4(ip) => (Some(ip), None),
IpAddr::V6(ip) => (None, Some(ip)),
};
let tun = create_tun(&interface_name, Some(remote_addr), local_addr);
let local_ip = local_ip.unwrap_or("0.0.0.0".parse().unwrap());
let mut stack = stack::Stack::new(tun, local_ip, local_ip6, mac);
let driver_type = stack.driver_type();
let socket = stack
.alloc_established_socket(local_addr, remote_addr, stack::State::SynSent)
.await;
let os_stream = os_socket.connect(remote_addr).await?;
tracing::info!(?remote_addr, "FakeTcpTunnelConnector connecting");
socket.recv_bytes().await.ok_or(TunnelError::InternalError(
"Failed to recv bytes to establish connection".into(),
))?;
tracing::info!(local_addr = ?socket.local_addr(), "FakeTcpTunnelConnector connected");
let info = TunnelInfo {
tunnel_type: get_faketcp_tunnel_type_str(driver_type),
local_addr: Some(
crate::tunnel::build_url_from_socket_addr(
&socket.local_addr().to_string(),
"faketcp",
)
.into(),
),
remote_addr: Some(self.addr.clone().into()),
};
let socket = Arc::new(socket);
let reader = FakeTcpStream::new(socket.clone());
let writer = FakeTcpSink::new(socket);
Ok(Box::new(TunnelWrapper::new_with_associate_data(
reader,
writer,
Some(info),
Some(Box::new((build_os_socket_reader_task(os_stream), stack))),
)))
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
}
use crate::tunnel::packet_def::{ZCPacket, ZCPacketType};
use crate::tunnel::{SinkError, SinkItem, StreamItem};
use futures::{Sink, Stream};
use std::task::{Context as TaskContext, Poll};
struct FakeTcpStream {
socket: Arc<stack::Socket>,
#[allow(clippy::type_complexity)]
recv_fut: Option<Pin<Box<dyn Future<Output = Option<Vec<u8>>> + Send + Sync>>>,
}
impl FakeTcpStream {
fn new(socket: Arc<stack::Socket>) -> Self {
Self {
socket,
recv_fut: None,
}
}
}
impl Stream for FakeTcpStream {
type Item = StreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let s = self.get_mut();
if s.recv_fut.is_none() {
let socket = s.socket.clone();
s.recv_fut = Some(Box::pin(async move { socket.recv_bytes().await }));
}
match s.recv_fut.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready(Some(data)) => {
let mut buf = BytesMut::new();
buf.extend_from_slice(&data);
let packet = ZCPacket::new_from_buf(buf, ZCPacketType::DummyTunnel);
s.recv_fut = None;
Poll::Ready(Some(Ok(packet)))
}
Poll::Ready(None) => {
// 连接关闭
s.recv_fut = None;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
struct FakeTcpSink {
socket: Arc<stack::Socket>,
}
impl FakeTcpSink {
fn new(socket: Arc<stack::Socket>) -> Self {
Self { socket }
}
}
impl Sink<SinkItem> for FakeTcpSink {
type Error = SinkError;
fn poll_ready(
self: Pin<&mut Self>,
_cx: &mut TaskContext<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
// We need to send the packet as bytes
// The item is ZCPacket, which has into_bytes() method
let bytes = item.convert_type(ZCPacketType::DummyTunnel).into_bytes();
// Let's just spawn for now as a simple implementation, noting the limitation.
self.socket.try_send(&bytes);
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut TaskContext<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut TaskContext<'_>,
) -> Poll<Result<(), Self::Error>> {
self.socket.close();
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use crate::tunnel::common::tests::_tunnel_pingpong;
use super::*;
#[tokio::test]
async fn faketcp_pingpong() {
#[cfg(target_family = "unix")]
{
if unsafe { nix::libc::geteuid() } != 0 {
return;
}
}
let listener = FakeTcpTunnelListener::new("faketcp://0.0.0.0:31011".parse().unwrap());
let connector = FakeTcpTunnelConnector::new("faketcp://127.0.0.1:31011".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
}
@@ -0,0 +1,708 @@
use bytes::Bytes;
use bytes::BytesMut;
use nix::libc;
use std::ffi::CString;
use std::io;
use std::mem;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::tunnel::fake_tcp::stack;
const ETH_HDR_LEN: usize = 14;
const ETH_TYPE_OFFSET: u32 = 12;
const ETHERTYPE_IPV4: u32 = 0x0800;
const ETHERTYPE_IPV6: u32 = 0x86DD;
const IPPROTO_TCP_U32: u32 = 6;
const BPF_LD: u16 = 0x00;
const BPF_LDX: u16 = 0x01;
const BPF_JMP: u16 = 0x05;
const BPF_RET: u16 = 0x06;
const BPF_W: u16 = 0x00;
const BPF_H: u16 = 0x08;
const BPF_B: u16 = 0x10;
const BPF_ABS: u16 = 0x20;
const BPF_IND: u16 = 0x40;
const BPF_MSH: u16 = 0xa0;
const BPF_JA: u16 = 0x00;
const BPF_JEQ: u16 = 0x10;
const BPF_K: u16 = 0x00;
fn stmt(code: u16, k: u32) -> libc::sock_filter {
libc::sock_filter {
code,
jt: 0,
jf: 0,
k,
}
}
fn jeq(k: u32, jt: u8, jf: u8) -> libc::sock_filter {
libc::sock_filter {
code: BPF_JMP | BPF_JEQ | BPF_K,
jt,
jf,
k,
}
}
fn ja(k: u32) -> libc::sock_filter {
libc::sock_filter {
code: BPF_JMP | BPF_JA,
jt: 0,
jf: 0,
k,
}
}
#[derive(Clone, Copy)]
struct Label(usize);
struct JeqPatch {
idx: usize,
t: Label,
f: Label,
}
struct JaPatch {
idx: usize,
target: Label,
}
struct BpfBuilder {
insns: Vec<libc::sock_filter>,
labels: Vec<Option<usize>>,
jeq_patches: Vec<JeqPatch>,
ja_patches: Vec<JaPatch>,
}
impl BpfBuilder {
fn new() -> Self {
Self {
insns: Vec::new(),
labels: Vec::new(),
jeq_patches: Vec::new(),
ja_patches: Vec::new(),
}
}
fn new_label(&mut self) -> Label {
let idx = self.labels.len();
self.labels.push(None);
Label(idx)
}
fn set_label(&mut self, label: Label) {
self.labels[label.0] = Some(self.insns.len());
}
fn push(&mut self, insn: libc::sock_filter) {
self.insns.push(insn);
}
fn push_jeq(&mut self, k: u32, t: Label, f: Label) {
let idx = self.insns.len();
self.insns.push(jeq(k, 0, 0));
self.jeq_patches.push(JeqPatch { idx, t, f });
}
fn push_ja(&mut self, target: Label) {
let idx = self.insns.len();
self.insns.push(ja(0));
self.ja_patches.push(JaPatch { idx, target });
}
fn finish(mut self) -> io::Result<Vec<libc::sock_filter>> {
for patch in self.jeq_patches {
let JeqPatch { idx, t, f } = patch;
let cur = idx + 1;
let t_pos =
self.labels.get(t.0).and_then(|v| *v).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "unresolved label")
})?;
let f_pos =
self.labels.get(f.0).and_then(|v| *v).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "unresolved label")
})?;
if t_pos < cur || f_pos < cur {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"backward bpf jump",
));
}
let jt: u8 = (t_pos - cur)
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?;
let jf: u8 = (f_pos - cur)
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?;
self.insns[idx].jt = jt;
self.insns[idx].jf = jf;
}
for patch in self.ja_patches {
let JaPatch { idx, target } = patch;
let cur = idx + 1;
let t_pos =
self.labels.get(target.0).and_then(|v| *v).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "unresolved label")
})?;
if t_pos < cur {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"backward bpf jump",
));
}
self.insns[idx].k = (t_pos - cur) as u32;
}
Ok(self.insns)
}
}
fn build_tcp_filter(
src_addr: Option<SocketAddr>,
dst_addr: SocketAddr,
) -> io::Result<Vec<libc::sock_filter>> {
if let Some(src) = src_addr {
if src.is_ipv4() != dst_addr.is_ipv4() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"src/dst addr family mismatch",
));
}
}
let mut b = BpfBuilder::new();
let l_check_ipv6 = b.new_label();
let l_ipv4 = b.new_label();
let l_ipv6 = b.new_label();
let l_accept = b.new_label();
let l_reject = b.new_label();
b.push(stmt(BPF_LD | BPF_H | BPF_ABS, ETH_TYPE_OFFSET));
b.push_jeq(ETHERTYPE_IPV4, l_ipv4, l_check_ipv6);
b.set_label(l_check_ipv6);
b.push_jeq(ETHERTYPE_IPV6, l_ipv6, l_reject);
if dst_addr.is_ipv4() {
b.set_label(l_ipv4);
let l_v4_proto_ok = b.new_label();
b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 9) as u32));
b.push_jeq(IPPROTO_TCP_U32, l_v4_proto_ok, l_reject);
b.set_label(l_v4_proto_ok);
let dst_ip = match dst_addr.ip() {
IpAddr::V4(ip) => u32::from(ip),
_ => unreachable!(),
};
let l_v4_dstip_ok = b.new_label();
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 16) as u32));
b.push_jeq(dst_ip, l_v4_dstip_ok, l_reject);
b.set_label(l_v4_dstip_ok);
if let Some(src) = src_addr {
let src_ip = match src.ip() {
IpAddr::V4(ip) => u32::from(ip),
_ => unreachable!(),
};
let l_v4_srcip_ok = b.new_label();
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 12) as u32));
b.push_jeq(src_ip, l_v4_srcip_ok, l_reject);
b.set_label(l_v4_srcip_ok);
}
b.push(stmt(BPF_LDX | BPF_B | BPF_MSH, ETH_HDR_LEN as u32));
let l_v4_dstport_ok = b.new_label();
b.push(stmt(BPF_LD | BPF_H | BPF_IND, (ETH_HDR_LEN + 2) as u32));
b.push_jeq(dst_addr.port() as u32, l_v4_dstport_ok, l_reject);
b.set_label(l_v4_dstport_ok);
if let Some(src) = src_addr {
b.push(stmt(BPF_LD | BPF_H | BPF_IND, ETH_HDR_LEN as u32));
b.push_jeq(src.port() as u32, l_accept, l_reject);
} else {
b.push_ja(l_accept);
}
} else {
b.set_label(l_ipv6);
let l_v6_proto_ok = b.new_label();
b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 6) as u32));
b.push_jeq(IPPROTO_TCP_U32, l_v6_proto_ok, l_reject);
b.set_label(l_v6_proto_ok);
let dst_ip = match dst_addr.ip() {
IpAddr::V6(ip) => ip.octets(),
_ => unreachable!(),
};
for (i, chunk) in dst_ip.chunks_exact(4).enumerate() {
let off = ETH_HDR_LEN + 24 + (i * 4);
let v = u32::from_be_bytes(chunk.try_into().unwrap());
let l_v6_dstip_word_ok = b.new_label();
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32));
b.push_jeq(v, l_v6_dstip_word_ok, l_reject);
b.set_label(l_v6_dstip_word_ok);
}
if let Some(src) = src_addr {
let src_ip = match src.ip() {
IpAddr::V6(ip) => ip.octets(),
_ => unreachable!(),
};
for (i, chunk) in src_ip.chunks_exact(4).enumerate() {
let off = ETH_HDR_LEN + 8 + (i * 4);
let v = u32::from_be_bytes(chunk.try_into().unwrap());
let l_v6_srcip_word_ok = b.new_label();
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32));
b.push_jeq(v, l_v6_srcip_word_ok, l_reject);
b.set_label(l_v6_srcip_word_ok);
}
}
let l_v6_dstport_ok = b.new_label();
b.push(stmt(
BPF_LD | BPF_H | BPF_ABS,
(ETH_HDR_LEN + 40 + 2) as u32,
));
b.push_jeq(dst_addr.port() as u32, l_v6_dstport_ok, l_reject);
b.set_label(l_v6_dstport_ok);
if let Some(src) = src_addr {
b.push(stmt(BPF_LD | BPF_H | BPF_ABS, (ETH_HDR_LEN + 40) as u32));
b.push_jeq(src.port() as u32, l_accept, l_reject);
} else {
b.push_ja(l_accept);
}
}
b.set_label(l_accept);
b.push(stmt(BPF_RET | BPF_K, 0xFFFF));
b.set_label(l_reject);
if dst_addr.is_ipv4() {
b.set_label(l_ipv6);
} else {
b.set_label(l_ipv4);
}
b.push(stmt(BPF_RET | BPF_K, 0));
b.finish()
}
pub struct LinuxBpfTun {
fd: OwnedFd,
ifindex: i32,
stop: Arc<AtomicBool>,
worker: Option<std::thread::JoinHandle<()>>,
recv_queue: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
}
impl LinuxBpfTun {
pub fn new(
interface_name: &str,
src_addr: Option<SocketAddr>,
dst_addr: SocketAddr,
) -> io::Result<Self> {
let c_ifname = CString::new(interface_name)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid interface name"))?;
let ifindex = unsafe { libc::if_nametoindex(c_ifname.as_ptr()) as i32 };
if ifindex <= 0 {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"interface not found",
));
}
let proto: i32 = (libc::ETH_P_ALL as u16).to_be() as i32;
let fd = unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_RAW, proto) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let fd = unsafe { OwnedFd::from_raw_fd(fd) };
let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() };
addr.sll_family = libc::AF_PACKET as u16;
addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be();
addr.sll_ifindex = ifindex;
let bind_ret = unsafe {
libc::bind(
fd.as_raw_fd(),
&addr as *const _ as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_ll>() as u32,
)
};
if bind_ret != 0 {
return Err(io::Error::last_os_error());
}
let filter = build_tcp_filter(src_addr, dst_addr)?;
let mut prog = libc::sock_fprog {
len: filter
.len()
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf program too long"))?,
filter: filter.as_ptr() as *mut libc::sock_filter,
};
let opt_ret = unsafe {
libc::setsockopt(
fd.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_ATTACH_FILTER,
&mut prog as *mut _ as *mut libc::c_void,
mem::size_of::<libc::sock_fprog>() as u32,
)
};
if opt_ret != 0 {
return Err(io::Error::last_os_error());
}
let timeout = libc::timeval {
tv_sec: 0,
tv_usec: 200_000,
};
let _ = unsafe {
libc::setsockopt(
fd.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_RCVTIMEO,
&timeout as *const _ as *const libc::c_void,
mem::size_of::<libc::timeval>() as u32,
)
};
let stop = Arc::new(AtomicBool::new(false));
let (tx, rx) = tokio::sync::mpsc::channel(1024);
let stop_clone = stop.clone();
let read_fd = fd.as_raw_fd();
let worker = std::thread::spawn(move || {
let mut buf = vec![0u8; 65536];
while !stop_clone.load(AtomicOrdering::Relaxed) {
let n = unsafe {
libc::recv(read_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0)
};
if n < 0 {
let err = io::Error::last_os_error();
if matches!(
err.kind(),
io::ErrorKind::Interrupted | io::ErrorKind::WouldBlock
) {
continue;
}
break;
}
if n == 0 {
continue;
}
let data = buf[..(n as usize)].to_vec();
if tx.blocking_send(data).is_err() {
break;
}
}
});
tracing::info!(
interface_name,
ifindex,
"LinuxBpfTun created with filter {:?}",
filter
);
Ok(Self {
fd,
ifindex,
stop,
worker: Some(worker),
recv_queue: Mutex::new(rx),
})
}
}
impl Drop for LinuxBpfTun {
fn drop(&mut self) {
self.stop.store(true, AtomicOrdering::Relaxed);
let _ = unsafe { libc::shutdown(self.fd.as_raw_fd(), libc::SHUT_RD) };
if let Some(worker) = self.worker.take() {
let _ = worker.join();
}
}
}
#[async_trait::async_trait]
impl stack::Tun for LinuxBpfTun {
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error> {
let mut rx = self.recv_queue.lock().await;
match rx.recv().await {
Some(data) => {
packet.extend_from_slice(&data);
Ok(data.len())
}
None => Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"LinuxBpfTun channel closed",
)),
}
}
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> {
if packet.len() < 6 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"packet too short",
));
}
let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() };
addr.sll_family = libc::AF_PACKET as u16;
addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be();
addr.sll_ifindex = self.ifindex;
addr.sll_halen = 6;
addr.sll_addr[..6].copy_from_slice(&packet[..6]);
let ret = unsafe {
libc::sendto(
self.fd.as_raw_fd(),
packet.as_ptr() as *const libc::c_void,
packet.len(),
0,
&addr as *const _ as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_ll>() as u32,
)
};
if ret < 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
fn driver_type(&self) -> &'static str {
"linux_bpf"
}
}
#[cfg(all(test, target_os = "linux"))]
mod tests {
use super::*;
use crate::tunnel::fake_tcp::packet::build_tcp_packet;
use crate::tunnel::fake_tcp::stack::Tun;
use pnet::datalink;
use pnet::packet::tcp::TcpFlags;
use pnet::util::MacAddr;
use rand::Rng;
use std::net::{IpAddr, Ipv4Addr};
use tokio::time::{timeout, Duration};
fn is_root() -> bool {
unsafe { libc::geteuid() == 0 }
}
fn pick_interface_v4() -> Option<(String, Ipv4Addr, MacAddr)> {
let interfaces = datalink::interfaces();
for iface in interfaces {
let Some(mac) = iface.mac else {
continue;
};
if iface.is_loopback() {
continue;
}
let ipv4 = iface.ips.iter().find_map(|n| match n.ip() {
IpAddr::V4(ip) => Some(ip),
IpAddr::V6(_) => None,
})?;
return Some((iface.name, ipv4, mac));
}
None
}
fn send_raw_frame(interface_name: &str, frame: &[u8]) -> io::Result<()> {
if frame.len() < 6 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"frame too short",
));
}
let c_ifname = CString::new(interface_name)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid interface name"))?;
let ifindex = unsafe { libc::if_nametoindex(c_ifname.as_ptr()) as i32 };
if ifindex <= 0 {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"interface not found",
));
}
let proto: i32 = (libc::ETH_P_ALL as u16).to_be() as i32;
let fd = unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_RAW, proto) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let fd = unsafe { OwnedFd::from_raw_fd(fd) };
let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() };
addr.sll_family = libc::AF_PACKET as u16;
addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be();
addr.sll_ifindex = ifindex;
addr.sll_halen = 6;
addr.sll_addr[..6].copy_from_slice(&frame[..6]);
let ret = unsafe {
libc::sendto(
fd.as_raw_fd(),
frame.as_ptr() as *const libc::c_void,
frame.len(),
0,
&addr as *const _ as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_ll>() as u32,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
#[tokio::test]
async fn linux_bpf_tun_receives_matching_ipv4_frame() {
if !is_root() {
eprintln!("linux_bpf_tun_receives_matching_ipv4_frame: skipped (not root)");
return;
}
let Some((ifname, dst_ip, mac)) = pick_interface_v4() else {
eprintln!("linux_bpf_tun_receives_matching_ipv4_frame: skipped (no suitable iface)");
return;
};
let dst_port: u16 = rand::thread_rng().gen_range(40000..60000);
let dst_addr = SocketAddr::new(IpAddr::V4(dst_ip), dst_port);
eprintln!(
"linux_bpf_tun_receives_matching_ipv4_frame: ifname={ifname} dst_addr={dst_addr} mac={mac}"
);
let tun = LinuxBpfTun::new(&ifname, None, dst_addr).unwrap();
let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 123, 0, 1)), 12345);
eprintln!(
"linux_bpf_tun_receives_matching_ipv4_frame: sending frame src_addr={src_addr} -> dst_addr={dst_addr}"
);
let frame = build_tcp_packet(
mac,
mac,
src_addr,
dst_addr,
1,
0,
TcpFlags::SYN,
Some(b"ping"),
);
send_raw_frame(&ifname, &frame).unwrap();
let mut received = BytesMut::new();
let n = timeout(Duration::from_secs(2), tun.recv(&mut received))
.await
.unwrap()
.unwrap();
eprintln!(
"linux_bpf_tun_receives_matching_ipv4_frame: received {} bytes",
n
);
assert_eq!(n, frame.len());
assert_eq!(&received[..], &frame[..]);
}
#[tokio::test]
async fn linux_bpf_tun_filters_out_non_matching_ipv4_frame() {
if !is_root() {
eprintln!("linux_bpf_tun_filters_out_non_matching_ipv4_frame: skipped (not root)");
return;
}
let Some((ifname, dst_ip, mac)) = pick_interface_v4() else {
eprintln!(
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: skipped (no suitable iface)"
);
return;
};
let dst_port: u16 = rand::thread_rng().gen_range(40000..60000);
let dst_addr = SocketAddr::new(IpAddr::V4(dst_ip), dst_port);
eprintln!(
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: ifname={ifname} dst_addr={dst_addr} mac={mac}"
);
let tun = LinuxBpfTun::new(&ifname, None, dst_addr).unwrap();
let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 123, 0, 2)), 23456);
let non_matching_dst = SocketAddr::new(IpAddr::V4(dst_ip), dst_port.wrapping_add(1));
eprintln!(
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: sending non-matching src_addr={src_addr} -> dst_addr={non_matching_dst}"
);
let non_matching = build_tcp_packet(
mac,
mac,
src_addr,
non_matching_dst,
1,
0,
TcpFlags::SYN,
Some(b"nope"),
);
send_raw_frame(&ifname, &non_matching).unwrap();
let mut received = BytesMut::new();
let non_matching_timeout = timeout(Duration::from_millis(400), tun.recv(&mut received))
.await
.is_err();
eprintln!(
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: non-matching recv timeout={}",
non_matching_timeout
);
assert!(non_matching_timeout);
eprintln!(
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: sending matching src_addr={src_addr} -> dst_addr={dst_addr}"
);
let matching = build_tcp_packet(
mac,
mac,
src_addr,
dst_addr,
2,
0,
TcpFlags::SYN,
Some(b"ok"),
);
send_raw_frame(&ifname, &matching).unwrap();
let mut received2 = BytesMut::new();
let n = timeout(Duration::from_secs(2), tun.recv(&mut received2))
.await
.unwrap()
.unwrap();
eprintln!(
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: received {} bytes",
n
);
assert_eq!(n, matching.len());
assert_eq!(&received2[..], &matching[..]);
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,87 @@
pub mod pnet;
use std::{net::SocketAddr, sync::Arc};
cfg_if::cfg_if! {
if #[cfg(target_os = "linux")] {
pub mod linux_bpf;
pub fn create_tun(
interface_name: &str,
src_addr: Option<SocketAddr>,
dst_addr: SocketAddr,
) -> Arc<dyn super::stack::Tun> {
match linux_bpf::LinuxBpfTun::new(interface_name, src_addr, dst_addr) {
Ok(tun) => Arc::new(tun),
Err(e) => {
tracing::warn!(
?e,
interface_name,
"LinuxBpfTun init failed, falling back to PnetTun"
);
Arc::new(pnet::PnetTun::new(
interface_name,
pnet::create_packet_filter(src_addr, dst_addr),
))
}
}
}
} else if #[cfg(target_os = "macos")] {
pub mod macos_bpf;
pub fn create_tun(
interface_name: &str,
src_addr: Option<SocketAddr>,
dst_addr: SocketAddr,
) -> Arc<dyn super::stack::Tun> {
match macos_bpf::MacosBpfTun::new(interface_name, src_addr, dst_addr) {
Ok(tun) => Arc::new(tun),
Err(e) => {
tracing::warn!(
?e,
interface_name,
"MacosBpfTun init failed, falling back to PnetTun"
);
Arc::new(pnet::PnetTun::new(
interface_name,
pnet::create_packet_filter(src_addr, dst_addr),
))
}
}
}
} else if #[cfg(all(windows, any(target_arch = "x86_64", target_arch = "x86")))] {
pub mod windivert;
pub fn create_tun(
_interface_name: &str,
_src_addr: Option<SocketAddr>,
local_addr: SocketAddr,
) -> Arc<dyn super::stack::Tun> {
match windivert::WinDivertTun::new(local_addr) {
Ok(tun) => Arc::new(tun),
Err(e) => {
tracing::warn!(
?e,
?local_addr,
"WinDivertTun init failed, falling back to PnetTun"
);
Arc::new(pnet::PnetTun::new(
local_addr.to_string().as_str(),
pnet::create_packet_filter(None, local_addr),
))
}
}
}
} else {
pub fn create_tun(
interface_name: &str,
src_addr: Option<SocketAddr>,
dst_addr: SocketAddr,
) -> Arc<dyn super::stack::Tun> {
Arc::new(pnet::PnetTun::new(
interface_name,
pnet::create_packet_filter(src_addr, dst_addr),
))
}
}
}
@@ -0,0 +1,304 @@
use std::{
net::{IpAddr, SocketAddr},
sync::{
atomic::{AtomicU32, Ordering},
Arc, Weak,
},
};
use bytes::{Bytes, BytesMut};
use dashmap::DashMap;
use once_cell::sync::Lazy;
use pnet::{
datalink::{self, DataLinkSender, NetworkInterface},
packet::{ethernet::EtherTypes, ip::IpNextHeaderProtocols, ipv6::Ipv6Packet},
};
use tokio::sync::Mutex;
use crate::tunnel::fake_tcp::stack;
type PacketFilter = Box<dyn Fn(&[u8]) -> bool + Send + Sync>;
fn filter_tcp_packet(
packet: &[u8],
src_addr: Option<&SocketAddr>,
dst_addr: Option<&SocketAddr>,
) -> bool {
use pnet::packet::ethernet::EthernetPacket;
use pnet::packet::ipv4::Ipv4Packet;
use pnet::packet::tcp::TcpPacket;
use pnet::packet::Packet;
let ethernet = if let Some(ethernet) = EthernetPacket::new(packet) {
ethernet
} else {
return false;
};
match ethernet.get_ethertype() {
EtherTypes::Ipv4 => {
let ipv4 = if let Some(ipv4) = Ipv4Packet::new(ethernet.payload()) {
ipv4
} else {
return false;
};
if ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
return false;
}
let tcp = if let Some(tcp) = TcpPacket::new(ipv4.payload()) {
tcp
} else {
return false;
};
if let Some(src_addr) = src_addr {
if IpAddr::V4(ipv4.get_source()) != src_addr.ip() {
return false;
}
if tcp.get_source() != src_addr.port() {
return false;
}
}
if let Some(dst_addr) = dst_addr {
if IpAddr::V4(ipv4.get_destination()) != dst_addr.ip() {
return false;
}
if tcp.get_destination() != dst_addr.port() {
return false;
}
}
tracing::trace!(
?tcp,
"FakeTcpTunnelListener packet matched filter, dispatching, src_addr: {:?}, dst_addr: {:?}, packet_src_ip: {:?}, packet_dst_ip: {:?}, packet_src_port: {:?}, packet_dst_port: {:?}",
src_addr,
dst_addr,
ipv4.get_source(),
ipv4.get_destination(),
tcp.get_source(),
tcp.get_destination(),
);
}
EtherTypes::Ipv6 => {
let ipv6 = if let Some(ipv6) = Ipv6Packet::new(ethernet.payload()) {
ipv6
} else {
return false;
};
if ipv6.get_next_header() != IpNextHeaderProtocols::Tcp {
return false;
}
let tcp = if let Some(tcp) = TcpPacket::new(ipv6.payload()) {
tcp
} else {
return false;
};
if let Some(src_addr) = src_addr {
if IpAddr::V6(ipv6.get_source()) != src_addr.ip() {
return false;
}
if tcp.get_source() != src_addr.port() {
return false;
}
}
if let Some(dst_addr) = dst_addr {
if IpAddr::V6(ipv6.get_destination()) != dst_addr.ip() {
return false;
}
if tcp.get_destination() != dst_addr.port() {
return false;
}
}
tracing::trace!(
?tcp,
"FakeTcpTunnelListener packet matched filter, dispatching"
);
}
_ => return false,
}
true
}
pub fn create_packet_filter(src_addr: Option<SocketAddr>, dst_addr: SocketAddr) -> PacketFilter {
Box::new(move |packet: &[u8]| -> bool {
filter_tcp_packet(packet, src_addr.as_ref(), Some(&dst_addr))
})
}
struct Subscriber {
filter: PacketFilter,
sender: tokio::sync::mpsc::Sender<Vec<u8>>,
}
struct InterfaceWorker {
tx: Mutex<Box<dyn DataLinkSender>>,
subscribers: Arc<DashMap<u32, Subscriber>>,
}
impl InterfaceWorker {
fn new(interface: NetworkInterface) -> Arc<Self> {
let (tx, mut rx) = match datalink::channel(&interface, Default::default()) {
Ok(pnet::datalink::Channel::Ethernet(tx, rx)) => (tx, rx),
Ok(_) => panic!("Unhandled channel type"),
Err(e) => panic!(
"An error occurred when creating the datalink channel: {}",
e
),
};
let subscribers = Arc::new(DashMap::<u32, Subscriber>::new());
let subscribers_clone = subscribers.clone();
std::thread::spawn(move || {
loop {
match rx.next() {
Ok(packet) => {
// Iterate over subscribers and send packet if filter matches
// Note: DashMap iteration might be slow if many subscribers, but usually few per interface.
// For high performance we might need a better structure or read-copy-update.
for r in subscribers_clone.iter() {
let subscriber = r.value();
if (subscriber.filter)(packet) {
tracing::trace!(
?packet,
"InterfaceWorker packet matched filter, dispatching"
);
// Try send, ignore errors (best effort)
let _ = subscriber.sender.try_send(packet.to_vec());
}
}
}
Err(e) => {
tracing::error!("InterfaceWorker read error: {}", e);
// If interface goes down, we might need to handle it.
// For now just break and maybe the worker is dead.
break;
}
}
}
});
Arc::new(Self {
tx: Mutex::new(tx),
subscribers,
})
}
fn subscribe(&self, filter: PacketFilter, sender: tokio::sync::mpsc::Sender<Vec<u8>>) -> u32 {
static ID_GEN: AtomicU32 = AtomicU32::new(0);
let id = ID_GEN.fetch_add(1, Ordering::Relaxed);
self.subscribers.insert(id, Subscriber { filter, sender });
id
}
fn unsubscribe(&self, id: u32) {
self.subscribers.remove(&id);
}
}
static INTERFACE_MANAGERS: Lazy<DashMap<String, Weak<InterfaceWorker>>> = Lazy::new(DashMap::new);
fn get_or_create_worker(interface_name: &str) -> Arc<InterfaceWorker> {
// Check if we have an active worker
if let Some(worker) = INTERFACE_MANAGERS
.get(interface_name)
.and_then(|w| w.upgrade())
{
return worker;
}
// Need to create new worker.
// Lock effectively by using entry API? DashMap entry API might not be enough for complex init.
// Let's use a double-check locking style or just accept race condition (creating two workers and one wins).
// DashMap doesn't support easy "compute_if_absent" with async or heavy logic without blocking the map shard.
// But creation is rare.
// Let's find interface first.
let interfaces = datalink::interfaces();
let interface = interfaces
.into_iter()
.find(|iface| iface.name == interface_name)
.expect("Network interface not found");
let worker = InterfaceWorker::new(interface);
INTERFACE_MANAGERS.insert(interface_name.to_string(), Arc::downgrade(&worker));
worker
}
pub struct PnetTun {
worker: Arc<InterfaceWorker>,
subscription_id: u32,
recv_queue: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
}
impl PnetTun {
pub fn new(interface_name: &str, filter: PacketFilter) -> Self {
tracing::debug!(interface_name, "Creating new PnetTun");
let worker = get_or_create_worker(interface_name);
let (tx, rx) = tokio::sync::mpsc::channel(1024);
let id = worker.subscribe(filter, tx);
Self {
worker,
subscription_id: id,
recv_queue: Mutex::new(rx),
}
}
}
impl Drop for PnetTun {
fn drop(&mut self) {
tracing::debug!(subscription_id = self.subscription_id, "Dropping PnetTun");
self.worker.unsubscribe(self.subscription_id);
}
}
#[async_trait::async_trait]
impl stack::Tun for PnetTun {
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error> {
let mut rx = self.recv_queue.lock().await;
match rx.recv().await {
Some(data) => {
tracing::trace!(?data, "PnetTun received packet");
packet.extend_from_slice(&data);
Ok(data.len())
}
None => {
tracing::warn!("PnetTun recv channel closed");
Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"PnetTun channel closed",
))
}
}
}
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> {
tracing::trace!(len = packet.len(), "PnetTun try_sending packet");
// We need async lock for tx.
// try_send is sync. We can use try_lock if available or blocking lock.
// tokio::sync::Mutex::try_lock is available.
if let Ok(mut tx) = self.worker.tx.try_lock() {
tx.send_to(packet, None)
.ok_or(std::io::Error::other("send_to failed"))?
} else {
Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"PnetTun tx lock busy",
))
}
}
fn driver_type(&self) -> &'static str {
"pnet"
}
}
@@ -0,0 +1,196 @@
use std::cell::UnsafeCell;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::Context as _;
use bytes::{Bytes, BytesMut};
use tokio::sync::Mutex;
use windivert::error::WinDivertError;
use windivert::packet::WinDivertPacket;
use windivert::prelude::{WinDivertFlags, WinDivertShutdownMode};
use windivert::{layer, WinDivert};
use crate::tunnel::fake_tcp::stack;
struct WinDivertReader {
inner: UnsafeCell<WinDivert<layer::NetworkLayer>>,
}
unsafe impl Send for WinDivertReader {}
unsafe impl Sync for WinDivertReader {}
impl WinDivertReader {
fn new(inner: WinDivert<layer::NetworkLayer>) -> Self {
Self {
inner: UnsafeCell::new(inner),
}
}
fn recv<'a>(
&self,
buffer: Option<&'a mut [u8]>,
) -> Result<WinDivertPacket<'a, layer::NetworkLayer>, WinDivertError> {
let inner = unsafe { &*self.inner.get() };
inner.recv(buffer)
}
fn shutdown(&self) -> anyhow::Result<()> {
let inner = unsafe { &mut *self.inner.get() };
inner
.shutdown(WinDivertShutdownMode::Recv)
.with_context(|| "WinDivertReader shutdown failed")?;
Ok(())
}
fn close(&self) -> anyhow::Result<()> {
let inner = unsafe { &mut *self.inner.get() };
inner
.close(windivert::CloseAction::Nothing)
.with_context(|| "WinDivertReader close failed")?;
Ok(())
}
}
impl Drop for WinDivertReader {
fn drop(&mut self) {
if let Err(e) = self.close() {
tracing::error!("WinDivertReader close failed: {:?}", e);
}
}
}
pub struct WinDivertTun {
recv_queue: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
sender: Arc<std::sync::Mutex<WinDivert<layer::NetworkLayer>>>,
reader: Arc<WinDivertReader>,
}
impl Drop for WinDivertTun {
fn drop(&mut self) {
if let Ok(mut sender) = self.sender.lock() {
if let Err(e) = sender.close(windivert::CloseAction::Nothing) {
tracing::error!("WinDivertSender close failed: {:?}", e);
}
}
if let Err(e) = self.reader.shutdown() {
tracing::error!("WinDivertReader shutdown failed: {:?}", e);
}
}
}
impl WinDivertTun {
pub fn new(local_addr: SocketAddr) -> io::Result<Self> {
let (tx, rx) = tokio::sync::mpsc::channel(1024);
let ip_filter = match local_addr {
SocketAddr::V4(addr) => format!("ip.DstAddr == {}", addr.ip()),
SocketAddr::V6(addr) => format!("ipv6.DstAddr == {}", addr.ip()),
};
// Filter: DstIP == LocalIP AND TCP.
let filter = format!("{} and tcp", ip_filter);
// Sniff mode: 1 (WINDIVERT_FLAG_SNIFF)
// Layer: Network (0)
// Priority: 0
let flags = WinDivertFlags::default().set_sniff();
let reader = WinDivert::network(&filter, 0, flags)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let reader = Arc::new(WinDivertReader::new(reader));
let reader_clone = reader.clone();
std::thread::spawn(move || {
let reader = reader_clone;
let mut buffer = vec![0u8; 65536];
loop {
match reader.recv(Some(&mut buffer)) {
Ok(packet) => {
let data = &packet.data;
let mut eth_data = vec![0u8; 14 + data.len()];
// Set EtherType
if data.len() > 0 && data[0] >> 4 == 4 {
eth_data[12] = 0x08;
eth_data[13] = 0x00;
} else {
eth_data[12] = 0x86;
eth_data[13] = 0xDD;
}
eth_data[14..].copy_from_slice(data);
if let Err(_) = tx.blocking_send(eth_data) {
break;
}
}
Err(_) => {
// log error?
break;
}
}
}
});
// Sender: non-sniff, empty filter?
// Use "false" to avoid capturing anything.
// Flags: 0
let sender = WinDivert::network("false", 0, WinDivertFlags::default())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(Self {
recv_queue: Mutex::new(rx),
sender: Arc::new(std::sync::Mutex::new(sender)),
reader,
})
}
}
#[async_trait::async_trait]
impl stack::Tun for WinDivertTun {
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error> {
let mut rx = self.recv_queue.lock().await;
match rx.recv().await {
Some(data) => {
packet.extend_from_slice(&data);
Ok(data.len())
}
None => Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Channel closed",
)),
}
}
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> {
// Strip ethernet header
if packet.len() < 14 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Packet too short",
));
}
let ip_data = &packet[14..];
let Ok(sender) = self.sender.try_lock() else {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"WinDivert sender lock failed",
));
};
let mut pkt = unsafe { WinDivertPacket::<layer::NetworkLayer>::new(ip_data.to_vec()) };
pkt.address.set_outbound(true);
sender.send(&pkt).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("WinDivert send failed: {}", e),
)
})?;
Ok(())
}
fn driver_type(&self) -> &'static str {
"windivert"
}
}
+159
View File
@@ -0,0 +1,159 @@
use bytes::{Bytes, BytesMut};
use pnet::packet::ethernet::{EtherTypes, EthernetPacket, MutableEthernetPacket};
use pnet::packet::{ip, ipv4, ipv6, tcp};
use pnet::util::MacAddr;
use std::convert::TryInto;
use std::net::{IpAddr, SocketAddr};
const IPV4_HEADER_LEN: usize = 20;
const IPV6_HEADER_LEN: usize = 40;
const TCP_HEADER_LEN: usize = 20;
pub const MAX_PACKET_LEN: usize = 1500;
#[derive(Debug)]
pub enum IPPacket<'p> {
V4(ipv4::Ipv4Packet<'p>),
V6(ipv6::Ipv6Packet<'p>),
}
impl IPPacket<'_> {
pub fn get_source(&self) -> IpAddr {
match self {
IPPacket::V4(p) => IpAddr::V4(p.get_source()),
IPPacket::V6(p) => IpAddr::V6(p.get_source()),
}
}
pub fn get_destination(&self) -> IpAddr {
match self {
IPPacket::V4(p) => IpAddr::V4(p.get_destination()),
IPPacket::V6(p) => IpAddr::V6(p.get_destination()),
}
}
}
const ETH_HDR_LEN: usize = 14;
#[allow(clippy::too_many_arguments)]
pub fn build_tcp_packet(
src_mac: MacAddr,
dst_mac: MacAddr,
local_addr: SocketAddr,
remote_addr: SocketAddr,
seq: u32,
ack: u32,
flags: u8,
payload: Option<&[u8]>,
) -> Bytes {
let ip_header_len = match local_addr {
SocketAddr::V4(_) => IPV4_HEADER_LEN,
SocketAddr::V6(_) => IPV6_HEADER_LEN,
};
let wscale = (flags & tcp::TcpFlags::SYN) != 0;
let tcp_header_len = TCP_HEADER_LEN + if wscale { 4 } else { 0 }; // nop + wscale
let tcp_total_len = tcp_header_len + payload.map_or(0, |payload| payload.len());
let total_len = ip_header_len + tcp_total_len;
let mut buf = BytesMut::zeroed(ETH_HDR_LEN + total_len);
let mut eth_buf = buf.split_to(ETH_HDR_LEN);
let mut ip_buf = buf.split_to(ip_header_len);
let mut tcp_buf = buf.split_to(tcp_total_len);
assert_eq!(0, buf.len());
let mut tcp = tcp::MutableTcpPacket::new(&mut tcp_buf).unwrap();
tcp.set_window(0xffff);
tcp.set_source(local_addr.port());
tcp.set_destination(remote_addr.port());
tcp.set_sequence(seq);
tcp.set_acknowledgement(ack);
tcp.set_flags(flags);
tcp.set_data_offset(TCP_HEADER_LEN as u8 / 4 + if wscale { 1 } else { 0 });
if wscale {
let wscale = tcp::TcpOption::wscale(14);
tcp.set_options(&[tcp::TcpOption::nop(), wscale]);
}
if let Some(payload) = payload {
tcp.set_payload(payload);
}
let mut ethernet = MutableEthernetPacket::new(&mut eth_buf).unwrap();
ethernet.set_destination(dst_mac);
ethernet.set_source(src_mac);
ethernet.set_ethertype(EtherTypes::Ipv4);
match (local_addr, remote_addr) {
(SocketAddr::V4(local), SocketAddr::V4(remote)) => {
let mut v4 = ipv4::MutableIpv4Packet::new(&mut ip_buf).unwrap();
v4.set_version(4);
v4.set_header_length(IPV4_HEADER_LEN as u8 / 4);
v4.set_next_level_protocol(ip::IpNextHeaderProtocols::Tcp);
v4.set_ttl(64);
v4.set_source(*local.ip());
v4.set_destination(*remote.ip());
v4.set_total_length(total_len.try_into().unwrap());
v4.set_flags(ipv4::Ipv4Flags::DontFragment);
tcp.set_checksum(tcp::ipv4_checksum(
&tcp.to_immutable(),
&v4.get_source(),
&v4.get_destination(),
));
v4.set_checksum(ipv4::checksum(&v4.to_immutable()));
}
(SocketAddr::V6(local), SocketAddr::V6(remote)) => {
let mut v6 = ipv6::MutableIpv6Packet::new(&mut ip_buf).unwrap();
v6.set_version(6);
v6.set_payload_length(tcp_total_len.try_into().unwrap());
v6.set_next_header(ip::IpNextHeaderProtocols::Tcp);
v6.set_hop_limit(64);
v6.set_source(*local.ip());
v6.set_destination(*remote.ip());
tcp.set_checksum(tcp::ipv6_checksum(
&tcp.to_immutable(),
&v6.get_source(),
&v6.get_destination(),
));
}
_ => unreachable!(),
};
ip_buf.unsplit(tcp_buf);
eth_buf.unsplit(ip_buf);
eth_buf.freeze()
}
#[tracing::instrument(ret)]
pub fn parse_ip_packet(
buf: &Bytes,
) -> Option<(MacAddr, MacAddr, IPPacket<'_>, tcp::TcpPacket<'_>)> {
let eth = EthernetPacket::new(buf).unwrap();
let src_mac = eth.get_source();
let dst_mac = eth.get_destination();
tracing::trace!("Parsing IP packet: {:?}", eth);
let buf = &buf[ETH_HDR_LEN..];
if buf[0] >> 4 == 4 {
let v4 = ipv4::Ipv4Packet::new(buf).unwrap();
if v4.get_next_level_protocol() != ip::IpNextHeaderProtocols::Tcp {
return None;
}
let tcp = tcp::TcpPacket::new(&buf[IPV4_HEADER_LEN..]).unwrap();
Some((src_mac, dst_mac, IPPacket::V4(v4), tcp))
} else if buf[0] >> 4 == 6 {
let v6 = ipv6::Ipv6Packet::new(buf).unwrap();
if v6.get_next_header() != ip::IpNextHeaderProtocols::Tcp {
return None;
}
let tcp = tcp::TcpPacket::new(&buf[IPV6_HEADER_LEN..]).unwrap();
Some((src_mac, dst_mac, IPPacket::V6(v6), tcp))
} else {
tracing::trace!("Invalid IP version: {}", buf[0] >> 4);
None
}
}
+561
View File
@@ -0,0 +1,561 @@
//! A minimum, userspace TCP based datagram stack
//!
//! # Overview
//!
//! `fake-tcp` is a reusable library that implements a minimum TCP stack in
//! user space using the Tun interface. It allows programs to send datagrams
//! as if they are part of a TCP connection. `fake-tcp` has been tested to
//! be able to pass through a variety of NAT and stateful firewalls while
//! fully preserves certain desirable behavior such as out of order delivery
//! and no congestion/flow controls.
//!
//! # Core Concepts
//!
//! The core of the `fake-tcp` crate compose of two structures. [`Stack`] and
//! [`Socket`].
//!
//! ## [`Stack`]
//!
//! [`Stack`] represents a virtual TCP stack that operates at
//! Layer 3. It is responsible for:
//!
//! * TCP active and passive open and handshake
//! * `RST` handling
//! * Interact with the Tun interface at Layer 3
//! * Distribute incoming datagrams to corresponding [`Socket`]
//!
//! ## [`Socket`]
//!
//! [`Socket`] represents a TCP connection. It registers the identifying
//! tuple `(src_ip, src_port, dest_ip, dest_port)` inside the [`Stack`] so
//! so that incoming packets can be distributed to the right [`Socket`] with
//! using a channel. It is also what the client should use for
//! sending/receiving datagrams.
//!
//! # Examples
//!
//! Please see [`client.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/client.rs)
//! and [`server.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/server.rs) files
//! from the `phantun` crate for how to use this library in client/server mode, respectively.
use crate::common::scoped_task::ScopedTask;
use super::packet::*;
use bytes::{Bytes, BytesMut};
use crossbeam::atomic::AtomicCell;
use pnet::packet::tcp::TcpOptionNumbers;
use pnet::packet::{tcp, Packet};
use pnet::util::MacAddr;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc, RwLock,
};
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::time;
use tracing::{info, trace, warn};
const TIMEOUT: time::Duration = time::Duration::from_secs(1);
const RETRIES: usize = 6;
const MPMC_BUFFER_LEN: usize = 512;
const MPSC_BUFFER_LEN: usize = 128;
const MAX_UNACKED_LEN: u32 = 128 * 1024 * 1024; // 128MB
#[async_trait::async_trait]
pub trait Tun: Send + Sync + 'static {
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error>;
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error>;
fn driver_type(&self) -> &'static str;
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
struct AddrTuple {
local_addr: SocketAddr,
remote_addr: SocketAddr,
}
impl AddrTuple {
fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> AddrTuple {
AddrTuple {
local_addr,
remote_addr,
}
}
}
struct Shared {
tuples: RwLock<HashMap<AddrTuple, flume::Sender<Bytes>>>,
listening: RwLock<HashSet<u16>>,
tun: Arc<dyn Tun>,
ready: mpsc::Sender<Socket>,
tuples_purge: broadcast::Sender<AddrTuple>,
}
pub struct Stack {
shared: Arc<Shared>,
local_ip: Ipv4Addr,
local_ip6: Option<Ipv6Addr>,
local_mac: MacAddr,
ready: mpsc::Receiver<Socket>,
reader_task: ScopedTask<()>,
}
#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
pub enum State {
Idle,
SynSent,
SynReceived,
Established,
}
pub struct Socket {
shared: Arc<Shared>,
tun: Arc<dyn Tun>,
incoming: flume::Receiver<Bytes>,
local_addr: SocketAddr,
remote_addr: SocketAddr,
local_mac: MacAddr,
remote_mac: AtomicCell<Option<MacAddr>>,
seq: AtomicU32,
ack: AtomicU32,
last_ack: AtomicU32,
state: AtomicCell<State>,
}
/// A socket that represents a unique TCP connection between a server and client.
///
/// The `Socket` object itself satisfies `Sync` and `Send`, which means it can
/// be safely called within an async future.
///
/// To close a TCP connection that is no longer needed, simply drop this object
/// out of scope.
impl Socket {
#[allow(clippy::too_many_arguments)]
fn new(
shared: Arc<Shared>,
tun: Arc<dyn Tun>,
local_addr: SocketAddr,
remote_addr: SocketAddr,
local_mac: MacAddr,
remote_mac: Option<MacAddr>,
ack: Option<u32>,
state: State,
) -> (Socket, flume::Sender<Bytes>) {
let (incoming_tx, incoming_rx) = flume::bounded(MPMC_BUFFER_LEN);
(
Socket {
shared,
tun,
incoming: incoming_rx,
local_addr,
remote_addr,
local_mac,
remote_mac: AtomicCell::new(remote_mac),
seq: AtomicU32::new(0),
ack: AtomicU32::new(ack.unwrap_or(0)),
last_ack: AtomicU32::new(ack.unwrap_or(0)),
state: AtomicCell::new(state),
},
incoming_tx,
)
}
fn build_tcp_packet(&self, flags: u8, payload: Option<&[u8]>) -> Bytes {
let ack = self.ack.load(Ordering::Relaxed);
self.last_ack.store(ack, Ordering::Relaxed);
build_tcp_packet(
self.local_mac,
self.remote_mac.load().unwrap_or(MacAddr::zero()),
self.local_addr,
self.remote_addr,
self.seq.load(Ordering::Relaxed),
ack,
flags,
payload,
)
}
/// Sends a datagram to the other end.
///
/// This method takes `&self`, and it can be called safely by multiple threads
/// at the same time.
///
/// A return of `None` means the Tun socket returned an error
/// and this socket must be closed.
pub fn try_send(&self, payload: &[u8]) -> Option<()> {
match self.state.load() {
State::Established => {
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed);
self.tun.try_send(&buf).ok().and(Some(()))
}
_ => unreachable!(),
}
}
pub fn close(&self) {
if self.state.load() != State::Idle {
let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None);
let _ = self.tun.try_send(&buf);
self.state.store(State::Idle);
}
}
pub async fn recv_bytes(&self) -> Option<Vec<u8>> {
let mut buf = [0u8; 2048];
self.recv(&mut buf).await.map(|size| buf[..size].to_vec())
}
/// Attempt to receive a datagram from the other end.
///
/// This method takes `&self`, and it can be called safely by multiple threads
/// at the same time.
///
/// A return of `None` means the TCP connection is broken
/// and this socket must be closed.
pub async fn recv(&self, buf: &mut [u8]) -> Option<usize> {
tracing::trace!(
"Socket recv called, local_addr: {:?}, remote_addr: {:?}",
self.local_addr,
self.remote_addr
);
loop {
match self.state.load() {
State::Established => {
let Ok(raw_buf) = self.incoming.recv_async().await else {
info!("Connection {} recv error", self);
return None;
};
let (src_mac, dst_mac, _v4_packet, tcp_packet) =
parse_ip_packet(&raw_buf).unwrap();
tracing::trace!(
"Socket received TCP packet from {}({:?}) to {}({:?}): {:?}",
self.remote_addr,
src_mac,
self.local_addr,
dst_mac,
tcp_packet
);
self.remote_mac.store(Some(src_mac));
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
info!("Connection {} reset by peer", self);
return None;
}
if (tcp_packet.get_flags() & tcp::TcpFlags::ACK) != 0
&& tcp_packet.payload().is_empty()
{
self.seq
.store(tcp_packet.get_acknowledgement(), Ordering::Relaxed);
}
let payload = tcp_packet.payload();
let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32);
self.ack.store(new_ack, Ordering::Relaxed);
for opt in tcp_packet.get_options_iter() {
if opt.get_number() == TcpOptionNumbers::SACK {
// SACK 选项类型为 5
let payload = opt.payload();
for chunk in payload.chunks(8) {
if chunk.len() != 8 {
continue;
}
let left = tcp_packet.get_acknowledgement();
let right = u32::from_be_bytes(chunk[0..4].try_into().unwrap());
let len = right.wrapping_sub(left);
let sack_end = u32::from_be_bytes(chunk[4..8].try_into().unwrap());
if len == 0 || sack_end <= left {
continue;
}
let send_len = std::cmp::min(len, 1400) as usize;
let data = vec![0u8; send_len];
let buf = build_tcp_packet(
self.local_mac,
self.remote_mac.load().unwrap_or(MacAddr::zero()),
self.local_addr,
self.remote_addr,
left,
self.ack.load(Ordering::Relaxed),
tcp::TcpFlags::ACK,
Some(&data),
);
if let Err(e) = self.tun.try_send(&buf) {
tracing::error!("Failed to send SACK response: {}", e);
}
break;
}
}
}
if payload.is_empty() {
continue;
}
if payload.len() >= buf.len() {
tracing::warn!(
"Payload len {} > buf len {}, tcp: {:?}, payload: {:?}",
payload.len(),
buf.len(),
tcp_packet,
payload
);
continue;
}
buf[..payload.len()].copy_from_slice(payload);
return Some(payload.len());
}
State::SynSent => {
let Ok(Ok(buf)) = time::timeout(TIMEOUT, self.incoming.recv_async()).await
else {
info!("Waiting for client SYN + ACK timed out");
return None;
};
let (src_mac, _dst_mac, _v4_packet, tcp_packet) =
parse_ip_packet(&buf).unwrap();
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
tracing::trace!("Connection {} reset by peer", self);
return None;
}
let expected_flag = tcp::TcpFlags::SYN | tcp::TcpFlags::ACK;
if (tcp_packet.get_flags() & expected_flag) == expected_flag {
// found our SYN + ACK
self.seq
.store(tcp_packet.get_acknowledgement(), Ordering::Relaxed);
self.ack
.store(tcp_packet.get_sequence() + 1, Ordering::Relaxed);
self.remote_mac.store(Some(src_mac));
self.state.store(State::Established);
return Some(0);
}
}
_ => unreachable!(),
}
}
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
impl Drop for Socket {
/// Drop the socket and close the TCP connection
fn drop(&mut self) {
let tuple = AddrTuple::new(self.local_addr, self.remote_addr);
// dissociates ourself from the dispatch map
assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some());
// purge cache
let _ = self.shared.tuples_purge.send(tuple);
let buf = build_tcp_packet(
self.local_mac,
self.remote_mac.load().unwrap_or(MacAddr::zero()),
self.local_addr,
self.remote_addr,
self.seq.load(Ordering::Relaxed),
0,
tcp::TcpFlags::RST,
None,
);
if let Err(e) = self.tun.try_send(&buf) {
warn!("Unable to send RST to remote end: {}", e);
}
info!("Fake TCP connection to {} closed", self);
}
}
impl fmt::Display for Socket {
/// User-friendly string representation of the socket
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"(Fake TCP connection from {} to {})",
self.local_addr, self.remote_addr
)
}
}
/// A userspace TCP state machine
impl Stack {
/// Create a new stack, `tun` is an array of [`Tun`](tokio_tun::Tun).
/// When more than one [`Tun`](tokio_tun::Tun) object is passed in, same amount
/// of reader will be spawned later. This allows user to utilize the performance
/// benefit of Multiqueue Tun support on machines with SMP.
pub fn new(
tun: Arc<dyn Tun>,
local_ip: Ipv4Addr,
local_ip6: Option<Ipv6Addr>,
local_mac: Option<MacAddr>,
) -> Stack {
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
let shared = Arc::new(Shared {
tuples: RwLock::new(HashMap::new()),
tun: tun.clone(),
listening: RwLock::new(HashSet::new()),
ready: ready_tx,
tuples_purge: tuples_purge_tx.clone(),
});
let t = tokio::spawn(Stack::reader_task(
tun,
shared.clone(),
tuples_purge_tx.subscribe(),
));
Stack {
shared,
local_ip,
local_ip6,
local_mac: local_mac.unwrap_or(MacAddr::zero()),
ready: ready_rx,
reader_task: t.into(),
}
}
/// Returns the driver type of the stack.
pub fn driver_type(&self) -> &'static str {
self.shared.tun.driver_type()
}
/// Listens for incoming connections on the given `port`.
pub fn listen(&mut self, port: u16) {
assert!(self.shared.listening.write().unwrap().insert(port));
}
/// Accepts an incoming connection.
pub async fn accept(&mut self) -> Socket {
self.ready.recv().await.unwrap()
}
pub async fn alloc_established_socket(
&mut self,
local_addr: SocketAddr,
remote_addr: SocketAddr,
state: State,
) -> Socket {
let tuple = AddrTuple::new(local_addr, remote_addr);
let mut tuples = self.shared.tuples.write().unwrap();
let (sock, incoming) = Socket::new(
self.shared.clone(),
// self.shared.tun.choose(&mut rng).unwrap().clone(),
self.shared.tun.clone(), // Simplification: just use the first tun
local_addr,
remote_addr,
self.local_mac,
None,
Some(0), // Initial ACK
state,
);
assert!(tuples.insert(tuple, incoming).is_none());
sock
}
async fn reader_task(
tun: Arc<dyn Tun>,
shared: Arc<Shared>,
mut tuples_purge: broadcast::Receiver<AddrTuple>,
) {
let mut tuples: HashMap<AddrTuple, flume::Sender<Bytes>> = HashMap::new();
loop {
let mut buf = BytesMut::new();
tokio::select! {
size = tun.recv(&mut buf) => {
let size = size.unwrap();
tracing::trace!(len = size, ?buf, "PnetTun received packet");
let buf = buf.split().freeze();
match parse_ip_packet(&buf) {
Some((_src_mac, _dst_mac, ip_packet, tcp_packet)) => {
let local_addr = SocketAddr::new(
ip_packet.get_destination(),
tcp_packet.get_destination(),
);
let remote_addr = SocketAddr::new(
ip_packet.get_source(),
tcp_packet.get_source(),
);
let tuple = AddrTuple::new(local_addr, remote_addr);
if let Some(c) = tuples.get(&tuple) {
if c.send_async(buf).await.is_err() {
trace!("Cache hit, but receiver already closed, dropping packet");
}
continue;
// If not Ok, receiver has been closed and just fall through to the slow
// path below
} else {
trace!("Cache miss, checking the shared tuples table for connection");
let sender = {
let tuples = shared.tuples.read().unwrap();
tuples.get(&tuple).cloned()
};
if let Some(c) = sender {
trace!("Storing connection information into local tuples");
tuples.insert(tuple, c.clone());
if let Err(e) = c.send_async(buf).await {
trace!("Error sending packet to connection: {:?}", e);
}
continue;
}
}
if tcp_packet.get_flags() == tcp::TcpFlags::SYN
&& shared
.listening
.read()
.unwrap()
.contains(&tcp_packet.get_destination())
{
trace!(?tcp_packet, "Received SYN packet for port {}, ignoring", tcp_packet.get_destination());
continue;
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
info!("Unknown RST TCP packet from {}, ignoring", remote_addr);
continue;
}
}
None => {
trace!("Dropping packet with no IP/TCP header");
continue;
}
}
},
tuple = tuples_purge.recv() => {
let tuple = tuple.unwrap();
tuples.remove(&tuple);
trace!("Removed cached tuple: {:?}", tuple);
}
}
}
}
}
+13 -3
View File
@@ -15,6 +15,7 @@ use self::packet_def::ZCPacket;
pub mod buf;
pub mod common;
pub mod fake_tcp;
pub mod filter;
pub mod mpsc;
pub mod packet_def;
@@ -23,8 +24,14 @@ pub mod stats;
pub mod tcp;
pub mod udp;
pub const PROTO_PORT_OFFSET: &[(&str, u16)] =
&[("tcp", 0), ("udp", 0), ("wg", 1), ("ws", 1), ("wss", 2)];
pub const PROTO_PORT_OFFSET: &[(&str, u16)] = &[
("tcp", 0),
("udp", 0),
("wg", 1),
("ws", 1),
("wss", 2),
("faketcp", 3),
];
#[cfg(feature = "wireguard")]
pub mod wireguard;
@@ -139,7 +146,9 @@ pub trait TunnelConnector: Send {
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
if let Ok(sock_addr) = addr.parse::<SocketAddr>() {
let mut ret_url = url::Url::parse(format!("{}://0.0.0.0", scheme).as_str()).unwrap();
let url_str = format!("{}://0.0.0.0", scheme);
let mut ret_url = url::Url::parse(url_str.as_str())
.unwrap_or_else(|_| panic!("invalid url: {}", url_str));
ret_url.set_ip_host(sock_addr.ip()).unwrap();
ret_url.set_port(Some(sock_addr.port())).unwrap();
ret_url
@@ -200,6 +209,7 @@ fn default_port(scheme: &str) -> Option<u16> {
"udp" => Some(11010),
"ws" => Some(11011),
"wss" => Some(11012),
"faketcp" => Some(11013),
"quic" => Some(11012),
"wg" => Some(11011),
_ => None,
+1
View File
@@ -0,0 +1 @@
WinDivert doesn't support aarch64, this is a placeholder file to make tauri happy.
Binary file not shown.
Binary file not shown.