support compress for rpc and tun data (#473)

* support compress for rpc and tun data
* add compression layer to easytier-web
This commit is contained in:
Sijie.Sun
2024-11-16 11:23:18 +08:00
committed by GitHub
parent 9d455e22fa
commit 6cdea38284
22 changed files with 623 additions and 82 deletions
+191
View File
@@ -0,0 +1,191 @@
use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder};
use tokio::io::AsyncWriteExt;
use zerocopy::{AsBytes as _, FromBytes as _};
use crate::tunnel::packet_def::{CompressorAlgo, CompressorTail, ZCPacket, COMPRESSOR_TAIL_SIZE};
type Error = anyhow::Error;
#[async_trait::async_trait]
pub trait Compressor {
async fn compress(
&self,
packet: &mut ZCPacket,
compress_algo: CompressorAlgo,
) -> Result<(), Error>;
async fn decompress(&self, packet: &mut ZCPacket) -> Result<(), Error>;
}
pub struct DefaultCompressor {}
impl DefaultCompressor {
pub fn new() -> Self {
DefaultCompressor {}
}
pub async fn compress_raw(
&self,
data: &[u8],
compress_algo: CompressorAlgo,
) -> Result<Vec<u8>, Error> {
let buf = match compress_algo {
CompressorAlgo::ZstdDefault => {
let mut o = ZstdEncoder::new(Vec::new());
o.write_all(data).await?;
o.shutdown().await?;
o.into_inner()
}
CompressorAlgo::None => data.to_vec(),
};
Ok(buf)
}
pub async fn decompress_raw(
&self,
data: &[u8],
compress_algo: CompressorAlgo,
) -> Result<Vec<u8>, Error> {
let buf = match compress_algo {
CompressorAlgo::ZstdDefault => {
let mut o = ZstdDecoder::new(Vec::new());
o.write_all(data).await?;
o.shutdown().await?;
o.into_inner()
}
CompressorAlgo::None => data.to_vec(),
};
Ok(buf)
}
}
#[async_trait::async_trait]
impl Compressor for DefaultCompressor {
async fn compress(
&self,
zc_packet: &mut ZCPacket,
compress_algo: CompressorAlgo,
) -> Result<(), Error> {
if matches!(compress_algo, CompressorAlgo::None) {
return Ok(());
}
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_compressed() {
return Ok(());
}
let tail = CompressorTail::new(compress_algo);
let buf = self
.compress_raw(zc_packet.payload(), compress_algo)
.await?;
if buf.len() + COMPRESSOR_TAIL_SIZE > pm_header.len.get() as usize {
// Compressed data is larger than original data, don't compress
return Ok(());
}
zc_packet
.mut_peer_manager_header()
.unwrap()
.set_compressed(true);
let payload_offset = zc_packet.payload_offset();
zc_packet.mut_inner().truncate(payload_offset);
zc_packet.mut_inner().extend_from_slice(&buf);
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
Ok(())
}
async fn decompress(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if !pm_header.is_compressed() {
return Ok(());
}
let payload_len = zc_packet.payload().len();
if payload_len < COMPRESSOR_TAIL_SIZE {
return Err(anyhow::anyhow!("Packet too short: {}", payload_len));
}
let text_len = payload_len - COMPRESSOR_TAIL_SIZE;
let tail = CompressorTail::ref_from_suffix(zc_packet.payload())
.unwrap()
.clone();
let algo = tail
.get_algo()
.ok_or(anyhow::anyhow!("Unknown algo: {:?}", tail))?;
let buf = self
.decompress_raw(&zc_packet.payload()[..text_len], algo)
.await?;
if buf.len() != pm_header.len.get() as usize {
anyhow::bail!(
"Decompressed length mismatch: decompressed len {} != pm header len {}",
buf.len(),
pm_header.len.get()
);
}
zc_packet
.mut_peer_manager_header()
.unwrap()
.set_compressed(false);
let payload_offset = zc_packet.payload_offset();
zc_packet.mut_inner().truncate(payload_offset);
zc_packet.mut_inner().extend_from_slice(&buf);
Ok(())
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[tokio::test]
async fn test_compress() {
let text = b"12345670000000000000000000";
let mut packet = ZCPacket::new_with_payload(text);
packet.fill_peer_manager_hdr(0, 0, 0);
let compressor = DefaultCompressor {};
compressor
.compress(&mut packet, CompressorAlgo::ZstdDefault)
.await
.unwrap();
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true);
compressor.decompress(&mut packet).await.unwrap();
assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
}
#[tokio::test]
async fn test_short_text_compress() {
let text = b"1234";
let mut packet = ZCPacket::new_with_payload(text);
packet.fill_peer_manager_hdr(0, 0, 0);
let compressor = DefaultCompressor {};
// short text can't be compressed
compressor
.compress(&mut packet, CompressorAlgo::ZstdDefault)
.await
.unwrap();
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
compressor.decompress(&mut packet).await.unwrap();
assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
}
}
+2 -1
View File
@@ -7,7 +7,7 @@ use std::{
use anyhow::Context;
use serde::{Deserialize, Serialize};
use crate::tunnel::generate_digest_from_str;
use crate::{proto::common::CompressionAlgoPb, tunnel::generate_digest_from_str};
pub type Flags = crate::proto::common::FlagsInConfig;
@@ -28,6 +28,7 @@ pub fn gen_default_flags() -> Flags {
disable_udp_hole_punching: false,
ipv6_listener: "udp://[::]:0".to_string(),
multi_thread: false,
data_compress_algo: CompressionAlgoPb::None.into(),
}
}
+1
View File
@@ -6,6 +6,7 @@ use std::{
use tokio::task::JoinSet;
use tracing::Instrument;
pub mod compressor;
pub mod config;
pub mod constants;
pub mod defer;