mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 18:24:36 +00:00
tunnel(bind): gather all bind logic to a single function (#2070)
* extract a Bindable trait for binding TcpSocket, TcpListener, and UdpSocket
This commit is contained in:
+27
-23
@@ -4,9 +4,10 @@
|
||||
|
||||
use super::{FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener};
|
||||
use crate::common::global_ctx::ArcGlobalCtx;
|
||||
use crate::tunnel::common::bind;
|
||||
use crate::tunnel::{
|
||||
TunnelInfo,
|
||||
common::{FramedReader, FramedWriter, TunnelWrapper, setup_socket2},
|
||||
common::{FramedReader, FramedWriter, TunnelWrapper},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use derivative::Derivative;
|
||||
@@ -19,6 +20,7 @@ use quinn::{
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::OnceLock;
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
// region config
|
||||
pub fn transport_config() -> Arc<TransportConfig> {
|
||||
@@ -179,27 +181,28 @@ pub struct QuicEndpointManager {
|
||||
static QUIC_ENDPOINT_MANAGER: OnceLock<QuicEndpointManager> = OnceLock::new();
|
||||
|
||||
impl QuicEndpointManager {
|
||||
fn try_create(addr: SocketAddr, dual_stack: bool) -> std::io::Result<Endpoint> {
|
||||
let socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_socket2(&socket, &addr, addr.is_ipv6() && !dual_stack)
|
||||
.map_err(std::io::Error::other)?;
|
||||
let socket = std::net::UdpSocket::from(socket);
|
||||
let runtime = default_runtime().ok_or(std::io::Error::other("no async runtime found"))?;
|
||||
fn try_create(addr: SocketAddr, dual_stack: bool) -> Result<Endpoint, TunnelError> {
|
||||
let socket = bind::<UdpSocket>()
|
||||
.addr(addr)
|
||||
.only_v6(addr.is_ipv6() && !dual_stack)
|
||||
.call()?;
|
||||
let runtime = default_runtime().ok_or(TunnelError::InternalError(
|
||||
"no async runtime found".to_owned(),
|
||||
))?;
|
||||
let mut endpoint = Endpoint::new_with_abstract_socket(
|
||||
endpoint_config(),
|
||||
None,
|
||||
runtime.wrap_udp_socket(socket)?,
|
||||
runtime.wrap_udp_socket(socket.into_std()?)?,
|
||||
runtime,
|
||||
)?;
|
||||
endpoint.set_default_client_config(client_config());
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
fn create<F>(&self, mut selector: F) -> std::io::Result<(&RwPool<Endpoint>, Option<Endpoint>)>
|
||||
fn create<F>(
|
||||
&self,
|
||||
mut selector: F,
|
||||
) -> Result<(&RwPool<Endpoint>, Option<Endpoint>), TunnelError>
|
||||
where
|
||||
F: FnMut(&QuicEndpointManager) -> (&RwPool<Endpoint>, Option<(SocketAddr, bool)>),
|
||||
{
|
||||
@@ -210,10 +213,10 @@ impl QuicEndpointManager {
|
||||
};
|
||||
|
||||
let endpoint = Self::try_create(addr, dual_stack);
|
||||
if let Err(e) = endpoint.as_ref()
|
||||
if let Err(error) = endpoint.as_ref()
|
||||
&& dual_stack
|
||||
{
|
||||
tracing::warn!("create dual stack quic endpoint failed: {:?}", e);
|
||||
tracing::warn!(?error, "create dual stack quic endpoint failed");
|
||||
self.both.disable();
|
||||
self.ipv4.enable();
|
||||
self.ipv6.enable();
|
||||
@@ -263,7 +266,7 @@ impl QuicEndpointManager {
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `addr`: listen address
|
||||
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> std::io::Result<Endpoint> {
|
||||
fn server(global_ctx: &ArcGlobalCtx, addr: SocketAddr) -> Result<Endpoint, TunnelError> {
|
||||
let mgr = Self::load(global_ctx);
|
||||
|
||||
let (pool, endpoint) = mgr.create(|mgr| {
|
||||
@@ -289,7 +292,7 @@ impl QuicEndpointManager {
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `ip_version`: the IP version of the remote address
|
||||
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> std::io::Result<Endpoint> {
|
||||
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
|
||||
let mgr = Self::load(global_ctx);
|
||||
|
||||
let (pool, endpoint) = mgr.create(|mgr| {
|
||||
@@ -318,7 +321,7 @@ impl QuicEndpointManager {
|
||||
async fn connect(
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
addr: SocketAddr,
|
||||
) -> std::io::Result<(Endpoint, Connection)> {
|
||||
) -> Result<(Endpoint, Connection), TunnelError> {
|
||||
let ip_version = if addr.ip().is_ipv4() {
|
||||
IpVersion::V4
|
||||
} else {
|
||||
@@ -327,8 +330,9 @@ impl QuicEndpointManager {
|
||||
let endpoint = Self::client(global_ctx, ip_version)?;
|
||||
let connection = endpoint
|
||||
.connect(addr, "localhost")
|
||||
.map_err(std::io::Error::other)?
|
||||
.await?;
|
||||
.with_context(|| format!("failed to create connection to {}", addr))?
|
||||
.await
|
||||
.with_context(|| format!("failed to connect to {}", addr))?;
|
||||
|
||||
Ok((endpoint, connection))
|
||||
}
|
||||
@@ -585,10 +589,10 @@ mod tests {
|
||||
async fn invalid_peer_addr_impl() {
|
||||
let mut connector =
|
||||
QuicTunnelConnector::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx());
|
||||
let err = connector.connect().await.unwrap_err();
|
||||
let err = format!("{:?}", connector.connect().await.unwrap_err());
|
||||
assert!(
|
||||
err.to_string().contains("invalid remote address"),
|
||||
"unexpected error: {:?}",
|
||||
err.contains("invalid remote address"),
|
||||
"unexpected error: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user