feat/web: Patchset 3 (#455)

https://apifox.com/apidoc/shared-ceda7a60-e817-4ea8-827b-de4e874dc45e

implement all backend API
This commit is contained in:
Sijie.Sun
2024-11-02 15:13:19 +08:00
committed by GitHub
parent 18da94bf33
commit 8aca5851f2
41 changed files with 4621 additions and 217 deletions
+171
View File
@@ -0,0 +1,171 @@
use axum::{
http::StatusCode,
routing::{get, post, put},
Router,
};
use axum_login::login_required;
use axum_messages::Message;
use serde::{Deserialize, Serialize};
use crate::restful::users::Backend;
use super::{
users::{AuthSession, Credentials},
AppStateInner,
};
#[derive(Debug, Deserialize, Serialize)]
pub struct LoginResult {
messages: Vec<Message>,
}
pub fn router() -> Router<AppStateInner> {
let r = Router::new()
.route("/api/v1/auth/password", put(self::put::change_password))
.route_layer(login_required!(Backend));
Router::new()
.merge(r)
.route("/api/v1/auth/login", post(self::post::login))
.route("/api/v1/auth/logout", get(self::get::logout))
.route("/api/v1/auth/captcha", get(self::get::get_captcha))
.route("/api/v1/auth/register", post(self::post::register))
}
mod put {
use axum::Json;
use axum_login::AuthUser;
use easytier::proto::common::Void;
use crate::restful::{other_error, users::ChangePassword, HttpHandleError};
use super::*;
pub async fn change_password(
mut auth_session: AuthSession,
Json(req): Json<ChangePassword>,
) -> Result<Json<Void>, HttpHandleError> {
if let Err(e) = auth_session
.backend
.change_password(auth_session.user.as_ref().unwrap().id(), &req)
.await
{
tracing::error!("Failed to change password: {:?}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
));
}
let _ = auth_session.logout().await;
Ok(Void::default().into())
}
}
mod post {
use axum::Json;
use easytier::proto::common::Void;
use crate::restful::{
captcha::extension::{axum_tower_sessions::CaptchaAxumTowerSessionStaticExt, CaptchaUtil},
other_error,
users::RegisterNewUser,
HttpHandleError,
};
use super::*;
pub async fn login(
mut auth_session: AuthSession,
Json(creds): Json<Credentials>,
) -> Result<Json<Void>, HttpHandleError> {
let user = match auth_session.authenticate(creds.clone()).await {
Ok(Some(user)) => user,
Ok(None) => {
return Err((
StatusCode::UNAUTHORIZED,
Json::from(other_error("Invalid credentials")),
));
}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
))
}
};
if let Err(e) = auth_session.login(&user).await {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
));
}
Ok(Void::default().into())
}
pub async fn register(
auth_session: AuthSession,
captcha_session: tower_sessions::Session,
Json(req): Json<RegisterNewUser>,
) -> Result<Json<Void>, HttpHandleError> {
// 调用CaptchaUtil的静态方法验证验证码是否正确
if !CaptchaUtil::ver(&req.captcha, &captcha_session).await {
return Err((
StatusCode::BAD_REQUEST,
other_error(format!("captcha verify error, input: {}", req.captcha)).into(),
));
}
if let Err(e) = auth_session.backend.register_new_user(&req).await {
tracing::error!("Failed to register new user: {:?}", e);
return Err((
StatusCode::BAD_REQUEST,
other_error(format!("{:?}", e)).into(),
));
}
Ok(Void::default().into())
}
}
mod get {
use crate::restful::{
captcha::{
captcha::spec::SpecCaptcha,
extension::{axum_tower_sessions::CaptchaAxumTowerSessionExt as _, CaptchaUtil},
NewCaptcha as _,
},
other_error, HttpHandleError,
};
use axum::{response::Response, Json};
use easytier::proto::common::Void;
use tower_sessions::Session;
use super::*;
pub async fn logout(mut auth_session: AuthSession) -> Result<Json<Void>, HttpHandleError> {
match auth_session.logout().await {
Ok(_) => Ok(Json(Void::default())),
Err(e) => {
tracing::error!("Failed to logout: {:?}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
))
}
}
}
pub async fn get_captcha(session: Session) -> Result<Response, HttpHandleError> {
let mut captcha: CaptchaUtil<SpecCaptcha> = CaptchaUtil::with_size_and_len(127, 48, 4);
match captcha.out(&session).await {
Ok(response) => Ok(response),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
)),
}
}
}
@@ -0,0 +1,308 @@
use super::super::base::randoms::Randoms;
use super::super::utils::color::Color;
use super::super::utils::font;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use rusttype::Font;
use std::fmt::Debug;
use std::io::Write;
use std::sync::Arc;
/// 验证码抽象类
pub(crate) struct Captcha {
/// 随机数工具类
pub(crate) randoms: Randoms,
/// 常用颜色
color: Vec<Color>,
/// 字体名称
font_names: [&'static str; 1],
/// 验证码的字体
font_name: String,
/// 验证码的字体大小
font_size: f32,
/// 验证码随机字符长度
pub len: usize,
/// 验证码显示宽度
pub width: i32,
/// 验证码显示高度
pub height: i32,
/// 验证码类型
char_type: CaptchaType,
/// 当前验证码
pub(crate) chars: Option<String>,
}
/// 验证码文本类型 The character type of the captcha
pub enum CaptchaType {
/// 字母数字混合
TypeDefault = 1,
/// 纯数字
TypeOnlyNumber,
/// 纯字母
TypeOnlyChar,
/// 纯大写字母
TypeOnlyUpper,
/// 纯小写字母
TypeOnlyLower,
/// 数字大写字母
TypeNumAndUpper,
}
/// 内置字体 Fonts shipped with the library
pub enum CaptchaFont {
/// actionj
Font1,
/// epilog
Font2,
/// fresnel
Font3,
/// headache
Font4,
/// lexo
Font5,
/// prefix
Font6,
/// progbot
Font7,
/// ransom
Font8,
/// robot
Font9,
/// scandal
Font10,
}
impl Captcha {
/// 生成随机验证码
pub fn alphas(&mut self) -> Vec<char> {
let mut cs = vec!['\0'; self.len];
for i in 0..self.len {
match self.char_type {
CaptchaType::TypeDefault => cs[i] = self.randoms.alpha(),
CaptchaType::TypeOnlyNumber => {
cs[i] = self.randoms.alpha_under(self.randoms.num_max_index)
}
CaptchaType::TypeOnlyChar => {
cs[i] = self
.randoms
.alpha_between(self.randoms.char_min_index, self.randoms.char_max_index)
}
CaptchaType::TypeOnlyUpper => {
cs[i] = self
.randoms
.alpha_between(self.randoms.upper_min_index, self.randoms.upper_max_index)
}
CaptchaType::TypeOnlyLower => {
cs[i] = self
.randoms
.alpha_between(self.randoms.lower_min_index, self.randoms.lower_max_index)
}
CaptchaType::TypeNumAndUpper => {
cs[i] = self.randoms.alpha_under(self.randoms.upper_max_index)
}
}
}
self.chars = Some(cs.iter().collect());
cs
}
/// 获取当前的验证码
pub fn text(&mut self) -> String {
self.check_alpha();
self.chars.clone().unwrap()
}
/// 获取当前验证码的字符数组
pub fn text_char(&mut self) -> Vec<char> {
self.check_alpha();
self.chars.clone().unwrap().chars().collect()
}
/// 检查验证码是否生成,没有则立即生成
pub fn check_alpha(&mut self) {
if self.chars.is_none() {
self.alphas();
}
}
pub fn get_font(&mut self) -> Arc<Font> {
if let Some(font) = font::get_font(&self.font_name) {
font
} else {
font::get_font(self.font_names[0]).unwrap()
}
}
pub fn get_font_size(&mut self) -> f32 {
self.font_size
}
pub fn set_font_by_enum(&mut self, font: CaptchaFont, size: Option<f32>) {
let font_name = self.font_names[font as usize];
self.font_name = font_name.into();
self.font_size = size.unwrap_or(32.);
}
}
/// 初始化验证码的抽象方法 Traits for initialize a Captcha instance.
pub trait NewCaptcha
where
Self: Sized,
{
/// 用默认参数初始化
///
/// Initialize the Captcha with the default properties.
fn new() -> Self;
/// 使用输出图像大小初始化
///
/// Initialize the Captcha with the size of output image.
fn with_size(width: i32, height: i32) -> Self;
/// 使用输出图像大小和验证码字符长度初始化
///
/// Initialize the Captcha with the size of output image and the character length of the Captcha.
///
/// <br/>
///
/// 特别地/In particular:
///
/// - 对算术验证码[ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha)而言,这里的`len`是验证码中数字的数量。
/// For [ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha), the `len` presents the count of the digits
/// in the Captcha.
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self;
/// 使用完整的参数来初始化,包括输出图像大小、验证码字符长度和输出字体及其大小
///
/// Initialize the Captcha with full properties, including the size of output image, the character length of the Captcha,
/// and the font used in Captcha with the font size.
///
/// 关于`len`字段的注意事项,请参见[with_size_and_len](Self::with_size_and_len)中的说明。Refer to the document of
/// [with_size_and_len](Self::with_size_and_len) for the precautions of the `len` property.
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self;
}
impl NewCaptcha for Captcha {
fn new() -> Self {
let color = [
(0, 135, 255),
(51, 153, 51),
(255, 102, 102),
(255, 153, 0),
(153, 102, 0),
(153, 102, 153),
(51, 153, 153),
(102, 102, 255),
(0, 102, 204),
(204, 51, 51),
(0, 153, 204),
(0, 51, 102),
]
.iter()
.map(|v| (*v).into())
.collect();
let font_names = ["robot.ttf"];
let font_name = font_names[0].into();
let font_size = 32.;
let len = 5;
let width = 130;
let height = 48;
let char_type = CaptchaType::TypeDefault;
let chars = None;
Self {
randoms: Randoms::new(),
color,
font_names,
font_name,
font_size,
len,
width,
height,
char_type,
chars,
}
}
fn with_size(width: i32, height: i32) -> Self {
let mut _self = Self::new();
_self.width = width;
_self.height = height;
_self
}
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self {
let mut _self = Self::new();
_self.width = width;
_self.height = height;
_self.len = len;
_self
}
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self {
let mut _self = Self::new();
_self.width = width;
_self.height = height;
_self.len = len;
_self.set_font_by_enum(font, None);
_self.font_size = font_size;
_self
}
}
/// 验证码的抽象方法 Traits which a Captcha must implements.
pub trait AbstractCaptcha: NewCaptcha {
/// 错误类型
type Error: std::error::Error + Debug + Send + Sync + 'static;
/// 输出验证码到指定位置
///
/// Write the Captcha image to the specified place.
fn out(&mut self, out: impl Write) -> Result<(), Self::Error>;
/// 获取验证码中的字符(即正确答案)
///
/// Get the characters (i.e. the correct answer) of the Captcha
fn get_chars(&mut self) -> Vec<char>;
/// 输出Base64编码。注意,返回值会带编码头(例如`data:image/png;base64,`),可以直接在浏览器中显示;如不需要编码头,
/// 请使用[base64_with_head](Self::base64_with_head)方法并传入空参数以去除编码头。
///
/// Get the Base64 encoded image. Reminds: the returned Base64 strings will begin with an encoding head like
/// `data:image/png;base64,`, which make it possible to display in browsers directly. If you don't need it, you may
/// use [base64_with_head](Self::base64_with_head) and pass a null string.
fn base64(&mut self) -> Result<String, Self::Error>;
/// 获取验证码的MIME类型
///
/// Get the MIME Content type of the Captcha.
fn get_content_type(&mut self) -> String;
/// 输出Base64编码(指定编码头)
///
/// Get the Base64 encoded image, with specified encoding head.
fn base64_with_head(&mut self, head: &str) -> Result<String, Self::Error> {
let mut output_stream = Vec::new();
self.out(&mut output_stream)?;
Ok(String::from(head) + &BASE64_STANDARD.encode(&output_stream))
}
}
@@ -0,0 +1,4 @@
//! Base traits
pub(crate) mod captcha;
pub(crate) mod randoms;
@@ -0,0 +1,86 @@
use rand::{random};
/// 随机数工具类
pub(crate) struct Randoms {
/// 定义验证码字符.去除了0、O、I、L等容易混淆的字母
pub alpha: [char; 54],
/// 数字的最大索引,不包括最大值
pub num_max_index: usize,
/// 字符的最小索引,包括最小值
pub char_min_index: usize,
/// 字符的最大索引,不包括最大值
pub char_max_index: usize,
/// 大写字符最小索引
pub upper_min_index: usize,
/// 大写字符最大索引
pub upper_max_index: usize,
/// 小写字母最小索引
pub lower_min_index: usize,
/// 小写字母最大索引
pub lower_max_index: usize,
}
impl Randoms {
pub fn new() -> Self {
// Defines the Captcha characters, removing characters like 0, O, I, l, etc.
let alpha = [
'2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
'K', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c',
'd', 'e', 'f', 'g', 'h', 'j', 'k', 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
'x', 'y', 'z',
];
let num_max_index = 8;
let char_min_index = num_max_index;
let char_max_index = alpha.len();
let upper_min_index = char_min_index;
let upper_max_index = upper_min_index + 23;
let lower_min_index = upper_max_index;
let lower_max_index = char_max_index;
Self {
alpha,
num_max_index,
char_min_index,
char_max_index,
upper_min_index,
upper_max_index,
lower_min_index,
lower_max_index,
}
}
/// 产生两个数之间的随机数
pub fn num_between(&mut self, min: i32, max: i32) -> i32 {
min + (random::<usize>() % (max - min) as usize) as i32
}
/// 产生0-num的随机数,不包括num
pub fn num(&mut self, num: usize) -> usize {
random::<usize>() % num
}
/// 返回ALPHA中的随机字符
pub fn alpha(&mut self) -> char {
self.alpha[self.num(self.alpha.len())]
}
/// 返回ALPHA中第0位到第num位的随机字符
pub fn alpha_under(&mut self, num: usize) -> char {
self.alpha[self.num(num)]
}
/// 返回ALPHA中第min位到第max位的随机字符
pub fn alpha_between(&mut self, min: usize, max: usize) -> char {
self.alpha[self.num_between(min as i32, max as i32) as usize]
}
}
@@ -0,0 +1 @@
pub mod spec;
@@ -0,0 +1,318 @@
//! Static alphabetical PNG Captcha
//!
//! PNG格式验证码
//!
use super::super::base::captcha::{AbstractCaptcha, Captcha};
use super::super::{CaptchaFont, NewCaptcha};
use image::{ImageBuffer, Rgba};
use imageproc::drawing;
use rand::{rngs::ThreadRng, Rng};
use rusttype::{Font, Scale};
use std::io::{Cursor, Write};
use std::sync::Arc;
mod color {
use image::Rgba;
use rand::{rngs::ThreadRng, Rng};
pub fn gen_background_color(rng: &mut ThreadRng) -> Rgba<u8> {
let red = rng.gen_range(200..=255);
let green = rng.gen_range(200..=255);
let blue = rng.gen_range(200..=255);
//let a=rng.gen_range(0..255);
Rgba([red, green, blue, 255])
}
pub fn gen_text_color(rng: &mut ThreadRng) -> Rgba<u8> {
let red = rng.gen_range(0..=150);
let green = rng.gen_range(0..=150);
let blue = rng.gen_range(0..=150);
Rgba([red, green, blue, 255])
}
pub fn gen_line_color(rng: &mut ThreadRng) -> Rgba<u8> {
let red = rng.gen_range(100..=255);
let green = rng.gen_range(100..=255);
let blue = rng.gen_range(100..=255);
Rgba([red, green, blue, 255])
}
}
///the builder of captcha
pub struct CaptchaBuilder<'a, 'b> {
///captcha image width
pub width: u32,
///captcha image height
pub height: u32,
///random string length.
pub length: u32,
///source is a unicode which is the rand string from.
pub source: String,
///image background color (optional)
pub background_color: Option<Rgba<u8>>,
///fonts collection for text
pub fonts: &'b [Arc<Font<'a>>],
///The maximum number of lines to draw behind of the image
pub max_behind_lines: Option<u32>,
///The maximum number of lines to draw in front of the image
pub max_front_lines: Option<u32>,
///The maximum number of ellipse lines to draw in front of the image
pub max_ellipse_lines: Option<u32>,
}
impl<'a, 'b> Default for CaptchaBuilder<'a, 'b> {
fn default() -> Self {
Self {
width: 150,
height: 40,
length: 5,
source: String::from("1234567890qwertyuioplkjhgfdsazxcvbnm"),
background_color: None,
fonts: &[],
max_behind_lines: None,
max_front_lines: None,
max_ellipse_lines: None,
}
}
}
impl<'a, 'b> CaptchaBuilder<'a, 'b> {
fn write_phrase(
&self,
image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>,
rng: &mut ThreadRng,
phrase: &str,
) {
//println!("phrase={}", phrase);
//println!("width={}, height={}", self.width, self.height);
let font_size = (self.width as f32) / (self.length as f32) - rng.gen_range(1.0..=4.0);
let scale = Scale::uniform(font_size);
if self.fonts.is_empty() {
panic!("no fonts loaded");
}
let font_index = rng.gen_range(0..self.fonts.len());
let font = &self.fonts[font_index];
let glyphs: Vec<_> = font
.layout(phrase, scale, rusttype::point(0.0, 0.0))
.collect();
let text_height = {
let v_metrics = font.v_metrics(scale);
(v_metrics.ascent - v_metrics.descent).ceil() as u32
};
let text_width = {
let min_x = glyphs.first().unwrap().pixel_bounding_box().unwrap().min.x;
let max_x = glyphs.last().unwrap().pixel_bounding_box().unwrap().max.x;
let last_x_pos = glyphs.last().unwrap().position().x as i32;
(max_x + last_x_pos - min_x) as u32
};
let node_width = text_width / self.length;
//println!("text_width={}, text_height={}", text_width, text_height);
let mut x = ((self.width as i32) - (text_width as i32)) / 2;
let y = ((self.height as i32) - (text_height as i32)) / 2;
//
for s in phrase.chars() {
let text_color = color::gen_text_color(rng);
let offset = rng.gen_range(-5..=5);
//println!("x={}, y={}", x, y);
drawing::draw_text_mut(
image,
text_color,
x,
y + offset,
scale,
font,
&s.to_string(),
);
x += node_width as i32;
}
}
fn draw_line(&self, image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>, rng: &mut ThreadRng) {
let line_color = color::gen_line_color(rng);
let is_h = rng.gen();
let (start, end) = if is_h {
let xa = rng.gen_range(0.0..(self.width as f32) / 2.0);
let ya = rng.gen_range(0.0..(self.height as f32));
let xb = rng.gen_range((self.width as f32) / 2.0..(self.width as f32));
let yb = rng.gen_range(0.0..(self.height as f32));
((xa, ya), (xb, yb))
} else {
let xa = rng.gen_range(0.0..(self.width as f32));
let ya = rng.gen_range(0.0..(self.height as f32) / 2.0);
let xb = rng.gen_range(0.0..(self.width as f32));
let yb = rng.gen_range((self.height as f32) / 2.0..(self.height as f32));
((xa, ya), (xb, yb))
};
let thickness = rng.gen_range(2..4);
for i in 0..thickness {
let offset = i as f32;
if is_h {
drawing::draw_line_segment_mut(
image,
(start.0, start.1 + offset),
(end.0, end.1 + offset),
line_color,
);
} else {
drawing::draw_line_segment_mut(
image,
(start.0 + offset, start.1),
(end.0 + offset, end.1),
line_color,
);
}
}
}
fn draw_ellipse(&self, image: &mut ImageBuffer<Rgba<u8>, Vec<u8>>, rng: &mut ThreadRng) {
let line_color = color::gen_line_color(rng);
let thickness = rng.gen_range(2..4);
for i in 0..thickness {
let center = (
rng.gen_range(-(self.width as i32) / 4..(self.width as i32) * 5 / 4),
rng.gen_range(-(self.height as i32) / 4..(self.height as i32) * 5 / 4),
);
drawing::draw_hollow_ellipse_mut(
image,
(center.0, center.1 + i),
(self.width * 6 / 7) as i32,
(self.height * 5 / 8) as i32,
line_color,
);
}
}
fn build_image(&self, phrase: String) -> ImageBuffer<Rgba<u8>, Vec<u8>> {
let mut rng = rand::thread_rng();
let bgc = match self.background_color {
Some(v) => v,
None => color::gen_background_color(&mut rng),
};
let mut image = ImageBuffer::from_fn(self.width, self.height, |_, _| bgc);
//draw behind line
let square = self.width * self.height;
let effects = match self.max_behind_lines {
Some(s) => {
if s > 0 {
rng.gen_range(square / 3000..square / 2000).min(s)
} else {
0
}
}
None => rng.gen_range(square / 3000..square / 2000),
};
for _ in 0..effects {
self.draw_line(&mut image, &mut rng);
}
//write phrase
self.write_phrase(&mut image, &mut rng, &phrase);
//draw front line
let effects = match self.max_front_lines {
Some(s) => {
if s > 0 {
rng.gen_range(square / 3000..=square / 2000).min(s)
} else {
0
}
}
None => rng.gen_range(square / 3000..=square / 2000),
};
for _ in 0..effects {
self.draw_line(&mut image, &mut rng);
}
//draw ellipse
let effects = match self.max_front_lines {
Some(s) => {
if s > 0 {
rng.gen_range(square / 4000..=square / 3000).min(s)
} else {
0
}
}
None => rng.gen_range(square / 4000..=square / 3000),
};
for _ in 0..effects {
self.draw_ellipse(&mut image, &mut rng);
}
image
}
}
/// PNG格式验证码
pub struct SpecCaptcha {
pub(crate) captcha: Captcha,
}
impl NewCaptcha for SpecCaptcha {
fn new() -> Self {
Self {
captcha: Captcha::new(),
}
}
fn with_size(width: i32, height: i32) -> Self {
Self {
captcha: Captcha::with_size(width, height),
}
}
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self {
Self {
captcha: Captcha::with_size_and_len(width, height, len),
}
}
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self {
Self {
captcha: Captcha::with_all(width, height, len, font, font_size),
}
}
}
impl AbstractCaptcha for SpecCaptcha {
type Error = image::ImageError;
fn out(&mut self, mut out: impl Write) -> Result<(), Self::Error> {
let phrase = self.captcha.text_char();
let builder = CaptchaBuilder {
width: self.captcha.width as u32,
height: self.captcha.height as u32,
length: self.captcha.len as u32,
background_color: None,
fonts: &[self.captcha.get_font()],
max_behind_lines: Some(0),
max_front_lines: Some(0),
max_ellipse_lines: Some(0),
..Default::default()
};
let image = builder.build_image(phrase.iter().collect());
let format = image::ImageOutputFormat::Png;
let mut raw_data: Vec<u8> = Vec::new();
image.write_to(&mut Cursor::new(&mut raw_data), format)?;
out.write_all(&raw_data)?;
Ok(())
}
fn get_chars(&mut self) -> Vec<char> {
self.captcha.text_char()
}
fn base64(&mut self) -> Result<String, Self::Error> {
self.base64_with_head("data:image/png;base64,")
}
fn get_content_type(&mut self) -> String {
"image/png".into()
}
}
#[cfg(test)]
mod test {
#[test]
fn it_works() {}
}
@@ -0,0 +1,69 @@
//! Axum & Tower_sessions 组合
//!
//! - Axum: [axum](https://docs.rs/axum)
//! - Tower Sessions: [axum](https://docs.rs/tower-sessions)
use super::AbstractCaptcha;
use super::CaptchaUtil;
use async_trait::async_trait;
use axum::response::Response;
use std::fmt::Debug;
use tower_sessions::Session;
const CAPTCHA_KEY: &'static str = "ez-captcha";
/// Axum & Tower_Sessions
#[async_trait]
pub trait CaptchaAxumTowerSessionExt {
/// 错误类型
type Error: Debug + Send + Sync + 'static;
/// 将验证码图片写入响应,并将用户的验证码信息保存至Session中
///
/// Write the Captcha Image into the response and save the Captcha information into the user's Session.
async fn out(&mut self, session: &Session) -> Result<Response, Self::Error>;
}
/// Axum & Tower_Sessions - 静态方法
#[async_trait]
pub trait CaptchaAxumTowerSessionStaticExt {
/// 验证验证码,返回的布尔值代表验证码是否正确
///
/// Verify the Captcha code, and return whether user's code is correct.
async fn ver(code: &str, session: &Session) -> bool {
match session.get::<String>(CAPTCHA_KEY).await {
Ok(Some(ans)) => ans.to_ascii_lowercase() == code.to_ascii_lowercase(),
_ => false,
}
}
/// 清除Session中的验证码
///
/// Clear the Captcha in the session.
async fn clear(session: &Session) {
if session.remove::<String>(CAPTCHA_KEY).await.is_err() {
tracing::warn!("Exception occurs during clearing the session.")
}
}
}
#[async_trait]
impl<T: AbstractCaptcha + Send> CaptchaAxumTowerSessionExt for CaptchaUtil<T> {
type Error = anyhow::Error;
async fn out(&mut self, session: &Session) -> Result<Response, Self::Error> {
let mut data = vec![];
self.captcha_instance.out(&mut data)?;
let ans: String = self.captcha_instance.get_chars().iter().collect();
session.insert(CAPTCHA_KEY, ans).await?;
let resp = Response::builder()
.header("Content-Type", self.captcha_instance.get_content_type())
.body(data.into())?;
Ok(resp)
}
}
#[async_trait]
impl CaptchaAxumTowerSessionStaticExt for CaptchaUtil {}
@@ -0,0 +1,41 @@
pub mod axum_tower_sessions;
use super::base::captcha::AbstractCaptcha;
use super::captcha::spec::SpecCaptcha;
use super::{CaptchaFont, NewCaptcha};
/// 验证码工具类 - Captcha Utils
///
/// 默认使用[SpecCaptcha](静态PNG字母验证码)作为验证码实现,用户也可以指定其他实现了[AbstractCaptcha]的类型。
///
/// Use [SpecCaptcha] (static PNG-format alphabetical Captcha) as the default implement of the Captcha service. Users may use other implementation of [AbstractCaptcha] they prefer.
///
pub struct CaptchaUtil<T: AbstractCaptcha = SpecCaptcha> {
captcha_instance: T,
}
impl<T: AbstractCaptcha> NewCaptcha for CaptchaUtil<T> {
fn new() -> Self {
Self {
captcha_instance: T::new(),
}
}
fn with_size(width: i32, height: i32) -> Self {
Self {
captcha_instance: T::with_size(width, height),
}
}
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self {
Self {
captcha_instance: T::with_size_and_len(width, height, len),
}
}
fn with_all(width: i32, height: i32, len: usize, font: CaptchaFont, font_size: f32) -> Self {
Self {
captcha_instance: T::with_all(width, height, len, font, font_size),
}
}
}
+134
View File
@@ -0,0 +1,134 @@
//! Rust图形验证码,由Java同名开源库[whvcse/EasyCaptcha](https://github.com/ele-admin/EasyCaptcha)移植而来👏,100%纯Rust实现,支持gif、算术等类型。
//!
//! Rust Captcha library, which is ported from Java's same-name library [whvcse/EasyCaptcha](https://github.com/ele-admin/EasyCaptcha),
//! implemented in 100% pure Rust, supporting GIF and arithmetic problems.
//!
//! <br/>
//!
//! 目前已适配框架 / Frameworks which is adapted now:
//!
//! - `axum` + `tower-sessions`
//!
//! 更多框架欢迎您提交PR,参与适配🙏 PR for new frameworks are welcomed
//!
//! <br/>
//!
//! ## 安装 Install
//!
//! 请参考Github README为Linux系统安装依赖。
//!
//! If you are compiling this project in linux, please refer to README in repository to install
//! dependencies into you system.
//!
//! ## 使用 Usage
//!
//! 若您正在使用的框架已适配,您可直接通过[CaptchaUtil](extension::CaptchaUtil)类(并导入相应框架的trait)来使用验证码:
//!
//! If your framework is adapted, you can just use [CaptchaUtil](extension::CaptchaUtil) and importing traits of your
//! framework to use the Captcha:
//!
//! ```
//! use std::collections::HashMap;
//! use axum::extract::Query;
//! use axum::response::IntoResponse;
//! use easy_captcha::captcha::gif::GifCaptcha;
//! use easy_captcha::extension::axum_tower_sessions::{
//! CaptchaAxumTowerSessionExt, CaptchaAxumTowerSessionStaticExt,
//! };
//! use easy_captcha::extension::CaptchaUtil;
//! use easy_captcha::NewCaptcha;
//!
//! /// 接口:获取验证码
//! /// Handler: Get a captcha
//! async fn get_captcha(session: tower_sessions::Session) -> Result<axum::response::Response, axum::http::StatusCode> {
//! let mut captcha: CaptchaUtil<GifCaptcha> = CaptchaUtil::new();
//! match captcha.out(&session).await {
//! Ok(response) => Ok(response),
//! Err(_) => Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
//! }
//! }
//!
//! /// 接口:验证验证码
//! /// Handler: Verify captcha codes
//! async fn verify_captcha(
//! session: tower_sessions::Session,
//! Query(query): Query<HashMap<String, String>>,
//! ) -> axum::response::Response {
//! // 从请求中获取验证码 Getting code from the request.
//! if let Some(code) = query.get("code") {
//! // 调用CaptchaUtil的静态方法验证验证码是否正确 Use a static method in CaptchaUtil to verify.
//! if CaptchaUtil::ver(code, &session).await {
//! CaptchaUtil::clear(&session).await; // 如果愿意的话,你可以从Session中清理掉验证码 You may clear the Captcha from the Session if you want
//! "Your code is valid, thank you.".into_response()
//! } else {
//! "Your code is not valid, I'm sorry.".into_response()
//! }
//! } else {
//! "You didn't provide the code.".into_response()
//! }
//! }
//! ```
//!
//! 您也可以自定义验证码的各项属性
//!
//! You can also specify properties of the Captcha.
//!
//! ```rust
//! use easy_captcha::captcha::gif::GifCaptcha;
//! use easy_captcha::extension::axum_tower_sessions::CaptchaAxumTowerSessionExt;
//! use easy_captcha::extension::CaptchaUtil;
//! use easy_captcha::NewCaptcha;
//!
//! async fn get_captcha(session: tower_sessions::Session) -> Result<axum::response::Response, axum::http::StatusCode> {
//! let mut captcha: CaptchaUtil<GifCaptcha> = CaptchaUtil::with_size_and_len(127, 48, 4);
//! match captcha.out(&session).await {
//! Ok(response) => Ok(response),
//! Err(_) => Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
//! }
//! }
//! ```
//!
//! 项目当前提供了三种验证码实现:[SpecCaptcha](captcha::spec::SpecCaptcha)(静态PNG)、[GifCaptcha](captcha::gif::GifCaptcha)(动态GIF
//! 、[ArithmeticCaptcha](captcha::arithmetic::ArithmeticCaptcha)(算术PNG),您可按需使用。
//!
//! There is three implementation of Captcha currently, which are [SpecCaptcha](captcha::spec::SpecCaptcha)(static PNG),
//! [GifCaptcha](captcha::gif::GifCaptcha)(GIF), [ArithmeticCaptcha](captcha::arithmetic::ArithmeticCaptcha)(Arithmetic problems),
//! you can use them according to your need.
//!
//! <br/>
//!
//! 自带字体效果 / Fonts shipped
//!
//! | 字体/Fonts | 效果/Preview |
//! |---------------------|------------------------------------------------|
//! | CaptchaFont::Font1 | ![](https://s2.ax1x.com/2019/08/23/msMe6U.png) |
//! | CaptchaFont::Font2 | ![](https://s2.ax1x.com/2019/08/23/msMAf0.png) |
//! | CaptchaFont::Font3 | ![](https://s2.ax1x.com/2019/08/23/msMCwj.png) |
//! | CaptchaFont::Font4 | ![](https://s2.ax1x.com/2019/08/23/msM9mQ.png) |
//! | CaptchaFont::Font5 | ![](https://s2.ax1x.com/2019/08/23/msKz6S.png) |
//! | CaptchaFont::Font6 | ![](https://s2.ax1x.com/2019/08/23/msKxl8.png) |
//! | CaptchaFont::Font7 | ![](https://s2.ax1x.com/2019/08/23/msMPTs.png) |
//! | CaptchaFont::Font8 | ![](https://s2.ax1x.com/2019/08/23/msMmXF.png) |
//! | CaptchaFont::Font9 | ![](https://s2.ax1x.com/2019/08/23/msMVpV.png) |
//! | CaptchaFont::Font10 | ![](https://s2.ax1x.com/2019/08/23/msMZlT.png) |
//!
#![warn(missing_docs)]
#![allow(dead_code)]
pub(crate) mod base;
pub mod captcha;
pub mod extension;
mod utils;
pub use base::captcha::*;
// #[cfg(test)]
// mod tests {
// use super::*;
//
// #[test]
// fn it_works() {
//
// }
// }
@@ -0,0 +1,53 @@
//! RGBA颜色
use std::fmt::{Debug, Formatter};
#[derive(Clone)]
pub struct Color(f64, f64, f64, f64);
impl Color {
pub fn set_alpha(&mut self, a: f64) {
self.3 = a;
}
}
impl Debug for Color {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Color")
.field("r", &self.0)
.field("g", &self.1)
.field("b", &self.2)
.field("a", &self.3)
.finish()
}
}
impl From<(u8, u8, u8)> for Color {
fn from(value: (u8, u8, u8)) -> Self {
Self(
value.0 as f64 / 255.0,
value.1 as f64 / 255.0,
value.2 as f64 / 255.0,
1.0,
)
}
}
impl Into<(u8, u8, u8, u8)> for Color {
fn into(self) -> (u8, u8, u8, u8) {
(
(self.0 * 255.0) as u8,
(self.1 * 255.0) as u8,
(self.2 * 255.0) as u8,
(self.3 * 255.0) as u8,
)
}
}
impl Into<u32> for Color {
fn into(self) -> u32 {
let color: (u8, u8, u8, u8) = self.into();
(color.0 as u32) << 24 + (color.1 as u32) << 16 + (color.2 as u32) << 8 + (color.3 as u32)
}
}
impl Color {}
@@ -0,0 +1,45 @@
use rust_embed::RustEmbed;
use rusttype::Font;
use std::error::Error;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "resources/"]
struct FontAssets;
// lazy_static! {
// pub(crate) static ref FONTS: RwLock<HashMap<String, Arc<Font>>> = Default::default();
// }
pub fn get_font(font_name: &str) -> Option<Arc<Font>> {
// let fonts_cell = FONTS.get_or_init(|| Default::default());
// let guard = fonts_cell.read();
//
// if guard.contains_key(font_name) {
// Some(guard.get(font_name).unwrap().clone())
// } else {
// drop(guard);
if let Ok(Some(font)) = load_font(font_name) {
// let mut guard = fonts_cell.write();
let font = Arc::new(font);
// guard.insert(String::from(font_name), font.clone());
Some(font)
} else {
None
}
// }
}
pub fn load_font(font_name: &str) -> Result<Option<Font>, Box<dyn Error>> {
match FontAssets::get(font_name) {
Some(assets) => {
let font = Font::try_from_vec(Vec::from(assets.data)).unwrap();
Ok(Some(font))
}
None => {
tracing::error!("Unable to find the specified font.");
Ok(None)
}
}
}
@@ -0,0 +1,4 @@
//! Utilities
pub(crate) mod color;
pub(crate) mod font;
+97 -181
View File
@@ -1,23 +1,40 @@
use std::vec;
mod auth;
pub(crate) mod captcha;
mod network;
mod users;
use std::{net::SocketAddr, sync::Arc};
use axum::extract::{Path, Query};
use axum::http::StatusCode;
use axum::routing::post;
use axum::{extract::State, routing::get, Json, Router};
use easytier::proto::{self, rpc_types, web::*};
use easytier::{common::scoped_task::ScopedTask, proto::rpc_types::controller::BaseController};
use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer};
use axum_login::{login_required, AuthManagerLayerBuilder, AuthzBackend};
use axum_messages::MessagesManagerLayer;
use easytier::common::scoped_task::ScopedTask;
use easytier::proto::{self, rpc_types};
use network::NetworkApi;
use sea_orm::DbErr;
use tokio::net::TcpListener;
use tower_sessions::cookie::time::Duration;
use tower_sessions::cookie::Key;
use tower_sessions::Expiry;
use tower_sessions_sqlx_store::SqliteStore;
use users::{AuthSession, Backend};
use crate::client_manager::session::Session;
use crate::client_manager::storage::StorageToken;
use crate::client_manager::ClientManager;
use crate::db::Db;
pub struct RestfulServer {
bind_addr: SocketAddr,
client_mgr: Arc<ClientManager>,
db: Db,
serve_task: Option<ScopedTask<()>>,
delete_task: Option<ScopedTask<tower_sessions::session_store::Result<()>>>,
network_api: NetworkApi,
}
type AppStateInner = Arc<ClientManager>;
@@ -26,52 +43,44 @@ type AppState = State<AppStateInner>;
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListSessionJsonResp(Vec<StorageToken>);
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ValidateConfigJsonReq {
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RunNetworkJsonReq {
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ColletNetworkInfoJsonReq {
inst_ids: Option<Vec<uuid::Uuid>>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RemoveNetworkJsonReq {
inst_ids: Vec<uuid::Uuid>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListNetworkInstanceIdsJsonResp(Vec<uuid::Uuid>);
type Error = proto::error::Error;
type ErrorKind = proto::error::error::ErrorKind;
pub type Error = proto::error::Error;
pub type ErrorKind = proto::error::error::ErrorKind;
type RpcError = rpc_types::error::Error;
type HttpHandleError = (StatusCode, Json<Error>);
fn convert_rpc_error(e: RpcError) -> (StatusCode, Json<Error>) {
let status_code = match &e {
RpcError::ExecutionError(_) => StatusCode::BAD_REQUEST,
RpcError::Timeout(_) => StatusCode::GATEWAY_TIMEOUT,
_ => StatusCode::BAD_GATEWAY,
};
let error = Error::from(&e);
(status_code, Json(error))
pub fn other_error<T: ToString>(error_message: T) -> Error {
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: error_message.to_string(),
})),
}
}
pub fn convert_db_error(e: DbErr) -> HttpHandleError {
(
StatusCode::INTERNAL_SERVER_ERROR,
other_error(format!("DB Error: {:#}", e)).into(),
)
}
impl RestfulServer {
pub fn new(bind_addr: SocketAddr, client_mgr: Arc<ClientManager>) -> Self {
pub async fn new(
bind_addr: SocketAddr,
client_mgr: Arc<ClientManager>,
db: Db,
) -> anyhow::Result<Self> {
assert!(client_mgr.is_running());
RestfulServer {
let network_api = NetworkApi::new();
Ok(RestfulServer {
bind_addr,
client_mgr,
db,
serve_task: None,
}
delete_task: None,
network_api,
})
}
async fn get_session_by_machine_id(
@@ -79,162 +88,69 @@ impl RestfulServer {
machine_id: &uuid::Uuid,
) -> Result<Arc<Session>, HttpHandleError> {
let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else {
return Err((
StatusCode::NOT_FOUND,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: "No such session".to_string(),
})),
}
.into(),
));
return Err((StatusCode::NOT_FOUND, other_error("No such session").into()));
};
Ok(result)
}
async fn handle_list_all_sessions(
auth_session: AuthSession,
State(client_mgr): AppState,
) -> Result<Json<ListSessionJsonResp>, HttpHandleError> {
let pers = auth_session
.backend
.get_group_permissions(auth_session.user.as_ref().unwrap())
.await
.unwrap();
println!("{:?}", pers);
let ret = client_mgr.list_sessions().await;
Ok(ListSessionJsonResp(ret).into())
}
async fn handle_validate_config(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<ValidateConfigJsonReq>,
) -> Result<(), HttpHandleError> {
let config = payload.config;
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.validate_config(BaseController::default(), ValidateConfigRequest { config })
.await
.map_err(convert_rpc_error)?;
Ok(())
}
async fn handle_run_network_instance(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<RunNetworkJsonReq>,
) -> Result<(), HttpHandleError> {
let config = payload.config;
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.run_network_instance(
BaseController::default(),
RunNetworkInstanceRequest { config },
)
.await
.map_err(convert_rpc_error)?;
Ok(())
}
async fn handle_collect_one_network_info(
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_collect_network_info(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Query(payload): Query<ColletNetworkInfoJsonReq>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: payload
.inst_ids
.unwrap_or_default()
.into_iter()
.map(Into::into)
.collect(),
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_list_network_instance_ids(
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
) -> Result<Json<ListNetworkInstanceIdsJsonResp>, HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.list_network_instance(BaseController::default(), ListNetworkInstanceRequest {})
.await
.map_err(convert_rpc_error)?;
Ok(
ListNetworkInstanceIdsJsonResp(ret.inst_ids.into_iter().map(Into::into).collect())
.into(),
)
}
async fn handle_remove_network_instance(
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<(), HttpHandleError> {
let result = Self::get_session_by_machine_id(&client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.delete_network_instance(
BaseController::default(),
DeleteNetworkInstanceRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(())
}
pub async fn start(&mut self) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind(self.bind_addr).await.unwrap();
let listener = TcpListener::bind(self.bind_addr).await?;
// Session layer.
//
// This uses `tower-sessions` to establish a layer that will provide the session
// as a request extension.
let session_store = SqliteStore::new(self.db.inner());
session_store.migrate().await?;
self.delete_task.replace(
tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
)
.into(),
);
// Generate a cryptographic key to sign the session cookie.
let key = Key::generate();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::days(1)))
.with_signed(key);
// Auth service.
//
// This combines the session layer with our backend to establish the auth
// service which will provide the auth session as a request extension.
let backend = Backend::new(self.db.clone());
let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
let app = Router::new()
.route("/api/v1/sessions", get(Self::handle_list_all_sessions))
.route(
"/api/v1/network/:machine-id/validate-config",
post(Self::handle_validate_config),
)
.route(
"/api/v1/network/:machine-id",
post(Self::handle_run_network_instance).get(Self::handle_list_network_instance_ids),
)
.route(
"/api/v1/network/:machine-id/info",
get(Self::handle_collect_network_info),
)
.route(
"/api/v1/network/:machine-id/:inst-id",
get(Self::handle_collect_one_network_info)
.delete(Self::handle_remove_network_instance),
)
.with_state(self.client_mgr.clone());
.merge(self.network_api.build_route())
.route_layer(login_required!(Backend))
.merge(auth::router())
.with_state(self.client_mgr.clone())
.layer(MessagesManagerLayer)
.layer(auth_layer)
.layer(tower_http::cors::CorsLayer::very_permissive());
let task = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
+321
View File
@@ -0,0 +1,321 @@
use std::sync::Arc;
use axum::extract::{Path, Query};
use axum::http::StatusCode;
use axum::routing::{delete, post};
use axum::{extract::State, routing::get, Json, Router};
use axum_login::AuthUser;
use dashmap::DashSet;
use easytier::proto::common::Void;
use easytier::proto::rpc_types::controller::BaseController;
use easytier::proto::{self, web::*};
use crate::client_manager::session::Session;
use crate::client_manager::ClientManager;
use super::users::AuthSession;
use super::{
convert_db_error, AppState, AppStateInner, Error, ErrorKind, HttpHandleError, RpcError,
};
fn convert_rpc_error(e: RpcError) -> (StatusCode, Json<Error>) {
let status_code = match &e {
RpcError::ExecutionError(_) => StatusCode::BAD_REQUEST,
RpcError::Timeout(_) => StatusCode::GATEWAY_TIMEOUT,
_ => StatusCode::BAD_GATEWAY,
};
let error = Error::from(&e);
(status_code, Json(error))
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ValidateConfigJsonReq {
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RunNetworkJsonReq {
config: String,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ColletNetworkInfoJsonReq {
inst_ids: Option<Vec<uuid::Uuid>>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct RemoveNetworkJsonReq {
inst_ids: Vec<uuid::Uuid>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListNetworkInstanceIdsJsonResp(Vec<uuid::Uuid>);
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListMachineItem {
client_url: Option<url::Url>,
info: Option<HeartbeatRequest>,
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
struct ListMachineJsonResp {
machines: Vec<ListMachineItem>,
}
pub struct NetworkApi {}
impl NetworkApi {
pub fn new() -> Self {
Self {}
}
async fn get_session_by_machine_id(
auth_session: &AuthSession,
client_mgr: &ClientManager,
machine_id: &uuid::Uuid,
) -> Result<Arc<Session>, HttpHandleError> {
let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else {
return Err((
StatusCode::NOT_FOUND,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: format!("No such session: {}", machine_id),
})),
}
.into(),
));
};
let Some(token) = result.get_token().await else {
return Err((
StatusCode::UNAUTHORIZED,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: "No token reported".to_string(),
})),
}
.into(),
));
};
if !auth_session
.user
.as_ref()
.map(|x| x.tokens.contains(&token.token))
.unwrap_or(false)
{
return Err((
StatusCode::FORBIDDEN,
Error {
error_kind: Some(ErrorKind::OtherError(proto::error::OtherError {
error_message: "Token mismatch".to_string(),
})),
}
.into(),
));
}
Ok(result)
}
async fn handle_validate_config(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<ValidateConfigJsonReq>,
) -> Result<Json<Void>, HttpHandleError> {
let config = payload.config;
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
c.validate_config(BaseController::default(), ValidateConfigRequest { config })
.await
.map_err(convert_rpc_error)?;
Ok(Void::default().into())
}
async fn handle_run_network_instance(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Json(payload): Json<RunNetworkJsonReq>,
) -> Result<Json<Void>, HttpHandleError> {
let config = payload.config;
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let resp = c
.run_network_instance(
BaseController::default(),
RunNetworkInstanceRequest {
inst_id: None,
config: config.clone(),
},
)
.await
.map_err(convert_rpc_error)?;
client_mgr
.db()
.insert_or_update_user_network_config(
auth_session.user.as_ref().unwrap().id(),
resp.inst_id.clone().unwrap_or_default().into(),
config,
)
.await
.map_err(convert_db_error)?;
Ok(Void::default().into())
}
async fn handle_collect_one_network_info(
auth_session: AuthSession,
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_collect_network_info(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
Query(payload): Query<ColletNetworkInfoJsonReq>,
) -> Result<Json<CollectNetworkInfoResponse>, HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.collect_network_info(
BaseController::default(),
CollectNetworkInfoRequest {
inst_ids: payload
.inst_ids
.unwrap_or_default()
.into_iter()
.map(Into::into)
.collect(),
},
)
.await
.map_err(convert_rpc_error)?;
Ok(ret.into())
}
async fn handle_list_network_instance_ids(
auth_session: AuthSession,
State(client_mgr): AppState,
Path(machine_id): Path<uuid::Uuid>,
) -> Result<Json<ListNetworkInstanceIdsJsonResp>, HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
let c = result.scoped_rpc_client();
let ret = c
.list_network_instance(BaseController::default(), ListNetworkInstanceRequest {})
.await
.map_err(convert_rpc_error)?;
Ok(
ListNetworkInstanceIdsJsonResp(ret.inst_ids.into_iter().map(Into::into).collect())
.into(),
)
}
async fn handle_remove_network_instance(
auth_session: AuthSession,
State(client_mgr): AppState,
Path((machine_id, inst_id)): Path<(uuid::Uuid, uuid::Uuid)>,
) -> Result<(), HttpHandleError> {
let result =
Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;
client_mgr
.db()
.delete_network_config(auth_session.user.as_ref().unwrap().id(), inst_id)
.await
.map_err(convert_db_error)?;
let c = result.scoped_rpc_client();
c.delete_network_instance(
BaseController::default(),
DeleteNetworkInstanceRequest {
inst_ids: vec![inst_id.into()],
},
)
.await
.map_err(convert_rpc_error)?;
Ok(())
}
async fn handle_list_machines(
auth_session: AuthSession,
State(client_mgr): AppState,
) -> Result<Json<ListMachineJsonResp>, HttpHandleError> {
let tokens = auth_session
.user
.as_ref()
.map(|x| x.tokens.clone())
.unwrap_or_default();
let client_urls = DashSet::new();
for token in tokens {
let urls = client_mgr.list_machine_by_token(token);
for url in urls {
client_urls.insert(url);
}
}
let mut machines = vec![];
for item in client_urls.iter() {
let client_url = item.key().clone();
let session = client_mgr.get_heartbeat_requests(&client_url).await;
machines.push(ListMachineItem {
client_url: Some(client_url),
info: session,
});
}
Ok(Json(ListMachineJsonResp { machines }))
}
pub fn build_route(&mut self) -> Router<AppStateInner> {
Router::new()
.route("/api/v1/machines", get(Self::handle_list_machines))
.route(
"/api/v1/machines/:machine-id/validate-config",
post(Self::handle_validate_config),
)
.route(
"/api/v1/machines/:machine-id/networks",
post(Self::handle_run_network_instance).get(Self::handle_list_network_instance_ids),
)
.route(
"/api/v1/machines/:machine-id/networks/:inst-id",
delete(Self::handle_remove_network_instance),
)
.route(
"/api/v1/machines/:machine-id/networks/info",
get(Self::handle_collect_network_info),
)
.route(
"/api/v1/machines/:machine-id/networks/info/:inst-id",
get(Self::handle_collect_one_network_info),
)
}
}
+241
View File
@@ -0,0 +1,241 @@
use std::collections::HashSet;
use async_trait::async_trait;
use axum_login::{AuthUser, AuthnBackend, AuthzBackend, UserId};
use password_auth::verify_password;
use sea_orm::{
ActiveModelTrait as _, ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel, JoinType,
QueryFilter, QuerySelect as _, RelationTrait, Set, TransactionTrait,
};
use serde::{Deserialize, Serialize};
use tokio::task;
use crate::db::{self, entity};
#[derive(Clone, Serialize, Deserialize)]
pub struct User {
db_user: entity::users::Model,
pub tokens: Vec<String>,
}
// Here we've implemented `Debug` manually to avoid accidentally logging the
// password hash.
impl std::fmt::Debug for User {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("User")
.field("id", &self.db_user.id)
.field("username", &self.db_user.username)
.field("password", &"[redacted]")
.finish()
}
}
impl AuthUser for User {
type Id = i32;
fn id(&self) -> Self::Id {
self.db_user.id
}
fn session_auth_hash(&self) -> &[u8] {
self.db_user.password.as_bytes() // We use the password hash as the auth
// hash--what this means
// is when the user changes their password the
// auth session becomes invalid.
}
}
// This allows us to extract the authentication fields from forms. We use this
// to authenticate requests with the backend.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Credentials {
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RegisterNewUser {
pub credentials: Credentials,
pub captcha: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChangePassword {
pub new_password: String,
}
#[derive(Debug, Clone)]
pub struct Backend {
db: db::Db,
}
impl Backend {
pub fn new(db: db::Db) -> Self {
Self { db }
}
pub async fn register_new_user(&self, new_user: &RegisterNewUser) -> anyhow::Result<()> {
let hashed_password = password_auth::generate_hash(new_user.credentials.password.as_str());
let mut txn = self.db.orm_db().begin().await?;
entity::users::ActiveModel {
username: Set(new_user.credentials.username.clone()),
password: Set(hashed_password.clone()),
..Default::default()
}
.save(&mut txn)
.await?;
entity::users_groups::ActiveModel {
user_id: Set(entity::users::Entity::find()
.filter(entity::users::Column::Username.eq(new_user.credentials.username.as_str()))
.one(&mut txn)
.await?
.unwrap()
.id),
group_id: Set(entity::groups::Entity::find()
.filter(entity::groups::Column::Name.eq("users"))
.one(&mut txn)
.await?
.unwrap()
.id),
..Default::default()
}
.save(&mut txn)
.await?;
txn.commit().await?;
Ok(())
}
pub async fn change_password(
&self,
id: <User as AuthUser>::Id,
req: &ChangePassword,
) -> anyhow::Result<()> {
let hashed_password = password_auth::generate_hash(req.new_password.as_str());
use entity::users;
let mut user = users::Entity::find_by_id(id)
.one(self.db.orm_db())
.await?
.ok_or(anyhow::anyhow!("User not found"))?
.into_active_model();
user.password = Set(hashed_password.clone());
entity::users::Entity::update(user)
.exec(self.db.orm_db())
.await?;
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Sqlx(#[from] sea_orm::DbErr),
#[error(transparent)]
TaskJoin(#[from] task::JoinError),
}
#[async_trait]
impl AuthnBackend for Backend {
type User = User;
type Credentials = Credentials;
type Error = Error;
async fn authenticate(
&self,
creds: Self::Credentials,
) -> Result<Option<Self::User>, Self::Error> {
let user = entity::users::Entity::find()
.filter(entity::users::Column::Username.eq(creds.username))
.one(self.db.orm_db())
.await?;
task::spawn_blocking(|| {
// We're using password-based authentication--this works by comparing our form
// input with an argon2 password hash.
Ok(user
.filter(|user| verify_password(creds.password, &user.password).is_ok())
.map(|user| User {
db_user: user.clone(),
tokens: vec![user.username.clone()],
}))
})
.await?
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
let mut user = entity::users::Entity::find()
.filter(entity::users::Column::Id.eq(*user_id))
.one(self.db.orm_db())
.await?;
if let Some(u) = &mut user {
let mut user = User {
db_user: u.clone(),
tokens: vec![],
};
// username is a token
user.tokens.push(u.username.clone());
Ok(Some(user))
} else {
Ok(None)
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash, FromQueryResult)]
pub struct Permission {
pub name: String,
}
impl From<&str> for Permission {
fn from(name: &str) -> Self {
Permission {
name: name.to_string(),
}
}
}
#[async_trait]
impl AuthzBackend for Backend {
type Permission = Permission;
async fn get_group_permissions(
&self,
_user: &Self::User,
) -> Result<HashSet<Self::Permission>, Self::Error> {
let permissions = entity::users::Entity::find()
.column_as(entity::permissions::Column::Name, "name")
.join(
JoinType::LeftJoin,
entity::users::Relation::UsersGroups.def(),
)
.join(
JoinType::LeftJoin,
entity::users_groups::Relation::Groups.def(),
)
.join(
JoinType::LeftJoin,
entity::groups::Relation::GroupsPermissions.def(),
)
.join(
JoinType::LeftJoin,
entity::groups_permissions::Relation::Permissions.def(),
)
.into_model::<Self::Permission>()
.all(self.db.orm_db())
.await?;
Ok(permissions.into_iter().collect())
}
}
// We use a type alias for convenience.
//
// Note that we've supplied our concrete backend here.
pub type AuthSession = axum_login::AuthSession<Backend>;