refactor: introduce HedgeExt for task hedging; rewrite NatDstQuicConnector (#2229)

This commit is contained in:
Luna Yao
2026-05-12 14:26:16 +02:00
committed by GitHub
parent 513695297c
commit 8428a89d2d
9 changed files with 271 additions and 166 deletions
+2 -1
View File
@@ -1,5 +1,4 @@
use std::{io, result};
use thiserror::Error;
use crate::tunnel;
@@ -55,4 +54,6 @@ pub enum Error {
pub type Result<T> = result::Result<T, Error>;
pub type ErrorCollection = crate::utils::error::ErrorCollection<Error>;
// impl From for std::
+45 -59
View File
@@ -4,7 +4,7 @@ use std::{
time::Duration,
};
use anyhow::Context;
use anyhow::{Context, anyhow, bail};
use bytes::Bytes;
use dashmap::DashMap;
use guarden::defer;
@@ -15,12 +15,13 @@ use kcp_sys::{
stream::KcpStream,
};
use prost::Message;
use tokio::{select, task::JoinSet};
use tokio::task::JoinSet;
use super::{
CidrSet,
tcp_proxy::{NatDstConnector, NatDstTcpConnector, TcpProxy},
};
use crate::utils::task::HedgeExt;
use crate::{
common::{
acl_processor::PacketInfo,
@@ -114,72 +115,57 @@ pub struct NatDstKcpConnector {
impl NatDstConnector for NatDstKcpConnector {
type DstStream = KcpStream;
async fn connect(&self, src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
async fn connect(
&self,
src: SocketAddr,
nat_dst: SocketAddr,
) -> anyhow::Result<Self::DstStream> {
let peer_mgr = self
.peer_mgr
.upgrade()
.ok_or_else(|| anyhow!("peer manager is not available"))?;
let dst_peer = {
let SocketAddr::V4(addr) = nat_dst else {
bail!("ipv6 is not supported");
};
peer_mgr
.get_peer_map()
.get_peer_id_by_ipv4(addr.ip())
.await
.ok_or_else(|| anyhow!("no peer found for nat dst: {}", nat_dst))?
};
tracing::trace!(?nat_dst, ?dst_peer, "kcp nat");
let conn_data = KcpConnData {
src: Some(src.into()),
dst: Some(nat_dst.into()),
};
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(anyhow::anyhow!("peer manager is not available").into());
};
let stream = (0..5)
.map(|_| {
let kcp_endpoint = self.kcp_endpoint.clone();
let my_peer_id = peer_mgr.my_peer_id();
let dst_peer_id = match nat_dst {
SocketAddr::V4(addr) => peer_mgr.get_peer_map().get_peer_id_by_ipv4(addr.ip()).await,
SocketAddr::V6(_) => return Err(anyhow::anyhow!("ipv6 is not supported").into()),
};
async move {
let conn_id = kcp_endpoint
.connect(
Duration::from_secs(10),
my_peer_id,
dst_peer,
Bytes::from(conn_data.encode_to_vec()),
)
.await?;
let Some(dst_peer) = dst_peer_id else {
return Err(anyhow::anyhow!("no peer found for nat dst: {}", nat_dst).into());
};
tracing::trace!("kcp nat dst: {:?}, dst peers: {:?}", nat_dst, dst_peer);
let mut connect_tasks: JoinSet<std::result::Result<ConnId, anyhow::Error>> = JoinSet::new();
let mut retry_remain = 5;
loop {
select! {
Some(Ok(Ok(ret))) = connect_tasks.join_next() => {
// just wait for the previous connection to finish
let stream = KcpStream::new(&self.kcp_endpoint, ret)
.ok_or(anyhow::anyhow!("failed to create kcp stream"))?;
return Ok(stream);
KcpStream::new(&kcp_endpoint, conn_id).context("failed to create kcp stream")
}
_ = tokio::time::sleep(Duration::from_millis(200)), if !connect_tasks.is_empty() && retry_remain > 0 => {
// no successful connection yet, trigger another connection attempt
}
else => {
// got error in connect_tasks, continue to retry
if retry_remain == 0 && connect_tasks.is_empty() {
break;
}
}
}
})
.hedge(Duration::from_millis(200))
.await
.context("failed to connect to peer")?;
// create a new connection task
if retry_remain == 0 {
continue;
}
retry_remain -= 1;
let kcp_endpoint = self.kcp_endpoint.clone();
let my_peer_id = peer_mgr.my_peer_id();
let conn_data_clone = conn_data;
connect_tasks.spawn(async move {
kcp_endpoint
.connect(
Duration::from_secs(10),
my_peer_id,
dst_peer,
Bytes::from(conn_data_clone.encode_to_vec()),
)
.await
.with_context(|| format!("failed to connect to nat dst: {}", nat_dst))
});
}
Err(anyhow::anyhow!("failed to connect to nat dst: {}", nat_dst).into())
Ok(stream)
}
fn check_packet_from_peer_fast(&self, _cidr_set: &CidrSet, _global_ctx: &GlobalCtx) -> bool {
+89 -79
View File
@@ -18,7 +18,8 @@ use crate::tunnel::packet_def::{
PacketType, PeerManagerHeader, TAIL_RESERVED_SIZE, ZCPacket, ZCPacketType,
};
use crate::tunnel::quic::{client_config, endpoint_config, server_config};
use anyhow::{Context, Error, anyhow};
use crate::utils::task::HedgeExt;
use anyhow::{Context, Error, anyhow, bail, ensure};
use atomic_refcell::AtomicRefCell;
use bytes::{BufMut, Bytes, BytesMut};
use dashmap::DashMap;
@@ -29,7 +30,8 @@ use moka::future::Cache;
use prost::Message;
use quinn::udp::{EcnCodepoint, RecvMeta, Transmit};
use quinn::{
AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, UdpPoller, default_runtime,
AsyncUdpSocket, Connection, ConnectionError, Endpoint, RecvStream, SendStream, StreamId,
UdpPoller, WriteError, default_runtime,
};
use std::cmp::min;
use std::future::Future;
@@ -280,7 +282,7 @@ impl From<(SendStream, RecvStream)> for QuicStream {
pub struct NatDstQuicConnector {
pub(crate) endpoint: Endpoint,
pub(crate) peer_mgr: Weak<PeerManager>,
pub(crate) conn_map: Cache<PeerId, quinn::Connection>,
pub(crate) conn_map: Cache<PeerId, Connection>,
}
#[async_trait::async_trait]
@@ -291,20 +293,25 @@ impl NatDstConnector for NatDstQuicConnector {
&self,
src: SocketAddr,
nat_dst: SocketAddr,
) -> crate::common::error::Result<Self::DstStream> {
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(anyhow::anyhow!("peer manager is not available").into());
) -> anyhow::Result<Self::DstStream> {
let peer_mgr = self
.peer_mgr
.upgrade()
.ok_or_else(|| anyhow!("peer manager is not available"))?;
let dst_peer = {
let SocketAddr::V4(addr) = nat_dst else {
bail!("ipv6 is not supported");
};
peer_mgr
.get_peer_map()
.get_peer_id_by_ipv4(addr.ip())
.await
.ok_or_else(|| anyhow!("no peer found for nat dst: {}", nat_dst))?
};
let Some(dst_peer_id) = (match nat_dst {
SocketAddr::V4(addr) => peer_mgr.get_peer_map().get_peer_id_by_ipv4(addr.ip()).await,
SocketAddr::V6(_) => return Err(anyhow::anyhow!("ipv6 is not supported").into()),
}) else {
return Err(anyhow::anyhow!("no peer found for nat dst: {}", nat_dst).into());
};
tracing::trace!(?nat_dst, ?dst_peer, "quic nat");
trace!("quic nat dst: {:?}, dst peers: {:?}", nat_dst, dst_peer_id);
let addr = QuicAddr::new(dst_peer_id, PacketType::QuicSrc).into();
let header = {
let conn_data = QuicConnData {
src: Some(src.into()),
@@ -312,77 +319,91 @@ impl NatDstConnector for NatDstQuicConnector {
};
let len = conn_data.encoded_len();
if len > (u16::MAX as usize) {
return Err(anyhow!("conn data too large: {:?}", len).into());
}
ensure!(len <= u16::MAX as usize, "conn data too large: {len}");
let mut buf = BytesMut::with_capacity(2 + len);
buf.put_u16(len as u16);
conn_data.encode(&mut buf).unwrap();
conn_data.encode(&mut buf)?;
buf.freeze()
};
for attempt in 0..2 {
let endpoint = self.endpoint.clone();
let reconnect = || async move {
self.conn_map.invalidate(&dst_peer).await;
let connection = match self
.conn_map
.try_get_with(dst_peer_id, async move {
endpoint
.connect(addr, "")
.map_err(|e| anyhow!("quic connect: {:#}", e))?
.await
.map_err(|e| anyhow!("quic connection: {:#}", e))
})
.await
{
Ok(conn) => conn,
Err(e) => {
if attempt == 0 {
debug!("quic connect failed, retrying: {:#}", e);
tokio::time::sleep(Duration::from_millis(300)).await;
continue;
let connect = (0..5)
.map(|_| {
let endpoint = self.endpoint.clone();
async move {
endpoint
.connect(QuicAddr::new(dst_peer, PacketType::QuicSrc).into(), "")
.context("failed to create connection")?
.await
.context("connection failed")
}
return Err(anyhow!("{:#}", e).into());
}
};
})
.hedge(Duration::from_millis(200));
let stream: Result<QuicStreamInner, anyhow::Error> = async {
self.conn_map
.try_get_with(dst_peer, connect)
.await
.context("failed to connect to peer")
};
let mut reconnected = false;
let mut connection = if let Some(connection) = self.conn_map.get(&dst_peer).await
&& connection.close_reason().is_none()
{
connection
} else {
reconnected = true;
reconnect().await?
};
loop {
let is_retryable = |error: &ConnectionError| {
matches!(
error,
ConnectionError::ConnectionClosed(_)
| ConnectionError::ApplicationClosed(_)
| ConnectionError::Reset
| ConnectionError::TimedOut
)
};
let mut retry = !reconnected;
let header = header.clone();
let result = async {
let mut stream: QuicStream = connection
.open_bi()
.await
.map_err(|e| anyhow!("open bi: {:#}", e))?
.inspect_err(|error| retry &= is_retryable(error))?
.into();
stream.writer_mut().write_chunk(header.clone()).await?;
stream
.writer_mut()
.write_chunk(header)
.await
.inspect_err(|error| {
retry &= matches!(error, WriteError::ConnectionLost(error) if is_retryable(error))
})?;
Ok(stream.into())
}
.await;
.await;
match stream {
Ok(stream) => return Ok(stream),
Err(error) => {
debug!(
?dst_peer_id,
attempt,
?error,
"quic connect: stream setup failed"
);
if let Err(error) = &result {
if retry {
debug!(?error, "failed to open quic stream, retrying...");
reconnected = true;
connection = reconnect().await?;
continue;
} else {
self.conn_map.invalidate(&dst_peer).await;
}
}
// Evict stale connection;
self.conn_map.invalidate(&dst_peer_id).await;
break result;
}
Err(anyhow!(
"quic connect: failed after {} attempts, dst_peer_id={}, nat_dst={}",
2,
dst_peer_id,
nat_dst
)
.into())
}
#[inline]
@@ -839,7 +860,7 @@ impl QuicProxy {
Arc::new(socket),
default_runtime().unwrap(),
)
.unwrap();
.unwrap(); // TODO: maybe a different transport config
endpoint.set_default_client_config(client_config());
self.endpoint = Some(endpoint.clone());
@@ -863,26 +884,15 @@ impl QuicProxy {
return;
}
let conn_map = Cache::builder()
.max_capacity(u8::MAX.into()) // same with max_concurrent_bidi_streams, can be increased
.time_to_idle(Duration::from_secs(600))
.build();
let conn_map_bg = conn_map.clone();
self.tasks.spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
conn_map_bg.run_pending_tasks().await;
}
});
let tcp_proxy = TcpProxyForQuicSrc(TcpProxy::new(
peer_mgr.clone(),
NatDstQuicConnector {
endpoint: endpoint.clone(),
peer_mgr: Arc::downgrade(&peer_mgr),
conn_map,
conn_map: Cache::builder()
.max_capacity(u8::MAX.into()) // cf. quinn transport config (max_concurrent_bidi_streams)
.time_to_idle(Duration::from_secs(600)) // cf. quinn transport config (max_idle_timeout)
.build(),
},
));
+1 -1
View File
@@ -240,7 +240,7 @@ impl AsyncTcpConnector for Socks5KcpConnector {
let ret = c
.connect(self.src_addr, addr)
.await
.map_err(|e| super::fast_socks5::SocksError::Other(e.into()))?;
.map_err(super::fast_socks5::SocksError::Other)?;
Ok(SocksTcpStream::Kcp(ret))
}
}
+8 -9
View File
@@ -44,7 +44,7 @@ use super::tokio_smoltcp::{self, Net, NetConfig, channel_device};
pub(crate) trait NatDstConnector: Send + Sync + Clone + 'static {
type DstStream: AsyncRead + AsyncWrite + Unpin + Send;
async fn connect(&self, src: SocketAddr, dst: SocketAddr) -> Result<Self::DstStream>;
async fn connect(&self, src: SocketAddr, dst: SocketAddr) -> anyhow::Result<Self::DstStream>;
fn check_packet_from_peer_fast(&self, cidr_set: &CidrSet, global_ctx: &GlobalCtx) -> bool;
fn check_packet_from_peer(
&self,
@@ -63,14 +63,13 @@ pub struct NatDstTcpConnector;
#[async_trait::async_trait]
impl NatDstConnector for NatDstTcpConnector {
type DstStream = TcpStream;
async fn connect(&self, _src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
let socket = match TcpSocket::new_v4() {
Ok(s) => s,
Err(error) => {
log::error!(?error, "create v4 socket failed");
return Err(error.into());
}
};
async fn connect(
&self,
_src: SocketAddr,
nat_dst: SocketAddr,
) -> anyhow::Result<Self::DstStream> {
let socket = TcpSocket::new_v4()
.inspect_err(|error| log::error!(?error, "create v4 socket failed"))?;
let stream = timeout(Duration::from_secs(10), socket.connect(nat_dst))
.await?
+5 -17
View File
@@ -1,6 +1,6 @@
use delegate::delegate;
use derivative::Derivative;
use derive_more::{Deref, DerefMut, From, IntoIterator};
use derive_more::{AsMut, AsRef, Deref, DerefMut, From, IntoIterator};
use prost::Message;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
@@ -45,11 +45,15 @@ where
From,
Deref,
DerefMut,
AsRef,
AsMut,
Serialize,
Deserialize,
IntoIterator,
)]
#[derivative(Default(bound = ""))]
#[as_ref(forward)]
#[as_mut(forward)]
#[serde(transparent)]
#[into_iterator(owned, ref, ref_mut)]
pub struct RepeatedMessageModel<Model>(Vec<Model>);
@@ -74,22 +78,6 @@ impl<Model> Extend<Model> for RepeatedMessageModel<Model> {
}
}
impl<Model> AsRef<[Model]> for RepeatedMessageModel<Model> {
delegate! {
to self.0 {
fn as_ref(&self) -> &[Model];
}
}
}
impl<Model> AsMut<[Model]> for RepeatedMessageModel<Model> {
delegate! {
to self.0 {
fn as_mut(&mut self) -> &mut [Model];
}
}
}
impl<'m, Message, Model> TryFrom<&'m [Message]> for RepeatedMessageModel<Model>
where
Message: prost::Message,
+58
View File
@@ -0,0 +1,58 @@
use delegate::delegate;
use derivative::Derivative;
use derive_more::{AsMut, AsRef, Deref, DerefMut, From, Into, IntoIterator};
use std::fmt;
use std::fmt::Display;
use thiserror::Error;
#[derive(Derivative, Debug, From, Into, Deref, DerefMut, AsRef, AsMut, IntoIterator, Error)]
#[derivative(Default(bound = ""))]
#[as_ref(forward)]
#[as_mut(forward)]
#[into_iterator(owned, ref, ref_mut)]
pub struct ErrorCollection<E> {
pub errors: Vec<E>,
}
impl<E> ErrorCollection<E> {
delegate! {
to Vec {
#[into]
pub fn new() -> Self;
#[into]
pub fn with_capacity(capacity: usize) -> Self;
}
}
}
impl<E, Item: Into<E>> FromIterator<Item> for ErrorCollection<E> {
fn from_iter<I: IntoIterator<Item = Item>>(iter: I) -> Self {
Self {
errors: iter.into_iter().map(Into::into).collect(),
}
}
}
impl<E> Extend<E> for ErrorCollection<E> {
delegate! {
to self.errors {
fn extend<T: IntoIterator<Item = E>>(&mut self, iter: T);
}
}
}
impl<E: Display> Display for ErrorCollection<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.errors.is_empty() {
return write!(f, "No errors");
}
write!(f, "{} error(s) occurred:", self.errors.len())?;
for (i, err) in self.errors.iter().enumerate() {
writeln!(f)?;
write!(f, " {}. {}", i + 1, err)?;
}
Ok(())
}
}
+1
View File
@@ -1,3 +1,4 @@
pub mod error;
pub mod panic;
pub mod string;
pub mod task;
+62
View File
@@ -1,9 +1,13 @@
use crate::utils::error::ErrorCollection;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
@@ -78,3 +82,61 @@ impl<Output> Future for CancellableTask<Output> {
}
// endregion
// region HedgeExt
pub(crate) trait HedgeExt: Iterator + Sized {
async fn hedge<T, E>(self, delay: Duration) -> Result<T, ErrorCollection<E>>
where
Self::Item: Future<Output = Result<T, E>>;
}
impl<I> HedgeExt for I
where
I: Iterator,
{
async fn hedge<T, E>(mut self, delay: Duration) -> Result<T, ErrorCollection<E>>
where
Self::Item: Future<Output = Result<T, E>>,
{
let mut tasks = FuturesUnordered::new();
let mut errors = ErrorCollection::new();
let mut exhausted = false;
macro_rules! spawn {
() => {
if let Some(fut) = self.next() {
tasks.push(fut);
} else {
exhausted = true;
}
};
}
spawn!();
while !tasks.is_empty() {
tokio::select! {
res = tasks.next() => {
match res {
Some(Ok(v)) => return Ok(v),
Some(Err(e)) => errors.push(e),
None => unreachable!(),
}
if !exhausted {
spawn!();
}
}
_ = sleep(delay), if !exhausted => {
spawn!();
}
}
}
Err(errors)
}
}
// endregion