mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-06 17:59:11 +00:00
fix dns query (#864)
1. dns resolver should be global unique so dns cache can work. avoid dns query influence hole punching. 2. when system dns failed, fallback to hickory dns.
This commit is contained in:
@@ -0,0 +1,134 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use hickory_proto::runtime::TokioRuntimeProvider;
|
||||||
|
use hickory_proto::xfer::Protocol;
|
||||||
|
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, ResolverConfig, ResolverOpts};
|
||||||
|
use hickory_resolver::name_server::{GenericConnector, TokioConnectionProvider};
|
||||||
|
use hickory_resolver::system_conf::read_system_conf;
|
||||||
|
use hickory_resolver::{Resolver, TokioResolver};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use tokio::net::lookup_host;
|
||||||
|
|
||||||
|
use super::error::Error;
|
||||||
|
|
||||||
|
pub fn get_default_resolver_config() -> ResolverConfig {
|
||||||
|
let mut default_resolve_config = ResolverConfig::new();
|
||||||
|
default_resolve_config.add_name_server(NameServerConfig::new(
|
||||||
|
"223.5.5.5:53".parse().unwrap(),
|
||||||
|
Protocol::Udp,
|
||||||
|
));
|
||||||
|
default_resolve_config.add_name_server(NameServerConfig::new(
|
||||||
|
"180.184.1.1:53".parse().unwrap(),
|
||||||
|
Protocol::Udp,
|
||||||
|
));
|
||||||
|
default_resolve_config
|
||||||
|
}
|
||||||
|
|
||||||
|
pub static ALLOW_USE_SYSTEM_DNS_RESOLVER: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(true));
|
||||||
|
|
||||||
|
pub static RESOLVER: Lazy<Arc<Resolver<GenericConnector<TokioRuntimeProvider>>>> =
|
||||||
|
Lazy::new(|| {
|
||||||
|
let system_cfg = read_system_conf();
|
||||||
|
let mut cfg = get_default_resolver_config();
|
||||||
|
let mut opt = ResolverOpts::default();
|
||||||
|
if let Ok(s) = system_cfg {
|
||||||
|
for ns in s.0.name_servers() {
|
||||||
|
cfg.add_name_server(ns.clone());
|
||||||
|
}
|
||||||
|
opt = s.1;
|
||||||
|
}
|
||||||
|
opt.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
|
||||||
|
let builder = TokioResolver::builder_with_config(cfg, TokioConnectionProvider::default())
|
||||||
|
.with_options(opt);
|
||||||
|
Arc::new(builder.build())
|
||||||
|
});
|
||||||
|
|
||||||
|
pub async fn resolve_txt_record(domain_name: &str) -> Result<String, Error> {
|
||||||
|
let r = RESOLVER.clone();
|
||||||
|
let response = r.txt_lookup(domain_name).await.with_context(|| {
|
||||||
|
format!(
|
||||||
|
"txt_lookup failed, domain_name: {}",
|
||||||
|
domain_name.to_string()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let txt_record = response.iter().next().with_context(|| {
|
||||||
|
format!(
|
||||||
|
"no txt record found, domain_name: {}",
|
||||||
|
domain_name.to_string()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let txt_data = String::from_utf8_lossy(&txt_record.txt_data()[0]);
|
||||||
|
tracing::info!(?txt_data, ?domain_name, "get txt record");
|
||||||
|
|
||||||
|
Ok(txt_data.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn socket_addrs(
|
||||||
|
url: &url::Url,
|
||||||
|
default_port_number: impl Fn() -> Option<u16>,
|
||||||
|
) -> Result<Vec<SocketAddr>, Error> {
|
||||||
|
let host = url.host_str().ok_or(Error::InvalidUrl(url.to_string()))?;
|
||||||
|
let port = url
|
||||||
|
.port()
|
||||||
|
.or_else(default_port_number)
|
||||||
|
.ok_or(Error::InvalidUrl(url.to_string()))?;
|
||||||
|
|
||||||
|
// if host is an ip address, return it directly
|
||||||
|
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||||
|
return Ok(vec![SocketAddr::new(ip, port)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ALLOW_USE_SYSTEM_DNS_RESOLVER.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
let socket_addr = format!("{}:{}", host, port);
|
||||||
|
match lookup_host(socket_addr).await {
|
||||||
|
Ok(a) => {
|
||||||
|
let a = a.collect();
|
||||||
|
tracing::debug!(?a, "system dns lookup done");
|
||||||
|
return Ok(a);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(?e, "system dns lookup failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// use hickory_resolver
|
||||||
|
let ret = RESOLVER.lookup_ip(host).await.with_context(|| {
|
||||||
|
format!(
|
||||||
|
"hickory dns lookup_ip failed, host: {}, port: {}",
|
||||||
|
host, port
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(ret
|
||||||
|
.iter()
|
||||||
|
.map(|ip| SocketAddr::new(ip, port))
|
||||||
|
.collect::<Vec<_>>())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::defer;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_socket_addrs() {
|
||||||
|
let url = url::Url::parse("tcp://public.easytier.cn:80").unwrap();
|
||||||
|
let addrs = socket_addrs(&url, || Some(80)).await.unwrap();
|
||||||
|
assert_eq!(2, addrs.len(), "addrs: {:?}", addrs);
|
||||||
|
println!("addrs: {:?}", addrs);
|
||||||
|
|
||||||
|
ALLOW_USE_SYSTEM_DNS_RESOLVER.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
defer!(
|
||||||
|
ALLOW_USE_SYSTEM_DNS_RESOLVER.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
);
|
||||||
|
let addrs = socket_addrs(&url, || Some(80)).await.unwrap();
|
||||||
|
assert_eq!(2, addrs.len(), "addrs: {:?}", addrs);
|
||||||
|
println!("addrs2: {:?}", addrs);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@ pub mod compressor;
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod constants;
|
pub mod constants;
|
||||||
pub mod defer;
|
pub mod defer;
|
||||||
|
pub mod dns;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod global_ctx;
|
pub mod global_ctx;
|
||||||
pub mod ifcfg;
|
pub mod ifcfg;
|
||||||
|
|||||||
@@ -8,10 +8,6 @@ use crate::proto::common::{NatType, StunInfo};
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use crossbeam::atomic::AtomicCell;
|
use crossbeam::atomic::AtomicCell;
|
||||||
use hickory_proto::xfer::Protocol;
|
|
||||||
use hickory_resolver::config::{NameServerConfig, ResolverConfig};
|
|
||||||
use hickory_resolver::name_server::TokioConnectionProvider;
|
|
||||||
use hickory_resolver::TokioResolver;
|
|
||||||
use rand::seq::IteratorRandom;
|
use rand::seq::IteratorRandom;
|
||||||
use tokio::net::{lookup_host, UdpSocket};
|
use tokio::net::{lookup_host, UdpSocket};
|
||||||
use tokio::sync::{broadcast, Mutex};
|
use tokio::sync::{broadcast, Mutex};
|
||||||
@@ -24,45 +20,9 @@ use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
|
|||||||
|
|
||||||
use crate::common::error::Error;
|
use crate::common::error::Error;
|
||||||
|
|
||||||
|
use super::dns::resolve_txt_record;
|
||||||
use super::stun_codec_ext::*;
|
use super::stun_codec_ext::*;
|
||||||
|
|
||||||
pub fn get_default_resolver_config() -> ResolverConfig {
|
|
||||||
let mut default_resolve_config = ResolverConfig::new();
|
|
||||||
default_resolve_config.add_name_server(NameServerConfig::new(
|
|
||||||
"223.5.5.5:53".parse().unwrap(),
|
|
||||||
Protocol::Udp,
|
|
||||||
));
|
|
||||||
default_resolve_config.add_name_server(NameServerConfig::new(
|
|
||||||
"180.184.1.1:53".parse().unwrap(),
|
|
||||||
Protocol::Udp,
|
|
||||||
));
|
|
||||||
default_resolve_config
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn resolve_txt_record(
|
|
||||||
domain_name: &str,
|
|
||||||
resolver: &TokioResolver,
|
|
||||||
) -> Result<String, Error> {
|
|
||||||
let response = resolver.txt_lookup(domain_name).await.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"txt_lookup failed, domain_name: {}",
|
|
||||||
domain_name.to_string()
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let txt_record = response.iter().next().with_context(|| {
|
|
||||||
format!(
|
|
||||||
"no txt record found, domain_name: {}",
|
|
||||||
domain_name.to_string()
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let txt_data = String::from_utf8_lossy(&txt_record.txt_data()[0]);
|
|
||||||
tracing::info!(?txt_data, ?domain_name, "get txt record");
|
|
||||||
|
|
||||||
Ok(txt_data.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HostResolverIter {
|
struct HostResolverIter {
|
||||||
hostnames: Vec<String>,
|
hostnames: Vec<String>,
|
||||||
ips: Vec<SocketAddr>,
|
ips: Vec<SocketAddr>,
|
||||||
@@ -81,13 +41,7 @@ impl HostResolverIter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn get_txt_record(domain_name: &str) -> Result<Vec<String>, Error> {
|
async fn get_txt_record(domain_name: &str) -> Result<Vec<String>, Error> {
|
||||||
let resolver = TokioResolver::builder_tokio()
|
let txt_data = resolve_txt_record(domain_name).await?;
|
||||||
.unwrap_or(TokioResolver::builder_with_config(
|
|
||||||
get_default_resolver_config(),
|
|
||||||
TokioConnectionProvider::default(),
|
|
||||||
))
|
|
||||||
.build();
|
|
||||||
let txt_data = resolve_txt_record(domain_name, &resolver).await?;
|
|
||||||
Ok(txt_data.split(" ").map(|x| x.to_string()).collect())
|
Ok(txt_data.split(" ").map(|x| x.to_string()).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,17 +2,15 @@ use std::{net::SocketAddr, sync::Arc};
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
|
dns::{resolve_txt_record, RESOLVER},
|
||||||
error::Error,
|
error::Error,
|
||||||
global_ctx::ArcGlobalCtx,
|
global_ctx::ArcGlobalCtx,
|
||||||
stun::{get_default_resolver_config, resolve_txt_record},
|
|
||||||
},
|
},
|
||||||
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, PROTO_PORT_OFFSET},
|
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, PROTO_PORT_OFFSET},
|
||||||
};
|
};
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use dashmap::DashSet;
|
use dashmap::DashSet;
|
||||||
use hickory_resolver::{
|
use hickory_resolver::proto::rr::rdata::SRV;
|
||||||
name_server::TokioConnectionProvider, proto::rr::rdata::SRV, TokioResolver,
|
|
||||||
};
|
|
||||||
use rand::{seq::SliceRandom, Rng as _};
|
use rand::{seq::SliceRandom, Rng as _};
|
||||||
|
|
||||||
use crate::proto::common::TunnelInfo;
|
use crate::proto::common::TunnelInfo;
|
||||||
@@ -58,14 +56,7 @@ impl DNSTunnelConnector {
|
|||||||
&self,
|
&self,
|
||||||
domain_name: &str,
|
domain_name: &str,
|
||||||
) -> Result<Box<dyn TunnelConnector>, Error> {
|
) -> Result<Box<dyn TunnelConnector>, Error> {
|
||||||
let resolver = TokioResolver::builder_tokio()
|
let txt_data = resolve_txt_record(domain_name)
|
||||||
.unwrap_or(TokioResolver::builder_with_config(
|
|
||||||
get_default_resolver_config(),
|
|
||||||
TokioConnectionProvider::default(),
|
|
||||||
))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
let txt_data = resolve_txt_record(domain_name, &resolver)
|
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("resolve txt record failed, domain_name: {}", domain_name))?;
|
.with_context(|| format!("resolve txt record failed, domain_name: {}", domain_name))?;
|
||||||
|
|
||||||
@@ -120,13 +111,6 @@ impl DNSTunnelConnector {
|
|||||||
) -> Result<Box<dyn TunnelConnector>, Error> {
|
) -> Result<Box<dyn TunnelConnector>, Error> {
|
||||||
tracing::info!("handle_srv_record: {}", domain_name);
|
tracing::info!("handle_srv_record: {}", domain_name);
|
||||||
|
|
||||||
let resolver = TokioResolver::builder_tokio()
|
|
||||||
.unwrap_or(TokioResolver::builder_with_config(
|
|
||||||
get_default_resolver_config(),
|
|
||||||
TokioConnectionProvider::default(),
|
|
||||||
))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
let srv_domains = PROTO_PORT_OFFSET
|
let srv_domains = PROTO_PORT_OFFSET
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(p, _)| (format!("_easytier._{}.{}", p, domain_name), *p)) // _easytier._udp.{domain_name}
|
.map(|(p, _)| (format!("_easytier._{}.{}", p, domain_name), *p)) // _easytier._udp.{domain_name}
|
||||||
@@ -136,7 +120,7 @@ impl DNSTunnelConnector {
|
|||||||
let srv_lookup_tasks = srv_domains
|
let srv_lookup_tasks = srv_domains
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(srv_domain, protocol)| {
|
.map(|(srv_domain, protocol)| {
|
||||||
let resolver = resolver.clone();
|
let resolver = RESOLVER.clone();
|
||||||
let responses = responses.clone();
|
let responses = responses.clone();
|
||||||
async move {
|
async move {
|
||||||
let response = resolver.srv_lookup(srv_domain).await.with_context(|| {
|
let response = resolver.srv_lookup(srv_domain).await.with_context(|| {
|
||||||
|
|||||||
@@ -60,7 +60,8 @@ pub async fn create_connector_by_url(
|
|||||||
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
||||||
let mut connector: Box<dyn TunnelConnector + 'static> = match url.scheme() {
|
let mut connector: Box<dyn TunnelConnector + 'static> = match url.scheme() {
|
||||||
"tcp" => {
|
"tcp" => {
|
||||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp", ip_version)?;
|
let dst_addr =
|
||||||
|
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp", ip_version).await?;
|
||||||
let mut connector = TcpTunnelConnector::new(url);
|
let mut connector = TcpTunnelConnector::new(url);
|
||||||
if global_ctx.config.get_flags().bind_device {
|
if global_ctx.config.get_flags().bind_device {
|
||||||
set_bind_addr_for_peer_connector(
|
set_bind_addr_for_peer_connector(
|
||||||
@@ -73,7 +74,8 @@ pub async fn create_connector_by_url(
|
|||||||
Box::new(connector)
|
Box::new(connector)
|
||||||
}
|
}
|
||||||
"udp" => {
|
"udp" => {
|
||||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp", ip_version)?;
|
let dst_addr =
|
||||||
|
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp", ip_version).await?;
|
||||||
let mut connector = UdpTunnelConnector::new(url);
|
let mut connector = UdpTunnelConnector::new(url);
|
||||||
if global_ctx.config.get_flags().bind_device {
|
if global_ctx.config.get_flags().bind_device {
|
||||||
set_bind_addr_for_peer_connector(
|
set_bind_addr_for_peer_connector(
|
||||||
@@ -90,14 +92,14 @@ pub async fn create_connector_by_url(
|
|||||||
Box::new(connector)
|
Box::new(connector)
|
||||||
}
|
}
|
||||||
"ring" => {
|
"ring" => {
|
||||||
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring", IpVersion::Both)?;
|
check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring", IpVersion::Both).await?;
|
||||||
let connector = RingTunnelConnector::new(url);
|
let connector = RingTunnelConnector::new(url);
|
||||||
Box::new(connector)
|
Box::new(connector)
|
||||||
}
|
}
|
||||||
#[cfg(feature = "quic")]
|
#[cfg(feature = "quic")]
|
||||||
"quic" => {
|
"quic" => {
|
||||||
let dst_addr =
|
let dst_addr =
|
||||||
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic", ip_version)?;
|
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "quic", ip_version).await?;
|
||||||
let mut connector = QUICTunnelConnector::new(url);
|
let mut connector = QUICTunnelConnector::new(url);
|
||||||
if global_ctx.config.get_flags().bind_device {
|
if global_ctx.config.get_flags().bind_device {
|
||||||
set_bind_addr_for_peer_connector(
|
set_bind_addr_for_peer_connector(
|
||||||
@@ -111,7 +113,8 @@ pub async fn create_connector_by_url(
|
|||||||
}
|
}
|
||||||
#[cfg(feature = "wireguard")]
|
#[cfg(feature = "wireguard")]
|
||||||
"wg" => {
|
"wg" => {
|
||||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg", ip_version)?;
|
let dst_addr =
|
||||||
|
check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg", ip_version).await?;
|
||||||
let nid = global_ctx.get_network_identity();
|
let nid = global_ctx.get_network_identity();
|
||||||
let wg_config = WgConfig::new_from_network_identity(
|
let wg_config = WgConfig::new_from_network_identity(
|
||||||
&nid.network_name,
|
&nid.network_name,
|
||||||
@@ -131,7 +134,7 @@ pub async fn create_connector_by_url(
|
|||||||
#[cfg(feature = "websocket")]
|
#[cfg(feature = "websocket")]
|
||||||
"ws" | "wss" => {
|
"ws" | "wss" => {
|
||||||
use crate::tunnel::FromUrl;
|
use crate::tunnel::FromUrl;
|
||||||
let dst_addr = SocketAddr::from_url(url.clone(), ip_version)?;
|
let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?;
|
||||||
let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url);
|
let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url);
|
||||||
if global_ctx.config.get_flags().bind_device {
|
if global_ctx.config.get_flags().bind_device {
|
||||||
set_bind_addr_for_peer_connector(
|
set_bind_addr_for_peer_connector(
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use tokio::net::{TcpListener, UdpSocket};
|
|||||||
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
|
|
||||||
use crate::common::stun::get_default_resolver_config;
|
use crate::common::dns::get_default_resolver_config;
|
||||||
|
|
||||||
use super::config::{GeneralConfig, Record, RunConfig};
|
use super::config::{GeneralConfig, Record, RunConfig};
|
||||||
|
|
||||||
|
|||||||
@@ -401,8 +401,8 @@ pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
|||||||
mut futures: FuturesUnordered<Fut>,
|
mut futures: FuturesUnordered<Fut>,
|
||||||
) -> Result<Ret, TunnelError>
|
) -> Result<Ret, TunnelError>
|
||||||
where
|
where
|
||||||
Fut: Future<Output = Result<Ret, E>> + Send + Sync,
|
Fut: Future<Output = Result<Ret, E>> + Send,
|
||||||
E: std::error::Error + Into<TunnelError> + Send + Sync + 'static,
|
E: std::error::Error + Into<TunnelError> + Send + 'static,
|
||||||
{
|
{
|
||||||
// return last error
|
// return last error
|
||||||
let mut last_err = None;
|
let mut last_err = None;
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use std::fmt::Debug;
|
|||||||
|
|
||||||
use tokio::time::error::Elapsed;
|
use tokio::time::error::Elapsed;
|
||||||
|
|
||||||
|
use crate::common::dns::socket_addrs;
|
||||||
use crate::proto::common::TunnelInfo;
|
use crate::proto::common::TunnelInfo;
|
||||||
|
|
||||||
use self::packet_def::ZCPacket;
|
use self::packet_def::ZCPacket;
|
||||||
@@ -169,13 +170,14 @@ impl std::fmt::Debug for dyn TunnelListener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
pub(crate) trait FromUrl {
|
pub(crate) trait FromUrl {
|
||||||
fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError>
|
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError>
|
||||||
where
|
where
|
||||||
Self: Sized;
|
Self: Sized;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn check_scheme_and_get_socket_addr_ext<T>(
|
pub(crate) async fn check_scheme_and_get_socket_addr_ext<T>(
|
||||||
url: &url::Url,
|
url: &url::Url,
|
||||||
scheme: &str,
|
scheme: &str,
|
||||||
ip_version: IpVersion,
|
ip_version: IpVersion,
|
||||||
@@ -187,10 +189,10 @@ where
|
|||||||
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(T::from_url(url.clone(), ip_version)?)
|
Ok(T::from_url(url.clone(), ip_version).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn check_scheme_and_get_socket_addr<T>(
|
pub(crate) async fn check_scheme_and_get_socket_addr<T>(
|
||||||
url: &url::Url,
|
url: &url::Url,
|
||||||
scheme: &str,
|
scheme: &str,
|
||||||
ip_version: IpVersion,
|
ip_version: IpVersion,
|
||||||
@@ -202,7 +204,7 @@ where
|
|||||||
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(T::from_url(url.clone(), ip_version)?)
|
Ok(T::from_url(url.clone(), ip_version).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_port(scheme: &str) -> Option<u16> {
|
fn default_port(scheme: &str) -> Option<u16> {
|
||||||
@@ -217,9 +219,17 @@ fn default_port(scheme: &str) -> Option<u16> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl FromUrl for SocketAddr {
|
impl FromUrl for SocketAddr {
|
||||||
fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
|
async fn from_url(url: url::Url, ip_version: IpVersion) -> Result<Self, TunnelError> {
|
||||||
let addrs = url.socket_addrs(|| default_port(url.scheme()))?;
|
let addrs = socket_addrs(&url, || default_port(url.scheme()))
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
TunnelError::InvalidAddr(format!(
|
||||||
|
"failed to resolve socket addr, url: {}, error: {}",
|
||||||
|
url, e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs");
|
tracing::debug!(?addrs, ?ip_version, ?url, "convert url to socket addrs");
|
||||||
let addrs = addrs
|
let addrs = addrs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -239,8 +249,9 @@ impl FromUrl for SocketAddr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
impl FromUrl for uuid::Uuid {
|
impl FromUrl for uuid::Uuid {
|
||||||
fn from_url(url: url::Url, _ip_version: IpVersion) -> Result<Self, TunnelError> {
|
async fn from_url(url: url::Url, _ip_version: IpVersion) -> Result<Self, TunnelError> {
|
||||||
let o = url.host_str().unwrap();
|
let o = url.host_str().unwrap();
|
||||||
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
||||||
Ok(o)
|
Ok(o)
|
||||||
|
|||||||
@@ -85,7 +85,8 @@ impl QUICTunnelListener {
|
|||||||
impl TunnelListener for QUICTunnelListener {
|
impl TunnelListener for QUICTunnelListener {
|
||||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||||
let addr =
|
let addr =
|
||||||
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", IpVersion::Both)?;
|
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic", IpVersion::Both)
|
||||||
|
.await?;
|
||||||
let (endpoint, server_cert) = make_server_endpoint(addr).unwrap();
|
let (endpoint, server_cert) = make_server_endpoint(addr).unwrap();
|
||||||
self.endpoint = Some(endpoint);
|
self.endpoint = Some(endpoint);
|
||||||
self.server_cert = Some(server_cert);
|
self.server_cert = Some(server_cert);
|
||||||
@@ -149,11 +150,9 @@ impl QUICTunnelConnector {
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl TunnelConnector for QUICTunnelConnector {
|
impl TunnelConnector for QUICTunnelConnector {
|
||||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
let addr = check_scheme_and_get_socket_addr_ext::<SocketAddr>(
|
let addr =
|
||||||
&self.addr,
|
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "quic", self.ip_version)
|
||||||
"quic",
|
.await?;
|
||||||
self.ip_version,
|
|
||||||
)?;
|
|
||||||
let local_addr = if addr.is_ipv4() {
|
let local_addr = if addr.is_ipv4() {
|
||||||
"0.0.0.0:0"
|
"0.0.0.0:0"
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -230,12 +230,13 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl RingTunnelListener {
|
impl RingTunnelListener {
|
||||||
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
async fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
||||||
check_scheme_and_get_socket_addr::<Uuid>(
|
check_scheme_and_get_socket_addr::<Uuid>(
|
||||||
&self.listerner_addr,
|
&self.listerner_addr,
|
||||||
"ring",
|
"ring",
|
||||||
super::IpVersion::Both,
|
super::IpVersion::Both,
|
||||||
)
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,13 +247,13 @@ impl TunnelListener for RingTunnelListener {
|
|||||||
CONNECTION_MAP
|
CONNECTION_MAP
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.insert(self.get_addr()?, self.conn_sender.clone());
|
.insert(self.get_addr().await?, self.conn_sender.clone());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||||
tracing::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
tracing::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
||||||
let my_addr = self.get_addr()?;
|
let my_addr = self.get_addr().await?;
|
||||||
if let Some(conn) = self.conn_receiver.recv().await {
|
if let Some(conn) = self.conn_receiver.recv().await {
|
||||||
if conn.server.id == my_addr {
|
if conn.server.id == my_addr {
|
||||||
tracing::info!("accept new conn of key: {}", self.listerner_addr);
|
tracing::info!("accept new conn of key: {}", self.listerner_addr);
|
||||||
@@ -292,7 +293,8 @@ impl TunnelConnector for RingTunnelConnector {
|
|||||||
&self.remote_addr,
|
&self.remote_addr,
|
||||||
"ring",
|
"ring",
|
||||||
super::IpVersion::Both,
|
super::IpVersion::Both,
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
let entry = CONNECTION_MAP
|
let entry = CONNECTION_MAP
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -59,7 +59,8 @@ impl TunnelListener for TcpTunnelListener {
|
|||||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||||
self.listener = None;
|
self.listener = None;
|
||||||
let addr =
|
let addr =
|
||||||
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", IpVersion::Both)?;
|
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp", IpVersion::Both)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let socket2_socket = socket2::Socket::new(
|
let socket2_socket = socket2::Socket::new(
|
||||||
socket2::Domain::for_address(addr),
|
socket2::Domain::for_address(addr),
|
||||||
@@ -190,7 +191,8 @@ impl TcpTunnelConnector {
|
|||||||
impl super::TunnelConnector for TcpTunnelConnector {
|
impl super::TunnelConnector for TcpTunnelConnector {
|
||||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
let addr =
|
let addr =
|
||||||
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "tcp", self.ip_version)?;
|
check_scheme_and_get_socket_addr_ext::<SocketAddr>(&self.addr, "tcp", self.ip_version)
|
||||||
|
.await?;
|
||||||
if self.bind_addrs.is_empty() {
|
if self.bind_addrs.is_empty() {
|
||||||
self.connect_with_default_bind(addr).await
|
self.connect_with_default_bind(addr).await
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -477,7 +477,8 @@ impl TunnelListener for UdpTunnelListener {
|
|||||||
&self.addr,
|
&self.addr,
|
||||||
"udp",
|
"udp",
|
||||||
IpVersion::Both,
|
IpVersion::Both,
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let socket2_socket = socket2::Socket::new(
|
let socket2_socket = socket2::Socket::new(
|
||||||
socket2::Domain::for_address(addr),
|
socket2::Domain::for_address(addr),
|
||||||
@@ -781,7 +782,8 @@ impl super::TunnelConnector for UdpTunnelConnector {
|
|||||||
&self.addr,
|
&self.addr,
|
||||||
"udp",
|
"udp",
|
||||||
self.ip_version,
|
self.ip_version,
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
if self.bind_addrs.is_empty() || addr.is_ipv6() {
|
if self.bind_addrs.is_empty() || addr.is_ipv6() {
|
||||||
self.connect_with_default_bind(addr).await
|
self.connect_with_default_bind(addr).await
|
||||||
} else {
|
} else {
|
||||||
@@ -963,6 +965,7 @@ mod tests {
|
|||||||
"udp",
|
"udp",
|
||||||
IpVersion::Both,
|
IpVersion::Both,
|
||||||
)
|
)
|
||||||
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let socket2_socket = socket2::Socket::new(
|
let socket2_socket = socket2::Socket::new(
|
||||||
socket2::Domain::for_address(addr),
|
socket2::Domain::for_address(addr),
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ impl WSTunnelListener {
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl TunnelListener for WSTunnelListener {
|
impl TunnelListener for WSTunnelListener {
|
||||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||||
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both)?;
|
let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
|
||||||
let socket2_socket = socket2::Socket::new(
|
let socket2_socket = socket2::Socket::new(
|
||||||
socket2::Domain::for_address(addr),
|
socket2::Domain::for_address(addr),
|
||||||
socket2::Type::STREAM,
|
socket2::Type::STREAM,
|
||||||
@@ -182,7 +182,7 @@ impl WSTunnelConnector {
|
|||||||
tcp_socket: TcpSocket,
|
tcp_socket: TcpSocket,
|
||||||
) -> Result<Box<dyn Tunnel>, TunnelError> {
|
) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||||
let is_wss = is_wss(&addr)?;
|
let is_wss = is_wss(&addr)?;
|
||||||
let socket_addr = SocketAddr::from_url(addr.clone(), ip_version)?;
|
let socket_addr = SocketAddr::from_url(addr.clone(), ip_version).await?;
|
||||||
let domain = addr.domain();
|
let domain = addr.domain();
|
||||||
let host = socket_addr.ip();
|
let host = socket_addr.ip();
|
||||||
let stream = tcp_socket.connect(socket_addr).await?;
|
let stream = tcp_socket.connect(socket_addr).await?;
|
||||||
@@ -205,12 +205,8 @@ impl WSTunnelConnector {
|
|||||||
let tls_conn =
|
let tls_conn =
|
||||||
tokio_rustls::TlsConnector::from(Arc::new(get_insecure_tls_client_config()));
|
tokio_rustls::TlsConnector::from(Arc::new(get_insecure_tls_client_config()));
|
||||||
let domain_or_ip = match domain {
|
let domain_or_ip = match domain {
|
||||||
None => {
|
None => host.to_string(),
|
||||||
host.to_string()
|
Some(domain) => domain.to_string(),
|
||||||
}
|
|
||||||
Some(domain) => {
|
|
||||||
domain.to_string()
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let stream = tls_conn
|
let stream = tls_conn
|
||||||
.connect(domain_or_ip.try_into().unwrap(), stream)
|
.connect(domain_or_ip.try_into().unwrap(), stream)
|
||||||
@@ -274,7 +270,7 @@ impl WSTunnelConnector {
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl TunnelConnector for WSTunnelConnector {
|
impl TunnelConnector for WSTunnelConnector {
|
||||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version)?;
|
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
|
||||||
if self.bind_addrs.is_empty() || addr.is_ipv6() {
|
if self.bind_addrs.is_empty() || addr.is_ipv6() {
|
||||||
self.connect_with_default_bind(addr).await
|
self.connect_with_default_bind(addr).await
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -548,7 +548,8 @@ impl WgTunnelListener {
|
|||||||
impl TunnelListener for WgTunnelListener {
|
impl TunnelListener for WgTunnelListener {
|
||||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||||
let addr =
|
let addr =
|
||||||
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg", IpVersion::Both)?;
|
check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg", IpVersion::Both)
|
||||||
|
.await?;
|
||||||
let socket2_socket = socket2::Socket::new(
|
let socket2_socket = socket2::Socket::new(
|
||||||
socket2::Domain::for_address(addr),
|
socket2::Domain::for_address(addr),
|
||||||
socket2::Type::DGRAM,
|
socket2::Type::DGRAM,
|
||||||
@@ -705,7 +706,8 @@ impl super::TunnelConnector for WgTunnelConnector {
|
|||||||
&self.addr,
|
&self.addr,
|
||||||
"wg",
|
"wg",
|
||||||
self.ip_version,
|
self.ip_version,
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
if addr.is_ipv6() {
|
if addr.is_ipv6() {
|
||||||
return self.connect_with_ipv6(addr).await;
|
return self.connect_with_ipv6(addr).await;
|
||||||
|
|||||||
Reference in New Issue
Block a user