Initial Version

This commit is contained in:
sijie.sun
2023-09-23 01:53:45 +00:00
commit 9779923b87
63 changed files with 10840 additions and 0 deletions
+105
View File
@@ -0,0 +1,105 @@
[package]
name = "easytier-core"
version = "0.1.0"
edition = "2021"
authors = ["easytier"]
rust-version = "1.75"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "easytier_rpc"
path = "src/rpc/lib.rs"
[dependencies]
tracing = { version = "0.1", features = ["log"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-appender = "0.2.3"
log = "0.4"
thiserror = "1.0"
auto_impl = "1.1.0"
crossbeam = "0.8.4"
gethostname = "0.4.3"
futures = "0.3"
tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1"
tokio-util = { version = "0.7.9", features = ["codec", "net"] }
async-stream = "0.3.5"
async-trait = "0.1.74"
dashmap = "5.5.3"
timedmap = "1.0.1"
# for tap device
tun = { version = "0.6.1", features = ["async"] }
# for net ns
nix = { version = "0.27", features = ["sched", "socket", "ioctl"] }
uuid = { version = "1.5.0", features = [
"v4",
"fast-rng",
"macro-diagnostics",
"serde",
] }
# for ring tunnel
crossbeam-queue = "0.3"
once_cell = "1.18.0"
# for packet
rkyv = { "version" = "0.7.42", features = ["validation", "archive_le"] }
# for rpc
tonic = "0.10"
prost = "0.12"
anyhow = "1.0"
tarpc = { version = "0.32", features = ["tokio1", "serde1"] }
bincode = "1.3"
url = "2.5.0"
# for tun packet
byteorder = "1.5.0"
# for proxy
cidr = "0.2.2"
socket2 = "0.5.5"
# for hole punching
stun-format = { git = "https://github.com/KKRainbow/stun-format.git", features = [
"fmt",
"rfc3489",
"iana",
] }
rand = "0.8.5"
[dependencies.serde]
version = "1.0"
features = ["derive"]
[dependencies.pnet]
version = "0.34.0"
features = ["serde"]
[dependencies.clap]
version = "4.4"
features = ["derive"]
[dependencies.public-ip]
version = "0.2"
features = ["default"]
[build-dependencies]
tonic-build = "0.10"
[target.'cfg(windows)'.build-dependencies]
reqwest = { version = "0.11", features = ["blocking"] }
zip = "*"
[dev-dependencies]
serial_test = "*"
+95
View File
@@ -0,0 +1,95 @@
#[cfg(target_os = "windows")]
use std::{
env,
fs::File,
io::{copy, Cursor},
path::PathBuf,
};
#[cfg(target_os = "windows")]
struct WindowsBuild {}
#[cfg(target_os = "windows")]
impl WindowsBuild {
fn check_protoc_exist() -> Option<PathBuf> {
let path = env::var_os("PROTOC").map(PathBuf::from);
if path.is_some() && path.as_ref().unwrap().exists() {
return path;
}
let path = env::var_os("PATH").unwrap_or_default();
for p in env::split_paths(&path) {
let p = p.join("protoc");
if p.exists() {
return Some(p);
}
}
None
}
fn get_cargo_target_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?);
let profile = std::env::var("PROFILE")?;
let mut target_dir = None;
let mut sub_path = out_dir.as_path();
while let Some(parent) = sub_path.parent() {
if parent.ends_with(&profile) {
target_dir = Some(parent);
break;
}
sub_path = parent;
}
let target_dir = target_dir.ok_or("not found")?;
Ok(target_dir.to_path_buf())
}
fn download_protoc() -> PathBuf {
println!("cargo:info=use exist protoc: {:?}", "k");
let out_dir = Self::get_cargo_target_dir().unwrap();
let fname = out_dir.join("protoc");
if fname.exists() {
println!("cargo:info=use exist protoc: {:?}", fname);
return fname;
}
println!("cargo:info=need download protoc, please wait...");
let url = "https://github.com/protocolbuffers/protobuf/releases/download/v26.0-rc1/protoc-26.0-rc-1-win64.zip";
let response = reqwest::blocking::get(url).unwrap();
println!("{:?}", response);
let mut content = response
.bytes()
.map(|v| v.to_vec())
.map(Cursor::new)
.map(zip::ZipArchive::new)
.unwrap()
.unwrap();
let protoc_zipped_file = content.by_name("bin/protoc.exe").unwrap();
let mut content = protoc_zipped_file;
copy(&mut content, &mut File::create(&fname).unwrap()).unwrap();
fname
}
pub fn check_for_win() {
// add third_party dir to link search path
println!("cargo:rustc-link-search=native=third_party/");
let protoc_path = if let Some(o) = Self::check_protoc_exist() {
println!("cargo:info=use os exist protoc: {:?}", o);
o
} else {
Self::download_protoc()
};
std::env::set_var("PROTOC", protoc_path);
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(target_os = "windows")]
WindowsBuild::check_for_win();
tonic_build::compile_protos("proto/cli.proto")?;
Ok(())
}
+116
View File
@@ -0,0 +1,116 @@
syntax = "proto3";
package cli;
message Status {
int32 code = 1;
string message = 2;
}
message PeerConnStats {
uint64 rx_bytes = 1;
uint64 tx_bytes = 2;
uint64 rx_packets = 3;
uint64 tx_packets = 4;
uint64 latency_us = 5;
}
message TunnelInfo {
string tunnel_type = 1;
string local_addr = 2;
string remote_addr = 3;
}
message PeerConnInfo {
string conn_id = 1;
string my_node_id = 2;
string peer_id = 3;
repeated string features = 4;
TunnelInfo tunnel = 5;
PeerConnStats stats = 6;
}
message PeerInfo {
string peer_id = 1;
repeated PeerConnInfo conns = 2;
}
message ListPeerRequest {}
message ListPeerResponse {
repeated PeerInfo peer_infos = 1;
}
enum NatType {
// has NAT; but own a single public IP, port is not changed
Unknown = 0;
OpenInternet = 1;
NoPAT = 2;
FullCone = 3;
Restricted = 4;
PortRestricted = 5;
Symmetric = 6;
SymUdpFirewall = 7;
}
message StunInfo {
NatType udp_nat_type = 1;
NatType tcp_nat_type = 2;
int64 last_update_time = 3;
}
message Route {
string peer_id = 1;
string ipv4_addr = 2;
string next_hop_peer_id = 3;
int32 cost = 4;
repeated string proxy_cidrs = 5;
string hostname = 6;
StunInfo stun_info = 7;
}
message ListRouteRequest {}
message ListRouteResponse {
repeated Route routes = 1;
}
service PeerManageRpc {
rpc ListPeer (ListPeerRequest) returns (ListPeerResponse);
rpc ListRoute (ListRouteRequest) returns (ListRouteResponse);
}
enum ConnectorStatus {
CONNECTED = 0;
DISCONNECTED = 1;
CONNECTING = 2;
}
message Connector {
string url = 1;
ConnectorStatus status = 2;
}
message ListConnectorRequest {}
message ListConnectorResponse {
repeated Connector connectors = 1;
}
enum ConnectorManageAction {
ADD = 0;
REMOVE = 1;
}
message ManageConnectorRequest {
ConnectorManageAction action = 1;
string url = 2;
}
message ManageConnectorResponse { }
service ConnectorManageRpc {
rpc ListConnector (ListConnectorRequest) returns (ListConnectorResponse);
rpc ManageConnector (ManageConnectorRequest) returns (ManageConnectorResponse);
}
+161
View File
@@ -0,0 +1,161 @@
// use filesystem as a config store
use std::{
ffi::OsStr,
io::Write,
path::{Path, PathBuf},
};
static DEFAULT_BASE_DIR: &str = "/var/lib/easytier";
static DIR_ROOT_CONFIG_FILE_NAME: &str = "__root__";
pub struct ConfigFs {
_db_name: String,
db_path: PathBuf,
}
impl ConfigFs {
pub fn new(db_name: &str) -> Self {
Self::new_with_dir(db_name, DEFAULT_BASE_DIR)
}
pub fn new_with_dir(db_name: &str, dir: &str) -> Self {
let p = Path::new(OsStr::new(dir)).join(OsStr::new(db_name));
std::fs::create_dir_all(&p).unwrap();
ConfigFs {
_db_name: db_name.to_string(),
db_path: p,
}
}
pub fn get(&self, key: &str) -> Result<String, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
// if path is dir, read the DIR_ROOT_CONFIG_FILE_NAME in it
if path.is_dir() {
let path = path.join(OsStr::new(DIR_ROOT_CONFIG_FILE_NAME));
std::fs::read_to_string(path)
} else if path.is_file() {
return std::fs::read_to_string(path);
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"key not found",
));
}
}
pub fn list_keys(&self, key: &str) -> Result<Vec<String>, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
let mut keys = Vec::new();
for entry in std::fs::read_dir(path)? {
let entry = entry?;
let path = entry.path();
let key = path.file_name().unwrap().to_str().unwrap().to_string();
if key != DIR_ROOT_CONFIG_FILE_NAME {
keys.push(key);
}
}
Ok(keys)
}
#[allow(dead_code)]
pub fn remove(&self, key: &str) -> Result<(), std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
// if path is dir, remove the DIR_ROOT_CONFIG_FILE_NAME in it
if path.is_dir() {
std::fs::remove_dir_all(path)
} else if path.is_file() {
return std::fs::remove_file(path);
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"key not found",
));
}
}
pub fn add_dir(&self, key: &str) -> Result<std::fs::File, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
// if path is dir, write the DIR_ROOT_CONFIG_FILE_NAME in it
if path.is_file() {
Err(std::io::Error::new(
std::io::ErrorKind::AlreadyExists,
"key already exists",
))
} else {
std::fs::create_dir_all(&path)?;
return std::fs::File::create(path.join(OsStr::new(DIR_ROOT_CONFIG_FILE_NAME)));
}
}
pub fn add_file(&self, key: &str) -> Result<std::fs::File, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
let base_dir = path.parent().unwrap();
if !path.is_file() {
std::fs::create_dir_all(base_dir)?;
}
std::fs::File::create(path)
}
pub fn get_or_add<F>(
&self,
key: &str,
val_fn: F,
add_dir: bool,
) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
let get_ret = self.get(key);
match get_ret {
Ok(v) => Ok(v),
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
let val = val_fn();
if add_dir {
let mut f = self.add_dir(key)?;
f.write_all(val.as_bytes())?;
} else {
let mut f = self.add_file(key)?;
f.write_all(val.as_bytes())?;
}
Ok(val)
} else {
Err(e)
}
}
}
}
#[allow(dead_code)]
pub fn get_or_add_dir<F>(&self, key: &str, val_fn: F) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
self.get_or_add(key, val_fn, true)
}
pub fn get_or_add_file<F>(&self, key: &str, val_fn: F) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
self.get_or_add(key, val_fn, false)
}
pub fn get_or_default<F>(&self, key: &str, default: F) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
let get_ret = self.get(key);
match get_ret {
Ok(v) => Ok(v),
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
Ok(default())
} else {
Err(e)
}
}
}
}
}
+28
View File
@@ -0,0 +1,28 @@
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 60;
pub const DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC: u64 = 60;
macro_rules! define_global_var {
($name:ident, $type:ty, $init:expr) => {
pub static $name: once_cell::sync::Lazy<tokio::sync::Mutex<$type>> =
once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new($init));
};
}
#[macro_export]
macro_rules! use_global_var {
($name:ident) => {
crate::common::constants::$name.lock().await.to_owned()
};
}
#[macro_export]
macro_rules! set_global_var {
($name:ident, $val:expr) => {
*crate::common::constants::$name.lock().await = $val
};
}
define_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, u64, 1000);
pub const UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID: u32 = 2;
+43
View File
@@ -0,0 +1,43 @@
use std::{io, result};
use thiserror::Error;
use crate::tunnels;
#[derive(Error, Debug)]
pub enum Error {
#[error("io error")]
IOError(#[from] io::Error),
#[error("rust tun error {0}")]
TunError(#[from] tun::Error),
#[error("tunnel error {0}")]
TunnelError(#[from] tunnels::TunnelError),
#[error("Peer has no conn, PeerId: {0}")]
PeerNoConnectionError(uuid::Uuid),
#[error("RouteError: {0}")]
RouteError(String),
#[error("Not found")]
NotFound,
#[error("Invalid Url: {0}")]
InvalidUrl(String),
#[error("Shell Command error: {0}")]
ShellCommandError(String),
// #[error("Rpc listen error: {0}")]
// RpcListenError(String),
#[error("Rpc connect error: {0}")]
RpcConnectError(String),
#[error("Rpc error: {0}")]
RpcClientError(#[from] tarpc::client::RpcError),
#[error("Timeout error: {0}")]
Timeout(#[from] tokio::time::error::Elapsed),
#[error("url in blacklist")]
UrlInBlacklist,
#[error("unknown data store error")]
Unknown,
#[error("anyhow error: {0}")]
AnyhowError(#[from] anyhow::Error),
}
pub type Result<T> = result::Result<T, Error>;
// impl From for std::
+259
View File
@@ -0,0 +1,259 @@
use std::{io::Write, sync::Arc};
use crossbeam::atomic::AtomicCell;
use easytier_rpc::PeerConnInfo;
use super::{
config_fs::ConfigFs,
netns::NetNS,
network::IPCollector,
stun::{StunInfoCollector, StunInfoCollectorTrait},
};
#[derive(Debug, Clone, PartialEq)]
pub enum GlobalCtxEvent {
PeerAdded,
PeerRemoved,
PeerConnAdded(PeerConnInfo),
PeerConnRemoved(PeerConnInfo),
}
type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
type EventBusSubscriber = tokio::sync::broadcast::Receiver<GlobalCtxEvent>;
pub struct GlobalCtx {
pub inst_name: String,
pub id: uuid::Uuid,
pub config_fs: ConfigFs,
pub net_ns: NetNS,
event_bus: EventBus,
cached_ipv4: AtomicCell<Option<std::net::Ipv4Addr>>,
cached_proxy_cidrs: AtomicCell<Option<Vec<cidr::IpCidr>>>,
ip_collector: Arc<IPCollector>,
hotname: AtomicCell<Option<String>>,
stun_info_collection: Box<dyn StunInfoCollectorTrait>,
}
impl std::fmt::Debug for GlobalCtx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GlobalCtx")
.field("inst_name", &self.inst_name)
.field("id", &self.id)
.field("net_ns", &self.net_ns.name())
.field("event_bus", &"EventBus")
.field("ipv4", &self.cached_ipv4)
.finish()
}
}
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
impl GlobalCtx {
pub fn new(inst_name: &str, config_fs: ConfigFs, net_ns: NetNS) -> Self {
let id = config_fs
.get_or_add_file("inst_id", || uuid::Uuid::new_v4().to_string())
.unwrap();
let id = uuid::Uuid::parse_str(&id).unwrap();
let (event_bus, _) = tokio::sync::broadcast::channel(100);
// NOTICE: we may need to choose stun stun server based on geo location
// stun server cross nation may return a external ip address with high latency and loss rate
let default_stun_servers = vec![
"stun.miwifi.com:3478".to_string(),
"stun.qq.com:3478".to_string(),
"stun.chat.bilibili.com:3478".to_string(),
"fwa.lifesizecloud.com:3478".to_string(),
"stun.isp.net.au:3478".to_string(),
"stun.nextcloud.com:3478".to_string(),
"stun.freeswitch.org:3478".to_string(),
"stun.voip.blackberry.com:3478".to_string(),
"stunserver.stunprotocol.org:3478".to_string(),
"stun.sipnet.com:3478".to_string(),
"stun.radiojar.com:3478".to_string(),
"stun.sonetel.com:3478".to_string(),
"stun.voipgate.com:3478".to_string(),
"stun.counterpath.com:3478".to_string(),
"180.235.108.91:3478".to_string(),
"193.22.2.248:3478".to_string(),
];
GlobalCtx {
inst_name: inst_name.to_string(),
id,
config_fs,
net_ns: net_ns.clone(),
event_bus,
cached_ipv4: AtomicCell::new(None),
cached_proxy_cidrs: AtomicCell::new(None),
ip_collector: Arc::new(IPCollector::new(net_ns)),
hotname: AtomicCell::new(None),
stun_info_collection: Box::new(StunInfoCollector::new(default_stun_servers)),
}
}
pub fn subscribe(&self) -> EventBusSubscriber {
self.event_bus.subscribe()
}
pub fn issue_event(&self, event: GlobalCtxEvent) {
if self.event_bus.receiver_count() != 0 {
self.event_bus.send(event).unwrap();
} else {
log::warn!("No subscriber for event: {:?}", event);
}
}
pub fn get_ipv4(&self) -> Option<std::net::Ipv4Addr> {
if let Some(ret) = self.cached_ipv4.load() {
return Some(ret);
}
let Ok(addr) = self.config_fs.get("ipv4") else {
return None;
};
let Ok(addr) = addr.parse() else {
tracing::error!("invalid ipv4 addr: {}", addr);
return None;
};
self.cached_ipv4.store(Some(addr));
return Some(addr);
}
pub fn set_ipv4(&mut self, addr: std::net::Ipv4Addr) {
self.config_fs
.add_file("ipv4")
.unwrap()
.write_all(addr.to_string().as_bytes())
.unwrap();
self.cached_ipv4.store(None);
}
pub fn add_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
let escaped_cidr = cidr.to_string().replace("/", "_");
self.config_fs
.add_file(&format!("proxy_cidrs/{}", escaped_cidr))?;
self.cached_proxy_cidrs.store(None);
Ok(())
}
pub fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
let escaped_cidr = cidr.to_string().replace("/", "_");
self.config_fs
.remove(&format!("proxy_cidrs/{}", escaped_cidr))?;
self.cached_proxy_cidrs.store(None);
Ok(())
}
pub fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr> {
if let Some(proxy_cidrs) = self.cached_proxy_cidrs.take() {
self.cached_proxy_cidrs.store(Some(proxy_cidrs.clone()));
return proxy_cidrs;
}
let Ok(keys) = self.config_fs.list_keys("proxy_cidrs") else {
return vec![];
};
let mut ret = Vec::new();
for key in keys.iter() {
let key = key.replace("_", "/");
let Ok(cidr) = key.parse() else {
tracing::error!("invalid proxy cidr: {}", key);
continue;
};
ret.push(cidr);
}
self.cached_proxy_cidrs.store(Some(ret.clone()));
ret
}
pub fn get_ip_collector(&self) -> Arc<IPCollector> {
self.ip_collector.clone()
}
pub fn get_hostname(&self) -> Option<String> {
if let Some(hostname) = self.hotname.take() {
self.hotname.store(Some(hostname.clone()));
return Some(hostname);
}
let hostname = gethostname::gethostname().to_string_lossy().to_string();
self.hotname.store(Some(hostname.clone()));
return Some(hostname);
}
pub fn get_stun_info_collector(&self) -> impl StunInfoCollectorTrait + '_ {
self.stun_info_collection.as_ref()
}
#[cfg(test)]
pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) {
// force replace the stun_info_collection without mut and drop the old one
let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>;
let ptr = ptr as *mut Box<dyn StunInfoCollectorTrait>;
unsafe {
std::ptr::drop_in_place(ptr);
std::ptr::write(ptr, collector);
}
}
pub fn get_id(&self) -> uuid::Uuid {
self.id
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[tokio::test]
async fn test_global_ctx() {
let config_fs = ConfigFs::new("/tmp/easytier");
let net_ns = NetNS::new(None);
let global_ctx = GlobalCtx::new("test", config_fs, net_ns);
let mut subscriber = global_ctx.subscribe();
global_ctx.issue_event(GlobalCtxEvent::PeerAdded);
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved);
global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default()));
global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default()));
assert_eq!(subscriber.recv().await.unwrap(), GlobalCtxEvent::PeerAdded);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerRemoved
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default())
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default())
);
}
pub fn get_mock_global_ctx() -> ArcGlobalCtx {
let node_id = uuid::Uuid::new_v4();
let config_fs = ConfigFs::new_with_dir(node_id.to_string().as_str(), "/tmp/easytier");
let net_ns = NetNS::new(None);
std::sync::Arc::new(GlobalCtx::new(
format!("test_{}", node_id).as_str(),
config_fs,
net_ns,
))
}
}
+312
View File
@@ -0,0 +1,312 @@
use std::net::Ipv4Addr;
use async_trait::async_trait;
use tokio::process::Command;
use super::error::Error;
#[async_trait]
pub trait IfConfiguerTrait {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error>;
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error>;
async fn wait_interface_show(&self, _name: &str) -> Result<(), Error> {
return Ok(());
}
}
fn cidr_to_subnet_mask(prefix_length: u8) -> Ipv4Addr {
if prefix_length > 32 {
panic!("Invalid CIDR prefix length");
}
let subnet_mask: u32 = (!0u32)
.checked_shl(32 - u32::from(prefix_length))
.unwrap_or(0);
Ipv4Addr::new(
((subnet_mask >> 24) & 0xFF) as u8,
((subnet_mask >> 16) & 0xFF) as u8,
((subnet_mask >> 8) & 0xFF) as u8,
(subnet_mask & 0xFF) as u8,
)
}
async fn run_shell_cmd(cmd: &str) -> Result<(), Error> {
let cmd_out = if cfg!(target_os = "windows") {
Command::new("cmd").arg("/C").arg(cmd).output().await?
} else {
Command::new("sh").arg("-c").arg(cmd).output().await?
};
let stdout = String::from_utf8_lossy(cmd_out.stdout.as_slice());
let stderr = String::from_utf8_lossy(cmd_out.stderr.as_slice());
let ec = cmd_out.status.code();
let succ = cmd_out.status.success();
tracing::info!(?cmd, ?ec, ?succ, ?stdout, ?stderr, "run shell cmd");
if !cmd_out.status.success() {
return Err(Error::ShellCommandError(
stdout.to_string() + &stderr.to_string(),
));
}
Ok(())
}
pub struct MacIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for MacIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route -n add {} -netmask {} -interface {} -hopcount 7",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route -n delete {} -netmask {} -interface {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"ifconfig {} {:?}/{:?} 10.8.8.8 up",
name, address, cidr_prefix,
)
.as_str(),
)
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(format!("ifconfig {} {}", name, if up { "up" } else { "down" }).as_str())
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("ifconfig {} inet delete", name).as_str()).await
} else {
run_shell_cmd(
format!("ifconfig {} inet {} delete", name, ip.unwrap().to_string()).as_str(),
)
.await
}
}
}
pub struct LinuxIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for LinuxIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"ip route add {}/{} dev {} metric 65535",
address, cidr_prefix, name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(format!("ip route del {}/{} dev {}", address, cidr_prefix, name).as_str())
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(format!("ip addr add {:?}/{:?} dev {}", address, cidr_prefix, name).as_str())
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(format!("ip link set {} {}", name, if up { "up" } else { "down" }).as_str())
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("ip addr flush dev {}", name).as_str()).await
} else {
run_shell_cmd(
format!("ip addr del {:?} dev {}", ip.unwrap().to_string(), name).as_str(),
)
.await
}
}
}
pub struct WindowsIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for WindowsIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route add {} mask {} {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route delete {} mask {} {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"netsh interface ipv4 add address {} address={} mask={}",
name,
address,
cidr_to_subnet_mask(cidr_prefix)
)
.as_str(),
)
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(
format!(
"netsh interface set interface {} {}",
name,
if up { "enable" } else { "disable" }
)
.as_str(),
)
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("netsh interface ipv4 delete address {}", name).as_str()).await
} else {
run_shell_cmd(
format!(
"netsh interface ipv4 delete address {} address={}",
name,
ip.unwrap().to_string()
)
.as_str(),
)
.await
}
}
async fn wait_interface_show(&self, name: &str) -> Result<(), Error> {
Ok(
tokio::time::timeout(std::time::Duration::from_secs(10), async move {
loop {
let Ok(_) = run_shell_cmd(
format!("netsh interface ipv4 show interfaces {}", name).as_str(),
)
.await
else {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
continue;
};
break;
}
})
.await?,
)
}
}
#[cfg(target_os = "macos")]
pub type IfConfiger = MacIfConfiger;
#[cfg(target_os = "linux")]
pub type IfConfiger = LinuxIfConfiger;
#[cfg(target_os = "windows")]
pub type IfConfiger = WindowsIfConfiger;
+9
View File
@@ -0,0 +1,9 @@
pub mod config_fs;
pub mod constants;
pub mod error;
pub mod global_ctx;
pub mod ifcfg;
pub mod netns;
pub mod network;
pub mod rkyv_util;
pub mod stun;
+118
View File
@@ -0,0 +1,118 @@
use futures::Future;
use once_cell::sync::Lazy;
use tokio::sync::Mutex;
#[cfg(target_os = "linux")]
use nix::sched::{setns, CloneFlags};
#[cfg(target_os = "linux")]
use std::os::fd::AsFd;
pub struct NetNSGuard {
#[cfg(target_os = "linux")]
old_ns: Option<std::fs::File>,
}
type NetNSLock = Mutex<()>;
static LOCK: Lazy<NetNSLock> = Lazy::new(|| Mutex::new(()));
pub static ROOT_NETNS_NAME: &str = "_root_ns";
#[cfg(target_os = "linux")]
impl NetNSGuard {
pub fn new(ns: Option<String>) -> Box<Self> {
let old_ns = if ns.is_some() {
let old_ns = if cfg!(target_os = "linux") {
Some(std::fs::File::open("/proc/self/ns/net").unwrap())
} else {
None
};
Self::switch_ns(ns);
old_ns
} else {
None
};
Box::new(NetNSGuard { old_ns })
}
fn switch_ns(name: Option<String>) {
if name.is_none() {
return;
}
let ns_path: String;
let name = name.unwrap();
if name == ROOT_NETNS_NAME {
ns_path = "/proc/1/ns/net".to_string();
} else {
ns_path = format!("/var/run/netns/{}", name);
}
let ns = std::fs::File::open(ns_path).unwrap();
log::info!(
"[INIT NS] switching to new ns_name: {:?}, ns_file: {:?}",
name,
ns
);
setns(ns.as_fd(), CloneFlags::CLONE_NEWNET).unwrap();
}
}
#[cfg(target_os = "linux")]
impl Drop for NetNSGuard {
fn drop(&mut self) {
if self.old_ns.is_none() {
return;
}
log::info!("[INIT NS] switching back to old ns, ns: {:?}", self.old_ns);
setns(
self.old_ns.as_ref().unwrap().as_fd(),
CloneFlags::CLONE_NEWNET,
)
.unwrap();
}
}
#[cfg(not(target_os = "linux"))]
impl NetNSGuard {
pub fn new(_ns: Option<String>) -> Box<Self> {
Box::new(NetNSGuard {})
}
}
#[derive(Clone)]
pub struct NetNS {
name: Option<String>,
}
impl NetNS {
pub fn new(name: Option<String>) -> Self {
NetNS { name }
}
pub async fn run_async<F, Fut, Ret>(&self, f: F) -> Ret
where
F: FnOnce() -> Fut,
Fut: Future<Output = Ret>,
{
// TODO: do we really need this lock
// let _lock = LOCK.lock().await;
let _guard = NetNSGuard::new(self.name.clone());
f().await
}
pub fn run<F, Ret>(&self, f: F) -> Ret
where
F: FnOnce() -> Ret,
{
let _guard = NetNSGuard::new(self.name.clone());
f()
}
pub fn guard(&self) -> Box<NetNSGuard> {
NetNSGuard::new(self.name.clone())
}
pub fn name(&self) -> Option<String> {
self.name.clone()
}
}
+218
View File
@@ -0,0 +1,218 @@
use std::{ops::Deref, sync::Arc};
use easytier_rpc::peer::GetIpListResponse;
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use super::{constants::DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC, netns::NetNS};
struct InterfaceFilter {
iface: NetworkInterface,
}
#[cfg(target_os = "linux")]
impl InterfaceFilter {
async fn is_iface_bridge(&self) -> bool {
let path = format!("/sys/class/net/{}/bridge", self.iface.name);
tokio::fs::metadata(&path).await.is_ok()
}
async fn is_iface_phsical(&self) -> bool {
let path = format!("/sys/class/net/{}/device", self.iface.name);
tokio::fs::metadata(&path).await.is_ok()
}
async fn filter_iface(&self) -> bool {
tracing::trace!(
"filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_bridge: {}, is_physical: {}",
self.iface,
self.iface.is_point_to_point(),
self.iface.is_loopback(),
self.iface.is_up(),
self.iface.is_lower_up(),
self.is_iface_bridge().await,
self.is_iface_phsical().await,
);
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self.iface.is_up()
&& self.iface.is_lower_up()
&& (self.is_iface_bridge().await || self.is_iface_phsical().await)
}
}
#[cfg(target_os = "macos")]
impl InterfaceFilter {
async fn is_interface_physical(interface_name: &str) -> bool {
let output = tokio::process::Command::new("networksetup")
.args(&["-listallhardwareports"])
.output()
.await
.expect("Failed to execute command");
let stdout = std::str::from_utf8(&output.stdout).expect("Invalid UTF-8");
let lines: Vec<&str> = stdout.lines().collect();
for i in 0..lines.len() {
let line = lines[i];
if line.contains("Device:") && line.contains(interface_name) {
let next_line = lines[i + 1];
if next_line.contains("Virtual Interface") {
return false;
} else {
return true;
}
}
}
false
}
async fn filter_iface(&self) -> bool {
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self.iface.is_up()
&& Self::is_interface_physical(&self.iface.name).await
}
}
#[cfg(target_os = "windows")]
impl InterfaceFilter {
async fn filter_iface(&self) -> bool {
!self.iface.is_point_to_point() && !self.iface.is_loopback() && self.iface.is_up()
}
}
pub async fn local_ipv4() -> std::io::Result<std::net::Ipv4Addr> {
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
socket.connect("8.8.8.8:80").await?;
let addr = socket.local_addr()?;
match addr.ip() {
std::net::IpAddr::V4(ip) => Ok(ip),
std::net::IpAddr::V6(_) => Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no ipv4 address",
)),
}
}
pub async fn local_ipv6() -> std::io::Result<std::net::Ipv6Addr> {
let socket = tokio::net::UdpSocket::bind("[::]:0").await?;
socket
.connect("[2001:4860:4860:0000:0000:0000:0000:8888]:80")
.await?;
let addr = socket.local_addr()?;
match addr.ip() {
std::net::IpAddr::V6(ip) => Ok(ip),
std::net::IpAddr::V4(_) => Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no ipv4 address",
)),
}
}
pub struct IPCollector {
cached_ip_list: Arc<RwLock<GetIpListResponse>>,
collect_ip_task: Mutex<JoinSet<()>>,
net_ns: NetNS,
}
impl IPCollector {
pub fn new(net_ns: NetNS) -> Self {
Self {
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::new())),
collect_ip_task: Mutex::new(JoinSet::new()),
net_ns,
}
}
pub async fn collect_ip_addrs(&self) -> GetIpListResponse {
let mut task = self.collect_ip_task.lock().await;
if task.is_empty() {
let cached_ip_list = self.cached_ip_list.clone();
*cached_ip_list.write().await =
Self::do_collect_ip_addrs(false, self.net_ns.clone()).await;
let net_ns = self.net_ns.clone();
task.spawn(async move {
loop {
let ip_addrs = Self::do_collect_ip_addrs(true, net_ns.clone()).await;
*cached_ip_list.write().await = ip_addrs;
tokio::time::sleep(std::time::Duration::from_secs(
DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC,
))
.await;
}
});
}
return self.cached_ip_list.read().await.deref().clone();
}
#[tracing::instrument(skip(net_ns))]
async fn do_collect_ip_addrs(with_public: bool, net_ns: NetNS) -> GetIpListResponse {
let mut ret = easytier_rpc::peer::GetIpListResponse {
public_ipv4: "".to_string(),
interface_ipv4s: vec![],
public_ipv6: "".to_string(),
interface_ipv6s: vec![],
};
if with_public {
if let Some(v4_addr) =
public_ip::addr_with(public_ip::http::ALL, public_ip::Version::V4).await
{
ret.public_ipv4 = v4_addr.to_string();
}
if let Some(v6_addr) = public_ip::addr_v6().await {
ret.public_ipv6 = v6_addr.to_string();
}
}
let _g = net_ns.guard();
let ifaces = pnet::datalink::interfaces();
for iface in ifaces {
let f = InterfaceFilter {
iface: iface.clone(),
};
if !f.filter_iface().await {
continue;
}
for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip();
if ip.is_loopback() || ip.is_multicast() {
continue;
}
if ip.is_ipv4() {
ret.interface_ipv4s.push(ip.to_string());
} else if ip.is_ipv6() {
ret.interface_ipv6s.push(ip.to_string());
}
}
}
if let Ok(v4_addr) = local_ipv4().await {
tracing::trace!("got local ipv4: {}", v4_addr);
if !ret.interface_ipv4s.contains(&v4_addr.to_string()) {
ret.interface_ipv4s.push(v4_addr.to_string());
}
}
if let Ok(v6_addr) = local_ipv6().await {
tracing::trace!("got local ipv6: {}", v6_addr);
if !ret.interface_ipv6s.contains(&v6_addr.to_string()) {
ret.interface_ipv6s.push(v6_addr.to_string());
}
}
ret
}
}
+54
View File
@@ -0,0 +1,54 @@
use rkyv::{
validation::{validators::DefaultValidator, CheckTypeError},
vec::ArchivedVec,
Archive, CheckBytes, Serialize,
};
use tokio_util::bytes::{Bytes, BytesMut};
pub fn decode_from_bytes_checked<'a, T: Archive>(
bytes: &'a [u8],
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
where
T::Archived: CheckBytes<DefaultValidator<'a>>,
{
rkyv::check_archived_root::<T>(bytes)
}
pub fn decode_from_bytes<'a, T: Archive>(
bytes: &'a [u8],
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
where
T::Archived: CheckBytes<DefaultValidator<'a>>,
{
// rkyv::check_archived_root::<T>(bytes)
unsafe { Ok(rkyv::archived_root::<T>(bytes)) }
}
// allow deseraial T to Bytes
pub fn encode_to_bytes<T, const N: usize>(val: &T) -> Bytes
where
T: Serialize<rkyv::ser::serializers::AllocSerializer<N>>,
{
let ret = rkyv::to_bytes::<_, N>(val).unwrap();
// let mut r = BytesMut::new();
// r.extend_from_slice(&ret);
// r.freeze()
ret.into_boxed_slice().into()
}
pub fn extract_bytes_from_archived_vec(raw_data: &Bytes, archived_data: &ArchivedVec<u8>) -> Bytes {
let ptr_range = archived_data.as_ptr_range();
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
return raw_data.slice(offset..offset + len);
}
pub fn extract_bytes_mut_from_archived_vec(
raw_data: &mut BytesMut,
archived_data: &ArchivedVec<u8>,
) -> BytesMut {
let ptr_range = archived_data.as_ptr_range();
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
raw_data.split_off(offset).split_to(len)
}
+433
View File
@@ -0,0 +1,433 @@
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::Arc;
use std::time::Duration;
use crossbeam::atomic::AtomicCell;
use easytier_rpc::{NatType, StunInfo};
use stun_format::Attr;
use tokio::net::{lookup_host, UdpSocket};
use tokio::sync::RwLock;
use tokio::task::JoinSet;
use crate::common::error::Error;
struct Stun {
stun_server: String,
req_repeat: u8,
resp_timeout: Duration,
}
#[derive(Debug, Clone, Copy)]
struct BindRequestResponse {
source_addr: SocketAddr,
mapped_socket_addr: Option<SocketAddr>,
changed_socket_addr: Option<SocketAddr>,
ip_changed: bool,
port_changed: bool,
}
impl BindRequestResponse {
pub fn get_mapped_addr_no_check(&self) -> &SocketAddr {
self.mapped_socket_addr.as_ref().unwrap()
}
}
impl Stun {
pub fn new(stun_server: String) -> Self {
Self {
stun_server,
req_repeat: 3,
resp_timeout: Duration::from_millis(3000),
}
}
async fn wait_stun_response<'a, const N: usize>(
&self,
buf: &'a mut [u8; N],
udp: &UdpSocket,
tids: &Vec<u128>,
) -> Result<(stun_format::Msg<'a>, SocketAddr), Error> {
let mut now = tokio::time::Instant::now();
let deadline = now + self.resp_timeout;
while now < deadline {
let mut udp_buf = [0u8; 1500];
let (len, remote_addr) =
tokio::time::timeout(deadline - now, udp.recv_from(udp_buf.as_mut_slice()))
.await??;
now = tokio::time::Instant::now();
if len < 20 {
continue;
}
// TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here
unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) };
let msg = stun_format::Msg::<'a>::from(&buf[..]);
tracing::trace!(b = ?&udp_buf[..len], ?msg, ?tids, "recv stun response");
if msg.typ().is_none() || msg.tid().is_none() {
continue;
}
if matches!(
msg.typ().as_ref().unwrap(),
stun_format::MsgType::BindingResponse
) && tids.contains(msg.tid().as_ref().unwrap())
{
return Ok((msg, remote_addr));
}
}
Err(Error::Unknown)
}
fn stun_addr(addr: stun_format::SocketAddr) -> SocketAddr {
match addr {
stun_format::SocketAddr::V4(ip, port) => {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port))
}
stun_format::SocketAddr::V6(ip, port) => {
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0))
}
}
}
fn extrace_mapped_addr(msg: &stun_format::Msg) -> Option<SocketAddr> {
let mut mapped_addr = None;
for x in msg.attrs_iter() {
match x {
Attr::MappedAddress(addr) => {
if mapped_addr.is_none() {
let _ = mapped_addr.insert(Self::stun_addr(addr));
}
}
Attr::XorMappedAddress(addr) => {
if mapped_addr.is_none() {
let _ = mapped_addr.insert(Self::stun_addr(addr));
}
}
_ => {}
}
}
mapped_addr
}
fn extract_changed_addr(msg: &stun_format::Msg) -> Option<SocketAddr> {
let mut changed_addr = None;
for x in msg.attrs_iter() {
match x {
Attr::ChangedAddress(addr) => {
if changed_addr.is_none() {
let _ = changed_addr.insert(Self::stun_addr(addr));
}
}
_ => {}
}
}
changed_addr
}
pub async fn bind_request(
&self,
source_port: u16,
change_ip: bool,
change_port: bool,
) -> Result<BindRequestResponse, Error> {
let stun_host = lookup_host(&self.stun_server)
.await?
.next()
.ok_or(Error::NotFound)?;
// let udp_socket = socket2::Socket::new(
// match stun_host {
// SocketAddr::V4(..) => socket2::Domain::IPV4,
// SocketAddr::V6(..) => socket2::Domain::IPV6,
// },
// socket2::Type::DGRAM,
// Some(socket2::Protocol::UDP),
// )?;
// udp_socket.set_reuse_port(true)?;
// udp_socket.set_reuse_address(true)?;
let udp = UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?;
// repeat req in case of packet loss
let mut tids = vec![];
for _ in 0..self.req_repeat {
let mut buf = [0u8; 28];
// memset buf
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
let mut msg = stun_format::MsgBuilder::from(buf.as_mut_slice());
msg.typ(stun_format::MsgType::BindingRequest).unwrap();
let tid = rand::random::<u32>();
msg.tid(tid as u128).unwrap();
if change_ip || change_port {
msg.add_attr(Attr::ChangeRequest {
change_ip,
change_port,
})
.unwrap();
}
tids.push(tid as u128);
tracing::trace!(b = ?msg.as_bytes(), tid, "send stun request");
udp.send_to(msg.as_bytes(), &stun_host).await?;
}
tracing::trace!("waiting stun response");
let mut buf = [0; 1620];
let (msg, recv_addr) = self.wait_stun_response(&mut buf, &udp, &tids).await?;
let changed_socket_addr = Self::extract_changed_addr(&msg);
let ip_changed = stun_host.ip() != recv_addr.ip();
let port_changed = stun_host.port() != recv_addr.port();
let resp = BindRequestResponse {
source_addr: udp.local_addr()?,
mapped_socket_addr: Self::extrace_mapped_addr(&msg),
changed_socket_addr,
ip_changed,
port_changed,
};
tracing::info!(
?stun_host,
?recv_addr,
?changed_socket_addr,
"finish stun bind request"
);
Ok(resp)
}
}
struct UdpNatTypeDetector {
stun_servers: Vec<String>,
}
impl UdpNatTypeDetector {
pub fn new(stun_servers: Vec<String>) -> Self {
Self { stun_servers }
}
async fn get_udp_nat_type(&self, mut source_port: u16) -> NatType {
// Like classic STUN (rfc3489). Detect NAT behavior for UDP.
// Modified from rfc3489. Requires at least two STUN servers.
let mut ret_test1_1 = None;
let mut ret_test1_2 = None;
let mut ret_test2 = None;
let mut ret_test3 = None;
if source_port == 0 {
let udp = UdpSocket::bind("0.0.0.0:0").await.unwrap();
source_port = udp.local_addr().unwrap().port();
}
let mut succ = false;
for server_ip in &self.stun_servers {
let stun = Stun::new(server_ip.clone());
let ret = stun.bind_request(source_port, false, false).await;
if ret.is_err() {
// Try another STUN server
continue;
}
if ret_test1_1.is_none() {
ret_test1_1 = ret.ok();
continue;
}
ret_test1_2 = ret.ok();
let ret = stun.bind_request(source_port, true, true).await;
if let Ok(resp) = ret {
if !resp.ip_changed || !resp.port_changed {
// Try another STUN server
continue;
}
}
ret_test2 = ret.ok();
ret_test3 = stun.bind_request(source_port, false, true).await.ok();
succ = true;
break;
}
if !succ {
return NatType::Unknown;
}
tracing::info!(
?ret_test1_1,
?ret_test1_2,
?ret_test2,
?ret_test3,
"finish stun test, try to detect nat type"
);
let ret_test1_1 = ret_test1_1.unwrap();
let ret_test1_2 = ret_test1_2.unwrap();
if ret_test1_1.mapped_socket_addr != ret_test1_2.mapped_socket_addr {
return NatType::Symmetric;
}
if ret_test1_1.mapped_socket_addr.is_some()
&& ret_test1_1.source_addr == ret_test1_1.mapped_socket_addr.unwrap()
{
if !ret_test2.is_none() {
return NatType::OpenInternet;
} else {
return NatType::SymUdpFirewall;
}
} else {
if let Some(ret_test2) = ret_test2 {
if source_port == ret_test2.get_mapped_addr_no_check().port()
&& source_port == ret_test1_1.get_mapped_addr_no_check().port()
{
return NatType::NoPat;
} else {
return NatType::FullCone;
}
} else {
if !ret_test3.is_none() {
return NatType::Restricted;
} else {
return NatType::PortRestricted;
}
}
}
}
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait StunInfoCollectorTrait: Send + Sync {
fn get_stun_info(&self) -> StunInfo;
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error>;
}
pub struct StunInfoCollector {
stun_servers: Arc<RwLock<Vec<String>>>,
udp_nat_type: Arc<AtomicCell<(NatType, std::time::Instant)>>,
redetect_notify: Arc<tokio::sync::Notify>,
tasks: JoinSet<()>,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for StunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
let (typ, time) = self.udp_nat_type.load();
StunInfo {
udp_nat_type: typ as i32,
tcp_nat_type: 0,
last_update_time: time.elapsed().as_secs() as i64,
}
}
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
let stun_servers = self.stun_servers.read().await.clone();
for server in stun_servers.iter() {
let stun = Stun::new(server.clone());
let Ok(ret) = stun.bind_request(local_port, false, false).await else {
tracing::warn!(?server, "stun bind request failed");
continue;
};
if let Some(mapped_addr) = ret.mapped_socket_addr {
return Ok(mapped_addr);
}
}
Err(Error::NotFound)
}
}
impl StunInfoCollector {
pub fn new(stun_servers: Vec<String>) -> Self {
let mut ret = Self {
stun_servers: Arc::new(RwLock::new(stun_servers)),
udp_nat_type: Arc::new(AtomicCell::new((
NatType::Unknown,
std::time::Instant::now(),
))),
redetect_notify: Arc::new(tokio::sync::Notify::new()),
tasks: JoinSet::new(),
};
ret.start_stun_routine();
ret
}
fn start_stun_routine(&mut self) {
let stun_servers = self.stun_servers.clone();
let udp_nat_type = self.udp_nat_type.clone();
let redetect_notify = self.redetect_notify.clone();
self.tasks.spawn(async move {
loop {
let detector = UdpNatTypeDetector::new(stun_servers.read().await.clone());
let ret = detector.get_udp_nat_type(0).await;
udp_nat_type.store((ret, std::time::Instant::now()));
let sleep_sec = match ret {
NatType::Unknown => 15,
_ => 60,
};
tracing::info!(?ret, ?sleep_sec, "finish udp nat type detect");
tokio::select! {
_ = redetect_notify.notified() => {}
_ = tokio::time::sleep(Duration::from_secs(sleep_sec)) => {}
}
}
});
}
pub fn update_stun_info(&self) {
self.redetect_notify.notify_one();
}
pub async fn set_stun_servers(&self, stun_servers: Vec<String>) {
*self.stun_servers.write().await = stun_servers;
self.update_stun_info();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_stun_bind_request() {
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
// let stun = Stun::new("stun.counterpath.com:3478".to_string());
let stun = Stun::new("180.235.108.91:3478".to_string());
// let stun = Stun::new("193.22.2.248:3478".to_string());
// let stun = Stun::new("stun.chat.bilibili.com:3478".to_string());
// let stun = Stun::new("stun.miwifi.com:3478".to_string());
let rs = stun.bind_request(12345, true, true).await.unwrap();
assert!(rs.ip_changed);
assert!(rs.port_changed);
let rs = stun.bind_request(12345, true, false).await.unwrap();
assert!(rs.ip_changed);
assert!(!rs.port_changed);
let rs = stun.bind_request(12345, false, true).await.unwrap();
assert!(!rs.ip_changed);
assert!(rs.port_changed);
let rs = stun.bind_request(12345, false, false).await.unwrap();
assert!(!rs.ip_changed);
assert!(!rs.port_changed);
}
#[tokio::test]
async fn test_udp_nat_type_detect() {
let detector = UdpNatTypeDetector::new(vec![
"stun.counterpath.com:3478".to_string(),
"180.235.108.91:3478".to_string(),
]);
let ret = detector.get_udp_nat_type(0).await;
assert_eq!(ret, NatType::FullCone);
}
}
+325
View File
@@ -0,0 +1,325 @@
// try connect peers directly, with either its public ip or lan ip
use std::sync::Arc;
use crate::{
common::{
constants::{self, DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC},
error::Error,
global_ctx::ArcGlobalCtx,
network::IPCollector,
},
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager, PeerId},
};
use easytier_rpc::{peer::GetIpListResponse, PeerConnInfo};
use tokio::{task::JoinSet, time::timeout};
use tracing::Instrument;
use super::create_connector_by_url;
#[tarpc::service]
pub trait DirectConnectorRpc {
async fn get_ip_list() -> GetIpListResponse;
}
#[async_trait::async_trait]
pub trait PeerManagerForDirectConnector {
async fn list_peers(&self) -> Vec<PeerId>;
async fn list_peer_conns(&self, peer_id: &PeerId) -> Option<Vec<PeerConnInfo>>;
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager>;
}
#[async_trait::async_trait]
impl PeerManagerForDirectConnector for PeerManager {
async fn list_peers(&self) -> Vec<PeerId> {
let mut ret = vec![];
let routes = self.list_routes().await;
for r in routes.iter() {
ret.push(r.peer_id.parse().unwrap());
}
ret
}
async fn list_peer_conns(&self, peer_id: &PeerId) -> Option<Vec<PeerConnInfo>> {
self.get_peer_map().list_peer_conns(peer_id).await
}
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
self.get_peer_rpc_mgr()
}
}
#[derive(Clone)]
struct DirectConnectorManagerRpcServer {
// TODO: this only cache for one src peer, should make it global
ip_list_collector: Arc<IPCollector>,
}
#[tarpc::server]
impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse {
return self.ip_list_collector.collect_ip_addrs().await;
}
}
impl DirectConnectorManagerRpcServer {
pub fn new(ip_collector: Arc<IPCollector>) -> Self {
Self {
ip_list_collector: ip_collector,
}
}
}
#[derive(Hash, Eq, PartialEq, Clone)]
struct DstBlackListItem(PeerId, String);
struct DirectConnectorManagerData {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
dst_blacklist: timedmap::TimedMap<DstBlackListItem, ()>,
}
impl std::fmt::Debug for DirectConnectorManagerData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DirectConnectorManagerData")
.field("peer_manager", &self.peer_manager)
.finish()
}
}
pub struct DirectConnectorManager {
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
data: Arc<DirectConnectorManagerData>,
tasks: JoinSet<()>,
}
impl DirectConnectorManager {
pub fn new(
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
) -> Self {
Self {
my_node_id,
global_ctx: global_ctx.clone(),
data: Arc::new(DirectConnectorManagerData {
global_ctx,
peer_manager,
dst_blacklist: timedmap::TimedMap::new(),
}),
tasks: JoinSet::new(),
}
}
pub fn run(&mut self) {
self.run_as_server();
self.run_as_client();
}
pub fn run_as_server(&mut self) {
self.data.peer_manager.get_peer_rpc_mgr().run_service(
constants::DIRECT_CONNECTOR_SERVICE_ID,
DirectConnectorManagerRpcServer::new(self.global_ctx.get_ip_collector()).serve(),
);
}
pub fn run_as_client(&mut self) {
let data = self.data.clone();
let my_node_id = self.my_node_id.clone();
self.tasks.spawn(
async move {
loop {
let peers = data.peer_manager.list_peers().await;
let mut tasks = JoinSet::new();
for peer_id in peers {
if peer_id == my_node_id {
continue;
}
tasks.spawn(Self::do_try_direct_connect(data.clone(), peer_id));
}
while let Some(task_ret) = tasks.join_next().await {
tracing::trace!(?task_ret, "direct connect task ret");
}
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
.instrument(tracing::info_span!("direct_connector_client", my_id = ?self.my_node_id)),
);
}
async fn do_try_connect_to_ip(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
addr: String,
) -> Result<(), Error> {
data.dst_blacklist.cleanup();
if data
.dst_blacklist
.contains(&DstBlackListItem(dst_peer_id.clone(), addr.clone()))
{
tracing::trace!("try_connect_to_ip failed, addr in blacklist: {}", addr);
return Err(Error::UrlInBlacklist);
}
let connector = create_connector_by_url(&addr, data.global_ctx.get_ip_collector()).await?;
let (peer_id, conn_id) = timeout(
std::time::Duration::from_secs(5),
data.peer_manager.try_connect(connector),
)
.await??;
// let (peer_id, conn_id) = data.peer_manager.try_connect(connector).await?;
if peer_id != dst_peer_id {
tracing::info!(
"connect to ip succ: {}, but peer id mismatch, expect: {}, actual: {}",
addr,
dst_peer_id,
peer_id
);
data.peer_manager
.get_peer_map()
.close_peer_conn(&peer_id, &conn_id)
.await?;
return Err(Error::InvalidUrl(addr));
}
Ok(())
}
#[tracing::instrument]
async fn try_connect_to_ip(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
addr: String,
) {
let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await;
if let Err(e) = ret {
if !matches!(e, Error::UrlInBlacklist) {
tracing::info!(
"try_connect_to_ip failed: {:?}, peer_id: {}",
e,
dst_peer_id
);
data.dst_blacklist.insert(
DstBlackListItem(dst_peer_id.clone(), addr.clone()),
(),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
}
} else {
log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id);
}
}
#[tracing::instrument]
async fn do_try_direct_connect(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let peer_manager = data.peer_manager.clone();
// check if we have direct connection with dst_peer_id
if let Some(c) = peer_manager.list_peer_conns(&dst_peer_id).await {
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
if !c.is_empty() {
return Ok(());
}
}
log::trace!("try direct connect to peer: {}", dst_peer_id);
let ip_list = peer_manager
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
let client =
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
let ip_list = client.get_ip_list(tarpc::context::current()).await;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
ip_list
})
.await?;
let mut tasks = JoinSet::new();
ip_list.interface_ipv4s.iter().for_each(|ip| {
let addr = format!("{}://{}:{}", "tcp", ip, 11010);
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr,
));
});
let addr = format!("{}://{}:{}", "tcp", ip_list.public_ipv4.clone(), 11010);
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr,
));
while let Some(ret) = tasks.join_next().await {
if let Err(e) = ret {
log::error!("join direct connect task failed: {:?}", e);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::{
connector::direct::DirectConnectorManager,
instance::listeners::ListenerManager,
peers::tests::{
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
tunnels::tcp_tunnel::TcpTunnelListener,
};
#[tokio::test]
async fn direct_connector_basic_test() {
let p_a = create_mock_peer_manager().await;
let p_b = create_mock_peer_manager().await;
let p_c = create_mock_peer_manager().await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.my_node_id())
.await
.unwrap();
let mut dm_a =
DirectConnectorManager::new(p_a.my_node_id(), p_a.get_global_ctx(), p_a.clone());
let mut dm_c =
DirectConnectorManager::new(p_c.my_node_id(), p_c.get_global_ctx(), p_c.clone());
dm_a.run_as_client();
dm_c.run_as_server();
let mut lis_c = ListenerManager::new(
p_c.my_node_id(),
p_c.get_global_ctx().net_ns.clone(),
p_c.clone(),
);
lis_c
.add_listener(TcpTunnelListener::new(
"tcp://0.0.0.0:11010".parse().unwrap(),
))
.await
.unwrap();
lis_c.run().await.unwrap();
wait_route_appear_with_cost(p_a.clone(), p_c.my_node_id(), Some(1))
.await
.unwrap();
}
}
+384
View File
@@ -0,0 +1,384 @@
use std::{collections::BTreeSet, sync::Arc};
use dashmap::{DashMap, DashSet};
use easytier_rpc::{
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
ListConnectorRequest, ManageConnectorRequest,
};
use tokio::{
sync::{broadcast::Receiver, mpsc, Mutex},
task::JoinSet,
time::timeout,
};
use crate::{
common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
netns::NetNS,
},
connector::set_bind_addr_for_peer_connector,
peers::peer_manager::PeerManager,
tunnels::{Tunnel, TunnelConnector},
use_global_var,
};
use super::create_connector_by_url;
type ConnectorMap = Arc<DashMap<String, Box<dyn TunnelConnector + Send + Sync>>>;
#[derive(Debug, Clone)]
struct ReconnResult {
dead_url: String,
peer_id: uuid::Uuid,
conn_id: uuid::Uuid,
}
struct ConnectorManagerData {
connectors: ConnectorMap,
reconnecting: DashSet<String>,
peer_manager: Arc<PeerManager>,
alive_conn_urls: Arc<Mutex<BTreeSet<String>>>,
// user removed connector urls
removed_conn_urls: Arc<DashSet<String>>,
net_ns: NetNS,
global_ctx: ArcGlobalCtx,
}
pub struct ManualConnectorManager {
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
data: Arc<ConnectorManagerData>,
tasks: JoinSet<()>,
}
impl ManualConnectorManager {
pub fn new(
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
) -> Self {
let connectors = Arc::new(DashMap::new());
let tasks = JoinSet::new();
let event_subscriber = global_ctx.subscribe();
let mut ret = Self {
my_node_id,
global_ctx: global_ctx.clone(),
data: Arc::new(ConnectorManagerData {
connectors,
reconnecting: DashSet::new(),
peer_manager,
alive_conn_urls: Arc::new(Mutex::new(BTreeSet::new())),
removed_conn_urls: Arc::new(DashSet::new()),
net_ns: global_ctx.net_ns.clone(),
global_ctx,
}),
tasks,
};
ret.tasks
.spawn(Self::conn_mgr_routine(ret.data.clone(), event_subscriber));
ret
}
pub fn add_connector<T>(&self, connector: T)
where
T: TunnelConnector + Send + Sync + 'static,
{
log::info!("add_connector: {}", connector.remote_url());
self.data
.connectors
.insert(connector.remote_url().into(), Box::new(connector));
}
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
self.add_connector(create_connector_by_url(url, self.global_ctx.get_ip_collector()).await?);
Ok(())
}
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> {
log::info!("remove_connector: {}", url);
if !self.list_connectors().await.iter().any(|x| x.url == url) {
return Err(Error::NotFound);
}
self.data.removed_conn_urls.insert(url.into());
Ok(())
}
pub async fn list_connectors(&self) -> Vec<Connector> {
let conn_urls: BTreeSet<String> = self
.data
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
.await
.into_iter()
.collect();
let mut ret = Vec::new();
for conn_url in conn_urls {
let mut status = ConnectorStatus::Connected;
if dead_urls.contains(&conn_url) {
status = ConnectorStatus::Disconnected;
}
ret.insert(
0,
Connector {
url: conn_url,
status: status.into(),
},
);
}
let reconnecting_urls: BTreeSet<String> = self
.data
.reconnecting
.iter()
.map(|x| x.clone().into())
.collect();
for conn_url in reconnecting_urls {
ret.insert(
0,
Connector {
url: conn_url,
status: ConnectorStatus::Connecting.into(),
},
);
}
ret
}
async fn conn_mgr_routine(
data: Arc<ConnectorManagerData>,
mut event_recv: Receiver<GlobalCtxEvent>,
) {
log::warn!("conn_mgr_routine started");
let mut reconn_interval = tokio::time::interval(std::time::Duration::from_millis(
use_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS),
));
let mut reconn_tasks = JoinSet::new();
let (reconn_result_send, mut reconn_result_recv) = mpsc::channel(100);
loop {
tokio::select! {
event = event_recv.recv() => {
if let Ok(event) = event {
Self::handle_event(&event, data.clone()).await;
} else {
log::warn!("event_recv closed");
panic!("event_recv closed");
}
}
_ = reconn_interval.tick() => {
let dead_urls = Self::collect_dead_conns(data.clone()).await;
if dead_urls.is_empty() {
continue;
}
for dead_url in dead_urls {
let data_clone = data.clone();
let sender = reconn_result_send.clone();
let (_, connector) = data.connectors.remove(&dead_url).unwrap();
let insert_succ = data.reconnecting.insert(dead_url.clone());
assert!(insert_succ);
reconn_tasks.spawn(async move {
sender.send(Self::conn_reconnect(data_clone.clone(), dead_url, connector).await).await.unwrap();
});
}
log::info!("reconn_interval tick, done");
}
ret = reconn_result_recv.recv() => {
log::warn!("reconn_tasks done, out: {:?}", ret);
let _ = reconn_tasks.join_next().await.unwrap();
}
}
}
}
async fn handle_event(event: &GlobalCtxEvent, data: Arc<ConnectorManagerData>) {
match event {
GlobalCtxEvent::PeerConnAdded(conn_info) => {
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.lock().await.insert(addr);
log::warn!("peer conn added: {:?}", conn_info);
}
GlobalCtxEvent::PeerConnRemoved(conn_info) => {
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.lock().await.remove(&addr);
log::warn!("peer conn removed: {:?}", conn_info);
}
GlobalCtxEvent::PeerAdded => todo!(),
GlobalCtxEvent::PeerRemoved => todo!(),
}
}
fn handle_remove_connector(data: Arc<ConnectorManagerData>) {
let remove_later = DashSet::new();
for it in data.removed_conn_urls.iter() {
let url = it.key();
if let Some(_) = data.connectors.remove(url) {
log::warn!("connector: {}, removed", url);
continue;
} else if data.reconnecting.contains(url) {
log::warn!("connector: {}, reconnecting, remove later.", url);
remove_later.insert(url.clone());
continue;
} else {
log::warn!("connector: {}, not found", url);
}
}
data.removed_conn_urls.clear();
for it in remove_later.iter() {
data.removed_conn_urls.insert(it.key().clone());
}
}
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
Self::handle_remove_connector(data.clone());
let curr_alive = data.alive_conn_urls.lock().await.clone();
let all_urls: BTreeSet<String> = data
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
&all_urls - &curr_alive
}
async fn conn_reconnect(
data: Arc<ConnectorManagerData>,
dead_url: String,
connector: Box<dyn TunnelConnector + Send + Sync>,
) -> Result<ReconnResult, Error> {
let connector = Arc::new(Mutex::new(Some(connector)));
let net_ns = data.net_ns.clone();
log::info!("reconnect: {}", dead_url);
let connector_clone = connector.clone();
let data_clone = data.clone();
let url_clone = dead_url.clone();
let ip_collector = data.global_ctx.get_ip_collector();
let reconn_task = async move {
let mut locked = connector_clone.lock().await;
let conn = locked.as_mut().unwrap();
// TODO: should support set v6 here, use url in connector array
set_bind_addr_for_peer_connector(conn, true, &ip_collector).await;
let _g = net_ns.guard();
log::info!("reconnect try connect... conn: {:?}", conn);
let tunnel = conn.connect().await?;
log::info!("reconnect get tunnel succ: {:?}", tunnel);
assert_eq!(
url_clone,
tunnel.info().unwrap().remote_addr,
"info: {:?}",
tunnel.info()
);
let (peer_id, conn_id) = data_clone.peer_manager.add_client_tunnel(tunnel).await?;
log::info!("reconnect succ: {} {} {}", peer_id, conn_id, url_clone);
Ok(ReconnResult {
dead_url: url_clone,
peer_id,
conn_id,
})
};
let ret = timeout(std::time::Duration::from_secs(1), reconn_task).await;
log::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
let conn = connector.lock().await.take().unwrap();
data.reconnecting.remove(&dead_url).unwrap();
data.connectors.insert(dead_url.clone(), conn);
ret?
}
}
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
#[tonic::async_trait]
impl ConnectorManageRpc for ConnectorManagerRpcService {
async fn list_connector(
&self,
_request: tonic::Request<ListConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> {
let mut ret = easytier_rpc::ListConnectorResponse::default();
let connectors = self.0.list_connectors().await;
ret.connectors = connectors;
Ok(tonic::Response::new(ret))
}
async fn manage_connector(
&self,
request: tonic::Request<ManageConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> {
let req = request.into_inner();
let url = url::Url::parse(&req.url)
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?;
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 {
self.0.remove_connector(url.path()).await.map_err(|e| {
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e))
})?;
return Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
));
} else {
self.0
.add_connector_by_url(url.as_str())
.await
.map_err(|e| {
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
})?;
}
Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
))
}
}
#[cfg(test)]
mod tests {
use crate::{
peers::tests::create_mock_peer_manager,
set_global_var,
tunnels::{Tunnel, TunnelError},
};
use super::*;
#[tokio::test]
async fn test_reconnect_with_connecting_addr() {
set_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, 1);
let peer_mgr = create_mock_peer_manager().await;
let my_node_id = uuid::Uuid::new_v4();
let mgr = ManualConnectorManager::new(my_node_id, peer_mgr.get_global_ctx(), peer_mgr);
struct MockConnector {}
#[async_trait::async_trait]
impl TunnelConnector for MockConnector {
fn remote_url(&self) -> url::Url {
url::Url::parse("tcp://aa.com").unwrap()
}
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Err(TunnelError::CommonError("fake error".into()))
}
}
mgr.add_connector(MockConnector {});
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
+73
View File
@@ -0,0 +1,73 @@
use std::{
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
sync::Arc,
};
use crate::{
common::{error::Error, network::IPCollector},
tunnels::{
ring_tunnel::RingTunnelConnector, tcp_tunnel::TcpTunnelConnector,
udp_tunnel::UdpTunnelConnector, TunnelConnector,
},
};
pub mod direct;
pub mod manual;
pub mod udp_hole_punch;
async fn set_bind_addr_for_peer_connector(
connector: &mut impl TunnelConnector,
is_ipv4: bool,
ip_collector: &Arc<IPCollector>,
) {
let ips = ip_collector.collect_ip_addrs().await;
if is_ipv4 {
let mut bind_addrs = vec![];
for ipv4 in ips.interface_ipv4s {
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
} else {
let mut bind_addrs = vec![];
for ipv6 in ips.interface_ipv6s {
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
}
let _ = connector;
}
pub async fn create_connector_by_url(
url: &str,
ip_collector: Arc<IPCollector>,
) -> Result<Box<dyn TunnelConnector + Send + Sync + 'static>, Error> {
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
match url.scheme() {
"tcp" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
let mut connector = TcpTunnelConnector::new(url);
set_bind_addr_for_peer_connector(&mut connector, dst_addr.is_ipv4(), &ip_collector)
.await;
return Ok(Box::new(connector));
}
"udp" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
let mut connector = UdpTunnelConnector::new(url);
set_bind_addr_for_peer_connector(&mut connector, dst_addr.is_ipv4(), &ip_collector)
.await;
return Ok(Box::new(connector));
}
"ring" => {
crate::tunnels::check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
let connector = RingTunnelConnector::new(url);
return Ok(Box::new(connector));
}
_ => {
return Err(Error::InvalidUrl(url.into()));
}
}
}
@@ -0,0 +1,523 @@
use std::{net::SocketAddr, sync::Arc};
use anyhow::Context;
use crossbeam::atomic::AtomicCell;
use easytier_rpc::NatType;
use rand::{seq::SliceRandom, Rng, SeedableRng};
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tracing::Instrument;
use crate::{
common::{
constants, error::Error, global_ctx::ArcGlobalCtx, rkyv_util::encode_to_bytes,
stun::StunInfoCollectorTrait,
},
peers::{peer_manager::PeerManager, PeerId},
tunnels::{
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
Tunnel, TunnelConnCounter, TunnelListener,
},
};
use super::direct::PeerManagerForDirectConnector;
#[tarpc::service]
pub trait UdpHolePunchService {
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
}
#[derive(Debug)]
struct UdpHolePunchListener {
socket: Arc<UdpSocket>,
tasks: JoinSet<()>,
running: Arc<AtomicCell<bool>>,
mapped_addr: SocketAddr,
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
listen_time: std::time::Instant,
last_select_time: AtomicCell<std::time::Instant>,
last_connected_time: Arc<AtomicCell<std::time::Instant>>,
}
impl UdpHolePunchListener {
async fn get_avail_port() -> Result<u16, Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
Ok(socket.local_addr()?.port())
}
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
let port = Self::get_avail_port().await?;
let listen_url = format!("udp://0.0.0.0:{}", port);
let gctx = peer_mgr.get_global_ctx();
let stun_info_collect = gctx.get_stun_info_collector();
let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?;
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
listener.listen().await?;
let socket = listener.get_socket().unwrap();
let running = Arc::new(AtomicCell::new(true));
let running_clone = running.clone();
let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
let last_connected_time_clone = last_connected_time.clone();
let conn_counter = listener.get_conn_counter();
let mut tasks = JoinSet::new();
tasks.spawn(async move {
while let Ok(conn) = listener.accept().await {
last_connected_time_clone.store(std::time::Instant::now());
tracing::warn!(?conn, "udp hole punching listener got peer connection");
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
tracing::error!(?e, "failed to add tunnel as server in hole punch listener");
}
}
running_clone.store(false);
});
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
Ok(Self {
tasks,
socket,
running,
mapped_addr,
conn_counter,
listen_time: std::time::Instant::now(),
last_select_time: AtomicCell::new(std::time::Instant::now()),
last_connected_time,
})
}
pub async fn get_socket(&self) -> Arc<UdpSocket> {
self.last_select_time.store(std::time::Instant::now());
self.socket.clone()
}
}
#[derive(Debug)]
struct UdpHolePunchConnectorData {
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
}
#[derive(Clone)]
struct UdpHolePunchRpcServer {
data: Arc<UdpHolePunchConnectorData>,
tasks: Arc<Mutex<JoinSet<()>>>,
}
#[tarpc::server]
impl UdpHolePunchService for UdpHolePunchRpcServer {
async fn try_punch_hole(
self,
_: tarpc::context::Context,
local_mapped_addr: SocketAddr,
) -> Option<SocketAddr> {
let (socket, mapped_addr) = self.select_listener().await?;
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
let my_udp_nat_type = self
.data
.global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type;
// if we are restricted, we need to send hole punching resp to client
if my_udp_nat_type == NatType::PortRestricted as i32
|| my_udp_nat_type == NatType::Restricted as i32
{
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
self.tasks.lock().await.spawn(async move {
for _ in 0..10 {
tracing::info!(?local_mapped_addr, "sending hole punching packet");
// generate a 128 bytes vec with random data
let mut rng = rand::rngs::StdRng::from_entropy();
let mut buf = vec![0u8; 128];
rng.fill(&mut buf[..]);
let udp_packet = UdpPacket::new_hole_punch_packet(buf);
let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet);
let _ = socket
.send_to(udp_packet_bytes.as_ref(), local_mapped_addr)
.await;
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
}
});
}
Some(mapped_addr)
}
}
impl UdpHolePunchRpcServer {
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
Self {
data,
tasks: Arc::new(Mutex::new(JoinSet::new())),
}
}
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
let all_listener_sockets = &self.data.listeners;
// remove listener that not have connection in for 20 seconds
all_listener_sockets.lock().await.retain(|listener| {
listener.last_connected_time.load().elapsed().as_secs() < 20
&& listener.conn_counter.get() > 0
});
let mut use_last = false;
if all_listener_sockets.lock().await.len() < 4 {
tracing::warn!("creating new udp hole punching listener");
all_listener_sockets.lock().await.push(
UdpHolePunchListener::new(self.data.peer_mgr.clone())
.await
.ok()?,
);
use_last = true;
}
let locked = all_listener_sockets.lock().await;
let listener = if use_last {
locked.last()?
} else {
locked.choose(&mut rand::rngs::StdRng::from_entropy())?
};
Some((listener.get_socket().await, listener.mapped_addr))
}
}
pub struct UdpHolePunchConnector {
data: Arc<UdpHolePunchConnectorData>,
tasks: JoinSet<()>,
}
// Currently support:
// Symmetric -> Full Cone
// Any Type of Full Cone -> Any Type of Full Cone
// if same level of full cone, node with smaller peer_id will be the initiator
// if different level of full cone, node with more strict level will be the initiator
impl UdpHolePunchConnector {
pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
Self {
data: Arc::new(UdpHolePunchConnectorData {
global_ctx,
peer_mgr,
listeners: Arc::new(Mutex::new(Vec::new())),
}),
tasks: JoinSet::new(),
}
}
pub async fn run_as_client(&mut self) -> Result<(), Error> {
let data = self.data.clone();
self.tasks.spawn(async move {
Self::main_loop(data).await;
});
Ok(())
}
pub async fn run_as_server(&mut self) -> Result<(), Error> {
self.data.peer_mgr.get_peer_rpc_mgr().run_service(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
UdpHolePunchRpcServer::new(self.data.clone()).serve(),
);
Ok(())
}
pub async fn run(&mut self) -> Result<(), Error> {
self.run_as_client().await?;
self.run_as_server().await?;
Ok(())
}
async fn collect_peer_to_connect(data: Arc<UdpHolePunchConnectorData>) -> Vec<PeerId> {
let mut peers_to_connect = Vec::new();
// do not do anything if:
// 1. our stun test has not finished
// 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
let my_nat_type = data
.global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type;
let my_nat_type = NatType::try_from(my_nat_type).unwrap();
if my_nat_type == NatType::Unknown
|| my_nat_type == NatType::OpenInternet
|| my_nat_type == NatType::NoPat
{
return peers_to_connect;
}
// collect peer list from peer manager and do some filter:
// 1. peers without direct conns;
// 2. peers is full cone (any restricted type);
for route in data.peer_mgr.list_routes().await.iter() {
let Some(peer_stun_info) = route.stun_info.as_ref() else {
continue;
};
let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else {
continue;
};
let peer_id: PeerId = route.peer_id.parse().unwrap();
let conns = data.peer_mgr.list_peer_conns(&peer_id).await;
if conns.is_some() && conns.unwrap().len() > 0 {
continue;
}
// if peer is symmetric ignore it because we cannot connect to it
// if peer is open internet or no pat, direct connector will connecto to it
if peer_nat_type == NatType::Unknown
|| peer_nat_type == NatType::OpenInternet
|| peer_nat_type == NatType::NoPat
|| peer_nat_type == NatType::Symmetric
|| peer_nat_type == NatType::SymUdpFirewall
{
continue;
}
// if we are symmetric, we can only connect to full cone
// TODO: can also connect to restricted full cone, with some extra work
if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall)
&& peer_nat_type != NatType::FullCone
{
continue;
}
// if we have smae level of full cone, node with smaller peer_id will be the initiator
if my_nat_type == peer_nat_type {
if data.global_ctx.id > peer_id {
continue;
}
} else {
// if we have different level of full cone
// we will be the initiator if we have more strict level
if my_nat_type < peer_nat_type {
continue;
}
}
tracing::info!(
?peer_id,
?peer_nat_type,
?my_nat_type,
?data.global_ctx.id,
"found peer to do hole punching"
);
peers_to_connect.push(peer_id);
}
peers_to_connect
}
#[tracing::instrument]
async fn do_hole_punching(
data: Arc<UdpHolePunchConnectorData>,
dst_peer_id: PeerId,
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
tracing::info!(?dst_peer_id, "start hole punching");
// client: choose a local udp port, and get the pubic mapped port from stun server
let socket = UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")?;
let local_socket_addr = socket.local_addr()?;
let local_port = socket.local_addr()?.port();
drop(socket); // drop the socket to release the port
let local_mapped_addr = data
.global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(local_port)
.await
.with_context(|| "failed to get udp port mapping")?;
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
let Some(remote_mapped_addr) = data
.peer_mgr
.get_peer_rpc_mgr()
.do_client_rpc_scoped(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
dst_peer_id,
|c| async {
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
let remote_mapped_addr = client
.try_punch_hole(tarpc::context::current(), local_mapped_addr)
.await;
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr");
remote_mapped_addr
},
)
.await?
else {
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
};
// server: will send some punching resps, total 10 packets.
// client: use the socket to create UdpTunnel with UdpTunnelConnector
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
let connector = UdpTunnelConnector::new(
format!(
"udp://{}:{}",
remote_mapped_addr.ip(),
remote_mapped_addr.port()
)
.to_string()
.parse()
.unwrap(),
);
let socket = UdpSocket::bind(local_socket_addr)
.await
.with_context(|| "")?;
Ok(connector
.try_connect_with_socket(socket)
.await
.with_context(|| "UdpTunnelConnector failed to connect remote")?)
}
async fn main_loop(data: Arc<UdpHolePunchConnectorData>) {
loop {
let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await;
tracing::trace!(?peers_to_connect, "peers to connect");
if peers_to_connect.len() == 0 {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
let mut tasks: JoinSet<Result<(), anyhow::Error>> = JoinSet::new();
for peer_id in peers_to_connect {
let data = data.clone();
tasks.spawn(
async move {
let tunnel = Self::do_hole_punching(data.clone(), peer_id)
.await
.with_context(|| "failed to do hole punching")?;
let _ =
data.peer_mgr
.add_client_tunnel(tunnel)
.await
.with_context(|| {
"failed to add tunnel as client in hole punch connector"
})?;
Ok(())
}
.instrument(tracing::info_span!("doing hole punching client", ?peer_id)),
);
}
while let Some(res) = tasks.join_next().await {
if let Err(e) = res {
tracing::error!(?e, "failed to join hole punching job");
continue;
}
match res.unwrap() {
Err(e) => {
tracing::error!(?e, "failed to do hole punching job");
}
Ok(_) => {
tracing::info!("hole punching job succeed");
}
}
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use easytier_rpc::{NatType, StunInfo};
use crate::{
common::{error::Error, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::UdpHolePunchConnector,
peers::{
peer_manager::PeerManager,
tests::{
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
},
tests::enable_log,
};
struct MockStunInfoCollector {
udp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: NatType::Unknown as i32,
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
}
}
async fn get_udp_port_mapping(&self, port: u16) -> Result<std::net::SocketAddr, Error> {
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
async fn create_mock_peer_manager_with_mock_stun(udp_nat_type: NatType) -> Arc<PeerManager> {
let p_a = create_mock_peer_manager().await;
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
p_a.get_global_ctx().replace_stun_info_collector(collector);
p_a
}
#[tokio::test]
async fn hole_punching() {
enable_log();
let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.my_node_id())
.await
.unwrap();
println!("{:?}", p_a.list_routes().await);
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone());
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone());
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
wait_route_appear_with_cost(p_a.clone(), p_c.my_node_id(), Some(1))
.await
.unwrap();
println!("{:?}", p_a.list_routes().await);
}
}
+301
View File
@@ -0,0 +1,301 @@
use std::{
mem::MaybeUninit,
net::{IpAddr, Ipv4Addr, SocketAddrV4},
sync::Arc,
thread,
};
use pnet::packet::{
icmp::{self, IcmpTypes},
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
Packet,
};
use socket2::Socket;
use tokio::{
sync::{mpsc::UnboundedSender, Mutex},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx},
peers::{
packet,
peer_manager::{PeerManager, PeerPacketFilter},
PeerId,
},
};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct IcmpNatKey {
dst_ip: std::net::IpAddr,
icmp_id: u16,
icmp_seq: u16,
}
#[derive(Debug)]
struct IcmpNatEntry {
src_peer_id: PeerId,
my_peer_id: PeerId,
src_ip: IpAddr,
start_time: std::time::Instant,
}
impl IcmpNatEntry {
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr) -> Result<Self, Error> {
Ok(Self {
src_peer_id,
my_peer_id,
src_ip,
start_time: std::time::Instant::now(),
})
}
}
type IcmpNatTable = Arc<dashmap::DashMap<IcmpNatKey, IcmpNatEntry>>;
type NewPacketSender = tokio::sync::mpsc::UnboundedSender<IcmpNatKey>;
type NewPacketReceiver = tokio::sync::mpsc::UnboundedReceiver<IcmpNatKey>;
#[derive(Debug)]
pub struct IcmpProxy {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
cidr_set: CidrSet,
socket: socket2::Socket,
nat_table: IcmpNatTable,
tasks: Mutex<JoinSet<()>>,
}
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
let (size, addr) = socket.recv_from(buf)?;
let addr = match addr.as_socket() {
None => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
Some(add) => add.ip(),
};
Ok((size, addr))
}
fn socket_recv_loop(
socket: Socket,
nat_table: IcmpNatTable,
sender: UnboundedSender<packet::Packet>,
) {
let mut buf = [0u8; 4096];
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
loop {
let Ok((len, peer_ip)) = socket_recv(&socket, data) else {
continue;
};
if !peer_ip.is_ipv4() {
continue;
}
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[12..12 + len]) else {
continue;
};
let Some(icmp_packet) = icmp::echo_reply::EchoReplyPacket::new(ipv4_packet.payload())
else {
continue;
};
if icmp_packet.get_icmp_type() != IcmpTypes::EchoReply {
continue;
}
let key = IcmpNatKey {
dst_ip: peer_ip,
icmp_id: icmp_packet.get_identifier(),
icmp_seq: icmp_packet.get_sequence_number(),
};
let Some((_, v)) = nat_table.remove(&key) else {
continue;
};
// send packet back to the peer where this request origin.
let IpAddr::V4(dest_ip) = v.src_ip else {
continue;
};
ipv4_packet.set_destination(dest_ip);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
let peer_packet = packet::Packet::new_data_packet(
v.my_peer_id,
v.src_peer_id,
&ipv4_packet.to_immutable().packet(),
);
if let Err(e) = sender.send(peer_packet) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
break;
}
}
}
#[async_trait::async_trait]
impl PeerPacketFilter for IcmpProxy {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
_: &Bytes,
) -> Option<()> {
let _ = self.global_ctx.get_ipv4()?;
let packet::ArchivedPacketBody::Data(x) = &packet.body else {
return None;
};
let ipv4 = Ipv4Packet::new(&x.data)?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
{
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
// drop it because we do not support other icmp types
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
return Some(());
}
let icmp_id = icmp_packet.get_identifier();
let icmp_seq = icmp_packet.get_sequence_number();
let key = IcmpNatKey {
dst_ip: ipv4.get_destination().into(),
icmp_id,
icmp_seq,
};
if packet.to_peer.is_none() {
return None;
}
let value = IcmpNatEntry::new(
packet.from_peer.to_uuid(),
packet.to_peer.as_ref().unwrap().to_uuid(),
ipv4.get_source().into(),
)
.ok()?;
if let Some(old) = self.nat_table.insert(key, value) {
tracing::info!("icmp nat table entry replaced: {:?}", old);
}
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
tracing::error!("send icmp packet failed: {:?}", e);
}
Some(())
}
}
impl IcmpProxy {
pub fn new(
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
) -> Result<Arc<Self>, Error> {
let cidr_set = CidrSet::new(global_ctx.clone());
let _g = global_ctx.net_ns.guard();
let socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::RAW,
Some(socket2::Protocol::ICMPV4),
)?;
socket.bind(&socket2::SockAddr::from(SocketAddrV4::new(
std::net::Ipv4Addr::UNSPECIFIED,
0,
)))?;
let ret = Self {
global_ctx,
peer_manager,
cidr_set,
socket,
nat_table: Arc::new(dashmap::DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
};
Ok(Arc::new(ret))
}
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
self.start_icmp_proxy().await?;
self.start_nat_table_cleaner().await?;
Ok(())
}
async fn start_nat_table_cleaner(self: &Arc<Self>) -> Result<(), Error> {
let nat_table = self.nat_table.clone();
self.tasks.lock().await.spawn(
async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
nat_table.retain(|_, v| v.start_time.elapsed().as_secs() < 20);
}
}
.instrument(tracing::info_span!("icmp proxy nat table cleaner")),
);
Ok(())
}
async fn start_icmp_proxy(self: &Arc<Self>) -> Result<(), Error> {
let socket = self.socket.try_clone()?;
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
let nat_table = self.nat_table.clone();
thread::spawn(|| {
socket_recv_loop(socket, nat_table, sender);
});
let peer_manager = self.peer_manager.clone();
self.tasks.lock().await.spawn(
async move {
while let Some(msg) = receiver.recv().await {
let to_peer_id: uuid::Uuid = msg.to_peer.as_ref().unwrap().clone().into();
let ret = peer_manager.send_msg(msg.into(), &to_peer_id).await;
if ret.is_err() {
tracing::error!("send icmp packet to peer failed: {:?}", ret);
}
}
}
.instrument(tracing::info_span!("icmp proxy send loop")),
);
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
Ok(())
}
fn send_icmp_packet(
&self,
dst_ip: Ipv4Addr,
icmp_packet: &icmp::echo_request::EchoRequestPacket,
) -> Result<(), Error> {
self.socket.send_to(
icmp_packet.packet(),
&SocketAddrV4::new(dst_ip.into(), 0).into(),
)?;
Ok(())
}
}
+51
View File
@@ -0,0 +1,51 @@
use dashmap::DashSet;
use std::sync::Arc;
use tokio::task::JoinSet;
use crate::common::global_ctx::ArcGlobalCtx;
pub mod icmp_proxy;
pub mod tcp_proxy;
#[derive(Debug)]
struct CidrSet {
global_ctx: ArcGlobalCtx,
cidr_set: Arc<DashSet<cidr::IpCidr>>,
tasks: JoinSet<()>,
}
impl CidrSet {
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
let mut ret = Self {
global_ctx,
cidr_set: Arc::new(DashSet::new()),
tasks: JoinSet::new(),
};
ret.run_cidr_updater();
ret
}
fn run_cidr_updater(&mut self) {
let global_ctx = self.global_ctx.clone();
let cidr_set = self.cidr_set.clone();
self.tasks.spawn(async move {
let mut last_cidrs = vec![];
loop {
let cidrs = global_ctx.get_proxy_cidrs();
if cidrs != last_cidrs {
last_cidrs = cidrs.clone();
cidr_set.clear();
for cidr in cidrs.iter() {
cidr_set.insert(cidr.clone());
}
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
}
pub fn contains_v4(&self, ip: std::net::Ipv4Addr) -> bool {
let ip = ip.into();
return self.cidr_set.iter().any(|cidr| cidr.contains(&ip));
}
}
+402
View File
@@ -0,0 +1,402 @@
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use pnet::packet::ip::IpNextHeaderProtocols;
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::atomic::AtomicU16;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::copy_bidirectional;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::bytes::{Bytes, BytesMut};
use tracing::Instrument;
use crate::common::error::Result;
use crate::common::global_ctx::GlobalCtx;
use crate::common::netns::NetNS;
use crate::peers::packet::{self, ArchivedPacket};
use crate::peers::peer_manager::{NicPacketFilter, PeerManager, PeerPacketFilter};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq)]
enum NatDstEntryState {
// receive syn packet but not start connecting to dst
SynReceived,
// connecting to dst
ConnectingDst,
// connected to dst
Connected,
// connection closed
Closed,
}
#[derive(Debug)]
pub struct NatDstEntry {
id: uuid::Uuid,
src: SocketAddr,
dst: SocketAddr,
start_time: Instant,
tasks: Mutex<JoinSet<()>>,
state: AtomicCell<NatDstEntryState>,
}
impl NatDstEntry {
pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
Self {
id: uuid::Uuid::new_v4(),
src,
dst,
start_time: Instant::now(),
tasks: Mutex::new(JoinSet::new()),
state: AtomicCell::new(NatDstEntryState::SynReceived),
}
}
}
type ArcNatDstEntry = Arc<NatDstEntry>;
type SynSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
type ConnSockMap = Arc<DashMap<uuid::Uuid, ArcNatDstEntry>>;
// peer src addr to nat entry, when respond tcp packet, should modify the tcp src addr to the nat entry's dst addr
type AddrConnSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
#[derive(Debug)]
pub struct TcpProxy {
global_ctx: Arc<GlobalCtx>,
peer_manager: Arc<PeerManager>,
local_port: AtomicU16,
tasks: Arc<Mutex<JoinSet<()>>>,
syn_map: SynSockMap,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
cidr_set: CidrSet,
}
#[async_trait::async_trait]
impl PeerPacketFilter for TcpProxy {
async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> {
let ipv4_addr = self.global_ctx.get_ipv4()?;
let packet::ArchivedPacketBody::Data(x) = &packet.body else {
return None;
};
let ipv4 = Ipv4Packet::new(&x.data)?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
let mut packet_buffer = BytesMut::with_capacity(x.data.len());
packet_buffer.extend_from_slice(&x.data.to_vec());
let (ip_buffer, tcp_buffer) =
packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4);
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
if is_tcp_syn {
let source_ip = ip_packet.get_source();
let source_port = tcp_packet.get_source();
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
let dest_ip = ip_packet.get_destination();
let dest_port = tcp_packet.get_destination();
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
let old_val = self
.syn_map
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
}
ip_packet.set_destination(ipv4_addr);
tcp_packet.set_destination(self.get_local_port());
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded");
if let Err(e) = self
.peer_manager
.get_nic_channel()
.send(packet_buffer.freeze())
.await
{
tracing::error!("send to nic failed: {:?}", e);
}
Some(())
}
}
#[async_trait::async_trait]
impl NicPacketFilter for TcpProxy {
async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut {
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
return data;
};
let header_len = {
let Some(ipv4) = &Ipv4Packet::new(&data[..]) else {
return data;
};
if ipv4.get_version() != 4
|| ipv4.get_source() != my_ipv4
|| ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
{
return data;
}
ipv4.get_header_length() as usize * 4
};
let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len);
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
if tcp_packet.get_source() != self.get_local_port() {
return data;
}
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
ip_packet.get_destination(),
tcp_packet.get_destination(),
));
tracing::trace!(dst_addr = ?dst_addr, "tcp packet try find entry");
let entry = if let Some(entry) = self.addr_conn_map.get(&dst_addr) {
entry
} else {
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
return data;
};
syn_entry
};
let nat_entry = entry.clone();
drop(entry);
assert_eq!(nat_entry.src, dst_addr);
let IpAddr::V4(ip) = nat_entry.dst.ip() else {
panic!("v4 nat entry src ip is not v4");
};
ip_packet.set_source(ip);
tcp_packet.set_source(nat_entry.dst.port());
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
data
}
}
impl TcpProxy {
pub fn new(global_ctx: Arc<GlobalCtx>, peer_manager: Arc<PeerManager>) -> Arc<Self> {
Arc::new(Self {
global_ctx: global_ctx.clone(),
peer_manager,
local_port: AtomicU16::new(0),
tasks: Arc::new(Mutex::new(JoinSet::new())),
syn_map: Arc::new(DashMap::new()),
conn_map: Arc::new(DashMap::new()),
addr_conn_map: Arc::new(DashMap::new()),
cidr_set: CidrSet::new(global_ctx),
})
}
fn update_ipv4_packet_checksum(
ipv4_packet: &mut MutableIpv4Packet,
tcp_packet: &mut MutableTcpPacket,
) {
tcp_packet.set_checksum(ipv4_checksum(
&tcp_packet.to_immutable(),
&ipv4_packet.get_source(),
&ipv4_packet.get_destination(),
));
ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable()));
}
pub async fn start(self: &Arc<Self>) -> Result<()> {
self.run_syn_map_cleaner().await?;
self.run_listener().await?;
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.clone()))
.await;
Ok(())
}
async fn run_syn_map_cleaner(&self) -> Result<()> {
let syn_map = self.syn_map.clone();
let tasks = self.tasks.clone();
let syn_map_cleaner_task = async move {
loop {
syn_map.retain(|_, entry| {
if entry.start_time.elapsed() > Duration::from_secs(30) {
tracing::warn!(entry = ?entry, "syn nat entry expired");
entry.state.store(NatDstEntryState::Closed);
false
} else {
true
}
});
tokio::time::sleep(Duration::from_secs(10)).await;
}
};
tasks.lock().await.spawn(syn_map_cleaner_task);
Ok(())
}
async fn run_listener(&self) -> Result<()> {
// bind on both v4 & v6
let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
let net_ns = self.global_ctx.net_ns.clone();
let tcp_listener = net_ns
.run_async(|| async { TcpListener::bind(&listen_addr).await })
.await?;
self.local_port.store(
tcp_listener.local_addr()?.port(),
std::sync::atomic::Ordering::Relaxed,
);
let tasks = self.tasks.clone();
let syn_map = self.syn_map.clone();
let conn_map = self.conn_map.clone();
let addr_conn_map = self.addr_conn_map.clone();
let accept_task = async move {
tracing::info!(listener = ?tcp_listener, "tcp connection start accepting");
let conn_map = conn_map.clone();
while let Ok((tcp_stream, socket_addr)) = tcp_listener.accept().await {
let Some(entry) = syn_map.get(&socket_addr) else {
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
continue;
};
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
let entry_clone = entry.clone();
drop(entry);
syn_map.remove_if(&socket_addr, |_, entry| entry.id == entry_clone.id);
entry_clone.state.store(NatDstEntryState::ConnectingDst);
let _ = addr_conn_map.insert(entry_clone.src, entry_clone.clone());
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
assert!(old_nat_val.is_none());
tasks.lock().await.spawn(Self::connect_to_nat_dst(
net_ns.clone(),
tcp_stream,
conn_map.clone(),
addr_conn_map.clone(),
entry_clone,
));
}
tracing::error!("nat tcp listener exited");
panic!("nat tcp listener exited");
};
self.tasks
.lock()
.await
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
Ok(())
}
fn remove_entry_from_all_conn_map(
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
conn_map.remove(&nat_entry.id);
addr_conn_map.remove_if(&nat_entry.src, |_, entry| entry.id == nat_entry.id);
}
async fn connect_to_nat_dst(
net_ns: NetNS,
src_tcp_stream: TcpStream,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
if let Err(e) = src_tcp_stream.set_nodelay(true) {
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
}
let _guard = net_ns.guard();
let socket = TcpSocket::new_v4().unwrap();
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
}
let Ok(Ok(dst_tcp_stream)) = tokio::time::timeout(
Duration::from_secs(10),
TcpSocket::new_v4().unwrap().connect(nat_entry.dst),
)
.await
else {
tracing::error!("connect to dst failed: {:?}", nat_entry);
nat_entry.state.store(NatDstEntryState::Closed);
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry);
return;
};
drop(_guard);
assert_eq!(nat_entry.state.load(), NatDstEntryState::ConnectingDst);
nat_entry.state.store(NatDstEntryState::Connected);
Self::handle_nat_connection(
src_tcp_stream,
dst_tcp_stream,
conn_map,
addr_conn_map,
nat_entry,
)
.await;
}
async fn handle_nat_connection(
mut src_tcp_stream: TcpStream,
mut dst_tcp_stream: TcpStream,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
let nat_entry_clone = nat_entry.clone();
nat_entry.tasks.lock().await.spawn(async move {
let ret = copy_bidirectional(&mut src_tcp_stream, &mut dst_tcp_stream).await;
tracing::trace!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
nat_entry_clone.state.store(NatDstEntryState::Closed);
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry_clone);
});
}
pub fn get_local_port(&self) -> u16 {
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
}
}
+413
View File
@@ -0,0 +1,413 @@
use std::borrow::BorrowMut;
use std::io::Write;
use std::sync::Arc;
use futures::StreamExt;
use pnet::packet::ethernet::EthernetPacket;
use pnet::packet::ipv4::Ipv4Packet;
use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::bytes::{Bytes, BytesMut};
use tonic::transport::Server;
use uuid::Uuid;
use crate::common::config_fs::ConfigFs;
use crate::common::error::Error;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx};
use crate::common::netns::NetNS;
use crate::connector::direct::DirectConnectorManager;
use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager};
use crate::connector::udp_hole_punch::UdpHolePunchConnector;
use crate::gateway::icmp_proxy::IcmpProxy;
use crate::gateway::tcp_proxy::TcpProxy;
use crate::peers::peer_manager::PeerManager;
use crate::peers::rip_route::BasicRoute;
use crate::peers::rpc_service::PeerManagerRpcService;
use crate::tunnels::SinkItem;
use tokio_stream::wrappers::ReceiverStream;
use super::listeners::ListenerManager;
use super::virtual_nic;
pub struct InstanceConfigWriter {
config: ConfigFs,
}
impl InstanceConfigWriter {
pub fn new(inst_name: &str) -> Self {
InstanceConfigWriter {
config: ConfigFs::new(inst_name),
}
}
pub fn set_ns(self, net_ns: Option<String>) -> Self {
let net_ns_in_conf = if let Some(net_ns) = net_ns {
net_ns
} else {
"".to_string()
};
self.config
.add_file("net_ns")
.unwrap()
.write_all(net_ns_in_conf.as_bytes())
.unwrap();
self
}
pub fn set_addr(self, addr: String) -> Self {
self.config
.add_file("ipv4")
.unwrap()
.write_all(addr.as_bytes())
.unwrap();
self
}
}
pub struct Instance {
inst_name: String,
id: uuid::Uuid,
virtual_nic: Option<Arc<virtual_nic::VirtualNic>>,
peer_packet_receiver: Option<ReceiverStream<SinkItem>>,
tasks: JoinSet<()>,
peer_manager: Arc<PeerManager>,
listener_manager: Arc<Mutex<ListenerManager<PeerManager>>>,
conn_manager: Arc<ManualConnectorManager>,
direct_conn_manager: Arc<DirectConnectorManager>,
udp_hole_puncher: Arc<Mutex<UdpHolePunchConnector>>,
tcp_proxy: Arc<TcpProxy>,
icmp_proxy: Arc<IcmpProxy>,
global_ctx: ArcGlobalCtx,
}
impl Instance {
pub fn new(inst_name: &str) -> Self {
let config = ConfigFs::new(inst_name);
let net_ns_in_conf = config.get_or_default("net_ns", || "".to_string()).unwrap();
let net_ns = NetNS::new(if net_ns_in_conf.is_empty() {
None
} else {
Some(net_ns_in_conf.clone())
});
let addr = config
.get_or_default("ipv4", || "10.144.144.10".to_string())
.unwrap();
log::info!(
"[INIT] instance creating. inst_name: {}, addr: {}, netns: {}",
inst_name,
addr,
net_ns_in_conf
);
let (peer_packet_sender, peer_packet_receiver) = tokio::sync::mpsc::channel(100);
let global_ctx = Arc::new(GlobalCtx::new(inst_name, config, net_ns.clone()));
let id = global_ctx.get_id();
let peer_manager = Arc::new(PeerManager::new(
global_ctx.clone(),
peer_packet_sender.clone(),
));
let listener_manager = Arc::new(Mutex::new(ListenerManager::new(
id,
net_ns.clone(),
peer_manager.clone(),
)));
let conn_manager = Arc::new(ManualConnectorManager::new(
id,
global_ctx.clone(),
peer_manager.clone(),
));
let mut direct_conn_manager =
DirectConnectorManager::new(id, global_ctx.clone(), peer_manager.clone());
direct_conn_manager.run();
let udp_hole_puncher = UdpHolePunchConnector::new(global_ctx.clone(), peer_manager.clone());
let arc_tcp_proxy = TcpProxy::new(global_ctx.clone(), peer_manager.clone());
let arc_icmp_proxy = IcmpProxy::new(global_ctx.clone(), peer_manager.clone()).unwrap();
Instance {
inst_name: inst_name.to_string(),
id,
virtual_nic: None,
peer_packet_receiver: Some(ReceiverStream::new(peer_packet_receiver)),
tasks: JoinSet::new(),
peer_manager,
listener_manager,
conn_manager,
direct_conn_manager: Arc::new(direct_conn_manager),
udp_hole_puncher: Arc::new(Mutex::new(udp_hole_puncher)),
tcp_proxy: arc_tcp_proxy,
icmp_proxy: arc_icmp_proxy,
global_ctx,
}
}
pub fn get_conn_manager(&self) -> Arc<ManualConnectorManager> {
self.conn_manager.clone()
}
async fn do_forward_nic_to_peers_ipv4(ret: BytesMut, mgr: &PeerManager) {
if let Some(ipv4) = Ipv4Packet::new(&ret) {
if ipv4.get_version() != 4 {
tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4);
}
let dst_ipv4 = ipv4.get_destination();
tracing::trace!(
?ret,
"[USER_PACKET] recv new packet from tun device and forward to peers."
);
let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await;
if send_ret.is_err() {
tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed")
}
} else {
tracing::warn!(?ret, "[USER_PACKET] not ipv4 packet");
}
}
async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) {
if let Some(eth) = EthernetPacket::new(&ret) {
log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype());
Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await;
} else {
log::warn!("not ipv4 packet: {:?}", ret);
}
}
fn do_forward_nic_to_peers(&mut self) -> Result<(), Error> {
// read from nic and write to corresponding tunnel
let nic = self.virtual_nic.as_ref().unwrap();
let nic = nic.clone();
let mgr = self.peer_manager.clone();
self.tasks.spawn(async move {
let mut stream = nic.pin_recv_stream();
while let Some(ret) = stream.next().await {
if ret.is_err() {
log::error!("read from nic failed: {:?}", ret);
break;
}
Self::do_forward_nic_to_peers_ipv4(ret.unwrap(), mgr.as_ref()).await;
// Self::do_forward_nic_to_peers_ethernet(ret.into(), mgr.as_ref()).await;
}
});
Ok(())
}
fn do_forward_peers_to_nic(
tasks: &mut JoinSet<()>,
nic: Arc<virtual_nic::VirtualNic>,
channel: Option<ReceiverStream<Bytes>>,
) {
tasks.spawn(async move {
let send = nic.pin_send_stream();
let channel = channel.unwrap();
let ret = channel
.map(|packet| {
log::trace!(
"[USER_PACKET] forward packet from peers to nic. packet: {:?}",
packet
);
Ok(packet)
})
.forward(send)
.await;
if ret.is_err() {
panic!("do_forward_tunnel_to_nic");
}
});
}
pub async fn run(&mut self) -> Result<(), Error> {
let ipv4_addr = self.global_ctx.get_ipv4().unwrap();
let mut nic = virtual_nic::VirtualNic::new(self.get_global_ctx())
.create_dev()
.await?
.link_up()
.await?
.remove_ip(None)
.await?
.add_ip(ipv4_addr, 24)
.await?;
if cfg!(target_os = "macos") {
nic = nic.add_route(ipv4_addr, 24).await?;
}
self.virtual_nic = Some(Arc::new(nic));
self.do_forward_nic_to_peers().unwrap();
Self::do_forward_peers_to_nic(
self.tasks.borrow_mut(),
self.virtual_nic.as_ref().unwrap().clone(),
self.peer_packet_receiver.take(),
);
self.listener_manager
.lock()
.await
.prepare_listeners()
.await?;
self.listener_manager.lock().await.run().await?;
self.peer_manager.run().await?;
let route = BasicRoute::new(self.id(), self.global_ctx.clone());
self.peer_manager.set_route(route).await;
self.run_rpc_server().unwrap();
self.tcp_proxy.start().await.unwrap();
self.icmp_proxy.start().await.unwrap();
self.run_proxy_cidrs_route_updater();
self.udp_hole_puncher.lock().await.run().await?;
Ok(())
}
pub fn get_peer_manager(&self) -> Arc<PeerManager> {
self.peer_manager.clone()
}
pub async fn close_peer_conn(&mut self, peer_id: &Uuid, conn_id: &Uuid) -> Result<(), Error> {
self.peer_manager
.get_peer_map()
.close_peer_conn(peer_id, conn_id)
.await?;
Ok(())
}
pub async fn wait(&mut self) {
while let Some(ret) = self.tasks.join_next().await {
log::info!("task finished: {:?}", ret);
ret.unwrap();
}
}
pub fn id(&self) -> uuid::Uuid {
self.id
}
fn run_rpc_server(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let addr = "0.0.0.0:15888".parse()?;
let peer_mgr = self.peer_manager.clone();
let conn_manager = self.conn_manager.clone();
let net_ns = self.global_ctx.net_ns.clone();
self.tasks.spawn(async move {
let _g = net_ns.guard();
log::info!("[INIT RPC] start rpc server. addr: {}", addr);
Server::builder()
.add_service(
easytier_rpc::peer_manage_rpc_server::PeerManageRpcServer::new(
PeerManagerRpcService::new(peer_mgr),
),
)
.add_service(
easytier_rpc::connector_manage_rpc_server::ConnectorManageRpcServer::new(
ConnectorManagerRpcService(conn_manager.clone()),
),
)
.serve(addr)
.await
.unwrap();
});
Ok(())
}
fn run_proxy_cidrs_route_updater(&mut self) {
let peer_mgr = self.peer_manager.clone();
let net_ns = self.global_ctx.net_ns.clone();
let nic = self.virtual_nic.as_ref().unwrap().clone();
self.tasks.spawn(async move {
let mut cur_proxy_cidrs = vec![];
loop {
let mut proxy_cidrs = vec![];
let routes = peer_mgr.list_routes().await;
for r in routes {
for cidr in r.proxy_cidrs {
let Ok(cidr) = cidr.parse::<cidr::Ipv4Cidr>() else {
continue;
};
proxy_cidrs.push(cidr);
}
}
// if route is in cur_proxy_cidrs but not in proxy_cidrs, delete it.
for cidr in cur_proxy_cidrs.iter() {
if proxy_cidrs.contains(cidr) {
continue;
}
let _g = net_ns.guard();
let ret = nic
.get_ifcfg()
.remove_ipv4_route(
nic.ifname(),
cidr.first_address(),
cidr.network_length(),
)
.await;
if ret.is_err() {
tracing::trace!(
cidr = ?cidr,
err = ?ret,
"remove route failed.",
);
}
}
for cidr in proxy_cidrs.iter() {
if cur_proxy_cidrs.contains(cidr) {
continue;
}
let _g = net_ns.guard();
let ret = nic
.get_ifcfg()
.add_ipv4_route(nic.ifname(), cidr.first_address(), cidr.network_length())
.await;
if ret.is_err() {
tracing::trace!(
cidr = ?cidr,
err = ?ret,
"add route failed.",
);
}
}
cur_proxy_cidrs = proxy_cidrs;
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
}
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
self.global_ctx.clone()
}
}
+150
View File
@@ -0,0 +1,150 @@
use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait;
use tokio::{sync::Mutex, task::JoinSet};
use crate::{
common::{error::Error, netns::NetNS},
peers::peer_manager::PeerManager,
tunnels::{
ring_tunnel::RingTunnelListener, tcp_tunnel::TcpTunnelListener,
udp_tunnel::UdpTunnelListener, Tunnel, TunnelListener,
},
};
#[async_trait]
pub trait TunnelHandlerForListener {
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error>;
}
#[async_trait]
impl TunnelHandlerForListener for PeerManager {
#[tracing::instrument]
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
self.add_tunnel_as_server(tunnel).await
}
}
pub struct ListenerManager<H> {
my_node_id: uuid::Uuid,
net_ns: NetNS,
listeners: Vec<Arc<Mutex<dyn TunnelListener>>>,
peer_manager: Arc<H>,
tasks: JoinSet<()>,
}
impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManager<H> {
pub fn new(my_node_id: uuid::Uuid, net_ns: NetNS, peer_manager: Arc<H>) -> Self {
Self {
my_node_id,
net_ns,
listeners: Vec::new(),
peer_manager,
tasks: JoinSet::new(),
}
}
pub async fn prepare_listeners(&mut self) -> Result<(), Error> {
self.add_listener(UdpTunnelListener::new(
"udp://0.0.0.0:11010".parse().unwrap(),
))
.await?;
self.add_listener(TcpTunnelListener::new(
"tcp://0.0.0.0:11010".parse().unwrap(),
))
.await?;
self.add_listener(RingTunnelListener::new(
format!("ring://{}", self.my_node_id).parse().unwrap(),
))
.await?;
Ok(())
}
pub async fn add_listener<Listener>(&mut self, listener: Listener) -> Result<(), Error>
where
Listener: TunnelListener + 'static,
{
let listener = Arc::new(Mutex::new(listener));
self.listeners.push(listener);
Ok(())
}
#[tracing::instrument]
async fn run_listener(listener: Arc<Mutex<dyn TunnelListener>>, peer_manager: Arc<H>) {
let mut l = listener.lock().await;
while let Ok(ret) = l.accept().await {
tracing::info!(ret = ?ret, "conn accepted");
let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret {
tracing::error!(error = ?e, "handle conn error");
}
}
}
pub async fn run(&mut self) -> Result<(), Error> {
for listener in &self.listeners {
let _guard = self.net_ns.guard();
log::warn!("run listener: {:?}", listener);
listener.lock().await.listen().await?;
self.tasks.spawn(Self::run_listener(
listener.clone(),
self.peer_manager.clone(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use futures::{SinkExt, StreamExt};
use tokio::time::timeout;
use crate::tunnels::{ring_tunnel::RingTunnelConnector, TunnelConnector};
use super::*;
#[derive(Debug)]
struct MockListenerHandler {}
#[async_trait]
impl TunnelHandlerForListener for MockListenerHandler {
async fn handle_tunnel(&self, _tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
let data = "abc";
_tunnel.pin_sink().send(data.into()).await.unwrap();
Err(Error::Unknown)
}
}
#[tokio::test]
async fn handle_error_in_accept() {
let net_ns = NetNS::new(None);
let handler = Arc::new(MockListenerHandler {});
let mut listener_mgr =
ListenerManager::new(uuid::Uuid::new_v4(), net_ns.clone(), handler.clone());
let ring_id = format!("ring://{}", uuid::Uuid::new_v4());
listener_mgr
.add_listener(RingTunnelListener::new(ring_id.parse().unwrap()))
.await
.unwrap();
listener_mgr.run().await.unwrap();
let connect_once = |ring_id| async move {
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
assert_eq!(tunnel.pin_stream().next().await.unwrap().unwrap(), "abc");
tunnel
};
timeout(std::time::Duration::from_secs(1), async move {
connect_once(ring_id.parse().unwrap()).await;
// handle tunnel fail should not impact the second connect
connect_once(ring_id.parse().unwrap()).await;
})
.await
.unwrap();
}
}
+4
View File
@@ -0,0 +1,4 @@
pub mod instance;
pub mod listeners;
pub mod tun_codec;
pub mod virtual_nic;
+179
View File
@@ -0,0 +1,179 @@
use std::io;
use byteorder::{NativeEndian, NetworkEndian, WriteBytesExt};
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
/// A packet protocol IP version
#[derive(Debug, Clone, Copy, Default)]
enum PacketProtocol {
#[default]
IPv4,
IPv6,
Other(u8),
}
// Note: the protocol in the packet information header is platform dependent.
impl PacketProtocol {
#[cfg(any(target_os = "linux", target_os = "android"))]
fn into_pi_field(self) -> Result<u16, io::Error> {
use nix::libc;
match self {
PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16),
PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16),
PacketProtocol::Other(_) => Err(io::Error::new(
io::ErrorKind::Other,
"neither an IPv4 nor IPv6 packet",
)),
}
}
#[cfg(any(target_os = "macos", target_os = "ios"))]
fn into_pi_field(self) -> Result<u16, io::Error> {
use nix::libc;
match self {
PacketProtocol::IPv4 => Ok(libc::PF_INET as u16),
PacketProtocol::IPv6 => Ok(libc::PF_INET6 as u16),
PacketProtocol::Other(_) => Err(io::Error::new(
io::ErrorKind::Other,
"neither an IPv4 nor IPv6 packet",
)),
}
}
#[cfg(target_os = "windows")]
fn into_pi_field(self) -> Result<u16, io::Error> {
unimplemented!()
}
}
#[derive(Debug)]
pub enum TunPacketBuffer {
Bytes(Bytes),
BytesMut(BytesMut),
}
impl From<TunPacketBuffer> for Bytes {
fn from(buf: TunPacketBuffer) -> Self {
match buf {
TunPacketBuffer::Bytes(bytes) => bytes,
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
}
}
}
impl AsRef<[u8]> for TunPacketBuffer {
fn as_ref(&self) -> &[u8] {
match self {
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
}
}
}
/// A Tun Packet to be sent or received on the TUN interface.
#[derive(Debug)]
pub struct TunPacket(PacketProtocol, TunPacketBuffer);
/// Infer the protocol based on the first nibble in the packet buffer.
fn infer_proto(buf: &[u8]) -> PacketProtocol {
match buf[0] >> 4 {
4 => PacketProtocol::IPv4,
6 => PacketProtocol::IPv6,
p => PacketProtocol::Other(p),
}
}
impl TunPacket {
/// Create a new `TunPacket` based on a byte slice.
pub fn new(buffer: TunPacketBuffer) -> TunPacket {
let proto = infer_proto(buffer.as_ref());
TunPacket(proto, buffer)
}
/// Return this packet's bytes.
pub fn get_bytes(&self) -> &[u8] {
match &self.1 {
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
}
}
pub fn into_bytes(self) -> Bytes {
match self.1 {
TunPacketBuffer::Bytes(bytes) => bytes,
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
}
}
pub fn into_bytes_mut(self) -> BytesMut {
match self.1 {
TunPacketBuffer::Bytes(_) => panic!("cannot into_bytes_mut from bytes"),
TunPacketBuffer::BytesMut(bytes) => bytes,
}
}
}
/// A TunPacket Encoder/Decoder.
pub struct TunPacketCodec(bool, i32);
impl TunPacketCodec {
/// Create a new `TunPacketCodec` specifying whether the underlying
/// tunnel Device has enabled the packet information header.
pub fn new(pi: bool, mtu: i32) -> TunPacketCodec {
TunPacketCodec(pi, mtu)
}
}
impl Decoder for TunPacketCodec {
type Item = TunPacket;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if buf.is_empty() {
return Ok(None);
}
let mut pkt = buf.split_to(buf.len());
// reserve enough space for the next packet
if self.0 {
buf.reserve(self.1 as usize + 4);
} else {
buf.reserve(self.1 as usize);
}
// if the packet information is enabled we have to ignore the first 4 bytes
if self.0 {
let _ = pkt.split_to(4);
}
let proto = infer_proto(pkt.as_ref());
Ok(Some(TunPacket(proto, TunPacketBuffer::BytesMut(pkt))))
}
}
impl Encoder<TunPacket> for TunPacketCodec {
type Error = io::Error;
fn encode(&mut self, item: TunPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.reserve(item.get_bytes().len() + 4);
match item {
TunPacket(proto, bytes) if self.0 => {
// build the packet information header comprising of 2 u16
// fields: flags and protocol.
let mut buf = Vec::<u8>::with_capacity(4);
// flags is always 0
buf.write_u16::<NativeEndian>(0)?;
// write the protocol as network byte order
buf.write_u16::<NetworkEndian>(proto.into_pi_field()?)?;
dst.put_slice(&buf);
dst.put(Bytes::from(bytes));
}
TunPacket(_, bytes) => dst.put(Bytes::from(bytes)),
}
Ok(())
}
}
+203
View File
@@ -0,0 +1,203 @@
use std::{net::Ipv4Addr, pin::Pin};
use crate::{
common::{
error::Result,
global_ctx::ArcGlobalCtx,
ifcfg::{IfConfiger, IfConfiguerTrait},
},
tunnels::{
codec::BytesCodec, common::FramedTunnel, DatagramSink, DatagramStream, Tunnel, TunnelError,
},
};
use futures::{SinkExt, StreamExt};
use tokio_util::{bytes::Bytes, codec::Framed};
use tun::Device;
use super::tun_codec::{TunPacket, TunPacketCodec};
pub struct VirtualNic {
dev_name: String,
queue_num: usize,
global_ctx: ArcGlobalCtx,
ifname: Option<String>,
tun: Option<Box<dyn Tunnel>>,
ifcfg: Box<dyn IfConfiguerTrait + Send + Sync + 'static>,
}
impl VirtualNic {
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
Self {
dev_name: "".to_owned(),
queue_num: 1,
global_ctx,
ifname: None,
tun: None,
ifcfg: Box::new(IfConfiger {}),
}
}
pub fn set_dev_name(mut self, dev_name: &str) -> Result<Self> {
self.dev_name = dev_name.to_owned();
Ok(self)
}
pub fn set_queue_num(mut self, queue_num: usize) -> Result<Self> {
self.queue_num = queue_num;
Ok(self)
}
async fn create_dev_ret_err(&mut self) -> Result<()> {
let mut config = tun::Configuration::default();
let has_packet_info = cfg!(target_os = "macos");
config.layer(tun::Layer::L3);
#[cfg(target_os = "linux")]
{
config.platform(|config| {
// detect protocol by ourselves for cross platform
config.packet_information(false);
});
config.name(self.dev_name.clone());
}
if self.queue_num != 1 {
todo!("queue_num != 1")
}
config.queues(self.queue_num);
config.up();
let dev = {
let _g = self.global_ctx.net_ns.guard();
tun::create_as_async(&config)?
};
let ifname = dev.get_ref().name()?;
self.ifcfg.wait_interface_show(ifname.as_str()).await?;
let ft: Box<dyn Tunnel> = if has_packet_info {
let framed = Framed::new(dev, TunPacketCodec::new(true, 2500));
let (sink, stream) = framed.split();
let new_stream = stream.map(|item| match item {
Ok(item) => Ok(item.into_bytes_mut()),
Err(err) => {
println!("tun stream error: {:?}", err);
Err(TunnelError::TunError(err.to_string()))
}
});
let new_sink = Box::pin(sink.with(|item: Bytes| async move {
if false {
return Err(TunnelError::TunError("tun sink error".to_owned()));
}
Ok(TunPacket::new(super::tun_codec::TunPacketBuffer::Bytes(
item,
)))
}));
Box::new(FramedTunnel::new(new_stream, new_sink, None))
} else {
let framed = Framed::new(dev, BytesCodec::new(2500));
let (sink, stream) = framed.split();
Box::new(FramedTunnel::new(stream, sink, None))
};
self.ifname = Some(ifname.to_owned());
self.tun = Some(ft);
Ok(())
}
pub async fn create_dev(mut self) -> Result<Self> {
self.create_dev_ret_err().await?;
Ok(self)
}
pub fn ifname(&self) -> &str {
self.ifname.as_ref().unwrap().as_str()
}
pub async fn link_up(self) -> Result<Self> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg.set_link_status(self.ifname(), true).await?;
Ok(self)
}
pub async fn add_route(self, address: Ipv4Addr, cidr: u8) -> Result<Self> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.add_ipv4_route(self.ifname(), address, cidr)
.await?;
Ok(self)
}
pub async fn remove_ip(self, ip: Option<Ipv4Addr>) -> Result<Self> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg.remove_ip(self.ifname(), ip).await?;
Ok(self)
}
pub async fn add_ip(self, ip: Ipv4Addr, cidr: i32) -> Result<Self> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.add_ipv4_ip(self.ifname(), ip, cidr as u8)
.await?;
Ok(self)
}
pub fn pin_recv_stream(&self) -> Pin<Box<dyn DatagramStream>> {
self.tun.as_ref().unwrap().pin_stream()
}
pub fn pin_send_stream(&self) -> Pin<Box<dyn DatagramSink>> {
self.tun.as_ref().unwrap().pin_sink()
}
pub fn get_ifcfg(&self) -> &dyn IfConfiguerTrait {
self.ifcfg.as_ref()
}
}
#[cfg(test)]
mod tests {
use crate::{
common::{error::Error, global_ctx::tests::get_mock_global_ctx},
tests::enable_log,
};
use super::VirtualNic;
async fn run_test_helper() -> Result<VirtualNic, Error> {
let dev = VirtualNic::new(get_mock_global_ctx()).create_dev().await?;
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
dev.link_up()
.await?
.remove_ip(None)
.await?
.add_ip("10.144.111.1".parse().unwrap(), 24)
.await
}
#[tokio::test]
async fn tun_test() {
enable_log();
let _dev = run_test_helper().await.unwrap();
// let mut stream = nic.pin_recv_stream();
// while let Some(item) = stream.next().await {
// println!("item: {:?}", item);
// }
// let framed = dev.into_framed();
// let (mut s, mut b) = framed.split();
// loop {
// let tmp = b.next().await.unwrap().unwrap();
// let tmp = EthernetPacket::new(tmp.get_bytes());
// println!("ret: {:?}", tmp.unwrap());
// }
}
}
+103
View File
@@ -0,0 +1,103 @@
#![allow(dead_code)]
#[cfg(test)]
mod tests;
use clap::Parser;
mod common;
mod connector;
mod gateway;
mod instance;
mod peers;
mod tunnels;
use instance::instance::{Instance, InstanceConfigWriter};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// the instance name
#[arg(short = 'n', long, default_value = "default")]
instance_name: String,
/// specify the network namespace, default is the root namespace
#[arg(long)]
net_ns: Option<String>,
#[arg(short, long)]
ipv4: Option<String>,
#[arg(short, long)]
peers: Vec<String>,
}
fn init_logger() {
// logger to rolling file
let file_filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env()
.unwrap();
let file_appender = tracing_appender::rolling::Builder::new()
.rotation(tracing_appender::rolling::Rotation::DAILY)
.max_log_files(5)
.filename_prefix("core.log")
.build("/var/log/easytier")
.expect("failed to initialize rolling file appender");
let mut file_layer = tracing_subscriber::fmt::layer();
file_layer.set_ansi(false);
let file_layer = file_layer
.with_writer(file_appender)
.with_filter(file_filter);
// logger to console
let console_filter = EnvFilter::builder()
.with_default_directive(LevelFilter::WARN.into())
.from_env()
.unwrap();
let console_layer = tracing_subscriber::fmt::layer()
.pretty()
.with_writer(std::io::stderr)
.with_filter(console_filter);
tracing_subscriber::Registry::default()
.with(console_layer)
.with(file_layer)
.init();
}
#[tokio::main(flavor = "current_thread")]
#[tracing::instrument]
pub async fn main() {
init_logger();
let cli = Cli::parse();
tracing::info!(cli = ?cli, "cli args parsed");
let cfg = InstanceConfigWriter::new(cli.instance_name.as_str()).set_ns(cli.net_ns.clone());
if let Some(ipv4) = &cli.ipv4 {
cfg.set_addr(ipv4.clone());
}
let mut inst = Instance::new(cli.instance_name.as_str());
let mut events = inst.get_global_ctx().subscribe();
tokio::spawn(async move {
while let Ok(e) = events.recv().await {
log::warn!("event: {:?}", e);
}
});
inst.run().await.unwrap();
for peer in cli.peers {
inst.get_conn_manager()
.add_connector_by_url(peer.as_str())
.await
.unwrap();
}
inst.wait().await;
}
+14
View File
@@ -0,0 +1,14 @@
pub mod packet;
pub mod peer;
pub mod peer_conn;
pub mod peer_manager;
pub mod peer_map;
pub mod peer_rpc;
pub mod rip_route;
pub mod route_trait;
pub mod rpc_service;
#[cfg(test)]
pub mod tests;
pub type PeerId = uuid::Uuid;
+205
View File
@@ -0,0 +1,205 @@
use rkyv::{Archive, Deserialize, Serialize};
use tokio_util::bytes::Bytes;
use crate::common::rkyv_util::{decode_from_bytes, encode_to_bytes};
const MAGIC: u32 = 0xd1e1a5e1;
const VERSION: u32 = 1;
#[derive(Archive, Deserialize, Serialize, PartialEq, Clone)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct UUID(uuid::Bytes);
// impl Debug for UUID
impl std::fmt::Debug for UUID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let uuid = uuid::Uuid::from_bytes(self.0);
write!(f, "{}", uuid)
}
}
impl From<uuid::Uuid> for UUID {
fn from(uuid: uuid::Uuid) -> Self {
UUID(*uuid.as_bytes())
}
}
impl From<UUID> for uuid::Uuid {
fn from(uuid: UUID) -> Self {
uuid::Uuid::from_bytes(uuid.0)
}
}
impl ArchivedUUID {
pub fn to_uuid(&self) -> uuid::Uuid {
uuid::Uuid::from_bytes(self.0)
}
}
impl From<&ArchivedUUID> for UUID {
fn from(uuid: &ArchivedUUID) -> Self {
UUID(uuid.0)
}
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct HandShake {
pub magic: u32,
pub my_peer_id: UUID,
pub version: u32,
pub features: Vec<String>,
// pub interfaces: Vec<String>,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
pub struct RoutePacket {
pub route_id: u8,
pub body: Vec<u8>,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum CtrlPacketBody {
HandShake(HandShake),
RoutePacket(RoutePacket),
Ping,
Pong,
TaRpc(u32, bool, Vec<u8>), // u32: service_id, bool: is_req, Vec<u8>: rpc body
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct DataPacketBody {
pub data: Vec<u8>,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum PacketBody {
Ctrl(CtrlPacketBody),
Data(DataPacketBody),
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct Packet {
pub from_peer: UUID,
pub to_peer: Option<UUID>,
pub body: PacketBody,
}
impl Packet {
pub fn decode(v: &[u8]) -> &ArchivedPacket {
decode_from_bytes::<Packet>(v).unwrap()
}
}
impl From<Packet> for Bytes {
fn from(val: Packet) -> Self {
encode_to_bytes::<_, 4096>(&val)
}
}
impl Packet {
pub fn new_handshake(from_peer: uuid::Uuid) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: None,
body: PacketBody::Ctrl(CtrlPacketBody::HandShake(HandShake {
magic: MAGIC,
my_peer_id: from_peer.into(),
version: VERSION,
features: Vec::new(),
})),
}
}
pub fn new_data_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid, data: &[u8]) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Data(DataPacketBody {
data: data.to_vec(),
}),
}
}
pub fn new_route_packet(
from_peer: uuid::Uuid,
to_peer: uuid::Uuid,
route_id: u8,
data: &[u8],
) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::RoutePacket(RoutePacket {
route_id,
body: data.to_vec(),
})),
}
}
pub fn new_ping_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::Ping),
}
}
pub fn new_pong_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::Pong),
}
}
pub fn new_tarpc_packet(
from_peer: uuid::Uuid,
to_peer: uuid::Uuid,
service_id: u32,
is_req: bool,
body: Vec<u8>,
) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::TaRpc(service_id, is_req, body)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn serialize() {
let a = "abcde";
let out = Packet::new_data_packet(uuid::Uuid::new_v4(), uuid::Uuid::new_v4(), a.as_bytes());
// let out = T::new(a.as_bytes());
let out_bytes: Bytes = out.into();
println!("out str: {:?}", a.as_bytes());
println!("out bytes: {:?}", out_bytes);
let archived = Packet::decode(&out_bytes[..]);
println!("in packet: {:?}", archived);
}
}
+218
View File
@@ -0,0 +1,218 @@
use std::sync::Arc;
use dashmap::DashMap;
use easytier_rpc::PeerConnInfo;
use tokio::{
select,
sync::{mpsc, Mutex},
task::JoinHandle,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use uuid::Uuid;
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
};
use super::peer_conn::PeerConn;
type ArcPeerConn = Arc<Mutex<PeerConn>>;
type ConnMap = Arc<DashMap<Uuid, ArcPeerConn>>;
pub struct Peer {
pub peer_node_id: uuid::Uuid,
conns: ConnMap,
global_ctx: ArcGlobalCtx,
packet_recv_chan: mpsc::Sender<Bytes>,
close_event_sender: mpsc::Sender<Uuid>,
close_event_listener: JoinHandle<()>,
shutdown_notifier: Arc<tokio::sync::Notify>,
}
impl Peer {
pub fn new(
peer_node_id: uuid::Uuid,
packet_recv_chan: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
) -> Self {
let conns: ConnMap = Arc::new(DashMap::new());
let (close_event_sender, mut close_event_receiver) = mpsc::channel(10);
let shutdown_notifier = Arc::new(tokio::sync::Notify::new());
let conns_copy = conns.clone();
let shutdown_notifier_copy = shutdown_notifier.clone();
let global_ctx_copy = global_ctx.clone();
let close_event_listener = tokio::spawn(
async move {
loop {
select! {
ret = close_event_receiver.recv() => {
if ret.is_none() {
break;
}
let ret = ret.unwrap();
tracing::warn!(
?peer_node_id,
?ret,
"notified that peer conn is closed",
);
if let Some((_, conn)) = conns_copy.remove(&ret) {
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
conn.lock().await.get_conn_info(),
));
}
}
_ = shutdown_notifier_copy.notified() => {
close_event_receiver.close();
tracing::warn!(?peer_node_id, "peer close event listener notified");
}
}
}
tracing::info!("peer {} close event listener exit", peer_node_id);
}
.instrument(tracing::info_span!(
"peer_close_event_listener",
?peer_node_id,
)),
);
Peer {
peer_node_id,
conns: conns.clone(),
packet_recv_chan,
global_ctx,
close_event_sender,
close_event_listener,
shutdown_notifier,
}
}
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
conn.set_close_event_sender(self.close_event_sender.clone());
conn.start_recv_loop(self.packet_recv_chan.clone());
self.global_ctx
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
self.conns
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
}
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
let Some(conn) = self.conns.iter().next() else {
return Err(Error::PeerNoConnectionError(self.peer_node_id));
};
let conn_clone = conn.clone();
drop(conn);
conn_clone.lock().await.send_msg(msg).await?;
Ok(())
}
pub async fn close_peer_conn(&self, conn_id: &Uuid) -> Result<(), Error> {
let has_key = self.conns.contains_key(conn_id);
if !has_key {
return Err(Error::NotFound);
}
self.close_event_sender.send(conn_id.clone()).await.unwrap();
Ok(())
}
pub async fn list_peer_conns(&self) -> Vec<PeerConnInfo> {
let mut conns = vec![];
for conn in self.conns.iter() {
// do not lock here, otherwise it will cause dashmap deadlock
conns.push(conn.clone());
}
let mut ret = Vec::new();
for conn in conns {
ret.push(conn.lock().await.get_conn_info());
}
ret
}
}
// pritn on drop
impl Drop for Peer {
fn drop(&mut self) {
self.shutdown_notifier.notify_one();
tracing::info!("peer {} drop", self.peer_node_id);
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::{sync::mpsc, time::timeout};
use crate::{
common::{config_fs::ConfigFs, global_ctx::GlobalCtx, netns::NetNS},
peers::peer_conn::PeerConn,
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::Peer;
#[tokio::test]
async fn close_peer() {
let (local_packet_send, _local_packet_recv) = mpsc::channel(10);
let (remote_packet_send, _remote_packet_recv) = mpsc::channel(10);
let global_ctx = Arc::new(GlobalCtx::new(
"test",
ConfigFs::new("/tmp/easytier-test"),
NetNS::new(None),
));
let local_peer = Peer::new(uuid::Uuid::new_v4(), local_packet_send, global_ctx.clone());
let remote_peer = Peer::new(uuid::Uuid::new_v4(), remote_packet_send, global_ctx.clone());
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
let mut local_peer_conn =
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
let mut remote_peer_conn =
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
assert!(!local_peer_conn.handshake_done());
assert!(!remote_peer_conn.handshake_done());
let (a, b) = tokio::join!(
local_peer_conn.do_handshake_as_client(),
remote_peer_conn.do_handshake_as_server()
);
a.unwrap();
b.unwrap();
let local_conn_id = local_peer_conn.get_conn_id();
local_peer.add_peer_conn(local_peer_conn).await;
remote_peer.add_peer_conn(remote_peer_conn).await;
assert_eq!(local_peer.list_peer_conns().await.len(), 1);
assert_eq!(remote_peer.list_peer_conns().await.len(), 1);
let close_handler =
tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await });
// wait for remote peer conn close
timeout(std::time::Duration::from_secs(5), async {
while (&remote_peer).list_peer_conns().await.len() != 0 {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
})
.await
.unwrap();
println!("wait for close handler");
close_handler.await.unwrap().unwrap();
}
}
+484
View File
@@ -0,0 +1,484 @@
use std::{pin::Pin, sync::Arc};
use easytier_rpc::{PeerConnInfo, PeerConnStats};
use futures::{SinkExt, StreamExt};
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{broadcast, mpsc},
task::JoinSet,
time::{timeout, Duration},
};
use tokio_util::{
bytes::{Bytes, BytesMut},
sync::PollSender,
};
use tracing::Instrument;
use crate::{
common::global_ctx::ArcGlobalCtx,
define_tunnel_filter_chain,
tunnels::{
stats::{Throughput, WindowLatency},
tunnel_filter::StatsRecorderTunnelFilter,
DatagramSink, Tunnel, TunnelError,
},
};
use super::packet::{self, ArchivedCtrlPacketBody, ArchivedHandShake, Packet};
pub type PacketRecvChan = mpsc::Sender<Bytes>;
macro_rules! wait_response {
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
if rsp_vec.is_err() {
return Err(TunnelError::WaitRespError(
"wait handshake response timeout".to_owned(),
));
}
let rsp_vec = rsp_vec.unwrap().unwrap()?;
let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec);
match &rsp_bytes.body {
$pattern => $out_var = $value,
_ => {
log::error!(
"unexpected packet: {:?}, pattern: {:?}",
rsp_bytes,
stringify!($pattern)
);
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
}
}
};
}
pub struct PeerInfo {
magic: u32,
pub my_peer_id: uuid::Uuid,
version: u32,
pub features: Vec<String>,
pub interfaces: Vec<NetworkInterface>,
}
impl<'a> From<&ArchivedHandShake> for PeerInfo {
fn from(hs: &ArchivedHandShake) -> Self {
PeerInfo {
magic: hs.magic.into(),
my_peer_id: hs.my_peer_id.to_uuid(),
version: hs.version.into(),
features: hs.features.iter().map(|x| x.to_string()).collect(),
interfaces: Vec::new(),
}
}
}
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
pub struct PeerConn {
conn_id: uuid::Uuid,
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
sink: Pin<Box<dyn DatagramSink>>,
tunnel: Box<dyn Tunnel>,
tasks: JoinSet<Result<(), TunnelError>>,
info: Option<PeerInfo>,
close_event_sender: Option<mpsc::Sender<uuid::Uuid>>,
ctrl_resp_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
throughput: Arc<Throughput>,
}
enum PeerConnPacketType {
Data(Bytes),
CtrlReq(Bytes),
CtrlResp(Bytes),
}
static CTRL_REQ_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0];
static CTRL_RESP_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1];
impl PeerConn {
pub fn new(node_id: uuid::Uuid, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
let peer_conn_tunnel = PeerConnTunnel::new();
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
PeerConn {
conn_id: uuid::Uuid::new_v4(),
my_node_id: node_id,
global_ctx,
sink: tunnel.pin_sink(),
tunnel: Box::new(tunnel),
tasks: JoinSet::new(),
info: None,
close_event_sender: None,
ctrl_resp_sender: ctrl_sender,
latency_stats: Arc::new(WindowLatency::new(15)),
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
}
}
pub fn get_conn_id(&self) -> uuid::Uuid {
self.conn_id
}
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
wait_response!(stream, hs_req, packet::ArchivedPacketBody::Ctrl(ArchivedCtrlPacketBody::HandShake(x)) => x);
self.info = Some(PeerInfo::from(hs_req));
log::info!("handshake request: {:?}", hs_req);
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_node_id));
sink.send(hs_req.into()).await?;
Ok(())
}
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_node_id));
sink.send(hs_req.into()).await?;
wait_response!(stream, hs_rsp, packet::ArchivedPacketBody::Ctrl(ArchivedCtrlPacketBody::HandShake(x)) => x);
self.info = Some(PeerInfo::from(hs_rsp));
log::info!("handshake response: {:?}", hs_rsp);
Ok(())
}
pub fn handshake_done(&self) -> bool {
self.info.is_some()
}
async fn do_pingpong_once(
my_node_id: uuid::Uuid,
peer_id: uuid::Uuid,
sink: &mut Pin<Box<dyn DatagramSink>>,
receiver: &mut broadcast::Receiver<Bytes>,
) -> Result<u128, TunnelError> {
// should add seq here. so latency can be calculated more accurately
let req = Self::build_ctrl_msg(
packet::Packet::new_ping_packet(my_node_id, peer_id).into(),
true,
);
log::trace!("send ping packet: {:?}", req);
sink.send(req).await?;
let now = std::time::Instant::now();
// wait until we get a pong packet in ctrl_resp_receiver
let resp = timeout(Duration::from_secs(4), async {
loop {
match receiver.recv().await {
Ok(p) => {
if let packet::ArchivedPacketBody::Ctrl(
packet::ArchivedCtrlPacketBody::Pong,
) = &Packet::decode(&p).body
{
break;
}
}
Err(e) => {
log::warn!("recv pong resp error: {:?}", e);
return Err(TunnelError::WaitRespError(
"recv pong resp error".to_owned(),
));
}
}
}
Ok(())
})
.await;
if resp.is_err() {
return Err(TunnelError::WaitRespError(
"wait ping response timeout".to_owned(),
));
}
if resp.as_ref().unwrap().is_err() {
return Err(resp.unwrap().err().unwrap());
}
Ok(now.elapsed().as_micros())
}
fn start_pingpong(&mut self) {
let mut sink = self.tunnel.pin_sink();
let my_node_id = self.my_node_id;
let peer_id = self.get_peer_id();
let receiver = self.ctrl_resp_sender.subscribe();
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
let latency_stats = self.latency_stats.clone();
self.tasks.spawn(async move {
//sleep 1s
tokio::time::sleep(Duration::from_secs(1)).await;
loop {
let mut receiver = receiver.resubscribe();
if let Ok(lat) =
Self::do_pingpong_once(my_node_id, peer_id, &mut sink, &mut receiver).await
{
log::trace!(
"pingpong latency: {}us, my_node_id: {}, peer_id: {}",
lat,
my_node_id,
peer_id
);
latency_stats.record_latency(lat as u64);
tokio::time::sleep(Duration::from_secs(1)).await;
} else {
break;
}
}
log::warn!(
"pingpong task exit, my_node_id: {}, peer_id: {}",
my_node_id,
peer_id,
);
if let Err(e) = close_event_sender.send(conn_id).await {
log::warn!("close event sender error: {:?}", e);
}
Ok(())
});
}
fn get_packet_type(mut bytes_item: Bytes) -> PeerConnPacketType {
if bytes_item.starts_with(CTRL_REQ_PACKET_PREFIX) {
PeerConnPacketType::CtrlReq(bytes_item.split_off(CTRL_REQ_PACKET_PREFIX.len()))
} else if bytes_item.starts_with(CTRL_RESP_PACKET_PREFIX) {
PeerConnPacketType::CtrlResp(bytes_item.split_off(CTRL_RESP_PACKET_PREFIX.len()))
} else {
PeerConnPacketType::Data(bytes_item)
}
}
fn handle_ctrl_req_packet(
bytes_item: Bytes,
conn_info: &PeerConnInfo,
) -> Result<Bytes, TunnelError> {
let packet = Packet::decode(&bytes_item);
match packet.body {
packet::ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::Ping) => {
log::trace!("recv ping packet: {:?}", packet);
Ok(Self::build_ctrl_msg(
packet::Packet::new_pong_packet(
conn_info.my_node_id.parse().unwrap(),
conn_info.peer_id.parse().unwrap(),
)
.into(),
false,
))
}
_ => {
log::error!("unexpected packet: {:?}", packet);
Err(TunnelError::CommonError("unexpected packet".to_owned()))
}
}
}
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let mut sender = PollSender::new(packet_recv_chan.clone());
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
let ctrl_sender = self.ctrl_resp_sender.clone();
let conn_info = self.get_conn_info();
let conn_info_for_instrument = self.get_conn_info();
self.tasks.spawn(
async move {
tracing::info!("start recving peer conn packet");
while let Some(ret) = stream.next().await {
if ret.is_err() {
tracing::error!(error = ?ret, "peer conn recv error");
if let Err(close_ret) = sink.close().await {
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
}
if let Err(e) = close_event_sender.send(conn_id).await {
tracing::error!(error = ?e, "peer conn close event send error");
}
return Err(ret.err().unwrap());
}
match Self::get_packet_type(ret.unwrap().into()) {
PeerConnPacketType::Data(item) => sender.send(item).await.unwrap(),
PeerConnPacketType::CtrlReq(item) => {
let ret = Self::handle_ctrl_req_packet(item, &conn_info).unwrap();
if let Err(e) = sink.send(ret).await {
tracing::error!(?e, "peer conn send req error");
}
}
PeerConnPacketType::CtrlResp(item) => {
if let Err(e) = ctrl_sender.send(item) {
tracing::error!(?e, "peer conn send ctrl resp error");
}
}
}
}
tracing::info!("end recving peer conn packet");
Ok(())
}
.instrument(
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
),
);
self.start_pingpong();
}
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
self.sink.send(msg).await
}
fn build_ctrl_msg(msg: Bytes, is_req: bool) -> Bytes {
let prefix: &'static [u8] = if is_req {
CTRL_REQ_PACKET_PREFIX
} else {
CTRL_RESP_PACKET_PREFIX
};
let mut new_msg = BytesMut::new();
new_msg.reserve(prefix.len() + msg.len());
new_msg.extend_from_slice(prefix);
new_msg.extend_from_slice(&msg);
new_msg.into()
}
pub fn get_peer_id(&self) -> uuid::Uuid {
self.info.as_ref().unwrap().my_peer_id
}
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<uuid::Uuid>) {
self.close_event_sender = Some(sender);
}
pub fn get_stats(&self) -> PeerConnStats {
PeerConnStats {
latency_us: self.latency_stats.get_latency_us(),
tx_bytes: self.throughput.tx_bytes(),
rx_bytes: self.throughput.rx_bytes(),
tx_packets: self.throughput.tx_packets(),
rx_packets: self.throughput.rx_packets(),
}
}
pub fn get_conn_info(&self) -> PeerConnInfo {
PeerConnInfo {
conn_id: self.conn_id.to_string(),
my_node_id: self.my_node_id.to_string(),
peer_id: self.get_peer_id().to_string(),
features: self.info.as_ref().unwrap().features.clone(),
tunnel: self.tunnel.info(),
stats: Some(self.get_stats()),
}
}
}
impl Drop for PeerConn {
fn drop(&mut self) {
let mut sink = self.tunnel.pin_sink();
tokio::spawn(async move {
let ret = sink.close().await;
tracing::info!(error = ?ret, "peer conn tunnel closed.");
});
log::info!("peer conn {:?} drop", self.conn_id);
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::common::config_fs::ConfigFs;
use crate::common::global_ctx::GlobalCtx;
use crate::common::netns::NetNS;
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
#[tokio::test]
async fn peer_conn_handshake() {
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
let (c, s) = create_ring_tunnel_pair();
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let c = TunnelWithFilter::new(c, c_recorder.clone());
let s = TunnelWithFilter::new(s, s_recorder.clone());
let c_uuid = uuid::Uuid::new_v4();
let s_uuid = uuid::Uuid::new_v4();
let mut c_peer = PeerConn::new(
c_uuid,
Arc::new(GlobalCtx::new(
"c",
ConfigFs::new_with_dir("c", "/tmp"),
NetNS::new(None),
)),
Box::new(c),
);
let mut s_peer = PeerConn::new(
s_uuid,
Arc::new(GlobalCtx::new(
"c",
ConfigFs::new_with_dir("c", "/tmp"),
NetNS::new(None),
)),
Box::new(s),
);
let (c_ret, s_ret) = tokio::join!(
c_peer.do_handshake_as_client(),
s_peer.do_handshake_as_server()
);
c_ret.unwrap();
s_ret.unwrap();
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
assert_eq!(c_peer.get_peer_id(), s_uuid);
assert_eq!(s_peer.get_peer_id(), c_uuid);
}
}
+539
View File
@@ -0,0 +1,539 @@
use std::{
fmt::Debug,
net::Ipv4Addr,
sync::{atomic::AtomicU8, Arc},
};
use async_trait::async_trait;
use futures::{StreamExt, TryFutureExt};
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex, RwLock,
},
task::JoinSet,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::bytes::{Bytes, BytesMut};
use uuid::Uuid;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_vec},
peers::{
packet::{self},
peer_conn::PeerConn,
peer_rpc::PeerRpcManagerTransport,
route_trait::RouteInterface,
},
tunnels::{SinkItem, Tunnel, TunnelConnector},
};
use super::{
peer_map::PeerMap,
peer_rpc::PeerRpcManager,
route_trait::{ArcRoute, Route},
PeerId,
};
struct RpcTransport {
my_peer_id: uuid::Uuid,
peers: Arc<PeerMap>,
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
route: Arc<Mutex<Option<ArcRoute>>>,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for RpcTransport {
fn my_peer_id(&self) -> Uuid {
self.my_peer_id
}
async fn send(&self, msg: Bytes, dst_peer_id: &uuid::Uuid) -> Result<(), Error> {
let route = self.route.lock().await;
if route.is_none() {
log::error!("no route info when send rpc msg");
return Err(Error::RouteError("No route info".to_string()));
}
self.peers
.send_msg(msg, dst_peer_id, route.as_ref().unwrap().clone())
.map_err(|e| e.into())
.await
}
async fn recv(&self) -> Result<Bytes, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await {
Ok(o)
} else {
Err(Error::Unknown)
}
}
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerPacketFilter {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()>;
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait NicPacketFilter {
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
}
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
pub struct PeerManager {
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
nic_channel: mpsc::Sender<SinkItem>,
tasks: Arc<Mutex<JoinSet<()>>>,
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
peers: Arc<PeerMap>,
route: Arc<Mutex<Option<ArcRoute>>>,
cur_route_id: AtomicU8,
peer_rpc_mgr: Arc<PeerRpcManager>,
peer_rpc_tspt: Arc<RpcTransport>,
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>,
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>,
}
impl Debug for PeerManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerManager")
.field("my_node_id", &self.my_node_id)
.field("instance_name", &self.global_ctx.inst_name)
.field("net_ns", &self.global_ctx.net_ns.name())
.field("cur_route_id", &self.cur_route_id)
.finish()
}
}
impl PeerManager {
pub fn new(global_ctx: ArcGlobalCtx, nic_channel: mpsc::Sender<SinkItem>) -> Self {
let (packet_send, packet_recv) = mpsc::channel(100);
let peers = Arc::new(PeerMap::new(packet_send.clone()));
// TODO: remove these because we have impl pipeline processor.
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_tspt = Arc::new(RpcTransport {
my_peer_id: global_ctx.get_id(),
peers: peers.clone(),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
peer_rpc_tspt_sender,
route: Arc::new(Mutex::new(None)),
});
PeerManager {
my_node_id: global_ctx.get_id(),
global_ctx,
nic_channel,
tasks: Arc::new(Mutex::new(JoinSet::new())),
packet_recv: Arc::new(Mutex::new(Some(packet_recv))),
peers: peers.clone(),
route: Arc::new(Mutex::new(None)),
cur_route_id: AtomicU8::new(0),
peer_rpc_mgr: Arc::new(PeerRpcManager::new(rpc_tspt.clone())),
peer_rpc_tspt: rpc_tspt,
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn add_client_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(Uuid, Uuid), Error> {
let mut peer = PeerConn::new(self.my_node_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
let peer_id = peer.get_peer_id();
self.peers
.add_new_peer_conn(peer, self.global_ctx.clone())
.await;
Ok((peer_id, conn_id))
}
#[tracing::instrument]
pub async fn try_connect<C>(&self, mut connector: C) -> Result<(Uuid, Uuid), Error>
where
C: TunnelConnector + Debug,
{
let ns = self.global_ctx.net_ns.clone();
let t = ns
.run_async(|| async move { connector.connect().await })
.await?;
self.add_client_tunnel(t).await
}
#[tracing::instrument]
pub async fn add_tunnel_as_server(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
tracing::info!("add tunnel as server start");
let mut peer = PeerConn::new(self.my_node_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_server().await?;
self.peers
.add_new_peer_conn(peer, self.global_ctx.clone())
.await;
tracing::info!("add tunnel as server done");
Ok(())
}
async fn start_peer_recv(&self) {
let mut recv = ReceiverStream::new(self.packet_recv.lock().await.take().unwrap());
let my_node_id = self.my_node_id;
let peers = self.peers.clone();
let arc_route = self.route.clone();
let pipe_line = self.peer_packet_process_pipeline.clone();
self.tasks.lock().await.spawn(async move {
log::trace!("start_peer_recv");
while let Some(ret) = recv.next().await {
log::trace!("peer recv a packet...: {:?}", ret);
let packet = packet::Packet::decode(&ret);
let from_peer_uuid = packet.from_peer.to_uuid();
let to_peer_uuid = packet.to_peer.as_ref().unwrap().to_uuid();
if to_peer_uuid != my_node_id {
let locked_arc_route = arc_route.lock().await;
if locked_arc_route.is_none() {
log::error!("no route info after recv a packet");
continue;
}
let route = locked_arc_route.as_ref().unwrap().clone();
drop(locked_arc_route);
log::trace!(
"need forward: to_peer_uuid: {:?}, my_uuid: {:?}",
to_peer_uuid,
my_node_id
);
let ret = peers
.send_msg(ret.clone(), &to_peer_uuid, route.clone())
.await;
if ret.is_err() {
log::error!(
"forward packet error: {:?}, dst: {:?}, from: {:?}",
ret,
to_peer_uuid,
from_peer_uuid
);
}
} else {
let mut processed = false;
for pipeline in pipe_line.read().await.iter().rev() {
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
{
processed = true;
break;
}
}
if !processed {
tracing::error!("unexpected packet: {:?}", ret);
}
}
}
panic!("done_peer_recv");
});
}
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
// newest pipeline will be executed first
self.peer_packet_process_pipeline
.write()
.await
.push(pipeline);
}
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
// newest pipeline will be executed first
self.nic_packet_process_pipeline
.write()
.await
.push(pipeline);
}
async fn init_packet_process_pipeline(&self) {
use packet::ArchivedPacketBody;
// for tun/tap ip/eth packet.
struct NicPacketProcessor {
nic_channel: mpsc::Sender<SinkItem>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for NicPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let packet::ArchivedPacketBody::Data(x) = &packet.body {
// TODO: use a function to get the body ref directly for zero copy
self.nic_channel
.send(extract_bytes_from_archived_vec(&data, &x.data))
.await
.unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(NicPacketProcessor {
nic_channel: self.nic_channel.clone(),
}))
.await;
// for peer manager router packet
struct RoutePacketProcessor {
route: Arc<Mutex<Option<ArcRoute>>>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for RoutePacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::RoutePacket(
route_packet,
)) = &packet.body
{
let r = self.route.lock().await;
match r.as_ref() {
Some(x) => {
let x = x.clone();
drop(r);
x.handle_route_packet(
packet.from_peer.to_uuid(),
extract_bytes_from_archived_vec(&data, &route_packet.body),
)
.await;
}
None => {
log::error!("no route info when handle route packet");
}
}
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(RoutePacketProcessor {
route: self.route.clone(),
}))
.await;
// for peer rpc packet
struct PeerRpcPacketProcessor {
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for PeerRpcPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::TaRpc(..)) =
&packet.body
{
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor {
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
}))
.await;
}
pub async fn set_route<T>(&self, route: T)
where
T: Route + Send + Sync + 'static,
{
struct Interface {
my_node_id: uuid::Uuid,
peers: Arc<PeerMap>,
}
#[async_trait]
impl RouteInterface for Interface {
async fn list_peers(&self) -> Vec<PeerId> {
self.peers.list_peers_with_conn().await
}
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: &PeerId,
) -> Result<(), Error> {
self.peers
.send_msg_directly(
packet::Packet::new_route_packet(
self.my_node_id,
*dst_peer_id,
route_id,
&msg,
)
.into(),
dst_peer_id,
)
.await
}
}
let my_node_id = self.my_node_id;
let route_id = route
.open(Box::new(Interface {
my_node_id,
peers: self.peers.clone(),
}))
.await
.unwrap();
self.cur_route_id
.store(route_id, std::sync::atomic::Ordering::Relaxed);
let arc_route: ArcRoute = Arc::new(Box::new(route));
self.route.lock().await.replace(arc_route.clone());
self.peer_rpc_tspt
.route
.lock()
.await
.replace(arc_route.clone());
}
pub async fn list_routes(&self) -> Vec<easytier_rpc::Route> {
let route_info = self.route.lock().await;
if route_info.is_none() {
return Vec::new();
}
let route = route_info.as_ref().unwrap().clone();
drop(route_info);
route.list_routes().await
}
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
data = pipeline.try_process_packet_from_nic(data).await;
}
data
}
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: &PeerId) -> Result<(), Error> {
self.peer_rpc_tspt.send(msg, dst_peer_id).await
}
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
let route_info = self.route.lock().await;
if route_info.is_none() {
log::error!("no route info");
return Err(Error::RouteError("No route info".to_string()));
}
let route = route_info.as_ref().unwrap().clone();
drop(route_info);
log::trace!(
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
msg,
ipv4_addr
);
match route.get_peer_id_by_ipv4(&ipv4_addr).await {
Some(peer_id) => {
let msg = self.run_nic_packet_process_pipeline(msg).await;
self.peers
.send_msg(
packet::Packet::new_data_packet(self.my_node_id, peer_id, &msg).into(),
&peer_id,
route.clone(),
)
.await?;
log::trace!(
"do send_msg in peer manager done, dst_peer_id: {:?}",
peer_id
);
}
None => {
log::trace!("no peer id for ipv4: {}", ipv4_addr);
return Ok(());
}
}
Ok(())
}
async fn run_clean_peer_without_conn_routine(&self) {
let peer_map = self.peers.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut to_remove = vec![];
for peer_id in peer_map.list_peers().await {
let conns = peer_map.list_peer_conns(&peer_id).await;
if conns.is_none() || conns.as_ref().unwrap().is_empty() {
to_remove.push(peer_id);
}
}
for peer_id in to_remove {
peer_map.close_peer(&peer_id).await.unwrap();
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
});
}
pub async fn run(&self) -> Result<(), Error> {
self.init_packet_process_pipeline().await;
self.start_peer_recv().await;
self.peer_rpc_mgr.run();
self.run_clean_peer_without_conn_routine().await;
Ok(())
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {
self.peers.clone()
}
pub fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
self.peer_rpc_mgr.clone()
}
pub fn my_node_id(&self) -> uuid::Uuid {
self.my_node_id
}
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
self.global_ctx.clone()
}
pub fn get_nic_channel(&self) -> mpsc::Sender<SinkItem> {
self.nic_channel.clone()
}
}
+140
View File
@@ -0,0 +1,140 @@
use std::sync::Arc;
use dashmap::DashMap;
use easytier_rpc::PeerConnInfo;
use tokio::sync::mpsc;
use tokio_util::bytes::Bytes;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx},
tunnels::TunnelError,
};
use super::{peer::Peer, peer_conn::PeerConn, route_trait::ArcRoute, PeerId};
pub struct PeerMap {
peer_map: DashMap<PeerId, Arc<Peer>>,
packet_send: mpsc::Sender<Bytes>,
}
impl PeerMap {
pub fn new(packet_send: mpsc::Sender<Bytes>) -> Self {
PeerMap {
peer_map: DashMap::new(),
packet_send,
}
}
async fn add_new_peer(&self, peer: Peer) {
self.peer_map.insert(peer.peer_node_id, Arc::new(peer));
}
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn, global_ctx: ArcGlobalCtx) {
let peer_id = peer_conn.get_peer_id();
let no_entry = self.peer_map.get(&peer_id).is_none();
if no_entry {
let new_peer = Peer::new(peer_id, self.packet_send.clone(), global_ctx);
new_peer.add_peer_conn(peer_conn).await;
self.add_new_peer(new_peer).await;
} else {
let peer = self.peer_map.get(&peer_id).unwrap().clone();
peer.add_peer_conn(peer_conn).await;
}
}
fn get_peer_by_id(&self, peer_id: &PeerId) -> Option<Arc<Peer>> {
self.peer_map.get(peer_id).map(|v| v.clone())
}
pub async fn send_msg_directly(
&self,
msg: Bytes,
dst_peer_id: &uuid::Uuid,
) -> Result<(), Error> {
match self.get_peer_by_id(dst_peer_id) {
Some(peer) => {
peer.send_msg(msg).await?;
}
None => {
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
return Ok(());
}
}
Ok(())
}
pub async fn send_msg(
&self,
msg: Bytes,
dst_peer_id: &uuid::Uuid,
route: ArcRoute,
) -> Result<(), Error> {
// get route info
let gateway_peer_id = route.get_next_hop(dst_peer_id).await;
if gateway_peer_id.is_none() {
log::error!("no gateway for dst_peer_id: {}", dst_peer_id);
return Ok(());
}
let gateway_peer_id = gateway_peer_id.unwrap();
self.send_msg_directly(msg, &gateway_peer_id).await?;
Ok(())
}
pub async fn list_peers(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
for item in self.peer_map.iter() {
let peer_id = item.key();
ret.push(*peer_id);
}
ret
}
pub async fn list_peers_with_conn(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
let peers = self.list_peers().await;
for peer_id in peers.iter() {
let Some(peer) = self.get_peer_by_id(peer_id) else {
continue;
};
if peer.list_peer_conns().await.len() > 0 {
ret.push(*peer_id);
}
}
ret
}
pub async fn list_peer_conns(&self, peer_id: &PeerId) -> Option<Vec<PeerConnInfo>> {
if let Some(p) = self.get_peer_by_id(peer_id) {
Some(p.list_peer_conns().await)
} else {
return None;
}
}
pub async fn close_peer_conn(
&self,
peer_id: &PeerId,
conn_id: &uuid::Uuid,
) -> Result<(), Error> {
if let Some(p) = self.get_peer_by_id(peer_id) {
p.close_peer_conn(conn_id).await
} else {
return Err(Error::NotFound);
}
}
pub async fn close_peer(&self, peer_id: &PeerId) -> Result<(), TunnelError> {
let remove_ret = self.peer_map.remove(peer_id);
tracing::info!(
?peer_id,
has_old_value = ?remove_ret.is_some(),
peer_ref_counter = ?remove_ret.map(|v| Arc::strong_count(&v.1)),
"peer is closed"
);
Ok(())
}
}
+509
View File
@@ -0,0 +1,509 @@
use std::sync::Arc;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use rkyv::Deserialize;
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
use tokio::{
sync::mpsc::{self, UnboundedSender},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{common::error::Error, peers::packet::Packet};
use super::packet::{CtrlPacketBody, PacketBody};
type PeerRpcServiceId = u32;
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
fn my_peer_id(&self) -> uuid::Uuid;
async fn send(&self, msg: Bytes, dst_peer_id: &uuid::Uuid) -> Result<(), Error>;
async fn recv(&self) -> Result<Bytes, Error>;
}
type PacketSender = UnboundedSender<Packet>;
struct PeerRpcEndPoint {
peer_id: uuid::Uuid,
packet_sender: PacketSender,
tasks: JoinSet<()>,
}
type PeerRpcEndPointCreator = Box<dyn Fn(uuid::Uuid) -> PeerRpcEndPoint + Send + Sync + 'static>;
#[derive(Hash, Eq, PartialEq, Clone)]
struct PeerRpcClientCtxKey(uuid::Uuid, PeerRpcServiceId);
// handle rpc request from one peer
pub struct PeerRpcManager {
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
tasks: JoinSet<()>,
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
peer_rpc_endpoints: Arc<DashMap<(uuid::Uuid, PeerRpcServiceId), PeerRpcEndPoint>>,
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
}
impl std::fmt::Debug for PeerRpcManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerRpcManager")
.field("node_id", &self.tspt.my_peer_id())
.finish()
}
}
#[derive(Debug)]
struct TaRpcPacketInfo {
from_peer: uuid::Uuid,
to_peer: uuid::Uuid,
service_id: PeerRpcServiceId,
is_req: bool,
content: Vec<u8>,
}
impl PeerRpcManager {
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
Self {
service_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(),
tspt: Arc::new(Box::new(tspt)),
service_registry: Arc::new(DashMap::new()),
peer_rpc_endpoints: Arc::new(DashMap::new()),
client_resp_receivers: Arc::new(DashMap::new()),
}
}
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
where
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Resp:
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Fut: Send + 'static,
{
let tspt = self.tspt.clone();
let creator = Box::new(move |peer_id: uuid::Uuid| {
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
let my_peer_id_clone = tspt.my_peer_id();
let peer_id_clone = peer_id.clone();
let o = server.execute(s.clone());
tasks.spawn(o);
let tspt = tspt.clone();
tasks.spawn(async move {
let mut cur_req_uuid = None;
loop {
tokio::select! {
Some(resp) = client_transport.next() => {
tracing::trace!(resp = ?resp, "recv packet from client");
if resp.is_err() {
tracing::warn!(err = ?resp.err(),
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
continue;
}
let resp = resp.unwrap();
if cur_req_uuid.is_none() {
tracing::error!("[PEER RPC MGR] cur_req_uuid is none, ignore this resp");
continue;
}
let serialized_resp = bincode::serialize(&resp);
if serialized_resp.is_err() {
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
continue;
}
let msg = Packet::new_tarpc_packet(
tspt.my_peer_id(),
cur_req_uuid.take().unwrap(),
service_id,
false,
serialized_resp.unwrap(),
);
if let Err(e) = tspt.send(msg.into(), &peer_id).await {
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
}
}
Some(packet) = packet_receiver.recv() => {
let info = Self::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
continue;
}
let info = info.unwrap();
assert_eq!(info.service_id, service_id);
cur_req_uuid = Some(packet.from_peer.clone().into());
tracing::trace!("recv packet from peer, packet: {:?}", packet);
let decoded_ret = bincode::deserialize(&info.content.as_slice());
if let Err(e) = decoded_ret {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
if let Err(e) = client_transport.send(decoded).await {
tracing::error!(error = ?e, "send to req to client transport failed");
}
}
else => {
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
}
}
}
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
tracing::info!(
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
peer_id,
service_id
);
return PeerRpcEndPoint {
peer_id,
packet_sender,
tasks,
};
// let resp = client_transport.next().await;
});
if let Some(_) = self.service_registry.insert(service_id, creator) {
panic!(
"[PEER RPC MGR] service {} is already registered",
service_id
);
}
log::info!(
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
service_id,
self.tspt.my_peer_id()
)
}
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
match &packet.body {
PacketBody::Ctrl(CtrlPacketBody::TaRpc(id, is_req, body)) => Ok(TaRpcPacketInfo {
from_peer: packet.from_peer.clone().into(),
to_peer: packet.to_peer.clone().unwrap().into(),
service_id: *id,
is_req: *is_req,
content: body.clone(),
}),
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
}
}
pub fn run(&self) {
let tspt = self.tspt.clone();
let service_registry = self.service_registry.clone();
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
let client_resp_receivers = self.client_resp_receivers.clone();
tokio::spawn(async move {
loop {
let o = tspt.recv().await.unwrap();
let packet = Packet::decode(&o);
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
let info = Self::parse_rpc_packet(&packet).unwrap();
if info.is_req {
if !service_registry.contains_key(&info.service_id) {
log::warn!(
"service {} not found, my_node_id: {}",
info.service_id,
tspt.my_peer_id()
);
continue;
}
let endpoint = peer_rpc_endpoints
.entry((info.to_peer, info.service_id))
.or_insert_with(|| {
service_registry.get(&info.service_id).unwrap()(info.from_peer)
});
endpoint.packet_sender.send(packet).unwrap();
} else {
if let Some(a) = client_resp_receivers
.get(&PeerRpcClientCtxKey(info.from_peer, info.service_id))
{
log::trace!("recv resp: {:?}", packet);
if let Err(e) = a.send(packet) {
tracing::error!(error = ?e, "send resp to client failed");
}
} else {
log::warn!("client resp receiver not found, info: {:?}", info);
}
}
}
});
}
#[tracing::instrument(skip(f))]
pub async fn do_client_rpc_scoped<CM, Req, RpcRet, Fut>(
&self,
service_id: PeerRpcServiceId,
dst_peer_id: uuid::Uuid,
f: impl FnOnce(UnboundedChannel<CM, Req>) -> Fut,
) -> RpcRet
where
CM: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Req: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Fut: std::future::Future<Output = RpcRet>,
{
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (client_transport, server_transport) =
tarpc::transport::channel::unbounded::<CM, Req>();
let (mut server_s, mut server_r) = server_transport.split();
let tspt = self.tspt.clone();
tasks.spawn(async move {
while let Some(a) = server_r.next().await {
if a.is_err() {
tracing::error!(error = ?a.err(), "channel error");
continue;
}
let a = bincode::serialize(&a.unwrap());
if a.is_err() {
tracing::error!(error = ?a.err(), "bincode serialize failed");
continue;
}
let a = Packet::new_tarpc_packet(
tspt.my_peer_id(),
dst_peer_id,
service_id,
true,
a.unwrap(),
);
if let Err(e) = tspt.send(a.into(), &dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
}
}
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
});
tasks.spawn(async move {
while let Some(packet) = packet_receiver.recv().await {
tracing::trace!("tunnel recv: {:?}", packet);
let info = PeerRpcManager::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, "parse rpc packet failed");
continue;
}
let decoded = bincode::deserialize(&info.unwrap().content.as_slice());
if let Err(e) = decoded {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
if let Err(e) = server_s.send(decoded.unwrap()).await {
tracing::error!(error = ?e, "send to rpc server channel failed");
}
}
tracing::warn!("[PEER RPC MGR] server packet read aborted");
});
let _insert_ret = self
.client_resp_receivers
.insert(PeerRpcClientCtxKey(dst_peer_id, service_id), packet_sender);
f(client_transport).await
}
pub fn my_peer_id(&self) -> uuid::Uuid {
self.tspt.my_peer_id()
}
}
#[cfg(test)]
mod tests {
use futures::{SinkExt, StreamExt};
use tokio_util::bytes::Bytes;
use crate::{
common::error::Error,
peers::{
peer_rpc::PeerRpcManager,
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
},
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
};
use super::PeerRpcManagerTransport;
#[tarpc::service]
pub trait TestRpcService {
async fn hello(s: String) -> String;
}
#[derive(Clone)]
struct MockService {
prefix: String,
}
#[tarpc::server]
impl TestRpcService for MockService {
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
format!("{} {}", self.prefix, s)
}
}
#[tokio::test]
async fn peer_rpc_basic_test() {
struct MockTransport {
tunnel: Box<dyn tunnels::Tunnel>,
my_peer_id: uuid::Uuid,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for MockTransport {
fn my_peer_id(&self) -> uuid::Uuid {
self.my_peer_id
}
async fn send(&self, msg: Bytes, _dst_peer_id: &uuid::Uuid) -> Result<(), Error> {
println!("rpc mgr send: {:?}", msg);
self.tunnel.pin_sink().send(msg).await.unwrap();
Ok(())
}
async fn recv(&self) -> Result<Bytes, Error> {
let ret = self.tunnel.pin_stream().next().await.unwrap();
println!("rpc mgr recv: {:?}", ret);
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
}
}
let (ct, st) = create_ring_tunnel_pair();
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: st,
my_peer_id: uuid::Uuid::new_v4(),
});
server_rpc_mgr.run();
let s = MockService {
prefix: "hello".to_owned(),
};
server_rpc_mgr.run_service(1, s.serve());
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: ct,
my_peer_id: uuid::Uuid::new_v4(),
});
client_rpc_mgr.run();
let ret = client_rpc_mgr
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), "hello abc");
}
#[tokio::test]
async fn test_rpc_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_node_id())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_node_id()
);
let s = MockService {
prefix: "hello".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_node_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello abc");
}
#[tokio::test]
async fn test_multi_service_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_node_id())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_node_id()
);
let s = MockService {
prefix: "hello_a".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let b = MockService {
prefix: "hello_b".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_node_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc");
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(2, peer_mgr_b.my_node_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc");
}
}
+480
View File
@@ -0,0 +1,480 @@
use std::{net::Ipv4Addr, sync::Arc, time::Duration};
use async_trait::async_trait;
use dashmap::DashMap;
use easytier_rpc::{NatType, StunInfo};
use rkyv::{Archive, Deserialize, Serialize};
use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use uuid::Uuid;
use crate::{
common::{
error::Error,
global_ctx::ArcGlobalCtx,
rkyv_util::{decode_from_bytes, encode_to_bytes},
stun::StunInfoCollectorTrait,
},
peers::{
packet::{self, UUID},
route_trait::{Route, RouteInterfaceBox},
PeerId,
},
};
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct SyncPeerInfo {
// means next hop in route table.
pub peer_id: UUID,
pub cost: u32,
pub ipv4_addr: Option<Ipv4Addr>,
pub proxy_cidrs: Vec<String>,
pub hostname: Option<String>,
pub udp_stun_info: i8,
}
impl SyncPeerInfo {
pub fn new_self(from_peer: UUID, global_ctx: &ArcGlobalCtx) -> Self {
SyncPeerInfo {
peer_id: from_peer,
cost: 0,
ipv4_addr: global_ctx.get_ipv4(),
proxy_cidrs: global_ctx
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.collect(),
hostname: global_ctx.get_hostname(),
udp_stun_info: global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type as i8,
}
}
pub fn clone_for_route_table(&self, next_hop: &UUID, cost: u32, from: &Self) -> Self {
SyncPeerInfo {
peer_id: next_hop.clone(),
cost,
ipv4_addr: from.ipv4_addr.clone(),
proxy_cidrs: from.proxy_cidrs.clone(),
hostname: from.hostname.clone(),
udp_stun_info: from.udp_stun_info,
}
}
}
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct SyncPeer {
pub myself: SyncPeerInfo,
pub neighbors: Vec<SyncPeerInfo>,
}
impl SyncPeer {
pub fn new(
from_peer: UUID,
_to_peer: UUID,
neighbors: Vec<SyncPeerInfo>,
global_ctx: ArcGlobalCtx,
) -> Self {
SyncPeer {
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
neighbors,
}
}
}
struct SyncPeerFromRemote {
packet: SyncPeer,
last_update: std::time::Instant,
}
type SyncPeerFromRemoteMap = Arc<DashMap<uuid::Uuid, SyncPeerFromRemote>>;
#[derive(Clone, Debug)]
struct RouteTable {
route_info: DashMap<uuid::Uuid, SyncPeerInfo>,
ipv4_peer_id_map: DashMap<Ipv4Addr, uuid::Uuid>,
cidr_peer_id_map: DashMap<cidr::IpCidr, uuid::Uuid>,
}
impl RouteTable {
fn new() -> Self {
RouteTable {
route_info: DashMap::new(),
ipv4_peer_id_map: DashMap::new(),
cidr_peer_id_map: DashMap::new(),
}
}
fn copy_from(&self, other: &Self) {
self.route_info.clear();
for item in other.route_info.iter() {
let (k, v) = item.pair();
self.route_info.insert(*k, v.clone());
}
self.ipv4_peer_id_map.clear();
for item in other.ipv4_peer_id_map.iter() {
let (k, v) = item.pair();
self.ipv4_peer_id_map.insert(*k, *v);
}
self.cidr_peer_id_map.clear();
for item in other.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
self.cidr_peer_id_map.insert(*k, *v);
}
}
}
pub struct BasicRoute {
my_peer_id: packet::UUID,
global_ctx: ArcGlobalCtx,
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
route_table: Arc<RouteTable>,
sync_peer_from_remote: SyncPeerFromRemoteMap,
tasks: Mutex<JoinSet<()>>,
need_sync_notifier: Arc<tokio::sync::Notify>,
}
impl BasicRoute {
pub fn new(my_peer_id: Uuid, global_ctx: ArcGlobalCtx) -> Self {
BasicRoute {
my_peer_id: my_peer_id.into(),
global_ctx,
interface: Arc::new(Mutex::new(None)),
route_table: Arc::new(RouteTable::new()),
sync_peer_from_remote: Arc::new(DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
}
}
fn update_route_table(
my_id: packet::UUID,
sync_peer_reqs: SyncPeerFromRemoteMap,
route_table: Arc<RouteTable>,
) {
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
let new_route_table = Arc::new(RouteTable::new());
for item in sync_peer_reqs.iter() {
Self::update_route_table_with_req(
my_id.clone(),
&item.value().packet,
new_route_table.clone(),
);
}
route_table.copy_from(&new_route_table);
}
fn update_route_table_with_req(
my_id: packet::UUID,
packet: &SyncPeer,
route_table: Arc<RouteTable>,
) {
let peer_id = packet.myself.peer_id.clone();
let update = |cost: u32, peer_info: &SyncPeerInfo| {
let node_id: uuid::Uuid = peer_info.peer_id.clone().into();
let ret = route_table
.route_info
.entry(node_id.clone().into())
.and_modify(|info| {
if info.cost > cost {
*info = info.clone_for_route_table(&peer_id, cost, &peer_info);
}
})
.or_insert(
peer_info
.clone()
.clone_for_route_table(&peer_id, cost, &peer_info),
)
.value()
.clone();
if ret.cost > 32 {
log::error!(
"cost too large: {}, may lost connection, remove it",
ret.cost
);
route_table.route_info.remove(&node_id);
}
log::trace!(
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
node_id,
peer_id,
cost,
&peer_info
);
if let Some(ipv4) = peer_info.ipv4_addr {
route_table
.ipv4_peer_id_map
.insert(ipv4.clone(), node_id.clone().into());
}
for cidr in peer_info.proxy_cidrs.iter() {
let cidr: cidr::IpCidr = cidr.parse().unwrap();
route_table
.cidr_peer_id_map
.insert(cidr, node_id.clone().into());
}
};
for neighbor in packet.neighbors.iter() {
if neighbor.peer_id == my_id {
continue;
}
update(neighbor.cost + 1, &neighbor);
log::trace!("route info: {:?}", neighbor);
}
// add the sender peer to route info
update(1, &packet.myself);
log::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
}
async fn send_sync_peer_request(
interface: &RouteInterfaceBox,
my_peer_id: packet::UUID,
global_ctx: ArcGlobalCtx,
peer_id: PeerId,
route_table: Arc<RouteTable>,
) -> Result<(), Error> {
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
// copy the route info
for item in route_table.route_info.iter() {
let (k, v) = item.pair();
route_info_copy.push(v.clone().clone_for_route_table(&(*k).into(), v.cost, &v));
}
let msg = SyncPeer::new(my_peer_id, peer_id.into(), route_info_copy, global_ctx);
// TODO: this may exceed the MTU of the tunnel
interface
.send_route_packet(encode_to_bytes::<_, 4096>(&msg), 1, &peer_id)
.await
}
async fn sync_peer_periodically(&self) {
let route_table = self.route_table.clone();
let global_ctx = self.global_ctx.clone();
let my_peer_id = self.my_peer_id.clone();
let interface = self.interface.clone();
let notifier = self.need_sync_notifier.clone();
self.tasks.lock().await.spawn(
async move {
loop {
let lockd_interface = interface.lock().await;
let interface = lockd_interface.as_ref().unwrap();
let peers = interface.list_peers().await;
for peer in peers.iter() {
let ret = Self::send_sync_peer_request(
interface,
my_peer_id.clone(),
global_ctx.clone(),
*peer,
route_table.clone(),
)
.await;
match &ret {
Ok(_) => {
log::trace!("send sync peer request to peer: {}", peer);
}
Err(Error::PeerNoConnectionError(_)) => {
log::trace!("peer {} no connection", peer);
}
Err(e) => {
log::error!(
"send sync peer request to peer: {} error: {:?}",
peer,
e
);
}
};
}
tokio::select! {
_ = notifier.notified() => {
log::trace!("sync peer request triggered by notifier");
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
log::trace!("sync peer request triggered by timeout");
}
}
}
}
.instrument(
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
),
);
}
async fn check_expired_sync_peer_from_remote(&self) {
let route_table = self.route_table.clone();
let my_peer_id = self.my_peer_id.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let notifier = self.need_sync_notifier.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut need_update_route = false;
let now = std::time::Instant::now();
let mut need_remove = Vec::new();
for item in sync_peer_from_remote.iter() {
let (k, v) = item.pair();
if now.duration_since(v.last_update).as_secs() > 5 {
need_update_route = true;
need_remove.insert(0, k.clone());
}
}
for k in need_remove.iter() {
log::warn!("remove expired sync peer: {:?}", k);
sync_peer_from_remote.remove(k);
}
if need_update_route {
Self::update_route_table(
my_peer_id.clone(),
sync_peer_from_remote.clone(),
route_table.clone(),
);
notifier.notify_one();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
let ipv4 = std::net::IpAddr::V4(*ipv4);
for item in self.route_table.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
if k.contains(&ipv4) {
return Some(*v);
}
}
None
}
}
#[async_trait]
impl Route for BasicRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
*self.interface.lock().await = Some(interface);
self.sync_peer_periodically().await;
self.check_expired_sync_peer_from_remote().await;
Ok(1)
}
async fn close(&self) {}
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
async fn handle_route_packet(&self, src_peer_id: uuid::Uuid, packet: Bytes) {
let packet = decode_from_bytes::<SyncPeer>(&packet).unwrap();
let p: SyncPeer = packet.deserialize(&mut rkyv::Infallible).unwrap();
let mut updated = true;
assert_eq!(packet.myself.peer_id.to_uuid(), src_peer_id);
self.sync_peer_from_remote
.entry(packet.myself.peer_id.to_uuid())
.and_modify(|v| {
if v.packet == *packet {
updated = false;
} else {
v.packet = p.clone();
}
v.last_update = std::time::Instant::now();
})
.or_insert(SyncPeerFromRemote {
packet: p.clone(),
last_update: std::time::Instant::now(),
});
if updated {
Self::update_route_table(
self.my_peer_id.clone(),
self.sync_peer_from_remote.clone(),
self.route_table.clone(),
);
self.need_sync_notifier.notify_one();
}
}
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
return Some(*peer_id);
}
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
return Some(peer_id);
}
log::info!("no peer id for ipv4: {}", ipv4_addr);
return None;
}
async fn get_next_hop(&self, dst_peer_id: &PeerId) -> Option<PeerId> {
match self.route_table.route_info.get(dst_peer_id) {
Some(info) => {
return Some(info.peer_id.clone().into());
}
None => {
log::error!("no route info for dst_peer_id: {}", dst_peer_id);
return None;
}
}
}
async fn list_routes(&self) -> Vec<easytier_rpc::Route> {
let mut routes = Vec::new();
let parse_route_info = |real_peer_id: &Uuid, route_info: &SyncPeerInfo| {
let mut route = easytier_rpc::Route::default();
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
ipv4_addr.to_string()
} else {
"".to_string()
};
route.peer_id = real_peer_id.to_string();
route.next_hop_peer_id = Uuid::from(route_info.peer_id.clone()).to_string();
route.cost = route_info.cost as i32;
route.proxy_cidrs = route_info.proxy_cidrs.clone();
route.hostname = if let Some(hostname) = &route_info.hostname {
hostname.clone()
} else {
"".to_string()
};
let mut stun_info = StunInfo::default();
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
stun_info.set_udp_nat_type(udp_nat_type);
}
route.stun_info = Some(stun_info);
route
};
self.route_table.route_info.iter().for_each(|item| {
routes.push(parse_route_info(item.key(), item.value()));
});
routes
}
}
+36
View File
@@ -0,0 +1,36 @@
use std::{net::Ipv4Addr, sync::Arc};
use async_trait::async_trait;
use tokio_util::bytes::Bytes;
use crate::common::error::Error;
use super::PeerId;
#[async_trait]
pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>;
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: &PeerId,
) -> Result<(), Error>;
}
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
#[async_trait]
pub trait Route {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
async fn close(&self);
async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId>;
async fn get_next_hop(&self, peer_id: &PeerId) -> Option<PeerId>;
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes);
async fn list_routes(&self) -> Vec<easytier_rpc::Route>;
}
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
+55
View File
@@ -0,0 +1,55 @@
use std::sync::Arc;
use easytier_rpc::peer_manage_rpc_server::PeerManageRpc;
use easytier_rpc::{ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse};
use tonic::{Request, Response, Status};
use super::peer_manager::PeerManager;
pub struct PeerManagerRpcService {
peer_manager: Arc<PeerManager>,
}
impl PeerManagerRpcService {
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
PeerManagerRpcService { peer_manager }
}
}
#[tonic::async_trait]
impl PeerManageRpc for PeerManagerRpcService {
async fn list_peer(
&self,
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListPeerResponse>, Status> {
let mut reply = ListPeerResponse::default();
let peers = self.peer_manager.get_peer_map().list_peers().await;
for peer in peers {
let mut peer_info = easytier_rpc::PeerInfo::default();
peer_info.peer_id = peer.to_string();
if let Some(conns) = self
.peer_manager
.get_peer_map()
.list_peer_conns(&peer)
.await
{
peer_info.conns = conns;
}
reply.peer_infos.push(peer_info);
}
Ok(Response::new(reply))
}
async fn list_route(
&self,
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListRouteResponse>, Status> {
let mut reply = ListRouteResponse::default();
reply.routes = self.peer_manager.list_routes().await;
Ok(Response::new(reply))
}
}
+60
View File
@@ -0,0 +1,60 @@
use std::sync::Arc;
use crate::{
common::{error::Error, global_ctx::tests::get_mock_global_ctx},
peers::rip_route::BasicRoute,
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::peer_manager::PeerManager;
pub async fn create_mock_peer_manager() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(get_mock_global_ctx(), s));
peer_mgr
.set_route(BasicRoute::new(
peer_mgr.my_node_id(),
peer_mgr.get_global_ctx(),
))
.await;
peer_mgr.run().await.unwrap();
peer_mgr
}
pub async fn connect_peer_manager(client: Arc<PeerManager>, server: Arc<PeerManager>) {
let (a_ring, b_ring) = create_ring_tunnel_pair();
let a_mgr_copy = client.clone();
tokio::spawn(async move {
a_mgr_copy.add_client_tunnel(a_ring).await.unwrap();
});
let b_mgr_copy = server.clone();
tokio::spawn(async move {
b_mgr_copy.add_tunnel_as_server(b_ring).await.unwrap();
});
}
pub async fn wait_route_appear_with_cost(
peer_mgr: Arc<PeerManager>,
node_id: uuid::Uuid,
cost: Option<i32>,
) -> Result<(), Error> {
let now = std::time::Instant::now();
while now.elapsed().as_secs() < 5 {
let route = peer_mgr.list_routes().await;
if route.iter().any(|r| {
r.peer_id.clone().parse::<uuid::Uuid>().unwrap() == node_id
&& (cost.is_none() || r.cost == cost.unwrap())
}) {
return Ok(());
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
return Err(Error::NotFound);
}
pub async fn wait_route_appear(
peer_mgr: Arc<PeerManager>,
node_id: uuid::Uuid,
) -> Result<(), Error> {
wait_route_appear_with_cost(peer_mgr, node_id, None).await
}
+1
View File
@@ -0,0 +1 @@
tonic::include_proto!("cli"); // The string specified here must match the proto package name
+4
View File
@@ -0,0 +1,4 @@
pub mod cli;
pub use cli::*;
pub mod peer;
+20
View File
@@ -0,0 +1,20 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct GetIpListResponse {
pub public_ipv4: String,
pub interface_ipv4s: Vec<String>,
pub public_ipv6: String,
pub interface_ipv6s: Vec<String>,
}
impl GetIpListResponse {
pub fn new() -> Self {
GetIpListResponse {
public_ipv4: "".to_string(),
interface_ipv4s: vec![],
public_ipv6: "".to_string(),
interface_ipv6s: vec![],
}
}
}
+174
View File
@@ -0,0 +1,174 @@
mod three_node;
pub fn get_guest_veth_name(net_ns: &str) -> &str {
Box::leak(format!("veth_{}_g", net_ns).into_boxed_str())
}
pub fn get_host_veth_name(net_ns: &str) -> &str {
Box::leak(format!("veth_{}_h", net_ns).into_boxed_str())
}
pub fn del_netns(name: &str) {
// del veth host
let _ = std::process::Command::new("ip")
.args(&["link", "del", get_host_veth_name(name)])
.output();
let _ = std::process::Command::new("ip")
.args(&["netns", "del", name])
.output();
}
pub fn create_netns(name: &str, ipv4: &str) {
// create netns
let _ = std::process::Command::new("ip")
.args(&["netns", "add", name])
.output()
.unwrap();
// set lo up
let _ = std::process::Command::new("ip")
.args(&["netns", "exec", name, "ip", "link", "set", "lo", "up"])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&[
"link",
"add",
get_host_veth_name(name),
"type",
"veth",
"peer",
"name",
get_guest_veth_name(name),
])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&["link", "set", get_guest_veth_name(name), "netns", name])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&[
"netns",
"exec",
name,
"ip",
"link",
"set",
get_guest_veth_name(name),
"up",
])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&["link", "set", get_host_veth_name(name), "up"])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&[
"netns",
"exec",
name,
"ip",
"addr",
"add",
ipv4,
"dev",
get_guest_veth_name(name),
])
.output()
.unwrap();
}
pub fn prepare_bridge(name: &str) {
// del bridge with brctl
let _ = std::process::Command::new("brctl")
.args(&["delbr", name])
.output();
// create new br
let _ = std::process::Command::new("brctl")
.args(&["addbr", name])
.output();
}
pub fn add_ns_to_bridge(br_name: &str, ns_name: &str) {
// use brctl to add ns to bridge
let _ = std::process::Command::new("brctl")
.args(&["addif", br_name, get_host_veth_name(ns_name)])
.output()
.unwrap();
// set bridge up
let _ = std::process::Command::new("ip")
.args(&["link", "set", br_name, "up"])
.output()
.unwrap();
}
pub fn enable_log() {
let filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
.from_env()
.unwrap();
tracing_subscriber::fmt::fmt()
.pretty()
.with_env_filter(filter)
.init();
}
fn check_route(ipv4: &str, dst_peer_id: uuid::Uuid, routes: Vec<easytier_rpc::Route>) {
let mut found = false;
for r in routes.iter() {
if r.ipv4_addr == ipv4.to_string() {
found = true;
assert_eq!(r.peer_id, dst_peer_id.to_string(), "{:?}", routes);
}
}
assert!(found);
}
async fn wait_proxy_route_appear(
mgr: &std::sync::Arc<crate::peers::peer_manager::PeerManager>,
ipv4: &str,
dst_peer_id: uuid::Uuid,
proxy_cidr: &str,
) {
let now = std::time::Instant::now();
loop {
for r in mgr.list_routes().await.iter() {
let r = r;
if r.proxy_cidrs.contains(&proxy_cidr.to_owned()) {
assert_eq!(r.peer_id, dst_peer_id.to_string());
assert_eq!(r.ipv4_addr, ipv4);
return;
}
}
if now.elapsed().as_secs() > 5 {
panic!("wait proxy route appear timeout");
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
fn set_link_status(net_ns: &str, up: bool) {
let _ = std::process::Command::new("ip")
.args(&[
"netns",
"exec",
net_ns,
"ip",
"link",
"set",
get_guest_veth_name(net_ns),
if up { "up" } else { "down" },
])
.output()
.unwrap();
}
+227
View File
@@ -0,0 +1,227 @@
use super::*;
use crate::{
common::netns::{NetNS, ROOT_NETNS_NAME},
instance::instance::{Instance, InstanceConfigWriter},
tunnels::{
common::tests::_tunnel_pingpong_netns,
ring_tunnel::RingTunnelConnector,
tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener},
udp_tunnel::UdpTunnelConnector,
},
};
pub fn prepare_linux_namespaces() {
del_netns("net_a");
del_netns("net_b");
del_netns("net_c");
del_netns("net_d");
create_netns("net_a", "10.1.1.1/24");
create_netns("net_b", "10.1.1.2/24");
create_netns("net_c", "10.1.2.3/24");
create_netns("net_d", "10.1.2.4/24");
prepare_bridge("br_a");
prepare_bridge("br_b");
add_ns_to_bridge("br_a", "net_a");
add_ns_to_bridge("br_a", "net_b");
add_ns_to_bridge("br_b", "net_c");
add_ns_to_bridge("br_b", "net_d");
}
pub async fn prepare_inst_configs() {
InstanceConfigWriter::new("inst1")
.set_ns(Some("net_a".into()))
.set_addr("10.144.144.1".to_owned());
InstanceConfigWriter::new("inst2")
.set_ns(Some("net_b".into()))
.set_addr("10.144.144.2".to_owned());
InstanceConfigWriter::new("inst3")
.set_ns(Some("net_c".into()))
.set_addr("10.144.144.3".to_owned());
}
pub async fn init_three_node(proto: &str) -> Vec<Instance> {
log::set_max_level(log::LevelFilter::Info);
prepare_linux_namespaces();
prepare_inst_configs().await;
let mut inst1 = Instance::new("inst1");
let mut inst2 = Instance::new("inst2");
let mut inst3 = Instance::new("inst3");
inst1.run().await.unwrap();
inst2.run().await.unwrap();
inst3.run().await.unwrap();
if proto == "tcp" {
inst2
.get_conn_manager()
.add_connector(TcpTunnelConnector::new(
"tcp://10.1.1.1:11010".parse().unwrap(),
));
} else {
inst2
.get_conn_manager()
.add_connector(UdpTunnelConnector::new(
"udp://10.1.1.1:11010".parse().unwrap(),
));
}
inst2
.get_conn_manager()
.add_connector(RingTunnelConnector::new(
format!("ring://{}", inst3.id()).parse().unwrap(),
));
// wait inst2 have two route.
let now = std::time::Instant::now();
loop {
if inst2.get_peer_manager().list_routes().await.len() == 2 {
break;
}
if now.elapsed().as_secs() > 5 {
panic!("wait inst2 have two route timeout");
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
vec![inst1, inst2, inst3]
}
#[tokio::test]
#[serial_test::serial]
pub async fn basic_three_node_test_tcp() {
let insts = init_three_node("tcp").await;
check_route(
"10.144.144.2",
insts[1].id(),
insts[0].get_peer_manager().list_routes().await,
);
check_route(
"10.144.144.3",
insts[2].id(),
insts[0].get_peer_manager().list_routes().await,
);
}
#[tokio::test]
#[serial_test::serial]
pub async fn basic_three_node_test_udp() {
let insts = init_three_node("udp").await;
check_route(
"10.144.144.2",
insts[1].id(),
insts[0].get_peer_manager().list_routes().await,
);
check_route(
"10.144.144.3",
insts[2].id(),
insts[0].get_peer_manager().list_routes().await,
);
}
#[tokio::test]
#[serial_test::serial]
pub async fn tcp_proxy_three_node_test() {
let insts = init_three_node("tcp").await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
&insts[0].get_peer_manager(),
"10.144.144.3",
insts[2].id(),
"10.1.2.0/24",
)
.await;
// wait updater
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap());
let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap());
_tunnel_pingpong_netns(
tcp_listener,
tcp_connector,
NetNS::new(Some("net_d".into())),
NetNS::new(Some("net_a".into())),
)
.await;
}
#[tokio::test]
#[serial_test::serial]
pub async fn icmp_proxy_three_node_test() {
let insts = init_three_node("tcp").await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
&insts[0].get_peer_manager(),
"10.144.144.3",
insts[2].id(),
"10.1.2.0/24",
)
.await;
// wait updater
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
// send ping with shell in net_a to net_d
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
let code = tokio::process::Command::new("ip")
.args(&[
"netns", "exec", "net_a", "ping", "-c", "1", "-W", "1", "10.1.2.4",
])
.status()
.await
.unwrap();
assert_eq!(code.code().unwrap(), 0);
}
#[tokio::test]
#[serial_test::serial]
pub async fn proxy_three_node_disconnect_test() {
InstanceConfigWriter::new("inst4")
.set_ns(Some("net_d".into()))
.set_addr("10.144.144.4".to_owned());
let mut inst4 = Instance::new("inst4");
inst4
.get_conn_manager()
.add_connector(TcpTunnelConnector::new(
"tcp://10.1.2.3:11010".parse().unwrap(),
));
inst4.run().await.unwrap();
tokio::spawn(async {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
set_link_status("net_d", false);
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
set_link_status("net_d", true);
}
});
// TODO: add some traffic here, also should check route & peer list
tokio::time::sleep(tokio::time::Duration::from_secs(35)).await;
}
+54
View File
@@ -0,0 +1,54 @@
use std::result::Result;
use tokio::io;
use tokio_util::{
bytes::{BufMut, Bytes, BytesMut},
codec::{Decoder, Encoder},
};
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
pub struct BytesCodec {
capacity: usize,
}
impl BytesCodec {
/// Creates a new `BytesCodec` for shipping around raw bytes.
pub fn new(capacity: usize) -> BytesCodec {
BytesCodec { capacity }
}
}
impl Decoder for BytesCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
if !buf.is_empty() {
let len = buf.len();
let ret = Some(buf.split_to(len));
buf.reserve(self.capacity);
Ok(ret)
} else {
Ok(None)
}
}
}
impl Encoder<Bytes> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
impl Encoder<BytesMut> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
+399
View File
@@ -0,0 +1,399 @@
use std::{
collections::VecDeque,
net::IpAddr,
sync::Arc,
task::{ready, Context, Poll},
};
use async_stream::stream;
use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
use tokio::{sync::Mutex, time::error::Elapsed};
use std::pin::Pin;
use crate::tunnels::{SinkError, TunnelError};
use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo};
pub struct FramedTunnel<R, W> {
read: Arc<Mutex<R>>,
write: Arc<Mutex<W>>,
info: Option<TunnelInfo>,
}
impl<R, RE, W, WE> FramedTunnel<R, W>
where
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
{
pub fn new(read: R, write: W, info: Option<TunnelInfo>) -> Self {
FramedTunnel {
read: Arc::new(Mutex::new(read)),
write: Arc::new(Mutex::new(write)),
info,
}
}
pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box<dyn Tunnel> {
Box::new(FramedTunnel::new(read, write, Some(info)))
}
pub fn recv_stream(&self) -> impl DatagramStream {
let read = self.read.clone();
let info = self.info.clone();
stream! {
loop {
let read_ret = read.lock().await.next().await;
if read_ret.is_none() {
tracing::info!(?info, "read_ret is none");
yield Err(TunnelError::CommonError("recv stream closed".to_string()));
} else {
let read_ret = read_ret.unwrap();
if read_ret.is_err() {
let err = read_ret.err().unwrap();
tracing::info!(?info, "recv stream read error");
yield Err(TunnelError::CommonError(err.to_string()));
} else {
yield Ok(read_ret.unwrap());
}
}
}
}
}
pub fn send_sink(&self) -> impl DatagramSink {
struct SendSink<W, WE> {
write: Arc<Mutex<W>>,
max_buffer_size: usize,
sending_buffers: Option<VecDeque<SinkItem>>,
send_task:
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
close_task:
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
}
impl<W, WE> SendSink<W, WE>
where
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
{
fn try_send_buffser(
&mut self,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), WE>> {
if self.send_task.is_none() {
let mut buffers = self.sending_buffers.take().unwrap();
let tun = self.write.clone();
let send_task = async move {
if buffers.is_empty() {
return Ok(());
}
let mut locked_tun = tun.lock_owned().await;
while let Some(buf) = buffers.front() {
log::trace!(
"try_send buffer, len: {:?}, buf: {:?}",
buffers.len(),
&buf
);
let timeout_task = tokio::time::timeout(
std::time::Duration::from_secs(1),
locked_tun.send(buf.clone()),
);
let send_res = timeout_task.await;
let Ok(send_res) = send_res else {
// panic!("send timeout");
let err = send_res.err().unwrap();
return Err(err.into());
};
let Ok(_) = send_res else {
let err = send_res.err().unwrap();
println!("send error: {:?}", err);
return Err(err);
};
buffers.pop_front();
}
return Ok(());
};
self.send_task = Some(Box::pin(send_task));
}
let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx));
self.send_task = None;
self.sending_buffers = Some(VecDeque::new());
return Poll::Ready(ret);
}
}
impl<W, WE> Sink<SinkItem> for SendSink<W, WE>
where
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
{
type Error = SinkError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
let sending_buf = self_mut.sending_buffers.as_ref();
// if sending_buffers is None, must already be doing flush
if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size {
return self_mut.poll_flush_unpin(cx);
} else {
return Poll::Ready(Ok(()));
}
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
assert!(self.send_task.is_none());
let self_mut = self.get_mut();
self_mut.sending_buffers.as_mut().unwrap().push_back(item);
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
let ret = self_mut.try_send_buffser(cx);
match ret {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))),
Poll::Pending => {
return Poll::Pending;
}
}
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
if self_mut.close_task.is_none() {
let tun = self_mut.write.clone();
let close_task = async move {
let mut locked_tun = tun.lock_owned().await;
return locked_tun.close().await;
};
self_mut.close_task = Some(Box::pin(close_task));
}
let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx));
self_mut.close_task = None;
if ret.is_err() {
return Poll::Ready(Err(SinkError::CommonError(
ret.err().unwrap().to_string(),
)));
} else {
return Poll::Ready(Ok(()));
}
}
}
SendSink {
write: self.write.clone(),
max_buffer_size: 1000,
sending_buffers: Some(VecDeque::new()),
send_task: None,
close_task: None,
}
}
}
impl<R, RE, W, WE> Tunnel for FramedTunnel<R, W>
where
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
{
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
if self.info.is_none() {
None
} else {
Some(self.info.clone().unwrap())
}
}
}
pub struct TunnelWithCustomInfo {
tunnel: Box<dyn Tunnel>,
info: TunnelInfo,
}
impl TunnelWithCustomInfo {
pub fn new(tunnel: Box<dyn Tunnel>, info: TunnelInfo) -> Self {
TunnelWithCustomInfo { tunnel, info }
}
}
impl Tunnel for TunnelWithCustomInfo {
fn stream(&self) -> Box<dyn DatagramStream> {
self.tunnel.stream()
}
fn sink(&self) -> Box<dyn DatagramSink> {
self.tunnel.sink()
}
fn info(&self) -> Option<TunnelInfo> {
Some(self.info.clone())
}
}
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
let ifaces = pnet::datalink::interfaces();
for iface in ifaces {
for ip in iface.ips {
if ip.ip() == *local_ip {
return Some(iface.name);
}
}
}
None
}
pub mod tests {
use std::time::Instant;
use futures::SinkExt;
use tokio_stream::StreamExt;
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
use crate::{
common::netns::NetNS,
tunnels::{close_tunnel, TunnelConnector, TunnelListener},
};
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
let mut recv = Box::into_pin(tunnel.stream());
let mut send = Box::into_pin(tunnel.sink());
while let Some(ret) = recv.next().await {
if ret.is_err() {
log::trace!("recv error: {:?}", ret.err().unwrap());
break;
}
let res = ret.unwrap();
log::trace!("recv a msg, try echo back: {:?}", res);
send.send(Bytes::from(res)).await.unwrap();
if once {
break;
}
}
log::warn!("echo server exit...");
}
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
}
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
mut listener: L,
mut connector: C,
l_netns: NetNS,
c_netns: NetNS,
) where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
l_netns
.run_async(|| async {
listener.listen().await.unwrap();
})
.await;
let lis = tokio::spawn(async move {
let ret = listener.accept().await.unwrap();
assert_eq!(
ret.info().unwrap().local_addr,
listener.local_url().to_string()
);
_tunnel_echo_server(ret, false).await
});
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
assert_eq!(
tunnel.info().unwrap().remote_addr,
connector.remote_url().to_string()
);
let mut send = tunnel.pin_sink();
let mut recv = tunnel.pin_stream();
let send_data = Bytes::from("abc");
send.send(send_data).await.unwrap();
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
.await
.unwrap()
.unwrap()
.unwrap();
println!("echo back: {:?}", ret);
assert_eq!(ret, Bytes::from("abc"));
close_tunnel(&tunnel).await.unwrap();
if connector.remote_url().scheme() == "udp" {
lis.abort();
} else {
// lis should finish in 1 second
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
assert!(ret.is_ok());
}
}
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
listener.listen().await.unwrap();
let lis = tokio::spawn(async move {
let ret = listener.accept().await.unwrap();
_tunnel_echo_server(ret, false).await
});
let tunnel = connector.connect().await.unwrap();
let mut send = tunnel.pin_sink();
let mut recv = tunnel.pin_stream();
// prepare a 4k buffer with random data
let mut send_buf = BytesMut::new();
for _ in 0..64 {
send_buf.put_i128(rand::random::<i128>());
}
let now = Instant::now();
let mut count = 0;
while now.elapsed().as_secs() < 3 {
send.send(send_buf.clone().freeze()).await.unwrap();
let _ = recv.next().await.unwrap().unwrap();
count += 1;
}
println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs());
lis.abort();
}
}
+159
View File
@@ -0,0 +1,159 @@
pub mod codec;
pub mod common;
pub mod ring_tunnel;
pub mod stats;
pub mod tcp_tunnel;
pub mod tunnel_filter;
pub mod udp_tunnel;
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
use async_trait::async_trait;
use easytier_rpc::TunnelInfo;
use futures::{Sink, SinkExt, Stream};
use thiserror::Error;
use tokio_util::bytes::{Bytes, BytesMut};
#[derive(Error, Debug)]
pub enum TunnelError {
#[error("Error: {0}")]
CommonError(String),
#[error("io error")]
IOError(#[from] std::io::Error),
#[error("wait resp error")]
WaitRespError(String),
#[error("Connect Error: {0}")]
ConnectError(String),
#[error("Invalid Protocol: {0}")]
InvalidProtocol(String),
#[error("Invalid Addr: {0}")]
InvalidAddr(String),
#[error("Tun Error: {0}")]
TunError(String),
#[error("timeout")]
Timeout(#[from] tokio::time::error::Elapsed),
}
pub type StreamT = BytesMut;
pub type StreamItem = Result<StreamT, TunnelError>;
pub type SinkItem = Bytes;
pub type SinkError = TunnelError;
pub trait DatagramStream: Stream<Item = StreamItem> + Send + Sync {}
impl<T> DatagramStream for T where T: Stream<Item = StreamItem> + Send + Sync {}
pub trait DatagramSink: Sink<SinkItem, Error = SinkError> + Send + Sync {}
impl<T> DatagramSink for T where T: Sink<SinkItem, Error = SinkError> + Send + Sync {}
#[auto_impl::auto_impl(Box, Arc)]
pub trait Tunnel: Send + Sync {
fn stream(&self) -> Box<dyn DatagramStream>;
fn sink(&self) -> Box<dyn DatagramSink>;
fn pin_stream(&self) -> Pin<Box<dyn DatagramStream>> {
Box::into_pin(self.stream())
}
fn pin_sink(&self) -> Pin<Box<dyn DatagramSink>> {
Box::into_pin(self.sink())
}
fn info(&self) -> Option<TunnelInfo>;
}
pub async fn close_tunnel(t: &Box<dyn Tunnel>) -> Result<(), TunnelError> {
t.pin_sink().close().await
}
#[auto_impl::auto_impl(Arc)]
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
fn get(&self) -> u32;
}
#[async_trait]
#[auto_impl::auto_impl(Box)]
pub trait TunnelListener: Send + Sync {
async fn listen(&mut self) -> Result<(), TunnelError>;
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
fn local_url(&self) -> url::Url;
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
#[derive(Debug)]
struct FakeTunnelConnCounter {}
impl TunnelConnCounter for FakeTunnelConnCounter {
fn get(&self) -> u32 {
0
}
}
Arc::new(Box::new(FakeTunnelConnCounter {}))
}
}
#[async_trait]
#[auto_impl::auto_impl(Box)]
pub trait TunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
fn remote_url(&self) -> url::Url;
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
}
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
}
impl std::fmt::Debug for dyn Tunnel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tunnel")
.field("info", &self.info())
.finish()
}
}
impl std::fmt::Debug for dyn TunnelConnector + Sync + Send {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TunnelConnector")
.field("remote_url", &self.remote_url())
.finish()
}
}
impl std::fmt::Debug for dyn TunnelListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TunnelListener")
.field("local_url", &self.local_url())
.finish()
}
}
pub(crate) trait FromUrl {
fn from_url(url: url::Url) -> Result<Self, TunnelError>
where
Self: Sized;
}
pub(crate) fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
Ok(T::from_url(url.clone())?)
}
impl FromUrl for SocketAddr {
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
Ok(url.socket_addrs(|| None)?.pop().unwrap())
}
}
impl FromUrl for uuid::Uuid {
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
let o = url.host_str().unwrap();
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
Ok(o)
}
}
+391
View File
@@ -0,0 +1,391 @@
use std::{
collections::HashMap,
sync::{atomic::AtomicBool, Arc},
task::Poll,
};
use async_stream::stream;
use crossbeam_queue::ArrayQueue;
use async_trait::async_trait;
use futures::Sink;
use once_cell::sync::Lazy;
use tokio::sync::{Mutex, Notify};
use futures::FutureExt;
use tokio_util::bytes::BytesMut;
use uuid::Uuid;
use crate::tunnels::{SinkError, SinkItem};
use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
};
static RING_TUNNEL_CAP: usize = 1000;
pub struct RingTunnel {
id: Uuid,
ring: Arc<ArrayQueue<SinkItem>>,
consume_notify: Arc<Notify>,
produce_notify: Arc<Notify>,
closed: Arc<AtomicBool>,
}
impl RingTunnel {
pub fn new(cap: usize) -> Self {
RingTunnel {
id: Uuid::new_v4(),
ring: Arc::new(ArrayQueue::new(cap)),
consume_notify: Arc::new(Notify::new()),
produce_notify: Arc::new(Notify::new()),
closed: Arc::new(AtomicBool::new(false)),
}
}
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
let mut ret = Self::new(cap);
ret.id = id;
ret
}
fn recv_stream(&self) -> impl DatagramStream {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
stream! {
loop {
if closed.load(std::sync::atomic::Ordering::Relaxed) {
log::warn!("ring recv tunnel {:?} closed", id);
yield Err(TunnelError::CommonError("Closed".to_owned()));
}
match ring.pop() {
Some(v) => {
let mut out = BytesMut::new();
out.extend_from_slice(&v);
consume_notify.notify_one();
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
yield Ok(out);
},
None => {
log::trace!("waiting recv buffer, id: {}", id);
produce_notify.notified().await;
}
}
}
}
}
fn send_sink(&self) -> impl DatagramSink {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
// type T = RingTunnel;
use tokio::task::JoinHandle;
struct T {
ring: RingTunnel,
wait_consume_task: Option<JoinHandle<()>>,
}
impl T {
fn wait_ring_consume(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
expected_size: usize,
) -> std::task::Poll<()> {
let self_mut = self.get_mut();
if self_mut.ring.ring.len() <= expected_size {
return Poll::Ready(());
}
if self_mut.wait_consume_task.is_none() {
let id = self_mut.ring.id;
let consume_notify = self_mut.ring.consume_notify.clone();
let ring = self_mut.ring.ring.clone();
let task = async move {
log::trace!(
"waiting ring consume done, expected_size: {}, id: {}",
expected_size,
id
);
while ring.len() > expected_size {
consume_notify.notified().await;
}
log::trace!(
"ring consume done, expected_size: {}, id: {}",
expected_size,
id
);
};
self_mut.wait_consume_task = Some(tokio::spawn(task));
}
let task = self_mut.wait_consume_task.as_mut().unwrap();
match task.poll_unpin(cx) {
Poll::Ready(_) => {
self_mut.wait_consume_task = None;
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
impl Sink<SinkItem> for T {
type Error = SinkError;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let expected_size = self.ring.ring.capacity() - 1;
match self.wait_ring_consume(cx, expected_size) {
Poll::Ready(_) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: SinkItem,
) -> Result<(), Self::Error> {
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
self.ring.ring.push(item).unwrap();
self.ring.produce_notify.notify_one();
Ok(())
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.ring
.closed
.store(true, std::sync::atomic::Ordering::Relaxed);
log::warn!("ring tunnel send {:?} closed", self.ring.id);
self.ring.produce_notify.notify_one();
Poll::Ready(Ok(()))
}
}
T {
ring: RingTunnel {
id,
ring,
consume_notify,
produce_notify,
closed,
},
wait_consume_task: None,
}
}
}
struct Connection {
client: RingTunnel,
server: RingTunnel,
connect_notify: Arc<Notify>,
}
impl Tunnel for RingTunnel {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
None
// Some(TunnelInfo {
// tunnel_type: "ring".to_owned(),
// local_addr: format!("ring://{}", self.id),
// remote_addr: format!("ring://{}", self.id),
// })
}
}
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, Arc<Connection>>>>> =
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
#[derive(Debug)]
pub struct RingTunnelListener {
listerner_addr: url::Url,
}
impl RingTunnelListener {
pub fn new(key: url::Url) -> Self {
RingTunnelListener {
listerner_addr: key,
}
}
}
struct ConnectionForServer {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForServer {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.server.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.client.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
})
}
}
struct ConnectionForClient {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForClient {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.client.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.server.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
})
}
}
impl RingTunnelListener {
async fn add_connection(listener_addr: uuid::Uuid) {
CONNECTION_MAP.lock().await.insert(
listener_addr.clone(),
Arc::new(Connection {
client: RingTunnel::new(RING_TUNNEL_CAP),
server: RingTunnel::new_with_id(listener_addr.clone(), RING_TUNNEL_CAP),
connect_notify: Arc::new(Notify::new()),
}),
);
}
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
}
}
#[async_trait]
impl TunnelListener for RingTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
log::info!("listen new conn of key: {}", self.listerner_addr);
Self::add_connection(self.get_addr()?).await;
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
let val = CONNECTION_MAP
.lock()
.await
.get(&self.get_addr()?)
.unwrap()
.clone();
val.connect_notify.notified().await;
CONNECTION_MAP.lock().await.remove(&self.get_addr()?);
Self::add_connection(self.get_addr()?).await;
log::info!("accept new conn of key: {}", self.listerner_addr);
Ok(Box::new(ConnectionForServer { conn: val }))
}
fn local_url(&self) -> url::Url {
self.listerner_addr.clone()
}
}
pub struct RingTunnelConnector {
remote_addr: url::Url,
}
impl RingTunnelConnector {
pub fn new(remote_addr: url::Url) -> Self {
RingTunnelConnector { remote_addr }
}
}
#[async_trait]
impl TunnelConnector for RingTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let val = CONNECTION_MAP
.lock()
.await
.get(&check_scheme_and_get_socket_addr::<Uuid>(
&self.remote_addr,
"ring",
)?)
.unwrap()
.clone();
val.connect_notify.notify_one();
log::info!("connecting");
Ok(Box::new(ConnectionForClient { conn: val }))
}
fn remote_url(&self) -> url::Url {
self.remote_addr.clone()
}
}
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
let conn = Arc::new(Connection {
client: RingTunnel::new(RING_TUNNEL_CAP),
server: RingTunnel::new(RING_TUNNEL_CAP),
connect_notify: Arc::new(Notify::new()),
});
(
Box::new(ConnectionForServer { conn: conn.clone() }),
Box::new(ConnectionForClient { conn: conn }),
)
}
#[cfg(test)]
mod tests {
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use super::*;
#[tokio::test]
async fn ring_pingpong() {
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
let listener = RingTunnelListener::new(id.clone());
let connector = RingTunnelConnector::new(id.clone());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ring_bench() {
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
let listener = RingTunnelListener::new(id.clone());
let connector = RingTunnelConnector::new(id);
_tunnel_bench(listener, connector).await
}
}
+101
View File
@@ -0,0 +1,101 @@
use std::sync::atomic::{AtomicU32, AtomicU64};
pub struct WindowLatency {
latency_us_window: Vec<AtomicU64>,
latency_us_window_index: AtomicU32,
latency_us_window_size: AtomicU32,
}
impl WindowLatency {
pub fn new(window_size: u32) -> Self {
Self {
latency_us_window: (0..window_size).map(|_| AtomicU64::new(0)).collect(),
latency_us_window_index: AtomicU32::new(0),
latency_us_window_size: AtomicU32::new(window_size),
}
}
pub fn record_latency(&self, latency_us: u64) {
let index = self
.latency_us_window_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let index = index
% self
.latency_us_window_size
.load(std::sync::atomic::Ordering::Relaxed);
self.latency_us_window[index as usize]
.store(latency_us, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_latency_us(&self) -> u64 {
let window_size = self
.latency_us_window_size
.load(std::sync::atomic::Ordering::Relaxed);
let mut sum = 0;
let mut count = 0;
for i in 0..window_size {
let latency_us =
self.latency_us_window[i as usize].load(std::sync::atomic::Ordering::Relaxed);
if latency_us > 0 {
sum += latency_us;
count += 1;
}
}
if count == 0 {
0
} else {
sum / count
}
}
}
pub struct Throughput {
tx_bytes: AtomicU64,
rx_bytes: AtomicU64,
tx_packets: AtomicU64,
rx_packets: AtomicU64,
}
impl Throughput {
pub fn new() -> Self {
Self {
tx_bytes: AtomicU64::new(0),
rx_bytes: AtomicU64::new(0),
tx_packets: AtomicU64::new(0),
rx_packets: AtomicU64::new(0),
}
}
pub fn tx_bytes(&self) -> u64 {
self.tx_bytes.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn rx_bytes(&self) -> u64 {
self.rx_bytes.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn tx_packets(&self) -> u64 {
self.tx_packets.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn rx_packets(&self) -> u64 {
self.rx_packets.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn record_tx_bytes(&self, bytes: u64) {
self.tx_bytes
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
self.tx_packets
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn record_rx_bytes(&self, bytes: u64) {
self.rx_bytes
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
self.rx_packets
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
+284
View File
@@ -0,0 +1,284 @@
use std::net::SocketAddr;
use async_trait::async_trait;
use futures::{stream::FuturesUnordered, StreamExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use super::{
check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener,
};
#[derive(Debug)]
pub struct TcpTunnelListener {
addr: url::Url,
listener: Option<TcpListener>,
}
impl TcpTunnelListener {
pub fn new(addr: url::Url) -> Self {
TcpTunnelListener {
addr,
listener: None,
}
}
}
#[async_trait]
impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
let socket = if addr.is_ipv4() {
TcpSocket::new_v4()?
} else {
TcpSocket::new_v6()?
};
socket.set_reuseaddr(true)?;
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
socket.set_reuseport(true)?;
socket.bind(addr)?;
self.listener = Some(socket.listen(1024)?);
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let listener = self.listener.as_ref().unwrap();
let (stream, _) = listener.accept().await?;
stream.set_nodelay(true).unwrap();
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: self.local_url().into(),
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
.into(),
};
let (r, w) = tokio::io::split(stream);
Ok(FramedTunnel::new_tunnel_with_info(
FramedRead::new(r, LengthDelimitedCodec::new()),
FramedWrite::new(w, LengthDelimitedCodec::new()),
info,
))
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
}
fn get_tunnel_with_tcp_stream(
stream: TcpStream,
remote_url: url::Url,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
stream.set_nodelay(true).unwrap();
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
.into(),
remote_addr: remote_url.into(),
};
let (r, w) = tokio::io::split(stream);
Ok(Box::new(FramedTunnel::new_tunnel_with_info(
FramedRead::new(r, LengthDelimitedCodec::new()),
FramedWrite::new(w, LengthDelimitedCodec::new()),
info,
)))
}
#[derive(Debug)]
pub struct TcpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
}
impl TcpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
TcpTunnelConnector {
addr,
bind_addrs: vec![],
}
}
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
tracing::info!(addr = ?self.addr, "connect tcp start");
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
let stream = TcpStream::connect(addr).await?;
tracing::info!(addr = ?self.addr, "connect tcp succ");
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
}
async fn connect_with_custom_bind(
&mut self,
is_ipv4: bool,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let mut futures = FuturesUnordered::new();
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
for bind_addr in self.bind_addrs.iter() {
let socket = if is_ipv4 {
TcpSocket::new_v4()?
} else {
TcpSocket::new_v6()?
};
socket.set_reuseaddr(true)?;
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
socket.set_reuseport(true)?;
socket.bind(*bind_addr)?;
// linux does not use interface of bind_addr to send packet, so we need to bind device
// mac can handle this with bind correctly
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket.bind_device(Some(dev_name.as_bytes()))?;
}
futures.push(socket.connect(dst_addr.clone()));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
}
}
#[async_trait]
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
if self.bind_addrs.is_empty() {
self.connect_with_default_bind().await
} else if self.bind_addrs[0].is_ipv4() {
self.connect_with_custom_bind(true).await
} else {
self.connect_with_custom_bind(false).await
}
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
mod tests {
use futures::SinkExt;
use crate::tunnels::{
common::tests::{_tunnel_bench, _tunnel_pingpong},
TunnelConnector,
};
use super::*;
#[tokio::test]
async fn tcp_pingpong() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn tcp_bench() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn tcp_bench_with_bind() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn tcp_bench_with_bind_fail() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
// test slow send lock in framed tunnel
#[tokio::test]
async fn tcp_multiple_sender_and_slow_receiver() {
// console_subscriber::init();
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
listener.listen().await.unwrap();
let t1 = tokio::spawn(async move {
let t = listener.accept().await.unwrap();
let mut stream = t.pin_stream();
let now = tokio::time::Instant::now();
while let Some(Ok(_)) = stream.next().await {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if now.elapsed().as_secs() > 5 {
break;
}
}
tracing::info!("t1 exit");
});
let tunnel = connector.connect().await.unwrap();
let mut sink1 = tunnel.pin_sink();
let t2 = tokio::spawn(async move {
for i in 0..1000000 {
let a = sink1.send(b"hello".to_vec().into()).await;
if a.is_err() {
tracing::info!(?a, "t2 exit with err");
break;
}
if i % 5000 == 0 {
tracing::info!(i, "send2 1000");
}
}
tracing::info!("t2 exit");
});
let mut sink2 = tunnel.pin_sink();
let t3 = tokio::spawn(async move {
for i in 0..1000000 {
let a = sink2.send(b"hello".to_vec().into()).await;
if a.is_err() {
tracing::info!(?a, "t3 exit with err");
break;
}
if i % 5000 == 0 {
tracing::info!(i, "send2 1000");
}
}
tracing::info!("t3 exit");
});
let t4 = tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
tracing::info!("closing");
let close_ret = tunnel.pin_sink().close().await;
tracing::info!("closed {:?}", close_ret);
});
let _ = tokio::join!(t1, t2, t3, t4);
}
}
+228
View File
@@ -0,0 +1,228 @@
use std::{
sync::Arc,
task::{Context, Poll},
};
use easytier_rpc::TunnelInfo;
use futures::{Sink, SinkExt, Stream, StreamExt};
use self::stats::Throughput;
use super::*;
use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel};
pub trait TunnelFilter {
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError>;
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError>;
}
pub struct TunnelWithFilter<T, F> {
inner: T,
filter: Arc<F>,
}
impl<T, F> Tunnel for TunnelWithFilter<T, F>
where
T: Tunnel + Send + Sync + 'static,
F: TunnelFilter + Send + Sync + 'static,
{
fn sink(&self) -> Box<dyn DatagramSink> {
struct SinkWrapper<F> {
sink: Pin<Box<dyn DatagramSink>>,
filter: Arc<F>,
}
impl<F> Sink<SinkItem> for SinkWrapper<F>
where
F: TunnelFilter + Send + Sync + 'static,
{
type Error = SinkError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
let item = self.filter.before_send(item)?;
self.get_mut().sink.start_send_unpin(item)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_flush_unpin(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_close_unpin(cx)
}
}
Box::new(SinkWrapper {
sink: self.inner.pin_sink(),
filter: self.filter.clone(),
})
}
fn stream(&self) -> Box<dyn DatagramStream> {
struct StreamWrapper<F> {
stream: Pin<Box<dyn DatagramStream>>,
filter: Arc<F>,
}
impl<F> Stream for StreamWrapper<F>
where
F: TunnelFilter + Send + Sync + 'static,
{
type Item = StreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_mut = self.get_mut();
match self_mut.stream.poll_next_unpin(cx) {
Poll::Ready(Some(ret)) => {
Poll::Ready(Some(self_mut.filter.after_received(ret)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
Box::new(StreamWrapper {
stream: self.inner.pin_stream(),
filter: self.filter.clone(),
})
}
fn info(&self) -> Option<TunnelInfo> {
self.inner.info()
}
}
impl<T, F> TunnelWithFilter<T, F>
where
T: Tunnel + Send + Sync + 'static,
F: TunnelFilter + Send + Sync + 'static,
{
pub fn new(inner: T, filter: Arc<F>) -> Self {
Self { inner, filter }
}
}
pub struct PacketRecorderTunnelFilter {
pub received: Arc<std::sync::Mutex<Vec<Bytes>>>,
pub sent: Arc<std::sync::Mutex<Vec<Bytes>>>,
}
impl TunnelFilter for PacketRecorderTunnelFilter {
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
self.received.lock().unwrap().push(data.clone());
Ok(data)
}
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
match data {
Ok(v) => {
self.sent.lock().unwrap().push(v.clone().into());
Ok(v)
}
Err(e) => Err(e),
}
}
}
impl PacketRecorderTunnelFilter {
pub fn new() -> Self {
Self {
received: Arc::new(std::sync::Mutex::new(Vec::new())),
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
}
pub struct StatsRecorderTunnelFilter {
throughput: Arc<Throughput>,
}
impl TunnelFilter for StatsRecorderTunnelFilter {
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
self.throughput.record_tx_bytes(data.len() as u64);
Ok(data)
}
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
match data {
Ok(v) => {
self.throughput.record_rx_bytes(v.len() as u64);
Ok(v)
}
Err(e) => Err(e),
}
}
}
impl StatsRecorderTunnelFilter {
pub fn new() -> Self {
Self {
throughput: Arc::new(Throughput::new()),
}
}
pub fn get_throughput(&self) -> Arc<Throughput> {
self.throughput.clone()
}
}
#[macro_export]
macro_rules! define_tunnel_filter_chain {
($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => (
pub struct $type_name {
$($field_name: std::sync::Arc<$filter_type>,)+
}
impl $type_name {
pub fn new() -> Self {
Self {
$($field_name: std::sync::Arc::new(<$filter_type>::new()),)+
}
}
pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel {
$(
let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone());
)+
tunnel
}
}
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tunnels::ring_tunnel::RingTunnel;
#[tokio::test]
async fn test_nested_filter() {
define_tunnel_filter_chain!(
Filter,
a = PacketRecorderTunnelFilter,
b = PacketRecorderTunnelFilter,
c = PacketRecorderTunnelFilter
);
let filter = Filter::new();
let tunnel = filter.wrap_tunnel(RingTunnel::new(1));
let mut s = tunnel.pin_sink();
s.send(Bytes::from("hello")).await.unwrap();
assert_eq!(1, filter.a.received.lock().unwrap().len());
assert_eq!(1, filter.b.received.lock().unwrap().len());
assert_eq!(1, filter.c.received.lock().unwrap().len());
}
}
+574
View File
@@ -0,0 +1,574 @@
use std::{fmt::Debug, pin::Pin, sync::Arc};
use async_trait::async_trait;
use dashmap::DashMap;
use easytier_rpc::TunnelInfo;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rkyv::{Archive, Deserialize, Serialize};
use std::net::SocketAddr;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tokio_util::{
bytes::{Buf, Bytes, BytesMut},
udp::UdpFramed,
};
use tracing::Instrument;
use crate::{
common::rkyv_util::{self, encode_to_bytes},
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
};
use super::{
codec::BytesCodec,
common::{FramedTunnel, TunnelWithCustomInfo},
ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelListener,
};
pub const UDP_DATA_MTU: usize = 2500;
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum UdpPacketPayload {
Syn,
Sack,
HolePunch(Vec<u8>),
Data(Vec<u8>),
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
pub struct UdpPacket {
pub conn_id: u32,
pub payload: UdpPacketPayload,
}
impl UdpPacket {
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
Self {
conn_id,
payload: UdpPacketPayload::Data(data),
}
}
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
Self {
conn_id: 0,
payload: UdpPacketPayload::HolePunch(data),
}
}
pub fn new_syn_packet(conn_id: u32) -> Self {
Self {
conn_id,
payload: UdpPacketPayload::Syn,
}
}
pub fn new_sack_packet(conn_id: u32) -> Self {
Self {
conn_id,
payload: UdpPacketPayload::Sack,
}
}
}
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error");
return None;
};
if udp_packet.conn_id != conn_id.clone() {
tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match");
return None;
}
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
tracing::warn!(?udp_packet, "udp payload not data");
return None;
};
let ptr_range = payload.as_ptr_range();
let offset = ptr_range.start as usize - buf.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
buf.advance(offset);
buf.truncate(len);
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
Some(buf)
}
fn get_tunnel_from_socket(
socket: Arc<UdpSocket>,
addr: SocketAddr,
conn_id: u32,
) -> Box<dyn super::Tunnel> {
let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU));
let (sink, stream) = udp.split();
let recv_addr = addr;
let stream = stream.filter_map(move |v| async move {
tracing::trace!(?v, "udp stream recv something");
if v.is_err() {
tracing::warn!(?v, "udp stream error");
return Some(Err(super::TunnelError::CommonError(
"udp stream error".to_owned(),
)));
}
let (buf, addr) = v.unwrap();
assert_eq!(addr, recv_addr.clone());
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
});
let stream = Box::pin(stream);
let sender_addr = addr;
let sink = Box::pin(sink.with(move |v: Bytes| async move {
if false {
return Err(super::TunnelError::CommonError("udp sink error".to_owned()));
}
// TODO: two copy here, how to avoid?
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
tracing::trace!(?udp_packet, ?v, "udp send packet");
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
Ok((v, sender_addr))
}));
FramedTunnel::new_tunnel_with_info(
stream,
sink,
// TODO: this remote addr is not a url
super::TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: super::build_url_from_socket_addr(
&socket.local_addr().unwrap().to_string(),
"udp",
)
.into(),
remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(),
},
)
}
struct StreamSinkPair(
Pin<Box<dyn DatagramStream>>,
Pin<Box<dyn DatagramSink>>,
u32,
);
type ArcStreamSinkPair = Arc<Mutex<StreamSinkPair>>;
pub struct UdpTunnelListener {
addr: url::Url,
socket: Option<Arc<UdpSocket>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
forward_tasks: Arc<Mutex<JoinSet<()>>>,
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
}
impl UdpTunnelListener {
pub fn new(addr: url::Url) -> Self {
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
Self {
addr,
socket: None,
sock_map: Arc::new(DashMap::new()),
forward_tasks: Arc::new(Mutex::new(JoinSet::new())),
conn_recv,
conn_send: Some(conn_send),
}
}
async fn try_forward_packet(
sock_map: &DashMap<SocketAddr, ArcStreamSinkPair>,
buf: BytesMut,
addr: SocketAddr,
) -> Result<(), super::TunnelError> {
let entry = sock_map.get_mut(&addr);
if entry.is_none() {
log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf);
return Ok(());
}
log::trace!("udp forward packet: {:?}, {:?}", addr, buf);
let entry = entry.unwrap();
let pair = entry.value().clone();
drop(entry);
let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else {
return Ok(());
};
pair.lock().await.1.send(buf.freeze()).await?;
Ok(())
}
async fn handle_connect(
socket: Arc<UdpSocket>,
addr: SocketAddr,
forward_tasks: Arc<Mutex<JoinSet<()>>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
local_url: url::Url,
conn_id: u32,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
tracing::info!(?conn_id, ?addr, "udp connection accept handling",);
let udp_packet = UdpPacket::new_sack_packet(conn_id);
let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
socket.send_to(&sack_buf, addr).await?;
let (ctunnel, stunnel) = create_ring_tunnel_pair();
let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id);
let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id);
let addr_copy = addr.clone();
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
let ctunnel_stream = ctunnel.pin_stream();
forward_tasks.lock().await.spawn(async move {
let ret = ctunnel_stream
.map(|v| {
tracing::trace!(?v, "udp stream recv something in forward task");
if v.is_err() {
return Err(super::TunnelError::CommonError(
"udp stream error".to_owned(),
));
}
Ok(v.unwrap().freeze())
})
.forward(udp_tunnel.pin_sink())
.await;
if let None = sock_map.remove(&addr_copy) {
log::warn!("udp forward packet: {:?}, no entry", addr_copy);
}
close_tunnel(&ctunnel).await.unwrap();
log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret);
});
Ok(Box::new(TunnelWithCustomInfo::new(
stunnel,
TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: local_url.into(),
remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(),
},
)))
}
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
self.socket.clone()
}
}
#[async_trait]
impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
self.socket = Some(Arc::new(UdpSocket::bind(addr).await?));
let socket = self.socket.as_ref().unwrap().clone();
let forward_tasks = self.forward_tasks.clone();
let sock_map = self.sock_map.clone();
let conn_send = self.conn_send.take().unwrap();
let local_url = self.local_url().clone();
self.forward_tasks.lock().await.spawn(
async move {
loop {
let mut buf = BytesMut::new();
buf.resize(2500, 0);
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
let _ = buf.split_off(_size);
log::trace!(
"udp recv packet: {:?}, buf: {:?}, size: {}",
addr,
buf,
_size
);
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf)
else {
tracing::warn!(?buf, "udp decode error in forward task");
continue;
};
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
let conn = Self::handle_connect(
socket.clone(),
addr,
forward_tasks.clone(),
sock_map.clone(),
local_url.clone(),
udp_packet.conn_id.into(),
)
.await
.unwrap();
if let Err(e) = conn_send.send(conn).await {
tracing::warn!(?e, "udp send conn to accept channel error");
}
} else {
Self::try_forward_packet(sock_map.as_ref(), buf, addr)
.await
.unwrap();
}
}
}
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
);
// let forward_tasks_clone = self.forward_tasks.clone();
// tokio::spawn(async move {
// loop {
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
// tokio::select! {
// ret = locked_forward_tasks.join_next() => {
// tracing::warn!(?ret, "udp forward task exit");
// }
// else => {
// drop(locked_forward_tasks);
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// continue;
// }
// }
// }
// });
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
log::info!("start udp accept: {:?}", self.addr);
while let Some(conn) = self.conn_recv.recv().await {
return Ok(conn);
}
return Err(super::TunnelError::CommonError(
"udp accept error".to_owned(),
));
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
struct UdpTunnelConnCounter {
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
}
impl Debug for UdpTunnelConnCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdpTunnelConnCounter")
.field("sock_map_len", &self.sock_map.len())
.finish()
}
}
impl TunnelConnCounter for UdpTunnelConnCounter {
fn get(&self) -> u32 {
self.sock_map.len() as u32
}
}
Arc::new(Box::new(UdpTunnelConnCounter {
sock_map: self.sock_map.clone(),
}))
}
}
pub struct UdpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
}
impl UdpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
Self {
addr,
bind_addrs: vec![],
}
}
async fn wait_sack(
socket: &UdpSocket,
addr: SocketAddr,
conn_id: u32,
) -> Result<(), super::TunnelError> {
let mut buf = BytesMut::new();
buf.resize(128, 0);
let (usize, recv_addr) = tokio::time::timeout(
tokio::time::Duration::from_secs(3),
socket.recv_from(&mut buf),
)
.await??;
if recv_addr != addr {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, unexpected sack addr: {:?}, {:?}",
recv_addr, addr
)));
}
let _ = buf.split_off(usize);
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error in wait sack");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, decode error. buf: {:?}",
buf
)));
};
if conn_id != udp_packet.conn_id {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
conn_id, udp_packet.conn_id
)));
}
if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, unexpected payload. payload: {:?}",
udp_packet.payload
)));
}
Ok(())
}
async fn wait_sack_loop(
socket: &UdpSocket,
addr: SocketAddr,
conn_id: u32,
) -> Result<(), super::TunnelError> {
while let Err(err) = Self::wait_sack(socket, addr, conn_id).await {
tracing::warn!(?err, "udp wait sack error");
}
Ok(())
}
pub async fn try_connect_with_socket(
&self,
socket: UdpSocket,
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
log::warn!("udp connect: {:?}", self.addr);
// send syn
let conn_id = rand::random();
let udp_packet = UdpPacket::new_syn_packet(conn_id);
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
let ret = socket.send_to(&b, &addr).await?;
tracing::warn!(?udp_packet, ?ret, "udp send syn");
// wait sack
tokio::time::timeout(
tokio::time::Duration::from_secs(3),
Self::wait_sack_loop(&socket, addr, conn_id),
)
.await??;
// sack done
let local_addr = socket.local_addr().unwrap().to_string();
Ok(Box::new(TunnelWithCustomInfo::new(
get_tunnel_from_socket(Arc::new(socket), addr, conn_id),
TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(),
remote_addr: self.remote_url().into(),
},
)))
}
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
return self.try_connect_with_socket(socket).await;
}
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let mut futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
let socket = UdpSocket::bind(*bind_addr).await?;
// linux does not use interface of bind_addr to send packet, so we need to bind device
// mac can handle this with bind correctly
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket.bind_device(Some(dev_name.as_bytes()))?;
}
futures.push(self.try_connect_with_socket(socket));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return ret;
}
}
#[async_trait]
impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
if self.bind_addrs.is_empty() {
self.connect_with_default_bind().await
} else {
self.connect_with_custom_bind().await
}
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
mod tests {
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use super::*;
#[tokio::test]
async fn udp_pingpong() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn udp_bench() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn udp_bench_with_bind() {
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn udp_bench_with_bind_fail() {
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
}