mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-06 17:59:11 +00:00
@@ -39,6 +39,7 @@ pub fn gen_default_flags() -> Flags {
|
||||
disable_p2p: false,
|
||||
p2p_only: false,
|
||||
relay_all_peer_rpc: false,
|
||||
disable_tcp_hole_punching: false,
|
||||
disable_udp_hole_punching: false,
|
||||
multi_thread: true,
|
||||
data_compress_algo: CompressionAlgoPb::None.into(),
|
||||
|
||||
+571
-42
@@ -1,5 +1,5 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -9,6 +9,8 @@ use anyhow::Context;
|
||||
use chrono::Local;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use rand::seq::IteratorRandom;
|
||||
use socket2::{SockAddr, SockRef};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{lookup_host, UdpSocket};
|
||||
use tokio::sync::{broadcast, Mutex};
|
||||
use tokio::task::JoinSet;
|
||||
@@ -375,16 +377,28 @@ impl StunClientBuilder {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UdpNatTypeDetectResult {
|
||||
pub enum StunTransport {
|
||||
Udp,
|
||||
Tcp,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StunNatTypeDetectResult {
|
||||
transport: StunTransport,
|
||||
source_addr: SocketAddr,
|
||||
stun_resps: Vec<BindRequestResponse>,
|
||||
// if we are easy symmetric nat, we need to test with another port to check inc or dec
|
||||
extra_bind_test: Option<BindRequestResponse>,
|
||||
}
|
||||
|
||||
impl UdpNatTypeDetectResult {
|
||||
fn new(source_addr: SocketAddr, stun_resps: Vec<BindRequestResponse>) -> Self {
|
||||
impl StunNatTypeDetectResult {
|
||||
fn new(
|
||||
transport: StunTransport,
|
||||
source_addr: SocketAddr,
|
||||
stun_resps: Vec<BindRequestResponse>,
|
||||
) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
source_addr,
|
||||
stun_resps,
|
||||
extra_bind_test: None,
|
||||
@@ -447,7 +461,7 @@ impl UdpNatTypeDetectResult {
|
||||
mapped_addr_count == 1
|
||||
}
|
||||
|
||||
pub fn nat_type(&self) -> NatType {
|
||||
fn nat_type_udp(&self) -> NatType {
|
||||
if self.stun_server_count() < 2 {
|
||||
return NatType::Unknown;
|
||||
}
|
||||
@@ -498,6 +512,33 @@ impl UdpNatTypeDetectResult {
|
||||
}
|
||||
}
|
||||
|
||||
fn nat_type_tcp(&self) -> NatType {
|
||||
if self.is_open_internet() {
|
||||
return NatType::OpenInternet;
|
||||
}
|
||||
|
||||
if self.stun_server_count() < 2 || self.stun_resps.is_empty() {
|
||||
return NatType::Unknown;
|
||||
}
|
||||
|
||||
if self.is_cone() {
|
||||
if self.is_pat() {
|
||||
NatType::NoPat
|
||||
} else {
|
||||
NatType::FullCone
|
||||
}
|
||||
} else {
|
||||
NatType::Symmetric
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nat_type(&self) -> NatType {
|
||||
match self.transport {
|
||||
StunTransport::Udp => self.nat_type_udp(),
|
||||
StunTransport::Tcp => self.nat_type_tcp(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn public_ips(&self) -> Vec<IpAddr> {
|
||||
self.stun_resps
|
||||
.iter()
|
||||
@@ -521,7 +562,7 @@ impl UdpNatTypeDetectResult {
|
||||
self.source_addr
|
||||
}
|
||||
|
||||
pub fn extend_result(&mut self, other: UdpNatTypeDetectResult) {
|
||||
pub fn extend_result(&mut self, other: StunNatTypeDetectResult) {
|
||||
self.stun_resps.extend(other.stun_resps);
|
||||
}
|
||||
|
||||
@@ -575,7 +616,10 @@ impl UdpNatTypeDetector {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn detect_nat_type(&self, source_port: u16) -> Result<UdpNatTypeDetectResult, Error> {
|
||||
pub async fn detect_nat_type(
|
||||
&self,
|
||||
source_port: u16,
|
||||
) -> Result<StunNatTypeDetectResult, Error> {
|
||||
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?);
|
||||
self.detect_nat_type_with_socket(udp).await
|
||||
}
|
||||
@@ -584,7 +628,7 @@ impl UdpNatTypeDetector {
|
||||
pub async fn detect_nat_type_with_socket(
|
||||
&self,
|
||||
udp: Arc<UdpSocket>,
|
||||
) -> Result<UdpNatTypeDetectResult, Error> {
|
||||
) -> Result<StunNatTypeDetectResult, Error> {
|
||||
let mut stun_servers = vec![];
|
||||
let mut host_resolver = HostResolverIter::new(
|
||||
self.stun_server_hosts.clone(),
|
||||
@@ -623,7 +667,241 @@ impl UdpNatTypeDetector {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(UdpNatTypeDetectResult::new(udp.local_addr()?, bind_resps))
|
||||
Ok(StunNatTypeDetectResult::new(
|
||||
StunTransport::Udp,
|
||||
udp.local_addr()?,
|
||||
bind_resps,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TcpStunClient {
|
||||
stun_server: SocketAddr,
|
||||
conn_timeout: Duration,
|
||||
io_timeout: Duration,
|
||||
source_port: u16,
|
||||
}
|
||||
|
||||
impl TcpStunClient {
|
||||
pub fn new(stun_server: SocketAddr, source_port: u16) -> Self {
|
||||
Self {
|
||||
stun_server,
|
||||
conn_timeout: Duration::from_millis(1500),
|
||||
io_timeout: Duration::from_millis(3000),
|
||||
source_port,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_mapped_addr(msg: &Message<Attribute>) -> Option<SocketAddr> {
|
||||
let mut mapped_addr = None;
|
||||
for x in msg.attributes() {
|
||||
match x {
|
||||
Attribute::MappedAddress(addr) => {
|
||||
if mapped_addr.is_none() {
|
||||
let _ = mapped_addr.insert(addr.address());
|
||||
}
|
||||
}
|
||||
Attribute::XorMappedAddress(addr) => {
|
||||
if mapped_addr.is_none() {
|
||||
let _ = mapped_addr.insert(addr.address());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
mapped_addr
|
||||
}
|
||||
|
||||
fn message_size_from_header(header: &[u8; 20]) -> Result<usize, Error> {
|
||||
if (header[0] & 0b1100_0000) != 0 {
|
||||
return Err(Error::MessageDecodeError(
|
||||
"invalid stun message type".to_string(),
|
||||
));
|
||||
}
|
||||
let msg_len = u16::from_be_bytes([header[2], header[3]]) as usize;
|
||||
if !msg_len.is_multiple_of(4) {
|
||||
return Err(Error::MessageDecodeError(
|
||||
"invalid stun message length".to_string(),
|
||||
));
|
||||
}
|
||||
let total = 20usize
|
||||
.checked_add(msg_len)
|
||||
.ok_or_else(|| Error::MessageDecodeError("invalid stun message size".to_string()))?;
|
||||
if total > 4096 {
|
||||
return Err(Error::MessageDecodeError(
|
||||
"stun message too large".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(total)
|
||||
}
|
||||
|
||||
async fn tcp_read_stun_message(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
timeout: Duration,
|
||||
) -> Result<Message<Attribute>, Error> {
|
||||
let mut header = [0u8; 20];
|
||||
tokio::time::timeout(timeout, stream.read_exact(&mut header)).await??;
|
||||
let total_size = Self::message_size_from_header(&header)?;
|
||||
let mut buf = vec![0u8; total_size];
|
||||
buf[..20].copy_from_slice(&header);
|
||||
if total_size > 20 {
|
||||
tokio::time::timeout(timeout, stream.read_exact(&mut buf[20..])).await??;
|
||||
}
|
||||
|
||||
let mut decoder = MessageDecoder::<Attribute>::new();
|
||||
let Ok(msg) = decoder
|
||||
.decode_from_bytes(&buf)
|
||||
.with_context(|| "decode tcp stun message")?
|
||||
else {
|
||||
return Err(Error::MessageDecodeError(
|
||||
"invalid stun message".to_string(),
|
||||
));
|
||||
};
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
async fn connect(&self) -> Result<tokio::net::TcpStream, Error> {
|
||||
let bind_addr = match self.stun_server {
|
||||
SocketAddr::V4(_) => {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), self.source_port)
|
||||
}
|
||||
SocketAddr::V6(_) => {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.source_port)
|
||||
}
|
||||
};
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(self.stun_server),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
|
||||
if bind_addr.is_ipv6() {
|
||||
socket2_socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
socket2_socket.set_nonblocking(true)?;
|
||||
socket2_socket.set_reuse_address(true)?;
|
||||
|
||||
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
{
|
||||
let _ = socket2_socket.set_reuse_port(true);
|
||||
}
|
||||
|
||||
socket2_socket.bind(&SockAddr::from(bind_addr))?;
|
||||
|
||||
let socket = tokio::net::TcpSocket::from_std_stream(socket2_socket.into());
|
||||
let stream =
|
||||
tokio::time::timeout(self.conn_timeout, socket.connect(self.stun_server)).await??;
|
||||
|
||||
let _ = SockRef::from(&stream).set_linger(Some(Duration::ZERO));
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
#[tracing::instrument(ret, level = Level::TRACE)]
|
||||
pub async fn bind_request(self) -> Result<BindRequestResponse, Error> {
|
||||
let mut tids = vec![];
|
||||
|
||||
let mut stream = self.connect().await?;
|
||||
let local_addr = stream.local_addr()?;
|
||||
let stun_host = self.stun_server;
|
||||
|
||||
let tid = rand::random::<u32>();
|
||||
let message = Message::<Attribute>::new(MessageClass::Request, BINDING, u32_to_tid(tid));
|
||||
let mut encoder = MessageEncoder::new();
|
||||
let msg = encoder
|
||||
.encode_into_bytes(message.clone())
|
||||
.with_context(|| "encode tcp stun message")?;
|
||||
tids.push(tid);
|
||||
tokio::time::timeout(self.io_timeout, stream.write_all(msg.as_slice())).await??;
|
||||
|
||||
let now = Instant::now();
|
||||
let msg = Self::tcp_read_stun_message(&mut stream, self.io_timeout).await?;
|
||||
if msg.class() != MessageClass::SuccessResponse
|
||||
|| msg.method() != BINDING
|
||||
|| !tids.contains(&tid_to_u32(&msg.transaction_id()))
|
||||
{
|
||||
return Err(Error::MessageDecodeError(
|
||||
"unexpected stun response".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(BindRequestResponse {
|
||||
local_addr,
|
||||
stun_server_addr: stun_host,
|
||||
recv_from_addr: stun_host,
|
||||
mapped_socket_addr: Self::extract_mapped_addr(&msg),
|
||||
changed_socket_addr: None,
|
||||
change_ip: false,
|
||||
change_port: false,
|
||||
real_ip_changed: false,
|
||||
real_port_changed: false,
|
||||
latency_us: now.elapsed().as_micros() as u32,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpNatTypeDetector {
|
||||
stun_server_hosts: Vec<String>,
|
||||
max_ip_per_domain: u32,
|
||||
}
|
||||
|
||||
impl TcpNatTypeDetector {
|
||||
pub fn new(stun_server_hosts: Vec<String>, max_ip_per_domain: u32) -> Self {
|
||||
Self {
|
||||
stun_server_hosts,
|
||||
max_ip_per_domain,
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn detect_nat_type(
|
||||
&self,
|
||||
source_port: u16,
|
||||
) -> Result<StunNatTypeDetectResult, Error> {
|
||||
let mut stun_servers = vec![];
|
||||
let mut host_resolver = HostResolverIter::new(
|
||||
self.stun_server_hosts.clone(),
|
||||
self.max_ip_per_domain,
|
||||
false,
|
||||
);
|
||||
while let Some(addr) = host_resolver.next().await {
|
||||
stun_servers.push(addr);
|
||||
}
|
||||
|
||||
let mut bind_resps = vec![];
|
||||
let mut source_addr = None;
|
||||
let mut selected_source_port = if source_port == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(source_port)
|
||||
};
|
||||
for server in stun_servers.iter() {
|
||||
let resp = TcpStunClient::new(*server, selected_source_port.unwrap_or(0))
|
||||
.bind_request()
|
||||
.await;
|
||||
if let Ok(resp) = resp {
|
||||
if selected_source_port.is_none() {
|
||||
selected_source_port = Some(resp.local_addr.port());
|
||||
}
|
||||
source_addr.get_or_insert(resp.local_addr);
|
||||
bind_resps.push(resp);
|
||||
if bind_resps.len() >= 3 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let Some(source_addr) = source_addr else {
|
||||
return Err(Error::NotFound);
|
||||
};
|
||||
Ok(StunNatTypeDetectResult::new(
|
||||
StunTransport::Tcp,
|
||||
source_addr,
|
||||
bind_resps,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -632,12 +910,15 @@ impl UdpNatTypeDetector {
|
||||
pub trait StunInfoCollectorTrait: Send + Sync {
|
||||
fn get_stun_info(&self) -> StunInfo;
|
||||
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error>;
|
||||
async fn get_tcp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error>;
|
||||
}
|
||||
|
||||
pub struct StunInfoCollector {
|
||||
stun_servers: Arc<RwLock<Vec<String>>>,
|
||||
tcp_stun_servers: Arc<RwLock<Vec<String>>>,
|
||||
stun_servers_v6: Arc<RwLock<Vec<String>>>,
|
||||
udp_nat_test_result: Arc<RwLock<Option<UdpNatTypeDetectResult>>>,
|
||||
udp_nat_test_result: Arc<RwLock<Option<StunNatTypeDetectResult>>>,
|
||||
tcp_nat_test_result: Arc<RwLock<Option<StunNatTypeDetectResult>>>,
|
||||
public_ipv6: Arc<AtomicCell<Option<Ipv6Addr>>>,
|
||||
nat_test_result_time: Arc<AtomicCell<chrono::DateTime<Local>>>,
|
||||
redetect_notify: Arc<tokio::sync::Notify>,
|
||||
@@ -650,21 +931,44 @@ impl StunInfoCollectorTrait for StunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
self.start_stun_routine();
|
||||
|
||||
let Some(result) = self.udp_nat_test_result.read().unwrap().clone() else {
|
||||
let udp_result = self.udp_nat_test_result.read().unwrap().clone();
|
||||
let tcp_result = self.tcp_nat_test_result.read().unwrap().clone();
|
||||
if udp_result.is_none() && tcp_result.is_none() {
|
||||
return Default::default();
|
||||
};
|
||||
}
|
||||
|
||||
let mut public_ip = BTreeSet::<String>::new();
|
||||
if let Some(result) = &udp_result {
|
||||
public_ip.extend(result.public_ips().into_iter().map(|x| x.to_string()));
|
||||
}
|
||||
if let Some(result) = &tcp_result {
|
||||
public_ip.extend(result.public_ips().into_iter().map(|x| x.to_string()));
|
||||
}
|
||||
if let Some(v6) = self.public_ipv6.load() {
|
||||
public_ip.insert(v6.to_string());
|
||||
}
|
||||
|
||||
StunInfo {
|
||||
udp_nat_type: result.nat_type() as i32,
|
||||
tcp_nat_type: 0,
|
||||
udp_nat_type: udp_result
|
||||
.as_ref()
|
||||
.map(|x| x.nat_type() as i32)
|
||||
.unwrap_or(NatType::Unknown as i32),
|
||||
tcp_nat_type: tcp_result
|
||||
.as_ref()
|
||||
.map(|x| x.nat_type() as i32)
|
||||
.unwrap_or(NatType::Unknown as i32),
|
||||
last_update_time: self.nat_test_result_time.load().timestamp(),
|
||||
public_ip: result
|
||||
.public_ips()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.chain(self.public_ipv6.load().map(|x| x.to_string()))
|
||||
.collect(),
|
||||
min_port: result.min_port() as u32,
|
||||
max_port: result.max_port() as u32,
|
||||
public_ip: public_ip.into_iter().collect(),
|
||||
min_port: udp_result
|
||||
.as_ref()
|
||||
.map(|x| x.min_port() as u32)
|
||||
.or_else(|| tcp_result.as_ref().map(|x| x.min_port() as u32))
|
||||
.unwrap_or(0),
|
||||
max_port: udp_result
|
||||
.as_ref()
|
||||
.map(|x| x.max_port() as u32)
|
||||
.or_else(|| tcp_result.as_ref().map(|x| x.max_port() as u32))
|
||||
.unwrap_or(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,14 +1019,60 @@ impl StunInfoCollectorTrait for StunInfoCollector {
|
||||
|
||||
Err(Error::NotFound)
|
||||
}
|
||||
|
||||
async fn get_tcp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
|
||||
self.start_stun_routine();
|
||||
|
||||
let mut stun_servers = self
|
||||
.tcp_nat_test_result
|
||||
.read()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.map(|x| x.collect_available_stun_server())
|
||||
.unwrap_or_default();
|
||||
|
||||
if stun_servers.is_empty() {
|
||||
let mut host_resolver =
|
||||
HostResolverIter::new(self.tcp_stun_servers.read().unwrap().clone(), 2, false);
|
||||
while let Some(addr) = host_resolver.next().await {
|
||||
stun_servers.push(addr);
|
||||
if stun_servers.len() >= 2 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stun_servers.is_empty() {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
|
||||
for server in stun_servers.iter() {
|
||||
let Ok(ret) = TcpStunClient::new(*server, local_port).bind_request().await else {
|
||||
tracing::warn!(?server, "tcp stun bind request failed");
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(mapped_addr) = ret.mapped_socket_addr {
|
||||
return Ok(mapped_addr);
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
impl StunInfoCollector {
|
||||
pub fn new(stun_servers: Vec<String>, stun_servers_v6: Vec<String>) -> Self {
|
||||
pub fn new(
|
||||
udp_stun_servers: Vec<String>,
|
||||
tcp_stun_servers: Vec<String>,
|
||||
stun_servers_v6: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stun_servers: Arc::new(RwLock::new(stun_servers)),
|
||||
stun_servers: Arc::new(RwLock::new(udp_stun_servers)),
|
||||
tcp_stun_servers: Arc::new(RwLock::new(tcp_stun_servers)),
|
||||
stun_servers_v6: Arc::new(RwLock::new(stun_servers_v6)),
|
||||
udp_nat_test_result: Arc::new(RwLock::new(None)),
|
||||
tcp_nat_test_result: Arc::new(RwLock::new(None)),
|
||||
public_ipv6: Arc::new(AtomicCell::new(None)),
|
||||
nat_test_result_time: Arc::new(AtomicCell::new(Local::now())),
|
||||
redetect_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
@@ -732,7 +1082,11 @@ impl StunInfoCollector {
|
||||
}
|
||||
|
||||
pub fn new_with_default_servers() -> Self {
|
||||
Self::new(Self::get_default_servers(), Self::get_default_servers_v6())
|
||||
Self::new(
|
||||
Self::get_default_servers(),
|
||||
Self::get_default_tcp_servers(),
|
||||
Self::get_default_servers_v6(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_stun_servers(&self, stun_servers: Vec<String>) {
|
||||
@@ -745,6 +1099,11 @@ impl StunInfoCollector {
|
||||
*g = stun_servers_v6;
|
||||
}
|
||||
|
||||
pub fn set_tcp_stun_servers(&self, stun_servers: Vec<String>) {
|
||||
let mut g = self.tcp_stun_servers.write().unwrap();
|
||||
*g = stun_servers;
|
||||
}
|
||||
|
||||
pub fn get_default_servers() -> Vec<String> {
|
||||
// NOTICE: we may need to choose stun server based on geolocation
|
||||
// stun server cross nation may return an external ip address with high latency and loss rate
|
||||
@@ -759,6 +1118,21 @@ impl StunInfoCollector {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_default_tcp_servers() -> Vec<String> {
|
||||
[
|
||||
"stun.hot-chilli.net",
|
||||
"stun.fitauto.ru",
|
||||
"fwa.lifesizecloud.com",
|
||||
"global.turn.twilio.com",
|
||||
"turn.cloudflare.com",
|
||||
"stun.voip.blackberry.com",
|
||||
"stun.radiojar.com",
|
||||
]
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_default_servers_v6() -> Vec<String> {
|
||||
["txt:stun-v6.easytier.cn"]
|
||||
.iter()
|
||||
@@ -794,35 +1168,35 @@ impl StunInfoCollector {
|
||||
|
||||
let stun_servers = self.stun_servers.clone();
|
||||
let udp_nat_test_result = self.udp_nat_test_result.clone();
|
||||
let udp_test_time = self.nat_test_result_time.clone();
|
||||
let nat_test_time = self.nat_test_result_time.clone();
|
||||
let redetect_notify = self.redetect_notify.clone();
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
loop {
|
||||
let servers = stun_servers.read().unwrap().clone();
|
||||
// use first three and random choose one from the rest
|
||||
let servers = servers
|
||||
let udp_servers = stun_servers.read().unwrap().clone();
|
||||
let udp_servers: Vec<String> = udp_servers
|
||||
.iter()
|
||||
.take(2)
|
||||
.chain(servers.iter().skip(2).choose(&mut rand::thread_rng()))
|
||||
.chain(udp_servers.iter().skip(2).choose(&mut rand::thread_rng()))
|
||||
.map(|x| x.to_string())
|
||||
.collect();
|
||||
let detector = UdpNatTypeDetector::new(servers, 1);
|
||||
let mut ret = detector.detect_nat_type(0).await;
|
||||
tracing::debug!(?ret, "finish udp nat type detect");
|
||||
|
||||
let udp_detector = UdpNatTypeDetector::new(udp_servers, 1);
|
||||
let mut udp_ret = udp_detector.detect_nat_type(0).await;
|
||||
tracing::debug!(?udp_ret, "finish udp nat type detect");
|
||||
|
||||
let mut nat_type = NatType::Unknown;
|
||||
if let Ok(resp) = &ret {
|
||||
if let Ok(resp) = &udp_ret {
|
||||
tracing::debug!(?resp, "got udp nat type detect result");
|
||||
nat_type = resp.nat_type();
|
||||
}
|
||||
|
||||
// if nat type is symmtric, detect with another port to gather more info
|
||||
if nat_type == NatType::Symmetric {
|
||||
let old_resp = ret.as_mut().unwrap();
|
||||
let old_resp = udp_ret.as_mut().unwrap();
|
||||
tracing::debug!(?old_resp, "start get extra bind result");
|
||||
let available_stun_servers = old_resp.collect_available_stun_server();
|
||||
for server in available_stun_servers.iter() {
|
||||
let ret = detector
|
||||
let ret = udp_detector
|
||||
.get_extra_bind_result(0, *server)
|
||||
.await
|
||||
.with_context(|| "get extra bind result failed");
|
||||
@@ -835,8 +1209,8 @@ impl StunInfoCollector {
|
||||
}
|
||||
|
||||
let mut sleep_sec = 10;
|
||||
if let Ok(resp) = &ret {
|
||||
udp_test_time.store(Local::now());
|
||||
if let Ok(resp) = &udp_ret {
|
||||
nat_test_time.store(Local::now());
|
||||
*udp_nat_test_result.write().unwrap() = Some(resp.clone());
|
||||
if nat_type != NatType::Unknown
|
||||
&& (nat_type != NatType::Symmetric || resp.extra_bind_test.is_some())
|
||||
@@ -852,6 +1226,40 @@ impl StunInfoCollector {
|
||||
}
|
||||
});
|
||||
|
||||
let tcp_stun_servers = self.tcp_stun_servers.clone();
|
||||
let tcp_nat_test_result = self.tcp_nat_test_result.clone();
|
||||
let nat_test_time = self.nat_test_result_time.clone();
|
||||
let redetect_notify = self.redetect_notify.clone();
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
loop {
|
||||
let tcp_servers = tcp_stun_servers.read().unwrap().clone();
|
||||
let tcp_servers: Vec<String> = tcp_servers
|
||||
.iter()
|
||||
.take(2)
|
||||
.chain(tcp_servers.iter().skip(2).choose(&mut rand::thread_rng()))
|
||||
.map(|x| x.to_string())
|
||||
.collect();
|
||||
|
||||
let tcp_detector = TcpNatTypeDetector::new(tcp_servers, 1);
|
||||
let tcp_ret = tcp_detector.detect_nat_type(0).await;
|
||||
tracing::debug!(?tcp_ret, "finish tcp nat type detect");
|
||||
|
||||
let mut sleep_sec = 10;
|
||||
if let Ok(resp) = &tcp_ret {
|
||||
nat_test_time.store(Local::now());
|
||||
*tcp_nat_test_result.write().unwrap() = Some(resp.clone());
|
||||
if resp.nat_type() != NatType::Unknown {
|
||||
sleep_sec = 600;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = redetect_notify.notified() => {}
|
||||
_ = tokio::time::sleep(Duration::from_secs(sleep_sec)) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// for ipv6
|
||||
let stun_servers = self.stun_servers_v6.clone();
|
||||
let stored_ipv6 = self.public_ipv6.clone();
|
||||
@@ -878,7 +1286,7 @@ impl StunInfoCollector {
|
||||
}
|
||||
|
||||
pub fn update_stun_info(&self) {
|
||||
self.redetect_notify.notify_one();
|
||||
self.redetect_notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -905,6 +1313,13 @@ impl StunInfoCollectorTrait for MockStunInfoCollector {
|
||||
}
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
|
||||
async fn get_tcp_port_mapping(&self, mut port: u16) -> Result<std::net::SocketAddr, Error> {
|
||||
if port == 0 {
|
||||
port = 40144;
|
||||
}
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -962,9 +1377,9 @@ mod tests {
|
||||
let detector = UdpNatTypeDetector::new(stun_servers, 1);
|
||||
for _ in 0..5 {
|
||||
let ret = detector.detect_nat_type(0).await;
|
||||
println!("{:#?}, {:?}", ret, ret.as_ref().unwrap().nat_type());
|
||||
if ret.is_ok() {
|
||||
assert!(!ret.unwrap().stun_resps.is_empty());
|
||||
println!("{:#?}, {:?}", ret, ret.as_ref().map(|x| x.nat_type()));
|
||||
if let Ok(resp) = ret {
|
||||
assert!(!resp.stun_resps.is_empty());
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -974,6 +1389,120 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_public_tcp_stun_server_fitauto_ru() {
|
||||
let stun_servers = vec![
|
||||
"stun.fitauto.ru".to_string(),
|
||||
"stun.hot-chilli.net".to_string(),
|
||||
];
|
||||
let detector = TcpNatTypeDetector::new(stun_servers, 3);
|
||||
let ret = detector.detect_nat_type(0).await;
|
||||
println!("{:#?}, {:?}", ret, ret.as_ref().map(|x| x.nat_type()));
|
||||
if let Ok(resp) = ret {
|
||||
assert!(!resp.stun_resps.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_internal_tcp_stun_server_reuse_same_local_port() {
|
||||
use stun_codec::rfc5389::attributes::XorMappedAddress;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
async fn spawn_tcp_stun_server() -> SocketAddr {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (mut stream, peer_addr) = listener.accept().await.unwrap();
|
||||
|
||||
let req = TcpStunClient::tcp_read_stun_message(&mut stream, Duration::from_secs(2))
|
||||
.await
|
||||
.unwrap();
|
||||
let mut resp_msg = Message::<Attribute>::new(
|
||||
MessageClass::SuccessResponse,
|
||||
BINDING,
|
||||
req.transaction_id(),
|
||||
);
|
||||
resp_msg.add_attribute(Attribute::XorMappedAddress(XorMappedAddress::new(
|
||||
peer_addr,
|
||||
)));
|
||||
|
||||
let mut encoder = MessageEncoder::new();
|
||||
let rsp_buf = encoder.encode_into_bytes(resp_msg).unwrap();
|
||||
stream.write_all(rsp_buf.as_slice()).await.unwrap();
|
||||
});
|
||||
|
||||
server_addr
|
||||
}
|
||||
|
||||
let server1 = spawn_tcp_stun_server().await;
|
||||
let server2 = spawn_tcp_stun_server().await;
|
||||
|
||||
let stun_servers = vec![server1.to_string(), server2.to_string()];
|
||||
let detector = TcpNatTypeDetector::new(stun_servers, 1);
|
||||
|
||||
let ret = detector.detect_nat_type(0).await.unwrap();
|
||||
assert!(ret.stun_resps.len() >= 2);
|
||||
|
||||
let local_ports = ret
|
||||
.stun_resps
|
||||
.iter()
|
||||
.map(|x| x.local_addr.port())
|
||||
.collect::<BTreeSet<_>>();
|
||||
assert_eq!(local_ports.len(), 1);
|
||||
|
||||
let mapped_ports = ret
|
||||
.stun_resps
|
||||
.iter()
|
||||
.map(|x| x.mapped_socket_addr.unwrap().port())
|
||||
.collect::<BTreeSet<_>>();
|
||||
assert_eq!(mapped_ports.len(), 1);
|
||||
assert_eq!(
|
||||
local_ports.into_iter().next(),
|
||||
mapped_ports.into_iter().next()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stun_info_collector_tcp_port_mapping() {
|
||||
use stun_codec::rfc5389::attributes::XorMappedAddress;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
for _ in 0..8 {
|
||||
let Ok((mut stream, peer_addr)) = listener.accept().await else {
|
||||
break;
|
||||
};
|
||||
|
||||
let req = TcpStunClient::tcp_read_stun_message(&mut stream, Duration::from_secs(2))
|
||||
.await
|
||||
.unwrap();
|
||||
let mut resp_msg = Message::<Attribute>::new(
|
||||
MessageClass::SuccessResponse,
|
||||
BINDING,
|
||||
req.transaction_id(),
|
||||
);
|
||||
resp_msg.add_attribute(Attribute::XorMappedAddress(XorMappedAddress::new(
|
||||
peer_addr,
|
||||
)));
|
||||
|
||||
let mut encoder = MessageEncoder::new();
|
||||
let rsp_buf = encoder.encode_into_bytes(resp_msg).unwrap();
|
||||
stream.write_all(rsp_buf.as_slice()).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let collector = StunInfoCollector::new(vec![], vec![server_addr.to_string()], vec![]);
|
||||
collector.set_tcp_stun_servers(vec![server_addr.to_string()]);
|
||||
let mapped = collector.get_tcp_port_mapping(0).await.unwrap();
|
||||
assert_eq!(mapped.ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
|
||||
assert!(mapped.port() > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v4_stun() {
|
||||
let mut udp_server = UdpTunnelListener::new("udp://0.0.0.0:55355".parse().unwrap());
|
||||
|
||||
@@ -35,7 +35,6 @@ use crate::{
|
||||
use_global_var,
|
||||
};
|
||||
|
||||
use crate::proto::api::instance::PeerConnInfo;
|
||||
use anyhow::Context;
|
||||
use rand::Rng;
|
||||
use tokio::{net::UdpSocket, task::JoinSet, time::timeout};
|
||||
@@ -51,7 +50,6 @@ static TESTING: AtomicBool = AtomicBool::new(false);
|
||||
#[async_trait::async_trait]
|
||||
pub trait PeerManagerForDirectConnector {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>>;
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager>;
|
||||
}
|
||||
|
||||
@@ -73,10 +71,6 @@ impl PeerManagerForDirectConnector for PeerManager {
|
||||
ret
|
||||
}
|
||||
|
||||
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
|
||||
self.get_peer_map().list_peer_conns(peer_id).await
|
||||
}
|
||||
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
|
||||
self.get_peer_rpc_mgr()
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ use crate::{
|
||||
|
||||
pub mod direct;
|
||||
pub mod manual;
|
||||
pub mod tcp_hole_punch;
|
||||
pub mod udp_hole_punch;
|
||||
|
||||
pub mod dns_connector;
|
||||
|
||||
@@ -0,0 +1,730 @@
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{Context, Error};
|
||||
use rand::Rng as _;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::{
|
||||
common::{join_joinset_background, stun::StunInfoCollectorTrait, PeerId},
|
||||
connector::udp_hole_punch::BackOff,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
peer_task::{PeerTaskLauncher, PeerTaskManager},
|
||||
},
|
||||
proto::{
|
||||
common::NatType,
|
||||
peer_rpc::{
|
||||
TcpHolePunchRequest, TcpHolePunchResponse, TcpHolePunchRpc,
|
||||
TcpHolePunchRpcClientFactory, TcpHolePunchRpcServer,
|
||||
},
|
||||
rpc_types::{self, controller::BaseController},
|
||||
},
|
||||
tunnel::{
|
||||
common::setup_sokcet2,
|
||||
tcp::{TcpTunnelConnector, TcpTunnelListener},
|
||||
TunnelConnector as _, TunnelListener as _,
|
||||
},
|
||||
};
|
||||
|
||||
pub const BLACKLIST_TIMEOUT_SEC: u64 = 3600;
|
||||
|
||||
fn handle_rpc_result<T>(
|
||||
ret: Result<T, rpc_types::error::Error>,
|
||||
dst_peer_id: PeerId,
|
||||
blacklist: &timedmap::TimedMap<PeerId, ()>,
|
||||
) -> Result<T, rpc_types::error::Error> {
|
||||
match ret {
|
||||
Ok(ret) => Ok(ret),
|
||||
Err(e) => {
|
||||
if matches!(e, rpc_types::error::Error::InvalidServiceKey(_, _)) {
|
||||
blacklist.insert(dst_peer_id, (), Duration::from_secs(BLACKLIST_TIMEOUT_SEC));
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_symmetric_tcp_nat(nat_type: NatType) -> bool {
|
||||
matches!(
|
||||
nat_type,
|
||||
NatType::Symmetric | NatType::SymmetricEasyInc | NatType::SymmetricEasyDec
|
||||
)
|
||||
}
|
||||
|
||||
fn bind_addr_for_port(port: u16, is_v6: bool) -> SocketAddr {
|
||||
if is_v6 {
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)
|
||||
} else {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port)
|
||||
}
|
||||
}
|
||||
|
||||
async fn select_local_port(peer_mgr: &Arc<PeerManager>, is_v6: bool) -> Result<u16, Error> {
|
||||
let bind_addr = bind_addr_for_port(0, is_v6);
|
||||
tracing::trace!(?bind_addr, is_v6, "tcp hole punch select local port");
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
|
||||
let port = listener.local_addr()?.port();
|
||||
tracing::debug!(?bind_addr, port, "tcp hole punch selected local port");
|
||||
Ok(port)
|
||||
}
|
||||
|
||||
async fn send_syn_from_port(
|
||||
peer_mgr: &Arc<PeerManager>,
|
||||
local_port: u16,
|
||||
dst: SocketAddr,
|
||||
) -> Result<(), Error> {
|
||||
let bind_addr = bind_addr_for_port(local_port, dst.is_ipv6());
|
||||
tracing::debug!(?bind_addr, ?dst, "tcp hole punch send syn");
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(dst),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = tokio::net::TcpSocket::from_std_stream(socket2_socket.into());
|
||||
match tokio::time::timeout(Duration::from_millis(600), socket.connect(dst)).await {
|
||||
Ok(Ok(_stream)) => {
|
||||
tracing::trace!(?bind_addr, ?dst, "tcp hole punch syn connect succeeded");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::trace!(?bind_addr, ?dst, ?e, "tcp hole punch syn connect failed");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::trace!(?bind_addr, ?dst, ?e, "tcp hole punch syn connect timeout");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// tcp support simultaneous connect, so initiator and server can both use connect.
|
||||
async fn try_connect_to_remote(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
a_mapped_addr: SocketAddr,
|
||||
local_port: u16,
|
||||
is_client: bool,
|
||||
max_attempts: u32,
|
||||
) -> Result<(), Error> {
|
||||
tracing::info!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
"tcp hole punch server start connect loop"
|
||||
);
|
||||
|
||||
let mut connector =
|
||||
TcpTunnelConnector::new(format!("tcp://{}", a_mapped_addr).parse().unwrap());
|
||||
connector.set_bind_addrs(vec![bind_addr_for_port(
|
||||
local_port,
|
||||
a_mapped_addr.is_ipv6(),
|
||||
)]);
|
||||
|
||||
let start = tokio::time::Instant::now();
|
||||
let mut attempts: u32 = 0;
|
||||
while start.elapsed() < Duration::from_secs(10) && attempts < max_attempts {
|
||||
attempts = attempts.wrapping_add(1);
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
if let Ok(Ok(tunnel)) =
|
||||
tokio::time::timeout(Duration::from_secs(3), connector.connect()).await
|
||||
{
|
||||
let add_tunnel_ret = if is_client {
|
||||
peer_mgr.add_client_tunnel(tunnel, false).await.map(|_| ())
|
||||
} else {
|
||||
peer_mgr.add_tunnel_as_server(tunnel, false).await
|
||||
};
|
||||
if let Err(e) = add_tunnel_ret {
|
||||
tracing::error!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
?e,
|
||||
"tcp hole punch server connected and added client tunnel failed"
|
||||
);
|
||||
continue;
|
||||
} else {
|
||||
tracing::info!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
is_client,
|
||||
"tcp hole punch server connected and added tunnel"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
tracing::trace!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
"tcp hole punch server connect attempt failed"
|
||||
);
|
||||
let sleep_ms = rand::thread_rng().gen_range(10..100);
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
attempts,
|
||||
"tcp hole punch server connect loop timeout"
|
||||
);
|
||||
|
||||
Err(anyhow::anyhow!(
|
||||
"tcp hole punch server connect loop timeout"
|
||||
))
|
||||
}
|
||||
|
||||
struct TcpHolePunchServer {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl TcpHolePunchServer {
|
||||
fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
|
||||
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||
join_joinset_background(tasks.clone(), "tcp hole punch server".to_string());
|
||||
Arc::new(Self { peer_mgr, tasks })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TcpHolePunchRpc for TcpHolePunchServer {
|
||||
type Controller = BaseController;
|
||||
|
||||
#[tracing::instrument(skip(self), fields(a_mapped_addr = ?input.connector_mapped_addr), err)]
|
||||
async fn exchange_mapped_addr(
|
||||
&self,
|
||||
_ctrl: Self::Controller,
|
||||
input: TcpHolePunchRequest,
|
||||
) -> rpc_types::error::Result<TcpHolePunchResponse> {
|
||||
let my_tcp_nat_type = NatType::try_from(
|
||||
self.peer_mgr
|
||||
.get_global_ctx()
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.tcp_nat_type,
|
||||
)
|
||||
.unwrap_or(NatType::Unknown);
|
||||
tracing::debug!(?my_tcp_nat_type, "tcp hole punch rpc received");
|
||||
if matches!(my_tcp_nat_type, NatType::Unknown) {
|
||||
tracing::warn!(?my_tcp_nat_type, "tcp hole punch rpc rejected (unknown)");
|
||||
return Err(anyhow::anyhow!("tcp nat type unknown not supported").into());
|
||||
}
|
||||
|
||||
let a_mapped_addr = input
|
||||
.connector_mapped_addr
|
||||
.ok_or(anyhow::anyhow!("connector_mapped_addr is required"))?;
|
||||
let a_mapped_addr: SocketAddr = a_mapped_addr.into();
|
||||
let a_ip = a_mapped_addr.ip();
|
||||
if a_ip.is_unspecified() || a_ip.is_multicast() {
|
||||
tracing::warn!(?a_mapped_addr, "tcp hole punch rpc invalid connector addr");
|
||||
return Err(anyhow::anyhow!("connector_mapped_addr is malformed").into());
|
||||
}
|
||||
|
||||
let is_v6 = a_mapped_addr.is_ipv6();
|
||||
let local_port = select_local_port(&self.peer_mgr, is_v6).await?;
|
||||
let mapped_addr = self
|
||||
.peer_mgr
|
||||
.get_global_ctx()
|
||||
.get_stun_info_collector()
|
||||
.get_tcp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get tcp port mapping")?;
|
||||
|
||||
tracing::info!(
|
||||
?a_mapped_addr,
|
||||
local_port,
|
||||
?mapped_addr,
|
||||
"tcp hole punch rpc responding with listener mapped addr and start connecting"
|
||||
);
|
||||
|
||||
let peer_mgr = self.peer_mgr.clone();
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
let _ = try_connect_to_remote(peer_mgr, a_mapped_addr, local_port, true, 5).await;
|
||||
});
|
||||
|
||||
Ok(TcpHolePunchResponse {
|
||||
listener_mapped_addr: Some(mapped_addr.into()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct TcpHolePunchConnectorData {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
blacklist: Arc<timedmap::TimedMap<PeerId, ()>>,
|
||||
}
|
||||
|
||||
impl TcpHolePunchConnectorData {
|
||||
fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
peer_mgr,
|
||||
blacklist: Arc::new(timedmap::TimedMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
async fn punch_as_initiator(self: Arc<Self>, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let mut backoff = BackOff::new(vec![1000, 1000, 4000, 8000]);
|
||||
|
||||
loop {
|
||||
backoff.sleep_for_next_backoff().await;
|
||||
if self.do_punch_as_initiator(dst_peer_id).await.is_ok() {
|
||||
break;
|
||||
}
|
||||
|
||||
if self.blacklist.contains(&dst_peer_id) {
|
||||
tracing::warn!(
|
||||
dst_peer_id,
|
||||
"tcp hole punch initiator skipped (blacklisted)"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(dst_peer_id), err)]
|
||||
async fn do_punch_as_initiator(&self, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let global_ctx = self.peer_mgr.get_global_ctx();
|
||||
let my_tcp_nat_type = NatType::try_from(
|
||||
global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.tcp_nat_type,
|
||||
)
|
||||
.unwrap_or(NatType::Unknown);
|
||||
tracing::debug!(?my_tcp_nat_type, "tcp hole punch initiator start");
|
||||
if is_symmetric_tcp_nat(my_tcp_nat_type) || my_tcp_nat_type == NatType::Unknown {
|
||||
tracing::debug!("tcp hole punch initiator skipped (symmetric)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let local_port = select_local_port(&self.peer_mgr, false).await?;
|
||||
let mapped_addr = global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_tcp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get tcp port mapping")?;
|
||||
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
?mapped_addr,
|
||||
"tcp hole punch initiator got mapped addr, start rpc exchange"
|
||||
);
|
||||
|
||||
let rpc_stub = self
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_client()
|
||||
.scoped_client::<TcpHolePunchRpcClientFactory<BaseController>>(
|
||||
self.peer_mgr.my_peer_id(),
|
||||
dst_peer_id,
|
||||
global_ctx.get_network_name(),
|
||||
);
|
||||
|
||||
let resp = rpc_stub
|
||||
.exchange_mapped_addr(
|
||||
BaseController {
|
||||
timeout_ms: 6000,
|
||||
..Default::default()
|
||||
},
|
||||
TcpHolePunchRequest {
|
||||
connector_mapped_addr: Some(mapped_addr.into()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let resp = handle_rpc_result(resp, dst_peer_id, &self.blacklist)?;
|
||||
let remote_mapped_addr = resp
|
||||
.listener_mapped_addr
|
||||
.ok_or(anyhow::anyhow!("listener_mapped_addr is required"))?;
|
||||
let remote_mapped_addr: SocketAddr = remote_mapped_addr.into();
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
?remote_mapped_addr,
|
||||
"tcp hole punch initiator rpc returned"
|
||||
);
|
||||
|
||||
if let Ok(()) = try_connect_to_remote(
|
||||
self.peer_mgr.clone(),
|
||||
remote_mapped_addr,
|
||||
local_port,
|
||||
false,
|
||||
1,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
?remote_mapped_addr,
|
||||
"tcp hole punch initiator connected to remote mapped addr with simultaneous connection"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
?remote_mapped_addr,
|
||||
"tcp hole punch initiator sent syn to remote mapped addr"
|
||||
);
|
||||
|
||||
let mut listener =
|
||||
TcpTunnelListener::new(format!("tcp://0.0.0.0:{}", local_port).parse().unwrap());
|
||||
{
|
||||
let _g = self.peer_mgr.get_global_ctx().net_ns.guard();
|
||||
listener.listen().await?;
|
||||
}
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
local_port,
|
||||
url = %listener.local_url(),
|
||||
"tcp hole punch initiator listening"
|
||||
);
|
||||
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
self.accept_loop(&mut listener, dst_peer_id),
|
||||
)
|
||||
.await??;
|
||||
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
"tcp hole punch initiator accepted and added server tunnel"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept_loop(
|
||||
&self,
|
||||
listener: &mut TcpTunnelListener,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error> {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok(tunnel) => {
|
||||
if let Err(e) = self.peer_mgr.add_tunnel_as_server(tunnel, false).await {
|
||||
tracing::error!("tcp hole punch add tunnel error: {}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
dst_peer_id,
|
||||
"tcp hole punch initiator accepted and added server tunnel"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("tcp hole punch accept error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
|
||||
struct TcpPunchTaskInfo {
|
||||
dst_peer_id: PeerId,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TcpHolePunchPeerTaskLauncher {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerTaskLauncher for TcpHolePunchPeerTaskLauncher {
|
||||
type Data = Arc<TcpHolePunchConnectorData>;
|
||||
type CollectPeerItem = TcpPunchTaskInfo;
|
||||
type TaskRet = ();
|
||||
|
||||
fn new_data(&self, peer_mgr: Arc<PeerManager>) -> Self::Data {
|
||||
TcpHolePunchConnectorData::new(peer_mgr)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, data))]
|
||||
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem> {
|
||||
let global_ctx = data.peer_mgr.get_global_ctx();
|
||||
let my_tcp_nat_type = NatType::try_from(
|
||||
global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.tcp_nat_type,
|
||||
)
|
||||
.unwrap_or(NatType::Unknown);
|
||||
if is_symmetric_tcp_nat(my_tcp_nat_type) || my_tcp_nat_type == NatType::Unknown {
|
||||
tracing::trace!(
|
||||
?my_tcp_nat_type,
|
||||
"tcp hole punch task collect skipped (symmetric)"
|
||||
);
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let my_peer_id = data.peer_mgr.my_peer_id();
|
||||
|
||||
data.blacklist.cleanup();
|
||||
|
||||
let mut peers_to_connect = Vec::new();
|
||||
for route in data.peer_mgr.list_routes().await.iter() {
|
||||
if route
|
||||
.feature_flag
|
||||
.map(|x| x.is_public_server)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let peer_id: PeerId = route.peer_id;
|
||||
if peer_id == my_peer_id {
|
||||
tracing::trace!(peer_id, "tcp hole punch task collect skip self");
|
||||
continue;
|
||||
}
|
||||
|
||||
if data.blacklist.contains(&peer_id) {
|
||||
tracing::debug!(peer_id, "tcp hole punch task collect skip blacklisted");
|
||||
continue;
|
||||
}
|
||||
|
||||
if data.peer_mgr.get_peer_map().has_peer(peer_id) {
|
||||
tracing::trace!(peer_id, "tcp hole punch task collect skip already has peer");
|
||||
continue;
|
||||
}
|
||||
|
||||
let peer_tcp_nat_type = route
|
||||
.stun_info
|
||||
.as_ref()
|
||||
.map(|x| x.tcp_nat_type)
|
||||
.unwrap_or(0);
|
||||
let peer_tcp_nat_type =
|
||||
NatType::try_from(peer_tcp_nat_type).unwrap_or(NatType::Unknown);
|
||||
if matches!(peer_tcp_nat_type, NatType::Unknown) {
|
||||
tracing::debug!(
|
||||
peer_id,
|
||||
?peer_tcp_nat_type,
|
||||
"tcp hole punch task collect skip peer unknown"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
peer_id,
|
||||
my_peer_id,
|
||||
?my_tcp_nat_type,
|
||||
?peer_tcp_nat_type,
|
||||
"tcp hole punch task collect add peer"
|
||||
);
|
||||
peers_to_connect.push(TcpPunchTaskInfo {
|
||||
dst_peer_id: peer_id,
|
||||
});
|
||||
}
|
||||
|
||||
peers_to_connect
|
||||
}
|
||||
|
||||
async fn launch_task(
|
||||
&self,
|
||||
data: &Self::Data,
|
||||
item: Self::CollectPeerItem,
|
||||
) -> tokio::task::JoinHandle<Result<Self::TaskRet, anyhow::Error>> {
|
||||
let data = data.clone();
|
||||
tokio::spawn(async move { data.punch_as_initiator(item.dst_peer_id).await.map(|_| ()) })
|
||||
}
|
||||
|
||||
async fn all_task_done(&self, _data: &Self::Data) {}
|
||||
|
||||
fn loop_interval_ms(&self) -> u64 {
|
||||
5000
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpHolePunchConnector {
|
||||
server: Arc<TcpHolePunchServer>,
|
||||
client: PeerTaskManager<TcpHolePunchPeerTaskLauncher>,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
impl TcpHolePunchConnector {
|
||||
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
server: TcpHolePunchServer::new(peer_mgr.clone()),
|
||||
client: PeerTaskManager::new(TcpHolePunchPeerTaskLauncher {}, peer_mgr.clone()),
|
||||
peer_mgr,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_as_client(&mut self) -> Result<(), Error> {
|
||||
tracing::info!("tcp hole punch client start");
|
||||
self.client.start();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_as_server(&mut self) -> Result<(), Error> {
|
||||
tracing::info!("tcp hole punch server register rpc");
|
||||
self.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_server()
|
||||
.registry()
|
||||
.register(
|
||||
TcpHolePunchRpcServer::new(self.server.clone()),
|
||||
&self.peer_mgr.get_global_ctx().get_network_name(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
if self.peer_mgr.get_global_ctx().get_flags().disable_p2p {
|
||||
tracing::debug!("tcp hole punch disabled by disable_p2p");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.run_as_client().await?;
|
||||
self.run_as_server().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, stun::StunInfoCollectorTrait},
|
||||
connector::tcp_hole_punch::TcpHolePunchConnector,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
||||
},
|
||||
proto::common::{NatType, StunInfo},
|
||||
tunnel::common::tests::wait_for_condition,
|
||||
};
|
||||
|
||||
struct MockStunInfoCollector {
|
||||
udp_nat_type: NatType,
|
||||
tcp_nat_type: NatType,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StunInfoCollectorTrait for MockStunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
StunInfo {
|
||||
udp_nat_type: self.udp_nat_type as i32,
|
||||
tcp_nat_type: self.tcp_nat_type as i32,
|
||||
last_update_time: 0,
|
||||
public_ip: vec!["127.0.0.1".to_string(), "::1".to_string()],
|
||||
min_port: 100,
|
||||
max_port: 200,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_udp_port_mapping(&self, mut port: u16) -> Result<SocketAddr, Error> {
|
||||
if port == 0 {
|
||||
port = 40144;
|
||||
}
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
|
||||
async fn get_tcp_port_mapping(&self, mut port: u16) -> Result<SocketAddr, Error> {
|
||||
if port == 0 {
|
||||
port = 40144;
|
||||
}
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, tcp_nat_type: NatType) {
|
||||
let collector = Box::new(MockStunInfoCollector {
|
||||
udp_nat_type: NatType::Unknown,
|
||||
tcp_nat_type,
|
||||
});
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.replace_stun_info_collector(collector);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_hole_punch_connects() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let p_b = create_mock_peer_manager().await;
|
||||
let p_c = create_mock_peer_manager().await;
|
||||
|
||||
replace_stun_info_collector(p_a.clone(), NatType::PortRestricted);
|
||||
replace_stun_info_collector(p_b.clone(), NatType::PortRestricted);
|
||||
replace_stun_info_collector(p_c.clone(), NatType::PortRestricted);
|
||||
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut hole_punching_a = TcpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = TcpHolePunchConnector::new(p_c.clone());
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
hole_punching_c.client.run_immediately().await;
|
||||
|
||||
wait_for_condition(
|
||||
|| {
|
||||
let p_a = p_a.clone();
|
||||
let p_c = p_c.clone();
|
||||
async move {
|
||||
let a_has = p_a
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_c.my_peer_id())
|
||||
.await
|
||||
.is_some_and(|c| !c.is_empty());
|
||||
let c_has = p_c
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_a.my_peer_id())
|
||||
.await
|
||||
.is_some_and(|c| !c.is_empty());
|
||||
a_has || c_has
|
||||
}
|
||||
},
|
||||
Duration::from_secs(15),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_hole_punch_skip_symmetric_peer() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let p_b = create_mock_peer_manager().await;
|
||||
let p_c = create_mock_peer_manager().await;
|
||||
|
||||
replace_stun_info_collector(p_a.clone(), NatType::Symmetric);
|
||||
replace_stun_info_collector(p_b.clone(), NatType::PortRestricted);
|
||||
replace_stun_info_collector(p_c.clone(), NatType::Symmetric);
|
||||
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut hole_punching_a = TcpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = TcpHolePunchConnector::new(p_c.clone());
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
hole_punching_c.client.run_immediately().await;
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
|
||||
assert!(p_a
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_c.my_peer_id())
|
||||
.await
|
||||
.map(|c| c.is_empty())
|
||||
.unwrap_or(true));
|
||||
assert!(p_c
|
||||
.get_peer_map()
|
||||
.list_peer_conns(p_a.my_peer_id())
|
||||
.await
|
||||
.map(|c| c.is_empty())
|
||||
.unwrap_or(true));
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,6 @@ use tokio::{sync::Mutex, task::JoinHandle};
|
||||
|
||||
use crate::{
|
||||
common::{stun::StunInfoCollectorTrait, PeerId},
|
||||
connector::direct::PeerManagerForDirectConnector,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
peer_task::{PeerTaskLauncher, PeerTaskManager},
|
||||
@@ -461,8 +460,7 @@ impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher {
|
||||
continue;
|
||||
}
|
||||
|
||||
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
|
||||
if conns.is_some() && !conns.unwrap().is_empty() {
|
||||
if data.peer_mgr.get_peer_map().has_peer(peer_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -421,6 +421,15 @@ struct NetworkOptions {
|
||||
)]
|
||||
disable_udp_hole_punching: Option<bool>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
env = "ET_DISABLE_TCP_HOLE_PUNCHING",
|
||||
help = t!("core_clap.disable_tcp_hole_punching").to_string(),
|
||||
num_args = 0..=1,
|
||||
default_missing_value = "true"
|
||||
)]
|
||||
disable_tcp_hole_punching: Option<bool>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
env = "ET_DISABLE_SYM_HOLE_PUNCHING",
|
||||
@@ -925,6 +934,9 @@ impl NetworkOptions {
|
||||
}
|
||||
f.disable_p2p = self.disable_p2p.unwrap_or(f.disable_p2p);
|
||||
f.p2p_only = self.p2p_only.unwrap_or(f.p2p_only);
|
||||
f.disable_tcp_hole_punching = self
|
||||
.disable_tcp_hole_punching
|
||||
.unwrap_or(f.disable_tcp_hole_punching);
|
||||
f.disable_udp_hole_punching = self
|
||||
.disable_udp_hole_punching
|
||||
.unwrap_or(f.disable_udp_hole_punching);
|
||||
|
||||
@@ -1470,7 +1470,9 @@ async fn main() -> Result<(), Error> {
|
||||
let collector = StunInfoCollector::new_with_default_servers();
|
||||
loop {
|
||||
let ret = collector.get_stun_info();
|
||||
if ret.udp_nat_type != NatType::Unknown as i32 {
|
||||
if ret.udp_nat_type != NatType::Unknown as i32
|
||||
&& ret.tcp_nat_type != NatType::Unknown as i32
|
||||
{
|
||||
if cli.output_format == OutputFormat::Json {
|
||||
match serde_json::to_string_pretty(&ret) {
|
||||
Ok(json) => println!("{}", json),
|
||||
|
||||
@@ -21,6 +21,7 @@ use crate::common::scoped_task::ScopedTask;
|
||||
use crate::common::PeerId;
|
||||
use crate::connector::direct::DirectConnectorManager;
|
||||
use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager};
|
||||
use crate::connector::tcp_hole_punch::TcpHolePunchConnector;
|
||||
use crate::connector::udp_hole_punch::UdpHolePunchConnector;
|
||||
use crate::gateway::icmp_proxy::IcmpProxy;
|
||||
use crate::gateway::kcp_proxy::{KcpProxyDst, KcpProxyDstRpcService, KcpProxySrc};
|
||||
@@ -516,6 +517,7 @@ pub struct Instance {
|
||||
conn_manager: Arc<ManualConnectorManager>,
|
||||
direct_conn_manager: Arc<DirectConnectorManager>,
|
||||
udp_hole_puncher: Arc<Mutex<UdpHolePunchConnector>>,
|
||||
tcp_hole_puncher: Arc<Mutex<TcpHolePunchConnector>>,
|
||||
|
||||
ip_proxy: Option<IpProxy>,
|
||||
|
||||
@@ -571,6 +573,7 @@ impl Instance {
|
||||
direct_conn_manager.run();
|
||||
|
||||
let udp_hole_puncher = UdpHolePunchConnector::new(peer_manager.clone());
|
||||
let tcp_hole_puncher = TcpHolePunchConnector::new(peer_manager.clone());
|
||||
|
||||
let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone()));
|
||||
|
||||
@@ -594,6 +597,7 @@ impl Instance {
|
||||
conn_manager,
|
||||
direct_conn_manager: Arc::new(direct_conn_manager),
|
||||
udp_hole_puncher: Arc::new(Mutex::new(udp_hole_puncher)),
|
||||
tcp_hole_puncher: Arc::new(Mutex::new(tcp_hole_puncher)),
|
||||
|
||||
ip_proxy: None,
|
||||
kcp_proxy_src: None,
|
||||
@@ -949,6 +953,7 @@ impl Instance {
|
||||
self.run_ip_proxy().await?;
|
||||
|
||||
self.udp_hole_puncher.lock().await.run().await?;
|
||||
self.tcp_hole_puncher.lock().await.run().await?;
|
||||
|
||||
self.peer_center.init().await;
|
||||
let route_calc = self.peer_center.get_cost_calculator();
|
||||
|
||||
@@ -201,6 +201,7 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
|
||||
return;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
loop {
|
||||
|
||||
@@ -740,6 +740,10 @@ impl NetworkConfig {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(disable_tcp_hole_punching) = self.disable_tcp_hole_punching {
|
||||
flags.disable_tcp_hole_punching = disable_tcp_hole_punching;
|
||||
}
|
||||
|
||||
if let Some(disable_udp_hole_punching) = self.disable_udp_hole_punching {
|
||||
flags.disable_udp_hole_punching = disable_udp_hole_punching;
|
||||
}
|
||||
@@ -898,6 +902,7 @@ impl NetworkConfig {
|
||||
result.multi_thread = Some(flags.multi_thread);
|
||||
result.proxy_forward_by_system = Some(flags.proxy_forward_by_system);
|
||||
result.disable_encryption = Some(!flags.enable_encryption);
|
||||
result.disable_tcp_hole_punching = Some(flags.disable_tcp_hole_punching);
|
||||
result.disable_udp_hole_punching = Some(flags.disable_udp_hole_punching);
|
||||
result.disable_sym_hole_punching = Some(flags.disable_sym_hole_punching);
|
||||
result.enable_magic_dns = Some(flags.accept_dns);
|
||||
@@ -1140,6 +1145,7 @@ mod tests {
|
||||
flags.multi_thread = rng.gen_bool(0.7);
|
||||
flags.proxy_forward_by_system = rng.gen_bool(0.3);
|
||||
flags.enable_encryption = rng.gen_bool(0.8);
|
||||
flags.disable_tcp_hole_punching = rng.gen_bool(0.2);
|
||||
flags.disable_udp_hole_punching = rng.gen_bool(0.2);
|
||||
flags.accept_dns = rng.gen_bool(0.6);
|
||||
flags.mtu = rng.gen_range(1200..1500);
|
||||
|
||||
@@ -545,11 +545,12 @@ mod tests {
|
||||
|
||||
println!("rpc service ready, {:#?}", rpc_service.global_peer_map);
|
||||
|
||||
if digest.is_none() {
|
||||
digest = Some(rpc_service.global_peer_map_digest.load());
|
||||
} else {
|
||||
if let Some(prev) = digest {
|
||||
let v = rpc_service.global_peer_map_digest.load();
|
||||
assert_eq!(digest.unwrap(), v);
|
||||
assert_eq!(prev, v);
|
||||
digest = Some(prev);
|
||||
} else {
|
||||
digest = Some(rpc_service.global_peer_map_digest.load());
|
||||
}
|
||||
|
||||
let mut route_cost = pc.get_cost_calculator();
|
||||
|
||||
@@ -131,7 +131,8 @@ impl RoutePeerInfo {
|
||||
ipv4_addr: None,
|
||||
proxy_cidrs: Vec::new(),
|
||||
hostname: None,
|
||||
udp_stun_info: 0,
|
||||
udp_nat_type: 0,
|
||||
tcp_nat_type: 0,
|
||||
// ensure this is updated when the peer_infos/conn_info/foreign_network lock is acquired.
|
||||
// else we may assign a older timestamp than iterate time.
|
||||
last_update: None,
|
||||
@@ -160,6 +161,7 @@ impl RoutePeerInfo {
|
||||
peer_route_id: u64,
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
) -> Self {
|
||||
let stun_info = global_ctx.get_stun_info_collector().get_stun_info();
|
||||
Self {
|
||||
peer_id: my_peer_id,
|
||||
inst_id: Some(global_ctx.get_id().into()),
|
||||
@@ -174,10 +176,8 @@ impl RoutePeerInfo {
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
hostname: Some(global_ctx.get_hostname()),
|
||||
udp_stun_info: global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type,
|
||||
udp_nat_type: stun_info.udp_nat_type,
|
||||
tcp_nat_type: stun_info.tcp_nat_type,
|
||||
|
||||
// these two fields should not participate in comparison.
|
||||
last_update: None,
|
||||
@@ -251,9 +251,12 @@ impl From<RoutePeerInfo> for crate::proto::api::instance::Route {
|
||||
hostname: val.hostname.unwrap_or_default(),
|
||||
stun_info: {
|
||||
let mut stun_info = StunInfo::default();
|
||||
if let Ok(udp_nat_type) = NatType::try_from(val.udp_stun_info) {
|
||||
if let Ok(udp_nat_type) = NatType::try_from(val.udp_nat_type) {
|
||||
stun_info.set_udp_nat_type(udp_nat_type);
|
||||
}
|
||||
if let Ok(tcp_nat_type) = NatType::try_from(val.tcp_nat_type) {
|
||||
stun_info.set_tcp_nat_type(tcp_nat_type);
|
||||
}
|
||||
Some(stun_info)
|
||||
},
|
||||
inst_id: val.inst_id.map(|x| x.to_string()).unwrap_or_default(),
|
||||
@@ -869,10 +872,10 @@ impl RouteTable {
|
||||
self.get_next_hop(peer_id).is_some()
|
||||
}
|
||||
|
||||
fn get_nat_type(&self, peer_id: PeerId) -> Option<NatType> {
|
||||
fn get_udp_nat_type(&self, peer_id: PeerId) -> Option<NatType> {
|
||||
self.peer_infos
|
||||
.get(&peer_id)
|
||||
.map(|x| NatType::try_from(x.udp_stun_info).unwrap_or_default())
|
||||
.map(|x| NatType::try_from(x.udp_nat_type).unwrap_or_default())
|
||||
}
|
||||
|
||||
// return graph and start node index (node of my peer id).
|
||||
@@ -2516,7 +2519,7 @@ impl RouteSessionManager {
|
||||
let mut new_initiator_dst = None;
|
||||
// if any peer has NoPAT or OpenInternet stun type, we should use it.
|
||||
for peer_id in initiator_candidates.iter() {
|
||||
let Some(nat_type) = service_impl.route_table.get_nat_type(*peer_id) else {
|
||||
let Some(nat_type) = service_impl.route_table.get_udp_nat_type(*peer_id) else {
|
||||
continue;
|
||||
};
|
||||
if nat_type == NatType::NoPat || nat_type == NatType::OpenInternet {
|
||||
|
||||
@@ -80,6 +80,7 @@ message NetworkConfig {
|
||||
optional bool p2p_only = 51;
|
||||
optional common.CompressionAlgoPb data_compress_algo = 52;
|
||||
optional string encryption_algorithm = 53;
|
||||
optional bool disable_tcp_hole_punching = 54;
|
||||
}
|
||||
|
||||
message PortForwardConfig {
|
||||
|
||||
@@ -62,6 +62,8 @@ message FlagsInConfig {
|
||||
string tld_dns_zone = 31;
|
||||
|
||||
bool p2p_only = 32;
|
||||
|
||||
bool disable_tcp_hole_punching = 34;
|
||||
}
|
||||
|
||||
message RpcDescriptor {
|
||||
|
||||
@@ -13,7 +13,7 @@ message RoutePeerInfo {
|
||||
optional common.Ipv4Addr ipv4_addr = 4;
|
||||
repeated string proxy_cidrs = 5;
|
||||
optional string hostname = 6;
|
||||
common.NatType udp_stun_info = 7;
|
||||
common.NatType udp_nat_type = 7;
|
||||
google.protobuf.Timestamp last_update = 8;
|
||||
uint32 version = 9;
|
||||
|
||||
@@ -27,6 +27,8 @@ message RoutePeerInfo {
|
||||
optional common.Ipv6Inet ipv6_addr = 15;
|
||||
|
||||
repeated PeerGroupInfo groups = 16;
|
||||
|
||||
common.NatType tcp_nat_type = 17;
|
||||
}
|
||||
|
||||
message PeerIdVersion {
|
||||
@@ -207,6 +209,14 @@ service UdpHolePunchRpc {
|
||||
returns (SendPunchPacketBothEasySymResponse);
|
||||
}
|
||||
|
||||
message TcpHolePunchRequest { common.SocketAddr connector_mapped_addr = 1; }
|
||||
|
||||
message TcpHolePunchResponse { common.SocketAddr listener_mapped_addr = 1; }
|
||||
|
||||
service TcpHolePunchRpc {
|
||||
rpc ExchangeMappedAddr(TcpHolePunchRequest) returns (TcpHolePunchResponse);
|
||||
}
|
||||
|
||||
message DirectConnectedPeerInfo { int32 latency_ms = 1; }
|
||||
|
||||
message PeerInfoForGlobalMap {
|
||||
|
||||
@@ -134,25 +134,25 @@ impl FakeTcpTunnelListener {
|
||||
IpAddr::V6(ip) => (None, Some(ip)),
|
||||
};
|
||||
|
||||
let ret = self
|
||||
.stack_map
|
||||
.entry(interface_name.to_string())
|
||||
.or_insert_with(|| {
|
||||
let tun = create_tun(interface_name, None, local_socket_addr);
|
||||
let ret = match self.stack_map.entry(interface_name.to_string()) {
|
||||
dashmap::Entry::Occupied(entry) => entry.get().clone(),
|
||||
dashmap::Entry::Vacant(entry) => {
|
||||
let tun = create_tun(interface_name, None, local_socket_addr)?;
|
||||
tracing::info!(
|
||||
?local_socket_addr,
|
||||
"create new stack with interface_name: {:?}",
|
||||
interface_name
|
||||
);
|
||||
// TODO: Get local MAC address of the interface
|
||||
Arc::new(Mutex::new(stack::Stack::new(
|
||||
let stack = Arc::new(Mutex::new(stack::Stack::new(
|
||||
tun,
|
||||
local_ip.unwrap_or(Ipv4Addr::UNSPECIFIED),
|
||||
local_ip6,
|
||||
accept_result.mac,
|
||||
)))
|
||||
})
|
||||
.clone();
|
||||
)));
|
||||
entry.insert(stack.clone());
|
||||
stack
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
@@ -314,7 +314,7 @@ impl crate::tunnel::TunnelConnector for FakeTcpTunnelConnector {
|
||||
IpAddr::V6(ip) => (None, Some(ip)),
|
||||
};
|
||||
|
||||
let tun = create_tun(&interface_name, Some(remote_addr), local_addr);
|
||||
let tun = create_tun(&interface_name, Some(remote_addr), local_addr)?;
|
||||
let local_ip = local_ip.unwrap_or("0.0.0.0".parse().unwrap());
|
||||
let mut stack = stack::Stack::new(tun, local_ip, local_ip6, mac);
|
||||
let driver_type = stack.driver_type();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
pub mod pnet;
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{io, net::SocketAddr, sync::Arc};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_os = "linux")] {
|
||||
@@ -10,19 +10,19 @@ cfg_if::cfg_if! {
|
||||
interface_name: &str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
) -> io::Result<Arc<dyn super::stack::Tun>> {
|
||||
match linux_bpf::LinuxBpfTun::new(interface_name, src_addr, dst_addr) {
|
||||
Ok(tun) => Arc::new(tun),
|
||||
Ok(tun) => Ok(Arc::new(tun)),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
?e,
|
||||
interface_name,
|
||||
"LinuxBpfTun init failed, falling back to PnetTun"
|
||||
);
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
Ok(Arc::new(pnet::PnetTun::new(
|
||||
interface_name,
|
||||
pnet::create_packet_filter(src_addr, dst_addr),
|
||||
))
|
||||
)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -33,19 +33,19 @@ cfg_if::cfg_if! {
|
||||
interface_name: &str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
) -> io::Result<Arc<dyn super::stack::Tun>> {
|
||||
match macos_bpf::MacosBpfTun::new(interface_name, src_addr, dst_addr) {
|
||||
Ok(tun) => Arc::new(tun),
|
||||
Ok(tun) => Ok(Arc::new(tun)),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
?e,
|
||||
interface_name,
|
||||
"MacosBpfTun init failed, falling back to PnetTun"
|
||||
);
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
Ok(Arc::new(pnet::PnetTun::new(
|
||||
interface_name,
|
||||
pnet::create_packet_filter(src_addr, dst_addr),
|
||||
))
|
||||
)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,19 +56,19 @@ cfg_if::cfg_if! {
|
||||
_interface_name: &str,
|
||||
_src_addr: Option<SocketAddr>,
|
||||
local_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
) -> io::Result<Arc<dyn super::stack::Tun>> {
|
||||
match windivert::WinDivertTun::new(local_addr) {
|
||||
Ok(tun) => Arc::new(tun),
|
||||
Ok(tun) => Ok(Arc::new(tun)),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
?e,
|
||||
?local_addr,
|
||||
"WinDivertTun init failed, falling back to PnetTun"
|
||||
);
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
Ok(Arc::new(pnet::PnetTun::new(
|
||||
local_addr.to_string().as_str(),
|
||||
pnet::create_packet_filter(None, local_addr),
|
||||
))
|
||||
)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -77,11 +77,11 @@ cfg_if::cfg_if! {
|
||||
interface_name: &str,
|
||||
src_addr: Option<SocketAddr>,
|
||||
dst_addr: SocketAddr,
|
||||
) -> Arc<dyn super::stack::Tun> {
|
||||
Arc::new(pnet::PnetTun::new(
|
||||
) -> io::Result<Arc<dyn super::stack::Tun>> {
|
||||
Ok(Arc::new(pnet::PnetTun::new(
|
||||
interface_name,
|
||||
pnet::create_packet_filter(src_addr, dst_addr),
|
||||
))
|
||||
)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
io,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
@@ -145,14 +146,11 @@ struct InterfaceWorker {
|
||||
}
|
||||
|
||||
impl InterfaceWorker {
|
||||
fn new(interface: NetworkInterface) -> Arc<Self> {
|
||||
fn new(interface: NetworkInterface) -> io::Result<Arc<Self>> {
|
||||
let (tx, mut rx) = match datalink::channel(&interface, Default::default()) {
|
||||
Ok(pnet::datalink::Channel::Ethernet(tx, rx)) => (tx, rx),
|
||||
Ok(_) => panic!("Unhandled channel type"),
|
||||
Err(e) => panic!(
|
||||
"An error occurred when creating the datalink channel: {}",
|
||||
e
|
||||
),
|
||||
Ok(_) => return Err(io::Error::other("Unhandled channel type")),
|
||||
Err(e) => return Err(io::Error::other(e)),
|
||||
};
|
||||
|
||||
let subscribers = Arc::new(DashMap::<u32, Subscriber>::new());
|
||||
@@ -187,10 +185,10 @@ impl InterfaceWorker {
|
||||
}
|
||||
});
|
||||
|
||||
Arc::new(Self {
|
||||
Ok(Arc::new(Self {
|
||||
tx: Mutex::new(tx),
|
||||
subscribers,
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
fn subscribe(&self, filter: PacketFilter, sender: tokio::sync::mpsc::Sender<Vec<u8>>) -> u32 {
|
||||
@@ -207,13 +205,13 @@ impl InterfaceWorker {
|
||||
|
||||
static INTERFACE_MANAGERS: Lazy<DashMap<String, Weak<InterfaceWorker>>> = Lazy::new(DashMap::new);
|
||||
|
||||
fn get_or_create_worker(interface_name: &str) -> Arc<InterfaceWorker> {
|
||||
fn get_or_create_worker(interface_name: &str) -> io::Result<Arc<InterfaceWorker>> {
|
||||
// Check if we have an active worker
|
||||
if let Some(worker) = INTERFACE_MANAGERS
|
||||
.get(interface_name)
|
||||
.and_then(|w| w.upgrade())
|
||||
{
|
||||
return worker;
|
||||
return Ok(worker);
|
||||
}
|
||||
|
||||
// Need to create new worker.
|
||||
@@ -229,9 +227,9 @@ fn get_or_create_worker(interface_name: &str) -> Arc<InterfaceWorker> {
|
||||
.find(|iface| iface.name == interface_name)
|
||||
.expect("Network interface not found");
|
||||
|
||||
let worker = InterfaceWorker::new(interface);
|
||||
let worker = InterfaceWorker::new(interface)?;
|
||||
INTERFACE_MANAGERS.insert(interface_name.to_string(), Arc::downgrade(&worker));
|
||||
worker
|
||||
Ok(worker)
|
||||
}
|
||||
|
||||
pub struct PnetTun {
|
||||
@@ -241,17 +239,17 @@ pub struct PnetTun {
|
||||
}
|
||||
|
||||
impl PnetTun {
|
||||
pub fn new(interface_name: &str, filter: PacketFilter) -> Self {
|
||||
pub fn new(interface_name: &str, filter: PacketFilter) -> io::Result<Self> {
|
||||
tracing::debug!(interface_name, "Creating new PnetTun");
|
||||
let worker = get_or_create_worker(interface_name);
|
||||
let worker = get_or_create_worker(interface_name)?;
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1024);
|
||||
let id = worker.subscribe(filter, tx);
|
||||
|
||||
Self {
|
||||
Ok(Self {
|
||||
worker,
|
||||
subscription_id: id,
|
||||
recv_queue: Mutex::new(rx),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user