diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index 12cdcc9b..79703a83 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -245,13 +245,13 @@ jobs: # windows is the only OS using a different convention for executable file name if [[ $OS =~ ^windows.*$ && $TARGET =~ ^x86_64.*$ ]]; then SUFFIX=.exe - cp easytier/third_party/*.dll ./artifacts/objects/ + cp easytier/third_party/x86_64/* ./artifacts/objects/ elif [[ $OS =~ ^windows.*$ && $TARGET =~ ^i686.*$ ]]; then SUFFIX=.exe - cp easytier/third_party/i686/*.dll ./artifacts/objects/ + cp easytier/third_party/i686/* ./artifacts/objects/ elif [[ $OS =~ ^windows.*$ && $TARGET =~ ^aarch64.*$ ]]; then SUFFIX=.exe - cp easytier/third_party/arm64/*.dll ./artifacts/objects/ + cp easytier/third_party/arm64/* ./artifacts/objects/ fi if [[ $GITHUB_REF_TYPE =~ ^tag$ ]]; then TAG=$GITHUB_REF_NAME diff --git a/.github/workflows/gui.yml b/.github/workflows/gui.yml index c402ff9b..7ac89e40 100644 --- a/.github/workflows/gui.yml +++ b/.github/workflows/gui.yml @@ -170,11 +170,11 @@ jobs: if: ${{ matrix.OS == 'windows-latest' }} run: | if [[ $GUI_TARGET =~ ^aarch64.*$ ]]; then - cp ./easytier/third_party/arm64/*.dll ./easytier-gui/src-tauri/ + cp ./easytier/third_party/arm64/* ./easytier-gui/src-tauri/ elif [[ $GUI_TARGET =~ ^i686.*$ ]]; then - cp ./easytier/third_party/i686/*.dll ./easytier-gui/src-tauri/ + cp ./easytier/third_party/i686/* ./easytier-gui/src-tauri/ else - cp ./easytier/third_party/*.dll ./easytier-gui/src-tauri/ + cp ./easytier/third_party/x86_64/* ./easytier-gui/src-tauri/ fi - name: Build GUI diff --git a/.gitignore b/.gitignore index edebcc4e..9d45b2a2 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ node_modules .vite easytier-gui/src-tauri/*.dll +easytier-gui/src-tauri/*.sys /easytier-contrib/easytier-ohrs/dist/ .direnv diff --git a/Cargo.lock b/Cargo.lock index ef348d63..5c55c4db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2095,6 +2095,7 @@ dependencies = [ "bytecodec", "byteorder", "bytes", + "cfg-if", "chrono", "cidr", "clap", @@ -2107,6 +2108,7 @@ dependencies = [ "derive_builder", "easytier-rpc-build 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "encoding", + "flume 0.12.0", "futures", "futures-util", "gethostname 0.5.0", @@ -2195,6 +2197,7 @@ dependencies = [ "which 7.0.3", "wildmatch", "winapi", + "windivert", "windows 0.52.0", "windows-service", "windows-sys 0.52.0", @@ -2577,6 +2580,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "etherparse" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "827292ea592108849932ad8e30218f8b1f21c0dfd0696698a18b5d0aed62d990" +dependencies = [ + "arrayvec", +] + [[package]] name = "event-listener" version = "5.3.1" @@ -2615,6 +2627,9 @@ name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +dependencies = [ + "getrandom 0.2.15", +] [[package]] name = "fdeflate" @@ -2675,6 +2690,18 @@ dependencies = [ "spin", ] +[[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "fastrand", + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -8175,7 +8202,7 @@ checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680" dependencies = [ "atoi", "chrono", - "flume", + "flume 0.11.0", "futures-channel", "futures-core", "futures-executor", @@ -10327,6 +10354,29 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windivert" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6b6833a760d1c36b489314a5541a12a39d162dc8341d8f6f400212b96d3df1" +dependencies = [ + "etherparse", + "thiserror 1.0.63", + "windivert-sys", + "windows 0.48.0", +] + +[[package]] +name = "windivert-sys" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "832bc4af9272458a8a64395b3aabe10dc4089546486fcbd0e19b9b6d28ba6e54" +dependencies = [ + "cc", + "thiserror 1.0.63", + "windows 0.48.0", +] + [[package]] name = "window-vibrancy" version = "0.6.0" @@ -10342,6 +10392,15 @@ dependencies = [ "windows-version", ] +[[package]] +name = "windows" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +dependencies = [ + "windows-targets 0.48.5", +] + [[package]] name = "windows" version = "0.52.0" diff --git a/easytier-gui/src-tauri/tauri.windows.conf.json b/easytier-gui/src-tauri/tauri.windows.conf.json index 43d71532..b71ceaa6 100644 --- a/easytier-gui/src-tauri/tauri.windows.conf.json +++ b/easytier-gui/src-tauri/tauri.windows.conf.json @@ -3,7 +3,8 @@ "externalBin": [], "resources": [ "./wintun.dll", - "./Packet.dll" + "./Packet.dll", + "./*.sys" ], "windows": { "webviewInstallMode": { diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 86d8de6a..9e7fb99a 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -220,6 +220,10 @@ hmac = "0.12.1" sha2 = "0.10.8" shellexpand = "3.1.1" +# for fake tcp +flume = "0.12" +cfg-if = "1.0" + [target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies] machine-uid = "0.5.3" @@ -233,6 +237,9 @@ resolv-conf = "0.7.3" dbus = { version = "0.9.7", features = ["vendored"] } which = "7.0.3" +[target.'cfg(all(windows, any(target_arch = "x86_64", target_arch = "x86")))'.dependencies] +windivert = { version = "0.6.0", features = ["static"] } + [target.'cfg(windows)'.dependencies] windows = { version = "0.52.0", features = [ "Win32_Foundation", diff --git a/easytier/build.rs b/easytier/build.rs index adef2ef1..bd4e3c00 100644 --- a/easytier/build.rs +++ b/easytier/build.rs @@ -70,7 +70,7 @@ impl WindowsBuild { let target = std::env::var("TARGET").unwrap_or_default(); if target.contains("x86_64") { - println!("cargo:rustc-link-search=native=easytier/third_party/"); + println!("cargo:rustc-link-search=native=easytier/third_party/x86_64/"); } else if target.contains("i686") { println!("cargo:rustc-link-search=native=easytier/third_party/i686/"); } else if target.contains("aarch64") { diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 4b50fb54..a3ea8644 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -12,8 +12,9 @@ use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector}; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector}, tunnel::{ - check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector, - udp::UdpTunnelConnector, IpVersion, TunnelConnector, + check_scheme_and_get_socket_addr, fake_tcp::FakeTcpTunnelConnector, + ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, IpVersion, + TunnelConnector, }, }; @@ -157,6 +158,20 @@ pub async fn create_connector_by_url( let connector = dns_connector::DNSTunnelConnector::new(url, global_ctx.clone()); Box::new(connector) } + "faketcp" => { + let dst_addr = + check_scheme_and_get_socket_addr::(&url, "faketcp", ip_version).await?; + let mut connector = FakeTcpTunnelConnector::new(url); + if global_ctx.config.get_flags().bind_device { + set_bind_addr_for_peer_connector( + &mut connector, + dst_addr.is_ipv4(), + &global_ctx.get_ip_collector(), + ) + .await; + } + Box::new(connector) + } _ => { return Err(Error::InvalidUrl(url.into())); } diff --git a/easytier/src/instance/listeners.rs b/easytier/src/instance/listeners.rs index 2fb8ed28..27e7fc17 100644 --- a/easytier/src/instance/listeners.rs +++ b/easytier/src/instance/listeners.rs @@ -21,8 +21,8 @@ use crate::{ }, peers::peer_manager::PeerManager, tunnel::{ - ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, Tunnel, - TunnelListener, + fake_tcp::FakeTcpTunnelListener, ring::RingTunnelListener, tcp::TcpTunnelListener, + udp::UdpTunnelListener, Tunnel, TunnelListener, }, }; @@ -49,6 +49,7 @@ pub fn get_listener_by_url( use crate::tunnel::websocket::WSTunnelListener; Box::new(WSTunnelListener::new(l.clone())) } + "faketcp" => Box::new(FakeTcpTunnelListener::new(l.clone())), _ => { return Err(Error::InvalidUrl(l.to_string())); } @@ -143,7 +144,7 @@ impl ListenerManage && !is_url_host_ipv6(&l) && is_url_host_unspecified(&l) // quic enables dual-stack by default, may conflict with v4 listener - && l.scheme() != "quic" + && l.scheme() != "quic" && l.scheme() != "faketcp" { let mut ipv6_listener = l.clone(); ipv6_listener diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 98d322d8..c6f2c002 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -223,8 +223,8 @@ impl PeerConn { if peer_mgr_hdr.packet_type != PacketType::HandShake as u8 { return Err(Error::WaitRespError(format!( - "unexpected packet type: {:?}", - peer_mgr_hdr.packet_type + "unexpected packet type: {:?}, packet: {:?}", + peer_mgr_hdr.packet_type, rsp ))); } diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index b5ac636d..50b81e6b 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -534,6 +534,11 @@ pub mod tests { let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap(); println!("connect: {:?}", tunnel.info()); + if connector.remote_url().scheme() == "faketcp" { + // listener need some time to start capturing packet + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + assert_eq!( url::Url::from(tunnel.info().unwrap().remote_addr.unwrap()), connector.remote_url(), diff --git a/easytier/src/tunnel/fake_tcp/LICENSE b/easytier/src/tunnel/fake_tcp/LICENSE new file mode 100644 index 00000000..e1b074a5 --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/LICENSE @@ -0,0 +1,6 @@ +Copyright 2021-2025 Datong Sun dndx@idndx.com + +Licensed under the Apache License, Version 2.0 or the MIT license , at your option. Files in the project may +not be copied, modified, or distributed except according to those terms. \ No newline at end of file diff --git a/easytier/src/tunnel/fake_tcp/mod.rs b/easytier/src/tunnel/fake_tcp/mod.rs new file mode 100644 index 00000000..6f77a5e7 --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/mod.rs @@ -0,0 +1,482 @@ +mod netfilter; +mod packet; +mod stack; + +use std::net::{IpAddr, Ipv4Addr, UdpSocket}; +use std::sync::Arc; +use std::{net::SocketAddr, pin::Pin}; + +use bytes::BytesMut; +use pnet::datalink; +use pnet::util::MacAddr; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; +use tokio::sync::Mutex; + +use crate::common::scoped_task::ScopedTask; +use crate::tunnel::fake_tcp::netfilter::create_tun; +use crate::tunnel::{common::TunnelWrapper, Tunnel, TunnelError, TunnelInfo, TunnelListener}; + +use futures::Future; + +use dashmap::DashMap; + +struct IpToIfNameCache { + ip_to_ifname: DashMap)>, +} + +impl IpToIfNameCache { + fn new() -> Self { + Self { + ip_to_ifname: DashMap::new(), + } + } + + fn reload_ip_to_ifname(&self) { + self.ip_to_ifname.clear(); + let interfaces = datalink::interfaces(); + for iface in interfaces { + for ip in iface.ips.iter() { + self.ip_to_ifname + .insert(ip.ip(), (iface.name.clone(), iface.mac)); + } + } + } + + fn get_ifname(&self, ip: &IpAddr) -> Option<(String, Option)> { + if let Some(ifname) = self.ip_to_ifname.get(ip) { + Some(ifname.clone()) + } else { + self.reload_ip_to_ifname(); + self.ip_to_ifname.get(ip).map(|s| s.clone()) + } + } +} + +fn get_faketcp_tunnel_type_str(driver_type: &str) -> String { + format!("faketcp_{}", driver_type) +} + +pub struct FakeTcpTunnelListener { + addr: url::Url, + os_listener: Option, + // interface_name -> fake tcp stack + stack_map: DashMap>>, + // a cache from ip addr to interface name + ip_to_ifname: IpToIfNameCache, +} + +impl FakeTcpTunnelListener { + pub fn new(addr: url::Url) -> Self { + // Define filter: Capture all packets (or refine this if needed) + // For FakeTCP, we probably want to capture packets destined to us? + // But `stack::Stack` handles IP/TCP logic. + // Maybe we just capture everything for now as a raw tunnel? + // Or better, filter based on some criteria? + // The user said "satisfy filter function". + // Let's create a filter that accepts everything for now, or maybe only IP packets? + FakeTcpTunnelListener { + addr, + os_listener: None, + stack_map: DashMap::new(), + ip_to_ifname: IpToIfNameCache::new(), + } + } + + async fn do_accept(&mut self) -> Result { + loop { + match self.os_listener.as_mut().unwrap().accept().await { + Ok((s, remote_addr)) => { + let Ok(local_addr) = s.local_addr() else { + tracing::warn!("accept fail with local_addr error"); + continue; + }; + let Some((interface_name, mac)) = + self.ip_to_ifname.get_ifname(&local_addr.ip()) + else { + tracing::warn!("accept fail with interface_name error"); + continue; + }; + return Ok(AcceptResult { + socket: s, + local_addr, + remote_addr, + interface_name, + mac, + }); + } + Err(e) => { + use std::io::ErrorKind::*; + if matches!( + e.kind(), + NotConnected | ConnectionAborted | ConnectionRefused | ConnectionReset + ) { + tracing::warn!(?e, "accept fail with retryable error: {:?}", e); + continue; + } + tracing::warn!(?e, "accept fail"); + return Err(e.into()); + } + } + } + } + + async fn get_stack( + &self, + accept_result: &AcceptResult, + ) -> Result>, TunnelError> { + let local_socket_addr = accept_result.local_addr; + + let interface_name = &accept_result.interface_name; + + let (local_ip, local_ip6) = match local_socket_addr.ip() { + IpAddr::V4(ip) => (Some(ip), None), + IpAddr::V6(ip) => (None, Some(ip)), + }; + + let ret = self + .stack_map + .entry(interface_name.to_string()) + .or_insert_with(|| { + let tun = create_tun(interface_name, None, local_socket_addr); + tracing::info!( + ?local_socket_addr, + "create new stack with interface_name: {:?}", + interface_name + ); + // TODO: Get local MAC address of the interface + Arc::new(Mutex::new(stack::Stack::new( + tun, + local_ip.unwrap_or(Ipv4Addr::UNSPECIFIED), + local_ip6, + accept_result.mac, + ))) + }) + .clone(); + + Ok(ret) + } +} + +fn build_os_socket_reader_task(mut socket: TcpStream) -> ScopedTask<()> { + let os_socket_reader_task: ScopedTask<()> = tokio::spawn(async move { + // read the os socket until it's closed + let mut buf = [0u8; 1024]; + while let Ok(size) = socket.read(&mut buf).await { + tracing::trace!("read {} bytes from os socket", size); + if size == 0 { + break; + } + } + tracing::info!("FakeTcpTunnelListener os socket closed"); + }) + .into(); + os_socket_reader_task +} + +#[derive(Debug)] +struct AcceptResult { + socket: TcpStream, + local_addr: SocketAddr, + remote_addr: SocketAddr, + interface_name: String, + mac: Option, +} + +#[async_trait::async_trait] +impl TunnelListener for FakeTcpTunnelListener { + async fn listen(&mut self) -> Result<(), TunnelError> { + let port = self.addr.port().unwrap_or(0); + let bind_addr = crate::tunnel::check_scheme_and_get_socket_addr::( + &self.addr, + "faketcp", + crate::tunnel::IpVersion::Both, + ) + .await?; + let os_listener = tokio::net::TcpListener::bind(bind_addr).await?; + tracing::info!(port, "FakeTcpTunnelListener listening"); + self.os_listener = Some(os_listener); + // self.stack.lock().await.listen(port); + Ok(()) + } + + async fn accept(&mut self) -> Result, TunnelError> { + tracing::debug!("FakeTcpTunnelListener waiting for accept"); + let res = self.do_accept().await?; + let stack = self.get_stack(&res).await?; + let socket = stack + .lock() + .await + .alloc_established_socket(res.local_addr, res.remote_addr, stack::State::Established) + .await; + + tracing::info!( + ?res, + remote = socket.remote_addr().to_string(), + "FakeTcpTunnelListener accepted connection" + ); + + let info = TunnelInfo { + tunnel_type: get_faketcp_tunnel_type_str(stack.lock().await.driver_type()), + local_addr: Some(self.local_url().into()), + remote_addr: Some( + crate::tunnel::build_url_from_socket_addr( + &socket.remote_addr().to_string(), + "faketcp", + ) + .into(), + ), + }; + + // We treat the fake tcp socket as a datagram tunnel directly + // The reader/writer will interface with the socket using recv_bytes/send + // We need to adapt the socket to ZCPacketStream and ZCPacketSink + + // Since FakeTcpTunnel is a datagram tunnel, we don't need FramedReader/Writer (which are for stream based tunnels like TCP) + // We should wrap the socket into something that produces/consumes ZCPacket directly. + + let socket = Arc::new(socket); + let reader = FakeTcpStream::new(socket.clone()); + let writer = FakeTcpSink::new(socket); + + Ok(Box::new(TunnelWrapper::new_with_associate_data( + reader, + writer, + Some(info), + Some(Box::new(build_os_socket_reader_task(res.socket))), + ))) + } + + fn local_url(&self) -> url::Url { + self.addr.clone() + } +} + +pub struct FakeTcpTunnelConnector { + addr: url::Url, + ip_to_if_name: IpToIfNameCache, +} + +impl FakeTcpTunnelConnector { + pub fn new(addr: url::Url) -> Self { + FakeTcpTunnelConnector { + addr, + ip_to_if_name: IpToIfNameCache::new(), + } + } +} + +fn get_local_ip_for_destination(destination: IpAddr) -> Option { + // 使用一个不可路由的、私有的、或回环地址创建一个临时的 socket,让内核自动选择源接口。 + // 对于 IPv4,使用 0.0.0.0; 对于 IPv6,使用 :: + let bind_addr = if destination.is_ipv4() { + IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)) + } else { + IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)) + }; + + // 绑定到一个临时端口 (0) + let socket = UdpSocket::bind((bind_addr, 0)).ok()?; + + // 尝试连接到目标地址。这不会真正发送数据包,只是让内核确定路由。 + socket.connect((destination, 80)).ok()?; // 使用一个常见的端口,例如 80 + + // 获取 socket 的本地地址信息 + socket.local_addr().map(|addr| addr.ip()).ok() +} + +#[async_trait::async_trait] +impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector { + async fn connect(&mut self) -> Result, TunnelError> { + let remote_addr = crate::tunnel::check_scheme_and_get_socket_addr::( + &self.addr, + "faketcp", + crate::tunnel::IpVersion::Both, + ) + .await?; + let local_ip = get_local_ip_for_destination(remote_addr.ip()) + .ok_or(TunnelError::InternalError("Failed to get local ip".into()))?; + + let os_socket = tokio::net::TcpSocket::new_v4()?; + os_socket.bind("0.0.0.0:0".parse().unwrap())?; + let local_port = os_socket.local_addr()?.port(); + let local_addr = SocketAddr::new(local_ip, local_port); + + let (interface_name, mac) = + self.ip_to_if_name + .get_ifname(&local_ip) + .ok_or(TunnelError::InternalError( + "Failed to get interface name".into(), + ))?; + + let (local_ip, local_ip6) = match local_ip { + IpAddr::V4(ip) => (Some(ip), None), + IpAddr::V6(ip) => (None, Some(ip)), + }; + + let tun = create_tun(&interface_name, Some(remote_addr), local_addr); + let local_ip = local_ip.unwrap_or("0.0.0.0".parse().unwrap()); + let mut stack = stack::Stack::new(tun, local_ip, local_ip6, mac); + let driver_type = stack.driver_type(); + + let socket = stack + .alloc_established_socket(local_addr, remote_addr, stack::State::SynSent) + .await; + + let os_stream = os_socket.connect(remote_addr).await?; + + tracing::info!(?remote_addr, "FakeTcpTunnelConnector connecting"); + + socket.recv_bytes().await.ok_or(TunnelError::InternalError( + "Failed to recv bytes to establish connection".into(), + ))?; + + tracing::info!(local_addr = ?socket.local_addr(), "FakeTcpTunnelConnector connected"); + + let info = TunnelInfo { + tunnel_type: get_faketcp_tunnel_type_str(driver_type), + local_addr: Some( + crate::tunnel::build_url_from_socket_addr( + &socket.local_addr().to_string(), + "faketcp", + ) + .into(), + ), + remote_addr: Some(self.addr.clone().into()), + }; + + let socket = Arc::new(socket); + let reader = FakeTcpStream::new(socket.clone()); + let writer = FakeTcpSink::new(socket); + + Ok(Box::new(TunnelWrapper::new_with_associate_data( + reader, + writer, + Some(info), + Some(Box::new((build_os_socket_reader_task(os_stream), stack))), + ))) + } + + fn remote_url(&self) -> url::Url { + self.addr.clone() + } +} + +use crate::tunnel::packet_def::{ZCPacket, ZCPacketType}; +use crate::tunnel::{SinkError, SinkItem, StreamItem}; +use futures::{Sink, Stream}; +use std::task::{Context as TaskContext, Poll}; + +struct FakeTcpStream { + socket: Arc, + #[allow(clippy::type_complexity)] + recv_fut: Option>> + Send + Sync>>>, +} + +impl FakeTcpStream { + fn new(socket: Arc) -> Self { + Self { + socket, + recv_fut: None, + } + } +} + +impl Stream for FakeTcpStream { + type Item = StreamItem; + + fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + let s = self.get_mut(); + if s.recv_fut.is_none() { + let socket = s.socket.clone(); + s.recv_fut = Some(Box::pin(async move { socket.recv_bytes().await })); + } + + match s.recv_fut.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready(Some(data)) => { + let mut buf = BytesMut::new(); + buf.extend_from_slice(&data); + let packet = ZCPacket::new_from_buf(buf, ZCPacketType::DummyTunnel); + + s.recv_fut = None; + + Poll::Ready(Some(Ok(packet))) + } + Poll::Ready(None) => { + // 连接关闭 + s.recv_fut = None; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } +} + +struct FakeTcpSink { + socket: Arc, +} + +impl FakeTcpSink { + fn new(socket: Arc) -> Self { + Self { socket } + } +} + +impl Sink for FakeTcpSink { + type Error = SinkError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut TaskContext<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { + // We need to send the packet as bytes + // The item is ZCPacket, which has into_bytes() method + let bytes = item.convert_type(ZCPacketType::DummyTunnel).into_bytes(); + + // Let's just spawn for now as a simple implementation, noting the limitation. + self.socket.try_send(&bytes); + + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut TaskContext<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut TaskContext<'_>, + ) -> Poll> { + self.socket.close(); + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod tests { + use crate::tunnel::common::tests::_tunnel_pingpong; + + use super::*; + + #[tokio::test] + async fn faketcp_pingpong() { + #[cfg(target_family = "unix")] + { + if unsafe { nix::libc::geteuid() } != 0 { + return; + } + } + + let listener = FakeTcpTunnelListener::new("faketcp://0.0.0.0:31011".parse().unwrap()); + let connector = FakeTcpTunnelConnector::new("faketcp://127.0.0.1:31011".parse().unwrap()); + + _tunnel_pingpong(listener, connector).await + } +} diff --git a/easytier/src/tunnel/fake_tcp/netfilter/linux_bpf.rs b/easytier/src/tunnel/fake_tcp/netfilter/linux_bpf.rs new file mode 100644 index 00000000..8e5379b9 --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/netfilter/linux_bpf.rs @@ -0,0 +1,708 @@ +use bytes::Bytes; +use bytes::BytesMut; +use nix::libc; +use std::ffi::CString; +use std::io; +use std::mem; +use std::net::IpAddr; +use std::net::SocketAddr; +use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; +use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering}; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::tunnel::fake_tcp::stack; + +const ETH_HDR_LEN: usize = 14; +const ETH_TYPE_OFFSET: u32 = 12; +const ETHERTYPE_IPV4: u32 = 0x0800; +const ETHERTYPE_IPV6: u32 = 0x86DD; +const IPPROTO_TCP_U32: u32 = 6; + +const BPF_LD: u16 = 0x00; +const BPF_LDX: u16 = 0x01; +const BPF_JMP: u16 = 0x05; +const BPF_RET: u16 = 0x06; + +const BPF_W: u16 = 0x00; +const BPF_H: u16 = 0x08; +const BPF_B: u16 = 0x10; + +const BPF_ABS: u16 = 0x20; +const BPF_IND: u16 = 0x40; +const BPF_MSH: u16 = 0xa0; + +const BPF_JA: u16 = 0x00; +const BPF_JEQ: u16 = 0x10; + +const BPF_K: u16 = 0x00; + +fn stmt(code: u16, k: u32) -> libc::sock_filter { + libc::sock_filter { + code, + jt: 0, + jf: 0, + k, + } +} + +fn jeq(k: u32, jt: u8, jf: u8) -> libc::sock_filter { + libc::sock_filter { + code: BPF_JMP | BPF_JEQ | BPF_K, + jt, + jf, + k, + } +} + +fn ja(k: u32) -> libc::sock_filter { + libc::sock_filter { + code: BPF_JMP | BPF_JA, + jt: 0, + jf: 0, + k, + } +} + +#[derive(Clone, Copy)] +struct Label(usize); + +struct JeqPatch { + idx: usize, + t: Label, + f: Label, +} + +struct JaPatch { + idx: usize, + target: Label, +} + +struct BpfBuilder { + insns: Vec, + labels: Vec>, + jeq_patches: Vec, + ja_patches: Vec, +} + +impl BpfBuilder { + fn new() -> Self { + Self { + insns: Vec::new(), + labels: Vec::new(), + jeq_patches: Vec::new(), + ja_patches: Vec::new(), + } + } + + fn new_label(&mut self) -> Label { + let idx = self.labels.len(); + self.labels.push(None); + Label(idx) + } + + fn set_label(&mut self, label: Label) { + self.labels[label.0] = Some(self.insns.len()); + } + + fn push(&mut self, insn: libc::sock_filter) { + self.insns.push(insn); + } + + fn push_jeq(&mut self, k: u32, t: Label, f: Label) { + let idx = self.insns.len(); + self.insns.push(jeq(k, 0, 0)); + self.jeq_patches.push(JeqPatch { idx, t, f }); + } + + fn push_ja(&mut self, target: Label) { + let idx = self.insns.len(); + self.insns.push(ja(0)); + self.ja_patches.push(JaPatch { idx, target }); + } + + fn finish(mut self) -> io::Result> { + for patch in self.jeq_patches { + let JeqPatch { idx, t, f } = patch; + let cur = idx + 1; + let t_pos = + self.labels.get(t.0).and_then(|v| *v).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "unresolved label") + })?; + let f_pos = + self.labels.get(f.0).and_then(|v| *v).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "unresolved label") + })?; + + if t_pos < cur || f_pos < cur { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "backward bpf jump", + )); + } + + let jt: u8 = (t_pos - cur) + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?; + let jf: u8 = (f_pos - cur) + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?; + + self.insns[idx].jt = jt; + self.insns[idx].jf = jf; + } + + for patch in self.ja_patches { + let JaPatch { idx, target } = patch; + let cur = idx + 1; + let t_pos = + self.labels.get(target.0).and_then(|v| *v).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "unresolved label") + })?; + if t_pos < cur { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "backward bpf jump", + )); + } + self.insns[idx].k = (t_pos - cur) as u32; + } + + Ok(self.insns) + } +} + +fn build_tcp_filter( + src_addr: Option, + dst_addr: SocketAddr, +) -> io::Result> { + if let Some(src) = src_addr { + if src.is_ipv4() != dst_addr.is_ipv4() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "src/dst addr family mismatch", + )); + } + } + + let mut b = BpfBuilder::new(); + let l_check_ipv6 = b.new_label(); + let l_ipv4 = b.new_label(); + let l_ipv6 = b.new_label(); + let l_accept = b.new_label(); + let l_reject = b.new_label(); + + b.push(stmt(BPF_LD | BPF_H | BPF_ABS, ETH_TYPE_OFFSET)); + b.push_jeq(ETHERTYPE_IPV4, l_ipv4, l_check_ipv6); + + b.set_label(l_check_ipv6); + b.push_jeq(ETHERTYPE_IPV6, l_ipv6, l_reject); + + if dst_addr.is_ipv4() { + b.set_label(l_ipv4); + let l_v4_proto_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 9) as u32)); + b.push_jeq(IPPROTO_TCP_U32, l_v4_proto_ok, l_reject); + + b.set_label(l_v4_proto_ok); + let dst_ip = match dst_addr.ip() { + IpAddr::V4(ip) => u32::from(ip), + _ => unreachable!(), + }; + let l_v4_dstip_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 16) as u32)); + b.push_jeq(dst_ip, l_v4_dstip_ok, l_reject); + + b.set_label(l_v4_dstip_ok); + if let Some(src) = src_addr { + let src_ip = match src.ip() { + IpAddr::V4(ip) => u32::from(ip), + _ => unreachable!(), + }; + let l_v4_srcip_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 12) as u32)); + b.push_jeq(src_ip, l_v4_srcip_ok, l_reject); + b.set_label(l_v4_srcip_ok); + } + + b.push(stmt(BPF_LDX | BPF_B | BPF_MSH, ETH_HDR_LEN as u32)); + + let l_v4_dstport_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_H | BPF_IND, (ETH_HDR_LEN + 2) as u32)); + b.push_jeq(dst_addr.port() as u32, l_v4_dstport_ok, l_reject); + + b.set_label(l_v4_dstport_ok); + if let Some(src) = src_addr { + b.push(stmt(BPF_LD | BPF_H | BPF_IND, ETH_HDR_LEN as u32)); + b.push_jeq(src.port() as u32, l_accept, l_reject); + } else { + b.push_ja(l_accept); + } + } else { + b.set_label(l_ipv6); + let l_v6_proto_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 6) as u32)); + b.push_jeq(IPPROTO_TCP_U32, l_v6_proto_ok, l_reject); + + b.set_label(l_v6_proto_ok); + let dst_ip = match dst_addr.ip() { + IpAddr::V6(ip) => ip.octets(), + _ => unreachable!(), + }; + for (i, chunk) in dst_ip.chunks_exact(4).enumerate() { + let off = ETH_HDR_LEN + 24 + (i * 4); + let v = u32::from_be_bytes(chunk.try_into().unwrap()); + let l_v6_dstip_word_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32)); + b.push_jeq(v, l_v6_dstip_word_ok, l_reject); + b.set_label(l_v6_dstip_word_ok); + } + + if let Some(src) = src_addr { + let src_ip = match src.ip() { + IpAddr::V6(ip) => ip.octets(), + _ => unreachable!(), + }; + for (i, chunk) in src_ip.chunks_exact(4).enumerate() { + let off = ETH_HDR_LEN + 8 + (i * 4); + let v = u32::from_be_bytes(chunk.try_into().unwrap()); + let l_v6_srcip_word_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32)); + b.push_jeq(v, l_v6_srcip_word_ok, l_reject); + b.set_label(l_v6_srcip_word_ok); + } + } + + let l_v6_dstport_ok = b.new_label(); + b.push(stmt( + BPF_LD | BPF_H | BPF_ABS, + (ETH_HDR_LEN + 40 + 2) as u32, + )); + b.push_jeq(dst_addr.port() as u32, l_v6_dstport_ok, l_reject); + + b.set_label(l_v6_dstport_ok); + if let Some(src) = src_addr { + b.push(stmt(BPF_LD | BPF_H | BPF_ABS, (ETH_HDR_LEN + 40) as u32)); + b.push_jeq(src.port() as u32, l_accept, l_reject); + } else { + b.push_ja(l_accept); + } + } + + b.set_label(l_accept); + b.push(stmt(BPF_RET | BPF_K, 0xFFFF)); + + b.set_label(l_reject); + if dst_addr.is_ipv4() { + b.set_label(l_ipv6); + } else { + b.set_label(l_ipv4); + } + b.push(stmt(BPF_RET | BPF_K, 0)); + + b.finish() +} + +pub struct LinuxBpfTun { + fd: OwnedFd, + ifindex: i32, + stop: Arc, + worker: Option>, + recv_queue: Mutex>>, +} + +impl LinuxBpfTun { + pub fn new( + interface_name: &str, + src_addr: Option, + dst_addr: SocketAddr, + ) -> io::Result { + let c_ifname = CString::new(interface_name) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid interface name"))?; + let ifindex = unsafe { libc::if_nametoindex(c_ifname.as_ptr()) as i32 }; + if ifindex <= 0 { + return Err(io::Error::new( + io::ErrorKind::NotFound, + "interface not found", + )); + } + + let proto: i32 = (libc::ETH_P_ALL as u16).to_be() as i32; + let fd = unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_RAW, proto) }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + let fd = unsafe { OwnedFd::from_raw_fd(fd) }; + + let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() }; + addr.sll_family = libc::AF_PACKET as u16; + addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be(); + addr.sll_ifindex = ifindex; + + let bind_ret = unsafe { + libc::bind( + fd.as_raw_fd(), + &addr as *const _ as *const libc::sockaddr, + mem::size_of::() as u32, + ) + }; + if bind_ret != 0 { + return Err(io::Error::last_os_error()); + } + + let filter = build_tcp_filter(src_addr, dst_addr)?; + let mut prog = libc::sock_fprog { + len: filter + .len() + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf program too long"))?, + filter: filter.as_ptr() as *mut libc::sock_filter, + }; + let opt_ret = unsafe { + libc::setsockopt( + fd.as_raw_fd(), + libc::SOL_SOCKET, + libc::SO_ATTACH_FILTER, + &mut prog as *mut _ as *mut libc::c_void, + mem::size_of::() as u32, + ) + }; + if opt_ret != 0 { + return Err(io::Error::last_os_error()); + } + + let timeout = libc::timeval { + tv_sec: 0, + tv_usec: 200_000, + }; + let _ = unsafe { + libc::setsockopt( + fd.as_raw_fd(), + libc::SOL_SOCKET, + libc::SO_RCVTIMEO, + &timeout as *const _ as *const libc::c_void, + mem::size_of::() as u32, + ) + }; + + let stop = Arc::new(AtomicBool::new(false)); + let (tx, rx) = tokio::sync::mpsc::channel(1024); + let stop_clone = stop.clone(); + let read_fd = fd.as_raw_fd(); + + let worker = std::thread::spawn(move || { + let mut buf = vec![0u8; 65536]; + while !stop_clone.load(AtomicOrdering::Relaxed) { + let n = unsafe { + libc::recv(read_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0) + }; + if n < 0 { + let err = io::Error::last_os_error(); + if matches!( + err.kind(), + io::ErrorKind::Interrupted | io::ErrorKind::WouldBlock + ) { + continue; + } + break; + } + if n == 0 { + continue; + } + let data = buf[..(n as usize)].to_vec(); + if tx.blocking_send(data).is_err() { + break; + } + } + }); + + tracing::info!( + interface_name, + ifindex, + "LinuxBpfTun created with filter {:?}", + filter + ); + + Ok(Self { + fd, + ifindex, + stop, + worker: Some(worker), + recv_queue: Mutex::new(rx), + }) + } +} + +impl Drop for LinuxBpfTun { + fn drop(&mut self) { + self.stop.store(true, AtomicOrdering::Relaxed); + let _ = unsafe { libc::shutdown(self.fd.as_raw_fd(), libc::SHUT_RD) }; + if let Some(worker) = self.worker.take() { + let _ = worker.join(); + } + } +} + +#[async_trait::async_trait] +impl stack::Tun for LinuxBpfTun { + async fn recv(&self, packet: &mut BytesMut) -> Result { + let mut rx = self.recv_queue.lock().await; + match rx.recv().await { + Some(data) => { + packet.extend_from_slice(&data); + Ok(data.len()) + } + None => Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "LinuxBpfTun channel closed", + )), + } + } + + fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> { + if packet.len() < 6 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "packet too short", + )); + } + + let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() }; + addr.sll_family = libc::AF_PACKET as u16; + addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be(); + addr.sll_ifindex = self.ifindex; + addr.sll_halen = 6; + addr.sll_addr[..6].copy_from_slice(&packet[..6]); + + let ret = unsafe { + libc::sendto( + self.fd.as_raw_fd(), + packet.as_ptr() as *const libc::c_void, + packet.len(), + 0, + &addr as *const _ as *const libc::sockaddr, + mem::size_of::() as u32, + ) + }; + if ret < 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) + } + + fn driver_type(&self) -> &'static str { + "linux_bpf" + } +} + +#[cfg(all(test, target_os = "linux"))] +mod tests { + use super::*; + + use crate::tunnel::fake_tcp::packet::build_tcp_packet; + use crate::tunnel::fake_tcp::stack::Tun; + use pnet::datalink; + use pnet::packet::tcp::TcpFlags; + use pnet::util::MacAddr; + use rand::Rng; + use std::net::{IpAddr, Ipv4Addr}; + use tokio::time::{timeout, Duration}; + + fn is_root() -> bool { + unsafe { libc::geteuid() == 0 } + } + + fn pick_interface_v4() -> Option<(String, Ipv4Addr, MacAddr)> { + let interfaces = datalink::interfaces(); + for iface in interfaces { + let Some(mac) = iface.mac else { + continue; + }; + if iface.is_loopback() { + continue; + } + let ipv4 = iface.ips.iter().find_map(|n| match n.ip() { + IpAddr::V4(ip) => Some(ip), + IpAddr::V6(_) => None, + })?; + return Some((iface.name, ipv4, mac)); + } + None + } + + fn send_raw_frame(interface_name: &str, frame: &[u8]) -> io::Result<()> { + if frame.len() < 6 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "frame too short", + )); + } + + let c_ifname = CString::new(interface_name) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid interface name"))?; + let ifindex = unsafe { libc::if_nametoindex(c_ifname.as_ptr()) as i32 }; + if ifindex <= 0 { + return Err(io::Error::new( + io::ErrorKind::NotFound, + "interface not found", + )); + } + + let proto: i32 = (libc::ETH_P_ALL as u16).to_be() as i32; + let fd = unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_RAW, proto) }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + let fd = unsafe { OwnedFd::from_raw_fd(fd) }; + + let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() }; + addr.sll_family = libc::AF_PACKET as u16; + addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be(); + addr.sll_ifindex = ifindex; + addr.sll_halen = 6; + addr.sll_addr[..6].copy_from_slice(&frame[..6]); + + let ret = unsafe { + libc::sendto( + fd.as_raw_fd(), + frame.as_ptr() as *const libc::c_void, + frame.len(), + 0, + &addr as *const _ as *const libc::sockaddr, + mem::size_of::() as u32, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(()) + } + + #[tokio::test] + async fn linux_bpf_tun_receives_matching_ipv4_frame() { + if !is_root() { + eprintln!("linux_bpf_tun_receives_matching_ipv4_frame: skipped (not root)"); + return; + } + + let Some((ifname, dst_ip, mac)) = pick_interface_v4() else { + eprintln!("linux_bpf_tun_receives_matching_ipv4_frame: skipped (no suitable iface)"); + return; + }; + + let dst_port: u16 = rand::thread_rng().gen_range(40000..60000); + let dst_addr = SocketAddr::new(IpAddr::V4(dst_ip), dst_port); + eprintln!( + "linux_bpf_tun_receives_matching_ipv4_frame: ifname={ifname} dst_addr={dst_addr} mac={mac}" + ); + + let tun = LinuxBpfTun::new(&ifname, None, dst_addr).unwrap(); + + let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 123, 0, 1)), 12345); + eprintln!( + "linux_bpf_tun_receives_matching_ipv4_frame: sending frame src_addr={src_addr} -> dst_addr={dst_addr}" + ); + let frame = build_tcp_packet( + mac, + mac, + src_addr, + dst_addr, + 1, + 0, + TcpFlags::SYN, + Some(b"ping"), + ); + + send_raw_frame(&ifname, &frame).unwrap(); + + let mut received = BytesMut::new(); + let n = timeout(Duration::from_secs(2), tun.recv(&mut received)) + .await + .unwrap() + .unwrap(); + eprintln!( + "linux_bpf_tun_receives_matching_ipv4_frame: received {} bytes", + n + ); + assert_eq!(n, frame.len()); + assert_eq!(&received[..], &frame[..]); + } + + #[tokio::test] + async fn linux_bpf_tun_filters_out_non_matching_ipv4_frame() { + if !is_root() { + eprintln!("linux_bpf_tun_filters_out_non_matching_ipv4_frame: skipped (not root)"); + return; + } + + let Some((ifname, dst_ip, mac)) = pick_interface_v4() else { + eprintln!( + "linux_bpf_tun_filters_out_non_matching_ipv4_frame: skipped (no suitable iface)" + ); + return; + }; + + let dst_port: u16 = rand::thread_rng().gen_range(40000..60000); + let dst_addr = SocketAddr::new(IpAddr::V4(dst_ip), dst_port); + eprintln!( + "linux_bpf_tun_filters_out_non_matching_ipv4_frame: ifname={ifname} dst_addr={dst_addr} mac={mac}" + ); + + let tun = LinuxBpfTun::new(&ifname, None, dst_addr).unwrap(); + + let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 123, 0, 2)), 23456); + let non_matching_dst = SocketAddr::new(IpAddr::V4(dst_ip), dst_port.wrapping_add(1)); + eprintln!( + "linux_bpf_tun_filters_out_non_matching_ipv4_frame: sending non-matching src_addr={src_addr} -> dst_addr={non_matching_dst}" + ); + let non_matching = build_tcp_packet( + mac, + mac, + src_addr, + non_matching_dst, + 1, + 0, + TcpFlags::SYN, + Some(b"nope"), + ); + send_raw_frame(&ifname, &non_matching).unwrap(); + + let mut received = BytesMut::new(); + let non_matching_timeout = timeout(Duration::from_millis(400), tun.recv(&mut received)) + .await + .is_err(); + eprintln!( + "linux_bpf_tun_filters_out_non_matching_ipv4_frame: non-matching recv timeout={}", + non_matching_timeout + ); + assert!(non_matching_timeout); + + eprintln!( + "linux_bpf_tun_filters_out_non_matching_ipv4_frame: sending matching src_addr={src_addr} -> dst_addr={dst_addr}" + ); + let matching = build_tcp_packet( + mac, + mac, + src_addr, + dst_addr, + 2, + 0, + TcpFlags::SYN, + Some(b"ok"), + ); + send_raw_frame(&ifname, &matching).unwrap(); + + let mut received2 = BytesMut::new(); + let n = timeout(Duration::from_secs(2), tun.recv(&mut received2)) + .await + .unwrap() + .unwrap(); + eprintln!( + "linux_bpf_tun_filters_out_non_matching_ipv4_frame: received {} bytes", + n + ); + assert_eq!(n, matching.len()); + assert_eq!(&received2[..], &matching[..]); + } +} diff --git a/easytier/src/tunnel/fake_tcp/netfilter/macos_bpf.rs b/easytier/src/tunnel/fake_tcp/netfilter/macos_bpf.rs new file mode 100644 index 00000000..ca252257 --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/netfilter/macos_bpf.rs @@ -0,0 +1,1122 @@ +use bytes::Bytes; +use bytes::BytesMut; +use nix::libc; +use std::ffi::CString; +use std::io; +use std::mem; +use std::net::IpAddr; +use std::net::SocketAddr; +use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; +use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering}; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::{debug, info, warn}; + +use crate::tunnel::fake_tcp::stack; + +const ETH_HDR_LEN: usize = 14; +const ETH_TYPE_OFFSET: u32 = 12; +const ETHERTYPE_IPV4: u32 = 0x0800; +const ETHERTYPE_IPV6: u32 = 0x86DD; +const IPPROTO_TCP_U32: u32 = 6; + +const DLT_EN10MB: u32 = 1; +const DLT_NULL: u32 = 4; +const DLT_RAW: u32 = 12; +const DLT_LOOP: u32 = 108; + +const BPF_LD: u16 = 0x00; +const BPF_LDX: u16 = 0x01; +const BPF_JMP: u16 = 0x05; +const BPF_RET: u16 = 0x06; + +const BPF_W: u16 = 0x00; +const BPF_H: u16 = 0x08; +const BPF_B: u16 = 0x10; + +const BPF_ABS: u16 = 0x20; +const BPF_IND: u16 = 0x40; +const BPF_MSH: u16 = 0xa0; + +const BPF_JA: u16 = 0x00; +const BPF_JEQ: u16 = 0x10; + +const BPF_K: u16 = 0x00; + +const BPF_GROUP: u8 = b'B'; + +const BIOCGBLEN_NUM: u8 = 102; +const BIOCSBLEN_NUM: u8 = 102; +const BIOCSETF_NUM: u8 = 103; +const BIOCFLUSH_NUM: u8 = 104; +const BIOCGDLT_NUM: u8 = 106; +const BIOCSETIF_NUM: u8 = 108; +const BIOCSRTIMEOUT_NUM: u8 = 109; +const BIOCIMMEDIATE_NUM: u8 = 112; +const BIOCSHDRCMPLT_NUM: u8 = 117; +const BIOCSSEESENT_NUM: u8 = 119; + +const IOCPARM_MASK: u32 = 0x1fff; +const IOC_VOID: u32 = 0x2000_0000; +const IOC_OUT: u32 = 0x4000_0000; +const IOC_IN: u32 = 0x8000_0000; +const IOC_INOUT: u32 = IOC_IN | IOC_OUT; + +#[derive(Clone, Copy)] +enum LinkType { + En10Mb, + Null, + Raw, + Loop, + Utun, +} + +impl LinkType { + fn from_dlt(dlt: u32) -> Option { + match dlt { + DLT_EN10MB => Some(Self::En10Mb), + DLT_NULL => Some(Self::Null), + DLT_RAW => Some(Self::Raw), + DLT_LOOP => Some(Self::Loop), + _ => None, + } + } +} + +fn looks_like_ip(packet: &[u8]) -> bool { + matches!(packet.first().map(|b| b >> 4), Some(4 | 6)) +} + +fn maybe_unwrap_utun_payload(packet: &[u8]) -> Option<&[u8]> { + if looks_like_ip(packet) { + return Some(packet); + } + if packet.len() < 5 { + return None; + } + let payload = &packet[4..]; + if !looks_like_ip(payload) { + return None; + } + Some(payload) +} + +fn ether_type_from_ip_packet(ip: &[u8]) -> Option { + let v = *ip.first()?; + match v >> 4 { + 4 => Some(0x0800), + 6 => Some(0x86DD), + _ => None, + } +} + +fn wrap_ip_with_ethernet(ip: &[u8]) -> Option> { + let ether_type = ether_type_from_ip_packet(ip)?; + let mut out = vec![0u8; ETH_HDR_LEN + ip.len()]; + out[12..14].copy_from_slice(ðer_type.to_be_bytes()); + out[ETH_HDR_LEN..].copy_from_slice(ip); + Some(out) +} + +fn family_word_for_null(family: u32) -> u32 { + u32::from_be_bytes(family.to_ne_bytes()) +} + +#[repr(C)] +#[derive(Clone, Copy)] +struct BpfInsn { + code: u16, + jt: u8, + jf: u8, + k: u32, +} + +#[repr(C)] +struct BpfProgram { + bf_len: u32, + bf_insns: *mut BpfInsn, +} + +fn read_u16_ne(bytes: &[u8]) -> u16 { + u16::from_ne_bytes([bytes[0], bytes[1]]) +} + +fn read_u32_ne(bytes: &[u8]) -> u32 { + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) +} + +fn bpf_word_align_with(align: usize, x: usize) -> usize { + (x + (align - 1)) & !(align - 1) +} + +fn parse_bpf_record(buf: &[u8], align: usize) -> Option<(usize, std::ops::Range, u16, u32)> { + let max_shift = std::cmp::min(align, buf.len()); + for shift in 0..max_shift { + let rest = &buf[shift..]; + + let try_ts8 = || -> Option<(usize, std::ops::Range, u16, u32)> { + let base_hdr_len = 18usize; + if rest.len() < base_hdr_len { + return None; + } + let caplen = read_u32_ne(rest.get(8..12)?) as usize; + let datalen = read_u32_ne(rest.get(12..16)?) as usize; + let hdrlen = read_u16_ne(rest.get(16..18)?) as usize; + if hdrlen < base_hdr_len || hdrlen > 512 { + return None; + } + if caplen > datalen { + return None; + } + let pkt_start = shift + hdrlen; + let pkt_end = pkt_start.checked_add(caplen)?; + if pkt_end > buf.len() { + return None; + } + let advance = shift + bpf_word_align_with(align, hdrlen + caplen); + Some((advance, pkt_start..pkt_end, hdrlen as u16, caplen as u32)) + }; + + if let Some(v) = try_ts8() { + return Some(v); + } + + let try_ts16 = || -> Option<(usize, std::ops::Range, u16, u32)> { + let base_hdr_len = 26usize; + if rest.len() < base_hdr_len { + return None; + } + let caplen = read_u32_ne(rest.get(16..20)?) as usize; + let datalen = read_u32_ne(rest.get(20..24)?) as usize; + let hdrlen = read_u16_ne(rest.get(24..26)?) as usize; + if hdrlen < base_hdr_len || hdrlen > 512 { + return None; + } + if caplen > datalen { + return None; + } + let pkt_start = shift + hdrlen; + let pkt_end = pkt_start.checked_add(caplen)?; + if pkt_end > buf.len() { + return None; + } + let advance = shift + bpf_word_align_with(align, hdrlen + caplen); + Some((advance, pkt_start..pkt_end, hdrlen as u16, caplen as u32)) + }; + + if let Some(v) = try_ts16() { + return Some(v); + } + } + None +} + +fn ioc(inout: u32, group: u8, num: u8, len: u32) -> libc::c_ulong { + (inout | ((len & IOCPARM_MASK) << 16) | ((group as u32) << 8) | (num as u32)) as libc::c_ulong +} + +fn io(group: u8, num: u8) -> libc::c_ulong { + ioc(IOC_VOID, group, num, 0) +} + +fn ior(group: u8, num: u8) -> libc::c_ulong { + ioc(IOC_OUT, group, num, mem::size_of::() as u32) +} + +fn iow(group: u8, num: u8) -> libc::c_ulong { + ioc(IOC_IN, group, num, mem::size_of::() as u32) +} + +fn iowr(group: u8, num: u8) -> libc::c_ulong { + ioc(IOC_INOUT, group, num, mem::size_of::() as u32) +} + +unsafe fn ioctl_ptr(fd: libc::c_int, req: libc::c_ulong, arg: *mut T) -> io::Result<()> { + let ret = libc::ioctl(fd, req, arg); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) +} + +unsafe fn ioctl_void(fd: libc::c_int, req: libc::c_ulong) -> io::Result<()> { + let ret = libc::ioctl(fd, req); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) +} + +fn stmt(code: u16, k: u32) -> BpfInsn { + BpfInsn { + code, + jt: 0, + jf: 0, + k, + } +} + +fn jeq(k: u32, jt: u8, jf: u8) -> BpfInsn { + BpfInsn { + code: BPF_JMP | BPF_JEQ | BPF_K, + jt, + jf, + k, + } +} + +fn ja(k: u32) -> BpfInsn { + BpfInsn { + code: BPF_JMP | BPF_JA, + jt: 0, + jf: 0, + k, + } +} + +#[derive(Clone, Copy)] +struct Label(usize); + +struct JeqPatch { + idx: usize, + t: Label, + f: Label, +} + +struct JaPatch { + idx: usize, + target: Label, +} + +struct BpfBuilder { + insns: Vec, + labels: Vec>, + jeq_patches: Vec, + ja_patches: Vec, +} + +impl BpfBuilder { + fn new() -> Self { + Self { + insns: Vec::new(), + labels: Vec::new(), + jeq_patches: Vec::new(), + ja_patches: Vec::new(), + } + } + + fn new_label(&mut self) -> Label { + let idx = self.labels.len(); + self.labels.push(None); + Label(idx) + } + + fn set_label(&mut self, label: Label) { + self.labels[label.0] = Some(self.insns.len()); + } + + fn push(&mut self, insn: BpfInsn) { + self.insns.push(insn); + } + + fn push_jeq(&mut self, k: u32, t: Label, f: Label) { + let idx = self.insns.len(); + self.insns.push(jeq(k, 0, 0)); + self.jeq_patches.push(JeqPatch { idx, t, f }); + } + + fn push_ja(&mut self, target: Label) { + let idx = self.insns.len(); + self.insns.push(ja(0)); + self.ja_patches.push(JaPatch { idx, target }); + } + + fn finish(mut self) -> io::Result> { + for patch in self.jeq_patches { + let JeqPatch { idx, t, f } = patch; + let cur = idx + 1; + let t_pos = + self.labels.get(t.0).and_then(|v| *v).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "unresolved label") + })?; + let f_pos = + self.labels.get(f.0).and_then(|v| *v).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "unresolved label") + })?; + + if t_pos < cur || f_pos < cur { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "backward bpf jump", + )); + } + + let jt: u8 = (t_pos - cur) + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?; + let jf: u8 = (f_pos - cur) + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?; + + self.insns[idx].jt = jt; + self.insns[idx].jf = jf; + } + + for patch in self.ja_patches { + let JaPatch { idx, target } = patch; + let cur = idx + 1; + let t_pos = + self.labels.get(target.0).and_then(|v| *v).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "unresolved label") + })?; + if t_pos < cur { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "backward bpf jump", + )); + } + self.insns[idx].k = (t_pos - cur) as u32; + } + + Ok(self.insns) + } +} + +fn build_tcp_filter_ethernet( + src_addr: Option, + dst_addr: SocketAddr, +) -> io::Result> { + if let Some(src) = src_addr { + if src.is_ipv4() != dst_addr.is_ipv4() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "src/dst addr family mismatch", + )); + } + } + + let mut b = BpfBuilder::new(); + let l_check_ipv6 = b.new_label(); + let l_ipv4 = b.new_label(); + let l_ipv6 = b.new_label(); + let l_accept = b.new_label(); + let l_reject = b.new_label(); + + b.push(stmt(BPF_LD | BPF_H | BPF_ABS, ETH_TYPE_OFFSET)); + b.push_jeq(ETHERTYPE_IPV4, l_ipv4, l_check_ipv6); + + b.set_label(l_check_ipv6); + b.push_jeq(ETHERTYPE_IPV6, l_ipv6, l_reject); + + if dst_addr.is_ipv4() { + b.set_label(l_ipv4); + let l_v4_proto_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 9) as u32)); + b.push_jeq(IPPROTO_TCP_U32, l_v4_proto_ok, l_reject); + + b.set_label(l_v4_proto_ok); + let dst_ip = match dst_addr.ip() { + IpAddr::V4(ip) => u32::from(ip), + _ => unreachable!(), + }; + let l_v4_dstip_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 16) as u32)); + b.push_jeq(dst_ip, l_v4_dstip_ok, l_reject); + + b.set_label(l_v4_dstip_ok); + if let Some(src) = src_addr { + let src_ip = match src.ip() { + IpAddr::V4(ip) => u32::from(ip), + _ => unreachable!(), + }; + let l_v4_srcip_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 12) as u32)); + b.push_jeq(src_ip, l_v4_srcip_ok, l_reject); + b.set_label(l_v4_srcip_ok); + } + + b.push(stmt(BPF_LDX | BPF_B | BPF_MSH, ETH_HDR_LEN as u32)); + + let l_v4_dstport_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_H | BPF_IND, (ETH_HDR_LEN + 2) as u32)); + b.push_jeq(dst_addr.port() as u32, l_v4_dstport_ok, l_reject); + + b.set_label(l_v4_dstport_ok); + if let Some(src) = src_addr { + b.push(stmt(BPF_LD | BPF_H | BPF_IND, ETH_HDR_LEN as u32)); + b.push_jeq(src.port() as u32, l_accept, l_reject); + } else { + b.push_ja(l_accept); + } + } else { + b.set_label(l_ipv6); + let l_v6_proto_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 6) as u32)); + b.push_jeq(IPPROTO_TCP_U32, l_v6_proto_ok, l_reject); + + b.set_label(l_v6_proto_ok); + let dst_ip = match dst_addr.ip() { + IpAddr::V6(ip) => ip.octets(), + _ => unreachable!(), + }; + for (i, chunk) in dst_ip.chunks_exact(4).enumerate() { + let off = ETH_HDR_LEN + 24 + (i * 4); + let v = u32::from_be_bytes(chunk.try_into().unwrap()); + let l_v6_dstip_word_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32)); + b.push_jeq(v, l_v6_dstip_word_ok, l_reject); + b.set_label(l_v6_dstip_word_ok); + } + + if let Some(src) = src_addr { + let src_ip = match src.ip() { + IpAddr::V6(ip) => ip.octets(), + _ => unreachable!(), + }; + for (i, chunk) in src_ip.chunks_exact(4).enumerate() { + let off = ETH_HDR_LEN + 8 + (i * 4); + let v = u32::from_be_bytes(chunk.try_into().unwrap()); + let l_v6_srcip_word_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32)); + b.push_jeq(v, l_v6_srcip_word_ok, l_reject); + b.set_label(l_v6_srcip_word_ok); + } + } + + let l_v6_dstport_ok = b.new_label(); + b.push(stmt( + BPF_LD | BPF_H | BPF_ABS, + (ETH_HDR_LEN + 40 + 2) as u32, + )); + b.push_jeq(dst_addr.port() as u32, l_v6_dstport_ok, l_reject); + + b.set_label(l_v6_dstport_ok); + if let Some(src) = src_addr { + b.push(stmt(BPF_LD | BPF_H | BPF_ABS, (ETH_HDR_LEN + 40) as u32)); + b.push_jeq(src.port() as u32, l_accept, l_reject); + } else { + b.push_ja(l_accept); + } + } + + b.set_label(l_accept); + b.push(stmt(BPF_RET | BPF_K, 0xFFFF)); + + b.set_label(l_reject); + if dst_addr.is_ipv4() { + b.set_label(l_ipv6); + } else { + b.set_label(l_ipv4); + } + b.push(stmt(BPF_RET | BPF_K, 0)); + + b.finish() +} + +fn build_tcp_filter_ip( + base: u32, + src_addr: Option, + dst_addr: SocketAddr, + family_word: Option, +) -> io::Result> { + if let Some(src) = src_addr { + if src.is_ipv4() != dst_addr.is_ipv4() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "src/dst addr family mismatch", + )); + } + } + + let mut b = BpfBuilder::new(); + let l_accept = b.new_label(); + let l_reject = b.new_label(); + + if let Some(family_word) = family_word { + let l_family_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, 0)); + b.push_jeq(family_word, l_family_ok, l_reject); + b.set_label(l_family_ok); + } + + if dst_addr.is_ipv4() { + let l_v4_proto_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_B | BPF_ABS, base + 9)); + b.push_jeq(IPPROTO_TCP_U32, l_v4_proto_ok, l_reject); + + b.set_label(l_v4_proto_ok); + let dst_ip = match dst_addr.ip() { + IpAddr::V4(ip) => u32::from(ip), + _ => unreachable!(), + }; + let l_v4_dstip_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, base + 16)); + b.push_jeq(dst_ip, l_v4_dstip_ok, l_reject); + + b.set_label(l_v4_dstip_ok); + if let Some(src) = src_addr { + let src_ip = match src.ip() { + IpAddr::V4(ip) => u32::from(ip), + _ => unreachable!(), + }; + let l_v4_srcip_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, base + 12)); + b.push_jeq(src_ip, l_v4_srcip_ok, l_reject); + b.set_label(l_v4_srcip_ok); + } + + b.push(stmt(BPF_LDX | BPF_B | BPF_MSH, base)); + + let l_v4_dstport_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_H | BPF_IND, base + 2)); + b.push_jeq(dst_addr.port() as u32, l_v4_dstport_ok, l_reject); + + b.set_label(l_v4_dstport_ok); + if let Some(src) = src_addr { + b.push(stmt(BPF_LD | BPF_H | BPF_IND, base)); + b.push_jeq(src.port() as u32, l_accept, l_reject); + } else { + b.push_ja(l_accept); + } + } else { + let l_v6_proto_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_B | BPF_ABS, base + 6)); + b.push_jeq(IPPROTO_TCP_U32, l_v6_proto_ok, l_reject); + + b.set_label(l_v6_proto_ok); + let dst_ip = match dst_addr.ip() { + IpAddr::V6(ip) => ip.octets(), + _ => unreachable!(), + }; + for (i, chunk) in dst_ip.chunks_exact(4).enumerate() { + let off = base + 24 + (i * 4) as u32; + let v = u32::from_be_bytes(chunk.try_into().unwrap()); + let l_v6_dstip_word_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off)); + b.push_jeq(v, l_v6_dstip_word_ok, l_reject); + b.set_label(l_v6_dstip_word_ok); + } + + if let Some(src) = src_addr { + let src_ip = match src.ip() { + IpAddr::V6(ip) => ip.octets(), + _ => unreachable!(), + }; + for (i, chunk) in src_ip.chunks_exact(4).enumerate() { + let off = base + 8 + (i * 4) as u32; + let v = u32::from_be_bytes(chunk.try_into().unwrap()); + let l_v6_srcip_word_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off)); + b.push_jeq(v, l_v6_srcip_word_ok, l_reject); + b.set_label(l_v6_srcip_word_ok); + } + } + + let l_v6_dstport_ok = b.new_label(); + b.push(stmt(BPF_LD | BPF_H | BPF_ABS, base + 40 + 2)); + b.push_jeq(dst_addr.port() as u32, l_v6_dstport_ok, l_reject); + + b.set_label(l_v6_dstport_ok); + if let Some(src) = src_addr { + b.push(stmt(BPF_LD | BPF_H | BPF_ABS, base + 40)); + b.push_jeq(src.port() as u32, l_accept, l_reject); + } else { + b.push_ja(l_accept); + } + } + + b.set_label(l_accept); + b.push(stmt(BPF_RET | BPF_K, 0xFFFF)); + + b.set_label(l_reject); + b.push(stmt(BPF_RET | BPF_K, 0)); + + b.finish() +} + +fn build_tcp_filter_utun( + src_addr: Option, + dst_addr: SocketAddr, +) -> io::Result> { + let raw = build_tcp_filter_ip(0, src_addr, dst_addr, None)?; + + let family = if dst_addr.is_ipv4() { + libc::AF_INET as u32 + } else { + libc::AF_INET6 as u32 + }; + let family_hdr = + build_tcp_filter_ip(4, src_addr, dst_addr, Some(family_word_for_null(family)))?; + + if raw.is_empty() { + return Ok(family_hdr); + } + + let mut combined = raw; + if let Some(last) = combined.last_mut() { + if last.code == (BPF_RET | BPF_K) && last.k == 0 { + *last = ja(0); + } else { + combined.push(ja(0)); + } + } + combined.extend(family_hdr); + Ok(combined) +} + +fn build_tcp_filter( + link_type: LinkType, + src_addr: Option, + dst_addr: SocketAddr, +) -> io::Result> { + match link_type { + LinkType::En10Mb => build_tcp_filter_ethernet(src_addr, dst_addr), + LinkType::Raw => build_tcp_filter_ip(0, src_addr, dst_addr, None), + LinkType::Null => { + let family = if dst_addr.is_ipv4() { + libc::AF_INET as u32 + } else { + libc::AF_INET6 as u32 + }; + build_tcp_filter_ip(4, src_addr, dst_addr, Some(family_word_for_null(family))) + } + LinkType::Loop => { + let family = if dst_addr.is_ipv4() { + libc::AF_INET as u32 + } else { + libc::AF_INET6 as u32 + }; + build_tcp_filter_ip(4, src_addr, dst_addr, Some(family)) + } + LinkType::Utun => build_tcp_filter_utun(src_addr, dst_addr), + } +} + +fn open_bpf_device() -> io::Result { + let mut last_err: Option = None; + for i in 0..256 { + let path = format!("/dev/bpf{}", i); + let c_path = CString::new(path.as_str()) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "path"))?; + let fd = unsafe { libc::open(c_path.as_ptr(), libc::O_RDWR) }; + if fd >= 0 { + debug!(path, "opened bpf device"); + return Ok(unsafe { OwnedFd::from_raw_fd(fd) }); + } + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::EBUSY) { + last_err = Some(err); + continue; + } + last_err = Some(err); + } + Err(last_err + .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "no available /dev/bpf device"))) +} + +fn set_ifreq_name(ifr: &mut libc::ifreq, interface_name: &str) -> io::Result<()> { + let bytes = interface_name.as_bytes(); + let ifnamsiz = libc::IFNAMSIZ as usize; + if bytes.len() >= ifnamsiz { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "interface name too long", + )); + } + for i in 0..ifnamsiz { + ifr.ifr_name[i] = 0; + } + for (i, &b) in bytes.iter().enumerate() { + ifr.ifr_name[i] = b as libc::c_char; + } + Ok(()) +} + +pub struct MacosBpfTun { + fd: OwnedFd, + link_type: LinkType, + stop: Arc, + worker: Option>, + recv_queue: Mutex>>, +} + +impl MacosBpfTun { + pub fn new( + interface_name: &str, + src_addr: Option, + dst_addr: SocketAddr, + ) -> io::Result { + let fd = open_bpf_device()?; + let raw_fd = fd.as_raw_fd(); + + let mut buf_len: libc::c_uint = 0; + unsafe { + ioctl_ptr( + raw_fd, + ior::(BPF_GROUP, BIOCGBLEN_NUM), + &mut buf_len, + ) + }?; + if buf_len == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "bpf buffer length is zero", + )); + } + + let mut desired_buf_len: libc::c_uint = buf_len; + let _ = unsafe { + ioctl_ptr( + raw_fd, + iowr::(BPF_GROUP, BIOCSBLEN_NUM), + &mut desired_buf_len, + ) + }; + + let mut immediate: libc::c_uint = 1; + unsafe { + ioctl_ptr( + raw_fd, + iow::(BPF_GROUP, BIOCIMMEDIATE_NUM), + &mut immediate, + ) + }?; + + let mut seesent: libc::c_uint = 0; + unsafe { + ioctl_ptr( + raw_fd, + iow::(BPF_GROUP, BIOCSSEESENT_NUM), + &mut seesent, + ) + }?; + + let mut hdr_complete: libc::c_uint = 1; + match unsafe { + ioctl_ptr( + raw_fd, + iow::(BPF_GROUP, BIOCSHDRCMPLT_NUM), + &mut hdr_complete, + ) + } { + Ok(()) => {} + Err(e) if e.raw_os_error() == Some(libc::EINVAL) => {} + Err(e) => return Err(e), + } + + let timeout = libc::timeval { + tv_sec: 0, + tv_usec: 200_000, + }; + let mut timeout_mut = timeout; + unsafe { + ioctl_ptr( + raw_fd, + iow::(BPF_GROUP, BIOCSRTIMEOUT_NUM), + &mut timeout_mut, + ) + }?; + + unsafe { ioctl_void(raw_fd, io(BPF_GROUP, BIOCFLUSH_NUM)) }?; + + let mut ifr: libc::ifreq = unsafe { mem::zeroed() }; + set_ifreq_name(&mut ifr, interface_name)?; + unsafe { + ioctl_ptr( + raw_fd, + iow::(BPF_GROUP, BIOCSETIF_NUM), + &mut ifr, + ) + }?; + + let mut dlt: libc::c_uint = 0; + unsafe { + ioctl_ptr( + raw_fd, + ior::(BPF_GROUP, BIOCGDLT_NUM), + &mut dlt, + ) + }?; + + let link_type = if interface_name.starts_with("utun") { + LinkType::Utun + } else { + LinkType::from_dlt(dlt as u32).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("unsupported datalink type {}", dlt), + ) + })? + }; + + let filter = build_tcp_filter(link_type, src_addr, dst_addr)?; + + let mut bpf_insns: Vec = filter; + let mut prog = BpfProgram { + bf_len: bpf_insns + .len() + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf program too long"))?, + bf_insns: bpf_insns.as_mut_ptr(), + }; + unsafe { + ioctl_ptr( + raw_fd, + iow::(BPF_GROUP, BIOCSETF_NUM), + &mut prog, + ) + }?; + + info!( + interface_name, + ?src_addr, + ?dst_addr, + dlt, + link_type = match link_type { + LinkType::En10Mb => "en10mb", + LinkType::Null => "null", + LinkType::Raw => "raw", + LinkType::Loop => "loop", + LinkType::Utun => "utun", + }, + filter_len = bpf_insns.len(), + buf_len, + desired_buf_len, + "MacosBpfTun created" + ); + + let stop = Arc::new(AtomicBool::new(false)); + let (tx, rx) = tokio::sync::mpsc::channel(1024); + let stop_clone = stop.clone(); + let read_fd = raw_fd; + let worker_link_type = link_type; + let worker = std::thread::spawn(move || { + let mut buf = vec![0u8; desired_buf_len as usize]; + let mut wrap_fail_logs_left: u8 = 5; + let mut bad_record_logs_left: u8 = 8; + let mut shifted_record_logs_left: u8 = 8; + let align = mem::size_of::(); + while !stop_clone.load(AtomicOrdering::Relaxed) { + let n = unsafe { + libc::read(read_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) + }; + if n < 0 { + let err = io::Error::last_os_error(); + if matches!( + err.kind(), + io::ErrorKind::Interrupted | io::ErrorKind::WouldBlock + ) { + continue; + } + warn!(?err, "MacosBpfTun bpf read failed"); + break; + } + if n == 0 { + continue; + } + let mut off = 0usize; + let n = n as usize; + while off < n { + let window = &buf[off..n]; + let Some((advance, pkt_range, hdr_len, cap_len)) = + parse_bpf_record(window, align) + else { + if bad_record_logs_left > 0 { + bad_record_logs_left -= 1; + let preview_len = std::cmp::min(window.len(), 48); + let preview = &window[..preview_len]; + warn!(off, read_len = n, preview = ?preview, "MacosBpfTun failed to parse bpf records"); + } + break; + }; + + let pkt_start = off + pkt_range.start; + let pkt_end = off + pkt_range.end; + let shift = (pkt_range.start as usize).saturating_sub(hdr_len as usize); + if shift != 0 && shifted_record_logs_left > 0 { + shifted_record_logs_left -= 1; + warn!( + off, + record_start = off + shift, + shift, + hdr_len, + cap_len, + read_len = n, + "MacosBpfTun parsed bpf record with non-zero offset" + ); + } + + let packet = &buf[pkt_start..pkt_end]; + let framed = match worker_link_type { + LinkType::En10Mb => Some(packet.to_vec()), + LinkType::Raw => wrap_ip_with_ethernet(packet), + LinkType::Null | LinkType::Loop => { + if packet.len() < 4 { + None + } else { + wrap_ip_with_ethernet(&packet[4..]) + } + } + LinkType::Utun => { + maybe_unwrap_utun_payload(packet).and_then(wrap_ip_with_ethernet) + } + }; + if let Some(framed) = framed { + if tx.blocking_send(framed).is_err() { + return; + } + } else if wrap_fail_logs_left > 0 { + wrap_fail_logs_left -= 1; + warn!( + link_type = match worker_link_type { + LinkType::En10Mb => "en10mb", + LinkType::Null => "null", + LinkType::Raw => "raw", + LinkType::Loop => "loop", + LinkType::Utun => "utun", + }, + packet_len = packet.len(), + "MacosBpfTun failed to wrap packet" + ); + } + if advance == 0 { + break; + } + off = off.saturating_add(advance); + } + } + }); + + Ok(Self { + fd, + link_type, + stop, + worker: Some(worker), + recv_queue: Mutex::new(rx), + }) + } +} + +impl Drop for MacosBpfTun { + fn drop(&mut self) { + self.stop.store(true, AtomicOrdering::Relaxed); + if let Some(worker) = self.worker.take() { + let _ = worker.join(); + } + } +} + +#[async_trait::async_trait] +impl stack::Tun for MacosBpfTun { + async fn recv(&self, packet: &mut BytesMut) -> Result { + let mut rx = self.recv_queue.lock().await; + match rx.recv().await { + Some(data) => { + packet.extend_from_slice(&data); + Ok(data.len()) + } + None => Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "MacosBpfTun channel closed", + )), + } + } + + #[tracing::instrument(ret, skip(self))] + fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> { + if packet.len() < ETH_HDR_LEN { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "packet too short", + )); + } + let payload = &packet[ETH_HDR_LEN..]; + let write_all = |ptr: *const u8, len: usize| -> Result<(), std::io::Error> { + let ret = unsafe { libc::write(self.fd.as_raw_fd(), ptr as *const libc::c_void, len) }; + if ret < 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) + }; + + let mut out_len = 0usize; + let res = match self.link_type { + LinkType::En10Mb => { + out_len = packet.len(); + write_all(packet.as_ptr(), packet.len()) + } + LinkType::Raw => { + out_len = payload.len(); + write_all(payload.as_ptr(), payload.len()) + } + LinkType::Null | LinkType::Loop | LinkType::Utun => { + let family = match payload.first().map(|b| b >> 4) { + Some(4) => libc::AF_INET as u32, + Some(6) => libc::AF_INET6 as u32, + _ => { + warn!( + first_byte = payload.first().copied(), + payload_len = payload.len(), + "MacosBpfTun try_send invalid ip version" + ); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid ip version", + )); + } + }; + + let primary_hdr = match self.link_type { + LinkType::Null => family.to_ne_bytes(), + LinkType::Loop => family.to_be_bytes(), + LinkType::Utun => family.to_ne_bytes(), + _ => unreachable!(), + }; + + let mut out = vec![0u8; 4 + payload.len()]; + out[..4].copy_from_slice(&primary_hdr); + out[4..].copy_from_slice(payload); + out_len = out.len(); + + match write_all(out.as_ptr(), out.len()) { + Ok(()) => Ok(()), + Err(e) + if matches!(self.link_type, LinkType::Utun) + && e.raw_os_error() == Some(libc::EINVAL) + && primary_hdr != family.to_be_bytes() => + { + let mut out = vec![0u8; 4 + payload.len()]; + out[..4].copy_from_slice(&family.to_be_bytes()); + out[4..].copy_from_slice(payload); + out_len = out.len(); + write_all(out.as_ptr(), out.len()) + } + Err(e) => Err(e), + } + } + }; + + if let Err(err) = res { + warn!( + ?err, + link_type = match self.link_type { + LinkType::En10Mb => "en10mb", + LinkType::Null => "null", + LinkType::Raw => "raw", + LinkType::Loop => "loop", + LinkType::Utun => "utun", + }, + in_len = packet.len(), + out_len, + "MacosBpfTun bpf write failed" + ); + return Err(err); + } + + Ok(()) + } + + fn driver_type(&self) -> &'static str { + "macos_bpf" + } +} diff --git a/easytier/src/tunnel/fake_tcp/netfilter/mod.rs b/easytier/src/tunnel/fake_tcp/netfilter/mod.rs new file mode 100644 index 00000000..5ed33d7a --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/netfilter/mod.rs @@ -0,0 +1,87 @@ +pub mod pnet; + +use std::{net::SocketAddr, sync::Arc}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "linux")] { + pub mod linux_bpf; + + pub fn create_tun( + interface_name: &str, + src_addr: Option, + dst_addr: SocketAddr, + ) -> Arc { + match linux_bpf::LinuxBpfTun::new(interface_name, src_addr, dst_addr) { + Ok(tun) => Arc::new(tun), + Err(e) => { + tracing::warn!( + ?e, + interface_name, + "LinuxBpfTun init failed, falling back to PnetTun" + ); + Arc::new(pnet::PnetTun::new( + interface_name, + pnet::create_packet_filter(src_addr, dst_addr), + )) + } + } + } + } else if #[cfg(target_os = "macos")] { + pub mod macos_bpf; + + pub fn create_tun( + interface_name: &str, + src_addr: Option, + dst_addr: SocketAddr, + ) -> Arc { + match macos_bpf::MacosBpfTun::new(interface_name, src_addr, dst_addr) { + Ok(tun) => Arc::new(tun), + Err(e) => { + tracing::warn!( + ?e, + interface_name, + "MacosBpfTun init failed, falling back to PnetTun" + ); + Arc::new(pnet::PnetTun::new( + interface_name, + pnet::create_packet_filter(src_addr, dst_addr), + )) + } + } + } + } else if #[cfg(all(windows, any(target_arch = "x86_64", target_arch = "x86")))] { + pub mod windivert; + + pub fn create_tun( + _interface_name: &str, + _src_addr: Option, + local_addr: SocketAddr, + ) -> Arc { + match windivert::WinDivertTun::new(local_addr) { + Ok(tun) => Arc::new(tun), + Err(e) => { + tracing::warn!( + ?e, + ?local_addr, + "WinDivertTun init failed, falling back to PnetTun" + ); + Arc::new(pnet::PnetTun::new( + local_addr.to_string().as_str(), + pnet::create_packet_filter(None, local_addr), + )) + } + } + } + } else { + pub fn create_tun( + interface_name: &str, + src_addr: Option, + dst_addr: SocketAddr, + ) -> Arc { + Arc::new(pnet::PnetTun::new( + interface_name, + pnet::create_packet_filter(src_addr, dst_addr), + )) + } + } +} diff --git a/easytier/src/tunnel/fake_tcp/netfilter/pnet.rs b/easytier/src/tunnel/fake_tcp/netfilter/pnet.rs new file mode 100644 index 00000000..d7246a0c --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/netfilter/pnet.rs @@ -0,0 +1,304 @@ +use std::{ + net::{IpAddr, SocketAddr}, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Weak, + }, +}; + +use bytes::{Bytes, BytesMut}; +use dashmap::DashMap; +use once_cell::sync::Lazy; +use pnet::{ + datalink::{self, DataLinkSender, NetworkInterface}, + packet::{ethernet::EtherTypes, ip::IpNextHeaderProtocols, ipv6::Ipv6Packet}, +}; +use tokio::sync::Mutex; + +use crate::tunnel::fake_tcp::stack; + +type PacketFilter = Box bool + Send + Sync>; + +fn filter_tcp_packet( + packet: &[u8], + src_addr: Option<&SocketAddr>, + dst_addr: Option<&SocketAddr>, +) -> bool { + use pnet::packet::ethernet::EthernetPacket; + use pnet::packet::ipv4::Ipv4Packet; + use pnet::packet::tcp::TcpPacket; + use pnet::packet::Packet; + + let ethernet = if let Some(ethernet) = EthernetPacket::new(packet) { + ethernet + } else { + return false; + }; + + match ethernet.get_ethertype() { + EtherTypes::Ipv4 => { + let ipv4 = if let Some(ipv4) = Ipv4Packet::new(ethernet.payload()) { + ipv4 + } else { + return false; + }; + + if ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { + return false; + } + + let tcp = if let Some(tcp) = TcpPacket::new(ipv4.payload()) { + tcp + } else { + return false; + }; + + if let Some(src_addr) = src_addr { + if IpAddr::V4(ipv4.get_source()) != src_addr.ip() { + return false; + } + if tcp.get_source() != src_addr.port() { + return false; + } + } + + if let Some(dst_addr) = dst_addr { + if IpAddr::V4(ipv4.get_destination()) != dst_addr.ip() { + return false; + } + if tcp.get_destination() != dst_addr.port() { + return false; + } + } + + tracing::trace!( + ?tcp, + "FakeTcpTunnelListener packet matched filter, dispatching, src_addr: {:?}, dst_addr: {:?}, packet_src_ip: {:?}, packet_dst_ip: {:?}, packet_src_port: {:?}, packet_dst_port: {:?}", + src_addr, + dst_addr, + ipv4.get_source(), + ipv4.get_destination(), + tcp.get_source(), + tcp.get_destination(), + ); + } + EtherTypes::Ipv6 => { + let ipv6 = if let Some(ipv6) = Ipv6Packet::new(ethernet.payload()) { + ipv6 + } else { + return false; + }; + + if ipv6.get_next_header() != IpNextHeaderProtocols::Tcp { + return false; + } + + let tcp = if let Some(tcp) = TcpPacket::new(ipv6.payload()) { + tcp + } else { + return false; + }; + + if let Some(src_addr) = src_addr { + if IpAddr::V6(ipv6.get_source()) != src_addr.ip() { + return false; + } + if tcp.get_source() != src_addr.port() { + return false; + } + } + + if let Some(dst_addr) = dst_addr { + if IpAddr::V6(ipv6.get_destination()) != dst_addr.ip() { + return false; + } + if tcp.get_destination() != dst_addr.port() { + return false; + } + } + + tracing::trace!( + ?tcp, + "FakeTcpTunnelListener packet matched filter, dispatching" + ); + } + _ => return false, + } + + true +} + +pub fn create_packet_filter(src_addr: Option, dst_addr: SocketAddr) -> PacketFilter { + Box::new(move |packet: &[u8]| -> bool { + filter_tcp_packet(packet, src_addr.as_ref(), Some(&dst_addr)) + }) +} + +struct Subscriber { + filter: PacketFilter, + sender: tokio::sync::mpsc::Sender>, +} + +struct InterfaceWorker { + tx: Mutex>, + subscribers: Arc>, +} + +impl InterfaceWorker { + fn new(interface: NetworkInterface) -> Arc { + let (tx, mut rx) = match datalink::channel(&interface, Default::default()) { + Ok(pnet::datalink::Channel::Ethernet(tx, rx)) => (tx, rx), + Ok(_) => panic!("Unhandled channel type"), + Err(e) => panic!( + "An error occurred when creating the datalink channel: {}", + e + ), + }; + + let subscribers = Arc::new(DashMap::::new()); + let subscribers_clone = subscribers.clone(); + + std::thread::spawn(move || { + loop { + match rx.next() { + Ok(packet) => { + // Iterate over subscribers and send packet if filter matches + // Note: DashMap iteration might be slow if many subscribers, but usually few per interface. + // For high performance we might need a better structure or read-copy-update. + for r in subscribers_clone.iter() { + let subscriber = r.value(); + if (subscriber.filter)(packet) { + tracing::trace!( + ?packet, + "InterfaceWorker packet matched filter, dispatching" + ); + // Try send, ignore errors (best effort) + let _ = subscriber.sender.try_send(packet.to_vec()); + } + } + } + Err(e) => { + tracing::error!("InterfaceWorker read error: {}", e); + // If interface goes down, we might need to handle it. + // For now just break and maybe the worker is dead. + break; + } + } + } + }); + + Arc::new(Self { + tx: Mutex::new(tx), + subscribers, + }) + } + + fn subscribe(&self, filter: PacketFilter, sender: tokio::sync::mpsc::Sender>) -> u32 { + static ID_GEN: AtomicU32 = AtomicU32::new(0); + let id = ID_GEN.fetch_add(1, Ordering::Relaxed); + self.subscribers.insert(id, Subscriber { filter, sender }); + id + } + + fn unsubscribe(&self, id: u32) { + self.subscribers.remove(&id); + } +} + +static INTERFACE_MANAGERS: Lazy>> = Lazy::new(DashMap::new); + +fn get_or_create_worker(interface_name: &str) -> Arc { + // Check if we have an active worker + if let Some(worker) = INTERFACE_MANAGERS + .get(interface_name) + .and_then(|w| w.upgrade()) + { + return worker; + } + + // Need to create new worker. + // Lock effectively by using entry API? DashMap entry API might not be enough for complex init. + // Let's use a double-check locking style or just accept race condition (creating two workers and one wins). + // DashMap doesn't support easy "compute_if_absent" with async or heavy logic without blocking the map shard. + + // But creation is rare. + // Let's find interface first. + let interfaces = datalink::interfaces(); + let interface = interfaces + .into_iter() + .find(|iface| iface.name == interface_name) + .expect("Network interface not found"); + + let worker = InterfaceWorker::new(interface); + INTERFACE_MANAGERS.insert(interface_name.to_string(), Arc::downgrade(&worker)); + worker +} + +pub struct PnetTun { + worker: Arc, + subscription_id: u32, + recv_queue: Mutex>>, +} + +impl PnetTun { + pub fn new(interface_name: &str, filter: PacketFilter) -> Self { + tracing::debug!(interface_name, "Creating new PnetTun"); + let worker = get_or_create_worker(interface_name); + let (tx, rx) = tokio::sync::mpsc::channel(1024); + let id = worker.subscribe(filter, tx); + + Self { + worker, + subscription_id: id, + recv_queue: Mutex::new(rx), + } + } +} + +impl Drop for PnetTun { + fn drop(&mut self) { + tracing::debug!(subscription_id = self.subscription_id, "Dropping PnetTun"); + self.worker.unsubscribe(self.subscription_id); + } +} + +#[async_trait::async_trait] +impl stack::Tun for PnetTun { + async fn recv(&self, packet: &mut BytesMut) -> Result { + let mut rx = self.recv_queue.lock().await; + match rx.recv().await { + Some(data) => { + tracing::trace!(?data, "PnetTun received packet"); + packet.extend_from_slice(&data); + Ok(data.len()) + } + None => { + tracing::warn!("PnetTun recv channel closed"); + Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "PnetTun channel closed", + )) + } + } + } + + fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> { + tracing::trace!(len = packet.len(), "PnetTun try_sending packet"); + // We need async lock for tx. + // try_send is sync. We can use try_lock if available or blocking lock. + // tokio::sync::Mutex::try_lock is available. + if let Ok(mut tx) = self.worker.tx.try_lock() { + tx.send_to(packet, None) + .ok_or(std::io::Error::other("send_to failed"))? + } else { + Err(std::io::Error::new( + std::io::ErrorKind::WouldBlock, + "PnetTun tx lock busy", + )) + } + } + + fn driver_type(&self) -> &'static str { + "pnet" + } +} diff --git a/easytier/src/tunnel/fake_tcp/netfilter/windivert.rs b/easytier/src/tunnel/fake_tcp/netfilter/windivert.rs new file mode 100644 index 00000000..5aac0dcf --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/netfilter/windivert.rs @@ -0,0 +1,196 @@ +use std::cell::UnsafeCell; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; + +use anyhow::Context as _; +use bytes::{Bytes, BytesMut}; +use tokio::sync::Mutex; +use windivert::error::WinDivertError; +use windivert::packet::WinDivertPacket; +use windivert::prelude::{WinDivertFlags, WinDivertShutdownMode}; +use windivert::{layer, WinDivert}; + +use crate::tunnel::fake_tcp::stack; + +struct WinDivertReader { + inner: UnsafeCell>, +} + +unsafe impl Send for WinDivertReader {} +unsafe impl Sync for WinDivertReader {} + +impl WinDivertReader { + fn new(inner: WinDivert) -> Self { + Self { + inner: UnsafeCell::new(inner), + } + } + + fn recv<'a>( + &self, + buffer: Option<&'a mut [u8]>, + ) -> Result, WinDivertError> { + let inner = unsafe { &*self.inner.get() }; + inner.recv(buffer) + } + + fn shutdown(&self) -> anyhow::Result<()> { + let inner = unsafe { &mut *self.inner.get() }; + inner + .shutdown(WinDivertShutdownMode::Recv) + .with_context(|| "WinDivertReader shutdown failed")?; + Ok(()) + } + + fn close(&self) -> anyhow::Result<()> { + let inner = unsafe { &mut *self.inner.get() }; + inner + .close(windivert::CloseAction::Nothing) + .with_context(|| "WinDivertReader close failed")?; + Ok(()) + } +} + +impl Drop for WinDivertReader { + fn drop(&mut self) { + if let Err(e) = self.close() { + tracing::error!("WinDivertReader close failed: {:?}", e); + } + } +} + +pub struct WinDivertTun { + recv_queue: Mutex>>, + sender: Arc>>, + reader: Arc, +} + +impl Drop for WinDivertTun { + fn drop(&mut self) { + if let Ok(mut sender) = self.sender.lock() { + if let Err(e) = sender.close(windivert::CloseAction::Nothing) { + tracing::error!("WinDivertSender close failed: {:?}", e); + } + } + if let Err(e) = self.reader.shutdown() { + tracing::error!("WinDivertReader shutdown failed: {:?}", e); + } + } +} + +impl WinDivertTun { + pub fn new(local_addr: SocketAddr) -> io::Result { + let (tx, rx) = tokio::sync::mpsc::channel(1024); + + let ip_filter = match local_addr { + SocketAddr::V4(addr) => format!("ip.DstAddr == {}", addr.ip()), + SocketAddr::V6(addr) => format!("ipv6.DstAddr == {}", addr.ip()), + }; + // Filter: DstIP == LocalIP AND TCP. + let filter = format!("{} and tcp", ip_filter); + + // Sniff mode: 1 (WINDIVERT_FLAG_SNIFF) + // Layer: Network (0) + // Priority: 0 + let flags = WinDivertFlags::default().set_sniff(); + let reader = WinDivert::network(&filter, 0, flags) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + let reader = Arc::new(WinDivertReader::new(reader)); + let reader_clone = reader.clone(); + + std::thread::spawn(move || { + let reader = reader_clone; + let mut buffer = vec![0u8; 65536]; + loop { + match reader.recv(Some(&mut buffer)) { + Ok(packet) => { + let data = &packet.data; + + let mut eth_data = vec![0u8; 14 + data.len()]; + // Set EtherType + if data.len() > 0 && data[0] >> 4 == 4 { + eth_data[12] = 0x08; + eth_data[13] = 0x00; + } else { + eth_data[12] = 0x86; + eth_data[13] = 0xDD; + } + eth_data[14..].copy_from_slice(data); + + if let Err(_) = tx.blocking_send(eth_data) { + break; + } + } + Err(_) => { + // log error? + break; + } + } + } + }); + + // Sender: non-sniff, empty filter? + // Use "false" to avoid capturing anything. + // Flags: 0 + let sender = WinDivert::network("false", 0, WinDivertFlags::default()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + Ok(Self { + recv_queue: Mutex::new(rx), + sender: Arc::new(std::sync::Mutex::new(sender)), + reader, + }) + } +} + +#[async_trait::async_trait] +impl stack::Tun for WinDivertTun { + async fn recv(&self, packet: &mut BytesMut) -> Result { + let mut rx = self.recv_queue.lock().await; + match rx.recv().await { + Some(data) => { + packet.extend_from_slice(&data); + Ok(data.len()) + } + None => Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Channel closed", + )), + } + } + + fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> { + // Strip ethernet header + if packet.len() < 14 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Packet too short", + )); + } + let ip_data = &packet[14..]; + + let Ok(sender) = self.sender.try_lock() else { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "WinDivert sender lock failed", + )); + }; + + let mut pkt = unsafe { WinDivertPacket::::new(ip_data.to_vec()) }; + pkt.address.set_outbound(true); + + sender.send(&pkt).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("WinDivert send failed: {}", e), + ) + })?; + + Ok(()) + } + + fn driver_type(&self) -> &'static str { + "windivert" + } +} diff --git a/easytier/src/tunnel/fake_tcp/packet.rs b/easytier/src/tunnel/fake_tcp/packet.rs new file mode 100644 index 00000000..0651fd70 --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/packet.rs @@ -0,0 +1,159 @@ +use bytes::{Bytes, BytesMut}; +use pnet::packet::ethernet::{EtherTypes, EthernetPacket, MutableEthernetPacket}; +use pnet::packet::{ip, ipv4, ipv6, tcp}; +use pnet::util::MacAddr; +use std::convert::TryInto; +use std::net::{IpAddr, SocketAddr}; + +const IPV4_HEADER_LEN: usize = 20; +const IPV6_HEADER_LEN: usize = 40; +const TCP_HEADER_LEN: usize = 20; +pub const MAX_PACKET_LEN: usize = 1500; + +#[derive(Debug)] +pub enum IPPacket<'p> { + V4(ipv4::Ipv4Packet<'p>), + V6(ipv6::Ipv6Packet<'p>), +} + +impl IPPacket<'_> { + pub fn get_source(&self) -> IpAddr { + match self { + IPPacket::V4(p) => IpAddr::V4(p.get_source()), + IPPacket::V6(p) => IpAddr::V6(p.get_source()), + } + } + + pub fn get_destination(&self) -> IpAddr { + match self { + IPPacket::V4(p) => IpAddr::V4(p.get_destination()), + IPPacket::V6(p) => IpAddr::V6(p.get_destination()), + } + } +} + +const ETH_HDR_LEN: usize = 14; + +#[allow(clippy::too_many_arguments)] +pub fn build_tcp_packet( + src_mac: MacAddr, + dst_mac: MacAddr, + local_addr: SocketAddr, + remote_addr: SocketAddr, + seq: u32, + ack: u32, + flags: u8, + payload: Option<&[u8]>, +) -> Bytes { + let ip_header_len = match local_addr { + SocketAddr::V4(_) => IPV4_HEADER_LEN, + SocketAddr::V6(_) => IPV6_HEADER_LEN, + }; + let wscale = (flags & tcp::TcpFlags::SYN) != 0; + let tcp_header_len = TCP_HEADER_LEN + if wscale { 4 } else { 0 }; // nop + wscale + let tcp_total_len = tcp_header_len + payload.map_or(0, |payload| payload.len()); + let total_len = ip_header_len + tcp_total_len; + let mut buf = BytesMut::zeroed(ETH_HDR_LEN + total_len); + + let mut eth_buf = buf.split_to(ETH_HDR_LEN); + let mut ip_buf = buf.split_to(ip_header_len); + let mut tcp_buf = buf.split_to(tcp_total_len); + assert_eq!(0, buf.len()); + + let mut tcp = tcp::MutableTcpPacket::new(&mut tcp_buf).unwrap(); + tcp.set_window(0xffff); + tcp.set_source(local_addr.port()); + tcp.set_destination(remote_addr.port()); + tcp.set_sequence(seq); + tcp.set_acknowledgement(ack); + tcp.set_flags(flags); + tcp.set_data_offset(TCP_HEADER_LEN as u8 / 4 + if wscale { 1 } else { 0 }); + if wscale { + let wscale = tcp::TcpOption::wscale(14); + tcp.set_options(&[tcp::TcpOption::nop(), wscale]); + } + + if let Some(payload) = payload { + tcp.set_payload(payload); + } + + let mut ethernet = MutableEthernetPacket::new(&mut eth_buf).unwrap(); + ethernet.set_destination(dst_mac); + ethernet.set_source(src_mac); + ethernet.set_ethertype(EtherTypes::Ipv4); + + match (local_addr, remote_addr) { + (SocketAddr::V4(local), SocketAddr::V4(remote)) => { + let mut v4 = ipv4::MutableIpv4Packet::new(&mut ip_buf).unwrap(); + v4.set_version(4); + v4.set_header_length(IPV4_HEADER_LEN as u8 / 4); + v4.set_next_level_protocol(ip::IpNextHeaderProtocols::Tcp); + v4.set_ttl(64); + v4.set_source(*local.ip()); + v4.set_destination(*remote.ip()); + v4.set_total_length(total_len.try_into().unwrap()); + v4.set_flags(ipv4::Ipv4Flags::DontFragment); + + tcp.set_checksum(tcp::ipv4_checksum( + &tcp.to_immutable(), + &v4.get_source(), + &v4.get_destination(), + )); + + v4.set_checksum(ipv4::checksum(&v4.to_immutable())); + } + (SocketAddr::V6(local), SocketAddr::V6(remote)) => { + let mut v6 = ipv6::MutableIpv6Packet::new(&mut ip_buf).unwrap(); + v6.set_version(6); + v6.set_payload_length(tcp_total_len.try_into().unwrap()); + v6.set_next_header(ip::IpNextHeaderProtocols::Tcp); + v6.set_hop_limit(64); + v6.set_source(*local.ip()); + v6.set_destination(*remote.ip()); + + tcp.set_checksum(tcp::ipv6_checksum( + &tcp.to_immutable(), + &v6.get_source(), + &v6.get_destination(), + )); + } + _ => unreachable!(), + }; + + ip_buf.unsplit(tcp_buf); + eth_buf.unsplit(ip_buf); + eth_buf.freeze() +} + +#[tracing::instrument(ret)] +pub fn parse_ip_packet( + buf: &Bytes, +) -> Option<(MacAddr, MacAddr, IPPacket<'_>, tcp::TcpPacket<'_>)> { + let eth = EthernetPacket::new(buf).unwrap(); + let src_mac = eth.get_source(); + let dst_mac = eth.get_destination(); + + tracing::trace!("Parsing IP packet: {:?}", eth); + + let buf = &buf[ETH_HDR_LEN..]; + if buf[0] >> 4 == 4 { + let v4 = ipv4::Ipv4Packet::new(buf).unwrap(); + if v4.get_next_level_protocol() != ip::IpNextHeaderProtocols::Tcp { + return None; + } + + let tcp = tcp::TcpPacket::new(&buf[IPV4_HEADER_LEN..]).unwrap(); + Some((src_mac, dst_mac, IPPacket::V4(v4), tcp)) + } else if buf[0] >> 4 == 6 { + let v6 = ipv6::Ipv6Packet::new(buf).unwrap(); + if v6.get_next_header() != ip::IpNextHeaderProtocols::Tcp { + return None; + } + + let tcp = tcp::TcpPacket::new(&buf[IPV6_HEADER_LEN..]).unwrap(); + Some((src_mac, dst_mac, IPPacket::V6(v6), tcp)) + } else { + tracing::trace!("Invalid IP version: {}", buf[0] >> 4); + None + } +} diff --git a/easytier/src/tunnel/fake_tcp/stack.rs b/easytier/src/tunnel/fake_tcp/stack.rs new file mode 100644 index 00000000..2392967e --- /dev/null +++ b/easytier/src/tunnel/fake_tcp/stack.rs @@ -0,0 +1,561 @@ +//! A minimum, userspace TCP based datagram stack +//! +//! # Overview +//! +//! `fake-tcp` is a reusable library that implements a minimum TCP stack in +//! user space using the Tun interface. It allows programs to send datagrams +//! as if they are part of a TCP connection. `fake-tcp` has been tested to +//! be able to pass through a variety of NAT and stateful firewalls while +//! fully preserves certain desirable behavior such as out of order delivery +//! and no congestion/flow controls. +//! +//! # Core Concepts +//! +//! The core of the `fake-tcp` crate compose of two structures. [`Stack`] and +//! [`Socket`]. +//! +//! ## [`Stack`] +//! +//! [`Stack`] represents a virtual TCP stack that operates at +//! Layer 3. It is responsible for: +//! +//! * TCP active and passive open and handshake +//! * `RST` handling +//! * Interact with the Tun interface at Layer 3 +//! * Distribute incoming datagrams to corresponding [`Socket`] +//! +//! ## [`Socket`] +//! +//! [`Socket`] represents a TCP connection. It registers the identifying +//! tuple `(src_ip, src_port, dest_ip, dest_port)` inside the [`Stack`] so +//! so that incoming packets can be distributed to the right [`Socket`] with +//! using a channel. It is also what the client should use for +//! sending/receiving datagrams. +//! +//! # Examples +//! +//! Please see [`client.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/client.rs) +//! and [`server.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/server.rs) files +//! from the `phantun` crate for how to use this library in client/server mode, respectively. + +use crate::common::scoped_task::ScopedTask; + +use super::packet::*; +use bytes::{Bytes, BytesMut}; +use crossbeam::atomic::AtomicCell; +use pnet::packet::tcp::TcpOptionNumbers; +use pnet::packet::{tcp, Packet}; +use pnet::util::MacAddr; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, RwLock, +}; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::time; +use tracing::{info, trace, warn}; + +const TIMEOUT: time::Duration = time::Duration::from_secs(1); +const RETRIES: usize = 6; +const MPMC_BUFFER_LEN: usize = 512; +const MPSC_BUFFER_LEN: usize = 128; +const MAX_UNACKED_LEN: u32 = 128 * 1024 * 1024; // 128MB + +#[async_trait::async_trait] +pub trait Tun: Send + Sync + 'static { + async fn recv(&self, packet: &mut BytesMut) -> Result; + fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error>; + fn driver_type(&self) -> &'static str; +} + +#[derive(Hash, Eq, PartialEq, Clone, Debug)] +struct AddrTuple { + local_addr: SocketAddr, + remote_addr: SocketAddr, +} + +impl AddrTuple { + fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> AddrTuple { + AddrTuple { + local_addr, + remote_addr, + } + } +} + +struct Shared { + tuples: RwLock>>, + listening: RwLock>, + tun: Arc, + ready: mpsc::Sender, + tuples_purge: broadcast::Sender, +} + +pub struct Stack { + shared: Arc, + local_ip: Ipv4Addr, + local_ip6: Option, + local_mac: MacAddr, + ready: mpsc::Receiver, + reader_task: ScopedTask<()>, +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)] +pub enum State { + Idle, + SynSent, + SynReceived, + Established, +} + +pub struct Socket { + shared: Arc, + tun: Arc, + incoming: flume::Receiver, + local_addr: SocketAddr, + remote_addr: SocketAddr, + local_mac: MacAddr, + remote_mac: AtomicCell>, + seq: AtomicU32, + ack: AtomicU32, + last_ack: AtomicU32, + state: AtomicCell, +} + +/// A socket that represents a unique TCP connection between a server and client. +/// +/// The `Socket` object itself satisfies `Sync` and `Send`, which means it can +/// be safely called within an async future. +/// +/// To close a TCP connection that is no longer needed, simply drop this object +/// out of scope. +impl Socket { + #[allow(clippy::too_many_arguments)] + fn new( + shared: Arc, + tun: Arc, + local_addr: SocketAddr, + remote_addr: SocketAddr, + local_mac: MacAddr, + remote_mac: Option, + ack: Option, + state: State, + ) -> (Socket, flume::Sender) { + let (incoming_tx, incoming_rx) = flume::bounded(MPMC_BUFFER_LEN); + + ( + Socket { + shared, + tun, + incoming: incoming_rx, + local_addr, + remote_addr, + local_mac, + remote_mac: AtomicCell::new(remote_mac), + seq: AtomicU32::new(0), + ack: AtomicU32::new(ack.unwrap_or(0)), + last_ack: AtomicU32::new(ack.unwrap_or(0)), + state: AtomicCell::new(state), + }, + incoming_tx, + ) + } + + fn build_tcp_packet(&self, flags: u8, payload: Option<&[u8]>) -> Bytes { + let ack = self.ack.load(Ordering::Relaxed); + self.last_ack.store(ack, Ordering::Relaxed); + + build_tcp_packet( + self.local_mac, + self.remote_mac.load().unwrap_or(MacAddr::zero()), + self.local_addr, + self.remote_addr, + self.seq.load(Ordering::Relaxed), + ack, + flags, + payload, + ) + } + + /// Sends a datagram to the other end. + /// + /// This method takes `&self`, and it can be called safely by multiple threads + /// at the same time. + /// + /// A return of `None` means the Tun socket returned an error + /// and this socket must be closed. + pub fn try_send(&self, payload: &[u8]) -> Option<()> { + match self.state.load() { + State::Established => { + let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload)); + self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed); + self.tun.try_send(&buf).ok().and(Some(())) + } + _ => unreachable!(), + } + } + + pub fn close(&self) { + if self.state.load() != State::Idle { + let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None); + let _ = self.tun.try_send(&buf); + self.state.store(State::Idle); + } + } + + pub async fn recv_bytes(&self) -> Option> { + let mut buf = [0u8; 2048]; + self.recv(&mut buf).await.map(|size| buf[..size].to_vec()) + } + + /// Attempt to receive a datagram from the other end. + /// + /// This method takes `&self`, and it can be called safely by multiple threads + /// at the same time. + /// + /// A return of `None` means the TCP connection is broken + /// and this socket must be closed. + pub async fn recv(&self, buf: &mut [u8]) -> Option { + tracing::trace!( + "Socket recv called, local_addr: {:?}, remote_addr: {:?}", + self.local_addr, + self.remote_addr + ); + loop { + match self.state.load() { + State::Established => { + let Ok(raw_buf) = self.incoming.recv_async().await else { + info!("Connection {} recv error", self); + return None; + }; + + let (src_mac, dst_mac, _v4_packet, tcp_packet) = + parse_ip_packet(&raw_buf).unwrap(); + + tracing::trace!( + "Socket received TCP packet from {}({:?}) to {}({:?}): {:?}", + self.remote_addr, + src_mac, + self.local_addr, + dst_mac, + tcp_packet + ); + + self.remote_mac.store(Some(src_mac)); + + if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { + info!("Connection {} reset by peer", self); + return None; + } + + if (tcp_packet.get_flags() & tcp::TcpFlags::ACK) != 0 + && tcp_packet.payload().is_empty() + { + self.seq + .store(tcp_packet.get_acknowledgement(), Ordering::Relaxed); + } + + let payload = tcp_packet.payload(); + + let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32); + self.ack.store(new_ack, Ordering::Relaxed); + + for opt in tcp_packet.get_options_iter() { + if opt.get_number() == TcpOptionNumbers::SACK { + // SACK 选项类型为 5 + let payload = opt.payload(); + for chunk in payload.chunks(8) { + if chunk.len() != 8 { + continue; + } + let left = tcp_packet.get_acknowledgement(); + let right = u32::from_be_bytes(chunk[0..4].try_into().unwrap()); + let len = right.wrapping_sub(left); + + let sack_end = u32::from_be_bytes(chunk[4..8].try_into().unwrap()); + if len == 0 || sack_end <= left { + continue; + } + + let send_len = std::cmp::min(len, 1400) as usize; + let data = vec![0u8; send_len]; + + let buf = build_tcp_packet( + self.local_mac, + self.remote_mac.load().unwrap_or(MacAddr::zero()), + self.local_addr, + self.remote_addr, + left, + self.ack.load(Ordering::Relaxed), + tcp::TcpFlags::ACK, + Some(&data), + ); + + if let Err(e) = self.tun.try_send(&buf) { + tracing::error!("Failed to send SACK response: {}", e); + } + break; + } + } + } + + if payload.is_empty() { + continue; + } + + if payload.len() >= buf.len() { + tracing::warn!( + "Payload len {} > buf len {}, tcp: {:?}, payload: {:?}", + payload.len(), + buf.len(), + tcp_packet, + payload + ); + continue; + } + + buf[..payload.len()].copy_from_slice(payload); + + return Some(payload.len()); + } + State::SynSent => { + let Ok(Ok(buf)) = time::timeout(TIMEOUT, self.incoming.recv_async()).await + else { + info!("Waiting for client SYN + ACK timed out"); + return None; + }; + let (src_mac, _dst_mac, _v4_packet, tcp_packet) = + parse_ip_packet(&buf).unwrap(); + + if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { + tracing::trace!("Connection {} reset by peer", self); + return None; + } + + let expected_flag = tcp::TcpFlags::SYN | tcp::TcpFlags::ACK; + if (tcp_packet.get_flags() & expected_flag) == expected_flag { + // found our SYN + ACK + self.seq + .store(tcp_packet.get_acknowledgement(), Ordering::Relaxed); + self.ack + .store(tcp_packet.get_sequence() + 1, Ordering::Relaxed); + self.remote_mac.store(Some(src_mac)); + self.state.store(State::Established); + return Some(0); + } + } + + _ => unreachable!(), + } + } + } + + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr + } +} + +impl Drop for Socket { + /// Drop the socket and close the TCP connection + fn drop(&mut self) { + let tuple = AddrTuple::new(self.local_addr, self.remote_addr); + // dissociates ourself from the dispatch map + assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some()); + // purge cache + let _ = self.shared.tuples_purge.send(tuple); + + let buf = build_tcp_packet( + self.local_mac, + self.remote_mac.load().unwrap_or(MacAddr::zero()), + self.local_addr, + self.remote_addr, + self.seq.load(Ordering::Relaxed), + 0, + tcp::TcpFlags::RST, + None, + ); + if let Err(e) = self.tun.try_send(&buf) { + warn!("Unable to send RST to remote end: {}", e); + } + + info!("Fake TCP connection to {} closed", self); + } +} + +impl fmt::Display for Socket { + /// User-friendly string representation of the socket + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(Fake TCP connection from {} to {})", + self.local_addr, self.remote_addr + ) + } +} + +/// A userspace TCP state machine +impl Stack { + /// Create a new stack, `tun` is an array of [`Tun`](tokio_tun::Tun). + /// When more than one [`Tun`](tokio_tun::Tun) object is passed in, same amount + /// of reader will be spawned later. This allows user to utilize the performance + /// benefit of Multiqueue Tun support on machines with SMP. + pub fn new( + tun: Arc, + local_ip: Ipv4Addr, + local_ip6: Option, + local_mac: Option, + ) -> Stack { + let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN); + let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16); + let shared = Arc::new(Shared { + tuples: RwLock::new(HashMap::new()), + tun: tun.clone(), + listening: RwLock::new(HashSet::new()), + ready: ready_tx, + tuples_purge: tuples_purge_tx.clone(), + }); + + let t = tokio::spawn(Stack::reader_task( + tun, + shared.clone(), + tuples_purge_tx.subscribe(), + )); + + Stack { + shared, + local_ip, + local_ip6, + local_mac: local_mac.unwrap_or(MacAddr::zero()), + ready: ready_rx, + reader_task: t.into(), + } + } + + /// Returns the driver type of the stack. + pub fn driver_type(&self) -> &'static str { + self.shared.tun.driver_type() + } + + /// Listens for incoming connections on the given `port`. + pub fn listen(&mut self, port: u16) { + assert!(self.shared.listening.write().unwrap().insert(port)); + } + + /// Accepts an incoming connection. + pub async fn accept(&mut self) -> Socket { + self.ready.recv().await.unwrap() + } + + pub async fn alloc_established_socket( + &mut self, + local_addr: SocketAddr, + remote_addr: SocketAddr, + state: State, + ) -> Socket { + let tuple = AddrTuple::new(local_addr, remote_addr); + let mut tuples = self.shared.tuples.write().unwrap(); + let (sock, incoming) = Socket::new( + self.shared.clone(), + // self.shared.tun.choose(&mut rng).unwrap().clone(), + self.shared.tun.clone(), // Simplification: just use the first tun + local_addr, + remote_addr, + self.local_mac, + None, + Some(0), // Initial ACK + state, + ); + assert!(tuples.insert(tuple, incoming).is_none()); + sock + } + + async fn reader_task( + tun: Arc, + shared: Arc, + mut tuples_purge: broadcast::Receiver, + ) { + let mut tuples: HashMap> = HashMap::new(); + + loop { + let mut buf = BytesMut::new(); + + tokio::select! { + size = tun.recv(&mut buf) => { + let size = size.unwrap(); + tracing::trace!(len = size, ?buf, "PnetTun received packet"); + let buf = buf.split().freeze(); + + match parse_ip_packet(&buf) { + Some((_src_mac, _dst_mac, ip_packet, tcp_packet)) => { + let local_addr = SocketAddr::new( + ip_packet.get_destination(), + tcp_packet.get_destination(), + ); + let remote_addr = SocketAddr::new( + ip_packet.get_source(), + tcp_packet.get_source(), + ); + + let tuple = AddrTuple::new(local_addr, remote_addr); + if let Some(c) = tuples.get(&tuple) { + if c.send_async(buf).await.is_err() { + trace!("Cache hit, but receiver already closed, dropping packet"); + } + + continue; + + // If not Ok, receiver has been closed and just fall through to the slow + // path below + } else { + trace!("Cache miss, checking the shared tuples table for connection"); + let sender = { + let tuples = shared.tuples.read().unwrap(); + tuples.get(&tuple).cloned() + }; + + if let Some(c) = sender { + trace!("Storing connection information into local tuples"); + tuples.insert(tuple, c.clone()); + if let Err(e) = c.send_async(buf).await { + trace!("Error sending packet to connection: {:?}", e); + } + continue; + } + } + + if tcp_packet.get_flags() == tcp::TcpFlags::SYN + && shared + .listening + .read() + .unwrap() + .contains(&tcp_packet.get_destination()) + { + trace!(?tcp_packet, "Received SYN packet for port {}, ignoring", tcp_packet.get_destination()); + continue; + } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 { + info!("Unknown RST TCP packet from {}, ignoring", remote_addr); + continue; + } + } + None => { + trace!("Dropping packet with no IP/TCP header"); + continue; + } + } + }, + tuple = tuples_purge.recv() => { + let tuple = tuple.unwrap(); + tuples.remove(&tuple); + trace!("Removed cached tuple: {:?}", tuple); + } + } + } + } +} diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index 5139f9af..20c38e09 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -15,6 +15,7 @@ use self::packet_def::ZCPacket; pub mod buf; pub mod common; +pub mod fake_tcp; pub mod filter; pub mod mpsc; pub mod packet_def; @@ -23,8 +24,14 @@ pub mod stats; pub mod tcp; pub mod udp; -pub const PROTO_PORT_OFFSET: &[(&str, u16)] = - &[("tcp", 0), ("udp", 0), ("wg", 1), ("ws", 1), ("wss", 2)]; +pub const PROTO_PORT_OFFSET: &[(&str, u16)] = &[ + ("tcp", 0), + ("udp", 0), + ("wg", 1), + ("ws", 1), + ("wss", 2), + ("faketcp", 3), +]; #[cfg(feature = "wireguard")] pub mod wireguard; @@ -139,7 +146,9 @@ pub trait TunnelConnector: Send { pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url { if let Ok(sock_addr) = addr.parse::() { - let mut ret_url = url::Url::parse(format!("{}://0.0.0.0", scheme).as_str()).unwrap(); + let url_str = format!("{}://0.0.0.0", scheme); + let mut ret_url = url::Url::parse(url_str.as_str()) + .unwrap_or_else(|_| panic!("invalid url: {}", url_str)); ret_url.set_ip_host(sock_addr.ip()).unwrap(); ret_url.set_port(Some(sock_addr.port())).unwrap(); ret_url @@ -200,6 +209,7 @@ fn default_port(scheme: &str) -> Option { "udp" => Some(11010), "ws" => Some(11011), "wss" => Some(11012), + "faketcp" => Some(11013), "quic" => Some(11012), "wg" => Some(11011), _ => None, diff --git a/easytier/third_party/arm64/WinDivert64.sys b/easytier/third_party/arm64/WinDivert64.sys new file mode 100644 index 00000000..ea9865f3 --- /dev/null +++ b/easytier/third_party/arm64/WinDivert64.sys @@ -0,0 +1 @@ +WinDivert doesn't support aarch64, this is a placeholder file to make tauri happy. diff --git a/easytier/third_party/i686/WinDivert32.sys b/easytier/third_party/i686/WinDivert32.sys new file mode 100644 index 00000000..d06738cb Binary files /dev/null and b/easytier/third_party/i686/WinDivert32.sys differ diff --git a/easytier/third_party/Packet.dll b/easytier/third_party/x86_64/Packet.dll similarity index 100% rename from easytier/third_party/Packet.dll rename to easytier/third_party/x86_64/Packet.dll diff --git a/easytier/third_party/Packet.lib b/easytier/third_party/x86_64/Packet.lib similarity index 100% rename from easytier/third_party/Packet.lib rename to easytier/third_party/x86_64/Packet.lib diff --git a/easytier/third_party/x86_64/WinDivert64.sys b/easytier/third_party/x86_64/WinDivert64.sys new file mode 100644 index 00000000..218ccaf4 Binary files /dev/null and b/easytier/third_party/x86_64/WinDivert64.sys differ diff --git a/easytier/third_party/wintun.dll b/easytier/third_party/x86_64/wintun.dll similarity index 100% rename from easytier/third_party/wintun.dll rename to easytier/third_party/x86_64/wintun.dll