mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 02:09:06 +00:00
Initial Version
This commit is contained in:
@@ -0,0 +1,105 @@
|
||||
[package]
|
||||
name = "easytier-core"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["easytier"]
|
||||
rust-version = "1.75"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
name = "easytier_rpc"
|
||||
path = "src/rpc/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
tracing = { version = "0.1", features = ["log"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tracing-appender = "0.2.3"
|
||||
log = "0.4"
|
||||
thiserror = "1.0"
|
||||
auto_impl = "1.1.0"
|
||||
crossbeam = "0.8.4"
|
||||
|
||||
gethostname = "0.4.3"
|
||||
|
||||
futures = "0.3"
|
||||
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
tokio-util = { version = "0.7.9", features = ["codec", "net"] }
|
||||
|
||||
async-stream = "0.3.5"
|
||||
async-trait = "0.1.74"
|
||||
|
||||
dashmap = "5.5.3"
|
||||
timedmap = "1.0.1"
|
||||
|
||||
# for tap device
|
||||
tun = { version = "0.6.1", features = ["async"] }
|
||||
# for net ns
|
||||
nix = { version = "0.27", features = ["sched", "socket", "ioctl"] }
|
||||
|
||||
uuid = { version = "1.5.0", features = [
|
||||
"v4",
|
||||
"fast-rng",
|
||||
"macro-diagnostics",
|
||||
"serde",
|
||||
] }
|
||||
|
||||
# for ring tunnel
|
||||
crossbeam-queue = "0.3"
|
||||
once_cell = "1.18.0"
|
||||
|
||||
# for packet
|
||||
rkyv = { "version" = "0.7.42", features = ["validation", "archive_le"] }
|
||||
|
||||
# for rpc
|
||||
tonic = "0.10"
|
||||
prost = "0.12"
|
||||
anyhow = "1.0"
|
||||
tarpc = { version = "0.32", features = ["tokio1", "serde1"] }
|
||||
bincode = "1.3"
|
||||
|
||||
url = "2.5.0"
|
||||
|
||||
# for tun packet
|
||||
byteorder = "1.5.0"
|
||||
|
||||
# for proxy
|
||||
cidr = "0.2.2"
|
||||
socket2 = "0.5.5"
|
||||
|
||||
# for hole punching
|
||||
stun-format = { git = "https://github.com/KKRainbow/stun-format.git", features = [
|
||||
"fmt",
|
||||
"rfc3489",
|
||||
"iana",
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1.0"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.pnet]
|
||||
version = "0.34.0"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.clap]
|
||||
version = "4.4"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.public-ip]
|
||||
version = "0.2"
|
||||
features = ["default"]
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10"
|
||||
|
||||
[target.'cfg(windows)'.build-dependencies]
|
||||
reqwest = { version = "0.11", features = ["blocking"] }
|
||||
zip = "*"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
serial_test = "*"
|
||||
@@ -0,0 +1,95 @@
|
||||
#[cfg(target_os = "windows")]
|
||||
use std::{
|
||||
env,
|
||||
fs::File,
|
||||
io::{copy, Cursor},
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
struct WindowsBuild {}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
impl WindowsBuild {
|
||||
fn check_protoc_exist() -> Option<PathBuf> {
|
||||
let path = env::var_os("PROTOC").map(PathBuf::from);
|
||||
if path.is_some() && path.as_ref().unwrap().exists() {
|
||||
return path;
|
||||
}
|
||||
|
||||
let path = env::var_os("PATH").unwrap_or_default();
|
||||
for p in env::split_paths(&path) {
|
||||
let p = p.join("protoc");
|
||||
if p.exists() {
|
||||
return Some(p);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn get_cargo_target_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
|
||||
let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?);
|
||||
let profile = std::env::var("PROFILE")?;
|
||||
let mut target_dir = None;
|
||||
let mut sub_path = out_dir.as_path();
|
||||
while let Some(parent) = sub_path.parent() {
|
||||
if parent.ends_with(&profile) {
|
||||
target_dir = Some(parent);
|
||||
break;
|
||||
}
|
||||
sub_path = parent;
|
||||
}
|
||||
let target_dir = target_dir.ok_or("not found")?;
|
||||
Ok(target_dir.to_path_buf())
|
||||
}
|
||||
|
||||
fn download_protoc() -> PathBuf {
|
||||
println!("cargo:info=use exist protoc: {:?}", "k");
|
||||
let out_dir = Self::get_cargo_target_dir().unwrap();
|
||||
let fname = out_dir.join("protoc");
|
||||
if fname.exists() {
|
||||
println!("cargo:info=use exist protoc: {:?}", fname);
|
||||
return fname;
|
||||
}
|
||||
|
||||
println!("cargo:info=need download protoc, please wait...");
|
||||
|
||||
let url = "https://github.com/protocolbuffers/protobuf/releases/download/v26.0-rc1/protoc-26.0-rc-1-win64.zip";
|
||||
let response = reqwest::blocking::get(url).unwrap();
|
||||
println!("{:?}", response);
|
||||
let mut content = response
|
||||
.bytes()
|
||||
.map(|v| v.to_vec())
|
||||
.map(Cursor::new)
|
||||
.map(zip::ZipArchive::new)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let protoc_zipped_file = content.by_name("bin/protoc.exe").unwrap();
|
||||
let mut content = protoc_zipped_file;
|
||||
|
||||
copy(&mut content, &mut File::create(&fname).unwrap()).unwrap();
|
||||
|
||||
fname
|
||||
}
|
||||
|
||||
pub fn check_for_win() {
|
||||
// add third_party dir to link search path
|
||||
println!("cargo:rustc-link-search=native=third_party/");
|
||||
let protoc_path = if let Some(o) = Self::check_protoc_exist() {
|
||||
println!("cargo:info=use os exist protoc: {:?}", o);
|
||||
o
|
||||
} else {
|
||||
Self::download_protoc()
|
||||
};
|
||||
std::env::set_var("PROTOC", protoc_path);
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
#[cfg(target_os = "windows")]
|
||||
WindowsBuild::check_for_win();
|
||||
|
||||
tonic_build::compile_protos("proto/cli.proto")?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
syntax = "proto3";
|
||||
package cli;
|
||||
|
||||
message Status {
|
||||
int32 code = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message PeerConnStats {
|
||||
uint64 rx_bytes = 1;
|
||||
uint64 tx_bytes = 2;
|
||||
|
||||
uint64 rx_packets = 3;
|
||||
uint64 tx_packets = 4;
|
||||
|
||||
uint64 latency_us = 5;
|
||||
}
|
||||
|
||||
message TunnelInfo {
|
||||
string tunnel_type = 1;
|
||||
string local_addr = 2;
|
||||
string remote_addr = 3;
|
||||
}
|
||||
|
||||
message PeerConnInfo {
|
||||
string conn_id = 1;
|
||||
string my_node_id = 2;
|
||||
string peer_id = 3;
|
||||
repeated string features = 4;
|
||||
TunnelInfo tunnel = 5;
|
||||
PeerConnStats stats = 6;
|
||||
}
|
||||
|
||||
message PeerInfo {
|
||||
string peer_id = 1;
|
||||
repeated PeerConnInfo conns = 2;
|
||||
}
|
||||
|
||||
message ListPeerRequest {}
|
||||
|
||||
message ListPeerResponse {
|
||||
repeated PeerInfo peer_infos = 1;
|
||||
}
|
||||
|
||||
enum NatType {
|
||||
// has NAT; but own a single public IP, port is not changed
|
||||
Unknown = 0;
|
||||
OpenInternet = 1;
|
||||
NoPAT = 2;
|
||||
FullCone = 3;
|
||||
Restricted = 4;
|
||||
PortRestricted = 5;
|
||||
Symmetric = 6;
|
||||
SymUdpFirewall = 7;
|
||||
}
|
||||
|
||||
message StunInfo {
|
||||
NatType udp_nat_type = 1;
|
||||
NatType tcp_nat_type = 2;
|
||||
int64 last_update_time = 3;
|
||||
}
|
||||
|
||||
message Route {
|
||||
string peer_id = 1;
|
||||
string ipv4_addr = 2;
|
||||
string next_hop_peer_id = 3;
|
||||
int32 cost = 4;
|
||||
repeated string proxy_cidrs = 5;
|
||||
string hostname = 6;
|
||||
StunInfo stun_info = 7;
|
||||
}
|
||||
|
||||
message ListRouteRequest {}
|
||||
|
||||
message ListRouteResponse {
|
||||
repeated Route routes = 1;
|
||||
}
|
||||
|
||||
service PeerManageRpc {
|
||||
rpc ListPeer (ListPeerRequest) returns (ListPeerResponse);
|
||||
rpc ListRoute (ListRouteRequest) returns (ListRouteResponse);
|
||||
}
|
||||
|
||||
enum ConnectorStatus {
|
||||
CONNECTED = 0;
|
||||
DISCONNECTED = 1;
|
||||
CONNECTING = 2;
|
||||
}
|
||||
|
||||
message Connector {
|
||||
string url = 1;
|
||||
ConnectorStatus status = 2;
|
||||
}
|
||||
|
||||
message ListConnectorRequest {}
|
||||
|
||||
message ListConnectorResponse {
|
||||
repeated Connector connectors = 1;
|
||||
}
|
||||
|
||||
enum ConnectorManageAction {
|
||||
ADD = 0;
|
||||
REMOVE = 1;
|
||||
}
|
||||
|
||||
message ManageConnectorRequest {
|
||||
ConnectorManageAction action = 1;
|
||||
string url = 2;
|
||||
}
|
||||
|
||||
message ManageConnectorResponse { }
|
||||
|
||||
service ConnectorManageRpc {
|
||||
rpc ListConnector (ListConnectorRequest) returns (ListConnectorResponse);
|
||||
rpc ManageConnector (ManageConnectorRequest) returns (ManageConnectorResponse);
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
// use filesystem as a config store
|
||||
|
||||
use std::{
|
||||
ffi::OsStr,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
static DEFAULT_BASE_DIR: &str = "/var/lib/easytier";
|
||||
static DIR_ROOT_CONFIG_FILE_NAME: &str = "__root__";
|
||||
|
||||
pub struct ConfigFs {
|
||||
_db_name: String,
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl ConfigFs {
|
||||
pub fn new(db_name: &str) -> Self {
|
||||
Self::new_with_dir(db_name, DEFAULT_BASE_DIR)
|
||||
}
|
||||
|
||||
pub fn new_with_dir(db_name: &str, dir: &str) -> Self {
|
||||
let p = Path::new(OsStr::new(dir)).join(OsStr::new(db_name));
|
||||
std::fs::create_dir_all(&p).unwrap();
|
||||
ConfigFs {
|
||||
_db_name: db_name.to_string(),
|
||||
db_path: p,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, key: &str) -> Result<String, std::io::Error> {
|
||||
let path = self.db_path.join(OsStr::new(key));
|
||||
// if path is dir, read the DIR_ROOT_CONFIG_FILE_NAME in it
|
||||
if path.is_dir() {
|
||||
let path = path.join(OsStr::new(DIR_ROOT_CONFIG_FILE_NAME));
|
||||
std::fs::read_to_string(path)
|
||||
} else if path.is_file() {
|
||||
return std::fs::read_to_string(path);
|
||||
} else {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"key not found",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_keys(&self, key: &str) -> Result<Vec<String>, std::io::Error> {
|
||||
let path = self.db_path.join(OsStr::new(key));
|
||||
let mut keys = Vec::new();
|
||||
for entry in std::fs::read_dir(path)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
let key = path.file_name().unwrap().to_str().unwrap().to_string();
|
||||
if key != DIR_ROOT_CONFIG_FILE_NAME {
|
||||
keys.push(key);
|
||||
}
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn remove(&self, key: &str) -> Result<(), std::io::Error> {
|
||||
let path = self.db_path.join(OsStr::new(key));
|
||||
// if path is dir, remove the DIR_ROOT_CONFIG_FILE_NAME in it
|
||||
if path.is_dir() {
|
||||
std::fs::remove_dir_all(path)
|
||||
} else if path.is_file() {
|
||||
return std::fs::remove_file(path);
|
||||
} else {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"key not found",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_dir(&self, key: &str) -> Result<std::fs::File, std::io::Error> {
|
||||
let path = self.db_path.join(OsStr::new(key));
|
||||
// if path is dir, write the DIR_ROOT_CONFIG_FILE_NAME in it
|
||||
if path.is_file() {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::AlreadyExists,
|
||||
"key already exists",
|
||||
))
|
||||
} else {
|
||||
std::fs::create_dir_all(&path)?;
|
||||
return std::fs::File::create(path.join(OsStr::new(DIR_ROOT_CONFIG_FILE_NAME)));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_file(&self, key: &str) -> Result<std::fs::File, std::io::Error> {
|
||||
let path = self.db_path.join(OsStr::new(key));
|
||||
let base_dir = path.parent().unwrap();
|
||||
if !path.is_file() {
|
||||
std::fs::create_dir_all(base_dir)?;
|
||||
}
|
||||
std::fs::File::create(path)
|
||||
}
|
||||
|
||||
pub fn get_or_add<F>(
|
||||
&self,
|
||||
key: &str,
|
||||
val_fn: F,
|
||||
add_dir: bool,
|
||||
) -> Result<String, std::io::Error>
|
||||
where
|
||||
F: FnOnce() -> String,
|
||||
{
|
||||
let get_ret = self.get(key);
|
||||
match get_ret {
|
||||
Ok(v) => Ok(v),
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
let val = val_fn();
|
||||
if add_dir {
|
||||
let mut f = self.add_dir(key)?;
|
||||
f.write_all(val.as_bytes())?;
|
||||
} else {
|
||||
let mut f = self.add_file(key)?;
|
||||
f.write_all(val.as_bytes())?;
|
||||
}
|
||||
Ok(val)
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_or_add_dir<F>(&self, key: &str, val_fn: F) -> Result<String, std::io::Error>
|
||||
where
|
||||
F: FnOnce() -> String,
|
||||
{
|
||||
self.get_or_add(key, val_fn, true)
|
||||
}
|
||||
|
||||
pub fn get_or_add_file<F>(&self, key: &str, val_fn: F) -> Result<String, std::io::Error>
|
||||
where
|
||||
F: FnOnce() -> String,
|
||||
{
|
||||
self.get_or_add(key, val_fn, false)
|
||||
}
|
||||
|
||||
pub fn get_or_default<F>(&self, key: &str, default: F) -> Result<String, std::io::Error>
|
||||
where
|
||||
F: FnOnce() -> String,
|
||||
{
|
||||
let get_ret = self.get(key);
|
||||
match get_ret {
|
||||
Ok(v) => Ok(v),
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
Ok(default())
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
|
||||
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 60;
|
||||
pub const DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC: u64 = 60;
|
||||
|
||||
macro_rules! define_global_var {
|
||||
($name:ident, $type:ty, $init:expr) => {
|
||||
pub static $name: once_cell::sync::Lazy<tokio::sync::Mutex<$type>> =
|
||||
once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new($init));
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! use_global_var {
|
||||
($name:ident) => {
|
||||
crate::common::constants::$name.lock().await.to_owned()
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! set_global_var {
|
||||
($name:ident, $val:expr) => {
|
||||
*crate::common::constants::$name.lock().await = $val
|
||||
};
|
||||
}
|
||||
|
||||
define_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, u64, 1000);
|
||||
|
||||
pub const UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID: u32 = 2;
|
||||
@@ -0,0 +1,43 @@
|
||||
use std::{io, result};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::tunnels;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("io error")]
|
||||
IOError(#[from] io::Error),
|
||||
#[error("rust tun error {0}")]
|
||||
TunError(#[from] tun::Error),
|
||||
#[error("tunnel error {0}")]
|
||||
TunnelError(#[from] tunnels::TunnelError),
|
||||
#[error("Peer has no conn, PeerId: {0}")]
|
||||
PeerNoConnectionError(uuid::Uuid),
|
||||
#[error("RouteError: {0}")]
|
||||
RouteError(String),
|
||||
#[error("Not found")]
|
||||
NotFound,
|
||||
#[error("Invalid Url: {0}")]
|
||||
InvalidUrl(String),
|
||||
#[error("Shell Command error: {0}")]
|
||||
ShellCommandError(String),
|
||||
// #[error("Rpc listen error: {0}")]
|
||||
// RpcListenError(String),
|
||||
#[error("Rpc connect error: {0}")]
|
||||
RpcConnectError(String),
|
||||
#[error("Rpc error: {0}")]
|
||||
RpcClientError(#[from] tarpc::client::RpcError),
|
||||
#[error("Timeout error: {0}")]
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("url in blacklist")]
|
||||
UrlInBlacklist,
|
||||
#[error("unknown data store error")]
|
||||
Unknown,
|
||||
#[error("anyhow error: {0}")]
|
||||
AnyhowError(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = result::Result<T, Error>;
|
||||
|
||||
// impl From for std::
|
||||
@@ -0,0 +1,259 @@
|
||||
use std::{io::Write, sync::Arc};
|
||||
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use easytier_rpc::PeerConnInfo;
|
||||
|
||||
use super::{
|
||||
config_fs::ConfigFs,
|
||||
netns::NetNS,
|
||||
network::IPCollector,
|
||||
stun::{StunInfoCollector, StunInfoCollectorTrait},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum GlobalCtxEvent {
|
||||
PeerAdded,
|
||||
PeerRemoved,
|
||||
PeerConnAdded(PeerConnInfo),
|
||||
PeerConnRemoved(PeerConnInfo),
|
||||
}
|
||||
|
||||
type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
|
||||
type EventBusSubscriber = tokio::sync::broadcast::Receiver<GlobalCtxEvent>;
|
||||
|
||||
pub struct GlobalCtx {
|
||||
pub inst_name: String,
|
||||
pub id: uuid::Uuid,
|
||||
pub config_fs: ConfigFs,
|
||||
pub net_ns: NetNS,
|
||||
|
||||
event_bus: EventBus,
|
||||
|
||||
cached_ipv4: AtomicCell<Option<std::net::Ipv4Addr>>,
|
||||
cached_proxy_cidrs: AtomicCell<Option<Vec<cidr::IpCidr>>>,
|
||||
|
||||
ip_collector: Arc<IPCollector>,
|
||||
|
||||
hotname: AtomicCell<Option<String>>,
|
||||
|
||||
stun_info_collection: Box<dyn StunInfoCollectorTrait>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GlobalCtx {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GlobalCtx")
|
||||
.field("inst_name", &self.inst_name)
|
||||
.field("id", &self.id)
|
||||
.field("net_ns", &self.net_ns.name())
|
||||
.field("event_bus", &"EventBus")
|
||||
.field("ipv4", &self.cached_ipv4)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
|
||||
|
||||
impl GlobalCtx {
|
||||
pub fn new(inst_name: &str, config_fs: ConfigFs, net_ns: NetNS) -> Self {
|
||||
let id = config_fs
|
||||
.get_or_add_file("inst_id", || uuid::Uuid::new_v4().to_string())
|
||||
.unwrap();
|
||||
let id = uuid::Uuid::parse_str(&id).unwrap();
|
||||
|
||||
let (event_bus, _) = tokio::sync::broadcast::channel(100);
|
||||
|
||||
// NOTICE: we may need to choose stun stun server based on geo location
|
||||
// stun server cross nation may return a external ip address with high latency and loss rate
|
||||
let default_stun_servers = vec![
|
||||
"stun.miwifi.com:3478".to_string(),
|
||||
"stun.qq.com:3478".to_string(),
|
||||
"stun.chat.bilibili.com:3478".to_string(),
|
||||
"fwa.lifesizecloud.com:3478".to_string(),
|
||||
"stun.isp.net.au:3478".to_string(),
|
||||
"stun.nextcloud.com:3478".to_string(),
|
||||
"stun.freeswitch.org:3478".to_string(),
|
||||
"stun.voip.blackberry.com:3478".to_string(),
|
||||
"stunserver.stunprotocol.org:3478".to_string(),
|
||||
"stun.sipnet.com:3478".to_string(),
|
||||
"stun.radiojar.com:3478".to_string(),
|
||||
"stun.sonetel.com:3478".to_string(),
|
||||
"stun.voipgate.com:3478".to_string(),
|
||||
"stun.counterpath.com:3478".to_string(),
|
||||
"180.235.108.91:3478".to_string(),
|
||||
"193.22.2.248:3478".to_string(),
|
||||
];
|
||||
|
||||
GlobalCtx {
|
||||
inst_name: inst_name.to_string(),
|
||||
id,
|
||||
config_fs,
|
||||
net_ns: net_ns.clone(),
|
||||
event_bus,
|
||||
cached_ipv4: AtomicCell::new(None),
|
||||
cached_proxy_cidrs: AtomicCell::new(None),
|
||||
|
||||
ip_collector: Arc::new(IPCollector::new(net_ns)),
|
||||
|
||||
hotname: AtomicCell::new(None),
|
||||
|
||||
stun_info_collection: Box::new(StunInfoCollector::new(default_stun_servers)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribe(&self) -> EventBusSubscriber {
|
||||
self.event_bus.subscribe()
|
||||
}
|
||||
|
||||
pub fn issue_event(&self, event: GlobalCtxEvent) {
|
||||
if self.event_bus.receiver_count() != 0 {
|
||||
self.event_bus.send(event).unwrap();
|
||||
} else {
|
||||
log::warn!("No subscriber for event: {:?}", event);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ipv4(&self) -> Option<std::net::Ipv4Addr> {
|
||||
if let Some(ret) = self.cached_ipv4.load() {
|
||||
return Some(ret);
|
||||
}
|
||||
|
||||
let Ok(addr) = self.config_fs.get("ipv4") else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let Ok(addr) = addr.parse() else {
|
||||
tracing::error!("invalid ipv4 addr: {}", addr);
|
||||
return None;
|
||||
};
|
||||
|
||||
self.cached_ipv4.store(Some(addr));
|
||||
return Some(addr);
|
||||
}
|
||||
|
||||
pub fn set_ipv4(&mut self, addr: std::net::Ipv4Addr) {
|
||||
self.config_fs
|
||||
.add_file("ipv4")
|
||||
.unwrap()
|
||||
.write_all(addr.to_string().as_bytes())
|
||||
.unwrap();
|
||||
|
||||
self.cached_ipv4.store(None);
|
||||
}
|
||||
|
||||
pub fn add_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
|
||||
let escaped_cidr = cidr.to_string().replace("/", "_");
|
||||
self.config_fs
|
||||
.add_file(&format!("proxy_cidrs/{}", escaped_cidr))?;
|
||||
self.cached_proxy_cidrs.store(None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
|
||||
let escaped_cidr = cidr.to_string().replace("/", "_");
|
||||
self.config_fs
|
||||
.remove(&format!("proxy_cidrs/{}", escaped_cidr))?;
|
||||
self.cached_proxy_cidrs.store(None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr> {
|
||||
if let Some(proxy_cidrs) = self.cached_proxy_cidrs.take() {
|
||||
self.cached_proxy_cidrs.store(Some(proxy_cidrs.clone()));
|
||||
return proxy_cidrs;
|
||||
}
|
||||
|
||||
let Ok(keys) = self.config_fs.list_keys("proxy_cidrs") else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut ret = Vec::new();
|
||||
for key in keys.iter() {
|
||||
let key = key.replace("_", "/");
|
||||
let Ok(cidr) = key.parse() else {
|
||||
tracing::error!("invalid proxy cidr: {}", key);
|
||||
continue;
|
||||
};
|
||||
ret.push(cidr);
|
||||
}
|
||||
|
||||
self.cached_proxy_cidrs.store(Some(ret.clone()));
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn get_ip_collector(&self) -> Arc<IPCollector> {
|
||||
self.ip_collector.clone()
|
||||
}
|
||||
|
||||
pub fn get_hostname(&self) -> Option<String> {
|
||||
if let Some(hostname) = self.hotname.take() {
|
||||
self.hotname.store(Some(hostname.clone()));
|
||||
return Some(hostname);
|
||||
}
|
||||
|
||||
let hostname = gethostname::gethostname().to_string_lossy().to_string();
|
||||
self.hotname.store(Some(hostname.clone()));
|
||||
return Some(hostname);
|
||||
}
|
||||
|
||||
pub fn get_stun_info_collector(&self) -> impl StunInfoCollectorTrait + '_ {
|
||||
self.stun_info_collection.as_ref()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) {
|
||||
// force replace the stun_info_collection without mut and drop the old one
|
||||
let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>;
|
||||
let ptr = ptr as *mut Box<dyn StunInfoCollectorTrait>;
|
||||
unsafe {
|
||||
std::ptr::drop_in_place(ptr);
|
||||
std::ptr::write(ptr, collector);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_id(&self) -> uuid::Uuid {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_global_ctx() {
|
||||
let config_fs = ConfigFs::new("/tmp/easytier");
|
||||
let net_ns = NetNS::new(None);
|
||||
let global_ctx = GlobalCtx::new("test", config_fs, net_ns);
|
||||
|
||||
let mut subscriber = global_ctx.subscribe();
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerAdded);
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved);
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default()));
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default()));
|
||||
|
||||
assert_eq!(subscriber.recv().await.unwrap(), GlobalCtxEvent::PeerAdded);
|
||||
assert_eq!(
|
||||
subscriber.recv().await.unwrap(),
|
||||
GlobalCtxEvent::PeerRemoved
|
||||
);
|
||||
assert_eq!(
|
||||
subscriber.recv().await.unwrap(),
|
||||
GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default())
|
||||
);
|
||||
assert_eq!(
|
||||
subscriber.recv().await.unwrap(),
|
||||
GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default())
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get_mock_global_ctx() -> ArcGlobalCtx {
|
||||
let node_id = uuid::Uuid::new_v4();
|
||||
let config_fs = ConfigFs::new_with_dir(node_id.to_string().as_str(), "/tmp/easytier");
|
||||
let net_ns = NetNS::new(None);
|
||||
std::sync::Arc::new(GlobalCtx::new(
|
||||
format!("test_{}", node_id).as_str(),
|
||||
config_fs,
|
||||
net_ns,
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::process::Command;
|
||||
|
||||
use super::error::Error;
|
||||
|
||||
#[async_trait]
|
||||
pub trait IfConfiguerTrait {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error>;
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error>;
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error>;
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error>;
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error>;
|
||||
async fn wait_interface_show(&self, _name: &str) -> Result<(), Error> {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
fn cidr_to_subnet_mask(prefix_length: u8) -> Ipv4Addr {
|
||||
if prefix_length > 32 {
|
||||
panic!("Invalid CIDR prefix length");
|
||||
}
|
||||
|
||||
let subnet_mask: u32 = (!0u32)
|
||||
.checked_shl(32 - u32::from(prefix_length))
|
||||
.unwrap_or(0);
|
||||
Ipv4Addr::new(
|
||||
((subnet_mask >> 24) & 0xFF) as u8,
|
||||
((subnet_mask >> 16) & 0xFF) as u8,
|
||||
((subnet_mask >> 8) & 0xFF) as u8,
|
||||
(subnet_mask & 0xFF) as u8,
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_shell_cmd(cmd: &str) -> Result<(), Error> {
|
||||
let cmd_out = if cfg!(target_os = "windows") {
|
||||
Command::new("cmd").arg("/C").arg(cmd).output().await?
|
||||
} else {
|
||||
Command::new("sh").arg("-c").arg(cmd).output().await?
|
||||
};
|
||||
let stdout = String::from_utf8_lossy(cmd_out.stdout.as_slice());
|
||||
let stderr = String::from_utf8_lossy(cmd_out.stderr.as_slice());
|
||||
let ec = cmd_out.status.code();
|
||||
let succ = cmd_out.status.success();
|
||||
tracing::info!(?cmd, ?ec, ?succ, ?stdout, ?stderr, "run shell cmd");
|
||||
|
||||
if !cmd_out.status.success() {
|
||||
return Err(Error::ShellCommandError(
|
||||
stdout.to_string() + &stderr.to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct MacIfConfiger {}
|
||||
#[async_trait]
|
||||
impl IfConfiguerTrait for MacIfConfiger {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route -n add {} -netmask {} -interface {} -hopcount 7",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route -n delete {} -netmask {} -interface {}",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"ifconfig {} {:?}/{:?} 10.8.8.8 up",
|
||||
name, address, cidr_prefix,
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ifconfig {} {}", name, if up { "up" } else { "down" }).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
|
||||
if ip.is_none() {
|
||||
run_shell_cmd(format!("ifconfig {} inet delete", name).as_str()).await
|
||||
} else {
|
||||
run_shell_cmd(
|
||||
format!("ifconfig {} inet {} delete", name, ip.unwrap().to_string()).as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LinuxIfConfiger {}
|
||||
#[async_trait]
|
||||
impl IfConfiguerTrait for LinuxIfConfiger {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"ip route add {}/{} dev {} metric 65535",
|
||||
address, cidr_prefix, name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ip route del {}/{} dev {}", address, cidr_prefix, name).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ip addr add {:?}/{:?} dev {}", address, cidr_prefix, name).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ip link set {} {}", name, if up { "up" } else { "down" }).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
|
||||
if ip.is_none() {
|
||||
run_shell_cmd(format!("ip addr flush dev {}", name).as_str()).await
|
||||
} else {
|
||||
run_shell_cmd(
|
||||
format!("ip addr del {:?} dev {}", ip.unwrap().to_string(), name).as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WindowsIfConfiger {}
|
||||
|
||||
#[async_trait]
|
||||
impl IfConfiguerTrait for WindowsIfConfiger {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route add {} mask {} {}",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route delete {} mask {} {}",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"netsh interface ipv4 add address {} address={} mask={}",
|
||||
name,
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix)
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"netsh interface set interface {} {}",
|
||||
name,
|
||||
if up { "enable" } else { "disable" }
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
|
||||
if ip.is_none() {
|
||||
run_shell_cmd(format!("netsh interface ipv4 delete address {}", name).as_str()).await
|
||||
} else {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"netsh interface ipv4 delete address {} address={}",
|
||||
name,
|
||||
ip.unwrap().to_string()
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_interface_show(&self, name: &str) -> Result<(), Error> {
|
||||
Ok(
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), async move {
|
||||
loop {
|
||||
let Ok(_) = run_shell_cmd(
|
||||
format!("netsh interface ipv4 show interfaces {}", name).as_str(),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
continue;
|
||||
};
|
||||
break;
|
||||
}
|
||||
})
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub type IfConfiger = MacIfConfiger;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub type IfConfiger = LinuxIfConfiger;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub type IfConfiger = WindowsIfConfiger;
|
||||
@@ -0,0 +1,9 @@
|
||||
pub mod config_fs;
|
||||
pub mod constants;
|
||||
pub mod error;
|
||||
pub mod global_ctx;
|
||||
pub mod ifcfg;
|
||||
pub mod netns;
|
||||
pub mod network;
|
||||
pub mod rkyv_util;
|
||||
pub mod stun;
|
||||
@@ -0,0 +1,118 @@
|
||||
use futures::Future;
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
use nix::sched::{setns, CloneFlags};
|
||||
#[cfg(target_os = "linux")]
|
||||
use std::os::fd::AsFd;
|
||||
|
||||
pub struct NetNSGuard {
|
||||
#[cfg(target_os = "linux")]
|
||||
old_ns: Option<std::fs::File>,
|
||||
}
|
||||
|
||||
type NetNSLock = Mutex<()>;
|
||||
static LOCK: Lazy<NetNSLock> = Lazy::new(|| Mutex::new(()));
|
||||
pub static ROOT_NETNS_NAME: &str = "_root_ns";
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl NetNSGuard {
|
||||
pub fn new(ns: Option<String>) -> Box<Self> {
|
||||
let old_ns = if ns.is_some() {
|
||||
let old_ns = if cfg!(target_os = "linux") {
|
||||
Some(std::fs::File::open("/proc/self/ns/net").unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self::switch_ns(ns);
|
||||
old_ns
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Box::new(NetNSGuard { old_ns })
|
||||
}
|
||||
|
||||
fn switch_ns(name: Option<String>) {
|
||||
if name.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
let ns_path: String;
|
||||
let name = name.unwrap();
|
||||
if name == ROOT_NETNS_NAME {
|
||||
ns_path = "/proc/1/ns/net".to_string();
|
||||
} else {
|
||||
ns_path = format!("/var/run/netns/{}", name);
|
||||
}
|
||||
|
||||
let ns = std::fs::File::open(ns_path).unwrap();
|
||||
log::info!(
|
||||
"[INIT NS] switching to new ns_name: {:?}, ns_file: {:?}",
|
||||
name,
|
||||
ns
|
||||
);
|
||||
|
||||
setns(ns.as_fd(), CloneFlags::CLONE_NEWNET).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl Drop for NetNSGuard {
|
||||
fn drop(&mut self) {
|
||||
if self.old_ns.is_none() {
|
||||
return;
|
||||
}
|
||||
log::info!("[INIT NS] switching back to old ns, ns: {:?}", self.old_ns);
|
||||
setns(
|
||||
self.old_ns.as_ref().unwrap().as_fd(),
|
||||
CloneFlags::CLONE_NEWNET,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
impl NetNSGuard {
|
||||
pub fn new(_ns: Option<String>) -> Box<Self> {
|
||||
Box::new(NetNSGuard {})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NetNS {
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
impl NetNS {
|
||||
pub fn new(name: Option<String>) -> Self {
|
||||
NetNS { name }
|
||||
}
|
||||
|
||||
pub async fn run_async<F, Fut, Ret>(&self, f: F) -> Ret
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Ret>,
|
||||
{
|
||||
// TODO: do we really need this lock
|
||||
// let _lock = LOCK.lock().await;
|
||||
let _guard = NetNSGuard::new(self.name.clone());
|
||||
f().await
|
||||
}
|
||||
|
||||
pub fn run<F, Ret>(&self, f: F) -> Ret
|
||||
where
|
||||
F: FnOnce() -> Ret,
|
||||
{
|
||||
let _guard = NetNSGuard::new(self.name.clone());
|
||||
f()
|
||||
}
|
||||
|
||||
pub fn guard(&self) -> Box<NetNSGuard> {
|
||||
NetNSGuard::new(self.name.clone())
|
||||
}
|
||||
|
||||
pub fn name(&self) -> Option<String> {
|
||||
self.name.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use easytier_rpc::peer::GetIpListResponse;
|
||||
use pnet::datalink::NetworkInterface;
|
||||
use tokio::{
|
||||
sync::{Mutex, RwLock},
|
||||
task::JoinSet,
|
||||
};
|
||||
|
||||
use super::{constants::DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC, netns::NetNS};
|
||||
|
||||
struct InterfaceFilter {
|
||||
iface: NetworkInterface,
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl InterfaceFilter {
|
||||
async fn is_iface_bridge(&self) -> bool {
|
||||
let path = format!("/sys/class/net/{}/bridge", self.iface.name);
|
||||
tokio::fs::metadata(&path).await.is_ok()
|
||||
}
|
||||
|
||||
async fn is_iface_phsical(&self) -> bool {
|
||||
let path = format!("/sys/class/net/{}/device", self.iface.name);
|
||||
tokio::fs::metadata(&path).await.is_ok()
|
||||
}
|
||||
|
||||
async fn filter_iface(&self) -> bool {
|
||||
tracing::trace!(
|
||||
"filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_bridge: {}, is_physical: {}",
|
||||
self.iface,
|
||||
self.iface.is_point_to_point(),
|
||||
self.iface.is_loopback(),
|
||||
self.iface.is_up(),
|
||||
self.iface.is_lower_up(),
|
||||
self.is_iface_bridge().await,
|
||||
self.is_iface_phsical().await,
|
||||
);
|
||||
|
||||
!self.iface.is_point_to_point()
|
||||
&& !self.iface.is_loopback()
|
||||
&& self.iface.is_up()
|
||||
&& self.iface.is_lower_up()
|
||||
&& (self.is_iface_bridge().await || self.is_iface_phsical().await)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
impl InterfaceFilter {
|
||||
async fn is_interface_physical(interface_name: &str) -> bool {
|
||||
let output = tokio::process::Command::new("networksetup")
|
||||
.args(&["-listallhardwareports"])
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute command");
|
||||
|
||||
let stdout = std::str::from_utf8(&output.stdout).expect("Invalid UTF-8");
|
||||
|
||||
let lines: Vec<&str> = stdout.lines().collect();
|
||||
|
||||
for i in 0..lines.len() {
|
||||
let line = lines[i];
|
||||
|
||||
if line.contains("Device:") && line.contains(interface_name) {
|
||||
let next_line = lines[i + 1];
|
||||
if next_line.contains("Virtual Interface") {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
async fn filter_iface(&self) -> bool {
|
||||
!self.iface.is_point_to_point()
|
||||
&& !self.iface.is_loopback()
|
||||
&& self.iface.is_up()
|
||||
&& Self::is_interface_physical(&self.iface.name).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
impl InterfaceFilter {
|
||||
async fn filter_iface(&self) -> bool {
|
||||
!self.iface.is_point_to_point() && !self.iface.is_loopback() && self.iface.is_up()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn local_ipv4() -> std::io::Result<std::net::Ipv4Addr> {
|
||||
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
|
||||
socket.connect("8.8.8.8:80").await?;
|
||||
let addr = socket.local_addr()?;
|
||||
match addr.ip() {
|
||||
std::net::IpAddr::V4(ip) => Ok(ip),
|
||||
std::net::IpAddr::V6(_) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::AddrNotAvailable,
|
||||
"no ipv4 address",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn local_ipv6() -> std::io::Result<std::net::Ipv6Addr> {
|
||||
let socket = tokio::net::UdpSocket::bind("[::]:0").await?;
|
||||
socket
|
||||
.connect("[2001:4860:4860:0000:0000:0000:0000:8888]:80")
|
||||
.await?;
|
||||
let addr = socket.local_addr()?;
|
||||
match addr.ip() {
|
||||
std::net::IpAddr::V6(ip) => Ok(ip),
|
||||
std::net::IpAddr::V4(_) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::AddrNotAvailable,
|
||||
"no ipv4 address",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IPCollector {
|
||||
cached_ip_list: Arc<RwLock<GetIpListResponse>>,
|
||||
collect_ip_task: Mutex<JoinSet<()>>,
|
||||
net_ns: NetNS,
|
||||
}
|
||||
|
||||
impl IPCollector {
|
||||
pub fn new(net_ns: NetNS) -> Self {
|
||||
Self {
|
||||
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::new())),
|
||||
collect_ip_task: Mutex::new(JoinSet::new()),
|
||||
net_ns,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn collect_ip_addrs(&self) -> GetIpListResponse {
|
||||
let mut task = self.collect_ip_task.lock().await;
|
||||
if task.is_empty() {
|
||||
let cached_ip_list = self.cached_ip_list.clone();
|
||||
*cached_ip_list.write().await =
|
||||
Self::do_collect_ip_addrs(false, self.net_ns.clone()).await;
|
||||
let net_ns = self.net_ns.clone();
|
||||
task.spawn(async move {
|
||||
loop {
|
||||
let ip_addrs = Self::do_collect_ip_addrs(true, net_ns.clone()).await;
|
||||
*cached_ip_list.write().await = ip_addrs;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(
|
||||
DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return self.cached_ip_list.read().await.deref().clone();
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(net_ns))]
|
||||
async fn do_collect_ip_addrs(with_public: bool, net_ns: NetNS) -> GetIpListResponse {
|
||||
let mut ret = easytier_rpc::peer::GetIpListResponse {
|
||||
public_ipv4: "".to_string(),
|
||||
interface_ipv4s: vec![],
|
||||
public_ipv6: "".to_string(),
|
||||
interface_ipv6s: vec![],
|
||||
};
|
||||
|
||||
if with_public {
|
||||
if let Some(v4_addr) =
|
||||
public_ip::addr_with(public_ip::http::ALL, public_ip::Version::V4).await
|
||||
{
|
||||
ret.public_ipv4 = v4_addr.to_string();
|
||||
}
|
||||
|
||||
if let Some(v6_addr) = public_ip::addr_v6().await {
|
||||
ret.public_ipv6 = v6_addr.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
let _g = net_ns.guard();
|
||||
let ifaces = pnet::datalink::interfaces();
|
||||
for iface in ifaces {
|
||||
let f = InterfaceFilter {
|
||||
iface: iface.clone(),
|
||||
};
|
||||
|
||||
if !f.filter_iface().await {
|
||||
continue;
|
||||
}
|
||||
|
||||
for ip in iface.ips {
|
||||
let ip: std::net::IpAddr = ip.ip();
|
||||
if ip.is_loopback() || ip.is_multicast() {
|
||||
continue;
|
||||
}
|
||||
if ip.is_ipv4() {
|
||||
ret.interface_ipv4s.push(ip.to_string());
|
||||
} else if ip.is_ipv6() {
|
||||
ret.interface_ipv6s.push(ip.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(v4_addr) = local_ipv4().await {
|
||||
tracing::trace!("got local ipv4: {}", v4_addr);
|
||||
if !ret.interface_ipv4s.contains(&v4_addr.to_string()) {
|
||||
ret.interface_ipv4s.push(v4_addr.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(v6_addr) = local_ipv6().await {
|
||||
tracing::trace!("got local ipv6: {}", v6_addr);
|
||||
if !ret.interface_ipv6s.contains(&v6_addr.to_string()) {
|
||||
ret.interface_ipv6s.push(v6_addr.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
use rkyv::{
|
||||
validation::{validators::DefaultValidator, CheckTypeError},
|
||||
vec::ArchivedVec,
|
||||
Archive, CheckBytes, Serialize,
|
||||
};
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
pub fn decode_from_bytes_checked<'a, T: Archive>(
|
||||
bytes: &'a [u8],
|
||||
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
|
||||
where
|
||||
T::Archived: CheckBytes<DefaultValidator<'a>>,
|
||||
{
|
||||
rkyv::check_archived_root::<T>(bytes)
|
||||
}
|
||||
|
||||
pub fn decode_from_bytes<'a, T: Archive>(
|
||||
bytes: &'a [u8],
|
||||
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
|
||||
where
|
||||
T::Archived: CheckBytes<DefaultValidator<'a>>,
|
||||
{
|
||||
// rkyv::check_archived_root::<T>(bytes)
|
||||
unsafe { Ok(rkyv::archived_root::<T>(bytes)) }
|
||||
}
|
||||
|
||||
// allow deseraial T to Bytes
|
||||
pub fn encode_to_bytes<T, const N: usize>(val: &T) -> Bytes
|
||||
where
|
||||
T: Serialize<rkyv::ser::serializers::AllocSerializer<N>>,
|
||||
{
|
||||
let ret = rkyv::to_bytes::<_, N>(val).unwrap();
|
||||
// let mut r = BytesMut::new();
|
||||
// r.extend_from_slice(&ret);
|
||||
// r.freeze()
|
||||
ret.into_boxed_slice().into()
|
||||
}
|
||||
|
||||
pub fn extract_bytes_from_archived_vec(raw_data: &Bytes, archived_data: &ArchivedVec<u8>) -> Bytes {
|
||||
let ptr_range = archived_data.as_ptr_range();
|
||||
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
|
||||
let len = ptr_range.end as usize - ptr_range.start as usize;
|
||||
return raw_data.slice(offset..offset + len);
|
||||
}
|
||||
|
||||
pub fn extract_bytes_mut_from_archived_vec(
|
||||
raw_data: &mut BytesMut,
|
||||
archived_data: &ArchivedVec<u8>,
|
||||
) -> BytesMut {
|
||||
let ptr_range = archived_data.as_ptr_range();
|
||||
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
|
||||
let len = ptr_range.end as usize - ptr_range.start as usize;
|
||||
raw_data.split_off(offset).split_to(len)
|
||||
}
|
||||
@@ -0,0 +1,433 @@
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use easytier_rpc::{NatType, StunInfo};
|
||||
use stun_format::Attr;
|
||||
use tokio::net::{lookup_host, UdpSocket};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::common::error::Error;
|
||||
|
||||
struct Stun {
|
||||
stun_server: String,
|
||||
req_repeat: u8,
|
||||
resp_timeout: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct BindRequestResponse {
|
||||
source_addr: SocketAddr,
|
||||
mapped_socket_addr: Option<SocketAddr>,
|
||||
changed_socket_addr: Option<SocketAddr>,
|
||||
|
||||
ip_changed: bool,
|
||||
port_changed: bool,
|
||||
}
|
||||
|
||||
impl BindRequestResponse {
|
||||
pub fn get_mapped_addr_no_check(&self) -> &SocketAddr {
|
||||
self.mapped_socket_addr.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stun {
|
||||
pub fn new(stun_server: String) -> Self {
|
||||
Self {
|
||||
stun_server,
|
||||
req_repeat: 3,
|
||||
resp_timeout: Duration::from_millis(3000),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_stun_response<'a, const N: usize>(
|
||||
&self,
|
||||
buf: &'a mut [u8; N],
|
||||
udp: &UdpSocket,
|
||||
tids: &Vec<u128>,
|
||||
) -> Result<(stun_format::Msg<'a>, SocketAddr), Error> {
|
||||
let mut now = tokio::time::Instant::now();
|
||||
let deadline = now + self.resp_timeout;
|
||||
|
||||
while now < deadline {
|
||||
let mut udp_buf = [0u8; 1500];
|
||||
let (len, remote_addr) =
|
||||
tokio::time::timeout(deadline - now, udp.recv_from(udp_buf.as_mut_slice()))
|
||||
.await??;
|
||||
now = tokio::time::Instant::now();
|
||||
|
||||
if len < 20 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here
|
||||
unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) };
|
||||
|
||||
let msg = stun_format::Msg::<'a>::from(&buf[..]);
|
||||
tracing::trace!(b = ?&udp_buf[..len], ?msg, ?tids, "recv stun response");
|
||||
|
||||
if msg.typ().is_none() || msg.tid().is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(
|
||||
msg.typ().as_ref().unwrap(),
|
||||
stun_format::MsgType::BindingResponse
|
||||
) && tids.contains(msg.tid().as_ref().unwrap())
|
||||
{
|
||||
return Ok((msg, remote_addr));
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
|
||||
fn stun_addr(addr: stun_format::SocketAddr) -> SocketAddr {
|
||||
match addr {
|
||||
stun_format::SocketAddr::V4(ip, port) => {
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port))
|
||||
}
|
||||
stun_format::SocketAddr::V6(ip, port) => {
|
||||
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extrace_mapped_addr(msg: &stun_format::Msg) -> Option<SocketAddr> {
|
||||
let mut mapped_addr = None;
|
||||
for x in msg.attrs_iter() {
|
||||
match x {
|
||||
Attr::MappedAddress(addr) => {
|
||||
if mapped_addr.is_none() {
|
||||
let _ = mapped_addr.insert(Self::stun_addr(addr));
|
||||
}
|
||||
}
|
||||
Attr::XorMappedAddress(addr) => {
|
||||
if mapped_addr.is_none() {
|
||||
let _ = mapped_addr.insert(Self::stun_addr(addr));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
mapped_addr
|
||||
}
|
||||
|
||||
fn extract_changed_addr(msg: &stun_format::Msg) -> Option<SocketAddr> {
|
||||
let mut changed_addr = None;
|
||||
for x in msg.attrs_iter() {
|
||||
match x {
|
||||
Attr::ChangedAddress(addr) => {
|
||||
if changed_addr.is_none() {
|
||||
let _ = changed_addr.insert(Self::stun_addr(addr));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
changed_addr
|
||||
}
|
||||
|
||||
pub async fn bind_request(
|
||||
&self,
|
||||
source_port: u16,
|
||||
change_ip: bool,
|
||||
change_port: bool,
|
||||
) -> Result<BindRequestResponse, Error> {
|
||||
let stun_host = lookup_host(&self.stun_server)
|
||||
.await?
|
||||
.next()
|
||||
.ok_or(Error::NotFound)?;
|
||||
// let udp_socket = socket2::Socket::new(
|
||||
// match stun_host {
|
||||
// SocketAddr::V4(..) => socket2::Domain::IPV4,
|
||||
// SocketAddr::V6(..) => socket2::Domain::IPV6,
|
||||
// },
|
||||
// socket2::Type::DGRAM,
|
||||
// Some(socket2::Protocol::UDP),
|
||||
// )?;
|
||||
// udp_socket.set_reuse_port(true)?;
|
||||
// udp_socket.set_reuse_address(true)?;
|
||||
let udp = UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?;
|
||||
|
||||
// repeat req in case of packet loss
|
||||
let mut tids = vec![];
|
||||
for _ in 0..self.req_repeat {
|
||||
let mut buf = [0u8; 28];
|
||||
// memset buf
|
||||
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
|
||||
let mut msg = stun_format::MsgBuilder::from(buf.as_mut_slice());
|
||||
msg.typ(stun_format::MsgType::BindingRequest).unwrap();
|
||||
let tid = rand::random::<u32>();
|
||||
msg.tid(tid as u128).unwrap();
|
||||
if change_ip || change_port {
|
||||
msg.add_attr(Attr::ChangeRequest {
|
||||
change_ip,
|
||||
change_port,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
tids.push(tid as u128);
|
||||
tracing::trace!(b = ?msg.as_bytes(), tid, "send stun request");
|
||||
udp.send_to(msg.as_bytes(), &stun_host).await?;
|
||||
}
|
||||
|
||||
tracing::trace!("waiting stun response");
|
||||
let mut buf = [0; 1620];
|
||||
let (msg, recv_addr) = self.wait_stun_response(&mut buf, &udp, &tids).await?;
|
||||
|
||||
let changed_socket_addr = Self::extract_changed_addr(&msg);
|
||||
let ip_changed = stun_host.ip() != recv_addr.ip();
|
||||
let port_changed = stun_host.port() != recv_addr.port();
|
||||
|
||||
let resp = BindRequestResponse {
|
||||
source_addr: udp.local_addr()?,
|
||||
mapped_socket_addr: Self::extrace_mapped_addr(&msg),
|
||||
changed_socket_addr,
|
||||
ip_changed,
|
||||
port_changed,
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
?stun_host,
|
||||
?recv_addr,
|
||||
?changed_socket_addr,
|
||||
"finish stun bind request"
|
||||
);
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
|
||||
struct UdpNatTypeDetector {
|
||||
stun_servers: Vec<String>,
|
||||
}
|
||||
|
||||
impl UdpNatTypeDetector {
|
||||
pub fn new(stun_servers: Vec<String>) -> Self {
|
||||
Self { stun_servers }
|
||||
}
|
||||
|
||||
async fn get_udp_nat_type(&self, mut source_port: u16) -> NatType {
|
||||
// Like classic STUN (rfc3489). Detect NAT behavior for UDP.
|
||||
// Modified from rfc3489. Requires at least two STUN servers.
|
||||
let mut ret_test1_1 = None;
|
||||
let mut ret_test1_2 = None;
|
||||
let mut ret_test2 = None;
|
||||
let mut ret_test3 = None;
|
||||
|
||||
if source_port == 0 {
|
||||
let udp = UdpSocket::bind("0.0.0.0:0").await.unwrap();
|
||||
source_port = udp.local_addr().unwrap().port();
|
||||
}
|
||||
|
||||
let mut succ = false;
|
||||
for server_ip in &self.stun_servers {
|
||||
let stun = Stun::new(server_ip.clone());
|
||||
let ret = stun.bind_request(source_port, false, false).await;
|
||||
if ret.is_err() {
|
||||
// Try another STUN server
|
||||
continue;
|
||||
}
|
||||
if ret_test1_1.is_none() {
|
||||
ret_test1_1 = ret.ok();
|
||||
continue;
|
||||
}
|
||||
ret_test1_2 = ret.ok();
|
||||
let ret = stun.bind_request(source_port, true, true).await;
|
||||
if let Ok(resp) = ret {
|
||||
if !resp.ip_changed || !resp.port_changed {
|
||||
// Try another STUN server
|
||||
continue;
|
||||
}
|
||||
}
|
||||
ret_test2 = ret.ok();
|
||||
ret_test3 = stun.bind_request(source_port, false, true).await.ok();
|
||||
succ = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if !succ {
|
||||
return NatType::Unknown;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
?ret_test1_1,
|
||||
?ret_test1_2,
|
||||
?ret_test2,
|
||||
?ret_test3,
|
||||
"finish stun test, try to detect nat type"
|
||||
);
|
||||
|
||||
let ret_test1_1 = ret_test1_1.unwrap();
|
||||
let ret_test1_2 = ret_test1_2.unwrap();
|
||||
|
||||
if ret_test1_1.mapped_socket_addr != ret_test1_2.mapped_socket_addr {
|
||||
return NatType::Symmetric;
|
||||
}
|
||||
|
||||
if ret_test1_1.mapped_socket_addr.is_some()
|
||||
&& ret_test1_1.source_addr == ret_test1_1.mapped_socket_addr.unwrap()
|
||||
{
|
||||
if !ret_test2.is_none() {
|
||||
return NatType::OpenInternet;
|
||||
} else {
|
||||
return NatType::SymUdpFirewall;
|
||||
}
|
||||
} else {
|
||||
if let Some(ret_test2) = ret_test2 {
|
||||
if source_port == ret_test2.get_mapped_addr_no_check().port()
|
||||
&& source_port == ret_test1_1.get_mapped_addr_no_check().port()
|
||||
{
|
||||
return NatType::NoPat;
|
||||
} else {
|
||||
return NatType::FullCone;
|
||||
}
|
||||
} else {
|
||||
if !ret_test3.is_none() {
|
||||
return NatType::Restricted;
|
||||
} else {
|
||||
return NatType::PortRestricted;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(&, Arc, Box)]
|
||||
pub trait StunInfoCollectorTrait: Send + Sync {
|
||||
fn get_stun_info(&self) -> StunInfo;
|
||||
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error>;
|
||||
}
|
||||
|
||||
pub struct StunInfoCollector {
|
||||
stun_servers: Arc<RwLock<Vec<String>>>,
|
||||
udp_nat_type: Arc<AtomicCell<(NatType, std::time::Instant)>>,
|
||||
redetect_notify: Arc<tokio::sync::Notify>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StunInfoCollectorTrait for StunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
let (typ, time) = self.udp_nat_type.load();
|
||||
StunInfo {
|
||||
udp_nat_type: typ as i32,
|
||||
tcp_nat_type: 0,
|
||||
last_update_time: time.elapsed().as_secs() as i64,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
|
||||
let stun_servers = self.stun_servers.read().await.clone();
|
||||
for server in stun_servers.iter() {
|
||||
let stun = Stun::new(server.clone());
|
||||
let Ok(ret) = stun.bind_request(local_port, false, false).await else {
|
||||
tracing::warn!(?server, "stun bind request failed");
|
||||
continue;
|
||||
};
|
||||
if let Some(mapped_addr) = ret.mapped_socket_addr {
|
||||
return Ok(mapped_addr);
|
||||
}
|
||||
}
|
||||
Err(Error::NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
impl StunInfoCollector {
|
||||
pub fn new(stun_servers: Vec<String>) -> Self {
|
||||
let mut ret = Self {
|
||||
stun_servers: Arc::new(RwLock::new(stun_servers)),
|
||||
udp_nat_type: Arc::new(AtomicCell::new((
|
||||
NatType::Unknown,
|
||||
std::time::Instant::now(),
|
||||
))),
|
||||
redetect_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
tasks: JoinSet::new(),
|
||||
};
|
||||
|
||||
ret.start_stun_routine();
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
fn start_stun_routine(&mut self) {
|
||||
let stun_servers = self.stun_servers.clone();
|
||||
let udp_nat_type = self.udp_nat_type.clone();
|
||||
let redetect_notify = self.redetect_notify.clone();
|
||||
self.tasks.spawn(async move {
|
||||
loop {
|
||||
let detector = UdpNatTypeDetector::new(stun_servers.read().await.clone());
|
||||
let ret = detector.get_udp_nat_type(0).await;
|
||||
udp_nat_type.store((ret, std::time::Instant::now()));
|
||||
|
||||
let sleep_sec = match ret {
|
||||
NatType::Unknown => 15,
|
||||
_ => 60,
|
||||
};
|
||||
tracing::info!(?ret, ?sleep_sec, "finish udp nat type detect");
|
||||
|
||||
tokio::select! {
|
||||
_ = redetect_notify.notified() => {}
|
||||
_ = tokio::time::sleep(Duration::from_secs(sleep_sec)) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn update_stun_info(&self) {
|
||||
self.redetect_notify.notify_one();
|
||||
}
|
||||
|
||||
pub async fn set_stun_servers(&self, stun_servers: Vec<String>) {
|
||||
*self.stun_servers.write().await = stun_servers;
|
||||
self.update_stun_info();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stun_bind_request() {
|
||||
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
|
||||
// let stun = Stun::new("stun.counterpath.com:3478".to_string());
|
||||
let stun = Stun::new("180.235.108.91:3478".to_string());
|
||||
// let stun = Stun::new("193.22.2.248:3478".to_string());
|
||||
|
||||
// let stun = Stun::new("stun.chat.bilibili.com:3478".to_string());
|
||||
// let stun = Stun::new("stun.miwifi.com:3478".to_string());
|
||||
|
||||
let rs = stun.bind_request(12345, true, true).await.unwrap();
|
||||
assert!(rs.ip_changed);
|
||||
assert!(rs.port_changed);
|
||||
|
||||
let rs = stun.bind_request(12345, true, false).await.unwrap();
|
||||
assert!(rs.ip_changed);
|
||||
assert!(!rs.port_changed);
|
||||
|
||||
let rs = stun.bind_request(12345, false, true).await.unwrap();
|
||||
assert!(!rs.ip_changed);
|
||||
assert!(rs.port_changed);
|
||||
|
||||
let rs = stun.bind_request(12345, false, false).await.unwrap();
|
||||
assert!(!rs.ip_changed);
|
||||
assert!(!rs.port_changed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udp_nat_type_detect() {
|
||||
let detector = UdpNatTypeDetector::new(vec![
|
||||
"stun.counterpath.com:3478".to_string(),
|
||||
"180.235.108.91:3478".to_string(),
|
||||
]);
|
||||
let ret = detector.get_udp_nat_type(0).await;
|
||||
|
||||
assert_eq!(ret, NatType::FullCone);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,325 @@
|
||||
// try connect peers directly, with either its public ip or lan ip
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
constants::{self, DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC},
|
||||
error::Error,
|
||||
global_ctx::ArcGlobalCtx,
|
||||
network::IPCollector,
|
||||
},
|
||||
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager, PeerId},
|
||||
};
|
||||
|
||||
use easytier_rpc::{peer::GetIpListResponse, PeerConnInfo};
|
||||
use tokio::{task::JoinSet, time::timeout};
|
||||
use tracing::Instrument;
|
||||
|
||||
use super::create_connector_by_url;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait DirectConnectorRpc {
|
||||
async fn get_ip_list() -> GetIpListResponse;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait PeerManagerForDirectConnector {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
async fn list_peer_conns(&self, peer_id: &PeerId) -> Option<Vec<PeerConnInfo>>;
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerManagerForDirectConnector for PeerManager {
|
||||
async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let mut ret = vec![];
|
||||
|
||||
let routes = self.list_routes().await;
|
||||
for r in routes.iter() {
|
||||
ret.push(r.peer_id.parse().unwrap());
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn list_peer_conns(&self, peer_id: &PeerId) -> Option<Vec<PeerConnInfo>> {
|
||||
self.get_peer_map().list_peer_conns(peer_id).await
|
||||
}
|
||||
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
|
||||
self.get_peer_rpc_mgr()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DirectConnectorManagerRpcServer {
|
||||
// TODO: this only cache for one src peer, should make it global
|
||||
ip_list_collector: Arc<IPCollector>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
|
||||
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse {
|
||||
return self.ip_list_collector.collect_ip_addrs().await;
|
||||
}
|
||||
}
|
||||
|
||||
impl DirectConnectorManagerRpcServer {
|
||||
pub fn new(ip_collector: Arc<IPCollector>) -> Self {
|
||||
Self {
|
||||
ip_list_collector: ip_collector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct DstBlackListItem(PeerId, String);
|
||||
|
||||
struct DirectConnectorManagerData {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
dst_blacklist: timedmap::TimedMap<DstBlackListItem, ()>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DirectConnectorManagerData {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("DirectConnectorManagerData")
|
||||
.field("peer_manager", &self.peer_manager)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DirectConnectorManager {
|
||||
my_node_id: uuid::Uuid,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl DirectConnectorManager {
|
||||
pub fn new(
|
||||
my_node_id: uuid::Uuid,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
) -> Self {
|
||||
Self {
|
||||
my_node_id,
|
||||
global_ctx: global_ctx.clone(),
|
||||
data: Arc::new(DirectConnectorManagerData {
|
||||
global_ctx,
|
||||
peer_manager,
|
||||
dst_blacklist: timedmap::TimedMap::new(),
|
||||
}),
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self) {
|
||||
self.run_as_server();
|
||||
self.run_as_client();
|
||||
}
|
||||
|
||||
pub fn run_as_server(&mut self) {
|
||||
self.data.peer_manager.get_peer_rpc_mgr().run_service(
|
||||
constants::DIRECT_CONNECTOR_SERVICE_ID,
|
||||
DirectConnectorManagerRpcServer::new(self.global_ctx.get_ip_collector()).serve(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn run_as_client(&mut self) {
|
||||
let data = self.data.clone();
|
||||
let my_node_id = self.my_node_id.clone();
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
loop {
|
||||
let peers = data.peer_manager.list_peers().await;
|
||||
let mut tasks = JoinSet::new();
|
||||
for peer_id in peers {
|
||||
if peer_id == my_node_id {
|
||||
continue;
|
||||
}
|
||||
tasks.spawn(Self::do_try_direct_connect(data.clone(), peer_id));
|
||||
}
|
||||
|
||||
while let Some(task_ret) = tasks.join_next().await {
|
||||
tracing::trace!(?task_ret, "direct connect task ret");
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("direct_connector_client", my_id = ?self.my_node_id)),
|
||||
);
|
||||
}
|
||||
|
||||
async fn do_try_connect_to_ip(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
addr: String,
|
||||
) -> Result<(), Error> {
|
||||
data.dst_blacklist.cleanup();
|
||||
if data
|
||||
.dst_blacklist
|
||||
.contains(&DstBlackListItem(dst_peer_id.clone(), addr.clone()))
|
||||
{
|
||||
tracing::trace!("try_connect_to_ip failed, addr in blacklist: {}", addr);
|
||||
return Err(Error::UrlInBlacklist);
|
||||
}
|
||||
|
||||
let connector = create_connector_by_url(&addr, data.global_ctx.get_ip_collector()).await?;
|
||||
let (peer_id, conn_id) = timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
data.peer_manager.try_connect(connector),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// let (peer_id, conn_id) = data.peer_manager.try_connect(connector).await?;
|
||||
|
||||
if peer_id != dst_peer_id {
|
||||
tracing::info!(
|
||||
"connect to ip succ: {}, but peer id mismatch, expect: {}, actual: {}",
|
||||
addr,
|
||||
dst_peer_id,
|
||||
peer_id
|
||||
);
|
||||
data.peer_manager
|
||||
.get_peer_map()
|
||||
.close_peer_conn(&peer_id, &conn_id)
|
||||
.await?;
|
||||
return Err(Error::InvalidUrl(addr));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn try_connect_to_ip(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
addr: String,
|
||||
) {
|
||||
let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await;
|
||||
if let Err(e) = ret {
|
||||
if !matches!(e, Error::UrlInBlacklist) {
|
||||
tracing::info!(
|
||||
"try_connect_to_ip failed: {:?}, peer_id: {}",
|
||||
e,
|
||||
dst_peer_id
|
||||
);
|
||||
data.dst_blacklist.insert(
|
||||
DstBlackListItem(dst_peer_id.clone(), addr.clone()),
|
||||
(),
|
||||
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_try_direct_connect(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error> {
|
||||
let peer_manager = data.peer_manager.clone();
|
||||
// check if we have direct connection with dst_peer_id
|
||||
if let Some(c) = peer_manager.list_peer_conns(&dst_peer_id).await {
|
||||
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
|
||||
if !c.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!("try direct connect to peer: {}", dst_peer_id);
|
||||
|
||||
let ip_list = peer_manager
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
|
||||
let client =
|
||||
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ip_list = client.get_ip_list(tarpc::context::current()).await;
|
||||
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
|
||||
ip_list
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
ip_list.interface_ipv4s.iter().for_each(|ip| {
|
||||
let addr = format!("{}://{}:{}", "tcp", ip, 11010);
|
||||
tasks.spawn(Self::try_connect_to_ip(
|
||||
data.clone(),
|
||||
dst_peer_id.clone(),
|
||||
addr,
|
||||
));
|
||||
});
|
||||
|
||||
let addr = format!("{}://{}:{}", "tcp", ip_list.public_ipv4.clone(), 11010);
|
||||
tasks.spawn(Self::try_connect_to_ip(
|
||||
data.clone(),
|
||||
dst_peer_id.clone(),
|
||||
addr,
|
||||
));
|
||||
|
||||
while let Some(ret) = tasks.join_next().await {
|
||||
if let Err(e) = ret {
|
||||
log::error!("join direct connect task failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
connector::direct::DirectConnectorManager,
|
||||
instance::listeners::ListenerManager,
|
||||
peers::tests::{
|
||||
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
|
||||
wait_route_appear_with_cost,
|
||||
},
|
||||
tunnels::tcp_tunnel::TcpTunnelListener,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connector_basic_test() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let p_b = create_mock_peer_manager().await;
|
||||
let p_c = create_mock_peer_manager().await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
|
||||
wait_route_appear(p_a.clone(), p_c.my_node_id())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut dm_a =
|
||||
DirectConnectorManager::new(p_a.my_node_id(), p_a.get_global_ctx(), p_a.clone());
|
||||
let mut dm_c =
|
||||
DirectConnectorManager::new(p_c.my_node_id(), p_c.get_global_ctx(), p_c.clone());
|
||||
|
||||
dm_a.run_as_client();
|
||||
dm_c.run_as_server();
|
||||
|
||||
let mut lis_c = ListenerManager::new(
|
||||
p_c.my_node_id(),
|
||||
p_c.get_global_ctx().net_ns.clone(),
|
||||
p_c.clone(),
|
||||
);
|
||||
|
||||
lis_c
|
||||
.add_listener(TcpTunnelListener::new(
|
||||
"tcp://0.0.0.0:11010".parse().unwrap(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
lis_c.run().await.unwrap();
|
||||
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_node_id(), Some(1))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,384 @@
|
||||
use std::{collections::BTreeSet, sync::Arc};
|
||||
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use easytier_rpc::{
|
||||
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
|
||||
ListConnectorRequest, ManageConnectorRequest,
|
||||
};
|
||||
use tokio::{
|
||||
sync::{broadcast::Receiver, mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
netns::NetNS,
|
||||
},
|
||||
connector::set_bind_addr_for_peer_connector,
|
||||
peers::peer_manager::PeerManager,
|
||||
tunnels::{Tunnel, TunnelConnector},
|
||||
use_global_var,
|
||||
};
|
||||
|
||||
use super::create_connector_by_url;
|
||||
|
||||
type ConnectorMap = Arc<DashMap<String, Box<dyn TunnelConnector + Send + Sync>>>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ReconnResult {
|
||||
dead_url: String,
|
||||
peer_id: uuid::Uuid,
|
||||
conn_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
struct ConnectorManagerData {
|
||||
connectors: ConnectorMap,
|
||||
reconnecting: DashSet<String>,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
alive_conn_urls: Arc<Mutex<BTreeSet<String>>>,
|
||||
// user removed connector urls
|
||||
removed_conn_urls: Arc<DashSet<String>>,
|
||||
net_ns: NetNS,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
}
|
||||
|
||||
pub struct ManualConnectorManager {
|
||||
my_node_id: uuid::Uuid,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
data: Arc<ConnectorManagerData>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl ManualConnectorManager {
|
||||
pub fn new(
|
||||
my_node_id: uuid::Uuid,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
) -> Self {
|
||||
let connectors = Arc::new(DashMap::new());
|
||||
let tasks = JoinSet::new();
|
||||
let event_subscriber = global_ctx.subscribe();
|
||||
|
||||
let mut ret = Self {
|
||||
my_node_id,
|
||||
global_ctx: global_ctx.clone(),
|
||||
data: Arc::new(ConnectorManagerData {
|
||||
connectors,
|
||||
reconnecting: DashSet::new(),
|
||||
peer_manager,
|
||||
alive_conn_urls: Arc::new(Mutex::new(BTreeSet::new())),
|
||||
removed_conn_urls: Arc::new(DashSet::new()),
|
||||
net_ns: global_ctx.net_ns.clone(),
|
||||
global_ctx,
|
||||
}),
|
||||
tasks,
|
||||
};
|
||||
|
||||
ret.tasks
|
||||
.spawn(Self::conn_mgr_routine(ret.data.clone(), event_subscriber));
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn add_connector<T>(&self, connector: T)
|
||||
where
|
||||
T: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
log::info!("add_connector: {}", connector.remote_url());
|
||||
self.data
|
||||
.connectors
|
||||
.insert(connector.remote_url().into(), Box::new(connector));
|
||||
}
|
||||
|
||||
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
|
||||
self.add_connector(create_connector_by_url(url, self.global_ctx.get_ip_collector()).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> {
|
||||
log::info!("remove_connector: {}", url);
|
||||
if !self.list_connectors().await.iter().any(|x| x.url == url) {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
self.data.removed_conn_urls.insert(url.into());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_connectors(&self) -> Vec<Connector> {
|
||||
let conn_urls: BTreeSet<String> = self
|
||||
.data
|
||||
.connectors
|
||||
.iter()
|
||||
.map(|x| x.key().clone().into())
|
||||
.collect();
|
||||
|
||||
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
|
||||
.await
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let mut ret = Vec::new();
|
||||
|
||||
for conn_url in conn_urls {
|
||||
let mut status = ConnectorStatus::Connected;
|
||||
if dead_urls.contains(&conn_url) {
|
||||
status = ConnectorStatus::Disconnected;
|
||||
}
|
||||
ret.insert(
|
||||
0,
|
||||
Connector {
|
||||
url: conn_url,
|
||||
status: status.into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let reconnecting_urls: BTreeSet<String> = self
|
||||
.data
|
||||
.reconnecting
|
||||
.iter()
|
||||
.map(|x| x.clone().into())
|
||||
.collect();
|
||||
|
||||
for conn_url in reconnecting_urls {
|
||||
ret.insert(
|
||||
0,
|
||||
Connector {
|
||||
url: conn_url,
|
||||
status: ConnectorStatus::Connecting.into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn conn_mgr_routine(
|
||||
data: Arc<ConnectorManagerData>,
|
||||
mut event_recv: Receiver<GlobalCtxEvent>,
|
||||
) {
|
||||
log::warn!("conn_mgr_routine started");
|
||||
let mut reconn_interval = tokio::time::interval(std::time::Duration::from_millis(
|
||||
use_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS),
|
||||
));
|
||||
let mut reconn_tasks = JoinSet::new();
|
||||
let (reconn_result_send, mut reconn_result_recv) = mpsc::channel(100);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = event_recv.recv() => {
|
||||
if let Ok(event) = event {
|
||||
Self::handle_event(&event, data.clone()).await;
|
||||
} else {
|
||||
log::warn!("event_recv closed");
|
||||
panic!("event_recv closed");
|
||||
}
|
||||
}
|
||||
|
||||
_ = reconn_interval.tick() => {
|
||||
let dead_urls = Self::collect_dead_conns(data.clone()).await;
|
||||
if dead_urls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for dead_url in dead_urls {
|
||||
let data_clone = data.clone();
|
||||
let sender = reconn_result_send.clone();
|
||||
let (_, connector) = data.connectors.remove(&dead_url).unwrap();
|
||||
let insert_succ = data.reconnecting.insert(dead_url.clone());
|
||||
assert!(insert_succ);
|
||||
reconn_tasks.spawn(async move {
|
||||
sender.send(Self::conn_reconnect(data_clone.clone(), dead_url, connector).await).await.unwrap();
|
||||
});
|
||||
}
|
||||
log::info!("reconn_interval tick, done");
|
||||
}
|
||||
|
||||
ret = reconn_result_recv.recv() => {
|
||||
log::warn!("reconn_tasks done, out: {:?}", ret);
|
||||
let _ = reconn_tasks.join_next().await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_event(event: &GlobalCtxEvent, data: Arc<ConnectorManagerData>) {
|
||||
match event {
|
||||
GlobalCtxEvent::PeerConnAdded(conn_info) => {
|
||||
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
|
||||
data.alive_conn_urls.lock().await.insert(addr);
|
||||
log::warn!("peer conn added: {:?}", conn_info);
|
||||
}
|
||||
|
||||
GlobalCtxEvent::PeerConnRemoved(conn_info) => {
|
||||
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
|
||||
data.alive_conn_urls.lock().await.remove(&addr);
|
||||
log::warn!("peer conn removed: {:?}", conn_info);
|
||||
}
|
||||
|
||||
GlobalCtxEvent::PeerAdded => todo!(),
|
||||
GlobalCtxEvent::PeerRemoved => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_remove_connector(data: Arc<ConnectorManagerData>) {
|
||||
let remove_later = DashSet::new();
|
||||
for it in data.removed_conn_urls.iter() {
|
||||
let url = it.key();
|
||||
if let Some(_) = data.connectors.remove(url) {
|
||||
log::warn!("connector: {}, removed", url);
|
||||
continue;
|
||||
} else if data.reconnecting.contains(url) {
|
||||
log::warn!("connector: {}, reconnecting, remove later.", url);
|
||||
remove_later.insert(url.clone());
|
||||
continue;
|
||||
} else {
|
||||
log::warn!("connector: {}, not found", url);
|
||||
}
|
||||
}
|
||||
data.removed_conn_urls.clear();
|
||||
for it in remove_later.iter() {
|
||||
data.removed_conn_urls.insert(it.key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
|
||||
Self::handle_remove_connector(data.clone());
|
||||
|
||||
let curr_alive = data.alive_conn_urls.lock().await.clone();
|
||||
let all_urls: BTreeSet<String> = data
|
||||
.connectors
|
||||
.iter()
|
||||
.map(|x| x.key().clone().into())
|
||||
.collect();
|
||||
&all_urls - &curr_alive
|
||||
}
|
||||
|
||||
async fn conn_reconnect(
|
||||
data: Arc<ConnectorManagerData>,
|
||||
dead_url: String,
|
||||
connector: Box<dyn TunnelConnector + Send + Sync>,
|
||||
) -> Result<ReconnResult, Error> {
|
||||
let connector = Arc::new(Mutex::new(Some(connector)));
|
||||
let net_ns = data.net_ns.clone();
|
||||
|
||||
log::info!("reconnect: {}", dead_url);
|
||||
|
||||
let connector_clone = connector.clone();
|
||||
let data_clone = data.clone();
|
||||
let url_clone = dead_url.clone();
|
||||
let ip_collector = data.global_ctx.get_ip_collector();
|
||||
let reconn_task = async move {
|
||||
let mut locked = connector_clone.lock().await;
|
||||
let conn = locked.as_mut().unwrap();
|
||||
// TODO: should support set v6 here, use url in connector array
|
||||
set_bind_addr_for_peer_connector(conn, true, &ip_collector).await;
|
||||
let _g = net_ns.guard();
|
||||
log::info!("reconnect try connect... conn: {:?}", conn);
|
||||
let tunnel = conn.connect().await?;
|
||||
log::info!("reconnect get tunnel succ: {:?}", tunnel);
|
||||
assert_eq!(
|
||||
url_clone,
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
"info: {:?}",
|
||||
tunnel.info()
|
||||
);
|
||||
let (peer_id, conn_id) = data_clone.peer_manager.add_client_tunnel(tunnel).await?;
|
||||
log::info!("reconnect succ: {} {} {}", peer_id, conn_id, url_clone);
|
||||
Ok(ReconnResult {
|
||||
dead_url: url_clone,
|
||||
peer_id,
|
||||
conn_id,
|
||||
})
|
||||
};
|
||||
|
||||
let ret = timeout(std::time::Duration::from_secs(1), reconn_task).await;
|
||||
log::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
|
||||
|
||||
let conn = connector.lock().await.take().unwrap();
|
||||
data.reconnecting.remove(&dead_url).unwrap();
|
||||
data.connectors.insert(dead_url.clone(), conn);
|
||||
|
||||
ret?
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl ConnectorManageRpc for ConnectorManagerRpcService {
|
||||
async fn list_connector(
|
||||
&self,
|
||||
_request: tonic::Request<ListConnectorRequest>,
|
||||
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> {
|
||||
let mut ret = easytier_rpc::ListConnectorResponse::default();
|
||||
let connectors = self.0.list_connectors().await;
|
||||
ret.connectors = connectors;
|
||||
Ok(tonic::Response::new(ret))
|
||||
}
|
||||
|
||||
async fn manage_connector(
|
||||
&self,
|
||||
request: tonic::Request<ManageConnectorRequest>,
|
||||
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> {
|
||||
let req = request.into_inner();
|
||||
let url = url::Url::parse(&req.url)
|
||||
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?;
|
||||
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 {
|
||||
self.0.remove_connector(url.path()).await.map_err(|e| {
|
||||
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e))
|
||||
})?;
|
||||
return Ok(tonic::Response::new(
|
||||
easytier_rpc::ManageConnectorResponse::default(),
|
||||
));
|
||||
} else {
|
||||
self.0
|
||||
.add_connector_by_url(url.as_str())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
|
||||
})?;
|
||||
}
|
||||
Ok(tonic::Response::new(
|
||||
easytier_rpc::ManageConnectorResponse::default(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
peers::tests::create_mock_peer_manager,
|
||||
set_global_var,
|
||||
tunnels::{Tunnel, TunnelError},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reconnect_with_connecting_addr() {
|
||||
set_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, 1);
|
||||
|
||||
let peer_mgr = create_mock_peer_manager().await;
|
||||
let my_node_id = uuid::Uuid::new_v4();
|
||||
let mgr = ManualConnectorManager::new(my_node_id, peer_mgr.get_global_ctx(), peer_mgr);
|
||||
|
||||
struct MockConnector {}
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelConnector for MockConnector {
|
||||
fn remote_url(&self) -> url::Url {
|
||||
url::Url::parse("tcp://aa.com").unwrap()
|
||||
}
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
Err(TunnelError::CommonError("fake error".into()))
|
||||
}
|
||||
}
|
||||
|
||||
mgr.add_connector(MockConnector {});
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
use std::{
|
||||
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, network::IPCollector},
|
||||
tunnels::{
|
||||
ring_tunnel::RingTunnelConnector, tcp_tunnel::TcpTunnelConnector,
|
||||
udp_tunnel::UdpTunnelConnector, TunnelConnector,
|
||||
},
|
||||
};
|
||||
|
||||
pub mod direct;
|
||||
pub mod manual;
|
||||
pub mod udp_hole_punch;
|
||||
|
||||
async fn set_bind_addr_for_peer_connector(
|
||||
connector: &mut impl TunnelConnector,
|
||||
is_ipv4: bool,
|
||||
ip_collector: &Arc<IPCollector>,
|
||||
) {
|
||||
let ips = ip_collector.collect_ip_addrs().await;
|
||||
if is_ipv4 {
|
||||
let mut bind_addrs = vec![];
|
||||
for ipv4 in ips.interface_ipv4s {
|
||||
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into();
|
||||
bind_addrs.push(socket_addr);
|
||||
}
|
||||
connector.set_bind_addrs(bind_addrs);
|
||||
} else {
|
||||
let mut bind_addrs = vec![];
|
||||
for ipv6 in ips.interface_ipv6s {
|
||||
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into();
|
||||
bind_addrs.push(socket_addr);
|
||||
}
|
||||
connector.set_bind_addrs(bind_addrs);
|
||||
}
|
||||
let _ = connector;
|
||||
}
|
||||
|
||||
pub async fn create_connector_by_url(
|
||||
url: &str,
|
||||
ip_collector: Arc<IPCollector>,
|
||||
) -> Result<Box<dyn TunnelConnector + Send + Sync + 'static>, Error> {
|
||||
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
||||
match url.scheme() {
|
||||
"tcp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
|
||||
let mut connector = TcpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(&mut connector, dst_addr.is_ipv4(), &ip_collector)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"udp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
|
||||
let mut connector = UdpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(&mut connector, dst_addr.is_ipv4(), &ip_collector)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"ring" => {
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
|
||||
let connector = RingTunnelConnector::new(url);
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidUrl(url.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,523 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use easytier_rpc::NatType;
|
||||
use rand::{seq::SliceRandom, Rng, SeedableRng};
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
constants, error::Error, global_ctx::ArcGlobalCtx, rkyv_util::encode_to_bytes,
|
||||
stun::StunInfoCollectorTrait,
|
||||
},
|
||||
peers::{peer_manager::PeerManager, PeerId},
|
||||
tunnels::{
|
||||
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
|
||||
Tunnel, TunnelConnCounter, TunnelListener,
|
||||
},
|
||||
};
|
||||
|
||||
use super::direct::PeerManagerForDirectConnector;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait UdpHolePunchService {
|
||||
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UdpHolePunchListener {
|
||||
socket: Arc<UdpSocket>,
|
||||
tasks: JoinSet<()>,
|
||||
running: Arc<AtomicCell<bool>>,
|
||||
mapped_addr: SocketAddr,
|
||||
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
|
||||
|
||||
listen_time: std::time::Instant,
|
||||
last_select_time: AtomicCell<std::time::Instant>,
|
||||
last_connected_time: Arc<AtomicCell<std::time::Instant>>,
|
||||
}
|
||||
|
||||
impl UdpHolePunchListener {
|
||||
async fn get_avail_port() -> Result<u16, Error> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
Ok(socket.local_addr()?.port())
|
||||
}
|
||||
|
||||
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
|
||||
let port = Self::get_avail_port().await?;
|
||||
let listen_url = format!("udp://0.0.0.0:{}", port);
|
||||
|
||||
let gctx = peer_mgr.get_global_ctx();
|
||||
let stun_info_collect = gctx.get_stun_info_collector();
|
||||
let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?;
|
||||
|
||||
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
|
||||
|
||||
listener.listen().await?;
|
||||
let socket = listener.get_socket().unwrap();
|
||||
|
||||
let running = Arc::new(AtomicCell::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
|
||||
let last_connected_time_clone = last_connected_time.clone();
|
||||
|
||||
let conn_counter = listener.get_conn_counter();
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
tasks.spawn(async move {
|
||||
while let Ok(conn) = listener.accept().await {
|
||||
last_connected_time_clone.store(std::time::Instant::now());
|
||||
tracing::warn!(?conn, "udp hole punching listener got peer connection");
|
||||
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
||||
tracing::error!(?e, "failed to add tunnel as server in hole punch listener");
|
||||
}
|
||||
}
|
||||
|
||||
running_clone.store(false);
|
||||
});
|
||||
|
||||
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
|
||||
|
||||
Ok(Self {
|
||||
tasks,
|
||||
socket,
|
||||
running,
|
||||
mapped_addr,
|
||||
conn_counter,
|
||||
|
||||
listen_time: std::time::Instant::now(),
|
||||
last_select_time: AtomicCell::new(std::time::Instant::now()),
|
||||
last_connected_time,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_socket(&self) -> Arc<UdpSocket> {
|
||||
self.last_select_time.store(std::time::Instant::now());
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UdpHolePunchConnectorData {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct UdpHolePunchRpcServer {
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl UdpHolePunchService for UdpHolePunchRpcServer {
|
||||
async fn try_punch_hole(
|
||||
self,
|
||||
_: tarpc::context::Context,
|
||||
local_mapped_addr: SocketAddr,
|
||||
) -> Option<SocketAddr> {
|
||||
let (socket, mapped_addr) = self.select_listener().await?;
|
||||
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
|
||||
|
||||
let my_udp_nat_type = self
|
||||
.data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type;
|
||||
|
||||
// if we are restricted, we need to send hole punching resp to client
|
||||
if my_udp_nat_type == NatType::PortRestricted as i32
|
||||
|| my_udp_nat_type == NatType::Restricted as i32
|
||||
{
|
||||
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
for _ in 0..10 {
|
||||
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
||||
// generate a 128 bytes vec with random data
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
let mut buf = vec![0u8; 128];
|
||||
rng.fill(&mut buf[..]);
|
||||
|
||||
let udp_packet = UdpPacket::new_hole_punch_packet(buf);
|
||||
let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet);
|
||||
let _ = socket
|
||||
.send_to(udp_packet_bytes.as_ref(), local_mapped_addr)
|
||||
.await;
|
||||
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Some(mapped_addr)
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpHolePunchRpcServer {
|
||||
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
|
||||
Self {
|
||||
data,
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
|
||||
let all_listener_sockets = &self.data.listeners;
|
||||
|
||||
// remove listener that not have connection in for 20 seconds
|
||||
all_listener_sockets.lock().await.retain(|listener| {
|
||||
listener.last_connected_time.load().elapsed().as_secs() < 20
|
||||
&& listener.conn_counter.get() > 0
|
||||
});
|
||||
|
||||
let mut use_last = false;
|
||||
if all_listener_sockets.lock().await.len() < 4 {
|
||||
tracing::warn!("creating new udp hole punching listener");
|
||||
all_listener_sockets.lock().await.push(
|
||||
UdpHolePunchListener::new(self.data.peer_mgr.clone())
|
||||
.await
|
||||
.ok()?,
|
||||
);
|
||||
use_last = true;
|
||||
}
|
||||
|
||||
let locked = all_listener_sockets.lock().await;
|
||||
|
||||
let listener = if use_last {
|
||||
locked.last()?
|
||||
} else {
|
||||
locked.choose(&mut rand::rngs::StdRng::from_entropy())?
|
||||
};
|
||||
|
||||
Some((listener.get_socket().await, listener.mapped_addr))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpHolePunchConnector {
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
// Currently support:
|
||||
// Symmetric -> Full Cone
|
||||
// Any Type of Full Cone -> Any Type of Full Cone
|
||||
|
||||
// if same level of full cone, node with smaller peer_id will be the initiator
|
||||
// if different level of full cone, node with more strict level will be the initiator
|
||||
|
||||
impl UdpHolePunchConnector {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
data: Arc::new(UdpHolePunchConnectorData {
|
||||
global_ctx,
|
||||
peer_mgr,
|
||||
listeners: Arc::new(Mutex::new(Vec::new())),
|
||||
}),
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_as_client(&mut self) -> Result<(), Error> {
|
||||
let data = self.data.clone();
|
||||
self.tasks.spawn(async move {
|
||||
Self::main_loop(data).await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_as_server(&mut self) -> Result<(), Error> {
|
||||
self.data.peer_mgr.get_peer_rpc_mgr().run_service(
|
||||
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
|
||||
UdpHolePunchRpcServer::new(self.data.clone()).serve(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
self.run_as_client().await?;
|
||||
self.run_as_server().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_peer_to_connect(data: Arc<UdpHolePunchConnectorData>) -> Vec<PeerId> {
|
||||
let mut peers_to_connect = Vec::new();
|
||||
|
||||
// do not do anything if:
|
||||
// 1. our stun test has not finished
|
||||
// 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
|
||||
let my_nat_type = data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type;
|
||||
|
||||
let my_nat_type = NatType::try_from(my_nat_type).unwrap();
|
||||
|
||||
if my_nat_type == NatType::Unknown
|
||||
|| my_nat_type == NatType::OpenInternet
|
||||
|| my_nat_type == NatType::NoPat
|
||||
{
|
||||
return peers_to_connect;
|
||||
}
|
||||
|
||||
// collect peer list from peer manager and do some filter:
|
||||
// 1. peers without direct conns;
|
||||
// 2. peers is full cone (any restricted type);
|
||||
for route in data.peer_mgr.list_routes().await.iter() {
|
||||
let Some(peer_stun_info) = route.stun_info.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let peer_id: PeerId = route.peer_id.parse().unwrap();
|
||||
let conns = data.peer_mgr.list_peer_conns(&peer_id).await;
|
||||
if conns.is_some() && conns.unwrap().len() > 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if peer is symmetric ignore it because we cannot connect to it
|
||||
// if peer is open internet or no pat, direct connector will connecto to it
|
||||
if peer_nat_type == NatType::Unknown
|
||||
|| peer_nat_type == NatType::OpenInternet
|
||||
|| peer_nat_type == NatType::NoPat
|
||||
|| peer_nat_type == NatType::Symmetric
|
||||
|| peer_nat_type == NatType::SymUdpFirewall
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// if we are symmetric, we can only connect to full cone
|
||||
// TODO: can also connect to restricted full cone, with some extra work
|
||||
if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall)
|
||||
&& peer_nat_type != NatType::FullCone
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// if we have smae level of full cone, node with smaller peer_id will be the initiator
|
||||
if my_nat_type == peer_nat_type {
|
||||
if data.global_ctx.id > peer_id {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// if we have different level of full cone
|
||||
// we will be the initiator if we have more strict level
|
||||
if my_nat_type < peer_nat_type {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
?peer_id,
|
||||
?peer_nat_type,
|
||||
?my_nat_type,
|
||||
?data.global_ctx.id,
|
||||
"found peer to do hole punching"
|
||||
);
|
||||
|
||||
peers_to_connect.push(peer_id);
|
||||
}
|
||||
|
||||
peers_to_connect
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_hole_punching(
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
|
||||
tracing::info!(?dst_peer_id, "start hole punching");
|
||||
// client: choose a local udp port, and get the pubic mapped port from stun server
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")?;
|
||||
let local_socket_addr = socket.local_addr()?;
|
||||
let local_port = socket.local_addr()?.port();
|
||||
drop(socket); // drop the socket to release the port
|
||||
|
||||
let local_mapped_addr = data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_udp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get udp port mapping")?;
|
||||
|
||||
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
|
||||
let Some(remote_mapped_addr) = data
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(
|
||||
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
|
||||
dst_peer_id,
|
||||
|c| async {
|
||||
let client =
|
||||
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let remote_mapped_addr = client
|
||||
.try_punch_hole(tarpc::context::current(), local_mapped_addr)
|
||||
.await;
|
||||
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr");
|
||||
remote_mapped_addr
|
||||
},
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
|
||||
};
|
||||
|
||||
// server: will send some punching resps, total 10 packets.
|
||||
// client: use the socket to create UdpTunnel with UdpTunnelConnector
|
||||
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
|
||||
|
||||
let connector = UdpTunnelConnector::new(
|
||||
format!(
|
||||
"udp://{}:{}",
|
||||
remote_mapped_addr.ip(),
|
||||
remote_mapped_addr.port()
|
||||
)
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let socket = UdpSocket::bind(local_socket_addr)
|
||||
.await
|
||||
.with_context(|| "")?;
|
||||
Ok(connector
|
||||
.try_connect_with_socket(socket)
|
||||
.await
|
||||
.with_context(|| "UdpTunnelConnector failed to connect remote")?)
|
||||
}
|
||||
|
||||
async fn main_loop(data: Arc<UdpHolePunchConnectorData>) {
|
||||
loop {
|
||||
let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await;
|
||||
tracing::trace!(?peers_to_connect, "peers to connect");
|
||||
if peers_to_connect.len() == 0 {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut tasks: JoinSet<Result<(), anyhow::Error>> = JoinSet::new();
|
||||
for peer_id in peers_to_connect {
|
||||
let data = data.clone();
|
||||
tasks.spawn(
|
||||
async move {
|
||||
let tunnel = Self::do_hole_punching(data.clone(), peer_id)
|
||||
.await
|
||||
.with_context(|| "failed to do hole punching")?;
|
||||
|
||||
let _ =
|
||||
data.peer_mgr
|
||||
.add_client_tunnel(tunnel)
|
||||
.await
|
||||
.with_context(|| {
|
||||
"failed to add tunnel as client in hole punch connector"
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.instrument(tracing::info_span!("doing hole punching client", ?peer_id)),
|
||||
);
|
||||
}
|
||||
|
||||
while let Some(res) = tasks.join_next().await {
|
||||
if let Err(e) = res {
|
||||
tracing::error!(?e, "failed to join hole punching job");
|
||||
continue;
|
||||
}
|
||||
|
||||
match res.unwrap() {
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "failed to do hole punching job");
|
||||
}
|
||||
Ok(_) => {
|
||||
tracing::info!("hole punching job succeed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use easytier_rpc::{NatType, StunInfo};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, stun::StunInfoCollectorTrait},
|
||||
connector::udp_hole_punch::UdpHolePunchConnector,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
tests::{
|
||||
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
|
||||
wait_route_appear_with_cost,
|
||||
},
|
||||
},
|
||||
tests::enable_log,
|
||||
};
|
||||
|
||||
struct MockStunInfoCollector {
|
||||
udp_nat_type: NatType,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StunInfoCollectorTrait for MockStunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
StunInfo {
|
||||
udp_nat_type: self.udp_nat_type as i32,
|
||||
tcp_nat_type: NatType::Unknown as i32,
|
||||
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_udp_port_mapping(&self, port: u16) -> Result<std::net::SocketAddr, Error> {
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_mock_peer_manager_with_mock_stun(udp_nat_type: NatType) -> Arc<PeerManager> {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
|
||||
p_a.get_global_ctx().replace_stun_info_collector(collector);
|
||||
p_a
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hole_punching() {
|
||||
enable_log();
|
||||
let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
|
||||
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
|
||||
wait_route_appear(p_a.clone(), p_c.my_node_id())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
|
||||
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone());
|
||||
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone());
|
||||
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_node_id(), Some(1))
|
||||
.await
|
||||
.unwrap();
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
use std::{
|
||||
mem::MaybeUninit,
|
||||
net::{IpAddr, Ipv4Addr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
thread,
|
||||
};
|
||||
|
||||
use pnet::packet::{
|
||||
icmp::{self, IcmpTypes},
|
||||
ip::IpNextHeaderProtocols,
|
||||
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
|
||||
Packet,
|
||||
};
|
||||
use socket2::Socket;
|
||||
use tokio::{
|
||||
sync::{mpsc::UnboundedSender, Mutex},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx},
|
||||
peers::{
|
||||
packet,
|
||||
peer_manager::{PeerManager, PeerPacketFilter},
|
||||
PeerId,
|
||||
},
|
||||
};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct IcmpNatKey {
|
||||
dst_ip: std::net::IpAddr,
|
||||
icmp_id: u16,
|
||||
icmp_seq: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IcmpNatEntry {
|
||||
src_peer_id: PeerId,
|
||||
my_peer_id: PeerId,
|
||||
src_ip: IpAddr,
|
||||
start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl IcmpNatEntry {
|
||||
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr) -> Result<Self, Error> {
|
||||
Ok(Self {
|
||||
src_peer_id,
|
||||
my_peer_id,
|
||||
src_ip,
|
||||
start_time: std::time::Instant::now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type IcmpNatTable = Arc<dashmap::DashMap<IcmpNatKey, IcmpNatEntry>>;
|
||||
type NewPacketSender = tokio::sync::mpsc::UnboundedSender<IcmpNatKey>;
|
||||
type NewPacketReceiver = tokio::sync::mpsc::UnboundedReceiver<IcmpNatKey>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IcmpProxy {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
|
||||
cidr_set: CidrSet,
|
||||
socket: socket2::Socket,
|
||||
|
||||
nat_table: IcmpNatTable,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
|
||||
let (size, addr) = socket.recv_from(buf)?;
|
||||
let addr = match addr.as_socket() {
|
||||
None => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
|
||||
Some(add) => add.ip(),
|
||||
};
|
||||
Ok((size, addr))
|
||||
}
|
||||
|
||||
fn socket_recv_loop(
|
||||
socket: Socket,
|
||||
nat_table: IcmpNatTable,
|
||||
sender: UnboundedSender<packet::Packet>,
|
||||
) {
|
||||
let mut buf = [0u8; 4096];
|
||||
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
|
||||
|
||||
loop {
|
||||
let Ok((len, peer_ip)) = socket_recv(&socket, data) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !peer_ip.is_ipv4() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[12..12 + len]) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(icmp_packet) = icmp::echo_reply::EchoReplyPacket::new(ipv4_packet.payload())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoReply {
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = IcmpNatKey {
|
||||
dst_ip: peer_ip,
|
||||
icmp_id: icmp_packet.get_identifier(),
|
||||
icmp_seq: icmp_packet.get_sequence_number(),
|
||||
};
|
||||
|
||||
let Some((_, v)) = nat_table.remove(&key) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// send packet back to the peer where this request origin.
|
||||
let IpAddr::V4(dest_ip) = v.src_ip else {
|
||||
continue;
|
||||
};
|
||||
|
||||
ipv4_packet.set_destination(dest_ip);
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
|
||||
let peer_packet = packet::Packet::new_data_packet(
|
||||
v.my_peer_id,
|
||||
v.src_peer_id,
|
||||
&ipv4_packet.to_immutable().packet(),
|
||||
);
|
||||
|
||||
if let Err(e) = sender.send(peer_packet) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for IcmpProxy {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
_: &Bytes,
|
||||
) -> Option<()> {
|
||||
let _ = self.global_ctx.get_ipv4()?;
|
||||
|
||||
let packet::ArchivedPacketBody::Data(x) = &packet.body else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let ipv4 = Ipv4Packet::new(&x.data)?;
|
||||
|
||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
|
||||
|
||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
|
||||
// drop it because we do not support other icmp types
|
||||
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
|
||||
return Some(());
|
||||
}
|
||||
|
||||
let icmp_id = icmp_packet.get_identifier();
|
||||
let icmp_seq = icmp_packet.get_sequence_number();
|
||||
|
||||
let key = IcmpNatKey {
|
||||
dst_ip: ipv4.get_destination().into(),
|
||||
icmp_id,
|
||||
icmp_seq,
|
||||
};
|
||||
|
||||
if packet.to_peer.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = IcmpNatEntry::new(
|
||||
packet.from_peer.to_uuid(),
|
||||
packet.to_peer.as_ref().unwrap().to_uuid(),
|
||||
ipv4.get_source().into(),
|
||||
)
|
||||
.ok()?;
|
||||
|
||||
if let Some(old) = self.nat_table.insert(key, value) {
|
||||
tracing::info!("icmp nat table entry replaced: {:?}", old);
|
||||
}
|
||||
|
||||
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
|
||||
tracing::error!("send icmp packet failed: {:?}", e);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
impl IcmpProxy {
|
||||
pub fn new(
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
) -> Result<Arc<Self>, Error> {
|
||||
let cidr_set = CidrSet::new(global_ctx.clone());
|
||||
|
||||
let _g = global_ctx.net_ns.guard();
|
||||
let socket = socket2::Socket::new(
|
||||
socket2::Domain::IPV4,
|
||||
socket2::Type::RAW,
|
||||
Some(socket2::Protocol::ICMPV4),
|
||||
)?;
|
||||
socket.bind(&socket2::SockAddr::from(SocketAddrV4::new(
|
||||
std::net::Ipv4Addr::UNSPECIFIED,
|
||||
0,
|
||||
)))?;
|
||||
|
||||
let ret = Self {
|
||||
global_ctx,
|
||||
peer_manager,
|
||||
cidr_set,
|
||||
socket,
|
||||
|
||||
nat_table: Arc::new(dashmap::DashMap::new()),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
};
|
||||
|
||||
Ok(Arc::new(ret))
|
||||
}
|
||||
|
||||
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
|
||||
self.start_icmp_proxy().await?;
|
||||
self.start_nat_table_cleaner().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_nat_table_cleaner(self: &Arc<Self>) -> Result<(), Error> {
|
||||
let nat_table = self.nat_table.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
nat_table.retain(|_, v| v.start_time.elapsed().as_secs() < 20);
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("icmp proxy nat table cleaner")),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_icmp_proxy(self: &Arc<Self>) -> Result<(), Error> {
|
||||
let socket = self.socket.try_clone()?;
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||
let nat_table = self.nat_table.clone();
|
||||
thread::spawn(|| {
|
||||
socket_recv_loop(socket, nat_table, sender);
|
||||
});
|
||||
|
||||
let peer_manager = self.peer_manager.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
let to_peer_id: uuid::Uuid = msg.to_peer.as_ref().unwrap().clone().into();
|
||||
let ret = peer_manager.send_msg(msg.into(), &to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("icmp proxy send loop")),
|
||||
);
|
||||
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_icmp_packet(
|
||||
&self,
|
||||
dst_ip: Ipv4Addr,
|
||||
icmp_packet: &icmp::echo_request::EchoRequestPacket,
|
||||
) -> Result<(), Error> {
|
||||
self.socket.send_to(
|
||||
icmp_packet.packet(),
|
||||
&SocketAddrV4::new(dst_ip.into(), 0).into(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
use dashmap::DashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::common::global_ctx::ArcGlobalCtx;
|
||||
|
||||
pub mod icmp_proxy;
|
||||
pub mod tcp_proxy;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CidrSet {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
cidr_set: Arc<DashSet<cidr::IpCidr>>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl CidrSet {
|
||||
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
|
||||
let mut ret = Self {
|
||||
global_ctx,
|
||||
cidr_set: Arc::new(DashSet::new()),
|
||||
tasks: JoinSet::new(),
|
||||
};
|
||||
ret.run_cidr_updater();
|
||||
ret
|
||||
}
|
||||
|
||||
fn run_cidr_updater(&mut self) {
|
||||
let global_ctx = self.global_ctx.clone();
|
||||
let cidr_set = self.cidr_set.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut last_cidrs = vec![];
|
||||
loop {
|
||||
let cidrs = global_ctx.get_proxy_cidrs();
|
||||
if cidrs != last_cidrs {
|
||||
last_cidrs = cidrs.clone();
|
||||
cidr_set.clear();
|
||||
for cidr in cidrs.iter() {
|
||||
cidr_set.insert(cidr.clone());
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn contains_v4(&self, ip: std::net::Ipv4Addr) -> bool {
|
||||
let ip = ip.into();
|
||||
return self.cidr_set.iter().any(|cidr| cidr.contains(&ip));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,402 @@
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use dashmap::DashMap;
|
||||
use pnet::packet::ip::IpNextHeaderProtocols;
|
||||
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
|
||||
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::sync::atomic::AtomicU16;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::copy_bidirectional;
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::common::error::Result;
|
||||
use crate::common::global_ctx::GlobalCtx;
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::peers::packet::{self, ArchivedPacket};
|
||||
use crate::peers::peer_manager::{NicPacketFilter, PeerManager, PeerPacketFilter};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum NatDstEntryState {
|
||||
// receive syn packet but not start connecting to dst
|
||||
SynReceived,
|
||||
// connecting to dst
|
||||
ConnectingDst,
|
||||
// connected to dst
|
||||
Connected,
|
||||
// connection closed
|
||||
Closed,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NatDstEntry {
|
||||
id: uuid::Uuid,
|
||||
src: SocketAddr,
|
||||
dst: SocketAddr,
|
||||
start_time: Instant,
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
state: AtomicCell<NatDstEntryState>,
|
||||
}
|
||||
|
||||
impl NatDstEntry {
|
||||
pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
src,
|
||||
dst,
|
||||
start_time: Instant::now(),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
state: AtomicCell::new(NatDstEntryState::SynReceived),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ArcNatDstEntry = Arc<NatDstEntry>;
|
||||
|
||||
type SynSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
|
||||
type ConnSockMap = Arc<DashMap<uuid::Uuid, ArcNatDstEntry>>;
|
||||
// peer src addr to nat entry, when respond tcp packet, should modify the tcp src addr to the nat entry's dst addr
|
||||
type AddrConnSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpProxy {
|
||||
global_ctx: Arc<GlobalCtx>,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
local_port: AtomicU16,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
|
||||
syn_map: SynSockMap,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
|
||||
cidr_set: CidrSet,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for TcpProxy {
|
||||
async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> {
|
||||
let ipv4_addr = self.global_ctx.get_ipv4()?;
|
||||
|
||||
let packet::ArchivedPacketBody::Data(x) = &packet.body else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let ipv4 = Ipv4Packet::new(&x.data)?;
|
||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
|
||||
|
||||
let mut packet_buffer = BytesMut::with_capacity(x.data.len());
|
||||
packet_buffer.extend_from_slice(&x.data.to_vec());
|
||||
|
||||
let (ip_buffer, tcp_buffer) =
|
||||
packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4);
|
||||
|
||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
||||
|
||||
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
|
||||
if is_tcp_syn {
|
||||
let source_ip = ip_packet.get_source();
|
||||
let source_port = tcp_packet.get_source();
|
||||
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
|
||||
|
||||
let dest_ip = ip_packet.get_destination();
|
||||
let dest_port = tcp_packet.get_destination();
|
||||
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
|
||||
|
||||
let old_val = self
|
||||
.syn_map
|
||||
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
|
||||
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
|
||||
}
|
||||
|
||||
ip_packet.set_destination(ipv4_addr);
|
||||
tcp_packet.set_destination(self.get_local_port());
|
||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
||||
|
||||
tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded");
|
||||
|
||||
if let Err(e) = self
|
||||
.peer_manager
|
||||
.get_nic_channel()
|
||||
.send(packet_buffer.freeze())
|
||||
.await
|
||||
{
|
||||
tracing::error!("send to nic failed: {:?}", e);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl NicPacketFilter for TcpProxy {
|
||||
async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut {
|
||||
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
|
||||
return data;
|
||||
};
|
||||
|
||||
let header_len = {
|
||||
let Some(ipv4) = &Ipv4Packet::new(&data[..]) else {
|
||||
return data;
|
||||
};
|
||||
|
||||
if ipv4.get_version() != 4
|
||||
|| ipv4.get_source() != my_ipv4
|
||||
|| ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
|
||||
{
|
||||
return data;
|
||||
}
|
||||
|
||||
ipv4.get_header_length() as usize * 4
|
||||
};
|
||||
|
||||
let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len);
|
||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
||||
|
||||
if tcp_packet.get_source() != self.get_local_port() {
|
||||
return data;
|
||||
}
|
||||
|
||||
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
|
||||
ip_packet.get_destination(),
|
||||
tcp_packet.get_destination(),
|
||||
));
|
||||
|
||||
tracing::trace!(dst_addr = ?dst_addr, "tcp packet try find entry");
|
||||
let entry = if let Some(entry) = self.addr_conn_map.get(&dst_addr) {
|
||||
entry
|
||||
} else {
|
||||
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
|
||||
return data;
|
||||
};
|
||||
syn_entry
|
||||
};
|
||||
let nat_entry = entry.clone();
|
||||
drop(entry);
|
||||
assert_eq!(nat_entry.src, dst_addr);
|
||||
|
||||
let IpAddr::V4(ip) = nat_entry.dst.ip() else {
|
||||
panic!("v4 nat entry src ip is not v4");
|
||||
};
|
||||
|
||||
ip_packet.set_source(ip);
|
||||
tcp_packet.set_source(nat_entry.dst.port());
|
||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
||||
|
||||
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
|
||||
|
||||
data
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpProxy {
|
||||
pub fn new(global_ctx: Arc<GlobalCtx>, peer_manager: Arc<PeerManager>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
peer_manager,
|
||||
|
||||
local_port: AtomicU16::new(0),
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
|
||||
syn_map: Arc::new(DashMap::new()),
|
||||
conn_map: Arc::new(DashMap::new()),
|
||||
addr_conn_map: Arc::new(DashMap::new()),
|
||||
|
||||
cidr_set: CidrSet::new(global_ctx),
|
||||
})
|
||||
}
|
||||
|
||||
fn update_ipv4_packet_checksum(
|
||||
ipv4_packet: &mut MutableIpv4Packet,
|
||||
tcp_packet: &mut MutableTcpPacket,
|
||||
) {
|
||||
tcp_packet.set_checksum(ipv4_checksum(
|
||||
&tcp_packet.to_immutable(),
|
||||
&ipv4_packet.get_source(),
|
||||
&ipv4_packet.get_destination(),
|
||||
));
|
||||
|
||||
ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
}
|
||||
|
||||
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
||||
self.run_syn_map_cleaner().await?;
|
||||
self.run_listener().await?;
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
self.peer_manager
|
||||
.add_nic_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_syn_map_cleaner(&self) -> Result<()> {
|
||||
let syn_map = self.syn_map.clone();
|
||||
let tasks = self.tasks.clone();
|
||||
let syn_map_cleaner_task = async move {
|
||||
loop {
|
||||
syn_map.retain(|_, entry| {
|
||||
if entry.start_time.elapsed() > Duration::from_secs(30) {
|
||||
tracing::warn!(entry = ?entry, "syn nat entry expired");
|
||||
entry.state.store(NatDstEntryState::Closed);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
}
|
||||
};
|
||||
tasks.lock().await.spawn(syn_map_cleaner_task);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_listener(&self) -> Result<()> {
|
||||
// bind on both v4 & v6
|
||||
let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
|
||||
|
||||
let net_ns = self.global_ctx.net_ns.clone();
|
||||
let tcp_listener = net_ns
|
||||
.run_async(|| async { TcpListener::bind(&listen_addr).await })
|
||||
.await?;
|
||||
|
||||
self.local_port.store(
|
||||
tcp_listener.local_addr()?.port(),
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
);
|
||||
|
||||
let tasks = self.tasks.clone();
|
||||
let syn_map = self.syn_map.clone();
|
||||
let conn_map = self.conn_map.clone();
|
||||
let addr_conn_map = self.addr_conn_map.clone();
|
||||
let accept_task = async move {
|
||||
tracing::info!(listener = ?tcp_listener, "tcp connection start accepting");
|
||||
|
||||
let conn_map = conn_map.clone();
|
||||
while let Ok((tcp_stream, socket_addr)) = tcp_listener.accept().await {
|
||||
let Some(entry) = syn_map.get(&socket_addr) else {
|
||||
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
|
||||
continue;
|
||||
};
|
||||
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
|
||||
|
||||
let entry_clone = entry.clone();
|
||||
drop(entry);
|
||||
syn_map.remove_if(&socket_addr, |_, entry| entry.id == entry_clone.id);
|
||||
|
||||
entry_clone.state.store(NatDstEntryState::ConnectingDst);
|
||||
|
||||
let _ = addr_conn_map.insert(entry_clone.src, entry_clone.clone());
|
||||
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
|
||||
assert!(old_nat_val.is_none());
|
||||
|
||||
tasks.lock().await.spawn(Self::connect_to_nat_dst(
|
||||
net_ns.clone(),
|
||||
tcp_stream,
|
||||
conn_map.clone(),
|
||||
addr_conn_map.clone(),
|
||||
entry_clone,
|
||||
));
|
||||
}
|
||||
tracing::error!("nat tcp listener exited");
|
||||
panic!("nat tcp listener exited");
|
||||
};
|
||||
self.tasks
|
||||
.lock()
|
||||
.await
|
||||
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_entry_from_all_conn_map(
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
conn_map.remove(&nat_entry.id);
|
||||
addr_conn_map.remove_if(&nat_entry.src, |_, entry| entry.id == nat_entry.id);
|
||||
}
|
||||
|
||||
async fn connect_to_nat_dst(
|
||||
net_ns: NetNS,
|
||||
src_tcp_stream: TcpStream,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
if let Err(e) = src_tcp_stream.set_nodelay(true) {
|
||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||
}
|
||||
|
||||
let _guard = net_ns.guard();
|
||||
let socket = TcpSocket::new_v4().unwrap();
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||
}
|
||||
let Ok(Ok(dst_tcp_stream)) = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
TcpSocket::new_v4().unwrap().connect(nat_entry.dst),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
tracing::error!("connect to dst failed: {:?}", nat_entry);
|
||||
nat_entry.state.store(NatDstEntryState::Closed);
|
||||
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry);
|
||||
return;
|
||||
};
|
||||
drop(_guard);
|
||||
|
||||
assert_eq!(nat_entry.state.load(), NatDstEntryState::ConnectingDst);
|
||||
nat_entry.state.store(NatDstEntryState::Connected);
|
||||
|
||||
Self::handle_nat_connection(
|
||||
src_tcp_stream,
|
||||
dst_tcp_stream,
|
||||
conn_map,
|
||||
addr_conn_map,
|
||||
nat_entry,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_nat_connection(
|
||||
mut src_tcp_stream: TcpStream,
|
||||
mut dst_tcp_stream: TcpStream,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
let nat_entry_clone = nat_entry.clone();
|
||||
nat_entry.tasks.lock().await.spawn(async move {
|
||||
let ret = copy_bidirectional(&mut src_tcp_stream, &mut dst_tcp_stream).await;
|
||||
tracing::trace!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
|
||||
nat_entry_clone.state.store(NatDstEntryState::Closed);
|
||||
|
||||
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry_clone);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn get_local_port(&self) -> u16 {
|
||||
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
use std::borrow::BorrowMut;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use pnet::packet::ethernet::EthernetPacket;
|
||||
use pnet::packet::ipv4::Ipv4Packet;
|
||||
|
||||
use tokio::{sync::Mutex, task::JoinSet};
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
use tonic::transport::Server;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::common::config_fs::ConfigFs;
|
||||
use crate::common::error::Error;
|
||||
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx};
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::connector::direct::DirectConnectorManager;
|
||||
use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager};
|
||||
use crate::connector::udp_hole_punch::UdpHolePunchConnector;
|
||||
use crate::gateway::icmp_proxy::IcmpProxy;
|
||||
use crate::gateway::tcp_proxy::TcpProxy;
|
||||
use crate::peers::peer_manager::PeerManager;
|
||||
use crate::peers::rip_route::BasicRoute;
|
||||
use crate::peers::rpc_service::PeerManagerRpcService;
|
||||
use crate::tunnels::SinkItem;
|
||||
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
use super::listeners::ListenerManager;
|
||||
use super::virtual_nic;
|
||||
|
||||
pub struct InstanceConfigWriter {
|
||||
config: ConfigFs,
|
||||
}
|
||||
|
||||
impl InstanceConfigWriter {
|
||||
pub fn new(inst_name: &str) -> Self {
|
||||
InstanceConfigWriter {
|
||||
config: ConfigFs::new(inst_name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_ns(self, net_ns: Option<String>) -> Self {
|
||||
let net_ns_in_conf = if let Some(net_ns) = net_ns {
|
||||
net_ns
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
self.config
|
||||
.add_file("net_ns")
|
||||
.unwrap()
|
||||
.write_all(net_ns_in_conf.as_bytes())
|
||||
.unwrap();
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_addr(self, addr: String) -> Self {
|
||||
self.config
|
||||
.add_file("ipv4")
|
||||
.unwrap()
|
||||
.write_all(addr.as_bytes())
|
||||
.unwrap();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Instance {
|
||||
inst_name: String,
|
||||
|
||||
id: uuid::Uuid,
|
||||
|
||||
virtual_nic: Option<Arc<virtual_nic::VirtualNic>>,
|
||||
peer_packet_receiver: Option<ReceiverStream<SinkItem>>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
|
||||
peer_manager: Arc<PeerManager>,
|
||||
listener_manager: Arc<Mutex<ListenerManager<PeerManager>>>,
|
||||
conn_manager: Arc<ManualConnectorManager>,
|
||||
direct_conn_manager: Arc<DirectConnectorManager>,
|
||||
udp_hole_puncher: Arc<Mutex<UdpHolePunchConnector>>,
|
||||
|
||||
tcp_proxy: Arc<TcpProxy>,
|
||||
icmp_proxy: Arc<IcmpProxy>,
|
||||
|
||||
global_ctx: ArcGlobalCtx,
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
pub fn new(inst_name: &str) -> Self {
|
||||
let config = ConfigFs::new(inst_name);
|
||||
let net_ns_in_conf = config.get_or_default("net_ns", || "".to_string()).unwrap();
|
||||
let net_ns = NetNS::new(if net_ns_in_conf.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(net_ns_in_conf.clone())
|
||||
});
|
||||
|
||||
let addr = config
|
||||
.get_or_default("ipv4", || "10.144.144.10".to_string())
|
||||
.unwrap();
|
||||
|
||||
log::info!(
|
||||
"[INIT] instance creating. inst_name: {}, addr: {}, netns: {}",
|
||||
inst_name,
|
||||
addr,
|
||||
net_ns_in_conf
|
||||
);
|
||||
|
||||
let (peer_packet_sender, peer_packet_receiver) = tokio::sync::mpsc::channel(100);
|
||||
|
||||
let global_ctx = Arc::new(GlobalCtx::new(inst_name, config, net_ns.clone()));
|
||||
|
||||
let id = global_ctx.get_id();
|
||||
|
||||
let peer_manager = Arc::new(PeerManager::new(
|
||||
global_ctx.clone(),
|
||||
peer_packet_sender.clone(),
|
||||
));
|
||||
|
||||
let listener_manager = Arc::new(Mutex::new(ListenerManager::new(
|
||||
id,
|
||||
net_ns.clone(),
|
||||
peer_manager.clone(),
|
||||
)));
|
||||
|
||||
let conn_manager = Arc::new(ManualConnectorManager::new(
|
||||
id,
|
||||
global_ctx.clone(),
|
||||
peer_manager.clone(),
|
||||
));
|
||||
|
||||
let mut direct_conn_manager =
|
||||
DirectConnectorManager::new(id, global_ctx.clone(), peer_manager.clone());
|
||||
direct_conn_manager.run();
|
||||
|
||||
let udp_hole_puncher = UdpHolePunchConnector::new(global_ctx.clone(), peer_manager.clone());
|
||||
|
||||
let arc_tcp_proxy = TcpProxy::new(global_ctx.clone(), peer_manager.clone());
|
||||
let arc_icmp_proxy = IcmpProxy::new(global_ctx.clone(), peer_manager.clone()).unwrap();
|
||||
|
||||
Instance {
|
||||
inst_name: inst_name.to_string(),
|
||||
id,
|
||||
|
||||
virtual_nic: None,
|
||||
peer_packet_receiver: Some(ReceiverStream::new(peer_packet_receiver)),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
peer_manager,
|
||||
listener_manager,
|
||||
conn_manager,
|
||||
direct_conn_manager: Arc::new(direct_conn_manager),
|
||||
udp_hole_puncher: Arc::new(Mutex::new(udp_hole_puncher)),
|
||||
|
||||
tcp_proxy: arc_tcp_proxy,
|
||||
icmp_proxy: arc_icmp_proxy,
|
||||
|
||||
global_ctx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_manager(&self) -> Arc<ManualConnectorManager> {
|
||||
self.conn_manager.clone()
|
||||
}
|
||||
|
||||
async fn do_forward_nic_to_peers_ipv4(ret: BytesMut, mgr: &PeerManager) {
|
||||
if let Some(ipv4) = Ipv4Packet::new(&ret) {
|
||||
if ipv4.get_version() != 4 {
|
||||
tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4);
|
||||
}
|
||||
let dst_ipv4 = ipv4.get_destination();
|
||||
tracing::trace!(
|
||||
?ret,
|
||||
"[USER_PACKET] recv new packet from tun device and forward to peers."
|
||||
);
|
||||
let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await;
|
||||
if send_ret.is_err() {
|
||||
tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed")
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(?ret, "[USER_PACKET] not ipv4 packet");
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) {
|
||||
if let Some(eth) = EthernetPacket::new(&ret) {
|
||||
log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype());
|
||||
Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await;
|
||||
} else {
|
||||
log::warn!("not ipv4 packet: {:?}", ret);
|
||||
}
|
||||
}
|
||||
|
||||
fn do_forward_nic_to_peers(&mut self) -> Result<(), Error> {
|
||||
// read from nic and write to corresponding tunnel
|
||||
let nic = self.virtual_nic.as_ref().unwrap();
|
||||
let nic = nic.clone();
|
||||
let mgr = self.peer_manager.clone();
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
let mut stream = nic.pin_recv_stream();
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
log::error!("read from nic failed: {:?}", ret);
|
||||
break;
|
||||
}
|
||||
Self::do_forward_nic_to_peers_ipv4(ret.unwrap(), mgr.as_ref()).await;
|
||||
// Self::do_forward_nic_to_peers_ethernet(ret.into(), mgr.as_ref()).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn do_forward_peers_to_nic(
|
||||
tasks: &mut JoinSet<()>,
|
||||
nic: Arc<virtual_nic::VirtualNic>,
|
||||
channel: Option<ReceiverStream<Bytes>>,
|
||||
) {
|
||||
tasks.spawn(async move {
|
||||
let send = nic.pin_send_stream();
|
||||
let channel = channel.unwrap();
|
||||
let ret = channel
|
||||
.map(|packet| {
|
||||
log::trace!(
|
||||
"[USER_PACKET] forward packet from peers to nic. packet: {:?}",
|
||||
packet
|
||||
);
|
||||
Ok(packet)
|
||||
})
|
||||
.forward(send)
|
||||
.await;
|
||||
if ret.is_err() {
|
||||
panic!("do_forward_tunnel_to_nic");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
let ipv4_addr = self.global_ctx.get_ipv4().unwrap();
|
||||
|
||||
let mut nic = virtual_nic::VirtualNic::new(self.get_global_ctx())
|
||||
.create_dev()
|
||||
.await?
|
||||
.link_up()
|
||||
.await?
|
||||
.remove_ip(None)
|
||||
.await?
|
||||
.add_ip(ipv4_addr, 24)
|
||||
.await?;
|
||||
|
||||
if cfg!(target_os = "macos") {
|
||||
nic = nic.add_route(ipv4_addr, 24).await?;
|
||||
}
|
||||
|
||||
self.virtual_nic = Some(Arc::new(nic));
|
||||
|
||||
self.do_forward_nic_to_peers().unwrap();
|
||||
Self::do_forward_peers_to_nic(
|
||||
self.tasks.borrow_mut(),
|
||||
self.virtual_nic.as_ref().unwrap().clone(),
|
||||
self.peer_packet_receiver.take(),
|
||||
);
|
||||
|
||||
self.listener_manager
|
||||
.lock()
|
||||
.await
|
||||
.prepare_listeners()
|
||||
.await?;
|
||||
self.listener_manager.lock().await.run().await?;
|
||||
self.peer_manager.run().await?;
|
||||
|
||||
let route = BasicRoute::new(self.id(), self.global_ctx.clone());
|
||||
self.peer_manager.set_route(route).await;
|
||||
|
||||
self.run_rpc_server().unwrap();
|
||||
|
||||
self.tcp_proxy.start().await.unwrap();
|
||||
self.icmp_proxy.start().await.unwrap();
|
||||
self.run_proxy_cidrs_route_updater();
|
||||
|
||||
self.udp_hole_puncher.lock().await.run().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_peer_manager(&self) -> Arc<PeerManager> {
|
||||
self.peer_manager.clone()
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(&mut self, peer_id: &Uuid, conn_id: &Uuid) -> Result<(), Error> {
|
||||
self.peer_manager
|
||||
.get_peer_map()
|
||||
.close_peer_conn(peer_id, conn_id)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn wait(&mut self) {
|
||||
while let Some(ret) = self.tasks.join_next().await {
|
||||
log::info!("task finished: {:?}", ret);
|
||||
ret.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> uuid::Uuid {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn run_rpc_server(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let addr = "0.0.0.0:15888".parse()?;
|
||||
let peer_mgr = self.peer_manager.clone();
|
||||
let conn_manager = self.conn_manager.clone();
|
||||
let net_ns = self.global_ctx.net_ns.clone();
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
let _g = net_ns.guard();
|
||||
log::info!("[INIT RPC] start rpc server. addr: {}", addr);
|
||||
Server::builder()
|
||||
.add_service(
|
||||
easytier_rpc::peer_manage_rpc_server::PeerManageRpcServer::new(
|
||||
PeerManagerRpcService::new(peer_mgr),
|
||||
),
|
||||
)
|
||||
.add_service(
|
||||
easytier_rpc::connector_manage_rpc_server::ConnectorManageRpcServer::new(
|
||||
ConnectorManagerRpcService(conn_manager.clone()),
|
||||
),
|
||||
)
|
||||
.serve(addr)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_proxy_cidrs_route_updater(&mut self) {
|
||||
let peer_mgr = self.peer_manager.clone();
|
||||
let net_ns = self.global_ctx.net_ns.clone();
|
||||
let nic = self.virtual_nic.as_ref().unwrap().clone();
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
let mut cur_proxy_cidrs = vec![];
|
||||
loop {
|
||||
let mut proxy_cidrs = vec![];
|
||||
let routes = peer_mgr.list_routes().await;
|
||||
for r in routes {
|
||||
for cidr in r.proxy_cidrs {
|
||||
let Ok(cidr) = cidr.parse::<cidr::Ipv4Cidr>() else {
|
||||
continue;
|
||||
};
|
||||
proxy_cidrs.push(cidr);
|
||||
}
|
||||
}
|
||||
|
||||
// if route is in cur_proxy_cidrs but not in proxy_cidrs, delete it.
|
||||
for cidr in cur_proxy_cidrs.iter() {
|
||||
if proxy_cidrs.contains(cidr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let _g = net_ns.guard();
|
||||
let ret = nic
|
||||
.get_ifcfg()
|
||||
.remove_ipv4_route(
|
||||
nic.ifname(),
|
||||
cidr.first_address(),
|
||||
cidr.network_length(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if ret.is_err() {
|
||||
tracing::trace!(
|
||||
cidr = ?cidr,
|
||||
err = ?ret,
|
||||
"remove route failed.",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for cidr in proxy_cidrs.iter() {
|
||||
if cur_proxy_cidrs.contains(cidr) {
|
||||
continue;
|
||||
}
|
||||
let _g = net_ns.guard();
|
||||
let ret = nic
|
||||
.get_ifcfg()
|
||||
.add_ipv4_route(nic.ifname(), cidr.first_address(), cidr.network_length())
|
||||
.await;
|
||||
|
||||
if ret.is_err() {
|
||||
tracing::trace!(
|
||||
cidr = ?cidr,
|
||||
err = ?ret,
|
||||
"add route failed.",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
cur_proxy_cidrs = proxy_cidrs;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
|
||||
self.global_ctx.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::{sync::Mutex, task::JoinSet};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, netns::NetNS},
|
||||
peers::peer_manager::PeerManager,
|
||||
tunnels::{
|
||||
ring_tunnel::RingTunnelListener, tcp_tunnel::TcpTunnelListener,
|
||||
udp_tunnel::UdpTunnelListener, Tunnel, TunnelListener,
|
||||
},
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait TunnelHandlerForListener {
|
||||
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelHandlerForListener for PeerManager {
|
||||
#[tracing::instrument]
|
||||
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||
self.add_tunnel_as_server(tunnel).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ListenerManager<H> {
|
||||
my_node_id: uuid::Uuid,
|
||||
net_ns: NetNS,
|
||||
listeners: Vec<Arc<Mutex<dyn TunnelListener>>>,
|
||||
peer_manager: Arc<H>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManager<H> {
|
||||
pub fn new(my_node_id: uuid::Uuid, net_ns: NetNS, peer_manager: Arc<H>) -> Self {
|
||||
Self {
|
||||
my_node_id,
|
||||
net_ns,
|
||||
listeners: Vec::new(),
|
||||
peer_manager,
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prepare_listeners(&mut self) -> Result<(), Error> {
|
||||
self.add_listener(UdpTunnelListener::new(
|
||||
"udp://0.0.0.0:11010".parse().unwrap(),
|
||||
))
|
||||
.await?;
|
||||
self.add_listener(TcpTunnelListener::new(
|
||||
"tcp://0.0.0.0:11010".parse().unwrap(),
|
||||
))
|
||||
.await?;
|
||||
self.add_listener(RingTunnelListener::new(
|
||||
format!("ring://{}", self.my_node_id).parse().unwrap(),
|
||||
))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_listener<Listener>(&mut self, listener: Listener) -> Result<(), Error>
|
||||
where
|
||||
Listener: TunnelListener + 'static,
|
||||
{
|
||||
let listener = Arc::new(Mutex::new(listener));
|
||||
self.listeners.push(listener);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn run_listener(listener: Arc<Mutex<dyn TunnelListener>>, peer_manager: Arc<H>) {
|
||||
let mut l = listener.lock().await;
|
||||
while let Ok(ret) = l.accept().await {
|
||||
tracing::info!(ret = ?ret, "conn accepted");
|
||||
let server_ret = peer_manager.handle_tunnel(ret).await;
|
||||
if let Err(e) = &server_ret {
|
||||
tracing::error!(error = ?e, "handle conn error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
for listener in &self.listeners {
|
||||
let _guard = self.net_ns.guard();
|
||||
log::warn!("run listener: {:?}", listener);
|
||||
listener.lock().await.listen().await?;
|
||||
self.tasks.spawn(Self::run_listener(
|
||||
listener.clone(),
|
||||
self.peer_manager.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::tunnels::{ring_tunnel::RingTunnelConnector, TunnelConnector};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MockListenerHandler {}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelHandlerForListener for MockListenerHandler {
|
||||
async fn handle_tunnel(&self, _tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||
let data = "abc";
|
||||
_tunnel.pin_sink().send(data.into()).await.unwrap();
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_error_in_accept() {
|
||||
let net_ns = NetNS::new(None);
|
||||
let handler = Arc::new(MockListenerHandler {});
|
||||
let mut listener_mgr =
|
||||
ListenerManager::new(uuid::Uuid::new_v4(), net_ns.clone(), handler.clone());
|
||||
|
||||
let ring_id = format!("ring://{}", uuid::Uuid::new_v4());
|
||||
|
||||
listener_mgr
|
||||
.add_listener(RingTunnelListener::new(ring_id.parse().unwrap()))
|
||||
.await
|
||||
.unwrap();
|
||||
listener_mgr.run().await.unwrap();
|
||||
|
||||
let connect_once = |ring_id| async move {
|
||||
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
|
||||
assert_eq!(tunnel.pin_stream().next().await.unwrap().unwrap(), "abc");
|
||||
tunnel
|
||||
};
|
||||
|
||||
timeout(std::time::Duration::from_secs(1), async move {
|
||||
connect_once(ring_id.parse().unwrap()).await;
|
||||
// handle tunnel fail should not impact the second connect
|
||||
connect_once(ring_id.parse().unwrap()).await;
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub mod instance;
|
||||
pub mod listeners;
|
||||
pub mod tun_codec;
|
||||
pub mod virtual_nic;
|
||||
@@ -0,0 +1,179 @@
|
||||
use std::io;
|
||||
|
||||
use byteorder::{NativeEndian, NetworkEndian, WriteBytesExt};
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
/// A packet protocol IP version
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
enum PacketProtocol {
|
||||
#[default]
|
||||
IPv4,
|
||||
IPv6,
|
||||
Other(u8),
|
||||
}
|
||||
|
||||
// Note: the protocol in the packet information header is platform dependent.
|
||||
impl PacketProtocol {
|
||||
#[cfg(any(target_os = "linux", target_os = "android"))]
|
||||
fn into_pi_field(self) -> Result<u16, io::Error> {
|
||||
use nix::libc;
|
||||
match self {
|
||||
PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16),
|
||||
PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16),
|
||||
PacketProtocol::Other(_) => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"neither an IPv4 nor IPv6 packet",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||
fn into_pi_field(self) -> Result<u16, io::Error> {
|
||||
use nix::libc;
|
||||
match self {
|
||||
PacketProtocol::IPv4 => Ok(libc::PF_INET as u16),
|
||||
PacketProtocol::IPv6 => Ok(libc::PF_INET6 as u16),
|
||||
PacketProtocol::Other(_) => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"neither an IPv4 nor IPv6 packet",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn into_pi_field(self) -> Result<u16, io::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TunPacketBuffer {
|
||||
Bytes(Bytes),
|
||||
BytesMut(BytesMut),
|
||||
}
|
||||
|
||||
impl From<TunPacketBuffer> for Bytes {
|
||||
fn from(buf: TunPacketBuffer) -> Self {
|
||||
match buf {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes,
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for TunPacketBuffer {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
match self {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tun Packet to be sent or received on the TUN interface.
|
||||
#[derive(Debug)]
|
||||
pub struct TunPacket(PacketProtocol, TunPacketBuffer);
|
||||
|
||||
/// Infer the protocol based on the first nibble in the packet buffer.
|
||||
fn infer_proto(buf: &[u8]) -> PacketProtocol {
|
||||
match buf[0] >> 4 {
|
||||
4 => PacketProtocol::IPv4,
|
||||
6 => PacketProtocol::IPv6,
|
||||
p => PacketProtocol::Other(p),
|
||||
}
|
||||
}
|
||||
|
||||
impl TunPacket {
|
||||
/// Create a new `TunPacket` based on a byte slice.
|
||||
pub fn new(buffer: TunPacketBuffer) -> TunPacket {
|
||||
let proto = infer_proto(buffer.as_ref());
|
||||
TunPacket(proto, buffer)
|
||||
}
|
||||
|
||||
/// Return this packet's bytes.
|
||||
pub fn get_bytes(&self) -> &[u8] {
|
||||
match &self.1 {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bytes(self) -> Bytes {
|
||||
match self.1 {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes,
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bytes_mut(self) -> BytesMut {
|
||||
match self.1 {
|
||||
TunPacketBuffer::Bytes(_) => panic!("cannot into_bytes_mut from bytes"),
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A TunPacket Encoder/Decoder.
|
||||
pub struct TunPacketCodec(bool, i32);
|
||||
|
||||
impl TunPacketCodec {
|
||||
/// Create a new `TunPacketCodec` specifying whether the underlying
|
||||
/// tunnel Device has enabled the packet information header.
|
||||
pub fn new(pi: bool, mtu: i32) -> TunPacketCodec {
|
||||
TunPacketCodec(pi, mtu)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for TunPacketCodec {
|
||||
type Item = TunPacket;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if buf.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut pkt = buf.split_to(buf.len());
|
||||
|
||||
// reserve enough space for the next packet
|
||||
if self.0 {
|
||||
buf.reserve(self.1 as usize + 4);
|
||||
} else {
|
||||
buf.reserve(self.1 as usize);
|
||||
}
|
||||
|
||||
// if the packet information is enabled we have to ignore the first 4 bytes
|
||||
if self.0 {
|
||||
let _ = pkt.split_to(4);
|
||||
}
|
||||
|
||||
let proto = infer_proto(pkt.as_ref());
|
||||
Ok(Some(TunPacket(proto, TunPacketBuffer::BytesMut(pkt))))
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<TunPacket> for TunPacketCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: TunPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
dst.reserve(item.get_bytes().len() + 4);
|
||||
match item {
|
||||
TunPacket(proto, bytes) if self.0 => {
|
||||
// build the packet information header comprising of 2 u16
|
||||
// fields: flags and protocol.
|
||||
let mut buf = Vec::<u8>::with_capacity(4);
|
||||
|
||||
// flags is always 0
|
||||
buf.write_u16::<NativeEndian>(0)?;
|
||||
// write the protocol as network byte order
|
||||
buf.write_u16::<NetworkEndian>(proto.into_pi_field()?)?;
|
||||
|
||||
dst.put_slice(&buf);
|
||||
dst.put(Bytes::from(bytes));
|
||||
}
|
||||
TunPacket(_, bytes) => dst.put(Bytes::from(bytes)),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
use std::{net::Ipv4Addr, pin::Pin};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Result,
|
||||
global_ctx::ArcGlobalCtx,
|
||||
ifcfg::{IfConfiger, IfConfiguerTrait},
|
||||
},
|
||||
tunnels::{
|
||||
codec::BytesCodec, common::FramedTunnel, DatagramSink, DatagramStream, Tunnel, TunnelError,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio_util::{bytes::Bytes, codec::Framed};
|
||||
use tun::Device;
|
||||
|
||||
use super::tun_codec::{TunPacket, TunPacketCodec};
|
||||
|
||||
pub struct VirtualNic {
|
||||
dev_name: String,
|
||||
queue_num: usize,
|
||||
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
ifname: Option<String>,
|
||||
tun: Option<Box<dyn Tunnel>>,
|
||||
ifcfg: Box<dyn IfConfiguerTrait + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
impl VirtualNic {
|
||||
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
|
||||
Self {
|
||||
dev_name: "".to_owned(),
|
||||
queue_num: 1,
|
||||
global_ctx,
|
||||
ifname: None,
|
||||
tun: None,
|
||||
ifcfg: Box::new(IfConfiger {}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dev_name(mut self, dev_name: &str) -> Result<Self> {
|
||||
self.dev_name = dev_name.to_owned();
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn set_queue_num(mut self, queue_num: usize) -> Result<Self> {
|
||||
self.queue_num = queue_num;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
async fn create_dev_ret_err(&mut self) -> Result<()> {
|
||||
let mut config = tun::Configuration::default();
|
||||
let has_packet_info = cfg!(target_os = "macos");
|
||||
config.layer(tun::Layer::L3);
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
config.platform(|config| {
|
||||
// detect protocol by ourselves for cross platform
|
||||
config.packet_information(false);
|
||||
});
|
||||
config.name(self.dev_name.clone());
|
||||
}
|
||||
|
||||
if self.queue_num != 1 {
|
||||
todo!("queue_num != 1")
|
||||
}
|
||||
config.queues(self.queue_num);
|
||||
config.up();
|
||||
|
||||
let dev = {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
tun::create_as_async(&config)?
|
||||
};
|
||||
let ifname = dev.get_ref().name()?;
|
||||
self.ifcfg.wait_interface_show(ifname.as_str()).await?;
|
||||
|
||||
let ft: Box<dyn Tunnel> = if has_packet_info {
|
||||
let framed = Framed::new(dev, TunPacketCodec::new(true, 2500));
|
||||
let (sink, stream) = framed.split();
|
||||
|
||||
let new_stream = stream.map(|item| match item {
|
||||
Ok(item) => Ok(item.into_bytes_mut()),
|
||||
Err(err) => {
|
||||
println!("tun stream error: {:?}", err);
|
||||
Err(TunnelError::TunError(err.to_string()))
|
||||
}
|
||||
});
|
||||
|
||||
let new_sink = Box::pin(sink.with(|item: Bytes| async move {
|
||||
if false {
|
||||
return Err(TunnelError::TunError("tun sink error".to_owned()));
|
||||
}
|
||||
Ok(TunPacket::new(super::tun_codec::TunPacketBuffer::Bytes(
|
||||
item,
|
||||
)))
|
||||
}));
|
||||
|
||||
Box::new(FramedTunnel::new(new_stream, new_sink, None))
|
||||
} else {
|
||||
let framed = Framed::new(dev, BytesCodec::new(2500));
|
||||
let (sink, stream) = framed.split();
|
||||
Box::new(FramedTunnel::new(stream, sink, None))
|
||||
};
|
||||
|
||||
self.ifname = Some(ifname.to_owned());
|
||||
self.tun = Some(ft);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_dev(mut self) -> Result<Self> {
|
||||
self.create_dev_ret_err().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn ifname(&self) -> &str {
|
||||
self.ifname.as_ref().unwrap().as_str()
|
||||
}
|
||||
|
||||
pub async fn link_up(self) -> Result<Self> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg.set_link_status(self.ifname(), true).await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub async fn add_route(self, address: Ipv4Addr, cidr: u8) -> Result<Self> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg
|
||||
.add_ipv4_route(self.ifname(), address, cidr)
|
||||
.await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub async fn remove_ip(self, ip: Option<Ipv4Addr>) -> Result<Self> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg.remove_ip(self.ifname(), ip).await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub async fn add_ip(self, ip: Ipv4Addr, cidr: i32) -> Result<Self> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg
|
||||
.add_ipv4_ip(self.ifname(), ip, cidr as u8)
|
||||
.await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn pin_recv_stream(&self) -> Pin<Box<dyn DatagramStream>> {
|
||||
self.tun.as_ref().unwrap().pin_stream()
|
||||
}
|
||||
|
||||
pub fn pin_send_stream(&self) -> Pin<Box<dyn DatagramSink>> {
|
||||
self.tun.as_ref().unwrap().pin_sink()
|
||||
}
|
||||
|
||||
pub fn get_ifcfg(&self) -> &dyn IfConfiguerTrait {
|
||||
self.ifcfg.as_ref()
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::tests::get_mock_global_ctx},
|
||||
tests::enable_log,
|
||||
};
|
||||
|
||||
use super::VirtualNic;
|
||||
|
||||
async fn run_test_helper() -> Result<VirtualNic, Error> {
|
||||
let dev = VirtualNic::new(get_mock_global_ctx()).create_dev().await?;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
|
||||
dev.link_up()
|
||||
.await?
|
||||
.remove_ip(None)
|
||||
.await?
|
||||
.add_ip("10.144.111.1".parse().unwrap(), 24)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tun_test() {
|
||||
enable_log();
|
||||
let _dev = run_test_helper().await.unwrap();
|
||||
|
||||
// let mut stream = nic.pin_recv_stream();
|
||||
// while let Some(item) = stream.next().await {
|
||||
// println!("item: {:?}", item);
|
||||
// }
|
||||
|
||||
// let framed = dev.into_framed();
|
||||
// let (mut s, mut b) = framed.split();
|
||||
// loop {
|
||||
// let tmp = b.next().await.unwrap().unwrap();
|
||||
// let tmp = EthernetPacket::new(tmp.get_bytes());
|
||||
// println!("ret: {:?}", tmp.unwrap());
|
||||
// }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
mod common;
|
||||
mod connector;
|
||||
mod gateway;
|
||||
mod instance;
|
||||
mod peers;
|
||||
mod tunnels;
|
||||
|
||||
use instance::instance::{Instance, InstanceConfigWriter};
|
||||
use tracing::level_filters::LevelFilter;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Cli {
|
||||
/// the instance name
|
||||
#[arg(short = 'n', long, default_value = "default")]
|
||||
instance_name: String,
|
||||
|
||||
/// specify the network namespace, default is the root namespace
|
||||
#[arg(long)]
|
||||
net_ns: Option<String>,
|
||||
|
||||
#[arg(short, long)]
|
||||
ipv4: Option<String>,
|
||||
|
||||
#[arg(short, long)]
|
||||
peers: Vec<String>,
|
||||
}
|
||||
|
||||
fn init_logger() {
|
||||
// logger to rolling file
|
||||
let file_filter = EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env()
|
||||
.unwrap();
|
||||
let file_appender = tracing_appender::rolling::Builder::new()
|
||||
.rotation(tracing_appender::rolling::Rotation::DAILY)
|
||||
.max_log_files(5)
|
||||
.filename_prefix("core.log")
|
||||
.build("/var/log/easytier")
|
||||
.expect("failed to initialize rolling file appender");
|
||||
let mut file_layer = tracing_subscriber::fmt::layer();
|
||||
file_layer.set_ansi(false);
|
||||
let file_layer = file_layer
|
||||
.with_writer(file_appender)
|
||||
.with_filter(file_filter);
|
||||
|
||||
// logger to console
|
||||
let console_filter = EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::WARN.into())
|
||||
.from_env()
|
||||
.unwrap();
|
||||
let console_layer = tracing_subscriber::fmt::layer()
|
||||
.pretty()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_filter(console_filter);
|
||||
|
||||
tracing_subscriber::Registry::default()
|
||||
.with(console_layer)
|
||||
.with(file_layer)
|
||||
.init();
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
#[tracing::instrument]
|
||||
pub async fn main() {
|
||||
init_logger();
|
||||
|
||||
let cli = Cli::parse();
|
||||
tracing::info!(cli = ?cli, "cli args parsed");
|
||||
|
||||
let cfg = InstanceConfigWriter::new(cli.instance_name.as_str()).set_ns(cli.net_ns.clone());
|
||||
if let Some(ipv4) = &cli.ipv4 {
|
||||
cfg.set_addr(ipv4.clone());
|
||||
}
|
||||
|
||||
let mut inst = Instance::new(cli.instance_name.as_str());
|
||||
|
||||
let mut events = inst.get_global_ctx().subscribe();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(e) = events.recv().await {
|
||||
log::warn!("event: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
inst.run().await.unwrap();
|
||||
|
||||
for peer in cli.peers {
|
||||
inst.get_conn_manager()
|
||||
.add_connector_by_url(peer.as_str())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
inst.wait().await;
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
pub mod packet;
|
||||
pub mod peer;
|
||||
pub mod peer_conn;
|
||||
pub mod peer_manager;
|
||||
pub mod peer_map;
|
||||
pub mod peer_rpc;
|
||||
pub mod rip_route;
|
||||
pub mod route_trait;
|
||||
pub mod rpc_service;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
pub type PeerId = uuid::Uuid;
|
||||
@@ -0,0 +1,205 @@
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::rkyv_util::{decode_from_bytes, encode_to_bytes};
|
||||
|
||||
const MAGIC: u32 = 0xd1e1a5e1;
|
||||
const VERSION: u32 = 1;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, PartialEq, Clone)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UUID(uuid::Bytes);
|
||||
|
||||
// impl Debug for UUID
|
||||
impl std::fmt::Debug for UUID {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let uuid = uuid::Uuid::from_bytes(self.0);
|
||||
write!(f, "{}", uuid)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<uuid::Uuid> for UUID {
|
||||
fn from(uuid: uuid::Uuid) -> Self {
|
||||
UUID(*uuid.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UUID> for uuid::Uuid {
|
||||
fn from(uuid: UUID) -> Self {
|
||||
uuid::Uuid::from_bytes(uuid.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl ArchivedUUID {
|
||||
pub fn to_uuid(&self) -> uuid::Uuid {
|
||||
uuid::Uuid::from_bytes(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ArchivedUUID> for UUID {
|
||||
fn from(uuid: &ArchivedUUID) -> Self {
|
||||
UUID(uuid.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct HandShake {
|
||||
pub magic: u32,
|
||||
pub my_peer_id: UUID,
|
||||
pub version: u32,
|
||||
pub features: Vec<String>,
|
||||
// pub interfaces: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct RoutePacket {
|
||||
pub route_id: u8,
|
||||
pub body: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub enum CtrlPacketBody {
|
||||
HandShake(HandShake),
|
||||
RoutePacket(RoutePacket),
|
||||
Ping,
|
||||
Pong,
|
||||
TaRpc(u32, bool, Vec<u8>), // u32: service_id, bool: is_req, Vec<u8>: rpc body
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct DataPacketBody {
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub enum PacketBody {
|
||||
Ctrl(CtrlPacketBody),
|
||||
Data(DataPacketBody),
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct Packet {
|
||||
pub from_peer: UUID,
|
||||
pub to_peer: Option<UUID>,
|
||||
pub body: PacketBody,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn decode(v: &[u8]) -> &ArchivedPacket {
|
||||
decode_from_bytes::<Packet>(v).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for Bytes {
|
||||
fn from(val: Packet) -> Self {
|
||||
encode_to_bytes::<_, 4096>(&val)
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn new_handshake(from_peer: uuid::Uuid) -> Self {
|
||||
Packet {
|
||||
from_peer: from_peer.into(),
|
||||
to_peer: None,
|
||||
body: PacketBody::Ctrl(CtrlPacketBody::HandShake(HandShake {
|
||||
magic: MAGIC,
|
||||
my_peer_id: from_peer.into(),
|
||||
version: VERSION,
|
||||
features: Vec::new(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_data_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid, data: &[u8]) -> Self {
|
||||
Packet {
|
||||
from_peer: from_peer.into(),
|
||||
to_peer: Some(to_peer.into()),
|
||||
body: PacketBody::Data(DataPacketBody {
|
||||
data: data.to_vec(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_route_packet(
|
||||
from_peer: uuid::Uuid,
|
||||
to_peer: uuid::Uuid,
|
||||
route_id: u8,
|
||||
data: &[u8],
|
||||
) -> Self {
|
||||
Packet {
|
||||
from_peer: from_peer.into(),
|
||||
to_peer: Some(to_peer.into()),
|
||||
body: PacketBody::Ctrl(CtrlPacketBody::RoutePacket(RoutePacket {
|
||||
route_id,
|
||||
body: data.to_vec(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_ping_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid) -> Self {
|
||||
Packet {
|
||||
from_peer: from_peer.into(),
|
||||
to_peer: Some(to_peer.into()),
|
||||
body: PacketBody::Ctrl(CtrlPacketBody::Ping),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_pong_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid) -> Self {
|
||||
Packet {
|
||||
from_peer: from_peer.into(),
|
||||
to_peer: Some(to_peer.into()),
|
||||
body: PacketBody::Ctrl(CtrlPacketBody::Pong),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tarpc_packet(
|
||||
from_peer: uuid::Uuid,
|
||||
to_peer: uuid::Uuid,
|
||||
service_id: u32,
|
||||
is_req: bool,
|
||||
body: Vec<u8>,
|
||||
) -> Self {
|
||||
Packet {
|
||||
from_peer: from_peer.into(),
|
||||
to_peer: Some(to_peer.into()),
|
||||
body: PacketBody::Ctrl(CtrlPacketBody::TaRpc(service_id, is_req, body)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn serialize() {
|
||||
let a = "abcde";
|
||||
let out = Packet::new_data_packet(uuid::Uuid::new_v4(), uuid::Uuid::new_v4(), a.as_bytes());
|
||||
// let out = T::new(a.as_bytes());
|
||||
let out_bytes: Bytes = out.into();
|
||||
println!("out str: {:?}", a.as_bytes());
|
||||
println!("out bytes: {:?}", out_bytes);
|
||||
|
||||
let archived = Packet::decode(&out_bytes[..]);
|
||||
println!("in packet: {:?}", archived);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use easytier_rpc::PeerConnInfo;
|
||||
|
||||
use tokio::{
|
||||
select,
|
||||
sync::{mpsc, Mutex},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
};
|
||||
|
||||
use super::peer_conn::PeerConn;
|
||||
|
||||
type ArcPeerConn = Arc<Mutex<PeerConn>>;
|
||||
type ConnMap = Arc<DashMap<Uuid, ArcPeerConn>>;
|
||||
|
||||
pub struct Peer {
|
||||
pub peer_node_id: uuid::Uuid,
|
||||
conns: ConnMap,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
||||
|
||||
close_event_sender: mpsc::Sender<Uuid>,
|
||||
close_event_listener: JoinHandle<()>,
|
||||
|
||||
shutdown_notifier: Arc<tokio::sync::Notify>,
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
peer_node_id: uuid::Uuid,
|
||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
) -> Self {
|
||||
let conns: ConnMap = Arc::new(DashMap::new());
|
||||
let (close_event_sender, mut close_event_receiver) = mpsc::channel(10);
|
||||
let shutdown_notifier = Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
let conns_copy = conns.clone();
|
||||
let shutdown_notifier_copy = shutdown_notifier.clone();
|
||||
let global_ctx_copy = global_ctx.clone();
|
||||
let close_event_listener = tokio::spawn(
|
||||
async move {
|
||||
loop {
|
||||
select! {
|
||||
ret = close_event_receiver.recv() => {
|
||||
if ret.is_none() {
|
||||
break;
|
||||
}
|
||||
let ret = ret.unwrap();
|
||||
tracing::warn!(
|
||||
?peer_node_id,
|
||||
?ret,
|
||||
"notified that peer conn is closed",
|
||||
);
|
||||
|
||||
if let Some((_, conn)) = conns_copy.remove(&ret) {
|
||||
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
|
||||
conn.lock().await.get_conn_info(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
_ = shutdown_notifier_copy.notified() => {
|
||||
close_event_receiver.close();
|
||||
tracing::warn!(?peer_node_id, "peer close event listener notified");
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!("peer {} close event listener exit", peer_node_id);
|
||||
}
|
||||
.instrument(tracing::info_span!(
|
||||
"peer_close_event_listener",
|
||||
?peer_node_id,
|
||||
)),
|
||||
);
|
||||
|
||||
Peer {
|
||||
peer_node_id,
|
||||
conns: conns.clone(),
|
||||
packet_recv_chan,
|
||||
global_ctx,
|
||||
|
||||
close_event_sender,
|
||||
close_event_listener,
|
||||
|
||||
shutdown_notifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
|
||||
conn.set_close_event_sender(self.close_event_sender.clone());
|
||||
conn.start_recv_loop(self.packet_recv_chan.clone());
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
|
||||
self.conns
|
||||
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
|
||||
let Some(conn) = self.conns.iter().next() else {
|
||||
return Err(Error::PeerNoConnectionError(self.peer_node_id));
|
||||
};
|
||||
|
||||
let conn_clone = conn.clone();
|
||||
drop(conn);
|
||||
conn_clone.lock().await.send_msg(msg).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(&self, conn_id: &Uuid) -> Result<(), Error> {
|
||||
let has_key = self.conns.contains_key(conn_id);
|
||||
if !has_key {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
self.close_event_sender.send(conn_id.clone()).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_peer_conns(&self) -> Vec<PeerConnInfo> {
|
||||
let mut conns = vec![];
|
||||
for conn in self.conns.iter() {
|
||||
// do not lock here, otherwise it will cause dashmap deadlock
|
||||
conns.push(conn.clone());
|
||||
}
|
||||
|
||||
let mut ret = Vec::new();
|
||||
for conn in conns {
|
||||
ret.push(conn.lock().await.get_conn_info());
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
// pritn on drop
|
||||
impl Drop for Peer {
|
||||
fn drop(&mut self) {
|
||||
self.shutdown_notifier.notify_one();
|
||||
tracing::info!("peer {} drop", self.peer_node_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::{sync::mpsc, time::timeout};
|
||||
|
||||
use crate::{
|
||||
common::{config_fs::ConfigFs, global_ctx::GlobalCtx, netns::NetNS},
|
||||
peers::peer_conn::PeerConn,
|
||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
||||
};
|
||||
|
||||
use super::Peer;
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_peer() {
|
||||
let (local_packet_send, _local_packet_recv) = mpsc::channel(10);
|
||||
let (remote_packet_send, _remote_packet_recv) = mpsc::channel(10);
|
||||
let global_ctx = Arc::new(GlobalCtx::new(
|
||||
"test",
|
||||
ConfigFs::new("/tmp/easytier-test"),
|
||||
NetNS::new(None),
|
||||
));
|
||||
let local_peer = Peer::new(uuid::Uuid::new_v4(), local_packet_send, global_ctx.clone());
|
||||
let remote_peer = Peer::new(uuid::Uuid::new_v4(), remote_packet_send, global_ctx.clone());
|
||||
|
||||
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
|
||||
let mut local_peer_conn =
|
||||
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
|
||||
let mut remote_peer_conn =
|
||||
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
|
||||
|
||||
assert!(!local_peer_conn.handshake_done());
|
||||
assert!(!remote_peer_conn.handshake_done());
|
||||
|
||||
let (a, b) = tokio::join!(
|
||||
local_peer_conn.do_handshake_as_client(),
|
||||
remote_peer_conn.do_handshake_as_server()
|
||||
);
|
||||
a.unwrap();
|
||||
b.unwrap();
|
||||
|
||||
let local_conn_id = local_peer_conn.get_conn_id();
|
||||
|
||||
local_peer.add_peer_conn(local_peer_conn).await;
|
||||
remote_peer.add_peer_conn(remote_peer_conn).await;
|
||||
|
||||
assert_eq!(local_peer.list_peer_conns().await.len(), 1);
|
||||
assert_eq!(remote_peer.list_peer_conns().await.len(), 1);
|
||||
|
||||
let close_handler =
|
||||
tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await });
|
||||
|
||||
// wait for remote peer conn close
|
||||
timeout(std::time::Duration::from_secs(5), async {
|
||||
while (&remote_peer).list_peer_conns().await.len() != 0 {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("wait for close handler");
|
||||
close_handler.await.unwrap().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,484 @@
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
|
||||
use easytier_rpc::{PeerConnInfo, PeerConnStats};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use pnet::datalink::NetworkInterface;
|
||||
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc},
|
||||
task::JoinSet,
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
|
||||
use tokio_util::{
|
||||
bytes::{Bytes, BytesMut},
|
||||
sync::PollSender,
|
||||
};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::global_ctx::ArcGlobalCtx,
|
||||
define_tunnel_filter_chain,
|
||||
tunnels::{
|
||||
stats::{Throughput, WindowLatency},
|
||||
tunnel_filter::StatsRecorderTunnelFilter,
|
||||
DatagramSink, Tunnel, TunnelError,
|
||||
},
|
||||
};
|
||||
|
||||
use super::packet::{self, ArchivedCtrlPacketBody, ArchivedHandShake, Packet};
|
||||
|
||||
pub type PacketRecvChan = mpsc::Sender<Bytes>;
|
||||
|
||||
macro_rules! wait_response {
|
||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
|
||||
if rsp_vec.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
let rsp_vec = rsp_vec.unwrap().unwrap()?;
|
||||
|
||||
let $out_var;
|
||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||
match &rsp_bytes.body {
|
||||
$pattern => $out_var = $value,
|
||||
_ => {
|
||||
log::error!(
|
||||
"unexpected packet: {:?}, pattern: {:?}",
|
||||
rsp_bytes,
|
||||
stringify!($pattern)
|
||||
);
|
||||
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub struct PeerInfo {
|
||||
magic: u32,
|
||||
pub my_peer_id: uuid::Uuid,
|
||||
version: u32,
|
||||
pub features: Vec<String>,
|
||||
pub interfaces: Vec<NetworkInterface>,
|
||||
}
|
||||
|
||||
impl<'a> From<&ArchivedHandShake> for PeerInfo {
|
||||
fn from(hs: &ArchivedHandShake) -> Self {
|
||||
PeerInfo {
|
||||
magic: hs.magic.into(),
|
||||
my_peer_id: hs.my_peer_id.to_uuid(),
|
||||
version: hs.version.into(),
|
||||
features: hs.features.iter().map(|x| x.to_string()).collect(),
|
||||
interfaces: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
|
||||
|
||||
pub struct PeerConn {
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
my_node_id: uuid::Uuid,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
|
||||
info: Option<PeerInfo>,
|
||||
|
||||
close_event_sender: Option<mpsc::Sender<uuid::Uuid>>,
|
||||
|
||||
ctrl_resp_sender: broadcast::Sender<Bytes>,
|
||||
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
throughput: Arc<Throughput>,
|
||||
}
|
||||
|
||||
enum PeerConnPacketType {
|
||||
Data(Bytes),
|
||||
CtrlReq(Bytes),
|
||||
CtrlResp(Bytes),
|
||||
}
|
||||
|
||||
static CTRL_REQ_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0];
|
||||
static CTRL_RESP_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1];
|
||||
|
||||
impl PeerConn {
|
||||
pub fn new(node_id: uuid::Uuid, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||
let peer_conn_tunnel = PeerConnTunnel::new();
|
||||
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
|
||||
|
||||
PeerConn {
|
||||
conn_id: uuid::Uuid::new_v4(),
|
||||
|
||||
my_node_id: node_id,
|
||||
global_ctx,
|
||||
|
||||
sink: tunnel.pin_sink(),
|
||||
tunnel: Box::new(tunnel),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
info: None,
|
||||
close_event_sender: None,
|
||||
|
||||
ctrl_resp_sender: ctrl_sender,
|
||||
|
||||
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_id(&self) -> uuid::Uuid {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
wait_response!(stream, hs_req, packet::ArchivedPacketBody::Ctrl(ArchivedCtrlPacketBody::HandShake(x)) => x);
|
||||
self.info = Some(PeerInfo::from(hs_req));
|
||||
log::info!("handshake request: {:?}", hs_req);
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_node_id));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_node_id));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
wait_response!(stream, hs_rsp, packet::ArchivedPacketBody::Ctrl(ArchivedCtrlPacketBody::HandShake(x)) => x);
|
||||
self.info = Some(PeerInfo::from(hs_rsp));
|
||||
log::info!("handshake response: {:?}", hs_rsp);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn handshake_done(&self) -> bool {
|
||||
self.info.is_some()
|
||||
}
|
||||
|
||||
async fn do_pingpong_once(
|
||||
my_node_id: uuid::Uuid,
|
||||
peer_id: uuid::Uuid,
|
||||
sink: &mut Pin<Box<dyn DatagramSink>>,
|
||||
receiver: &mut broadcast::Receiver<Bytes>,
|
||||
) -> Result<u128, TunnelError> {
|
||||
// should add seq here. so latency can be calculated more accurately
|
||||
let req = Self::build_ctrl_msg(
|
||||
packet::Packet::new_ping_packet(my_node_id, peer_id).into(),
|
||||
true,
|
||||
);
|
||||
log::trace!("send ping packet: {:?}", req);
|
||||
sink.send(req).await?;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// wait until we get a pong packet in ctrl_resp_receiver
|
||||
let resp = timeout(Duration::from_secs(4), async {
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(p) => {
|
||||
if let packet::ArchivedPacketBody::Ctrl(
|
||||
packet::ArchivedCtrlPacketBody::Pong,
|
||||
) = &Packet::decode(&p).body
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("recv pong resp error: {:?}", e);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"recv pong resp error".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
if resp.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait ping response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
if resp.as_ref().unwrap().is_err() {
|
||||
return Err(resp.unwrap().err().unwrap());
|
||||
}
|
||||
|
||||
Ok(now.elapsed().as_micros())
|
||||
}
|
||||
|
||||
fn start_pingpong(&mut self) {
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
let my_node_id = self.my_node_id;
|
||||
let peer_id = self.get_peer_id();
|
||||
let receiver = self.ctrl_resp_sender.subscribe();
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let latency_stats = self.latency_stats.clone();
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
//sleep 1s
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
loop {
|
||||
let mut receiver = receiver.resubscribe();
|
||||
if let Ok(lat) =
|
||||
Self::do_pingpong_once(my_node_id, peer_id, &mut sink, &mut receiver).await
|
||||
{
|
||||
log::trace!(
|
||||
"pingpong latency: {}us, my_node_id: {}, peer_id: {}",
|
||||
lat,
|
||||
my_node_id,
|
||||
peer_id
|
||||
);
|
||||
latency_stats.record_latency(lat as u64);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
log::warn!(
|
||||
"pingpong task exit, my_node_id: {}, peer_id: {}",
|
||||
my_node_id,
|
||||
peer_id,
|
||||
);
|
||||
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
log::warn!("close event sender error: {:?}", e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
fn get_packet_type(mut bytes_item: Bytes) -> PeerConnPacketType {
|
||||
if bytes_item.starts_with(CTRL_REQ_PACKET_PREFIX) {
|
||||
PeerConnPacketType::CtrlReq(bytes_item.split_off(CTRL_REQ_PACKET_PREFIX.len()))
|
||||
} else if bytes_item.starts_with(CTRL_RESP_PACKET_PREFIX) {
|
||||
PeerConnPacketType::CtrlResp(bytes_item.split_off(CTRL_RESP_PACKET_PREFIX.len()))
|
||||
} else {
|
||||
PeerConnPacketType::Data(bytes_item)
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_ctrl_req_packet(
|
||||
bytes_item: Bytes,
|
||||
conn_info: &PeerConnInfo,
|
||||
) -> Result<Bytes, TunnelError> {
|
||||
let packet = Packet::decode(&bytes_item);
|
||||
match packet.body {
|
||||
packet::ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::Ping) => {
|
||||
log::trace!("recv ping packet: {:?}", packet);
|
||||
Ok(Self::build_ctrl_msg(
|
||||
packet::Packet::new_pong_packet(
|
||||
conn_info.my_node_id.parse().unwrap(),
|
||||
conn_info.peer_id.parse().unwrap(),
|
||||
)
|
||||
.into(),
|
||||
false,
|
||||
))
|
||||
}
|
||||
_ => {
|
||||
log::error!("unexpected packet: {:?}", packet);
|
||||
Err(TunnelError::CommonError("unexpected packet".to_owned()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
if let Err(close_ret) = sink.close().await {
|
||||
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
|
||||
}
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
return Err(ret.err().unwrap());
|
||||
}
|
||||
|
||||
match Self::get_packet_type(ret.unwrap().into()) {
|
||||
PeerConnPacketType::Data(item) => sender.send(item).await.unwrap(),
|
||||
PeerConnPacketType::CtrlReq(item) => {
|
||||
let ret = Self::handle_ctrl_req_packet(item, &conn_info).unwrap();
|
||||
if let Err(e) = sink.send(ret).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
}
|
||||
PeerConnPacketType::CtrlResp(item) => {
|
||||
if let Err(e) = ctrl_sender.send(item) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!("end recving peer conn packet");
|
||||
Ok(())
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
|
||||
self.start_pingpong();
|
||||
}
|
||||
|
||||
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
|
||||
self.sink.send(msg).await
|
||||
}
|
||||
|
||||
fn build_ctrl_msg(msg: Bytes, is_req: bool) -> Bytes {
|
||||
let prefix: &'static [u8] = if is_req {
|
||||
CTRL_REQ_PACKET_PREFIX
|
||||
} else {
|
||||
CTRL_RESP_PACKET_PREFIX
|
||||
};
|
||||
let mut new_msg = BytesMut::new();
|
||||
new_msg.reserve(prefix.len() + msg.len());
|
||||
new_msg.extend_from_slice(prefix);
|
||||
new_msg.extend_from_slice(&msg);
|
||||
new_msg.into()
|
||||
}
|
||||
|
||||
pub fn get_peer_id(&self) -> uuid::Uuid {
|
||||
self.info.as_ref().unwrap().my_peer_id
|
||||
}
|
||||
|
||||
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<uuid::Uuid>) {
|
||||
self.close_event_sender = Some(sender);
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> PeerConnStats {
|
||||
PeerConnStats {
|
||||
latency_us: self.latency_stats.get_latency_us(),
|
||||
|
||||
tx_bytes: self.throughput.tx_bytes(),
|
||||
rx_bytes: self.throughput.rx_bytes(),
|
||||
|
||||
tx_packets: self.throughput.tx_packets(),
|
||||
rx_packets: self.throughput.rx_packets(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_info(&self) -> PeerConnInfo {
|
||||
PeerConnInfo {
|
||||
conn_id: self.conn_id.to_string(),
|
||||
my_node_id: self.my_node_id.to_string(),
|
||||
peer_id: self.get_peer_id().to_string(),
|
||||
features: self.info.as_ref().unwrap().features.clone(),
|
||||
tunnel: self.tunnel.info(),
|
||||
stats: Some(self.get_stats()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PeerConn {
|
||||
fn drop(&mut self) {
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
tokio::spawn(async move {
|
||||
let ret = sink.close().await;
|
||||
tracing::info!(error = ?ret, "peer conn tunnel closed.");
|
||||
});
|
||||
log::info!("peer conn {:?} drop", self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::common::config_fs::ConfigFs;
|
||||
use crate::common::global_ctx::GlobalCtx;
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_handshake() {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
|
||||
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||
let s = TunnelWithFilter::new(s, s_recorder.clone());
|
||||
|
||||
let c_uuid = uuid::Uuid::new_v4();
|
||||
let s_uuid = uuid::Uuid::new_v4();
|
||||
|
||||
let mut c_peer = PeerConn::new(
|
||||
c_uuid,
|
||||
Arc::new(GlobalCtx::new(
|
||||
"c",
|
||||
ConfigFs::new_with_dir("c", "/tmp"),
|
||||
NetNS::new(None),
|
||||
)),
|
||||
Box::new(c),
|
||||
);
|
||||
|
||||
let mut s_peer = PeerConn::new(
|
||||
s_uuid,
|
||||
Arc::new(GlobalCtx::new(
|
||||
"c",
|
||||
ConfigFs::new_with_dir("c", "/tmp"),
|
||||
NetNS::new(None),
|
||||
)),
|
||||
Box::new(s),
|
||||
);
|
||||
|
||||
let (c_ret, s_ret) = tokio::join!(
|
||||
c_peer.do_handshake_as_client(),
|
||||
s_peer.do_handshake_as_server()
|
||||
);
|
||||
|
||||
c_ret.unwrap();
|
||||
s_ret.unwrap();
|
||||
|
||||
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(c_peer.get_peer_id(), s_uuid);
|
||||
assert_eq!(s_peer.get_peer_id(), c_uuid);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,539 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
net::Ipv4Addr,
|
||||
sync::{atomic::AtomicU8, Arc},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{StreamExt, TryFutureExt};
|
||||
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
Mutex, RwLock,
|
||||
},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_vec},
|
||||
peers::{
|
||||
packet::{self},
|
||||
peer_conn::PeerConn,
|
||||
peer_rpc::PeerRpcManagerTransport,
|
||||
route_trait::RouteInterface,
|
||||
},
|
||||
tunnels::{SinkItem, Tunnel, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::{
|
||||
peer_map::PeerMap,
|
||||
peer_rpc::PeerRpcManager,
|
||||
route_trait::{ArcRoute, Route},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
struct RpcTransport {
|
||||
my_peer_id: uuid::Uuid,
|
||||
peers: Arc<PeerMap>,
|
||||
|
||||
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
|
||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
||||
|
||||
route: Arc<Mutex<Option<ArcRoute>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for RpcTransport {
|
||||
fn my_peer_id(&self) -> Uuid {
|
||||
self.my_peer_id
|
||||
}
|
||||
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: &uuid::Uuid) -> Result<(), Error> {
|
||||
let route = self.route.lock().await;
|
||||
if route.is_none() {
|
||||
log::error!("no route info when send rpc msg");
|
||||
return Err(Error::RouteError("No route info".to_string()));
|
||||
}
|
||||
|
||||
self.peers
|
||||
.send_msg(msg, dst_peer_id, route.as_ref().unwrap().clone())
|
||||
.map_err(|e| e.into())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
||||
Ok(o)
|
||||
} else {
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait PeerPacketFilter {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait NicPacketFilter {
|
||||
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
|
||||
}
|
||||
|
||||
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
|
||||
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
|
||||
|
||||
pub struct PeerManager {
|
||||
my_node_id: uuid::Uuid,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
|
||||
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
|
||||
|
||||
peers: Arc<PeerMap>,
|
||||
route: Arc<Mutex<Option<ArcRoute>>>,
|
||||
cur_route_id: AtomicU8,
|
||||
|
||||
peer_rpc_mgr: Arc<PeerRpcManager>,
|
||||
peer_rpc_tspt: Arc<RpcTransport>,
|
||||
|
||||
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>,
|
||||
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>,
|
||||
}
|
||||
|
||||
impl Debug for PeerManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerManager")
|
||||
.field("my_node_id", &self.my_node_id)
|
||||
.field("instance_name", &self.global_ctx.inst_name)
|
||||
.field("net_ns", &self.global_ctx.net_ns.name())
|
||||
.field("cur_route_id", &self.cur_route_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerManager {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, nic_channel: mpsc::Sender<SinkItem>) -> Self {
|
||||
let (packet_send, packet_recv) = mpsc::channel(100);
|
||||
let peers = Arc::new(PeerMap::new(packet_send.clone()));
|
||||
|
||||
// TODO: remove these because we have impl pipeline processor.
|
||||
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
|
||||
let rpc_tspt = Arc::new(RpcTransport {
|
||||
my_peer_id: global_ctx.get_id(),
|
||||
peers: peers.clone(),
|
||||
packet_recv: Mutex::new(peer_rpc_tspt_recv),
|
||||
peer_rpc_tspt_sender,
|
||||
route: Arc::new(Mutex::new(None)),
|
||||
});
|
||||
|
||||
PeerManager {
|
||||
my_node_id: global_ctx.get_id(),
|
||||
global_ctx,
|
||||
nic_channel,
|
||||
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
|
||||
packet_recv: Arc::new(Mutex::new(Some(packet_recv))),
|
||||
|
||||
peers: peers.clone(),
|
||||
route: Arc::new(Mutex::new(None)),
|
||||
cur_route_id: AtomicU8::new(0),
|
||||
|
||||
peer_rpc_mgr: Arc::new(PeerRpcManager::new(rpc_tspt.clone())),
|
||||
peer_rpc_tspt: rpc_tspt,
|
||||
|
||||
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
|
||||
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_client_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(Uuid, Uuid), Error> {
|
||||
let mut peer = PeerConn::new(self.my_node_id, self.global_ctx.clone(), tunnel);
|
||||
peer.do_handshake_as_client().await?;
|
||||
let conn_id = peer.get_conn_id();
|
||||
let peer_id = peer.get_peer_id();
|
||||
self.peers
|
||||
.add_new_peer_conn(peer, self.global_ctx.clone())
|
||||
.await;
|
||||
Ok((peer_id, conn_id))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn try_connect<C>(&self, mut connector: C) -> Result<(Uuid, Uuid), Error>
|
||||
where
|
||||
C: TunnelConnector + Debug,
|
||||
{
|
||||
let ns = self.global_ctx.net_ns.clone();
|
||||
let t = ns
|
||||
.run_async(|| async move { connector.connect().await })
|
||||
.await?;
|
||||
self.add_client_tunnel(t).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn add_tunnel_as_server(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||
tracing::info!("add tunnel as server start");
|
||||
let mut peer = PeerConn::new(self.my_node_id, self.global_ctx.clone(), tunnel);
|
||||
peer.do_handshake_as_server().await?;
|
||||
self.peers
|
||||
.add_new_peer_conn(peer, self.global_ctx.clone())
|
||||
.await;
|
||||
tracing::info!("add tunnel as server done");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_peer_recv(&self) {
|
||||
let mut recv = ReceiverStream::new(self.packet_recv.lock().await.take().unwrap());
|
||||
let my_node_id = self.my_node_id;
|
||||
let peers = self.peers.clone();
|
||||
let arc_route = self.route.clone();
|
||||
let pipe_line = self.peer_packet_process_pipeline.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
log::trace!("start_peer_recv");
|
||||
while let Some(ret) = recv.next().await {
|
||||
log::trace!("peer recv a packet...: {:?}", ret);
|
||||
let packet = packet::Packet::decode(&ret);
|
||||
let from_peer_uuid = packet.from_peer.to_uuid();
|
||||
let to_peer_uuid = packet.to_peer.as_ref().unwrap().to_uuid();
|
||||
if to_peer_uuid != my_node_id {
|
||||
let locked_arc_route = arc_route.lock().await;
|
||||
if locked_arc_route.is_none() {
|
||||
log::error!("no route info after recv a packet");
|
||||
continue;
|
||||
}
|
||||
|
||||
let route = locked_arc_route.as_ref().unwrap().clone();
|
||||
drop(locked_arc_route);
|
||||
log::trace!(
|
||||
"need forward: to_peer_uuid: {:?}, my_uuid: {:?}",
|
||||
to_peer_uuid,
|
||||
my_node_id
|
||||
);
|
||||
let ret = peers
|
||||
.send_msg(ret.clone(), &to_peer_uuid, route.clone())
|
||||
.await;
|
||||
if ret.is_err() {
|
||||
log::error!(
|
||||
"forward packet error: {:?}, dst: {:?}, from: {:?}",
|
||||
ret,
|
||||
to_peer_uuid,
|
||||
from_peer_uuid
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let mut processed = false;
|
||||
for pipeline in pipe_line.read().await.iter().rev() {
|
||||
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
|
||||
{
|
||||
processed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !processed {
|
||||
tracing::error!("unexpected packet: {:?}", ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!("done_peer_recv");
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
|
||||
// newest pipeline will be executed first
|
||||
self.peer_packet_process_pipeline
|
||||
.write()
|
||||
.await
|
||||
.push(pipeline);
|
||||
}
|
||||
|
||||
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
|
||||
// newest pipeline will be executed first
|
||||
self.nic_packet_process_pipeline
|
||||
.write()
|
||||
.await
|
||||
.push(pipeline);
|
||||
}
|
||||
|
||||
async fn init_packet_process_pipeline(&self) {
|
||||
use packet::ArchivedPacketBody;
|
||||
|
||||
// for tun/tap ip/eth packet.
|
||||
struct NicPacketProcessor {
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
}
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for NicPacketProcessor {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if let packet::ArchivedPacketBody::Data(x) = &packet.body {
|
||||
// TODO: use a function to get the body ref directly for zero copy
|
||||
self.nic_channel
|
||||
.send(extract_bytes_from_archived_vec(&data, &x.data))
|
||||
.await
|
||||
.unwrap();
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
self.add_packet_process_pipeline(Box::new(NicPacketProcessor {
|
||||
nic_channel: self.nic_channel.clone(),
|
||||
}))
|
||||
.await;
|
||||
|
||||
// for peer manager router packet
|
||||
struct RoutePacketProcessor {
|
||||
route: Arc<Mutex<Option<ArcRoute>>>,
|
||||
}
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for RoutePacketProcessor {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if let ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::RoutePacket(
|
||||
route_packet,
|
||||
)) = &packet.body
|
||||
{
|
||||
let r = self.route.lock().await;
|
||||
match r.as_ref() {
|
||||
Some(x) => {
|
||||
let x = x.clone();
|
||||
drop(r);
|
||||
x.handle_route_packet(
|
||||
packet.from_peer.to_uuid(),
|
||||
extract_bytes_from_archived_vec(&data, &route_packet.body),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
None => {
|
||||
log::error!("no route info when handle route packet");
|
||||
}
|
||||
}
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
self.add_packet_process_pipeline(Box::new(RoutePacketProcessor {
|
||||
route: self.route.clone(),
|
||||
}))
|
||||
.await;
|
||||
|
||||
// for peer rpc packet
|
||||
struct PeerRpcPacketProcessor {
|
||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for PeerRpcPacketProcessor {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if let ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::TaRpc(..)) =
|
||||
&packet.body
|
||||
{
|
||||
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor {
|
||||
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn set_route<T>(&self, route: T)
|
||||
where
|
||||
T: Route + Send + Sync + 'static,
|
||||
{
|
||||
struct Interface {
|
||||
my_node_id: uuid::Uuid,
|
||||
peers: Arc<PeerMap>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouteInterface for Interface {
|
||||
async fn list_peers(&self) -> Vec<PeerId> {
|
||||
self.peers.list_peers_with_conn().await
|
||||
}
|
||||
async fn send_route_packet(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
route_id: u8,
|
||||
dst_peer_id: &PeerId,
|
||||
) -> Result<(), Error> {
|
||||
self.peers
|
||||
.send_msg_directly(
|
||||
packet::Packet::new_route_packet(
|
||||
self.my_node_id,
|
||||
*dst_peer_id,
|
||||
route_id,
|
||||
&msg,
|
||||
)
|
||||
.into(),
|
||||
dst_peer_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
let my_node_id = self.my_node_id;
|
||||
let route_id = route
|
||||
.open(Box::new(Interface {
|
||||
my_node_id,
|
||||
peers: self.peers.clone(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
self.cur_route_id
|
||||
.store(route_id, std::sync::atomic::Ordering::Relaxed);
|
||||
let arc_route: ArcRoute = Arc::new(Box::new(route));
|
||||
|
||||
self.route.lock().await.replace(arc_route.clone());
|
||||
|
||||
self.peer_rpc_tspt
|
||||
.route
|
||||
.lock()
|
||||
.await
|
||||
.replace(arc_route.clone());
|
||||
}
|
||||
|
||||
pub async fn list_routes(&self) -> Vec<easytier_rpc::Route> {
|
||||
let route_info = self.route.lock().await;
|
||||
if route_info.is_none() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let route = route_info.as_ref().unwrap().clone();
|
||||
drop(route_info);
|
||||
route.list_routes().await
|
||||
}
|
||||
|
||||
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
|
||||
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
|
||||
data = pipeline.try_process_packet_from_nic(data).await;
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: &PeerId) -> Result<(), Error> {
|
||||
self.peer_rpc_tspt.send(msg, dst_peer_id).await
|
||||
}
|
||||
|
||||
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
|
||||
let route_info = self.route.lock().await;
|
||||
if route_info.is_none() {
|
||||
log::error!("no route info");
|
||||
return Err(Error::RouteError("No route info".to_string()));
|
||||
}
|
||||
|
||||
let route = route_info.as_ref().unwrap().clone();
|
||||
drop(route_info);
|
||||
|
||||
log::trace!(
|
||||
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
|
||||
msg,
|
||||
ipv4_addr
|
||||
);
|
||||
|
||||
match route.get_peer_id_by_ipv4(&ipv4_addr).await {
|
||||
Some(peer_id) => {
|
||||
let msg = self.run_nic_packet_process_pipeline(msg).await;
|
||||
self.peers
|
||||
.send_msg(
|
||||
packet::Packet::new_data_packet(self.my_node_id, peer_id, &msg).into(),
|
||||
&peer_id,
|
||||
route.clone(),
|
||||
)
|
||||
.await?;
|
||||
log::trace!(
|
||||
"do send_msg in peer manager done, dst_peer_id: {:?}",
|
||||
peer_id
|
||||
);
|
||||
}
|
||||
None => {
|
||||
log::trace!("no peer id for ipv4: {}", ipv4_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_clean_peer_without_conn_routine(&self) {
|
||||
let peer_map = self.peers.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
let mut to_remove = vec![];
|
||||
|
||||
for peer_id in peer_map.list_peers().await {
|
||||
let conns = peer_map.list_peer_conns(&peer_id).await;
|
||||
if conns.is_none() || conns.as_ref().unwrap().is_empty() {
|
||||
to_remove.push(peer_id);
|
||||
}
|
||||
}
|
||||
|
||||
for peer_id in to_remove {
|
||||
peer_map.close_peer(&peer_id).await.unwrap();
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<(), Error> {
|
||||
self.init_packet_process_pipeline().await;
|
||||
self.start_peer_recv().await;
|
||||
self.peer_rpc_mgr.run();
|
||||
self.run_clean_peer_without_conn_routine().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_peer_map(&self) -> Arc<PeerMap> {
|
||||
self.peers.clone()
|
||||
}
|
||||
|
||||
pub fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
|
||||
self.peer_rpc_mgr.clone()
|
||||
}
|
||||
|
||||
pub fn my_node_id(&self) -> uuid::Uuid {
|
||||
self.my_node_id
|
||||
}
|
||||
|
||||
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
|
||||
self.global_ctx.clone()
|
||||
}
|
||||
|
||||
pub fn get_nic_channel(&self) -> mpsc::Sender<SinkItem> {
|
||||
self.nic_channel.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use easytier_rpc::PeerConnInfo;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx},
|
||||
tunnels::TunnelError,
|
||||
};
|
||||
|
||||
use super::{peer::Peer, peer_conn::PeerConn, route_trait::ArcRoute, PeerId};
|
||||
|
||||
pub struct PeerMap {
|
||||
peer_map: DashMap<PeerId, Arc<Peer>>,
|
||||
packet_send: mpsc::Sender<Bytes>,
|
||||
}
|
||||
|
||||
impl PeerMap {
|
||||
pub fn new(packet_send: mpsc::Sender<Bytes>) -> Self {
|
||||
PeerMap {
|
||||
peer_map: DashMap::new(),
|
||||
packet_send,
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_new_peer(&self, peer: Peer) {
|
||||
self.peer_map.insert(peer.peer_node_id, Arc::new(peer));
|
||||
}
|
||||
|
||||
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn, global_ctx: ArcGlobalCtx) {
|
||||
let peer_id = peer_conn.get_peer_id();
|
||||
let no_entry = self.peer_map.get(&peer_id).is_none();
|
||||
if no_entry {
|
||||
let new_peer = Peer::new(peer_id, self.packet_send.clone(), global_ctx);
|
||||
new_peer.add_peer_conn(peer_conn).await;
|
||||
self.add_new_peer(new_peer).await;
|
||||
} else {
|
||||
let peer = self.peer_map.get(&peer_id).unwrap().clone();
|
||||
peer.add_peer_conn(peer_conn).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn get_peer_by_id(&self, peer_id: &PeerId) -> Option<Arc<Peer>> {
|
||||
self.peer_map.get(peer_id).map(|v| v.clone())
|
||||
}
|
||||
|
||||
pub async fn send_msg_directly(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
dst_peer_id: &uuid::Uuid,
|
||||
) -> Result<(), Error> {
|
||||
match self.get_peer_by_id(dst_peer_id) {
|
||||
Some(peer) => {
|
||||
peer.send_msg(msg).await?;
|
||||
}
|
||||
None => {
|
||||
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_msg(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
dst_peer_id: &uuid::Uuid,
|
||||
route: ArcRoute,
|
||||
) -> Result<(), Error> {
|
||||
// get route info
|
||||
let gateway_peer_id = route.get_next_hop(dst_peer_id).await;
|
||||
|
||||
if gateway_peer_id.is_none() {
|
||||
log::error!("no gateway for dst_peer_id: {}", dst_peer_id);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let gateway_peer_id = gateway_peer_id.unwrap();
|
||||
self.send_msg_directly(msg, &gateway_peer_id).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let mut ret = Vec::new();
|
||||
for item in self.peer_map.iter() {
|
||||
let peer_id = item.key();
|
||||
ret.push(*peer_id);
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub async fn list_peers_with_conn(&self) -> Vec<PeerId> {
|
||||
let mut ret = Vec::new();
|
||||
let peers = self.list_peers().await;
|
||||
for peer_id in peers.iter() {
|
||||
let Some(peer) = self.get_peer_by_id(peer_id) else {
|
||||
continue;
|
||||
};
|
||||
if peer.list_peer_conns().await.len() > 0 {
|
||||
ret.push(*peer_id);
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub async fn list_peer_conns(&self, peer_id: &PeerId) -> Option<Vec<PeerConnInfo>> {
|
||||
if let Some(p) = self.get_peer_by_id(peer_id) {
|
||||
Some(p.list_peer_conns().await)
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(
|
||||
&self,
|
||||
peer_id: &PeerId,
|
||||
conn_id: &uuid::Uuid,
|
||||
) -> Result<(), Error> {
|
||||
if let Some(p) = self.get_peer_by_id(peer_id) {
|
||||
p.close_peer_conn(conn_id).await
|
||||
} else {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close_peer(&self, peer_id: &PeerId) -> Result<(), TunnelError> {
|
||||
let remove_ret = self.peer_map.remove(peer_id);
|
||||
tracing::info!(
|
||||
?peer_id,
|
||||
has_old_value = ?remove_ret.is_some(),
|
||||
peer_ref_counter = ?remove_ret.map(|v| Arc::strong_count(&v.1)),
|
||||
"peer is closed"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,509 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use rkyv::Deserialize;
|
||||
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
|
||||
use tokio::{
|
||||
sync::mpsc::{self, UnboundedSender},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{common::error::Error, peers::packet::Packet};
|
||||
|
||||
use super::packet::{CtrlPacketBody, PacketBody};
|
||||
|
||||
type PeerRpcServiceId = u32;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
|
||||
fn my_peer_id(&self) -> uuid::Uuid;
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: &uuid::Uuid) -> Result<(), Error>;
|
||||
async fn recv(&self) -> Result<Bytes, Error>;
|
||||
}
|
||||
|
||||
type PacketSender = UnboundedSender<Packet>;
|
||||
|
||||
struct PeerRpcEndPoint {
|
||||
peer_id: uuid::Uuid,
|
||||
packet_sender: PacketSender,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
type PeerRpcEndPointCreator = Box<dyn Fn(uuid::Uuid) -> PeerRpcEndPoint + Send + Sync + 'static>;
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct PeerRpcClientCtxKey(uuid::Uuid, PeerRpcServiceId);
|
||||
|
||||
// handle rpc request from one peer
|
||||
pub struct PeerRpcManager {
|
||||
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
|
||||
tasks: JoinSet<()>,
|
||||
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
|
||||
|
||||
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
|
||||
peer_rpc_endpoints: Arc<DashMap<(uuid::Uuid, PeerRpcServiceId), PeerRpcEndPoint>>,
|
||||
|
||||
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PeerRpcManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerRpcManager")
|
||||
.field("node_id", &self.tspt.my_peer_id())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TaRpcPacketInfo {
|
||||
from_peer: uuid::Uuid,
|
||||
to_peer: uuid::Uuid,
|
||||
service_id: PeerRpcServiceId,
|
||||
is_req: bool,
|
||||
content: Vec<u8>,
|
||||
}
|
||||
|
||||
impl PeerRpcManager {
|
||||
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
|
||||
Self {
|
||||
service_map: Arc::new(DashMap::new()),
|
||||
tasks: JoinSet::new(),
|
||||
tspt: Arc::new(Box::new(tspt)),
|
||||
|
||||
service_registry: Arc::new(DashMap::new()),
|
||||
peer_rpc_endpoints: Arc::new(DashMap::new()),
|
||||
|
||||
client_resp_receivers: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
|
||||
where
|
||||
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
|
||||
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
|
||||
S::Resp:
|
||||
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
|
||||
S::Fut: Send + 'static,
|
||||
{
|
||||
let tspt = self.tspt.clone();
|
||||
let creator = Box::new(move |peer_id: uuid::Uuid| {
|
||||
let mut tasks = JoinSet::new();
|
||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
||||
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
let my_peer_id_clone = tspt.my_peer_id();
|
||||
let peer_id_clone = peer_id.clone();
|
||||
|
||||
let o = server.execute(s.clone());
|
||||
tasks.spawn(o);
|
||||
|
||||
let tspt = tspt.clone();
|
||||
tasks.spawn(async move {
|
||||
let mut cur_req_uuid = None;
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(resp) = client_transport.next() => {
|
||||
tracing::trace!(resp = ?resp, "recv packet from client");
|
||||
if resp.is_err() {
|
||||
tracing::warn!(err = ?resp.err(),
|
||||
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
|
||||
continue;
|
||||
}
|
||||
let resp = resp.unwrap();
|
||||
|
||||
if cur_req_uuid.is_none() {
|
||||
tracing::error!("[PEER RPC MGR] cur_req_uuid is none, ignore this resp");
|
||||
continue;
|
||||
}
|
||||
|
||||
let serialized_resp = bincode::serialize(&resp);
|
||||
if serialized_resp.is_err() {
|
||||
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg = Packet::new_tarpc_packet(
|
||||
tspt.my_peer_id(),
|
||||
cur_req_uuid.take().unwrap(),
|
||||
service_id,
|
||||
false,
|
||||
serialized_resp.unwrap(),
|
||||
);
|
||||
|
||||
if let Err(e) = tspt.send(msg.into(), &peer_id).await {
|
||||
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
|
||||
}
|
||||
}
|
||||
Some(packet) = packet_receiver.recv() => {
|
||||
let info = Self::parse_rpc_packet(&packet);
|
||||
if let Err(e) = info {
|
||||
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
let info = info.unwrap();
|
||||
|
||||
assert_eq!(info.service_id, service_id);
|
||||
cur_req_uuid = Some(packet.from_peer.clone().into());
|
||||
|
||||
tracing::trace!("recv packet from peer, packet: {:?}", packet);
|
||||
|
||||
let decoded_ret = bincode::deserialize(&info.content.as_slice());
|
||||
if let Err(e) = decoded_ret {
|
||||
tracing::error!(error = ?e, "decode rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
|
||||
|
||||
if let Err(e) = client_transport.send(decoded).await {
|
||||
tracing::error!(error = ?e, "send to req to client transport failed");
|
||||
}
|
||||
}
|
||||
else => {
|
||||
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
|
||||
|
||||
tracing::info!(
|
||||
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
|
||||
peer_id,
|
||||
service_id
|
||||
);
|
||||
|
||||
return PeerRpcEndPoint {
|
||||
peer_id,
|
||||
packet_sender,
|
||||
tasks,
|
||||
};
|
||||
// let resp = client_transport.next().await;
|
||||
});
|
||||
|
||||
if let Some(_) = self.service_registry.insert(service_id, creator) {
|
||||
panic!(
|
||||
"[PEER RPC MGR] service {} is already registered",
|
||||
service_id
|
||||
);
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
|
||||
service_id,
|
||||
self.tspt.my_peer_id()
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
|
||||
match &packet.body {
|
||||
PacketBody::Ctrl(CtrlPacketBody::TaRpc(id, is_req, body)) => Ok(TaRpcPacketInfo {
|
||||
from_peer: packet.from_peer.clone().into(),
|
||||
to_peer: packet.to_peer.clone().unwrap().into(),
|
||||
service_id: *id,
|
||||
is_req: *is_req,
|
||||
content: body.clone(),
|
||||
}),
|
||||
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&self) {
|
||||
let tspt = self.tspt.clone();
|
||||
let service_registry = self.service_registry.clone();
|
||||
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
|
||||
let client_resp_receivers = self.client_resp_receivers.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let o = tspt.recv().await.unwrap();
|
||||
let packet = Packet::decode(&o);
|
||||
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
|
||||
let info = Self::parse_rpc_packet(&packet).unwrap();
|
||||
|
||||
if info.is_req {
|
||||
if !service_registry.contains_key(&info.service_id) {
|
||||
log::warn!(
|
||||
"service {} not found, my_node_id: {}",
|
||||
info.service_id,
|
||||
tspt.my_peer_id()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let endpoint = peer_rpc_endpoints
|
||||
.entry((info.to_peer, info.service_id))
|
||||
.or_insert_with(|| {
|
||||
service_registry.get(&info.service_id).unwrap()(info.from_peer)
|
||||
});
|
||||
|
||||
endpoint.packet_sender.send(packet).unwrap();
|
||||
} else {
|
||||
if let Some(a) = client_resp_receivers
|
||||
.get(&PeerRpcClientCtxKey(info.from_peer, info.service_id))
|
||||
{
|
||||
log::trace!("recv resp: {:?}", packet);
|
||||
if let Err(e) = a.send(packet) {
|
||||
tracing::error!(error = ?e, "send resp to client failed");
|
||||
}
|
||||
} else {
|
||||
log::warn!("client resp receiver not found, info: {:?}", info);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(f))]
|
||||
pub async fn do_client_rpc_scoped<CM, Req, RpcRet, Fut>(
|
||||
&self,
|
||||
service_id: PeerRpcServiceId,
|
||||
dst_peer_id: uuid::Uuid,
|
||||
f: impl FnOnce(UnboundedChannel<CM, Req>) -> Fut,
|
||||
) -> RpcRet
|
||||
where
|
||||
CM: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
|
||||
Req: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
|
||||
Fut: std::future::Future<Output = RpcRet>,
|
||||
{
|
||||
let mut tasks = JoinSet::new();
|
||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
||||
let (client_transport, server_transport) =
|
||||
tarpc::transport::channel::unbounded::<CM, Req>();
|
||||
|
||||
let (mut server_s, mut server_r) = server_transport.split();
|
||||
|
||||
let tspt = self.tspt.clone();
|
||||
tasks.spawn(async move {
|
||||
while let Some(a) = server_r.next().await {
|
||||
if a.is_err() {
|
||||
tracing::error!(error = ?a.err(), "channel error");
|
||||
continue;
|
||||
}
|
||||
|
||||
let a = bincode::serialize(&a.unwrap());
|
||||
if a.is_err() {
|
||||
tracing::error!(error = ?a.err(), "bincode serialize failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let a = Packet::new_tarpc_packet(
|
||||
tspt.my_peer_id(),
|
||||
dst_peer_id,
|
||||
service_id,
|
||||
true,
|
||||
a.unwrap(),
|
||||
);
|
||||
|
||||
if let Err(e) = tspt.send(a.into(), &dst_peer_id).await {
|
||||
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
|
||||
});
|
||||
|
||||
tasks.spawn(async move {
|
||||
while let Some(packet) = packet_receiver.recv().await {
|
||||
tracing::trace!("tunnel recv: {:?}", packet);
|
||||
|
||||
let info = PeerRpcManager::parse_rpc_packet(&packet);
|
||||
if let Err(e) = info {
|
||||
tracing::error!(error = ?e, "parse rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let decoded = bincode::deserialize(&info.unwrap().content.as_slice());
|
||||
if let Err(e) = decoded {
|
||||
tracing::error!(error = ?e, "decode rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = server_s.send(decoded.unwrap()).await {
|
||||
tracing::error!(error = ?e, "send to rpc server channel failed");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("[PEER RPC MGR] server packet read aborted");
|
||||
});
|
||||
|
||||
let _insert_ret = self
|
||||
.client_resp_receivers
|
||||
.insert(PeerRpcClientCtxKey(dst_peer_id, service_id), packet_sender);
|
||||
|
||||
f(client_transport).await
|
||||
}
|
||||
|
||||
pub fn my_peer_id(&self) -> uuid::Uuid {
|
||||
self.tspt.my_peer_id()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
common::error::Error,
|
||||
peers::{
|
||||
peer_rpc::PeerRpcManager,
|
||||
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
||||
},
|
||||
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
|
||||
};
|
||||
|
||||
use super::PeerRpcManagerTransport;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait TestRpcService {
|
||||
async fn hello(s: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockService {
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl TestRpcService for MockService {
|
||||
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
|
||||
format!("{} {}", self.prefix, s)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_rpc_basic_test() {
|
||||
struct MockTransport {
|
||||
tunnel: Box<dyn tunnels::Tunnel>,
|
||||
my_peer_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for MockTransport {
|
||||
fn my_peer_id(&self) -> uuid::Uuid {
|
||||
self.my_peer_id
|
||||
}
|
||||
async fn send(&self, msg: Bytes, _dst_peer_id: &uuid::Uuid) -> Result<(), Error> {
|
||||
println!("rpc mgr send: {:?}", msg);
|
||||
self.tunnel.pin_sink().send(msg).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
let ret = self.tunnel.pin_stream().next().await.unwrap();
|
||||
println!("rpc mgr recv: {:?}", ret);
|
||||
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
|
||||
}
|
||||
}
|
||||
|
||||
let (ct, st) = create_ring_tunnel_pair();
|
||||
|
||||
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||
tunnel: st,
|
||||
my_peer_id: uuid::Uuid::new_v4(),
|
||||
});
|
||||
server_rpc_mgr.run();
|
||||
let s = MockService {
|
||||
prefix: "hello".to_owned(),
|
||||
};
|
||||
server_rpc_mgr.run_service(1, s.serve());
|
||||
|
||||
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||
tunnel: ct,
|
||||
my_peer_id: uuid::Uuid::new_v4(),
|
||||
});
|
||||
client_rpc_mgr.run();
|
||||
|
||||
let ret = client_rpc_mgr
|
||||
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
println!("ret: {:?}", ret);
|
||||
assert_eq!(ret.unwrap(), "hello abc");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rpc_with_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager().await;
|
||||
let peer_mgr_b = create_mock_peer_manager().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_node_id())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_a.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_node_id()
|
||||
);
|
||||
|
||||
let s = MockService {
|
||||
prefix: "hello".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_node_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
println!("ip_list: {:?}", ip_list);
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello abc");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_service_with_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager().await;
|
||||
let peer_mgr_b = create_mock_peer_manager().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_node_id())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_a.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_node_id()
|
||||
);
|
||||
|
||||
let s = MockService {
|
||||
prefix: "hello_a".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
|
||||
let b = MockService {
|
||||
prefix: "hello_b".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_node_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc");
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(2, peer_mgr_b.my_node_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,480 @@
|
||||
use std::{net::Ipv4Addr, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use easytier_rpc::{NatType, StunInfo};
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use tokio::{sync::Mutex, task::JoinSet};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::ArcGlobalCtx,
|
||||
rkyv_util::{decode_from_bytes, encode_to_bytes},
|
||||
stun::StunInfoCollectorTrait,
|
||||
},
|
||||
peers::{
|
||||
packet::{self, UUID},
|
||||
route_trait::{Route, RouteInterfaceBox},
|
||||
PeerId,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct SyncPeerInfo {
|
||||
// means next hop in route table.
|
||||
pub peer_id: UUID,
|
||||
pub cost: u32,
|
||||
pub ipv4_addr: Option<Ipv4Addr>,
|
||||
pub proxy_cidrs: Vec<String>,
|
||||
pub hostname: Option<String>,
|
||||
pub udp_stun_info: i8,
|
||||
}
|
||||
|
||||
impl SyncPeerInfo {
|
||||
pub fn new_self(from_peer: UUID, global_ctx: &ArcGlobalCtx) -> Self {
|
||||
SyncPeerInfo {
|
||||
peer_id: from_peer,
|
||||
cost: 0,
|
||||
ipv4_addr: global_ctx.get_ipv4(),
|
||||
proxy_cidrs: global_ctx
|
||||
.get_proxy_cidrs()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
hostname: global_ctx.get_hostname(),
|
||||
udp_stun_info: global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type as i8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clone_for_route_table(&self, next_hop: &UUID, cost: u32, from: &Self) -> Self {
|
||||
SyncPeerInfo {
|
||||
peer_id: next_hop.clone(),
|
||||
cost,
|
||||
ipv4_addr: from.ipv4_addr.clone(),
|
||||
proxy_cidrs: from.proxy_cidrs.clone(),
|
||||
hostname: from.hostname.clone(),
|
||||
udp_stun_info: from.udp_stun_info,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct SyncPeer {
|
||||
pub myself: SyncPeerInfo,
|
||||
pub neighbors: Vec<SyncPeerInfo>,
|
||||
}
|
||||
|
||||
impl SyncPeer {
|
||||
pub fn new(
|
||||
from_peer: UUID,
|
||||
_to_peer: UUID,
|
||||
neighbors: Vec<SyncPeerInfo>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
) -> Self {
|
||||
SyncPeer {
|
||||
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
|
||||
neighbors,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SyncPeerFromRemote {
|
||||
packet: SyncPeer,
|
||||
last_update: std::time::Instant,
|
||||
}
|
||||
|
||||
type SyncPeerFromRemoteMap = Arc<DashMap<uuid::Uuid, SyncPeerFromRemote>>;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct RouteTable {
|
||||
route_info: DashMap<uuid::Uuid, SyncPeerInfo>,
|
||||
ipv4_peer_id_map: DashMap<Ipv4Addr, uuid::Uuid>,
|
||||
cidr_peer_id_map: DashMap<cidr::IpCidr, uuid::Uuid>,
|
||||
}
|
||||
|
||||
impl RouteTable {
|
||||
fn new() -> Self {
|
||||
RouteTable {
|
||||
route_info: DashMap::new(),
|
||||
ipv4_peer_id_map: DashMap::new(),
|
||||
cidr_peer_id_map: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_from(&self, other: &Self) {
|
||||
self.route_info.clear();
|
||||
for item in other.route_info.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.route_info.insert(*k, v.clone());
|
||||
}
|
||||
|
||||
self.ipv4_peer_id_map.clear();
|
||||
for item in other.ipv4_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.ipv4_peer_id_map.insert(*k, *v);
|
||||
}
|
||||
|
||||
self.cidr_peer_id_map.clear();
|
||||
for item in other.cidr_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.cidr_peer_id_map.insert(*k, *v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BasicRoute {
|
||||
my_peer_id: packet::UUID,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
|
||||
|
||||
route_table: Arc<RouteTable>,
|
||||
|
||||
sync_peer_from_remote: SyncPeerFromRemoteMap,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
|
||||
need_sync_notifier: Arc<tokio::sync::Notify>,
|
||||
}
|
||||
|
||||
impl BasicRoute {
|
||||
pub fn new(my_peer_id: Uuid, global_ctx: ArcGlobalCtx) -> Self {
|
||||
BasicRoute {
|
||||
my_peer_id: my_peer_id.into(),
|
||||
global_ctx,
|
||||
interface: Arc::new(Mutex::new(None)),
|
||||
|
||||
route_table: Arc::new(RouteTable::new()),
|
||||
|
||||
sync_peer_from_remote: Arc::new(DashMap::new()),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
|
||||
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_route_table(
|
||||
my_id: packet::UUID,
|
||||
sync_peer_reqs: SyncPeerFromRemoteMap,
|
||||
route_table: Arc<RouteTable>,
|
||||
) {
|
||||
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
|
||||
|
||||
let new_route_table = Arc::new(RouteTable::new());
|
||||
for item in sync_peer_reqs.iter() {
|
||||
Self::update_route_table_with_req(
|
||||
my_id.clone(),
|
||||
&item.value().packet,
|
||||
new_route_table.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
route_table.copy_from(&new_route_table);
|
||||
}
|
||||
|
||||
fn update_route_table_with_req(
|
||||
my_id: packet::UUID,
|
||||
packet: &SyncPeer,
|
||||
route_table: Arc<RouteTable>,
|
||||
) {
|
||||
let peer_id = packet.myself.peer_id.clone();
|
||||
let update = |cost: u32, peer_info: &SyncPeerInfo| {
|
||||
let node_id: uuid::Uuid = peer_info.peer_id.clone().into();
|
||||
let ret = route_table
|
||||
.route_info
|
||||
.entry(node_id.clone().into())
|
||||
.and_modify(|info| {
|
||||
if info.cost > cost {
|
||||
*info = info.clone_for_route_table(&peer_id, cost, &peer_info);
|
||||
}
|
||||
})
|
||||
.or_insert(
|
||||
peer_info
|
||||
.clone()
|
||||
.clone_for_route_table(&peer_id, cost, &peer_info),
|
||||
)
|
||||
.value()
|
||||
.clone();
|
||||
|
||||
if ret.cost > 32 {
|
||||
log::error!(
|
||||
"cost too large: {}, may lost connection, remove it",
|
||||
ret.cost
|
||||
);
|
||||
route_table.route_info.remove(&node_id);
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
|
||||
node_id,
|
||||
peer_id,
|
||||
cost,
|
||||
&peer_info
|
||||
);
|
||||
|
||||
if let Some(ipv4) = peer_info.ipv4_addr {
|
||||
route_table
|
||||
.ipv4_peer_id_map
|
||||
.insert(ipv4.clone(), node_id.clone().into());
|
||||
}
|
||||
|
||||
for cidr in peer_info.proxy_cidrs.iter() {
|
||||
let cidr: cidr::IpCidr = cidr.parse().unwrap();
|
||||
route_table
|
||||
.cidr_peer_id_map
|
||||
.insert(cidr, node_id.clone().into());
|
||||
}
|
||||
};
|
||||
|
||||
for neighbor in packet.neighbors.iter() {
|
||||
if neighbor.peer_id == my_id {
|
||||
continue;
|
||||
}
|
||||
update(neighbor.cost + 1, &neighbor);
|
||||
log::trace!("route info: {:?}", neighbor);
|
||||
}
|
||||
|
||||
// add the sender peer to route info
|
||||
update(1, &packet.myself);
|
||||
|
||||
log::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
|
||||
}
|
||||
|
||||
async fn send_sync_peer_request(
|
||||
interface: &RouteInterfaceBox,
|
||||
my_peer_id: packet::UUID,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_id: PeerId,
|
||||
route_table: Arc<RouteTable>,
|
||||
) -> Result<(), Error> {
|
||||
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
|
||||
// copy the route info
|
||||
for item in route_table.route_info.iter() {
|
||||
let (k, v) = item.pair();
|
||||
route_info_copy.push(v.clone().clone_for_route_table(&(*k).into(), v.cost, &v));
|
||||
}
|
||||
let msg = SyncPeer::new(my_peer_id, peer_id.into(), route_info_copy, global_ctx);
|
||||
// TODO: this may exceed the MTU of the tunnel
|
||||
interface
|
||||
.send_route_packet(encode_to_bytes::<_, 4096>(&msg), 1, &peer_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn sync_peer_periodically(&self) {
|
||||
let route_table = self.route_table.clone();
|
||||
let global_ctx = self.global_ctx.clone();
|
||||
let my_peer_id = self.my_peer_id.clone();
|
||||
let interface = self.interface.clone();
|
||||
let notifier = self.need_sync_notifier.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
let lockd_interface = interface.lock().await;
|
||||
let interface = lockd_interface.as_ref().unwrap();
|
||||
let peers = interface.list_peers().await;
|
||||
for peer in peers.iter() {
|
||||
let ret = Self::send_sync_peer_request(
|
||||
interface,
|
||||
my_peer_id.clone(),
|
||||
global_ctx.clone(),
|
||||
*peer,
|
||||
route_table.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match &ret {
|
||||
Ok(_) => {
|
||||
log::trace!("send sync peer request to peer: {}", peer);
|
||||
}
|
||||
Err(Error::PeerNoConnectionError(_)) => {
|
||||
log::trace!("peer {} no connection", peer);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"send sync peer request to peer: {} error: {:?}",
|
||||
peer,
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
tokio::select! {
|
||||
_ = notifier.notified() => {
|
||||
log::trace!("sync peer request triggered by notifier");
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(1)) => {
|
||||
log::trace!("sync peer request triggered by timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
async fn check_expired_sync_peer_from_remote(&self) {
|
||||
let route_table = self.route_table.clone();
|
||||
let my_peer_id = self.my_peer_id.clone();
|
||||
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
|
||||
let notifier = self.need_sync_notifier.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
let mut need_update_route = false;
|
||||
let now = std::time::Instant::now();
|
||||
let mut need_remove = Vec::new();
|
||||
for item in sync_peer_from_remote.iter() {
|
||||
let (k, v) = item.pair();
|
||||
if now.duration_since(v.last_update).as_secs() > 5 {
|
||||
need_update_route = true;
|
||||
need_remove.insert(0, k.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for k in need_remove.iter() {
|
||||
log::warn!("remove expired sync peer: {:?}", k);
|
||||
sync_peer_from_remote.remove(k);
|
||||
}
|
||||
|
||||
if need_update_route {
|
||||
Self::update_route_table(
|
||||
my_peer_id.clone(),
|
||||
sync_peer_from_remote.clone(),
|
||||
route_table.clone(),
|
||||
);
|
||||
notifier.notify_one();
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
let ipv4 = std::net::IpAddr::V4(*ipv4);
|
||||
for item in self.route_table.cidr_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
if k.contains(&ipv4) {
|
||||
return Some(*v);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Route for BasicRoute {
|
||||
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
|
||||
*self.interface.lock().await = Some(interface);
|
||||
self.sync_peer_periodically().await;
|
||||
self.check_expired_sync_peer_from_remote().await;
|
||||
Ok(1)
|
||||
}
|
||||
|
||||
async fn close(&self) {}
|
||||
|
||||
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
|
||||
async fn handle_route_packet(&self, src_peer_id: uuid::Uuid, packet: Bytes) {
|
||||
let packet = decode_from_bytes::<SyncPeer>(&packet).unwrap();
|
||||
let p: SyncPeer = packet.deserialize(&mut rkyv::Infallible).unwrap();
|
||||
let mut updated = true;
|
||||
assert_eq!(packet.myself.peer_id.to_uuid(), src_peer_id);
|
||||
self.sync_peer_from_remote
|
||||
.entry(packet.myself.peer_id.to_uuid())
|
||||
.and_modify(|v| {
|
||||
if v.packet == *packet {
|
||||
updated = false;
|
||||
} else {
|
||||
v.packet = p.clone();
|
||||
}
|
||||
v.last_update = std::time::Instant::now();
|
||||
})
|
||||
.or_insert(SyncPeerFromRemote {
|
||||
packet: p.clone(),
|
||||
last_update: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
if updated {
|
||||
Self::update_route_table(
|
||||
self.my_peer_id.clone(),
|
||||
self.sync_peer_from_remote.clone(),
|
||||
self.route_table.clone(),
|
||||
);
|
||||
self.need_sync_notifier.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
|
||||
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
|
||||
return Some(*peer_id);
|
||||
}
|
||||
|
||||
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
|
||||
return Some(peer_id);
|
||||
}
|
||||
|
||||
log::info!("no peer id for ipv4: {}", ipv4_addr);
|
||||
return None;
|
||||
}
|
||||
|
||||
async fn get_next_hop(&self, dst_peer_id: &PeerId) -> Option<PeerId> {
|
||||
match self.route_table.route_info.get(dst_peer_id) {
|
||||
Some(info) => {
|
||||
return Some(info.peer_id.clone().into());
|
||||
}
|
||||
None => {
|
||||
log::error!("no route info for dst_peer_id: {}", dst_peer_id);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_routes(&self) -> Vec<easytier_rpc::Route> {
|
||||
let mut routes = Vec::new();
|
||||
|
||||
let parse_route_info = |real_peer_id: &Uuid, route_info: &SyncPeerInfo| {
|
||||
let mut route = easytier_rpc::Route::default();
|
||||
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
|
||||
ipv4_addr.to_string()
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
route.peer_id = real_peer_id.to_string();
|
||||
route.next_hop_peer_id = Uuid::from(route_info.peer_id.clone()).to_string();
|
||||
route.cost = route_info.cost as i32;
|
||||
route.proxy_cidrs = route_info.proxy_cidrs.clone();
|
||||
route.hostname = if let Some(hostname) = &route_info.hostname {
|
||||
hostname.clone()
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let mut stun_info = StunInfo::default();
|
||||
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
|
||||
stun_info.set_udp_nat_type(udp_nat_type);
|
||||
}
|
||||
route.stun_info = Some(stun_info);
|
||||
|
||||
route
|
||||
};
|
||||
|
||||
self.route_table.route_info.iter().for_each(|item| {
|
||||
routes.push(parse_route_info(item.key(), item.value()));
|
||||
});
|
||||
|
||||
routes
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::error::Error;
|
||||
|
||||
use super::PeerId;
|
||||
|
||||
#[async_trait]
|
||||
pub trait RouteInterface {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
async fn send_route_packet(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
route_id: u8,
|
||||
dst_peer_id: &PeerId,
|
||||
) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Route {
|
||||
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
|
||||
async fn close(&self);
|
||||
|
||||
async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId>;
|
||||
async fn get_next_hop(&self, peer_id: &PeerId) -> Option<PeerId>;
|
||||
|
||||
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes);
|
||||
|
||||
async fn list_routes(&self) -> Vec<easytier_rpc::Route>;
|
||||
}
|
||||
|
||||
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
|
||||
@@ -0,0 +1,55 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use easytier_rpc::peer_manage_rpc_server::PeerManageRpc;
|
||||
use easytier_rpc::{ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse};
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use super::peer_manager::PeerManager;
|
||||
|
||||
pub struct PeerManagerRpcService {
|
||||
peer_manager: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
impl PeerManagerRpcService {
|
||||
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
|
||||
PeerManagerRpcService { peer_manager }
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl PeerManageRpc for PeerManagerRpcService {
|
||||
async fn list_peer(
|
||||
&self,
|
||||
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
|
||||
) -> Result<Response<ListPeerResponse>, Status> {
|
||||
let mut reply = ListPeerResponse::default();
|
||||
|
||||
let peers = self.peer_manager.get_peer_map().list_peers().await;
|
||||
for peer in peers {
|
||||
let mut peer_info = easytier_rpc::PeerInfo::default();
|
||||
peer_info.peer_id = peer.to_string();
|
||||
|
||||
if let Some(conns) = self
|
||||
.peer_manager
|
||||
.get_peer_map()
|
||||
.list_peer_conns(&peer)
|
||||
.await
|
||||
{
|
||||
peer_info.conns = conns;
|
||||
}
|
||||
|
||||
reply.peer_infos.push(peer_info);
|
||||
}
|
||||
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
|
||||
async fn list_route(
|
||||
&self,
|
||||
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
|
||||
) -> Result<Response<ListRouteResponse>, Status> {
|
||||
let mut reply = ListRouteResponse::default();
|
||||
reply.routes = self.peer_manager.list_routes().await;
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::tests::get_mock_global_ctx},
|
||||
peers::rip_route::BasicRoute,
|
||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
||||
};
|
||||
|
||||
use super::peer_manager::PeerManager;
|
||||
|
||||
pub async fn create_mock_peer_manager() -> Arc<PeerManager> {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(get_mock_global_ctx(), s));
|
||||
peer_mgr
|
||||
.set_route(BasicRoute::new(
|
||||
peer_mgr.my_node_id(),
|
||||
peer_mgr.get_global_ctx(),
|
||||
))
|
||||
.await;
|
||||
peer_mgr.run().await.unwrap();
|
||||
peer_mgr
|
||||
}
|
||||
|
||||
pub async fn connect_peer_manager(client: Arc<PeerManager>, server: Arc<PeerManager>) {
|
||||
let (a_ring, b_ring) = create_ring_tunnel_pair();
|
||||
let a_mgr_copy = client.clone();
|
||||
tokio::spawn(async move {
|
||||
a_mgr_copy.add_client_tunnel(a_ring).await.unwrap();
|
||||
});
|
||||
let b_mgr_copy = server.clone();
|
||||
tokio::spawn(async move {
|
||||
b_mgr_copy.add_tunnel_as_server(b_ring).await.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn wait_route_appear_with_cost(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
node_id: uuid::Uuid,
|
||||
cost: Option<i32>,
|
||||
) -> Result<(), Error> {
|
||||
let now = std::time::Instant::now();
|
||||
while now.elapsed().as_secs() < 5 {
|
||||
let route = peer_mgr.list_routes().await;
|
||||
if route.iter().any(|r| {
|
||||
r.peer_id.clone().parse::<uuid::Uuid>().unwrap() == node_id
|
||||
&& (cost.is_none() || r.cost == cost.unwrap())
|
||||
}) {
|
||||
return Ok(());
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
|
||||
pub async fn wait_route_appear(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
node_id: uuid::Uuid,
|
||||
) -> Result<(), Error> {
|
||||
wait_route_appear_with_cost(peer_mgr, node_id, None).await
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
tonic::include_proto!("cli"); // The string specified here must match the proto package name
|
||||
@@ -0,0 +1,4 @@
|
||||
pub mod cli;
|
||||
pub use cli::*;
|
||||
|
||||
pub mod peer;
|
||||
@@ -0,0 +1,20 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
|
||||
pub struct GetIpListResponse {
|
||||
pub public_ipv4: String,
|
||||
pub interface_ipv4s: Vec<String>,
|
||||
pub public_ipv6: String,
|
||||
pub interface_ipv6s: Vec<String>,
|
||||
}
|
||||
|
||||
impl GetIpListResponse {
|
||||
pub fn new() -> Self {
|
||||
GetIpListResponse {
|
||||
public_ipv4: "".to_string(),
|
||||
interface_ipv4s: vec![],
|
||||
public_ipv6: "".to_string(),
|
||||
interface_ipv6s: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
mod three_node;
|
||||
|
||||
pub fn get_guest_veth_name(net_ns: &str) -> &str {
|
||||
Box::leak(format!("veth_{}_g", net_ns).into_boxed_str())
|
||||
}
|
||||
|
||||
pub fn get_host_veth_name(net_ns: &str) -> &str {
|
||||
Box::leak(format!("veth_{}_h", net_ns).into_boxed_str())
|
||||
}
|
||||
|
||||
pub fn del_netns(name: &str) {
|
||||
// del veth host
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "del", get_host_veth_name(name)])
|
||||
.output();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["netns", "del", name])
|
||||
.output();
|
||||
}
|
||||
|
||||
pub fn create_netns(name: &str, ipv4: &str) {
|
||||
// create netns
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["netns", "add", name])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// set lo up
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["netns", "exec", name, "ip", "link", "set", "lo", "up"])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"link",
|
||||
"add",
|
||||
get_host_veth_name(name),
|
||||
"type",
|
||||
"veth",
|
||||
"peer",
|
||||
"name",
|
||||
get_guest_veth_name(name),
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "set", get_guest_veth_name(name), "netns", name])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns",
|
||||
"exec",
|
||||
name,
|
||||
"ip",
|
||||
"link",
|
||||
"set",
|
||||
get_guest_veth_name(name),
|
||||
"up",
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "set", get_host_veth_name(name), "up"])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns",
|
||||
"exec",
|
||||
name,
|
||||
"ip",
|
||||
"addr",
|
||||
"add",
|
||||
ipv4,
|
||||
"dev",
|
||||
get_guest_veth_name(name),
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn prepare_bridge(name: &str) {
|
||||
// del bridge with brctl
|
||||
let _ = std::process::Command::new("brctl")
|
||||
.args(&["delbr", name])
|
||||
.output();
|
||||
|
||||
// create new br
|
||||
let _ = std::process::Command::new("brctl")
|
||||
.args(&["addbr", name])
|
||||
.output();
|
||||
}
|
||||
|
||||
pub fn add_ns_to_bridge(br_name: &str, ns_name: &str) {
|
||||
// use brctl to add ns to bridge
|
||||
let _ = std::process::Command::new("brctl")
|
||||
.args(&["addif", br_name, get_host_veth_name(ns_name)])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// set bridge up
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "set", br_name, "up"])
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn enable_log() {
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||||
.from_env()
|
||||
.unwrap();
|
||||
tracing_subscriber::fmt::fmt()
|
||||
.pretty()
|
||||
.with_env_filter(filter)
|
||||
.init();
|
||||
}
|
||||
|
||||
fn check_route(ipv4: &str, dst_peer_id: uuid::Uuid, routes: Vec<easytier_rpc::Route>) {
|
||||
let mut found = false;
|
||||
for r in routes.iter() {
|
||||
if r.ipv4_addr == ipv4.to_string() {
|
||||
found = true;
|
||||
assert_eq!(r.peer_id, dst_peer_id.to_string(), "{:?}", routes);
|
||||
}
|
||||
}
|
||||
assert!(found);
|
||||
}
|
||||
|
||||
async fn wait_proxy_route_appear(
|
||||
mgr: &std::sync::Arc<crate::peers::peer_manager::PeerManager>,
|
||||
ipv4: &str,
|
||||
dst_peer_id: uuid::Uuid,
|
||||
proxy_cidr: &str,
|
||||
) {
|
||||
let now = std::time::Instant::now();
|
||||
loop {
|
||||
for r in mgr.list_routes().await.iter() {
|
||||
let r = r;
|
||||
if r.proxy_cidrs.contains(&proxy_cidr.to_owned()) {
|
||||
assert_eq!(r.peer_id, dst_peer_id.to_string());
|
||||
assert_eq!(r.ipv4_addr, ipv4);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
panic!("wait proxy route appear timeout");
|
||||
}
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_link_status(net_ns: &str, up: bool) {
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns",
|
||||
"exec",
|
||||
net_ns,
|
||||
"ip",
|
||||
"link",
|
||||
"set",
|
||||
get_guest_veth_name(net_ns),
|
||||
if up { "up" } else { "down" },
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
common::netns::{NetNS, ROOT_NETNS_NAME},
|
||||
instance::instance::{Instance, InstanceConfigWriter},
|
||||
tunnels::{
|
||||
common::tests::_tunnel_pingpong_netns,
|
||||
ring_tunnel::RingTunnelConnector,
|
||||
tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener},
|
||||
udp_tunnel::UdpTunnelConnector,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn prepare_linux_namespaces() {
|
||||
del_netns("net_a");
|
||||
del_netns("net_b");
|
||||
del_netns("net_c");
|
||||
del_netns("net_d");
|
||||
|
||||
create_netns("net_a", "10.1.1.1/24");
|
||||
create_netns("net_b", "10.1.1.2/24");
|
||||
create_netns("net_c", "10.1.2.3/24");
|
||||
create_netns("net_d", "10.1.2.4/24");
|
||||
|
||||
prepare_bridge("br_a");
|
||||
prepare_bridge("br_b");
|
||||
|
||||
add_ns_to_bridge("br_a", "net_a");
|
||||
add_ns_to_bridge("br_a", "net_b");
|
||||
add_ns_to_bridge("br_b", "net_c");
|
||||
add_ns_to_bridge("br_b", "net_d");
|
||||
}
|
||||
|
||||
pub async fn prepare_inst_configs() {
|
||||
InstanceConfigWriter::new("inst1")
|
||||
.set_ns(Some("net_a".into()))
|
||||
.set_addr("10.144.144.1".to_owned());
|
||||
|
||||
InstanceConfigWriter::new("inst2")
|
||||
.set_ns(Some("net_b".into()))
|
||||
.set_addr("10.144.144.2".to_owned());
|
||||
|
||||
InstanceConfigWriter::new("inst3")
|
||||
.set_ns(Some("net_c".into()))
|
||||
.set_addr("10.144.144.3".to_owned());
|
||||
}
|
||||
|
||||
pub async fn init_three_node(proto: &str) -> Vec<Instance> {
|
||||
log::set_max_level(log::LevelFilter::Info);
|
||||
prepare_linux_namespaces();
|
||||
|
||||
prepare_inst_configs().await;
|
||||
|
||||
let mut inst1 = Instance::new("inst1");
|
||||
let mut inst2 = Instance::new("inst2");
|
||||
let mut inst3 = Instance::new("inst3");
|
||||
|
||||
inst1.run().await.unwrap();
|
||||
inst2.run().await.unwrap();
|
||||
inst3.run().await.unwrap();
|
||||
|
||||
if proto == "tcp" {
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(TcpTunnelConnector::new(
|
||||
"tcp://10.1.1.1:11010".parse().unwrap(),
|
||||
));
|
||||
} else {
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(UdpTunnelConnector::new(
|
||||
"udp://10.1.1.1:11010".parse().unwrap(),
|
||||
));
|
||||
}
|
||||
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(RingTunnelConnector::new(
|
||||
format!("ring://{}", inst3.id()).parse().unwrap(),
|
||||
));
|
||||
|
||||
// wait inst2 have two route.
|
||||
let now = std::time::Instant::now();
|
||||
loop {
|
||||
if inst2.get_peer_manager().list_routes().await.len() == 2 {
|
||||
break;
|
||||
}
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
panic!("wait inst2 have two route timeout");
|
||||
}
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
vec![inst1, inst2, inst3]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn basic_three_node_test_tcp() {
|
||||
let insts = init_three_node("tcp").await;
|
||||
|
||||
check_route(
|
||||
"10.144.144.2",
|
||||
insts[1].id(),
|
||||
insts[0].get_peer_manager().list_routes().await,
|
||||
);
|
||||
|
||||
check_route(
|
||||
"10.144.144.3",
|
||||
insts[2].id(),
|
||||
insts[0].get_peer_manager().list_routes().await,
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn basic_three_node_test_udp() {
|
||||
let insts = init_three_node("udp").await;
|
||||
|
||||
check_route(
|
||||
"10.144.144.2",
|
||||
insts[1].id(),
|
||||
insts[0].get_peer_manager().list_routes().await,
|
||||
);
|
||||
|
||||
check_route(
|
||||
"10.144.144.3",
|
||||
insts[2].id(),
|
||||
insts[0].get_peer_manager().list_routes().await,
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn tcp_proxy_three_node_test() {
|
||||
let insts = init_three_node("tcp").await;
|
||||
|
||||
insts[2]
|
||||
.get_global_ctx()
|
||||
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
|
||||
|
||||
wait_proxy_route_appear(
|
||||
&insts[0].get_peer_manager(),
|
||||
"10.144.144.3",
|
||||
insts[2].id(),
|
||||
"10.1.2.0/24",
|
||||
)
|
||||
.await;
|
||||
|
||||
// wait updater
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
|
||||
|
||||
let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap());
|
||||
let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap());
|
||||
|
||||
_tunnel_pingpong_netns(
|
||||
tcp_listener,
|
||||
tcp_connector,
|
||||
NetNS::new(Some("net_d".into())),
|
||||
NetNS::new(Some("net_a".into())),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn icmp_proxy_three_node_test() {
|
||||
let insts = init_three_node("tcp").await;
|
||||
|
||||
insts[2]
|
||||
.get_global_ctx()
|
||||
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
|
||||
|
||||
wait_proxy_route_appear(
|
||||
&insts[0].get_peer_manager(),
|
||||
"10.144.144.3",
|
||||
insts[2].id(),
|
||||
"10.1.2.0/24",
|
||||
)
|
||||
.await;
|
||||
|
||||
// wait updater
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
|
||||
|
||||
// send ping with shell in net_a to net_d
|
||||
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
|
||||
let code = tokio::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns", "exec", "net_a", "ping", "-c", "1", "-W", "1", "10.1.2.4",
|
||||
])
|
||||
.status()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(code.code().unwrap(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn proxy_three_node_disconnect_test() {
|
||||
InstanceConfigWriter::new("inst4")
|
||||
.set_ns(Some("net_d".into()))
|
||||
.set_addr("10.144.144.4".to_owned());
|
||||
|
||||
let mut inst4 = Instance::new("inst4");
|
||||
inst4
|
||||
.get_conn_manager()
|
||||
.add_connector(TcpTunnelConnector::new(
|
||||
"tcp://10.1.2.3:11010".parse().unwrap(),
|
||||
));
|
||||
inst4.run().await.unwrap();
|
||||
|
||||
tokio::spawn(async {
|
||||
loop {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
|
||||
set_link_status("net_d", false);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
|
||||
set_link_status("net_d", true);
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: add some traffic here, also should check route & peer list
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(35)).await;
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
use std::result::Result;
|
||||
use tokio::io;
|
||||
use tokio_util::{
|
||||
bytes::{BufMut, Bytes, BytesMut},
|
||||
codec::{Decoder, Encoder},
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
|
||||
pub struct BytesCodec {
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl BytesCodec {
|
||||
/// Creates a new `BytesCodec` for shipping around raw bytes.
|
||||
pub fn new(capacity: usize) -> BytesCodec {
|
||||
BytesCodec { capacity }
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for BytesCodec {
|
||||
type Item = BytesMut;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
|
||||
if !buf.is_empty() {
|
||||
let len = buf.len();
|
||||
let ret = Some(buf.split_to(len));
|
||||
buf.reserve(self.capacity);
|
||||
Ok(ret)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Bytes> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<BytesMut> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,399 @@
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
net::IpAddr,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
|
||||
use tokio::{sync::Mutex, time::error::Elapsed};
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::tunnels::{SinkError, TunnelError};
|
||||
|
||||
use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo};
|
||||
|
||||
pub struct FramedTunnel<R, W> {
|
||||
read: Arc<Mutex<R>>,
|
||||
write: Arc<Mutex<W>>,
|
||||
|
||||
info: Option<TunnelInfo>,
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
pub fn new(read: R, write: W, info: Option<TunnelInfo>) -> Self {
|
||||
FramedTunnel {
|
||||
read: Arc::new(Mutex::new(read)),
|
||||
write: Arc::new(Mutex::new(write)),
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box<dyn Tunnel> {
|
||||
Box::new(FramedTunnel::new(read, write, Some(info)))
|
||||
}
|
||||
|
||||
pub fn recv_stream(&self) -> impl DatagramStream {
|
||||
let read = self.read.clone();
|
||||
let info = self.info.clone();
|
||||
stream! {
|
||||
loop {
|
||||
let read_ret = read.lock().await.next().await;
|
||||
if read_ret.is_none() {
|
||||
tracing::info!(?info, "read_ret is none");
|
||||
yield Err(TunnelError::CommonError("recv stream closed".to_string()));
|
||||
} else {
|
||||
let read_ret = read_ret.unwrap();
|
||||
if read_ret.is_err() {
|
||||
let err = read_ret.err().unwrap();
|
||||
tracing::info!(?info, "recv stream read error");
|
||||
yield Err(TunnelError::CommonError(err.to_string()));
|
||||
} else {
|
||||
yield Ok(read_ret.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_sink(&self) -> impl DatagramSink {
|
||||
struct SendSink<W, WE> {
|
||||
write: Arc<Mutex<W>>,
|
||||
max_buffer_size: usize,
|
||||
sending_buffers: Option<VecDeque<SinkItem>>,
|
||||
send_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
close_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
}
|
||||
|
||||
impl<W, WE> SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
fn try_send_buffser(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), WE>> {
|
||||
if self.send_task.is_none() {
|
||||
let mut buffers = self.sending_buffers.take().unwrap();
|
||||
let tun = self.write.clone();
|
||||
let send_task = async move {
|
||||
if buffers.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
while let Some(buf) = buffers.front() {
|
||||
log::trace!(
|
||||
"try_send buffer, len: {:?}, buf: {:?}",
|
||||
buffers.len(),
|
||||
&buf
|
||||
);
|
||||
let timeout_task = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(1),
|
||||
locked_tun.send(buf.clone()),
|
||||
);
|
||||
let send_res = timeout_task.await;
|
||||
let Ok(send_res) = send_res else {
|
||||
// panic!("send timeout");
|
||||
let err = send_res.err().unwrap();
|
||||
return Err(err.into());
|
||||
};
|
||||
let Ok(_) = send_res else {
|
||||
let err = send_res.err().unwrap();
|
||||
println!("send error: {:?}", err);
|
||||
return Err(err);
|
||||
};
|
||||
buffers.pop_front();
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
self.send_task = Some(Box::pin(send_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self.send_task = None;
|
||||
self.sending_buffers = Some(VecDeque::new());
|
||||
return Poll::Ready(ret);
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, WE> Sink<SinkItem> for SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let sending_buf = self_mut.sending_buffers.as_ref();
|
||||
// if sending_buffers is None, must already be doing flush
|
||||
if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size {
|
||||
return self_mut.poll_flush_unpin(cx);
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
assert!(self.send_task.is_none());
|
||||
let self_mut = self.get_mut();
|
||||
self_mut.sending_buffers.as_mut().unwrap().push_back(item);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let ret = self_mut.try_send_buffser(cx);
|
||||
match ret {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))),
|
||||
Poll::Pending => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.close_task.is_none() {
|
||||
let tun = self_mut.write.clone();
|
||||
let close_task = async move {
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
return locked_tun.close().await;
|
||||
};
|
||||
self_mut.close_task = Some(Box::pin(close_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self_mut.close_task = None;
|
||||
|
||||
if ret.is_err() {
|
||||
return Poll::Ready(Err(SinkError::CommonError(
|
||||
ret.err().unwrap().to_string(),
|
||||
)));
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SendSink {
|
||||
write: self.write.clone(),
|
||||
max_buffer_size: 1000,
|
||||
sending_buffers: Some(VecDeque::new()),
|
||||
send_task: None,
|
||||
close_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> Tunnel for FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
if self.info.is_none() {
|
||||
None
|
||||
} else {
|
||||
Some(self.info.clone().unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelWithCustomInfo {
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
info: TunnelInfo,
|
||||
}
|
||||
|
||||
impl TunnelWithCustomInfo {
|
||||
pub fn new(tunnel: Box<dyn Tunnel>, info: TunnelInfo) -> Self {
|
||||
TunnelWithCustomInfo { tunnel, info }
|
||||
}
|
||||
}
|
||||
|
||||
impl Tunnel for TunnelWithCustomInfo {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
self.tunnel.stream()
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
self.tunnel.sink()
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(self.info.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
let ifaces = pnet::datalink::interfaces();
|
||||
for iface in ifaces {
|
||||
for ip in iface.ips {
|
||||
if ip.ip() == *local_ip {
|
||||
return Some(iface.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use futures::SinkExt;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::netns::NetNS,
|
||||
tunnels::{close_tunnel, TunnelConnector, TunnelListener},
|
||||
};
|
||||
|
||||
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
||||
let mut recv = Box::into_pin(tunnel.stream());
|
||||
let mut send = Box::into_pin(tunnel.sink());
|
||||
|
||||
while let Some(ret) = recv.next().await {
|
||||
if ret.is_err() {
|
||||
log::trace!("recv error: {:?}", ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
let res = ret.unwrap();
|
||||
log::trace!("recv a msg, try echo back: {:?}", res);
|
||||
send.send(Bytes::from(res)).await.unwrap();
|
||||
if once {
|
||||
break;
|
||||
}
|
||||
}
|
||||
log::warn!("echo server exit...");
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||
mut listener: L,
|
||||
mut connector: C,
|
||||
l_netns: NetNS,
|
||||
c_netns: NetNS,
|
||||
) where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
l_netns
|
||||
.run_async(|| async {
|
||||
listener.listen().await.unwrap();
|
||||
})
|
||||
.await;
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
connector.remote_url().to_string()
|
||||
);
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
let send_data = Bytes::from("abc");
|
||||
send.send(send_data).await.unwrap();
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret, Bytes::from("abc"));
|
||||
|
||||
close_tunnel(&tunnel).await.unwrap();
|
||||
|
||||
if connector.remote_url().scheme() == "udp" {
|
||||
lis.abort();
|
||||
} else {
|
||||
// lis should finish in 1 second
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
|
||||
// prepare a 4k buffer with random data
|
||||
let mut send_buf = BytesMut::new();
|
||||
for _ in 0..64 {
|
||||
send_buf.put_i128(rand::random::<i128>());
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mut count = 0;
|
||||
while now.elapsed().as_secs() < 3 {
|
||||
send.send(send_buf.clone().freeze()).await.unwrap();
|
||||
let _ = recv.next().await.unwrap().unwrap();
|
||||
count += 1;
|
||||
}
|
||||
println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs());
|
||||
|
||||
lis.abort();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
pub mod codec;
|
||||
pub mod common;
|
||||
pub mod ring_tunnel;
|
||||
pub mod stats;
|
||||
pub mod tcp_tunnel;
|
||||
pub mod tunnel_filter;
|
||||
pub mod udp_tunnel;
|
||||
|
||||
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use easytier_rpc::TunnelInfo;
|
||||
use futures::{Sink, SinkExt, Stream};
|
||||
|
||||
use thiserror::Error;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TunnelError {
|
||||
#[error("Error: {0}")]
|
||||
CommonError(String),
|
||||
#[error("io error")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("wait resp error")]
|
||||
WaitRespError(String),
|
||||
#[error("Connect Error: {0}")]
|
||||
ConnectError(String),
|
||||
#[error("Invalid Protocol: {0}")]
|
||||
InvalidProtocol(String),
|
||||
#[error("Invalid Addr: {0}")]
|
||||
InvalidAddr(String),
|
||||
#[error("Tun Error: {0}")]
|
||||
TunError(String),
|
||||
#[error("timeout")]
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
pub type StreamT = BytesMut;
|
||||
pub type StreamItem = Result<StreamT, TunnelError>;
|
||||
pub type SinkItem = Bytes;
|
||||
pub type SinkError = TunnelError;
|
||||
|
||||
pub trait DatagramStream: Stream<Item = StreamItem> + Send + Sync {}
|
||||
impl<T> DatagramStream for T where T: Stream<Item = StreamItem> + Send + Sync {}
|
||||
pub trait DatagramSink: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
impl<T> DatagramSink for T where T: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Tunnel: Send + Sync {
|
||||
fn stream(&self) -> Box<dyn DatagramStream>;
|
||||
fn sink(&self) -> Box<dyn DatagramSink>;
|
||||
|
||||
fn pin_stream(&self) -> Pin<Box<dyn DatagramStream>> {
|
||||
Box::into_pin(self.stream())
|
||||
}
|
||||
|
||||
fn pin_sink(&self) -> Pin<Box<dyn DatagramSink>> {
|
||||
Box::into_pin(self.sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo>;
|
||||
}
|
||||
|
||||
pub async fn close_tunnel(t: &Box<dyn Tunnel>) -> Result<(), TunnelError> {
|
||||
t.pin_sink().close().await
|
||||
}
|
||||
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
|
||||
fn get(&self) -> u32;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelListener: Send + Sync {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError>;
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn local_url(&self) -> url::Url;
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
#[derive(Debug)]
|
||||
struct FakeTunnelConnCounter {}
|
||||
impl TunnelConnCounter for FakeTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
Arc::new(Box::new(FakeTunnelConnCounter {}))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn remote_url(&self) -> url::Url;
|
||||
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
|
||||
}
|
||||
|
||||
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
|
||||
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn Tunnel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Tunnel")
|
||||
.field("info", &self.info())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelConnector + Sync + Send {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelConnector")
|
||||
.field("remote_url", &self.remote_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelListener {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelListener")
|
||||
.field("local_url", &self.local_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait FromUrl {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub(crate) fn check_scheme_and_get_socket_addr<T>(
|
||||
url: &url::Url,
|
||||
scheme: &str,
|
||||
) -> Result<T, TunnelError>
|
||||
where
|
||||
T: FromUrl,
|
||||
{
|
||||
if url.scheme() != scheme {
|
||||
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||
}
|
||||
|
||||
Ok(T::from_url(url.clone())?)
|
||||
}
|
||||
|
||||
impl FromUrl for SocketAddr {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
Ok(url.socket_addrs(|| None)?.pop().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromUrl for uuid::Uuid {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
let o = url.host_str().unwrap();
|
||||
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
||||
Ok(o)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Sink;
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio_util::bytes::BytesMut;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tunnels::{SinkError, SinkItem};
|
||||
|
||||
use super::{
|
||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
|
||||
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
static RING_TUNNEL_CAP: usize = 1000;
|
||||
|
||||
pub struct RingTunnel {
|
||||
id: Uuid,
|
||||
ring: Arc<ArrayQueue<SinkItem>>,
|
||||
consume_notify: Arc<Notify>,
|
||||
produce_notify: Arc<Notify>,
|
||||
closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl RingTunnel {
|
||||
pub fn new(cap: usize) -> Self {
|
||||
RingTunnel {
|
||||
id: Uuid::new_v4(),
|
||||
ring: Arc::new(ArrayQueue::new(cap)),
|
||||
consume_notify: Arc::new(Notify::new()),
|
||||
produce_notify: Arc::new(Notify::new()),
|
||||
closed: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
|
||||
let mut ret = Self::new(cap);
|
||||
ret.id = id;
|
||||
ret
|
||||
}
|
||||
|
||||
fn recv_stream(&self) -> impl DatagramStream {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
stream! {
|
||||
loop {
|
||||
if closed.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
log::warn!("ring recv tunnel {:?} closed", id);
|
||||
yield Err(TunnelError::CommonError("Closed".to_owned()));
|
||||
}
|
||||
match ring.pop() {
|
||||
Some(v) => {
|
||||
let mut out = BytesMut::new();
|
||||
out.extend_from_slice(&v);
|
||||
consume_notify.notify_one();
|
||||
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
|
||||
yield Ok(out);
|
||||
},
|
||||
None => {
|
||||
log::trace!("waiting recv buffer, id: {}", id);
|
||||
produce_notify.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_sink(&self) -> impl DatagramSink {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
|
||||
// type T = RingTunnel;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
struct T {
|
||||
ring: RingTunnel,
|
||||
wait_consume_task: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl T {
|
||||
fn wait_ring_consume(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
expected_size: usize,
|
||||
) -> std::task::Poll<()> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.ring.ring.len() <= expected_size {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
if self_mut.wait_consume_task.is_none() {
|
||||
let id = self_mut.ring.id;
|
||||
let consume_notify = self_mut.ring.consume_notify.clone();
|
||||
let ring = self_mut.ring.ring.clone();
|
||||
let task = async move {
|
||||
log::trace!(
|
||||
"waiting ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
while ring.len() > expected_size {
|
||||
consume_notify.notified().await;
|
||||
}
|
||||
log::trace!(
|
||||
"ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
};
|
||||
self_mut.wait_consume_task = Some(tokio::spawn(task));
|
||||
}
|
||||
let task = self_mut.wait_consume_task.as_mut().unwrap();
|
||||
match task.poll_unpin(cx) {
|
||||
Poll::Ready(_) => {
|
||||
self_mut.wait_consume_task = None;
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<SinkItem> for T {
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
let expected_size = self.ring.ring.capacity() - 1;
|
||||
match self.wait_ring_consume(cx, expected_size) {
|
||||
Poll::Ready(_) => Poll::Ready(Ok(())),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
item: SinkItem,
|
||||
) -> Result<(), Self::Error> {
|
||||
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
|
||||
self.ring.ring.push(item).unwrap();
|
||||
self.ring.produce_notify.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.ring
|
||||
.closed
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
log::warn!("ring tunnel send {:?} closed", self.ring.id);
|
||||
self.ring.produce_notify.notify_one();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
T {
|
||||
ring: RingTunnel {
|
||||
id,
|
||||
ring,
|
||||
consume_notify,
|
||||
produce_notify,
|
||||
closed,
|
||||
},
|
||||
wait_consume_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Connection {
|
||||
client: RingTunnel,
|
||||
server: RingTunnel,
|
||||
connect_notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl Tunnel for RingTunnel {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
None
|
||||
// Some(TunnelInfo {
|
||||
// tunnel_type: "ring".to_owned(),
|
||||
// local_addr: format!("ring://{}", self.id),
|
||||
// remote_addr: format!("ring://{}", self.id),
|
||||
// })
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, Arc<Connection>>>>> =
|
||||
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingTunnelListener {
|
||||
listerner_addr: url::Url,
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
pub fn new(key: url::Url) -> Self {
|
||||
RingTunnelListener {
|
||||
listerner_addr: key,
|
||||
}
|
||||
}
|
||||
}
|
||||
struct ConnectionForServer {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForServer {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.server.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.client.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectionForClient {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForClient {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.client.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.server.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
async fn add_connection(listener_addr: uuid::Uuid) {
|
||||
CONNECTION_MAP.lock().await.insert(
|
||||
listener_addr.clone(),
|
||||
Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new_with_id(listener_addr.clone(), RING_TUNNEL_CAP),
|
||||
connect_notify: Arc::new(Notify::new()),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
||||
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for RingTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
log::info!("listen new conn of key: {}", self.listerner_addr);
|
||||
Self::add_connection(self.get_addr()?).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
||||
let val = CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.get(&self.get_addr()?)
|
||||
.unwrap()
|
||||
.clone();
|
||||
val.connect_notify.notified().await;
|
||||
CONNECTION_MAP.lock().await.remove(&self.get_addr()?);
|
||||
Self::add_connection(self.get_addr()?).await;
|
||||
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||
Ok(Box::new(ConnectionForServer { conn: val }))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.listerner_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingTunnelConnector {
|
||||
remote_addr: url::Url,
|
||||
}
|
||||
|
||||
impl RingTunnelConnector {
|
||||
pub fn new(remote_addr: url::Url) -> Self {
|
||||
RingTunnelConnector { remote_addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelConnector for RingTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let val = CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.get(&check_scheme_and_get_socket_addr::<Uuid>(
|
||||
&self.remote_addr,
|
||||
"ring",
|
||||
)?)
|
||||
.unwrap()
|
||||
.clone();
|
||||
val.connect_notify.notify_one();
|
||||
log::info!("connecting");
|
||||
Ok(Box::new(ConnectionForClient { conn: val }))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.remote_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
||||
let conn = Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
connect_notify: Arc::new(Notify::new()),
|
||||
});
|
||||
(
|
||||
Box::new(ConnectionForServer { conn: conn.clone() }),
|
||||
Box::new(ConnectionForClient { conn: conn }),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_pingpong() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id.clone());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_bench() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64};
|
||||
|
||||
pub struct WindowLatency {
|
||||
latency_us_window: Vec<AtomicU64>,
|
||||
latency_us_window_index: AtomicU32,
|
||||
latency_us_window_size: AtomicU32,
|
||||
}
|
||||
|
||||
impl WindowLatency {
|
||||
pub fn new(window_size: u32) -> Self {
|
||||
Self {
|
||||
latency_us_window: (0..window_size).map(|_| AtomicU64::new(0)).collect(),
|
||||
latency_us_window_index: AtomicU32::new(0),
|
||||
latency_us_window_size: AtomicU32::new(window_size),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_latency(&self, latency_us: u64) {
|
||||
let index = self
|
||||
.latency_us_window_index
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let index = index
|
||||
% self
|
||||
.latency_us_window_size
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
self.latency_us_window[index as usize]
|
||||
.store(latency_us, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_latency_us(&self) -> u64 {
|
||||
let window_size = self
|
||||
.latency_us_window_size
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let mut sum = 0;
|
||||
let mut count = 0;
|
||||
for i in 0..window_size {
|
||||
let latency_us =
|
||||
self.latency_us_window[i as usize].load(std::sync::atomic::Ordering::Relaxed);
|
||||
if latency_us > 0 {
|
||||
sum += latency_us;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
0
|
||||
} else {
|
||||
sum / count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Throughput {
|
||||
tx_bytes: AtomicU64,
|
||||
rx_bytes: AtomicU64,
|
||||
|
||||
tx_packets: AtomicU64,
|
||||
rx_packets: AtomicU64,
|
||||
}
|
||||
|
||||
impl Throughput {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tx_bytes: AtomicU64::new(0),
|
||||
rx_bytes: AtomicU64::new(0),
|
||||
|
||||
tx_packets: AtomicU64::new(0),
|
||||
rx_packets: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tx_bytes(&self) -> u64 {
|
||||
self.tx_bytes.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_bytes(&self) -> u64 {
|
||||
self.rx_bytes.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn tx_packets(&self) -> u64 {
|
||||
self.tx_packets.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_packets(&self) -> u64 {
|
||||
self.rx_packets.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn record_tx_bytes(&self, bytes: u64) {
|
||||
self.tx_bytes
|
||||
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
|
||||
self.tx_packets
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn record_rx_bytes(&self, bytes: u64) {
|
||||
self.rx_bytes
|
||||
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
|
||||
self.rx_packets
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelListener {
|
||||
addr: url::Url,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl TcpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelListener {
|
||||
addr,
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for TcpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
let socket = if addr.is_ipv4() {
|
||||
TcpSocket::new_v4()?
|
||||
} else {
|
||||
TcpSocket::new_v6()?
|
||||
};
|
||||
|
||||
socket.set_reuseaddr(true)?;
|
||||
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
socket.set_reuseport(true)?;
|
||||
socket.bind(addr)?;
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let listener = self.listener.as_ref().unwrap();
|
||||
let (stream, _) = listener.accept().await?;
|
||||
stream.set_nodelay(true).unwrap();
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: self.local_url().into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tunnel_with_tcp_stream(
|
||||
stream: TcpStream,
|
||||
remote_url: url::Url,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
stream.set_nodelay(true).unwrap();
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
remote_addr: remote_url.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(Box::new(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelConnector {
|
||||
addr: url::Url,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl TcpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelConnector {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(addr = ?self.addr, "connect tcp start");
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
let stream = TcpStream::connect(addr).await?;
|
||||
tracing::info!(addr = ?self.addr, "connect tcp succ");
|
||||
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(
|
||||
&mut self,
|
||||
is_ipv4: bool,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket = if is_ipv4 {
|
||||
TcpSocket::new_v4()?
|
||||
} else {
|
||||
TcpSocket::new_v6()?
|
||||
};
|
||||
socket.set_reuseaddr(true)?;
|
||||
|
||||
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
socket.set_reuseport(true)?;
|
||||
|
||||
socket.bind(*bind_addr)?;
|
||||
// linux does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// mac can handle this with bind correctly
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
futures.push(socket.connect(dst_addr.clone()));
|
||||
}
|
||||
|
||||
let Some(ret) = futures.next().await else {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"join connect futures failed".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for TcpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else if self.bind_addrs[0].is_ipv4() {
|
||||
self.connect_with_custom_bind(true).await
|
||||
} else {
|
||||
self.connect_with_custom_bind(false).await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::SinkExt;
|
||||
|
||||
use crate::tunnels::{
|
||||
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||
TunnelConnector,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_pingpong() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench_with_bind() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn tcp_bench_with_bind_fail() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
// test slow send lock in framed tunnel
|
||||
#[tokio::test]
|
||||
async fn tcp_multiple_sender_and_slow_receiver() {
|
||||
// console_subscriber::init();
|
||||
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
|
||||
listener.listen().await.unwrap();
|
||||
let t1 = tokio::spawn(async move {
|
||||
let t = listener.accept().await.unwrap();
|
||||
let mut stream = t.pin_stream();
|
||||
|
||||
let now = tokio::time::Instant::now();
|
||||
|
||||
while let Some(Ok(_)) = stream.next().await {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t1 exit");
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
let mut sink1 = tunnel.pin_sink();
|
||||
let t2 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink1.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t2 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t2 exit");
|
||||
});
|
||||
|
||||
let mut sink2 = tunnel.pin_sink();
|
||||
let t3 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink2.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t3 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t3 exit");
|
||||
});
|
||||
|
||||
let t4 = tokio::spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
tracing::info!("closing");
|
||||
let close_ret = tunnel.pin_sink().close().await;
|
||||
tracing::info!("closed {:?}", close_ret);
|
||||
});
|
||||
|
||||
let _ = tokio::join!(t1, t2, t3, t4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use easytier_rpc::TunnelInfo;
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
|
||||
use self::stats::Throughput;
|
||||
|
||||
use super::*;
|
||||
use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel};
|
||||
|
||||
pub trait TunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError>;
|
||||
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError>;
|
||||
}
|
||||
|
||||
pub struct TunnelWithFilter<T, F> {
|
||||
inner: T,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
|
||||
impl<T, F> Tunnel for TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
struct SinkWrapper<F> {
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Sink<SinkItem> for SinkWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
let item = self.filter.before_send(item)?;
|
||||
self.get_mut().sink.start_send_unpin(item)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_close_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(SinkWrapper {
|
||||
sink: self.inner.pin_sink(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
struct StreamWrapper<F> {
|
||||
stream: Pin<Box<dyn DatagramStream>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Stream for StreamWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let self_mut = self.get_mut();
|
||||
match self_mut.stream.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ret)) => {
|
||||
Poll::Ready(Some(self_mut.filter.after_received(ret)))
|
||||
}
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(StreamWrapper {
|
||||
stream: self.inner.pin_stream(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
self.inner.info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F> TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(inner: T, filter: Arc<F>) -> Self {
|
||||
Self { inner, filter }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PacketRecorderTunnelFilter {
|
||||
pub received: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
pub sent: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for PacketRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
|
||||
self.received.lock().unwrap().push(data.clone());
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.sent.lock().unwrap().push(v.clone().into());
|
||||
Ok(v)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
received: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StatsRecorderTunnelFilter {
|
||||
throughput: Arc<Throughput>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for StatsRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
|
||||
self.throughput.record_tx_bytes(data.len() as u64);
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.throughput.record_rx_bytes(v.len() as u64);
|
||||
Ok(v)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StatsRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
throughput: Arc::new(Throughput::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_throughput(&self) -> Arc<Throughput> {
|
||||
self.throughput.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! define_tunnel_filter_chain {
|
||||
($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => (
|
||||
pub struct $type_name {
|
||||
$($field_name: std::sync::Arc<$filter_type>,)+
|
||||
}
|
||||
|
||||
impl $type_name {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
$($field_name: std::sync::Arc::new(<$filter_type>::new()),)+
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel {
|
||||
$(
|
||||
let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone());
|
||||
)+
|
||||
tunnel
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tunnels::ring_tunnel::RingTunnel;
|
||||
#[tokio::test]
|
||||
async fn test_nested_filter() {
|
||||
define_tunnel_filter_chain!(
|
||||
Filter,
|
||||
a = PacketRecorderTunnelFilter,
|
||||
b = PacketRecorderTunnelFilter,
|
||||
c = PacketRecorderTunnelFilter
|
||||
);
|
||||
|
||||
let filter = Filter::new();
|
||||
let tunnel = filter.wrap_tunnel(RingTunnel::new(1));
|
||||
|
||||
let mut s = tunnel.pin_sink();
|
||||
s.send(Bytes::from("hello")).await.unwrap();
|
||||
|
||||
assert_eq!(1, filter.a.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.b.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.c.received.lock().unwrap().len());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,574 @@
|
||||
use std::{fmt::Debug, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use easytier_rpc::TunnelInfo;
|
||||
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tokio_util::{
|
||||
bytes::{Buf, Bytes, BytesMut},
|
||||
udp::UdpFramed,
|
||||
};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::rkyv_util::{self, encode_to_bytes},
|
||||
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::{
|
||||
codec::BytesCodec,
|
||||
common::{FramedTunnel, TunnelWithCustomInfo},
|
||||
ring_tunnel::create_ring_tunnel_pair,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelListener,
|
||||
};
|
||||
|
||||
pub const UDP_DATA_MTU: usize = 2500;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub enum UdpPacketPayload {
|
||||
Syn,
|
||||
Sack,
|
||||
HolePunch(Vec<u8>),
|
||||
Data(Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UdpPacket {
|
||||
pub conn_id: u32,
|
||||
pub payload: UdpPacketPayload,
|
||||
}
|
||||
|
||||
impl UdpPacket {
|
||||
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
payload: UdpPacketPayload::Data(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id: 0,
|
||||
payload: UdpPacketPayload::HolePunch(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_syn_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
payload: UdpPacketPayload::Syn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_sack_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
payload: UdpPacketPayload::Sack,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error");
|
||||
return None;
|
||||
};
|
||||
|
||||
if udp_packet.conn_id != conn_id.clone() {
|
||||
tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
|
||||
tracing::warn!(?udp_packet, "udp payload not data");
|
||||
return None;
|
||||
};
|
||||
|
||||
let ptr_range = payload.as_ptr_range();
|
||||
let offset = ptr_range.start as usize - buf.as_ptr() as usize;
|
||||
let len = ptr_range.end as usize - ptr_range.start as usize;
|
||||
buf.advance(offset);
|
||||
buf.truncate(len);
|
||||
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
|
||||
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
fn get_tunnel_from_socket(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Box<dyn super::Tunnel> {
|
||||
let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU));
|
||||
let (sink, stream) = udp.split();
|
||||
|
||||
let recv_addr = addr;
|
||||
let stream = stream.filter_map(move |v| async move {
|
||||
tracing::trace!(?v, "udp stream recv something");
|
||||
if v.is_err() {
|
||||
tracing::warn!(?v, "udp stream error");
|
||||
return Some(Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
)));
|
||||
}
|
||||
|
||||
let (buf, addr) = v.unwrap();
|
||||
assert_eq!(addr, recv_addr.clone());
|
||||
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
|
||||
});
|
||||
let stream = Box::pin(stream);
|
||||
|
||||
let sender_addr = addr;
|
||||
let sink = Box::pin(sink.with(move |v: Bytes| async move {
|
||||
if false {
|
||||
return Err(super::TunnelError::CommonError("udp sink error".to_owned()));
|
||||
}
|
||||
|
||||
// TODO: two copy here, how to avoid?
|
||||
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
|
||||
tracing::trace!(?udp_packet, ?v, "udp send packet");
|
||||
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
|
||||
Ok((v, sender_addr))
|
||||
}));
|
||||
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
stream,
|
||||
sink,
|
||||
// TODO: this remote addr is not a url
|
||||
super::TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(
|
||||
&socket.local_addr().unwrap().to_string(),
|
||||
"udp",
|
||||
)
|
||||
.into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
struct StreamSinkPair(
|
||||
Pin<Box<dyn DatagramStream>>,
|
||||
Pin<Box<dyn DatagramSink>>,
|
||||
u32,
|
||||
);
|
||||
type ArcStreamSinkPair = Arc<Mutex<StreamSinkPair>>;
|
||||
|
||||
pub struct UdpTunnelListener {
|
||||
addr: url::Url,
|
||||
socket: Option<Arc<UdpSocket>>,
|
||||
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
|
||||
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
|
||||
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
|
||||
}
|
||||
|
||||
impl UdpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
|
||||
Self {
|
||||
addr,
|
||||
socket: None,
|
||||
sock_map: Arc::new(DashMap::new()),
|
||||
forward_tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_forward_packet(
|
||||
sock_map: &DashMap<SocketAddr, ArcStreamSinkPair>,
|
||||
buf: BytesMut,
|
||||
addr: SocketAddr,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let entry = sock_map.get_mut(&addr);
|
||||
if entry.is_none() {
|
||||
log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::trace!("udp forward packet: {:?}, {:?}", addr, buf);
|
||||
let entry = entry.unwrap();
|
||||
let pair = entry.value().clone();
|
||||
drop(entry);
|
||||
|
||||
let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else {
|
||||
return Ok(());
|
||||
};
|
||||
pair.lock().await.1.send(buf.freeze()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
local_url: url::Url,
|
||||
conn_id: u32,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(?conn_id, ?addr, "udp connection accept handling",);
|
||||
|
||||
let udp_packet = UdpPacket::new_sack_packet(conn_id);
|
||||
let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
socket.send_to(&sack_buf, addr).await?;
|
||||
|
||||
let (ctunnel, stunnel) = create_ring_tunnel_pair();
|
||||
let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id);
|
||||
let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id);
|
||||
let addr_copy = addr.clone();
|
||||
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
|
||||
let ctunnel_stream = ctunnel.pin_stream();
|
||||
forward_tasks.lock().await.spawn(async move {
|
||||
let ret = ctunnel_stream
|
||||
.map(|v| {
|
||||
tracing::trace!(?v, "udp stream recv something in forward task");
|
||||
if v.is_err() {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
));
|
||||
}
|
||||
Ok(v.unwrap().freeze())
|
||||
})
|
||||
.forward(udp_tunnel.pin_sink())
|
||||
.await;
|
||||
if let None = sock_map.remove(&addr_copy) {
|
||||
log::warn!("udp forward packet: {:?}, no entry", addr_copy);
|
||||
}
|
||||
close_tunnel(&ctunnel).await.unwrap();
|
||||
log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret);
|
||||
});
|
||||
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
stunnel,
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: local_url.into(),
|
||||
remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for UdpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
self.socket = Some(Arc::new(UdpSocket::bind(addr).await?));
|
||||
|
||||
let socket = self.socket.as_ref().unwrap().clone();
|
||||
let forward_tasks = self.forward_tasks.clone();
|
||||
let sock_map = self.sock_map.clone();
|
||||
let conn_send = self.conn_send.take().unwrap();
|
||||
let local_url = self.local_url().clone();
|
||||
self.forward_tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(2500, 0);
|
||||
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
|
||||
let _ = buf.split_off(_size);
|
||||
log::trace!(
|
||||
"udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||
addr,
|
||||
buf,
|
||||
_size
|
||||
);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf)
|
||||
else {
|
||||
tracing::warn!(?buf, "udp decode error in forward task");
|
||||
continue;
|
||||
};
|
||||
|
||||
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
|
||||
let conn = Self::handle_connect(
|
||||
socket.clone(),
|
||||
addr,
|
||||
forward_tasks.clone(),
|
||||
sock_map.clone(),
|
||||
local_url.clone(),
|
||||
udp_packet.conn_id.into(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
if let Err(e) = conn_send.send(conn).await {
|
||||
tracing::warn!(?e, "udp send conn to accept channel error");
|
||||
}
|
||||
} else {
|
||||
Self::try_forward_packet(sock_map.as_ref(), buf, addr)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
|
||||
);
|
||||
|
||||
// let forward_tasks_clone = self.forward_tasks.clone();
|
||||
// tokio::spawn(async move {
|
||||
// loop {
|
||||
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
|
||||
// tokio::select! {
|
||||
// ret = locked_forward_tasks.join_next() => {
|
||||
// tracing::warn!(?ret, "udp forward task exit");
|
||||
// }
|
||||
// else => {
|
||||
// drop(locked_forward_tasks);
|
||||
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
// continue;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
log::info!("start udp accept: {:?}", self.addr);
|
||||
while let Some(conn) = self.conn_recv.recv().await {
|
||||
return Ok(conn);
|
||||
}
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp accept error".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
struct UdpTunnelConnCounter {
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
}
|
||||
|
||||
impl Debug for UdpTunnelConnCounter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("UdpTunnelConnCounter")
|
||||
.field("sock_map_len", &self.sock_map.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
self.sock_map.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(Box::new(UdpTunnelConnCounter {
|
||||
sock_map: self.sock_map.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpTunnelConnector {
|
||||
addr: url::Url,
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl UdpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_sack(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(128, 0);
|
||||
|
||||
let (usize, recv_addr) = tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
socket.recv_from(&mut buf),
|
||||
)
|
||||
.await??;
|
||||
|
||||
if recv_addr != addr {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected sack addr: {:?}, {:?}",
|
||||
recv_addr, addr
|
||||
)));
|
||||
}
|
||||
|
||||
let _ = buf.split_off(usize);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error in wait sack");
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, decode error. buf: {:?}",
|
||||
buf
|
||||
)));
|
||||
};
|
||||
|
||||
if conn_id != udp_packet.conn_id {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
|
||||
conn_id, udp_packet.conn_id
|
||||
)));
|
||||
}
|
||||
|
||||
if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected payload. payload: {:?}",
|
||||
udp_packet.payload
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_sack_loop(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
while let Err(err) = Self::wait_sack(socket, addr, conn_id).await {
|
||||
tracing::warn!(?err, "udp wait sack error");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn try_connect_with_socket(
|
||||
&self,
|
||||
socket: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
log::warn!("udp connect: {:?}", self.addr);
|
||||
|
||||
// send syn
|
||||
let conn_id = rand::random();
|
||||
let udp_packet = UdpPacket::new_syn_packet(conn_id);
|
||||
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
let ret = socket.send_to(&b, &addr).await?;
|
||||
tracing::warn!(?udp_packet, ?ret, "udp send syn");
|
||||
|
||||
// wait sack
|
||||
tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
Self::wait_sack_loop(&socket, addr, conn_id),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// sack done
|
||||
let local_addr = socket.local_addr().unwrap().to_string();
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
get_tunnel_from_socket(Arc::new(socket), addr, conn_id),
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(),
|
||||
remote_addr: self.remote_url().into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
return self.try_connect_with_socket(socket).await;
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket = UdpSocket::bind(*bind_addr).await?;
|
||||
|
||||
// linux does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// mac can handle this with bind correctly
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
|
||||
futures.push(self.try_connect_with_socket(socket));
|
||||
}
|
||||
|
||||
let Some(ret) = futures.next().await else {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"join connect futures failed".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for UdpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_pingpong() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench_with_bind() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn udp_bench_with_bind_fail() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user