support encryption (#60)

This commit is contained in:
Sijie.Sun
2024-04-27 13:44:59 +08:00
committed by GitHub
parent 69651ae3fd
commit fcc73159b3
23 changed files with 489 additions and 81 deletions
+34
View File
@@ -0,0 +1,34 @@
use crate::tunnel::packet_def::ZCPacket;
pub mod ring_aes_gcm;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("packet is not encrypted")]
NotEcrypted,
#[error("packet is too short. len: {0}")]
PacketTooShort(usize),
#[error("decryption failed")]
DecryptionFailed,
#[error("encryption failed")]
EncryptionFailed,
#[error("invalid tag. tag: {0:?}")]
InvalidTag(Vec<u8>),
}
pub trait Encryptor: Send + Sync + 'static {
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
}
pub struct NullCipher;
impl Encryptor for NullCipher {
fn encrypt(&self, _zc_packet: &mut ZCPacket) -> Result<(), Error> {
Ok(())
}
fn decrypt(&self, _zc_packet: &mut ZCPacket) -> Result<(), Error> {
Ok(())
}
}
+161
View File
@@ -0,0 +1,161 @@
use rand::RngCore;
use ring::aead::{self};
use ring::aead::{LessSafeKey, UnboundKey};
use zerocopy::{AsBytes, FromBytes};
use crate::tunnel::packet_def::{AesGcmTail, ZCPacket, AES_GCM_ENCRYPTION_RESERVED};
use super::{Encryptor, Error};
#[derive(Clone)]
pub struct AesGcmCipher {
pub(crate) cipher: AesGcmEnum,
}
pub enum AesGcmEnum {
AesGCM128(LessSafeKey, [u8; 16]),
AesGCM256(LessSafeKey, [u8; 32]),
}
impl Clone for AesGcmEnum {
fn clone(&self) -> Self {
match &self {
AesGcmEnum::AesGCM128(_, key) => {
let c =
LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, key.as_slice()).unwrap());
AesGcmEnum::AesGCM128(c, *key)
}
AesGcmEnum::AesGCM256(_, key) => {
let c =
LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, key.as_slice()).unwrap());
AesGcmEnum::AesGCM256(c, *key)
}
}
}
}
impl AesGcmCipher {
pub fn new_128(key: [u8; 16]) -> Self {
let cipher = LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
Self {
cipher: AesGcmEnum::AesGCM128(cipher, key),
}
}
pub fn new_256(key: [u8; 32]) -> Self {
let cipher = LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, &key).unwrap());
Self {
cipher: AesGcmEnum::AesGCM256(cipher, key),
}
}
}
impl Encryptor for AesGcmCipher {
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if !pm_header.is_encrypted() {
return Err(Error::NotEcrypted);
}
let payload_len = zc_packet.payload().len();
if payload_len < AES_GCM_ENCRYPTION_RESERVED {
return Err(Error::PacketTooShort(zc_packet.payload().len()));
}
let text_and_tag_len = payload_len - AES_GCM_ENCRYPTION_RESERVED + 16;
let aes_tail = AesGcmTail::ref_from_suffix(zc_packet.payload()).unwrap();
let nonce = aead::Nonce::assume_unique_for_key(aes_tail.nonce.clone());
let rs = match &self.cipher {
AesGcmEnum::AesGCM128(cipher, _) => cipher.open_in_place(
nonce,
aead::Aad::empty(),
&mut zc_packet.mut_payload()[..text_and_tag_len],
),
AesGcmEnum::AesGCM256(cipher, _) => cipher.open_in_place(
nonce,
aead::Aad::empty(),
&mut zc_packet.mut_payload()[..text_and_tag_len],
),
};
if let Err(_) = rs {
return Err(Error::DecryptionFailed);
}
let pm_header = zc_packet.mut_peer_manager_header().unwrap();
pm_header.set_encrypted(false);
let old_len = zc_packet.buf_len();
zc_packet
.mut_inner()
.truncate(old_len - AES_GCM_ENCRYPTION_RESERVED);
return Ok(());
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
return Ok(());
}
let mut tail = AesGcmTail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce);
let nonce = aead::Nonce::assume_unique_for_key(tail.nonce.clone());
let rs = match &self.cipher {
AesGcmEnum::AesGCM128(cipher, _) => cipher.seal_in_place_separate_tag(
nonce,
aead::Aad::empty(),
zc_packet.mut_payload(),
),
AesGcmEnum::AesGCM256(cipher, _) => cipher.seal_in_place_separate_tag(
nonce,
aead::Aad::empty(),
zc_packet.mut_payload(),
),
};
return match rs {
Ok(tag) => {
let tag = tag.as_ref();
if tag.len() != 16 {
return Err(Error::InvalidTag(tag.to_vec()));
}
tail.tag.copy_from_slice(tag);
let pm_header = zc_packet.mut_peer_manager_header().unwrap();
pm_header.set_encrypted(true);
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
Ok(())
}
Err(_) => Err(Error::EncryptionFailed),
};
}
}
#[cfg(test)]
mod tests {
use crate::{
peers::encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor},
tunnel::packet_def::{ZCPacket, ZCPacketType, AES_GCM_ENCRYPTION_RESERVED},
};
#[test]
fn test_aes_gcm_cipher() {
let key = [0u8; 16];
let cipher = AesGcmCipher::new_128(key);
let text = b"1234567";
let mut packet = ZCPacket::new_with_payload(text);
packet.fill_peer_manager_hdr(0, 0, 0);
cipher.encrypt(&mut packet).unwrap();
assert_eq!(
packet.payload().len(),
text.len() + AES_GCM_ENCRYPTION_RESERVED
);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), true);
cipher.decrypt(&mut packet).unwrap();
assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), false);
}
}
@@ -141,6 +141,10 @@ impl ForeignNetworkClient {
self.get_next_hop(peer_id).is_some()
}
pub fn is_peer_public_node(&self, peer_id: &PeerId) -> bool {
self.peer_map.has_peer(*peer_id)
}
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
if self.peer_map.has_peer(peer_id) {
return Some(peer_id.clone());
+11 -6
View File
@@ -212,8 +212,13 @@ impl ForeignNetworkManager {
peer_conn.get_network_identity().network_name.clone(),
);
if entry.network.network_secret != peer_conn.get_network_identity().network_secret {
return Err(anyhow::anyhow!("network secret not match").into());
if entry.network != peer_conn.get_network_identity() {
return Err(anyhow::anyhow!(
"network secret not match. exp: {:?} real: {:?}",
entry.network,
peer_conn.get_network_identity()
)
.into());
}
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
@@ -337,10 +342,10 @@ mod tests {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx_with_network(Some(NetworkIdentity {
network_name: network.to_string(),
network_secret: network.to_string(),
})),
get_mock_global_ctx_with_network(Some(NetworkIdentity::new(
network.to_string(),
network.to_string(),
))),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
+2
View File
@@ -14,6 +14,8 @@ pub mod zc_peer_conn;
pub mod foreign_network_client;
pub mod foreign_network_manager;
pub mod encrypt;
#[cfg(test)]
pub mod tests;
+61 -8
View File
@@ -4,6 +4,7 @@ use std::{
sync::{Arc, Weak},
};
use anyhow::Context;
use async_trait::async_trait;
use futures::StreamExt;
@@ -31,6 +32,7 @@ use crate::{
};
use super::{
encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor, NullCipher},
foreign_network_client::ForeignNetworkClient,
foreign_network_manager::ForeignNetworkManager,
peer_map::PeerMap,
@@ -49,6 +51,8 @@ struct RpcTransport {
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
peer_rpc_tspt_sender: UnboundedSender<ZCPacket>,
encryptor: Arc<Box<dyn Encryptor>>,
}
#[async_trait::async_trait]
@@ -57,7 +61,7 @@ impl PeerRpcManagerTransport for RpcTransport {
self.my_peer_id
}
async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
let foreign_peers = self
.foreign_peers
.lock()
@@ -75,8 +79,17 @@ impl PeerRpcManagerTransport for RpcTransport {
?self.my_peer_id,
"send msg to peer via gateway",
);
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
peers.send_msg_directly(msg, gateway_id).await
} else if foreign_peers.has_next_hop(dst_peer_id) {
if !foreign_peers.is_peer_public_node(&dst_peer_id) {
// do not encrypt for msg sending to public node
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
}
tracing::debug!(
?dst_peer_id,
?self.my_peer_id,
@@ -134,6 +147,8 @@ pub struct PeerManager {
foreign_network_manager: Arc<ForeignNetworkManager>,
foreign_network_client: Arc<ForeignNetworkClient>,
encryptor: Arc<Box<dyn Encryptor>>,
}
impl Debug for PeerManager {
@@ -161,6 +176,13 @@ impl PeerManager {
my_peer_id,
));
let encryptor: Arc<Box<dyn Encryptor>> =
Arc::new(if global_ctx.get_flags().enable_encryption {
Box::new(AesGcmCipher::new_128(global_ctx.get_128_key()))
} else {
Box::new(NullCipher)
});
// TODO: remove these because we have impl pipeline processor.
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_tspt = Arc::new(RpcTransport {
@@ -169,6 +191,7 @@ impl PeerManager {
foreign_peers: Mutex::new(None),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
peer_rpc_tspt_sender,
encryptor: encryptor.clone(),
});
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
@@ -218,9 +241,20 @@ impl PeerManager {
foreign_network_manager,
foreign_network_client,
encryptor,
}
}
async fn add_new_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> {
if self.global_ctx.get_network_identity() != peer_conn.get_network_identity() {
return Err(Error::SecretKeyError(
"network identity not match".to_string(),
));
}
Ok(self.peers.add_new_peer_conn(peer_conn).await)
}
pub async fn add_client_tunnel(
&self,
tunnel: Box<dyn Tunnel>,
@@ -229,8 +263,10 @@ impl PeerManager {
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
let peer_id = peer.get_peer_id();
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
if peer.get_network_identity().network_name
== self.global_ctx.get_network_identity().network_name
{
self.add_new_peer_conn(peer).await?;
} else {
self.foreign_network_client.add_new_peer_conn(peer).await;
}
@@ -254,8 +290,10 @@ impl PeerManager {
tracing::info!("add tunnel as server start");
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_server().await?;
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
if peer.get_network_identity().network_name
== self.global_ctx.get_network_identity().network_name
{
self.add_new_peer_conn(peer).await?;
} else {
self.foreign_network_manager.add_peer_conn(peer).await?;
}
@@ -268,9 +306,10 @@ impl PeerManager {
let my_peer_id = self.my_peer_id;
let peers = self.peers.clone();
let pipe_line = self.peer_packet_process_pipeline.clone();
let encryptor = self.encryptor.clone();
self.tasks.lock().await.spawn(async move {
log::trace!("start_peer_recv");
while let Some(ret) = recv.next().await {
while let Some(mut ret) = recv.next().await {
let Some(hdr) = ret.peer_manager_header() else {
tracing::warn!(?ret, "invalid packet, skip");
continue;
@@ -285,6 +324,13 @@ impl PeerManager {
tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error");
}
} else {
if let Err(e) = encryptor
.decrypt(&mut ret)
.with_context(|| "decrypt failed")
{
tracing::error!(?e, "decrypt failed");
}
let mut processed = false;
let mut zc_packet = Some(ret);
let mut idx = 0;
@@ -490,7 +536,12 @@ impl PeerManager {
return Ok(());
}
msg.fill_peer_manager_hdr(self.my_peer_id, 0, packet::PacketType::Data as u8);
self.run_nic_packet_process_pipeline(&mut msg).await;
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
let mut errs: Vec<Error> = vec![];
let mut msg = Some(msg);
@@ -503,8 +554,10 @@ impl PeerManager {
};
let peer_id = &dst_peers[i];
msg.fill_peer_manager_hdr(self.my_peer_id, *peer_id, packet::PacketType::Data as u8);
msg.mut_peer_manager_header()
.unwrap()
.to_peer_id
.set(*peer_id);
if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await {
if let Err(e) = self.peers.send_msg_directly(msg, gateway).await {
+25 -9
View File
@@ -24,8 +24,9 @@ use zerocopy::AsBytes;
use crate::{
common::{
config::{NetworkIdentity, NetworkSecretDigest},
error::Error,
global_ctx::{ArcGlobalCtx, NetworkIdentity},
global_ctx::ArcGlobalCtx,
PeerId,
},
peers::packet::PacketType,
@@ -129,10 +130,17 @@ impl PeerConn {
));
};
let rsp = rsp?;
let rsp = HandshakeRequest::decode(rsp.payload())
.map_err(|e| Error::WaitRespError(format!("decode handshake response error: {:?}", e)));
let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| {
Error::WaitRespError(format!("decode handshake response error: {:?}", e))
})?;
return Ok(rsp.unwrap());
if rsp.network_secret_digrest.len() != std::mem::size_of::<NetworkSecretDigest>() {
return Err(Error::WaitRespError(
"invalid network secret digest".to_owned(),
));
}
return Ok(rsp);
}
async fn wait_handshake_loop(&mut self) -> Result<HandshakeRequest, Error> {
@@ -152,14 +160,16 @@ impl PeerConn {
async fn send_handshake(&mut self) -> Result<(), Error> {
let network = self.global_ctx.get_network_identity();
let req = HandshakeRequest {
let mut req = HandshakeRequest {
magic: MAGIC,
my_peer_id: self.my_peer_id,
version: VERSION,
features: Vec::new(),
network_name: network.network_name.clone(),
network_secret: network.network_secret.clone(),
..Default::default()
};
req.network_secret_digrest
.extend_from_slice(&network.network_secret_digest.unwrap_or_default());
let hs_req = req.encode_to_vec();
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
@@ -297,10 +307,16 @@ impl PeerConn {
pub fn get_network_identity(&self) -> NetworkIdentity {
let info = self.info.as_ref().unwrap();
NetworkIdentity {
let mut ret = NetworkIdentity {
network_name: info.network_name.clone(),
network_secret: info.network_secret.clone(),
}
..Default::default()
};
ret.network_secret_digest = Some([0u8; 32]);
ret.network_secret_digest
.as_mut()
.unwrap()
.copy_from_slice(&info.network_secret_digrest);
ret
}
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {