allow listener retry listen (#554)

This commit is contained in:
Sijie.Sun
2025-01-09 00:01:41 +08:00
committed by GitHub
parent d2ec60e108
commit 306817ae9a
3 changed files with 157 additions and 73 deletions
+1 -1
View File
@@ -27,7 +27,7 @@ pub fn gen_default_flags() -> Flags {
relay_all_peer_rpc: false, relay_all_peer_rpc: false,
disable_udp_hole_punching: false, disable_udp_hole_punching: false,
ipv6_listener: "udp://[::]:0".to_string(), ipv6_listener: "udp://[::]:0".to_string(),
multi_thread: false, multi_thread: true,
data_compress_algo: CompressionAlgoPb::None.into(), data_compress_algo: CompressionAlgoPb::None.into(),
} }
} }
+4 -1
View File
@@ -230,7 +230,10 @@ impl GlobalCtx {
} }
pub fn add_running_listener(&self, url: url::Url) { pub fn add_running_listener(&self, url: url::Url) {
self.running_listeners.lock().unwrap().push(url); let mut l = self.running_listeners.lock().unwrap();
if !l.contains(&url) {
l.push(url);
}
} }
pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> { pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> {
+152 -71
View File
@@ -1,8 +1,7 @@
use std::{fmt::Debug, sync::Arc}; use std::{fmt::Debug, sync::Arc};
use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use tokio::{sync::Mutex, task::JoinSet}; use tokio::task::JoinSet;
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
use crate::tunnel::quic::QUICTunnelListener; use crate::tunnel::quic::QUICTunnelListener;
@@ -63,16 +62,20 @@ impl TunnelHandlerForListener for PeerManager {
} }
} }
#[derive(Debug, Clone)] pub trait ListenerCreatorTrait: Fn() -> Box<dyn TunnelListener> + Send + Sync {}
struct Listener { impl<T: Send + Sync> ListenerCreatorTrait for T where T: Fn() -> Box<dyn TunnelListener> + Send {}
inner: Arc<Mutex<dyn TunnelListener>>, pub type ListenerCreator = Box<dyn ListenerCreatorTrait>;
#[derive(Clone)]
struct ListenerFactory {
creator_fn: Arc<ListenerCreator>,
must_succ: bool, must_succ: bool,
} }
pub struct ListenerManager<H> { pub struct ListenerManager<H> {
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
net_ns: NetNS, net_ns: NetNS,
listeners: Vec<Listener>, listeners: Vec<ListenerFactory>,
peer_manager: Arc<H>, peer_manager: Arc<H>,
tasks: JoinSet<()>, tasks: JoinSet<()>,
@@ -90,31 +93,39 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
} }
pub async fn prepare_listeners(&mut self) -> Result<(), Error> { pub async fn prepare_listeners(&mut self) -> Result<(), Error> {
let self_id = self.global_ctx.get_id();
self.add_listener( self.add_listener(
RingTunnelListener::new( move || {
format!("ring://{}", self.global_ctx.get_id()) Box::new(RingTunnelListener::new(
.parse() format!("ring://{}", self_id).parse().unwrap(),
.unwrap(), ))
), },
true, true,
) )
.await?; .await?;
for l in self.global_ctx.config.get_listener_uris().iter() { for l in self.global_ctx.config.get_listener_uris().iter() {
let Ok(lis) = get_listener_by_url(l, self.global_ctx.clone()) else { let l = l.clone();
let Ok(_) = get_listener_by_url(&l, self.global_ctx.clone()) else {
let msg = format!("failed to get listener by url: {}, maybe not supported", l); let msg = format!("failed to get listener by url: {}, maybe not supported", l);
self.global_ctx self.global_ctx
.issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg)); .issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg));
continue; continue;
}; };
self.add_listener(lis, true).await?; let ctx = self.global_ctx.clone();
self.add_listener(move || get_listener_by_url(&l, ctx.clone()).unwrap(), true)
.await?;
} }
if self.global_ctx.config.get_flags().enable_ipv6 { if self.global_ctx.config.get_flags().enable_ipv6 {
let ipv6_listener = self.global_ctx.config.get_flags().ipv6_listener.clone(); let ipv6_listener = self.global_ctx.config.get_flags().ipv6_listener.clone();
let _ = self let _ = self
.add_listener( .add_listener(
UdpTunnelListener::new(ipv6_listener.parse().unwrap()), move || {
Box::new(UdpTunnelListener::new(
ipv6_listener.clone().parse().unwrap(),
))
},
false, false,
) )
.await?; .await?;
@@ -123,85 +134,91 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
Ok(()) Ok(())
} }
pub async fn add_listener<L>(&mut self, listener: L, must_succ: bool) -> Result<(), Error> pub async fn add_listener<C: ListenerCreatorTrait + 'static>(
where &mut self,
L: TunnelListener + 'static, creator: C,
{ must_succ: bool,
let listener = Arc::new(Mutex::new(listener)); ) -> Result<(), Error> {
self.listeners.push(Listener { self.listeners.push(ListenerFactory {
inner: listener, creator_fn: Arc::new(Box::new(creator)),
must_succ, must_succ,
}); });
Ok(()) Ok(())
} }
#[tracing::instrument] #[tracing::instrument(skip(creator))]
async fn run_listener( async fn run_listener(
listener: Arc<Mutex<dyn TunnelListener>>, creator: Arc<ListenerCreator>,
peer_manager: Arc<H>, peer_manager: Arc<H>,
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
) { ) {
let mut l = listener.lock().await;
global_ctx.add_running_listener(l.local_url());
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
loop { loop {
let ret = match l.accept().await { let mut l = (creator)();
Ok(ret) => ret, let _g = global_ctx.net_ns.guard();
match l.listen().await {
Ok(_) => {
global_ctx.add_running_listener(l.local_url());
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
}
Err(e) => { Err(e) => {
global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed( global_ctx.issue_event(GlobalCtxEvent::ListenerAddFailed(
l.local_url(), l.local_url(),
e.to_string(), e.to_string(),
)); ));
tracing::error!(?e, ?l, "listener accept error"); tracing::error!(?e, ?l, "listener listen error");
tokio::time::sleep(std::time::Duration::from_secs(1)).await; tokio::time::sleep(std::time::Duration::from_secs(1)).await;
continue; continue;
} }
}; }
loop {
let ret = match l.accept().await {
Ok(ret) => ret,
Err(e) => {
global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed(
l.local_url(),
format!("error: {}, retry listen later...", e.to_string()),
));
tracing::error!(?e, ?l, "listener accept error");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
break;
}
};
let tunnel_info = ret.info().unwrap(); let tunnel_info = ret.info().unwrap();
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted( global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
tunnel_info tunnel_info
.local_addr .local_addr
.clone() .clone()
.unwrap_or_default() .unwrap_or_default()
.to_string(), .to_string(),
tunnel_info tunnel_info
.remote_addr .remote_addr
.clone() .clone()
.unwrap_or_default() .unwrap_or_default()
.to_string(), .to_string(),
)); ));
tracing::info!(ret = ?ret, "conn accepted"); tracing::info!(ret = ?ret, "conn accepted");
let peer_manager = peer_manager.clone(); let peer_manager = peer_manager.clone();
let global_ctx = global_ctx.clone(); let global_ctx = global_ctx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let server_ret = peer_manager.handle_tunnel(ret).await; let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret { if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError( global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr.unwrap_or_default().to_string(), tunnel_info.local_addr.unwrap_or_default().to_string(),
tunnel_info.remote_addr.unwrap_or_default().to_string(), tunnel_info.remote_addr.unwrap_or_default().to_string(),
e.to_string(), e.to_string(),
)); ));
tracing::error!(error = ?e, "handle conn error"); tracing::error!(error = ?e, "handle conn error");
} }
}); });
}
} }
} }
pub async fn run(&mut self) -> Result<(), Error> { pub async fn run(&mut self) -> Result<(), Error> {
for listener in &self.listeners { for listener in &self.listeners {
let _guard = self.net_ns.guard();
let addr = listener.inner.lock().await.local_url();
tracing::warn!("run listener: {:?}", listener);
listener
.inner
.lock()
.await
.listen()
.await
.with_context(|| format!("failed to add listener {}", addr))?;
self.tasks.spawn(Self::run_listener( self.tasks.spawn(Self::run_listener(
listener.inner.clone(), listener.creator_fn.clone(),
self.peer_manager.clone(), self.peer_manager.clone(),
self.global_ctx.clone(), self.global_ctx.clone(),
)); ));
@@ -213,12 +230,14 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::atomic::{AtomicI32, Ordering};
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use tokio::time::timeout; use tokio::time::timeout;
use crate::{ use crate::{
common::global_ctx::tests::get_mock_global_ctx, common::global_ctx::tests::get_mock_global_ctx,
tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector}, tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector, TunnelError},
}; };
use super::*; use super::*;
@@ -245,12 +264,18 @@ mod tests {
let ring_id = format!("ring://{}", uuid::Uuid::new_v4()); let ring_id = format!("ring://{}", uuid::Uuid::new_v4());
let ring_id_clone = ring_id.clone();
listener_mgr listener_mgr
.add_listener(RingTunnelListener::new(ring_id.parse().unwrap()), true) .add_listener(
move || Box::new(RingTunnelListener::new(ring_id_clone.parse().unwrap())),
true,
)
.await .await
.unwrap(); .unwrap();
listener_mgr.run().await.unwrap(); listener_mgr.run().await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let connect_once = |ring_id| async move { let connect_once = |ring_id| async move {
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap(); let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
let (mut recv, _send) = tunnel.split(); let (mut recv, _send) = tunnel.split();
@@ -269,4 +294,60 @@ mod tests {
.await .await
.unwrap(); .unwrap();
} }
#[tokio::test]
async fn retry_listen() {
let counter = Arc::new(AtomicI32::new(0));
let drop_counter = Arc::new(AtomicI32::new(0));
struct MockListener {
counter: Arc<AtomicI32>,
drop_counter: Arc<AtomicI32>,
}
#[async_trait::async_trait]
impl TunnelListener for MockListener {
fn local_url(&self) -> url::Url {
"mock://".parse().unwrap()
}
async fn listen(&mut self) -> Result<(), TunnelError> {
self.counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Err(TunnelError::BufferFull)
}
}
impl Drop for MockListener {
fn drop(&mut self) {
self.drop_counter.fetch_add(1, Ordering::Relaxed);
}
}
let handler = Arc::new(MockListenerHandler {});
let mut listener_mgr = ListenerManager::new(get_mock_global_ctx(), handler.clone());
let counter_clone = counter.clone();
let drop_counter_clone = drop_counter.clone();
listener_mgr
.add_listener(
move || {
Box::new(MockListener {
counter: counter_clone.clone(),
drop_counter: drop_counter_clone.clone(),
})
},
true,
)
.await
.unwrap();
listener_mgr.run().await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
assert!(counter.load(Ordering::Relaxed) >= 2);
assert!(drop_counter.load(Ordering::Relaxed) >= 1);
}
} }