fix machine uid and easytier-web panic (#2215)

1. fix(web-client): persist and migrate machine id
2. fix panic when easytier-web session receive malformat packet
This commit is contained in:
KKRainbow
2026-05-07 00:57:42 +08:00
committed by GitHub
parent 4342c8d7a2
commit baeee40b79
13 changed files with 725 additions and 100 deletions
-2
View File
@@ -23,8 +23,6 @@ define_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, u64, 1000);
define_global_var!(OSPF_UPDATE_MY_GLOBAL_FOREIGN_NETWORK_INTERVAL_SEC, u64, 10);
define_global_var!(MACHINE_UID, Option<String>, None);
define_global_var!(MAX_DIRECT_CONNS_PER_PEER_IN_FOREIGN_NETWORK, u32, 3);
define_global_var!(DIRECT_CONNECT_TO_PUBLIC_SERVER, bool, true);
+596
View File
@@ -0,0 +1,596 @@
use std::{
env,
ffi::OsString,
io::Write as _,
path::{Path, PathBuf},
time::{Duration, Instant},
};
use anyhow::Context as _;
#[cfg(unix)]
use nix::{
errno::Errno,
fcntl::{Flock, FlockArg},
};
#[derive(Debug, Clone, Default)]
pub struct MachineIdOptions {
pub explicit_machine_id: Option<String>,
pub state_dir: Option<PathBuf>,
}
pub fn resolve_machine_id(opts: &MachineIdOptions) -> anyhow::Result<uuid::Uuid> {
if let Some(explicit_machine_id) = opts.explicit_machine_id.as_deref() {
return Ok(parse_or_hash_machine_id(explicit_machine_id));
}
let state_file = resolve_machine_id_state_file(opts.state_dir.as_deref())?;
let allow_legacy_machine_uid_migration =
should_attempt_legacy_machine_uid_migration(&state_file);
if let Some(machine_id) = read_state_machine_id(&state_file)? {
return Ok(machine_id);
}
if let Some(machine_id) = read_legacy_machine_id_file() {
return persist_machine_id(&state_file, machine_id);
}
if allow_legacy_machine_uid_migration
&& let Some(machine_id) = resolve_legacy_machine_uid_hash()
{
return persist_machine_id(&state_file, machine_id);
}
let machine_id = resolve_new_machine_id().unwrap_or_else(uuid::Uuid::new_v4);
persist_machine_id(&state_file, machine_id)
}
fn parse_or_hash_machine_id(raw: &str) -> uuid::Uuid {
if let Ok(mid) = uuid::Uuid::parse_str(raw.trim()) {
return mid;
}
digest_uuid_from_str(raw)
}
fn digest_uuid_from_str(raw: &str) -> uuid::Uuid {
let mut b = [0u8; 16];
crate::tunnel::generate_digest_from_str("", raw, &mut b);
uuid::Uuid::from_bytes(b)
}
fn resolve_machine_id_state_file(state_dir: Option<&Path>) -> anyhow::Result<PathBuf> {
let state_dir = match state_dir {
Some(dir) => dir.to_path_buf(),
None => default_machine_id_state_dir()?,
};
Ok(state_dir.join("machine_id"))
}
fn non_empty_os_string(value: Option<OsString>) -> Option<OsString> {
value.filter(|value| !value.is_empty())
}
#[cfg(target_os = "linux")]
fn default_linux_machine_id_state_dir(
xdg_data_home: Option<OsString>,
home: Option<OsString>,
) -> PathBuf {
if let Some(path) = non_empty_os_string(xdg_data_home) {
return PathBuf::from(path).join("easytier");
}
if let Some(home) = non_empty_os_string(home) {
return PathBuf::from(home)
.join(".local")
.join("share")
.join("easytier");
}
PathBuf::from("/var/lib/easytier")
}
fn default_machine_id_state_dir() -> anyhow::Result<PathBuf> {
cfg_select! {
target_os = "linux" => Ok(default_linux_machine_id_state_dir(
env::var_os("XDG_DATA_HOME"),
env::var_os("HOME"),
)),
all(target_os = "macos", not(feature = "macos-ne")) => {
let home = non_empty_os_string(env::var_os("HOME"))
.ok_or_else(|| anyhow::anyhow!("HOME is not set, cannot resolve machine id state directory"))?;
Ok(PathBuf::from(home)
.join("Library")
.join("Application Support")
.join("com.easytier"))
},
target_os = "windows" => {
let local_app_data = non_empty_os_string(env::var_os("LOCALAPPDATA")).ok_or_else(|| {
anyhow::anyhow!("LOCALAPPDATA is not set, cannot resolve machine id state directory")
})?;
Ok(PathBuf::from(local_app_data).join("easytier"))
},
target_os = "freebsd" => {
let home = non_empty_os_string(env::var_os("HOME"))
.ok_or_else(|| anyhow::anyhow!("HOME is not set, cannot resolve machine id state directory"))?;
Ok(PathBuf::from(home).join(".local").join("share").join("easytier"))
},
target_os = "android" => {
anyhow::bail!("machine id state directory must be provided explicitly on Android");
},
_ => anyhow::bail!("machine id state directory is unsupported on this platform"),
}
}
fn read_state_machine_id(path: &Path) -> anyhow::Result<Option<uuid::Uuid>> {
let Some(contents) = read_optional_file(path)? else {
return Ok(None);
};
let machine_id = uuid::Uuid::parse_str(contents.trim())
.with_context(|| format!("invalid machine id in state file {}", path.display()))?;
Ok(Some(machine_id))
}
fn read_legacy_machine_id_file() -> Option<uuid::Uuid> {
let path = legacy_machine_id_file_path()?;
read_legacy_machine_id_file_at(&path)
}
fn read_legacy_machine_id_file_at(path: &Path) -> Option<uuid::Uuid> {
let contents = match std::fs::read_to_string(path) {
Ok(contents) => contents,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return None,
Err(err) => {
tracing::warn!(
path = %path.display(),
%err,
"ignoring unreadable legacy machine id file"
);
return None;
}
};
match uuid::Uuid::parse_str(contents.trim()) {
Ok(machine_id) => Some(machine_id),
Err(err) => {
tracing::warn!(
path = %path.display(),
%err,
"ignoring invalid legacy machine id file"
);
None
}
}
}
fn legacy_machine_id_file_path() -> Option<PathBuf> {
std::env::current_exe()
.ok()
.map(|path| path.with_file_name("et_machine_id"))
}
fn read_optional_file(path: &Path) -> anyhow::Result<Option<String>> {
match std::fs::read_to_string(path) {
Ok(contents) => Ok(Some(contents)),
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(err).with_context(|| format!("failed to read {}", path.display())),
}
}
fn should_attempt_legacy_machine_uid_migration(state_file: &Path) -> bool {
let Some(state_dir) = state_file.parent() else {
return false;
};
let Ok(mut entries) = std::fs::read_dir(state_dir) else {
return false;
};
entries.any(|entry| entry.is_ok())
}
fn resolve_legacy_machine_uid_hash() -> Option<uuid::Uuid> {
machine_uid_seed().map(|seed| digest_uuid_from_str(seed.as_str()))
}
fn resolve_new_machine_id() -> Option<uuid::Uuid> {
let seed = machine_uid_seed()?;
#[cfg(target_os = "linux")]
{
let seed = linux_machine_id_seed(&seed);
Some(digest_uuid_from_str(&seed))
}
#[cfg(not(target_os = "linux"))]
{
Some(digest_uuid_from_str(&seed))
}
}
#[cfg(any(
target_os = "linux",
all(target_os = "macos", not(feature = "macos-ne")),
target_os = "windows",
target_os = "freebsd"
))]
fn machine_uid_seed() -> Option<String> {
machine_uid::get()
.ok()
.filter(|value| !value.trim().is_empty())
}
#[cfg(not(any(
target_os = "linux",
all(target_os = "macos", not(feature = "macos-ne")),
target_os = "windows",
target_os = "freebsd"
)))]
fn machine_uid_seed() -> Option<String> {
None
}
#[cfg(target_os = "linux")]
fn linux_machine_id_seed(machine_uid: &str) -> String {
let mut seed = format!("machine_uid={machine_uid}");
let hostname = gethostname::gethostname()
.to_string_lossy()
.trim()
.to_string();
if !hostname.is_empty() {
seed.push_str("\nhostname=");
seed.push_str(&hostname);
}
let mac_addresses = collect_linux_mac_addresses();
if !mac_addresses.is_empty() {
seed.push_str("\nmacs=");
seed.push_str(&mac_addresses.join(","));
}
seed
}
#[cfg(target_os = "linux")]
fn collect_linux_mac_addresses() -> Vec<String> {
let mut macs = Vec::new();
let Ok(entries) = std::fs::read_dir("/sys/class/net") else {
return macs;
};
for entry in entries.flatten() {
let Ok(name) = entry.file_name().into_string() else {
continue;
};
if name == "lo" {
continue;
}
let address_path = entry.path().join("address");
let Ok(address) = std::fs::read_to_string(address_path) else {
continue;
};
let address = address.trim().to_ascii_lowercase();
if address.is_empty() || address == "00:00:00:00:00:00" {
continue;
}
macs.push(address);
}
macs.sort();
macs.dedup();
macs.truncate(3);
macs
}
fn persist_machine_id(path: &Path, machine_id: uuid::Uuid) -> anyhow::Result<uuid::Uuid> {
if let Some(existing) = read_state_machine_id(path)? {
return Ok(existing);
}
let _lock = MachineIdWriteLock::acquire(path)?;
if let Some(existing) = read_state_machine_id(path)? {
return Ok(existing);
}
write_uuid_file_atomically(path, machine_id)?;
Ok(machine_id)
}
fn write_uuid_file_atomically(path: &Path, machine_id: uuid::Uuid) -> anyhow::Result<()> {
let parent = path.parent().ok_or_else(|| {
anyhow::anyhow!(
"machine id state file {} has no parent directory",
path.display()
)
})?;
std::fs::create_dir_all(parent).with_context(|| {
format!(
"failed to create machine id state directory {}",
parent.display()
)
})?;
let tmp_path = parent.join(format!(
".machine_id.tmp-{}-{}",
std::process::id(),
uuid::Uuid::new_v4()
));
{
let mut file = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(&tmp_path)
.with_context(|| format!("failed to create {}", tmp_path.display()))?;
file.write_all(machine_id.to_string().as_bytes())
.with_context(|| format!("failed to write {}", tmp_path.display()))?;
file.sync_all()
.with_context(|| format!("failed to flush {}", tmp_path.display()))?;
}
if let Err(err) = std::fs::rename(&tmp_path, path) {
let _ = std::fs::remove_file(&tmp_path);
return Err(err).with_context(|| {
format!(
"failed to move machine id state file into place at {}",
path.display()
)
});
}
Ok(())
}
struct MachineIdWriteLock {
#[cfg(unix)]
_lock: Flock<std::fs::File>,
#[cfg(not(unix))]
path: PathBuf,
}
impl MachineIdWriteLock {
fn acquire(path: &Path) -> anyhow::Result<Self> {
let parent = path.parent().ok_or_else(|| {
anyhow::anyhow!(
"machine id state file {} has no parent directory",
path.display()
)
})?;
std::fs::create_dir_all(parent).with_context(|| {
format!(
"failed to create machine id state directory {}",
parent.display()
)
})?;
#[cfg(unix)]
{
Self::acquire_unix(path)
}
#[cfg(not(unix))]
{
Self::acquire_fallback(path)
}
}
#[cfg(unix)]
fn acquire_unix(path: &Path) -> anyhow::Result<Self> {
let lock_path = path.with_extension("lock");
let deadline = Instant::now() + Duration::from_secs(5);
let mut lock_file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&lock_path)
.with_context(|| format!("failed to open machine id lock {}", lock_path.display()))?;
loop {
match Flock::lock(lock_file, FlockArg::LockExclusiveNonblock) {
Ok(lock) => return Ok(Self { _lock: lock }),
Err((file, Errno::EAGAIN)) => {
if Instant::now() >= deadline {
anyhow::bail!(
"timed out waiting for machine id lock {}",
lock_path.display()
);
}
lock_file = file;
std::thread::sleep(Duration::from_millis(50));
}
Err((_file, err)) => {
anyhow::bail!(
"failed to acquire machine id lock {}: {}",
lock_path.display(),
err
);
}
}
}
}
#[cfg(not(unix))]
fn acquire_fallback(path: &Path) -> anyhow::Result<Self> {
let lock_path = path.with_extension("lock");
let deadline = Instant::now() + Duration::from_secs(5);
loop {
match std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(&lock_path)
{
Ok(mut file) => {
writeln!(file, "pid={}", std::process::id()).ok();
return Ok(Self { path: lock_path });
}
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
if should_reap_stale_lock_file(&lock_path) {
let _ = std::fs::remove_file(&lock_path);
continue;
}
if Instant::now() >= deadline {
anyhow::bail!(
"timed out waiting for machine id lock {}",
lock_path.display()
);
}
std::thread::sleep(Duration::from_millis(50));
}
Err(err) => {
return Err(err).with_context(|| {
format!("failed to acquire machine id lock {}", lock_path.display())
});
}
}
}
}
}
#[cfg(not(unix))]
fn should_reap_stale_lock_file(lock_path: &Path) -> bool {
const STALE_LOCK_AGE: Duration = Duration::from_secs(30);
let Ok(metadata) = std::fs::metadata(lock_path) else {
return false;
};
let Ok(modified) = metadata.modified() else {
return false;
};
modified
.elapsed()
.is_ok_and(|elapsed| elapsed >= STALE_LOCK_AGE)
}
impl Drop for MachineIdWriteLock {
fn drop(&mut self) {
#[cfg(not(unix))]
let _ = std::fs::remove_file(&self.path);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolve_machine_id_uses_uuid_seed_verbatim() {
let raw = "33333333-3333-3333-3333-333333333333".to_string();
let opts = MachineIdOptions {
explicit_machine_id: Some(raw.clone()),
state_dir: None,
};
assert_eq!(
resolve_machine_id(&opts).unwrap(),
uuid::Uuid::parse_str(&raw).unwrap()
);
}
#[test]
fn test_resolve_machine_id_reads_state_file() {
let temp_dir = tempfile::tempdir().unwrap();
let expected = uuid::Uuid::new_v4();
std::fs::write(temp_dir.path().join("machine_id"), expected.to_string()).unwrap();
let opts = MachineIdOptions {
explicit_machine_id: None,
state_dir: Some(temp_dir.path().to_path_buf()),
};
assert_eq!(resolve_machine_id(&opts).unwrap(), expected);
}
#[test]
fn test_read_legacy_machine_id_file_ignores_read_errors() {
let temp_dir = tempfile::tempdir().unwrap();
assert_eq!(read_legacy_machine_id_file_at(temp_dir.path()), None);
}
#[test]
fn test_write_uuid_file_atomically_writes_expected_contents() {
let temp_dir = tempfile::tempdir().unwrap();
let machine_id = uuid::Uuid::new_v4();
let state_file = temp_dir.path().join("machine_id");
write_uuid_file_atomically(&state_file, machine_id).unwrap();
assert_eq!(
std::fs::read_to_string(state_file).unwrap(),
machine_id.to_string()
);
}
#[test]
fn test_non_empty_os_string_filters_empty_values() {
assert_eq!(non_empty_os_string(Some(OsString::new())), None);
assert_eq!(
non_empty_os_string(Some(OsString::from("foo"))),
Some(OsString::from("foo"))
);
}
#[cfg(target_os = "linux")]
#[test]
fn test_default_linux_machine_id_state_dir_falls_back_in_order() {
assert_eq!(
default_linux_machine_id_state_dir(
Some(OsString::from("/tmp/xdg")),
Some(OsString::from("/tmp/home"))
),
PathBuf::from("/tmp/xdg").join("easytier")
);
assert_eq!(
default_linux_machine_id_state_dir(
Some(OsString::new()),
Some(OsString::from("/tmp/home"))
),
PathBuf::from("/tmp/home")
.join(".local")
.join("share")
.join("easytier")
);
assert_eq!(
default_linux_machine_id_state_dir(Some(OsString::new()), Some(OsString::new())),
PathBuf::from("/var/lib/easytier")
);
}
#[test]
fn test_persist_machine_id_creates_missing_state_dir() {
let temp_dir = tempfile::tempdir().unwrap();
let state_file = temp_dir.path().join("nested").join("machine_id");
let machine_id = uuid::Uuid::new_v4();
assert_eq!(
persist_machine_id(&state_file, machine_id).unwrap(),
machine_id
);
assert_eq!(
std::fs::read_to_string(state_file).unwrap(),
machine_id.to_string()
);
}
#[test]
fn test_legacy_machine_uid_migration_requires_existing_state_dir_content() {
let temp_dir = tempfile::tempdir().unwrap();
let missing_state_file = temp_dir.path().join("missing").join("machine_id");
assert!(!should_attempt_legacy_machine_uid_migration(
&missing_state_file
));
let empty_dir = temp_dir.path().join("empty");
std::fs::create_dir_all(&empty_dir).unwrap();
assert!(!should_attempt_legacy_machine_uid_migration(
&empty_dir.join("machine_id")
));
std::fs::write(empty_dir.join("config.toml"), "x=1").unwrap();
assert!(should_attempt_legacy_machine_uid_migration(
&empty_dir.join("machine_id")
));
}
}
+3 -76
View File
@@ -1,15 +1,12 @@
use std::{
fmt::Debug,
future,
io::Write as _,
sync::{Arc, Mutex},
};
use time::util::refresh_tz;
use tokio::{task::JoinSet, time::timeout};
use tracing::Instrument;
use crate::{set_global_var, use_global_var};
pub mod acl_processor;
pub mod compressor;
pub mod config;
@@ -21,6 +18,7 @@ pub mod global_ctx;
pub mod idn;
pub mod ifcfg;
pub mod log;
pub mod machine_id;
pub mod netns;
pub mod network;
pub mod os_info;
@@ -31,6 +29,8 @@ pub mod token_bucket;
pub mod tracing_rolling_appender;
pub mod upnp;
pub use machine_id::{MachineIdOptions, resolve_machine_id};
pub fn get_logger_timer<F: time::formatting::Formattable>(
format: F,
) -> tracing_subscriber::fmt::time::OffsetTime<F> {
@@ -96,71 +96,6 @@ pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
);
}
pub fn set_default_machine_id(mid: Option<String>) {
set_global_var!(MACHINE_UID, mid);
}
pub fn get_machine_id() -> uuid::Uuid {
if let Some(default_mid) = use_global_var!(MACHINE_UID) {
if let Ok(mid) = uuid::Uuid::parse_str(default_mid.trim()) {
return mid;
}
let mut b = [0u8; 16];
crate::tunnel::generate_digest_from_str("", &default_mid, &mut b);
return uuid::Uuid::from_bytes(b);
}
// a path same as the binary
let machine_id_file = std::env::current_exe()
.map(|x| x.with_file_name("et_machine_id"))
.unwrap_or_else(|_| std::path::PathBuf::from("et_machine_id"));
// try load from local file
if let Ok(mid) = std::fs::read_to_string(&machine_id_file)
&& let Ok(mid) = uuid::Uuid::parse_str(mid.trim())
{
return mid;
}
#[cfg(any(
target_os = "linux",
all(target_os = "macos", not(feature = "macos-ne")),
target_os = "windows",
target_os = "freebsd"
))]
let gen_mid = machine_uid::get()
.map(|x| {
if x.is_empty() {
return uuid::Uuid::new_v4();
}
let mut b = [0u8; 16];
crate::tunnel::generate_digest_from_str("", x.as_str(), &mut b);
uuid::Uuid::from_bytes(b)
})
.ok();
#[cfg(not(any(
target_os = "linux",
all(target_os = "macos", not(feature = "macos-ne")),
target_os = "windows",
target_os = "freebsd"
)))]
let gen_mid = None;
if let Some(mid) = gen_mid {
return mid;
}
let gen_mid = uuid::Uuid::new_v4();
// try save to local file
if let Ok(mut file) = std::fs::File::create(machine_id_file) {
let _ = file.write_all(gen_mid.to_string().as_bytes());
}
gen_mid
}
pub fn shrink_dashmap<K: Eq + std::hash::Hash, V>(
map: &dashmap::DashMap<K, V>,
threshold: Option<usize>,
@@ -210,12 +145,4 @@ mod tests {
assert_eq!(weak_js.weak_count(), 0);
assert_eq!(weak_js.strong_count(), 0);
}
#[test]
fn test_get_machine_id_uses_uuid_seed_verbatim() {
let raw = "33333333-3333-3333-3333-333333333333".to_string();
set_default_machine_id(Some(raw.clone()));
assert_eq!(get_machine_id(), uuid::Uuid::parse_str(&raw).unwrap());
set_default_machine_id(None);
}
}
+4 -1
View File
@@ -1336,7 +1336,10 @@ async fn run_main(cli: Cli) -> anyhow::Result<()> {
let _web_client = if let Some(config_server_url_s) = cli.config_server.as_ref() {
let wc = web_client::run_web_client(
config_server_url_s,
cli.machine_id.clone(),
crate::common::MachineIdOptions {
explicit_machine_id: cli.machine_id.clone(),
state_dir: None,
},
cli.network_options.hostname.clone(),
cli.network_options.secure_mode.unwrap_or(false),
manager.clone(),
+26
View File
@@ -115,6 +115,12 @@ impl<R> FramedReader<R> {
return Some(Err(TunnelError::InvalidPacket("body too long".to_string())));
}
if body_len < PEER_MANAGER_HEADER_SIZE {
return Some(Err(TunnelError::InvalidPacket(
"body too short".to_string(),
)));
}
if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len {
// body is not complete
return None;
@@ -555,6 +561,26 @@ pub mod tests {
tunnel::{TunnelConnector, TunnelListener, packet_def::ZCPacket},
};
#[cfg(test)]
use crate::tunnel::{
TunnelError,
packet_def::{PEER_MANAGER_HEADER_SIZE, TCP_TUNNEL_HEADER_SIZE},
};
#[test]
fn framed_reader_rejects_short_peer_manager_body() {
let mut buf = BytesMut::new();
buf.put_u32_le((PEER_MANAGER_HEADER_SIZE - 1) as u32);
buf.resize(TCP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE - 1, 0);
let ret = super::FramedReader::<tokio::io::Empty>::extract_one_packet(&mut buf, 2000);
assert!(matches!(
ret,
Some(Err(TunnelError::InvalidPacket(msg))) if msg == "body too short"
));
}
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
let (mut recv, mut send) = tunnel.split();
+7
View File
@@ -9,6 +9,7 @@ use crate::{
pub struct Controller {
token: String,
machine_id: uuid::Uuid,
hostname: String,
device_os: DeviceOsInfo,
manager: Arc<NetworkInstanceManager>,
@@ -18,6 +19,7 @@ pub struct Controller {
impl Controller {
pub fn new(
token: String,
machine_id: uuid::Uuid,
hostname: String,
device_os: DeviceOsInfo,
manager: Arc<NetworkInstanceManager>,
@@ -25,6 +27,7 @@ impl Controller {
) -> Self {
Controller {
token,
machine_id,
hostname,
device_os,
manager,
@@ -44,6 +47,10 @@ impl Controller {
self.hostname.clone()
}
pub fn machine_id(&self) -> uuid::Uuid {
self.machine_id
}
pub fn device_os(&self) -> DeviceOsInfo {
self.device_os.clone()
}
+19 -6
View File
@@ -2,11 +2,12 @@ use std::sync::Arc;
use crate::{
common::{
MachineIdOptions,
config::TomlConfigLoader,
global_ctx::{ArcGlobalCtx, GlobalCtx},
log,
os_info::collect_device_os_info,
set_default_machine_id,
resolve_machine_id,
stun::MockStunInfoCollector,
},
connector::create_connector_by_url,
@@ -81,6 +82,7 @@ impl WebClient {
pub fn new<T: TunnelConnector + 'static, S: ToString, H: ToString>(
connector: T,
token: S,
machine_id: Uuid,
hostname: H,
secure_mode: bool,
manager: Arc<NetworkInstanceManager>,
@@ -90,6 +92,7 @@ impl WebClient {
let hooks = hooks.unwrap_or_else(|| Arc::new(DefaultHooks));
let controller = Arc::new(controller::Controller::new(
token.to_string(),
machine_id,
hostname.to_string(),
collect_device_os_info(),
manager,
@@ -229,13 +232,14 @@ impl WebClient {
pub async fn run_web_client(
config_server_url_s: &str,
machine_id: Option<String>,
machine_id_opts: MachineIdOptions,
hostname: Option<String>,
secure_mode: bool,
manager: Arc<NetworkInstanceManager>,
hooks: Option<Arc<dyn WebClientHooks>>,
) -> Result<WebClient> {
set_default_machine_id(machine_id);
let machine_id = resolve_machine_id(&machine_id_opts)
.with_context(|| "failed to resolve machine id for web client")?;
let config_server_url = match Url::parse(config_server_url_s) {
Ok(u) => u,
Err(_) => format!(
@@ -289,6 +293,7 @@ pub async fn run_web_client(
global_ctx,
},
token.to_string(),
machine_id,
hostname,
secure_mode,
manager,
@@ -300,14 +305,18 @@ pub async fn run_web_client(
mod tests {
use std::sync::{Arc, atomic::AtomicBool};
use crate::instance_manager::NetworkInstanceManager;
use crate::{common::MachineIdOptions, instance_manager::NetworkInstanceManager};
#[tokio::test]
async fn test_manager_wait() {
let manager = Arc::new(NetworkInstanceManager::new());
let temp_dir = tempfile::tempdir().unwrap();
let client = super::run_web_client(
format!("ring://{}/test", uuid::Uuid::new_v4()).as_str(),
None,
MachineIdOptions {
explicit_machine_id: None,
state_dir: Some(temp_dir.path().to_path_buf()),
},
None,
false,
manager.clone(),
@@ -335,9 +344,13 @@ mod tests {
#[tokio::test]
async fn test_run_web_client_with_unreachable_config_server() {
let manager = Arc::new(NetworkInstanceManager::new());
let temp_dir = tempfile::tempdir().unwrap();
let client = super::run_web_client(
"udp://config-server.invalid:22020/test",
None,
MachineIdOptions {
explicit_machine_id: None,
state_dir: Some(temp_dir.path().to_path_buf()),
},
None,
false,
manager,
+48 -7
View File
@@ -101,7 +101,11 @@ impl TunnelFilter for SecureDatagramTunnelFilter {
Err(e) => return Some(Err(e)),
};
let mut cipher = ZCPacket::new_with_payload(packet.payload());
let payload = match checked_payload(&packet, "secure packet") {
Ok(v) => v,
Err(e) => return Some(Err(e)),
};
let mut cipher = ZCPacket::new_with_payload(payload);
cipher.fill_peer_manager_hdr(0, 0, PacketType::Data as u8);
cipher
.mut_peer_manager_header()
@@ -116,15 +120,27 @@ impl TunnelFilter for SecureDatagramTunnelFilter {
))));
}
Some(Ok(ZCPacket::new_from_buf(
cipher.payload_bytes(),
ZCPacketType::DummyTunnel,
)))
let packet = ZCPacket::new_from_buf(cipher.payload_bytes(), ZCPacketType::DummyTunnel);
if packet.peer_manager_header().is_none() {
return Some(Err(TunnelError::InvalidPacket(
"decrypted secure packet too short".to_string(),
)));
}
Some(Ok(packet))
}
fn filter_output(&self) {}
}
fn checked_payload<'a>(packet: &'a ZCPacket, context: &str) -> Result<&'a [u8], TunnelError> {
if packet.peer_manager_header().is_none() {
return Err(TunnelError::InvalidPacket(format!("{context} too short")));
}
Ok(packet.payload())
}
fn pack_control_packet(payload: &[u8]) -> ZCPacket {
let mut packet = ZCPacket::new_with_payload(payload);
packet.fill_peer_manager_hdr(0, 0, PacketType::Data as u8);
@@ -214,7 +230,8 @@ pub async fn upgrade_client_tunnel(
Ok(None) => return Err(TunnelError::Shutdown),
Err(error) => return Err(error.into()),
};
let msg2_cipher = decode_noise_payload(msg2_packet.payload())
let msg2_payload = checked_payload(&msg2_packet, "noise msg2 packet")?;
let msg2_cipher = decode_noise_payload(msg2_payload)
.ok_or_else(|| TunnelError::InvalidPacket("invalid noise msg2 magic".to_string()))?;
let mut root_key_buf = [0u8; 32];
let root_key_len = state
@@ -254,7 +271,8 @@ pub async fn accept_or_upgrade_server_tunnel(
));
}
};
let Some(msg1_cipher) = decode_noise_payload(first_packet.payload()) else {
let first_payload = checked_payload(&first_packet, "first packet")?;
let Some(msg1_cipher) = decode_noise_payload(first_payload) else {
let stream = Box::pin(futures::stream::once(async move { Ok(first_packet) }).chain(stream));
return Ok((
Box::new(RawSplitTunnel::new(info, stream, sink)) as Box<dyn Tunnel>,
@@ -303,6 +321,7 @@ pub async fn accept_or_upgrade_server_tunnel(
mod tests {
use super::*;
use crate::tunnel::ring::create_ring_tunnel_pair;
use bytes::BytesMut;
#[test]
fn web_secure_cipher_algorithm_matches_support_flag() {
@@ -338,6 +357,28 @@ mod tests {
assert!(matches!(err, TunnelError::Timeout(_)));
}
#[tokio::test]
async fn accept_secure_tunnel_rejects_short_first_packet() {
let (server_tunnel, client_tunnel) = create_ring_tunnel_pair();
let server_task =
tokio::spawn(async move { accept_or_upgrade_server_tunnel(server_tunnel).await });
let (_stream, mut sink) = client_tunnel.split();
sink.send(ZCPacket::new_from_buf(
BytesMut::from(&b"\x01"[..]),
ZCPacketType::TCP,
))
.await
.unwrap();
let err = server_task.await.unwrap().unwrap_err();
assert!(matches!(
err,
TunnelError::InvalidPacket(msg) if msg == "first packet too short"
));
}
#[tokio::test]
async fn accept_secure_tunnel_after_short_client_delay() {
let (server_tunnel, client_tunnel) = create_ring_tunnel_pair();
+7 -5
View File
@@ -7,7 +7,7 @@ use tokio::{
};
use crate::{
common::{constants::EASYTIER_VERSION, get_machine_id},
common::constants::EASYTIER_VERSION,
proto::{
rpc_impl::bidirect::BidirectRpcManager,
rpc_types::controller::BaseController,
@@ -65,11 +65,13 @@ impl Session {
tasks: &mut JoinSet<()>,
ctx: HeartbeatCtx,
) {
let mid = get_machine_id();
let controller = controller.upgrade().unwrap();
let mid = controller.machine_id();
let inst_id = uuid::Uuid::new_v4();
let token = controller.upgrade().unwrap().token();
let hostname = controller.upgrade().unwrap().hostname();
let device_os = controller.upgrade().unwrap().device_os();
let token = controller.token();
let hostname = controller.hostname();
let device_os = controller.device_os();
let controller = Arc::downgrade(&controller);
let ctx_clone = ctx.clone();
let mut tick = interval(std::time::Duration::from_secs(1));