Files
Easytier/easytier/src/peers/encrypt/openssl.rs
T

202 lines
6.0 KiB
Rust

use crate::tunnel::packet_def::{StandardAeadTail, ZCPacket};
use openssl::symm::{Cipher, Crypter, Mode};
use rand::RngCore;
use zerocopy::{AsBytes, FromBytes, FromZeroes};
use crate::peers::encrypt::{Encryptor, Error};
#[derive(Clone)]
pub struct OpenSslCipher {
pub(crate) cipher: OpenSslEnum,
}
#[derive(Clone, Copy)]
pub enum OpenSslEnum {
Aes128Gcm([u8; 16]),
Aes256Gcm([u8; 32]),
ChaCha20([u8; 32]),
}
impl OpenSslCipher {
pub fn new_aes128_gcm(key: [u8; 16]) -> Self {
Self {
cipher: OpenSslEnum::Aes128Gcm(key),
}
}
pub fn new_aes256_gcm(key: [u8; 32]) -> Self {
Self {
cipher: OpenSslEnum::Aes256Gcm(key),
}
}
pub fn new_chacha20(key: [u8; 32]) -> Self {
Self {
cipher: OpenSslEnum::ChaCha20(key),
}
}
fn get_cipher_and_key(&self) -> (Cipher, &[u8]) {
match &self.cipher {
OpenSslEnum::Aes128Gcm(key) => (Cipher::aes_128_gcm(), key.as_slice()),
OpenSslEnum::Aes256Gcm(key) => (Cipher::aes_256_gcm(), key.as_slice()),
OpenSslEnum::ChaCha20(key) => (Cipher::chacha20_poly1305(), key.as_slice()),
}
}
}
impl Encryptor for OpenSslCipher {
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if !pm_header.is_encrypted() {
return Ok(());
}
let payload = zc_packet.payload();
let len = payload.len();
if len < StandardAeadTail::SIZE {
return Err(Error::PacketTooShort(len));
}
let (cipher, key) = self.get_cipher_and_key();
// 提取 nonce/IV 和 tag
let tail = StandardAeadTail::ref_from_suffix(payload).unwrap();
let mut decrypter = Crypter::new(cipher, Mode::Decrypt, key, Some(&tail.nonce))
.map_err(|_| Error::DecryptionFailed)?;
decrypter
.set_tag(&tail.tag)
.map_err(|_| Error::DecryptionFailed)?;
let text_len = len - StandardAeadTail::SIZE;
let mut output = vec![0u8; text_len + cipher.block_size()];
let mut count = decrypter
.update(&payload[..text_len], &mut output)
.map_err(|_| Error::DecryptionFailed)?;
count += decrypter
.finalize(&mut output[count..])
.map_err(|_| Error::DecryptionFailed)?;
// 更新数据包
zc_packet.mut_payload()[..count].copy_from_slice(&output[..count]);
let pm_header = zc_packet.mut_peer_manager_header().unwrap();
pm_header.set_encrypted(false);
let len = zc_packet.buf_len() - (len - count);
zc_packet.mut_inner().truncate(len);
Ok(())
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
self.encrypt_with_nonce(zc_packet, None)
}
fn encrypt_with_nonce(
&self,
zc_packet: &mut ZCPacket,
nonce: Option<&[u8]>,
) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
return Ok(());
}
let (cipher, key) = self.get_cipher_and_key();
let mut tail = StandardAeadTail::new_zeroed();
if let Some(nonce) = nonce {
if nonce.len() != StandardAeadTail::NONCE_SIZE {
return Err(Error::EncryptionFailed);
}
tail.nonce.copy_from_slice(nonce);
} else {
rand::thread_rng().fill_bytes(&mut tail.nonce);
}
let mut encrypter = Crypter::new(cipher, Mode::Encrypt, key, Some(&tail.nonce))
.map_err(|_| Error::EncryptionFailed)?;
let payload_len = zc_packet.payload().len();
let mut output = vec![0u8; payload_len + cipher.block_size()];
let mut count = encrypter
.update(zc_packet.payload(), &mut output)
.map_err(|_| Error::EncryptionFailed)?;
count += encrypter
.finalize(&mut output[count..])
.map_err(|_| Error::EncryptionFailed)?;
// 更新数据包内容
zc_packet.mut_payload()[..count].copy_from_slice(&output[..count]);
encrypter
.get_tag(&mut tail.tag)
.map_err(|_| Error::EncryptionFailed)?;
// 添加 nonce/IV & tag 的结构
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
let pm_header = zc_packet.mut_peer_manager_header().unwrap();
pm_header.set_encrypted(true);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn run_cipher_test_with_nonce(cipher: OpenSslCipher) {
let text = b"Hello, World! This is a standardized test message.";
let nonce: [u8; 12] = [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112];
let mut packet = ZCPacket::new_with_payload(text);
packet.fill_peer_manager_hdr(0, 0, 0);
cipher
.encrypt_with_nonce(&mut packet, Some(&nonce))
.unwrap();
let payload = packet.payload();
let len = payload.len();
assert!(len > text.len() + StandardAeadTail::SIZE - 1);
assert!(packet.peer_manager_header().unwrap().is_encrypted());
let tail = StandardAeadTail::ref_from_suffix(payload).unwrap().clone();
assert_eq!(tail.nonce, nonce);
cipher.decrypt(&mut packet).unwrap();
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_encrypted());
}
#[test]
fn test_openssl_aes128_gcm() {
let key = [1u8; 16];
let cipher = OpenSslCipher::new_aes128_gcm(key);
run_cipher_test_with_nonce(cipher);
}
#[test]
fn test_openssl_aes256_gcm() {
let key = [2u8; 32];
let cipher = OpenSslCipher::new_aes256_gcm(key);
run_cipher_test_with_nonce(cipher);
}
#[test]
fn test_openssl_chacha20() {
let key = [3u8; 32];
let cipher = OpenSslCipher::new_chacha20(key);
run_cipher_test_with_nonce(cipher);
}
}