diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index bb829f03..88b2f0b3 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -8,7 +8,7 @@ use anyhow::Context; use async_trait::async_trait; use bytes::BytesMut; use dashmap::DashMap; -use futures::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use rand::{Rng, SeedableRng}; use zerocopy::{AsBytes, FromBytes}; @@ -265,18 +265,18 @@ async fn forward_from_ring_to_udp( } } -async fn udp_recv_from_socket_forward_task(socket: Arc, allow_stun: bool, mut f: F) -where - F: FnMut(ZCPacket, SocketAddr), -{ - let mut buf = BytesMut::new(); +async fn udp_recv_from_socket_forward_task( + socket: &UdpSocket, + buf: &mut BytesMut, + allow_stun: bool, +) -> Result<(ZCPacket, SocketAddr), TunnelError> { loop { - reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 4); - let (dg_size, addr) = match socket.recv_buf_from(&mut buf).await { + reserve_buf(buf, UDP_DATA_MTU, UDP_DATA_MTU * 4); + let (dg_size, addr) = match socket.recv_buf_from(buf).await { Ok(v) => v, Err(e) => { tracing::error!(?e, "udp recv from socket error"); - break; + return Err(e.into()); } }; tracing::trace!( @@ -294,7 +294,7 @@ where } }; - f(zc_packet, addr); + return Ok((zc_packet, addr)); } } @@ -335,7 +335,10 @@ impl UdpConnection { } } - pub fn handle_packet_from_remote(&mut self, zc_packet: ZCPacket) -> Result<(), TunnelError> { + pub async fn handle_packet_from_remote( + &mut self, + zc_packet: ZCPacket, + ) -> Result<(), TunnelError> { let header = zc_packet.udp_tunnel_header().unwrap(); let conn_id = header.conn_id.get(); @@ -347,13 +350,7 @@ impl UdpConnection { return Err(TunnelError::ConnIdNotMatch(self.conn_id, conn_id)); } - if zc_packet.is_lossy() { - if let Err(e) = self.ring_sender.try_send(zc_packet) { - tracing::trace!(?e, "ring sender full, drop lossy packet"); - } - } else if let Err(e) = self.ring_sender.force_send(zc_packet) { - tracing::trace!(?e, "ring sender full, drop non-lossy packet"); - } + self.ring_sender.send(zc_packet).await?; Ok(()) } @@ -442,7 +439,7 @@ impl UdpTunnelListenerData { } } - fn do_forward_one_packet_to_conn(&self, zc_packet: ZCPacket, addr: SocketAddr) { + async fn do_forward_one_packet_to_conn(&self, zc_packet: ZCPacket, addr: SocketAddr) { let header = zc_packet.udp_tunnel_header().unwrap(); if header.msg_type == UdpPacketType::Syn as u8 { tokio::spawn(Self::handle_new_connect(self.clone(), addr, zc_packet)); @@ -481,7 +478,7 @@ impl UdpTunnelListenerData { tracing::trace!(?header, "udp forward packet error, connection not found"); return; }; - if let Err(e) = conn.handle_packet_from_remote(zc_packet) { + if let Err(e) = conn.handle_packet_from_remote(zc_packet).await { tracing::trace!(?e, "udp forward packet error"); } } else { @@ -491,10 +488,16 @@ impl UdpTunnelListenerData { async fn do_forward_task(self) { let socket = self.socket.as_ref().unwrap().clone(); - udp_recv_from_socket_forward_task(socket, true, |zc_packet, addr| { - self.do_forward_one_packet_to_conn(zc_packet, addr); - }) - .await; + let mut buf = BytesMut::new(); + loop { + match udp_recv_from_socket_forward_task(&socket, &mut buf, true).await { + Ok((zc_packet, addr)) => self.do_forward_one_packet_to_conn(zc_packet, addr).await, + Err(e) => { + tracing::error!(?e, "udp recv packet error"); + break; + } + } + } } } @@ -730,18 +733,31 @@ impl UdpTunnelConnector { ); let socket_clone = socket.clone(); + + let recv_loop = async move { + let mut buf = BytesMut::new(); + loop { + match udp_recv_from_socket_forward_task(&socket_clone, &mut buf, false).await { + Ok((zc_packet, addr)) => { + tracing::trace!(?addr, "connector udp forward task done"); + if let Err(e) = udp_conn.handle_packet_from_remote(zc_packet).await { + tracing::trace!(?e, ?addr, "udp forward packet error"); + } + } + Err(e) => { + tracing::trace!(?e, "udp forward task error"); + break; + } + } + } + }; tokio::spawn( async move { tokio::select! { _ = close_event_recv.recv() => { tracing::debug!("connector udp close event"); } - _ = udp_recv_from_socket_forward_task(socket_clone,false, |zc_packet, addr| { - tracing::trace!(?addr, "connector udp forward task done"); - if let Err(e) = udp_conn.handle_packet_from_remote(zc_packet) { - tracing::trace!(?e, ?addr, "udp forward packet error"); - } - }) => { + _ = recv_loop => { tracing::debug!("connector udp forward task done"); } }