mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 10:14:35 +00:00
some minor bug fixs (#41)
* fix joinset leak; * fix udp packet format * fix trace log panic * avoid waiting after listener accept
This commit is contained in:
@@ -1,3 +1,11 @@
|
|||||||
|
use std::{
|
||||||
|
fmt::Debug,
|
||||||
|
future,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
|
use tokio::task::JoinSet;
|
||||||
|
use tracing::Instrument;
|
||||||
|
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod constants;
|
pub mod constants;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
@@ -30,3 +38,83 @@ pub type PeerId = u32;
|
|||||||
pub fn new_peer_id() -> PeerId {
|
pub fn new_peer_id() -> PeerId {
|
||||||
rand::random()
|
rand::random()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
|
||||||
|
js: Arc<Mutex<JoinSet<T>>>,
|
||||||
|
origin: String,
|
||||||
|
) {
|
||||||
|
let js = Arc::downgrade(&js);
|
||||||
|
tokio::spawn(
|
||||||
|
async move {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
if js.weak_count() == 0 {
|
||||||
|
tracing::info!("joinset task exit");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
future::poll_fn(|cx| {
|
||||||
|
tracing::info!("try join joinset tasks");
|
||||||
|
let Some(js) = js.upgrade() else {
|
||||||
|
return std::task::Poll::Ready(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut js = js.lock().unwrap();
|
||||||
|
while !js.is_empty() {
|
||||||
|
let ret = js.poll_join_next(cx);
|
||||||
|
if ret.is_pending() {
|
||||||
|
return std::task::Poll::Pending;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::task::Poll::Ready(())
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.instrument(tracing::info_span!(
|
||||||
|
"join_joinset_background",
|
||||||
|
origin = origin
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_join_joinset_backgroud() {
|
||||||
|
let js = Arc::new(Mutex::new(JoinSet::<()>::new()));
|
||||||
|
join_joinset_background(js.clone(), "TEST".to_owned());
|
||||||
|
js.try_lock().unwrap().spawn(async {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
});
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||||
|
assert!(js.try_lock().unwrap().is_empty());
|
||||||
|
|
||||||
|
for _ in 0..5 {
|
||||||
|
js.try_lock().unwrap().spawn(async {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||||
|
});
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||||
|
|
||||||
|
for _ in 0..5 {
|
||||||
|
js.try_lock().unwrap().spawn(async {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
});
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||||
|
assert!(js.try_lock().unwrap().is_empty());
|
||||||
|
|
||||||
|
let weak_js = Arc::downgrade(&js);
|
||||||
|
drop(js);
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||||
|
assert_eq!(weak_js.weak_count(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ use tracing::Instrument;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
constants, error::Error, global_ctx::ArcGlobalCtx, rkyv_util::encode_to_bytes,
|
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
|
||||||
stun::StunInfoCollectorTrait, PeerId,
|
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
|
||||||
},
|
},
|
||||||
peers::peer_manager::PeerManager,
|
peers::peer_manager::PeerManager,
|
||||||
rpc::NatType,
|
rpc::NatType,
|
||||||
@@ -75,9 +75,15 @@ impl UdpHolePunchListener {
|
|||||||
while let Ok(conn) = listener.accept().await {
|
while let Ok(conn) = listener.accept().await {
|
||||||
last_connected_time_clone.store(std::time::Instant::now());
|
last_connected_time_clone.store(std::time::Instant::now());
|
||||||
tracing::warn!(?conn, "udp hole punching listener got peer connection");
|
tracing::warn!(?conn, "udp hole punching listener got peer connection");
|
||||||
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
let peer_mgr = peer_mgr.clone();
|
||||||
tracing::error!(?e, "failed to add tunnel as server in hole punch listener");
|
tokio::spawn(async move {
|
||||||
}
|
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
||||||
|
tracing::error!(
|
||||||
|
?e,
|
||||||
|
"failed to add tunnel as server in hole punch listener"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
running_clone.store(false);
|
running_clone.store(false);
|
||||||
@@ -115,7 +121,7 @@ struct UdpHolePunchConnectorData {
|
|||||||
struct UdpHolePunchRpcServer {
|
struct UdpHolePunchRpcServer {
|
||||||
data: Arc<UdpHolePunchConnectorData>,
|
data: Arc<UdpHolePunchConnectorData>,
|
||||||
|
|
||||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
#[tarpc::server]
|
||||||
@@ -140,7 +146,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
|
|||||||
|| my_udp_nat_type == NatType::Restricted as i32
|
|| my_udp_nat_type == NatType::Restricted as i32
|
||||||
{
|
{
|
||||||
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
|
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
|
||||||
self.tasks.lock().await.spawn(async move {
|
self.tasks.lock().unwrap().spawn(async move {
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
||||||
// generate a 128 bytes vec with random data
|
// generate a 128 bytes vec with random data
|
||||||
@@ -164,10 +170,9 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
|
|||||||
|
|
||||||
impl UdpHolePunchRpcServer {
|
impl UdpHolePunchRpcServer {
|
||||||
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
|
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
|
||||||
Self {
|
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||||
data,
|
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
|
||||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
Self { data, tasks }
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
|
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ use tracing::Instrument;
|
|||||||
|
|
||||||
use crate::common::error::Result;
|
use crate::common::error::Result;
|
||||||
use crate::common::global_ctx::GlobalCtx;
|
use crate::common::global_ctx::GlobalCtx;
|
||||||
|
use crate::common::join_joinset_background;
|
||||||
use crate::common::netns::NetNS;
|
use crate::common::netns::NetNS;
|
||||||
use crate::peers::packet::{self, ArchivedPacket};
|
use crate::peers::packet::{self, ArchivedPacket};
|
||||||
use crate::peers::peer_manager::PeerManager;
|
use crate::peers::peer_manager::PeerManager;
|
||||||
@@ -71,7 +72,7 @@ pub struct TcpProxy {
|
|||||||
peer_manager: Arc<PeerManager>,
|
peer_manager: Arc<PeerManager>,
|
||||||
local_port: AtomicU16,
|
local_port: AtomicU16,
|
||||||
|
|
||||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||||
|
|
||||||
syn_map: SynSockMap,
|
syn_map: SynSockMap,
|
||||||
conn_map: ConnSockMap,
|
conn_map: ConnSockMap,
|
||||||
@@ -215,7 +216,7 @@ impl TcpProxy {
|
|||||||
peer_manager,
|
peer_manager,
|
||||||
|
|
||||||
local_port: AtomicU16::new(0),
|
local_port: AtomicU16::new(0),
|
||||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||||
|
|
||||||
syn_map: Arc::new(DashMap::new()),
|
syn_map: Arc::new(DashMap::new()),
|
||||||
conn_map: Arc::new(DashMap::new()),
|
conn_map: Arc::new(DashMap::new()),
|
||||||
@@ -247,6 +248,7 @@ impl TcpProxy {
|
|||||||
self.peer_manager
|
self.peer_manager
|
||||||
.add_nic_packet_process_pipeline(Box::new(self.clone()))
|
.add_nic_packet_process_pipeline(Box::new(self.clone()))
|
||||||
.await;
|
.await;
|
||||||
|
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -268,7 +270,7 @@ impl TcpProxy {
|
|||||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
tasks.lock().await.spawn(syn_map_cleaner_task);
|
tasks.lock().unwrap().spawn(syn_map_cleaner_task);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -312,7 +314,7 @@ impl TcpProxy {
|
|||||||
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
|
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
|
||||||
assert!(old_nat_val.is_none());
|
assert!(old_nat_val.is_none());
|
||||||
|
|
||||||
tasks.lock().await.spawn(Self::connect_to_nat_dst(
|
tasks.lock().unwrap().spawn(Self::connect_to_nat_dst(
|
||||||
net_ns.clone(),
|
net_ns.clone(),
|
||||||
tcp_stream,
|
tcp_stream,
|
||||||
conn_map.clone(),
|
conn_map.clone(),
|
||||||
@@ -325,7 +327,7 @@ impl TcpProxy {
|
|||||||
};
|
};
|
||||||
self.tasks
|
self.tasks
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.unwrap()
|
||||||
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
|
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -100,15 +100,19 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
|
|||||||
tunnel_info.remote_addr.clone(),
|
tunnel_info.remote_addr.clone(),
|
||||||
));
|
));
|
||||||
tracing::info!(ret = ?ret, "conn accepted");
|
tracing::info!(ret = ?ret, "conn accepted");
|
||||||
let server_ret = peer_manager.handle_tunnel(ret).await;
|
let peer_manager = peer_manager.clone();
|
||||||
if let Err(e) = &server_ret {
|
let global_ctx = global_ctx.clone();
|
||||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
|
tokio::spawn(async move {
|
||||||
tunnel_info.local_addr,
|
let server_ret = peer_manager.handle_tunnel(ret).await;
|
||||||
tunnel_info.remote_addr,
|
if let Err(e) = &server_ret {
|
||||||
e.to_string(),
|
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
|
||||||
));
|
tunnel_info.local_addr,
|
||||||
tracing::error!(error = ?e, "handle conn error");
|
tunnel_info.remote_addr,
|
||||||
}
|
e.to_string(),
|
||||||
|
));
|
||||||
|
tracing::error!(error = ?e, "handle conn error");
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+14
-1
@@ -99,7 +99,7 @@ pub enum PacketType {
|
|||||||
TaRpc = 6,
|
TaRpc = 6,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
#[derive(Archive, Deserialize, Serialize)]
|
||||||
#[archive(compare(PartialEq), check_bytes)]
|
#[archive(compare(PartialEq), check_bytes)]
|
||||||
// Derives can be passed through to the generated type:
|
// Derives can be passed through to the generated type:
|
||||||
pub struct Packet {
|
pub struct Packet {
|
||||||
@@ -109,6 +109,19 @@ pub struct Packet {
|
|||||||
pub payload: String,
|
pub payload: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Packet {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
|
||||||
|
self.from_peer,
|
||||||
|
self.to_peer,
|
||||||
|
self.packet_type,
|
||||||
|
&self.payload.as_bytes()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for ArchivedPacket {
|
impl std::fmt::Debug for ArchivedPacket {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(
|
write!(
|
||||||
|
|||||||
+48
-23
@@ -13,7 +13,10 @@ use tokio_util::{
|
|||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::rkyv_util::{self, encode_to_bytes, vec_to_string},
|
common::{
|
||||||
|
join_joinset_background,
|
||||||
|
rkyv_util::{self, encode_to_bytes, vec_to_string},
|
||||||
|
},
|
||||||
rpc::TunnelInfo,
|
rpc::TunnelInfo,
|
||||||
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
|
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
|
||||||
};
|
};
|
||||||
@@ -27,7 +30,7 @@ use super::{
|
|||||||
|
|
||||||
pub const UDP_DATA_MTU: usize = 2500;
|
pub const UDP_DATA_MTU: usize = 2500;
|
||||||
|
|
||||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
#[derive(Archive, Deserialize, Serialize)]
|
||||||
#[archive(compare(PartialEq), check_bytes)]
|
#[archive(compare(PartialEq), check_bytes)]
|
||||||
// Derives can be passed through to the generated type:
|
// Derives can be passed through to the generated type:
|
||||||
pub enum UdpPacketPayload {
|
pub enum UdpPacketPayload {
|
||||||
@@ -37,14 +40,29 @@ pub enum UdpPacketPayload {
|
|||||||
Data(String),
|
Data(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for UdpPacketPayload {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||||
|
match self {
|
||||||
|
UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
|
||||||
|
UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
|
||||||
|
UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(),
|
||||||
|
UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||||
#[archive(compare(PartialEq), check_bytes)]
|
#[archive(compare(PartialEq), check_bytes)]
|
||||||
#[archive_attr(derive(Debug))]
|
#[archive_attr(derive(Debug))]
|
||||||
pub struct UdpPacket {
|
pub struct UdpPacket {
|
||||||
pub conn_id: u32,
|
pub conn_id: u32,
|
||||||
|
pub magic: u32,
|
||||||
pub payload: UdpPacketPayload,
|
pub payload: UdpPacketPayload,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const UDP_PACKET_MAGIC: u32 = 0x19941126;
|
||||||
|
|
||||||
impl std::fmt::Debug for ArchivedUdpPacketPayload {
|
impl std::fmt::Debug for ArchivedUdpPacketPayload {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||||
@@ -63,6 +81,7 @@ impl UdpPacket {
|
|||||||
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
|
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
conn_id,
|
conn_id,
|
||||||
|
magic: UDP_PACKET_MAGIC,
|
||||||
payload: UdpPacketPayload::Data(vec_to_string(data)),
|
payload: UdpPacketPayload::Data(vec_to_string(data)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -70,6 +89,7 @@ impl UdpPacket {
|
|||||||
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
|
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
conn_id: 0,
|
conn_id: 0,
|
||||||
|
magic: UDP_PACKET_MAGIC,
|
||||||
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
|
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -77,6 +97,7 @@ impl UdpPacket {
|
|||||||
pub fn new_syn_packet(conn_id: u32) -> Self {
|
pub fn new_syn_packet(conn_id: u32) -> Self {
|
||||||
Self {
|
Self {
|
||||||
conn_id,
|
conn_id,
|
||||||
|
magic: UDP_PACKET_MAGIC,
|
||||||
payload: UdpPacketPayload::Syn,
|
payload: UdpPacketPayload::Syn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -84,6 +105,7 @@ impl UdpPacket {
|
|||||||
pub fn new_sack_packet(conn_id: u32) -> Self {
|
pub fn new_sack_packet(conn_id: u32) -> Self {
|
||||||
Self {
|
Self {
|
||||||
conn_id,
|
conn_id,
|
||||||
|
magic: UDP_PACKET_MAGIC,
|
||||||
payload: UdpPacketPayload::Sack,
|
payload: UdpPacketPayload::Sack,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -100,6 +122,11 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
|
|||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||||
|
tracing::warn!(?udp_packet, "udp magic not match");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
|
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
|
||||||
tracing::warn!(?udp_packet, "udp payload not data");
|
tracing::warn!(?udp_packet, "udp payload not data");
|
||||||
return None;
|
return None;
|
||||||
@@ -189,7 +216,7 @@ pub struct UdpTunnelListener {
|
|||||||
socket: Option<Arc<UdpSocket>>,
|
socket: Option<Arc<UdpSocket>>,
|
||||||
|
|
||||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||||
|
|
||||||
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
|
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
|
||||||
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
|
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
|
||||||
@@ -202,7 +229,7 @@ impl UdpTunnelListener {
|
|||||||
addr,
|
addr,
|
||||||
socket: None,
|
socket: None,
|
||||||
sock_map: Arc::new(DashMap::new()),
|
sock_map: Arc::new(DashMap::new()),
|
||||||
forward_tasks: Arc::new(Mutex::new(JoinSet::new())),
|
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||||
conn_recv,
|
conn_recv,
|
||||||
conn_send: Some(conn_send),
|
conn_send: Some(conn_send),
|
||||||
}
|
}
|
||||||
@@ -234,7 +261,7 @@ impl UdpTunnelListener {
|
|||||||
async fn handle_connect(
|
async fn handle_connect(
|
||||||
socket: Arc<UdpSocket>,
|
socket: Arc<UdpSocket>,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||||
local_url: url::Url,
|
local_url: url::Url,
|
||||||
conn_id: u32,
|
conn_id: u32,
|
||||||
@@ -251,7 +278,7 @@ impl UdpTunnelListener {
|
|||||||
let addr_copy = addr.clone();
|
let addr_copy = addr.clone();
|
||||||
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
|
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
|
||||||
let ctunnel_stream = ctunnel.pin_stream();
|
let ctunnel_stream = ctunnel.pin_stream();
|
||||||
forward_tasks.lock().await.spawn(async move {
|
forward_tasks.lock().unwrap().spawn(async move {
|
||||||
let ret = ctunnel_stream
|
let ret = ctunnel_stream
|
||||||
.map(|v| {
|
.map(|v| {
|
||||||
tracing::trace!(?v, "udp stream recv something in forward task");
|
tracing::trace!(?v, "udp stream recv something in forward task");
|
||||||
@@ -304,7 +331,7 @@ impl TunnelListener for UdpTunnelListener {
|
|||||||
let sock_map = self.sock_map.clone();
|
let sock_map = self.sock_map.clone();
|
||||||
let conn_send = self.conn_send.take().unwrap();
|
let conn_send = self.conn_send.take().unwrap();
|
||||||
let local_url = self.local_url().clone();
|
let local_url = self.local_url().clone();
|
||||||
self.forward_tasks.lock().await.spawn(
|
self.forward_tasks.lock().unwrap().spawn(
|
||||||
async move {
|
async move {
|
||||||
loop {
|
loop {
|
||||||
let mut buf = BytesMut::new();
|
let mut buf = BytesMut::new();
|
||||||
@@ -323,6 +350,11 @@ impl TunnelListener for UdpTunnelListener {
|
|||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||||
|
tracing::info!(?udp_packet, "udp magic not match");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
|
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
|
||||||
let Ok(conn) = Self::handle_connect(
|
let Ok(conn) = Self::handle_connect(
|
||||||
socket.clone(),
|
socket.clone(),
|
||||||
@@ -350,22 +382,7 @@ impl TunnelListener for UdpTunnelListener {
|
|||||||
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
|
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
|
||||||
);
|
);
|
||||||
|
|
||||||
// let forward_tasks_clone = self.forward_tasks.clone();
|
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
|
||||||
// tokio::spawn(async move {
|
|
||||||
// loop {
|
|
||||||
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
|
|
||||||
// tokio::select! {
|
|
||||||
// ret = locked_forward_tasks.join_next() => {
|
|
||||||
// tracing::warn!(?ret, "udp forward task exit");
|
|
||||||
// }
|
|
||||||
// else => {
|
|
||||||
// drop(locked_forward_tasks);
|
|
||||||
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
|
||||||
// continue;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -453,6 +470,14 @@ impl UdpTunnelConnector {
|
|||||||
)));
|
)));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||||
|
tracing::info!(?udp_packet, "udp magic not match");
|
||||||
|
return Err(super::TunnelError::ConnectError(format!(
|
||||||
|
"udp connect error, magic not match. magic: {:?}",
|
||||||
|
udp_packet.magic
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
if conn_id != udp_packet.conn_id {
|
if conn_id != udp_packet.conn_id {
|
||||||
return Err(super::TunnelError::ConnectError(format!(
|
return Err(super::TunnelError::ConnectError(format!(
|
||||||
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
|
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
|
||||||
|
|||||||
Reference in New Issue
Block a user