Files
2026-04-23 13:44:18 +08:00

169 lines
5.1 KiB
Rust

use std::sync::{Arc, Weak};
use tokio::{
sync::{Mutex, broadcast},
task::JoinSet,
time::interval,
};
use crate::{
common::{constants::EASYTIER_VERSION, get_machine_id},
proto::{
rpc_impl::bidirect::BidirectRpcManager,
rpc_types::controller::BaseController,
web::{
GetFeatureRequest, GetFeatureResponse, HeartbeatRequest, HeartbeatResponse,
WebServerServiceClientFactory,
},
},
tunnel::Tunnel,
};
use super::controller::Controller;
#[derive(Debug, Clone)]
struct HeartbeatCtx {
notifier: Arc<broadcast::Sender<HeartbeatResponse>>,
resp: Arc<Mutex<Option<HeartbeatResponse>>>,
}
pub struct Session {
rpc_mgr: BidirectRpcManager,
controller: Arc<Controller>,
heartbeat_ctx: HeartbeatCtx,
heartbeat_started: std::sync::atomic::AtomicBool,
tasks: Mutex<JoinSet<()>>,
}
impl Session {
pub fn new(tunnel: Box<dyn Tunnel>, controller: Arc<Controller>) -> Self {
let rpc_mgr = BidirectRpcManager::new();
rpc_mgr.run_with_tunnel(tunnel);
controller.register_api_rpc_service(rpc_mgr.rpc_server().registry());
let (tx, _rx1) = broadcast::channel(2);
let heartbeat_ctx = HeartbeatCtx {
notifier: Arc::new(tx),
resp: Arc::new(Mutex::new(None)),
};
Session {
rpc_mgr,
controller,
heartbeat_ctx,
heartbeat_started: std::sync::atomic::AtomicBool::new(false),
tasks: Mutex::new(JoinSet::new()),
}
}
fn heartbeat_routine(
rpc_mgr: &BidirectRpcManager,
controller: Weak<Controller>,
tasks: &mut JoinSet<()>,
ctx: HeartbeatCtx,
) {
let mid = get_machine_id();
let inst_id = uuid::Uuid::new_v4();
let token = controller.upgrade().unwrap().token();
let hostname = controller.upgrade().unwrap().hostname();
let device_os = controller.upgrade().unwrap().device_os();
let ctx_clone = ctx.clone();
let mut tick = interval(std::time::Duration::from_secs(1));
let client = rpc_mgr
.rpc_client()
.scoped_client::<WebServerServiceClientFactory<BaseController>>(1, 1, "".to_string());
tasks.spawn(async move {
loop {
tick.tick().await;
let Some(controller) = controller.upgrade() else {
break;
};
let req = HeartbeatRequest {
machine_id: Some(mid.into()),
inst_id: Some(inst_id.into()),
user_token: token.to_string(),
easytier_version: EASYTIER_VERSION.to_string(),
hostname: hostname.clone(),
report_time: chrono::Local::now().to_rfc3339(),
device_os: Some(device_os.clone()),
support_config_source: true,
running_network_instances: controller
.list_network_instance_ids()
.into_iter()
.map(Into::into)
.collect(),
};
match client
.heartbeat(BaseController::default(), req.clone())
.await
{
Err(e) => {
tracing::error!("heartbeat failed: {:?}", e);
break;
}
Ok(resp) => {
tracing::debug!("heartbeat response: {:?}", resp);
let _ = ctx_clone.notifier.send(resp);
ctx_clone.resp.lock().await.replace(resp);
}
}
}
});
}
pub async fn start_heartbeat(&self) {
if self
.heartbeat_started
.swap(true, std::sync::atomic::Ordering::AcqRel)
{
return;
}
let mut tasks = self.tasks.lock().await;
Self::heartbeat_routine(
&self.rpc_mgr,
Arc::downgrade(&self.controller),
&mut tasks,
self.heartbeat_ctx.clone(),
);
}
async fn wait_routines(&self) {
self.tasks.lock().await.join_next().await;
// if any task failed, we should abort all tasks
self.tasks.lock().await.abort_all();
}
pub async fn wait(&mut self) {
tokio::select! {
_ = self.rpc_mgr.wait() => {}
_ = self.wait_routines() => {}
}
}
pub async fn get_feature(
&self,
) -> Result<GetFeatureResponse, crate::proto::rpc_types::error::Error> {
let client = self
.rpc_mgr
.rpc_client()
.scoped_client::<WebServerServiceClientFactory<BaseController>>(1, 1, "".to_string());
client
.get_feature(BaseController::default(), GetFeatureRequest {})
.await
}
pub async fn wait_next_heartbeat(&self) -> Option<HeartbeatResponse> {
let mut rx = self.heartbeat_ctx.notifier.subscribe();
rx.recv().await.ok()
}
}