feat: relay peer end-to-end encryption via Noise IK handshake (#1960)

Enable encryption for non-direct nodes requiring relay forwarding.
When secure_mode is enabled, peers perform Noise IK handshake to
establish an encrypted PeerSession. Relay packets are encrypted at
the sender and decrypted at the receiver. Intermediate forwarding
nodes cannot read plaintext data.

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: KKRainbow <5665404+KKRainbow@users.noreply.github.com>
This commit is contained in:
KKRainbow
2026-03-07 14:47:22 +08:00
committed by GitHub
parent 22b4c4be2c
commit 59d4475743
14 changed files with 2081 additions and 73 deletions
+105 -8
View File
@@ -1,6 +1,6 @@
use std::{
sync::{
atomic::{AtomicU32, Ordering},
atomic::{AtomicBool, AtomicU32, Ordering},
Arc, Mutex, RwLock,
},
time::{SystemTime, UNIX_EPOCH},
@@ -70,7 +70,29 @@ impl PeerSessionStore {
}
pub fn get(&self, key: &SessionKey) -> Option<Arc<PeerSession>> {
self.sessions.get(key).map(|v| v.clone())
let session = self.sessions.get(key)?.clone();
if session.is_valid() {
Some(session)
} else {
self.sessions.remove(key);
None
}
}
pub fn remove(&self, key: &SessionKey) {
self.sessions.remove(key);
}
pub fn insert_session(&self, key: SessionKey, session: Arc<PeerSession>) {
self.sessions.insert(key, session);
}
/// Remove sessions that are no longer referenced by any PeerConn or RelayPeerMap.
/// A session with strong_count == 1 means only the store holds it — no active
/// connection is using it, so it can be safely cleaned up.
pub fn evict_unused_sessions(&self) {
self.sessions
.retain(|_key, session| Arc::strong_count(session) > 1);
}
pub fn upsert_responder_session(
@@ -79,8 +101,13 @@ impl PeerSessionStore {
a_session_generation: Option<u32>,
send_algorithm: String,
recv_algorithm: String,
peer_static_pubkey: Option<[u8; 32]>,
) -> Result<UpsertResponderSessionReturn, anyhow::Error> {
let existing = self.sessions.get(key).map(|v| v.clone());
let existing = self
.sessions
.get(key)
.map(|v| v.clone())
.filter(|s| s.is_valid());
match existing {
None => {
let root_key = PeerSession::new_root_key();
@@ -93,6 +120,7 @@ impl PeerSessionStore {
initial_epoch,
send_algorithm,
recv_algorithm,
peer_static_pubkey,
));
self.sessions.insert(key.clone(), session.clone());
Ok(UpsertResponderSessionReturn {
@@ -105,6 +133,7 @@ impl PeerSessionStore {
}
Some(session) => {
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
session.check_or_set_peer_static_pubkey(peer_static_pubkey)?;
let local_gen = session.session_generation();
if a_session_generation.is_some_and(|g| g == local_gen) {
Ok(UpsertResponderSessionReturn {
@@ -139,6 +168,7 @@ impl PeerSessionStore {
initial_epoch: u32,
send_algorithm: String,
recv_algorithm: String,
peer_static_pubkey: Option<[u8; 32]>,
) -> Result<Arc<PeerSession>, anyhow::Error> {
tracing::info!(
"apply_initiator_action {:?}, send_algorithm: {}, recv_algorithm: {}",
@@ -152,6 +182,7 @@ impl PeerSessionStore {
return Err(anyhow!("no local session for JOIN"));
};
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
session.check_or_set_peer_static_pubkey(peer_static_pubkey)?;
if session.session_generation() != b_session_generation {
return Err(anyhow!("JOIN generation mismatch"));
}
@@ -159,6 +190,13 @@ impl PeerSessionStore {
}
PeerSessionAction::Sync | PeerSessionAction::Create => {
let root_key = root_key_32.ok_or_else(|| anyhow!("missing root_key"))?;
// If the existing session is invalidated, remove it so we create a fresh one
if let Some(existing) = self.sessions.get(key) {
if !existing.is_valid() {
drop(existing);
self.sessions.remove(key);
}
}
let session = self
.sessions
.entry(key.clone())
@@ -170,10 +208,12 @@ impl PeerSessionStore {
initial_epoch,
send_algorithm.clone(),
recv_algorithm.clone(),
peer_static_pubkey,
))
})
.clone();
session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?;
session.check_or_set_peer_static_pubkey(peer_static_pubkey)?;
session.sync_root_key(root_key, b_session_generation, initial_epoch);
Ok(session)
}
@@ -318,6 +358,7 @@ pub struct PeerSession {
peer_id: PeerId,
root_key: RwLock<[u8; 32]>,
session_generation: AtomicU32,
peer_static_pubkey: RwLock<Option<[u8; 32]>>,
send_epoch: AtomicU32,
send_seq: [AtomicU64; 2],
@@ -329,6 +370,12 @@ pub struct PeerSession {
send_cipher_algorithm: String,
recv_cipher_algorithm: String,
/// Set to true when the session is detected as corrupted (persistent decrypt failures).
/// Holders of Arc<PeerSession> can check this to know the session should be discarded.
invalidated: AtomicBool,
/// Consecutive decrypt failure counter. Auto-invalidates when threshold is reached.
decrypt_fail_count: AtomicU32,
}
impl std::fmt::Debug for PeerSession {
@@ -337,6 +384,7 @@ impl std::fmt::Debug for PeerSession {
.field("peer_id", &self.peer_id)
.field("root_key", &self.root_key)
.field("session_generation", &self.session_generation)
.field("peer_static_pubkey", &self.peer_static_pubkey)
.field("send_epoch", &self.send_epoch)
.field("send_seq", &self.send_seq)
.field("send_epoch_started_ms", &self.send_epoch_started_ms)
@@ -381,6 +429,7 @@ impl PeerSession {
/// stricter security requirements may decrease it.
const ROTATE_AFTER_MS: u64 = 10 * 60 * 1000;
const MAX_ACCEPTED_RX_EPOCH_AHEAD: u32 = 3;
const DECRYPT_FAIL_THRESHOLD: u32 = 10;
pub fn new(
peer_id: PeerId,
@@ -389,11 +438,8 @@ impl PeerSession {
initial_epoch: u32,
send_cipher_algorithm: String,
recv_cipher_algorithm: String,
peer_static_pubkey: Option<[u8; 32]>,
) -> Self {
// let mut root_key_128 = [0u8; 16];
// root_key_128.copy_from_slice(&root_key[..16]);
// let send_cipher = create_encryptor(&send_algorithm, root_key_128, root_key);
// let recv_cipher = create_encryptor(&recv_algorithm, root_key_128, root_key);
let rx_slots = [
[EpochRxSlot::default(), EpochRxSlot::default()],
[EpochRxSlot::default(), EpochRxSlot::default()],
@@ -407,6 +453,7 @@ impl PeerSession {
peer_id,
root_key: RwLock::new(root_key),
session_generation: AtomicU32::new(session_generation),
peer_static_pubkey: RwLock::new(peer_static_pubkey),
send_epoch: AtomicU32::new(initial_epoch),
send_seq: [AtomicU64::new(0), AtomicU64::new(0)],
send_epoch_started_ms: AtomicU64::new(now_ms),
@@ -415,6 +462,8 @@ impl PeerSession {
key_cache: Mutex::new(key_cache),
send_cipher_algorithm,
recv_cipher_algorithm,
invalidated: AtomicBool::new(false),
decrypt_fail_count: AtomicU32::new(0),
}
}
@@ -422,6 +471,15 @@ impl PeerSession {
self.peer_id
}
/// Mark this session as invalid. All holders of Arc<PeerSession> will see this.
pub fn invalidate(&self) {
self.invalidated.store(true, Ordering::Relaxed);
}
pub fn is_valid(&self) -> bool {
!self.invalidated.load(Ordering::Relaxed)
}
pub fn session_generation(&self) -> u32 {
self.session_generation.load(Ordering::Relaxed)
}
@@ -466,6 +524,24 @@ impl PeerSession {
Ok(())
}
pub fn check_or_set_peer_static_pubkey(
&self,
peer_static_pubkey: Option<[u8; 32]>,
) -> Result<(), anyhow::Error> {
let Some(peer_static_pubkey) = peer_static_pubkey else {
return Ok(());
};
let mut guard = self.peer_static_pubkey.write().unwrap();
if let Some(existing) = *guard {
if existing != peer_static_pubkey {
return Err(anyhow!("peer static pubkey mismatch"));
}
return Ok(());
}
*guard = Some(peer_static_pubkey);
Ok(())
}
pub fn sync_root_key(&self, root_key: [u8; 32], session_generation: u32, initial_epoch: u32) {
{
let mut g = self.root_key.write().unwrap();
@@ -703,6 +779,9 @@ impl PeerSession {
receiver_peer_id: PeerId,
pkt: &mut ZCPacket,
) -> Result<(), anyhow::Error> {
if !self.is_valid() {
return Err(anyhow!("session invalidated"));
}
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let (epoch, _seq, nonce_bytes) = self.next_nonce(dir);
let encryptor = self
@@ -718,6 +797,9 @@ impl PeerSession {
receiver_peer_id: PeerId,
ciphertext_with_tail: &mut ZCPacket,
) -> Result<(), anyhow::Error> {
if !self.is_valid() {
return Err(anyhow!("session invalidated"));
}
let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id);
let nonce_bytes =
Self::parse_tail(ciphertext_with_tail.payload()).ok_or_else(|| anyhow!("no tail"))?;
@@ -732,7 +814,19 @@ impl PeerSession {
let encryptor = self
.get_encryptor(epoch, dir, false)
.ok_or_else(|| anyhow!("no key for epoch"))?;
encryptor.decrypt(ciphertext_with_tail)?;
if let Err(e) = encryptor.decrypt(ciphertext_with_tail) {
let count = self.decrypt_fail_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= Self::DECRYPT_FAIL_THRESHOLD {
self.invalidate();
tracing::warn!(
peer_id = ?self.peer_id,
count,
"session auto-invalidated after consecutive decrypt failures"
);
}
return Err(e.into());
}
self.decrypt_fail_count.store(0, Ordering::Relaxed);
Ok(())
}
@@ -764,6 +858,7 @@ mod tests {
initial_epoch,
"aes-256-gcm".to_string(),
"chacha20-poly1305".to_string(),
None,
);
let sb = PeerSession::new(
a,
@@ -772,6 +867,7 @@ mod tests {
initial_epoch,
"chacha20-poly1305".to_string(),
"aes-256-gcm".to_string(),
None,
);
let plaintext1 = b"hello from a";
@@ -802,6 +898,7 @@ mod tests {
initial_epoch,
"aes-256-gcm".to_string(),
"aes-256-gcm".to_string(),
None,
);
let now = now_ms();