Introduce secure mode (part 1) (#1808)

Use noise protocol on handshake. Check peer's public key if needed. Also support rekey and replay attack prevention.

E2EE and temporary password will be implemented based on this.
This commit is contained in:
KKRainbow
2026-01-25 20:16:51 +08:00
committed by GitHub
parent ffa08d1c43
commit 101f416268
29 changed files with 3320 additions and 91 deletions
+56 -7
View File
@@ -84,6 +84,14 @@ impl Encryptor for AesGcmCipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -91,16 +99,28 @@ impl Encryptor for AesGcmCipher {
}
let mut tail = AesGcmTail::default();
if let Some(nonce) = nonce {
if nonce.len() != tail.nonce.len() {
return Err(Error::EncryptionFailed);
}
tail.nonce.copy_from_slice(nonce);
}
let rs = match &self.cipher {
AesGcmEnum::AES128GCM(aes_gcm) => {
let nonce = Aes128Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
aes_gcm.encrypt_in_place_detached(&nonce, &[], zc_packet.mut_payload())
if nonce.is_none() {
let nonce = Aes128Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
}
let nonce = Nonce::from_slice(&tail.nonce);
aes_gcm.encrypt_in_place_detached(nonce, &[], zc_packet.mut_payload())
}
AesGcmEnum::AES256GCM(aes_gcm) => {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
aes_gcm.encrypt_in_place_detached(&nonce, &[], zc_packet.mut_payload())
if nonce.is_none() {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
tail.nonce.copy_from_slice(nonce.as_slice());
}
let nonce = Nonce::from_slice(&tail.nonce);
aes_gcm.encrypt_in_place_detached(nonce, &[], zc_packet.mut_payload())
}
};
@@ -122,8 +142,9 @@ impl Encryptor for AesGcmCipher {
mod tests {
use crate::{
peers::encrypt::{aes_gcm::AesGcmCipher, Encryptor},
tunnel::packet_def::{ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
tunnel::packet_def::{AesGcmTail, ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
};
use zerocopy::FromBytes;
#[test]
fn test_aes_gcm_cipher() {
@@ -143,4 +164,32 @@ mod tests {
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_aes_gcm_cipher_with_nonce() {
let key = [7u8; 16];
let cipher = AesGcmCipher::new_128(key);
let text = b"Hello";
let nonce = [3u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
let tail = AesGcmTail::ref_from_suffix(packet1.payload()).unwrap();
assert_eq!(tail.nonce, nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
}
}
+7
View File
@@ -30,6 +30,13 @@ pub enum Error {
pub trait Encryptor: Send + Sync + 'static {
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
_nonce: Option<&[u8]>,
) -> Result<(), Error> {
self.encrypt(zc_packet)
}
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
}
+48 -1
View File
@@ -142,6 +142,14 @@ impl Encryptor for OpenSslCipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -153,7 +161,14 @@ impl Encryptor for OpenSslCipher {
let nonce_size = self.get_nonce_size();
let mut tail = OpenSslTail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce[..nonce_size]);
if let Some(nonce) = nonce {
if nonce.len() != nonce_size {
return Err(Error::EncryptionFailed);
}
tail.nonce[..nonce_size].copy_from_slice(nonce);
} else {
rand::thread_rng().fill_bytes(&mut tail.nonce[..nonce_size]);
}
let mut encrypter =
Crypter::new(cipher, Mode::Encrypt, key, Some(&tail.nonce[..nonce_size]))
@@ -198,6 +213,7 @@ mod tests {
peers::encrypt::{openssl_cipher::OpenSslCipher, Encryptor},
tunnel::packet_def::ZCPacket,
};
use zerocopy::FromBytes;
use super::OPENSSL_ENCRYPTION_RESERVED;
@@ -220,6 +236,37 @@ mod tests {
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_openssl_aes128_gcm_with_nonce() {
let key = [7u8; 16];
let cipher = OpenSslCipher::new_aes128_gcm(key);
let text = b"Hello";
let nonce = [3u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
assert!(packet1.payload().len() > text.len() + OPENSSL_ENCRYPTION_RESERVED);
let tail = super::OpenSslTail::ref_from_suffix(packet1.payload())
.unwrap()
.clone();
assert_eq!(&tail.nonce[..nonce.len()], nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
}
#[test]
fn test_openssl_chacha20() {
let key = [0u8; 32];
+46 -2
View File
@@ -93,6 +93,14 @@ impl Encryptor for AesGcmCipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -100,7 +108,14 @@ impl Encryptor for AesGcmCipher {
}
let mut tail = AesGcmTail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce);
if let Some(nonce) = nonce {
if nonce.len() != tail.nonce.len() {
return Err(Error::EncryptionFailed);
}
tail.nonce.copy_from_slice(nonce);
} else {
rand::thread_rng().fill_bytes(&mut tail.nonce);
}
let nonce = aead::Nonce::assume_unique_for_key(tail.nonce);
let rs = match &self.cipher {
@@ -137,8 +152,9 @@ impl Encryptor for AesGcmCipher {
mod tests {
use crate::{
peers::encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor},
tunnel::packet_def::{ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
tunnel::packet_def::{AesGcmTail, ZCPacket, AES_GCM_ENCRYPTION_RESERVED},
};
use zerocopy::FromBytes;
#[test]
fn test_aes_gcm_cipher() {
@@ -158,4 +174,32 @@ mod tests {
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_aes_gcm_cipher_with_nonce() {
let key = [7u8; 16];
let cipher = AesGcmCipher::new_128(key);
let text = b"Hello";
let nonce = [3u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
let tail = AesGcmTail::ref_from_suffix(packet1.payload()).unwrap();
assert_eq!(tail.nonce, nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
}
}
+46 -1
View File
@@ -67,6 +67,14 @@ impl Encryptor for RingChaCha20Cipher {
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
@@ -74,7 +82,14 @@ impl Encryptor for RingChaCha20Cipher {
}
let mut tail = ChaCha20Poly1305Tail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce);
if let Some(nonce) = nonce {
if nonce.len() != tail.nonce.len() {
return Err(Error::EncryptionFailed);
}
tail.nonce.copy_from_slice(nonce);
} else {
rand::thread_rng().fill_bytes(&mut tail.nonce);
}
let nonce = Nonce::assume_unique_for_key(tail.nonce);
let rs =
@@ -100,6 +115,7 @@ mod tests {
peers::encrypt::{ring_chacha20::RingChaCha20Cipher, Encryptor},
tunnel::packet_def::ZCPacket,
};
use zerocopy::FromBytes;
use super::CHACHA20_POLY1305_ENCRYPTION_RESERVED;
@@ -122,4 +138,33 @@ mod tests {
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_ring_chacha20_cipher_with_nonce() {
let key = [9u8; 32];
let cipher = RingChaCha20Cipher::new(key);
let text = b"Hello";
let nonce = [5u8; 12];
let mut packet1 = ZCPacket::new_with_payload(text);
packet1.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet1, Some(&nonce))
.unwrap();
let mut packet2 = ZCPacket::new_with_payload(text);
packet2.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet2, Some(&nonce))
.unwrap();
assert_eq!(packet1.payload(), packet2.payload());
let tail = super::ChaCha20Poly1305Tail::ref_from_suffix(packet1.payload()).unwrap();
assert_eq!(tail.nonce, nonce);
cipher.decrypt(&mut packet1).unwrap();
assert_eq!(packet1.payload(), text);
assert!(!packet1.peer_manager_header().unwrap().is_encrypted());
}
}
+1 -1
View File
@@ -2,7 +2,6 @@ mod graph_algo;
pub mod acl_filter;
pub mod peer;
// pub mod peer_conn;
pub mod peer_conn;
pub mod peer_conn_ping;
pub mod peer_manager;
@@ -10,6 +9,7 @@ pub mod peer_map;
pub mod peer_ospf_route;
pub mod peer_rpc;
pub mod peer_rpc_service;
pub mod peer_session;
pub mod route_trait;
pub mod rpc_service;
+15 -6
View File
@@ -238,12 +238,12 @@ impl Drop for Peer {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::time::timeout;
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
peers::{create_packet_recv_chan, peer_conn::PeerConn},
peers::{create_packet_recv_chan, peer_conn::PeerConn, peer_session::PeerSessionStore},
tunnel::ring::create_ring_tunnel_pair,
};
@@ -257,11 +257,20 @@ mod tests {
let local_peer = Peer::new(new_peer_id(), local_packet_send, global_ctx.clone());
let remote_peer = Peer::new(new_peer_id(), remote_packet_send, global_ctx.clone());
let ps = Arc::new(PeerSessionStore::new());
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
let mut local_peer_conn =
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
let mut remote_peer_conn =
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
let mut local_peer_conn = PeerConn::new(
local_peer.peer_node_id,
global_ctx.clone(),
local_tunnel,
ps.clone(),
);
let mut remote_peer_conn = PeerConn::new(
remote_peer.peer_node_id,
global_ctx.clone(),
remote_tunnel,
ps.clone(),
);
assert!(!local_peer_conn.handshake_done());
assert!(!remote_peer_conn.handshake_done());
File diff suppressed because it is too large Load Diff
+328 -11
View File
@@ -32,6 +32,7 @@ use crate::{
peers::{
peer_conn::PeerConn,
peer_rpc::PeerRpcManagerTransport,
peer_session::PeerSessionStore,
recv_packet_from_chan,
route_trait::{ForeignNetworkRouteInfoMap, MockRoute, NextHopPolicy, RouteInterface},
PeerPacketFilter,
@@ -160,6 +161,8 @@ pub struct PeerManager {
allow_loopback_tunnel: AtomicBool,
self_tx_counters: SelfTxCounters,
peer_session_store: Arc<PeerSessionStore>,
}
impl Debug for PeerManager {
@@ -312,6 +315,8 @@ impl PeerManager {
allow_loopback_tunnel: AtomicBool::new(true),
self_tx_counters,
peer_session_store: Arc::new(PeerSessionStore::new()),
}
}
@@ -363,7 +368,23 @@ impl PeerManager {
tunnel: Box<dyn Tunnel>,
is_directly_connected: bool,
) -> Result<(PeerId, PeerConnId), Error> {
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
self.add_client_tunnel_with_peer_id_hint(tunnel, is_directly_connected, None)
.await
}
pub async fn add_client_tunnel_with_peer_id_hint(
&self,
tunnel: Box<dyn Tunnel>,
is_directly_connected: bool,
peer_id_hint: Option<PeerId>,
) -> Result<(PeerId, PeerConnId), Error> {
let mut peer = PeerConn::new_with_peer_id_hint(
self.my_peer_id,
self.global_ctx.clone(),
tunnel,
peer_id_hint,
self.peer_session_store.clone(),
);
peer.set_is_hole_punched(!is_directly_connected);
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
@@ -387,9 +408,19 @@ impl PeerManager {
}
#[tracing::instrument]
pub async fn try_direct_connect<C>(
pub async fn try_direct_connect<C>(&self, connector: C) -> Result<(PeerId, PeerConnId), Error>
where
C: TunnelConnector + Debug,
{
self.try_direct_connect_with_peer_id_hint(connector, None)
.await
}
#[tracing::instrument]
pub async fn try_direct_connect_with_peer_id_hint<C>(
&self,
mut connector: C,
peer_id_hint: Option<PeerId>,
) -> Result<(PeerId, PeerConnId), Error>
where
C: TunnelConnector + Debug,
@@ -398,7 +429,8 @@ impl PeerManager {
let t = ns
.run_async(|| async move { connector.connect().await })
.await?;
self.add_client_tunnel(t, true).await
self.add_client_tunnel_with_peer_id_hint(t, true, peer_id_hint)
.await
}
// avoid loop back to virtual network
@@ -447,9 +479,14 @@ impl PeerManager {
tracing::info!("add tunnel as server start");
self.check_remote_addr_not_from_virtual_network(&tunnel)?;
let mut conn = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
conn.do_handshake_as_server_ext(|peer, msg| {
if msg.network_name
let mut conn = PeerConn::new(
self.my_peer_id,
self.global_ctx.clone(),
tunnel,
self.peer_session_store.clone(),
);
conn.do_handshake_as_server_ext(|peer, network_name:&str| {
if network_name
== self.global_ctx.get_network_identity().network_name
{
return Ok(());
@@ -463,9 +500,9 @@ impl PeerManager {
let mut peer_id = self
.foreign_network_manager
.get_network_peer_id(&msg.network_name);
.get_network_peer_id(network_name);
if peer_id.is_none() {
peer_id = Some(*self.reserved_my_peer_id_map.entry(msg.network_name.clone()).or_insert_with(|| {
peer_id = Some(*self.reserved_my_peer_id_map.entry(network_name.to_string()).or_insert_with(|| {
rand::random::<PeerId>()
}).value());
}
@@ -473,7 +510,7 @@ impl PeerManager {
tracing::info!(
?peer_id,
?msg.network_name,
?network_name,
"handshake as server with foreign network, new peer id: {}, peer id in foreign manager: {:?}",
peer.get_my_peer_id(), peer_id
);
@@ -1464,7 +1501,10 @@ mod tests {
use std::{fmt::Debug, sync::Arc, time::Duration};
use crate::{
common::{config::Flags, global_ctx::tests::get_mock_global_ctx},
common::{
config::Flags,
global_ctx::{tests::get_mock_global_ctx, NetworkIdentity},
},
connector::{
create_connector_by_url, direct::PeerManagerForDirectConnector,
udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
@@ -1472,6 +1512,7 @@ mod tests {
instance::listeners::get_listener_by_url,
peers::{
create_packet_recv_chan,
peer_conn::tests::set_secure_mode_cfg,
peer_manager::RouteAlgoType,
peer_rpc::tests::register_service,
route_trait::NextHopPolicy,
@@ -1480,7 +1521,10 @@ mod tests {
wait_route_appear_with_cost,
},
},
proto::common::{CompressionAlgoPb, NatType, PeerFeatureFlag},
proto::{
common::{CompressionAlgoPb, NatType, PeerFeatureFlag},
peer_rpc::SecureAuthLevel,
},
tunnel::{
common::tests::wait_for_condition,
filter::{tests::DropSendTunnelFilter, TunnelWithFilter},
@@ -1523,6 +1567,279 @@ mod tests {
.await;
}
#[tokio::test]
async fn peer_manager_safe_mode_connect_between_peers() {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_a
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
peer_mgr_b
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
set_secure_mode_cfg(&peer_mgr_a.get_global_ctx(), true);
set_secure_mode_cfg(&peer_mgr_b.get_global_ctx(), true);
let (a_ring, b_ring) = create_ring_tunnel_pair();
let (a_ret, b_ret) = tokio::join!(
peer_mgr_a.add_client_tunnel(a_ring, false),
peer_mgr_b.add_tunnel_as_server(b_ring, true)
);
let (peer_b_id, _) = a_ret.unwrap();
b_ret.unwrap();
wait_for_condition(
|| {
let peer_mgr_a = peer_mgr_a.clone();
async move {
if !peer_mgr_a
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&peer_b_id)
{
return false;
}
let Some(conns) = peer_mgr_a.get_peer_map().list_peer_conns(peer_b_id).await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.len() == 32
&& c.noise_remote_static_pubkey.len() == 32
&& c.secure_auth_level == SecureAuthLevel::NetworkSecretConfirmed as i32
})
}
},
Duration::from_secs(10),
)
.await;
let peer_a_id = peer_mgr_a.my_peer_id();
wait_for_condition(
|| {
let peer_mgr_b = peer_mgr_b.clone();
async move {
if !peer_mgr_b
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&peer_a_id)
{
return false;
}
let Some(conns) = peer_mgr_b.get_peer_map().list_peer_conns(peer_a_id).await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.len() == 32
&& c.noise_remote_static_pubkey.len() == 32
&& c.secure_auth_level == SecureAuthLevel::NetworkSecretConfirmed as i32
})
}
},
Duration::from_secs(10),
)
.await;
}
#[tokio::test]
async fn peer_manager_safe_server_accept_legacy_client() {
let peer_mgr_client = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_server = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_client
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
peer_mgr_server
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("net1".to_string(), "sec1".to_string()));
set_secure_mode_cfg(&peer_mgr_server.get_global_ctx(), true);
let (c_ring, s_ring) = create_ring_tunnel_pair();
let (c_ret, s_ret) = tokio::join!(
peer_mgr_client.add_client_tunnel(c_ring, false),
peer_mgr_server.add_tunnel_as_server(s_ring, true)
);
let (server_id, _) = c_ret.unwrap();
s_ret.unwrap();
wait_for_condition(
|| {
let peer_mgr_client = peer_mgr_client.clone();
async move {
if !peer_mgr_client
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&server_id)
{
return false;
}
let Some(conns) = peer_mgr_client
.get_peer_map()
.list_peer_conns(server_id)
.await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.is_empty()
&& c.noise_remote_static_pubkey.is_empty()
&& c.secure_auth_level == SecureAuthLevel::None as i32
})
}
},
Duration::from_secs(10),
)
.await;
let client_id = peer_mgr_client.my_peer_id();
wait_for_condition(
|| {
let peer_mgr_server = peer_mgr_server.clone();
async move {
if !peer_mgr_server
.get_peer_map()
.list_peers_with_conn()
.await
.contains(&client_id)
{
return false;
}
let Some(conns) = peer_mgr_server
.get_peer_map()
.list_peer_conns(client_id)
.await
else {
return false;
};
conns.iter().any(|c| {
c.noise_local_static_pubkey.is_empty()
&& c.noise_remote_static_pubkey.is_empty()
&& c.secure_auth_level == SecureAuthLevel::None as i32
})
}
},
Duration::from_secs(5),
)
.await;
}
#[tokio::test]
async fn peer_manager_safe_mode_shared_node_pinning_connect() {
let peer_mgr_client = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_server = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_client
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity::new("user".to_string(), "sec1".to_string()));
peer_mgr_server
.get_global_ctx()
.config
.set_network_identity(NetworkIdentity {
network_name: "shared".to_string(),
network_secret: None,
network_secret_digest: None,
});
set_secure_mode_cfg(&peer_mgr_client.get_global_ctx(), true);
set_secure_mode_cfg(&peer_mgr_server.get_global_ctx(), true);
let server_pub_b64 = peer_mgr_server
.get_global_ctx()
.config
.get_secure_mode()
.unwrap()
.local_public_key
.unwrap();
let (a_ring, b_ring) = create_ring_tunnel_pair();
let server_remote_url: url::Url = a_ring
.info()
.unwrap()
.remote_addr
.unwrap()
.url
.parse()
.unwrap();
peer_mgr_client.get_global_ctx().config.set_peers(vec![
crate::common::config::PeerConfig {
uri: server_remote_url,
peer_public_key: Some(server_pub_b64.clone()),
},
]);
let (c_ret, s_ret) = tokio::join!(
peer_mgr_client.add_client_tunnel(a_ring, false),
peer_mgr_server.add_tunnel_as_server(b_ring, true)
);
c_ret.unwrap();
s_ret.unwrap();
wait_for_condition(
|| {
let peer_mgr_client = peer_mgr_client.clone();
async move {
let foreign_peer_map =
peer_mgr_client.get_foreign_network_client().get_peer_map();
if foreign_peer_map.list_peers_with_conn().await.len() != 1 {
return false;
}
let Some(peer_id) = foreign_peer_map
.list_peers_with_conn()
.await
.into_iter()
.next()
else {
return false;
};
let Some(conns) = foreign_peer_map.list_peer_conns(peer_id).await else {
return false;
};
conns.iter().any(|c| {
c.secure_auth_level == SecureAuthLevel::SharedNodePubkeyVerified as i32
&& c.noise_local_static_pubkey.len() == 32
&& c.noise_remote_static_pubkey.len() == 32
})
}
},
Duration::from_secs(10),
)
.await;
wait_for_condition(
|| {
let peer_mgr_server = peer_mgr_server.clone();
async move {
let foreigns = peer_mgr_server
.get_foreign_network_manager()
.list_foreign_networks()
.await;
let Some(entry) = foreigns.foreign_networks.get("user") else {
return false;
};
entry.peers.iter().any(|p| {
p.conns
.iter()
.any(|c| c.noise_local_static_pubkey.len() == 32)
})
}
},
Duration::from_secs(10),
)
.await;
}
async fn connect_peer_manager_with<C: TunnelConnector + Debug + 'static, L: TunnelListener>(
client_mgr: Arc<PeerManager>,
server_mgr: &Arc<PeerManager>,
+817
View File
@@ -0,0 +1,817 @@
use std::{
sync::{
atomic::{AtomicU32, Ordering},
Arc, Mutex, RwLock,
},
time::{SystemTime, UNIX_EPOCH},
};
use atomic_shim::AtomicU64;
use anyhow::anyhow;
use dashmap::DashMap;
use hmac::{Hmac, Mac as _};
use rand::RngCore as _;
use sha2::Sha256;
use crate::{
common::PeerId,
peers::encrypt::{create_encryptor, Encryptor},
tunnel::packet_def::{AesGcmTail, ZCPacket},
};
type HmacSha256 = Hmac<Sha256>;
pub struct UpsertResponderSessionReturn {
pub session: Arc<PeerSession>,
pub action: PeerSessionAction,
pub session_generation: u32,
pub root_key: Option<[u8; 32]>,
pub initial_epoch: u32,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PeerSessionAction {
Join,
Sync,
Create,
}
#[derive(PartialEq, Clone, Eq, Hash)]
pub struct SessionKey {
network_name: String,
peer_id: PeerId,
}
impl SessionKey {
pub fn new(network_name: String, peer_id: PeerId) -> Self {
Self {
network_name,
peer_id,
}
}
}
#[derive(Clone)]
pub struct PeerSessionStore {
sessions: Arc<DashMap<SessionKey, Arc<PeerSession>>>,
}
impl Default for PeerSessionStore {
fn default() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
}
impl PeerSessionStore {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, key: &SessionKey) -> Option<Arc<PeerSession>> {
self.sessions.get(key).map(|v| v.clone())
}
pub fn upsert_responder_session(
&self,
key: &SessionKey,
a_session_generation: Option<u32>,
send_algorithm: String,
recv_algorithm: String,
) -> Result<UpsertResponderSessionReturn, anyhow::Error> {
let existing = self.sessions.get(key).map(|v| v.clone());
match existing {
None => {
let root_key = PeerSession::new_root_key();
let session_generation = 1u32;
let initial_epoch = 0u32;
let session = Arc::new(PeerSession::new(
key.peer_id,
root_key,
session_generation,
initial_epoch,
send_algorithm,
recv_algorithm,
));
self.sessions.insert(key.clone(), session.clone());
Ok(UpsertResponderSessionReturn {
session,
action: PeerSessionAction::Create,
session_generation,
root_key: Some(root_key),
initial_epoch,
})
}
Some(session) => {
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
let local_gen = session.session_generation();
if a_session_generation.is_some_and(|g| g == local_gen) {
Ok(UpsertResponderSessionReturn {
session,
action: PeerSessionAction::Join,
session_generation: local_gen,
root_key: None,
initial_epoch: 0,
})
} else {
let initial_epoch = session.next_sync_epoch();
let root_key = session.root_key();
Ok(UpsertResponderSessionReturn {
session,
action: PeerSessionAction::Sync,
session_generation: local_gen,
root_key: Some(root_key),
initial_epoch,
})
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn apply_initiator_action(
&self,
key: &SessionKey,
action: PeerSessionAction,
b_session_generation: u32,
root_key_32: Option<[u8; 32]>,
initial_epoch: u32,
send_algorithm: String,
recv_algorithm: String,
) -> Result<Arc<PeerSession>, anyhow::Error> {
tracing::info!(
"apply_initiator_action {:?}, send_algorithm: {}, recv_algorithm: {}",
action,
send_algorithm,
recv_algorithm
);
match action {
PeerSessionAction::Join => {
let Some(session) = self.get(key) else {
return Err(anyhow!("no local session for JOIN"));
};
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
if session.session_generation() != b_session_generation {
return Err(anyhow!("JOIN generation mismatch"));
}
Ok(session)
}
PeerSessionAction::Sync | PeerSessionAction::Create => {
let root_key = root_key_32.ok_or_else(|| anyhow!("missing root_key"))?;
let session = self
.sessions
.entry(key.clone())
.or_insert_with(|| {
Arc::new(PeerSession::new(
key.peer_id,
root_key,
b_session_generation,
initial_epoch,
send_algorithm.clone(),
recv_algorithm.clone(),
))
})
.clone();
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
session.sync_root_key(root_key, b_session_generation, initial_epoch);
Ok(session)
}
}
}
}
#[derive(Clone, Default)]
struct EpochKeySlot {
epoch: u32,
generation: u32,
valid: bool,
send_cipher: Option<Arc<dyn Encryptor>>,
recv_cipher: Option<Arc<dyn Encryptor>>,
}
impl std::fmt::Debug for EpochKeySlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EpochKeySlot")
.field("epoch", &self.epoch)
.field("generation", &self.generation)
.field("valid", &self.valid)
.finish()
}
}
impl EpochKeySlot {
fn get_encryptor(&self, is_send: bool) -> Arc<dyn Encryptor> {
if is_send {
self.send_cipher.as_ref().unwrap().clone()
} else {
self.recv_cipher.as_ref().unwrap().clone()
}
}
}
#[derive(Debug, Clone, Copy, Default)]
struct ReplayWindow256 {
max_seq: u64,
bitmap: [u8; 32],
valid: bool,
}
impl ReplayWindow256 {
fn clear(&mut self) {
self.max_seq = 0;
self.bitmap.fill(0);
self.valid = false;
}
fn test_bit(&self, idx: usize) -> bool {
let byte = idx / 8;
let bit = idx % 8;
(self.bitmap[byte] >> bit) & 1 == 1
}
fn set_bit(&mut self, idx: usize) {
let byte = idx / 8;
let bit = idx % 8;
self.bitmap[byte] |= 1u8 << bit;
}
fn shift_right(&mut self, shift: usize) {
if shift == 0 {
return;
}
let total_bits = 256usize;
if shift >= total_bits {
self.bitmap.fill(0);
return;
}
let byte_shift = shift / 8;
let bit_shift = shift % 8;
if byte_shift > 0 {
for i in (0..self.bitmap.len()).rev() {
self.bitmap[i] = if i >= byte_shift {
self.bitmap[i - byte_shift]
} else {
0
};
}
}
if bit_shift > 0 {
let mut carry = 0u8;
for b in self.bitmap.iter_mut().rev() {
let new_carry = *b << (8 - bit_shift);
*b = (*b >> bit_shift) | carry;
carry = new_carry;
}
}
}
fn accept(&mut self, seq: u64) -> bool {
if !self.valid {
self.valid = true;
self.max_seq = seq;
self.set_bit(0);
return true;
}
if seq > self.max_seq {
let shift = (seq - self.max_seq) as usize;
self.shift_right(shift);
self.max_seq = seq;
self.set_bit(0);
return true;
}
let delta = (self.max_seq - seq) as usize;
if delta >= 256 {
return false;
}
if self.test_bit(delta) {
return false;
}
self.set_bit(delta);
true
}
}
#[derive(Debug, Clone, Copy, Default)]
struct EpochRxSlot {
epoch: u32,
window: ReplayWindow256,
last_rx_ms: u64,
valid: bool,
}
impl EpochRxSlot {
fn clear(&mut self) {
self.epoch = 0;
self.window.clear();
self.last_rx_ms = 0;
self.valid = false;
}
}
pub struct PeerSession {
peer_id: PeerId,
root_key: RwLock<[u8; 32]>,
session_generation: AtomicU32,
send_epoch: AtomicU32,
send_seq: [AtomicU64; 2],
send_epoch_started_ms: AtomicU64,
send_packets_since_epoch: AtomicU64,
rx_slots: Mutex<[[EpochRxSlot; 2]; 2]>,
key_cache: Mutex<[[EpochKeySlot; 2]; 2]>,
send_cipher_algorithm: String,
recv_cipher_algorithm: String,
}
impl std::fmt::Debug for PeerSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerSession")
.field("peer_id", &self.peer_id)
.field("root_key", &self.root_key)
.field("session_generation", &self.session_generation)
.field("send_epoch", &self.send_epoch)
.field("send_seq", &self.send_seq)
.field("send_epoch_started_ms", &self.send_epoch_started_ms)
.field("send_packets_since_epoch", &self.send_packets_since_epoch)
.field("rx_slots", &self.rx_slots)
.field("key_cache", &self.key_cache)
.field("send_cipher_algorithm", &self.send_cipher_algorithm)
.field("recv_cipher_algorithm", &self.recv_cipher_algorithm)
.finish()
}
}
impl PeerSession {
/// Idle-eviction timeout for receive slots, in milliseconds.
///
/// If no packets are received for this period (~30 seconds), the
/// corresponding RX slot is considered idle and may be cleared/reused.
/// This helps reclaim state for dead peers or paths while still tolerating
/// short network stalls. Environments with very bursty or high-latency
/// traffic may want to increase this value; low-latency or tightly
/// resource-constrained deployments may lower it.
const EVICT_IDLE_AFTER_MS: u64 = 30_000;
/// Maximum number of packets to send in a single epoch before forcing
/// a key/epoch rotation.
///
/// This bounds the amount of traffic protected under a single set of
/// derived keys, which is a common best practice for long-lived secure
/// channels. The current value (~1 million packets) is a conservative
/// default chosen to balance security (more frequent rotation) and
/// performance (avoiding excessive rekeying). Deployments with very high
/// or very low packet rates may tune this threshold accordingly.
const ROTATE_AFTER_PACKETS: u64 = 1_000_000;
/// Maximum wall-clock lifetime of a send epoch, in milliseconds.
///
/// Even if the packet-based limit is not reached, epochs are rotated
/// after this duration (~10 minutes) to avoid long-lived keys and keep
/// replay windows bounded in time. This also limits the impact of a
/// compromised key. Installations that prioritize lower overhead over
/// more aggressive key rotation may increase this value; those with
/// stricter security requirements may decrease it.
const ROTATE_AFTER_MS: u64 = 10 * 60 * 1000;
const MAX_ACCEPTED_RX_EPOCH_AHEAD: u32 = 3;
pub fn new(
peer_id: PeerId,
root_key: [u8; 32],
session_generation: u32,
initial_epoch: u32,
send_cipher_algorithm: String,
recv_cipher_algorithm: String,
) -> Self {
// let mut root_key_128 = [0u8; 16];
// root_key_128.copy_from_slice(&root_key[..16]);
// let send_cipher = create_encryptor(&send_algorithm, root_key_128, root_key);
// let recv_cipher = create_encryptor(&recv_algorithm, root_key_128, root_key);
let rx_slots = [
[EpochRxSlot::default(), EpochRxSlot::default()],
[EpochRxSlot::default(), EpochRxSlot::default()],
];
let key_cache = [
[EpochKeySlot::default(), EpochKeySlot::default()],
[EpochKeySlot::default(), EpochKeySlot::default()],
];
let now_ms = now_ms();
Self {
peer_id,
root_key: RwLock::new(root_key),
session_generation: AtomicU32::new(session_generation),
send_epoch: AtomicU32::new(initial_epoch),
send_seq: [AtomicU64::new(0), AtomicU64::new(0)],
send_epoch_started_ms: AtomicU64::new(now_ms),
send_packets_since_epoch: AtomicU64::new(0),
rx_slots: Mutex::new(rx_slots),
key_cache: Mutex::new(key_cache),
send_cipher_algorithm,
recv_cipher_algorithm,
}
}
pub fn peer_id(&self) -> PeerId {
self.peer_id
}
pub fn session_generation(&self) -> u32 {
self.session_generation.load(Ordering::Relaxed)
}
pub fn root_key(&self) -> [u8; 32] {
*self.root_key.read().unwrap()
}
pub fn new_root_key() -> [u8; 32] {
let mut out = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut out);
out
}
pub fn next_sync_epoch(&self) -> u32 {
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
let rx = self.rx_slots.lock().unwrap();
let mut max_epoch = send_epoch;
for dir in 0..2 {
let cur = rx[dir][0];
if cur.valid {
max_epoch = max_epoch.max(cur.epoch);
}
let prev = rx[dir][1];
if prev.valid {
max_epoch = max_epoch.max(prev.epoch);
}
}
max_epoch.wrapping_add(1)
}
pub fn check_encrypt_algo_same(
&self,
send_algorithm: &str,
recv_algorithm: &str,
) -> Result<(), anyhow::Error> {
if self.send_cipher_algorithm != send_algorithm
|| self.recv_cipher_algorithm != recv_algorithm
{
return Err(anyhow!("encrypt algorithm not same"));
}
Ok(())
}
pub fn sync_root_key(&self, root_key: [u8; 32], session_generation: u32, initial_epoch: u32) {
{
let mut g = self.root_key.write().unwrap();
*g = root_key;
}
self.session_generation
.store(session_generation, Ordering::Relaxed);
self.send_epoch.store(initial_epoch, Ordering::Relaxed);
self.send_seq[0].store(0, Ordering::Relaxed);
self.send_seq[1].store(0, Ordering::Relaxed);
self.send_epoch_started_ms
.store(now_ms(), Ordering::Relaxed);
self.send_packets_since_epoch.store(0, Ordering::Relaxed);
{
let mut rx = self.rx_slots.lock().unwrap();
for dir in 0..2 {
rx[dir][0] = EpochRxSlot {
epoch: initial_epoch,
window: ReplayWindow256::default(),
last_rx_ms: 0,
valid: true,
};
rx[dir][1].clear();
}
}
self.key_cache
.lock()
.unwrap()
.fill([EpochKeySlot::default(), EpochKeySlot::default()]);
}
pub fn dir_for_sender(sender_peer_id: PeerId, receiver_peer_id: PeerId) -> usize {
if sender_peer_id < receiver_peer_id {
0
} else {
1
}
}
fn hkdf_traffic_key(&self, epoch: u32, dir: usize) -> [u8; 32] {
let root_key = self.root_key();
let salt = [0u8; 32];
let mut extract = HmacSha256::new_from_slice(&salt).unwrap();
extract.update(&root_key);
let prk = extract.finalize().into_bytes();
let mut info = Vec::with_capacity(9 + 4 + 1);
info.extend_from_slice(b"et-traffic");
info.extend_from_slice(&epoch.to_be_bytes());
info.push(dir as u8);
let mut expand = HmacSha256::new_from_slice(&prk).unwrap();
expand.update(&info);
expand.update(&[1u8]);
let okm = expand.finalize().into_bytes();
let mut key = [0u8; 32];
key.copy_from_slice(&okm[..32]);
key
}
fn get_encryptor(&self, epoch: u32, dir: usize, is_send: bool) -> Option<Arc<dyn Encryptor>> {
let generation = self.session_generation();
let rx = self.rx_slots.lock().unwrap();
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
let allowed = epoch == send_epoch
|| rx[dir][0].valid && rx[dir][0].epoch == epoch
|| rx[dir][1].valid && rx[dir][1].epoch == epoch;
if !allowed {
return None;
}
let mut guard = self.key_cache.lock().unwrap();
for slot in guard[dir].iter_mut() {
if slot.valid && slot.epoch == epoch && slot.generation == generation {
return Some(slot.get_encryptor(is_send));
}
}
let key = self.hkdf_traffic_key(epoch, dir);
let mut key_128 = [0u8; 16];
key_128.copy_from_slice(&key[..16]);
let slot = EpochKeySlot {
epoch,
generation,
valid: true,
send_cipher: Some(create_encryptor(&self.send_cipher_algorithm, key_128, key)),
recv_cipher: Some(create_encryptor(&self.recv_cipher_algorithm, key_128, key)),
};
let ret = slot.get_encryptor(is_send);
if !guard[dir][0].valid || guard[dir][0].epoch == epoch {
guard[dir][0] = slot;
} else {
guard[dir][1] = slot;
}
Some(ret)
}
fn maybe_rotate_epoch(&self, now_ms: u64) {
let packets = self
.send_packets_since_epoch
.fetch_add(1, Ordering::Relaxed)
+ 1;
let started = self.send_epoch_started_ms.load(Ordering::Relaxed);
if packets < Self::ROTATE_AFTER_PACKETS
&& now_ms.saturating_sub(started) < Self::ROTATE_AFTER_MS
{
return;
}
let cur = self.send_epoch.load(Ordering::Relaxed);
let next = cur.wrapping_add(1);
if self
.send_epoch
.compare_exchange(cur, next, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.send_epoch_started_ms.store(now_ms, Ordering::Relaxed);
self.send_packets_since_epoch.store(0, Ordering::Relaxed);
}
}
fn next_nonce(&self, dir: usize) -> (u32, u64, [u8; 12]) {
let now_ms = now_ms();
self.maybe_rotate_epoch(now_ms);
let epoch = self.send_epoch.load(Ordering::Relaxed);
let seq = self.send_seq[dir].fetch_add(1, Ordering::Relaxed);
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&epoch.to_be_bytes());
nonce[4..].copy_from_slice(&seq.to_be_bytes());
(epoch, seq, nonce)
}
fn parse_tail(payload: &[u8]) -> Option<[u8; 12]> {
if payload.len() < std::mem::size_of::<AesGcmTail>() {
return None;
}
let tail_off = payload.len() - std::mem::size_of::<AesGcmTail>();
let tail = &payload[tail_off..];
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&tail[16..]);
Some(nonce)
}
fn evict_old_rx_slots(rx: &mut [[EpochRxSlot; 2]; 2], now_ms: u64) {
for dir_slots in rx.iter_mut() {
for slot in dir_slots.iter_mut() {
if !slot.valid {
continue;
}
let last = slot.last_rx_ms;
if last != 0 && now_ms.saturating_sub(last) > Self::EVICT_IDLE_AFTER_MS {
slot.clear();
}
}
}
}
fn check_replay(&self, epoch: u32, seq: u64, dir: usize, now_ms: u64) -> bool {
let mut rx = self.rx_slots.lock().unwrap();
Self::evict_old_rx_slots(&mut rx, now_ms);
let send_epoch = self.send_epoch.load(Ordering::Relaxed);
{
let mut key_cache = self.key_cache.lock().unwrap();
for d in 0..2 {
for s in 0..2 {
if !key_cache[d][s].valid {
continue;
}
let e = key_cache[d][s].epoch;
let allowed = e == send_epoch
|| rx[d][0].valid && rx[d][0].epoch == e
|| rx[d][1].valid && rx[d][1].epoch == e;
if !allowed {
key_cache[d][s].valid = false;
}
}
}
}
if !rx[dir][0].valid {
rx[dir][0] = EpochRxSlot {
epoch,
window: ReplayWindow256::default(),
last_rx_ms: now_ms,
valid: true,
};
}
if rx[dir][0].valid && epoch == rx[dir][0].epoch {
rx[dir][0].last_rx_ms = now_ms;
return rx[dir][0].window.accept(seq);
}
if rx[dir][1].valid && epoch == rx[dir][1].epoch {
rx[dir][1].last_rx_ms = now_ms;
return rx[dir][1].window.accept(seq);
}
if rx[dir][0].valid && epoch > rx[dir][0].epoch {
let mut baseline_epoch = send_epoch;
if rx[dir][0].valid {
baseline_epoch = baseline_epoch.max(rx[dir][0].epoch);
}
if rx[dir][1].valid {
baseline_epoch = baseline_epoch.max(rx[dir][1].epoch);
}
let max_allowed_epoch =
baseline_epoch.saturating_add(Self::MAX_ACCEPTED_RX_EPOCH_AHEAD);
if epoch > max_allowed_epoch {
return false;
}
rx[dir][1] = rx[dir][0];
rx[dir][0] = EpochRxSlot {
epoch,
window: ReplayWindow256::default(),
last_rx_ms: now_ms,
valid: true,
};
return rx[dir][0].window.accept(seq);
}
false
}
pub fn encrypt_payload(
&self,
sender_peer_id: PeerId,
receiver_peer_id: PeerId,
pkt: &mut ZCPacket,
) -> Result<(), anyhow::Error> {
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let (epoch, _seq, nonce_bytes) = self.next_nonce(dir);
let encryptor = self
.get_encryptor(epoch, dir, true)
.ok_or_else(|| anyhow!("no key for epoch"))?;
let _ = encryptor.encrypt_with_nonce(pkt, Some(nonce_bytes.as_slice()));
Ok(())
}
pub fn decrypt_payload(
&self,
sender_peer_id: PeerId,
receiver_peer_id: PeerId,
ciphertext_with_tail: &mut ZCPacket,
) -> Result<(), anyhow::Error> {
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let nonce_bytes =
Self::parse_tail(ciphertext_with_tail.payload()).ok_or_else(|| anyhow!("no tail"))?;
let epoch = u32::from_be_bytes(nonce_bytes[..4].try_into().unwrap());
let seq = u64::from_be_bytes(nonce_bytes[4..].try_into().unwrap());
let now_ms = now_ms();
if !self.check_replay(epoch, seq, dir, now_ms) {
return Err(anyhow!("replay rejected"));
}
let encryptor = self
.get_encryptor(epoch, dir, false)
.ok_or_else(|| anyhow!("no key for epoch"))?;
encryptor.decrypt(ciphertext_with_tail)?;
Ok(())
}
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn peer_session_supports_asymmetric_algorithms() {
let a: PeerId = 10;
let b: PeerId = 20;
let root_key = PeerSession::new_root_key();
let generation = 1u32;
let initial_epoch = 0u32;
let sa = PeerSession::new(
b,
root_key,
generation,
initial_epoch,
"aes-256-gcm".to_string(),
"chacha20-poly1305".to_string(),
);
let sb = PeerSession::new(
a,
root_key,
generation,
initial_epoch,
"chacha20-poly1305".to_string(),
"aes-256-gcm".to_string(),
);
let plaintext1 = b"hello from a";
let mut pkt1 = ZCPacket::new_with_payload(plaintext1);
pkt1.fill_peer_manager_hdr(a as u32, b as u32, 0);
sa.encrypt_payload(a, b, &mut pkt1).unwrap();
sb.decrypt_payload(a, b, &mut pkt1).unwrap();
assert_eq!(pkt1.payload(), plaintext1);
let plaintext2 = b"hello from b";
let mut pkt2 = ZCPacket::new_with_payload(plaintext2);
pkt2.fill_peer_manager_hdr(b as u32, a as u32, 0);
sb.encrypt_payload(b, a, &mut pkt2).unwrap();
sa.decrypt_payload(b, a, &mut pkt2).unwrap();
assert_eq!(pkt2.payload(), plaintext2);
}
#[test]
fn replay_rejects_far_future_epoch_without_poisoning_window() {
let peer_id: PeerId = 10;
let root_key = PeerSession::new_root_key();
let generation = 1u32;
let initial_epoch = 0u32;
let s = PeerSession::new(
peer_id,
root_key,
generation,
initial_epoch,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
);
let now = now_ms();
assert!(s.check_replay(0, 1, 0, now));
assert!(s.check_replay(0, 2, 0, now));
assert!(!s.check_replay(1000, 1, 0, now));
assert!(s.check_replay(1, 1, 0, now + 1));
assert!(s.check_replay(1, 2, 0, now + 2));
}
}