feat(web): add OIDC SSO login support (#1943)

This commit is contained in:
Mg Pig
2026-03-03 18:23:31 +08:00
committed by GitHub
parent d4ff0b1767
commit ff24332e23
16 changed files with 1300 additions and 156 deletions
+20 -4
View File
@@ -17,6 +17,8 @@ use easytier::{
use maxminddb::geoip2;
use session::{Location, Session};
use storage::{Storage, StorageToken};
use crate::FeatureFlags;
use tokio::task::JoinSet;
use crate::db::{entity::user_running_network_configs, Db, UserIdInDb};
@@ -55,11 +57,13 @@ pub struct ClientManager {
client_sessions: Arc<DashMap<url::Url, Arc<Session>>>,
storage: Storage,
feature_flags: Arc<FeatureFlags>,
geoip_db: Arc<Option<maxminddb::Reader<Vec<u8>>>>,
}
impl ClientManager {
pub fn new(db: Db, geoip_db: Option<String>) -> Self {
pub fn new(db: Db, geoip_db: Option<String>, feature_flags: Arc<FeatureFlags>) -> Self {
let client_sessions = Arc::new(DashMap::new());
let sessions: Arc<DashMap<url::Url, Arc<Session>>> = client_sessions.clone();
let mut tasks = JoinSet::new();
@@ -76,6 +80,8 @@ impl ClientManager {
client_sessions,
storage: Storage::new(db),
feature_flags,
geoip_db: Arc::new(load_geoip_db(geoip_db)),
}
}
@@ -90,6 +96,7 @@ impl ClientManager {
let storage = self.storage.weak_ref();
let listeners_cnt = self.listeners_cnt.clone();
let geoip_db = self.geoip_db.clone();
let feature_flags = self.feature_flags.clone();
self.tasks.spawn(async move {
while let Ok(tunnel) = listener.accept().await {
let info = tunnel.info().unwrap();
@@ -100,7 +107,12 @@ impl ClientManager {
client_url,
location
);
let mut session = Session::new(storage.clone(), client_url.clone(), location);
let mut session = Session::new(
storage.clone(),
client_url.clone(),
location,
feature_flags.clone(),
);
session.serve(tunnel).await;
sessions.insert(client_url, Arc::new(session));
}
@@ -291,12 +303,16 @@ mod tests {
};
use sqlx::Executor;
use crate::{client_manager::ClientManager, db::Db};
use crate::{client_manager::ClientManager, db::Db, FeatureFlags};
#[tokio::test]
async fn test_client() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:54333".parse().unwrap());
let mut mgr = ClientManager::new(Db::memory_db().await, None);
let mut mgr = ClientManager::new(
Db::memory_db().await,
None,
Arc::new(FeatureFlags::default()),
);
mgr.add_listener(Box::new(listener)).await.unwrap();
mgr.db()
+29 -9
View File
@@ -18,6 +18,7 @@ use easytier::{
use tokio::sync::{broadcast, RwLock};
use super::storage::{Storage, StorageToken, WeakRefStorage};
use crate::FeatureFlags;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Location {
@@ -29,6 +30,7 @@ pub struct Location {
#[derive(Debug)]
pub struct SessionData {
storage: WeakRefStorage,
feature_flags: Arc<FeatureFlags>,
client_url: url::Url,
storage_token: Option<StorageToken>,
@@ -38,11 +40,17 @@ pub struct SessionData {
}
impl SessionData {
fn new(storage: WeakRefStorage, client_url: url::Url, location: Option<Location>) -> Self {
fn new(
storage: WeakRefStorage,
client_url: url::Url,
location: Option<Location>,
feature_flags: Arc<FeatureFlags>,
) -> Self {
let (tx, _rx1) = broadcast::channel(2);
SessionData {
storage,
feature_flags,
client_url,
storage_token: None,
notifier: tx,
@@ -98,7 +106,7 @@ impl SessionRpcService {
req.machine_id
))?;
let user_id = storage
let user_id = match storage
.db()
.get_user_id_by_token(req.user_token.clone())
.await
@@ -107,11 +115,18 @@ impl SessionRpcService {
"Failed to get user id by token from db: {:?}",
req.user_token
)
})?
.ok_or(anyhow::anyhow!(
"User not found by token: {:?}",
req.user_token
))?;
})? {
Some(id) => id,
None if data.feature_flags.allow_auto_create_user => storage
.auto_create_user(&req.user_token)
.await
.with_context(|| format!("Failed to auto-create user: {:?}", req.user_token))?,
None => {
return Err(
anyhow::anyhow!("User not found by token: {:?}", req.user_token).into(),
);
}
};
if data.req.replace(req.clone()).is_none() {
assert!(data.storage_token.is_none());
@@ -173,8 +188,13 @@ impl Debug for Session {
type SessionRpcClient = Box<dyn WebClientService<Controller = BaseController> + Send>;
impl Session {
pub fn new(storage: WeakRefStorage, client_url: url::Url, location: Option<Location>) -> Self {
let session_data = SessionData::new(storage, client_url, location);
pub fn new(
storage: WeakRefStorage,
client_url: url::Url,
location: Option<Location>,
feature_flags: Arc<FeatureFlags>,
) -> Self {
let session_data = SessionData::new(storage, client_url, location, feature_flags);
let data = Arc::new(RwLock::new(session_data));
let rpc_mgr =
+6 -1
View File
@@ -21,7 +21,6 @@ struct ClientInfo {
#[derive(Debug)]
pub struct StorageInner {
// some map for indexing
user_clients_map: DashMap<UserIdInDb, DashMap<uuid::Uuid, ClientInfo>>,
pub db: Db,
}
@@ -123,4 +122,10 @@ impl Storage {
pub fn db(&self) -> &Db {
&self.0.db
}
pub async fn auto_create_user(&self, username: &str) -> anyhow::Result<UserIdInDb> {
let new_user = self.db().auto_create_user(username).await?;
tracing::info!("Auto-created user '{}' with id {}", username, new_user.id);
Ok(new_user.id)
}
}
+52 -1
View File
@@ -9,7 +9,7 @@ use easytier::{
use entity::user_running_network_configs;
use sea_orm::{
prelude::Expr, sea_query::OnConflict, ColumnTrait as _, DatabaseConnection, DbErr, EntityTrait,
QueryFilter as _, SqlxSqliteConnector, TransactionTrait as _,
QueryFilter as _, Set, SqlxSqliteConnector, TransactionTrait as _,
};
use sea_orm_migration::MigratorTrait as _;
use sqlx::{migrate::MigrateDatabase as _, types::chrono, Sqlite, SqlitePool};
@@ -82,6 +82,57 @@ impl Db {
Ok(user.map(|u| u.id))
}
/// `password_hash` must be pre-hashed by the caller.
/// Creates user + joins "users" group in one transaction. Returns the created user model.
pub async fn create_user_and_join_users_group(
&self,
username: &str,
password_hash: String,
) -> Result<entity::users::Model, DbErr> {
use entity::{groups, users, users_groups};
let txn = self.orm_db().begin().await?;
let user_active = users::ActiveModel {
username: Set(username.to_string()),
password: Set(password_hash),
..Default::default()
};
let insert_result = users::Entity::insert(user_active).exec(&txn).await?;
let new_user = users::Entity::find_by_id(insert_result.last_insert_id)
.one(&txn)
.await?
.ok_or_else(|| DbErr::Custom("Failed to find newly created user".to_string()))?;
let users_group = groups::Entity::find()
.filter(groups::Column::Name.eq("users"))
.one(&txn)
.await?
.ok_or_else(|| DbErr::Custom("Users group not found".to_string()))?;
let ug_active = users_groups::ActiveModel {
user_id: Set(new_user.id),
group_id: Set(users_group.id),
..Default::default()
};
users_groups::Entity::insert(ug_active).exec(&txn).await?;
txn.commit().await?;
Ok(new_user)
}
pub async fn auto_create_user(&self, username: &str) -> Result<entity::users::Model, DbErr> {
let random_password = uuid::Uuid::new_v4().to_string();
let hashed_password =
tokio::task::spawn_blocking(move || password_auth::generate_hash(&random_password))
.await
.map_err(|e| DbErr::Custom(format!("Failed to hash password: {}", e)))?;
self.create_user_and_join_users_group(username, hashed_password)
.await
}
// TODO: currently we don't have a token system, so we just use the user name as token
pub async fn get_user_id_by_token<T: ToString>(
&self,
+60 -8
View File
@@ -126,12 +126,22 @@ struct Cli {
)]
api_host: Option<url::Url>,
#[arg(
long,
default_value = "false",
help = t!("cli.disable_registration").to_string(),
)]
disable_registration: bool,
#[command(flatten)]
feature_flags: FeatureFlags,
#[command(flatten)]
oidc: restful::oidc::OidcOptions,
}
#[derive(Debug, Clone, Default, clap::Args)]
pub struct FeatureFlags {
/// Whether user registration via the web UI is disabled.
#[arg(long, default_value = "false", help = t!("cli.disable_registration").to_string())]
pub disable_registration: bool,
/// Whether to auto-create users when they connect via heartbeat with an unknown token.
#[arg(long, default_value = "false", help = t!("cli.allow_auto_create_user").to_string())]
pub allow_auto_create_user: bool,
}
impl LoggingConfigLoader for &Cli {
@@ -197,9 +207,37 @@ async fn main() {
let cli = Cli::parse();
init_logger(&cli, false).unwrap();
// Validate OIDC configuration: check split-deploy specific requirements
// Basic OIDC parameter validation is handled in OidcConfig::from_params
if cli.oidc.any_param_provided() {
let is_split_deploy = {
#[cfg(feature = "embed")]
{
let embed_split_by_port = cli.web_server_port.is_some()
&& cli.web_server_port != Some(cli.api_server_port);
cli.no_web || embed_split_by_port
}
#[cfg(not(feature = "embed"))]
{
true
}
};
if is_split_deploy && cli.oidc.oidc_frontend_base_url.is_none() {
eprintln!("Error: --oidc-frontend-base-url is required in split-deploy mode");
eprintln!(
"When frontend and API are deployed separately, you must specify the frontend URL"
);
eprintln!("Example: --oidc-frontend-base-url http://your-frontend-domain.com");
std::process::exit(1);
}
}
// let db = db::Db::new(":memory:").await.unwrap();
let db = db::Db::new(cli.db).await.unwrap();
let mut mgr = client_manager::ClientManager::new(db.clone(), cli.geoip_db);
let feature_flags = Arc::new(cli.feature_flags);
let mut mgr =
client_manager::ClientManager::new(db.clone(), cli.geoip_db, feature_flags.clone());
let (v6_listener, v4_listener) =
get_dual_stack_listener(&cli.config_server_protocol, cli.config_server_port)
.await
@@ -233,12 +271,26 @@ async fn main() {
#[cfg(not(feature = "embed"))]
let web_router_restful = None;
let oidc_config = if cli.oidc.oidc_issuer_url.is_some() {
match restful::oidc::OidcConfig::from_params(cli.oidc).await {
Ok(config) => config,
Err(e) => {
eprintln!("Failed to initialize OIDC: {:?}", e);
eprintln!("Please check your OIDC configuration (issuer URL, client ID, etc.)");
std::process::exit(1);
}
}
} else {
restful::oidc::OidcConfig::disabled()
};
let _restful_server_tasks = restful::RestfulServer::new(
std::net::SocketAddr::new(cli.api_server_addr, cli.api_server_port),
mgr.clone(),
db,
web_router_restful,
cli.disable_registration,
feature_flags,
oidc_config,
)
.await
.unwrap()
+5 -8
View File
@@ -9,18 +9,15 @@ use serde::{Deserialize, Serialize};
use crate::restful::users::Backend;
use std::sync::Arc;
use crate::FeatureFlags;
use super::{
users::{AuthSession, Credentials},
AppStateInner,
};
/// Feature flags for the web server
#[derive(Clone, Default)]
pub struct FeatureFlags {
/// Whether user registration is disabled
pub disable_registration: bool,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct LoginResult {
messages: Vec<Message>,
@@ -117,7 +114,7 @@ mod post {
}
pub async fn register(
Extension(feature_flags): Extension<FeatureFlags>,
Extension(feature_flags): Extension<Arc<FeatureFlags>>,
auth_session: AuthSession,
captcha_session: tower_sessions::Session,
Json(req): Json<RegisterNewUser>,
+13 -7
View File
@@ -1,6 +1,7 @@
mod auth;
pub(crate) mod captcha;
mod network;
pub(crate) mod oidc;
mod users;
use std::{net::SocketAddr, sync::Arc};
@@ -19,7 +20,7 @@ use network::NetworkApi;
use sea_orm::DbErr;
use tokio::net::TcpListener;
use tower_sessions::cookie::time::Duration;
use tower_sessions::cookie::Key;
use tower_sessions::cookie::{Key, SameSite};
use tower_sessions::Expiry;
use tower_sessions_sqlx_store::SqliteStore;
use users::{AuthSession, Backend};
@@ -27,6 +28,7 @@ use users::{AuthSession, Backend};
use crate::client_manager::storage::StorageToken;
use crate::client_manager::ClientManager;
use crate::db::Db;
use crate::FeatureFlags;
/// Embed assets for web dashboard, build frontend first
#[cfg(feature = "embed")]
@@ -37,8 +39,9 @@ struct Assets;
pub struct RestfulServer {
bind_addr: SocketAddr,
client_mgr: Arc<ClientManager>,
registration_disabled: bool,
feature_flags: Arc<FeatureFlags>,
db: Db,
oidc_config: oidc::OidcConfig,
// serve_task: Option<ScopedTask<()>>,
// delete_task: Option<ScopedTask<tower_sessions::session_store::Result<()>>>,
@@ -105,7 +108,8 @@ impl RestfulServer {
client_mgr: Arc<ClientManager>,
db: Db,
web_router: Option<Router>,
registration_disabled: bool,
feature_flags: Arc<FeatureFlags>,
oidc_config: oidc::OidcConfig,
) -> anyhow::Result<Self> {
assert!(client_mgr.is_running());
@@ -114,8 +118,9 @@ impl RestfulServer {
Ok(RestfulServer {
bind_addr,
client_mgr,
registration_disabled,
feature_flags,
db,
oidc_config,
// serve_task: None,
// delete_task: None,
// network_api,
@@ -222,6 +227,7 @@ impl RestfulServer {
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_same_site(SameSite::Lax)
.with_expiry(Expiry::OnInactivity(Duration::days(1)))
.with_signed(key);
@@ -243,15 +249,15 @@ impl RestfulServer {
.route("/api/v1/sessions", get(Self::handle_list_all_sessions))
.merge(NetworkApi::build_route())
.route_layer(login_required!(Backend))
.merge(auth::router().layer(Extension(auth::FeatureFlags {
disable_registration: self.registration_disabled,
})))
.merge(auth::router().layer(Extension(self.feature_flags.clone())))
.merge(oidc::router())
.with_state(self.client_mgr.clone())
.route(
"/api/v1/generate-config",
post(Self::handle_generate_config),
)
.route("/api/v1/parse-config", post(Self::handle_parse_config))
.layer(Extension(self.oidc_config.clone()))
.layer(MessagesManagerLayer)
.layer(auth_layer)
.layer(tower_http::cors::CorsLayer::very_permissive())
+734
View File
@@ -0,0 +1,734 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use subtle::ConstantTimeEq;
use axum::routing::get;
use axum::Router;
use openidconnect::core::{
CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType,
};
use openidconnect::{
Client, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointMaybeSet, EndpointNotSet,
EndpointSet, IdTokenFields, IssuerUrl, RedirectUrl, StandardErrorResponse,
StandardTokenResponse,
};
use serde::{Deserialize, Serialize};
use super::AppStateInner;
const DEFAULT_OIDC_SCOPES: [&str; 2] = ["openid", "profile"];
fn normalize_oidc_scopes(scopes: &[String]) -> Vec<String> {
let mut normalized: Vec<String> = scopes
.iter()
.map(|scope| scope.trim().to_string())
.filter(|scope| !scope.is_empty())
.collect();
if normalized.is_empty() {
normalized = DEFAULT_OIDC_SCOPES
.iter()
.map(|scope| scope.to_string())
.collect();
}
if !normalized.iter().any(|scope| scope == "openid") {
normalized.insert(0, "openid".to_string());
}
normalized
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct JsonAdditionalClaims {
#[serde(flatten)]
pub claims: HashMap<String, serde_json::Value>,
}
impl openidconnect::AdditionalClaims for JsonAdditionalClaims {}
pub type AppIdTokenFields = IdTokenFields<
JsonAdditionalClaims,
EmptyExtraTokenFields,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
>;
pub type AppTokenResponse = StandardTokenResponse<AppIdTokenFields, CoreTokenType>;
pub type AppClient<
HasAuthUrl = EndpointNotSet,
HasDeviceAuthUrl = EndpointNotSet,
HasIntrospectionUrl = EndpointNotSet,
HasRevocationUrl = EndpointNotSet,
HasTokenUrl = EndpointNotSet,
HasUserInfoUrl = EndpointNotSet,
> = Client<
JsonAdditionalClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
AppTokenResponse,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
HasAuthUrl,
HasDeviceAuthUrl,
HasIntrospectionUrl,
HasRevocationUrl,
HasTokenUrl,
HasUserInfoUrl,
>;
pub type ConfiguredAppClient = AppClient<
EndpointSet,
EndpointNotSet,
EndpointNotSet,
EndpointNotSet,
EndpointMaybeSet,
EndpointMaybeSet,
>;
/// Convert a dot-path (e.g. `realm_access.roles.0`) to a JSON Pointer (e.g. `/realm_access/roles/0`).
/// Each segment is escaped per RFC 6901: `~` → `~0`, `/` → `~1`.
fn dot_path_to_json_pointer(dot_path: &str) -> String {
let mut pointer = String::new();
for segment in dot_path.split('.') {
pointer.push('/');
for ch in segment.chars() {
match ch {
'~' => pointer.push_str("~0"),
'/' => pointer.push_str("~1"),
_ => pointer.push(ch),
}
}
}
pointer
}
/// Timing-safe string comparison via constant-time equality check.
/// Prevents timing side-channel attacks on CSRF token verification.
fn timing_safe_eq(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
a.as_bytes().ct_eq(b.as_bytes()).into()
}
#[derive(Debug, Clone, clap::Args)]
pub struct OidcOptions {
#[arg(long, help = t!("cli.oidc_issuer_url").to_string())]
pub oidc_issuer_url: Option<String>,
#[arg(long, help = t!("cli.oidc_client_id").to_string())]
pub oidc_client_id: Option<String>,
#[arg(long, env = "OIDC_CLIENT_SECRET", help = t!("cli.oidc_client_secret").to_string())]
pub oidc_client_secret: Option<String>,
#[arg(long, default_value = "preferred_username", help = t!("cli.oidc_username_claim").to_string())]
pub oidc_username_claim: String,
#[arg(
long,
value_delimiter = ',',
default_values = DEFAULT_OIDC_SCOPES,
help = t!("cli.oidc_scopes").to_string()
)]
pub oidc_scopes: Vec<String>,
#[arg(long, help = t!("cli.oidc_redirect_url").to_string())]
pub oidc_redirect_url: Option<String>,
#[arg(long, default_value = "false", help = t!("cli.oidc_disable_pkce").to_string())]
pub oidc_disable_pkce: bool,
#[arg(long, help = t!("cli.oidc_frontend_base_url").to_string())]
pub oidc_frontend_base_url: Option<String>,
}
impl OidcOptions {
pub fn any_param_provided(&self) -> bool {
self.oidc_issuer_url.is_some()
|| self.oidc_client_id.is_some()
|| self.oidc_client_secret.is_some()
|| self.oidc_redirect_url.is_some()
|| self.oidc_frontend_base_url.is_some()
|| self.oidc_username_claim != "preferred_username"
|| self.oidc_scopes != DEFAULT_OIDC_SCOPES
|| self.oidc_disable_pkce
}
}
#[derive(Clone)]
pub struct OidcConfig {
pub enabled: bool,
pub provider_metadata: Option<Arc<CoreProviderMetadata>>,
pub client_id: String,
pub client_secret: Option<String>,
pub redirect_url: Option<RedirectUrl>,
pub username_claim: String,
pub scopes: Vec<String>,
pub pkce_enabled: bool,
pub frontend_base_url: Option<String>,
pub http_client: Option<reqwest::Client>,
cached_client: Option<Arc<ConfiguredAppClient>>,
}
impl OidcConfig {
pub fn disabled() -> Self {
Self {
enabled: false,
provider_metadata: None,
client_id: String::new(),
client_secret: None,
redirect_url: None,
username_claim: "preferred_username".to_string(),
scopes: DEFAULT_OIDC_SCOPES
.iter()
.map(|scope| scope.to_string())
.collect(),
pkce_enabled: false,
frontend_base_url: None,
http_client: None,
cached_client: None,
}
}
pub async fn from_params(opts: OidcOptions) -> anyhow::Result<Self> {
let OidcOptions {
oidc_issuer_url,
oidc_client_id,
oidc_client_secret,
oidc_username_claim,
oidc_scopes,
oidc_redirect_url,
oidc_disable_pkce,
oidc_frontend_base_url,
} = opts;
if oidc_issuer_url.is_none() || oidc_client_id.is_none() || oidc_redirect_url.is_none() {
return Err(anyhow::anyhow!("--oidc-issuer-url, --oidc-client-id and --oidc-redirect-url are required when using OIDC authentication"));
}
if oidc_username_claim.trim().is_empty() {
return Err(anyhow::anyhow!("--oidc-username-claim cannot be empty"));
}
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.timeout(Duration::from_secs(30))
.build()?;
let issuer_url = oidc_issuer_url.ok_or_else(|| {
anyhow::anyhow!("--oidc-issuer-url is required when using OIDC authentication")
})?;
let provider_metadata =
CoreProviderMetadata::discover_async(IssuerUrl::new(issuer_url)?, &http_client).await?;
let client_id = oidc_client_id.ok_or_else(|| {
anyhow::anyhow!("--oidc-client-id is required when using OIDC authentication")
})?;
let redirect_url = oidc_redirect_url
.ok_or_else(|| anyhow::anyhow!("--oidc-redirect-url is required when using OIDC authentication. The redirect URL must match exactly what is registered with your Identity Provider. Example: --oidc-redirect-url http://your-domain.com:11211/api/v1/auth/oidc/callback"))?;
let provider_metadata = Arc::new(provider_metadata);
let redirect_url = RedirectUrl::new(redirect_url)?;
let client_secret = oidc_client_secret;
let cached_client = {
let c = AppClient::from_provider_metadata(
provider_metadata.as_ref().clone(),
ClientId::new(client_id.clone()),
client_secret.as_ref().map(|s| ClientSecret::new(s.clone())),
)
.set_redirect_uri(redirect_url.clone());
Arc::new(c)
};
Ok(Self {
enabled: true,
provider_metadata: Some(provider_metadata),
client_id,
client_secret,
redirect_url: Some(redirect_url),
username_claim: oidc_username_claim,
scopes: normalize_oidc_scopes(&oidc_scopes),
pkce_enabled: !oidc_disable_pkce,
frontend_base_url: oidc_frontend_base_url,
http_client: Some(http_client),
cached_client: Some(cached_client),
})
}
pub fn client(&self) -> Option<&ConfiguredAppClient> {
self.cached_client.as_deref()
}
}
pub fn router() -> Router<AppStateInner> {
Router::new()
.route("/api/v1/auth/oidc/config", get(self::route::oidc_config))
.route("/api/v1/auth/oidc/login", get(self::route::oidc_login))
.route(
"/api/v1/auth/oidc/callback",
get(self::route::oidc_callback),
)
}
mod route {
use axum::extract::Query;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Redirect, Response};
use axum::{Extension, Json};
use openidconnect::core::CoreAuthenticationFlow;
use openidconnect::{
AccessTokenHash, AuthorizationCode, CsrfToken, Nonce, OAuth2TokenResponse,
PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse,
};
use serde::Deserialize;
use crate::restful::other_error;
use crate::restful::users::AuthSession;
use super::OidcConfig;
pub async fn oidc_config(Extension(oidc): Extension<OidcConfig>) -> Json<serde_json::Value> {
Json(serde_json::json!({ "enabled": oidc.enabled }))
}
pub async fn oidc_login(
Extension(oidc): Extension<OidcConfig>,
session: tower_sessions::Session,
) -> Response {
if !oidc.enabled {
return (
StatusCode::BAD_REQUEST,
Json(other_error("OIDC is not enabled")),
)
.into_response();
}
let client = match oidc.client() {
Some(c) => c,
None => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("OIDC client not initialized")),
)
.into_response();
}
};
let scopes = oidc.scopes.clone();
let pkce_enabled = oidc.pkce_enabled;
let (pkce_challenge, pkce_verifier) = if pkce_enabled {
let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
(Some(challenge), Some(verifier))
} else {
(None, None)
};
let mut auth_request = client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
for scope in &scopes {
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
}
if let Some(challenge) = pkce_challenge {
auth_request = auth_request.set_pkce_challenge(challenge);
}
let (auth_url, csrf_token, nonce) = auth_request.url();
if let Err(e) = session
.insert("oidc_csrf_token", csrf_token.secret().clone())
.await
{
tracing::error!("Failed to store csrf_token in session: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Session error")),
)
.into_response();
}
if let Err(e) = session.insert("oidc_nonce", nonce.secret().clone()).await {
tracing::error!("Failed to store nonce in session: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Session error")),
)
.into_response();
}
if let Some(verifier) = pkce_verifier {
if let Err(e) = session
.insert("oidc_pkce_verifier", verifier.secret().clone())
.await
{
tracing::error!("Failed to store pkce_verifier in session: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Session error")),
)
.into_response();
}
}
if let Err(e) = session.insert("oidc_pkce_used", pkce_enabled).await {
tracing::error!("Failed to store pkce_used in session: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Session error")),
)
.into_response();
}
Redirect::temporary(auth_url.as_str()).into_response()
}
#[derive(Deserialize)]
pub struct CallbackParams {
code: Option<String>,
state: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
async fn cleanup_oidc_session(session: &tower_sessions::Session) {
let _ = session.remove::<String>("oidc_csrf_token").await;
let _ = session.remove::<String>("oidc_nonce").await;
let _ = session.remove::<String>("oidc_pkce_verifier").await;
let _ = session.remove::<bool>("oidc_pkce_used").await;
}
pub async fn oidc_callback(
Extension(oidc): Extension<OidcConfig>,
Query(params): Query<CallbackParams>,
session: tower_sessions::Session,
mut auth_session: AuthSession,
) -> Response {
if !oidc.enabled {
return (
StatusCode::BAD_REQUEST,
Json(other_error("OIDC is not enabled")),
)
.into_response();
}
if let Some(ref error) = params.error {
tracing::error!(
"OIDC provider returned error: {}, description: {:?}",
error,
params.error_description
);
return (
StatusCode::BAD_REQUEST,
Json(other_error(
"Authentication failed at the identity provider",
)),
)
.into_response();
}
let code = match params.code {
Some(ref c) => c.clone(),
None => {
return (
StatusCode::BAD_REQUEST,
Json(other_error("Missing authorization code")),
)
.into_response();
}
};
let callback_state = match params.state {
Some(ref s) => s.clone(),
None => {
return (
StatusCode::BAD_REQUEST,
Json(other_error("Missing state parameter in callback")),
)
.into_response();
}
};
let stored_csrf: String = match session.get("oidc_csrf_token").await {
Ok(Some(v)) => v,
_ => {
return (
StatusCode::BAD_REQUEST,
Json(other_error("Missing or invalid CSRF token in session")),
)
.into_response();
}
};
if !super::timing_safe_eq(&stored_csrf, &callback_state) {
return (
StatusCode::BAD_REQUEST,
Json(other_error("CSRF state mismatch")),
)
.into_response();
}
let stored_nonce: String = match session.get("oidc_nonce").await {
Ok(Some(v)) => v,
_ => {
return (
StatusCode::BAD_REQUEST,
Json(other_error("Missing nonce in session")),
)
.into_response();
}
};
let stored_pkce_verifier: Option<String> =
session.get("oidc_pkce_verifier").await.ok().flatten();
let pkce_was_used: Option<bool> = session.get("oidc_pkce_used").await.ok().flatten();
cleanup_oidc_session(&session).await;
let client = match oidc.client() {
Some(c) => c,
None => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("OIDC client not initialized")),
)
.into_response();
}
};
let http_client = match oidc.http_client.as_ref() {
Some(c) => c,
None => {
tracing::error!("HTTP client not initialized in OIDC config");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("OIDC internal error")),
)
.into_response();
}
};
let mut token_request = match client.exchange_code(AuthorizationCode::new(code)) {
Ok(req) => req,
Err(e) => {
tracing::error!("Failed to create token request: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Failed to create token exchange request")),
)
.into_response();
}
};
if let Some(stored_pkce_verifier) = stored_pkce_verifier {
token_request =
token_request.set_pkce_verifier(PkceCodeVerifier::new(stored_pkce_verifier));
} else if pkce_was_used == Some(true) {
return (
StatusCode::BAD_REQUEST,
Json(other_error(
"PKCE was enabled but verifier is missing from session (session may have expired)",
)),
)
.into_response();
}
let token_response = match token_request.request_async(http_client).await {
Ok(resp) => resp,
Err(e) => {
tracing::error!("Failed to exchange code for token: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Token exchange failed")),
)
.into_response();
}
};
let id_token = match token_response.id_token() {
Some(t) => t,
None => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("No ID token in response")),
)
.into_response();
}
};
let claims = match id_token.claims(&client.id_token_verifier(), &Nonce::new(stored_nonce)) {
Ok(c) => c,
Err(e) => {
tracing::error!("Failed to verify ID token: {:?}", e);
return (
StatusCode::UNAUTHORIZED,
Json(other_error("ID token verification failed")),
)
.into_response();
}
};
if let Some(expected_at_hash) = claims.access_token_hash() {
let id_token_verifier = client.id_token_verifier();
let (Ok(signing_alg), Ok(signing_key)) = (
id_token.signing_alg(),
id_token.signing_key(&id_token_verifier),
) else {
tracing::error!("Failed to get signing algorithm or key for at_hash verification");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Failed to determine token signing algorithm")),
)
.into_response();
};
let actual_at_hash = match AccessTokenHash::from_token(
token_response.access_token(),
signing_alg,
signing_key,
) {
Ok(hash) => hash,
Err(e) => {
tracing::error!("Failed to compute access token hash: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Failed to verify access token hash")),
)
.into_response();
}
};
if actual_at_hash != *expected_at_hash {
tracing::error!("Access token hash mismatch");
return (
StatusCode::UNAUTHORIZED,
Json(other_error("Access token hash mismatch")),
)
.into_response();
}
}
let claims_json = match serde_json::to_value(claims) {
Ok(v) => v,
Err(e) => {
tracing::error!("Failed to serialize claims: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Failed to process ID token claims")),
)
.into_response();
}
};
let pointer = super::dot_path_to_json_pointer(&oidc.username_claim);
let username: Option<String> = claims_json
.pointer(&pointer)
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let username = match username {
Some(u) if !u.is_empty() => u,
_ => {
tracing::error!(
"Could not extract username from claim '{}' in token",
oidc.username_claim
);
return (
StatusCode::BAD_REQUEST,
Json(other_error("Could not extract username from token claims")),
)
.into_response();
}
};
let user = match auth_session
.backend
.find_or_create_oidc_user(&username)
.await
{
Ok(u) => u,
Err(e) => {
tracing::error!("Failed to find or create OIDC user '{}': {:?}", username, e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Failed to provision user account")),
)
.into_response();
}
};
if let Err(e) = auth_session.login(&user).await {
tracing::error!("Failed to login user via OIDC: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(other_error("Failed to establish session")),
)
.into_response();
}
if let Err(e) = session.cycle_id().await {
tracing::error!("Failed to cycle session ID after OIDC login: {:?}", e);
}
if let Some(frontend_url) = &oidc.frontend_base_url {
Redirect::temporary(frontend_url).into_response()
} else {
Redirect::temporary("/").into_response()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dot_path_to_json_pointer() {
use serde_json::json;
let cases = vec![
(
"realm_access.roles.0",
"/realm_access/roles/0",
json!({ "realm_access": { "roles": ["admin", "user"] } }),
"admin",
),
(
"preferred_username",
"/preferred_username",
json!({ "preferred_username": "bob" }),
"bob",
),
("a~b.c", "/a~0b/c", json!({ "a~b": { "c": "v" } }), "v"),
("a/b.c", "/a~1b/c", json!({ "a/b": { "c": "w" } }), "w"),
("~/.x", "/~0~1/x", json!({ "~/": { "x": "z" } }), "z"),
("a..b", "/a//b", json!({ "a": { "": { "b": "x" } } }), "x"),
("", "/", json!({ "": "root" }), "root"),
];
for (path, expected_ptr, json_val, expected_val) in cases {
let ptr = dot_path_to_json_pointer(path);
assert_eq!(ptr, expected_ptr, "Pointer mismatch for path: {}", path);
assert_eq!(
json_val.pointer(&ptr).and_then(|v| v.as_str()),
Some(expected_val),
"Value extraction failed for path: {}, pointer: {}",
path,
ptr
);
}
}
}
+39 -32
View File
@@ -4,8 +4,8 @@ use async_trait::async_trait;
use axum_login::{AuthUser, AuthnBackend, AuthzBackend, UserId};
use password_auth::verify_password;
use sea_orm::{
ActiveModelTrait as _, ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel, JoinType,
QueryFilter, QuerySelect as _, RelationTrait, Set, TransactionTrait,
ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel, JoinType, QueryFilter,
QuerySelect as _, RelationTrait, Set,
};
use serde::{Deserialize, Serialize};
use tokio::task;
@@ -14,7 +14,7 @@ use crate::db::{self, entity};
#[derive(Clone, Serialize, Deserialize)]
pub struct User {
db_user: entity::users::Model,
pub(crate) db_user: entity::users::Model,
pub tokens: Vec<String>,
}
@@ -74,40 +74,47 @@ impl Backend {
Self { db }
}
pub fn db(&self) -> &db::Db {
&self.db
}
pub async fn register_new_user(&self, new_user: &RegisterNewUser) -> anyhow::Result<()> {
let hashed_password = password_auth::generate_hash(new_user.credentials.password.as_str());
let txn = self.db.orm_db().begin().await?;
entity::users::ActiveModel {
username: Set(new_user.credentials.username.clone()),
password: Set(hashed_password.clone()),
..Default::default()
}
.save(&txn)
.await?;
entity::users_groups::ActiveModel {
user_id: Set(entity::users::Entity::find()
.filter(entity::users::Column::Username.eq(new_user.credentials.username.as_str()))
.one(&txn)
.await?
.unwrap()
.id),
group_id: Set(entity::groups::Entity::find()
.filter(entity::groups::Column::Name.eq("users"))
.one(&txn)
.await?
.unwrap()
.id),
..Default::default()
}
.save(&txn)
.await?;
txn.commit().await?;
self.db
.create_user_and_join_users_group(&new_user.credentials.username, hashed_password)
.await?;
Ok(())
}
/// Find a user by username, or auto-create one for OIDC-authenticated users.
///
/// Unlike the heartbeat auto-creation path (controlled by `allow_auto_create_user`),
/// OIDC users are always provisioned automatically because their identity has already
/// been verified by a trusted external Identity Provider (IdP).
pub async fn find_or_create_oidc_user(&self, username: &str) -> anyhow::Result<User> {
use entity::users;
// Try to find an existing user first.
if let Some(db_user) = users::Entity::find()
.filter(users::Column::Username.eq(username))
.one(self.db.orm_db())
.await?
{
return Ok(User {
tokens: vec![db_user.username.clone()],
db_user,
});
}
// User not found auto-provision a local account backed by the IdP identity.
let db_user = self.db.auto_create_user(username).await?;
tracing::info!("Auto-provisioned OIDC user '{username}'");
Ok(User {
tokens: vec![db_user.username.clone()],
db_user,
})
}
pub async fn change_password(
&self,
id: <User as AuthUser>::Id,