Compare commits

..

4 Commits

Author SHA1 Message Date
fanyang 349dbf7d8d fix(web): avoid false default-password reminders
Only flag seeded accounts that still use the shipped password hash,
and keep auth status and password change responses stable during
review follow-up.
2026-04-05 17:54:12 +08:00
fanyang 7707b1cf5e fix(web): require password confirmation in auth forms
Require users to enter new passwords twice in the registration
and password change forms so typos are caught before credentials
are stored.
2026-04-05 17:31:22 +08:00
fanyang 2490bb9808 fix(web): enforce password strength in auth forms
Apply the same password policy to registration and password
changes so operators cannot replace default credentials with
another weak password and users see consistent guidance.
2026-04-05 17:31:22 +08:00
fanyang 3f3e36e653 feat(web): warn on default-password accounts
Track built-in admin and user accounts that still use their
seeded password so the web UI can prompt operators to
rotate credentials after deployment.

- Persist must-change-password state for seeded accounts.
- Clear the reminder after password changes and validate
  empty-password updates.
- Keep the migration and auth API behavior explicit.
2026-04-05 17:31:22 +08:00
32 changed files with 696 additions and 499 deletions
@@ -286,6 +286,9 @@ web:
logout: 退出登录
language: 语言
change_password: 修改密码
change_password_now: 立即修改密码
default_password_warning: 当前账号仍在使用系统默认密码。为保障安全,请部署完成后立即修改密码。
password_changed_relogin: 密码已修改,请重新登录。
device:
list: 设备列表
@@ -360,6 +363,11 @@ web:
success: 成功
warning: 警告
info: 提示
password_empty: 密码不能为空
password_min_length: 密码至少需要 8 位
password_too_weak: 密码强度不足
password_mismatch: 两次输入的密码不一致
password_strength_hint: 密码至少 8 位,且需包含大小写字母、数字、特殊字符中的至少 2 类
enable: 开启
disable: 关闭
address: 地址
@@ -286,6 +286,9 @@ web:
logout: Logout
language: Language
change_password: Change Password
change_password_now: Change Password Now
default_password_warning: This account is still using the default password. Change it immediately after deployment to keep your instance secure.
password_changed_relogin: Password changed. Please log in again.
device:
list: Device List
@@ -360,6 +363,11 @@ web:
success: Success
warning: Warning
info: Info
password_empty: Password cannot be empty
password_min_length: Password must be at least 8 characters long
password_too_weak: Password is too weak
password_mismatch: Passwords do not match
password_strength_hint: Password must be at least 8 characters and include at least 2 of uppercase letters, lowercase letters, numbers, or special characters
enable: Enable
disable: Disable
address: Address
@@ -1,17 +1,80 @@
<script lang="ts" setup>
import { computed, inject, ref } from 'vue';
import { Card, Password, Button } from 'primevue';
import { useToast } from 'primevue/usetoast';
import { useRouter } from 'vue-router';
import { useI18n } from 'vue-i18n';
import ApiClient from '../modules/api';
import { clearMustChangePasswordFlag } from '../modules/auth-status';
import { validatePasswordStrength } from '../modules/password-policy';
const dialogRef = inject<any>('dialogRef');
const api = computed<ApiClient>(() => dialogRef.value.data.api);
const password = ref('');
const confirmPassword = ref('');
const toast = useToast();
const router = useRouter();
const { t } = useI18n();
const passwordValidation = computed(() => validatePasswordStrength(password.value));
const passwordMatches = computed(() => password.value === confirmPassword.value);
const passwordErrorMessage = computed(() => {
if (password.value.length === 0 || passwordValidation.value.valid) {
return '';
}
return t(passwordValidation.value.reasonKey!);
});
const confirmPasswordErrorMessage = computed(() => {
if (confirmPassword.value.length === 0 || passwordMatches.value) {
return '';
}
return t('web.common.password_mismatch');
});
const canSubmit = computed(() => passwordValidation.value.valid && passwordMatches.value);
const changePassword = async () => {
await api.value.change_password(password.value);
dialogRef.value.close();
if (!passwordValidation.value.valid) {
toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t(passwordValidation.value.reasonKey!),
life: 3000,
});
return;
}
if (!passwordMatches.value) {
toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t('web.common.password_mismatch'),
life: 3000,
});
return;
}
try {
await api.value.change_password(password.value);
toast.add({
severity: 'success',
summary: t('web.common.success'),
detail: t('web.main.password_changed_relogin'),
life: 3000,
});
clearMustChangePasswordFlag();
dialogRef.value.close();
router.push({ name: 'login' });
} catch (error) {
toast.add({
severity: 'error',
summary: t('web.common.error'),
detail: error instanceof Error ? error.message : String(error),
life: 3000,
});
}
}
</script>
@@ -19,15 +82,28 @@ const changePassword = async () => {
<div class="flex items-center justify-center">
<Card class="w-full max-w-md p-6">
<template #header>
<h2 class="text-2xl font-semibold text-center">Change Password
<h2 class="text-2xl font-semibold text-center">{{ t('web.main.change_password') }}
</h2>
</template>
<template #content>
<div class="flex flex-col space-y-4">
<Password v-model="password" placeholder="New Password" :feedback="false" toggleMask />
<Button @click="changePassword" label="Ok" />
<Password v-model="password" :placeholder="t('web.settings.new_password')" :feedback="false"
toggleMask />
<Password v-model="confirmPassword" :placeholder="t('web.settings.confirm_password')"
:feedback="false" toggleMask />
<small class="text-surface-500 dark:text-surface-400">
{{ t('web.common.password_strength_hint') }}
</small>
<small v-if="passwordErrorMessage" class="text-red-500 dark:text-red-400">
{{ passwordErrorMessage }}
</small>
<small v-if="confirmPasswordErrorMessage" class="text-red-500 dark:text-red-400">
{{ confirmPasswordErrorMessage }}
</small>
<Button @click="changePassword" :label="t('web.common.confirm')"
:disabled="!canSubmit" />
</div>
</template>
</Card>
</div>
</template>
</template>
+60 -1
View File
@@ -7,6 +7,8 @@ import { I18nUtils } from 'easytier-frontend-lib';
import { getInitialApiHost, cleanAndLoadApiHosts, saveApiHost } from "../modules/api-host"
import { useI18n } from 'vue-i18n'
import ApiClient, { Credential, RegisterData } from '../modules/api';
import { setMustChangePasswordFlag } from '../modules/auth-status';
import { validatePasswordStrength } from '../modules/password-policy';
const { t } = useI18n()
@@ -22,8 +24,26 @@ const username = ref('');
const password = ref('');
const registerUsername = ref('');
const registerPassword = ref('');
const registerConfirmPassword = ref('');
const captcha = ref('');
const captchaSrc = computed(() => api.value.captcha_url());
const registerPasswordValidation = computed(() => validatePasswordStrength(registerPassword.value));
const registerPasswordsMatch = computed(() => registerPassword.value === registerConfirmPassword.value);
const registerPasswordErrorMessage = computed(() => {
if (registerPassword.value.length === 0 || registerPasswordValidation.value.valid) {
return '';
}
return t(registerPasswordValidation.value.reasonKey!);
});
const registerConfirmPasswordErrorMessage = computed(() => {
if (registerConfirmPassword.value.length === 0 || registerPasswordsMatch.value) {
return '';
}
return t('web.common.password_mismatch');
});
const canRegister = computed(() => registerPasswordValidation.value.valid && registerPasswordsMatch.value);
const onSubmit = async () => {
@@ -33,6 +53,7 @@ const onSubmit = async () => {
let ret = await api.value?.login(credential);
if (ret.success) {
localStorage.setItem('apiHost', btoa(apiHost.value));
setMustChangePasswordFlag(Boolean(ret.mustChangePassword));
router.push({
name: 'dashboard',
params: { apiHost: btoa(apiHost.value) },
@@ -43,6 +64,26 @@ const onSubmit = async () => {
};
const onRegister = async () => {
if (!registerPasswordValidation.value.valid) {
toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t(registerPasswordValidation.value.reasonKey!),
life: 3000,
});
return;
}
if (!registerPasswordsMatch.value) {
toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t('web.common.password_mismatch'),
life: 3000,
});
return;
}
saveApiHost(apiHost.value);
const credential: Credential = { username: registerUsername.value, password: registerPassword.value };
const registerReq: RegisterData = { credentials: credential, captcha: captcha.value };
@@ -156,6 +197,23 @@ onBeforeUnmount(() => {
}}</label>
<Password id="register-password" v-model="registerPassword" required toggleMask
:feedback="false" class="w-full" />
<small class="text-surface-500 dark:text-surface-400">
{{ t('web.common.password_strength_hint') }}
</small>
<small v-if="registerPasswordErrorMessage" class="block text-red-500 dark:text-red-400">
{{ registerPasswordErrorMessage }}
</small>
</div>
<div class="p-field">
<label for="register-confirm-password" class="block text-sm font-medium">
{{ t('web.settings.confirm_password') }}
</label>
<Password id="register-confirm-password" v-model="registerConfirmPassword" required toggleMask
:feedback="false" class="w-full" />
<small v-if="registerConfirmPasswordErrorMessage"
class="block text-red-500 dark:text-red-400">
{{ registerConfirmPasswordErrorMessage }}
</small>
</div>
<div class="p-field">
<label for="captcha" class="block text-sm font-medium">{{ t('web.login.captcha') }}</label>
@@ -163,7 +221,8 @@ onBeforeUnmount(() => {
<img :src="captchaSrc" alt="Captcha" class="mt-2 mb-2" />
</div>
<div class="flex items-center justify-between">
<Button :label="t('web.login.register')" type="submit" class="w-full" />
<Button :label="t('web.login.register')" type="submit" class="w-full"
:disabled="!canRegister" />
</div>
<div class="flex items-center justify-between">
<Button :label="t('web.login.back_to_login')" type="button" class="w-full"
@@ -1,13 +1,18 @@
<script setup lang="ts">
import { I18nUtils } from 'easytier-frontend-lib'
import { computed, onMounted, ref, onUnmounted, nextTick } from 'vue';
import { Button, TieredMenu } from 'primevue';
import { Button, Message, TieredMenu } from 'primevue';
import { useRoute, useRouter } from 'vue-router';
import { useDialog } from 'primevue/usedialog';
import ChangePassword from './ChangePassword.vue';
import Icon from '../assets/easytier.png'
import { useI18n } from 'vue-i18n'
import ApiClient from '../modules/api';
import {
clearMustChangePasswordFlag,
getMustChangePasswordFlag,
setMustChangePasswordFlag,
} from '../modules/auth-status';
const { t } = useI18n()
const route = useRoute();
@@ -15,6 +20,7 @@ const router = useRouter();
const api = computed<ApiClient | undefined>(() => {
try {
return new ApiClient(atob(route.params.apiHost as string), () => {
clearMustChangePasswordFlag();
router.push({ name: 'login' });
})
} catch (e) {
@@ -23,25 +29,42 @@ const api = computed<ApiClient | undefined>(() => {
});
const dialog = useDialog();
const mustChangePassword = ref(false);
const openChangePasswordDialog = () => {
dialog.open(ChangePassword, {
props: {
modal: true,
},
data: {
api: api.value,
}
});
};
const loadAuthStatus = async () => {
const cachedStatus = getMustChangePasswordFlag();
if (cachedStatus !== null) {
mustChangePassword.value = cachedStatus;
}
try {
const status = await api.value?.check_login_status();
mustChangePassword.value = Boolean(
status?.loggedIn && status?.mustChangePassword,
);
setMustChangePasswordFlag(mustChangePassword.value);
} catch (e) {
console.error('Failed to load auth status', e);
}
};
const userMenu = ref();
const userMenuItems = ref([
{
label: t('web.main.change_password'),
icon: 'pi pi-key',
command: () => {
console.log('File');
let ret = dialog.open(ChangePassword, {
props: {
modal: true,
},
data: {
api: api.value,
}
});
console.log("return", ret)
},
command: openChangePasswordDialog,
},
{
label: t('web.main.logout'),
@@ -52,6 +75,7 @@ const userMenuItems = ref([
} catch (e) {
console.error("logout failed", e);
}
clearMustChangePasswordFlag();
router.push({ name: 'login' });
},
},
@@ -92,6 +116,7 @@ onMounted(async () => {
// 等待 DOM 渲染完成后添加事件监听器
await nextTick();
document.addEventListener('click', handleClickOutside);
await loadAuthStatus();
});
onUnmounted(() => {
@@ -171,6 +196,13 @@ onUnmounted(() => {
<div class="p-4 sm:ml-64">
<div class="p-4 border-2 border-gray-200 border-dashed rounded-lg dark:border-gray-700">
<div class="grid grid-cols-1 gap-4">
<Message v-if="mustChangePassword" severity="warn" :closable="false">
<div class="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
<span>{{ t('web.main.default_password_warning') }}</span>
<Button size="small" icon="pi pi-key" :label="t('web.main.change_password_now')"
@click="openChangePasswordDialog" />
</div>
</Message>
<RouterView v-slot="{ Component }">
<component :is="Component" :api="api" />
</RouterView>
+37 -14
View File
@@ -2,6 +2,8 @@ import axios, { AxiosError, AxiosInstance, AxiosResponse, InternalAxiosRequestCo
import { type Api, NetworkTypes, Utils } from 'easytier-frontend-lib';
import { Md5 } from 'ts-md5';
const hashAuthPassword = (password: string) => Md5.hashStr(password);
export interface ValidateConfigResponse {
toml_config: string;
}
@@ -14,6 +16,16 @@ export interface OidcConfigResponse {
export interface LoginResponse {
success: boolean;
message: string;
mustChangePassword?: boolean;
}
export interface AuthStatusResponse {
must_change_password: boolean;
}
export interface CheckLoginStatusResponse {
loggedIn: boolean;
mustChangePassword: boolean;
}
export interface RegisterResponse {
@@ -82,7 +94,6 @@ export class ApiClient {
// 添加响应拦截器
this.client.interceptors.response.use((response: AxiosResponse) => {
console.debug('Axios Response:', response);
return response.data; // 假设服务器返回的数据都在data属性中
}, (error: any) => {
if (error.response) {
@@ -108,9 +119,8 @@ export class ApiClient {
// 注册
public async register(data: RegisterData): Promise<RegisterResponse> {
try {
data.credentials.password = Md5.hashStr(data.credentials.password);
const response = await this.client.post<RegisterResponse>('/auth/register', data);
console.log("register response:", response);
data.credentials.password = hashAuthPassword(data.credentials.password);
await this.client.post<RegisterResponse>('/auth/register', data);
return { success: true, message: 'Register success', };
} catch (error) {
if (error instanceof AxiosError) {
@@ -123,10 +133,13 @@ export class ApiClient {
// 登录
public async login(data: Credential): Promise<LoginResponse> {
try {
data.password = Md5.hashStr(data.password);
const response = await this.client.post<any>('/auth/login', data);
console.log("login response:", response);
return { success: true, message: 'Login success', };
data.password = hashAuthPassword(data.password);
const response = await this.client.post<any, AuthStatusResponse>('/auth/login', data);
return {
success: true,
message: 'Login success',
mustChangePassword: response.must_change_password,
};
} catch (error) {
if (error instanceof AxiosError) {
if (error.response?.status === 401) {
@@ -147,16 +160,26 @@ export class ApiClient {
}
public async change_password(new_password: string) {
await this.client.put('/auth/password', { new_password: Md5.hashStr(new_password) });
await this.client.put('/auth/password', { new_password: hashAuthPassword(new_password) });
}
public async check_login_status() {
public async check_login_status(): Promise<CheckLoginStatusResponse> {
try {
await this.client.get('/auth/check_login_status');
return true;
const response = await this.client.get<any, AuthStatusResponse>('/auth/check_login_status');
return {
loggedIn: true,
mustChangePassword: response.must_change_password,
};
} catch (error) {
return false;
}
if (error instanceof AxiosError && error.response?.status === 401) {
return {
loggedIn: false,
mustChangePassword: false,
};
}
throw error;
};
}
public async list_session() {
@@ -0,0 +1,18 @@
const MUST_CHANGE_PASSWORD_STORAGE_KEY = 'auth.mustChangePassword';
export const getMustChangePasswordFlag = (): boolean | null => {
const value = sessionStorage.getItem(MUST_CHANGE_PASSWORD_STORAGE_KEY);
if (value === null) {
return null;
}
return value === 'true';
};
export const setMustChangePasswordFlag = (value: boolean) => {
sessionStorage.setItem(MUST_CHANGE_PASSWORD_STORAGE_KEY, value ? 'true' : 'false');
};
export const clearMustChangePasswordFlag = () => {
sessionStorage.removeItem(MUST_CHANGE_PASSWORD_STORAGE_KEY);
};
@@ -0,0 +1,55 @@
export type PasswordValidationReasonKey =
| 'web.common.password_empty'
| 'web.common.password_min_length'
| 'web.common.password_too_weak';
export interface PasswordValidationResult {
valid: boolean;
reasonKey?: PasswordValidationReasonKey;
}
const PASSWORD_MIN_LENGTH = 8;
export const countPasswordClasses = (password: string) => {
let count = 0;
if (/[a-z]/.test(password)) {
count += 1;
}
if (/[A-Z]/.test(password)) {
count += 1;
}
if (/\d/.test(password)) {
count += 1;
}
if (/[^A-Za-z0-9\s]/.test(password)) {
count += 1;
}
return count;
};
export const validatePasswordStrength = (password: string): PasswordValidationResult => {
if (password.trim().length === 0) {
return {
valid: false,
reasonKey: 'web.common.password_empty',
};
}
if (password.length < PASSWORD_MIN_LENGTH) {
return {
valid: false,
reasonKey: 'web.common.password_min_length',
};
}
if (countPasswordClasses(password) < 2) {
return {
valid: false,
reasonKey: 'web.common.password_too_weak',
};
}
return { valid: true };
};
+1
View File
@@ -11,6 +11,7 @@ pub struct Model {
#[sea_orm(unique)]
pub username: String,
pub password: String,
pub must_change_password: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+23 -1
View File
@@ -96,6 +96,7 @@ impl Db {
let user_active = users::ActiveModel {
username: Set(username.to_string()),
password: Set(password_hash),
must_change_password: Set(false),
..Default::default()
};
let insert_result = users::Entity::insert(user_active).exec(&txn).await?;
@@ -280,7 +281,28 @@ mod tests {
use easytier::{proto::api::manage::NetworkConfig, rpc_service::remote_client::Storage};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _};
use crate::db::{entity::user_running_network_configs, Db, ListNetworkProps};
use crate::db::{
entity::{user_running_network_configs, users},
Db, ListNetworkProps,
};
#[tokio::test]
async fn created_users_default_to_not_requiring_password_change() {
let db = Db::memory_db().await;
let user = db
.create_user_and_join_users_group("created-user", "pre-hashed-password".to_string())
.await
.unwrap();
let stored = users::Entity::find_by_id(user.id)
.one(db.orm_db())
.await
.unwrap()
.unwrap();
assert!(!stored.must_change_password);
}
#[tokio::test]
async fn test_user_network_config_management() {
@@ -0,0 +1,129 @@
use sea_orm_migration::prelude::*;
pub struct Migration;
const DEFAULT_USER_PASSWORD_HASH: &str =
"$argon2i$v=19$m=16,t=2,p=1$aGVyRDBrcnRycnlaMDhkbw$449SEcv/qXf+0fnI9+fYVQ";
const DEFAULT_ADMIN_PASSWORD_HASH: &str =
"$argon2i$v=19$m=16,t=2,p=1$bW5idXl0cmY$61n+JxL4r3dwLPAEDlDdtg";
#[derive(DeriveIden)]
enum Users {
Table,
Username,
Password,
MustChangePassword,
}
impl MigrationName for Migration {
fn name(&self) -> &str {
"m20260405_000003_add_must_change_password"
}
}
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.alter_table(
Table::alter()
.table(Users::Table)
.add_column(
ColumnDef::new(Users::MustChangePassword)
.boolean()
.not_null()
.default(false),
)
.to_owned(),
)
.await?;
manager
.exec_stmt(
Query::update()
.table(Users::Table)
.value(Users::MustChangePassword, true)
.cond_where(any![
Expr::col(Users::Username)
.eq("admin")
.and(Expr::col(Users::Password).eq(DEFAULT_ADMIN_PASSWORD_HASH)),
Expr::col(Users::Username)
.eq("user")
.and(Expr::col(Users::Password).eq(DEFAULT_USER_PASSWORD_HASH)),
])
.to_owned(),
)
.await?;
Ok(())
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.alter_table(
Table::alter()
.table(Users::Table)
.drop_column(Users::MustChangePassword)
.to_owned(),
)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _, SqlxSqliteConnector};
use sea_orm_migration::prelude::SchemaManager;
use sqlx::sqlite::SqlitePoolOptions;
use super::{Migration, MigrationTrait, DEFAULT_USER_PASSWORD_HASH};
use crate::db::entity::users;
async fn find_user(db: &sea_orm::DatabaseConnection, username: &str) -> users::Model {
users::Entity::find()
.filter(users::Column::Username.eq(username))
.one(db)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn migration_only_marks_seeded_accounts_still_using_default_passwords() {
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.unwrap();
sqlx::query(
"CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL
)",
)
.execute(&pool)
.await
.unwrap();
let changed_admin_password = password_auth::generate_hash("already-changed");
sqlx::query("INSERT INTO users (username, password) VALUES (?, ?), (?, ?)")
.bind("admin")
.bind(changed_admin_password)
.bind("user")
.bind(DEFAULT_USER_PASSWORD_HASH)
.execute(&pool)
.await
.unwrap();
let db = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool);
Migration.up(&SchemaManager::new(&db)).await.unwrap();
assert!(!find_user(&db, "admin").await.must_change_password);
assert!(find_user(&db, "user").await.must_change_password);
}
}
+2
View File
@@ -2,6 +2,7 @@ use sea_orm_migration::prelude::*;
mod m20241029_000001_init;
mod m20260403_000002_scope_network_config_unique;
mod m20260405_000003_add_must_change_password;
pub struct Migrator;
@@ -11,6 +12,7 @@ impl MigratorTrait for Migrator {
vec![
Box::new(m20241029_000001_init::Migration),
Box::new(m20260403_000002_scope_network_config_unique::Migration),
Box::new(m20260405_000003_add_must_change_password::Migration),
]
}
}
+29 -17
View File
@@ -4,8 +4,7 @@ use axum::{
Router,
};
use axum_login::login_required;
use axum_messages::Message;
use serde::{Deserialize, Serialize};
use serde::Serialize;
use crate::restful::users::Backend;
@@ -18,9 +17,9 @@ use super::{
AppStateInner,
};
#[derive(Debug, Deserialize, Serialize)]
pub struct LoginResult {
messages: Vec<Message>,
#[derive(Debug, Serialize)]
pub struct AuthStatusResponse {
must_change_password: bool,
}
pub fn router() -> Router<AppStateInner> {
@@ -40,12 +39,15 @@ pub fn router() -> Router<AppStateInner> {
}
mod put {
use crate::restful::{
other_error,
users::{ChangePassword, ChangePasswordError},
HttpHandleError,
};
use axum::Json;
use axum_login::AuthUser;
use easytier::proto::common::Void;
use crate::restful::{other_error, users::ChangePassword, HttpHandleError};
use super::*;
pub async fn change_password(
@@ -58,15 +60,21 @@ mod put {
.await
{
tracing::error!("Failed to change password: {:?}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json::from(other_error(format!("{:?}", e))),
));
let (status, message) = match &e {
ChangePasswordError::EmptyPassword => {
(StatusCode::BAD_REQUEST, "password cannot be empty")
}
ChangePasswordError::UserNotFound | ChangePasswordError::Db(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"failed to change password",
),
};
return Err((status, Json::from(other_error(message.to_string()))));
}
let _ = auth_session.logout().await;
Ok(Void::default().into())
Ok(Json(Void::default()))
}
}
@@ -86,7 +94,7 @@ mod post {
pub async fn login(
mut auth_session: AuthSession,
Json(creds): Json<Credentials>,
) -> Result<Json<Void>, HttpHandleError> {
) -> Result<Json<AuthStatusResponse>, HttpHandleError> {
let user = match auth_session.authenticate(creds.clone()).await {
Ok(Some(user)) => user,
Ok(None) => {
@@ -110,7 +118,9 @@ mod post {
));
}
Ok(Void::default().into())
Ok(Json(AuthStatusResponse {
must_change_password: user.db_user.must_change_password,
}))
}
pub async fn register(
@@ -189,9 +199,11 @@ mod get {
pub async fn check_login_status(
auth_session: AuthSession,
) -> Result<Json<Void>, HttpHandleError> {
if auth_session.user.is_some() {
Ok(Json(Void::default()))
) -> Result<Json<AuthStatusResponse>, HttpHandleError> {
if let Some(user) = auth_session.user {
Ok(Json(AuthStatusResponse {
must_change_password: user.db_user.must_change_password,
}))
} else {
Err((
StatusCode::UNAUTHORIZED,
+125 -2
View File
@@ -12,6 +12,8 @@ use tokio::task;
use crate::db::{self, entity};
const EMPTY_PASSWORD_MD5: &str = "d41d8cd98f00b204e9800998ecf8427e";
#[derive(Clone, Serialize, Deserialize)]
pub struct User {
pub(crate) db_user: entity::users::Model,
@@ -64,6 +66,18 @@ pub struct ChangePassword {
pub new_password: String,
}
#[derive(Debug, thiserror::Error)]
pub enum ChangePasswordError {
#[error("Password cannot be empty")]
EmptyPassword,
#[error("User not found")]
UserNotFound,
#[error(transparent)]
Db(#[from] sea_orm::DbErr),
}
#[derive(Debug, Clone)]
pub struct Backend {
db: db::Db,
@@ -119,7 +133,14 @@ impl Backend {
&self,
id: <User as AuthUser>::Id,
req: &ChangePassword,
) -> anyhow::Result<()> {
) -> Result<(), ChangePasswordError> {
// With the existing pre-hashed protocol the backend can only reject the
// exact empty-string digest; whitespace-only passwords must be blocked
// on the client before hashing.
if req.new_password == EMPTY_PASSWORD_MD5 {
return Err(ChangePasswordError::EmptyPassword);
}
let hashed_password = password_auth::generate_hash(req.new_password.as_str());
use entity::users;
@@ -127,9 +148,10 @@ impl Backend {
let mut user = users::Entity::find_by_id(id)
.one(self.db.orm_db())
.await?
.ok_or(anyhow::anyhow!("User not found"))?
.ok_or(ChangePasswordError::UserNotFound)?
.into_active_model();
user.password = Set(hashed_password.clone());
user.must_change_password = Set(false);
entity::users::Entity::update(user)
.exec(self.db.orm_db())
@@ -242,6 +264,107 @@ impl AuthzBackend for Backend {
}
}
#[cfg(test)]
mod tests {
use axum_login::AuthnBackend;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _};
use super::{Backend, ChangePassword, ChangePasswordError, EMPTY_PASSWORD_MD5};
use crate::db::{entity::users, Db};
async fn find_user(db: &Db, username: &str) -> users::Model {
users::Entity::find()
.filter(users::Column::Username.eq(username))
.one(db.orm_db())
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn seeded_default_users_require_password_change() {
let db = Db::memory_db().await;
assert!(find_user(&db, "admin").await.must_change_password);
assert!(find_user(&db, "user").await.must_change_password);
}
#[tokio::test]
async fn auto_created_user_does_not_require_password_change() {
let db = Db::memory_db().await;
db.auto_create_user("oidc-user").await.unwrap();
assert!(!find_user(&db, "oidc-user").await.must_change_password);
}
#[tokio::test]
async fn change_password_clears_must_change_password_flag() {
let db = Db::memory_db().await;
let backend = Backend::new(db.clone());
let admin = find_user(&db, "admin").await;
backend
.change_password(
admin.id,
&ChangePassword {
new_password: "f1086f68460b65771de50a970cd1242d".to_string(),
},
)
.await
.unwrap();
assert!(!find_user(&db, "admin").await.must_change_password);
}
#[tokio::test]
async fn change_password_rejects_empty_password_digest() {
let db = Db::memory_db().await;
let backend = Backend::new(db.clone());
let admin = find_user(&db, "admin").await;
let error = backend
.change_password(
admin.id,
&ChangePassword {
new_password: EMPTY_PASSWORD_MD5.to_string(),
},
)
.await
.unwrap_err();
assert!(matches!(error, ChangePasswordError::EmptyPassword));
assert!(find_user(&db, "admin").await.must_change_password);
}
#[tokio::test]
async fn can_authenticate_with_new_password_after_change() {
let db = Db::memory_db().await;
let backend = Backend::new(db.clone());
let admin = find_user(&db, "admin").await;
backend
.change_password(
admin.id,
&ChangePassword {
new_password: "f1086f68460b65771de50a970cd1242d".to_string(),
},
)
.await
.unwrap();
let authenticated = backend
.authenticate(super::Credentials {
username: "admin".to_string(),
password: "f1086f68460b65771de50a970cd1242d".to_string(),
})
.await
.unwrap();
assert!(authenticated.is_some());
}
}
// We use a type alias for convenience.
//
// Note that we've supplied our concrete backend here.
+5 -5
View File
@@ -194,11 +194,11 @@ impl super::TunnelConnector for DnsTunnelConnector {
TunnelInfo {
local_addr: info.local_addr.clone(),
remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: info
.resolved_remote_addr
.clone()
.or(info.remote_addr.clone()),
tunnel_type: format!("{}-{}", self.addr.scheme(), info.tunnel_type),
tunnel_type: format!(
"{}-{}",
self.addr.scheme(),
info.remote_addr.unwrap_or_default()
),
},
)))
}
+5 -7
View File
@@ -229,11 +229,11 @@ impl super::TunnelConnector for HttpTunnelConnector {
TunnelInfo {
local_addr: info.local_addr.clone(),
remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: info
.resolved_remote_addr
.clone()
.or(info.remote_addr.clone()),
tunnel_type: format!("{}-{}", self.addr.scheme(), info.tunnel_type),
tunnel_type: format!(
"{:?}-{}",
self.redirect_type,
info.remote_addr.unwrap_or_default()
),
},
)))
}
@@ -353,8 +353,6 @@ mod tests {
let info = t.info().unwrap();
let remote_addr = info.remote_addr.unwrap();
assert_eq!(remote_addr, test_url.into());
let resolved_remote_addr = info.resolved_remote_addr.unwrap();
assert_eq!(resolved_remote_addr.url, "tcp://127.0.0.1:25888");
tokio::join!(task).0.unwrap();
}
+1 -1
View File
@@ -1404,7 +1404,7 @@ impl<'a> CommandHandler<'a> {
"remote_addr: {}, rx_bytes: {}, tx_bytes: {}, latency_us: {}",
conn.tunnel
.as_ref()
.and_then(|t| t.display_remote_addr())
.map(|t| t.remote_addr.clone().unwrap_or_default())
.unwrap_or_default(),
conn.stats.as_ref().map(|s| s.rx_bytes).unwrap_or_default(),
conn.stats.as_ref().map(|s| s.tx_bytes).unwrap_or_default(),
@@ -232,7 +232,6 @@ async fn test_magic_dns_update_replaces_records_for_same_client() {
remote_addr: Some(crate::proto::common::Url {
url: "tcp://127.0.0.1:54321".to_string(),
}),
resolved_remote_addr: None,
}));
dns_server_inst
@@ -300,7 +299,6 @@ async fn test_magic_dns_update_replaces_records_for_same_client() {
remote_addr: Some(crate::proto::common::Url {
url: "tcp://127.0.0.1:54321".to_string(),
}),
resolved_remote_addr: None,
}));
dns_server_inst
+35 -37
View File
@@ -867,49 +867,47 @@ impl Instance {
tokio::spawn(async move {
let mut output_tx = Some(first_round_output);
loop {
let close_notifier = Arc::new(Notify::new());
{
let Some(peer_mgr) = peer_mgr.upgrade() else {
tracing::warn!("peer manager is dropped, stop static ip check.");
if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Err(Error::Unknown));
return;
}
let Some(peer_manager) = peer_mgr.upgrade() else {
tracing::warn!("peer manager is dropped, stop static ip check.");
if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Err(Error::Unknown));
return;
};
let mut new_nic_ctx = NicCtx::new(
peer_mgr.get_global_ctx(),
&peer_mgr,
peer_packet_receiver.clone(),
close_notifier.clone(),
);
if let Err(e) = new_nic_ctx.run(ipv4_addr, ipv6_addr).await {
if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Err(e));
return;
}
tracing::error!("failed to create new nic ctx, err: {:?}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
return;
};
// Create Magic DNS runner only if we have IPv4
#[cfg(feature = "magic-dns")]
{
let ifname = new_nic_ctx.ifname().await;
let dns_runner = if let Some(ipv4) = ipv4_addr {
Self::create_magic_dns_runner(peer_mgr, ifname, ipv4)
} else {
None
};
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx, dns_runner).await;
let close_notifier = Arc::new(Notify::new());
let mut new_nic_ctx = NicCtx::new(
peer_manager.get_global_ctx(),
&peer_manager,
peer_packet_receiver.clone(),
close_notifier.clone(),
);
if let Err(e) = new_nic_ctx.run(ipv4_addr, ipv6_addr).await {
if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Err(e));
return;
}
#[cfg(not(feature = "magic-dns"))]
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx).await;
tracing::error!("failed to create new nic ctx, err: {:?}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
// Create Magic DNS runner only if we have IPv4
#[cfg(feature = "magic-dns")]
{
let ifname = new_nic_ctx.ifname().await;
let dns_runner = if let Some(ipv4) = ipv4_addr {
Self::create_magic_dns_runner(peer_manager, ifname, ipv4)
} else {
None
};
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx, dns_runner).await;
}
#[cfg(not(feature = "magic-dns"))]
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx).await;
if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Ok(()));
}
+6 -30
View File
@@ -659,8 +659,7 @@ impl SyncedRouteInfo {
}
}
fn update_foreign_network(&self, foreign_network: &RouteForeignNetworkInfos) -> bool {
let mut changed = false;
fn update_foreign_network(&self, foreign_network: &RouteForeignNetworkInfos) {
for item in foreign_network.infos.iter().map(Clone::clone) {
let Some(key) = item.key else {
continue;
@@ -676,15 +675,10 @@ impl SyncedRouteInfo {
.and_modify(|old_entry| {
if entry.version > old_entry.version {
*old_entry = entry.clone();
changed = true;
}
})
.or_insert_with(|| {
changed = true;
entry.clone()
});
.or_insert_with(|| entry.clone());
}
changed
}
fn update_my_peer_info(
@@ -2853,14 +2847,8 @@ impl RouteSessionManager {
dst_peer_id: PeerId,
mut sync_now: tokio::sync::broadcast::Receiver<()>,
) {
const RETRY_BASE_MS: u64 = 50;
const RETRY_MAX_MS: u64 = 5000;
let mut last_sync = Instant::now();
let mut last_clean_dst_saved_map = Instant::now();
// Keep retry_delay_ms across outer iterations so that rapid
// connect/disconnect flaps don't fully reset the backoff.
let mut retry_delay_ms = RETRY_BASE_MS;
loop {
loop {
let Some(service_impl) = service_impl.clone().upgrade() else {
@@ -2887,18 +2875,13 @@ impl RouteSessionManager {
last_clean_dst_saved_map = Instant::now();
service_impl.clean_dst_saved_map(dst_peer_id);
}
// Successful sync: decay backoff towards base so the next
// real failure still starts at a reasonable level, but
// don't fully reset to avoid 50ms bursts during flapping.
retry_delay_ms = (retry_delay_ms / 2).max(RETRY_BASE_MS);
break;
}
drop(service_impl);
drop(peer_rpc);
tokio::time::sleep(Duration::from_millis(retry_delay_ms)).await;
retry_delay_ms = (retry_delay_ms * 2).min(RETRY_MAX_MS);
tokio::time::sleep(Duration::from_millis(50)).await;
}
sync_now = sync_now.resubscribe();
@@ -3231,18 +3214,17 @@ impl RouteSessionManager {
service_impl.update_route_table_and_cached_local_conn_bitmap();
}
let mut foreign_network_changed = false;
if let Some(foreign_network) = &foreign_network {
// Step 9b: credential peers' foreign_network_infos are always ignored
if !from_is_credential {
foreign_network_changed = service_impl
service_impl
.synced_route_info
.update_foreign_network(foreign_network);
session.update_dst_saved_foreign_network_version(foreign_network, from_peer_id);
}
}
if need_update_route_table || foreign_network_changed {
if need_update_route_table || foreign_network.is_some() {
service_impl.update_foreign_network_owner_map();
}
@@ -3261,13 +3243,7 @@ impl RouteSessionManager {
.disconnect_untrusted_peers(&untrusted_peers)
.await;
// Only trigger reverse sync when we actually received new data that
// needs to be propagated to other peers. Previously this was
// unconditional, which created an A→B→A→B ping-pong storm even when
// there was nothing new to propagate.
if need_update_route_table || foreign_network_changed {
self.sync_now("sync_route_info");
}
self.sync_now("sync_route_info");
Ok(SyncRouteInfoResponse {
is_initiator,
+16 -1
View File
@@ -149,8 +149,23 @@ pub mod instance {
ret
}
fn is_tunnel_ipv6(tunnel_info: &super::super::common::TunnelInfo) -> bool {
let Some(local_addr) = &tunnel_info.local_addr else {
return false;
};
let u: url::Url = local_addr.clone().into();
u.host()
.map(|h| matches!(h, url::Host::Ipv6(_)))
.unwrap_or(false)
}
fn get_tunnel_proto_str(tunnel_info: &super::super::common::TunnelInfo) -> String {
tunnel_info.display_tunnel_type()
if Self::is_tunnel_ipv6(tunnel_info) {
format!("{}6", tunnel_info.tunnel_type)
} else {
tunnel_info.tunnel_type.clone()
}
}
pub fn get_conn_protos(&self) -> Option<Vec<String>> {
-1
View File
@@ -201,7 +201,6 @@ message TunnelInfo {
string tunnel_type = 1;
common.Url local_addr = 2;
common.Url remote_addr = 3;
common.Url resolved_remote_addr = 4;
}
message StunInfo {
+1 -279
View File
@@ -5,9 +5,8 @@ use std::{
use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine as _};
use strum::VariantArray;
use crate::tunnel::{packet_def::CompressorAlgo, IpScheme};
use crate::tunnel::packet_def::CompressorAlgo;
include!(concat!(env!("OUT_DIR"), "/common.rs"));
@@ -285,105 +284,6 @@ impl fmt::Display for Url {
}
}
fn split_tunnel_scheme(raw_scheme: &str) -> Option<(&str, &'static str, bool)> {
for scheme in IpScheme::VARIANTS {
let scheme: &'static str = scheme.into();
if let Some(base) = raw_scheme.strip_suffix('6') {
if let Some(prefix) = base.strip_suffix(scheme) {
if prefix.is_empty() || prefix.ends_with('-') {
return Some((prefix, scheme, true));
}
}
}
if let Some(prefix) = raw_scheme.strip_suffix(scheme) {
if prefix.is_empty() || prefix.ends_with('-') {
return Some((prefix, scheme, false));
}
}
}
None
}
fn normalize_tunnel_scheme(raw_scheme: &str, is_ipv6: bool) -> Option<String> {
let (prefix, scheme, had_ipv6_suffix) = split_tunnel_scheme(raw_scheme)?;
let suffix = if is_ipv6 || had_ipv6_suffix { "6" } else { "" };
Some(format!("{prefix}{scheme}{suffix}"))
}
fn infer_tunnel_ipv6(raw: &str) -> Option<bool> {
let (_, rest) = raw.split_once("://")?;
if rest.starts_with('[') {
return Some(true);
}
match url::Url::parse(raw).ok()?.host() {
Some(url::Host::Ipv4(_)) => Some(false),
Some(url::Host::Ipv6(_)) => Some(true),
Some(url::Host::Domain(_)) | None => None,
}
}
fn normalize_tunnel_port(raw_port: &str, is_ipv6: bool) -> Option<u16> {
if let Ok(port) = raw_port.parse::<u16>() {
return Some(port);
}
if is_ipv6 && raw_port.ends_with('6') {
return raw_port[..raw_port.len() - 1].parse::<u16>().ok();
}
None
}
fn normalize_tunnel_url(raw: &str, fallback_ipv6: Option<bool>) -> Option<String> {
let (raw_scheme, rest) = raw.split_once("://")?;
if let Some(rest) = rest.strip_prefix('[') {
let (host, remainder) = rest.split_once(']')?;
let scheme = normalize_tunnel_scheme(raw_scheme, true)?;
if remainder.is_empty() {
return Some(format!("{scheme}://[{host}]"));
}
let raw_port = remainder.strip_prefix(':')?;
let port = normalize_tunnel_port(raw_port, true)?;
return Some(format!("{scheme}://[{host}]:{port}"));
}
let is_ipv6 = infer_tunnel_ipv6(raw).or(fallback_ipv6).unwrap_or(false);
let scheme = normalize_tunnel_scheme(raw_scheme, is_ipv6)?;
if let Ok(url) = url::Url::parse(raw) {
let host = match url.host()? {
url::Host::Ipv4(host) => host.to_string(),
url::Host::Ipv6(host) => format!("[{host}]"),
url::Host::Domain(host) => host.to_string(),
};
return Some(match url.port_or_known_default() {
Some(port) => format!("{scheme}://{host}:{port}"),
None => format!("{scheme}://{host}"),
});
}
let (host, raw_port) = rest.rsplit_once(':')?;
let port = normalize_tunnel_port(raw_port, is_ipv6)?;
Some(format!("{scheme}://{host}:{port}"))
}
impl Url {
pub fn is_ipv6_tunnel_endpoint(&self) -> bool {
infer_tunnel_ipv6(&self.url).unwrap_or(false)
}
pub fn normalized_tunnel_display(&self) -> String {
normalize_tunnel_url(&self.url, None).unwrap_or_else(|| self.url.clone())
}
}
impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self {
match value {
@@ -425,38 +325,6 @@ impl Display for SocketAddr {
}
}
impl TunnelInfo {
pub fn effective_remote_addr(&self) -> Option<&Url> {
self.resolved_remote_addr
.as_ref()
.or(self.remote_addr.as_ref())
}
pub fn display_tunnel_type(&self) -> String {
let is_ipv6 = infer_tunnel_ipv6(&self.tunnel_type).or_else(|| {
self.resolved_remote_addr
.as_ref()
.or(self.local_addr.as_ref())
.or(self.remote_addr.as_ref())
.map(Url::is_ipv6_tunnel_endpoint)
});
if self.tunnel_type.contains("://") {
normalize_tunnel_url(&self.tunnel_type, is_ipv6)
.unwrap_or_else(|| self.tunnel_type.clone())
} else {
is_ipv6
.and_then(|is_ipv6| normalize_tunnel_scheme(&self.tunnel_type, is_ipv6))
.unwrap_or_else(|| self.tunnel_type.clone())
}
}
pub fn display_remote_addr(&self) -> Option<String> {
self.effective_remote_addr()
.map(Url::normalized_tunnel_display)
}
}
impl TryFrom<CompressionAlgoPb> for CompressorAlgo {
type Error = anyhow::Error;
@@ -529,149 +397,3 @@ impl SecureModeConfig {
Ok(x25519_dalek::PublicKey::from(k))
}
}
#[cfg(test)]
mod tests {
use super::{normalize_tunnel_url, TunnelInfo, Url};
fn assert_ipv6_tunnel_normalization(scheme: &str, port: u16) {
let expected = format!("{scheme}6://[2001:db8::1]:{port}");
assert_eq!(
normalize_tunnel_url(&format!("{scheme}://[2001:db8::1]:{port}"), None).as_deref(),
Some(expected.as_str())
);
}
#[test]
fn normalize_plain_ipv6_tunnel_url() {
let url = Url {
url: "tcp://[2001:db8::1]:11010".to_string(),
};
assert_eq!(
url.normalized_tunnel_display(),
"tcp6://[2001:db8::1]:11010"
);
assert!(url.is_ipv6_tunnel_endpoint());
}
#[test]
fn normalize_all_enabled_ipv6_tunnel_urls() {
assert_ipv6_tunnel_normalization("tcp", 11010);
assert_ipv6_tunnel_normalization("udp", 11010);
#[cfg(feature = "wireguard")]
assert_ipv6_tunnel_normalization("wg", 11011);
#[cfg(feature = "quic")]
assert_ipv6_tunnel_normalization("quic", 11012);
#[cfg(feature = "websocket")]
assert_ipv6_tunnel_normalization("ws", 80);
#[cfg(feature = "websocket")]
assert_ipv6_tunnel_normalization("wss", 443);
#[cfg(feature = "faketcp")]
assert_ipv6_tunnel_normalization("faketcp", 11013);
}
#[test]
fn normalize_composite_ipv6_tunnel_url() {
assert_eq!(
normalize_tunnel_url("txt-tcp://[2001:db8::1]:11010", None).as_deref(),
Some("txt-tcp6://[2001:db8::1]:11010")
);
}
#[test]
fn recover_malformed_composite_ipv6_tunnel_url() {
assert_eq!(
normalize_tunnel_url("txt-tcp://[2001:db8::1]:110106", None).as_deref(),
Some("txt-tcp6://[2001:db8::1]:11010")
);
}
#[test]
fn keep_normalized_ipv6_tunnel_url_stable() {
assert_eq!(
normalize_tunnel_url("tcp6://[2001:db8::1]:11010", None).as_deref(),
Some("tcp6://[2001:db8::1]:11010")
);
}
#[test]
fn normalize_ipv6_tunnel_url_without_explicit_port() {
assert_eq!(
normalize_tunnel_url("tcp://[2001:db8::1]", None).as_deref(),
Some("tcp6://[2001:db8::1]")
);
}
#[test]
fn keep_domain_host_unbracketed_when_ipv6_falls_back() {
assert_eq!(
normalize_tunnel_url("tcp://localhost:11010", Some(true)).as_deref(),
Some("tcp6://localhost:11010")
);
}
#[test]
fn tunnel_info_display_tunnel_type_preserves_composite_prefix() {
let tunnel = TunnelInfo {
tunnel_type: "txt-tcp://[2001:db8::2]:110106".to_string(),
local_addr: None,
remote_addr: Some(Url {
url: "txt://et.example.com".to_string(),
}),
resolved_remote_addr: None,
};
assert_eq!(
tunnel.display_tunnel_type(),
"txt-tcp6://[2001:db8::2]:11010"
);
}
#[test]
fn tunnel_info_display_tunnel_type_uses_remote_addr_fallback() {
let tunnel = TunnelInfo {
tunnel_type: "tcp".to_string(),
local_addr: None,
remote_addr: Some(Url {
url: "tcp://[2001:db8::2]:11010".to_string(),
}),
resolved_remote_addr: None,
};
assert_eq!(tunnel.display_tunnel_type(), "tcp6");
assert_eq!(
tunnel.display_remote_addr().as_deref(),
Some("tcp6://[2001:db8::2]:11010")
);
}
#[test]
fn tunnel_info_prefers_resolved_remote_addr() {
let tunnel = TunnelInfo {
tunnel_type: "txt-tcp".to_string(),
local_addr: None,
remote_addr: Some(Url {
url: "txt://et.example.com".to_string(),
}),
resolved_remote_addr: Some(Url {
url: "tcp://[2001:db8::3]:11010".to_string(),
}),
};
assert_eq!(tunnel.display_tunnel_type(), "txt-tcp6");
assert_eq!(
tunnel.display_remote_addr().as_deref(),
Some("tcp6://[2001:db8::3]:11010")
);
assert_eq!(
tunnel.effective_remote_addr().map(|url| url.url.as_str()),
Some("tcp://[2001:db8::3]:11010")
);
}
}
-11
View File
@@ -249,13 +249,6 @@ impl TunnelListener for FakeTcpTunnelListener {
)
.into(),
),
resolved_remote_addr: Some(
crate::tunnel::build_url_from_socket_addr(
&socket.remote_addr().to_string(),
"faketcp",
)
.into(),
),
};
// We treat the fake tcp socket as a datagram tunnel directly
@@ -373,10 +366,6 @@ impl TunnelConnector for FakeTcpTunnelConnector {
.into(),
),
remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(
crate::tunnel::build_url_from_socket_addr(&remote_addr.to_string(), "faketcp")
.into(),
),
};
let socket = Arc::new(socket);
+2 -2
View File
@@ -11,7 +11,7 @@ use derive_more::{From, TryInto};
use futures::{Sink, Stream};
use socket2::Protocol;
use std::fmt::Debug;
use strum::{Display, EnumString, IntoStaticStr, VariantArray};
use strum::{Display, EnumString, VariantArray};
use tokio::time::error::Elapsed;
use self::packet_def::ZCPacket;
@@ -284,7 +284,7 @@ struct IpSchemeAttributes {
port_offset: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, IntoStaticStr, VariantArray)]
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, VariantArray)]
#[strum(serialize_all = "lowercase")]
pub enum IpScheme {
Tcp,
-7
View File
@@ -175,9 +175,6 @@ impl QuicTunnelListener {
remote_addr: Some(
super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
),
};
Ok(Box::new(TunnelWrapper::new(
@@ -283,10 +280,6 @@ impl TunnelConnector for QuicTunnelConnector {
super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(),
),
remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&connection.remote_address().to_string(), "quic")
.into(),
),
};
let arc_conn = Arc::new(ConnWrapper { conn: connection });
-6
View File
@@ -214,9 +214,6 @@ fn get_tunnel_for_client(conn: Arc<Connection>) -> impl Tunnel {
tunnel_type: "ring".to_owned(),
local_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
remote_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
resolved_remote_addr: Some(
build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
),
}),
)
}
@@ -229,9 +226,6 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
tunnel_type: "ring".to_owned(),
local_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
remote_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
resolved_remote_addr: Some(
build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
),
}),
)
}
-34
View File
@@ -41,9 +41,6 @@ impl TcpTunnelListener {
remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
};
let (r, w) = stream.into_split();
@@ -120,9 +117,6 @@ fn get_tunnel_with_tcp_stream(
super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp").into(),
),
remote_addr: Some(remote_url.into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
};
let (r, w) = stream.into_split();
@@ -283,34 +277,6 @@ mod tests {
_tunnel_pingpong(listener, connector).await;
}
#[tokio::test]
async fn connector_keeps_source_addr_and_reports_resolved_addr() {
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
let source_url: url::Url = format!("tcp://localhost:{port}").parse().unwrap();
let mut connector = TcpTunnelConnector::new(source_url.clone());
connector.set_ip_version(IpVersion::V4);
let accept_task = tokio::spawn(async move { listener.accept().await.unwrap() });
let tunnel = connector.connect().await.unwrap();
let accepted_tunnel = accept_task.await.unwrap();
let info = tunnel.info().unwrap();
assert_eq!(info.remote_addr.unwrap().url, source_url.to_string());
let resolved_remote_addr: url::Url = info.resolved_remote_addr.unwrap().into();
assert_eq!(resolved_remote_addr.host_str(), Some("127.0.0.1"));
assert_eq!(resolved_remote_addr.port(), Some(port));
let accepted_info = accepted_tunnel.info().unwrap();
assert_eq!(
accepted_info.remote_addr,
accepted_info.resolved_remote_addr,
);
}
#[tokio::test]
async fn test_alloc_port() {
// v4
-6
View File
@@ -428,9 +428,6 @@ impl UdpTunnelListenerData {
remote_addr: Some(
build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(),
),
resolved_remote_addr: Some(
build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(),
),
}),
));
@@ -775,9 +772,6 @@ impl UdpTunnelConnector {
build_url_from_socket_addr(&socket.local_addr()?.to_string(), "udp").into(),
),
remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(
build_url_from_socket_addr(&dst_addr.to_string(), "udp").into(),
),
}),
)))
}
+1 -3
View File
@@ -43,8 +43,7 @@ impl UnixSocketTunnelListener {
let info = TunnelInfo {
tunnel_type: "unix".to_owned(),
local_addr: Some(self.local_url().into()),
remote_addr: remote_addr.clone().map(Into::into),
resolved_remote_addr: remote_addr.map(Into::into),
remote_addr: remote_addr.map(Into::into),
};
let (r, w) = stream.into_split();
@@ -123,7 +122,6 @@ impl super::TunnelConnector for UnixSocketTunnelConnector {
tunnel_type: "unix".to_owned(),
local_addr: local_addr.map(Into::into),
remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(self.addr.clone().into()),
};
let (r, w) = stream.into_split();
+1 -6
View File
@@ -143,13 +143,11 @@ impl WsTunnelListener {
}
let (write, read) = stream.split();
let remote_addr: crate::proto::common::Url = remote_addr.into();
let info = TunnelInfo {
tunnel_type: self.addr.scheme().to_owned(),
local_addr: Some(self.local_url().into()),
remote_addr: Some(remote_addr.clone()),
resolved_remote_addr: Some(remote_addr),
remote_addr: Some(remote_addr.into()),
};
Ok(Box::new(TunnelWrapper::new(
@@ -237,9 +235,6 @@ impl WsTunnelConnector {
.into(),
),
remote_addr: Some(addr.clone().into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&socket_addr.to_string(), addr.scheme()).into(),
),
};
let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap());
-6
View File
@@ -538,9 +538,6 @@ impl WgTunnelListener {
remote_addr: Some(
build_url_from_socket_addr(&addr.to_string(), "wg").into(),
),
resolved_remote_addr: Some(
build_url_from_socket_addr(&addr.to_string(), "wg").into(),
),
}),
));
if let Err(e) = conn_sender.send(tunnel) {
@@ -688,9 +685,6 @@ impl WgTunnelConnector {
tunnel_type: "wg".to_owned(),
local_addr: Some(super::build_url_from_socket_addr(&local_addr, "wg").into()),
remote_addr: Some(addr_url.into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&addr.to_string(), "wg").into(),
),
}),
Some(Box::new(wg_peer)),
));