mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-06 17:59:11 +00:00
Add fake tcp tunnel (experimental) (#1673)
support faketcp to avoid tcp-over-tcp problem. linux/macos/windows are supported. better to be used in internet env, the maximum performance is majorly limited by windivert/raw socket.
This commit is contained in:
@@ -12,8 +12,9 @@ use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector},
|
||||
tunnel::{
|
||||
check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector,
|
||||
udp::UdpTunnelConnector, IpVersion, TunnelConnector,
|
||||
check_scheme_and_get_socket_addr, fake_tcp::FakeTcpTunnelConnector,
|
||||
ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector, IpVersion,
|
||||
TunnelConnector,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -157,6 +158,20 @@ pub async fn create_connector_by_url(
|
||||
let connector = dns_connector::DNSTunnelConnector::new(url, global_ctx.clone());
|
||||
Box::new(connector)
|
||||
}
|
||||
"faketcp" => {
|
||||
let dst_addr =
|
||||
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "faketcp", ip_version).await?;
|
||||
let mut connector = FakeTcpTunnelConnector::new(url);
|
||||
if global_ctx.config.get_flags().bind_device {
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
dst_addr.is_ipv4(),
|
||||
&global_ctx.get_ip_collector(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Box::new(connector)
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidUrl(url.into()));
|
||||
}
|
||||
|
||||
@@ -21,8 +21,8 @@ use crate::{
|
||||
},
|
||||
peers::peer_manager::PeerManager,
|
||||
tunnel::{
|
||||
ring::RingTunnelListener, tcp::TcpTunnelListener, udp::UdpTunnelListener, Tunnel,
|
||||
TunnelListener,
|
||||
fake_tcp::FakeTcpTunnelListener, ring::RingTunnelListener, tcp::TcpTunnelListener,
|
||||
udp::UdpTunnelListener, Tunnel, TunnelListener,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -49,6 +49,7 @@ pub fn get_listener_by_url(
|
||||
use crate::tunnel::websocket::WSTunnelListener;
|
||||
Box::new(WSTunnelListener::new(l.clone()))
|
||||
}
|
||||
"faketcp" => Box::new(FakeTcpTunnelListener::new(l.clone())),
|
||||
_ => {
|
||||
return Err(Error::InvalidUrl(l.to_string()));
|
||||
}
|
||||
@@ -143,7 +144,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
|
||||
&& !is_url_host_ipv6(&l)
|
||||
&& is_url_host_unspecified(&l)
|
||||
// quic enables dual-stack by default, may conflict with v4 listener
|
||||
&& l.scheme() != "quic"
|
||||
&& l.scheme() != "quic" && l.scheme() != "faketcp"
|
||||
{
|
||||
let mut ipv6_listener = l.clone();
|
||||
ipv6_listener
|
||||
|
||||
@@ -223,8 +223,8 @@ impl PeerConn {
|
||||
|
||||
if peer_mgr_hdr.packet_type != PacketType::HandShake as u8 {
|
||||
return Err(Error::WaitRespError(format!(
|
||||
"unexpected packet type: {:?}",
|
||||
peer_mgr_hdr.packet_type
|
||||
"unexpected packet type: {:?}, packet: {:?}",
|
||||
peer_mgr_hdr.packet_type, rsp
|
||||
)));
|
||||
}
|
||||
|
||||
|
||||
@@ -534,6 +534,11 @@ pub mod tests {
|
||||
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||
println!("connect: {:?}", tunnel.info());
|
||||
|
||||
if connector.remote_url().scheme() == "faketcp" {
|
||||
// listener need some time to start capturing packet
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
url::Url::from(tunnel.info().unwrap().remote_addr.unwrap()),
|
||||
connector.remote_url(),
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
Copyright 2021-2025 Datong Sun dndx@idndx.com
|
||||
|
||||
Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
https://www.apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
|
||||
https://opensource.org/licenses/MIT>, at your option. Files in the project may
|
||||
not be copied, modified, or distributed except according to those terms.
|
||||
@@ -0,0 +1,482 @@
|
||||
mod netfilter;
|
||||
mod packet;
|
||||
mod stack;
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, UdpSocket};
|
||||
use std::sync::Arc;
|
||||
use std::{net::SocketAddr, pin::Pin};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use pnet::datalink;
|
||||
use pnet::util::MacAddr;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::common::scoped_task::ScopedTask;
|
||||
use crate::tunnel::fake_tcp::netfilter::create_tun;
|
||||
use crate::tunnel::{common::TunnelWrapper, Tunnel, TunnelError, TunnelInfo, TunnelListener};
|
||||
|
||||
use futures::Future;
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
struct IpToIfNameCache {
|
||||
ip_to_ifname: DashMap<IpAddr, (String, Option<MacAddr>)>,
|
||||
}
|
||||
|
||||
impl IpToIfNameCache {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
ip_to_ifname: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reload_ip_to_ifname(&self) {
|
||||
self.ip_to_ifname.clear();
|
||||
let interfaces = datalink::interfaces();
|
||||
for iface in interfaces {
|
||||
for ip in iface.ips.iter() {
|
||||
self.ip_to_ifname
|
||||
.insert(ip.ip(), (iface.name.clone(), iface.mac));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_ifname(&self, ip: &IpAddr) -> Option<(String, Option<MacAddr>)> {
|
||||
if let Some(ifname) = self.ip_to_ifname.get(ip) {
|
||||
Some(ifname.clone())
|
||||
} else {
|
||||
self.reload_ip_to_ifname();
|
||||
self.ip_to_ifname.get(ip).map(|s| s.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_faketcp_tunnel_type_str(driver_type: &str) -> String {
|
||||
format!("faketcp_{}", driver_type)
|
||||
}
|
||||
|
||||
pub struct FakeTcpTunnelListener {
|
||||
addr: url::Url,
|
||||
os_listener: Option<tokio::net::TcpListener>,
|
||||
// interface_name -> fake tcp stack
|
||||
stack_map: DashMap<String, Arc<Mutex<stack::Stack>>>,
|
||||
// a cache from ip addr to interface name
|
||||
ip_to_ifname: IpToIfNameCache,
|
||||
}
|
||||
|
||||
impl FakeTcpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
// Define filter: Capture all packets (or refine this if needed)
|
||||
// For FakeTCP, we probably want to capture packets destined to us?
|
||||
// But `stack::Stack` handles IP/TCP logic.
|
||||
// Maybe we just capture everything for now as a raw tunnel?
|
||||
// Or better, filter based on some criteria?
|
||||
// The user said "satisfy filter function".
|
||||
// Let's create a filter that accepts everything for now, or maybe only IP packets?
|
||||
FakeTcpTunnelListener {
|
||||
addr,
|
||||
os_listener: None,
|
||||
stack_map: DashMap::new(),
|
||||
ip_to_ifname: IpToIfNameCache::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_accept(&mut self) -> Result<AcceptResult, TunnelError> {
|
||||
loop {
|
||||
match self.os_listener.as_mut().unwrap().accept().await {
|
||||
Ok((s, remote_addr)) => {
|
||||
let Ok(local_addr) = s.local_addr() else {
|
||||
tracing::warn!("accept fail with local_addr error");
|
||||
continue;
|
||||
};
|
||||
let Some((interface_name, mac)) =
|
||||
self.ip_to_ifname.get_ifname(&local_addr.ip())
|
||||
else {
|
||||
tracing::warn!("accept fail with interface_name error");
|
||||
continue;
|
||||
};
|
||||
return Ok(AcceptResult {
|
||||
socket: s,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
interface_name,
|
||||
mac,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
use std::io::ErrorKind::*;
|
||||
if matches!(
|
||||
e.kind(),
|
||||
NotConnected | ConnectionAborted | ConnectionRefused | ConnectionReset
|
||||
) {
|
||||
tracing::warn!(?e, "accept fail with retryable error: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
tracing::warn!(?e, "accept fail");
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_stack(
|
||||
&self,
|
||||
accept_result: &AcceptResult,
|
||||
) -> Result<Arc<Mutex<stack::Stack>>, TunnelError> {
|
||||
let local_socket_addr = accept_result.local_addr;
|
||||
|
||||
let interface_name = &accept_result.interface_name;
|
||||
|
||||
let (local_ip, local_ip6) = match local_socket_addr.ip() {
|
||||
IpAddr::V4(ip) => (Some(ip), None),
|
||||
IpAddr::V6(ip) => (None, Some(ip)),
|
||||
};
|
||||
|
||||
let ret = self
|
||||
.stack_map
|
||||
.entry(interface_name.to_string())
|
||||
.or_insert_with(|| {
|
||||
let tun = create_tun(interface_name, None, local_socket_addr);
|
||||
tracing::info!(
|
||||
?local_socket_addr,
|
||||
"create new stack with interface_name: {:?}",
|
||||
interface_name
|
||||
);
|
||||
// TODO: Get local MAC address of the interface
|
||||
Arc::new(Mutex::new(stack::Stack::new(
|
||||
tun,
|
||||
local_ip.unwrap_or(Ipv4Addr::UNSPECIFIED),
|
||||
local_ip6,
|
||||
accept_result.mac,
|
||||
)))
|
||||
})
|
||||
.clone();
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_os_socket_reader_task(mut socket: TcpStream) -> ScopedTask<()> {
|
||||
let os_socket_reader_task: ScopedTask<()> = tokio::spawn(async move {
|
||||
// read the os socket until it's closed
|
||||
let mut buf = [0u8; 1024];
|
||||
while let Ok(size) = socket.read(&mut buf).await {
|
||||
tracing::trace!("read {} bytes from os socket", size);
|
||||
if size == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
tracing::info!("FakeTcpTunnelListener os socket closed");
|
||||
})
|
||||
.into();
|
||||
os_socket_reader_task
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AcceptResult {
|
||||
socket: TcpStream,
|
||||
local_addr: SocketAddr,
|
||||
remote_addr: SocketAddr,
|
||||
interface_name: String,
|
||||
mac: Option<MacAddr>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelListener for FakeTcpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
let port = self.addr.port().unwrap_or(0);
|
||||
let bind_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
|
||||
&self.addr,
|
||||
"faketcp",
|
||||
crate::tunnel::IpVersion::Both,
|
||||
)
|
||||
.await?;
|
||||
let os_listener = tokio::net::TcpListener::bind(bind_addr).await?;
|
||||
tracing::info!(port, "FakeTcpTunnelListener listening");
|
||||
self.os_listener = Some(os_listener);
|
||||
// self.stack.lock().await.listen(port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
tracing::debug!("FakeTcpTunnelListener waiting for accept");
|
||||
let res = self.do_accept().await?;
|
||||
let stack = self.get_stack(&res).await?;
|
||||
let socket = stack
|
||||
.lock()
|
||||
.await
|
||||
.alloc_established_socket(res.local_addr, res.remote_addr, stack::State::Established)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
?res,
|
||||
remote = socket.remote_addr().to_string(),
|
||||
"FakeTcpTunnelListener accepted connection"
|
||||
);
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: get_faketcp_tunnel_type_str(stack.lock().await.driver_type()),
|
||||
local_addr: Some(self.local_url().into()),
|
||||
remote_addr: Some(
|
||||
crate::tunnel::build_url_from_socket_addr(
|
||||
&socket.remote_addr().to_string(),
|
||||
"faketcp",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
};
|
||||
|
||||
// We treat the fake tcp socket as a datagram tunnel directly
|
||||
// The reader/writer will interface with the socket using recv_bytes/send
|
||||
// We need to adapt the socket to ZCPacketStream and ZCPacketSink
|
||||
|
||||
// Since FakeTcpTunnel is a datagram tunnel, we don't need FramedReader/Writer (which are for stream based tunnels like TCP)
|
||||
// We should wrap the socket into something that produces/consumes ZCPacket directly.
|
||||
|
||||
let socket = Arc::new(socket);
|
||||
let reader = FakeTcpStream::new(socket.clone());
|
||||
let writer = FakeTcpSink::new(socket);
|
||||
|
||||
Ok(Box::new(TunnelWrapper::new_with_associate_data(
|
||||
reader,
|
||||
writer,
|
||||
Some(info),
|
||||
Some(Box::new(build_os_socket_reader_task(res.socket))),
|
||||
)))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeTcpTunnelConnector {
|
||||
addr: url::Url,
|
||||
ip_to_if_name: IpToIfNameCache,
|
||||
}
|
||||
|
||||
impl FakeTcpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
FakeTcpTunnelConnector {
|
||||
addr,
|
||||
ip_to_if_name: IpToIfNameCache::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_local_ip_for_destination(destination: IpAddr) -> Option<IpAddr> {
|
||||
// 使用一个不可路由的、私有的、或回环地址创建一个临时的 socket,让内核自动选择源接口。
|
||||
// 对于 IPv4,使用 0.0.0.0; 对于 IPv6,使用 ::
|
||||
let bind_addr = if destination.is_ipv4() {
|
||||
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
|
||||
} else {
|
||||
IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0))
|
||||
};
|
||||
|
||||
// 绑定到一个临时端口 (0)
|
||||
let socket = UdpSocket::bind((bind_addr, 0)).ok()?;
|
||||
|
||||
// 尝试连接到目标地址。这不会真正发送数据包,只是让内核确定路由。
|
||||
socket.connect((destination, 80)).ok()?; // 使用一个常见的端口,例如 80
|
||||
|
||||
// 获取 socket 的本地地址信息
|
||||
socket.local_addr().map(|addr| addr.ip()).ok()
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
let remote_addr = crate::tunnel::check_scheme_and_get_socket_addr::<SocketAddr>(
|
||||
&self.addr,
|
||||
"faketcp",
|
||||
crate::tunnel::IpVersion::Both,
|
||||
)
|
||||
.await?;
|
||||
let local_ip = get_local_ip_for_destination(remote_addr.ip())
|
||||
.ok_or(TunnelError::InternalError("Failed to get local ip".into()))?;
|
||||
|
||||
let os_socket = tokio::net::TcpSocket::new_v4()?;
|
||||
os_socket.bind("0.0.0.0:0".parse().unwrap())?;
|
||||
let local_port = os_socket.local_addr()?.port();
|
||||
let local_addr = SocketAddr::new(local_ip, local_port);
|
||||
|
||||
let (interface_name, mac) =
|
||||
self.ip_to_if_name
|
||||
.get_ifname(&local_ip)
|
||||
.ok_or(TunnelError::InternalError(
|
||||
"Failed to get interface name".into(),
|
||||
))?;
|
||||
|
||||
let (local_ip, local_ip6) = match local_ip {
|
||||
IpAddr::V4(ip) => (Some(ip), None),
|
||||
IpAddr::V6(ip) => (None, Some(ip)),
|
||||
};
|
||||
|
||||
let tun = create_tun(&interface_name, Some(remote_addr), local_addr);
|
||||
let local_ip = local_ip.unwrap_or("0.0.0.0".parse().unwrap());
|
||||
let mut stack = stack::Stack::new(tun, local_ip, local_ip6, mac);
|
||||
let driver_type = stack.driver_type();
|
||||
|
||||
let socket = stack
|
||||
.alloc_established_socket(local_addr, remote_addr, stack::State::SynSent)
|
||||
.await;
|
||||
|
||||
let os_stream = os_socket.connect(remote_addr).await?;
|
||||
|
||||
tracing::info!(?remote_addr, "FakeTcpTunnelConnector connecting");
|
||||
|
||||
socket.recv_bytes().await.ok_or(TunnelError::InternalError(
|
||||
"Failed to recv bytes to establish connection".into(),
|
||||
))?;
|
||||
|
||||
tracing::info!(local_addr = ?socket.local_addr(), "FakeTcpTunnelConnector connected");
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: get_faketcp_tunnel_type_str(driver_type),
|
||||
local_addr: Some(
|
||||
crate::tunnel::build_url_from_socket_addr(
|
||||
&socket.local_addr().to_string(),
|
||||
"faketcp",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
remote_addr: Some(self.addr.clone().into()),
|
||||
};
|
||||
|
||||
let socket = Arc::new(socket);
|
||||
let reader = FakeTcpStream::new(socket.clone());
|
||||
let writer = FakeTcpSink::new(socket);
|
||||
|
||||
Ok(Box::new(TunnelWrapper::new_with_associate_data(
|
||||
reader,
|
||||
writer,
|
||||
Some(info),
|
||||
Some(Box::new((build_os_socket_reader_task(os_stream), stack))),
|
||||
)))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
use crate::tunnel::packet_def::{ZCPacket, ZCPacketType};
|
||||
use crate::tunnel::{SinkError, SinkItem, StreamItem};
|
||||
use futures::{Sink, Stream};
|
||||
use std::task::{Context as TaskContext, Poll};
|
||||
|
||||
struct FakeTcpStream {
|
||||
socket: Arc<stack::Socket>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
recv_fut: Option<Pin<Box<dyn Future<Output = Option<Vec<u8>>> + Send + Sync>>>,
|
||||
}
|
||||
|
||||
impl FakeTcpStream {
|
||||
fn new(socket: Arc<stack::Socket>) -> Self {
|
||||
Self {
|
||||
socket,
|
||||
recv_fut: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for FakeTcpStream {
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
|
||||
let s = self.get_mut();
|
||||
if s.recv_fut.is_none() {
|
||||
let socket = s.socket.clone();
|
||||
s.recv_fut = Some(Box::pin(async move { socket.recv_bytes().await }));
|
||||
}
|
||||
|
||||
match s.recv_fut.as_mut().unwrap().as_mut().poll(cx) {
|
||||
Poll::Ready(Some(data)) => {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.extend_from_slice(&data);
|
||||
let packet = ZCPacket::new_from_buf(buf, ZCPacketType::DummyTunnel);
|
||||
|
||||
s.recv_fut = None;
|
||||
|
||||
Poll::Ready(Some(Ok(packet)))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// 连接关闭
|
||||
s.recv_fut = None;
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FakeTcpSink {
|
||||
socket: Arc<stack::Socket>,
|
||||
}
|
||||
|
||||
impl FakeTcpSink {
|
||||
fn new(socket: Arc<stack::Socket>) -> Self {
|
||||
Self { socket }
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<SinkItem> for FakeTcpSink {
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut TaskContext<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
// We need to send the packet as bytes
|
||||
// The item is ZCPacket, which has into_bytes() method
|
||||
let bytes = item.convert_type(ZCPacketType::DummyTunnel).into_bytes();
|
||||
|
||||
// Let's just spawn for now as a simple implementation, noting the limitation.
|
||||
self.socket.try_send(&bytes);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut TaskContext<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut TaskContext<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.socket.close();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnel::common::tests::_tunnel_pingpong;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn faketcp_pingpong() {
|
||||
#[cfg(target_family = "unix")]
|
||||
{
|
||||
if unsafe { nix::libc::geteuid() } != 0 {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let listener = FakeTcpTunnelListener::new("faketcp://0.0.0.0:31011".parse().unwrap());
|
||||
let connector = FakeTcpTunnelConnector::new("faketcp://127.0.0.1:31011".parse().unwrap());
|
||||
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,708 @@
|
||||
use bytes::Bytes;
|
||||
use bytes::BytesMut;
|
||||
use nix::libc;
|
||||
use std::ffi::CString;
|
||||
use std::io;
|
||||
use std::mem;
|
||||
use std::net::IpAddr;
|
||||
use std::net::SocketAddr;
|
||||
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
|
||||
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::tunnel::fake_tcp::stack;
|
||||
|
||||
const ETH_HDR_LEN: usize = 14;
|
||||
const ETH_TYPE_OFFSET: u32 = 12;
|
||||
const ETHERTYPE_IPV4: u32 = 0x0800;
|
||||
const ETHERTYPE_IPV6: u32 = 0x86DD;
|
||||
const IPPROTO_TCP_U32: u32 = 6;
|
||||
|
||||
const BPF_LD: u16 = 0x00;
|
||||
const BPF_LDX: u16 = 0x01;
|
||||
const BPF_JMP: u16 = 0x05;
|
||||
const BPF_RET: u16 = 0x06;
|
||||
|
||||
const BPF_W: u16 = 0x00;
|
||||
const BPF_H: u16 = 0x08;
|
||||
const BPF_B: u16 = 0x10;
|
||||
|
||||
const BPF_ABS: u16 = 0x20;
|
||||
const BPF_IND: u16 = 0x40;
|
||||
const BPF_MSH: u16 = 0xa0;
|
||||
|
||||
const BPF_JA: u16 = 0x00;
|
||||
const BPF_JEQ: u16 = 0x10;
|
||||
|
||||
const BPF_K: u16 = 0x00;
|
||||
|
||||
fn stmt(code: u16, k: u32) -> libc::sock_filter {
|
||||
libc::sock_filter {
|
||||
code,
|
||||
jt: 0,
|
||||
jf: 0,
|
||||
k,
|
||||
}
|
||||
}
|
||||
|
||||
fn jeq(k: u32, jt: u8, jf: u8) -> libc::sock_filter {
|
||||
libc::sock_filter {
|
||||
code: BPF_JMP | BPF_JEQ | BPF_K,
|
||||
jt,
|
||||
jf,
|
||||
k,
|
||||
}
|
||||
}
|
||||
|
||||
fn ja(k: u32) -> libc::sock_filter {
|
||||
libc::sock_filter {
|
||||
code: BPF_JMP | BPF_JA,
|
||||
jt: 0,
|
||||
jf: 0,
|
||||
k,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct Label(usize);
|
||||
|
||||
struct JeqPatch {
|
||||
idx: usize,
|
||||
t: Label,
|
||||
f: Label,
|
||||
}
|
||||
|
||||
struct JaPatch {
|
||||
idx: usize,
|
||||
target: Label,
|
||||
}
|
||||
|
||||
struct BpfBuilder {
|
||||
insns: Vec<libc::sock_filter>,
|
||||
labels: Vec<Option<usize>>,
|
||||
jeq_patches: Vec<JeqPatch>,
|
||||
ja_patches: Vec<JaPatch>,
|
||||
}
|
||||
|
||||
impl BpfBuilder {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
insns: Vec::new(),
|
||||
labels: Vec::new(),
|
||||
jeq_patches: Vec::new(),
|
||||
ja_patches: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_label(&mut self) -> Label {
|
||||
let idx = self.labels.len();
|
||||
self.labels.push(None);
|
||||
Label(idx)
|
||||
}
|
||||
|
||||
fn set_label(&mut self, label: Label) {
|
||||
self.labels[label.0] = Some(self.insns.len());
|
||||
}
|
||||
|
||||
fn push(&mut self, insn: libc::sock_filter) {
|
||||
self.insns.push(insn);
|
||||
}
|
||||
|
||||
fn push_jeq(&mut self, k: u32, t: Label, f: Label) {
|
||||
let idx = self.insns.len();
|
||||
self.insns.push(jeq(k, 0, 0));
|
||||
self.jeq_patches.push(JeqPatch { idx, t, f });
|
||||
}
|
||||
|
||||
fn push_ja(&mut self, target: Label) {
|
||||
let idx = self.insns.len();
|
||||
self.insns.push(ja(0));
|
||||
self.ja_patches.push(JaPatch { idx, target });
|
||||
}
|
||||
|
||||
fn finish(mut self) -> io::Result<Vec<libc::sock_filter>> {
|
||||
for patch in self.jeq_patches {
|
||||
let JeqPatch { idx, t, f } = patch;
|
||||
let cur = idx + 1;
|
||||
let t_pos =
|
||||
self.labels.get(t.0).and_then(|v| *v).ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "unresolved label")
|
||||
})?;
|
||||
let f_pos =
|
||||
self.labels.get(f.0).and_then(|v| *v).ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "unresolved label")
|
||||
})?;
|
||||
|
||||
if t_pos < cur || f_pos < cur {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"backward bpf jump",
|
||||
));
|
||||
}
|
||||
|
||||
let jt: u8 = (t_pos - cur)
|
||||
.try_into()
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?;
|
||||
let jf: u8 = (f_pos - cur)
|
||||
.try_into()
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf jump too far"))?;
|
||||
|
||||
self.insns[idx].jt = jt;
|
||||
self.insns[idx].jf = jf;
|
||||
}
|
||||
|
||||
for patch in self.ja_patches {
|
||||
let JaPatch { idx, target } = patch;
|
||||
let cur = idx + 1;
|
||||
let t_pos =
|
||||
self.labels.get(target.0).and_then(|v| *v).ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "unresolved label")
|
||||
})?;
|
||||
if t_pos < cur {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"backward bpf jump",
|
||||
));
|
||||
}
|
||||
self.insns[idx].k = (t_pos - cur) as u32;
|
||||
}
|
||||
|
||||
Ok(self.insns)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_tcp_filter(
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> io::Result<Vec<libc::sock_filter>> {
|
||||
if let Some(src) = src_addr {
|
||||
if src.is_ipv4() != dst_addr.is_ipv4() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"src/dst addr family mismatch",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let mut b = BpfBuilder::new();
|
||||
let l_check_ipv6 = b.new_label();
|
||||
let l_ipv4 = b.new_label();
|
||||
let l_ipv6 = b.new_label();
|
||||
let l_accept = b.new_label();
|
||||
let l_reject = b.new_label();
|
||||
|
||||
b.push(stmt(BPF_LD | BPF_H | BPF_ABS, ETH_TYPE_OFFSET));
|
||||
b.push_jeq(ETHERTYPE_IPV4, l_ipv4, l_check_ipv6);
|
||||
|
||||
b.set_label(l_check_ipv6);
|
||||
b.push_jeq(ETHERTYPE_IPV6, l_ipv6, l_reject);
|
||||
|
||||
if dst_addr.is_ipv4() {
|
||||
b.set_label(l_ipv4);
|
||||
let l_v4_proto_ok = b.new_label();
|
||||
b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 9) as u32));
|
||||
b.push_jeq(IPPROTO_TCP_U32, l_v4_proto_ok, l_reject);
|
||||
|
||||
b.set_label(l_v4_proto_ok);
|
||||
let dst_ip = match dst_addr.ip() {
|
||||
IpAddr::V4(ip) => u32::from(ip),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let l_v4_dstip_ok = b.new_label();
|
||||
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 16) as u32));
|
||||
b.push_jeq(dst_ip, l_v4_dstip_ok, l_reject);
|
||||
|
||||
b.set_label(l_v4_dstip_ok);
|
||||
if let Some(src) = src_addr {
|
||||
let src_ip = match src.ip() {
|
||||
IpAddr::V4(ip) => u32::from(ip),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let l_v4_srcip_ok = b.new_label();
|
||||
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, (ETH_HDR_LEN + 12) as u32));
|
||||
b.push_jeq(src_ip, l_v4_srcip_ok, l_reject);
|
||||
b.set_label(l_v4_srcip_ok);
|
||||
}
|
||||
|
||||
b.push(stmt(BPF_LDX | BPF_B | BPF_MSH, ETH_HDR_LEN as u32));
|
||||
|
||||
let l_v4_dstport_ok = b.new_label();
|
||||
b.push(stmt(BPF_LD | BPF_H | BPF_IND, (ETH_HDR_LEN + 2) as u32));
|
||||
b.push_jeq(dst_addr.port() as u32, l_v4_dstport_ok, l_reject);
|
||||
|
||||
b.set_label(l_v4_dstport_ok);
|
||||
if let Some(src) = src_addr {
|
||||
b.push(stmt(BPF_LD | BPF_H | BPF_IND, ETH_HDR_LEN as u32));
|
||||
b.push_jeq(src.port() as u32, l_accept, l_reject);
|
||||
} else {
|
||||
b.push_ja(l_accept);
|
||||
}
|
||||
} else {
|
||||
b.set_label(l_ipv6);
|
||||
let l_v6_proto_ok = b.new_label();
|
||||
b.push(stmt(BPF_LD | BPF_B | BPF_ABS, (ETH_HDR_LEN + 6) as u32));
|
||||
b.push_jeq(IPPROTO_TCP_U32, l_v6_proto_ok, l_reject);
|
||||
|
||||
b.set_label(l_v6_proto_ok);
|
||||
let dst_ip = match dst_addr.ip() {
|
||||
IpAddr::V6(ip) => ip.octets(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
for (i, chunk) in dst_ip.chunks_exact(4).enumerate() {
|
||||
let off = ETH_HDR_LEN + 24 + (i * 4);
|
||||
let v = u32::from_be_bytes(chunk.try_into().unwrap());
|
||||
let l_v6_dstip_word_ok = b.new_label();
|
||||
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32));
|
||||
b.push_jeq(v, l_v6_dstip_word_ok, l_reject);
|
||||
b.set_label(l_v6_dstip_word_ok);
|
||||
}
|
||||
|
||||
if let Some(src) = src_addr {
|
||||
let src_ip = match src.ip() {
|
||||
IpAddr::V6(ip) => ip.octets(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
for (i, chunk) in src_ip.chunks_exact(4).enumerate() {
|
||||
let off = ETH_HDR_LEN + 8 + (i * 4);
|
||||
let v = u32::from_be_bytes(chunk.try_into().unwrap());
|
||||
let l_v6_srcip_word_ok = b.new_label();
|
||||
b.push(stmt(BPF_LD | BPF_W | BPF_ABS, off as u32));
|
||||
b.push_jeq(v, l_v6_srcip_word_ok, l_reject);
|
||||
b.set_label(l_v6_srcip_word_ok);
|
||||
}
|
||||
}
|
||||
|
||||
let l_v6_dstport_ok = b.new_label();
|
||||
b.push(stmt(
|
||||
BPF_LD | BPF_H | BPF_ABS,
|
||||
(ETH_HDR_LEN + 40 + 2) as u32,
|
||||
));
|
||||
b.push_jeq(dst_addr.port() as u32, l_v6_dstport_ok, l_reject);
|
||||
|
||||
b.set_label(l_v6_dstport_ok);
|
||||
if let Some(src) = src_addr {
|
||||
b.push(stmt(BPF_LD | BPF_H | BPF_ABS, (ETH_HDR_LEN + 40) as u32));
|
||||
b.push_jeq(src.port() as u32, l_accept, l_reject);
|
||||
} else {
|
||||
b.push_ja(l_accept);
|
||||
}
|
||||
}
|
||||
|
||||
b.set_label(l_accept);
|
||||
b.push(stmt(BPF_RET | BPF_K, 0xFFFF));
|
||||
|
||||
b.set_label(l_reject);
|
||||
if dst_addr.is_ipv4() {
|
||||
b.set_label(l_ipv6);
|
||||
} else {
|
||||
b.set_label(l_ipv4);
|
||||
}
|
||||
b.push(stmt(BPF_RET | BPF_K, 0));
|
||||
|
||||
b.finish()
|
||||
}
|
||||
|
||||
pub struct LinuxBpfTun {
|
||||
fd: OwnedFd,
|
||||
ifindex: i32,
|
||||
stop: Arc<AtomicBool>,
|
||||
worker: Option<std::thread::JoinHandle<()>>,
|
||||
recv_queue: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl LinuxBpfTun {
|
||||
pub fn new(
|
||||
interface_name: &str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> io::Result<Self> {
|
||||
let c_ifname = CString::new(interface_name)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid interface name"))?;
|
||||
let ifindex = unsafe { libc::if_nametoindex(c_ifname.as_ptr()) as i32 };
|
||||
if ifindex <= 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"interface not found",
|
||||
));
|
||||
}
|
||||
|
||||
let proto: i32 = (libc::ETH_P_ALL as u16).to_be() as i32;
|
||||
let fd = unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_RAW, proto) };
|
||||
if fd < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
let fd = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||
|
||||
let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() };
|
||||
addr.sll_family = libc::AF_PACKET as u16;
|
||||
addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be();
|
||||
addr.sll_ifindex = ifindex;
|
||||
|
||||
let bind_ret = unsafe {
|
||||
libc::bind(
|
||||
fd.as_raw_fd(),
|
||||
&addr as *const _ as *const libc::sockaddr,
|
||||
mem::size_of::<libc::sockaddr_ll>() as u32,
|
||||
)
|
||||
};
|
||||
if bind_ret != 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
let filter = build_tcp_filter(src_addr, dst_addr)?;
|
||||
let mut prog = libc::sock_fprog {
|
||||
len: filter
|
||||
.len()
|
||||
.try_into()
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bpf program too long"))?,
|
||||
filter: filter.as_ptr() as *mut libc::sock_filter,
|
||||
};
|
||||
let opt_ret = unsafe {
|
||||
libc::setsockopt(
|
||||
fd.as_raw_fd(),
|
||||
libc::SOL_SOCKET,
|
||||
libc::SO_ATTACH_FILTER,
|
||||
&mut prog as *mut _ as *mut libc::c_void,
|
||||
mem::size_of::<libc::sock_fprog>() as u32,
|
||||
)
|
||||
};
|
||||
if opt_ret != 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
let timeout = libc::timeval {
|
||||
tv_sec: 0,
|
||||
tv_usec: 200_000,
|
||||
};
|
||||
let _ = unsafe {
|
||||
libc::setsockopt(
|
||||
fd.as_raw_fd(),
|
||||
libc::SOL_SOCKET,
|
||||
libc::SO_RCVTIMEO,
|
||||
&timeout as *const _ as *const libc::c_void,
|
||||
mem::size_of::<libc::timeval>() as u32,
|
||||
)
|
||||
};
|
||||
|
||||
let stop = Arc::new(AtomicBool::new(false));
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1024);
|
||||
let stop_clone = stop.clone();
|
||||
let read_fd = fd.as_raw_fd();
|
||||
|
||||
let worker = std::thread::spawn(move || {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
while !stop_clone.load(AtomicOrdering::Relaxed) {
|
||||
let n = unsafe {
|
||||
libc::recv(read_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0)
|
||||
};
|
||||
if n < 0 {
|
||||
let err = io::Error::last_os_error();
|
||||
if matches!(
|
||||
err.kind(),
|
||||
io::ErrorKind::Interrupted | io::ErrorKind::WouldBlock
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if n == 0 {
|
||||
continue;
|
||||
}
|
||||
let data = buf[..(n as usize)].to_vec();
|
||||
if tx.blocking_send(data).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tracing::info!(
|
||||
interface_name,
|
||||
ifindex,
|
||||
"LinuxBpfTun created with filter {:?}",
|
||||
filter
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
fd,
|
||||
ifindex,
|
||||
stop,
|
||||
worker: Some(worker),
|
||||
recv_queue: Mutex::new(rx),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LinuxBpfTun {
|
||||
fn drop(&mut self) {
|
||||
self.stop.store(true, AtomicOrdering::Relaxed);
|
||||
let _ = unsafe { libc::shutdown(self.fd.as_raw_fd(), libc::SHUT_RD) };
|
||||
if let Some(worker) = self.worker.take() {
|
||||
let _ = worker.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl stack::Tun for LinuxBpfTun {
|
||||
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error> {
|
||||
let mut rx = self.recv_queue.lock().await;
|
||||
match rx.recv().await {
|
||||
Some(data) => {
|
||||
packet.extend_from_slice(&data);
|
||||
Ok(data.len())
|
||||
}
|
||||
None => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"LinuxBpfTun channel closed",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> {
|
||||
if packet.len() < 6 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"packet too short",
|
||||
));
|
||||
}
|
||||
|
||||
let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() };
|
||||
addr.sll_family = libc::AF_PACKET as u16;
|
||||
addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be();
|
||||
addr.sll_ifindex = self.ifindex;
|
||||
addr.sll_halen = 6;
|
||||
addr.sll_addr[..6].copy_from_slice(&packet[..6]);
|
||||
|
||||
let ret = unsafe {
|
||||
libc::sendto(
|
||||
self.fd.as_raw_fd(),
|
||||
packet.as_ptr() as *const libc::c_void,
|
||||
packet.len(),
|
||||
0,
|
||||
&addr as *const _ as *const libc::sockaddr,
|
||||
mem::size_of::<libc::sockaddr_ll>() as u32,
|
||||
)
|
||||
};
|
||||
if ret < 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn driver_type(&self) -> &'static str {
|
||||
"linux_bpf"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, target_os = "linux"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::tunnel::fake_tcp::packet::build_tcp_packet;
|
||||
use crate::tunnel::fake_tcp::stack::Tun;
|
||||
use pnet::datalink;
|
||||
use pnet::packet::tcp::TcpFlags;
|
||||
use pnet::util::MacAddr;
|
||||
use rand::Rng;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
fn is_root() -> bool {
|
||||
unsafe { libc::geteuid() == 0 }
|
||||
}
|
||||
|
||||
fn pick_interface_v4() -> Option<(String, Ipv4Addr, MacAddr)> {
|
||||
let interfaces = datalink::interfaces();
|
||||
for iface in interfaces {
|
||||
let Some(mac) = iface.mac else {
|
||||
continue;
|
||||
};
|
||||
if iface.is_loopback() {
|
||||
continue;
|
||||
}
|
||||
let ipv4 = iface.ips.iter().find_map(|n| match n.ip() {
|
||||
IpAddr::V4(ip) => Some(ip),
|
||||
IpAddr::V6(_) => None,
|
||||
})?;
|
||||
return Some((iface.name, ipv4, mac));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn send_raw_frame(interface_name: &str, frame: &[u8]) -> io::Result<()> {
|
||||
if frame.len() < 6 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"frame too short",
|
||||
));
|
||||
}
|
||||
|
||||
let c_ifname = CString::new(interface_name)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid interface name"))?;
|
||||
let ifindex = unsafe { libc::if_nametoindex(c_ifname.as_ptr()) as i32 };
|
||||
if ifindex <= 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"interface not found",
|
||||
));
|
||||
}
|
||||
|
||||
let proto: i32 = (libc::ETH_P_ALL as u16).to_be() as i32;
|
||||
let fd = unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_RAW, proto) };
|
||||
if fd < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
let fd = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||
|
||||
let mut addr: libc::sockaddr_ll = unsafe { mem::zeroed() };
|
||||
addr.sll_family = libc::AF_PACKET as u16;
|
||||
addr.sll_protocol = (libc::ETH_P_ALL as u16).to_be();
|
||||
addr.sll_ifindex = ifindex;
|
||||
addr.sll_halen = 6;
|
||||
addr.sll_addr[..6].copy_from_slice(&frame[..6]);
|
||||
|
||||
let ret = unsafe {
|
||||
libc::sendto(
|
||||
fd.as_raw_fd(),
|
||||
frame.as_ptr() as *const libc::c_void,
|
||||
frame.len(),
|
||||
0,
|
||||
&addr as *const _ as *const libc::sockaddr,
|
||||
mem::size_of::<libc::sockaddr_ll>() as u32,
|
||||
)
|
||||
};
|
||||
if ret < 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn linux_bpf_tun_receives_matching_ipv4_frame() {
|
||||
if !is_root() {
|
||||
eprintln!("linux_bpf_tun_receives_matching_ipv4_frame: skipped (not root)");
|
||||
return;
|
||||
}
|
||||
|
||||
let Some((ifname, dst_ip, mac)) = pick_interface_v4() else {
|
||||
eprintln!("linux_bpf_tun_receives_matching_ipv4_frame: skipped (no suitable iface)");
|
||||
return;
|
||||
};
|
||||
|
||||
let dst_port: u16 = rand::thread_rng().gen_range(40000..60000);
|
||||
let dst_addr = SocketAddr::new(IpAddr::V4(dst_ip), dst_port);
|
||||
eprintln!(
|
||||
"linux_bpf_tun_receives_matching_ipv4_frame: ifname={ifname} dst_addr={dst_addr} mac={mac}"
|
||||
);
|
||||
|
||||
let tun = LinuxBpfTun::new(&ifname, None, dst_addr).unwrap();
|
||||
|
||||
let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 123, 0, 1)), 12345);
|
||||
eprintln!(
|
||||
"linux_bpf_tun_receives_matching_ipv4_frame: sending frame src_addr={src_addr} -> dst_addr={dst_addr}"
|
||||
);
|
||||
let frame = build_tcp_packet(
|
||||
mac,
|
||||
mac,
|
||||
src_addr,
|
||||
dst_addr,
|
||||
1,
|
||||
0,
|
||||
TcpFlags::SYN,
|
||||
Some(b"ping"),
|
||||
);
|
||||
|
||||
send_raw_frame(&ifname, &frame).unwrap();
|
||||
|
||||
let mut received = BytesMut::new();
|
||||
let n = timeout(Duration::from_secs(2), tun.recv(&mut received))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
eprintln!(
|
||||
"linux_bpf_tun_receives_matching_ipv4_frame: received {} bytes",
|
||||
n
|
||||
);
|
||||
assert_eq!(n, frame.len());
|
||||
assert_eq!(&received[..], &frame[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn linux_bpf_tun_filters_out_non_matching_ipv4_frame() {
|
||||
if !is_root() {
|
||||
eprintln!("linux_bpf_tun_filters_out_non_matching_ipv4_frame: skipped (not root)");
|
||||
return;
|
||||
}
|
||||
|
||||
let Some((ifname, dst_ip, mac)) = pick_interface_v4() else {
|
||||
eprintln!(
|
||||
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: skipped (no suitable iface)"
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
let dst_port: u16 = rand::thread_rng().gen_range(40000..60000);
|
||||
let dst_addr = SocketAddr::new(IpAddr::V4(dst_ip), dst_port);
|
||||
eprintln!(
|
||||
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: ifname={ifname} dst_addr={dst_addr} mac={mac}"
|
||||
);
|
||||
|
||||
let tun = LinuxBpfTun::new(&ifname, None, dst_addr).unwrap();
|
||||
|
||||
let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 123, 0, 2)), 23456);
|
||||
let non_matching_dst = SocketAddr::new(IpAddr::V4(dst_ip), dst_port.wrapping_add(1));
|
||||
eprintln!(
|
||||
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: sending non-matching src_addr={src_addr} -> dst_addr={non_matching_dst}"
|
||||
);
|
||||
let non_matching = build_tcp_packet(
|
||||
mac,
|
||||
mac,
|
||||
src_addr,
|
||||
non_matching_dst,
|
||||
1,
|
||||
0,
|
||||
TcpFlags::SYN,
|
||||
Some(b"nope"),
|
||||
);
|
||||
send_raw_frame(&ifname, &non_matching).unwrap();
|
||||
|
||||
let mut received = BytesMut::new();
|
||||
let non_matching_timeout = timeout(Duration::from_millis(400), tun.recv(&mut received))
|
||||
.await
|
||||
.is_err();
|
||||
eprintln!(
|
||||
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: non-matching recv timeout={}",
|
||||
non_matching_timeout
|
||||
);
|
||||
assert!(non_matching_timeout);
|
||||
|
||||
eprintln!(
|
||||
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: sending matching src_addr={src_addr} -> dst_addr={dst_addr}"
|
||||
);
|
||||
let matching = build_tcp_packet(
|
||||
mac,
|
||||
mac,
|
||||
src_addr,
|
||||
dst_addr,
|
||||
2,
|
||||
0,
|
||||
TcpFlags::SYN,
|
||||
Some(b"ok"),
|
||||
);
|
||||
send_raw_frame(&ifname, &matching).unwrap();
|
||||
|
||||
let mut received2 = BytesMut::new();
|
||||
let n = timeout(Duration::from_secs(2), tun.recv(&mut received2))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
eprintln!(
|
||||
"linux_bpf_tun_filters_out_non_matching_ipv4_frame: received {} bytes",
|
||||
n
|
||||
);
|
||||
assert_eq!(n, matching.len());
|
||||
assert_eq!(&received2[..], &matching[..]);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,87 @@
|
||||
pub mod pnet;
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_os = "linux")] {
|
||||
pub mod linux_bpf;
|
||||
|
||||
pub fn create_tun(
|
||||
interface_name: &str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
match linux_bpf::LinuxBpfTun::new(interface_name, src_addr, dst_addr) {
|
||||
Ok(tun) => Arc::new(tun),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
?e,
|
||||
interface_name,
|
||||
"LinuxBpfTun init failed, falling back to PnetTun"
|
||||
);
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
interface_name,
|
||||
pnet::create_packet_filter(src_addr, dst_addr),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if #[cfg(target_os = "macos")] {
|
||||
pub mod macos_bpf;
|
||||
|
||||
pub fn create_tun(
|
||||
interface_name: &str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
match macos_bpf::MacosBpfTun::new(interface_name, src_addr, dst_addr) {
|
||||
Ok(tun) => Arc::new(tun),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
?e,
|
||||
interface_name,
|
||||
"MacosBpfTun init failed, falling back to PnetTun"
|
||||
);
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
interface_name,
|
||||
pnet::create_packet_filter(src_addr, dst_addr),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if #[cfg(all(windows, any(target_arch = "x86_64", target_arch = "x86")))] {
|
||||
pub mod windivert;
|
||||
|
||||
pub fn create_tun(
|
||||
_interface_name: &str,
|
||||
_src_addr: Option<SocketAddr>,
|
||||
local_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
match windivert::WinDivertTun::new(local_addr) {
|
||||
Ok(tun) => Arc::new(tun),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
?e,
|
||||
?local_addr,
|
||||
"WinDivertTun init failed, falling back to PnetTun"
|
||||
);
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
local_addr.to_string().as_str(),
|
||||
pnet::create_packet_filter(None, local_addr),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
pub fn create_tun(
|
||||
interface_name: &str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
interface_name,
|
||||
pnet::create_packet_filter(src_addr, dst_addr),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc, Weak,
|
||||
},
|
||||
};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use dashmap::DashMap;
|
||||
use once_cell::sync::Lazy;
|
||||
use pnet::{
|
||||
datalink::{self, DataLinkSender, NetworkInterface},
|
||||
packet::{ethernet::EtherTypes, ip::IpNextHeaderProtocols, ipv6::Ipv6Packet},
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::tunnel::fake_tcp::stack;
|
||||
|
||||
type PacketFilter = Box<dyn Fn(&[u8]) -> bool + Send + Sync>;
|
||||
|
||||
fn filter_tcp_packet(
|
||||
packet: &[u8],
|
||||
src_addr: Option<&SocketAddr>,
|
||||
dst_addr: Option<&SocketAddr>,
|
||||
) -> bool {
|
||||
use pnet::packet::ethernet::EthernetPacket;
|
||||
use pnet::packet::ipv4::Ipv4Packet;
|
||||
use pnet::packet::tcp::TcpPacket;
|
||||
use pnet::packet::Packet;
|
||||
|
||||
let ethernet = if let Some(ethernet) = EthernetPacket::new(packet) {
|
||||
ethernet
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
match ethernet.get_ethertype() {
|
||||
EtherTypes::Ipv4 => {
|
||||
let ipv4 = if let Some(ipv4) = Ipv4Packet::new(ethernet.payload()) {
|
||||
ipv4
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
|
||||
return false;
|
||||
}
|
||||
|
||||
let tcp = if let Some(tcp) = TcpPacket::new(ipv4.payload()) {
|
||||
tcp
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if let Some(src_addr) = src_addr {
|
||||
if IpAddr::V4(ipv4.get_source()) != src_addr.ip() {
|
||||
return false;
|
||||
}
|
||||
if tcp.get_source() != src_addr.port() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(dst_addr) = dst_addr {
|
||||
if IpAddr::V4(ipv4.get_destination()) != dst_addr.ip() {
|
||||
return false;
|
||||
}
|
||||
if tcp.get_destination() != dst_addr.port() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::trace!(
|
||||
?tcp,
|
||||
"FakeTcpTunnelListener packet matched filter, dispatching, src_addr: {:?}, dst_addr: {:?}, packet_src_ip: {:?}, packet_dst_ip: {:?}, packet_src_port: {:?}, packet_dst_port: {:?}",
|
||||
src_addr,
|
||||
dst_addr,
|
||||
ipv4.get_source(),
|
||||
ipv4.get_destination(),
|
||||
tcp.get_source(),
|
||||
tcp.get_destination(),
|
||||
);
|
||||
}
|
||||
EtherTypes::Ipv6 => {
|
||||
let ipv6 = if let Some(ipv6) = Ipv6Packet::new(ethernet.payload()) {
|
||||
ipv6
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if ipv6.get_next_header() != IpNextHeaderProtocols::Tcp {
|
||||
return false;
|
||||
}
|
||||
|
||||
let tcp = if let Some(tcp) = TcpPacket::new(ipv6.payload()) {
|
||||
tcp
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if let Some(src_addr) = src_addr {
|
||||
if IpAddr::V6(ipv6.get_source()) != src_addr.ip() {
|
||||
return false;
|
||||
}
|
||||
if tcp.get_source() != src_addr.port() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(dst_addr) = dst_addr {
|
||||
if IpAddr::V6(ipv6.get_destination()) != dst_addr.ip() {
|
||||
return false;
|
||||
}
|
||||
if tcp.get_destination() != dst_addr.port() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::trace!(
|
||||
?tcp,
|
||||
"FakeTcpTunnelListener packet matched filter, dispatching"
|
||||
);
|
||||
}
|
||||
_ => return false,
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn create_packet_filter(src_addr: Option<SocketAddr>, dst_addr: SocketAddr) -> PacketFilter {
|
||||
Box::new(move |packet: &[u8]| -> bool {
|
||||
filter_tcp_packet(packet, src_addr.as_ref(), Some(&dst_addr))
|
||||
})
|
||||
}
|
||||
|
||||
struct Subscriber {
|
||||
filter: PacketFilter,
|
||||
sender: tokio::sync::mpsc::Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
struct InterfaceWorker {
|
||||
tx: Mutex<Box<dyn DataLinkSender>>,
|
||||
subscribers: Arc<DashMap<u32, Subscriber>>,
|
||||
}
|
||||
|
||||
impl InterfaceWorker {
|
||||
fn new(interface: NetworkInterface) -> Arc<Self> {
|
||||
let (tx, mut rx) = match datalink::channel(&interface, Default::default()) {
|
||||
Ok(pnet::datalink::Channel::Ethernet(tx, rx)) => (tx, rx),
|
||||
Ok(_) => panic!("Unhandled channel type"),
|
||||
Err(e) => panic!(
|
||||
"An error occurred when creating the datalink channel: {}",
|
||||
e
|
||||
),
|
||||
};
|
||||
|
||||
let subscribers = Arc::new(DashMap::<u32, Subscriber>::new());
|
||||
let subscribers_clone = subscribers.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
loop {
|
||||
match rx.next() {
|
||||
Ok(packet) => {
|
||||
// Iterate over subscribers and send packet if filter matches
|
||||
// Note: DashMap iteration might be slow if many subscribers, but usually few per interface.
|
||||
// For high performance we might need a better structure or read-copy-update.
|
||||
for r in subscribers_clone.iter() {
|
||||
let subscriber = r.value();
|
||||
if (subscriber.filter)(packet) {
|
||||
tracing::trace!(
|
||||
?packet,
|
||||
"InterfaceWorker packet matched filter, dispatching"
|
||||
);
|
||||
// Try send, ignore errors (best effort)
|
||||
let _ = subscriber.sender.try_send(packet.to_vec());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("InterfaceWorker read error: {}", e);
|
||||
// If interface goes down, we might need to handle it.
|
||||
// For now just break and maybe the worker is dead.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Arc::new(Self {
|
||||
tx: Mutex::new(tx),
|
||||
subscribers,
|
||||
})
|
||||
}
|
||||
|
||||
fn subscribe(&self, filter: PacketFilter, sender: tokio::sync::mpsc::Sender<Vec<u8>>) -> u32 {
|
||||
static ID_GEN: AtomicU32 = AtomicU32::new(0);
|
||||
let id = ID_GEN.fetch_add(1, Ordering::Relaxed);
|
||||
self.subscribers.insert(id, Subscriber { filter, sender });
|
||||
id
|
||||
}
|
||||
|
||||
fn unsubscribe(&self, id: u32) {
|
||||
self.subscribers.remove(&id);
|
||||
}
|
||||
}
|
||||
|
||||
static INTERFACE_MANAGERS: Lazy<DashMap<String, Weak<InterfaceWorker>>> = Lazy::new(DashMap::new);
|
||||
|
||||
fn get_or_create_worker(interface_name: &str) -> Arc<InterfaceWorker> {
|
||||
// Check if we have an active worker
|
||||
if let Some(worker) = INTERFACE_MANAGERS
|
||||
.get(interface_name)
|
||||
.and_then(|w| w.upgrade())
|
||||
{
|
||||
return worker;
|
||||
}
|
||||
|
||||
// Need to create new worker.
|
||||
// Lock effectively by using entry API? DashMap entry API might not be enough for complex init.
|
||||
// Let's use a double-check locking style or just accept race condition (creating two workers and one wins).
|
||||
// DashMap doesn't support easy "compute_if_absent" with async or heavy logic without blocking the map shard.
|
||||
|
||||
// But creation is rare.
|
||||
// Let's find interface first.
|
||||
let interfaces = datalink::interfaces();
|
||||
let interface = interfaces
|
||||
.into_iter()
|
||||
.find(|iface| iface.name == interface_name)
|
||||
.expect("Network interface not found");
|
||||
|
||||
let worker = InterfaceWorker::new(interface);
|
||||
INTERFACE_MANAGERS.insert(interface_name.to_string(), Arc::downgrade(&worker));
|
||||
worker
|
||||
}
|
||||
|
||||
pub struct PnetTun {
|
||||
worker: Arc<InterfaceWorker>,
|
||||
subscription_id: u32,
|
||||
recv_queue: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl PnetTun {
|
||||
pub fn new(interface_name: &str, filter: PacketFilter) -> Self {
|
||||
tracing::debug!(interface_name, "Creating new PnetTun");
|
||||
let worker = get_or_create_worker(interface_name);
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1024);
|
||||
let id = worker.subscribe(filter, tx);
|
||||
|
||||
Self {
|
||||
worker,
|
||||
subscription_id: id,
|
||||
recv_queue: Mutex::new(rx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PnetTun {
|
||||
fn drop(&mut self) {
|
||||
tracing::debug!(subscription_id = self.subscription_id, "Dropping PnetTun");
|
||||
self.worker.unsubscribe(self.subscription_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl stack::Tun for PnetTun {
|
||||
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error> {
|
||||
let mut rx = self.recv_queue.lock().await;
|
||||
match rx.recv().await {
|
||||
Some(data) => {
|
||||
tracing::trace!(?data, "PnetTun received packet");
|
||||
packet.extend_from_slice(&data);
|
||||
Ok(data.len())
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("PnetTun recv channel closed");
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"PnetTun channel closed",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> {
|
||||
tracing::trace!(len = packet.len(), "PnetTun try_sending packet");
|
||||
// We need async lock for tx.
|
||||
// try_send is sync. We can use try_lock if available or blocking lock.
|
||||
// tokio::sync::Mutex::try_lock is available.
|
||||
if let Ok(mut tx) = self.worker.tx.try_lock() {
|
||||
tx.send_to(packet, None)
|
||||
.ok_or(std::io::Error::other("send_to failed"))?
|
||||
} else {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"PnetTun tx lock busy",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn driver_type(&self) -> &'static str {
|
||||
"pnet"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
use std::cell::UnsafeCell;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use tokio::sync::Mutex;
|
||||
use windivert::error::WinDivertError;
|
||||
use windivert::packet::WinDivertPacket;
|
||||
use windivert::prelude::{WinDivertFlags, WinDivertShutdownMode};
|
||||
use windivert::{layer, WinDivert};
|
||||
|
||||
use crate::tunnel::fake_tcp::stack;
|
||||
|
||||
struct WinDivertReader {
|
||||
inner: UnsafeCell<WinDivert<layer::NetworkLayer>>,
|
||||
}
|
||||
|
||||
unsafe impl Send for WinDivertReader {}
|
||||
unsafe impl Sync for WinDivertReader {}
|
||||
|
||||
impl WinDivertReader {
|
||||
fn new(inner: WinDivert<layer::NetworkLayer>) -> Self {
|
||||
Self {
|
||||
inner: UnsafeCell::new(inner),
|
||||
}
|
||||
}
|
||||
|
||||
fn recv<'a>(
|
||||
&self,
|
||||
buffer: Option<&'a mut [u8]>,
|
||||
) -> Result<WinDivertPacket<'a, layer::NetworkLayer>, WinDivertError> {
|
||||
let inner = unsafe { &*self.inner.get() };
|
||||
inner.recv(buffer)
|
||||
}
|
||||
|
||||
fn shutdown(&self) -> anyhow::Result<()> {
|
||||
let inner = unsafe { &mut *self.inner.get() };
|
||||
inner
|
||||
.shutdown(WinDivertShutdownMode::Recv)
|
||||
.with_context(|| "WinDivertReader shutdown failed")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn close(&self) -> anyhow::Result<()> {
|
||||
let inner = unsafe { &mut *self.inner.get() };
|
||||
inner
|
||||
.close(windivert::CloseAction::Nothing)
|
||||
.with_context(|| "WinDivertReader close failed")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WinDivertReader {
|
||||
fn drop(&mut self) {
|
||||
if let Err(e) = self.close() {
|
||||
tracing::error!("WinDivertReader close failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WinDivertTun {
|
||||
recv_queue: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
|
||||
sender: Arc<std::sync::Mutex<WinDivert<layer::NetworkLayer>>>,
|
||||
reader: Arc<WinDivertReader>,
|
||||
}
|
||||
|
||||
impl Drop for WinDivertTun {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut sender) = self.sender.lock() {
|
||||
if let Err(e) = sender.close(windivert::CloseAction::Nothing) {
|
||||
tracing::error!("WinDivertSender close failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
if let Err(e) = self.reader.shutdown() {
|
||||
tracing::error!("WinDivertReader shutdown failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WinDivertTun {
|
||||
pub fn new(local_addr: SocketAddr) -> io::Result<Self> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1024);
|
||||
|
||||
let ip_filter = match local_addr {
|
||||
SocketAddr::V4(addr) => format!("ip.DstAddr == {}", addr.ip()),
|
||||
SocketAddr::V6(addr) => format!("ipv6.DstAddr == {}", addr.ip()),
|
||||
};
|
||||
// Filter: DstIP == LocalIP AND TCP.
|
||||
let filter = format!("{} and tcp", ip_filter);
|
||||
|
||||
// Sniff mode: 1 (WINDIVERT_FLAG_SNIFF)
|
||||
// Layer: Network (0)
|
||||
// Priority: 0
|
||||
let flags = WinDivertFlags::default().set_sniff();
|
||||
let reader = WinDivert::network(&filter, 0, flags)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
let reader = Arc::new(WinDivertReader::new(reader));
|
||||
let reader_clone = reader.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let reader = reader_clone;
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
loop {
|
||||
match reader.recv(Some(&mut buffer)) {
|
||||
Ok(packet) => {
|
||||
let data = &packet.data;
|
||||
|
||||
let mut eth_data = vec![0u8; 14 + data.len()];
|
||||
// Set EtherType
|
||||
if data.len() > 0 && data[0] >> 4 == 4 {
|
||||
eth_data[12] = 0x08;
|
||||
eth_data[13] = 0x00;
|
||||
} else {
|
||||
eth_data[12] = 0x86;
|
||||
eth_data[13] = 0xDD;
|
||||
}
|
||||
eth_data[14..].copy_from_slice(data);
|
||||
|
||||
if let Err(_) = tx.blocking_send(eth_data) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// log error?
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Sender: non-sniff, empty filter?
|
||||
// Use "false" to avoid capturing anything.
|
||||
// Flags: 0
|
||||
let sender = WinDivert::network("false", 0, WinDivertFlags::default())
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
|
||||
Ok(Self {
|
||||
recv_queue: Mutex::new(rx),
|
||||
sender: Arc::new(std::sync::Mutex::new(sender)),
|
||||
reader,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl stack::Tun for WinDivertTun {
|
||||
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error> {
|
||||
let mut rx = self.recv_queue.lock().await;
|
||||
match rx.recv().await {
|
||||
Some(data) => {
|
||||
packet.extend_from_slice(&data);
|
||||
Ok(data.len())
|
||||
}
|
||||
None => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"Channel closed",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error> {
|
||||
// Strip ethernet header
|
||||
if packet.len() < 14 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Packet too short",
|
||||
));
|
||||
}
|
||||
let ip_data = &packet[14..];
|
||||
|
||||
let Ok(sender) = self.sender.try_lock() else {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"WinDivert sender lock failed",
|
||||
));
|
||||
};
|
||||
|
||||
let mut pkt = unsafe { WinDivertPacket::<layer::NetworkLayer>::new(ip_data.to_vec()) };
|
||||
pkt.address.set_outbound(true);
|
||||
|
||||
sender.send(&pkt).map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
format!("WinDivert send failed: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn driver_type(&self) -> &'static str {
|
||||
"windivert"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pnet::packet::ethernet::{EtherTypes, EthernetPacket, MutableEthernetPacket};
|
||||
use pnet::packet::{ip, ipv4, ipv6, tcp};
|
||||
use pnet::util::MacAddr;
|
||||
use std::convert::TryInto;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
const IPV4_HEADER_LEN: usize = 20;
|
||||
const IPV6_HEADER_LEN: usize = 40;
|
||||
const TCP_HEADER_LEN: usize = 20;
|
||||
pub const MAX_PACKET_LEN: usize = 1500;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum IPPacket<'p> {
|
||||
V4(ipv4::Ipv4Packet<'p>),
|
||||
V6(ipv6::Ipv6Packet<'p>),
|
||||
}
|
||||
|
||||
impl IPPacket<'_> {
|
||||
pub fn get_source(&self) -> IpAddr {
|
||||
match self {
|
||||
IPPacket::V4(p) => IpAddr::V4(p.get_source()),
|
||||
IPPacket::V6(p) => IpAddr::V6(p.get_source()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_destination(&self) -> IpAddr {
|
||||
match self {
|
||||
IPPacket::V4(p) => IpAddr::V4(p.get_destination()),
|
||||
IPPacket::V6(p) => IpAddr::V6(p.get_destination()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ETH_HDR_LEN: usize = 14;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn build_tcp_packet(
|
||||
src_mac: MacAddr,
|
||||
dst_mac: MacAddr,
|
||||
local_addr: SocketAddr,
|
||||
remote_addr: SocketAddr,
|
||||
seq: u32,
|
||||
ack: u32,
|
||||
flags: u8,
|
||||
payload: Option<&[u8]>,
|
||||
) -> Bytes {
|
||||
let ip_header_len = match local_addr {
|
||||
SocketAddr::V4(_) => IPV4_HEADER_LEN,
|
||||
SocketAddr::V6(_) => IPV6_HEADER_LEN,
|
||||
};
|
||||
let wscale = (flags & tcp::TcpFlags::SYN) != 0;
|
||||
let tcp_header_len = TCP_HEADER_LEN + if wscale { 4 } else { 0 }; // nop + wscale
|
||||
let tcp_total_len = tcp_header_len + payload.map_or(0, |payload| payload.len());
|
||||
let total_len = ip_header_len + tcp_total_len;
|
||||
let mut buf = BytesMut::zeroed(ETH_HDR_LEN + total_len);
|
||||
|
||||
let mut eth_buf = buf.split_to(ETH_HDR_LEN);
|
||||
let mut ip_buf = buf.split_to(ip_header_len);
|
||||
let mut tcp_buf = buf.split_to(tcp_total_len);
|
||||
assert_eq!(0, buf.len());
|
||||
|
||||
let mut tcp = tcp::MutableTcpPacket::new(&mut tcp_buf).unwrap();
|
||||
tcp.set_window(0xffff);
|
||||
tcp.set_source(local_addr.port());
|
||||
tcp.set_destination(remote_addr.port());
|
||||
tcp.set_sequence(seq);
|
||||
tcp.set_acknowledgement(ack);
|
||||
tcp.set_flags(flags);
|
||||
tcp.set_data_offset(TCP_HEADER_LEN as u8 / 4 + if wscale { 1 } else { 0 });
|
||||
if wscale {
|
||||
let wscale = tcp::TcpOption::wscale(14);
|
||||
tcp.set_options(&[tcp::TcpOption::nop(), wscale]);
|
||||
}
|
||||
|
||||
if let Some(payload) = payload {
|
||||
tcp.set_payload(payload);
|
||||
}
|
||||
|
||||
let mut ethernet = MutableEthernetPacket::new(&mut eth_buf).unwrap();
|
||||
ethernet.set_destination(dst_mac);
|
||||
ethernet.set_source(src_mac);
|
||||
ethernet.set_ethertype(EtherTypes::Ipv4);
|
||||
|
||||
match (local_addr, remote_addr) {
|
||||
(SocketAddr::V4(local), SocketAddr::V4(remote)) => {
|
||||
let mut v4 = ipv4::MutableIpv4Packet::new(&mut ip_buf).unwrap();
|
||||
v4.set_version(4);
|
||||
v4.set_header_length(IPV4_HEADER_LEN as u8 / 4);
|
||||
v4.set_next_level_protocol(ip::IpNextHeaderProtocols::Tcp);
|
||||
v4.set_ttl(64);
|
||||
v4.set_source(*local.ip());
|
||||
v4.set_destination(*remote.ip());
|
||||
v4.set_total_length(total_len.try_into().unwrap());
|
||||
v4.set_flags(ipv4::Ipv4Flags::DontFragment);
|
||||
|
||||
tcp.set_checksum(tcp::ipv4_checksum(
|
||||
&tcp.to_immutable(),
|
||||
&v4.get_source(),
|
||||
&v4.get_destination(),
|
||||
));
|
||||
|
||||
v4.set_checksum(ipv4::checksum(&v4.to_immutable()));
|
||||
}
|
||||
(SocketAddr::V6(local), SocketAddr::V6(remote)) => {
|
||||
let mut v6 = ipv6::MutableIpv6Packet::new(&mut ip_buf).unwrap();
|
||||
v6.set_version(6);
|
||||
v6.set_payload_length(tcp_total_len.try_into().unwrap());
|
||||
v6.set_next_header(ip::IpNextHeaderProtocols::Tcp);
|
||||
v6.set_hop_limit(64);
|
||||
v6.set_source(*local.ip());
|
||||
v6.set_destination(*remote.ip());
|
||||
|
||||
tcp.set_checksum(tcp::ipv6_checksum(
|
||||
&tcp.to_immutable(),
|
||||
&v6.get_source(),
|
||||
&v6.get_destination(),
|
||||
));
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ip_buf.unsplit(tcp_buf);
|
||||
eth_buf.unsplit(ip_buf);
|
||||
eth_buf.freeze()
|
||||
}
|
||||
|
||||
#[tracing::instrument(ret)]
|
||||
pub fn parse_ip_packet(
|
||||
buf: &Bytes,
|
||||
) -> Option<(MacAddr, MacAddr, IPPacket<'_>, tcp::TcpPacket<'_>)> {
|
||||
let eth = EthernetPacket::new(buf).unwrap();
|
||||
let src_mac = eth.get_source();
|
||||
let dst_mac = eth.get_destination();
|
||||
|
||||
tracing::trace!("Parsing IP packet: {:?}", eth);
|
||||
|
||||
let buf = &buf[ETH_HDR_LEN..];
|
||||
if buf[0] >> 4 == 4 {
|
||||
let v4 = ipv4::Ipv4Packet::new(buf).unwrap();
|
||||
if v4.get_next_level_protocol() != ip::IpNextHeaderProtocols::Tcp {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tcp = tcp::TcpPacket::new(&buf[IPV4_HEADER_LEN..]).unwrap();
|
||||
Some((src_mac, dst_mac, IPPacket::V4(v4), tcp))
|
||||
} else if buf[0] >> 4 == 6 {
|
||||
let v6 = ipv6::Ipv6Packet::new(buf).unwrap();
|
||||
if v6.get_next_header() != ip::IpNextHeaderProtocols::Tcp {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tcp = tcp::TcpPacket::new(&buf[IPV6_HEADER_LEN..]).unwrap();
|
||||
Some((src_mac, dst_mac, IPPacket::V6(v6), tcp))
|
||||
} else {
|
||||
tracing::trace!("Invalid IP version: {}", buf[0] >> 4);
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,561 @@
|
||||
//! A minimum, userspace TCP based datagram stack
|
||||
//!
|
||||
//! # Overview
|
||||
//!
|
||||
//! `fake-tcp` is a reusable library that implements a minimum TCP stack in
|
||||
//! user space using the Tun interface. It allows programs to send datagrams
|
||||
//! as if they are part of a TCP connection. `fake-tcp` has been tested to
|
||||
//! be able to pass through a variety of NAT and stateful firewalls while
|
||||
//! fully preserves certain desirable behavior such as out of order delivery
|
||||
//! and no congestion/flow controls.
|
||||
//!
|
||||
//! # Core Concepts
|
||||
//!
|
||||
//! The core of the `fake-tcp` crate compose of two structures. [`Stack`] and
|
||||
//! [`Socket`].
|
||||
//!
|
||||
//! ## [`Stack`]
|
||||
//!
|
||||
//! [`Stack`] represents a virtual TCP stack that operates at
|
||||
//! Layer 3. It is responsible for:
|
||||
//!
|
||||
//! * TCP active and passive open and handshake
|
||||
//! * `RST` handling
|
||||
//! * Interact with the Tun interface at Layer 3
|
||||
//! * Distribute incoming datagrams to corresponding [`Socket`]
|
||||
//!
|
||||
//! ## [`Socket`]
|
||||
//!
|
||||
//! [`Socket`] represents a TCP connection. It registers the identifying
|
||||
//! tuple `(src_ip, src_port, dest_ip, dest_port)` inside the [`Stack`] so
|
||||
//! so that incoming packets can be distributed to the right [`Socket`] with
|
||||
//! using a channel. It is also what the client should use for
|
||||
//! sending/receiving datagrams.
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! Please see [`client.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/client.rs)
|
||||
//! and [`server.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/server.rs) files
|
||||
//! from the `phantun` crate for how to use this library in client/server mode, respectively.
|
||||
|
||||
use crate::common::scoped_task::ScopedTask;
|
||||
|
||||
use super::packet::*;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use pnet::packet::tcp::TcpOptionNumbers;
|
||||
use pnet::packet::{tcp, Packet};
|
||||
use pnet::util::MacAddr;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fmt;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc, RwLock,
|
||||
};
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time;
|
||||
use tracing::{info, trace, warn};
|
||||
|
||||
const TIMEOUT: time::Duration = time::Duration::from_secs(1);
|
||||
const RETRIES: usize = 6;
|
||||
const MPMC_BUFFER_LEN: usize = 512;
|
||||
const MPSC_BUFFER_LEN: usize = 128;
|
||||
const MAX_UNACKED_LEN: u32 = 128 * 1024 * 1024; // 128MB
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Tun: Send + Sync + 'static {
|
||||
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error>;
|
||||
fn try_send(&self, packet: &Bytes) -> Result<(), std::io::Error>;
|
||||
fn driver_type(&self) -> &'static str;
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
|
||||
struct AddrTuple {
|
||||
local_addr: SocketAddr,
|
||||
remote_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl AddrTuple {
|
||||
fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> AddrTuple {
|
||||
AddrTuple {
|
||||
local_addr,
|
||||
remote_addr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Shared {
|
||||
tuples: RwLock<HashMap<AddrTuple, flume::Sender<Bytes>>>,
|
||||
listening: RwLock<HashSet<u16>>,
|
||||
tun: Arc<dyn Tun>,
|
||||
ready: mpsc::Sender<Socket>,
|
||||
tuples_purge: broadcast::Sender<AddrTuple>,
|
||||
}
|
||||
|
||||
pub struct Stack {
|
||||
shared: Arc<Shared>,
|
||||
local_ip: Ipv4Addr,
|
||||
local_ip6: Option<Ipv6Addr>,
|
||||
local_mac: MacAddr,
|
||||
ready: mpsc::Receiver<Socket>,
|
||||
reader_task: ScopedTask<()>,
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
|
||||
pub enum State {
|
||||
Idle,
|
||||
SynSent,
|
||||
SynReceived,
|
||||
Established,
|
||||
}
|
||||
|
||||
pub struct Socket {
|
||||
shared: Arc<Shared>,
|
||||
tun: Arc<dyn Tun>,
|
||||
incoming: flume::Receiver<Bytes>,
|
||||
local_addr: SocketAddr,
|
||||
remote_addr: SocketAddr,
|
||||
local_mac: MacAddr,
|
||||
remote_mac: AtomicCell<Option<MacAddr>>,
|
||||
seq: AtomicU32,
|
||||
ack: AtomicU32,
|
||||
last_ack: AtomicU32,
|
||||
state: AtomicCell<State>,
|
||||
}
|
||||
|
||||
/// A socket that represents a unique TCP connection between a server and client.
|
||||
///
|
||||
/// The `Socket` object itself satisfies `Sync` and `Send`, which means it can
|
||||
/// be safely called within an async future.
|
||||
///
|
||||
/// To close a TCP connection that is no longer needed, simply drop this object
|
||||
/// out of scope.
|
||||
impl Socket {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
shared: Arc<Shared>,
|
||||
tun: Arc<dyn Tun>,
|
||||
local_addr: SocketAddr,
|
||||
remote_addr: SocketAddr,
|
||||
local_mac: MacAddr,
|
||||
remote_mac: Option<MacAddr>,
|
||||
ack: Option<u32>,
|
||||
state: State,
|
||||
) -> (Socket, flume::Sender<Bytes>) {
|
||||
let (incoming_tx, incoming_rx) = flume::bounded(MPMC_BUFFER_LEN);
|
||||
|
||||
(
|
||||
Socket {
|
||||
shared,
|
||||
tun,
|
||||
incoming: incoming_rx,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
local_mac,
|
||||
remote_mac: AtomicCell::new(remote_mac),
|
||||
seq: AtomicU32::new(0),
|
||||
ack: AtomicU32::new(ack.unwrap_or(0)),
|
||||
last_ack: AtomicU32::new(ack.unwrap_or(0)),
|
||||
state: AtomicCell::new(state),
|
||||
},
|
||||
incoming_tx,
|
||||
)
|
||||
}
|
||||
|
||||
fn build_tcp_packet(&self, flags: u8, payload: Option<&[u8]>) -> Bytes {
|
||||
let ack = self.ack.load(Ordering::Relaxed);
|
||||
self.last_ack.store(ack, Ordering::Relaxed);
|
||||
|
||||
build_tcp_packet(
|
||||
self.local_mac,
|
||||
self.remote_mac.load().unwrap_or(MacAddr::zero()),
|
||||
self.local_addr,
|
||||
self.remote_addr,
|
||||
self.seq.load(Ordering::Relaxed),
|
||||
ack,
|
||||
flags,
|
||||
payload,
|
||||
)
|
||||
}
|
||||
|
||||
/// Sends a datagram to the other end.
|
||||
///
|
||||
/// This method takes `&self`, and it can be called safely by multiple threads
|
||||
/// at the same time.
|
||||
///
|
||||
/// A return of `None` means the Tun socket returned an error
|
||||
/// and this socket must be closed.
|
||||
pub fn try_send(&self, payload: &[u8]) -> Option<()> {
|
||||
match self.state.load() {
|
||||
State::Established => {
|
||||
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
|
||||
self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed);
|
||||
self.tun.try_send(&buf).ok().and(Some(()))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
if self.state.load() != State::Idle {
|
||||
let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None);
|
||||
let _ = self.tun.try_send(&buf);
|
||||
self.state.store(State::Idle);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn recv_bytes(&self) -> Option<Vec<u8>> {
|
||||
let mut buf = [0u8; 2048];
|
||||
self.recv(&mut buf).await.map(|size| buf[..size].to_vec())
|
||||
}
|
||||
|
||||
/// Attempt to receive a datagram from the other end.
|
||||
///
|
||||
/// This method takes `&self`, and it can be called safely by multiple threads
|
||||
/// at the same time.
|
||||
///
|
||||
/// A return of `None` means the TCP connection is broken
|
||||
/// and this socket must be closed.
|
||||
pub async fn recv(&self, buf: &mut [u8]) -> Option<usize> {
|
||||
tracing::trace!(
|
||||
"Socket recv called, local_addr: {:?}, remote_addr: {:?}",
|
||||
self.local_addr,
|
||||
self.remote_addr
|
||||
);
|
||||
loop {
|
||||
match self.state.load() {
|
||||
State::Established => {
|
||||
let Ok(raw_buf) = self.incoming.recv_async().await else {
|
||||
info!("Connection {} recv error", self);
|
||||
return None;
|
||||
};
|
||||
|
||||
let (src_mac, dst_mac, _v4_packet, tcp_packet) =
|
||||
parse_ip_packet(&raw_buf).unwrap();
|
||||
|
||||
tracing::trace!(
|
||||
"Socket received TCP packet from {}({:?}) to {}({:?}): {:?}",
|
||||
self.remote_addr,
|
||||
src_mac,
|
||||
self.local_addr,
|
||||
dst_mac,
|
||||
tcp_packet
|
||||
);
|
||||
|
||||
self.remote_mac.store(Some(src_mac));
|
||||
|
||||
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
|
||||
info!("Connection {} reset by peer", self);
|
||||
return None;
|
||||
}
|
||||
|
||||
if (tcp_packet.get_flags() & tcp::TcpFlags::ACK) != 0
|
||||
&& tcp_packet.payload().is_empty()
|
||||
{
|
||||
self.seq
|
||||
.store(tcp_packet.get_acknowledgement(), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let payload = tcp_packet.payload();
|
||||
|
||||
let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32);
|
||||
self.ack.store(new_ack, Ordering::Relaxed);
|
||||
|
||||
for opt in tcp_packet.get_options_iter() {
|
||||
if opt.get_number() == TcpOptionNumbers::SACK {
|
||||
// SACK 选项类型为 5
|
||||
let payload = opt.payload();
|
||||
for chunk in payload.chunks(8) {
|
||||
if chunk.len() != 8 {
|
||||
continue;
|
||||
}
|
||||
let left = tcp_packet.get_acknowledgement();
|
||||
let right = u32::from_be_bytes(chunk[0..4].try_into().unwrap());
|
||||
let len = right.wrapping_sub(left);
|
||||
|
||||
let sack_end = u32::from_be_bytes(chunk[4..8].try_into().unwrap());
|
||||
if len == 0 || sack_end <= left {
|
||||
continue;
|
||||
}
|
||||
|
||||
let send_len = std::cmp::min(len, 1400) as usize;
|
||||
let data = vec![0u8; send_len];
|
||||
|
||||
let buf = build_tcp_packet(
|
||||
self.local_mac,
|
||||
self.remote_mac.load().unwrap_or(MacAddr::zero()),
|
||||
self.local_addr,
|
||||
self.remote_addr,
|
||||
left,
|
||||
self.ack.load(Ordering::Relaxed),
|
||||
tcp::TcpFlags::ACK,
|
||||
Some(&data),
|
||||
);
|
||||
|
||||
if let Err(e) = self.tun.try_send(&buf) {
|
||||
tracing::error!("Failed to send SACK response: {}", e);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if payload.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if payload.len() >= buf.len() {
|
||||
tracing::warn!(
|
||||
"Payload len {} > buf len {}, tcp: {:?}, payload: {:?}",
|
||||
payload.len(),
|
||||
buf.len(),
|
||||
tcp_packet,
|
||||
payload
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
buf[..payload.len()].copy_from_slice(payload);
|
||||
|
||||
return Some(payload.len());
|
||||
}
|
||||
State::SynSent => {
|
||||
let Ok(Ok(buf)) = time::timeout(TIMEOUT, self.incoming.recv_async()).await
|
||||
else {
|
||||
info!("Waiting for client SYN + ACK timed out");
|
||||
return None;
|
||||
};
|
||||
let (src_mac, _dst_mac, _v4_packet, tcp_packet) =
|
||||
parse_ip_packet(&buf).unwrap();
|
||||
|
||||
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
|
||||
tracing::trace!("Connection {} reset by peer", self);
|
||||
return None;
|
||||
}
|
||||
|
||||
let expected_flag = tcp::TcpFlags::SYN | tcp::TcpFlags::ACK;
|
||||
if (tcp_packet.get_flags() & expected_flag) == expected_flag {
|
||||
// found our SYN + ACK
|
||||
self.seq
|
||||
.store(tcp_packet.get_acknowledgement(), Ordering::Relaxed);
|
||||
self.ack
|
||||
.store(tcp_packet.get_sequence() + 1, Ordering::Relaxed);
|
||||
self.remote_mac.store(Some(src_mac));
|
||||
self.state.store(State::Established);
|
||||
return Some(0);
|
||||
}
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_addr(&self) -> SocketAddr {
|
||||
self.local_addr
|
||||
}
|
||||
|
||||
pub fn remote_addr(&self) -> SocketAddr {
|
||||
self.remote_addr
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Socket {
|
||||
/// Drop the socket and close the TCP connection
|
||||
fn drop(&mut self) {
|
||||
let tuple = AddrTuple::new(self.local_addr, self.remote_addr);
|
||||
// dissociates ourself from the dispatch map
|
||||
assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some());
|
||||
// purge cache
|
||||
let _ = self.shared.tuples_purge.send(tuple);
|
||||
|
||||
let buf = build_tcp_packet(
|
||||
self.local_mac,
|
||||
self.remote_mac.load().unwrap_or(MacAddr::zero()),
|
||||
self.local_addr,
|
||||
self.remote_addr,
|
||||
self.seq.load(Ordering::Relaxed),
|
||||
0,
|
||||
tcp::TcpFlags::RST,
|
||||
None,
|
||||
);
|
||||
if let Err(e) = self.tun.try_send(&buf) {
|
||||
warn!("Unable to send RST to remote end: {}", e);
|
||||
}
|
||||
|
||||
info!("Fake TCP connection to {} closed", self);
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Socket {
|
||||
/// User-friendly string representation of the socket
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"(Fake TCP connection from {} to {})",
|
||||
self.local_addr, self.remote_addr
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// A userspace TCP state machine
|
||||
impl Stack {
|
||||
/// Create a new stack, `tun` is an array of [`Tun`](tokio_tun::Tun).
|
||||
/// When more than one [`Tun`](tokio_tun::Tun) object is passed in, same amount
|
||||
/// of reader will be spawned later. This allows user to utilize the performance
|
||||
/// benefit of Multiqueue Tun support on machines with SMP.
|
||||
pub fn new(
|
||||
tun: Arc<dyn Tun>,
|
||||
local_ip: Ipv4Addr,
|
||||
local_ip6: Option<Ipv6Addr>,
|
||||
local_mac: Option<MacAddr>,
|
||||
) -> Stack {
|
||||
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
|
||||
let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
|
||||
let shared = Arc::new(Shared {
|
||||
tuples: RwLock::new(HashMap::new()),
|
||||
tun: tun.clone(),
|
||||
listening: RwLock::new(HashSet::new()),
|
||||
ready: ready_tx,
|
||||
tuples_purge: tuples_purge_tx.clone(),
|
||||
});
|
||||
|
||||
let t = tokio::spawn(Stack::reader_task(
|
||||
tun,
|
||||
shared.clone(),
|
||||
tuples_purge_tx.subscribe(),
|
||||
));
|
||||
|
||||
Stack {
|
||||
shared,
|
||||
local_ip,
|
||||
local_ip6,
|
||||
local_mac: local_mac.unwrap_or(MacAddr::zero()),
|
||||
ready: ready_rx,
|
||||
reader_task: t.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the driver type of the stack.
|
||||
pub fn driver_type(&self) -> &'static str {
|
||||
self.shared.tun.driver_type()
|
||||
}
|
||||
|
||||
/// Listens for incoming connections on the given `port`.
|
||||
pub fn listen(&mut self, port: u16) {
|
||||
assert!(self.shared.listening.write().unwrap().insert(port));
|
||||
}
|
||||
|
||||
/// Accepts an incoming connection.
|
||||
pub async fn accept(&mut self) -> Socket {
|
||||
self.ready.recv().await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn alloc_established_socket(
|
||||
&mut self,
|
||||
local_addr: SocketAddr,
|
||||
remote_addr: SocketAddr,
|
||||
state: State,
|
||||
) -> Socket {
|
||||
let tuple = AddrTuple::new(local_addr, remote_addr);
|
||||
let mut tuples = self.shared.tuples.write().unwrap();
|
||||
let (sock, incoming) = Socket::new(
|
||||
self.shared.clone(),
|
||||
// self.shared.tun.choose(&mut rng).unwrap().clone(),
|
||||
self.shared.tun.clone(), // Simplification: just use the first tun
|
||||
local_addr,
|
||||
remote_addr,
|
||||
self.local_mac,
|
||||
None,
|
||||
Some(0), // Initial ACK
|
||||
state,
|
||||
);
|
||||
assert!(tuples.insert(tuple, incoming).is_none());
|
||||
sock
|
||||
}
|
||||
|
||||
async fn reader_task(
|
||||
tun: Arc<dyn Tun>,
|
||||
shared: Arc<Shared>,
|
||||
mut tuples_purge: broadcast::Receiver<AddrTuple>,
|
||||
) {
|
||||
let mut tuples: HashMap<AddrTuple, flume::Sender<Bytes>> = HashMap::new();
|
||||
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
tokio::select! {
|
||||
size = tun.recv(&mut buf) => {
|
||||
let size = size.unwrap();
|
||||
tracing::trace!(len = size, ?buf, "PnetTun received packet");
|
||||
let buf = buf.split().freeze();
|
||||
|
||||
match parse_ip_packet(&buf) {
|
||||
Some((_src_mac, _dst_mac, ip_packet, tcp_packet)) => {
|
||||
let local_addr = SocketAddr::new(
|
||||
ip_packet.get_destination(),
|
||||
tcp_packet.get_destination(),
|
||||
);
|
||||
let remote_addr = SocketAddr::new(
|
||||
ip_packet.get_source(),
|
||||
tcp_packet.get_source(),
|
||||
);
|
||||
|
||||
let tuple = AddrTuple::new(local_addr, remote_addr);
|
||||
if let Some(c) = tuples.get(&tuple) {
|
||||
if c.send_async(buf).await.is_err() {
|
||||
trace!("Cache hit, but receiver already closed, dropping packet");
|
||||
}
|
||||
|
||||
continue;
|
||||
|
||||
// If not Ok, receiver has been closed and just fall through to the slow
|
||||
// path below
|
||||
} else {
|
||||
trace!("Cache miss, checking the shared tuples table for connection");
|
||||
let sender = {
|
||||
let tuples = shared.tuples.read().unwrap();
|
||||
tuples.get(&tuple).cloned()
|
||||
};
|
||||
|
||||
if let Some(c) = sender {
|
||||
trace!("Storing connection information into local tuples");
|
||||
tuples.insert(tuple, c.clone());
|
||||
if let Err(e) = c.send_async(buf).await {
|
||||
trace!("Error sending packet to connection: {:?}", e);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if tcp_packet.get_flags() == tcp::TcpFlags::SYN
|
||||
&& shared
|
||||
.listening
|
||||
.read()
|
||||
.unwrap()
|
||||
.contains(&tcp_packet.get_destination())
|
||||
{
|
||||
trace!(?tcp_packet, "Received SYN packet for port {}, ignoring", tcp_packet.get_destination());
|
||||
continue;
|
||||
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
|
||||
info!("Unknown RST TCP packet from {}, ignoring", remote_addr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
trace!("Dropping packet with no IP/TCP header");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
},
|
||||
tuple = tuples_purge.recv() => {
|
||||
let tuple = tuple.unwrap();
|
||||
tuples.remove(&tuple);
|
||||
trace!("Removed cached tuple: {:?}", tuple);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ use self::packet_def::ZCPacket;
|
||||
|
||||
pub mod buf;
|
||||
pub mod common;
|
||||
pub mod fake_tcp;
|
||||
pub mod filter;
|
||||
pub mod mpsc;
|
||||
pub mod packet_def;
|
||||
@@ -23,8 +24,14 @@ pub mod stats;
|
||||
pub mod tcp;
|
||||
pub mod udp;
|
||||
|
||||
pub const PROTO_PORT_OFFSET: &[(&str, u16)] =
|
||||
&[("tcp", 0), ("udp", 0), ("wg", 1), ("ws", 1), ("wss", 2)];
|
||||
pub const PROTO_PORT_OFFSET: &[(&str, u16)] = &[
|
||||
("tcp", 0),
|
||||
("udp", 0),
|
||||
("wg", 1),
|
||||
("ws", 1),
|
||||
("wss", 2),
|
||||
("faketcp", 3),
|
||||
];
|
||||
|
||||
#[cfg(feature = "wireguard")]
|
||||
pub mod wireguard;
|
||||
@@ -139,7 +146,9 @@ pub trait TunnelConnector: Send {
|
||||
|
||||
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
|
||||
if let Ok(sock_addr) = addr.parse::<SocketAddr>() {
|
||||
let mut ret_url = url::Url::parse(format!("{}://0.0.0.0", scheme).as_str()).unwrap();
|
||||
let url_str = format!("{}://0.0.0.0", scheme);
|
||||
let mut ret_url = url::Url::parse(url_str.as_str())
|
||||
.unwrap_or_else(|_| panic!("invalid url: {}", url_str));
|
||||
ret_url.set_ip_host(sock_addr.ip()).unwrap();
|
||||
ret_url.set_port(Some(sock_addr.port())).unwrap();
|
||||
ret_url
|
||||
@@ -200,6 +209,7 @@ fn default_port(scheme: &str) -> Option<u16> {
|
||||
"udp" => Some(11010),
|
||||
"ws" => Some(11011),
|
||||
"wss" => Some(11012),
|
||||
"faketcp" => Some(11013),
|
||||
"quic" => Some(11012),
|
||||
"wg" => Some(11011),
|
||||
_ => None,
|
||||
|
||||
Reference in New Issue
Block a user