adapt tun device to zerocopy (#57)

This commit is contained in:
Sijie.Sun
2024-04-25 23:25:37 +08:00
committed by GitHub
parent 3467890270
commit 57c9f11371
15 changed files with 405 additions and 150 deletions
+37 -36
View File
@@ -1,13 +1,13 @@
use std::borrow::BorrowMut;
use std::net::Ipv4Addr;
use std::pin::Pin;
use std::sync::{Arc, Weak};
use anyhow::Context;
use futures::{SinkExt, StreamExt};
use pnet::packet::ethernet::EthernetPacket;
use pnet::packet::ipv4::Ipv4Packet;
use bytes::BytesMut;
use tokio::{sync::Mutex, task::JoinSet};
use tonic::transport::Server;
@@ -30,11 +30,14 @@ use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::tunnel::packet_def::ZCPacket;
use crate::tunnel::{ZCPacketSink, ZCPacketStream};
use crate::vpn_portal::{self, VpnPortal};
use super::listeners::ListenerManager;
use super::virtual_nic;
use crate::common::ifcfg::IfConfiguerTrait;
#[derive(Clone)]
struct IpProxy {
tcp_proxy: Arc<TcpProxy>,
@@ -156,8 +159,8 @@ impl Instance {
self.conn_manager.clone()
}
async fn do_forward_nic_to_peers_ipv4(ret: BytesMut, mgr: &PeerManager) {
if let Some(ipv4) = Ipv4Packet::new(&ret) {
async fn do_forward_nic_to_peers_ipv4(ret: ZCPacket, mgr: &PeerManager) {
if let Some(ipv4) = Ipv4Packet::new(ret.payload()) {
if ipv4.get_version() != 4 {
tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4);
return;
@@ -169,9 +172,7 @@ impl Instance {
);
// TODO: use zero-copy
let send_ret = mgr
.send_msg_ipv4(ZCPacket::new_with_payload(ret.as_ref()), dst_ipv4)
.await;
let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await;
if send_ret.is_err() {
tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed")
}
@@ -180,23 +181,23 @@ impl Instance {
}
}
async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) {
if let Some(eth) = EthernetPacket::new(&ret) {
log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype());
Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await;
} else {
log::warn!("not ipv4 packet: {:?}", ret);
}
}
// async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) {
// if let Some(eth) = EthernetPacket::new(&ret) {
// log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype());
// Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await;
// } else {
// log::warn!("not ipv4 packet: {:?}", ret);
// }
// }
fn do_forward_nic_to_peers(&mut self) -> Result<(), Error> {
fn do_forward_nic_to_peers(
&mut self,
mut stream: Pin<Box<dyn ZCPacketStream>>,
) -> Result<(), Error> {
// read from nic and write to corresponding tunnel
let nic = self.virtual_nic.as_ref().unwrap();
let nic = nic.clone();
let mgr = self.peer_manager.clone();
self.tasks.spawn(async move {
let mut stream = nic.pin_recv_stream();
while let Some(ret) = stream.next().await {
if ret.is_err() {
log::error!("read from nic failed: {:?}", ret);
@@ -212,21 +213,17 @@ impl Instance {
fn do_forward_peers_to_nic(
tasks: &mut JoinSet<()>,
nic: Arc<virtual_nic::VirtualNic>,
mut sink: Pin<Box<dyn ZCPacketSink>>,
channel: Option<PacketRecvChanReceiver>,
) {
tasks.spawn(async move {
let mut send = nic.pin_send_stream();
let mut channel = channel.unwrap();
while let Some(packet) = channel.recv().await {
tracing::trace!(
"[USER_PACKET] forward packet from peers to nic. packet: {:?}",
packet
);
let mut b = BytesMut::new();
b.extend_from_slice(packet.payload());
let ret = send.send(b.freeze()).await;
let ret = sink.send(packet).await;
if ret.is_err() {
panic!("do_forward_tunnel_to_nic");
}
@@ -244,19 +241,19 @@ impl Instance {
}
async fn prepare_tun_device(&mut self) -> Result<(), Error> {
let nic = virtual_nic::VirtualNic::new(self.get_global_ctx())
.create_dev()
.await?;
let mut nic = virtual_nic::VirtualNic::new(self.get_global_ctx());
let tunnel = nic.create_dev().await?;
self.global_ctx
.issue_event(GlobalCtxEvent::TunDeviceReady(nic.ifname().to_string()));
let (stream, sink) = tunnel.split();
self.virtual_nic = Some(Arc::new(nic));
self.do_forward_nic_to_peers().unwrap();
self.do_forward_nic_to_peers(stream).unwrap();
Self::do_forward_peers_to_nic(
self.tasks.borrow_mut(),
self.virtual_nic.as_ref().unwrap().clone(),
sink,
self.peer_packet_receiver.take(),
);
@@ -438,6 +435,8 @@ impl Instance {
let global_ctx = self.global_ctx.clone();
let net_ns = self.global_ctx.net_ns.clone();
let nic = self.virtual_nic.as_ref().unwrap().clone();
let ifcfg = nic.get_ifcfg();
let ifname = nic.ifname().to_owned();
self.tasks.spawn(async move {
let mut cur_proxy_cidrs = vec![];
@@ -464,10 +463,9 @@ impl Instance {
}
let _g = net_ns.guard();
let ret = nic
.get_ifcfg()
let ret = ifcfg
.remove_ipv4_route(
nic.ifname(),
ifname.as_str(),
cidr.first_address(),
cidr.network_length(),
)
@@ -487,9 +485,12 @@ impl Instance {
continue;
}
let _g = net_ns.guard();
let ret = nic
.get_ifcfg()
.add_ipv4_route(nic.ifname(), cidr.first_address(), cidr.network_length())
let ret = ifcfg
.add_ipv4_route(
ifname.as_str(),
cidr.first_address(),
cidr.network_length(),
)
.await;
if ret.is_err() {
+221 -63
View File
@@ -1,21 +1,207 @@
use std::{net::Ipv4Addr, pin::Pin};
use std::{
io,
net::Ipv4Addr,
pin::Pin,
task::{Context, Poll},
};
use crate::{
common::{
error::Result,
error::Error,
global_ctx::ArcGlobalCtx,
ifcfg::{IfConfiger, IfConfiguerTrait},
},
tunnels::{
codec::BytesCodec, common::FramedTunnel, DatagramSink, DatagramStream, Tunnel, TunnelError,
tunnel::{
common::{FramedWriter, TunnelWrapper, ZCPacketToBytes},
packet_def::ZCPacket,
StreamItem, Tunnel, TunnelError,
},
};
use futures::{SinkExt, StreamExt};
use tokio_util::{bytes::Bytes, codec::Framed};
use tun::Device;
use byteorder::WriteBytesExt as _;
use futures::{lock::BiLock, ready, Stream};
use pin_project_lite::pin_project;
use tokio::io::AsyncWrite;
use tokio_util::{bytes::Bytes, io::poll_read_buf};
use tun::{create_as_async, AsyncDevice, Configuration, Device as _, Layer};
use zerocopy::{NativeEndian, NetworkEndian};
use super::tun_codec::{TunPacket, TunPacketCodec};
pin_project! {
pub struct TunStream {
#[pin]
l: BiLock<AsyncDevice>,
cur_packet: Option<ZCPacket>,
}
}
impl TunStream {
pub fn new(l: BiLock<AsyncDevice>) -> Self {
Self {
l,
cur_packet: None,
}
}
}
impl Stream for TunStream {
type Item = StreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<StreamItem>> {
let self_mut = self.project();
let mut g = ready!(self_mut.l.poll_lock(cx));
if self_mut.cur_packet.is_none() {
*self_mut.cur_packet = Some(ZCPacket::new_with_reserved_payload(2048));
}
let cur_packet = self_mut.cur_packet.as_mut().unwrap();
match ready!(poll_read_buf(
g.as_pin_mut(),
cx,
&mut cur_packet.mut_inner()
)) {
Ok(0) => Poll::Ready(None),
Ok(_n) => Poll::Ready(Some(Ok(self_mut.cur_packet.take().unwrap()))),
Err(err) => {
println!("tun stream error: {:?}", err);
Poll::Ready(None)
}
}
}
}
#[derive(Debug, Clone, Copy, Default)]
enum PacketProtocol {
#[default]
IPv4,
IPv6,
Other(u8),
}
// Note: the protocol in the packet information header is platform dependent.
impl PacketProtocol {
#[cfg(any(target_os = "linux", target_os = "android"))]
fn into_pi_field(self) -> Result<u16, io::Error> {
use nix::libc;
match self {
PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16),
PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16),
PacketProtocol::Other(_) => Err(io::Error::new(
io::ErrorKind::Other,
"neither an IPv4 nor IPv6 packet",
)),
}
}
#[cfg(any(target_os = "macos", target_os = "ios"))]
fn into_pi_field(self) -> Result<u16, io::Error> {
use nix::libc;
match self {
PacketProtocol::IPv4 => Ok(libc::PF_INET as u16),
PacketProtocol::IPv6 => Ok(libc::PF_INET6 as u16),
PacketProtocol::Other(_) => Err(io::Error::new(
io::ErrorKind::Other,
"neither an IPv4 nor IPv6 packet",
)),
}
}
#[cfg(target_os = "windows")]
fn into_pi_field(self) -> Result<u16, io::Error> {
unimplemented!()
}
}
/// Infer the protocol based on the first nibble in the packet buffer.
fn infer_proto(buf: &[u8]) -> PacketProtocol {
match buf[0] >> 4 {
4 => PacketProtocol::IPv4,
6 => PacketProtocol::IPv6,
p => PacketProtocol::Other(p),
}
}
struct TunZCPacketToBytes {
has_packet_info: bool,
}
impl TunZCPacketToBytes {
pub fn new(has_packet_info: bool) -> Self {
Self { has_packet_info }
}
pub fn fill_packet_info(&self, mut buf: &mut [u8]) -> Result<(), io::Error> {
// flags is always 0
buf.write_u16::<NativeEndian>(0)?;
// write the protocol as network byte order
buf.write_u16::<NetworkEndian>(infer_proto(&buf).into_pi_field()?)?;
Ok(())
}
}
impl ZCPacketToBytes for TunZCPacketToBytes {
fn into_bytes(&self, zc_packet: ZCPacket) -> Result<Bytes, TunnelError> {
let payload_offset = zc_packet.payload_offset();
let mut inner = zc_packet.inner();
// we have peer manager header, so payload offset must larger than 4
assert!(payload_offset >= 4);
let ret = if self.has_packet_info {
let mut inner = inner.split_off(payload_offset - 4);
self.fill_packet_info(&mut inner[0..4])?;
inner
} else {
inner.split_off(payload_offset)
};
tracing::debug!(?ret, ?payload_offset, "convert zc packet to tun packet");
Ok(ret.into())
}
}
pin_project! {
pub struct TunAsyncWrite {
#[pin]
l: BiLock<AsyncDevice>,
}
}
impl AsyncWrite for TunAsyncWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let self_mut = self.project();
let mut g = ready!(self_mut.l.poll_lock(cx));
g.as_pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let self_mut = self.project();
let mut g = ready!(self_mut.l.poll_lock(cx));
g.as_pin_mut().poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let self_mut = self.project();
let mut g = ready!(self_mut.l.poll_lock(cx));
g.as_pin_mut().poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let self_mut = self.project();
let mut g = ready!(self_mut.l.poll_lock(cx));
g.as_pin_mut().poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
true
}
}
pub struct VirtualNic {
dev_name: String,
@@ -24,7 +210,6 @@ pub struct VirtualNic {
global_ctx: ArcGlobalCtx,
ifname: Option<String>,
tun: Option<Box<dyn Tunnel>>,
ifcfg: Box<dyn IfConfiguerTrait + Send + Sync + 'static>,
}
@@ -35,25 +220,24 @@ impl VirtualNic {
queue_num: 1,
global_ctx,
ifname: None,
tun: None,
ifcfg: Box::new(IfConfiger {}),
}
}
pub fn set_dev_name(mut self, dev_name: &str) -> Result<Self> {
pub fn set_dev_name(mut self, dev_name: &str) -> Result<Self, Error> {
self.dev_name = dev_name.to_owned();
Ok(self)
}
pub fn set_queue_num(mut self, queue_num: usize) -> Result<Self> {
pub fn set_queue_num(mut self, queue_num: usize) -> Result<Self, Error> {
self.queue_num = queue_num;
Ok(self)
}
async fn create_dev_ret_err(&mut self) -> Result<()> {
let mut config = tun::Configuration::default();
async fn create_dev_ret_err(&mut self) -> Result<Box<dyn Tunnel>, Error> {
let mut config = Configuration::default();
let has_packet_info = cfg!(target_os = "macos");
config.layer(tun::Layer::L3);
config.layer(Layer::L3);
#[cfg(target_os = "linux")]
{
@@ -71,61 +255,42 @@ impl VirtualNic {
let dev = {
let _g = self.global_ctx.net_ns.guard();
tun::create_as_async(&config)?
create_as_async(&config)?
};
let ifname = dev.get_ref().name()?;
self.ifcfg.wait_interface_show(ifname.as_str()).await?;
let ft: Box<dyn Tunnel> = if has_packet_info {
let framed = Framed::new(dev, TunPacketCodec::new(true, 2500));
let (sink, stream) = framed.split();
let (a, b) = BiLock::new(dev);
let new_stream = stream.map(|item| match item {
Ok(item) => Ok(item.into_bytes_mut()),
Err(err) => {
println!("tun stream error: {:?}", err);
Err(TunnelError::TunError(err.to_string()))
}
});
let new_sink = Box::pin(sink.with(|item: Bytes| async move {
if false {
return Err(TunnelError::TunError("tun sink error".to_owned()));
}
Ok(TunPacket::new(super::tun_codec::TunPacketBuffer::Bytes(
item,
)))
}));
Box::new(FramedTunnel::new(new_stream, new_sink, None))
} else {
let framed = Framed::new(dev, BytesCodec::new(2500));
let (sink, stream) = framed.split();
Box::new(FramedTunnel::new(stream, sink, None))
};
let ft = TunnelWrapper::new(
TunStream::new(a),
FramedWriter::new_with_converter(
TunAsyncWrite { l: b },
TunZCPacketToBytes::new(has_packet_info),
),
None,
);
self.ifname = Some(ifname.to_owned());
self.tun = Some(ft);
Ok(())
Ok(Box::new(ft))
}
pub async fn create_dev(mut self) -> Result<Self> {
self.create_dev_ret_err().await?;
Ok(self)
pub async fn create_dev(&mut self) -> Result<Box<dyn Tunnel>, Error> {
self.create_dev_ret_err().await
}
pub fn ifname(&self) -> &str {
self.ifname.as_ref().unwrap().as_str()
}
pub async fn link_up(&self) -> Result<()> {
pub async fn link_up(&self) -> Result<(), Error> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg.set_link_status(self.ifname(), true).await?;
Ok(())
}
pub async fn add_route(&self, address: Ipv4Addr, cidr: u8) -> Result<()> {
pub async fn add_route(&self, address: Ipv4Addr, cidr: u8) -> Result<(), Error> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.add_ipv4_route(self.ifname(), address, cidr)
@@ -133,13 +298,13 @@ impl VirtualNic {
Ok(())
}
pub async fn remove_ip(&self, ip: Option<Ipv4Addr>) -> Result<()> {
pub async fn remove_ip(&self, ip: Option<Ipv4Addr>) -> Result<(), Error> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg.remove_ip(self.ifname(), ip).await?;
Ok(())
}
pub async fn add_ip(&self, ip: Ipv4Addr, cidr: i32) -> Result<()> {
pub async fn add_ip(&self, ip: Ipv4Addr, cidr: i32) -> Result<(), Error> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.add_ipv4_ip(self.ifname(), ip, cidr as u8)
@@ -147,16 +312,8 @@ impl VirtualNic {
Ok(())
}
pub fn pin_recv_stream(&self) -> Pin<Box<dyn DatagramStream>> {
self.tun.as_ref().unwrap().pin_stream()
}
pub fn pin_send_stream(&self) -> Pin<Box<dyn DatagramSink>> {
self.tun.as_ref().unwrap().pin_sink()
}
pub fn get_ifcfg(&self) -> &dyn IfConfiguerTrait {
self.ifcfg.as_ref()
pub fn get_ifcfg(&self) -> impl IfConfiguerTrait {
IfConfiger {}
}
}
#[cfg(test)]
@@ -166,7 +323,8 @@ mod tests {
use super::VirtualNic;
async fn run_test_helper() -> Result<VirtualNic, Error> {
let dev = VirtualNic::new(get_mock_global_ctx()).create_dev().await?;
let mut dev = VirtualNic::new(get_mock_global_ctx());
let _tunnel = dev.create_dev().await?;
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;