mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 18:24:36 +00:00
zero copy tunnel (#55)
make tunnel zero copy, for better performance. remove most of the locks in io path. introduce quic tunnel prepare for encryption
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::io::IoSlice;
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
pub(crate) struct BufList<T> {
|
||||
bufs: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Buf> BufList<T> {
|
||||
pub(crate) fn new() -> BufList<T> {
|
||||
BufList {
|
||||
bufs: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn push(&mut self, buf: T) {
|
||||
debug_assert!(buf.has_remaining());
|
||||
self.bufs.push_back(buf);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn bufs_cnt(&self) -> usize {
|
||||
self.bufs.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Buf> Buf for BufList<T> {
|
||||
#[inline]
|
||||
fn remaining(&self) -> usize {
|
||||
self.bufs.iter().map(|buf| buf.remaining()).sum()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn chunk(&self) -> &[u8] {
|
||||
self.bufs.front().map(Buf::chunk).unwrap_or_default()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn advance(&mut self, mut cnt: usize) {
|
||||
while cnt > 0 {
|
||||
{
|
||||
let front = &mut self.bufs[0];
|
||||
let rem = front.remaining();
|
||||
if rem > cnt {
|
||||
front.advance(cnt);
|
||||
return;
|
||||
} else {
|
||||
front.advance(rem);
|
||||
cnt -= rem;
|
||||
}
|
||||
}
|
||||
self.bufs.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
|
||||
if dst.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let mut vecs = 0;
|
||||
for buf in &self.bufs {
|
||||
vecs += buf.chunks_vectored(&mut dst[vecs..]);
|
||||
if vecs == dst.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
vecs
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn copy_to_bytes(&mut self, len: usize) -> Bytes {
|
||||
// Our inner buffer may have an optimized version of copy_to_bytes, and if the whole
|
||||
// request can be fulfilled by the front buffer, we can take advantage.
|
||||
match self.bufs.front_mut() {
|
||||
Some(front) if front.remaining() == len => {
|
||||
let b = front.copy_to_bytes(len);
|
||||
self.bufs.pop_front();
|
||||
b
|
||||
}
|
||||
Some(front) if front.remaining() > len => front.copy_to_bytes(len),
|
||||
_ => {
|
||||
assert!(len <= self.remaining(), "`len` greater than remaining");
|
||||
let mut bm = BytesMut::with_capacity(len);
|
||||
bm.put(self.take(len));
|
||||
bm.freeze()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,539 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
net::{IpAddr, SocketAddr},
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
task::{ready, Poll},
|
||||
};
|
||||
|
||||
use futures::{stream::FuturesUnordered, Future, Sink, Stream};
|
||||
use network_interface::NetworkInterfaceConfig as _;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::io::{poll_read_buf, poll_write_buf};
|
||||
use zerocopy::FromBytes as _;
|
||||
|
||||
use crate::{
|
||||
rpc::TunnelInfo,
|
||||
tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE},
|
||||
};
|
||||
|
||||
use super::{
|
||||
buf::BufList,
|
||||
packet_def::{TCPTunnelHeader, ZCPacketType, TCP_TUNNEL_HEADER_SIZE},
|
||||
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
|
||||
};
|
||||
|
||||
pub struct TunnelWrapper<R, W> {
|
||||
reader: Arc<Mutex<Option<R>>>,
|
||||
writer: Arc<Mutex<Option<W>>>,
|
||||
info: Option<TunnelInfo>,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
|
||||
impl<R, W> TunnelWrapper<R, W> {
|
||||
pub fn new(reader: R, writer: W, info: Option<TunnelInfo>) -> Self {
|
||||
Self::new_with_associate_data(reader, writer, info, None)
|
||||
}
|
||||
|
||||
pub fn new_with_associate_data(
|
||||
reader: R,
|
||||
writer: W,
|
||||
info: Option<TunnelInfo>,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
) -> Self {
|
||||
TunnelWrapper {
|
||||
reader: Arc::new(Mutex::new(Some(reader))),
|
||||
writer: Arc::new(Mutex::new(Some(writer))),
|
||||
info,
|
||||
associate_data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, W> Tunnel for TunnelWrapper<R, W>
|
||||
where
|
||||
R: ZCPacketStream + Send + 'static,
|
||||
W: ZCPacketSink + Send + 'static,
|
||||
{
|
||||
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>) {
|
||||
let reader = self.reader.lock().unwrap().take().unwrap();
|
||||
let writer = self.writer.lock().unwrap().take().unwrap();
|
||||
(Box::pin(reader), Box::pin(writer))
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
self.info.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// a length delimited codec for async reader
|
||||
pin_project! {
|
||||
pub struct FramedReader<R> {
|
||||
#[pin]
|
||||
reader: R,
|
||||
buf: BytesMut,
|
||||
state: FrameReaderState,
|
||||
max_packet_size: usize,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
}
|
||||
|
||||
// usize means the size remaining to read
|
||||
enum FrameReaderState {
|
||||
ReadingHeader(usize),
|
||||
ReadingBody(usize),
|
||||
}
|
||||
|
||||
impl<R> FramedReader<R> {
|
||||
pub fn new(reader: R, max_packet_size: usize) -> Self {
|
||||
Self::new_with_associate_data(reader, max_packet_size, None)
|
||||
}
|
||||
|
||||
pub fn new_with_associate_data(
|
||||
reader: R,
|
||||
max_packet_size: usize,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
) -> Self {
|
||||
FramedReader {
|
||||
reader,
|
||||
buf: BytesMut::with_capacity(max_packet_size),
|
||||
state: FrameReaderState::ReadingHeader(4),
|
||||
max_packet_size,
|
||||
associate_data,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_one_packet(buf: &mut BytesMut) -> Option<ZCPacket> {
|
||||
if buf.len() < TCP_TUNNEL_HEADER_SIZE {
|
||||
// header is not complete
|
||||
return None;
|
||||
}
|
||||
|
||||
let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap();
|
||||
let body_len = header.len.get() as usize;
|
||||
if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len {
|
||||
// body is not complete
|
||||
return None;
|
||||
}
|
||||
|
||||
// extract one packet
|
||||
let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len);
|
||||
Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Stream for FramedReader<R>
|
||||
where
|
||||
R: AsyncRead + Send + 'static + Unpin,
|
||||
{
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
let mut self_mut = self.project();
|
||||
|
||||
loop {
|
||||
while let Some(packet) = Self::extract_one_packet(self_mut.buf) {
|
||||
return Poll::Ready(Some(Ok(packet)));
|
||||
}
|
||||
|
||||
reserve_buf(
|
||||
&mut self_mut.buf,
|
||||
*self_mut.max_packet_size,
|
||||
*self_mut.max_packet_size * 64,
|
||||
);
|
||||
|
||||
match ready!(poll_read_buf(
|
||||
self_mut.reader.as_mut(),
|
||||
cx,
|
||||
&mut self_mut.buf
|
||||
)) {
|
||||
Ok(size) => {
|
||||
if size == 0 {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Poll::Ready(Some(Err(TunnelError::IOError(e))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct FramedWriter<W> {
|
||||
#[pin]
|
||||
writer: W,
|
||||
sending_bufs: BufList<Bytes>,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> FramedWriter<W> {
|
||||
pub fn new(writer: W) -> Self {
|
||||
Self::new_with_associate_data(writer, None)
|
||||
}
|
||||
|
||||
pub fn new_with_associate_data(
|
||||
writer: W,
|
||||
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||
) -> Self {
|
||||
FramedWriter {
|
||||
writer,
|
||||
sending_bufs: BufList::new(),
|
||||
associate_data: associate_data,
|
||||
}
|
||||
}
|
||||
|
||||
fn max_buffer_count(&self) -> usize {
|
||||
64
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> Sink<SinkItem> for FramedWriter<W>
|
||||
where
|
||||
W: AsyncWrite + Send + 'static,
|
||||
{
|
||||
type Error = TunnelError;
|
||||
|
||||
fn poll_ready(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
let max_buffer_count = self.max_buffer_count();
|
||||
if self.sending_bufs.bufs_cnt() >= max_buffer_count {
|
||||
self.as_mut().poll_flush(cx)
|
||||
} else {
|
||||
tracing::trace!(bufs_cnt = self.sending_bufs.bufs_cnt(), "ready to send");
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, mut item: ZCPacket) -> Result<(), Self::Error> {
|
||||
let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len();
|
||||
let Some(header) = item.mut_tcp_tunnel_header() else {
|
||||
return Err(TunnelError::InvalidPacket("packet too short".to_string()));
|
||||
};
|
||||
header.len.set(tcp_len.try_into().unwrap());
|
||||
|
||||
let item = item.into_bytes(ZCPacketType::TCP);
|
||||
self.project().sending_bufs.push(item);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let mut pinned = self.project();
|
||||
let mut remaining = pinned.sending_bufs.remaining();
|
||||
while remaining != 0 {
|
||||
let n = ready!(poll_write_buf(
|
||||
pinned.writer.as_mut(),
|
||||
cx,
|
||||
pinned.sending_bufs
|
||||
))?;
|
||||
if n == 0 {
|
||||
return Poll::Ready(Err(TunnelError::IOError(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"failed to \
|
||||
write frame to transport",
|
||||
))));
|
||||
}
|
||||
remaining -= n;
|
||||
}
|
||||
|
||||
tracing::trace!(?remaining, "flushed");
|
||||
|
||||
// Try flushing the underlying IO
|
||||
ready!(pinned.writer.poll_flush(cx))?;
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
ready!(self.as_mut().poll_flush(cx))?;
|
||||
ready!(self.project().writer.poll_shutdown(cx))?;
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
if local_ip.is_unspecified() || local_ip.is_multicast() {
|
||||
return None;
|
||||
}
|
||||
let ifaces = network_interface::NetworkInterface::show().ok()?;
|
||||
for iface in ifaces {
|
||||
for addr in iface.addr {
|
||||
if addr.ip() == *local_ip {
|
||||
return Some(iface.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::error!(?local_ip, "can not find interface name by ip");
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2_ext(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
bind_dev: Option<String>,
|
||||
) -> Result<(), TunnelError> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
|
||||
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
|
||||
}
|
||||
|
||||
socket2_socket.set_nonblocking(true)?;
|
||||
socket2_socket.set_reuse_address(true)?;
|
||||
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;
|
||||
|
||||
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
// socket2_socket.set_reuse_port(true)?;
|
||||
|
||||
if bind_addr.ip().is_unspecified() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// win can handle this with bind correctly
|
||||
#[cfg(any(target_os = "ios", target_os = "macos"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
// use IP_BOUND_IF to bind device
|
||||
unsafe {
|
||||
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device");
|
||||
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||
mut futures: FuturesUnordered<Fut>,
|
||||
) -> Result<Ret, TunnelError>
|
||||
where
|
||||
Fut: Future<Output = Result<Ret, E>> + Send + Sync,
|
||||
E: std::error::Error + Into<TunnelError> + Send + Sync + 'static,
|
||||
{
|
||||
// return last error
|
||||
let mut last_err = None;
|
||||
|
||||
while let Some(ret) = futures.next().await {
|
||||
if let Err(e) = ret {
|
||||
last_err = Some(e.into());
|
||||
} else {
|
||||
return ret.map_err(|e| e.into());
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
) -> Result<(), TunnelError> {
|
||||
setup_sokcet2_ext(
|
||||
socket2_socket,
|
||||
bind_addr,
|
||||
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
||||
if buf.capacity() < min_size {
|
||||
buf.reserve(max_size);
|
||||
}
|
||||
}
|
||||
|
||||
pub mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use futures::{SinkExt, StreamExt, TryStreamExt};
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::netns::NetNS,
|
||||
tunnel::{packet_def::ZCPacket, TunnelConnector, TunnelListener},
|
||||
};
|
||||
|
||||
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
||||
let (mut recv, mut send) = tunnel.split();
|
||||
|
||||
if !once {
|
||||
recv.forward(send).await.unwrap();
|
||||
} else {
|
||||
let Some(ret) = recv.next().await else {
|
||||
assert!(false, "recv error");
|
||||
return;
|
||||
};
|
||||
|
||||
if ret.is_err() {
|
||||
tracing::debug!(?ret, "recv error");
|
||||
return;
|
||||
}
|
||||
|
||||
let res = ret.unwrap();
|
||||
tracing::debug!(?res, "recv a msg, try echo back");
|
||||
send.send(res).await.unwrap();
|
||||
}
|
||||
|
||||
tracing::warn!("echo server exit...");
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||
mut listener: L,
|
||||
mut connector: C,
|
||||
l_netns: NetNS,
|
||||
c_netns: NetNS,
|
||||
) where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
l_netns
|
||||
.run_async(|| async {
|
||||
listener.listen().await.unwrap();
|
||||
})
|
||||
.await;
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
connector.remote_url().to_string()
|
||||
);
|
||||
|
||||
let (mut recv, mut send) = tunnel.split();
|
||||
|
||||
send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret.payload(), Bytes::from("12345678abcdefg"));
|
||||
|
||||
drop(send);
|
||||
|
||||
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
|
||||
lis.abort();
|
||||
} else {
|
||||
// lis should finish in 1 second
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
|
||||
let (recv, mut send) = tunnel.split();
|
||||
|
||||
// prepare a 4k buffer with random data
|
||||
let mut send_buf = BytesMut::new();
|
||||
for _ in 0..64 {
|
||||
send_buf.put_i128(rand::random::<i128>());
|
||||
}
|
||||
|
||||
let r = tokio::spawn(async move {
|
||||
let now = Instant::now();
|
||||
let count = recv
|
||||
.try_fold(0usize, |mut ret, _| async move {
|
||||
ret += 1;
|
||||
Ok(ret)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!(
|
||||
"bps: {}",
|
||||
(count / 1024) * 4 / now.elapsed().as_secs() as usize
|
||||
);
|
||||
});
|
||||
|
||||
let now = Instant::now();
|
||||
while now.elapsed().as_secs() < 10 {
|
||||
// send.feed(item)
|
||||
let item = ZCPacket::new_with_payload(send_buf.as_ref());
|
||||
let _ = send.feed(item).await.unwrap();
|
||||
}
|
||||
|
||||
drop(send);
|
||||
drop(connector);
|
||||
drop(tunnel);
|
||||
|
||||
tracing::warn!("wait for recv to finish...");
|
||||
|
||||
let _ = tokio::join!(r);
|
||||
|
||||
lis.abort();
|
||||
let _ = tokio::join!(lis);
|
||||
}
|
||||
|
||||
pub fn enable_log() {
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||
.from_env()
|
||||
.unwrap()
|
||||
.add_directive("tarpc=error".parse().unwrap());
|
||||
tracing_subscriber::fmt::fmt()
|
||||
.pretty()
|
||||
.with_env_filter(filter)
|
||||
.init();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use crate::rpc::TunnelInfo;
|
||||
use auto_impl::auto_impl;
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
|
||||
use self::stats::Throughput;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[auto_impl(Arc, Box)]
|
||||
pub trait TunnelFilter: Send + Sync {
|
||||
type FilterOutput;
|
||||
|
||||
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||
Some(data)
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||
match data {
|
||||
Ok(v) => Some(Ok(v)),
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_output(&self) -> Self::FilterOutput;
|
||||
}
|
||||
|
||||
pub struct TunnelFilterChain<A, B> {
|
||||
a: A,
|
||||
b: B,
|
||||
}
|
||||
|
||||
impl<A, B, OA, OB> TunnelFilter for TunnelFilterChain<A, B>
|
||||
where
|
||||
A: TunnelFilter<FilterOutput = OA>,
|
||||
B: TunnelFilter<FilterOutput = OB>,
|
||||
{
|
||||
type FilterOutput = (OA, OB);
|
||||
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||
let data = self.a.before_send(data)?;
|
||||
self.b.before_send(data)
|
||||
}
|
||||
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||
let data = self.b.after_received(data)?;
|
||||
self.a.after_received(data)
|
||||
}
|
||||
fn filter_output(&self) -> Self::FilterOutput {
|
||||
(self.a.filter_output(), self.b.filter_output())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, B> TunnelFilterChain<A, B> {
|
||||
pub fn new(a: A, b: B) -> Self {
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn chain<T: TunnelFilter>(self, c: T) -> TunnelFilterChain<Self, T> {
|
||||
TunnelFilterChain::new(self, c)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EmptyFilter;
|
||||
impl TunnelFilter for EmptyFilter {
|
||||
type FilterOutput = ();
|
||||
fn filter_output(&self) {}
|
||||
}
|
||||
|
||||
pub trait ToTunnelChain {
|
||||
fn to_chain(self) -> TunnelFilterChain<EmptyFilter, Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
TunnelFilterChain::new(EmptyFilter, self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<O, T: TunnelFilter<FilterOutput = O>> ToTunnelChain for T {}
|
||||
|
||||
pub struct TunnelWithFilter<T, F> {
|
||||
inner: T,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
|
||||
impl<T, F> TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + 'static,
|
||||
F: TunnelFilter + Send + 'static,
|
||||
{
|
||||
pub fn new(inner: T, filter: F) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
filter: Arc::new(filter),
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_sink<S: ZCPacketSink + Unpin + 'static>(&self, sink: S) -> impl ZCPacketSink {
|
||||
struct SinkWrapper<F, S> {
|
||||
sink: S,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
|
||||
impl<F, S> Sink<ZCPacket> for SinkWrapper<F, S>
|
||||
where
|
||||
F: TunnelFilter + 'static,
|
||||
S: ZCPacketSink + 'static + Unpin,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
item: ZCPacket,
|
||||
) -> Result<(), Self::Error> {
|
||||
let Some(item) = self.filter.before_send(item) else {
|
||||
return Ok(());
|
||||
};
|
||||
self.get_mut().sink.start_send_unpin(item)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_close_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
SinkWrapper {
|
||||
sink,
|
||||
filter: self.filter.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_stream<S: ZCPacketStream + Unpin + 'static>(&self, stream: S) -> impl ZCPacketStream {
|
||||
struct StreamWrapper<F, S> {
|
||||
stream: S,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
|
||||
impl<F, S> Stream for StreamWrapper<F, S>
|
||||
where
|
||||
F: TunnelFilter + 'static,
|
||||
S: ZCPacketStream + 'static + Unpin,
|
||||
{
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
let self_mut = self.get_mut();
|
||||
loop {
|
||||
match self_mut.stream.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ret)) => {
|
||||
let Some(ret) = self_mut.filter.after_received(ret) else {
|
||||
continue;
|
||||
};
|
||||
return Poll::Ready(Some(ret));
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Pending => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamWrapper {
|
||||
stream,
|
||||
filter: self.filter.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F> Tunnel for TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + 'static,
|
||||
F: TunnelFilter + Send + 'static,
|
||||
{
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
self.inner.info()
|
||||
}
|
||||
|
||||
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>) {
|
||||
let (stream, sink) = self.inner.split();
|
||||
(
|
||||
Box::pin(self.wrap_stream(stream)),
|
||||
Box::pin(self.wrap_sink(sink)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PacketRecorderTunnelFilter {
|
||||
pub received: Arc<std::sync::Mutex<Vec<ZCPacket>>>,
|
||||
pub sent: Arc<std::sync::Mutex<Vec<ZCPacket>>>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for PacketRecorderTunnelFilter {
|
||||
type FilterOutput = (Vec<ZCPacket>, Vec<ZCPacket>);
|
||||
|
||||
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||
self.received.lock().unwrap().push(data.clone());
|
||||
Some(data)
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.sent.lock().unwrap().push(v.clone().into());
|
||||
Some(Ok(v))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_output(&self) -> Self::FilterOutput {
|
||||
(
|
||||
self.received.lock().unwrap().clone(),
|
||||
self.sent.lock().unwrap().clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
received: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StatsRecorderTunnelFilter {
|
||||
throughput: Arc<Throughput>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for StatsRecorderTunnelFilter {
|
||||
type FilterOutput = Arc<Throughput>;
|
||||
|
||||
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||
self.throughput.record_tx_bytes(data.buf_len() as u64);
|
||||
Some(data)
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.throughput.record_rx_bytes(v.buf_len() as u64);
|
||||
Some(Ok(v))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_output(&self) -> Self::FilterOutput {
|
||||
self.throughput.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl StatsRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
throughput: Arc::new(Throughput::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_throughput(&self) -> Arc<Throughput> {
|
||||
self.throughput.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
use filter::ring::create_ring_tunnel_pair;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct DropSendTunnelFilter {
|
||||
start: AtomicU32,
|
||||
end: AtomicU32,
|
||||
cur: AtomicU32,
|
||||
}
|
||||
|
||||
impl TunnelFilter for DropSendTunnelFilter {
|
||||
type FilterOutput = ();
|
||||
|
||||
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||
self.cur.fetch_add(1, Ordering::SeqCst);
|
||||
if self.cur.load(Ordering::SeqCst) >= self.start.load(Ordering::SeqCst)
|
||||
&& self.cur.load(std::sync::atomic::Ordering::SeqCst)
|
||||
< self.end.load(Ordering::SeqCst)
|
||||
{
|
||||
tracing::trace!("drop packet: {:?}", data);
|
||||
return None;
|
||||
}
|
||||
Some(data)
|
||||
}
|
||||
|
||||
fn filter_output(&self) {}
|
||||
}
|
||||
|
||||
impl DropSendTunnelFilter {
|
||||
pub fn new(start: u32, end: u32) -> Self {
|
||||
Self {
|
||||
start: AtomicU32::new(start),
|
||||
end: AtomicU32::new(end),
|
||||
cur: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_filter() {
|
||||
let filter = Arc::new(
|
||||
PacketRecorderTunnelFilter::new()
|
||||
.to_chain()
|
||||
.chain(PacketRecorderTunnelFilter::new())
|
||||
.chain(PacketRecorderTunnelFilter::new())
|
||||
.chain(PacketRecorderTunnelFilter::new()),
|
||||
);
|
||||
let (s, _b) = create_ring_tunnel_pair();
|
||||
let tunnel = TunnelWithFilter::new(s, filter.clone());
|
||||
|
||||
let (_r, mut s) = tunnel.split();
|
||||
s.send(ZCPacket::new_with_payload("ab".as_bytes()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let out = filter.filter_output();
|
||||
|
||||
let a = out.0 .0 .0 .1;
|
||||
let b = out.0 .0 .1;
|
||||
let c = out.0 .1;
|
||||
let _d = out.1;
|
||||
|
||||
assert_eq!(1, a.0.len());
|
||||
assert_eq!(1, b.0.len());
|
||||
assert_eq!(1, c.0.len());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
use std::{net::SocketAddr, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Sink, Stream};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use tokio::time::error::Elapsed;
|
||||
|
||||
use crate::rpc::TunnelInfo;
|
||||
|
||||
use self::packet_def::ZCPacket;
|
||||
|
||||
pub mod buf;
|
||||
pub mod common;
|
||||
pub mod filter;
|
||||
pub mod mpsc;
|
||||
pub mod packet_def;
|
||||
pub mod quic;
|
||||
pub mod ring;
|
||||
pub mod stats;
|
||||
pub mod tcp;
|
||||
pub mod udp;
|
||||
pub mod wireguard;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum TunnelError {
|
||||
#[error("io error")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("invalid packet. msg: {0}")]
|
||||
InvalidPacket(String),
|
||||
#[error("exceed max packet size. max: {0}, input: {1}")]
|
||||
ExceedMaxPacketSize(usize, usize),
|
||||
|
||||
#[error("invalid protocol: {0}")]
|
||||
InvalidProtocol(String),
|
||||
#[error("invalid addr: {0}")]
|
||||
InvalidAddr(String),
|
||||
|
||||
#[error("internal error {0}")]
|
||||
InternalError(String),
|
||||
|
||||
#[error("conn id not match, expect: {0}, actual: {1}")]
|
||||
ConnIdNotMatch(u32, u32),
|
||||
#[error("buffer full")]
|
||||
BufferFull,
|
||||
|
||||
#[error("timeout")]
|
||||
Timeout(#[from] Elapsed),
|
||||
|
||||
#[error("anyhow error: {0}")]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("shutdown")]
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
pub type StreamT = packet_def::ZCPacket;
|
||||
pub type StreamItem = Result<StreamT, TunnelError>;
|
||||
pub type SinkItem = packet_def::ZCPacket;
|
||||
pub type SinkError = TunnelError;
|
||||
|
||||
pub trait ZCPacketStream: Stream<Item = StreamItem> + Send {}
|
||||
impl<T> ZCPacketStream for T where T: Stream<Item = StreamItem> + Send {}
|
||||
pub trait ZCPacketSink: Sink<SinkItem, Error = SinkError> + Send {}
|
||||
impl<T> ZCPacketSink for T where T: Sink<SinkItem, Error = SinkError> + Send {}
|
||||
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Tunnel: Send {
|
||||
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>);
|
||||
fn info(&self) -> Option<TunnelInfo>;
|
||||
}
|
||||
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
|
||||
fn get(&self) -> u32;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelListener: Send {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError>;
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn local_url(&self) -> url::Url;
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
#[derive(Debug)]
|
||||
struct FakeTunnelConnCounter {}
|
||||
impl TunnelConnCounter for FakeTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
Arc::new(Box::new(FakeTunnelConnCounter {}))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelConnector: Send {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn remote_url(&self) -> url::Url;
|
||||
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
|
||||
}
|
||||
|
||||
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
|
||||
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn Tunnel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Tunnel")
|
||||
.field("info", &self.info())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelConnector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelConnector")
|
||||
.field("remote_url", &self.remote_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelListener {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelListener")
|
||||
.field("local_url", &self.local_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait FromUrl {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub(crate) fn check_scheme_and_get_socket_addr<T>(
|
||||
url: &url::Url,
|
||||
scheme: &str,
|
||||
) -> Result<T, TunnelError>
|
||||
where
|
||||
T: FromUrl,
|
||||
{
|
||||
if url.scheme() != scheme {
|
||||
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||
}
|
||||
|
||||
Ok(T::from_url(url.clone())?)
|
||||
}
|
||||
|
||||
impl FromUrl for SocketAddr {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
Ok(url.socket_addrs(|| None)?.pop().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromUrl for uuid::Uuid {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
let o = url.host_str().unwrap();
|
||||
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
||||
Ok(o)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelUrl {
|
||||
inner: url::Url,
|
||||
}
|
||||
|
||||
impl From<url::Url> for TunnelUrl {
|
||||
fn from(url: url::Url) -> Self {
|
||||
TunnelUrl { inner: url }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TunnelUrl> for url::Url {
|
||||
fn from(url: TunnelUrl) -> Self {
|
||||
url.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelUrl {
|
||||
pub fn into_inner(self) -> url::Url {
|
||||
self.inner
|
||||
}
|
||||
|
||||
pub fn bind_dev(&self) -> Option<String> {
|
||||
self.inner.path().strip_prefix("/").and_then(|s| {
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
// this mod wrap tunnel to a mpsc tunnel, based on crossbeam_channel
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream};
|
||||
|
||||
use tachyonix::{channel, Receiver, Sender};
|
||||
|
||||
use futures::SinkExt;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MpscTunnelSender(Sender<ZCPacket>);
|
||||
|
||||
impl MpscTunnelSender {
|
||||
pub async fn send(&self, item: ZCPacket) -> Result<(), TunnelError> {
|
||||
self.0.send(item).await.with_context(|| "send error")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MpscTunnel<T> {
|
||||
tx: Sender<ZCPacket>,
|
||||
|
||||
tunnel: T,
|
||||
stream: Option<Pin<Box<dyn ZCPacketStream>>>,
|
||||
|
||||
task: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl<T: Tunnel> MpscTunnel<T> {
|
||||
pub fn new(tunnel: T) -> Self {
|
||||
let (tx, mut rx) = channel(32);
|
||||
let (stream, mut sink) = tunnel.split();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
if let Err(e) = Self::forward_one_round(&mut rx, &mut sink).await {
|
||||
tracing::error!(?e, "forward error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
tx,
|
||||
tunnel,
|
||||
stream: Some(stream),
|
||||
task: Some(task),
|
||||
}
|
||||
}
|
||||
|
||||
async fn forward_one_round(
|
||||
rx: &mut Receiver<ZCPacket>,
|
||||
sink: &mut Pin<Box<dyn ZCPacketSink>>,
|
||||
) -> Result<(), TunnelError> {
|
||||
let item = rx.recv().await.with_context(|| "recv error")?;
|
||||
sink.feed(item).await?;
|
||||
while let Ok(item) = rx.try_recv() {
|
||||
if let Err(e) = sink.feed(item).await {
|
||||
tracing::error!(?e, "feed error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
sink.flush().await
|
||||
}
|
||||
|
||||
pub fn get_stream(&mut self) -> Pin<Box<dyn ZCPacketStream>> {
|
||||
self.stream.take().unwrap()
|
||||
}
|
||||
|
||||
pub fn get_sink(&self) -> MpscTunnelSender {
|
||||
MpscTunnelSender(self.tx.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Tunnel> From<T> for MpscTunnel<T> {
|
||||
fn from(tunnel: T) -> Self {
|
||||
Self::new(tunnel)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::StreamExt;
|
||||
|
||||
use crate::tunnel::{
|
||||
tcp::{TcpTunnelConnector, TcpTunnelListener},
|
||||
TunnelConnector, TunnelListener,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
// test slow send lock in framed tunnel
|
||||
#[tokio::test]
|
||||
async fn mpsc_slow_receiver() {
|
||||
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
|
||||
listener.listen().await.unwrap();
|
||||
let t1 = tokio::spawn(async move {
|
||||
let t = listener.accept().await.unwrap();
|
||||
let (mut stream, _sink) = t.split();
|
||||
let now = tokio::time::Instant::now();
|
||||
|
||||
let mut a_counter = 0;
|
||||
let mut b_counter = 0;
|
||||
|
||||
while let Some(Ok(msg)) = stream.next().await {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
break;
|
||||
}
|
||||
|
||||
if msg.payload() == "hello".as_bytes() {
|
||||
a_counter += 1;
|
||||
} else if msg.payload() == "hello2".as_bytes() {
|
||||
b_counter += 1;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t1 exit");
|
||||
assert_ne!(a_counter, 0);
|
||||
assert_ne!(b_counter, 0);
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
let mpsc_tunnel = MpscTunnel::from(tunnel);
|
||||
|
||||
let sink1 = mpsc_tunnel.get_sink();
|
||||
let t2 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
let a = sink1
|
||||
.send(ZCPacket::new_with_payload("hello".as_bytes()))
|
||||
.await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t2 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t2 exit");
|
||||
});
|
||||
|
||||
let sink2 = mpsc_tunnel.get_sink();
|
||||
let t3 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
let a = sink2
|
||||
.send(ZCPacket::new_with_payload("hello2".as_bytes()))
|
||||
.await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t3 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t3 exit");
|
||||
});
|
||||
|
||||
let t4 = tokio::spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
tracing::info!("closing");
|
||||
drop(mpsc_tunnel);
|
||||
tracing::info!("closed");
|
||||
});
|
||||
|
||||
let _ = tokio::join!(t1, t2, t3, t4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,340 @@
|
||||
use bytes::Bytes;
|
||||
use bytes::BytesMut;
|
||||
use zerocopy::byteorder::*;
|
||||
use zerocopy::AsBytes;
|
||||
use zerocopy::FromBytes;
|
||||
use zerocopy::FromZeroes;
|
||||
|
||||
type DefaultEndian = LittleEndian;
|
||||
|
||||
// TCP TunnelHeader
|
||||
#[repr(C, packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||
pub struct TCPTunnelHeader {
|
||||
pub len: U32<DefaultEndian>,
|
||||
}
|
||||
pub const TCP_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::<TCPTunnelHeader>();
|
||||
|
||||
#[derive(AsBytes, FromZeroes, Clone, Debug)]
|
||||
#[repr(u8)]
|
||||
pub enum UdpPacketType {
|
||||
Invalid = 0,
|
||||
Syn = 1,
|
||||
Sack = 2,
|
||||
Data = 3,
|
||||
Fin = 4,
|
||||
HolePunch = 5,
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||
pub struct UDPTunnelHeader {
|
||||
pub conn_id: U32<DefaultEndian>,
|
||||
pub msg_type: u8,
|
||||
pub padding: u8,
|
||||
pub len: U16<DefaultEndian>,
|
||||
}
|
||||
pub const UDP_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::<UDPTunnelHeader>();
|
||||
|
||||
#[repr(C, packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||
pub struct WGTunnelHeader {
|
||||
pub ipv4_header: [u8; 20],
|
||||
}
|
||||
pub const WG_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::<WGTunnelHeader>();
|
||||
|
||||
#[derive(AsBytes, FromZeroes, Clone, Debug)]
|
||||
#[repr(u8)]
|
||||
pub enum PacketType {
|
||||
Invalid = 0,
|
||||
Data = 1,
|
||||
HandShake = 2,
|
||||
RoutePacket = 3,
|
||||
Ping = 4,
|
||||
Pong = 5,
|
||||
TaRpc = 6,
|
||||
Route = 7,
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||
pub struct PeerManagerHeader {
|
||||
pub from_peer_id: U32<DefaultEndian>,
|
||||
pub to_peer_id: U32<DefaultEndian>,
|
||||
pub packet_type: u8,
|
||||
pub len: U32<DefaultEndian>,
|
||||
}
|
||||
pub const PEER_MANAGER_HEADER_SIZE: usize = std::mem::size_of::<PeerManagerHeader>();
|
||||
|
||||
const fn max(a: usize, b: usize) -> usize {
|
||||
[a, b][(a < b) as usize]
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ZCPacketOffsets {
|
||||
pub payload_offset: usize,
|
||||
pub peer_manager_header_offset: usize,
|
||||
pub tcp_tunnel_header_offset: usize,
|
||||
pub udp_tunnel_header_offset: usize,
|
||||
pub wg_tunnel_header_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum ZCPacketType {
|
||||
// received from peer tcp connection
|
||||
TCP,
|
||||
// received from peer udp connection
|
||||
UDP,
|
||||
// received from peer wireguard connection
|
||||
WG,
|
||||
// received from local tun device, should reserve header space for tcp or udp tunnel
|
||||
NIC,
|
||||
}
|
||||
|
||||
const PAYLOAD_OFFSET_FOR_NIC_PACKET: usize = max(
|
||||
max(TCP_TUNNEL_HEADER_SIZE, UDP_TUNNEL_HEADER_SIZE),
|
||||
WG_TUNNEL_HEADER_SIZE,
|
||||
) + PEER_MANAGER_HEADER_SIZE;
|
||||
|
||||
impl ZCPacketType {
|
||||
pub fn get_packet_offsets(&self) -> ZCPacketOffsets {
|
||||
match self {
|
||||
ZCPacketType::TCP => ZCPacketOffsets {
|
||||
payload_offset: TCP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||
peer_manager_header_offset: TCP_TUNNEL_HEADER_SIZE,
|
||||
..Default::default()
|
||||
},
|
||||
ZCPacketType::UDP => ZCPacketOffsets {
|
||||
payload_offset: UDP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||
peer_manager_header_offset: UDP_TUNNEL_HEADER_SIZE,
|
||||
..Default::default()
|
||||
},
|
||||
ZCPacketType::WG => ZCPacketOffsets {
|
||||
payload_offset: WG_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||
peer_manager_header_offset: WG_TUNNEL_HEADER_SIZE,
|
||||
..Default::default()
|
||||
},
|
||||
ZCPacketType::NIC => ZCPacketOffsets {
|
||||
payload_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET,
|
||||
peer_manager_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||
- PEER_MANAGER_HEADER_SIZE,
|
||||
tcp_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||
- PEER_MANAGER_HEADER_SIZE
|
||||
- TCP_TUNNEL_HEADER_SIZE,
|
||||
udp_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||
- PEER_MANAGER_HEADER_SIZE
|
||||
- UDP_TUNNEL_HEADER_SIZE,
|
||||
wg_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||
- PEER_MANAGER_HEADER_SIZE
|
||||
- WG_TUNNEL_HEADER_SIZE,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ZCPacket {
|
||||
inner: BytesMut,
|
||||
packet_type: ZCPacketType,
|
||||
}
|
||||
|
||||
impl ZCPacket {
|
||||
pub fn new_nic_packet() -> Self {
|
||||
Self {
|
||||
inner: BytesMut::new(),
|
||||
packet_type: ZCPacketType::NIC,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_buf(buf: BytesMut, packet_type: ZCPacketType) -> Self {
|
||||
Self {
|
||||
inner: buf,
|
||||
packet_type,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_payload(payload: &[u8]) -> Self {
|
||||
let mut ret = Self::new_nic_packet();
|
||||
let total_len = ret.packet_type.get_packet_offsets().payload_offset + payload.len();
|
||||
ret.inner.resize(total_len, 0);
|
||||
ret.mut_payload()[..payload.len()].copy_from_slice(&payload);
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn packet_type(&self) -> ZCPacketType {
|
||||
self.packet_type
|
||||
}
|
||||
|
||||
pub fn mut_payload(&mut self) -> &mut [u8] {
|
||||
&mut self.inner[self.packet_type.get_packet_offsets().payload_offset..]
|
||||
}
|
||||
|
||||
pub fn mut_peer_manager_header(&mut self) -> Option<&mut PeerManagerHeader> {
|
||||
PeerManagerHeader::mut_from_prefix(
|
||||
&mut self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.peer_manager_header_offset..],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mut_tcp_tunnel_header(&mut self) -> Option<&mut TCPTunnelHeader> {
|
||||
TCPTunnelHeader::mut_from_prefix(
|
||||
&mut self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.tcp_tunnel_header_offset..],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mut_udp_tunnel_header(&mut self) -> Option<&mut UDPTunnelHeader> {
|
||||
UDPTunnelHeader::mut_from_prefix(
|
||||
&mut self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.udp_tunnel_header_offset..],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mut_wg_tunnel_header(&mut self) -> Option<&mut WGTunnelHeader> {
|
||||
WGTunnelHeader::mut_from_prefix(
|
||||
&mut self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.wg_tunnel_header_offset..],
|
||||
)
|
||||
}
|
||||
|
||||
// ref versions
|
||||
pub fn payload(&self) -> &[u8] {
|
||||
&self.inner[self.packet_type.get_packet_offsets().payload_offset..]
|
||||
}
|
||||
|
||||
pub fn peer_manager_header(&self) -> Option<&PeerManagerHeader> {
|
||||
PeerManagerHeader::ref_from_prefix(
|
||||
&self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.peer_manager_header_offset..],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tcp_tunnel_header(&self) -> Option<&TCPTunnelHeader> {
|
||||
TCPTunnelHeader::ref_from_prefix(
|
||||
&self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.tcp_tunnel_header_offset..],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn udp_tunnel_header(&self) -> Option<&UDPTunnelHeader> {
|
||||
UDPTunnelHeader::ref_from_prefix(
|
||||
&self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.udp_tunnel_header_offset..],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn udp_payload(&self) -> &[u8] {
|
||||
&self.inner[self
|
||||
.packet_type
|
||||
.get_packet_offsets()
|
||||
.udp_tunnel_header_offset
|
||||
+ UDP_TUNNEL_HEADER_SIZE..]
|
||||
}
|
||||
|
||||
pub fn payload_len(&self) -> usize {
|
||||
let payload_offset = self.packet_type.get_packet_offsets().payload_offset;
|
||||
self.inner.len() - payload_offset
|
||||
}
|
||||
|
||||
pub fn buf_len(&self) -> usize {
|
||||
self.inner.len()
|
||||
}
|
||||
|
||||
pub fn fill_peer_manager_hdr(&mut self, from_peer_id: u32, to_peer_id: u32, packet_type: u8) {
|
||||
let payload_len = self.payload_len();
|
||||
let hdr = self.mut_peer_manager_header().unwrap();
|
||||
hdr.from_peer_id.set(from_peer_id);
|
||||
hdr.to_peer_id.set(to_peer_id);
|
||||
hdr.packet_type = packet_type;
|
||||
hdr.len.set(payload_len as u32);
|
||||
}
|
||||
|
||||
pub fn into_bytes(mut self, target_packet_type: ZCPacketType) -> Bytes {
|
||||
if target_packet_type == self.packet_type {
|
||||
return self.inner.freeze();
|
||||
} else {
|
||||
assert_eq!(
|
||||
self.packet_type,
|
||||
ZCPacketType::NIC,
|
||||
"only support NIC, got {:?}",
|
||||
self
|
||||
);
|
||||
}
|
||||
|
||||
match target_packet_type {
|
||||
ZCPacketType::TCP => self
|
||||
.inner
|
||||
.split_off(
|
||||
self.packet_type
|
||||
.get_packet_offsets()
|
||||
.tcp_tunnel_header_offset,
|
||||
)
|
||||
.freeze(),
|
||||
ZCPacketType::UDP => self
|
||||
.inner
|
||||
.split_off(
|
||||
self.packet_type
|
||||
.get_packet_offsets()
|
||||
.udp_tunnel_header_offset,
|
||||
)
|
||||
.freeze(),
|
||||
ZCPacketType::WG => self
|
||||
.inner
|
||||
.split_off(
|
||||
self.packet_type
|
||||
.get_packet_offsets()
|
||||
.wg_tunnel_header_offset,
|
||||
)
|
||||
.freeze(),
|
||||
ZCPacketType::NIC => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inner(self) -> BytesMut {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_zc_packet() {
|
||||
let payload = b"hello world";
|
||||
let mut packet = ZCPacket::new_with_payload(payload);
|
||||
let peer_manager_header = packet.mut_peer_manager_header().unwrap();
|
||||
peer_manager_header.packet_type = PacketType::Data as u8;
|
||||
peer_manager_header.len.set(payload.len() as u32);
|
||||
|
||||
let tcp_tunnel_header = packet.mut_tcp_tunnel_header().unwrap();
|
||||
tcp_tunnel_header.len.set(payload.len() as u32);
|
||||
|
||||
// let udp_tunnel_header = packet.mut_udp_tunnel_header().unwrap();
|
||||
// udp_tunnel_header.conn_id = 1;
|
||||
// udp_tunnel_header.msg_type = 2;
|
||||
// udp_tunnel_header.len = payload.len() as u32;
|
||||
|
||||
assert_eq!(packet.payload(), b"hello world");
|
||||
assert_eq!(packet.payload_len(), 11);
|
||||
println!("{:?}", packet.inner);
|
||||
|
||||
let tcp_packet = packet.into_bytes(ZCPacketType::TCP);
|
||||
assert_eq!(&tcp_packet[..1], b"\x0b");
|
||||
println!("{:?}", tcp_packet);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
//! This example demonstrates how to make a QUIC connection that ignores the server certificate.
|
||||
//!
|
||||
//! Checkout the `README.md` for guidance.
|
||||
|
||||
use std::{error::Error, net::SocketAddr, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
rpc::TunnelInfo,
|
||||
tunnel::common::{FramedReader, FramedWriter, TunnelWrapper},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use quinn::{ClientConfig, Connection, Endpoint, ServerConfig};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr, Tunnel, TunnelConnector, TunnelError, TunnelListener,
|
||||
};
|
||||
|
||||
/// Dummy certificate verifier that treats any certificate as valid.
|
||||
/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing.
|
||||
struct SkipServerVerification;
|
||||
|
||||
impl SkipServerVerification {
|
||||
fn new() -> Arc<Self> {
|
||||
Arc::new(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::client::ServerCertVerifier for SkipServerVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::Certificate,
|
||||
_intermediates: &[rustls::Certificate],
|
||||
_server_name: &rustls::ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: std::time::SystemTime,
|
||||
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_client() -> ClientConfig {
|
||||
let crypto = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_custom_certificate_verifier(SkipServerVerification::new())
|
||||
.with_no_client_auth();
|
||||
|
||||
ClientConfig::new(Arc::new(crypto))
|
||||
}
|
||||
|
||||
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address
|
||||
/// and port.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// - a stream of incoming QUIC connections
|
||||
/// - server certificate serialized into DER format
|
||||
#[allow(unused)]
|
||||
pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec<u8>), Box<dyn Error>> {
|
||||
let (server_config, server_cert) = configure_server()?;
|
||||
let endpoint = Endpoint::server(server_config, bind_addr)?;
|
||||
Ok((endpoint, server_cert))
|
||||
}
|
||||
|
||||
/// Returns default server configuration along with its certificate.
|
||||
fn configure_server() -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> {
|
||||
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
|
||||
let cert_der = cert.serialize_der().unwrap();
|
||||
let priv_key = cert.serialize_private_key_der();
|
||||
let priv_key = rustls::PrivateKey(priv_key);
|
||||
let cert_chain = vec![rustls::Certificate(cert_der.clone())];
|
||||
|
||||
let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?;
|
||||
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
|
||||
transport_config.max_concurrent_uni_streams(10_u8.into());
|
||||
transport_config.max_concurrent_bidi_streams(10_u8.into());
|
||||
|
||||
Ok((server_config, cert_der))
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"];
|
||||
|
||||
/// Runs a QUIC server bound to given address.
|
||||
|
||||
struct ConnWrapper {
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
impl Drop for ConnWrapper {
|
||||
fn drop(&mut self) {
|
||||
self.conn.close(0u32.into(), b"done");
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QUICTunnelListener {
|
||||
addr: url::Url,
|
||||
endpoint: Option<Endpoint>,
|
||||
server_cert: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl QUICTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
QUICTunnelListener {
|
||||
addr,
|
||||
endpoint: None,
|
||||
server_cert: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelListener for QUICTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic")?;
|
||||
let (endpoint, server_cert) = make_server_endpoint(addr).unwrap();
|
||||
self.endpoint = Some(endpoint);
|
||||
self.server_cert = Some(server_cert);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
// accept a single connection
|
||||
let incoming_conn = self.endpoint.as_ref().unwrap().accept().await.unwrap();
|
||||
let conn = incoming_conn.await.unwrap();
|
||||
println!(
|
||||
"[server] connection accepted: addr={}",
|
||||
conn.remote_address()
|
||||
);
|
||||
let remote_addr = conn.remote_address();
|
||||
let (w, r) = conn.accept_bi().await.with_context(|| "accept_bi failed")?;
|
||||
|
||||
let arc_conn = Arc::new(ConnWrapper { conn });
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "quic".to_owned(),
|
||||
local_addr: self.local_url().into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
|
||||
};
|
||||
|
||||
Ok(Box::new(TunnelWrapper::new(
|
||||
FramedReader::new_with_associate_data(r, 4500, Some(Box::new(arc_conn.clone()))),
|
||||
FramedWriter::new_with_associate_data(w, Some(Box::new(arc_conn))),
|
||||
Some(info),
|
||||
)))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QUICTunnelConnector {
|
||||
addr: url::Url,
|
||||
endpoint: Option<Endpoint>,
|
||||
}
|
||||
|
||||
impl QUICTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
QUICTunnelConnector {
|
||||
addr,
|
||||
endpoint: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelConnector for QUICTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic")?;
|
||||
|
||||
let mut endpoint = Endpoint::client("127.0.0.1:0".parse().unwrap())?;
|
||||
endpoint.set_default_client_config(configure_client());
|
||||
|
||||
// connect to server
|
||||
let connection = endpoint.connect(addr, "localhost").unwrap().await.unwrap();
|
||||
println!("[client] connected: addr={}", connection.remote_address());
|
||||
|
||||
let local_addr = endpoint.local_addr().unwrap();
|
||||
|
||||
self.endpoint = Some(endpoint);
|
||||
|
||||
let (w, r) = connection
|
||||
.open_bi()
|
||||
.await
|
||||
.with_context(|| "open_bi failed")?;
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "quic".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(),
|
||||
remote_addr: self.addr.to_string(),
|
||||
};
|
||||
|
||||
let arc_conn = Arc::new(ConnWrapper { conn: connection });
|
||||
Ok(Box::new(TunnelWrapper::new(
|
||||
FramedReader::new_with_associate_data(r, 4500, Some(Box::new(arc_conn.clone()))),
|
||||
FramedWriter::new_with_associate_data(w, Some(Box::new(arc_conn))),
|
||||
Some(info),
|
||||
)))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnel::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn quic_pingpong() {
|
||||
let listener = QUICTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
|
||||
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quic_bench() {
|
||||
let listener = QUICTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
|
||||
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,427 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::{Poll, Waker},
|
||||
};
|
||||
|
||||
use atomicbox::AtomicOptionBox;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Sink, Stream};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use tokio::sync::{
|
||||
mpsc::{UnboundedReceiver, UnboundedSender},
|
||||
Mutex,
|
||||
};
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tunnel::{SinkError, SinkItem};
|
||||
|
||||
use super::{
|
||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::TunnelWrapper,
|
||||
StreamItem, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
static RING_TUNNEL_CAP: usize = 128;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingTunnel {
|
||||
id: Uuid,
|
||||
ring: ArrayQueue<SinkItem>,
|
||||
closed: AtomicBool,
|
||||
|
||||
wait_for_new_item: AtomicOptionBox<Waker>,
|
||||
wait_for_empty_slot: AtomicOptionBox<Waker>,
|
||||
}
|
||||
|
||||
impl RingTunnel {
|
||||
fn wait_for_new_item<T>(&self, cx: &mut std::task::Context<'_>) -> Poll<T> {
|
||||
let ret = self
|
||||
.wait_for_new_item
|
||||
.swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel);
|
||||
if let Some(old_waker) = ret {
|
||||
assert!(old_waker.will_wake(cx.waker()));
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn wait_for_empty_slot<T>(&self, cx: &mut std::task::Context<'_>) -> Poll<T> {
|
||||
let ret = self
|
||||
.wait_for_empty_slot
|
||||
.swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel);
|
||||
if let Some(old_waker) = ret {
|
||||
assert!(old_waker.will_wake(cx.waker()));
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn notify_new_item(&self) {
|
||||
if let Some(w) = self.wait_for_new_item.take(Ordering::AcqRel) {
|
||||
tracing::trace!(?self.id, "notify new item");
|
||||
w.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn notify_empty_slot(&self) {
|
||||
if let Some(w) = self.wait_for_empty_slot.take(Ordering::AcqRel) {
|
||||
tracing::trace!(?self.id, "notify empty slot");
|
||||
w.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn id(&self) -> &Uuid {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.ring.len()
|
||||
}
|
||||
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.ring.capacity()
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
tracing::info!("close ring tunnel {:?}", self.id);
|
||||
self.closed
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
self.notify_new_item();
|
||||
}
|
||||
|
||||
fn closed(&self) -> bool {
|
||||
self.closed.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn new(cap: usize) -> Self {
|
||||
let id = Uuid::new_v4();
|
||||
Self {
|
||||
id: id.clone(),
|
||||
ring: ArrayQueue::new(cap),
|
||||
closed: AtomicBool::new(false),
|
||||
|
||||
wait_for_new_item: AtomicOptionBox::new(None),
|
||||
wait_for_empty_slot: AtomicOptionBox::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
|
||||
let mut ret = Self::new(cap);
|
||||
ret.id = id;
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingStream {
|
||||
tunnel: Arc<RingTunnel>,
|
||||
}
|
||||
|
||||
impl RingStream {
|
||||
pub fn new(tunnel: Arc<RingTunnel>) -> Self {
|
||||
Self { tunnel }
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for RingStream {
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
let s = self.get_mut();
|
||||
let ret = s.tunnel.ring.pop();
|
||||
match ret {
|
||||
Some(v) => {
|
||||
s.tunnel.notify_empty_slot();
|
||||
return Poll::Ready(Some(Ok(v)));
|
||||
}
|
||||
None => {
|
||||
if s.tunnel.closed() {
|
||||
tracing::warn!("ring recv tunnel {:?} closed", s.tunnel.id());
|
||||
return Poll::Ready(None);
|
||||
} else {
|
||||
tracing::trace!("waiting recv buffer, id: {}", s.tunnel.id());
|
||||
}
|
||||
s.tunnel.wait_for_new_item(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingSink {
|
||||
tunnel: Arc<RingTunnel>,
|
||||
}
|
||||
|
||||
impl Drop for RingSink {
|
||||
fn drop(&mut self) {
|
||||
self.tunnel.close();
|
||||
}
|
||||
}
|
||||
|
||||
impl RingSink {
|
||||
pub fn new(tunnel: Arc<RingTunnel>) -> Self {
|
||||
Self { tunnel }
|
||||
}
|
||||
|
||||
pub fn push_no_check(&self, item: SinkItem) -> Result<(), TunnelError> {
|
||||
if self.tunnel.closed() {
|
||||
return Err(TunnelError::Shutdown);
|
||||
}
|
||||
|
||||
log::trace!("id: {}, send buffer, buf: {:?}", self.tunnel.id(), &item);
|
||||
self.tunnel.ring.push(item).unwrap();
|
||||
self.tunnel.notify_new_item();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn has_empty_slot(&self) -> bool {
|
||||
self.tunnel.len() < self.tunnel.capacity()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<SinkItem> for RingSink {
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
if !self_mut.has_empty_slot() {
|
||||
if self_mut.tunnel.closed() {
|
||||
return Poll::Ready(Err(TunnelError::Shutdown));
|
||||
}
|
||||
self_mut.tunnel.wait_for_empty_slot(cx)
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: std::pin::Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
self.push_no_check(item)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
if self.tunnel.closed() {
|
||||
return Poll::Ready(Err(TunnelError::Shutdown));
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.tunnel.close();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct Connection {
|
||||
client: Arc<RingTunnel>,
|
||||
server: Arc<RingTunnel>,
|
||||
}
|
||||
|
||||
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, UnboundedSender<Arc<Connection>>>>>> =
|
||||
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingTunnelListener {
|
||||
listerner_addr: url::Url,
|
||||
conn_sender: UnboundedSender<Arc<Connection>>,
|
||||
conn_receiver: UnboundedReceiver<Arc<Connection>>,
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
pub fn new(key: url::Url) -> Self {
|
||||
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||
RingTunnelListener {
|
||||
listerner_addr: key,
|
||||
conn_sender,
|
||||
conn_receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tunnel_for_client(conn: Arc<Connection>) -> impl Tunnel {
|
||||
TunnelWrapper::new(
|
||||
RingStream::new(conn.client.clone()),
|
||||
RingSink::new(conn.server.clone()),
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
|
||||
TunnelWrapper::new(
|
||||
RingStream::new(conn.server.clone()),
|
||||
RingSink::new(conn.client.clone()),
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
||||
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for RingTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
log::info!("listen new conn of key: {}", self.listerner_addr);
|
||||
CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.insert(self.get_addr()?, self.conn_sender.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
||||
let my_addr = self.get_addr()?;
|
||||
if let Some(conn) = self.conn_receiver.recv().await {
|
||||
if conn.server.id == my_addr {
|
||||
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||
return Ok(Box::new(get_tunnel_for_server(conn)));
|
||||
} else {
|
||||
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
|
||||
return Err(TunnelError::InternalError(
|
||||
"accept got wrong ring server id".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
return Err(TunnelError::InternalError(
|
||||
"conn receiver stopped".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.listerner_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingTunnelConnector {
|
||||
remote_addr: url::Url,
|
||||
}
|
||||
|
||||
impl RingTunnelConnector {
|
||||
pub fn new(remote_addr: url::Url) -> Self {
|
||||
RingTunnelConnector { remote_addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelConnector for RingTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>(&self.remote_addr, "ring")?;
|
||||
let entry = CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.get(&remote_addr)
|
||||
.unwrap()
|
||||
.clone();
|
||||
log::info!("connecting");
|
||||
let conn = Arc::new(Connection {
|
||||
client: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)),
|
||||
server: Arc::new(RingTunnel::new_with_id(
|
||||
remote_addr.clone(),
|
||||
RING_TUNNEL_CAP,
|
||||
)),
|
||||
});
|
||||
entry
|
||||
.send(conn.clone())
|
||||
.map_err(|_| TunnelError::InternalError("send conn to listner failed".to_owned()))?;
|
||||
Ok(Box::new(get_tunnel_for_client(conn)))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.remote_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
||||
let conn = Arc::new(Connection {
|
||||
client: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)),
|
||||
server: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)),
|
||||
});
|
||||
(
|
||||
Box::new(get_tunnel_for_server(conn.clone())),
|
||||
Box::new(get_tunnel_for_client(conn)),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::StreamExt;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::tunnel::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_pingpong() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id.clone());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_bench() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_close() {
|
||||
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
drop(stunnel);
|
||||
|
||||
let mut stream = ctunnel.split().0;
|
||||
let ret = stream.next().await;
|
||||
assert!(ret.as_ref().is_none(), "expect none, got {:?}", ret);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_ring_stream() {
|
||||
let (_stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
let mut stream = ctunnel.split().0;
|
||||
let task = tokio::spawn(async move {
|
||||
let _ = stream.next().await;
|
||||
});
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
task.abort();
|
||||
let _ = tokio::join!(task);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_stream_recv_timeout() {
|
||||
let (_stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
let mut stream = ctunnel.split().0;
|
||||
let _ = timeout(tokio::time::Duration::from_millis(10), stream.next()).await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed};
|
||||
|
||||
pub struct WindowLatency {
|
||||
latency_us_window: Vec<AtomicU32>,
|
||||
latency_us_window_index: AtomicU32,
|
||||
latency_us_window_size: u32,
|
||||
|
||||
sum: AtomicU32,
|
||||
count: AtomicU32,
|
||||
}
|
||||
|
||||
impl WindowLatency {
|
||||
pub fn new(window_size: u32) -> Self {
|
||||
Self {
|
||||
latency_us_window: (0..window_size).map(|_| AtomicU32::new(0)).collect(),
|
||||
latency_us_window_index: AtomicU32::new(0),
|
||||
latency_us_window_size: window_size,
|
||||
|
||||
sum: AtomicU32::new(0),
|
||||
count: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_latency(&self, latency_us: u32) {
|
||||
let index = self.latency_us_window_index.fetch_add(1, Relaxed);
|
||||
if self.count.load(Relaxed) < self.latency_us_window_size {
|
||||
self.count.fetch_add(1, Relaxed);
|
||||
}
|
||||
|
||||
let index = index % self.latency_us_window_size;
|
||||
let old_lat = self.latency_us_window[index as usize].swap(latency_us, Relaxed);
|
||||
|
||||
if old_lat < latency_us {
|
||||
self.sum.fetch_add(latency_us - old_lat, Relaxed);
|
||||
} else {
|
||||
self.sum.fetch_sub(old_lat - latency_us, Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_latency_us<T: From<u32> + std::ops::Div<Output = T>>(&self) -> T {
|
||||
let count = self.count.load(Relaxed);
|
||||
let sum = self.sum.load(Relaxed);
|
||||
if count == 0 {
|
||||
0.into()
|
||||
} else {
|
||||
(T::from(sum)) / T::from(count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Throughput {
|
||||
tx_bytes: AtomicU64,
|
||||
rx_bytes: AtomicU64,
|
||||
|
||||
tx_packets: AtomicU64,
|
||||
rx_packets: AtomicU64,
|
||||
}
|
||||
|
||||
impl Throughput {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tx_bytes: AtomicU64::new(0),
|
||||
rx_bytes: AtomicU64::new(0),
|
||||
|
||||
tx_packets: AtomicU64::new(0),
|
||||
rx_packets: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tx_bytes(&self) -> u64 {
|
||||
self.tx_bytes.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_bytes(&self) -> u64 {
|
||||
self.rx_bytes.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn tx_packets(&self) -> u64 {
|
||||
self.tx_packets.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_packets(&self) -> u64 {
|
||||
self.rx_packets.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn record_tx_bytes(&self, bytes: u64) {
|
||||
self.tx_bytes.fetch_add(bytes, Relaxed);
|
||||
self.tx_packets.fetch_add(1, Relaxed);
|
||||
}
|
||||
|
||||
pub fn record_rx_bytes(&self, bytes: u64) {
|
||||
self.rx_bytes.fetch_add(bytes, Relaxed);
|
||||
self.rx_packets.fetch_add(1, Relaxed);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
|
||||
use crate::{rpc::TunnelInfo, tunnel::common::setup_sokcet2};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper},
|
||||
Tunnel, TunnelError, TunnelListener,
|
||||
};
|
||||
|
||||
const TCP_MTU_BYTES: usize = 64 * 1024;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelListener {
|
||||
addr: url::Url,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl TcpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelListener {
|
||||
addr,
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for TcpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
let socket = if addr.is_ipv4() {
|
||||
TcpSocket::new_v4()?
|
||||
} else {
|
||||
TcpSocket::new_v6()?
|
||||
};
|
||||
|
||||
socket.set_reuseaddr(true)?;
|
||||
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
// socket.set_reuseport(true)?;
|
||||
socket.bind(addr)?;
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let listener = self.listener.as_ref().unwrap();
|
||||
let (stream, _) = listener.accept().await?;
|
||||
stream.set_nodelay(true).unwrap();
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: self.local_url().into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
};
|
||||
|
||||
let (r, w) = stream.into_split();
|
||||
Ok(Box::new(TunnelWrapper::new(
|
||||
FramedReader::new(r, TCP_MTU_BYTES),
|
||||
FramedWriter::new(w),
|
||||
Some(info),
|
||||
)))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tunnel_with_tcp_stream(
|
||||
stream: TcpStream,
|
||||
remote_url: url::Url,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
stream.set_nodelay(true).unwrap();
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
remote_addr: remote_url.into(),
|
||||
};
|
||||
|
||||
let (r, w) = stream.into_split();
|
||||
Ok(Box::new(TunnelWrapper::new(
|
||||
FramedReader::new(r, TCP_MTU_BYTES),
|
||||
FramedWriter::new(w),
|
||||
Some(info),
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelConnector {
|
||||
addr: url::Url,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl TcpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelConnector {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(addr = ?self.addr, "connect tcp start");
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
let stream = TcpStream::connect(addr).await?;
|
||||
tracing::info!(addr = ?self.addr, "connect tcp succ");
|
||||
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let futures = FuturesUnordered::new();
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr");
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(dst_addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, bind_addr)?;
|
||||
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
futures.push(socket.connect(dst_addr.clone()));
|
||||
}
|
||||
|
||||
let ret = wait_for_connect_futures(futures).await;
|
||||
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for TcpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnel::{
|
||||
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||
TunnelConnector,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_pingpong() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:31011".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31011".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:31012".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31012".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench_with_bind() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn tcp_bench_with_bind_fail() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,838 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use dashmap::DashMap;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use tokio::{
|
||||
net::UdpSocket,
|
||||
sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
|
||||
task::{JoinHandle, JoinSet},
|
||||
};
|
||||
|
||||
use tracing::{instrument, Instrument};
|
||||
|
||||
use crate::{
|
||||
common::join_joinset_background,
|
||||
rpc::TunnelInfo,
|
||||
tunnel::{
|
||||
common::{reserve_buf, TunnelWrapper},
|
||||
packet_def::{UdpPacketType, ZCPacket, ZCPacketType},
|
||||
ring::RingTunnel,
|
||||
},
|
||||
};
|
||||
|
||||
use super::{
|
||||
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
|
||||
packet_def::{UDPTunnelHeader, UDP_TUNNEL_HEADER_SIZE},
|
||||
ring::{RingSink, RingStream},
|
||||
Tunnel, TunnelConnCounter, TunnelError, TunnelListener, TunnelUrl,
|
||||
};
|
||||
|
||||
pub const UDP_DATA_MTU: usize = 65000;
|
||||
|
||||
type UdpCloseEventSender = UnboundedSender<Option<TunnelError>>;
|
||||
type UdpCloseEventReceiver = UnboundedReceiver<Option<TunnelError>>;
|
||||
|
||||
fn new_udp_packet<F>(f: F, udp_body: Option<&mut [u8]>) -> ZCPacket
|
||||
where
|
||||
F: FnOnce(&mut UDPTunnelHeader),
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(
|
||||
UDP_TUNNEL_HEADER_SIZE + udp_body.as_ref().map(|v| v.len()).unwrap_or(0),
|
||||
0,
|
||||
);
|
||||
buf[UDP_TUNNEL_HEADER_SIZE..].copy_from_slice(udp_body.unwrap());
|
||||
|
||||
let mut ret = ZCPacket::new_from_buf(buf, ZCPacketType::UDP);
|
||||
let header = ret.mut_udp_tunnel_header().unwrap();
|
||||
f(header);
|
||||
ret
|
||||
}
|
||||
|
||||
fn new_syn_packet(conn_id: u32, magic: u64) -> ZCPacket {
|
||||
new_udp_packet(
|
||||
|header| {
|
||||
header.msg_type = UdpPacketType::Syn as u8;
|
||||
header.conn_id.set(conn_id);
|
||||
header.len.set(8);
|
||||
},
|
||||
Some(&mut magic.to_le_bytes()),
|
||||
)
|
||||
}
|
||||
|
||||
fn new_sack_packet(conn_id: u32, magic: u64) -> ZCPacket {
|
||||
new_udp_packet(
|
||||
|header| {
|
||||
header.msg_type = UdpPacketType::Sack as u8;
|
||||
header.conn_id.set(conn_id);
|
||||
header.len.set(8);
|
||||
},
|
||||
Some(&mut magic.to_le_bytes()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_hole_punch_packet() -> ZCPacket {
|
||||
// generate a 128 bytes vec with random data
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
let mut buf = vec![0u8; 128];
|
||||
rng.fill(&mut buf[..]);
|
||||
new_udp_packet(
|
||||
|header| {
|
||||
header.msg_type = UdpPacketType::HolePunch as u8;
|
||||
header.conn_id.set(0);
|
||||
header.len.set(0);
|
||||
},
|
||||
Some(&mut buf),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_zcpacket_from_buf(buf: BytesMut) -> Result<ZCPacket, TunnelError> {
|
||||
let dg_size = buf.len();
|
||||
if dg_size < UDP_TUNNEL_HEADER_SIZE {
|
||||
return Err(TunnelError::InvalidPacket(format!(
|
||||
"udp packet size too small: {:?}, packet: {:?}",
|
||||
dg_size, buf
|
||||
)));
|
||||
}
|
||||
|
||||
let zc_packet = ZCPacket::new_from_buf(buf, ZCPacketType::UDP);
|
||||
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||
let payload_len = header.len.get() as usize;
|
||||
if payload_len != dg_size - UDP_TUNNEL_HEADER_SIZE {
|
||||
return Err(TunnelError::InvalidPacket(format!(
|
||||
"udp packet payload len not match: header len: {:?}, real len: {:?}",
|
||||
payload_len, dg_size
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(zc_packet)
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
async fn forward_from_ring_to_udp(
|
||||
mut ring_recv: RingStream,
|
||||
socket: &Arc<UdpSocket>,
|
||||
addr: &SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Option<TunnelError> {
|
||||
tracing::debug!("udp forward from ring to udp");
|
||||
loop {
|
||||
let Some(buf) = ring_recv.next().await else {
|
||||
return None;
|
||||
};
|
||||
let mut packet = match buf {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return Some(e);
|
||||
}
|
||||
};
|
||||
|
||||
let udp_payload_len = packet.udp_payload().len();
|
||||
let header = packet.mut_udp_tunnel_header().unwrap();
|
||||
header.conn_id.set(conn_id);
|
||||
header.len.set(udp_payload_len as u16);
|
||||
header.msg_type = UdpPacketType::Data as u8;
|
||||
|
||||
let buf = packet.into_bytes(ZCPacketType::UDP);
|
||||
tracing::trace!(?udp_payload_len, ?buf, "udp forward from ring to udp");
|
||||
let ret = socket.send_to(&buf, &addr).await;
|
||||
if ret.is_err() {
|
||||
return Some(TunnelError::IOError(ret.unwrap_err()));
|
||||
} else if ret.unwrap() == 0 {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct UdpConnection {
|
||||
socket: Arc<UdpSocket>,
|
||||
conn_id: u32,
|
||||
dst_addr: SocketAddr,
|
||||
|
||||
ring_sender: RingSink,
|
||||
forward_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl UdpConnection {
|
||||
pub fn new(
|
||||
socket: Arc<UdpSocket>,
|
||||
conn_id: u32,
|
||||
dst_addr: SocketAddr,
|
||||
ring_sender: RingSink,
|
||||
ring_recv: RingStream,
|
||||
close_event_sender: UdpCloseEventSender,
|
||||
) -> Self {
|
||||
let s = socket.clone();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
let close_event_sender = close_event_sender;
|
||||
let err = forward_from_ring_to_udp(ring_recv, &s, &dst_addr, conn_id).await;
|
||||
if let Err(e) = close_event_sender.send(err) {
|
||||
tracing::error!(?e, "udp send close event error");
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
socket,
|
||||
conn_id,
|
||||
dst_addr,
|
||||
ring_sender,
|
||||
forward_task,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UdpConnection {
|
||||
fn drop(&mut self) {
|
||||
self.forward_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct UdpTunnelListenerData {
|
||||
local_url: url::Url,
|
||||
socket: Option<Arc<UdpSocket>>,
|
||||
sock_map: Arc<DashMap<SocketAddr, UdpConnection>>,
|
||||
conn_send: Sender<Box<dyn Tunnel>>,
|
||||
close_event_sender: UdpCloseEventSender,
|
||||
}
|
||||
|
||||
impl UdpTunnelListenerData {
|
||||
pub fn new(
|
||||
local_url: url::Url,
|
||||
conn_send: Sender<Box<dyn Tunnel>>,
|
||||
close_event_sender: UdpCloseEventSender,
|
||||
) -> Self {
|
||||
Self {
|
||||
local_url,
|
||||
socket: None,
|
||||
sock_map: Arc::new(DashMap::new()),
|
||||
conn_send,
|
||||
close_event_sender,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_new_connect(self: Self, remote_addr: SocketAddr, zc_packet: ZCPacket) {
|
||||
let udp_payload = zc_packet.udp_payload();
|
||||
if udp_payload.len() != 8 {
|
||||
tracing::warn!(
|
||||
"udp syn packet payload len not match: {:?}, packet: {:?}",
|
||||
udp_payload.len(),
|
||||
zc_packet,
|
||||
);
|
||||
return;
|
||||
}
|
||||
let magic = u64::from_le_bytes(udp_payload[..8].try_into().unwrap());
|
||||
let conn_id = zc_packet.udp_tunnel_header().unwrap().conn_id.get();
|
||||
|
||||
tracing::info!(?conn_id, ?remote_addr, "udp connection accept handling",);
|
||||
let socket = self.socket.as_ref().unwrap().clone();
|
||||
|
||||
let sack_buf = new_sack_packet(conn_id, magic).into_bytes(ZCPacketType::UDP);
|
||||
if let Err(e) = socket.send_to(&sack_buf, remote_addr).await {
|
||||
tracing::error!(?e, "udp send sack packet error");
|
||||
return;
|
||||
}
|
||||
|
||||
let ring_for_send_udp = Arc::new(RingTunnel::new(128));
|
||||
let ring_for_recv_udp = Arc::new(RingTunnel::new(128));
|
||||
tracing::debug!(
|
||||
?ring_for_send_udp,
|
||||
?ring_for_recv_udp,
|
||||
"udp build tunnel for listener"
|
||||
);
|
||||
|
||||
let internal_conn = UdpConnection::new(
|
||||
socket.clone(),
|
||||
conn_id,
|
||||
remote_addr,
|
||||
RingSink::new(ring_for_recv_udp.clone()),
|
||||
RingStream::new(ring_for_send_udp.clone()),
|
||||
self.close_event_sender.clone(),
|
||||
);
|
||||
self.sock_map.insert(remote_addr, internal_conn);
|
||||
|
||||
let conn = Box::new(TunnelWrapper::new(
|
||||
Box::new(RingStream::new(ring_for_recv_udp)),
|
||||
Box::new(RingSink::new(ring_for_send_udp)),
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: self.local_url.clone().into(),
|
||||
remote_addr: url::Url::parse(&format!("udp://{}", remote_addr))
|
||||
.unwrap()
|
||||
.into(),
|
||||
}),
|
||||
));
|
||||
|
||||
if let Err(e) = self.conn_send.send(conn).await {
|
||||
tracing::warn!(?e, "udp send conn to accept channel error");
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_forward_packet(
|
||||
self: &Self,
|
||||
remote_addr: &SocketAddr,
|
||||
conn_id: u32,
|
||||
p: ZCPacket,
|
||||
) -> Result<(), TunnelError> {
|
||||
let Some(conn) = self.sock_map.get(remote_addr) else {
|
||||
return Err(TunnelError::InternalError(
|
||||
"udp connection not found".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
if conn.conn_id != conn_id {
|
||||
return Err(TunnelError::ConnIdNotMatch(conn.conn_id, conn_id));
|
||||
}
|
||||
|
||||
if !conn.ring_sender.has_empty_slot() {
|
||||
return Err(TunnelError::BufferFull);
|
||||
}
|
||||
|
||||
conn.ring_sender.push_no_check(p)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_forward_packet(&self, zc_packet: ZCPacket, addr: &SocketAddr) {
|
||||
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||
if header.msg_type == UdpPacketType::Syn as u8 {
|
||||
tokio::spawn(Self::handle_new_connect(self.clone(), *addr, zc_packet));
|
||||
} else {
|
||||
if let Err(e) = self
|
||||
.try_forward_packet(addr, header.conn_id.get(), zc_packet)
|
||||
.await
|
||||
{
|
||||
tracing::trace!(?e, "udp forward packet error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_forward_task(self: Self) {
|
||||
let socket = self.socket.as_ref().unwrap().clone();
|
||||
let mut buf = BytesMut::new();
|
||||
loop {
|
||||
reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128);
|
||||
let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap();
|
||||
tracing::trace!(
|
||||
"udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||
addr,
|
||||
buf,
|
||||
dg_size
|
||||
);
|
||||
|
||||
let zc_packet = match get_zcpacket_from_buf(buf.split()) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!(?e, "udp get zc packet from buf error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
self.process_forward_packet(zc_packet, &addr).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpTunnelListener {
|
||||
addr: url::Url,
|
||||
socket: Option<Arc<UdpSocket>>,
|
||||
|
||||
conn_recv: Receiver<Box<dyn Tunnel>>,
|
||||
data: UdpTunnelListenerData,
|
||||
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
close_event_recv: UdpCloseEventReceiver,
|
||||
}
|
||||
|
||||
impl UdpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
let (close_event_send, close_event_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
|
||||
Self {
|
||||
addr: addr.clone(),
|
||||
socket: None,
|
||||
conn_recv,
|
||||
data: UdpTunnelListenerData::new(addr, conn_send, close_event_send),
|
||||
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||
close_event_recv,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for UdpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
}
|
||||
|
||||
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
self.data.socket = self.socket.clone();
|
||||
|
||||
self.forward_tasks
|
||||
.lock()
|
||||
.unwrap()
|
||||
.spawn(self.data.clone().do_forward_task());
|
||||
|
||||
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
log::info!("start udp accept: {:?}", self.addr);
|
||||
while let Some(conn) = self.conn_recv.recv().await {
|
||||
return Ok(conn);
|
||||
}
|
||||
return Err(super::TunnelError::InternalError(
|
||||
"udp accept error".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
struct UdpTunnelConnCounter {
|
||||
sock_map: Arc<DashMap<SocketAddr, UdpConnection>>,
|
||||
}
|
||||
|
||||
impl Debug for UdpTunnelConnCounter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("UdpTunnelConnCounter")
|
||||
.field("sock_map_len", &self.sock_map.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
self.sock_map.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(Box::new(UdpTunnelConnCounter {
|
||||
sock_map: self.data.sock_map.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpTunnelConnector {
|
||||
addr: url::Url,
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl UdpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_sack(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
magic: u64,
|
||||
) -> Result<SocketAddr, TunnelError> {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.reserve(UDP_DATA_MTU);
|
||||
|
||||
let (usize, recv_addr) = tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
socket.recv_buf_from(&mut buf),
|
||||
)
|
||||
.await??;
|
||||
let zc_packet = get_zcpacket_from_buf(buf.split())?;
|
||||
if recv_addr != addr {
|
||||
tracing::warn!(?recv_addr, ?addr, ?usize, "udp wait sack addr not match");
|
||||
}
|
||||
|
||||
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||
|
||||
if header.conn_id.get() != conn_id {
|
||||
return Err(super::TunnelError::ConnIdNotMatch(
|
||||
header.conn_id.get(),
|
||||
conn_id,
|
||||
));
|
||||
}
|
||||
|
||||
if header.msg_type != UdpPacketType::Sack as u8 {
|
||||
return Err(TunnelError::InvalidPacket("not sack packet".to_owned()));
|
||||
}
|
||||
|
||||
let payload = zc_packet.udp_payload();
|
||||
if payload.len() != 8 {
|
||||
return Err(TunnelError::InvalidPacket(
|
||||
"udp sack packet payload len not match".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
let sack_magic = u64::from_le_bytes(payload[..8].try_into().unwrap());
|
||||
if sack_magic != magic {
|
||||
return Err(TunnelError::InvalidPacket(
|
||||
"udp sack magic not match".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(recv_addr)
|
||||
}
|
||||
|
||||
async fn wait_sack_loop(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
magic: u64,
|
||||
) -> Result<SocketAddr, super::TunnelError> {
|
||||
loop {
|
||||
let ret = Self::wait_sack(socket, addr, conn_id, magic).await;
|
||||
if ret.is_err() {
|
||||
tracing::debug!(?ret, "udp wait sack error");
|
||||
continue;
|
||||
} else {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_tunnel(
|
||||
&self,
|
||||
socket: UdpSocket,
|
||||
dst_addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let socket = Arc::new(socket);
|
||||
let ring_for_send_udp = Arc::new(RingTunnel::new(128));
|
||||
let ring_for_recv_udp = Arc::new(RingTunnel::new(128));
|
||||
tracing::debug!(
|
||||
?ring_for_send_udp,
|
||||
?ring_for_recv_udp,
|
||||
"udp build tunnel for connector"
|
||||
);
|
||||
|
||||
let (close_event_send, mut close_event_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
// forward from ring to udp
|
||||
let socket_sender = socket.clone();
|
||||
let ring_recv = RingStream::new(ring_for_send_udp.clone());
|
||||
tokio::spawn(async move {
|
||||
let err = forward_from_ring_to_udp(ring_recv, &socket_sender, &dst_addr, conn_id).await;
|
||||
tracing::debug!(?err, "udp forward from ring to udp done");
|
||||
close_event_send.send(err).unwrap();
|
||||
});
|
||||
|
||||
let socket_recv = socket.clone();
|
||||
let ring_sender = RingSink::new(ring_for_recv_udp.clone());
|
||||
tokio::spawn(async move {
|
||||
let mut buf = BytesMut::new();
|
||||
loop {
|
||||
reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128);
|
||||
let ret;
|
||||
tokio::select! {
|
||||
_ = close_event_recv.recv() => {
|
||||
tracing::debug!("connector udp close event");
|
||||
break;
|
||||
}
|
||||
recv_res = socket_recv.recv_buf_from(&mut buf) => ret = Some(recv_res.unwrap()),
|
||||
}
|
||||
let (dg_size, addr) = ret.unwrap();
|
||||
tracing::trace!(
|
||||
"connector udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||
addr,
|
||||
buf,
|
||||
dg_size
|
||||
);
|
||||
|
||||
let zc_packet = match get_zcpacket_from_buf(buf.split()) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!(?e, "connector udp get zc packet from buf error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||
if header.conn_id.get() != conn_id {
|
||||
tracing::trace!(
|
||||
"connector udp conn id not match: {:?}, {:?}",
|
||||
header.conn_id.get(),
|
||||
conn_id
|
||||
);
|
||||
}
|
||||
|
||||
if header.msg_type == UdpPacketType::Data as u8 {
|
||||
if let Err(e) = ring_sender.push_no_check(zc_packet) {
|
||||
tracing::trace!(?e, "udp forward packet error");
|
||||
}
|
||||
}
|
||||
}
|
||||
}.instrument(tracing::info_span!("udp connector forward from udp to ring", ?ring_for_recv_udp)));
|
||||
|
||||
Ok(Box::new(TunnelWrapper::new(
|
||||
Box::new(RingStream::new(ring_for_recv_udp)),
|
||||
Box::new(RingSink::new(ring_for_send_udp)),
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: url::Url::parse(&format!("udp://{}", socket.local_addr()?))
|
||||
.unwrap()
|
||||
.into(),
|
||||
remote_addr: self.addr.clone().into(),
|
||||
}),
|
||||
)))
|
||||
}
|
||||
|
||||
pub async fn try_connect_with_socket(
|
||||
&self,
|
||||
socket: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
log::warn!("udp connect: {:?}", self.addr);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
crate::arch::windows::disable_connection_reset(&socket)?;
|
||||
|
||||
// send syn
|
||||
let conn_id = rand::random();
|
||||
let magic = rand::random();
|
||||
let udp_packet = new_syn_packet(conn_id, magic).into_bytes(ZCPacketType::UDP);
|
||||
let ret = socket.send_to(&udp_packet, &addr).await?;
|
||||
tracing::warn!(?udp_packet, ?ret, "udp send syn");
|
||||
|
||||
// wait sack
|
||||
let recv_addr = tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
Self::wait_sack_loop(&socket, addr, conn_id, magic),
|
||||
)
|
||||
.await??;
|
||||
|
||||
socket.connect(recv_addr).await?;
|
||||
self.build_tunnel(socket, addr, conn_id).await
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
return self.try_connect_with_socket(socket).await;
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(*bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
futures.push(self.try_connect_with_socket(socket));
|
||||
}
|
||||
wait_for_connect_futures(futures).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for UdpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::SinkExt;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
common::global_ctx::tests::get_mock_global_ctx,
|
||||
tunnel::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{
|
||||
get_interface_name_by_ip,
|
||||
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong},
|
||||
},
|
||||
TunnelConnector,
|
||||
},
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_pingpong() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench_with_bind() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn udp_bench_with_bind_fail() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
async fn send_random_data_to_socket(remote_url: url::Url) {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
|
||||
socket
|
||||
.connect(format!(
|
||||
"{}:{}",
|
||||
remote_url.host().unwrap(),
|
||||
remote_url.port().unwrap()
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// get a random 100-len buf
|
||||
loop {
|
||||
let mut buf = vec![0u8; 100];
|
||||
rand::thread_rng().fill(&mut buf[..]);
|
||||
socket.send(&buf).await.unwrap();
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_multiple_conns() {
|
||||
let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap());
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let _lis = tokio::spawn(async move {
|
||||
loop {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
|
||||
}
|
||||
});
|
||||
|
||||
let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||
let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||
|
||||
let t1 = connector1.connect().await.unwrap();
|
||||
let t2 = connector2.connect().await.unwrap();
|
||||
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()),
|
||||
));
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()),
|
||||
));
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()),
|
||||
));
|
||||
|
||||
let sender1 = tokio::spawn(async move {
|
||||
let (mut stream, mut sink) = t1.split();
|
||||
|
||||
for i in 0..10 {
|
||||
sink.send(ZCPacket::new_with_payload("hello1".as_bytes()))
|
||||
.await
|
||||
.unwrap();
|
||||
let recv = stream.next().await.unwrap().unwrap();
|
||||
println!("t1 recv: {:?}, {:?}", recv, i);
|
||||
assert_eq!(recv.payload(), "hello1".as_bytes());
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let sender2 = tokio::spawn(async move {
|
||||
let (mut stream, mut sink) = t2.split();
|
||||
|
||||
for i in 0..10 {
|
||||
sink.send(ZCPacket::new_with_payload("hello2".as_bytes()))
|
||||
.await
|
||||
.unwrap();
|
||||
let recv = stream.next().await.unwrap().unwrap();
|
||||
println!("t2 recv: {:?}, {:?}", recv, i);
|
||||
assert_eq!(recv.payload(), "hello2".as_bytes());
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let _ = tokio::join!(sender1, sender2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bind_multi_ip_to_same_dev() {
|
||||
let global_ctx = get_mock_global_ctx();
|
||||
let ips = global_ctx
|
||||
.get_ip_collector()
|
||||
.collect_ip_addrs()
|
||||
.await
|
||||
.interface_ipv4s;
|
||||
if ips.is_empty() {
|
||||
return;
|
||||
}
|
||||
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap());
|
||||
|
||||
for ip in ips {
|
||||
println!("bind to ip: {:?}, {:?}", ip, bind_dev);
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
|
||||
&format!("udp://{}:11111", ip).parse().unwrap(),
|
||||
"udp",
|
||||
)
|
||||
.unwrap();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)
|
||||
.unwrap();
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,827 @@
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
fmt::{Debug, Formatter},
|
||||
hash::Hasher,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_recursion::async_recursion;
|
||||
use async_trait::async_trait;
|
||||
use boringtun::{
|
||||
noise::{errors::WireGuardError, Tunn, TunnResult},
|
||||
x25519::{PublicKey, StaticSecret},
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use dashmap::DashMap;
|
||||
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||
use rand::RngCore;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
|
||||
use crate::{
|
||||
rpc::TunnelInfo,
|
||||
tunnel::{
|
||||
build_url_from_socket_addr,
|
||||
common::TunnelWrapper,
|
||||
packet_def::{ZCPacket, WG_TUNNEL_HEADER_SIZE},
|
||||
},
|
||||
};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
|
||||
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
|
||||
ring::create_ring_tunnel_pair,
|
||||
Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream,
|
||||
};
|
||||
|
||||
const MAX_PACKET: usize = 65500;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum WgType {
|
||||
// used by easytier peer, need remove/add ip header for in/out wg msg
|
||||
InternalUse,
|
||||
// used by wireguard peer, keep original ip header
|
||||
ExternalUse,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WgConfig {
|
||||
my_secret_key: StaticSecret,
|
||||
my_public_key: PublicKey,
|
||||
|
||||
peer_secret_key: StaticSecret,
|
||||
peer_public_key: PublicKey,
|
||||
|
||||
wg_type: WgType,
|
||||
}
|
||||
|
||||
impl WgConfig {
|
||||
pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self {
|
||||
let mut my_sec = [0u8; 32];
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write(network_name.as_bytes());
|
||||
hasher.write(network_secret.as_bytes());
|
||||
my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..8]);
|
||||
my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..16]);
|
||||
my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..24]);
|
||||
my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
|
||||
let my_secret_key = StaticSecret::from(my_sec);
|
||||
let my_public_key = PublicKey::from(&my_secret_key);
|
||||
let peer_secret_key = StaticSecret::from(my_sec);
|
||||
let peer_public_key = my_public_key.clone();
|
||||
|
||||
WgConfig {
|
||||
my_secret_key,
|
||||
my_public_key,
|
||||
peer_secret_key,
|
||||
peer_public_key,
|
||||
|
||||
wg_type: WgType::InternalUse,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self {
|
||||
let server_cfg = Self::new_from_network_identity("server", server_key_seed);
|
||||
let client_cfg = Self::new_from_network_identity("client", client_key_seed);
|
||||
Self {
|
||||
my_secret_key: server_cfg.my_secret_key,
|
||||
my_public_key: server_cfg.my_public_key,
|
||||
peer_secret_key: client_cfg.my_secret_key,
|
||||
peer_public_key: client_cfg.my_public_key,
|
||||
|
||||
wg_type: WgType::ExternalUse,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn my_secret_key(&self) -> &[u8] {
|
||||
self.my_secret_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn peer_secret_key(&self) -> &[u8] {
|
||||
self.peer_secret_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn my_public_key(&self) -> &[u8] {
|
||||
self.my_public_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn peer_public_key(&self) -> &[u8] {
|
||||
self.peer_public_key.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct WgPeerData {
|
||||
udp: Arc<UdpSocket>, // only for send
|
||||
endpoint: SocketAddr,
|
||||
tunn: Arc<Mutex<Tunn>>,
|
||||
wg_type: WgType,
|
||||
stopped: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Debug for WgPeerData {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WgPeerData")
|
||||
.field("endpoint", &self.endpoint)
|
||||
.field("local", &self.udp.local_addr())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl WgPeerData {
|
||||
#[tracing::instrument]
|
||||
async fn handle_one_packet_from_me(
|
||||
&self,
|
||||
mut zc_packet: ZCPacket,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
|
||||
let packet = if matches!(self.wg_type, WgType::InternalUse) {
|
||||
Self::fill_ip_header(&mut zc_packet);
|
||||
zc_packet.into_bytes(ZCPacketType::WG)
|
||||
} else {
|
||||
zc_packet.into_bytes(ZCPacketType::WG)
|
||||
};
|
||||
tracing::trace!(?packet, "Sending packet to peer");
|
||||
|
||||
let encapsulate_result = {
|
||||
let mut peer = self.tunn.lock().await;
|
||||
peer.encapsulate(&packet, &mut send_buf)
|
||||
};
|
||||
|
||||
tracing::trace!(
|
||||
?encapsulate_result,
|
||||
"Received {} bytes from me",
|
||||
packet.len()
|
||||
);
|
||||
|
||||
match encapsulate_result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
self.udp
|
||||
.send_to(packet, self.endpoint)
|
||||
.await
|
||||
.context("Failed to send encrypted IP packet to WireGuard endpoint.")?;
|
||||
tracing::debug!(
|
||||
"Sent {} bytes to WireGuard endpoint (encrypted IP packet)",
|
||||
packet.len()
|
||||
);
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!("Failed to encapsulate IP packet: {:?}", e);
|
||||
}
|
||||
TunnResult::Done => {
|
||||
// Ignored
|
||||
}
|
||||
other => {
|
||||
tracing::error!(
|
||||
"Unexpected WireGuard state during encapsulation: {:?}",
|
||||
other
|
||||
);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
|
||||
/// decapsulates them, and dispatches newly received IP packets.
|
||||
#[tracing::instrument(skip(sink))]
|
||||
pub async fn handle_one_packet_from_peer<S: ZCPacketSink + Unpin>(
|
||||
&self,
|
||||
mut sink: S,
|
||||
recv_buf: &[u8],
|
||||
) {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
let data = &recv_buf[..];
|
||||
let decapsulate_result = {
|
||||
let mut peer = self.tunn.lock().await;
|
||||
peer.decapsulate(None, data, &mut send_buf)
|
||||
};
|
||||
|
||||
tracing::debug!("Decapsulation result: {:?}", decapsulate_result);
|
||||
|
||||
match decapsulate_result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut peer = self.tunn.lock().await;
|
||||
loop {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
match peer.decapsulate(None, &[], &mut send_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
_ => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => {
|
||||
tracing::debug!(
|
||||
?packet,
|
||||
"receive IP packet from peer: {} bytes",
|
||||
packet.len()
|
||||
);
|
||||
let mut b = BytesMut::new();
|
||||
if matches!(self.wg_type, WgType::InternalUse) {
|
||||
b.resize(WG_TUNNEL_HEADER_SIZE, 0);
|
||||
b.extend_from_slice(self.remove_ip_header(packet, packet[0] >> 4 == 4));
|
||||
} else {
|
||||
b.extend_from_slice(packet);
|
||||
};
|
||||
let zc_packet = ZCPacket::new_from_buf(b, ZCPacketType::WG);
|
||||
tracing::trace!(?zc_packet, "forward zc_packet to sink");
|
||||
let ret = sink.send(zc_packet).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("Failed to send packet to tunnel: {:?}", ret);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
"Unexpected WireGuard state during decapsulation: {:?}",
|
||||
decapsulate_result
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
#[async_recursion]
|
||||
async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () {
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
tracing::debug!(
|
||||
"Sending routine packet of {} bytes to WireGuard endpoint",
|
||||
packet.len()
|
||||
);
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to send routine packet to WireGuard endpoint: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired) => {
|
||||
tracing::warn!("Wireguard handshake has expired!");
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let result = self
|
||||
.tunn
|
||||
.lock()
|
||||
.await
|
||||
.format_handshake_initiation(&mut buf[..], false);
|
||||
|
||||
self.handle_routine_tun_result(result).await
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to prepare routine packet for WireGuard endpoint: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
TunnResult::Done => {
|
||||
// Sleep for a bit
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("Unexpected WireGuard routine task state: {:?}", other);
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||
pub async fn routine_task(self) {
|
||||
loop {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) };
|
||||
self.handle_routine_tun_result(tun_result).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_ip_header(zc_packet: &mut ZCPacket) {
|
||||
let len = zc_packet.payload_len() + PEER_MANAGER_HEADER_SIZE;
|
||||
let ip_header = &mut zc_packet.mut_wg_tunnel_header().unwrap().ipv4_header;
|
||||
ip_header[0] = 0x45;
|
||||
ip_header[1] = 0;
|
||||
ip_header[2..4].copy_from_slice(&((len + 20) as u16).to_be_bytes());
|
||||
ip_header[4..6].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[6..8].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[8] = 64;
|
||||
ip_header[9] = 0;
|
||||
ip_header[10..12].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[12..16].copy_from_slice(&0u32.to_be_bytes());
|
||||
ip_header[16..20].copy_from_slice(&0u32.to_be_bytes());
|
||||
}
|
||||
|
||||
fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] {
|
||||
if is_v4 {
|
||||
return &packet[20..];
|
||||
} else {
|
||||
return &packet[40..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct WgPeer {
|
||||
udp: Arc<UdpSocket>, // only for send
|
||||
config: WgConfig,
|
||||
endpoint: SocketAddr,
|
||||
|
||||
sink: std::sync::Mutex<Option<Pin<Box<dyn ZCPacketSink>>>>,
|
||||
|
||||
data: Option<WgPeerData>,
|
||||
tasks: JoinSet<()>,
|
||||
|
||||
access_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl WgPeer {
|
||||
fn new(udp: Arc<UdpSocket>, config: WgConfig, endpoint: SocketAddr) -> Self {
|
||||
WgPeer {
|
||||
udp,
|
||||
config,
|
||||
endpoint,
|
||||
|
||||
sink: std::sync::Mutex::new(None),
|
||||
|
||||
data: None,
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
access_time: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_packet_from_me<S: ZCPacketStream + Unpin>(mut stream: S, data: WgPeerData) {
|
||||
while let Some(Ok(packet)) = stream.next().await {
|
||||
let ret = data.handle_one_packet_from_me(packet).await;
|
||||
if let Err(e) = ret {
|
||||
tracing::error!("Failed to handle packet from me: {}", e);
|
||||
}
|
||||
}
|
||||
data.stopped
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
|
||||
self.access_time = std::time::Instant::now();
|
||||
tracing::trace!("Received {} bytes from peer", packet.len());
|
||||
let data = self.data.as_ref().unwrap();
|
||||
// TODO: improve this
|
||||
let mut sink = self.sink.lock().unwrap().take().unwrap();
|
||||
data.handle_one_packet_from_peer(&mut sink, packet).await;
|
||||
self.sink.lock().unwrap().replace(sink);
|
||||
}
|
||||
|
||||
fn start_and_get_tunnel(&mut self) -> Box<dyn Tunnel> {
|
||||
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
|
||||
let (stream, sink) = stunnel.split();
|
||||
|
||||
let data = WgPeerData {
|
||||
udp: self.udp.clone(),
|
||||
endpoint: self.endpoint,
|
||||
tunn: Arc::new(Mutex::new(
|
||||
Tunn::new(
|
||||
self.config.my_secret_key.clone(),
|
||||
self.config.peer_public_key.clone(),
|
||||
None,
|
||||
None,
|
||||
rand::thread_rng().next_u32(),
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
)),
|
||||
wg_type: self.config.wg_type.clone(),
|
||||
stopped: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
self.data = Some(data.clone());
|
||||
self.sink.lock().unwrap().replace(sink);
|
||||
|
||||
self.tasks
|
||||
.spawn(Self::handle_packet_from_me(stream, data.clone()));
|
||||
self.tasks.spawn(data.routine_task());
|
||||
|
||||
ctunnel
|
||||
}
|
||||
|
||||
fn stopped(&self) -> bool {
|
||||
self.data
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.stopped
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>;
|
||||
type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver<Box<dyn Tunnel>>;
|
||||
|
||||
pub struct WgTunnelListener {
|
||||
addr: url::Url,
|
||||
config: WgConfig,
|
||||
|
||||
udp: Option<Arc<UdpSocket>>,
|
||||
conn_recv: ConnReceiver,
|
||||
conn_send: Option<ConnSender>,
|
||||
|
||||
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl WgTunnelListener {
|
||||
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
WgTunnelListener {
|
||||
addr,
|
||||
config,
|
||||
|
||||
udp: None,
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
|
||||
wg_peer_map: Arc::new(DashMap::new()),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_udp_socket(&self) -> Arc<UdpSocket> {
|
||||
self.udp.as_ref().unwrap().clone()
|
||||
}
|
||||
|
||||
async fn handle_udp_incoming(
|
||||
socket: Arc<UdpSocket>,
|
||||
config: WgConfig,
|
||||
conn_sender: ConnSender,
|
||||
peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||
) {
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
let peer_map_clone = peer_map.clone();
|
||||
tasks.spawn(async move {
|
||||
loop {
|
||||
peer_map_clone
|
||||
.retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped());
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
loop {
|
||||
let Ok((n, addr)) = socket.recv_from(&mut buf).await else {
|
||||
tracing::error!("Failed to receive from UDP socket");
|
||||
break;
|
||||
};
|
||||
|
||||
let data = &buf[..n];
|
||||
tracing::trace!(?n, ?addr, "Received bytes from peer");
|
||||
|
||||
if !peer_map.contains_key(&addr) {
|
||||
tracing::info!("New peer: {}", addr);
|
||||
let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone());
|
||||
let (stream, sink) = wg.start_and_get_tunnel().split();
|
||||
let tunnel = Box::new(TunnelWrapper::new(
|
||||
stream,
|
||||
sink,
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "wg".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(
|
||||
&socket.local_addr().unwrap().to_string(),
|
||||
"wg",
|
||||
)
|
||||
.into(),
|
||||
remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(),
|
||||
}),
|
||||
));
|
||||
if let Err(e) = conn_sender.send(tunnel) {
|
||||
tracing::error!("Failed to send tunnel to conn_sender: {}", e);
|
||||
}
|
||||
peer_map.insert(addr, wg);
|
||||
}
|
||||
|
||||
let mut peer = peer_map.get_mut(&addr).unwrap();
|
||||
peer.handle_packet_from_peer(data).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for WgTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg")?;
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
}
|
||||
|
||||
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
self.tasks.spawn(Self::handle_udp_incoming(
|
||||
self.get_udp_socket(),
|
||||
self.config.clone(),
|
||||
self.conn_send.take().unwrap(),
|
||||
self.wg_peer_map.clone(),
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
while let Some(tunnel) = self.conn_recv.recv().await {
|
||||
tracing::info!(?tunnel, "Accepted tunnel");
|
||||
return Ok(tunnel);
|
||||
}
|
||||
Err(TunnelError::Shutdown)
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WgTunnelConnector {
|
||||
addr: url::Url,
|
||||
config: WgConfig,
|
||||
udp: Option<Arc<UdpSocket>>,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl Debug for WgTunnelConnector {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WgTunnelConnector")
|
||||
.field("addr", &self.addr)
|
||||
.field("udp", &self.udp)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl WgTunnelConnector {
|
||||
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||
WgTunnelConnector {
|
||||
addr,
|
||||
config,
|
||||
udp: None,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
|
||||
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
|
||||
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
handshake_init.into()
|
||||
}
|
||||
|
||||
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
|
||||
assert!(
|
||||
matches!(keepalive, TunnResult::WriteToNetwork(_)),
|
||||
"Failed to parse handshake response, {:?}",
|
||||
keepalive
|
||||
);
|
||||
|
||||
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
keepalive.into()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(config))]
|
||||
async fn connect_with_socket(
|
||||
addr_url: url::Url,
|
||||
config: WgConfig,
|
||||
udp: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&addr_url, "wg")?;
|
||||
tracing::warn!("wg connect: {:?}", addr);
|
||||
let local_addr = udp.local_addr().unwrap().to_string();
|
||||
|
||||
let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr);
|
||||
let tunnel = wg_peer.start_and_get_tunnel();
|
||||
|
||||
let data = wg_peer.data.as_ref().unwrap().clone();
|
||||
let mut sink = wg_peer.sink.lock().unwrap().take().unwrap();
|
||||
wg_peer.tasks.spawn(async move {
|
||||
loop {
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap();
|
||||
if recv_addr != addr {
|
||||
continue;
|
||||
}
|
||||
data.handle_one_packet_from_peer(&mut sink, &buf[..n]).await;
|
||||
}
|
||||
});
|
||||
|
||||
let (stream, sink) = tunnel.split();
|
||||
let ret = Box::new(TunnelWrapper::new_with_associate_data(
|
||||
stream,
|
||||
sink,
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "wg".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(),
|
||||
remote_addr: addr_url.to_string(),
|
||||
}),
|
||||
Some(Box::new(wg_peer)),
|
||||
));
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for WgTunnelConnector {
|
||||
#[tracing::instrument]
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let bind_addrs = if self.bind_addrs.is_empty() {
|
||||
vec!["0.0.0.0:0".parse().unwrap()]
|
||||
} else {
|
||||
self.bind_addrs.clone()
|
||||
};
|
||||
let futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in bind_addrs.into_iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
|
||||
futures.push(Self::connect_with_socket(
|
||||
self.addr.clone(),
|
||||
self.config.clone(),
|
||||
socket,
|
||||
));
|
||||
}
|
||||
|
||||
wait_for_connect_futures(futures).await
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
use crate::tunnel::{
|
||||
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||
TunnelConnector,
|
||||
};
|
||||
use boringtun::*;
|
||||
|
||||
pub fn create_wg_config() -> (WgConfig, WgConfig) {
|
||||
let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||
let my_public_key = x25519::PublicKey::from(&my_secret_key);
|
||||
|
||||
let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||
let their_public_key = x25519::PublicKey::from(&their_secret_key);
|
||||
|
||||
let server_cfg = WgConfig {
|
||||
my_secret_key: my_secret_key.clone(),
|
||||
my_public_key,
|
||||
peer_secret_key: their_secret_key.clone(),
|
||||
peer_public_key: their_public_key.clone(),
|
||||
wg_type: WgType::InternalUse,
|
||||
};
|
||||
|
||||
let client_cfg = WgConfig {
|
||||
my_secret_key: their_secret_key,
|
||||
my_public_key: their_public_key,
|
||||
peer_secret_key: my_secret_key,
|
||||
peer_public_key: my_public_key,
|
||||
wg_type: WgType::InternalUse,
|
||||
};
|
||||
|
||||
(server_cfg, client_cfg)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_pingpong() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg);
|
||||
let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_bench() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg);
|
||||
let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_bench_with_bind() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg);
|
||||
let mut connector =
|
||||
WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg);
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn wg_bench_with_bind_fail() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg);
|
||||
let mut connector =
|
||||
WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg);
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_server_erase_from_map_after_close() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let mut listener =
|
||||
WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg);
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
const CONN_COUNT: usize = 10;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut tunnels = vec![];
|
||||
for _ in 0..CONN_COUNT {
|
||||
let mut connector = WgTunnelConnector::new(
|
||||
"wg://127.0.0.1:5595".parse().unwrap(),
|
||||
client_cfg.clone(),
|
||||
);
|
||||
let ret = connector.connect().await;
|
||||
assert!(ret.is_ok());
|
||||
let t = ret.unwrap();
|
||||
let (_stream, mut sink) = t.split();
|
||||
sink.send(ZCPacket::new_with_payload("payload".as_bytes()))
|
||||
.await
|
||||
.unwrap();
|
||||
tunnels.push(t);
|
||||
}
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
});
|
||||
|
||||
for _ in 0..CONN_COUNT {
|
||||
println!("accepting");
|
||||
let conn = listener.accept().await;
|
||||
let (mut stream, _sink) = conn.unwrap().split();
|
||||
let packet = stream.next().await.unwrap().unwrap();
|
||||
assert_eq!("payload".as_bytes(), packet.payload());
|
||||
println!("accepting drop");
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
|
||||
assert_eq!(0, listener.wg_peer_map.len());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user