fix token mismatch when using web (#871)

This commit is contained in:
Sijie.Sun
2025-05-24 00:36:00 +08:00
committed by GitHub
parent 5a2fd4465c
commit fec885c427
6 changed files with 108 additions and 90 deletions
+11 -5
View File
@@ -10,7 +10,7 @@ use easytier::{
use session::Session; use session::Session;
use storage::{Storage, StorageToken}; use storage::{Storage, StorageToken};
use crate::db::Db; use crate::db::{Db, UserIdInDb};
#[derive(Debug)] #[derive(Debug)]
pub struct ClientManager { pub struct ClientManager {
@@ -86,15 +86,21 @@ impl ClientManager {
ret ret
} }
pub fn get_session_by_machine_id(&self, machine_id: &uuid::Uuid) -> Option<Arc<Session>> { pub fn get_session_by_machine_id(
let c_url = self.storage.get_client_url_by_machine_id(machine_id)?; &self,
user_id: UserIdInDb,
machine_id: &uuid::Uuid,
) -> Option<Arc<Session>> {
let c_url = self
.storage
.get_client_url_by_machine_id(user_id, machine_id)?;
self.client_sessions self.client_sessions
.get(&c_url) .get(&c_url)
.map(|item| item.value().clone()) .map(|item| item.value().clone())
} }
pub async fn list_machine_by_token(&self, token: String) -> Vec<url::Url> { pub async fn list_machine_by_user_id(&self, user_id: UserIdInDb) -> Vec<url::Url> {
self.storage.list_token_clients(&token) self.storage.list_user_clients(user_id)
} }
pub async fn get_heartbeat_requests(&self, client_url: &url::Url) -> Option<HeartbeatRequest> { pub async fn get_heartbeat_requests(&self, client_url: &url::Url) -> Option<HeartbeatRequest> {
+34 -9
View File
@@ -1,5 +1,6 @@
use std::{fmt::Debug, str::FromStr as _, sync::Arc}; use std::{fmt::Debug, str::FromStr as _, sync::Arc};
use anyhow::Context;
use easytier::{ use easytier::{
common::scoped_task::ScopedTask, common::scoped_task::ScopedTask,
proto::{ proto::{
@@ -78,22 +79,47 @@ impl WebServerService for SessionRpcService {
req: HeartbeatRequest, req: HeartbeatRequest,
) -> rpc_types::error::Result<HeartbeatResponse> { ) -> rpc_types::error::Result<HeartbeatResponse> {
let mut data = self.data.write().await; let mut data = self.data.write().await;
let Ok(storage) = Storage::try_from(data.storage.clone()) else {
tracing::error!("Failed to get storage");
return Ok(HeartbeatResponse {});
};
let machine_id: uuid::Uuid =
req.machine_id
.clone()
.map(Into::into)
.ok_or(anyhow::anyhow!(
"Machine id is not set correctly, expect uuid but got: {:?}",
req.machine_id
))?;
let user_id = storage
.db()
.get_user_id_by_token(req.user_token.clone())
.await
.with_context(|| {
format!(
"Failed to get user id by token from db: {:?}",
req.user_token
)
})?
.ok_or(anyhow::anyhow!(
"User not found by token: {:?}",
req.user_token
))?;
if data.req.replace(req.clone()).is_none() { if data.req.replace(req.clone()).is_none() {
assert!(data.storage_token.is_none()); assert!(data.storage_token.is_none());
data.storage_token = Some(StorageToken { data.storage_token = Some(StorageToken {
token: req.user_token.clone().into(), token: req.user_token.clone().into(),
client_url: data.client_url.clone(), client_url: data.client_url.clone(),
machine_id: req machine_id,
.machine_id user_id,
.clone()
.map(Into::into)
.unwrap_or(uuid::Uuid::new_v4()),
}); });
} }
if let Ok(storage) = Storage::try_from(data.storage.clone()) { let Ok(report_time) = chrono::DateTime::<chrono::Local>::from_str(&req.report_time) else {
let Ok(report_time) = chrono::DateTime::<chrono::Local>::from_str(&req.report_time)
else {
tracing::error!("Failed to parse report time: {:?}", req.report_time); tracing::error!("Failed to parse report time: {:?}", req.report_time);
return Ok(HeartbeatResponse {}); return Ok(HeartbeatResponse {});
}; };
@@ -101,7 +127,6 @@ impl WebServerService for SessionRpcService {
data.storage_token.as_ref().unwrap().clone(), data.storage_token.as_ref().unwrap().clone(),
report_time.timestamp(), report_time.timestamp(),
); );
}
let _ = data.notifier.send(req); let _ = data.notifier.send(req);
Ok(HeartbeatResponse {}) Ok(HeartbeatResponse {})
+33 -33
View File
@@ -2,7 +2,7 @@ use std::sync::{Arc, Weak};
use dashmap::DashMap; use dashmap::DashMap;
use crate::db::Db; use crate::db::{Db, UserIdInDb};
// use this to maintain Storage // use this to maintain Storage
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
@@ -10,21 +10,19 @@ pub struct StorageToken {
pub token: String, pub token: String,
pub client_url: url::Url, pub client_url: url::Url,
pub machine_id: uuid::Uuid, pub machine_id: uuid::Uuid,
pub user_id: UserIdInDb,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct ClientInfo { struct ClientInfo {
client_url: url::Url, storage_token: StorageToken,
machine_id: uuid::Uuid,
token: String,
report_time: i64, report_time: i64,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct StorageInner { pub struct StorageInner {
// some map for indexing // some map for indexing
token_clients_map: DashMap<String, DashMap<uuid::Uuid, ClientInfo>>, user_clients_map: DashMap<UserIdInDb, DashMap<uuid::Uuid, ClientInfo>>,
machine_client_url_map: DashMap<uuid::Uuid, ClientInfo>,
pub db: Db, pub db: Db,
} }
@@ -43,8 +41,7 @@ impl TryFrom<WeakRefStorage> for Storage {
impl Storage { impl Storage {
pub fn new(db: Db) -> Self { pub fn new(db: Db) -> Self {
Storage(Arc::new(StorageInner { Storage(Arc::new(StorageInner {
token_clients_map: DashMap::new(), user_clients_map: DashMap::new(),
machine_client_url_map: DashMap::new(),
db, db,
})) }))
} }
@@ -54,17 +51,22 @@ impl Storage {
machine_id: &uuid::Uuid, machine_id: &uuid::Uuid,
client_url: &url::Url, client_url: &url::Url,
) { ) {
map.remove_if(&machine_id, |_, v| v.client_url == *client_url); map.remove_if(&machine_id, |_, v| {
v.storage_token.client_url == *client_url
});
} }
fn update_mid_to_client_info_map( fn update_mid_to_client_info_map(
map: &DashMap<uuid::Uuid, ClientInfo>, map: &DashMap<uuid::Uuid, ClientInfo>,
client_info: &ClientInfo, client_info: &ClientInfo,
) { ) {
map.entry(client_info.machine_id) map.entry(client_info.storage_token.machine_id)
.and_modify(|e| { .and_modify(|e| {
if e.report_time < client_info.report_time { if e.report_time < client_info.report_time {
assert_eq!(e.machine_id, client_info.machine_id); assert_eq!(
e.storage_token.machine_id,
client_info.storage_token.machine_id
);
*e = client_info.clone(); *e = client_info.clone();
} }
}) })
@@ -74,53 +76,51 @@ impl Storage {
pub fn update_client(&self, stoken: StorageToken, report_time: i64) { pub fn update_client(&self, stoken: StorageToken, report_time: i64) {
let inner = self let inner = self
.0 .0
.token_clients_map .user_clients_map
.entry(stoken.token.clone()) .entry(stoken.user_id)
.or_insert_with(DashMap::new); .or_insert_with(DashMap::new);
let client_info = ClientInfo { let client_info = ClientInfo {
client_url: stoken.client_url.clone(), storage_token: stoken.clone(),
machine_id: stoken.machine_id,
token: stoken.token.clone(),
report_time, report_time,
}; };
Self::update_mid_to_client_info_map(&inner, &client_info); Self::update_mid_to_client_info_map(&inner, &client_info);
Self::update_mid_to_client_info_map(&self.0.machine_client_url_map, &client_info);
} }
pub fn remove_client(&self, stoken: &StorageToken) { pub fn remove_client(&self, stoken: &StorageToken) {
self.0.token_clients_map.remove_if(&stoken.token, |_, set| { self.0
.user_clients_map
.remove_if(&stoken.user_id, |_, set| {
Self::remove_mid_to_client_info_map(set, &stoken.machine_id, &stoken.client_url); Self::remove_mid_to_client_info_map(set, &stoken.machine_id, &stoken.client_url);
set.is_empty() set.is_empty()
}); });
Self::remove_mid_to_client_info_map(
&self.0.machine_client_url_map,
&stoken.machine_id,
&stoken.client_url,
);
} }
pub fn weak_ref(&self) -> WeakRefStorage { pub fn weak_ref(&self) -> WeakRefStorage {
Arc::downgrade(&self.0) Arc::downgrade(&self.0)
} }
pub fn get_client_url_by_machine_id(&self, machine_id: &uuid::Uuid) -> Option<url::Url> { pub fn get_client_url_by_machine_id(
self.0 &self,
.machine_client_url_map user_id: UserIdInDb,
.get(&machine_id) machine_id: &uuid::Uuid,
.map(|info| info.client_url.clone()) ) -> Option<url::Url> {
self.0.user_clients_map.get(&user_id).and_then(|info_map| {
info_map
.get(machine_id)
.map(|info| info.storage_token.client_url.clone())
})
} }
pub fn list_token_clients(&self, token: &str) -> Vec<url::Url> { pub fn list_user_clients(&self, user_id: UserIdInDb) -> Vec<url::Url> {
self.0 self.0
.token_clients_map .user_clients_map
.get(token) .get(&user_id)
.map(|info_map| { .map(|info_map| {
info_map info_map
.iter() .iter()
.map(|info| info.value().client_url.clone()) .map(|info| info.value().storage_token.client_url.clone())
.collect() .collect()
}) })
.unwrap_or_default() .unwrap_or_default()
+1 -1
View File
@@ -12,7 +12,7 @@ use sqlx::{migrate::MigrateDatabase as _, types::chrono, Sqlite, SqlitePool};
use crate::migrator; use crate::migrator;
type UserIdInDb = i32; pub type UserIdInDb = i32;
pub enum ListNetworkProps { pub enum ListNetworkProps {
All, All,
+2 -16
View File
@@ -9,7 +9,7 @@ use axum::http::StatusCode;
use axum::routing::post; use axum::routing::post;
use axum::{extract::State, routing::get, Json, Router}; use axum::{extract::State, routing::get, Json, Router};
use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer}; use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer};
use axum_login::{login_required, AuthManagerLayerBuilder, AuthzBackend}; use axum_login::{login_required, AuthManagerLayerBuilder, AuthUser, AuthzBackend};
use axum_messages::MessagesManagerLayer; use axum_messages::MessagesManagerLayer;
use easytier::common::config::ConfigLoader; use easytier::common::config::ConfigLoader;
use easytier::common::scoped_task::ScopedTask; use easytier::common::scoped_task::ScopedTask;
@@ -24,7 +24,6 @@ use tower_sessions::Expiry;
use tower_sessions_sqlx_store::SqliteStore; use tower_sessions_sqlx_store::SqliteStore;
use users::{AuthSession, Backend}; use users::{AuthSession, Backend};
use crate::client_manager::session::Session;
use crate::client_manager::storage::StorageToken; use crate::client_manager::storage::StorageToken;
use crate::client_manager::ClientManager; use crate::client_manager::ClientManager;
use crate::db::Db; use crate::db::Db;
@@ -112,17 +111,6 @@ impl RestfulServer {
}) })
} }
async fn get_session_by_machine_id(
client_mgr: &ClientManager,
machine_id: &uuid::Uuid,
) -> Result<Arc<Session>, HttpHandleError> {
let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else {
return Err((StatusCode::NOT_FOUND, other_error("No such session").into()));
};
Ok(result)
}
async fn handle_list_all_sessions( async fn handle_list_all_sessions(
auth_session: AuthSession, auth_session: AuthSession,
State(client_mgr): AppState, State(client_mgr): AppState,
@@ -145,9 +133,7 @@ impl RestfulServer {
return Err((StatusCode::UNAUTHORIZED, other_error("No such user").into())); return Err((StatusCode::UNAUTHORIZED, other_error("No such user").into()));
}; };
let machines = client_mgr let machines = client_mgr.list_machine_by_user_id(user.id().clone()).await;
.list_machine_by_token(user.tokens[0].clone())
.await;
Ok(GetSummaryJsonResp { Ok(GetSummaryJsonResp {
device_count: machines.len() as u32, device_count: machines.len() as u32,
+17 -16
View File
@@ -5,7 +5,6 @@ use axum::http::StatusCode;
use axum::routing::{delete, post}; use axum::routing::{delete, post};
use axum::{extract::State, routing::get, Json, Router}; use axum::{extract::State, routing::get, Json, Router};
use axum_login::AuthUser; use axum_login::AuthUser;
use dashmap::DashSet;
use easytier::launcher::NetworkConfig; use easytier::launcher::NetworkConfig;
use easytier::proto::common::Void; use easytier::proto::common::Void;
use easytier::proto::rpc_types::controller::BaseController; use easytier::proto::rpc_types::controller::BaseController;
@@ -13,7 +12,7 @@ use easytier::proto::web::*;
use crate::client_manager::session::Session; use crate::client_manager::session::Session;
use crate::client_manager::ClientManager; use crate::client_manager::ClientManager;
use crate::db::ListNetworkProps; use crate::db::{ListNetworkProps, UserIdInDb};
use super::users::AuthSession; use super::users::AuthSession;
use super::{ use super::{
@@ -81,12 +80,24 @@ impl NetworkApi {
Self {} Self {}
} }
fn get_user_id(auth_session: &AuthSession) -> Result<UserIdInDb, (StatusCode, Json<Error>)> {
let Some(user_id) = auth_session.user.as_ref().map(|x| x.id()) else {
return Err((
StatusCode::UNAUTHORIZED,
other_error(format!("No user id found")).into(),
));
};
Ok(user_id)
}
async fn get_session_by_machine_id( async fn get_session_by_machine_id(
auth_session: &AuthSession, auth_session: &AuthSession,
client_mgr: &ClientManager, client_mgr: &ClientManager,
machine_id: &uuid::Uuid, machine_id: &uuid::Uuid,
) -> Result<Arc<Session>, HttpHandleError> { ) -> Result<Arc<Session>, HttpHandleError> {
let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else { let user_id = Self::get_user_id(auth_session)?;
let Some(result) = client_mgr.get_session_by_machine_id(user_id, machine_id) else {
return Err(( return Err((
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
other_error(format!("No such session: {}", machine_id)).into(), other_error(format!("No such session: {}", machine_id)).into(),
@@ -289,23 +300,13 @@ impl NetworkApi {
auth_session: AuthSession, auth_session: AuthSession,
State(client_mgr): AppState, State(client_mgr): AppState,
) -> Result<Json<ListMachineJsonResp>, HttpHandleError> { ) -> Result<Json<ListMachineJsonResp>, HttpHandleError> {
let tokens = auth_session let user_id = Self::get_user_id(&auth_session)?;
.user
.as_ref()
.map(|x| x.tokens.clone())
.unwrap_or_default();
let client_urls = DashSet::new(); let client_urls = client_mgr.list_machine_by_user_id(user_id).await;
for token in tokens {
let urls = client_mgr.list_machine_by_token(token).await;
for url in urls {
client_urls.insert(url);
}
}
let mut machines = vec![]; let mut machines = vec![];
for item in client_urls.iter() { for item in client_urls.iter() {
let client_url = item.key().clone(); let client_url = item.clone();
let session = client_mgr.get_heartbeat_requests(&client_url).await; let session = client_mgr.get_heartbeat_requests(&client_url).await;
machines.push(ListMachineItem { machines.push(ListMachineItem {
client_url: Some(client_url), client_url: Some(client_url),