mirror of
https://github.com/EasyTier/EasyTier.git
synced 2026-05-07 10:14:35 +00:00
@@ -2,6 +2,7 @@ use std::{
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
path::PathBuf,
|
||||
sync::{Arc, Mutex},
|
||||
u64,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -41,6 +42,7 @@ pub fn gen_default_flags() -> Flags {
|
||||
private_mode: false,
|
||||
enable_quic_proxy: false,
|
||||
disable_quic_input: false,
|
||||
foreign_relay_bps_limit: u64::MAX,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::{
|
||||
};
|
||||
|
||||
use crate::common::config::ProxyNetworkConfig;
|
||||
use crate::common::token_bucket::TokenBucketManager;
|
||||
use crate::proto::cli::PeerConnInfo;
|
||||
use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb};
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
@@ -77,6 +78,8 @@ pub struct GlobalCtx {
|
||||
feature_flags: AtomicCell<PeerFeatureFlag>,
|
||||
|
||||
quic_proxy_port: AtomicCell<Option<u16>>,
|
||||
|
||||
token_bucket_manager: TokenBucketManager,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GlobalCtx {
|
||||
@@ -140,6 +143,8 @@ impl GlobalCtx {
|
||||
|
||||
feature_flags: AtomicCell::new(feature_flags),
|
||||
quic_proxy_port: AtomicCell::new(None),
|
||||
|
||||
token_bucket_manager: TokenBucketManager::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,6 +297,10 @@ impl GlobalCtx {
|
||||
pub fn set_quic_proxy_port(&self, port: Option<u16>) {
|
||||
self.quic_proxy_port.store(port);
|
||||
}
|
||||
|
||||
pub fn token_bucket_manager(&self) -> &TokenBucketManager {
|
||||
&self.token_bucket_manager
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -23,6 +23,7 @@ pub mod network;
|
||||
pub mod scoped_task;
|
||||
pub mod stun;
|
||||
pub mod stun_codec_ext;
|
||||
pub mod token_bucket;
|
||||
|
||||
pub fn get_logger_timer<F: time::formatting::Formattable>(
|
||||
format: F,
|
||||
|
||||
@@ -0,0 +1,312 @@
|
||||
use atomic_shim::AtomicU64;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::time;
|
||||
|
||||
use crate::common::scoped_task::ScopedTask;
|
||||
use crate::proto::common::LimiterConfig;
|
||||
|
||||
/// Token Bucket rate limiter using atomic operations
|
||||
pub struct TokenBucket {
|
||||
available_tokens: AtomicU64, // Current token count (atomic)
|
||||
last_refill_time: AtomicU64, // Last refill time as micros since epoch
|
||||
config: BucketConfig, // Immutable configuration
|
||||
refill_task: Mutex<Option<ScopedTask<()>>>, // Background refill task
|
||||
start_time: Instant, // Bucket creation time
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct BucketConfig {
|
||||
capacity: u64, // Maximum token capacity
|
||||
fill_rate: u64, // Tokens added per second
|
||||
refill_interval: Duration, // Time between refill operations
|
||||
}
|
||||
|
||||
impl From<LimiterConfig> for BucketConfig {
|
||||
fn from(cfg: LimiterConfig) -> Self {
|
||||
let burst_rate = 1.max(cfg.burst_rate.unwrap_or(1));
|
||||
let fill_rate = 8196.max(cfg.bps.unwrap_or(u64::MAX / burst_rate));
|
||||
let refill_interval = cfg
|
||||
.fill_duration_ms
|
||||
.map(|x| Duration::from_millis(1.max(x)))
|
||||
.unwrap_or(Duration::from_millis(10));
|
||||
BucketConfig {
|
||||
capacity: burst_rate * fill_rate,
|
||||
fill_rate: fill_rate,
|
||||
refill_interval,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
pub fn new(capacity: u64, bps: u64, refill_interval: Duration) -> Arc<Self> {
|
||||
let config = BucketConfig {
|
||||
capacity,
|
||||
fill_rate: bps,
|
||||
refill_interval,
|
||||
};
|
||||
Self::new_from_cfg(config)
|
||||
}
|
||||
|
||||
/// Creates a new Token Bucket rate limiter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `capacity` - Bucket capacity in bytes
|
||||
/// * `bps` - Bandwidth limit in bytes per second
|
||||
/// * `refill_interval` - Refill interval (recommended 10-50ms)
|
||||
pub fn new_from_cfg(config: BucketConfig) -> Arc<Self> {
|
||||
// Create Arc instance with placeholder task
|
||||
let arc_self = Arc::new(Self {
|
||||
available_tokens: AtomicU64::new(config.capacity),
|
||||
last_refill_time: AtomicU64::new(0),
|
||||
config,
|
||||
refill_task: Mutex::new(None),
|
||||
start_time: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
// Start background refill task
|
||||
let arc_clone = arc_self.clone();
|
||||
let refill_task = tokio::spawn(async move {
|
||||
let mut interval = time::interval(arc_clone.config.refill_interval);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
arc_clone.refill();
|
||||
}
|
||||
});
|
||||
|
||||
// Replace placeholder task with actual one
|
||||
arc_self
|
||||
.refill_task
|
||||
.lock()
|
||||
.unwrap()
|
||||
.replace(refill_task.into());
|
||||
arc_self
|
||||
}
|
||||
|
||||
/// Internal refill method (called only by background task)
|
||||
fn refill(&self) {
|
||||
let now_micros = self.elapsed_micros();
|
||||
let prev_time = self.last_refill_time.swap(now_micros, Ordering::Acquire);
|
||||
|
||||
// Calculate elapsed time in seconds
|
||||
let elapsed_secs = (now_micros.saturating_sub(prev_time)) as f64 / 1_000_000.0;
|
||||
|
||||
// Calculate tokens to add
|
||||
let tokens_to_add = (self.config.fill_rate as f64 * elapsed_secs) as u64;
|
||||
if tokens_to_add == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
// Add tokens without exceeding capacity
|
||||
let mut current = self.available_tokens.load(Ordering::Relaxed);
|
||||
loop {
|
||||
let new = current
|
||||
.saturating_add(tokens_to_add)
|
||||
.min(self.config.capacity);
|
||||
match self.available_tokens.compare_exchange_weak(
|
||||
current,
|
||||
new,
|
||||
Ordering::Release,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(actual) => current = actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate microseconds since bucket creation
|
||||
fn elapsed_micros(&self) -> u64 {
|
||||
self.start_time.elapsed().as_micros() as u64
|
||||
}
|
||||
|
||||
/// Attempt to consume tokens without blocking
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if tokens were consumed, `false` if insufficient tokens
|
||||
pub fn try_consume(&self, tokens: u64) -> bool {
|
||||
// Fast path for oversized packets
|
||||
if tokens > self.config.capacity {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut current = self.available_tokens.load(Ordering::Relaxed);
|
||||
loop {
|
||||
if current < tokens {
|
||||
return false;
|
||||
}
|
||||
|
||||
let new = current - tokens;
|
||||
match self.available_tokens.compare_exchange_weak(
|
||||
current,
|
||||
new,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => return true,
|
||||
Err(actual) => current = actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TokenBucketManager {
|
||||
buckets: Arc<DashMap<String, Arc<TokenBucket>>>,
|
||||
|
||||
retain_task: ScopedTask<()>,
|
||||
}
|
||||
|
||||
impl TokenBucketManager {
|
||||
/// Creates a new TokenBucketManager
|
||||
pub fn new() -> Self {
|
||||
let buckets = Arc::new(DashMap::new());
|
||||
|
||||
let buckets_clone = buckets.clone();
|
||||
let retain_task = tokio::spawn(async move {
|
||||
loop {
|
||||
// Retain only buckets that are still in use
|
||||
buckets_clone.retain(|_, bucket| Arc::<TokenBucket>::strong_count(bucket) <= 1);
|
||||
// Sleep for a while before next retention check
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
buckets,
|
||||
retain_task: retain_task.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create a token bucket for the given key
|
||||
pub fn get_or_create(&self, key: &str, cfg: BucketConfig) -> Arc<TokenBucket> {
|
||||
self.buckets
|
||||
.entry(key.to_string())
|
||||
.or_insert_with(|| TokenBucket::new_from_cfg(cfg))
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
/// Test initial state after creation
|
||||
#[tokio::test]
|
||||
async fn test_initial_state() {
|
||||
let bucket = TokenBucket::new(1000, 1000, Duration::from_millis(10));
|
||||
|
||||
// Should have full capacity initially
|
||||
assert!(bucket.try_consume(1000));
|
||||
assert!(!bucket.try_consume(1)); // Should be empty now
|
||||
}
|
||||
|
||||
/// Test token consumption behavior
|
||||
#[tokio::test]
|
||||
async fn test_consumption() {
|
||||
let bucket = TokenBucket::new(1500, 1000, Duration::from_millis(10));
|
||||
|
||||
// First packet should succeed
|
||||
assert!(bucket.try_consume(1000));
|
||||
|
||||
// Second packet should fail (only 500 left)
|
||||
assert!(!bucket.try_consume(600));
|
||||
|
||||
// Should be able to take remaining tokens
|
||||
assert!(bucket.try_consume(500));
|
||||
}
|
||||
|
||||
/// Test background refill functionality
|
||||
#[tokio::test]
|
||||
async fn test_refill() {
|
||||
let bucket = TokenBucket::new(1000, 1000, Duration::from_millis(10));
|
||||
|
||||
// Drain the bucket
|
||||
assert!(bucket.try_consume(1000));
|
||||
assert!(!bucket.try_consume(1));
|
||||
|
||||
// Wait for refill (1 refill interval + buffer)
|
||||
sleep(Duration::from_millis(25)).await;
|
||||
|
||||
// Should have approximately 20 tokens (1000 tokens/s * 0.02s)
|
||||
assert!(bucket.try_consume(15));
|
||||
assert!(!bucket.try_consume(10)); // But not full capacity
|
||||
}
|
||||
|
||||
/// Test capacity enforcement
|
||||
#[tokio::test]
|
||||
async fn test_capacity_limit() {
|
||||
let bucket = TokenBucket::new(500, 1000, Duration::from_millis(10));
|
||||
|
||||
// Wait longer than refill interval
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Should not exceed capacity despite time passed
|
||||
assert!(bucket.try_consume(500));
|
||||
assert!(!bucket.try_consume(1));
|
||||
}
|
||||
|
||||
/// Test high load with concurrent access
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_access() {
|
||||
let bucket = TokenBucket::new(10_000, 1_000_000, Duration::from_millis(10));
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn 100 tasks to consume tokens concurrently
|
||||
for _ in 0..100 {
|
||||
let bucket = bucket.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
for _ in 0..100 {
|
||||
let _ = bucket.try_consume(10);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Wait for all tasks to complete
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// Verify we didn't exceed capacity
|
||||
let tokens_left = bucket.available_tokens.load(Ordering::Relaxed);
|
||||
assert!(
|
||||
tokens_left <= 10_000,
|
||||
"Tokens exceeded capacity: {}",
|
||||
tokens_left
|
||||
);
|
||||
}
|
||||
|
||||
/// Test behavior when packet size exceeds capacity
|
||||
#[tokio::test]
|
||||
async fn test_oversized_packet() {
|
||||
let bucket = TokenBucket::new(1500, 1000, Duration::from_millis(10));
|
||||
|
||||
// Packet larger than capacity should be rejected
|
||||
assert!(!bucket.try_consume(1600));
|
||||
|
||||
// Regular packets should still work
|
||||
assert!(bucket.try_consume(1000));
|
||||
}
|
||||
|
||||
/// Test refill precision with small intervals
|
||||
#[tokio::test]
|
||||
async fn test_refill_precision() {
|
||||
let bucket = TokenBucket::new(10_000, 10_000, Duration::from_micros(100)); // 100μs interval
|
||||
|
||||
// Drain most tokens
|
||||
assert!(bucket.try_consume(9900));
|
||||
|
||||
// Wait for multiple refills
|
||||
sleep(Duration::from_millis(1)).await;
|
||||
|
||||
// Should have accumulated about 100 tokens (10,000 tokens/s * 0.001s)
|
||||
let tokens = bucket.available_tokens.load(Ordering::Relaxed);
|
||||
assert!(
|
||||
tokens >= 100 && tokens <= 200,
|
||||
"Unexpected token count: {}",
|
||||
tokens
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user