// this mod wrap tunnel to a mpsc tunnel, based on crossbeam_channel use std::{pin::Pin, time::Duration}; use anyhow::Context; use tokio::time::timeout; use crate::common::scoped_task::ScopedTask; use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream}; // use tokio::sync::mpsc::{channel, error::TrySendError, Receiver, Sender}; use tachyonix::{channel, Receiver, Sender, TrySendError}; use futures::SinkExt; #[derive(Clone)] pub struct MpscTunnelSender(Sender); impl MpscTunnelSender { pub async fn send(&self, item: ZCPacket) -> Result<(), TunnelError> { self.0.send(item).await.with_context(|| "send error")?; Ok(()) } pub fn try_send(&self, item: ZCPacket) -> Result<(), TunnelError> { self.0.try_send(item).map_err(|e| match e { TrySendError::Full(_) => TunnelError::BufferFull, TrySendError::Closed(_) => TunnelError::Shutdown, }) } } pub struct MpscTunnel { tx: Option>, tunnel: T, stream: Option>>, task: ScopedTask<()>, } impl MpscTunnel { pub fn new(tunnel: T, send_timeout: Option) -> Self { let (tx, mut rx) = channel(32); let (stream, mut sink) = tunnel.split(); let task = tokio::spawn(async move { loop { if let Err(e) = Self::forward_one_round(&mut rx, &mut sink, send_timeout).await { tracing::error!(?e, "forward error"); break; } } rx.close(); let close_ret = timeout(Duration::from_secs(5), sink.close()).await; tracing::warn!(?close_ret, "mpsc close sink"); }); Self { tx: Some(tx), tunnel, stream: Some(stream), task: task.into(), } } async fn forward_one_round( rx: &mut Receiver, sink: &mut Pin>, send_timeout_ms: Option, ) -> Result<(), TunnelError> { let item = rx.recv().await.with_context(|| "recv error")?; if let Some(timeout_ms) = send_timeout_ms { Self::forward_one_round_with_timeout(rx, sink, item, timeout_ms).await } else { Self::forward_one_round_no_timeout(rx, sink, item).await } } async fn forward_one_round_no_timeout( rx: &mut Receiver, sink: &mut Pin>, initial_item: ZCPacket, ) -> Result<(), TunnelError> { sink.feed(initial_item).await?; while let Ok(item) = rx.try_recv() { match sink.feed(item).await { Err(e) => { tracing::error!(?e, "feed error"); return Err(e); } Ok(_) => {} } } sink.flush().await } async fn forward_one_round_with_timeout( rx: &mut Receiver, sink: &mut Pin>, initial_item: ZCPacket, timeout_ms: Duration, ) -> Result<(), TunnelError> { match timeout(timeout_ms, async move { Self::forward_one_round_no_timeout(rx, sink, initial_item).await }) .await { Ok(Ok(_)) => Ok(()), Ok(Err(e)) => { tracing::error!(?e, "forward error"); Err(e) } Err(e) => { tracing::error!(?e, "forward timeout"); Err(e.into()) } } } pub fn get_stream(&mut self) -> Pin> { self.stream.take().unwrap() } pub fn get_sink(&self) -> MpscTunnelSender { MpscTunnelSender(self.tx.as_ref().unwrap().clone()) } pub fn close(&mut self) { self.tx.take(); self.task.abort(); } } #[cfg(test)] mod tests { use futures::StreamExt; use crate::tunnel::{ ring::{create_ring_tunnel_pair, RING_TUNNEL_CAP}, tcp::{TcpTunnelConnector, TcpTunnelListener}, TunnelConnector, TunnelListener, }; use super::*; // test slow send lock in framed tunnel #[tokio::test] async fn mpsc_slow_receiver() { let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap()); let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap()); listener.listen().await.unwrap(); let t1 = tokio::spawn(async move { let t = listener.accept().await.unwrap(); let (mut stream, _sink) = t.split(); let now = tokio::time::Instant::now(); let mut a_counter = 0; let mut b_counter = 0; while let Some(Ok(msg)) = stream.next().await { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; if now.elapsed().as_secs() > 5 { break; } if msg.payload() == "hello".as_bytes() { a_counter += 1; } else if msg.payload() == "hello2".as_bytes() { b_counter += 1; } } tracing::info!("t1 exit"); assert_ne!(a_counter, 0); assert_ne!(b_counter, 0); }); let tunnel = connector.connect().await.unwrap(); let mpsc_tunnel = MpscTunnel::new(tunnel, None); let sink1 = mpsc_tunnel.get_sink(); let t2 = tokio::spawn(async move { for i in 0..1000000 { tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; let a = sink1 .send(ZCPacket::new_with_payload("hello".as_bytes())) .await; if a.is_err() { tracing::info!(?a, "t2 exit with err"); break; } if i % 5000 == 0 { tracing::info!(i, "send2 1000"); } } tracing::info!("t2 exit"); }); let sink2 = mpsc_tunnel.get_sink(); let t3 = tokio::spawn(async move { for i in 0..1000000 { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; let a = sink2 .send(ZCPacket::new_with_payload("hello2".as_bytes())) .await; if a.is_err() { tracing::info!(?a, "t3 exit with err"); break; } if i % 5000 == 0 { tracing::info!(i, "send2 1000"); } } tracing::info!("t3 exit"); }); let t4 = tokio::spawn(async move { tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; tracing::info!("closing"); drop(mpsc_tunnel); tracing::info!("closed"); }); let _ = tokio::join!(t1, t2, t3, t4); } #[tokio::test] async fn mpsc_slow_receiver_with_send_timeout() { let (a, _b) = create_ring_tunnel_pair(); let mpsc_tunnel = MpscTunnel::new(a, Some(Duration::from_secs(1))); let s = mpsc_tunnel.get_sink(); for _ in 0..RING_TUNNEL_CAP { s.send(ZCPacket::new_with_payload(&[0; 1024])) .await .unwrap(); } tokio::time::sleep(Duration::from_millis(1500)).await; let e = s.send(ZCPacket::new_with_payload(&[0; 1024])).await; assert!(e.is_ok()); tokio::time::sleep(Duration::from_millis(1500)).await; let e = s.send(ZCPacket::new_with_payload(&[0; 1024])).await; assert!(e.is_err()); } }