Introduce secure mode (part 1) (#1808)

Use noise protocol on handshake. Check peer's public key if needed. Also support rekey and replay attack prevention.

E2EE and temporary password will be implemented based on this.
This commit is contained in:
KKRainbow
2026-01-25 20:16:51 +08:00
committed by GitHub
parent ffa08d1c43
commit 101f416268
29 changed files with 3320 additions and 91 deletions
+15 -1
View File
@@ -14,7 +14,7 @@ use crate::{
instance::dns_server::DEFAULT_ET_DNS_ZONE,
proto::{
acl::Acl,
common::{CompressionAlgoPb, PortForwardConfigPb, SocketType},
common::{CompressionAlgoPb, PortForwardConfigPb, SecureModeConfig, SocketType},
},
tunnel::generate_digest_from_str,
};
@@ -209,6 +209,9 @@ pub trait ConfigLoader: Send + Sync {
fn get_stun_servers_v6(&self) -> Option<Vec<String>>;
fn set_stun_servers_v6(&self, servers: Option<Vec<String>>);
fn get_secure_mode(&self) -> Option<SecureModeConfig>;
fn set_secure_mode(&self, secure_mode: Option<SecureModeConfig>);
fn dump(&self) -> String;
}
@@ -300,6 +303,7 @@ impl Default for NetworkIdentity {
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct PeerConfig {
pub uri: url::Url,
pub peer_public_key: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
@@ -407,6 +411,8 @@ struct Config {
port_forward: Option<Vec<PortForwardConfig>>,
secure_mode: Option<SecureModeConfig>,
flags: Option<serde_json::Map<String, serde_json::Value>>,
#[serde(skip)]
@@ -802,6 +808,14 @@ impl ConfigLoader for TomlConfigLoader {
self.config.lock().unwrap().stun_servers_v6 = servers;
}
fn get_secure_mode(&self) -> Option<SecureModeConfig> {
self.config.lock().unwrap().secure_mode.clone()
}
fn set_secure_mode(&self, secure_mode: Option<SecureModeConfig>) {
self.config.lock().unwrap().secure_mode = secure_mode;
}
fn dump(&self) -> String {
let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap();
let default_flags_hashmap =
+3
View File
@@ -29,6 +29,9 @@ define_global_var!(MAX_DIRECT_CONNS_PER_PEER_IN_FOREIGN_NETWORK, u32, 3);
define_global_var!(DIRECT_CONNECT_TO_PUBLIC_SERVER, bool, true);
// must make it true in future.
define_global_var!(HMAC_SECRET_DIGEST, bool, false);
pub const UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID: u32 = 2;
pub const WIN_SERVICE_WORK_DIR_REG_KEY: &str = "SOFTWARE\\EasyTier\\Service\\WorkDir";
+3
View File
@@ -48,6 +48,9 @@ pub enum Error {
#[error("secret key error: {0}")]
SecretKeyError(String),
#[error("noise protocol error: {0}")]
NoiseError(#[from] snow::Error),
}
pub type Result<T> = result::Result<T, Error>;
+11
View File
@@ -15,6 +15,8 @@ use crate::proto::api::instance::PeerConnInfo;
use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb};
use crate::proto::peer_rpc::PeerGroupInfo;
use crossbeam::atomic::AtomicCell;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use super::{
config::{ConfigLoader, Flags},
@@ -268,6 +270,15 @@ impl GlobalCtx {
self.config.get_network_identity()
}
pub fn get_secret_proof(&self, challenge: &[u8]) -> Option<Hmac<Sha256>> {
let network_secret = self.get_network_identity().network_secret?;
let key = network_secret.as_bytes();
let mut mac = Hmac::<Sha256>::new_from_slice(key).unwrap();
mac.update(b"easytier secret proof");
mac.update(challenge);
Some(mac)
}
pub fn get_network_name(&self) -> String {
self.get_network_identity().network_name
}
+5 -2
View File
@@ -186,7 +186,9 @@ impl DirectConnectorManagerData {
.await?;
// NOTICE: must add as directly connected tunnel
self.peer_manager.add_client_tunnel(ret, true).await
self.peer_manager
.add_client_tunnel_with_peer_id_hint(ret, true, Some(dst_peer_id))
.await
}
async fn do_try_connect_to_ip(&self, dst_peer_id: PeerId, addr: String) -> Result<(), Error> {
@@ -199,7 +201,8 @@ impl DirectConnectorManagerData {
} else {
timeout(
std::time::Duration::from_secs(3),
self.peer_manager.try_direct_connect(connector),
self.peer_manager
.try_direct_connect_with_peer_id_hint(connector, Some(dst_peer_id)),
)
.await??
};
+75 -1
View File
@@ -19,15 +19,17 @@ use crate::{
defer,
instance_manager::NetworkInstanceManager,
launcher::add_proxy_network_to_config,
proto::common::CompressionAlgoPb,
proto::common::{CompressionAlgoPb, SecureModeConfig},
rpc_service::ApiRpcServer,
tunnel::PROTO_PORT_OFFSET,
utils::{init_logger, setup_panic_handler},
web_client, ShellType,
};
use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine as _};
use cidr::IpCidr;
use clap::{CommandFactory, Parser};
use rand::rngs::OsRng;
use rust_i18n::t;
use tokio::io::AsyncReadExt;
@@ -600,6 +602,29 @@ struct NetworkOptions {
num_args = 0..
)]
stun_servers_v6: Option<Vec<String>>,
#[arg(
long,
env = "ET_SECURE_MODE",
help = t!("core_clap.secure_mode").to_string(),
num_args = 0..=1,
default_missing_value = "true"
)]
secure_mode: Option<bool>,
#[arg(
long,
env = "ET_LOCAL_PRIVATE_KEY",
help = t!("core_clap.local_private_key").to_string()
)]
local_private_key: Option<String>,
#[arg(
long,
env = "ET_LOCAL_PUBLIC_KEY",
help = t!("core_clap.local_public_key").to_string()
)]
local_public_key: Option<String>,
}
#[derive(Parser, Debug)]
@@ -723,6 +748,42 @@ impl NetworkOptions {
false
}
fn process_secure_mode_cfg(mut user_cfg: SecureModeConfig) -> anyhow::Result<SecureModeConfig> {
if !user_cfg.enabled {
return Ok(user_cfg);
}
let private_key = if user_cfg.local_private_key.is_none() {
// if no private key, generate random one
let private = x25519_dalek::StaticSecret::random_from_rng(OsRng);
user_cfg.local_private_key = Some(BASE64_STANDARD.encode(private.clone().as_bytes()));
private
} else {
// check if private key is valid
user_cfg.private_key()?
};
let public = x25519_dalek::PublicKey::from(&private_key);
match user_cfg.local_public_key {
None => {
user_cfg.local_public_key = Some(BASE64_STANDARD.encode(public.as_bytes()));
}
Some(ref user_pub) => {
let public = user_cfg.public_key()?;
if *user_pub != BASE64_STANDARD.encode(public.as_bytes()) {
return Err(anyhow::anyhow!(
"local public key {} does not match generated public key {}",
user_pub,
BASE64_STANDARD.encode(public.as_bytes())
));
}
}
}
Ok(user_cfg)
}
fn merge_into(&self, cfg: &TomlConfigLoader) -> anyhow::Result<()> {
if self.hostname.is_some() {
cfg.set_hostname(self.hostname.clone());
@@ -760,6 +821,7 @@ impl NetworkOptions {
uri: p
.parse()
.with_context(|| format!("failed to parse peer uri: {}", p))?,
peer_public_key: None,
});
}
cfg.set_peers(peers);
@@ -820,6 +882,7 @@ impl NetworkOptions {
uri: external_nodes.parse().with_context(|| {
format!("failed to parse external node uri: {}", external_nodes)
})?,
peer_public_key: None,
});
cfg.set_peers(old_peers);
}
@@ -902,6 +965,17 @@ impl NetworkOptions {
cfg.set_port_forwards(old);
}
if let Some(secure_mode) = self.secure_mode {
if secure_mode {
let c = SecureModeConfig {
enabled: secure_mode,
local_private_key: self.local_private_key.clone(),
local_public_key: self.local_public_key.clone(),
};
cfg.set_secure_mode(Some(Self::process_secure_mode_cfg(c)?));
}
}
let mut f = cfg.get_flags();
if let Some(default_protocol) = &self.default_protocol {
f.default_protocol = default_protocol.clone()
+19 -2
View File
@@ -536,6 +536,7 @@ impl NetworkConfig {
uri: public_server_url.parse().with_context(|| {
format!("failed to parse public server uri: {}", public_server_url)
})?,
peer_public_key: None,
}]);
}
NetworkingMethod::Manual => {
@@ -548,6 +549,7 @@ impl NetworkConfig {
uri: peer_url
.parse()
.with_context(|| format!("failed to parse peer uri: {}", peer_url))?,
peer_public_key: None,
});
}
@@ -673,6 +675,8 @@ impl NetworkConfig {
));
}
cfg.set_secure_mode(self.secure_mode.clone());
let mut flags = gen_default_flags();
if let Some(latency_first) = self.latency_first {
flags.latency_first = latency_first;
@@ -897,6 +901,8 @@ impl NetworkConfig {
result.mapped_listeners = mapped_listeners.iter().map(|l| l.to_string()).collect();
}
result.secure_mode = config.get_secure_mode();
let flags = config.get_flags();
result.latency_first = Some(flags.latency_first);
result.dev_name = Some(flags.dev_name.clone());
@@ -944,7 +950,7 @@ impl NetworkConfig {
#[cfg(test)]
mod tests {
use crate::common::config::ConfigLoader;
use crate::{common::config::ConfigLoader, proto::common::SecureModeConfig};
use rand::Rng;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
@@ -1018,7 +1024,10 @@ mod tests {
let uri = format!("{}://127.0.0.1:{}", protocol, port)
.parse()
.unwrap();
peers.push(crate::common::config::PeerConfig { uri });
peers.push(crate::common::config::PeerConfig {
uri,
peer_public_key: None,
});
}
config.set_peers(peers);
@@ -1140,6 +1149,14 @@ mod tests {
config.set_mapped_listeners(Some(mapped_listeners));
}
if rng.gen_bool(0.3) {
config.set_secure_mode(Some(SecureModeConfig {
enabled: true,
local_private_key: None,
local_public_key: None,
}));
}
if rng.gen_bool(0.9) {
let mut flags = crate::common::config::gen_default_flags();
flags.latency_first = rng.gen_bool(0.5);
+56 -7
View File
@@ -84,6 +84,14 @@ impl Encryptor for AesGcmCipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -91,16 +99,28 @@ impl Encryptor for AesGcmCipher {
}
let mut tail = AesGcmTail::default();
if let Some(nonce) = nonce {
if nonce.len() != tail.nonce.len() {
return Err(Error::EncryptionFailed);
}
tail.nonce.copy_from_slice(nonce);
}
let rs = match &self.cipher {
AesGcmEnum::AES128GCM(aes_gcm) => {
let nonce = Aes128Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
aes_gcm.encrypt_in_place_detached(&nonce, &[], zc_packet.mut_payload())
if nonce.is_none() {
let nonce = Aes128Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
}
let nonce = Nonce::from_slice(&tail.nonce);
aes_gcm.encrypt_in_place_detached(nonce, &[], zc_packet.mut_payload())
}
AesGcmEnum::AES256GCM(aes_gcm) => {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
aes_gcm.encrypt_in_place_detached(&nonce, &[], zc_packet.mut_payload())
if nonce.is_none() {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
}
let nonce = Nonce::from_slice(&tail.nonce);
aes_gcm.encrypt_in_place_detached(nonce, &[], zc_packet.mut_payload())
}
};
@@ -122,8 +142,9 @@ impl Encryptor for AesGcmCipher {
mod tests {
use crate::{
peers::encrypt::{aes_gcm::AesGcmCipher, Encryptor},
tunnel::packet_def::{ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
tunnel::packet_def::{AesGcmTail, ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
};
use zerocopy::FromBytes;
#[test]
fn test_aes_gcm_cipher() {
@@ -143,4 +164,32 @@ mod tests {
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_aes_gcm_cipher_with_nonce() {
let key = [7u8; 16];
let cipher = AesGcmCipher::new_128(key);
let text = b"Hello";
let nonce = [3u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
let tail = AesGcmTail::ref_from_suffix(packet1.payload()).unwrap();
assert_eq!(tail.nonce, nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
}
}
+7
View File
@@ -30,6 +30,13 @@ pub enum Error {
pub trait Encryptor: Send + Sync + 'static {
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
_nonce: Option<&[u8]>,
) -> Result<(), Error> {
self.encrypt(zc_packet)
}
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
}
+48 -1
View File
@@ -142,6 +142,14 @@ impl Encryptor for OpenSslCipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -153,7 +161,14 @@ impl Encryptor for OpenSslCipher {
let nonce_size = self.get_nonce_size();
let mut tail = OpenSslTail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce[..nonce_size]);
if let Some(nonce) = nonce {
if nonce.len() != nonce_size {
return Err(Error::EncryptionFailed);
}
tail.nonce[..nonce_size].copy_from_slice(nonce);
} else {
rand::thread_rng().fill_bytes(&mut tail.nonce[..nonce_size]);
}
let mut encrypter =
Crypter::new(cipher, Mode::Encrypt, key, Some(&tail.nonce[..nonce_size]))
@@ -198,6 +213,7 @@ mod tests {
peers::encrypt::{openssl_cipher::OpenSslCipher, Encryptor},
tunnel::packet_def::ZCPacket,
};
use zerocopy::FromBytes;
use super::OPENSSL_ENCRYPTION_RESERVED;
@@ -220,6 +236,37 @@ mod tests {
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_openssl_aes128_gcm_with_nonce() {
let key = [7u8; 16];
let cipher = OpenSslCipher::new_aes128_gcm(key);
let text = b"Hello";
let nonce = [3u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
assert!(packet1.payload().len() > text.len() + OPENSSL_ENCRYPTION_RESERVED);
let tail = super::OpenSslTail::ref_from_suffix(packet1.payload())
.unwrap()
.clone();
assert_eq!(&tail.nonce[..nonce.len()], nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
}
#[test]
fn test_openssl_chacha20() {
let key = [0u8; 32];
+46 -2
View File
@@ -93,6 +93,14 @@ impl Encryptor for AesGcmCipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -100,7 +108,14 @@ impl Encryptor for AesGcmCipher {
}
let mut tail = AesGcmTail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce);
if let Some(nonce) = nonce {
if nonce.len() != tail.nonce.len() {
return Err(Error::EncryptionFailed);
}
tail.nonce.copy_from_slice(nonce);
} else {
rand::thread_rng().fill_bytes(&mut tail.nonce);
}
let nonce = aead::Nonce::assume_unique_for_key(tail.nonce);
let rs = match &self.cipher {
@@ -137,8 +152,9 @@ impl Encryptor for AesGcmCipher {
mod tests {
use crate::{
peers::encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor},
tunnel::packet_def::{ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
tunnel::packet_def::{AesGcmTail, ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
};
use zerocopy::FromBytes;
#[test]
fn test_aes_gcm_cipher() {
@@ -158,4 +174,32 @@ mod tests {
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_aes_gcm_cipher_with_nonce() {
let key = [7u8; 16];
let cipher = AesGcmCipher::new_128(key);
let text = b"Hello";
let nonce = [3u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
let tail = AesGcmTail::ref_from_suffix(packet1.payload()).unwrap();
assert_eq!(tail.nonce, nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
}
}
+46 -1
View File
@@ -67,6 +67,14 @@ impl Encryptor for RingChaCha20Cipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -74,7 +82,14 @@ impl Encryptor for RingChaCha20Cipher {
}
let mut tail = ChaCha20Poly1305Tail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce);
if let Some(nonce) = nonce {
if nonce.len() != tail.nonce.len() {
return Err(Error::EncryptionFailed);
}
tail.nonce.copy_from_slice(nonce);
} else {
rand::thread_rng().fill_bytes(&mut tail.nonce);
}
let nonce = Nonce::assume_unique_for_key(tail.nonce);
let rs =
@@ -100,6 +115,7 @@ mod tests {
peers::encrypt::{ring_chacha20::RingChaCha20Cipher, Encryptor},
tunnel::packet_def::ZCPacket,
};
use zerocopy::FromBytes;
use super::CHACHA20_POLY1305_ENCRYPTION_RESERVED;
@@ -122,4 +138,33 @@ mod tests {
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_ring_chacha20_cipher_with_nonce() {
let key = [9u8; 32];
let cipher = RingChaCha20Cipher::new(key);
let text = b"Hello";
let nonce = [5u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
let tail = super::ChaCha20Poly1305Tail::ref_from_suffix(packet1.payload()).unwrap();
assert_eq!(tail.nonce, nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
assert!(!packet1.peer_manager_header().unwrap().is_encrypted());
}
}
+1 -1
View File
@@ -2,7 +2,6 @@ mod graph_algo;
pub mod acl_filter;
pub mod peer;
// pub mod peer_conn;
pub mod peer_conn;
pub mod peer_conn_ping;
pub mod peer_manager;
@@ -10,6 +9,7 @@ pub mod peer_map;
pub mod peer_ospf_route;
pub mod peer_rpc;
pub mod peer_rpc_service;
pub mod peer_session;
pub mod route_trait;
pub mod rpc_service;
+15 -6
View File
@@ -238,12 +238,12 @@ impl Drop for Peer {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::time::timeout;
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
peers::{create_packet_recv_chan, peer_conn::PeerConn},
peers::{create_packet_recv_chan, peer_conn::PeerConn, peer_session::PeerSessionStore},
tunnel::ring::create_ring_tunnel_pair,
};
@@ -257,11 +257,20 @@ mod tests {
let local_peer = Peer::new(new_peer_id(), local_packet_send, global_ctx.clone());
let remote_peer = Peer::new(new_peer_id(), remote_packet_send, global_ctx.clone());
let ps = Arc::new(PeerSessionStore::new());
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
let mut local_peer_conn =
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
let mut remote_peer_conn =
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
let mut local_peer_conn = PeerConn::new(
local_peer.peer_node_id,
global_ctx.clone(),
local_tunnel,
ps.clone(),
);
let mut remote_peer_conn = PeerConn::new(
remote_peer.peer_node_id,
global_ctx.clone(),
remote_tunnel,
ps.clone(),
);
assert!(!local_peer_conn.handshake_done());
assert!(!remote_peer_conn.handshake_done());
File diff suppressed because it is too large Load Diff
+328 -11
View File
@@ -32,6 +32,7 @@ use crate::{
peers::{
peer_conn::PeerConn,
peer_rpc::PeerRpcManagerTransport,
peer_session::PeerSessionStore,
recv_packet_from_chan,
route_trait::{ForeignNetworkRouteInfoMap, MockRoute, NextHopPolicy, RouteInterface},
PeerPacketFilter,
@@ -160,6 +161,8 @@ pub struct PeerManager {
allow_loopback_tunnel: AtomicBool,
self_tx_counters: SelfTxCounters,
peer_session_store: Arc<PeerSessionStore>,
}
impl Debug for PeerManager {
@@ -312,6 +315,8 @@ impl PeerManager {
allow_loopback_tunnel: AtomicBool::new(true),
self_tx_counters,
peer_session_store: Arc::new(PeerSessionStore::new()),
}
}
@@ -363,7 +368,23 @@ impl PeerManager {
tunnel: Box<dyn Tunnel>,
is_directly_connected: bool,
) -> Result<(PeerId, PeerConnId), Error> {
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
self.add_client_tunnel_with_peer_id_hint(tunnel, is_directly_connected, None)
.await
}
pub async fn add_client_tunnel_with_peer_id_hint(
&self,
tunnel: Box<dyn Tunnel>,
is_directly_connected: bool,
peer_id_hint: Option<PeerId>,
) -> Result<(PeerId, PeerConnId), Error> {
let mut peer = PeerConn::new_with_peer_id_hint(
self.my_peer_id,
self.global_ctx.clone(),
tunnel,
peer_id_hint,
self.peer_session_store.clone(),
);
peer.set_is_hole_punched(!is_directly_connected);
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
@@ -387,9 +408,19 @@ impl PeerManager {
}
#[tracing::instrument]
pub async fn try_direct_connect<C>(
pub async fn try_direct_connect<C>(&self, connector: C) -> Result<(PeerId, PeerConnId), Error>
where
C: TunnelConnector + Debug,
{
self.try_direct_connect_with_peer_id_hint(connector, None)
.await
}
#[tracing::instrument]
pub async fn try_direct_connect_with_peer_id_hint<C>(
&self,
mut connector: C,
peer_id_hint: Option<PeerId>,
) -> Result<(PeerId, PeerConnId), Error>
where
C: TunnelConnector + Debug,
@@ -398,7 +429,8 @@ impl PeerManager {
let t = ns
.run_async(|| async move { connector.connect().await })
.await?;
self.add_client_tunnel(t, true).await
self.add_client_tunnel_with_peer_id_hint(t, true, peer_id_hint)
.await
}
// avoid loop back to virtual network
@@ -447,9 +479,14 @@ impl PeerManager {
tracing::info!("add tunnel as server start");
self.check_remote_addr_not_from_virtual_network(&tunnel)?;
let mut conn = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
conn.do_handshake_as_server_ext(|peer, msg| {
if msg.network_name
let mut conn = PeerConn::new(
self.my_peer_id,
self.global_ctx.clone(),
tunnel,
self.peer_session_store.clone(),
);
conn.do_handshake_as_server_ext(|peer, network_name:&str| {
if network_name
== self.global_ctx.get_network_identity().network_name
{
return Ok(());
@@ -463,9 +500,9 @@ impl PeerManager {
let mut peer_id = self
.foreign_network_manager
.get_network_peer_id(&msg.network_name);
.get_network_peer_id(network_name);
if peer_id.is_none() {
peer_id = Some(*self.reserved_my_peer_id_map.entry(msg.network_name.clone()).or_insert_with(|| {
peer_id = Some(*self.reserved_my_peer_id_map.entry(network_name.to_string()).or_insert_with(|| {
rand::random::<PeerId>()
}).value());
}
@@ -473,7 +510,7 @@ impl PeerManager {
tracing::info!(
?peer_id,
?msg.network_name,
?network_name,
"handshake as server with foreign network, new peer id: {}, peer id in foreign manager: {:?}",
peer.get_my_peer_id(), peer_id
);
@@ -1464,7 +1501,10 @@ mod tests {
use std::{fmt::Debug, sync::Arc, time::Duration};
use crate::{
common::{config::Flags, global_ctx::tests::get_mock_global_ctx},
common::{
config::Flags,
global_ctx::{tests::get_mock_global_ctx, NetworkIdentity},
},
connector::{
create_connector_by_url, direct::PeerManagerForDirectConnector,
udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
@@ -1472,6 +1512,7 @@ mod tests {
instance::listeners::get_listener_by_url,
peers::{
create_packet_recv_chan,
peer_conn::tests::set_secure_mode_cfg,
peer_manager::RouteAlgoType,
peer_rpc::tests::register_service,
route_trait::NextHopPolicy,
@@ -1480,7 +1521,10 @@ mod tests {
wait_route_appear_with_cost,
},
},
proto::common::{CompressionAlgoPb, NatType, PeerFeatureFlag},
proto::{
common::{CompressionAlgoPb, NatType, PeerFeatureFlag},
peer_rpc::SecureAuthLevel,
},
tunnel::{
common::tests::wait_for_condition,
filter::{tests::DropSendTunnelFilter, TunnelWithFilter},
@@ -1523,6 +1567,279 @@ mod tests {
.await;
}
#[tokio::test]
async fn peer_manager_safe_mode_connect_between_peers() {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_a
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
peer_mgr_b
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
set_secure_mode_cfg(&peer_mgr_a.get_global_ctx(), true);
set_secure_mode_cfg(&peer_mgr_b.get_global_ctx(), true);
let (a_ring, b_ring) = create_ring_tunnel_pair();
let (a_ret, b_ret) = tokio::join!(
peer_mgr_a.add_client_tunnel(a_ring, false),
peer_mgr_b.add_tunnel_as_server(b_ring, true)
);
let (peer_b_id, _) = a_ret.unwrap();
b_ret.unwrap();
wait_for_condition(
|| {
let peer_mgr_a = peer_mgr_a.clone();
async move {
if !peer_mgr_a
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&peer_b_id)
{
return false;
}
let Some(conns) = peer_mgr_a.get_peer_map().list_peer_conns(peer_b_id).await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.len() == 32
&& c.noise_remote_static_pubkey.len() == 32
&& c.secure_auth_level == SecureAuthLevel::NetworkSecretConfirmed as i32
})
}
},
Duration::from_secs(10),
)
.await;
let peer_a_id = peer_mgr_a.my_peer_id();
wait_for_condition(
|| {
let peer_mgr_b = peer_mgr_b.clone();
async move {
if !peer_mgr_b
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&peer_a_id)
{
return false;
}
let Some(conns) = peer_mgr_b.get_peer_map().list_peer_conns(peer_a_id).await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.len() == 32
&& c.noise_remote_static_pubkey.len() == 32
&& c.secure_auth_level == SecureAuthLevel::NetworkSecretConfirmed as i32
})
}
},
Duration::from_secs(10),
)
.await;
}
#[tokio::test]
async fn peer_manager_safe_server_accept_legacy_client() {
let peer_mgr_client = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_server = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_client
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
peer_mgr_server
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
set_secure_mode_cfg(&peer_mgr_server.get_global_ctx(), true);
let (c_ring, s_ring) = create_ring_tunnel_pair();
let (c_ret, s_ret) = tokio::join!(
peer_mgr_client.add_client_tunnel(c_ring, false),
peer_mgr_server.add_tunnel_as_server(s_ring, true)
);
let (server_id, _) = c_ret.unwrap();
s_ret.unwrap();
wait_for_condition(
|| {
let peer_mgr_client = peer_mgr_client.clone();
async move {
if !peer_mgr_client
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&server_id)
{
return false;
}
let Some(conns) = peer_mgr_client
.get_peer_map()
.list_peer_conns(server_id)
.await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.is_empty()
&& c.noise_remote_static_pubkey.is_empty()
&& c.secure_auth_level == SecureAuthLevel::None as i32
})
}
},
Duration::from_secs(10),
)
.await;
let client_id = peer_mgr_client.my_peer_id();
wait_for_condition(
|| {
let peer_mgr_server = peer_mgr_server.clone();
async move {
if !peer_mgr_server
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&client_id)
{
return false;
}
let Some(conns) = peer_mgr_server
.get_peer_map()
.list_peer_conns(client_id)
.await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.is_empty()
&& c.noise_remote_static_pubkey.is_empty()
&& c.secure_auth_level == SecureAuthLevel::None as i32
})
}
},
Duration::from_secs(5),
)
.await;
}
#[tokio::test]
async fn peer_manager_safe_mode_shared_node_pinning_connect() {
let peer_mgr_client = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_server = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_client
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("user".to_string(), "sec1".to_string()));
peer_mgr_server
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity {
network_name: "shared".to_string(),
network_secret: None,
network_secret_digest: None,
});
set_secure_mode_cfg(&peer_mgr_client.get_global_ctx(), true);
set_secure_mode_cfg(&peer_mgr_server.get_global_ctx(), true);
let server_pub_b64 = peer_mgr_server
.get_global_ctx()
.config
.get_secure_mode()
.unwrap()
.local_public_key
.unwrap();
let (a_ring, b_ring) = create_ring_tunnel_pair();
let server_remote_url: url::Url = a_ring
.info()
.unwrap()
.remote_addr
.unwrap()
.url
.parse()
.unwrap();
peer_mgr_client.get_global_ctx().config.set_peers(vec![
crate::common::config::PeerConfig {
uri: server_remote_url,
peer_public_key: Some(server_pub_b64.clone()),
},
]);
let (c_ret, s_ret) = tokio::join!(
peer_mgr_client.add_client_tunnel(a_ring, false),
peer_mgr_server.add_tunnel_as_server(b_ring, true)
);
c_ret.unwrap();
s_ret.unwrap();
wait_for_condition(
|| {
let peer_mgr_client = peer_mgr_client.clone();
async move {
let foreign_peer_map =
peer_mgr_client.get_foreign_network_client().get_peer_map();
if foreign_peer_map.list_peers_with_conn().await.len() != 1 {
return false;
}
let Some(peer_id) = foreign_peer_map
.list_peers_with_conn()
.await
.into_iter()
.next()
else {
return false;
};
let Some(conns) = foreign_peer_map.list_peer_conns(peer_id).await else {
return false;
};
conns.iter().any(|c| {
c.secure_auth_level == SecureAuthLevel::SharedNodePubkeyVerified as i32
&& c.noise_local_static_pubkey.len() == 32
&& c.noise_remote_static_pubkey.len() == 32
})
}
},
Duration::from_secs(10),
)
.await;
wait_for_condition(
|| {
let peer_mgr_server = peer_mgr_server.clone();
async move {
let foreigns = peer_mgr_server
.get_foreign_network_manager()
.list_foreign_networks()
.await;
let Some(entry) = foreigns.foreign_networks.get("user") else {
return false;
};
entry.peers.iter().any(|p| {
p.conns
.iter()
.any(|c| c.noise_local_static_pubkey.len() == 32)
})
}
},
Duration::from_secs(10),
)
.await;
}
async fn connect_peer_manager_with<C: TunnelConnector + Debug + 'static, L: TunnelListener>(
client_mgr: Arc<PeerManager>,
server_mgr: &Arc<PeerManager>,
+817
View File
@@ -0,0 +1,817 @@
use std::{
sync::{
atomic::{AtomicU32, Ordering},
Arc, Mutex, RwLock,
},
time::{SystemTime, UNIX_EPOCH},
};
use atomic_shim::AtomicU64;
use anyhow::anyhow;
use dashmap::DashMap;
use hmac::{Hmac, Mac as _};
use rand::RngCore as _;
use sha2::Sha256;
use crate::{
common::PeerId,
peers::encrypt::{create_encryptor, Encryptor},
tunnel::packet_def::{AesGcmTail, ZCPacket},
};
type HmacSha256 = Hmac<Sha256>;
pub struct UpsertResponderSessionReturn {
pub session: Arc<PeerSession>,
pub action: PeerSessionAction,
pub session_generation: u32,
pub root_key: Option<[u8; 32]>,
pub initial_epoch: u32,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PeerSessionAction {
Join,
Sync,
Create,
}
#[derive(PartialEq, Clone, Eq, Hash)]
pub struct SessionKey {
network_name: String,
peer_id: PeerId,
}
impl SessionKey {
pub fn new(network_name: String, peer_id: PeerId) -> Self {
Self {
network_name,
peer_id,
}
}
}
#[derive(Clone)]
pub struct PeerSessionStore {
sessions: Arc<DashMap<SessionKey, Arc<PeerSession>>>,
}
impl Default for PeerSessionStore {
fn default() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
}
impl PeerSessionStore {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, key: &SessionKey) -> Option<Arc<PeerSession>> {
self.sessions.get(key).map(|v| v.clone())
}
pub fn upsert_responder_session(
&self,
key: &SessionKey,
a_session_generation: Option<u32>,
send_algorithm: String,
recv_algorithm: String,
) -> Result<UpsertResponderSessionReturn, anyhow::Error> {
let existing = self.sessions.get(key).map(|v| v.clone());
match existing {
None => {
let root_key = PeerSession::new_root_key();
let session_generation = 1u32;
let initial_epoch = 0u32;
let session = Arc::new(PeerSession::new(
key.peer_id,
root_key,
session_generation,
initial_epoch,
send_algorithm,
recv_algorithm,
));
self.sessions.insert(key.clone(), session.clone());
Ok(UpsertResponderSessionReturn {
session,
action: PeerSessionAction::Create,
session_generation,
root_key: Some(root_key),
initial_epoch,
})
}
Some(session) => {
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
let local_gen = session.session_generation();
if a_session_generation.is_some_and(|g| g == local_gen) {
Ok(UpsertResponderSessionReturn {
session,
action: PeerSessionAction::Join,
session_generation: local_gen,
root_key: None,
initial_epoch: 0,
})
} else {
let initial_epoch = session.next_sync_epoch();
let root_key = session.root_key();
Ok(UpsertResponderSessionReturn {
session,
action: PeerSessionAction::Sync,
session_generation: local_gen,
root_key: Some(root_key),
initial_epoch,
})
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn apply_initiator_action(
&self,
key: &SessionKey,
action: PeerSessionAction,
b_session_generation: u32,
root_key_32: Option<[u8; 32]>,
initial_epoch: u32,
send_algorithm: String,
recv_algorithm: String,
) -> Result<Arc<PeerSession>, anyhow::Error> {
tracing::info!(
"apply_initiator_action {:?}, send_algorithm: {}, recv_algorithm: {}",
action,
send_algorithm,
recv_algorithm
);
match action {
PeerSessionAction::Join => {
let Some(session) = self.get(key) else {
return Err(anyhow!("no local session for JOIN"));
};
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
if session.session_generation() != b_session_generation {
return Err(anyhow!("JOIN generation mismatch"));
}
Ok(session)
}
PeerSessionAction::Sync | PeerSessionAction::Create => {
let root_key = root_key_32.ok_or_else(|| anyhow!("missing root_key"))?;
let session = self
.sessions
.entry(key.clone())
.or_insert_with(|| {
Arc::new(PeerSession::new(
key.peer_id,
root_key,
b_session_generation,
initial_epoch,
send_algorithm.clone(),
recv_algorithm.clone(),
))
})
.clone();
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
session.sync_root_key(root_key, b_session_generation, initial_epoch);
Ok(session)
}
}
}
}
#[derive(Clone, Default)]
struct EpochKeySlot {
epoch: u32,
generation: u32,
valid: bool,
send_cipher: Option<Arc<dyn Encryptor>>,
recv_cipher: Option<Arc<dyn Encryptor>>,
}
impl std::fmt::Debug for EpochKeySlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EpochKeySlot")
.field("epoch", &self.epoch)
.field("generation", &self.generation)
.field("valid", &self.valid)
.finish()
}
}
impl EpochKeySlot {
fn get_encryptor(&self, is_send: bool) -> Arc<dyn Encryptor> {
if is_send {
self.send_cipher.as_ref().unwrap().clone()
} else {
self.recv_cipher.as_ref().unwrap().clone()
}
}
}
#[derive(Debug, Clone, Copy, Default)]
struct ReplayWindow256 {
max_seq: u64,
bitmap: [u8; 32],
valid: bool,
}
impl ReplayWindow256 {
fn clear(&mut self) {
self.max_seq = 0;
self.bitmap.fill(0);
self.valid = false;
}
fn test_bit(&self, idx: usize) -> bool {
let byte = idx / 8;
let bit = idx % 8;
(self.bitmap[byte] >> bit) & 1 == 1
}
fn set_bit(&mut self, idx: usize) {
let byte = idx / 8;
let bit = idx % 8;
self.bitmap[byte] |= 1u8 << bit;
}
fn shift_right(&mut self, shift: usize) {
if shift == 0 {
return;
}
let total_bits = 256usize;
if shift >= total_bits {
self.bitmap.fill(0);
return;
}
let byte_shift = shift / 8;
let bit_shift = shift % 8;
if byte_shift > 0 {
for i in (0..self.bitmap.len()).rev() {
self.bitmap[i] = if i >= byte_shift {
self.bitmap[i - byte_shift]
} else {
0
};
}
}
if bit_shift > 0 {
let mut carry = 0u8;
for b in self.bitmap.iter_mut().rev() {
let new_carry = *b << (8 - bit_shift);
*b = (*b >> bit_shift) | carry;
carry = new_carry;
}
}
}
fn accept(&mut self, seq: u64) -> bool {
if !self.valid {
self.valid = true;
self.max_seq = seq;
self.set_bit(0);
return true;
}
if seq > self.max_seq {
let shift = (seq - self.max_seq) as usize;
self.shift_right(shift);
self.max_seq = seq;
self.set_bit(0);
return true;
}
let delta = (self.max_seq - seq) as usize;
if delta >= 256 {
return false;
}
if self.test_bit(delta) {
return false;
}
self.set_bit(delta);
true
}
}
#[derive(Debug, Clone, Copy, Default)]
struct EpochRxSlot {
epoch: u32,
window: ReplayWindow256,
last_rx_ms: u64,
valid: bool,
}
impl EpochRxSlot {
fn clear(&mut self) {
self.epoch = 0;
self.window.clear();
self.last_rx_ms = 0;
self.valid = false;
}
}
pub struct PeerSession {
peer_id: PeerId,
root_key: RwLock<[u8; 32]>,
session_generation: AtomicU32,
send_epoch: AtomicU32,
send_seq: [AtomicU64; 2],
send_epoch_started_ms: AtomicU64,
send_packets_since_epoch: AtomicU64,
rx_slots: Mutex<[[EpochRxSlot; 2]; 2]>,
key_cache: Mutex<[[EpochKeySlot; 2]; 2]>,
send_cipher_algorithm: String,
recv_cipher_algorithm: String,
}
impl std::fmt::Debug for PeerSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerSession")
.field("peer_id", &self.peer_id)
.field("root_key", &self.root_key)
.field("session_generation", &self.session_generation)
.field("send_epoch", &self.send_epoch)
.field("send_seq", &self.send_seq)
.field("send_epoch_started_ms", &self.send_epoch_started_ms)
.field("send_packets_since_epoch", &self.send_packets_since_epoch)
.field("rx_slots", &self.rx_slots)
.field("key_cache", &self.key_cache)
.field("send_cipher_algorithm", &self.send_cipher_algorithm)
.field("recv_cipher_algorithm", &self.recv_cipher_algorithm)
.finish()
}
}
impl PeerSession {
/// Idle-eviction timeout for receive slots, in milliseconds.
///
/// If no packets are received for this period (~30 seconds), the
/// corresponding RX slot is considered idle and may be cleared/reused.
/// This helps reclaim state for dead peers or paths while still tolerating
/// short network stalls. Environments with very bursty or high-latency
/// traffic may want to increase this value; low-latency or tightly
/// resource-constrained deployments may lower it.
const EVICT_IDLE_AFTER_MS: u64 = 30_000;
/// Maximum number of packets to send in a single epoch before forcing
/// a key/epoch rotation.
///
/// This bounds the amount of traffic protected under a single set of
/// derived keys, which is a common best practice for long-lived secure
/// channels. The current value (~1 million packets) is a conservative
/// default chosen to balance security (more frequent rotation) and
/// performance (avoiding excessive rekeying). Deployments with very high
/// or very low packet rates may tune this threshold accordingly.
const ROTATE_AFTER_PACKETS: u64 = 1_000_000;
/// Maximum wall-clock lifetime of a send epoch, in milliseconds.
///
/// Even if the packet-based limit is not reached, epochs are rotated
/// after this duration (~10 minutes) to avoid long-lived keys and keep
/// replay windows bounded in time. This also limits the impact of a
/// compromised key. Installations that prioritize lower overhead over
/// more aggressive key rotation may increase this value; those with
/// stricter security requirements may decrease it.
const ROTATE_AFTER_MS: u64 = 10 * 60 * 1000;
const MAX_ACCEPTED_RX_EPOCH_AHEAD: u32 = 3;
pub fn new(
peer_id: PeerId,
root_key: [u8; 32],
session_generation: u32,
initial_epoch: u32,
send_cipher_algorithm: String,
recv_cipher_algorithm: String,
) -> Self {
// let mut root_key_128 = [0u8; 16];
// root_key_128.copy_from_slice(&root_key[..16]);
// let send_cipher = create_encryptor(&send_algorithm, root_key_128, root_key);
// let recv_cipher = create_encryptor(&recv_algorithm, root_key_128, root_key);
let rx_slots = [
[EpochRxSlot::default(), EpochRxSlot::default()],
[EpochRxSlot::default(), EpochRxSlot::default()],
];
let key_cache = [
[EpochKeySlot::default(), EpochKeySlot::default()],
[EpochKeySlot::default(), EpochKeySlot::default()],
];
let now_ms = now_ms();
Self {
peer_id,
root_key: RwLock::new(root_key),
session_generation: AtomicU32::new(session_generation),
send_epoch: AtomicU32::new(initial_epoch),
send_seq: [AtomicU64::new(0), AtomicU64::new(0)],
send_epoch_started_ms: AtomicU64::new(now_ms),
send_packets_since_epoch: AtomicU64::new(0),
rx_slots: Mutex::new(rx_slots),
key_cache: Mutex::new(key_cache),
send_cipher_algorithm,
recv_cipher_algorithm,
}
}
pub fn peer_id(&self) -> PeerId {
self.peer_id
}
pub fn session_generation(&self) -> u32 {
self.session_generation.load(Ordering::Relaxed)
}
pub fn root_key(&self) -> [u8; 32] {
*self.root_key.read().unwrap()
}
pub fn new_root_key() -> [u8; 32] {
let mut out = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut out);
out
}
pub fn next_sync_epoch(&self) -> u32 {
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
let rx = self.rx_slots.lock().unwrap();
let mut max_epoch = send_epoch;
for dir in 0..2 {
let cur = rx[dir][0];
if cur.valid {
max_epoch = max_epoch.max(cur.epoch);
}
let prev = rx[dir][1];
if prev.valid {
max_epoch = max_epoch.max(prev.epoch);
}
}
max_epoch.wrapping_add(1)
}
pub fn check_encrypt_algo_same(
&self,
send_algorithm: &str,
recv_algorithm: &str,
) -> Result<(), anyhow::Error> {
if self.send_cipher_algorithm != send_algorithm
|| self.recv_cipher_algorithm != recv_algorithm
{
return Err(anyhow!("encrypt algorithm not same"));
}
Ok(())
}
pub fn sync_root_key(&self, root_key: [u8; 32], session_generation: u32, initial_epoch: u32) {
{
let mut g = self.root_key.write().unwrap();
*g = root_key;
}
self.session_generation
.store(session_generation, Ordering::Relaxed);
self.send_epoch.store(initial_epoch, Ordering::Relaxed);
self.send_seq[0].store(0, Ordering::Relaxed);
self.send_seq[1].store(0, Ordering::Relaxed);
self.send_epoch_started_ms
.store(now_ms(), Ordering::Relaxed);
self.send_packets_since_epoch.store(0, Ordering::Relaxed);
{
let mut rx = self.rx_slots.lock().unwrap();
for dir in 0..2 {
rx[dir][0] = EpochRxSlot {
epoch: initial_epoch,
window: ReplayWindow256::default(),
last_rx_ms: 0,
valid: true,
};
rx[dir][1].clear();
}
}
self.key_cache
.lock()
.unwrap()
.fill([EpochKeySlot::default(), EpochKeySlot::default()]);
}
pub fn dir_for_sender(sender_peer_id: PeerId, receiver_peer_id: PeerId) -> usize {
if sender_peer_id < receiver_peer_id {
0
} else {
1
}
}
fn hkdf_traffic_key(&self, epoch: u32, dir: usize) -> [u8; 32] {
let root_key = self.root_key();
let salt = [0u8; 32];
let mut extract = HmacSha256::new_from_slice(&salt).unwrap();
extract.update(&root_key);
let prk = extract.finalize().into_bytes();
let mut info = Vec::with_capacity(9 + 4 + 1);
info.extend_from_slice(b"et-traffic");
info.extend_from_slice(&epoch.to_be_bytes());
info.push(dir as u8);
let mut expand = HmacSha256::new_from_slice(&prk).unwrap();
expand.update(&info);
expand.update(&[1u8]);
let okm = expand.finalize().into_bytes();
let mut key = [0u8; 32];
key.copy_from_slice(&okm[..32]);
key
}
fn get_encryptor(&self, epoch: u32, dir: usize, is_send: bool) -> Option<Arc<dyn Encryptor>> {
let generation = self.session_generation();
let rx = self.rx_slots.lock().unwrap();
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
let allowed = epoch == send_epoch
|| rx[dir][0].valid && rx[dir][0].epoch == epoch
|| rx[dir][1].valid && rx[dir][1].epoch == epoch;
if !allowed {
return None;
}
let mut guard = self.key_cache.lock().unwrap();
for slot in guard[dir].iter_mut() {
if slot.valid && slot.epoch == epoch && slot.generation == generation {
return Some(slot.get_encryptor(is_send));
}
}
let key = self.hkdf_traffic_key(epoch, dir);
let mut key_128 = [0u8; 16];
key_128.copy_from_slice(&key[..16]);
let slot = EpochKeySlot {
epoch,
generation,
valid: true,
send_cipher: Some(create_encryptor(&self.send_cipher_algorithm, key_128, key)),
recv_cipher: Some(create_encryptor(&self.recv_cipher_algorithm, key_128, key)),
};
let ret = slot.get_encryptor(is_send);
if !guard[dir][0].valid || guard[dir][0].epoch == epoch {
guard[dir][0] = slot;
} else {
guard[dir][1] = slot;
}
Some(ret)
}
fn maybe_rotate_epoch(&self, now_ms: u64) {
let packets = self
.send_packets_since_epoch
.fetch_add(1, Ordering::Relaxed)
+ 1;
let started = self.send_epoch_started_ms.load(Ordering::Relaxed);
if packets < Self::ROTATE_AFTER_PACKETS
&& now_ms.saturating_sub(started) < Self::ROTATE_AFTER_MS
{
return;
}
let cur = self.send_epoch.load(Ordering::Relaxed);
let next = cur.wrapping_add(1);
if self
.send_epoch
.compare_exchange(cur, next, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.send_epoch_started_ms.store(now_ms, Ordering::Relaxed);
self.send_packets_since_epoch.store(0, Ordering::Relaxed);
}
}
fn next_nonce(&self, dir: usize) -> (u32, u64, [u8; 12]) {
let now_ms = now_ms();
self.maybe_rotate_epoch(now_ms);
let epoch = self.send_epoch.load(Ordering::Relaxed);
let seq = self.send_seq[dir].fetch_add(1, Ordering::Relaxed);
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&epoch.to_be_bytes());
nonce[4..].copy_from_slice(&seq.to_be_bytes());
(epoch, seq, nonce)
}
fn parse_tail(payload: &[u8]) -> Option<[u8; 12]> {
if payload.len() < std::mem::size_of::<AesGcmTail>() {
return None;
}
let tail_off = payload.len() - std::mem::size_of::<AesGcmTail>();
let tail = &payload[tail_off..];
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&tail[16..]);
Some(nonce)
}
fn evict_old_rx_slots(rx: &mut [[EpochRxSlot; 2]; 2], now_ms: u64) {
for dir_slots in rx.iter_mut() {
for slot in dir_slots.iter_mut() {
if !slot.valid {
continue;
}
let last = slot.last_rx_ms;
if last != 0 && now_ms.saturating_sub(last) > Self::EVICT_IDLE_AFTER_MS {
slot.clear();
}
}
}
}
fn check_replay(&self, epoch: u32, seq: u64, dir: usize, now_ms: u64) -> bool {
let mut rx = self.rx_slots.lock().unwrap();
Self::evict_old_rx_slots(&mut rx, now_ms);
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
{
let mut key_cache = self.key_cache.lock().unwrap();
for d in 0..2 {
for s in 0..2 {
if !key_cache[d][s].valid {
continue;
}
let e = key_cache[d][s].epoch;
let allowed = e == send_epoch
|| rx[d][0].valid && rx[d][0].epoch == e
|| rx[d][1].valid && rx[d][1].epoch == e;
if !allowed {
key_cache[d][s].valid = false;
}
}
}
}
if !rx[dir][0].valid {
rx[dir][0] = EpochRxSlot {
epoch,
window: ReplayWindow256::default(),
last_rx_ms: now_ms,
valid: true,
};
}
if rx[dir][0].valid && epoch == rx[dir][0].epoch {
rx[dir][0].last_rx_ms = now_ms;
return rx[dir][0].window.accept(seq);
}
if rx[dir][1].valid && epoch == rx[dir][1].epoch {
rx[dir][1].last_rx_ms = now_ms;
return rx[dir][1].window.accept(seq);
}
if rx[dir][0].valid && epoch > rx[dir][0].epoch {
let mut baseline_epoch = send_epoch;
if rx[dir][0].valid {
baseline_epoch = baseline_epoch.max(rx[dir][0].epoch);
}
if rx[dir][1].valid {
baseline_epoch = baseline_epoch.max(rx[dir][1].epoch);
}
let max_allowed_epoch =
baseline_epoch.saturating_add(Self::MAX_ACCEPTED_RX_EPOCH_AHEAD);
if epoch > max_allowed_epoch {
return false;
}
rx[dir][1] = rx[dir][0];
rx[dir][0] = EpochRxSlot {
epoch,
window: ReplayWindow256::default(),
last_rx_ms: now_ms,
valid: true,
};
return rx[dir][0].window.accept(seq);
}
false
}
pub fn encrypt_payload(
&self,
sender_peer_id: PeerId,
receiver_peer_id: PeerId,
pkt: &mut ZCPacket,
) -> Result<(), anyhow::Error> {
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let (epoch, _seq, nonce_bytes) = self.next_nonce(dir);
let encryptor = self
.get_encryptor(epoch, dir, true)
.ok_or_else(|| anyhow!("no key for epoch"))?;
let _ = encryptor.encrypt_with_nonce(pkt, Some(nonce_bytes.as_slice()));
Ok(())
}
pub fn decrypt_payload(
&self,
sender_peer_id: PeerId,
receiver_peer_id: PeerId,
ciphertext_with_tail: &mut ZCPacket,
) -> Result<(), anyhow::Error> {
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let nonce_bytes =
Self::parse_tail(ciphertext_with_tail.payload()).ok_or_else(|| anyhow!("no tail"))?;
let epoch = u32::from_be_bytes(nonce_bytes[..4].try_into().unwrap());
let seq = u64::from_be_bytes(nonce_bytes[4..].try_into().unwrap());
let now_ms = now_ms();
if !self.check_replay(epoch, seq, dir, now_ms) {
return Err(anyhow!("replay rejected"));
}
let encryptor = self
.get_encryptor(epoch, dir, false)
.ok_or_else(|| anyhow!("no key for epoch"))?;
encryptor.decrypt(ciphertext_with_tail)?;
Ok(())
}
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn peer_session_supports_asymmetric_algorithms() {
let a: PeerId = 10;
let b: PeerId = 20;
let root_key = PeerSession::new_root_key();
let generation = 1u32;
let initial_epoch = 0u32;
let sa = PeerSession::new(
b,
root_key,
generation,
initial_epoch,
"aes-256-gcm".to_string(),
"chacha20-poly1305".to_string(),
);
let sb = PeerSession::new(
a,
root_key,
generation,
initial_epoch,
"chacha20-poly1305".to_string(),
"aes-256-gcm".to_string(),
);
let plaintext1 = b"hello from a";
let mut pkt1 = ZCPacket::new_with_payload(plaintext1);
pkt1.fill_peer_manager_hdr(a as u32, b as u32, 0);
sa.encrypt_payload(a, b, &mut pkt1).unwrap();
sb.decrypt_payload(a, b, &mut pkt1).unwrap();
assert_eq!(pkt1.payload(), plaintext1);
let plaintext2 = b"hello from b";
let mut pkt2 = ZCPacket::new_with_payload(plaintext2);
pkt2.fill_peer_manager_hdr(b as u32, a as u32, 0);
sb.encrypt_payload(b, a, &mut pkt2).unwrap();
sa.decrypt_payload(b, a, &mut pkt2).unwrap();
assert_eq!(pkt2.payload(), plaintext2);
}
#[test]
fn replay_rejects_far_future_epoch_without_poisoning_window() {
let peer_id: PeerId = 10;
let root_key = PeerSession::new_root_key();
let generation = 1u32;
let initial_epoch = 0u32;
let s = PeerSession::new(
peer_id,
root_key,
generation,
initial_epoch,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
);
let now = now_ms();
assert!(s.check_replay(0, 1, 0, now));
assert!(s.check_replay(0, 2, 0, now));
assert!(!s.check_replay(1000, 1, 0, now));
assert!(s.check_replay(1, 1, 0, now + 1));
assert!(s.check_replay(1, 2, 0, now + 2));
}
}
+3
View File
@@ -41,6 +41,9 @@ message PeerConnInfo {
bool is_client = 8;
string network_name = 9;
bool is_closed = 10;
bytes noise_local_static_pubkey = 11;
bytes noise_remote_static_pubkey = 12;
peer_rpc.SecureAuthLevel secure_auth_level = 13;
}
message PeerInfo {
+2
View File
@@ -81,6 +81,8 @@ message NetworkConfig {
optional common.CompressionAlgoPb data_compress_algo = 52;
optional string encryption_algorithm = 53;
optional bool disable_tcp_hole_punching = 54;
common.SecureModeConfig secure_mode = 55;
}
message PortForwardConfig {
+10
View File
@@ -230,3 +230,13 @@ message LimiterConfig {
optional uint64 fill_duration_ms =
3; // default 10ms, the period to fill the bucket
}
message SecureModeConfig {
bool enabled = 1;
// base64(X25519 private key), used by shared node to present a stable identity
optional string local_private_key = 2;
// base64(X25519 public key), required if local_private_key is set
optional string local_public_key = 3;
}
+35
View File
@@ -4,6 +4,7 @@ use std::{
};
use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine as _};
use crate::tunnel::packet_def::CompressorAlgo;
@@ -360,3 +361,37 @@ impl fmt::Debug for Ipv6Addr {
write!(f, "{}", std_ipv6_addr)
}
}
impl SecureModeConfig {
pub fn private_key(&self) -> anyhow::Result<x25519_dalek::StaticSecret> {
let local_private_key = self
.local_private_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("local private key is not set"))?;
let k = BASE64_STANDARD
.decode(local_private_key)
.with_context(|| format!("failed to decode private key: {}", local_private_key))?;
// convert vec to 32b array
let len = k.len();
let k: [u8; 32] = k
.try_into()
.map_err(|_| anyhow::anyhow!("invalid private key length: {}", len))?;
Ok(x25519_dalek::StaticSecret::from(k))
}
pub fn public_key(&self) -> anyhow::Result<x25519_dalek::PublicKey> {
let local_public_key = self
.local_public_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("local public key is not set"))?;
let k = BASE64_STANDARD
.decode(local_public_key)
.with_context(|| format!("failed to decode public key: {}", local_public_key))?;
// convert vec to 32b array
let len = k.len();
let k: [u8; 32] = k
.try_into()
.map_err(|_| anyhow::anyhow!("invalid public key length: {}", len))?;
Ok(x25519_dalek::PublicKey::from(k))
}
}
+42 -1
View File
@@ -251,10 +251,51 @@ message HandshakeRequest {
uint32 version = 3;
repeated string features = 4;
string network_name = 5;
bytes network_secret_digrest = 6;
bytes network_secret_digest = 6;
}
message KcpConnData {
common.SocketAddr src = 1;
common.SocketAddr dst = 4;
}
enum SecureAuthLevel {
None = 0;
EncryptedUnauthenticated = 1;
SharedNodePubkeyVerified = 2;
NetworkSecretConfirmed = 3;
}
enum PeerConnSessionActionPb {
Join = 0;
Sync = 1;
Create = 2;
}
message PeerConnNoiseMsg1Pb {
uint32 version = 1;
string a_network_name = 2;
optional uint32 a_session_generation = 3;
common.UUID a_conn_id = 4;
string client_encryption_algorithm = 5;
}
message PeerConnNoiseMsg2Pb {
string b_network_name = 1;
uint32 role_hint = 2;
PeerConnSessionActionPb action = 3;
uint32 b_session_generation = 4;
optional bytes root_key_32 = 5;
uint32 initial_epoch = 6;
common.UUID b_conn_id = 7;
common.UUID a_conn_id_echo = 8;
optional bytes secret_proof_32 = 9;
string server_encryption_algorithm = 10;
}
message PeerConnNoiseMsg3Pb {
common.UUID a_conn_id_echo = 1;
common.UUID b_conn_id_echo = 2;
optional bytes secret_proof_32 = 3;
bytes secret_digest = 4;
}
+3 -3
View File
@@ -220,14 +220,14 @@ impl TunnelFilter for PacketRecorderTunnelFilter {
type FilterOutput = (Vec<ZCPacket>, Vec<ZCPacket>);
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
self.received.lock().unwrap().push(data.clone());
self.sent.lock().unwrap().push(data.clone());
Some(data)
}
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
match data {
Ok(v) => {
self.sent.lock().unwrap().push(v.clone());
self.received.lock().unwrap().push(v.clone());
Some(Ok(v))
}
Err(e) => Some(Err(e)),
@@ -236,8 +236,8 @@ impl TunnelFilter for PacketRecorderTunnelFilter {
fn filter_output(&self) -> Self::FilterOutput {
(
self.received.lock().unwrap().clone(),
self.sent.lock().unwrap().clone(),
self.received.lock().unwrap().clone(),
)
}
}
+4 -1
View File
@@ -56,7 +56,7 @@ pub struct WGTunnelHeader {
}
pub const WG_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::<WGTunnelHeader>();
#[derive(AsBytes, FromZeroes, Clone, Debug)]
#[derive(AsBytes, FromZeroes, Copy, Clone, Debug)]
#[repr(u8)]
pub enum PacketType {
Invalid = 0,
@@ -72,6 +72,9 @@ pub enum PacketType {
ForeignNetworkPacket = 10,
KcpSrc = 11,
KcpDst = 12,
NoiseHandshakeMsg1 = 13,
NoiseHandshakeMsg2 = 14,
NoiseHandshakeMsg3 = 15,
}
bitflags::bitflags! {