Compare commits

..

4 Commits

Author SHA1 Message Date
fanyang 349dbf7d8d fix(web): avoid false default-password reminders
Only flag seeded accounts that still use the shipped password hash,
and keep auth status and password change responses stable during
review follow-up.
2026-04-05 17:54:12 +08:00
fanyang 7707b1cf5e fix(web): require password confirmation in auth forms
Require users to enter new passwords twice in the registration
and password change forms so typos are caught before credentials
are stored.
2026-04-05 17:31:22 +08:00
fanyang 2490bb9808 fix(web): enforce password strength in auth forms
Apply the same password policy to registration and password
changes so operators cannot replace default credentials with
another weak password and users see consistent guidance.
2026-04-05 17:31:22 +08:00
fanyang 3f3e36e653 feat(web): warn on default-password accounts
Track built-in admin and user accounts that still use their
seeded password so the web UI can prompt operators to
rotate credentials after deployment.

- Persist must-change-password state for seeded accounts.
- Clear the reminder after password changes and validate
  empty-password updates.
- Keep the migration and auth API behavior explicit.
2026-04-05 17:31:22 +08:00
32 changed files with 696 additions and 499 deletions
@@ -286,6 +286,9 @@ web:
logout: 退出登录 logout: 退出登录
language: 语言 language: 语言
change_password: 修改密码 change_password: 修改密码
change_password_now: 立即修改密码
default_password_warning: 当前账号仍在使用系统默认密码。为保障安全,请部署完成后立即修改密码。
password_changed_relogin: 密码已修改,请重新登录。
device: device:
list: 设备列表 list: 设备列表
@@ -360,6 +363,11 @@ web:
success: 成功 success: 成功
warning: 警告 warning: 警告
info: 提示 info: 提示
password_empty: 密码不能为空
password_min_length: 密码至少需要 8 位
password_too_weak: 密码强度不足
password_mismatch: 两次输入的密码不一致
password_strength_hint: 密码至少 8 位,且需包含大小写字母、数字、特殊字符中的至少 2 类
enable: 开启 enable: 开启
disable: 关闭 disable: 关闭
address: 地址 address: 地址
@@ -286,6 +286,9 @@ web:
logout: Logout logout: Logout
language: Language language: Language
change_password: Change Password change_password: Change Password
change_password_now: Change Password Now
default_password_warning: This account is still using the default password. Change it immediately after deployment to keep your instance secure.
password_changed_relogin: Password changed. Please log in again.
device: device:
list: Device List list: Device List
@@ -360,6 +363,11 @@ web:
success: Success success: Success
warning: Warning warning: Warning
info: Info info: Info
password_empty: Password cannot be empty
password_min_length: Password must be at least 8 characters long
password_too_weak: Password is too weak
password_mismatch: Passwords do not match
password_strength_hint: Password must be at least 8 characters and include at least 2 of uppercase letters, lowercase letters, numbers, or special characters
enable: Enable enable: Enable
disable: Disable disable: Disable
address: Address address: Address
@@ -1,17 +1,80 @@
<script lang="ts" setup> <script lang="ts" setup>
import { computed, inject, ref } from 'vue'; import { computed, inject, ref } from 'vue';
import { Card, Password, Button } from 'primevue'; import { Card, Password, Button } from 'primevue';
import { useToast } from 'primevue/usetoast';
import { useRouter } from 'vue-router';
import { useI18n } from 'vue-i18n';
import ApiClient from '../modules/api'; import ApiClient from '../modules/api';
import { clearMustChangePasswordFlag } from '../modules/auth-status';
import { validatePasswordStrength } from '../modules/password-policy';
const dialogRef = inject<any>('dialogRef'); const dialogRef = inject<any>('dialogRef');
const api = computed<ApiClient>(() => dialogRef.value.data.api); const api = computed<ApiClient>(() => dialogRef.value.data.api);
const password = ref(''); const password = ref('');
const confirmPassword = ref('');
const toast = useToast();
const router = useRouter();
const { t } = useI18n();
const passwordValidation = computed(() => validatePasswordStrength(password.value));
const passwordMatches = computed(() => password.value === confirmPassword.value);
const passwordErrorMessage = computed(() => {
if (password.value.length === 0 || passwordValidation.value.valid) {
return '';
}
return t(passwordValidation.value.reasonKey!);
});
const confirmPasswordErrorMessage = computed(() => {
if (confirmPassword.value.length === 0 || passwordMatches.value) {
return '';
}
return t('web.common.password_mismatch');
});
const canSubmit = computed(() => passwordValidation.value.valid && passwordMatches.value);
const changePassword = async () => { const changePassword = async () => {
await api.value.change_password(password.value); if (!passwordValidation.value.valid) {
dialogRef.value.close(); toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t(passwordValidation.value.reasonKey!),
life: 3000,
});
return;
}
if (!passwordMatches.value) {
toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t('web.common.password_mismatch'),
life: 3000,
});
return;
}
try {
await api.value.change_password(password.value);
toast.add({
severity: 'success',
summary: t('web.common.success'),
detail: t('web.main.password_changed_relogin'),
life: 3000,
});
clearMustChangePasswordFlag();
dialogRef.value.close();
router.push({ name: 'login' });
} catch (error) {
toast.add({
severity: 'error',
summary: t('web.common.error'),
detail: error instanceof Error ? error.message : String(error),
life: 3000,
});
}
} }
</script> </script>
@@ -19,15 +82,28 @@ const changePassword = async () => {
<div class="flex items-center justify-center"> <div class="flex items-center justify-center">
<Card class="w-full max-w-md p-6"> <Card class="w-full max-w-md p-6">
<template #header> <template #header>
<h2 class="text-2xl font-semibold text-center">Change Password <h2 class="text-2xl font-semibold text-center">{{ t('web.main.change_password') }}
</h2> </h2>
</template> </template>
<template #content> <template #content>
<div class="flex flex-col space-y-4"> <div class="flex flex-col space-y-4">
<Password v-model="password" placeholder="New Password" :feedback="false" toggleMask /> <Password v-model="password" :placeholder="t('web.settings.new_password')" :feedback="false"
<Button @click="changePassword" label="Ok" /> toggleMask />
<Password v-model="confirmPassword" :placeholder="t('web.settings.confirm_password')"
:feedback="false" toggleMask />
<small class="text-surface-500 dark:text-surface-400">
{{ t('web.common.password_strength_hint') }}
</small>
<small v-if="passwordErrorMessage" class="text-red-500 dark:text-red-400">
{{ passwordErrorMessage }}
</small>
<small v-if="confirmPasswordErrorMessage" class="text-red-500 dark:text-red-400">
{{ confirmPasswordErrorMessage }}
</small>
<Button @click="changePassword" :label="t('web.common.confirm')"
:disabled="!canSubmit" />
</div> </div>
</template> </template>
</Card> </Card>
</div> </div>
</template> </template>
+60 -1
View File
@@ -7,6 +7,8 @@ import { I18nUtils } from 'easytier-frontend-lib';
import { getInitialApiHost, cleanAndLoadApiHosts, saveApiHost } from "../modules/api-host" import { getInitialApiHost, cleanAndLoadApiHosts, saveApiHost } from "../modules/api-host"
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import ApiClient, { Credential, RegisterData } from '../modules/api'; import ApiClient, { Credential, RegisterData } from '../modules/api';
import { setMustChangePasswordFlag } from '../modules/auth-status';
import { validatePasswordStrength } from '../modules/password-policy';
const { t } = useI18n() const { t } = useI18n()
@@ -22,8 +24,26 @@ const username = ref('');
const password = ref(''); const password = ref('');
const registerUsername = ref(''); const registerUsername = ref('');
const registerPassword = ref(''); const registerPassword = ref('');
const registerConfirmPassword = ref('');
const captcha = ref(''); const captcha = ref('');
const captchaSrc = computed(() => api.value.captcha_url()); const captchaSrc = computed(() => api.value.captcha_url());
const registerPasswordValidation = computed(() => validatePasswordStrength(registerPassword.value));
const registerPasswordsMatch = computed(() => registerPassword.value === registerConfirmPassword.value);
const registerPasswordErrorMessage = computed(() => {
if (registerPassword.value.length === 0 || registerPasswordValidation.value.valid) {
return '';
}
return t(registerPasswordValidation.value.reasonKey!);
});
const registerConfirmPasswordErrorMessage = computed(() => {
if (registerConfirmPassword.value.length === 0 || registerPasswordsMatch.value) {
return '';
}
return t('web.common.password_mismatch');
});
const canRegister = computed(() => registerPasswordValidation.value.valid && registerPasswordsMatch.value);
const onSubmit = async () => { const onSubmit = async () => {
@@ -33,6 +53,7 @@ const onSubmit = async () => {
let ret = await api.value?.login(credential); let ret = await api.value?.login(credential);
if (ret.success) { if (ret.success) {
localStorage.setItem('apiHost', btoa(apiHost.value)); localStorage.setItem('apiHost', btoa(apiHost.value));
setMustChangePasswordFlag(Boolean(ret.mustChangePassword));
router.push({ router.push({
name: 'dashboard', name: 'dashboard',
params: { apiHost: btoa(apiHost.value) }, params: { apiHost: btoa(apiHost.value) },
@@ -43,6 +64,26 @@ const onSubmit = async () => {
}; };
const onRegister = async () => { const onRegister = async () => {
if (!registerPasswordValidation.value.valid) {
toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t(registerPasswordValidation.value.reasonKey!),
life: 3000,
});
return;
}
if (!registerPasswordsMatch.value) {
toast.add({
severity: 'warn',
summary: t('web.common.warning'),
detail: t('web.common.password_mismatch'),
life: 3000,
});
return;
}
saveApiHost(apiHost.value); saveApiHost(apiHost.value);
const credential: Credential = { username: registerUsername.value, password: registerPassword.value }; const credential: Credential = { username: registerUsername.value, password: registerPassword.value };
const registerReq: RegisterData = { credentials: credential, captcha: captcha.value }; const registerReq: RegisterData = { credentials: credential, captcha: captcha.value };
@@ -156,6 +197,23 @@ onBeforeUnmount(() => {
}}</label> }}</label>
<Password id="register-password" v-model="registerPassword" required toggleMask <Password id="register-password" v-model="registerPassword" required toggleMask
:feedback="false" class="w-full" /> :feedback="false" class="w-full" />
<small class="text-surface-500 dark:text-surface-400">
{{ t('web.common.password_strength_hint') }}
</small>
<small v-if="registerPasswordErrorMessage" class="block text-red-500 dark:text-red-400">
{{ registerPasswordErrorMessage }}
</small>
</div>
<div class="p-field">
<label for="register-confirm-password" class="block text-sm font-medium">
{{ t('web.settings.confirm_password') }}
</label>
<Password id="register-confirm-password" v-model="registerConfirmPassword" required toggleMask
:feedback="false" class="w-full" />
<small v-if="registerConfirmPasswordErrorMessage"
class="block text-red-500 dark:text-red-400">
{{ registerConfirmPasswordErrorMessage }}
</small>
</div> </div>
<div class="p-field"> <div class="p-field">
<label for="captcha" class="block text-sm font-medium">{{ t('web.login.captcha') }}</label> <label for="captcha" class="block text-sm font-medium">{{ t('web.login.captcha') }}</label>
@@ -163,7 +221,8 @@ onBeforeUnmount(() => {
<img :src="captchaSrc" alt="Captcha" class="mt-2 mb-2" /> <img :src="captchaSrc" alt="Captcha" class="mt-2 mb-2" />
</div> </div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<Button :label="t('web.login.register')" type="submit" class="w-full" /> <Button :label="t('web.login.register')" type="submit" class="w-full"
:disabled="!canRegister" />
</div> </div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<Button :label="t('web.login.back_to_login')" type="button" class="w-full" <Button :label="t('web.login.back_to_login')" type="button" class="w-full"
@@ -1,13 +1,18 @@
<script setup lang="ts"> <script setup lang="ts">
import { I18nUtils } from 'easytier-frontend-lib' import { I18nUtils } from 'easytier-frontend-lib'
import { computed, onMounted, ref, onUnmounted, nextTick } from 'vue'; import { computed, onMounted, ref, onUnmounted, nextTick } from 'vue';
import { Button, TieredMenu } from 'primevue'; import { Button, Message, TieredMenu } from 'primevue';
import { useRoute, useRouter } from 'vue-router'; import { useRoute, useRouter } from 'vue-router';
import { useDialog } from 'primevue/usedialog'; import { useDialog } from 'primevue/usedialog';
import ChangePassword from './ChangePassword.vue'; import ChangePassword from './ChangePassword.vue';
import Icon from '../assets/easytier.png' import Icon from '../assets/easytier.png'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import ApiClient from '../modules/api'; import ApiClient from '../modules/api';
import {
clearMustChangePasswordFlag,
getMustChangePasswordFlag,
setMustChangePasswordFlag,
} from '../modules/auth-status';
const { t } = useI18n() const { t } = useI18n()
const route = useRoute(); const route = useRoute();
@@ -15,6 +20,7 @@ const router = useRouter();
const api = computed<ApiClient | undefined>(() => { const api = computed<ApiClient | undefined>(() => {
try { try {
return new ApiClient(atob(route.params.apiHost as string), () => { return new ApiClient(atob(route.params.apiHost as string), () => {
clearMustChangePasswordFlag();
router.push({ name: 'login' }); router.push({ name: 'login' });
}) })
} catch (e) { } catch (e) {
@@ -23,25 +29,42 @@ const api = computed<ApiClient | undefined>(() => {
}); });
const dialog = useDialog(); const dialog = useDialog();
const mustChangePassword = ref(false);
const openChangePasswordDialog = () => {
dialog.open(ChangePassword, {
props: {
modal: true,
},
data: {
api: api.value,
}
});
};
const loadAuthStatus = async () => {
const cachedStatus = getMustChangePasswordFlag();
if (cachedStatus !== null) {
mustChangePassword.value = cachedStatus;
}
try {
const status = await api.value?.check_login_status();
mustChangePassword.value = Boolean(
status?.loggedIn && status?.mustChangePassword,
);
setMustChangePasswordFlag(mustChangePassword.value);
} catch (e) {
console.error('Failed to load auth status', e);
}
};
const userMenu = ref(); const userMenu = ref();
const userMenuItems = ref([ const userMenuItems = ref([
{ {
label: t('web.main.change_password'), label: t('web.main.change_password'),
icon: 'pi pi-key', icon: 'pi pi-key',
command: () => { command: openChangePasswordDialog,
console.log('File');
let ret = dialog.open(ChangePassword, {
props: {
modal: true,
},
data: {
api: api.value,
}
});
console.log("return", ret)
},
}, },
{ {
label: t('web.main.logout'), label: t('web.main.logout'),
@@ -52,6 +75,7 @@ const userMenuItems = ref([
} catch (e) { } catch (e) {
console.error("logout failed", e); console.error("logout failed", e);
} }
clearMustChangePasswordFlag();
router.push({ name: 'login' }); router.push({ name: 'login' });
}, },
}, },
@@ -92,6 +116,7 @@ onMounted(async () => {
// 等待 DOM 渲染完成后添加事件监听器 // 等待 DOM 渲染完成后添加事件监听器
await nextTick(); await nextTick();
document.addEventListener('click', handleClickOutside); document.addEventListener('click', handleClickOutside);
await loadAuthStatus();
}); });
onUnmounted(() => { onUnmounted(() => {
@@ -171,6 +196,13 @@ onUnmounted(() => {
<div class="p-4 sm:ml-64"> <div class="p-4 sm:ml-64">
<div class="p-4 border-2 border-gray-200 border-dashed rounded-lg dark:border-gray-700"> <div class="p-4 border-2 border-gray-200 border-dashed rounded-lg dark:border-gray-700">
<div class="grid grid-cols-1 gap-4"> <div class="grid grid-cols-1 gap-4">
<Message v-if="mustChangePassword" severity="warn" :closable="false">
<div class="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between">
<span>{{ t('web.main.default_password_warning') }}</span>
<Button size="small" icon="pi pi-key" :label="t('web.main.change_password_now')"
@click="openChangePasswordDialog" />
</div>
</Message>
<RouterView v-slot="{ Component }"> <RouterView v-slot="{ Component }">
<component :is="Component" :api="api" /> <component :is="Component" :api="api" />
</RouterView> </RouterView>
+37 -14
View File
@@ -2,6 +2,8 @@ import axios, { AxiosError, AxiosInstance, AxiosResponse, InternalAxiosRequestCo
import { type Api, NetworkTypes, Utils } from 'easytier-frontend-lib'; import { type Api, NetworkTypes, Utils } from 'easytier-frontend-lib';
import { Md5 } from 'ts-md5'; import { Md5 } from 'ts-md5';
const hashAuthPassword = (password: string) => Md5.hashStr(password);
export interface ValidateConfigResponse { export interface ValidateConfigResponse {
toml_config: string; toml_config: string;
} }
@@ -14,6 +16,16 @@ export interface OidcConfigResponse {
export interface LoginResponse { export interface LoginResponse {
success: boolean; success: boolean;
message: string; message: string;
mustChangePassword?: boolean;
}
export interface AuthStatusResponse {
must_change_password: boolean;
}
export interface CheckLoginStatusResponse {
loggedIn: boolean;
mustChangePassword: boolean;
} }
export interface RegisterResponse { export interface RegisterResponse {
@@ -82,7 +94,6 @@ export class ApiClient {
// 添加响应拦截器 // 添加响应拦截器
this.client.interceptors.response.use((response: AxiosResponse) => { this.client.interceptors.response.use((response: AxiosResponse) => {
console.debug('Axios Response:', response);
return response.data; // 假设服务器返回的数据都在data属性中 return response.data; // 假设服务器返回的数据都在data属性中
}, (error: any) => { }, (error: any) => {
if (error.response) { if (error.response) {
@@ -108,9 +119,8 @@ export class ApiClient {
// 注册 // 注册
public async register(data: RegisterData): Promise<RegisterResponse> { public async register(data: RegisterData): Promise<RegisterResponse> {
try { try {
data.credentials.password = Md5.hashStr(data.credentials.password); data.credentials.password = hashAuthPassword(data.credentials.password);
const response = await this.client.post<RegisterResponse>('/auth/register', data); await this.client.post<RegisterResponse>('/auth/register', data);
console.log("register response:", response);
return { success: true, message: 'Register success', }; return { success: true, message: 'Register success', };
} catch (error) { } catch (error) {
if (error instanceof AxiosError) { if (error instanceof AxiosError) {
@@ -123,10 +133,13 @@ export class ApiClient {
// 登录 // 登录
public async login(data: Credential): Promise<LoginResponse> { public async login(data: Credential): Promise<LoginResponse> {
try { try {
data.password = Md5.hashStr(data.password); data.password = hashAuthPassword(data.password);
const response = await this.client.post<any>('/auth/login', data); const response = await this.client.post<any, AuthStatusResponse>('/auth/login', data);
console.log("login response:", response); return {
return { success: true, message: 'Login success', }; success: true,
message: 'Login success',
mustChangePassword: response.must_change_password,
};
} catch (error) { } catch (error) {
if (error instanceof AxiosError) { if (error instanceof AxiosError) {
if (error.response?.status === 401) { if (error.response?.status === 401) {
@@ -147,16 +160,26 @@ export class ApiClient {
} }
public async change_password(new_password: string) { public async change_password(new_password: string) {
await this.client.put('/auth/password', { new_password: Md5.hashStr(new_password) }); await this.client.put('/auth/password', { new_password: hashAuthPassword(new_password) });
} }
public async check_login_status() { public async check_login_status(): Promise<CheckLoginStatusResponse> {
try { try {
await this.client.get('/auth/check_login_status'); const response = await this.client.get<any, AuthStatusResponse>('/auth/check_login_status');
return true; return {
loggedIn: true,
mustChangePassword: response.must_change_password,
};
} catch (error) { } catch (error) {
return false; if (error instanceof AxiosError && error.response?.status === 401) {
} return {
loggedIn: false,
mustChangePassword: false,
};
}
throw error;
};
} }
public async list_session() { public async list_session() {
@@ -0,0 +1,18 @@
const MUST_CHANGE_PASSWORD_STORAGE_KEY = 'auth.mustChangePassword';
export const getMustChangePasswordFlag = (): boolean | null => {
const value = sessionStorage.getItem(MUST_CHANGE_PASSWORD_STORAGE_KEY);
if (value === null) {
return null;
}
return value === 'true';
};
export const setMustChangePasswordFlag = (value: boolean) => {
sessionStorage.setItem(MUST_CHANGE_PASSWORD_STORAGE_KEY, value ? 'true' : 'false');
};
export const clearMustChangePasswordFlag = () => {
sessionStorage.removeItem(MUST_CHANGE_PASSWORD_STORAGE_KEY);
};
@@ -0,0 +1,55 @@
export type PasswordValidationReasonKey =
| 'web.common.password_empty'
| 'web.common.password_min_length'
| 'web.common.password_too_weak';
export interface PasswordValidationResult {
valid: boolean;
reasonKey?: PasswordValidationReasonKey;
}
const PASSWORD_MIN_LENGTH = 8;
export const countPasswordClasses = (password: string) => {
let count = 0;
if (/[a-z]/.test(password)) {
count += 1;
}
if (/[A-Z]/.test(password)) {
count += 1;
}
if (/\d/.test(password)) {
count += 1;
}
if (/[^A-Za-z0-9\s]/.test(password)) {
count += 1;
}
return count;
};
export const validatePasswordStrength = (password: string): PasswordValidationResult => {
if (password.trim().length === 0) {
return {
valid: false,
reasonKey: 'web.common.password_empty',
};
}
if (password.length < PASSWORD_MIN_LENGTH) {
return {
valid: false,
reasonKey: 'web.common.password_min_length',
};
}
if (countPasswordClasses(password) < 2) {
return {
valid: false,
reasonKey: 'web.common.password_too_weak',
};
}
return { valid: true };
};
+1
View File
@@ -11,6 +11,7 @@ pub struct Model {
#[sea_orm(unique)] #[sea_orm(unique)]
pub username: String, pub username: String,
pub password: String, pub password: String,
pub must_change_password: bool,
} }
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+23 -1
View File
@@ -96,6 +96,7 @@ impl Db {
let user_active = users::ActiveModel { let user_active = users::ActiveModel {
username: Set(username.to_string()), username: Set(username.to_string()),
password: Set(password_hash), password: Set(password_hash),
must_change_password: Set(false),
..Default::default() ..Default::default()
}; };
let insert_result = users::Entity::insert(user_active).exec(&txn).await?; let insert_result = users::Entity::insert(user_active).exec(&txn).await?;
@@ -280,7 +281,28 @@ mod tests {
use easytier::{proto::api::manage::NetworkConfig, rpc_service::remote_client::Storage}; use easytier::{proto::api::manage::NetworkConfig, rpc_service::remote_client::Storage};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _}; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _};
use crate::db::{entity::user_running_network_configs, Db, ListNetworkProps}; use crate::db::{
entity::{user_running_network_configs, users},
Db, ListNetworkProps,
};
#[tokio::test]
async fn created_users_default_to_not_requiring_password_change() {
let db = Db::memory_db().await;
let user = db
.create_user_and_join_users_group("created-user", "pre-hashed-password".to_string())
.await
.unwrap();
let stored = users::Entity::find_by_id(user.id)
.one(db.orm_db())
.await
.unwrap()
.unwrap();
assert!(!stored.must_change_password);
}
#[tokio::test] #[tokio::test]
async fn test_user_network_config_management() { async fn test_user_network_config_management() {
@@ -0,0 +1,129 @@
use sea_orm_migration::prelude::*;
pub struct Migration;
const DEFAULT_USER_PASSWORD_HASH: &str =
"$argon2i$v=19$m=16,t=2,p=1$aGVyRDBrcnRycnlaMDhkbw$449SEcv/qXf+0fnI9+fYVQ";
const DEFAULT_ADMIN_PASSWORD_HASH: &str =
"$argon2i$v=19$m=16,t=2,p=1$bW5idXl0cmY$61n+JxL4r3dwLPAEDlDdtg";
#[derive(DeriveIden)]
enum Users {
Table,
Username,
Password,
MustChangePassword,
}
impl MigrationName for Migration {
fn name(&self) -> &str {
"m20260405_000003_add_must_change_password"
}
}
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.alter_table(
Table::alter()
.table(Users::Table)
.add_column(
ColumnDef::new(Users::MustChangePassword)
.boolean()
.not_null()
.default(false),
)
.to_owned(),
)
.await?;
manager
.exec_stmt(
Query::update()
.table(Users::Table)
.value(Users::MustChangePassword, true)
.cond_where(any![
Expr::col(Users::Username)
.eq("admin")
.and(Expr::col(Users::Password).eq(DEFAULT_ADMIN_PASSWORD_HASH)),
Expr::col(Users::Username)
.eq("user")
.and(Expr::col(Users::Password).eq(DEFAULT_USER_PASSWORD_HASH)),
])
.to_owned(),
)
.await?;
Ok(())
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.alter_table(
Table::alter()
.table(Users::Table)
.drop_column(Users::MustChangePassword)
.to_owned(),
)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _, SqlxSqliteConnector};
use sea_orm_migration::prelude::SchemaManager;
use sqlx::sqlite::SqlitePoolOptions;
use super::{Migration, MigrationTrait, DEFAULT_USER_PASSWORD_HASH};
use crate::db::entity::users;
async fn find_user(db: &sea_orm::DatabaseConnection, username: &str) -> users::Model {
users::Entity::find()
.filter(users::Column::Username.eq(username))
.one(db)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn migration_only_marks_seeded_accounts_still_using_default_passwords() {
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.unwrap();
sqlx::query(
"CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL
)",
)
.execute(&pool)
.await
.unwrap();
let changed_admin_password = password_auth::generate_hash("already-changed");
sqlx::query("INSERT INTO users (username, password) VALUES (?, ?), (?, ?)")
.bind("admin")
.bind(changed_admin_password)
.bind("user")
.bind(DEFAULT_USER_PASSWORD_HASH)
.execute(&pool)
.await
.unwrap();
let db = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool);
Migration.up(&SchemaManager::new(&db)).await.unwrap();
assert!(!find_user(&db, "admin").await.must_change_password);
assert!(find_user(&db, "user").await.must_change_password);
}
}
+2
View File
@@ -2,6 +2,7 @@ use sea_orm_migration::prelude::*;
mod m20241029_000001_init; mod m20241029_000001_init;
mod m20260403_000002_scope_network_config_unique; mod m20260403_000002_scope_network_config_unique;
mod m20260405_000003_add_must_change_password;
pub struct Migrator; pub struct Migrator;
@@ -11,6 +12,7 @@ impl MigratorTrait for Migrator {
vec![ vec![
Box::new(m20241029_000001_init::Migration), Box::new(m20241029_000001_init::Migration),
Box::new(m20260403_000002_scope_network_config_unique::Migration), Box::new(m20260403_000002_scope_network_config_unique::Migration),
Box::new(m20260405_000003_add_must_change_password::Migration),
] ]
} }
} }
+29 -17
View File
@@ -4,8 +4,7 @@ use axum::{
Router, Router,
}; };
use axum_login::login_required; use axum_login::login_required;
use axum_messages::Message; use serde::Serialize;
use serde::{Deserialize, Serialize};
use crate::restful::users::Backend; use crate::restful::users::Backend;
@@ -18,9 +17,9 @@ use super::{
AppStateInner, AppStateInner,
}; };
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Serialize)]
pub struct LoginResult { pub struct AuthStatusResponse {
messages: Vec<Message>, must_change_password: bool,
} }
pub fn router() -> Router<AppStateInner> { pub fn router() -> Router<AppStateInner> {
@@ -40,12 +39,15 @@ pub fn router() -> Router<AppStateInner> {
} }
mod put { mod put {
use crate::restful::{
other_error,
users::{ChangePassword, ChangePasswordError},
HttpHandleError,
};
use axum::Json; use axum::Json;
use axum_login::AuthUser; use axum_login::AuthUser;
use easytier::proto::common::Void; use easytier::proto::common::Void;
use crate::restful::{other_error, users::ChangePassword, HttpHandleError};
use super::*; use super::*;
pub async fn change_password( pub async fn change_password(
@@ -58,15 +60,21 @@ mod put {
.await .await
{ {
tracing::error!("Failed to change password: {:?}", e); tracing::error!("Failed to change password: {:?}", e);
return Err(( let (status, message) = match &e {
StatusCode::INTERNAL_SERVER_ERROR, ChangePasswordError::EmptyPassword => {
Json::from(other_error(format!("{:?}", e))), (StatusCode::BAD_REQUEST, "password cannot be empty")
)); }
ChangePasswordError::UserNotFound | ChangePasswordError::Db(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"failed to change password",
),
};
return Err((status, Json::from(other_error(message.to_string()))));
} }
let _ = auth_session.logout().await; let _ = auth_session.logout().await;
Ok(Void::default().into()) Ok(Json(Void::default()))
} }
} }
@@ -86,7 +94,7 @@ mod post {
pub async fn login( pub async fn login(
mut auth_session: AuthSession, mut auth_session: AuthSession,
Json(creds): Json<Credentials>, Json(creds): Json<Credentials>,
) -> Result<Json<Void>, HttpHandleError> { ) -> Result<Json<AuthStatusResponse>, HttpHandleError> {
let user = match auth_session.authenticate(creds.clone()).await { let user = match auth_session.authenticate(creds.clone()).await {
Ok(Some(user)) => user, Ok(Some(user)) => user,
Ok(None) => { Ok(None) => {
@@ -110,7 +118,9 @@ mod post {
)); ));
} }
Ok(Void::default().into()) Ok(Json(AuthStatusResponse {
must_change_password: user.db_user.must_change_password,
}))
} }
pub async fn register( pub async fn register(
@@ -189,9 +199,11 @@ mod get {
pub async fn check_login_status( pub async fn check_login_status(
auth_session: AuthSession, auth_session: AuthSession,
) -> Result<Json<Void>, HttpHandleError> { ) -> Result<Json<AuthStatusResponse>, HttpHandleError> {
if auth_session.user.is_some() { if let Some(user) = auth_session.user {
Ok(Json(Void::default())) Ok(Json(AuthStatusResponse {
must_change_password: user.db_user.must_change_password,
}))
} else { } else {
Err(( Err((
StatusCode::UNAUTHORIZED, StatusCode::UNAUTHORIZED,
+125 -2
View File
@@ -12,6 +12,8 @@ use tokio::task;
use crate::db::{self, entity}; use crate::db::{self, entity};
const EMPTY_PASSWORD_MD5: &str = "d41d8cd98f00b204e9800998ecf8427e";
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct User { pub struct User {
pub(crate) db_user: entity::users::Model, pub(crate) db_user: entity::users::Model,
@@ -64,6 +66,18 @@ pub struct ChangePassword {
pub new_password: String, pub new_password: String,
} }
#[derive(Debug, thiserror::Error)]
pub enum ChangePasswordError {
#[error("Password cannot be empty")]
EmptyPassword,
#[error("User not found")]
UserNotFound,
#[error(transparent)]
Db(#[from] sea_orm::DbErr),
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Backend { pub struct Backend {
db: db::Db, db: db::Db,
@@ -119,7 +133,14 @@ impl Backend {
&self, &self,
id: <User as AuthUser>::Id, id: <User as AuthUser>::Id,
req: &ChangePassword, req: &ChangePassword,
) -> anyhow::Result<()> { ) -> Result<(), ChangePasswordError> {
// With the existing pre-hashed protocol the backend can only reject the
// exact empty-string digest; whitespace-only passwords must be blocked
// on the client before hashing.
if req.new_password == EMPTY_PASSWORD_MD5 {
return Err(ChangePasswordError::EmptyPassword);
}
let hashed_password = password_auth::generate_hash(req.new_password.as_str()); let hashed_password = password_auth::generate_hash(req.new_password.as_str());
use entity::users; use entity::users;
@@ -127,9 +148,10 @@ impl Backend {
let mut user = users::Entity::find_by_id(id) let mut user = users::Entity::find_by_id(id)
.one(self.db.orm_db()) .one(self.db.orm_db())
.await? .await?
.ok_or(anyhow::anyhow!("User not found"))? .ok_or(ChangePasswordError::UserNotFound)?
.into_active_model(); .into_active_model();
user.password = Set(hashed_password.clone()); user.password = Set(hashed_password.clone());
user.must_change_password = Set(false);
entity::users::Entity::update(user) entity::users::Entity::update(user)
.exec(self.db.orm_db()) .exec(self.db.orm_db())
@@ -242,6 +264,107 @@ impl AuthzBackend for Backend {
} }
} }
#[cfg(test)]
mod tests {
use axum_login::AuthnBackend;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter as _};
use super::{Backend, ChangePassword, ChangePasswordError, EMPTY_PASSWORD_MD5};
use crate::db::{entity::users, Db};
async fn find_user(db: &Db, username: &str) -> users::Model {
users::Entity::find()
.filter(users::Column::Username.eq(username))
.one(db.orm_db())
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn seeded_default_users_require_password_change() {
let db = Db::memory_db().await;
assert!(find_user(&db, "admin").await.must_change_password);
assert!(find_user(&db, "user").await.must_change_password);
}
#[tokio::test]
async fn auto_created_user_does_not_require_password_change() {
let db = Db::memory_db().await;
db.auto_create_user("oidc-user").await.unwrap();
assert!(!find_user(&db, "oidc-user").await.must_change_password);
}
#[tokio::test]
async fn change_password_clears_must_change_password_flag() {
let db = Db::memory_db().await;
let backend = Backend::new(db.clone());
let admin = find_user(&db, "admin").await;
backend
.change_password(
admin.id,
&ChangePassword {
new_password: "f1086f68460b65771de50a970cd1242d".to_string(),
},
)
.await
.unwrap();
assert!(!find_user(&db, "admin").await.must_change_password);
}
#[tokio::test]
async fn change_password_rejects_empty_password_digest() {
let db = Db::memory_db().await;
let backend = Backend::new(db.clone());
let admin = find_user(&db, "admin").await;
let error = backend
.change_password(
admin.id,
&ChangePassword {
new_password: EMPTY_PASSWORD_MD5.to_string(),
},
)
.await
.unwrap_err();
assert!(matches!(error, ChangePasswordError::EmptyPassword));
assert!(find_user(&db, "admin").await.must_change_password);
}
#[tokio::test]
async fn can_authenticate_with_new_password_after_change() {
let db = Db::memory_db().await;
let backend = Backend::new(db.clone());
let admin = find_user(&db, "admin").await;
backend
.change_password(
admin.id,
&ChangePassword {
new_password: "f1086f68460b65771de50a970cd1242d".to_string(),
},
)
.await
.unwrap();
let authenticated = backend
.authenticate(super::Credentials {
username: "admin".to_string(),
password: "f1086f68460b65771de50a970cd1242d".to_string(),
})
.await
.unwrap();
assert!(authenticated.is_some());
}
}
// We use a type alias for convenience. // We use a type alias for convenience.
// //
// Note that we've supplied our concrete backend here. // Note that we've supplied our concrete backend here.
+5 -5
View File
@@ -194,11 +194,11 @@ impl super::TunnelConnector for DnsTunnelConnector {
TunnelInfo { TunnelInfo {
local_addr: info.local_addr.clone(), local_addr: info.local_addr.clone(),
remote_addr: Some(self.addr.clone().into()), remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: info tunnel_type: format!(
.resolved_remote_addr "{}-{}",
.clone() self.addr.scheme(),
.or(info.remote_addr.clone()), info.remote_addr.unwrap_or_default()
tunnel_type: format!("{}-{}", self.addr.scheme(), info.tunnel_type), ),
}, },
))) )))
} }
+5 -7
View File
@@ -229,11 +229,11 @@ impl super::TunnelConnector for HttpTunnelConnector {
TunnelInfo { TunnelInfo {
local_addr: info.local_addr.clone(), local_addr: info.local_addr.clone(),
remote_addr: Some(self.addr.clone().into()), remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: info tunnel_type: format!(
.resolved_remote_addr "{:?}-{}",
.clone() self.redirect_type,
.or(info.remote_addr.clone()), info.remote_addr.unwrap_or_default()
tunnel_type: format!("{}-{}", self.addr.scheme(), info.tunnel_type), ),
}, },
))) )))
} }
@@ -353,8 +353,6 @@ mod tests {
let info = t.info().unwrap(); let info = t.info().unwrap();
let remote_addr = info.remote_addr.unwrap(); let remote_addr = info.remote_addr.unwrap();
assert_eq!(remote_addr, test_url.into()); assert_eq!(remote_addr, test_url.into());
let resolved_remote_addr = info.resolved_remote_addr.unwrap();
assert_eq!(resolved_remote_addr.url, "tcp://127.0.0.1:25888");
tokio::join!(task).0.unwrap(); tokio::join!(task).0.unwrap();
} }
+1 -1
View File
@@ -1404,7 +1404,7 @@ impl<'a> CommandHandler<'a> {
"remote_addr: {}, rx_bytes: {}, tx_bytes: {}, latency_us: {}", "remote_addr: {}, rx_bytes: {}, tx_bytes: {}, latency_us: {}",
conn.tunnel conn.tunnel
.as_ref() .as_ref()
.and_then(|t| t.display_remote_addr()) .map(|t| t.remote_addr.clone().unwrap_or_default())
.unwrap_or_default(), .unwrap_or_default(),
conn.stats.as_ref().map(|s| s.rx_bytes).unwrap_or_default(), conn.stats.as_ref().map(|s| s.rx_bytes).unwrap_or_default(),
conn.stats.as_ref().map(|s| s.tx_bytes).unwrap_or_default(), conn.stats.as_ref().map(|s| s.tx_bytes).unwrap_or_default(),
@@ -232,7 +232,6 @@ async fn test_magic_dns_update_replaces_records_for_same_client() {
remote_addr: Some(crate::proto::common::Url { remote_addr: Some(crate::proto::common::Url {
url: "tcp://127.0.0.1:54321".to_string(), url: "tcp://127.0.0.1:54321".to_string(),
}), }),
resolved_remote_addr: None,
})); }));
dns_server_inst dns_server_inst
@@ -300,7 +299,6 @@ async fn test_magic_dns_update_replaces_records_for_same_client() {
remote_addr: Some(crate::proto::common::Url { remote_addr: Some(crate::proto::common::Url {
url: "tcp://127.0.0.1:54321".to_string(), url: "tcp://127.0.0.1:54321".to_string(),
}), }),
resolved_remote_addr: None,
})); }));
dns_server_inst dns_server_inst
+35 -37
View File
@@ -867,49 +867,47 @@ impl Instance {
tokio::spawn(async move { tokio::spawn(async move {
let mut output_tx = Some(first_round_output); let mut output_tx = Some(first_round_output);
loop { loop {
let close_notifier = Arc::new(Notify::new()); let Some(peer_manager) = peer_mgr.upgrade() else {
{ tracing::warn!("peer manager is dropped, stop static ip check.");
let Some(peer_mgr) = peer_mgr.upgrade() else { if let Some(output_tx) = output_tx.take() {
tracing::warn!("peer manager is dropped, stop static ip check."); let _ = output_tx.send(Err(Error::Unknown));
if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Err(Error::Unknown));
return;
}
return; return;
};
let mut new_nic_ctx = NicCtx::new(
peer_mgr.get_global_ctx(),
&peer_mgr,
peer_packet_receiver.clone(),
close_notifier.clone(),
);
if let Err(e) = new_nic_ctx.run(ipv4_addr, ipv6_addr).await {
if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Err(e));
return;
}
tracing::error!("failed to create new nic ctx, err: {:?}", e);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
} }
return;
};
// Create Magic DNS runner only if we have IPv4 let close_notifier = Arc::new(Notify::new());
#[cfg(feature = "magic-dns")] let mut new_nic_ctx = NicCtx::new(
{ peer_manager.get_global_ctx(),
let ifname = new_nic_ctx.ifname().await; &peer_manager,
let dns_runner = if let Some(ipv4) = ipv4_addr { peer_packet_receiver.clone(),
Self::create_magic_dns_runner(peer_mgr, ifname, ipv4) close_notifier.clone(),
} else { );
None
}; if let Err(e) = new_nic_ctx.run(ipv4_addr, ipv6_addr).await {
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx, dns_runner).await; if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Err(e));
return;
} }
#[cfg(not(feature = "magic-dns"))] tracing::error!("failed to create new nic ctx, err: {:?}", e);
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx).await; tokio::time::sleep(Duration::from_secs(1)).await;
continue;
} }
// Create Magic DNS runner only if we have IPv4
#[cfg(feature = "magic-dns")]
{
let ifname = new_nic_ctx.ifname().await;
let dns_runner = if let Some(ipv4) = ipv4_addr {
Self::create_magic_dns_runner(peer_manager, ifname, ipv4)
} else {
None
};
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx, dns_runner).await;
}
#[cfg(not(feature = "magic-dns"))]
Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx).await;
if let Some(output_tx) = output_tx.take() { if let Some(output_tx) = output_tx.take() {
let _ = output_tx.send(Ok(())); let _ = output_tx.send(Ok(()));
} }
+6 -30
View File
@@ -659,8 +659,7 @@ impl SyncedRouteInfo {
} }
} }
fn update_foreign_network(&self, foreign_network: &RouteForeignNetworkInfos) -> bool { fn update_foreign_network(&self, foreign_network: &RouteForeignNetworkInfos) {
let mut changed = false;
for item in foreign_network.infos.iter().map(Clone::clone) { for item in foreign_network.infos.iter().map(Clone::clone) {
let Some(key) = item.key else { let Some(key) = item.key else {
continue; continue;
@@ -676,15 +675,10 @@ impl SyncedRouteInfo {
.and_modify(|old_entry| { .and_modify(|old_entry| {
if entry.version > old_entry.version { if entry.version > old_entry.version {
*old_entry = entry.clone(); *old_entry = entry.clone();
changed = true;
} }
}) })
.or_insert_with(|| { .or_insert_with(|| entry.clone());
changed = true;
entry.clone()
});
} }
changed
} }
fn update_my_peer_info( fn update_my_peer_info(
@@ -2853,14 +2847,8 @@ impl RouteSessionManager {
dst_peer_id: PeerId, dst_peer_id: PeerId,
mut sync_now: tokio::sync::broadcast::Receiver<()>, mut sync_now: tokio::sync::broadcast::Receiver<()>,
) { ) {
const RETRY_BASE_MS: u64 = 50;
const RETRY_MAX_MS: u64 = 5000;
let mut last_sync = Instant::now(); let mut last_sync = Instant::now();
let mut last_clean_dst_saved_map = Instant::now(); let mut last_clean_dst_saved_map = Instant::now();
// Keep retry_delay_ms across outer iterations so that rapid
// connect/disconnect flaps don't fully reset the backoff.
let mut retry_delay_ms = RETRY_BASE_MS;
loop { loop {
loop { loop {
let Some(service_impl) = service_impl.clone().upgrade() else { let Some(service_impl) = service_impl.clone().upgrade() else {
@@ -2887,18 +2875,13 @@ impl RouteSessionManager {
last_clean_dst_saved_map = Instant::now(); last_clean_dst_saved_map = Instant::now();
service_impl.clean_dst_saved_map(dst_peer_id); service_impl.clean_dst_saved_map(dst_peer_id);
} }
// Successful sync: decay backoff towards base so the next
// real failure still starts at a reasonable level, but
// don't fully reset to avoid 50ms bursts during flapping.
retry_delay_ms = (retry_delay_ms / 2).max(RETRY_BASE_MS);
break; break;
} }
drop(service_impl); drop(service_impl);
drop(peer_rpc); drop(peer_rpc);
tokio::time::sleep(Duration::from_millis(retry_delay_ms)).await; tokio::time::sleep(Duration::from_millis(50)).await;
retry_delay_ms = (retry_delay_ms * 2).min(RETRY_MAX_MS);
} }
sync_now = sync_now.resubscribe(); sync_now = sync_now.resubscribe();
@@ -3231,18 +3214,17 @@ impl RouteSessionManager {
service_impl.update_route_table_and_cached_local_conn_bitmap(); service_impl.update_route_table_and_cached_local_conn_bitmap();
} }
let mut foreign_network_changed = false;
if let Some(foreign_network) = &foreign_network { if let Some(foreign_network) = &foreign_network {
// Step 9b: credential peers' foreign_network_infos are always ignored // Step 9b: credential peers' foreign_network_infos are always ignored
if !from_is_credential { if !from_is_credential {
foreign_network_changed = service_impl service_impl
.synced_route_info .synced_route_info
.update_foreign_network(foreign_network); .update_foreign_network(foreign_network);
session.update_dst_saved_foreign_network_version(foreign_network, from_peer_id); session.update_dst_saved_foreign_network_version(foreign_network, from_peer_id);
} }
} }
if need_update_route_table || foreign_network_changed { if need_update_route_table || foreign_network.is_some() {
service_impl.update_foreign_network_owner_map(); service_impl.update_foreign_network_owner_map();
} }
@@ -3261,13 +3243,7 @@ impl RouteSessionManager {
.disconnect_untrusted_peers(&untrusted_peers) .disconnect_untrusted_peers(&untrusted_peers)
.await; .await;
// Only trigger reverse sync when we actually received new data that self.sync_now("sync_route_info");
// needs to be propagated to other peers. Previously this was
// unconditional, which created an A→B→A→B ping-pong storm even when
// there was nothing new to propagate.
if need_update_route_table || foreign_network_changed {
self.sync_now("sync_route_info");
}
Ok(SyncRouteInfoResponse { Ok(SyncRouteInfoResponse {
is_initiator, is_initiator,
+16 -1
View File
@@ -149,8 +149,23 @@ pub mod instance {
ret ret
} }
fn is_tunnel_ipv6(tunnel_info: &super::super::common::TunnelInfo) -> bool {
let Some(local_addr) = &tunnel_info.local_addr else {
return false;
};
let u: url::Url = local_addr.clone().into();
u.host()
.map(|h| matches!(h, url::Host::Ipv6(_)))
.unwrap_or(false)
}
fn get_tunnel_proto_str(tunnel_info: &super::super::common::TunnelInfo) -> String { fn get_tunnel_proto_str(tunnel_info: &super::super::common::TunnelInfo) -> String {
tunnel_info.display_tunnel_type() if Self::is_tunnel_ipv6(tunnel_info) {
format!("{}6", tunnel_info.tunnel_type)
} else {
tunnel_info.tunnel_type.clone()
}
} }
pub fn get_conn_protos(&self) -> Option<Vec<String>> { pub fn get_conn_protos(&self) -> Option<Vec<String>> {
-1
View File
@@ -201,7 +201,6 @@ message TunnelInfo {
string tunnel_type = 1; string tunnel_type = 1;
common.Url local_addr = 2; common.Url local_addr = 2;
common.Url remote_addr = 3; common.Url remote_addr = 3;
common.Url resolved_remote_addr = 4;
} }
message StunInfo { message StunInfo {
+1 -279
View File
@@ -5,9 +5,8 @@ use std::{
use anyhow::Context; use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine as _}; use base64::{prelude::BASE64_STANDARD, Engine as _};
use strum::VariantArray;
use crate::tunnel::{packet_def::CompressorAlgo, IpScheme}; use crate::tunnel::packet_def::CompressorAlgo;
include!(concat!(env!("OUT_DIR"), "/common.rs")); include!(concat!(env!("OUT_DIR"), "/common.rs"));
@@ -285,105 +284,6 @@ impl fmt::Display for Url {
} }
} }
fn split_tunnel_scheme(raw_scheme: &str) -> Option<(&str, &'static str, bool)> {
for scheme in IpScheme::VARIANTS {
let scheme: &'static str = scheme.into();
if let Some(base) = raw_scheme.strip_suffix('6') {
if let Some(prefix) = base.strip_suffix(scheme) {
if prefix.is_empty() || prefix.ends_with('-') {
return Some((prefix, scheme, true));
}
}
}
if let Some(prefix) = raw_scheme.strip_suffix(scheme) {
if prefix.is_empty() || prefix.ends_with('-') {
return Some((prefix, scheme, false));
}
}
}
None
}
fn normalize_tunnel_scheme(raw_scheme: &str, is_ipv6: bool) -> Option<String> {
let (prefix, scheme, had_ipv6_suffix) = split_tunnel_scheme(raw_scheme)?;
let suffix = if is_ipv6 || had_ipv6_suffix { "6" } else { "" };
Some(format!("{prefix}{scheme}{suffix}"))
}
fn infer_tunnel_ipv6(raw: &str) -> Option<bool> {
let (_, rest) = raw.split_once("://")?;
if rest.starts_with('[') {
return Some(true);
}
match url::Url::parse(raw).ok()?.host() {
Some(url::Host::Ipv4(_)) => Some(false),
Some(url::Host::Ipv6(_)) => Some(true),
Some(url::Host::Domain(_)) | None => None,
}
}
fn normalize_tunnel_port(raw_port: &str, is_ipv6: bool) -> Option<u16> {
if let Ok(port) = raw_port.parse::<u16>() {
return Some(port);
}
if is_ipv6 && raw_port.ends_with('6') {
return raw_port[..raw_port.len() - 1].parse::<u16>().ok();
}
None
}
fn normalize_tunnel_url(raw: &str, fallback_ipv6: Option<bool>) -> Option<String> {
let (raw_scheme, rest) = raw.split_once("://")?;
if let Some(rest) = rest.strip_prefix('[') {
let (host, remainder) = rest.split_once(']')?;
let scheme = normalize_tunnel_scheme(raw_scheme, true)?;
if remainder.is_empty() {
return Some(format!("{scheme}://[{host}]"));
}
let raw_port = remainder.strip_prefix(':')?;
let port = normalize_tunnel_port(raw_port, true)?;
return Some(format!("{scheme}://[{host}]:{port}"));
}
let is_ipv6 = infer_tunnel_ipv6(raw).or(fallback_ipv6).unwrap_or(false);
let scheme = normalize_tunnel_scheme(raw_scheme, is_ipv6)?;
if let Ok(url) = url::Url::parse(raw) {
let host = match url.host()? {
url::Host::Ipv4(host) => host.to_string(),
url::Host::Ipv6(host) => format!("[{host}]"),
url::Host::Domain(host) => host.to_string(),
};
return Some(match url.port_or_known_default() {
Some(port) => format!("{scheme}://{host}:{port}"),
None => format!("{scheme}://{host}"),
});
}
let (host, raw_port) = rest.rsplit_once(':')?;
let port = normalize_tunnel_port(raw_port, is_ipv6)?;
Some(format!("{scheme}://{host}:{port}"))
}
impl Url {
pub fn is_ipv6_tunnel_endpoint(&self) -> bool {
infer_tunnel_ipv6(&self.url).unwrap_or(false)
}
pub fn normalized_tunnel_display(&self) -> String {
normalize_tunnel_url(&self.url, None).unwrap_or_else(|| self.url.clone())
}
}
impl From<std::net::SocketAddr> for SocketAddr { impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self { fn from(value: std::net::SocketAddr) -> Self {
match value { match value {
@@ -425,38 +325,6 @@ impl Display for SocketAddr {
} }
} }
impl TunnelInfo {
pub fn effective_remote_addr(&self) -> Option<&Url> {
self.resolved_remote_addr
.as_ref()
.or(self.remote_addr.as_ref())
}
pub fn display_tunnel_type(&self) -> String {
let is_ipv6 = infer_tunnel_ipv6(&self.tunnel_type).or_else(|| {
self.resolved_remote_addr
.as_ref()
.or(self.local_addr.as_ref())
.or(self.remote_addr.as_ref())
.map(Url::is_ipv6_tunnel_endpoint)
});
if self.tunnel_type.contains("://") {
normalize_tunnel_url(&self.tunnel_type, is_ipv6)
.unwrap_or_else(|| self.tunnel_type.clone())
} else {
is_ipv6
.and_then(|is_ipv6| normalize_tunnel_scheme(&self.tunnel_type, is_ipv6))
.unwrap_or_else(|| self.tunnel_type.clone())
}
}
pub fn display_remote_addr(&self) -> Option<String> {
self.effective_remote_addr()
.map(Url::normalized_tunnel_display)
}
}
impl TryFrom<CompressionAlgoPb> for CompressorAlgo { impl TryFrom<CompressionAlgoPb> for CompressorAlgo {
type Error = anyhow::Error; type Error = anyhow::Error;
@@ -529,149 +397,3 @@ impl SecureModeConfig {
Ok(x25519_dalek::PublicKey::from(k)) Ok(x25519_dalek::PublicKey::from(k))
} }
} }
#[cfg(test)]
mod tests {
use super::{normalize_tunnel_url, TunnelInfo, Url};
fn assert_ipv6_tunnel_normalization(scheme: &str, port: u16) {
let expected = format!("{scheme}6://[2001:db8::1]:{port}");
assert_eq!(
normalize_tunnel_url(&format!("{scheme}://[2001:db8::1]:{port}"), None).as_deref(),
Some(expected.as_str())
);
}
#[test]
fn normalize_plain_ipv6_tunnel_url() {
let url = Url {
url: "tcp://[2001:db8::1]:11010".to_string(),
};
assert_eq!(
url.normalized_tunnel_display(),
"tcp6://[2001:db8::1]:11010"
);
assert!(url.is_ipv6_tunnel_endpoint());
}
#[test]
fn normalize_all_enabled_ipv6_tunnel_urls() {
assert_ipv6_tunnel_normalization("tcp", 11010);
assert_ipv6_tunnel_normalization("udp", 11010);
#[cfg(feature = "wireguard")]
assert_ipv6_tunnel_normalization("wg", 11011);
#[cfg(feature = "quic")]
assert_ipv6_tunnel_normalization("quic", 11012);
#[cfg(feature = "websocket")]
assert_ipv6_tunnel_normalization("ws", 80);
#[cfg(feature = "websocket")]
assert_ipv6_tunnel_normalization("wss", 443);
#[cfg(feature = "faketcp")]
assert_ipv6_tunnel_normalization("faketcp", 11013);
}
#[test]
fn normalize_composite_ipv6_tunnel_url() {
assert_eq!(
normalize_tunnel_url("txt-tcp://[2001:db8::1]:11010", None).as_deref(),
Some("txt-tcp6://[2001:db8::1]:11010")
);
}
#[test]
fn recover_malformed_composite_ipv6_tunnel_url() {
assert_eq!(
normalize_tunnel_url("txt-tcp://[2001:db8::1]:110106", None).as_deref(),
Some("txt-tcp6://[2001:db8::1]:11010")
);
}
#[test]
fn keep_normalized_ipv6_tunnel_url_stable() {
assert_eq!(
normalize_tunnel_url("tcp6://[2001:db8::1]:11010", None).as_deref(),
Some("tcp6://[2001:db8::1]:11010")
);
}
#[test]
fn normalize_ipv6_tunnel_url_without_explicit_port() {
assert_eq!(
normalize_tunnel_url("tcp://[2001:db8::1]", None).as_deref(),
Some("tcp6://[2001:db8::1]")
);
}
#[test]
fn keep_domain_host_unbracketed_when_ipv6_falls_back() {
assert_eq!(
normalize_tunnel_url("tcp://localhost:11010", Some(true)).as_deref(),
Some("tcp6://localhost:11010")
);
}
#[test]
fn tunnel_info_display_tunnel_type_preserves_composite_prefix() {
let tunnel = TunnelInfo {
tunnel_type: "txt-tcp://[2001:db8::2]:110106".to_string(),
local_addr: None,
remote_addr: Some(Url {
url: "txt://et.example.com".to_string(),
}),
resolved_remote_addr: None,
};
assert_eq!(
tunnel.display_tunnel_type(),
"txt-tcp6://[2001:db8::2]:11010"
);
}
#[test]
fn tunnel_info_display_tunnel_type_uses_remote_addr_fallback() {
let tunnel = TunnelInfo {
tunnel_type: "tcp".to_string(),
local_addr: None,
remote_addr: Some(Url {
url: "tcp://[2001:db8::2]:11010".to_string(),
}),
resolved_remote_addr: None,
};
assert_eq!(tunnel.display_tunnel_type(), "tcp6");
assert_eq!(
tunnel.display_remote_addr().as_deref(),
Some("tcp6://[2001:db8::2]:11010")
);
}
#[test]
fn tunnel_info_prefers_resolved_remote_addr() {
let tunnel = TunnelInfo {
tunnel_type: "txt-tcp".to_string(),
local_addr: None,
remote_addr: Some(Url {
url: "txt://et.example.com".to_string(),
}),
resolved_remote_addr: Some(Url {
url: "tcp://[2001:db8::3]:11010".to_string(),
}),
};
assert_eq!(tunnel.display_tunnel_type(), "txt-tcp6");
assert_eq!(
tunnel.display_remote_addr().as_deref(),
Some("tcp6://[2001:db8::3]:11010")
);
assert_eq!(
tunnel.effective_remote_addr().map(|url| url.url.as_str()),
Some("tcp://[2001:db8::3]:11010")
);
}
}
-11
View File
@@ -249,13 +249,6 @@ impl TunnelListener for FakeTcpTunnelListener {
) )
.into(), .into(),
), ),
resolved_remote_addr: Some(
crate::tunnel::build_url_from_socket_addr(
&socket.remote_addr().to_string(),
"faketcp",
)
.into(),
),
}; };
// We treat the fake tcp socket as a datagram tunnel directly // We treat the fake tcp socket as a datagram tunnel directly
@@ -373,10 +366,6 @@ impl TunnelConnector for FakeTcpTunnelConnector {
.into(), .into(),
), ),
remote_addr: Some(self.addr.clone().into()), remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(
crate::tunnel::build_url_from_socket_addr(&remote_addr.to_string(), "faketcp")
.into(),
),
}; };
let socket = Arc::new(socket); let socket = Arc::new(socket);
+2 -2
View File
@@ -11,7 +11,7 @@ use derive_more::{From, TryInto};
use futures::{Sink, Stream}; use futures::{Sink, Stream};
use socket2::Protocol; use socket2::Protocol;
use std::fmt::Debug; use std::fmt::Debug;
use strum::{Display, EnumString, IntoStaticStr, VariantArray}; use strum::{Display, EnumString, VariantArray};
use tokio::time::error::Elapsed; use tokio::time::error::Elapsed;
use self::packet_def::ZCPacket; use self::packet_def::ZCPacket;
@@ -284,7 +284,7 @@ struct IpSchemeAttributes {
port_offset: u16, port_offset: u16,
} }
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, IntoStaticStr, VariantArray)] #[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, VariantArray)]
#[strum(serialize_all = "lowercase")] #[strum(serialize_all = "lowercase")]
pub enum IpScheme { pub enum IpScheme {
Tcp, Tcp,
-7
View File
@@ -175,9 +175,6 @@ impl QuicTunnelListener {
remote_addr: Some( remote_addr: Some(
super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(), super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
), ),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
),
}; };
Ok(Box::new(TunnelWrapper::new( Ok(Box::new(TunnelWrapper::new(
@@ -283,10 +280,6 @@ impl TunnelConnector for QuicTunnelConnector {
super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(), super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(),
), ),
remote_addr: Some(self.addr.clone().into()), remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&connection.remote_address().to_string(), "quic")
.into(),
),
}; };
let arc_conn = Arc::new(ConnWrapper { conn: connection }); let arc_conn = Arc::new(ConnWrapper { conn: connection });
-6
View File
@@ -214,9 +214,6 @@ fn get_tunnel_for_client(conn: Arc<Connection>) -> impl Tunnel {
tunnel_type: "ring".to_owned(), tunnel_type: "ring".to_owned(),
local_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()), local_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
remote_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()), remote_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
resolved_remote_addr: Some(
build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
),
}), }),
) )
} }
@@ -229,9 +226,6 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
tunnel_type: "ring".to_owned(), tunnel_type: "ring".to_owned(),
local_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()), local_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
remote_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()), remote_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
resolved_remote_addr: Some(
build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
),
}), }),
) )
} }
-34
View File
@@ -41,9 +41,6 @@ impl TcpTunnelListener {
remote_addr: Some( remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(), super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
), ),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
}; };
let (r, w) = stream.into_split(); let (r, w) = stream.into_split();
@@ -120,9 +117,6 @@ fn get_tunnel_with_tcp_stream(
super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp").into(), super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp").into(),
), ),
remote_addr: Some(remote_url.into()), remote_addr: Some(remote_url.into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
}; };
let (r, w) = stream.into_split(); let (r, w) = stream.into_split();
@@ -283,34 +277,6 @@ mod tests {
_tunnel_pingpong(listener, connector).await; _tunnel_pingpong(listener, connector).await;
} }
#[tokio::test]
async fn connector_keeps_source_addr_and_reports_resolved_addr() {
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
let source_url: url::Url = format!("tcp://localhost:{port}").parse().unwrap();
let mut connector = TcpTunnelConnector::new(source_url.clone());
connector.set_ip_version(IpVersion::V4);
let accept_task = tokio::spawn(async move { listener.accept().await.unwrap() });
let tunnel = connector.connect().await.unwrap();
let accepted_tunnel = accept_task.await.unwrap();
let info = tunnel.info().unwrap();
assert_eq!(info.remote_addr.unwrap().url, source_url.to_string());
let resolved_remote_addr: url::Url = info.resolved_remote_addr.unwrap().into();
assert_eq!(resolved_remote_addr.host_str(), Some("127.0.0.1"));
assert_eq!(resolved_remote_addr.port(), Some(port));
let accepted_info = accepted_tunnel.info().unwrap();
assert_eq!(
accepted_info.remote_addr,
accepted_info.resolved_remote_addr,
);
}
#[tokio::test] #[tokio::test]
async fn test_alloc_port() { async fn test_alloc_port() {
// v4 // v4
-6
View File
@@ -428,9 +428,6 @@ impl UdpTunnelListenerData {
remote_addr: Some( remote_addr: Some(
build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(), build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(),
), ),
resolved_remote_addr: Some(
build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(),
),
}), }),
)); ));
@@ -775,9 +772,6 @@ impl UdpTunnelConnector {
build_url_from_socket_addr(&socket.local_addr()?.to_string(), "udp").into(), build_url_from_socket_addr(&socket.local_addr()?.to_string(), "udp").into(),
), ),
remote_addr: Some(self.addr.clone().into()), remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(
build_url_from_socket_addr(&dst_addr.to_string(), "udp").into(),
),
}), }),
))) )))
} }
+1 -3
View File
@@ -43,8 +43,7 @@ impl UnixSocketTunnelListener {
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: "unix".to_owned(), tunnel_type: "unix".to_owned(),
local_addr: Some(self.local_url().into()), local_addr: Some(self.local_url().into()),
remote_addr: remote_addr.clone().map(Into::into), remote_addr: remote_addr.map(Into::into),
resolved_remote_addr: remote_addr.map(Into::into),
}; };
let (r, w) = stream.into_split(); let (r, w) = stream.into_split();
@@ -123,7 +122,6 @@ impl super::TunnelConnector for UnixSocketTunnelConnector {
tunnel_type: "unix".to_owned(), tunnel_type: "unix".to_owned(),
local_addr: local_addr.map(Into::into), local_addr: local_addr.map(Into::into),
remote_addr: Some(self.addr.clone().into()), remote_addr: Some(self.addr.clone().into()),
resolved_remote_addr: Some(self.addr.clone().into()),
}; };
let (r, w) = stream.into_split(); let (r, w) = stream.into_split();
+1 -6
View File
@@ -143,13 +143,11 @@ impl WsTunnelListener {
} }
let (write, read) = stream.split(); let (write, read) = stream.split();
let remote_addr: crate::proto::common::Url = remote_addr.into();
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: self.addr.scheme().to_owned(), tunnel_type: self.addr.scheme().to_owned(),
local_addr: Some(self.local_url().into()), local_addr: Some(self.local_url().into()),
remote_addr: Some(remote_addr.clone()), remote_addr: Some(remote_addr.into()),
resolved_remote_addr: Some(remote_addr),
}; };
Ok(Box::new(TunnelWrapper::new( Ok(Box::new(TunnelWrapper::new(
@@ -237,9 +235,6 @@ impl WsTunnelConnector {
.into(), .into(),
), ),
remote_addr: Some(addr.clone().into()), remote_addr: Some(addr.clone().into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&socket_addr.to_string(), addr.scheme()).into(),
),
}; };
let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap()); let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap());
-6
View File
@@ -538,9 +538,6 @@ impl WgTunnelListener {
remote_addr: Some( remote_addr: Some(
build_url_from_socket_addr(&addr.to_string(), "wg").into(), build_url_from_socket_addr(&addr.to_string(), "wg").into(),
), ),
resolved_remote_addr: Some(
build_url_from_socket_addr(&addr.to_string(), "wg").into(),
),
}), }),
)); ));
if let Err(e) = conn_sender.send(tunnel) { if let Err(e) = conn_sender.send(tunnel) {
@@ -688,9 +685,6 @@ impl WgTunnelConnector {
tunnel_type: "wg".to_owned(), tunnel_type: "wg".to_owned(),
local_addr: Some(super::build_url_from_socket_addr(&local_addr, "wg").into()), local_addr: Some(super::build_url_from_socket_addr(&local_addr, "wg").into()),
remote_addr: Some(addr_url.into()), remote_addr: Some(addr_url.into()),
resolved_remote_addr: Some(
super::build_url_from_socket_addr(&addr.to_string(), "wg").into(),
),
}), }),
Some(Box::new(wg_peer)), Some(Box::new(wg_peer)),
)); ));