mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-14 01:45:46 +00:00
refactor: introduce HedgeExt for task hedging; rewrite NatDstQuicConnector (#2229)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
use std::{io, result};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::tunnel;
|
||||
@@ -55,4 +54,6 @@ pub enum Error {
|
||||
|
||||
pub type Result<T> = result::Result<T, Error>;
|
||||
|
||||
pub type ErrorCollection = crate::utils::error::ErrorCollection<Error>;
|
||||
|
||||
// impl From for std::
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow, bail};
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use guarden::defer;
|
||||
@@ -15,12 +15,13 @@ use kcp_sys::{
|
||||
stream::KcpStream,
|
||||
};
|
||||
use prost::Message;
|
||||
use tokio::{select, task::JoinSet};
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use super::{
|
||||
CidrSet,
|
||||
tcp_proxy::{NatDstConnector, NatDstTcpConnector, TcpProxy},
|
||||
};
|
||||
use crate::utils::task::HedgeExt;
|
||||
use crate::{
|
||||
common::{
|
||||
acl_processor::PacketInfo,
|
||||
@@ -114,72 +115,57 @@ pub struct NatDstKcpConnector {
|
||||
impl NatDstConnector for NatDstKcpConnector {
|
||||
type DstStream = KcpStream;
|
||||
|
||||
async fn connect(&self, src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
|
||||
async fn connect(
|
||||
&self,
|
||||
src: SocketAddr,
|
||||
nat_dst: SocketAddr,
|
||||
) -> anyhow::Result<Self::DstStream> {
|
||||
let peer_mgr = self
|
||||
.peer_mgr
|
||||
.upgrade()
|
||||
.ok_or_else(|| anyhow!("peer manager is not available"))?;
|
||||
|
||||
let dst_peer = {
|
||||
let SocketAddr::V4(addr) = nat_dst else {
|
||||
bail!("ipv6 is not supported");
|
||||
};
|
||||
peer_mgr
|
||||
.get_peer_map()
|
||||
.get_peer_id_by_ipv4(addr.ip())
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("no peer found for nat dst: {}", nat_dst))?
|
||||
};
|
||||
|
||||
tracing::trace!(?nat_dst, ?dst_peer, "kcp nat");
|
||||
|
||||
let conn_data = KcpConnData {
|
||||
src: Some(src.into()),
|
||||
dst: Some(nat_dst.into()),
|
||||
};
|
||||
|
||||
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
|
||||
return Err(anyhow::anyhow!("peer manager is not available").into());
|
||||
};
|
||||
let stream = (0..5)
|
||||
.map(|_| {
|
||||
let kcp_endpoint = self.kcp_endpoint.clone();
|
||||
let my_peer_id = peer_mgr.my_peer_id();
|
||||
|
||||
let dst_peer_id = match nat_dst {
|
||||
SocketAddr::V4(addr) => peer_mgr.get_peer_map().get_peer_id_by_ipv4(addr.ip()).await,
|
||||
SocketAddr::V6(_) => return Err(anyhow::anyhow!("ipv6 is not supported").into()),
|
||||
};
|
||||
async move {
|
||||
let conn_id = kcp_endpoint
|
||||
.connect(
|
||||
Duration::from_secs(10),
|
||||
my_peer_id,
|
||||
dst_peer,
|
||||
Bytes::from(conn_data.encode_to_vec()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Some(dst_peer) = dst_peer_id else {
|
||||
return Err(anyhow::anyhow!("no peer found for nat dst: {}", nat_dst).into());
|
||||
};
|
||||
|
||||
tracing::trace!("kcp nat dst: {:?}, dst peers: {:?}", nat_dst, dst_peer);
|
||||
|
||||
let mut connect_tasks: JoinSet<std::result::Result<ConnId, anyhow::Error>> = JoinSet::new();
|
||||
let mut retry_remain = 5;
|
||||
loop {
|
||||
select! {
|
||||
Some(Ok(Ok(ret))) = connect_tasks.join_next() => {
|
||||
// just wait for the previous connection to finish
|
||||
let stream = KcpStream::new(&self.kcp_endpoint, ret)
|
||||
.ok_or(anyhow::anyhow!("failed to create kcp stream"))?;
|
||||
return Ok(stream);
|
||||
KcpStream::new(&kcp_endpoint, conn_id).context("failed to create kcp stream")
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_millis(200)), if !connect_tasks.is_empty() && retry_remain > 0 => {
|
||||
// no successful connection yet, trigger another connection attempt
|
||||
}
|
||||
else => {
|
||||
// got error in connect_tasks, continue to retry
|
||||
if retry_remain == 0 && connect_tasks.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.hedge(Duration::from_millis(200))
|
||||
.await
|
||||
.context("failed to connect to peer")?;
|
||||
|
||||
// create a new connection task
|
||||
if retry_remain == 0 {
|
||||
continue;
|
||||
}
|
||||
retry_remain -= 1;
|
||||
|
||||
let kcp_endpoint = self.kcp_endpoint.clone();
|
||||
let my_peer_id = peer_mgr.my_peer_id();
|
||||
let conn_data_clone = conn_data;
|
||||
|
||||
connect_tasks.spawn(async move {
|
||||
kcp_endpoint
|
||||
.connect(
|
||||
Duration::from_secs(10),
|
||||
my_peer_id,
|
||||
dst_peer,
|
||||
Bytes::from(conn_data_clone.encode_to_vec()),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("failed to connect to nat dst: {}", nat_dst))
|
||||
});
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("failed to connect to nat dst: {}", nat_dst).into())
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn check_packet_from_peer_fast(&self, _cidr_set: &CidrSet, _global_ctx: &GlobalCtx) -> bool {
|
||||
|
||||
@@ -18,7 +18,8 @@ use crate::tunnel::packet_def::{
|
||||
PacketType, PeerManagerHeader, TAIL_RESERVED_SIZE, ZCPacket, ZCPacketType,
|
||||
};
|
||||
use crate::tunnel::quic::{client_config, endpoint_config, server_config};
|
||||
use anyhow::{Context, Error, anyhow};
|
||||
use crate::utils::task::HedgeExt;
|
||||
use anyhow::{Context, Error, anyhow, bail, ensure};
|
||||
use atomic_refcell::AtomicRefCell;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use dashmap::DashMap;
|
||||
@@ -29,7 +30,8 @@ use moka::future::Cache;
|
||||
use prost::Message;
|
||||
use quinn::udp::{EcnCodepoint, RecvMeta, Transmit};
|
||||
use quinn::{
|
||||
AsyncUdpSocket, Endpoint, RecvStream, SendStream, StreamId, UdpPoller, default_runtime,
|
||||
AsyncUdpSocket, Connection, ConnectionError, Endpoint, RecvStream, SendStream, StreamId,
|
||||
UdpPoller, WriteError, default_runtime,
|
||||
};
|
||||
use std::cmp::min;
|
||||
use std::future::Future;
|
||||
@@ -280,7 +282,7 @@ impl From<(SendStream, RecvStream)> for QuicStream {
|
||||
pub struct NatDstQuicConnector {
|
||||
pub(crate) endpoint: Endpoint,
|
||||
pub(crate) peer_mgr: Weak<PeerManager>,
|
||||
pub(crate) conn_map: Cache<PeerId, quinn::Connection>,
|
||||
pub(crate) conn_map: Cache<PeerId, Connection>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -291,20 +293,25 @@ impl NatDstConnector for NatDstQuicConnector {
|
||||
&self,
|
||||
src: SocketAddr,
|
||||
nat_dst: SocketAddr,
|
||||
) -> crate::common::error::Result<Self::DstStream> {
|
||||
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
|
||||
return Err(anyhow::anyhow!("peer manager is not available").into());
|
||||
) -> anyhow::Result<Self::DstStream> {
|
||||
let peer_mgr = self
|
||||
.peer_mgr
|
||||
.upgrade()
|
||||
.ok_or_else(|| anyhow!("peer manager is not available"))?;
|
||||
|
||||
let dst_peer = {
|
||||
let SocketAddr::V4(addr) = nat_dst else {
|
||||
bail!("ipv6 is not supported");
|
||||
};
|
||||
peer_mgr
|
||||
.get_peer_map()
|
||||
.get_peer_id_by_ipv4(addr.ip())
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("no peer found for nat dst: {}", nat_dst))?
|
||||
};
|
||||
|
||||
let Some(dst_peer_id) = (match nat_dst {
|
||||
SocketAddr::V4(addr) => peer_mgr.get_peer_map().get_peer_id_by_ipv4(addr.ip()).await,
|
||||
SocketAddr::V6(_) => return Err(anyhow::anyhow!("ipv6 is not supported").into()),
|
||||
}) else {
|
||||
return Err(anyhow::anyhow!("no peer found for nat dst: {}", nat_dst).into());
|
||||
};
|
||||
tracing::trace!(?nat_dst, ?dst_peer, "quic nat");
|
||||
|
||||
trace!("quic nat dst: {:?}, dst peers: {:?}", nat_dst, dst_peer_id);
|
||||
let addr = QuicAddr::new(dst_peer_id, PacketType::QuicSrc).into();
|
||||
let header = {
|
||||
let conn_data = QuicConnData {
|
||||
src: Some(src.into()),
|
||||
@@ -312,77 +319,91 @@ impl NatDstConnector for NatDstQuicConnector {
|
||||
};
|
||||
|
||||
let len = conn_data.encoded_len();
|
||||
if len > (u16::MAX as usize) {
|
||||
return Err(anyhow!("conn data too large: {:?}", len).into());
|
||||
}
|
||||
ensure!(len <= u16::MAX as usize, "conn data too large: {len}");
|
||||
|
||||
let mut buf = BytesMut::with_capacity(2 + len);
|
||||
|
||||
buf.put_u16(len as u16);
|
||||
conn_data.encode(&mut buf).unwrap();
|
||||
conn_data.encode(&mut buf)?;
|
||||
|
||||
buf.freeze()
|
||||
};
|
||||
|
||||
for attempt in 0..2 {
|
||||
let endpoint = self.endpoint.clone();
|
||||
let reconnect = || async move {
|
||||
self.conn_map.invalidate(&dst_peer).await;
|
||||
|
||||
let connection = match self
|
||||
.conn_map
|
||||
.try_get_with(dst_peer_id, async move {
|
||||
endpoint
|
||||
.connect(addr, "")
|
||||
.map_err(|e| anyhow!("quic connect: {:#}", e))?
|
||||
.await
|
||||
.map_err(|e| anyhow!("quic connection: {:#}", e))
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
if attempt == 0 {
|
||||
debug!("quic connect failed, retrying: {:#}", e);
|
||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
||||
continue;
|
||||
let connect = (0..5)
|
||||
.map(|_| {
|
||||
let endpoint = self.endpoint.clone();
|
||||
async move {
|
||||
endpoint
|
||||
.connect(QuicAddr::new(dst_peer, PacketType::QuicSrc).into(), "")
|
||||
.context("failed to create connection")?
|
||||
.await
|
||||
.context("connection failed")
|
||||
}
|
||||
return Err(anyhow!("{:#}", e).into());
|
||||
}
|
||||
};
|
||||
})
|
||||
.hedge(Duration::from_millis(200));
|
||||
|
||||
let stream: Result<QuicStreamInner, anyhow::Error> = async {
|
||||
self.conn_map
|
||||
.try_get_with(dst_peer, connect)
|
||||
.await
|
||||
.context("failed to connect to peer")
|
||||
};
|
||||
|
||||
let mut reconnected = false;
|
||||
|
||||
let mut connection = if let Some(connection) = self.conn_map.get(&dst_peer).await
|
||||
&& connection.close_reason().is_none()
|
||||
{
|
||||
connection
|
||||
} else {
|
||||
reconnected = true;
|
||||
reconnect().await?
|
||||
};
|
||||
|
||||
loop {
|
||||
let is_retryable = |error: &ConnectionError| {
|
||||
matches!(
|
||||
error,
|
||||
ConnectionError::ConnectionClosed(_)
|
||||
| ConnectionError::ApplicationClosed(_)
|
||||
| ConnectionError::Reset
|
||||
| ConnectionError::TimedOut
|
||||
)
|
||||
};
|
||||
let mut retry = !reconnected;
|
||||
let header = header.clone();
|
||||
let result = async {
|
||||
let mut stream: QuicStream = connection
|
||||
.open_bi()
|
||||
.await
|
||||
.map_err(|e| anyhow!("open bi: {:#}", e))?
|
||||
.inspect_err(|error| retry &= is_retryable(error))?
|
||||
.into();
|
||||
stream.writer_mut().write_chunk(header.clone()).await?;
|
||||
stream
|
||||
.writer_mut()
|
||||
.write_chunk(header)
|
||||
.await
|
||||
.inspect_err(|error| {
|
||||
retry &= matches!(error, WriteError::ConnectionLost(error) if is_retryable(error))
|
||||
})?;
|
||||
Ok(stream.into())
|
||||
}
|
||||
.await;
|
||||
.await;
|
||||
|
||||
match stream {
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(error) => {
|
||||
debug!(
|
||||
?dst_peer_id,
|
||||
attempt,
|
||||
?error,
|
||||
"quic connect: stream setup failed"
|
||||
);
|
||||
if let Err(error) = &result {
|
||||
if retry {
|
||||
debug!(?error, "failed to open quic stream, retrying...");
|
||||
reconnected = true;
|
||||
connection = reconnect().await?;
|
||||
continue;
|
||||
} else {
|
||||
self.conn_map.invalidate(&dst_peer).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Evict stale connection;
|
||||
self.conn_map.invalidate(&dst_peer_id).await;
|
||||
break result;
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"quic connect: failed after {} attempts, dst_peer_id={}, nat_dst={}",
|
||||
2,
|
||||
dst_peer_id,
|
||||
nat_dst
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -839,7 +860,7 @@ impl QuicProxy {
|
||||
Arc::new(socket),
|
||||
default_runtime().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
.unwrap(); // TODO: maybe a different transport config
|
||||
endpoint.set_default_client_config(client_config());
|
||||
self.endpoint = Some(endpoint.clone());
|
||||
|
||||
@@ -863,26 +884,15 @@ impl QuicProxy {
|
||||
return;
|
||||
}
|
||||
|
||||
let conn_map = Cache::builder()
|
||||
.max_capacity(u8::MAX.into()) // same with max_concurrent_bidi_streams, can be increased
|
||||
.time_to_idle(Duration::from_secs(600))
|
||||
.build();
|
||||
|
||||
let conn_map_bg = conn_map.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
conn_map_bg.run_pending_tasks().await;
|
||||
}
|
||||
});
|
||||
|
||||
let tcp_proxy = TcpProxyForQuicSrc(TcpProxy::new(
|
||||
peer_mgr.clone(),
|
||||
NatDstQuicConnector {
|
||||
endpoint: endpoint.clone(),
|
||||
peer_mgr: Arc::downgrade(&peer_mgr),
|
||||
conn_map,
|
||||
conn_map: Cache::builder()
|
||||
.max_capacity(u8::MAX.into()) // cf. quinn transport config (max_concurrent_bidi_streams)
|
||||
.time_to_idle(Duration::from_secs(600)) // cf. quinn transport config (max_idle_timeout)
|
||||
.build(),
|
||||
},
|
||||
));
|
||||
|
||||
|
||||
@@ -240,7 +240,7 @@ impl AsyncTcpConnector for Socks5KcpConnector {
|
||||
let ret = c
|
||||
.connect(self.src_addr, addr)
|
||||
.await
|
||||
.map_err(|e| super::fast_socks5::SocksError::Other(e.into()))?;
|
||||
.map_err(super::fast_socks5::SocksError::Other)?;
|
||||
Ok(SocksTcpStream::Kcp(ret))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ use super::tokio_smoltcp::{self, Net, NetConfig, channel_device};
|
||||
pub(crate) trait NatDstConnector: Send + Sync + Clone + 'static {
|
||||
type DstStream: AsyncRead + AsyncWrite + Unpin + Send;
|
||||
|
||||
async fn connect(&self, src: SocketAddr, dst: SocketAddr) -> Result<Self::DstStream>;
|
||||
async fn connect(&self, src: SocketAddr, dst: SocketAddr) -> anyhow::Result<Self::DstStream>;
|
||||
fn check_packet_from_peer_fast(&self, cidr_set: &CidrSet, global_ctx: &GlobalCtx) -> bool;
|
||||
fn check_packet_from_peer(
|
||||
&self,
|
||||
@@ -63,14 +63,13 @@ pub struct NatDstTcpConnector;
|
||||
#[async_trait::async_trait]
|
||||
impl NatDstConnector for NatDstTcpConnector {
|
||||
type DstStream = TcpStream;
|
||||
async fn connect(&self, _src: SocketAddr, nat_dst: SocketAddr) -> Result<Self::DstStream> {
|
||||
let socket = match TcpSocket::new_v4() {
|
||||
Ok(s) => s,
|
||||
Err(error) => {
|
||||
log::error!(?error, "create v4 socket failed");
|
||||
return Err(error.into());
|
||||
}
|
||||
};
|
||||
async fn connect(
|
||||
&self,
|
||||
_src: SocketAddr,
|
||||
nat_dst: SocketAddr,
|
||||
) -> anyhow::Result<Self::DstStream> {
|
||||
let socket = TcpSocket::new_v4()
|
||||
.inspect_err(|error| log::error!(?error, "create v4 socket failed"))?;
|
||||
|
||||
let stream = timeout(Duration::from_secs(10), socket.connect(nat_dst))
|
||||
.await?
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use delegate::delegate;
|
||||
use derivative::Derivative;
|
||||
use derive_more::{Deref, DerefMut, From, IntoIterator};
|
||||
use derive_more::{AsMut, AsRef, Deref, DerefMut, From, IntoIterator};
|
||||
use prost::Message;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
@@ -45,11 +45,15 @@ where
|
||||
From,
|
||||
Deref,
|
||||
DerefMut,
|
||||
AsRef,
|
||||
AsMut,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
IntoIterator,
|
||||
)]
|
||||
#[derivative(Default(bound = ""))]
|
||||
#[as_ref(forward)]
|
||||
#[as_mut(forward)]
|
||||
#[serde(transparent)]
|
||||
#[into_iterator(owned, ref, ref_mut)]
|
||||
pub struct RepeatedMessageModel<Model>(Vec<Model>);
|
||||
@@ -74,22 +78,6 @@ impl<Model> Extend<Model> for RepeatedMessageModel<Model> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Model> AsRef<[Model]> for RepeatedMessageModel<Model> {
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn as_ref(&self) -> &[Model];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Model> AsMut<[Model]> for RepeatedMessageModel<Model> {
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn as_mut(&mut self) -> &mut [Model];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'m, Message, Model> TryFrom<&'m [Message]> for RepeatedMessageModel<Model>
|
||||
where
|
||||
Message: prost::Message,
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
use delegate::delegate;
|
||||
use derivative::Derivative;
|
||||
use derive_more::{AsMut, AsRef, Deref, DerefMut, From, Into, IntoIterator};
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Derivative, Debug, From, Into, Deref, DerefMut, AsRef, AsMut, IntoIterator, Error)]
|
||||
#[derivative(Default(bound = ""))]
|
||||
#[as_ref(forward)]
|
||||
#[as_mut(forward)]
|
||||
#[into_iterator(owned, ref, ref_mut)]
|
||||
pub struct ErrorCollection<E> {
|
||||
pub errors: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E> ErrorCollection<E> {
|
||||
delegate! {
|
||||
to Vec {
|
||||
#[into]
|
||||
pub fn new() -> Self;
|
||||
#[into]
|
||||
pub fn with_capacity(capacity: usize) -> Self;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, Item: Into<E>> FromIterator<Item> for ErrorCollection<E> {
|
||||
fn from_iter<I: IntoIterator<Item = Item>>(iter: I) -> Self {
|
||||
Self {
|
||||
errors: iter.into_iter().map(Into::into).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Extend<E> for ErrorCollection<E> {
|
||||
delegate! {
|
||||
to self.errors {
|
||||
fn extend<T: IntoIterator<Item = E>>(&mut self, iter: T);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Display> Display for ErrorCollection<E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if self.errors.is_empty() {
|
||||
return write!(f, "No errors");
|
||||
}
|
||||
|
||||
write!(f, "{} error(s) occurred:", self.errors.len())?;
|
||||
for (i, err) in self.errors.iter().enumerate() {
|
||||
writeln!(f)?;
|
||||
write!(f, " {}. {}", i + 1, err)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod error;
|
||||
pub mod panic;
|
||||
pub mod string;
|
||||
pub mod task;
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
use crate::utils::error::ErrorCollection;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::sleep;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
@@ -78,3 +82,61 @@ impl<Output> Future for CancellableTask<Output> {
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region HedgeExt
|
||||
|
||||
pub(crate) trait HedgeExt: Iterator + Sized {
|
||||
async fn hedge<T, E>(self, delay: Duration) -> Result<T, ErrorCollection<E>>
|
||||
where
|
||||
Self::Item: Future<Output = Result<T, E>>;
|
||||
}
|
||||
|
||||
impl<I> HedgeExt for I
|
||||
where
|
||||
I: Iterator,
|
||||
{
|
||||
async fn hedge<T, E>(mut self, delay: Duration) -> Result<T, ErrorCollection<E>>
|
||||
where
|
||||
Self::Item: Future<Output = Result<T, E>>,
|
||||
{
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
let mut errors = ErrorCollection::new();
|
||||
let mut exhausted = false;
|
||||
|
||||
macro_rules! spawn {
|
||||
() => {
|
||||
if let Some(fut) = self.next() {
|
||||
tasks.push(fut);
|
||||
} else {
|
||||
exhausted = true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
spawn!();
|
||||
|
||||
while !tasks.is_empty() {
|
||||
tokio::select! {
|
||||
res = tasks.next() => {
|
||||
match res {
|
||||
Some(Ok(v)) => return Ok(v),
|
||||
Some(Err(e)) => errors.push(e),
|
||||
None => unreachable!(),
|
||||
}
|
||||
|
||||
if !exhausted {
|
||||
spawn!();
|
||||
}
|
||||
}
|
||||
|
||||
_ = sleep(delay), if !exhausted => {
|
||||
spawn!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
Reference in New Issue
Block a user