Compare commits

...

21 Commits

Author SHA1 Message Date
fanyang 9d7a938e93 Address review comments 2026-05-04 10:42:51 +08:00
fanyang 6229229b31 feat: support lzo compression 2026-05-04 10:42:51 +08:00
fanyang 6a63853bad fix: silence listener warning in feature builds 2026-05-04 10:42:51 +08:00
fanyang 362aa7a9cd fix: allow omitted ACL config fields (#2206) 2026-05-04 00:47:24 +08:00
KKRainbow 12a7b5a5c5 fix: scope peer center server data to instance (#2198)
Stop sharing PeerCenterServer state through a process-global map so local and foreign-network services cannot mix peer-center data when peer ids overlap.
2026-05-02 01:43:01 +08:00
fanyang 4eba9b07b6 fix(web-client): keep retrying unreachable config server (#2140)
Defer config-server connector creation into the web client retry loop so
service startup does not fail when network or DNS is unavailable.
2026-05-02 00:09:48 +08:00
KKRainbow 1b48029bdc fix: clean stale foreign network state (#2197)
- clear foreign-network traffic metric peer caches on peer removal and network cleanup
- release reserved foreign-network peer IDs on handshake/add-peer error paths
- avoid creating no-op foreign-network token buckets when limits are unlimited
- shrink relay/session maps after cleanup and remove unused peer-center global data entries
2026-05-01 23:30:51 +08:00
KKRainbow 3542e944cb fix(quic): prune stopped endpoints from pool (#2195)
* remove wss port 0 compatibility code
* fix(quic): prune stopped endpoints from pool
2026-05-01 18:51:39 +08:00
KKRainbow 852d1c9e14 feat(gui): add UPnP and public IPv6 advanced options (#2194)
Expose disable-upnp and ipv6_public_addr_auto in the shared web/GUI config editor
bump release metadata to 2.6.3.
2026-05-01 13:45:19 +08:00
KKRainbow 4958394469 fix: protect self peer during credential refresh and allow need-p2p peers through public server (#2192)
* fix: protect self peer during credential refresh

* fix: allow need-p2p peers through public server
2026-05-01 06:59:30 +08:00
KKRainbow 41b6d65604 fix faketcp filter on windows (#2190) 2026-04-30 23:55:56 +08:00
KKRainbow aae30894dd fix: keep file logger disabled by default (#2189) 2026-04-30 21:42:30 +08:00
fanyang 81d169abfc fix: fall back when CLI manage service is unavailable (#2185) 2026-04-30 19:50:50 +08:00
Luna Yao 9c6c210e89 fix: disable SO_EXCLUSIVEADDRUSE on Windows (#2180) 2026-04-30 19:48:54 +08:00
Mg Pig d1c6dcf754 fix: prevent URL input layout flicker with container queries (#2186) 2026-04-30 19:45:01 +08:00
KKRainbow 97c8c4f55a feat: support disabling relay data forwarding (#2188)
- add a disable_relay_data runtime/config patch option
- reuse the existing avoid_relay_data feature flag when relay data forwarding is disabled
2026-04-30 19:44:40 +08:00
KKRainbow ed8df2d58f prevent EasyTier-managed IPv6 from being used as underlay connections (#2181)
When a node has public IPv6 addresses allocated by EasyTier, those addresses
are installed on the host's network interfaces. The system would then pick
them up as candidate source/destination addresses for underlay connections
(direct peer, UDP hole punch, bind addresses), causing overlay traffic to
loop back into the overlay itself.

Add a central predicate is_ip_easytier_managed_ipv6() and apply it at every
point where IPv6 addresses are selected for underlay use:
- Filter managed IPv6 from DNS-resolved connector addresses, including a
  UDP socket getsockname check to detect whether the OS would route through
  the overlay to reach a destination
- Skip managed IPv6 in bind address selection and STUN candidate filtering
- Strip managed IPv6 from GetIpListResponse RPC so peers never learn them
- Pass pre-resolved addresses to tunnel connectors to avoid re-resolution

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-29 12:17:22 +08:00
lurenjia f66010e6f9 fix: preserve URL type in matches_scheme (#2179)
Avoid resolving Url::as_ref() to the full URL string before TunnelScheme
conversion. Add regression coverage for owned/borrowed URLs and the UDP
IPv6 hole-punch branch condition.

Co-authored-by: KKRainbow <443152178@qq.com>
2026-04-28 23:23:41 +08:00
Luna Yao d5c4700d32 utils: replace defer, ContextGuard, DetachableTask with guarden crate (#2163) 2026-04-27 18:29:46 +08:00
KKRainbow 969ecfc4ca fix(gui): refresh service after core version upgrade (#2172) 2026-04-27 15:54:52 +08:00
KKRainbow 8f862997eb feat: support allocating public IPv6 addresses from a provider (#2162)
* feat: support allocating public IPv6 addresses from a provider

Add a provider/leaser architecture for public IPv6 address allocation
between nodes in the same network:

- A node with `--ipv6-public-addr-provider` advertises a delegable
  public IPv6 prefix (auto-detected from kernel routes or manually
  configured via `--ipv6-public-addr-prefix`).
- Other nodes with `--ipv6-public-addr-auto` request a /128 lease from
  the selected provider via a new RPC service (PublicIpv6AddrRpc).
- Leases have a 30s TTL, renewed every 10s by the client routine.
- The provider allocates addresses deterministically from its prefix
  using instance-UUID-based hashing to prefer stable assignments.
- Routes to peer leases are installed on the TUN device, and each
  client's own /128 is assigned as its IPv6 address.

Also includes netlink IPv6 route table inspection, integration tests,
and event-driven route/address reconciliation.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-26 21:37:34 +08:00
88 changed files with 6751 additions and 1349 deletions
+1 -1
View File
@@ -11,7 +11,7 @@ on:
image_tag:
description: 'Tag for this image build'
type: string
default: 'v2.6.2'
default: 'v2.6.3'
required: true
mark_latest:
description: 'Mark this image as latest'
+1 -1
View File
@@ -18,7 +18,7 @@ on:
version:
description: 'Version for this release'
type: string
default: 'v2.6.2'
default: 'v2.6.3'
required: true
make_latest:
description: 'Mark this release as latest'
Generated
+53 -25
View File
@@ -2229,7 +2229,7 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555"
[[package]]
name = "easytier"
version = "2.6.2"
version = "2.6.3"
dependencies = [
"aes-gcm",
"anyhow",
@@ -2273,6 +2273,7 @@ dependencies = [
"gethostname 0.5.0",
"git-version",
"globwalk",
"guarden",
"hickory-client",
"hickory-proto",
"hickory-resolver",
@@ -2287,6 +2288,7 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"kcp-sys",
"lzokay-native",
"machine-uid",
"maplit",
"mimalloc",
@@ -2404,7 +2406,7 @@ dependencies = [
[[package]]
name = "easytier-gui"
version = "2.6.2"
version = "2.6.3"
dependencies = [
"anyhow",
"async-trait",
@@ -2456,6 +2458,7 @@ dependencies = [
"dashmap",
"easytier",
"futures",
"guarden",
"jsonwebtoken",
"mimalloc",
"mockall",
@@ -2484,7 +2487,7 @@ dependencies = [
[[package]]
name = "easytier-web"
version = "2.6.2"
version = "2.6.3"
dependencies = [
"anyhow",
"async-trait",
@@ -3590,6 +3593,28 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "guarden"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca87812d87fa82896df1adfb5c111cdeaae3edb6da028f5df002dcbd7df71454"
dependencies = [
"futures",
"guarden-macros",
"tokio",
]
[[package]]
name = "guarden-macros"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b42f4b8de91cbd793ce8e6cf8d4821ef02d2d5b4468e0a55a36c65c5581de53"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "h2"
version = "0.4.7"
@@ -3705,12 +3730,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hermit-abi"
version = "0.5.2"
@@ -4026,7 +4045,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2 0.6.1",
"socket2 0.5.10",
"tokio",
"tower-service",
"tracing",
@@ -4695,9 +4714,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.172"
version = "0.2.186"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
[[package]]
name = "libdbus-sys"
@@ -4856,6 +4875,16 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "lzokay-native"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "792ba667add2798c6c3e988e630f4eb921b5cbc735044825b7111ef1582c8730"
dependencies = [
"byteorder",
"thiserror 1.0.63",
]
[[package]]
name = "mac"
version = "0.1.1"
@@ -5043,14 +5072,13 @@ dependencies = [
[[package]]
name = "mio"
version = "1.0.2"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1"
dependencies = [
"hermit-abi 0.3.9",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -6551,7 +6579,7 @@ checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218"
dependencies = [
"cfg-if",
"concurrent-queue",
"hermit-abi 0.5.2",
"hermit-abi",
"pin-project-lite",
"rustix 1.0.7",
"windows-sys 0.61.2",
@@ -8650,12 +8678,12 @@ dependencies = [
[[package]]
name = "socket2"
version = "0.6.1"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17129e116933cf371d018bb80ae557e889637989d8638274fb25622827b03881"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [
"libc",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -9774,9 +9802,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.48.0"
version = "1.52.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408"
checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6"
dependencies = [
"bytes",
"libc",
@@ -9784,7 +9812,7 @@ dependencies = [
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2 0.6.1",
"socket2 0.6.3",
"tokio-macros",
"tracing",
"windows-sys 0.61.2",
@@ -9792,9 +9820,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.6.0"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496"
dependencies = [
"proc-macro2",
"quote",
+1 -1
View File
@@ -1,6 +1,6 @@
id=easytier_magisk
name=EasyTier_Magisk
version=v2.6.2
version=v2.6.3
versionCode=1
author=EasyTier
description=easytier magisk module @EasyTier(https://github.com/EasyTier/EasyTier)
@@ -12,6 +12,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
guarden = "0.1"
# Axum web framework
axum = { version = "0.8.4", features = ["macros"] }
@@ -10,9 +10,9 @@ use easytier::{
common::config::{
ConfigFileControl, ConfigLoader, NetworkIdentity, PeerConfig, TomlConfigLoader,
},
defer,
instance_manager::NetworkInstanceManager,
};
use guarden::defer;
use serde::{Deserialize, Serialize};
use sqlx::any;
use tokio_util::task::AbortOnDropHandle;
+1 -1
View File
@@ -1,7 +1,7 @@
{
"name": "easytier-gui",
"type": "module",
"version": "2.6.2",
"version": "2.6.3",
"private": true,
"packageManager": "pnpm@9.12.1+sha512.e5a7e52a4183a02d5931057f7a0dbff9d5e9ce3161e33fa68ae392125b79282a8a8a470a51dfc8a0ed86221442eb2fb57019b0990ed24fab519bf0e1bc5ccfc4",
"scripts": {
+1 -1
View File
@@ -1,6 +1,6 @@
[package]
name = "easytier-gui"
version = "2.6.2"
version = "2.6.3"
description = "EasyTier GUI"
authors = ["you"]
edition.workspace = true
+1 -1
View File
@@ -17,7 +17,7 @@
"createUpdaterArtifacts": false
},
"productName": "easytier-gui",
"version": "2.6.2",
"version": "2.6.3",
"identifier": "com.kkrainbow.easytier",
"plugins": {
"shell": {
+1
View File
@@ -18,6 +18,7 @@ export interface ServiceMode extends WebClientConfig {
rpc_portal: string
file_log_level: 'off' | 'warn' | 'info' | 'debug' | 'trace'
file_log_dir: string
installed_core_version?: string
}
export interface RemoteMode {
+20 -3
View File
@@ -16,7 +16,7 @@ import { useToast, useConfirm } from 'primevue'
import { loadMode, saveMode, WebClientConfig, type Mode } from '~/composables/mode'
import { saveLastNetworkInstanceId, loadLastNetworkInstanceId } from '~/composables/config'
import ModeSwitcher from '~/components/ModeSwitcher.vue'
import { getServiceStatus } from '~/composables/backend'
import { getEasytierVersion, getServiceStatus } from '~/composables/backend'
const { t, locale } = useI18n()
const confirm = useConfirm()
@@ -85,6 +85,20 @@ async function onUninstallService() {
});
}
function stripModeMetadata(mode: Mode) {
if (mode.mode !== 'service') {
return mode
}
const serviceConfig = { ...mode }
delete serviceConfig.installed_core_version
return serviceConfig
}
function modeConfigChanged(next: Mode) {
return JSON.stringify(stripModeMetadata(next)) !== JSON.stringify(stripModeMetadata(currentMode.value))
}
async function onStopService() {
isModeSaving.value = true
manualDisconnect.value = true
@@ -134,13 +148,14 @@ async function initWithMode(mode: Mode) {
}
url = mode.remote_rpc_address
break;
case 'service':
case 'service': {
if (!mode.config_dir || !mode.file_log_dir || !mode.file_log_level || !mode.rpc_portal) {
toast.add({ severity: 'error', summary: t('error'), detail: t('mode.service_config_empty'), life: 10000 })
return initWithMode({ ...mode, mode: 'normal' });
}
let serviceStatus = await getServiceStatus()
if (serviceStatus === "NotInstalled" || JSON.stringify(mode) !== JSON.stringify(currentMode.value)) {
const coreVersion = await getEasytierVersion()
if (serviceStatus === "NotInstalled" || modeConfigChanged(mode) || mode.installed_core_version !== coreVersion) {
mode.config_server_url = mode.config_server_url || undefined
await initService({
config_dir: mode.config_dir,
@@ -149,6 +164,7 @@ async function initWithMode(mode: Mode) {
rpc_portal: mode.rpc_portal,
config_server: mode.config_server_url,
})
mode.installed_core_version = coreVersion
serviceStatus = await getServiceStatus()
}
if (serviceStatus === "Stopped") {
@@ -157,6 +173,7 @@ async function initWithMode(mode: Mode) {
url = "tcp://" + mode.rpc_portal.replace("0.0.0.0", "127.0.0.1")
retrys = 5
break;
}
case 'normal':
url = mode.rpc_portal;
break;
+1 -1
View File
@@ -1,6 +1,6 @@
[package]
name = "easytier-web"
version = "2.6.2"
version = "2.6.3"
edition.workspace = true
description = "Config server for easytier. easytier-core gets config from this and web frontend use it as restful api server."
@@ -81,6 +81,7 @@ const bool_flags: BoolFlag[] = [
{ field: 'latency_first', help: 'latency_first_help' },
{ field: 'use_smoltcp', help: 'use_smoltcp_help' },
{ field: 'disable_ipv6', help: 'disable_ipv6_help' },
{ field: 'ipv6_public_addr_auto', help: 'ipv6_public_addr_auto_help' },
{ field: 'enable_kcp_proxy', help: 'enable_kcp_proxy_help' },
{ field: 'disable_kcp_input', help: 'disable_kcp_input_help' },
{ field: 'enable_quic_proxy', help: 'enable_quic_proxy_help' },
@@ -98,6 +99,7 @@ const bool_flags: BoolFlag[] = [
{ field: 'disable_encryption', help: 'disable_encryption_help' },
{ field: 'disable_tcp_hole_punching', help: 'disable_tcp_hole_punching_help' },
{ field: 'disable_udp_hole_punching', help: 'disable_udp_hole_punching_help' },
{ field: 'disable_upnp', help: 'disable_upnp_help' },
{ field: 'disable_sym_hole_punching', help: 'disable_sym_hole_punching_help' },
{ field: 'enable_magic_dns', help: 'enable_magic_dns_help' },
{ field: 'enable_private_mode', help: 'enable_private_mode_help' },
@@ -2,7 +2,7 @@
import { AutoComplete, Button, Dialog, InputNumber, InputText } from 'primevue'
import InputGroup from 'primevue/inputgroup'
import InputGroupAddon from 'primevue/inputgroupaddon'
import { computed, onMounted, onUnmounted, ref, watch } from 'vue'
import { computed, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
const props = defineProps<{
@@ -13,25 +13,8 @@ const props = defineProps<{
const { t } = useI18n()
const url = defineModel<string>({ required: true })
const editing = ref(false)
const container = ref<HTMLElement | null>(null)
const internalCompact = ref(false)
const hostFocused = ref(false)
onMounted(() => {
if (container.value) {
const observer = new ResizeObserver(entries => {
for (const entry of entries) {
internalCompact.value = entry.contentRect.width < 400
}
})
observer.observe(container.value)
onUnmounted(() => {
observer.disconnect()
})
}
})
const parseUrl = (val: string | null | undefined): { proto: string; host: string; port: number | null } => {
const getValidPort = (portStr: string, proto: string) => {
const p = parseInt(portStr)
@@ -169,28 +152,30 @@ const onProtoChange = (newProto: string) => {
</script>
<template>
<div ref="container" class="w-full">
<InputGroup v-if="!internalCompact" class="w-full">
<div class="url-input-container w-full min-w-0 overflow-hidden">
<InputGroup class="url-input-full w-full min-w-0">
<AutoComplete :model-value="internalValue.proto" :suggestions="filteredProtos" dropdown
class="max-w-32 proto-autocomplete-in-group" @complete="searchProtos"
@update:model-value="onProtoChange" />
<InputText v-model="internalValue.host" :placeholder="placeholder || '0.0.0.0'" class="grow"
<InputText v-model="internalValue.host" :placeholder="placeholder || '0.0.0.0'" class="grow min-w-0"
@focus="onHostFocus" @blur="onHostBlur" />
<template v-if="!isNoPortProto">
<InputGroupAddon>
<span style="font-weight: bold">:</span>
</InputGroupAddon>
<InputNumber v-model="internalValue.port" :format="false" :min="1" :max="65535" class="max-w-24"
:placeholder="String(protos[internalValue.proto] ?? 11010)"
fluid />
:placeholder="String(protos[internalValue.proto] ?? 11010)" fluid />
</template>
<!-- Rendered in both responsive branches; keep action slot content free of side effects and duplicate IDs. -->
<slot name="actions"></slot>
</InputGroup>
<div v-else class="flex justify-between items-center p-2 border rounded w-full">
<span class="truncate mr-2">{{ url }}</span>
<div class="flex items-center">
<Button icon="pi pi-pencil" class="p-button-sm p-button-text" @click="editing = true" />
<div
class="url-input-compact flex justify-between items-center p-2 border rounded w-full min-w-0 overflow-hidden">
<span class="truncate mr-2 min-w-0 flex-1 overflow-hidden">{{ url }}</span>
<div class="flex items-center shrink-0">
<Button icon="pi pi-pencil" class="p-button-sm p-button-text" :aria-label="t('web.common.edit')"
@click="editing = true" />
<slot name="actions"></slot>
</div>
</div>
@@ -222,6 +207,28 @@ const onProtoChange = (newProto: string) => {
</template>
<style scoped>
.url-input-container {
container-type: inline-size;
}
.url-input-full {
display: none;
}
.url-input-compact {
display: flex;
}
@container (min-width: 400px) {
.url-input-full {
display: flex;
}
.url-input-compact {
display: none;
}
}
.proto-autocomplete-in-group,
.proto-autocomplete-in-group :deep(.p-autocomplete-input),
.proto-autocomplete-in-group :deep(.p-autocomplete-dropdown) {
@@ -104,6 +104,9 @@ use_smoltcp_help: 使用用户态 TCP/IP 协议栈,避免操作系统防火墙
disable_ipv6: 禁用IPv6
disable_ipv6_help: 禁用此节点的IPv6功能,仅使用IPv4进行网络通信。
ipv6_public_addr_auto: 自动获取公网 IPv6
ipv6_public_addr_auto_help: 自动从共享了 IPv6 子网的对等节点获取一个公网 IPv6 地址。
enable_kcp_proxy: 启用 KCP 代理
enable_kcp_proxy_help: 将 TCP 流量转为 KCP 流量,降低传输延迟,提升传输速度。
@@ -157,6 +160,9 @@ disable_tcp_hole_punching_help: 禁用TCP打洞功能
disable_udp_hole_punching: 禁用UDP打洞
disable_udp_hole_punching_help: 禁用UDP打洞功能
disable_upnp: 禁用 UPnP
disable_upnp_help: 禁用符合条件监听器的运行时 UPnP/NAT-PMP 端口映射;自动端口映射默认开启。
disable_sym_hole_punching: 禁用对称NAT打洞
disable_sym_hole_punching_help: 禁用对称NAT的打洞(生日攻击),将对称NAT视为锥形NAT处理
@@ -103,6 +103,9 @@ use_smoltcp_help: Use a user-space TCP/IP stack to avoid issues with operating s
disable_ipv6: Disable IPv6
disable_ipv6_help: Disable IPv6 functionality for this node, only use IPv4 for network communication.
ipv6_public_addr_auto: Auto Public IPv6
ipv6_public_addr_auto_help: Auto-obtain a public IPv6 address from a peer that shares its IPv6 subnet.
enable_kcp_proxy: Enable KCP Proxy
enable_kcp_proxy_help: Convert TCP traffic to KCP traffic to reduce latency and boost transmission speed.
@@ -156,6 +159,9 @@ disable_tcp_hole_punching_help: Disable tcp hole punching
disable_udp_hole_punching: Disable UDP Hole Punching
disable_udp_hole_punching_help: Disable udp hole punching
disable_upnp: Disable UPnP
disable_upnp_help: Disable runtime UPnP/NAT-PMP port mapping for eligible listeners; automatic port mapping is enabled by default.
disable_sym_hole_punching: Disable Symmetric NAT Hole Punching
disable_sym_hole_punching_help: Disable special hole punching handling for symmetric NAT (based on birthday attack), treat symmetric NAT as cone NAT
@@ -115,6 +115,7 @@ export interface NetworkConfig {
use_smoltcp?: boolean
disable_ipv6?: boolean
ipv6_public_addr_auto?: boolean
enable_kcp_proxy?: boolean
disable_kcp_input?: boolean
enable_quic_proxy?: boolean
@@ -132,6 +133,7 @@ export interface NetworkConfig {
disable_encryption?: boolean
disable_tcp_hole_punching?: boolean
disable_udp_hole_punching?: boolean
disable_upnp?: boolean
disable_sym_hole_punching?: boolean
enable_relay_network_whitelist?: boolean
@@ -190,6 +192,7 @@ export function DEFAULT_NETWORK_CONFIG(): NetworkConfig {
use_smoltcp: false,
disable_ipv6: false,
ipv6_public_addr_auto: false,
enable_kcp_proxy: false,
disable_kcp_input: false,
enable_quic_proxy: false,
@@ -207,6 +210,7 @@ export function DEFAULT_NETWORK_CONFIG(): NetworkConfig {
disable_encryption: false,
disable_tcp_hole_punching: false,
disable_udp_hole_punching: false,
disable_upnp: false,
disable_sym_hole_punching: false,
enable_relay_network_whitelist: false,
relay_network_whitelist: [],
+7 -1
View File
@@ -3,7 +3,7 @@ name = "easytier"
description = "A full meshed p2p VPN, connecting all your devices in one network with one command."
homepage = "https://github.com/EasyTier/EasyTier"
repository = "https://github.com/EasyTier/EasyTier"
version = "2.6.2"
version = "2.6.3"
edition.workspace = true
rust-version.workspace = true
authors = ["kkrainbow"]
@@ -50,6 +50,8 @@ time = "0.3"
toml = "0.8.12"
chrono = { version = "0.4.37", features = ["serde"] }
guarden = "0.1"
delegate = "0.13.5"
itertools = "0.14.0"
@@ -219,6 +221,7 @@ async-ringbuf = "0.3.1"
service-manager = { git = "https://github.com/EasyTier/service-manager-rs.git", branch = "main" }
zstd = { version = "0.13", optional = true }
lzokay-native = { version = "0.1", optional = true }
kcp-sys = { git = "https://github.com/EasyTier/kcp-sys", rev = "94964794caaed5d388463137da59b97499619e5f", optional = true }
@@ -356,6 +359,7 @@ default = [
"faketcp",
"magic-dns",
"zstd",
"lzo",
]
full = [
"websocket",
@@ -370,6 +374,7 @@ full = [
"faketcp",
"magic-dns",
"zstd",
"lzo",
]
wireguard = ["dep:boringtun", "dep:ring"]
quic = ["dep:quinn", "dep:quinn-plaintext", "dep:rustls", "dep:rcgen"]
@@ -400,5 +405,6 @@ tracing = ["tokio/tracing", "dep:console-subscriber"]
magic-dns = ["dep:hickory-client", "dep:hickory-server"]
faketcp = ["dep:flume"]
zstd = ["dep:zstd"]
lzo = ["dep:lzokay-native"]
# For Network Extension on macOS
macos-ne = []
+5
View File
@@ -191,6 +191,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.type_attribute("peer_rpc.RouteForeignNetworkSummary", "#[derive(Hash, Eq)]")
.type_attribute("common.RpcDescriptor", "#[derive(Hash, Eq)]")
.type_attribute("acl.Acl", "#[serde(default)]")
.type_attribute("acl.AclV1", "#[serde(default)]")
.type_attribute("acl.Chain", "#[serde(default)]")
.type_attribute("acl.Rule", "#[serde(default)]")
.type_attribute("acl.GroupInfo", "#[serde(default)]")
.field_attribute(".api.manage.NetworkConfig", "#[serde(default)]")
.service_generator(Box::new(easytier_rpc_build::ServiceGenerator::default()))
.btree_map(["."])
+11 -2
View File
@@ -39,6 +39,15 @@ core_clap:
ipv6:
en: "ipv6 address of this vpn node, can be used together with ipv4 for dual-stack operation"
zh-CN: "此VPN节点的IPv6地址,可与IPv4一起使用以进行双栈操作"
ipv6_public_addr_provider:
en: "share this node's public IPv6 subnet with other peers so they can obtain public IPv6 addresses (Linux only)"
zh-CN: "将此节点的公网 IPv6 子网共享给其他节点,使它们也能获得公网 IPv6 地址(仅 Linux 支持)"
ipv6_public_addr_auto:
en: "auto-obtain a public IPv6 address from a peer that shares its IPv6 subnet"
zh-CN: "自动从共享了 IPv6 子网的对等节点获取一个公网 IPv6 地址"
ipv6_public_addr_prefix:
en: "manually specify the public IPv6 subnet to share, instead of auto-detecting from system routes"
zh-CN: "手动指定要共享的公网 IPv6 子网,不自动从系统路由检测"
dhcp:
en: "automatically determine and set IP address by Easytier, and the IP address starts from 10.0.0.1 by default. Warning, if there is an IP conflict in the network when using DHCP, the IP will be automatically changed."
zh-CN: "由Easytier自动确定并设置IP地址,默认从10.0.0.1开始。警告:在使用DHCP时,如果网络中出现IP冲突,IP将自动更改。"
@@ -185,8 +194,8 @@ core_clap:
en: "the url of the ipv6 listener, e.g.: tcp://[::]:11010, if not set, will listen on random udp port"
zh-CN: "IPv6 监听器的URL,例如:tcp://[::]:11010,如果未设置,将在随机UDP端口上监听"
compression:
en: "compression algorithm to use, support none, zstd. default is none"
zh-CN: "要使用的压缩算法,支持 none、zstd。默认为 none"
en: "compression algorithm to use, supported: %{algorithms}. default is none"
zh-CN: "要使用的压缩算法,支持%{algorithms}。默认为 none"
mapped_listeners:
en: "manually specify the public address of the listener, other nodes can use this address to connect to this node. e.g.: tcp://123.123.123.123:11223, can specify multiple."
zh-CN: "手动指定监听器的公网地址,其他节点可以使用该地址连接到本节点。例如:tcp://123.123.123.123:11223,可以指定多个。"
+7 -6
View File
@@ -137,12 +137,13 @@ pub fn setup_socket_for_win<S: AsRawSocket>(
}
let socket = SOCKET(socket.as_raw_socket() as usize);
let optval = 1_i32.to_ne_bytes();
unsafe {
if setsockopt(socket, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, Some(&optval)) == SOCKET_ERROR {
return Err(io::Error::last_os_error());
}
}
// let optval = 1_i32.to_ne_bytes();
// unsafe {
// if setsockopt(socket, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, Some(&optval)) == SOCKET_ERROR {
// return Err(io::Error::last_os_error());
// }
// }
if let Some(iface) = bind_dev {
set_ip_unicast_if(socket, bind_addr, &iface)?;
+39
View File
@@ -1339,6 +1339,45 @@ mod tests {
assert_eq!(result.matched_rule, Some(RuleId::Priority(70)));
}
#[tokio::test]
async fn test_forward_acl_source_ip_whitelist() {
let mut acl_config = Acl::default();
let mut acl_v1 = AclV1::default();
let mut chain = Chain {
name: "subnet_proxy_protect".to_string(),
chain_type: ChainType::Forward as i32,
enabled: true,
default_action: Action::Drop as i32,
..Default::default()
};
chain.rules.push(Rule {
name: "allow_my_devices".to_string(),
priority: 1000,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
source_ips: vec!["10.172.192.2/32".to_string()],
..Default::default()
});
acl_v1.chains.push(chain);
acl_config.acl_v1 = Some(acl_v1);
let processor = AclProcessor::new(acl_config);
let mut packet_info = create_test_packet_info();
packet_info.dst_ip = "192.168.1.10".parse().unwrap();
packet_info.src_ip = "10.172.192.2".parse().unwrap();
let result = processor.process_packet(&packet_info, ChainType::Forward);
assert_eq!(result.action, Action::Allow);
assert_eq!(result.matched_rule, Some(RuleId::Priority(1000)));
packet_info.src_ip = "10.172.192.3".parse().unwrap();
let result = processor.process_packet(&packet_info, ChainType::Forward);
assert_eq!(result.action, Action::Drop);
assert_eq!(result.matched_rule, Some(RuleId::Default));
}
fn create_test_acl_config() -> Acl {
let mut acl_config = Acl::default();
+46 -10
View File
@@ -1,4 +1,4 @@
#[cfg(feature = "zstd")]
#[cfg(any(feature = "zstd", feature = "lzo"))]
use anyhow::Context;
#[cfg(feature = "zstd")]
use dashmap::DashMap;
@@ -53,6 +53,13 @@ impl DefaultCompressor {
)
})
}),
#[cfg(feature = "lzo")]
CompressorAlgo::Lzo => lzokay_native::compress(data).with_context(|| {
format!(
"Failed to compress data with algorithm: {:?}",
compress_algo
)
}),
CompressorAlgo::None => Ok(data.to_vec()),
}
}
@@ -85,6 +92,13 @@ impl DefaultCompressor {
compress_algo
))
}),
#[cfg(feature = "lzo")]
CompressorAlgo::Lzo => lzokay_native::decompress_all(data, None).with_context(|| {
format!(
"Failed to decompress data with algorithm: {:?}",
compress_algo
)
}),
CompressorAlgo::None => Ok(data.to_vec()),
}
}
@@ -181,14 +195,13 @@ thread_local! {
static DCTX_MAP: RefCell<DashMap<CompressorAlgo, bulk::Decompressor<'static>>> = RefCell::new(DashMap::new());
}
#[cfg(all(test, feature = "zstd"))]
#[cfg(all(test, any(feature = "zstd", feature = "lzo")))]
pub mod tests {
use super::*;
#[tokio::test]
async fn test_compress() {
let text = b"12345670000000000000000000";
let mut packet = ZCPacket::new_with_payload(text);
async fn test_compress_algo(compress_algo: CompressorAlgo) {
let text = vec![b'a'; 4096];
let mut packet = ZCPacket::new_with_payload(&text);
packet.fill_peer_manager_hdr(0, 0, 0);
let compressor = DefaultCompressor {};
@@ -200,7 +213,7 @@ pub mod tests {
);
compressor
.compress(&mut packet, CompressorAlgo::ZstdDefault)
.compress(&mut packet, compress_algo)
.await
.unwrap();
println!(
@@ -215,8 +228,7 @@ pub mod tests {
assert!(!packet.peer_manager_header().unwrap().is_compressed());
}
#[tokio::test]
async fn test_short_text_compress() {
async fn test_short_text_compress_algo(compress_algo: CompressorAlgo) {
let text = b"1234";
let mut packet = ZCPacket::new_with_payload(text);
packet.fill_peer_manager_hdr(0, 0, 0);
@@ -225,7 +237,7 @@ pub mod tests {
// short text can't be compressed
compressor
.compress(&mut packet, CompressorAlgo::ZstdDefault)
.compress(&mut packet, compress_algo)
.await
.unwrap();
assert!(!packet.peer_manager_header().unwrap().is_compressed());
@@ -234,4 +246,28 @@ pub mod tests {
assert_eq!(packet.payload(), text);
assert!(!packet.peer_manager_header().unwrap().is_compressed());
}
#[cfg(feature = "zstd")]
#[tokio::test]
async fn test_zstd_compress() {
test_compress_algo(CompressorAlgo::ZstdDefault).await;
}
#[cfg(feature = "zstd")]
#[tokio::test]
async fn test_zstd_short_text_compress() {
test_short_text_compress_algo(CompressorAlgo::ZstdDefault).await;
}
#[cfg(feature = "lzo")]
#[tokio::test]
async fn test_lzo_compress() {
test_compress_algo(CompressorAlgo::Lzo).await;
}
#[cfg(feature = "lzo")]
#[tokio::test]
async fn test_lzo_short_text_compress() {
test_short_text_compress_algo(CompressorAlgo::Lzo).await;
}
}
+135
View File
@@ -71,6 +71,7 @@ pub fn gen_default_flags() -> Flags {
need_p2p: false,
instance_recv_bps_limit: u64::MAX,
disable_upnp: false,
disable_relay_data: false,
}
}
@@ -170,6 +171,15 @@ pub trait ConfigLoader: Send + Sync {
fn get_ipv6(&self) -> Option<cidr::Ipv6Inet>;
fn set_ipv6(&self, addr: Option<cidr::Ipv6Inet>);
fn get_ipv6_public_addr_provider(&self) -> bool;
fn set_ipv6_public_addr_provider(&self, enabled: bool);
fn get_ipv6_public_addr_auto(&self) -> bool;
fn set_ipv6_public_addr_auto(&self, enabled: bool);
fn get_ipv6_public_addr_prefix(&self) -> Option<cidr::Ipv6Cidr>;
fn set_ipv6_public_addr_prefix(&self, prefix: Option<cidr::Ipv6Cidr>);
fn get_dhcp(&self) -> bool;
fn set_dhcp(&self, dhcp: bool);
@@ -519,6 +529,9 @@ struct Config {
instance_id: Option<uuid::Uuid>,
ipv4: Option<String>,
ipv6: Option<String>,
ipv6_public_addr_provider: Option<bool>,
ipv6_public_addr_auto: Option<bool>,
ipv6_public_addr_prefix: Option<String>,
dhcp: Option<bool>,
network_identity: Option<NetworkIdentity>,
listeners: Option<Vec<url::Url>>,
@@ -700,6 +713,43 @@ impl ConfigLoader for TomlConfigLoader {
self.config.lock().unwrap().ipv6 = addr.map(|addr| addr.to_string());
}
fn get_ipv6_public_addr_provider(&self) -> bool {
self.config
.lock()
.unwrap()
.ipv6_public_addr_provider
.unwrap_or_default()
}
fn set_ipv6_public_addr_provider(&self, enabled: bool) {
self.config.lock().unwrap().ipv6_public_addr_provider = Some(enabled);
}
fn get_ipv6_public_addr_auto(&self) -> bool {
self.config
.lock()
.unwrap()
.ipv6_public_addr_auto
.unwrap_or_default()
}
fn set_ipv6_public_addr_auto(&self, enabled: bool) {
self.config.lock().unwrap().ipv6_public_addr_auto = Some(enabled);
}
fn get_ipv6_public_addr_prefix(&self) -> Option<cidr::Ipv6Cidr> {
let locked_config = self.config.lock().unwrap();
locked_config
.ipv6_public_addr_prefix
.as_ref()
.and_then(|s| s.parse().ok())
}
fn set_ipv6_public_addr_prefix(&self, prefix: Option<cidr::Ipv6Cidr>) {
self.config.lock().unwrap().ipv6_public_addr_prefix =
prefix.map(|prefix| prefix.to_string());
}
fn get_dhcp(&self) -> bool {
self.config.lock().unwrap().dhcp.unwrap_or_default()
}
@@ -1287,6 +1337,71 @@ stun_servers = [
assert!(err.to_string().contains("mapped listener port is missing"));
}
#[test]
fn test_acl_toml_rule_uses_defaults_for_omitted_fields() {
use crate::proto::acl::{Action, ChainType, Protocol};
let config_str = r#"
[[acl.acl_v1.chains]]
name = "subnet_proxy_protect"
chain_type = 3
enabled = true
default_action = 2
[[acl.acl_v1.chains.rules]]
name = "allow_my_devices"
priority = 1000
action = 1
source_ips = ["10.172.192.2/32"]
protocol = 5
enabled = true
"#;
let config = TomlConfigLoader::new_from_str(config_str).unwrap();
let acl = config.get_acl().unwrap();
let acl_v1 = acl.acl_v1.unwrap();
let chain = &acl_v1.chains[0];
let rule = &chain.rules[0];
assert_eq!(chain.chain_type, ChainType::Forward as i32);
assert_eq!(chain.default_action, Action::Drop as i32);
assert_eq!(rule.action, Action::Allow as i32);
assert_eq!(rule.protocol, Protocol::Any as i32);
assert_eq!(rule.source_ips, vec!["10.172.192.2/32"]);
assert!(rule.ports.is_empty());
assert!(rule.source_ports.is_empty());
assert!(rule.destination_ips.is_empty());
assert!(rule.source_groups.is_empty());
assert!(rule.destination_groups.is_empty());
assert_eq!(rule.rate_limit, 0);
assert_eq!(rule.burst_limit, 0);
assert!(!rule.stateful);
}
#[test]
fn test_acl_toml_group_can_omit_declares_or_members() {
let declares_only = r#"
[acl.acl_v1.group]
[[acl.acl_v1.group.declares]]
group_name = "admin"
group_secret = "admin-pw"
"#;
let config = TomlConfigLoader::new_from_str(declares_only).unwrap();
let group = config.get_acl().unwrap().acl_v1.unwrap().group.unwrap();
assert_eq!(group.declares.len(), 1);
assert!(group.members.is_empty());
let members_only = r#"
[acl.acl_v1.group]
members = ["admin"]
"#;
let config = TomlConfigLoader::new_from_str(members_only).unwrap();
let group = config.get_acl().unwrap().acl_v1.unwrap().group.unwrap();
assert!(group.declares.is_empty());
assert_eq!(group.members, vec!["admin"]);
}
#[test]
fn test_network_config_source_user_is_implicit() {
let config = TomlConfigLoader::default();
@@ -1312,6 +1427,26 @@ source = "user"
assert!(!explicit_user.dump().contains("[source]"));
}
#[test]
fn test_ipv6_public_addr_config_roundtrip() {
let config = TomlConfigLoader::default();
let prefix: cidr::Ipv6Cidr = "2001:db8:100::/64".parse().unwrap();
config.set_ipv6_public_addr_provider(true);
config.set_ipv6_public_addr_auto(true);
config.set_ipv6_public_addr_prefix(Some(prefix));
assert!(config.get_ipv6_public_addr_provider());
assert!(config.get_ipv6_public_addr_auto());
assert_eq!(config.get_ipv6_public_addr_prefix(), Some(prefix));
let dumped = config.dump();
let loaded = TomlConfigLoader::new_from_str(&dumped).unwrap();
assert!(loaded.get_ipv6_public_addr_provider());
assert!(loaded.get_ipv6_public_addr_auto());
assert_eq!(loaded.get_ipv6_public_addr_prefix(), Some(prefix));
}
#[tokio::test]
async fn full_example_test() {
let config_str = r#"
+20 -12
View File
@@ -73,16 +73,6 @@ pub async fn socket_addrs(
.port()
.or_else(default_port_number)
.ok_or(Error::InvalidUrl(url.to_string()))?;
// See https://github.com/EasyTier/EasyTier/pull/947
// here is for compatibility with old version
let port = match port {
0 => match url.scheme() {
"ws" => 80,
"wss" => 443,
_ => port,
},
_ => port,
};
// if host is an ip address, return it directly
match host {
@@ -121,9 +111,8 @@ pub async fn socket_addrs(
#[cfg(test)]
mod tests {
use crate::defer;
use super::*;
use guarden::defer;
#[tokio::test]
async fn test_socket_addrs() {
@@ -140,4 +129,23 @@ mod tests {
assert_eq!(2, addrs.len(), "addrs: {:?}", addrs);
println!("addrs2: {:?}", addrs);
}
#[tokio::test]
async fn socket_addrs_preserves_explicit_zero_port() {
let cases = [
("ws://127.0.0.1:0", 80, 0),
("wss://127.0.0.1:0", 443, 0),
("ws://127.0.0.1", 80, 80),
("wss://127.0.0.1", 443, 443),
];
for (raw_url, default_port, expected_port) in cases {
let url = url::Url::parse(raw_url).unwrap();
let addrs = socket_addrs(&url, || Some(default_port)).await.unwrap();
assert_eq!(
addrs,
vec![SocketAddr::from(([127, 0, 0, 1], expected_port))]
);
}
}
}
+237 -15
View File
@@ -1,5 +1,5 @@
use std::{
collections::{HashMap, hash_map::DefaultHasher},
collections::{BTreeSet, HashMap, hash_map::DefaultHasher},
hash::Hasher,
net::{IpAddr, SocketAddr},
sync::{Arc, Mutex},
@@ -68,6 +68,8 @@ pub enum GlobalCtxEvent {
DhcpIpv4Changed(Option<cidr::Ipv4Inet>, Option<cidr::Ipv4Inet>), // (old, new)
DhcpIpv4Conflicted(Option<cidr::Ipv4Inet>),
PublicIpv6Changed(Option<cidr::Ipv6Inet>, Option<cidr::Ipv6Inet>), // (old, new)
PublicIpv6RoutesUpdated(Vec<cidr::Ipv6Inet>, Vec<cidr::Ipv6Inet>), // (added, removed)
PortForwardAdded(PortForwardConfigPb),
@@ -200,6 +202,8 @@ pub struct GlobalCtx {
cached_ipv4: AtomicCell<Option<cidr::Ipv4Inet>>,
cached_ipv6: AtomicCell<Option<cidr::Ipv6Inet>>,
public_ipv6_lease: AtomicCell<Option<cidr::Ipv6Inet>>,
public_ipv6_routes: Mutex<BTreeSet<std::net::Ipv6Addr>>,
cached_proxy_cidrs: AtomicCell<Option<Vec<ProxyNetworkConfig>>>,
ip_collector: Mutex<Option<Arc<IPCollector>>>,
@@ -209,9 +213,16 @@ pub struct GlobalCtx {
stun_info_collection: Mutex<Arc<dyn StunInfoCollectorTrait>>,
running_listeners: Mutex<Vec<url::Url>>,
advertised_ipv6_public_addr_prefix: Mutex<Option<cidr::Ipv6Cidr>>,
flags: ArcSwap<Flags>,
// Runtime/base advertised feature flags before config-owned fields are
// overlaid by set_flags. Keep this separate so config patches do not erase
// runtime state such as public-server role, IPv6 provider status, or the
// non-whitelist avoid-relay preference.
base_feature_flags: AtomicCell<PeerFeatureFlag>,
feature_flags: AtomicCell<PeerFeatureFlag>,
token_bucket_manager: TokenBucketManager,
@@ -242,8 +253,17 @@ impl std::fmt::Debug for GlobalCtx {
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
impl GlobalCtx {
fn derive_feature_flags(flags: &Flags, current: Option<PeerFeatureFlag>) -> PeerFeatureFlag {
let mut feature_flags = current.unwrap_or_default();
fn apply_disable_relay_data_flag(
flags: &Flags,
mut feature_flags: PeerFeatureFlag,
) -> PeerFeatureFlag {
if flags.disable_relay_data {
feature_flags.avoid_relay_data = true;
}
feature_flags
}
fn derive_feature_flags(flags: &Flags, mut feature_flags: PeerFeatureFlag) -> PeerFeatureFlag {
feature_flags.kcp_input = !flags.disable_kcp_input;
feature_flags.no_relay_kcp = flags.disable_relay_kcp;
feature_flags.support_conn_list_sync = true;
@@ -251,7 +271,7 @@ impl GlobalCtx {
feature_flags.no_relay_quic = flags.disable_relay_quic;
feature_flags.need_p2p = flags.need_p2p;
feature_flags.disable_p2p = flags.disable_p2p;
feature_flags
Self::apply_disable_relay_data_flag(flags, feature_flags)
}
pub fn new(config_fs: impl ConfigLoader + 'static) -> Self {
@@ -280,7 +300,8 @@ impl GlobalCtx {
let flags = config_fs.get_flags();
let feature_flags = Self::derive_feature_flags(&flags, None);
let base_feature_flags = PeerFeatureFlag::default();
let feature_flags = Self::derive_feature_flags(&flags, base_feature_flags);
let credential_storage_path = config_fs.get_credential_file();
let credential_manager = Arc::new(CredentialManager::new(credential_storage_path));
@@ -295,6 +316,8 @@ impl GlobalCtx {
event_bus,
cached_ipv4: AtomicCell::new(None),
cached_ipv6: AtomicCell::new(None),
public_ipv6_lease: AtomicCell::new(None),
public_ipv6_routes: Mutex::new(BTreeSet::new()),
cached_proxy_cidrs: AtomicCell::new(None),
ip_collector: Mutex::new(Some(Arc::new(IPCollector::new(
@@ -307,9 +330,12 @@ impl GlobalCtx {
stun_info_collection: Mutex::new(stun_info_collector),
running_listeners: Mutex::new(Vec::new()),
advertised_ipv6_public_addr_prefix: Mutex::new(None),
flags: ArcSwap::new(Arc::new(flags)),
base_feature_flags: AtomicCell::new(base_feature_flags),
feature_flags: AtomicCell::new(feature_flags),
token_bucket_manager: TokenBucketManager::new(),
@@ -381,6 +407,45 @@ impl GlobalCtx {
self.cached_ipv6.store(None);
}
pub fn get_public_ipv6_lease(&self) -> Option<cidr::Ipv6Inet> {
self.public_ipv6_lease.load()
}
pub fn set_public_ipv6_lease(&self, addr: Option<cidr::Ipv6Inet>) {
self.public_ipv6_lease.store(addr);
}
pub fn set_public_ipv6_routes(&self, routes: BTreeSet<cidr::Ipv6Inet>) {
*self.public_ipv6_routes.lock().unwrap() =
routes.into_iter().map(|route| route.address()).collect();
}
pub fn is_ip_local_ipv6(&self, ip: &std::net::Ipv6Addr) -> bool {
self.get_ipv6().map(|x| x.address() == *ip).unwrap_or(false)
|| self
.get_public_ipv6_lease()
.map(|x| x.address() == *ip)
.unwrap_or(false)
}
pub fn is_ip_easytier_managed_ipv6(&self, ip: &std::net::Ipv6Addr) -> bool {
self.is_ip_local_ipv6(ip) || self.public_ipv6_routes.lock().unwrap().contains(ip)
}
pub fn get_advertised_ipv6_public_addr_prefix(&self) -> Option<cidr::Ipv6Cidr> {
*self.advertised_ipv6_public_addr_prefix.lock().unwrap()
}
pub fn set_advertised_ipv6_public_addr_prefix(&self, prefix: Option<cidr::Ipv6Cidr>) -> bool {
let mut guard = self.advertised_ipv6_public_addr_prefix.lock().unwrap();
if *guard == prefix {
return false;
}
*guard = prefix;
true
}
pub fn get_id(&self) -> uuid::Uuid {
self.config.get_id()
}
@@ -395,7 +460,7 @@ impl GlobalCtx {
pub fn is_ip_local_virtual_ip(&self, ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => self.get_ipv4().map(|x| x.address() == *v4).unwrap_or(false),
IpAddr::V6(v6) => self.get_ipv6().map(|x| x.address() == *v6).unwrap_or(false),
IpAddr::V6(v6) => self.is_ip_local_ipv6(v6),
}
}
@@ -466,7 +531,7 @@ impl GlobalCtx {
self.config.set_flags(flags.clone());
self.feature_flags.store(Self::derive_feature_flags(
&flags,
Some(self.feature_flags.load()),
self.base_feature_flags.load(),
));
self.flags.store(Arc::new(flags));
}
@@ -531,8 +596,53 @@ impl GlobalCtx {
self.feature_flags.load()
}
pub fn set_feature_flags(&self, flags: PeerFeatureFlag) {
self.feature_flags.store(flags);
/// Replace the runtime/base advertised flags as a complete snapshot.
///
/// This is intended for foreign scoped contexts that inherit an already
/// computed feature-flag snapshot from their parent. Most callers should use
/// a narrower setter so they do not accidentally overwrite unrelated runtime
/// state.
pub fn set_base_advertised_feature_flags(&self, feature_flags: PeerFeatureFlag) {
self.base_feature_flags.store(feature_flags);
let flags = self.flags.load();
self.feature_flags
.store(Self::apply_disable_relay_data_flag(
flags.as_ref(),
feature_flags,
));
}
/// Set the avoid-relay preference that is independent of disable_relay_data.
///
/// disable_relay_data still forces the effective advertised flag to true,
/// but this base preference is preserved when that config flag is toggled.
pub fn set_avoid_relay_data_preference(&self, avoid_relay_data: bool) -> bool {
let mut base_feature_flags = self.base_feature_flags.load();
base_feature_flags.avoid_relay_data = avoid_relay_data;
self.base_feature_flags.store(base_feature_flags);
let mut feature_flags = self.feature_flags.load();
let previous = feature_flags.avoid_relay_data;
feature_flags.avoid_relay_data = avoid_relay_data || self.flags.load().disable_relay_data;
self.feature_flags.store(feature_flags);
previous != feature_flags.avoid_relay_data
}
/// Set the runtime IPv6-provider advertised bit without touching
/// config-derived feature flags.
pub fn set_ipv6_public_addr_provider_feature_flag(&self, enabled: bool) -> bool {
let mut base_feature_flags = self.base_feature_flags.load();
base_feature_flags.ipv6_public_addr_provider = enabled;
self.base_feature_flags.store(base_feature_flags);
let mut feature_flags = self.feature_flags.load();
if feature_flags.ipv6_public_addr_provider == enabled {
return false;
}
feature_flags.ipv6_public_addr_provider = enabled;
self.feature_flags.store(feature_flags);
true
}
pub fn token_bucket_manager(&self) -> &TokenBucketManager {
@@ -645,23 +755,23 @@ impl GlobalCtx {
pub fn should_deny_proxy(&self, dst_addr: &SocketAddr, is_udp: bool) -> bool {
let _g = self.net_ns.guard();
let ip = dst_addr.ip();
// first check if ip is virtual ip
// first check if ip is an EasyTier-managed local address
// then try bind this ip, if succ means it is local ip
let dst_is_local_virtual_ip = self.is_ip_local_virtual_ip(&ip);
let dst_is_local_et_ip = self.is_ip_local_virtual_ip(&ip);
// this is an expensive operation, should be called sparingly
// 1. tcp/kcp/quic call this only after proxy conn is established
// 2. udp cache the result in nat entry
let dst_is_local_phy_ip = std::net::UdpSocket::bind(format!("{}:0", ip)).is_ok();
tracing::trace!(
"check should_deny_proxy: dst_addr={}, dst_is_local_virtual_ip={}, dst_is_local_phy_ip={}, is_udp={}",
"check should_deny_proxy: dst_addr={}, dst_is_local_et_ip={}, dst_is_local_phy_ip={}, is_udp={}",
dst_addr,
dst_is_local_virtual_ip,
dst_is_local_et_ip,
dst_is_local_phy_ip,
is_udp
);
if dst_is_local_virtual_ip || dst_is_local_phy_ip {
if dst_is_local_et_ip || dst_is_local_phy_ip {
// if is local ip, make sure the port is not one of the listening ports
self.is_port_in_running_listeners(dst_addr.port(), is_udp)
|| (!is_udp && protected_port::is_protected_tcp_port(dst_addr.port()))
@@ -749,7 +859,7 @@ pub mod tests {
let mut feature_flags = global_ctx.get_feature_flags();
feature_flags.avoid_relay_data = true;
feature_flags.is_public_server = true;
global_ctx.set_feature_flags(feature_flags);
global_ctx.set_base_advertised_feature_flags(feature_flags);
let mut flags = global_ctx.get_flags().clone();
flags.disable_kcp_input = true;
@@ -770,6 +880,84 @@ pub mod tests {
assert!(feature_flags.support_conn_list_sync);
assert!(feature_flags.avoid_relay_data);
assert!(feature_flags.is_public_server);
assert!(!feature_flags.ipv6_public_addr_provider);
}
#[tokio::test]
async fn set_base_advertised_feature_flags_applies_current_values() {
let config = TomlConfigLoader::default();
let global_ctx = GlobalCtx::new(config);
let feature_flags = PeerFeatureFlag {
kcp_input: false,
no_relay_kcp: true,
quic_input: false,
no_relay_quic: true,
is_public_server: true,
..Default::default()
};
global_ctx.set_base_advertised_feature_flags(feature_flags);
assert_eq!(global_ctx.get_feature_flags(), feature_flags);
}
#[tokio::test]
async fn set_base_advertised_feature_flags_keeps_disable_relay_data_effective() {
let config = TomlConfigLoader::default();
let global_ctx = GlobalCtx::new(config);
let mut flags = global_ctx.get_flags().clone();
flags.disable_relay_data = true;
global_ctx.set_flags(flags);
let mut feature_flags = global_ctx.get_feature_flags();
feature_flags.avoid_relay_data = false;
feature_flags.is_public_server = true;
global_ctx.set_base_advertised_feature_flags(feature_flags);
let advertised_feature_flags = global_ctx.get_feature_flags();
assert!(advertised_feature_flags.avoid_relay_data);
assert!(advertised_feature_flags.is_public_server);
let mut flags = global_ctx.get_flags().clone();
flags.disable_relay_data = false;
global_ctx.set_flags(flags);
let advertised_feature_flags = global_ctx.get_feature_flags();
assert!(!advertised_feature_flags.avoid_relay_data);
assert!(advertised_feature_flags.is_public_server);
}
#[tokio::test]
async fn disable_relay_data_sets_avoid_relay_feature_flag() {
let config = TomlConfigLoader::default();
let global_ctx = GlobalCtx::new(config);
let mut flags = global_ctx.get_flags().clone();
flags.disable_relay_data = true;
global_ctx.set_flags(flags);
assert!(global_ctx.get_feature_flags().avoid_relay_data);
let mut flags = global_ctx.get_flags().clone();
flags.disable_relay_data = false;
global_ctx.set_flags(flags);
assert!(!global_ctx.get_feature_flags().avoid_relay_data);
global_ctx.set_avoid_relay_data_preference(true);
let mut flags = global_ctx.get_flags().clone();
flags.disable_relay_data = true;
global_ctx.set_flags(flags);
assert!(global_ctx.get_feature_flags().avoid_relay_data);
let mut flags = global_ctx.get_flags().clone();
flags.disable_relay_data = false;
global_ctx.set_flags(flags);
assert!(global_ctx.get_feature_flags().avoid_relay_data);
}
#[tokio::test]
@@ -789,6 +977,40 @@ pub mod tests {
protected_port::clear_protected_tcp_ports_for_test();
}
#[tokio::test]
async fn virtual_ipv6_and_public_ipv6_lease_are_stored_separately() {
let config = TomlConfigLoader::default();
let global_ctx = GlobalCtx::new(config);
let virtual_ipv6 = "fd00::1/64".parse().unwrap();
let public_ipv6 = "2001:db8::2/64".parse().unwrap();
global_ctx.set_ipv6(Some(virtual_ipv6));
global_ctx.set_public_ipv6_lease(Some(public_ipv6));
assert_eq!(global_ctx.get_ipv6(), Some(virtual_ipv6));
assert_eq!(global_ctx.get_public_ipv6_lease(), Some(public_ipv6));
}
#[tokio::test]
async fn public_ipv6_lease_is_treated_as_local_ip() {
protected_port::clear_protected_tcp_ports_for_test();
let config = TomlConfigLoader::default();
let global_ctx = GlobalCtx::new(config);
let public_ipv6 = "2001:db8::2/64".parse().unwrap();
let listener: url::Url = "tcp://[2001:db8::2]:11010".parse().unwrap();
global_ctx.set_public_ipv6_lease(Some(public_ipv6));
global_ctx.add_running_listener(listener);
let ip = std::net::IpAddr::V6(public_ipv6.address());
let socket = SocketAddr::from((public_ipv6.address(), 11010));
assert!(global_ctx.is_ip_local_virtual_ip(&ip));
assert!(global_ctx.should_deny_proxy(&socket, false));
protected_port::clear_protected_tcp_ports_for_test();
}
pub fn get_mock_global_ctx_with_network(
network_identy: Option<NetworkIdentity>,
) -> ArcGlobalCtx {
+11
View File
@@ -166,3 +166,14 @@ pub type IfConfiger = DummyIfConfiger;
#[cfg(target_os = "windows")]
pub use windows::RegistryManager;
#[cfg(target_os = "linux")]
pub(crate) fn list_ipv6_route_messages()
-> Result<Vec<netlink_packet_route::route::RouteMessage>, Error> {
netlink::NetlinkIfConfiger::list_ipv6_route_messages()
}
#[cfg(target_os = "linux")]
pub(crate) fn get_interface_index(name: &str) -> Result<u32, Error> {
netlink::NetlinkIfConfiger::get_interface_index(name)
}
+200 -16
View File
@@ -160,7 +160,7 @@ impl From<RouteMessage> for Route {
pub struct NetlinkIfConfiger {}
impl NetlinkIfConfiger {
fn get_interface_index(name: &str) -> Result<u32, Error> {
pub(crate) fn get_interface_index(name: &str) -> Result<u32, Error> {
let name = CString::new(name).with_context(|| "failed to convert interface name")?;
match unsafe { libc::if_nametoindex(name.as_ptr()) } {
0 => Err(std::io::Error::last_os_error().into()),
@@ -311,7 +311,7 @@ impl NetlinkIfConfiger {
Self::set_flags_op(name, SIOCGIFFLAGS, InterfaceFlags::empty())
}
fn list_routes() -> Result<Vec<RouteMessage>, Error> {
fn list_route_messages(address_family: AddressFamily) -> Result<Vec<RouteMessage>, Error> {
let mut message = RouteMessage::default();
message.header.table = RouteHeader::RT_TABLE_UNSPEC;
@@ -320,7 +320,7 @@ impl NetlinkIfConfiger {
message.header.scope = RouteScope::Universe;
message.header.kind = RouteType::Unicast;
message.header.address_family = AddressFamily::Inet;
message.header.address_family = address_family;
message.header.destination_prefix_length = 0;
message.header.source_prefix_length = 0;
@@ -367,6 +367,14 @@ impl NetlinkIfConfiger {
Ok(ret_vec)
}
fn list_routes() -> Result<Vec<RouteMessage>, Error> {
Self::list_route_messages(AddressFamily::Inet)
}
pub(crate) fn list_ipv6_route_messages() -> Result<Vec<RouteMessage>, Error> {
Self::list_route_messages(AddressFamily::Inet6)
}
}
#[async_trait]
@@ -551,12 +559,9 @@ impl IfConfiguerTrait for NetlinkIfConfiger {
message.header.scope = RouteScope::Universe;
message.header.kind = RouteType::Unicast;
// Add metric (cost) if specified
if let Some(cost) = cost {
message
.attributes
.push(RouteAttribute::Priority(cost as u32));
}
message
.attributes
.push(RouteAttribute::Priority(cost.unwrap_or(65535) as u32));
message
.attributes
@@ -564,9 +569,11 @@ impl IfConfiguerTrait for NetlinkIfConfiger {
name,
)?));
message
.attributes
.push(RouteAttribute::Destination(RouteAddress::Inet6(address)));
if cidr_prefix != 0 {
message
.attributes
.push(RouteAttribute::Destination(RouteAddress::Inet6(address)));
}
send_netlink_req_and_wait_one_resp(RouteNetlinkMessage::NewRoute(message), false)
}
@@ -577,7 +584,7 @@ impl IfConfiguerTrait for NetlinkIfConfiger {
address: std::net::Ipv6Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
let routes = Self::list_routes()?;
let routes = Self::list_route_messages(AddressFamily::Inet6)?;
let ifidx = NetlinkIfConfiger::get_interface_index(name)?;
for msg in routes {
@@ -598,29 +605,82 @@ impl IfConfiguerTrait for NetlinkIfConfiger {
#[cfg(test)]
mod tests {
use super::*;
use std::process::Command;
const DUMMY_IFACE_NAME: &str = "dummy";
fn run_cmd(cmd: &str) -> String {
let output = std::process::Command::new("sh")
let output = Command::new("sh")
.arg("-c")
.arg(cmd)
.output()
.expect("failed to execute process");
assert!(
output.status.success(),
"command failed: {cmd}\nstdout: {}\nstderr: {}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
);
String::from_utf8(output.stdout).unwrap()
}
fn run_ip(args: &[&str]) {
let output = Command::new("ip")
.args(args)
.output()
.expect("failed to execute ip process");
assert!(
output.status.success(),
"ip command failed: {:?}\nstdout: {}\nstderr: {}",
args,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
);
}
fn test_iface_name(tag: &str) -> String {
format!("et{}{:x}", tag, std::process::id() & 0xffff)
}
struct ScopedDummyLink {
name: String,
}
impl ScopedDummyLink {
fn new(name: &str) -> Self {
let _ = Command::new("ip").args(["link", "del", name]).output();
run_ip(&["link", "add", name, "type", "dummy"]);
run_ip(&["link", "set", name, "up"]);
Self {
name: name.to_string(),
}
}
}
impl Drop for ScopedDummyLink {
fn drop(&mut self) {
let _ = Command::new("ip")
.args(["link", "del", &self.name])
.output();
}
}
struct PrepareEnv {}
impl PrepareEnv {
fn new() -> Self {
let _ = run_cmd(&format!("sudo ip link add {} type dummy", DUMMY_IFACE_NAME));
let _ = Command::new("ip")
.args(["link", "del", DUMMY_IFACE_NAME])
.output();
let _ = run_cmd(&format!("ip link add {} type dummy", DUMMY_IFACE_NAME));
PrepareEnv {}
}
}
impl Drop for PrepareEnv {
fn drop(&mut self) {
let _ = run_cmd(&format!("sudo ip link del {}", DUMMY_IFACE_NAME));
let _ = Command::new("ip")
.args(["link", "del", DUMMY_IFACE_NAME])
.output();
}
}
@@ -701,4 +761,128 @@ mod tests {
.collect::<Vec<_>>();
assert!(!routes.contains(&IpAddr::V4("10.5.5.0".parse().unwrap())));
}
#[serial_test::serial]
#[tokio::test]
async fn ipv6_addr_readback_test() {
let iface = test_iface_name("a");
let _link = ScopedDummyLink::new(&iface);
run_ip(&["-6", "addr", "add", "2001:db8:1234::2/64", "dev", &iface]);
let addrs = NetlinkIfConfiger::list_addresses(&iface).unwrap();
assert!(addrs.iter().any(|addr| {
addr.address() == IpAddr::V6("2001:db8:1234::2".parse().unwrap())
&& addr.network_length() == 64
}));
}
#[serial_test::serial]
#[tokio::test]
async fn ipv6_route_readback_test() {
let wan_if = test_iface_name("rw");
let lan_if = test_iface_name("rl");
let _wan = ScopedDummyLink::new(&wan_if);
let _lan = ScopedDummyLink::new(&lan_if);
run_ip(&[
"-6",
"addr",
"add",
"2001:db8:100:ffff::2/64",
"dev",
&wan_if,
]);
run_ip(&[
"-6",
"route",
"add",
"default",
"from",
"2001:db8:100::/56",
"dev",
&wan_if,
]);
run_ip(&["-6", "route", "add", "2001:db8:100::/56", "dev", &lan_if]);
let wan_ifindex = NetlinkIfConfiger::get_interface_index(&wan_if).unwrap();
let lan_ifindex = NetlinkIfConfiger::get_interface_index(&lan_if).unwrap();
let routes = NetlinkIfConfiger::list_ipv6_route_messages().unwrap();
assert!(routes.iter().any(|route| {
route.header.kind == RouteType::Unicast
&& route.header.source_prefix_length == 56
&& route.attributes.iter().any(|attr| {
matches!(
attr,
RouteAttribute::Source(RouteAddress::Inet6(addr))
if *addr == "2001:db8:100::".parse::<std::net::Ipv6Addr>().unwrap()
)
})
&& route
.attributes
.iter()
.any(|attr| matches!(attr, RouteAttribute::Oif(index) if *index == wan_ifindex))
&& !route
.attributes
.iter()
.any(|attr| matches!(attr, RouteAttribute::Destination(_)))
}));
assert!(routes.iter().any(|route| {
route.header.kind == RouteType::Unicast
&& route.header.destination_prefix_length == 56
&& route.attributes.iter().any(|attr| {
matches!(
attr,
RouteAttribute::Destination(RouteAddress::Inet6(addr))
if *addr == "2001:db8:100::".parse::<std::net::Ipv6Addr>().unwrap()
)
})
&& route
.attributes
.iter()
.any(|attr| matches!(attr, RouteAttribute::Oif(index) if *index == lan_ifindex))
}));
}
#[serial_test::serial]
#[tokio::test]
async fn ipv6_route_remove_test() {
let iface = test_iface_name("rr");
let _link = ScopedDummyLink::new(&iface);
let ifcfg = NetlinkIfConfiger {};
let route_addr = "2001:db8:200::".parse::<std::net::Ipv6Addr>().unwrap();
ifcfg
.add_ipv6_route(&iface, route_addr, 56, None)
.await
.unwrap();
let ifindex = NetlinkIfConfiger::get_interface_index(&iface).unwrap();
let has_route = |routes: &[RouteMessage]| {
routes.iter().any(|route| {
route.header.destination_prefix_length == 56
&& route.attributes.iter().any(|attr| {
matches!(
attr,
RouteAttribute::Destination(RouteAddress::Inet6(addr)) if *addr == route_addr
)
})
&& route
.attributes
.iter()
.any(|attr| matches!(attr, RouteAttribute::Oif(index) if *index == ifindex))
})
};
let routes = NetlinkIfConfiger::list_ipv6_route_messages().unwrap();
assert!(has_route(&routes));
ifcfg
.remove_ipv6_route(&iface, route_addr, 56)
.await
.unwrap();
let routes = NetlinkIfConfiger::list_ipv6_route_messages().unwrap();
assert!(!has_route(&routes));
}
}
+193 -11
View File
@@ -58,6 +58,21 @@ fn parse_env_filter(default_level: Option<LevelFilter>) -> Result<EnvFilter, any
.with_context(|| "failed to create env filter")
}
fn parse_static_filter(level: LevelFilter) -> Result<EnvFilter, anyhow::Error> {
EnvFilter::builder()
.with_default_directive(level.into())
.parse("")
.with_context(|| "failed to create static filter")
}
fn parse_file_filter(level: LevelFilter) -> Result<EnvFilter, anyhow::Error> {
if matches!(level, LevelFilter::OFF) {
parse_static_filter(level)
} else {
parse_env_filter(Some(level))
}
}
fn is_log(meta: &Metadata) -> bool {
meta.target() == LOG_TARGET || meta.target().starts_with(&format!("{LOG_TARGET}::"))
}
@@ -165,14 +180,17 @@ fn file_layers(
) -> anyhow::Result<(Vec<BoxLayer>, Option<NewFilterSender>)> {
let mut layers = Vec::new();
let level = config.level.map(|s| s.parse().unwrap());
let level = config
.level
.map(|s| s.parse().unwrap())
.unwrap_or(LevelFilter::OFF);
if matches!(level, Some(LevelFilter::OFF)) && !reload {
if matches!(level, LevelFilter::OFF) && !reload {
return Ok((layers, None));
}
let (file_filter, file_filter_reloader) =
tracing_subscriber::reload::Layer::<_, Registry>::new(parse_env_filter(level)?);
tracing_subscriber::reload::Layer::<_, Registry>::new(parse_file_filter(level)?);
let layer = |wrapper| {
layer()
@@ -218,9 +236,7 @@ fn file_layers(
// 初始化全局状态
let _ = LOGGER_LEVEL_SENDER.set(std::sync::Mutex::new(tx.clone()));
if let Some(level) = level {
let _ = CURRENT_LOG_LEVEL.set(std::sync::Mutex::new(level.to_string()));
}
let _ = CURRENT_LOG_LEVEL.set(std::sync::Mutex::new(level.to_string()));
std::thread::spawn(move || {
while let Ok(lf) = rx.recv() {
@@ -232,11 +248,7 @@ fn file_layers(
}
};
let mut new_filter = match EnvFilter::builder()
.with_default_directive(parsed_level.into())
.from_env()
.with_context(|| "failed to create file filter")
{
let mut new_filter = match parse_file_filter(parsed_level) {
Ok(filter) => Some(filter),
Err(e) => {
error!("Failed to build new log filter for {:?}: {:?}", lf, e);
@@ -268,6 +280,36 @@ mod tests {
use super::*;
use crate::common::config::FileLoggerConfig;
const RUST_LOG: &str = "RUST_LOG";
struct EnvVarGuard {
key: &'static str,
previous: Option<std::ffi::OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: &str) -> Self {
let previous = std::env::var_os(key);
unsafe { std::env::set_var(key, value) };
Self { key, previous }
}
fn unset(key: &'static str) -> Self {
let previous = std::env::var_os(key);
unsafe { std::env::remove_var(key) };
Self { key, previous }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => unsafe { std::env::set_var(self.key, value) },
None => unsafe { std::env::remove_var(self.key) },
}
}
}
#[ctor::ctor]
fn init() {
let _ = Registry::default()
@@ -276,7 +318,147 @@ mod tests {
}
#[test]
fn default_file_logger_level_is_off_without_reload() {
let (layers, sender) = file_layers(FileLoggerConfig::default(), false).unwrap();
assert!(layers.is_empty());
assert!(sender.is_none());
}
#[test]
#[serial_test::serial]
fn default_file_logger_level_filters_info_with_reload() {
let _guard = EnvVarGuard::set(RUST_LOG, "info");
let temp_dir = tempfile::tempdir().unwrap();
let log_file_name = "default-off-test.log".to_string();
let log_path = temp_dir.path().join(&log_file_name);
let cfg = FileLoggerConfig {
file: Some(log_file_name),
dir: Some(temp_dir.path().to_string_lossy().to_string()),
..Default::default()
};
let (layers, _sender) = file_layers(cfg, true).unwrap();
let marker = "default-file-logger-off-marker";
let subscriber = Registry::default().with(layers);
tracing::subscriber::with_default(subscriber, || {
tracing::info!(target: LOG_TARGET, "{}", marker);
std::thread::sleep(std::time::Duration::from_millis(300));
});
let content = std::fs::read_to_string(&log_path).unwrap_or_default();
assert!(
!content.contains(marker),
"default file logger level should filter info logs"
);
}
#[test]
#[serial_test::serial]
fn file_logger_level_uses_env_filter_when_enabled() {
let _guard = EnvVarGuard::set(RUST_LOG, "debug");
let temp_dir = tempfile::tempdir().unwrap();
let log_file_name = "env-filter-test.log".to_string();
let log_path = temp_dir.path().join(&log_file_name);
let cfg = FileLoggerConfig {
level: Some(LevelFilter::INFO.to_string()),
file: Some(log_file_name),
dir: Some(temp_dir.path().to_string_lossy().to_string()),
..Default::default()
};
let (layers, _sender) = file_layers(cfg, true).unwrap();
let marker = "file-logger-env-filter-marker";
let subscriber = Registry::default().with(layers);
tracing::subscriber::with_default(subscriber, || {
tracing::debug!(target: LOG_TARGET, "{}", marker);
std::thread::sleep(std::time::Duration::from_millis(300));
});
let content = std::fs::read_to_string(&log_path).unwrap_or_default();
assert!(
content.contains(marker),
"enabled file logger should use RUST_LOG directives"
);
}
#[test]
#[serial_test::serial]
fn file_logger_reload_uses_env_filter_when_enabled() {
let _guard = EnvVarGuard::set(RUST_LOG, "debug");
let temp_dir = tempfile::tempdir().unwrap();
let log_file_name = "reload-env-filter-test.log".to_string();
let log_path = temp_dir.path().join(&log_file_name);
let cfg = FileLoggerConfig {
file: Some(log_file_name),
dir: Some(temp_dir.path().to_string_lossy().to_string()),
..Default::default()
};
let (layers, sender) = file_layers(cfg, true).unwrap();
let sender = sender.expect("reload=true should return a sender");
let marker = "file-logger-reload-env-filter-marker";
let subscriber = Registry::default().with(layers);
tracing::subscriber::with_default(subscriber, || {
sender.send(LevelFilter::INFO.to_string()).unwrap();
std::thread::sleep(std::time::Duration::from_millis(300));
tracing::debug!(target: LOG_TARGET, "{}", marker);
std::thread::sleep(std::time::Duration::from_millis(300));
});
let content = std::fs::read_to_string(&log_path).unwrap_or_default();
assert!(
content.contains(marker),
"file logger enabled by reload should use RUST_LOG directives"
);
}
#[test]
#[serial_test::serial]
fn file_logger_reload_off_ignores_env_filter() {
let _guard = EnvVarGuard::set(RUST_LOG, "info");
let temp_dir = tempfile::tempdir().unwrap();
let log_file_name = "reload-off-test.log".to_string();
let log_path = temp_dir.path().join(&log_file_name);
let cfg = FileLoggerConfig {
level: Some(LevelFilter::INFO.to_string()),
file: Some(log_file_name),
dir: Some(temp_dir.path().to_string_lossy().to_string()),
..Default::default()
};
let (layers, sender) = file_layers(cfg, true).unwrap();
let sender = sender.expect("reload=true should return a sender");
let marker = "file-logger-reload-off-marker";
let subscriber = Registry::default().with(layers);
tracing::subscriber::with_default(subscriber, || {
sender.send(LevelFilter::OFF.to_string()).unwrap();
std::thread::sleep(std::time::Duration::from_millis(300));
tracing::info!(target: LOG_TARGET, "{}", marker);
std::thread::sleep(std::time::Duration::from_millis(300));
});
let content = std::fs::read_to_string(&log_path).unwrap_or_default();
assert!(
!content.contains(marker),
"disabled file logger should ignore RUST_LOG directives"
);
}
#[test]
#[serial_test::serial]
fn test_logger_reload() {
let _guard = EnvVarGuard::unset(RUST_LOG);
let temp_dir = tempfile::tempdir().unwrap();
let log_file_name = "reload-test.log".to_string();
let log_path = temp_dir.path().join(&log_file_name);
+70 -30
View File
@@ -64,6 +64,24 @@ async fn resolve_mapped_listener_addrs(listener: &url::Url) -> Result<Vec<Socket
socket_addrs(listener, || mapped_listener_port(listener)).await
}
fn is_usable_public_ipv6_candidate(ip: &Ipv6Addr, global_ctx: &ArcGlobalCtx) -> bool {
is_usable_public_ipv6_candidate_with_mode(ip, global_ctx, TESTING.load(Ordering::Relaxed))
}
fn is_usable_public_ipv6_candidate_with_mode(
ip: &Ipv6Addr,
global_ctx: &ArcGlobalCtx,
testing: bool,
) -> bool {
!global_ctx.is_ip_easytier_managed_ipv6(ip)
&& (testing
|| (!ip.is_loopback()
&& !ip.is_unspecified()
&& !ip.is_unique_local()
&& !ip.is_unicast_link_local()
&& !ip.is_multicast()))
}
#[async_trait::async_trait]
pub trait PeerManagerForDirectConnector {
async fn list_peers(&self) -> Vec<PeerId>;
@@ -190,34 +208,28 @@ impl DirectConnectorManagerData {
.with_context(|| format!("failed to bind local socket for {}", remote_url))?,
);
let connector_ip = self
.peer_manager
.get_global_ctx()
.global_ctx
.get_stun_info_collector()
.get_stun_info()
.public_ip
.iter()
.find(|x| x.contains(':'))
.ok_or(anyhow::anyhow!(
"failed to get public ipv6 address from stun info"
))?
.parse::<Ipv6Addr>()
.with_context(|| {
format!(
"failed to parse public ipv6 address from stun info: {:?}",
self.peer_manager
.get_global_ctx()
.get_stun_info_collector()
.get_stun_info()
)
})?;
let connector_addr =
SocketAddr::new(IpAddr::V6(connector_ip), local_socket.local_addr()?.port());
.filter_map(|ip| ip.parse::<Ipv6Addr>().ok())
.find(|ip| !self.global_ctx.is_ip_easytier_managed_ipv6(ip));
// ask remote to send v6 hole punch packet
// and no matter what the result is, continue to connect
let _ = self
.remote_send_udp_hole_punch_packet(dst_peer_id, connector_addr, remote_url)
.await;
if let Some(connector_ip) = connector_ip {
let connector_addr =
SocketAddr::new(IpAddr::V6(connector_ip), local_socket.local_addr()?.port());
let _ = self
.remote_send_udp_hole_punch_packet(dst_peer_id, connector_addr, remote_url)
.await;
} else {
tracing::debug!(
?remote_url,
"skip remote IPv6 hole-punch packet; no non-EasyTier public IPv6 in STUN info"
);
}
let udp_connector = UdpTunnelConnector::new(remote_url.clone());
let remote_addr = SocketAddr::from_url(remote_url.clone(), IpVersion::V6).await?;
@@ -479,14 +491,7 @@ impl DirectConnectorManagerData {
.iter()
.chain(ip_list.public_ipv6.iter())
.filter_map(|x| Ipv6Addr::from_str(&x.to_string()).ok())
.filter(|x| {
TESTING.load(Ordering::Relaxed)
|| (!x.is_loopback()
&& !x.is_unspecified()
&& !x.is_unique_local()
&& !x.is_unicast_link_local()
&& !x.is_multicast())
})
.filter(|x| is_usable_public_ipv6_candidate(x, &self.global_ctx))
.collect::<HashSet<_>>()
.iter()
.for_each(|ip| {
@@ -515,6 +520,11 @@ impl DirectConnectorManagerData {
);
}
});
} else if self.global_ctx.is_ip_easytier_managed_ipv6(s_addr.ip()) {
tracing::debug!(
?listener,
"skip EasyTier-managed IPv6 as direct-connect target"
);
} else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) {
if self
.global_ctx
@@ -790,9 +800,10 @@ impl DirectConnectorManager {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::{collections::BTreeSet, sync::Arc};
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
connector::direct::{
DirectConnectorManager, DirectConnectorManagerData, DstListenerUrlBlackListItem,
},
@@ -802,12 +813,41 @@ mod tests {
wait_route_appear_with_cost,
},
proto::peer_rpc::GetIpListResponse,
tunnel::{IpScheme, TunnelScheme, matches_scheme},
};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use super::{TESTING, mapped_listener_port, resolve_mapped_listener_addrs};
#[tokio::test]
async fn public_ipv6_candidate_rejects_easytier_managed_addr_even_in_tests() {
let global_ctx = get_mock_global_ctx();
let managed_ipv6: cidr::Ipv6Inet = "2001:db8::2/128".parse().unwrap();
global_ctx.set_public_ipv6_routes(BTreeSet::from([managed_ipv6]));
assert!(!super::is_usable_public_ipv6_candidate_with_mode(
&"2001:db8::2".parse().unwrap(),
&global_ctx,
true,
));
assert!(super::is_usable_public_ipv6_candidate_with_mode(
&"::1".parse().unwrap(),
&global_ctx,
true,
));
}
#[test]
fn udp_ipv6_url_matches_hole_punch_branch_condition() {
let remote_url: url::Url = "udp://[2001:db8::1]:11010".parse().unwrap();
let takes_udp_ipv6_hole_punch_branch =
matches_scheme!(remote_url, TunnelScheme::Ip(IpScheme::Udp))
&& matches!(remote_url.host(), Some(url::Host::Ipv6(_)));
assert!(takes_udp_ipv6_hole_punch_branch);
}
#[test]
fn mapped_listener_port_uses_ip_scheme_defaults() {
assert_eq!(
+180 -15
View File
@@ -1,19 +1,17 @@
use std::{
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
sync::Arc,
};
use std::net::{IpAddr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, idn, network::IPCollector},
common::{dns::socket_addrs, error::Error, global_ctx::ArcGlobalCtx, idn},
connector::dns_connector::DnsTunnelConnector,
proto::common::PeerFeatureFlag,
tunnel::{
self, FromUrl, IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme,
self, IpScheme, IpVersion, TunnelConnector, TunnelError, TunnelScheme,
ring::RingTunnelConnector, tcp::TcpTunnelConnector, udp::UdpTunnelConnector,
},
utils::BoxExt,
};
use http_connector::HttpTunnelConnector;
use rand::seq::SliceRandom;
pub mod direct;
pub mod manual;
@@ -56,7 +54,7 @@ pub(crate) fn should_background_p2p_with_peer(
async fn set_bind_addr_for_peer_connector(
connector: &mut (impl TunnelConnector + ?Sized),
is_ipv4: bool,
ip_collector: &Arc<IPCollector>,
global_ctx: &ArcGlobalCtx,
) {
if cfg!(any(
target_os = "android",
@@ -69,7 +67,7 @@ async fn set_bind_addr_for_peer_connector(
return;
}
let ips = ip_collector.collect_ip_addrs().await;
let ips = global_ctx.get_ip_collector().collect_ip_addrs().await;
if is_ipv4 {
let mut bind_addrs = vec![];
for ipv4 in ips.interface_ipv4s {
@@ -80,7 +78,11 @@ async fn set_bind_addr_for_peer_connector(
} else {
let mut bind_addrs = vec![];
for ipv6 in ips.interface_ipv6s.iter().chain(ips.public_ipv6.iter()) {
let socket_addr = SocketAddrV6::new(std::net::Ipv6Addr::from(*ipv6), 0, 0, 0).into();
let ipv6 = std::net::Ipv6Addr::from(*ipv6);
if global_ctx.is_ip_easytier_managed_ipv6(&ipv6) {
continue;
}
let socket_addr = SocketAddrV6::new(ipv6, 0, 0, 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
@@ -88,6 +90,144 @@ async fn set_bind_addr_for_peer_connector(
let _ = connector;
}
struct ResolvedConnectorAddr {
addr: SocketAddr,
ip_version: IpVersion,
}
fn connector_default_port(url: &url::Url) -> Option<u16> {
url.try_into()
.ok()
.and_then(|s: TunnelScheme| s.try_into().ok())
.map(IpScheme::default_port)
}
fn addr_matches_ip_version(addr: &SocketAddr, ip_version: IpVersion) -> bool {
match ip_version {
IpVersion::V4 => addr.is_ipv4(),
IpVersion::V6 => addr.is_ipv6(),
IpVersion::Both => true,
}
}
fn infer_effective_ip_version(addrs: &[SocketAddr], requested_ip_version: IpVersion) -> IpVersion {
match requested_ip_version {
IpVersion::Both if addrs.iter().all(SocketAddr::is_ipv4) => IpVersion::V4,
IpVersion::Both if addrs.iter().all(SocketAddr::is_ipv6) => IpVersion::V6,
_ => requested_ip_version,
}
}
async fn easytier_managed_ipv6_source_for_dst(
global_ctx: &ArcGlobalCtx,
dst_addr: SocketAddrV6,
) -> Result<Option<Ipv6Addr>, Error> {
let socket = {
let _g = global_ctx.net_ns.guard();
tokio::net::UdpSocket::bind("[::]:0").await?
};
socket.connect(SocketAddr::V6(dst_addr)).await?;
let IpAddr::V6(local_ip) = socket.local_addr()?.ip() else {
return Ok(None);
};
Ok(global_ctx
.is_ip_easytier_managed_ipv6(&local_ip)
.then_some(local_ip))
}
async fn ipv6_connector_reject_reason(
url: &url::Url,
global_ctx: &ArcGlobalCtx,
v6_addr: SocketAddrV6,
skip_source_validation_errors: bool,
) -> Result<Option<String>, Error> {
if global_ctx.is_ip_easytier_managed_ipv6(v6_addr.ip()) {
return Ok(Some(format!(
"{} resolves to EasyTier-managed IPv6 {}",
url,
v6_addr.ip()
)));
}
match easytier_managed_ipv6_source_for_dst(global_ctx, v6_addr).await {
Ok(Some(local_ip)) => Ok(Some(format!(
"{} would use EasyTier-managed IPv6 {} as local source for {}",
url, local_ip, v6_addr
))),
Ok(None) => Ok(None),
Err(err) if skip_source_validation_errors => Ok(Some(format!(
"{} IPv6 candidate {} could not be validated: {}",
url, v6_addr, err
))),
Err(err) => Err(err),
}
}
async fn resolve_connector_socket_addr(
url: &url::Url,
global_ctx: &ArcGlobalCtx,
ip_version: IpVersion,
) -> Result<ResolvedConnectorAddr, Error> {
let addrs = socket_addrs(url, || connector_default_port(url))
.await
.map_err(|e| {
TunnelError::InvalidAddr(format!(
"failed to resolve socket addr, url: {}, error: {}",
url, e
))
})?;
let mut usable_addrs = Vec::new();
let mut rejected_ipv6_reason = None;
let skip_source_validation_errors = ip_version == IpVersion::Both;
for addr in addrs
.into_iter()
.filter(|addr| addr_matches_ip_version(addr, ip_version))
{
if let SocketAddr::V6(v6_addr) = addr
&& let Some(reason) = ipv6_connector_reject_reason(
url,
global_ctx,
v6_addr,
skip_source_validation_errors,
)
.await?
{
rejected_ipv6_reason = Some(reason);
continue;
}
usable_addrs.push(addr);
}
if usable_addrs.is_empty() {
if let Some(reason) = rejected_ipv6_reason {
return Err(Error::InvalidUrl(format!(
"{}, refusing overlay-backed underlay connection",
reason
)));
}
return Err(Error::TunnelError(TunnelError::NoDnsRecordFound(
ip_version,
)));
}
let effective_ip_version = infer_effective_ip_version(&usable_addrs, ip_version);
let addr = usable_addrs
.choose(&mut rand::thread_rng())
.copied()
.ok_or_else(|| Error::TunnelError(TunnelError::NoDnsRecordFound(ip_version)))?;
Ok(ResolvedConnectorAddr {
addr,
ip_version: effective_ip_version,
})
}
pub async fn create_connector_by_url(
url: &str,
global_ctx: &ArcGlobalCtx,
@@ -98,9 +238,11 @@ pub async fn create_connector_by_url(
let scheme = (&url)
.try_into()
.map_err(|_| TunnelError::InvalidProtocol(url.scheme().to_owned()))?;
let mut effective_connector_ip_version = ip_version;
let mut connector: Box<dyn TunnelConnector + 'static> = match scheme {
TunnelScheme::Ip(scheme) => {
let dst_addr = SocketAddr::from_url(url.clone(), ip_version).await?;
let resolved_addr = resolve_connector_socket_addr(&url, global_ctx, ip_version).await?;
effective_connector_ip_version = resolved_addr.ip_version;
let mut connector: Box<dyn TunnelConnector> = match scheme {
IpScheme::Tcp => TcpTunnelConnector::new(url).boxed(),
IpScheme::Udp => UdpTunnelConnector::new(url).boxed(),
@@ -125,11 +267,12 @@ pub async fn create_connector_by_url(
#[cfg(feature = "faketcp")]
IpScheme::FakeTcp => tunnel::fake_tcp::FakeTcpTunnelConnector::new(url).boxed(),
};
connector.set_resolved_addr(resolved_addr.addr);
if global_ctx.config.get_flags().bind_device {
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
resolved_addr.addr.is_ipv4(),
global_ctx,
)
.await;
}
@@ -151,16 +294,38 @@ pub async fn create_connector_by_url(
DnsTunnelConnector::new(url, global_ctx.clone()).boxed()
}
};
connector.set_ip_version(ip_version);
connector.set_ip_version(effective_connector_ip_version);
Ok(connector)
}
#[cfg(test)]
mod tests {
use crate::proto::common::PeerFeatureFlag;
use std::collections::BTreeSet;
use super::{should_background_p2p_with_peer, should_try_p2p_with_peer};
use crate::{
common::global_ctx::tests::get_mock_global_ctx, proto::common::PeerFeatureFlag,
tunnel::IpVersion,
};
use super::{
create_connector_by_url, should_background_p2p_with_peer, should_try_p2p_with_peer,
};
#[tokio::test]
async fn connector_rejects_easytier_managed_ipv6_destination() {
let global_ctx = get_mock_global_ctx();
let public_route: cidr::Ipv6Inet = "2001:db8::2/128".parse().unwrap();
global_ctx.set_public_ipv6_routes(BTreeSet::from([public_route]));
let ret =
create_connector_by_url("tcp://[2001:db8::2]:11010", &global_ctx, IpVersion::V6).await;
assert!(matches!(
ret,
Err(crate::common::error::Error::InvalidUrl(_))
));
}
#[test]
fn lazy_background_p2p_requires_need_p2p() {
+41 -17
View File
@@ -6,6 +6,7 @@ use std::{
use crossbeam::atomic::AtomicCell;
use dashmap::{DashMap, DashSet};
use guarden::defer;
use rand::seq::SliceRandom as _;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tracing::{Instrument, Level, instrument};
@@ -15,7 +16,6 @@ use crate::{
common::{
PeerId, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, upnp,
},
defer,
peers::peer_manager::PeerManager,
proto::common::NatType,
tunnel::{
@@ -719,25 +719,31 @@ async fn check_udp_socket_local_addr(
) -> Result<(), Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(remote_mapped_addr).await?;
if let Ok(local_addr) = socket.local_addr() {
// local_addr should not be equal to virtual ipv4 or virtual ipv6
match local_addr.ip() {
IpAddr::V4(ip) => {
if global_ctx.get_ipv4().map(|ip| ip.address()) == Some(ip) {
return Err(anyhow::anyhow!("local address is virtual ipv4").into());
}
}
IpAddr::V6(ip) => {
if global_ctx.get_ipv6().map(|ip| ip.address()) == Some(ip) {
return Err(anyhow::anyhow!("local address is virtual ipv6").into());
}
}
}
if let Ok(local_addr) = socket.local_addr()
&& let Some(err) = easytier_managed_local_addr_error(&global_ctx, local_addr)
{
return Err(anyhow::anyhow!(err).into());
}
Ok(())
}
fn easytier_managed_local_addr_error(
global_ctx: &ArcGlobalCtx,
local_addr: SocketAddr,
) -> Option<&'static str> {
// local_addr should not be equal to an EasyTier-managed virtual/public address.
match local_addr.ip() {
IpAddr::V4(ip) if global_ctx.get_ipv4().map(|ip| ip.address()) == Some(ip) => {
Some("local address is virtual ipv4")
}
IpAddr::V6(ip) if global_ctx.is_ip_easytier_managed_ipv6(&ip) => {
Some("local address is easytier-managed ipv6")
}
_ => None,
}
}
pub(crate) async fn try_connect_with_socket(
global_ctx: ArcGlobalCtx,
socket: Arc<UdpSocket>,
@@ -763,11 +769,29 @@ pub(crate) async fn try_connect_with_socket(
#[cfg(test)]
mod tests {
use std::{collections::BTreeSet, net::SocketAddr};
use crate::common::global_ctx::tests::get_mock_global_ctx;
use super::{
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS, should_create_public_listener,
should_retry_public_listener_selection,
MAX_PUBLIC_UDP_HOLE_PUNCH_LISTENERS, easytier_managed_local_addr_error,
should_create_public_listener, should_retry_public_listener_selection,
};
#[tokio::test]
async fn local_addr_check_rejects_easytier_public_ipv6_route() {
let global_ctx = get_mock_global_ctx();
let public_route: cidr::Ipv6Inet = "2001:db8::4/128".parse().unwrap();
global_ctx.set_public_ipv6_routes(BTreeSet::from([public_route]));
let local_addr: SocketAddr = "[2001:db8::4]:1234".parse().unwrap();
assert_eq!(
easytier_managed_local_addr_error(&global_ctx, local_addr),
Some("local address is easytier-managed ipv6")
);
}
#[test]
fn listener_selection_prefers_reuse_before_cap() {
assert!(!should_create_public_listener(1, true, true, false, false));
@@ -9,6 +9,7 @@ use std::{
};
use anyhow::Context;
use guarden::defer;
use rand::{Rng, seq::SliceRandom};
use tokio::{net::UdpSocket, sync::RwLock};
use tokio_util::task::AbortOnDropHandle;
@@ -22,7 +23,6 @@ use crate::{
},
handle_rpc_result,
},
defer,
peers::peer_manager::PeerManager,
proto::{
peer_rpc::{
+89 -11
View File
@@ -12,7 +12,6 @@ use crate::{
constants::EASYTIER_VERSION,
log,
},
defer,
instance_manager::NetworkInstanceManager,
launcher::add_proxy_network_to_config,
proto::common::{CompressionAlgoPb, SecureModeConfig},
@@ -23,6 +22,7 @@ use crate::{
use anyhow::Context;
use cidr::IpCidr;
use clap::{CommandFactory, Parser};
use guarden::defer;
use rust_i18n::t;
use std::{
net::{IpAddr, SocketAddr},
@@ -37,6 +37,38 @@ use crate::tunnel::IpScheme;
#[cfg(feature = "jemalloc-prof")]
use jemalloc_ctl::{Access as _, AsName as _, epoch, stats};
fn supported_compression_algorithms() -> &'static str {
cfg_select! {
all(feature = "zstd", feature = "lzo") => "none, zstd, lzo",
feature = "zstd" => "none, zstd",
feature = "lzo" => "none, lzo",
_ => "none",
}
}
fn compression_help() -> String {
t!(
"core_clap.compression",
algorithms = supported_compression_algorithms()
)
.to_string()
}
fn parse_compression_algorithm(compression: &str) -> anyhow::Result<CompressionAlgoPb> {
match compression {
"none" => Ok(CompressionAlgoPb::None),
#[cfg(feature = "zstd")]
"zstd" => Ok(CompressionAlgoPb::Zstd),
#[cfg(feature = "lzo")]
"lzo" => Ok(CompressionAlgoPb::Lzo),
_ => anyhow::bail!(
"unknown compression algorithm: {}, supported: {}",
compression,
supported_compression_algorithms()
),
}
}
#[cfg(target_os = "windows")]
windows_service::define_windows_service!(ffi_service_main, win_service_main);
@@ -171,6 +203,31 @@ struct NetworkOptions {
)]
ipv6: Option<String>,
#[arg(
long,
env = "ET_IPV6_PUBLIC_ADDR_PROVIDER",
help = t!("core_clap.ipv6_public_addr_provider").to_string(),
num_args = 0..=1,
default_missing_value = "true"
)]
ipv6_public_addr_provider: Option<bool>,
#[arg(
long,
env = "ET_IPV6_PUBLIC_ADDR_AUTO",
help = t!("core_clap.ipv6_public_addr_auto").to_string(),
num_args = 0..=1,
default_missing_value = "true"
)]
ipv6_public_addr_auto: Option<bool>,
#[arg(
long,
env = "ET_IPV6_PUBLIC_ADDR_PREFIX",
help = t!("core_clap.ipv6_public_addr_prefix").to_string()
)]
ipv6_public_addr_prefix: Option<String>,
#[arg(
short,
long,
@@ -488,7 +545,7 @@ struct NetworkOptions {
#[arg(
long,
env = "ET_COMPRESSION",
help = t!("core_clap.compression").to_string(),
help = compression_help(),
)]
compression: Option<String>,
@@ -875,6 +932,20 @@ impl NetworkOptions {
})?))
}
if let Some(enabled) = self.ipv6_public_addr_provider {
cfg.set_ipv6_public_addr_provider(enabled);
}
if let Some(enabled) = self.ipv6_public_addr_auto {
cfg.set_ipv6_public_addr_auto(enabled);
}
if let Some(prefix) = &self.ipv6_public_addr_prefix {
cfg.set_ipv6_public_addr_prefix(Some(prefix.parse().with_context(|| {
format!("failed to parse ipv6 public address prefix: {}", prefix)
})?));
}
if !self.peers.is_empty() {
let mut peers = cfg.get_peers();
peers.reserve(peers.len() + self.peers.len());
@@ -1067,15 +1138,7 @@ impl NetworkOptions {
f.need_p2p = self.need_p2p.unwrap_or(f.need_p2p);
f.multi_thread = self.multi_thread.unwrap_or(f.multi_thread);
if let Some(compression) = &self.compression {
f.data_compress_algo = match compression.as_str() {
"none" => CompressionAlgoPb::None,
"zstd" => CompressionAlgoPb::Zstd,
_ => panic!(
"unknown compression algorithm: {}, supported: none, zstd",
compression
),
}
.into();
f.data_compress_algo = parse_compression_algorithm(compression)?.into();
}
f.bind_device = self.bind_device.unwrap_or(f.bind_device);
f.enable_kcp_proxy = self.enable_kcp_proxy.unwrap_or(f.enable_kcp_proxy);
@@ -1588,6 +1651,21 @@ async fn validate_config(cli: &Cli) -> anyhow::Result<()> {
mod tests {
use super::*;
#[test]
fn test_compression_help_uses_supported_algorithms() {
assert!(compression_help().contains(supported_compression_algorithms()));
}
#[test]
fn test_parse_compression_algorithm_rejects_unknown() {
let err = parse_compression_algorithm("snappy")
.unwrap_err()
.to_string();
assert!(err.contains("snappy"));
assert!(err.contains(supported_compression_algorithms()));
}
#[test]
fn test_parse_listeners() {
type IpSchemeMap = fn(&IpScheme) -> String;
+228 -8
View File
@@ -51,13 +51,14 @@ use easytier::{
ListCredentialsRequest, ListCredentialsResponse, ListForeignNetworkRequest,
ListGlobalForeignNetworkRequest, ListMappedListenerRequest, ListPeerRequest,
ListPeerResponse, ListPortForwardRequest, ListPortForwardResponse,
ListRouteRequest, ListRouteResponse, MappedListener, MappedListenerManageRpc,
ListPublicIpv6InfoRequest, ListPublicIpv6InfoResponse, ListRouteRequest,
ListRouteResponse, MappedListener, MappedListenerManageRpc,
MappedListenerManageRpcClientFactory, MetricSnapshot, NodeInfo, PeerManageRpc,
PeerManageRpcClientFactory, PortForwardManageRpc,
PortForwardManageRpcClientFactory, RevokeCredentialRequest, ShowNodeInfoRequest,
StatsRpc, StatsRpcClientFactory, TcpProxyEntryState, TcpProxyEntryTransportType,
TcpProxyRpc, TcpProxyRpcClientFactory, TrustedKeySourcePb, VpnPortalInfo,
VpnPortalRpc, VpnPortalRpcClientFactory,
PortForwardManageRpcClientFactory, RevokeCredentialRequest, Route as ApiRoute,
ShowNodeInfoRequest, StatsRpc, StatsRpcClientFactory, TcpProxyEntryState,
TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory,
TrustedKeySourcePb, VpnPortalInfo, VpnPortalRpc, VpnPortalRpcClientFactory,
instance_identifier::{InstanceSelector, Selector},
list_global_foreign_network_response, list_peer_route_pair,
},
@@ -73,7 +74,7 @@ use easytier::{
common::{NatType, PortForwardConfigPb, SocketType},
peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory},
rpc_impl::standalone::StandAloneClient,
rpc_types::controller::BaseController,
rpc_types::{controller::BaseController, error::Error as RpcError},
},
tunnel::{TunnelScheme, tcp::TcpTunnelConnector},
utils::{PeerRoutePair, string::cost_to_str},
@@ -193,6 +194,7 @@ struct PeerArgs {
#[derive(Subcommand, Debug)]
enum PeerSubCommand {
List,
Ipv6,
ListForeign {
#[arg(
long,
@@ -524,6 +526,40 @@ type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, Error>> + 'a>
type ForeignNetworkMap = BTreeMap<String, ForeignNetworkEntryPb>;
type GlobalForeignNetworkMap = BTreeMap<u32, list_global_foreign_network_response::ForeignNetworks>;
fn is_missing_web_client_service(error: &RpcError) -> bool {
matches!(
error,
RpcError::InvalidServiceKey(service_name, _)
if service_name.trim_matches('"') == "WebClientService"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn missing_web_client_service_matches_raw_service_name() {
let error = RpcError::InvalidServiceKey("WebClientService".to_string(), "".to_string());
assert!(is_missing_web_client_service(&error));
}
#[test]
fn missing_web_client_service_matches_serialized_service_name() {
let error = RpcError::InvalidServiceKey("\"WebClientService\"".to_string(), "".to_string());
assert!(is_missing_web_client_service(&error));
}
#[test]
fn missing_web_client_service_rejects_other_services() {
let error = RpcError::InvalidServiceKey("PeerManageRpc".to_string(), "".to_string());
assert!(!is_missing_web_client_service(&error));
}
}
#[derive(serde::Serialize)]
struct PeerListData {
node_info: NodeInfo,
@@ -536,6 +572,12 @@ struct RouteListData {
peer_routes: Vec<PeerRoutePair>,
}
struct PeerIpv6DataRaw {
node_info: NodeInfo,
routes: Vec<ApiRoute>,
provider_info: ListPublicIpv6InfoResponse,
}
#[derive(serde::Serialize)]
struct PeerCenterRowData {
node_id: String,
@@ -591,9 +633,15 @@ impl<'a> CommandHandler<'a> {
}
let client = self.get_manage_client().await?;
let inst_ids = client
let list_response = match client
.list_network_instance(BaseController::default(), ListNetworkInstanceRequest {})
.await?
.await
{
Ok(response) => response,
Err(error) if is_missing_web_client_service(&error) => return Ok(None),
Err(error) => return Err(error.into()),
};
let inst_ids = list_response
.inst_ids
.into_iter()
.map(uuid::Uuid::from)
@@ -963,6 +1011,27 @@ impl<'a> CommandHandler<'a> {
})
}
async fn fetch_local_public_ipv6_info(&self) -> Result<ListPublicIpv6InfoResponse, Error> {
Ok(self
.get_peer_manager_client()
.await?
.list_public_ipv6_info(
BaseController::default(),
ListPublicIpv6InfoRequest {
instance: Some(self.instance_selector.clone()),
},
)
.await?)
}
async fn fetch_peer_ipv6_data(&self) -> Result<PeerIpv6DataRaw, Error> {
Ok(PeerIpv6DataRaw {
node_info: self.fetch_node_info().await?,
routes: self.list_routes().await?.routes,
provider_info: self.fetch_local_public_ipv6_info().await?,
})
}
async fn fetch_connector_list(&self) -> Result<Vec<Connector>, Error> {
Ok(self
.get_connector_manager_client()
@@ -1375,6 +1444,154 @@ impl<'a> CommandHandler<'a> {
})
}
async fn handle_peer_ipv6(&self) -> Result<(), Error> {
#[derive(tabled::Tabled, serde::Serialize)]
struct PeerIpv6NodeRow {
peer_id: u32,
hostname: String,
inst_id: String,
ipv4: String,
public_ipv6_addr: String,
provider_prefix: String,
}
#[derive(tabled::Tabled, serde::Serialize)]
struct ProviderLeaseRow {
peer_id: u32,
inst_id: String,
leased_addr: String,
valid_until: String,
reused: bool,
}
#[derive(serde::Serialize)]
struct ProviderLeaseSection {
provider_prefix: String,
leases: Vec<ProviderLeaseRow>,
}
#[derive(serde::Serialize)]
struct PeerIpv6View {
nodes: Vec<PeerIpv6NodeRow>,
local_provider: Option<ProviderLeaseSection>,
}
fn fmt_ipv6_inet(value: Option<easytier::proto::common::Ipv6Inet>) -> String {
value
.map(|value| value.to_string())
.unwrap_or_else(|| "-".to_string())
}
fn fmt_valid_until(unix_seconds: i64) -> String {
chrono::DateTime::<chrono::Utc>::from_timestamp(unix_seconds, 0)
.map(|ts| {
ts.with_timezone(&chrono::Local)
.format("%Y-%m-%d %H:%M:%S")
.to_string()
})
.unwrap_or_else(|| unix_seconds.to_string())
}
let build_view = |data: &PeerIpv6DataRaw| {
let mut nodes = Vec::with_capacity(data.routes.len() + 1);
nodes.push(PeerIpv6NodeRow {
peer_id: data.node_info.peer_id,
hostname: data.node_info.hostname.clone(),
inst_id: data.node_info.inst_id.clone(),
ipv4: data.node_info.ipv4_addr.clone(),
public_ipv6_addr: fmt_ipv6_inet(data.node_info.public_ipv6_addr),
provider_prefix: fmt_ipv6_inet(data.node_info.ipv6_public_addr_prefix),
});
nodes.extend(data.routes.iter().map(|route| {
PeerIpv6NodeRow {
peer_id: route.peer_id,
hostname: route.hostname.clone(),
inst_id: route.inst_id.clone(),
ipv4: route
.ipv4_addr
.map(|ipv4| ipv4.to_string())
.unwrap_or_else(|| "-".to_string()),
public_ipv6_addr: fmt_ipv6_inet(route.public_ipv6_addr),
provider_prefix: fmt_ipv6_inet(route.ipv6_public_addr_prefix),
}
}));
nodes.sort_by_key(|row| {
(
row.peer_id != data.node_info.peer_id,
row.peer_id,
row.inst_id.clone(),
)
});
let local_provider = data.provider_info.provider_prefix.map(|provider_prefix| {
let mut leases = data
.provider_info
.provider_leases
.iter()
.map(|lease| ProviderLeaseRow {
peer_id: lease.peer_id,
inst_id: lease.inst_id.clone(),
leased_addr: fmt_ipv6_inet(lease.leased_addr),
valid_until: fmt_valid_until(lease.valid_until_unix_seconds),
reused: lease.reused,
})
.collect::<Vec<_>>();
leases.sort_by_key(|lease| {
(
lease.peer_id,
lease.inst_id.clone(),
lease.leased_addr.clone(),
)
});
ProviderLeaseSection {
provider_prefix: provider_prefix.to_string(),
leases,
}
});
PeerIpv6View {
nodes,
local_provider,
}
};
let results = self
.collect_instance_results(|handler| Box::pin(handler.fetch_peer_ipv6_data()))
.await?;
if self.verbose || *self.output_format == OutputFormat::Json {
return self.print_json_results(
results
.into_iter()
.map(|result| result.map(|data| build_view(&data)))
.collect(),
);
}
self.print_results(&results, |data| {
let view = build_view(data);
print_output(&view.nodes, self.output_format, &[], &[], self.no_trunc)?;
if let Some(local_provider) = view.local_provider {
println!();
println!("Local provider prefix: {}", local_provider.provider_prefix);
if local_provider.leases.is_empty() {
println!("No active provider leases");
} else {
print_output(
&local_provider.leases,
self.output_format,
&[],
&[],
self.no_trunc,
)?;
}
}
Ok(())
})
}
async fn handle_route_dump(&self) -> Result<(), Error> {
let results = self
.collect_instance_results(|handler| Box::pin(handler.fetch_route_dump()))
@@ -2652,6 +2869,9 @@ async fn main() -> Result<(), Error> {
Some(PeerSubCommand::List) => {
handler.handle_peer_list().await?;
}
Some(PeerSubCommand::Ipv6) => {
handler.handle_peer_ipv6().await?;
}
Some(PeerSubCommand::ListForeign { trusted_keys }) => {
handler.handle_foreign_network_list(*trusted_keys).await?;
}
+2 -1
View File
@@ -7,6 +7,7 @@ use std::{
use anyhow::Context;
use bytes::Bytes;
use dashmap::DashMap;
use guarden::defer;
use kcp_sys::{
endpoint::{ConnId, KcpEndpoint, KcpPacketReceiver},
ffi_safe::KcpConfig,
@@ -359,7 +360,7 @@ impl KcpProxyDst {
transport_type: TcpProxyEntryTransportType::Kcp.into(),
},
);
crate::defer! {
defer! {
proxy_entries.remove(&conn_id);
if proxy_entries.capacity() - proxy_entries.len() > 16 {
proxy_entries.shrink_to_fit();
+2 -1
View File
@@ -24,6 +24,7 @@ use bytes::{BufMut, Bytes, BytesMut};
use dashmap::DashMap;
use derivative::Derivative;
use derive_more::{Constructor, Deref, DerefMut, From, Into};
use guarden::defer;
use prost::Message;
use quinn::udp::{EcnCodepoint, RecvMeta, Transmit};
use quinn::{
@@ -662,7 +663,7 @@ impl QuicStreamReceiver {
transport_type: TcpProxyEntryTransportType::Quic.into(),
},
);
crate::defer! {
defer! {
proxy_entries.remove(&handle);
if proxy_entries.capacity() - proxy_entries.len() > 16 {
proxy_entries.shrink_to_fit();
@@ -1,6 +1,5 @@
// translated from tailscale #32ce1bdb48078ec4cedaeeb5b1b2ff9c0ef61a49
use crate::defer;
use anyhow::{Context, Result};
use dbus::blocking::stdintf::org_freedesktop_dbus::Properties as _;
use std::fs;
@@ -167,6 +166,7 @@ fn new_os_configurator(_interface_name: String) -> Result<()> {
Ok(())
}
use guarden::defer;
use std::io::{self, BufRead, Cursor};
/// 返回 `resolv.conf` 内容的拥有者("systemd-resolved"、"NetworkManager"、"resolvconf" 或空字符串)
+143 -3
View File
@@ -9,7 +9,6 @@ use std::time::Duration;
use anyhow::Context;
use cidr::{IpCidr, Ipv4Inet};
use futures::FutureExt;
use tokio::sync::{Mutex, Notify};
#[cfg(feature = "tun")]
@@ -65,6 +64,11 @@ use crate::vpn_portal::{self, VpnPortal};
#[cfg(feature = "magic-dns")]
use super::dns_server::{MAGIC_DNS_FAKE_IP, runner::DnsRunner};
use super::listeners::ListenerManager;
use super::public_ipv6_provider::{
reconcile_public_ipv6_provider_runtime, run_public_ipv6_provider_reconcile_task,
should_run_public_ipv6_provider_reconcile, validate_public_ipv6_config,
validate_public_ipv6_config_values,
};
#[cfg(feature = "socks5")]
use crate::gateway::socks5::Socks5Server;
@@ -253,11 +257,64 @@ pub struct InstanceConfigPatcher {
}
impl InstanceConfigPatcher {
fn parse_ipv6_public_addr_prefix_patch(
prefix: Option<&str>,
) -> Result<Option<Option<cidr::Ipv6Cidr>>, anyhow::Error> {
let Some(prefix) = prefix else {
return Ok(None);
};
let prefix = prefix.trim();
if prefix.is_empty() {
return Ok(Some(None));
}
let parsed = prefix
.parse()
.with_context(|| format!("failed to parse ipv6 public address prefix: {prefix}"))?;
Ok(Some(Some(parsed)))
}
fn effective_ipv6_for_public_ipv6_validation(
global_ctx: &ArcGlobalCtx,
patch: &crate::proto::api::config::InstanceConfigPatch,
_auto_enabled: bool,
) -> Option<cidr::Ipv6Inet> {
if let Some(ipv6) = patch.ipv6 {
return Some(ipv6.into());
}
global_ctx.get_ipv6()
}
fn validate_public_ipv6_patch(
global_ctx: &ArcGlobalCtx,
patch: &crate::proto::api::config::InstanceConfigPatch,
) -> Result<Option<Option<cidr::Ipv6Cidr>>, anyhow::Error> {
let parsed_prefix =
Self::parse_ipv6_public_addr_prefix_patch(patch.ipv6_public_addr_prefix.as_deref())?;
let auto_enabled = patch
.ipv6_public_addr_auto
.unwrap_or(global_ctx.config.get_ipv6_public_addr_auto());
let provider_enabled = patch
.ipv6_public_addr_provider
.unwrap_or(global_ctx.config.get_ipv6_public_addr_provider());
let prefix =
parsed_prefix.unwrap_or_else(|| global_ctx.config.get_ipv6_public_addr_prefix());
let ipv6 = Self::effective_ipv6_for_public_ipv6_validation(global_ctx, patch, auto_enabled);
validate_public_ipv6_config_values(ipv6, provider_enabled, auto_enabled, prefix)?;
Ok(parsed_prefix)
}
pub async fn apply_patch(
&self,
patch: crate::proto::api::config::InstanceConfigPatch,
) -> Result<(), anyhow::Error> {
let patch_for_event = patch.clone();
let global_ctx = weak_upgrade(&self.global_ctx)?;
let parsed_ipv6_public_addr_prefix = Self::validate_public_ipv6_patch(&global_ctx, &patch)?;
self.patch_port_forwards(patch.port_forwards).await?;
self.patch_acl(patch.acl).await?;
@@ -267,7 +324,8 @@ impl InstanceConfigPatcher {
self.patch_mapped_listeners(patch.mapped_listeners).await?;
self.patch_connector(patch.connectors).await?;
let global_ctx = weak_upgrade(&self.global_ctx)?;
let provider_reconcile_was_running = should_run_public_ipv6_provider_reconcile(&global_ctx);
let mut provider_config_changed = false;
if let Some(hostname) = patch.hostname {
global_ctx.set_hostname(hostname.clone());
global_ctx.config.set_hostname(Some(hostname));
@@ -282,9 +340,35 @@ impl InstanceConfigPatcher {
global_ctx.set_ipv6(Some(ipv6.into()));
global_ctx.config.set_ipv6(Some(ipv6.into()));
}
if let Some(disable_relay_data) = patch.disable_relay_data {
let mut flags = global_ctx.get_flags();
flags.disable_relay_data = disable_relay_data;
global_ctx.set_flags(flags);
}
if let Some(enabled) = patch.ipv6_public_addr_provider {
global_ctx.config.set_ipv6_public_addr_provider(enabled);
provider_config_changed = true;
}
if let Some(enabled) = patch.ipv6_public_addr_auto {
global_ctx.config.set_ipv6_public_addr_auto(enabled);
}
if let Some(prefix) = parsed_ipv6_public_addr_prefix {
global_ctx.config.set_ipv6_public_addr_prefix(prefix);
provider_config_changed = true;
}
global_ctx.issue_event(GlobalCtxEvent::ConfigPatched(patch_for_event));
if provider_config_changed {
reconcile_public_ipv6_provider_runtime(&global_ctx).await;
let provider_reconcile_should_run =
should_run_public_ipv6_provider_reconcile(&global_ctx);
if !provider_reconcile_was_running && provider_reconcile_should_run {
run_public_ipv6_provider_reconcile_task(&global_ctx);
}
}
Ok(())
}
@@ -664,6 +748,12 @@ impl Instance {
Ok(())
}
async fn prepare_public_ipv6_config(&self) -> Result<(), Error> {
validate_public_ipv6_config(&self.global_ctx)?;
reconcile_public_ipv6_provider_runtime(&self.global_ctx).await;
Ok(())
}
// use a mock nic ctx to consume packets.
#[cfg(feature = "tun")]
async fn clear_nic_ctx(
@@ -932,6 +1022,7 @@ impl Instance {
}
pub async fn run(&mut self) -> Result<(), Error> {
self.prepare_public_ipv6_config().await?;
self.listener_manager
.lock()
.await
@@ -939,6 +1030,7 @@ impl Instance {
.await?;
self.listener_manager.lock().await.run().await?;
self.peer_manager.run().await?;
run_public_ipv6_provider_reconcile_task(&self.global_ctx);
#[cfg(feature = "tun")]
{
@@ -1544,7 +1636,9 @@ impl Drop for Instance {
#[cfg(test)]
mod tests {
use crate::{
instance::instance::InstanceRpcServerHook, proto::rpc_impl::standalone::RpcServerHook,
common::global_ctx::tests::get_mock_global_ctx,
instance::instance::{InstanceConfigPatcher, InstanceRpcServerHook},
proto::{api::config::InstanceConfigPatch, rpc_impl::standalone::RpcServerHook},
};
#[tokio::test]
@@ -1665,4 +1759,50 @@ mod tests {
}
}
}
#[tokio::test]
async fn validate_public_ipv6_patch_rejects_non_global_prefix() {
let global_ctx = get_mock_global_ctx();
let patch = InstanceConfigPatch {
ipv6_public_addr_provider: Some(true),
ipv6_public_addr_prefix: Some("fd00::/64".to_string()),
..Default::default()
};
let err =
InstanceConfigPatcher::validate_public_ipv6_patch(&global_ctx, &patch).unwrap_err();
assert!(
err.to_string()
.contains("not a valid global unicast IPv6 prefix")
);
}
#[tokio::test]
async fn validate_public_ipv6_patch_allows_enabling_auto_with_manual_ipv6() {
let global_ctx = get_mock_global_ctx();
global_ctx.set_ipv6(Some("fd00::1/64".parse().unwrap()));
let patch = InstanceConfigPatch {
ipv6_public_addr_auto: Some(true),
..Default::default()
};
assert!(InstanceConfigPatcher::validate_public_ipv6_patch(&global_ctx, &patch).is_ok());
}
#[tokio::test]
async fn validate_public_ipv6_patch_ignores_runtime_auto_ipv6_cache() {
let global_ctx = get_mock_global_ctx();
global_ctx.config.set_ipv6_public_addr_auto(true);
global_ctx.set_ipv6(Some("2001:db8::10/64".parse().unwrap()));
let patch = InstanceConfigPatch {
ipv6_public_addr_provider: Some(true),
ipv6_public_addr_prefix: Some("2001:db8:100::/64".to_string()),
..Default::default()
};
assert!(InstanceConfigPatcher::validate_public_ipv6_patch(&global_ctx, &patch).is_ok());
}
}
+3 -3
View File
@@ -25,7 +25,7 @@ use crate::{
pub fn create_listener_by_url(
l: &url::Url,
global_ctx: ArcGlobalCtx,
_global_ctx: ArcGlobalCtx,
) -> Result<Box<dyn TunnelListener>, Error> {
Ok(match l.try_into()? {
TunnelScheme::Ip(scheme) => match scheme {
@@ -34,7 +34,7 @@ pub fn create_listener_by_url(
#[cfg(feature = "wireguard")]
IpScheme::Wg => {
use crate::tunnel::wireguard::{WgConfig, WgTunnelListener};
let nid = global_ctx.get_network_identity();
let nid = _global_ctx.get_network_identity();
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
@@ -43,7 +43,7 @@ pub fn create_listener_by_url(
}
#[cfg(feature = "quic")]
IpScheme::Quic => {
tunnel::quic::QuicTunnelListener::new(l.clone(), global_ctx.clone()).boxed()
tunnel::quic::QuicTunnelListener::new(l.clone(), _global_ctx.clone()).boxed()
}
#[cfg(feature = "websocket")]
IpScheme::Ws | IpScheme::Wss => {
+2
View File
@@ -4,6 +4,8 @@ pub mod instance;
pub mod listeners;
mod public_ipv6_provider;
pub mod proxy_cidrs_monitor;
#[cfg(feature = "tun")]
@@ -0,0 +1,910 @@
use std::{path::Path, sync::Arc};
use anyhow::Context;
use cidr::{Ipv6Cidr, Ipv6Inet};
#[cfg(target_os = "linux")]
use netlink_packet_route::route::{RouteAddress, RouteAttribute, RouteMessage, RouteType};
#[cfg(target_os = "linux")]
use crate::common::ifcfg::{get_interface_index, list_ipv6_route_messages};
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
};
const PUBLIC_IPV6_PROVIDER_RECONCILE_INTERVAL: std::time::Duration =
std::time::Duration::from_secs(5);
const PUBLIC_IPV6_PROVIDER_RECONCILE_MAX_RETRIES: usize = 3;
#[derive(Debug, Clone, PartialEq, Eq)]
enum PublicIpv6ProviderRuntimeState {
Disabled,
Pending(String),
Active(Ipv6Cidr),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct PublicIpv6ProviderConfigSnapshot {
provider_enabled: bool,
configured_prefix: Option<Ipv6Cidr>,
}
fn read_public_ipv6_provider_config_snapshot(
global_ctx: &ArcGlobalCtx,
) -> PublicIpv6ProviderConfigSnapshot {
PublicIpv6ProviderConfigSnapshot {
provider_enabled: global_ctx.config.get_ipv6_public_addr_provider(),
configured_prefix: global_ctx.config.get_ipv6_public_addr_prefix(),
}
}
fn should_run_public_ipv6_provider_reconcile_task(
config: PublicIpv6ProviderConfigSnapshot,
) -> bool {
config.provider_enabled && config.configured_prefix.is_none()
}
pub(super) fn should_run_public_ipv6_provider_reconcile(global_ctx: &ArcGlobalCtx) -> bool {
should_run_public_ipv6_provider_reconcile_task(read_public_ipv6_provider_config_snapshot(
global_ctx,
))
}
fn is_global_routable_public_ipv6_prefix(prefix: Ipv6Cidr) -> bool {
let addr = prefix.first_address();
!addr.is_loopback()
&& !addr.is_multicast()
&& !addr.is_unicast_link_local()
&& !addr.is_unique_local()
&& !addr.is_unspecified()
}
pub(super) fn validate_public_ipv6_config_values(
_ipv6: Option<Ipv6Inet>,
provider_enabled: bool,
_auto_enabled: bool,
prefix: Option<Ipv6Cidr>,
) -> Result<(), Error> {
if !provider_enabled {
return Ok(());
}
ensure_public_ipv6_provider_supported()?;
if let Some(prefix) = prefix
&& !is_global_routable_public_ipv6_prefix(prefix)
{
return Err(anyhow::anyhow!(
"the prefix {} is not a valid global unicast IPv6 prefix; it must be a routable address range, not a private, link-local, or multicast address",
prefix
)
.into());
}
Ok(())
}
pub(super) fn validate_public_ipv6_config(global_ctx: &ArcGlobalCtx) -> Result<(), Error> {
validate_public_ipv6_config_values(
global_ctx.get_ipv6(),
global_ctx.config.get_ipv6_public_addr_provider(),
global_ctx.config.get_ipv6_public_addr_auto(),
global_ctx.config.get_ipv6_public_addr_prefix(),
)
}
fn ensure_public_ipv6_provider_supported() -> Result<(), Error> {
if cfg!(target_os = "linux") {
return Ok(());
}
Err(anyhow::anyhow!(
"the provider feature requires Linux; run without --ipv6-public-addr-provider on this node, or move the provider role to a Linux node. client mode (--ipv6-public-addr-auto) works on all platforms"
)
.into())
}
fn public_ipv6_provider_auto_detect_error() -> Error {
anyhow::anyhow!(
"no public IPv6 prefix found on this system; set --ipv6-public-addr-prefix manually, or check that your ISP has delegated an IPv6 prefix and a default-from route exists in the kernel routing table"
)
.into()
}
#[cfg(target_os = "linux")]
fn read_linux_proc_bool(path: &Path) -> Result<bool, Error> {
let value = std::fs::read_to_string(path)
.with_context(|| format!("failed to read {}", path.display()))?;
match value.trim() {
"0" => Ok(false),
"1" => Ok(true),
other => Err(anyhow::anyhow!("unexpected value '{}' in {}", other, path.display()).into()),
}
}
#[cfg(target_os = "linux")]
fn write_linux_proc_bool(path: &Path, enabled: bool) -> Result<(), Error> {
let value = if enabled { "1\n" } else { "0\n" };
std::fs::write(path, value).with_context(|| format!("failed to write {}", path.display()))?;
Ok(())
}
#[cfg(target_os = "linux")]
fn ensure_linux_ipv6_forwarding_at_paths(
all_path: &Path,
default_path: &Path,
) -> Result<bool, Error> {
let all_enabled = read_linux_proc_bool(all_path)?;
let default_enabled = read_linux_proc_bool(default_path)?;
let mut changed = false;
if !all_enabled {
write_linux_proc_bool(all_path, true)?;
changed = true;
}
if !default_enabled {
write_linux_proc_bool(default_path, true)?;
changed = true;
}
if !read_linux_proc_bool(all_path)? || !read_linux_proc_bool(default_path)? {
return Err(anyhow::anyhow!(
"failed to enable Linux IPv6 forwarding in {} and {}",
all_path.display(),
default_path.display()
)
.into());
}
Ok(changed)
}
#[cfg(target_os = "linux")]
fn ensure_linux_ipv6_forwarding() -> Result<bool, Error> {
let all_path = Path::new("/proc/sys/net/ipv6/conf/all/forwarding");
let default_path = Path::new("/proc/sys/net/ipv6/conf/default/forwarding");
ensure_linux_ipv6_forwarding_at_paths(all_path, default_path).map_err(|err| {
anyhow::anyhow!(
"public IPv6 provider requires Linux IPv6 forwarding; failed to enable net.ipv6.conf.all.forwarding=1 and net.ipv6.conf.default.forwarding=1 automatically: {}. run with sufficient privileges or set them manually",
err
)
.into()
})
}
#[cfg(target_os = "linux")]
#[derive(Clone, Debug, PartialEq, Eq)]
struct DetectedIpv6Route {
dst: Option<Ipv6Cidr>,
src: Option<Ipv6Cidr>,
ifindex: Option<u32>,
kind: RouteType,
}
#[cfg(target_os = "linux")]
fn ipv6_cidr_from_route_addr(addr: RouteAddress, prefix_len: u8) -> Option<Ipv6Cidr> {
match addr {
RouteAddress::Inet6(addr) => Ipv6Cidr::new(addr, prefix_len).ok(),
_ => None,
}
}
#[cfg(target_os = "linux")]
impl TryFrom<RouteMessage> for DetectedIpv6Route {
type Error = Error;
fn try_from(message: RouteMessage) -> Result<Self, Self::Error> {
let dst = message.attributes.iter().find_map(|attr| match attr {
RouteAttribute::Destination(addr) => {
ipv6_cidr_from_route_addr(addr.clone(), message.header.destination_prefix_length)
}
_ => None,
});
let src = message.attributes.iter().find_map(|attr| match attr {
RouteAttribute::Source(addr) => {
ipv6_cidr_from_route_addr(addr.clone(), message.header.source_prefix_length)
}
_ => None,
});
let ifindex = message.attributes.iter().find_map(|attr| match attr {
RouteAttribute::Oif(index) => Some(*index),
_ => None,
});
Ok(Self {
dst,
src,
ifindex,
kind: message.header.kind,
})
}
}
#[cfg(target_os = "linux")]
fn is_ipv6_default_route(dst: Option<Ipv6Cidr>) -> bool {
dst.is_none() || dst == Some(Ipv6Cidr::new(std::net::Ipv6Addr::UNSPECIFIED, 0).unwrap())
}
#[cfg(target_os = "linux")]
fn detect_public_ipv6_prefix_from_routes(
routes: &[DetectedIpv6Route],
loopback_ifindex: u32,
) -> Option<Ipv6Cidr> {
routes
.iter()
.filter_map(|route| {
if !is_ipv6_default_route(route.dst) {
return None;
}
let prefix = route.src?;
let wan_ifindex = route.ifindex?;
if !is_global_routable_public_ipv6_prefix(prefix) {
return None;
}
let delegated = routes.iter().any(|candidate| {
candidate.dst == Some(prefix)
&& candidate.ifindex.is_some()
&& candidate.ifindex != Some(wan_ifindex)
&& candidate.ifindex != Some(loopback_ifindex)
&& candidate.kind == RouteType::Unicast
});
delegated.then_some(prefix)
})
.min_by_key(|prefix| prefix.network_length())
}
#[cfg(target_os = "linux")]
async fn detect_public_ipv6_prefix_linux() -> Result<Option<Ipv6Cidr>, Error> {
let routes = list_ipv6_route_messages().with_context(|| "failed to query linux ipv6 routes")?;
let routes = routes
.iter()
.cloned()
.map(DetectedIpv6Route::try_from)
.collect::<Result<Vec<_>, _>>()?;
let loopback_ifindex =
get_interface_index("lo").with_context(|| "failed to resolve linux loopback ifindex")?;
Ok(detect_public_ipv6_prefix_from_routes(
&routes,
loopback_ifindex,
))
}
#[cfg(not(target_os = "linux"))]
async fn detect_public_ipv6_prefix_linux() -> Result<Option<Ipv6Cidr>, Error> {
Ok(None)
}
fn invalid_public_ipv6_prefix_state(
prefix: Ipv6Cidr,
source: &str,
) -> PublicIpv6ProviderRuntimeState {
PublicIpv6ProviderRuntimeState::Pending(format!(
"the {} prefix {} is not a valid global unicast IPv6 prefix",
source, prefix
))
}
#[cfg(target_os = "linux")]
async fn resolve_public_ipv6_provider_runtime_state_linux(
global_ctx: &ArcGlobalCtx,
configured_prefix: Option<Ipv6Cidr>,
) -> PublicIpv6ProviderRuntimeState {
let _g = global_ctx.net_ns.guard();
if let Err(err) = ensure_linux_ipv6_forwarding() {
return PublicIpv6ProviderRuntimeState::Pending(err.to_string());
}
if let Some(prefix) = configured_prefix {
if !is_global_routable_public_ipv6_prefix(prefix) {
return invalid_public_ipv6_prefix_state(prefix, "configured");
}
return PublicIpv6ProviderRuntimeState::Active(prefix);
}
match detect_public_ipv6_prefix_linux().await {
Ok(Some(prefix)) if is_global_routable_public_ipv6_prefix(prefix) => {
PublicIpv6ProviderRuntimeState::Active(prefix)
}
Ok(Some(prefix)) => invalid_public_ipv6_prefix_state(prefix, "detected"),
Ok(None) => PublicIpv6ProviderRuntimeState::Pending(
public_ipv6_provider_auto_detect_error().to_string(),
),
Err(err) => PublicIpv6ProviderRuntimeState::Pending(err.to_string()),
}
}
async fn resolve_public_ipv6_provider_runtime_state(
global_ctx: &ArcGlobalCtx,
config: PublicIpv6ProviderConfigSnapshot,
) -> PublicIpv6ProviderRuntimeState {
if !config.provider_enabled {
return PublicIpv6ProviderRuntimeState::Disabled;
}
#[cfg(target_os = "linux")]
{
return resolve_public_ipv6_provider_runtime_state_linux(
global_ctx,
config.configured_prefix,
)
.await;
}
#[cfg(not(target_os = "linux"))]
{
let _ = config.configured_prefix;
PublicIpv6ProviderRuntimeState::Pending(
ensure_public_ipv6_provider_supported()
.unwrap_err()
.to_string(),
)
}
}
fn apply_public_ipv6_provider_runtime_state(
global_ctx: &ArcGlobalCtx,
state: &PublicIpv6ProviderRuntimeState,
) -> bool {
let next_prefix = match state {
PublicIpv6ProviderRuntimeState::Active(prefix) => Some(*prefix),
PublicIpv6ProviderRuntimeState::Disabled | PublicIpv6ProviderRuntimeState::Pending(_) => {
None
}
};
let prefix_changed = global_ctx.set_advertised_ipv6_public_addr_prefix(next_prefix);
let next_provider_enabled = matches!(state, PublicIpv6ProviderRuntimeState::Active(_));
let feature_changed =
global_ctx.set_ipv6_public_addr_provider_feature_flag(next_provider_enabled);
prefix_changed || feature_changed
}
fn try_apply_public_ipv6_provider_runtime_state(
global_ctx: &ArcGlobalCtx,
config: PublicIpv6ProviderConfigSnapshot,
state: &PublicIpv6ProviderRuntimeState,
) -> Option<bool> {
(read_public_ipv6_provider_config_snapshot(global_ctx) == config)
.then(|| apply_public_ipv6_provider_runtime_state(global_ctx, state))
}
fn current_public_ipv6_provider_runtime_state(
global_ctx: &ArcGlobalCtx,
) -> PublicIpv6ProviderRuntimeState {
match (
global_ctx.get_feature_flags().ipv6_public_addr_provider,
global_ctx.get_advertised_ipv6_public_addr_prefix(),
) {
(false, _) => PublicIpv6ProviderRuntimeState::Disabled,
(true, Some(prefix)) => PublicIpv6ProviderRuntimeState::Active(prefix),
(true, None) => PublicIpv6ProviderRuntimeState::Pending(
"public IPv6 provider runtime is missing an advertised prefix".to_string(),
),
}
}
async fn reconcile_public_ipv6_provider_runtime_with_state(
global_ctx: &ArcGlobalCtx,
) -> (PublicIpv6ProviderRuntimeState, bool) {
for attempt in 0..PUBLIC_IPV6_PROVIDER_RECONCILE_MAX_RETRIES {
let config = read_public_ipv6_provider_config_snapshot(global_ctx);
let next_state = resolve_public_ipv6_provider_runtime_state(global_ctx, config).await;
if let Some(changed) =
try_apply_public_ipv6_provider_runtime_state(global_ctx, config, &next_state)
{
return (next_state, changed);
}
tracing::debug!(
attempt = attempt + 1,
max_retries = PUBLIC_IPV6_PROVIDER_RECONCILE_MAX_RETRIES,
"public IPv6 provider config changed during reconcile, retrying"
);
}
tracing::warn!(
max_retries = PUBLIC_IPV6_PROVIDER_RECONCILE_MAX_RETRIES,
"skipping public IPv6 provider reconcile because config kept changing"
);
(
current_public_ipv6_provider_runtime_state(global_ctx),
false,
)
}
pub(super) async fn reconcile_public_ipv6_provider_runtime(global_ctx: &ArcGlobalCtx) -> bool {
reconcile_public_ipv6_provider_runtime_with_state(global_ctx)
.await
.1
}
pub(super) fn run_public_ipv6_provider_reconcile_task(global_ctx: &ArcGlobalCtx) {
if !should_run_public_ipv6_provider_reconcile_task(read_public_ipv6_provider_config_snapshot(
global_ctx,
)) {
return;
}
let global_ctx = Arc::downgrade(global_ctx);
tokio::spawn(async move {
let Some(initial_ctx) = global_ctx.upgrade() else {
return;
};
let mut event_receiver = initial_ctx.subscribe();
let mut last_state: Option<PublicIpv6ProviderRuntimeState> = None;
loop {
let Some(global_ctx) = global_ctx.upgrade() else {
tracing::debug!("global ctx dropped, stopping public ipv6 provider reconcile");
return;
};
let (next_state, changed) =
reconcile_public_ipv6_provider_runtime_with_state(&global_ctx).await;
if last_state.as_ref() != Some(&next_state) {
match &next_state {
PublicIpv6ProviderRuntimeState::Disabled if last_state.is_some() => {
tracing::info!("public IPv6 provider disabled");
}
PublicIpv6ProviderRuntimeState::Disabled => {}
PublicIpv6ProviderRuntimeState::Pending(reason) => {
tracing::warn!(reason = %reason, "public IPv6 provider not ready");
}
PublicIpv6ProviderRuntimeState::Active(prefix) => {
tracing::info!(prefix = %prefix, "public IPv6 provider is active");
}
}
} else if changed {
tracing::info!("public IPv6 provider runtime state changed");
}
last_state = Some(next_state);
if matches!(
last_state.as_ref(),
Some(PublicIpv6ProviderRuntimeState::Disabled)
) {
match event_receiver.recv().await {
Ok(GlobalCtxEvent::ConfigPatched(_)) => {}
Ok(_) => {}
Err(tokio::sync::broadcast::error::RecvError::Closed) => return,
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
event_receiver = event_receiver.resubscribe();
}
}
} else {
tokio::select! {
recv = event_receiver.recv() => match recv {
Ok(GlobalCtxEvent::ConfigPatched(_)) => {}
Ok(_) => {}
Err(tokio::sync::broadcast::error::RecvError::Closed) => return,
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
event_receiver = event_receiver.resubscribe();
}
},
_ = tokio::time::sleep(PUBLIC_IPV6_PROVIDER_RECONCILE_INTERVAL) => {}
}
}
}
});
}
#[cfg(test)]
mod tests {
#[cfg(target_os = "linux")]
use std::fs;
#[cfg(target_os = "linux")]
use std::path::PathBuf;
#[cfg(target_os = "linux")]
use std::process::Command;
use std::sync::Arc;
#[cfg(target_os = "linux")]
use netlink_packet_route::route::RouteType;
#[cfg(target_os = "linux")]
use super::{
DetectedIpv6Route, detect_public_ipv6_prefix_from_routes, detect_public_ipv6_prefix_linux,
ensure_linux_ipv6_forwarding_at_paths, ensure_public_ipv6_provider_supported,
public_ipv6_provider_auto_detect_error,
};
use super::{
PublicIpv6ProviderConfigSnapshot, PublicIpv6ProviderRuntimeState,
read_public_ipv6_provider_config_snapshot, should_run_public_ipv6_provider_reconcile_task,
try_apply_public_ipv6_provider_runtime_state,
};
#[cfg(not(target_os = "linux"))]
use super::{ensure_public_ipv6_provider_supported, public_ipv6_provider_auto_detect_error};
use crate::common::{
config::{ConfigLoader, TomlConfigLoader},
global_ctx::GlobalCtx,
};
#[cfg(target_os = "linux")]
fn run_ip(args: &[&str]) {
let output = Command::new("ip")
.args(args)
.output()
.expect("failed to execute ip process");
assert!(
output.status.success(),
"ip command failed: {:?}\nstdout: {}\nstderr: {}",
args,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
);
}
#[cfg(target_os = "linux")]
fn test_iface_name(tag: &str) -> String {
format!("et{}{:x}", tag, std::process::id() & 0xffff)
}
#[cfg(target_os = "linux")]
struct ScopedDummyLink {
name: String,
}
#[cfg(target_os = "linux")]
impl ScopedDummyLink {
fn new(name: &str) -> Self {
let _ = Command::new("ip").args(["link", "del", name]).output();
run_ip(&["link", "add", name, "type", "dummy"]);
run_ip(&["link", "set", name, "up"]);
Self {
name: name.to_string(),
}
}
}
#[cfg(target_os = "linux")]
impl Drop for ScopedDummyLink {
fn drop(&mut self) {
let _ = Command::new("ip")
.args(["link", "del", &self.name])
.output();
}
}
#[cfg(target_os = "linux")]
fn temp_forwarding_paths(
all_value: &str,
default_value: &str,
) -> (tempfile::TempDir, PathBuf, PathBuf) {
let dir = tempfile::tempdir().unwrap();
let all_path = dir.path().join("all_forwarding");
let default_path = dir.path().join("default_forwarding");
fs::write(&all_path, all_value).unwrap();
fs::write(&default_path, default_value).unwrap();
(dir, all_path, default_path)
}
#[cfg(target_os = "linux")]
fn route(
dst: Option<&str>,
src: Option<&str>,
ifindex: Option<u32>,
kind: RouteType,
) -> DetectedIpv6Route {
DetectedIpv6Route {
dst: dst.map(|cidr| cidr.parse().unwrap()),
src: src.map(|cidr| cidr.parse().unwrap()),
ifindex,
kind,
}
}
#[cfg(target_os = "linux")]
#[test]
fn test_detect_public_ipv6_prefix_from_routes_selects_delegated_prefix() {
let routes = vec![
route(None, Some("2001:db8:1::/56"), Some(2), RouteType::Unicast),
route(Some("2001:db8:1::/56"), None, Some(3), RouteType::Unicast),
];
assert_eq!(
detect_public_ipv6_prefix_from_routes(&routes, 1),
Some("2001:db8:1::/56".parse().unwrap())
);
}
#[cfg(target_os = "linux")]
#[test]
fn test_detect_public_ipv6_prefix_from_routes_rejects_non_public_prefixes() {
let routes = vec![
route(Some("::/0"), Some("fd00::/48"), Some(2), RouteType::Unicast),
route(Some("fd00::/48"), None, Some(3), RouteType::Unicast),
route(None, Some("fe80::/64"), Some(4), RouteType::Unicast),
route(Some("fe80::/64"), None, Some(5), RouteType::Unicast),
route(None, Some("ff00::/8"), Some(6), RouteType::Unicast),
route(Some("ff00::/8"), None, Some(7), RouteType::Unicast),
route(None, Some("::/0"), Some(8), RouteType::Unicast),
route(Some("::/0"), None, Some(9), RouteType::Unicast),
];
assert_eq!(detect_public_ipv6_prefix_from_routes(&routes, 1), None);
}
#[cfg(target_os = "linux")]
#[test]
fn test_detect_public_ipv6_prefix_from_routes_requires_delegated_route() {
let routes = vec![route(
None,
Some("2001:db8:1::/56"),
Some(2),
RouteType::Unicast,
)];
assert_eq!(detect_public_ipv6_prefix_from_routes(&routes, 1), None);
}
#[cfg(target_os = "linux")]
#[test]
fn test_detect_public_ipv6_prefix_from_routes_rejects_loopback_delegation() {
let routes = vec![
route(None, Some("2001:db8:1::/56"), Some(2), RouteType::Unicast),
route(Some("2001:db8:1::/56"), None, Some(1), RouteType::Unicast),
];
assert_eq!(detect_public_ipv6_prefix_from_routes(&routes, 1), None);
}
#[cfg(target_os = "linux")]
#[test]
fn test_detect_public_ipv6_prefix_from_routes_prefers_shortest_prefix() {
let routes = vec![
route(None, Some("2001:db8:1::/56"), Some(2), RouteType::Unicast),
route(Some("2001:db8:1::/56"), None, Some(3), RouteType::Unicast),
route(None, Some("2001:db8::/48"), Some(4), RouteType::Unicast),
route(Some("2001:db8::/48"), None, Some(5), RouteType::Unicast),
];
assert_eq!(
detect_public_ipv6_prefix_from_routes(&routes, 1),
Some("2001:db8::/48".parse().unwrap())
);
}
#[cfg(target_os = "linux")]
#[test]
fn test_detect_public_ipv6_prefix_from_routes_rejects_non_unicast_delegation() {
let routes = vec![
route(None, Some("2001:db8:1::/56"), Some(2), RouteType::Unicast),
route(Some("2001:db8:1::/56"), None, Some(3), RouteType::BlackHole),
];
assert_eq!(detect_public_ipv6_prefix_from_routes(&routes, 1), None);
}
#[test]
fn test_public_ipv6_provider_auto_detect_error_mentions_manual_prefix() {
let err = public_ipv6_provider_auto_detect_error();
let msg = err.to_string();
assert!(msg.contains("IPv6 prefix"), "{}", msg);
assert!(msg.contains("ipv6-public-addr-prefix"), "{}", msg);
}
fn test_global_ctx() -> Arc<GlobalCtx> {
Arc::new(GlobalCtx::new(TomlConfigLoader::default()))
}
#[tokio::test]
async fn test_read_public_ipv6_provider_config_snapshot_reads_provider_fields() {
let global_ctx = test_global_ctx();
let prefix = "2001:db8::/48".parse().unwrap();
global_ctx.config.set_ipv6_public_addr_provider(true);
global_ctx.config.set_ipv6_public_addr_prefix(Some(prefix));
assert_eq!(
read_public_ipv6_provider_config_snapshot(&global_ctx),
PublicIpv6ProviderConfigSnapshot {
provider_enabled: true,
configured_prefix: Some(prefix),
}
);
}
#[test]
fn test_reconcile_task_only_runs_for_auto_detect_provider() {
assert!(!should_run_public_ipv6_provider_reconcile_task(
PublicIpv6ProviderConfigSnapshot {
provider_enabled: false,
configured_prefix: None,
}
));
assert!(!should_run_public_ipv6_provider_reconcile_task(
PublicIpv6ProviderConfigSnapshot {
provider_enabled: true,
configured_prefix: Some("2001:db8::/48".parse().unwrap()),
}
));
assert!(should_run_public_ipv6_provider_reconcile_task(
PublicIpv6ProviderConfigSnapshot {
provider_enabled: true,
configured_prefix: None,
}
));
}
#[tokio::test]
async fn test_try_apply_public_ipv6_provider_runtime_state_rejects_stale_config() {
let global_ctx = test_global_ctx();
let prefix = "2001:db8::/48".parse().unwrap();
let config = PublicIpv6ProviderConfigSnapshot {
provider_enabled: true,
configured_prefix: Some(prefix),
};
global_ctx.config.set_ipv6_public_addr_provider(false);
global_ctx.config.set_ipv6_public_addr_prefix(None);
let changed = try_apply_public_ipv6_provider_runtime_state(
&global_ctx,
config,
&PublicIpv6ProviderRuntimeState::Active(prefix),
);
assert_eq!(changed, None);
assert_eq!(global_ctx.get_advertised_ipv6_public_addr_prefix(), None);
assert!(!global_ctx.get_feature_flags().ipv6_public_addr_provider);
}
#[tokio::test]
async fn test_try_apply_public_ipv6_provider_runtime_state_applies_matching_config() {
let global_ctx = test_global_ctx();
let prefix = "2001:db8::/48".parse().unwrap();
global_ctx.config.set_ipv6_public_addr_provider(true);
global_ctx.config.set_ipv6_public_addr_prefix(Some(prefix));
let config = read_public_ipv6_provider_config_snapshot(&global_ctx);
let changed = try_apply_public_ipv6_provider_runtime_state(
&global_ctx,
config,
&PublicIpv6ProviderRuntimeState::Active(prefix),
);
assert_eq!(changed, Some(true));
assert_eq!(
global_ctx.get_advertised_ipv6_public_addr_prefix(),
Some(prefix)
);
assert!(global_ctx.get_feature_flags().ipv6_public_addr_provider);
}
#[cfg(target_os = "linux")]
#[test]
fn test_public_ipv6_provider_platform_check_accepts_linux() {
assert!(ensure_public_ipv6_provider_supported().is_ok());
}
#[cfg(target_os = "linux")]
#[test]
fn test_ensure_linux_ipv6_forwarding_enables_all_and_default() {
let (_dir, all_path, default_path) = temp_forwarding_paths("0\n", "0\n");
let changed = ensure_linux_ipv6_forwarding_at_paths(&all_path, &default_path).unwrap();
assert!(changed);
assert_eq!(fs::read_to_string(&all_path).unwrap(), "1\n");
assert_eq!(fs::read_to_string(&default_path).unwrap(), "1\n");
}
#[cfg(target_os = "linux")]
#[test]
fn test_ensure_linux_ipv6_forwarding_is_noop_when_already_enabled() {
let (_dir, all_path, default_path) = temp_forwarding_paths("1\n", "1\n");
let changed = ensure_linux_ipv6_forwarding_at_paths(&all_path, &default_path).unwrap();
assert!(!changed);
assert_eq!(fs::read_to_string(&all_path).unwrap(), "1\n");
assert_eq!(fs::read_to_string(&default_path).unwrap(), "1\n");
}
#[cfg(not(target_os = "linux"))]
#[test]
fn test_public_ipv6_provider_platform_check_reports_linux_only() {
let err = ensure_public_ipv6_provider_supported().unwrap_err();
let msg = err.to_string();
assert!(msg.contains("Linux"), "{}", msg);
assert!(msg.contains("ipv6-public-addr-auto"), "{}", msg);
}
#[cfg(target_os = "linux")]
#[serial_test::serial]
#[tokio::test]
async fn test_detect_public_ipv6_prefix_linux_reads_netlink_routes_from_kernel() {
let wan_if = test_iface_name("dw");
let lan_if = test_iface_name("dl");
let _wan = ScopedDummyLink::new(&wan_if);
let _lan = ScopedDummyLink::new(&lan_if);
run_ip(&[
"-6",
"addr",
"add",
"2001:db8:100:ffff::1/64",
"dev",
&wan_if,
]);
run_ip(&[
"-6",
"route",
"add",
"default",
"from",
"2001:db8:100::/56",
"dev",
&wan_if,
]);
run_ip(&["-6", "route", "add", "2001:db8:100::/56", "dev", &lan_if]);
assert_eq!(
detect_public_ipv6_prefix_linux().await.unwrap(),
Some("2001:db8:100::/56".parse().unwrap())
);
}
#[cfg(target_os = "linux")]
#[serial_test::serial]
#[tokio::test]
async fn test_detect_public_ipv6_prefix_linux_prefers_shortest_prefix_from_kernel() {
let wan_if_1 = test_iface_name("sw1");
let lan_if_1 = test_iface_name("sl1");
let wan_if_2 = test_iface_name("sw2");
let lan_if_2 = test_iface_name("sl2");
let _wan_1 = ScopedDummyLink::new(&wan_if_1);
let _lan_1 = ScopedDummyLink::new(&lan_if_1);
let _wan_2 = ScopedDummyLink::new(&wan_if_2);
let _lan_2 = ScopedDummyLink::new(&lan_if_2);
run_ip(&[
"-6",
"addr",
"add",
"2001:db8:3000:ffff::1/64",
"dev",
&wan_if_1,
]);
run_ip(&[
"-6",
"route",
"add",
"default",
"from",
"2001:db8:3000::/56",
"dev",
&wan_if_1,
]);
run_ip(&["-6", "route", "add", "2001:db8:3000::/56", "dev", &lan_if_1]);
run_ip(&["-6", "addr", "add", "2001:db9:ffff::1/64", "dev", &wan_if_2]);
run_ip(&[
"-6",
"route",
"add",
"default",
"from",
"2001:db9::/48",
"dev",
&wan_if_2,
]);
run_ip(&["-6", "route", "add", "2001:db9::/48", "dev", &lan_if_2]);
assert_eq!(
detect_public_ipv6_prefix_linux().await.unwrap(),
Some("2001:db9::/48".parse().unwrap())
);
}
}
+194 -4
View File
@@ -735,9 +735,26 @@ impl VirtualNic {
}
pub async fn add_ipv6_route(&self, address: Ipv6Addr, cidr: u8) -> Result<(), Error> {
self.add_ipv6_route_with_cost(address, cidr, None).await
}
pub async fn add_ipv6_route_with_cost(
&self,
address: Ipv6Addr,
cidr: u8,
cost: Option<i32>,
) -> Result<(), Error> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.add_ipv6_route(self.ifname(), address, cidr, None)
.add_ipv6_route(self.ifname(), address, cidr, cost)
.await?;
Ok(())
}
pub async fn remove_ipv6_route(&self, address: Ipv6Addr, cidr: u8) -> Result<(), Error> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.remove_ipv6_route(self.ifname(), address, cidr)
.await?;
Ok(())
}
@@ -903,7 +920,7 @@ impl NicCtx {
}
let src_ipv6 = ipv6.get_source();
let dst_ipv6 = ipv6.get_destination();
let my_ipv6 = mgr.get_global_ctx().get_ipv6().map(|x| x.address());
let is_local_src = mgr.get_global_ctx().is_ip_local_ipv6(&src_ipv6);
tracing::trace!(
?ret,
?src_ipv6,
@@ -911,14 +928,14 @@ impl NicCtx {
"[USER_PACKET] recv new packet from tun device and forward to peers."
);
if src_ipv6.is_unicast_link_local() && Some(src_ipv6) != my_ipv6 {
if src_ipv6.is_unicast_link_local() && !is_local_src {
// do not route link local packet to other nodes unless the address is assigned by user
return;
}
// TODO: use zero-copy
let send_ret = mgr
.send_msg_by_ip(ret, IpAddr::V6(dst_ipv6), Some(src_ipv6) == my_ipv6)
.send_msg_by_ip(ret, IpAddr::V6(dst_ipv6), is_local_src)
.await;
if send_ret.is_err() {
tracing::trace!(?send_ret, "[USER_PACKET] send_msg failed")
@@ -1039,6 +1056,44 @@ impl NicCtx {
}
}
async fn apply_public_ipv6_route_changes(
ifcfg: &impl IfConfiguerTrait,
ifname: &str,
net_ns: &crate::common::netns::NetNS,
cur_routes: &mut BTreeSet<cidr::Ipv6Inet>,
added: Vec<cidr::Ipv6Inet>,
removed: Vec<cidr::Ipv6Inet>,
) {
for route in removed {
if !cur_routes.contains(&route) {
continue;
}
let _g = net_ns.guard();
let ret = ifcfg
.remove_ipv6_route(ifname, route.address(), route.network_length())
.await;
if ret.is_err() {
tracing::trace!(route = ?route, err = ?ret, "remove public ipv6 route failed");
}
cur_routes.remove(&route);
}
for route in added {
if cur_routes.contains(&route) {
continue;
}
let _g = net_ns.guard();
let ret = ifcfg
.add_ipv6_route(ifname, route.address(), route.network_length(), None)
.await;
if ret.is_err() {
tracing::trace!(route = ?route, err = ?ret, "add public ipv6 route failed");
} else {
cur_routes.insert(route);
}
}
}
async fn run_proxy_cidrs_route_updater(&mut self) -> Result<(), Error> {
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(anyhow::anyhow!("peer manager not available").into());
@@ -1114,6 +1169,137 @@ impl NicCtx {
Ok(())
}
async fn run_public_ipv6_route_updater(&mut self) -> Result<(), Error> {
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(anyhow::anyhow!("peer manager not available").into());
};
let global_ctx = self.global_ctx.clone();
let net_ns = self.global_ctx.net_ns.clone();
let nic = self.nic.lock().await;
let ifcfg = nic.get_ifcfg();
let ifname = nic.ifname().to_owned();
let mut event_receiver = global_ctx.subscribe();
self.tasks.spawn(async move {
let mut cur_routes = BTreeSet::<cidr::Ipv6Inet>::new();
let initial_routes = peer_mgr.list_public_ipv6_routes().await;
let initial_added = initial_routes.iter().copied().collect::<Vec<_>>();
Self::apply_public_ipv6_route_changes(
&ifcfg,
&ifname,
&net_ns,
&mut cur_routes,
initial_added,
Vec::new(),
)
.await;
loop {
let event = match event_receiver.recv().await {
Ok(event) => event,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
event_receiver = event_receiver.resubscribe();
let latest = peer_mgr.list_public_ipv6_routes().await;
let added = latest.difference(&cur_routes).copied().collect::<Vec<_>>();
let removed = cur_routes.difference(&latest).copied().collect::<Vec<_>>();
GlobalCtxEvent::PublicIpv6RoutesUpdated(added, removed)
}
};
let (added, removed) = match event {
GlobalCtxEvent::PublicIpv6RoutesUpdated(added, removed) => (added, removed),
_ => continue,
};
Self::apply_public_ipv6_route_changes(
&ifcfg,
&ifname,
&net_ns,
&mut cur_routes,
added,
removed,
)
.await;
}
});
Ok(())
}
async fn run_public_ipv6_addr_updater(&mut self) -> Result<(), Error> {
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(anyhow::anyhow!("peer manager not available").into());
};
let global_ctx = self.global_ctx.clone();
let nic = self.nic.clone();
let mut event_receiver = global_ctx.subscribe();
self.tasks.spawn(async move {
let mut current_addr = peer_mgr.get_my_public_ipv6_addr().await;
if let Some(addr) = current_addr {
let nic = nic.lock().await;
if let Err(err) = nic.link_up().await {
tracing::warn!(?err, "failed to bring public ipv6 nic link up");
}
if let Err(err) = nic.add_ipv6(addr.address(), addr.network_length() as i32).await {
tracing::warn!(addr = ?addr, ?err, "failed to add public ipv6 address");
}
if let Err(err) = nic
.add_ipv6_route_with_cost(Ipv6Addr::UNSPECIFIED, 0, Some(5))
.await
{
tracing::warn!(route = %Ipv6Addr::UNSPECIFIED, prefix = 0, ?err, "failed to add default public ipv6 route");
}
}
loop {
let event = match event_receiver.recv().await {
Ok(event) => event,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
event_receiver = event_receiver.resubscribe();
let latest = peer_mgr.get_my_public_ipv6_addr().await;
GlobalCtxEvent::PublicIpv6Changed(current_addr, latest)
}
};
let (old, new) = match event {
GlobalCtxEvent::PublicIpv6Changed(old, new) => (old, new),
_ => continue,
};
current_addr = new;
let nic = nic.lock().await;
if let Err(err) = nic.link_up().await {
tracing::warn!(?err, "failed to bring public ipv6 nic link up");
}
if let Some(old) = old {
if let Err(err) = nic.remove_ipv6_route(Ipv6Addr::UNSPECIFIED, 0).await {
tracing::warn!(route = %Ipv6Addr::UNSPECIFIED, prefix = 0, ?err, "failed to remove default public ipv6 route");
}
if let Err(err) = nic.remove_ipv6(Some(old)).await {
tracing::warn!(addr = ?old, ?err, "failed to remove old public ipv6 address");
}
}
if let Some(new) = new {
if let Err(err) = nic.add_ipv6(new.address(), new.network_length() as i32).await
{
tracing::warn!(addr = ?new, ?err, "failed to add public ipv6 address");
}
if let Err(err) = nic
.add_ipv6_route_with_cost(Ipv6Addr::UNSPECIFIED, 0, Some(5))
.await
{
tracing::warn!(route = %Ipv6Addr::UNSPECIFIED, prefix = 0, ?err, "failed to add default public ipv6 route");
}
}
}
});
Ok(())
}
pub async fn run(
&mut self,
ipv4_addr: Option<cidr::Ipv4Inet>,
@@ -1169,6 +1355,10 @@ impl NicCtx {
}
self.run_proxy_cidrs_route_updater().await?;
self.run_public_ipv6_route_updater().await?;
// Keep the updater running so runtime config patches can enable auto mode
// without recreating the NIC.
self.run_public_ipv6_addr_updater().await?;
Ok(())
}
+14
View File
@@ -435,6 +435,20 @@ fn handle_event(
event!(info, ?ip, "[{}] dhcp ip conflict", instance_id);
}
GlobalCtxEvent::PublicIpv6Changed(old, new) => {
event!(info, ?old, ?new, "[{}] public ipv6 changed", instance_id);
}
GlobalCtxEvent::PublicIpv6RoutesUpdated(added, removed) => {
event!(
info,
?added,
?removed,
"[{}] public ipv6 routes updated",
instance_id
);
}
GlobalCtxEvent::PortForwardAdded(cfg) => {
event!(
info,
+34
View File
@@ -714,6 +714,24 @@ impl NetworkConfig {
flags.use_smoltcp = use_smoltcp;
}
if let Some(ipv6_public_addr_provider) = self.ipv6_public_addr_provider {
cfg.set_ipv6_public_addr_provider(ipv6_public_addr_provider);
}
if let Some(ipv6_public_addr_auto) = self.ipv6_public_addr_auto {
cfg.set_ipv6_public_addr_auto(ipv6_public_addr_auto);
}
if let Some(ipv6_public_addr_prefix) = self
.ipv6_public_addr_prefix
.as_ref()
.filter(|prefix| !prefix.is_empty())
{
cfg.set_ipv6_public_addr_prefix(Some(ipv6_public_addr_prefix.parse().with_context(
|| format!("failed to parse ipv6 public address prefix: {ipv6_public_addr_prefix}"),
)?));
}
if let Some(disable_ipv6) = self.disable_ipv6 {
flags.enable_ipv6 = !disable_ipv6;
}
@@ -798,6 +816,10 @@ impl NetworkConfig {
flags.disable_upnp = disable_upnp;
}
if let Some(disable_relay_data) = self.disable_relay_data {
flags.disable_relay_data = disable_relay_data;
}
if let Some(disable_sym_hole_punching) = self.disable_sym_hole_punching {
flags.disable_sym_hole_punching = disable_sym_hole_punching;
}
@@ -863,6 +885,17 @@ impl NetworkConfig {
result.network_length = Some(ipv4.network_length() as i32);
}
if config.get_ipv6_public_addr_provider() != default_config.get_ipv6_public_addr_provider()
{
result.ipv6_public_addr_provider = Some(config.get_ipv6_public_addr_provider());
}
if config.get_ipv6_public_addr_auto() != default_config.get_ipv6_public_addr_auto() {
result.ipv6_public_addr_auto = Some(config.get_ipv6_public_addr_auto());
}
result.ipv6_public_addr_prefix = config
.get_ipv6_public_addr_prefix()
.map(|prefix| prefix.to_string());
let peers = config.get_peers();
result.networking_method = Some(NetworkingMethod::Manual as i32);
if !peers.is_empty() {
@@ -961,6 +994,7 @@ impl NetworkConfig {
result.disable_tcp_hole_punching = Some(flags.disable_tcp_hole_punching);
result.disable_udp_hole_punching = Some(flags.disable_udp_hole_punching);
result.disable_upnp = Some(flags.disable_upnp);
result.disable_relay_data = Some(flags.disable_relay_data);
result.disable_sym_hole_punching = Some(flags.disable_sym_hole_punching);
result.enable_magic_dns = Some(flags.accept_dns);
result.mtu = Some(flags.mtu as i32);
+1 -24
View File
@@ -65,7 +65,7 @@ impl PeerCenterBase {
return Err(Error::Shutdown);
};
rpc_mgr.rpc_server().registry().register(
PeerCenterRpcServer::new(PeerCenterServer::new(self.peer_mgr.my_peer_id())),
PeerCenterRpcServer::new(PeerCenterServer::new()),
&self.peer_mgr.get_global_ctx().get_network_name(),
);
Ok(())
@@ -486,7 +486,6 @@ impl PeerCenterPeerManagerTrait for PeerMapWithPeerRpcManager {
#[cfg(test)]
mod tests {
use crate::{
peer_center::server::get_global_data,
peers::tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
tunnel::common::tests::wait_for_condition,
};
@@ -515,25 +514,6 @@ mod tests {
.await
.unwrap();
let center_peer = PeerCenterBase::select_center_peer(&peer_mgr_a)
.await
.unwrap();
let center_data = get_global_data(center_peer);
// wait center_data has 3 records for 10 seconds
wait_for_condition(
|| async {
if center_data.global_peer_map.len() == 4 {
println!("center data {:#?}", center_data.global_peer_map);
true
} else {
false
}
},
Duration::from_secs(20),
)
.await;
let mut digest = None;
for pc in peer_centers.iter() {
let rpc_service = pc.get_rpc_service();
@@ -578,8 +558,5 @@ mod tests {
route_cost.end_update();
assert!(!route_cost.need_update());
}
let global_digest = get_global_data(center_peer).digest.load();
assert_eq!(digest.as_ref().unwrap(), &global_digest);
}
}
+96 -30
View File
@@ -6,7 +6,6 @@ use std::{
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use once_cell::sync::Lazy;
use tokio::task::JoinSet;
use crate::{
@@ -35,50 +34,41 @@ pub(crate) struct PeerCenterInfoEntry {
update_time: std::time::Instant,
}
#[derive(Default)]
pub(crate) struct PeerCenterServerGlobalData {
pub(crate) global_peer_map: DashMap<SrcDstPeerPair, PeerCenterInfoEntry>,
pub(crate) peer_report_time: DashMap<PeerId, std::time::Instant>,
pub(crate) digest: AtomicCell<Digest>,
}
// a global unique instance for PeerCenterServer
pub(crate) static GLOBAL_DATA: Lazy<DashMap<PeerId, Arc<PeerCenterServerGlobalData>>> =
Lazy::new(DashMap::new);
pub(crate) fn get_global_data(node_id: PeerId) -> Arc<PeerCenterServerGlobalData> {
GLOBAL_DATA
.entry(node_id)
.or_insert_with(|| Arc::new(PeerCenterServerGlobalData::default()))
.value()
.clone()
#[derive(Debug, Default)]
struct PeerCenterServerData {
global_peer_map: DashMap<SrcDstPeerPair, PeerCenterInfoEntry>,
peer_report_time: DashMap<PeerId, std::time::Instant>,
digest: AtomicCell<Digest>,
}
#[derive(Clone, Debug)]
pub struct PeerCenterServer {
// every peer has its own server, so use per-struct dash map is ok.
my_node_id: PeerId,
data: Arc<PeerCenterServerData>,
tasks: Arc<JoinSet<()>>,
}
impl PeerCenterServer {
pub fn new(my_node_id: PeerId) -> Self {
pub fn new() -> Self {
let data = Arc::new(PeerCenterServerData::default());
let weak_data = Arc::downgrade(&data);
let mut tasks = JoinSet::new();
tasks.spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
PeerCenterServer::clean_outdated_peer(my_node_id).await;
let Some(data) = weak_data.upgrade() else {
break;
};
PeerCenterServer::clean_outdated_peer_data(&data).await;
}
});
PeerCenterServer {
my_node_id,
data,
tasks: Arc::new(tasks),
}
}
async fn clean_outdated_peer(my_node_id: PeerId) {
let data = get_global_data(my_node_id);
async fn clean_outdated_peer_data(data: &PeerCenterServerData) {
data.peer_report_time.retain(|_, v| {
std::time::Instant::now().duration_since(*v) < std::time::Duration::from_secs(180)
});
@@ -88,8 +78,7 @@ impl PeerCenterServer {
});
}
fn calc_global_digest(my_node_id: PeerId) -> Digest {
let data = get_global_data(my_node_id);
fn calc_global_digest_data(data: &PeerCenterServerData) -> Digest {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
data.global_peer_map
.iter()
@@ -117,7 +106,7 @@ impl PeerCenterRpc for PeerCenterServer {
tracing::debug!("receive report_peers");
let data = get_global_data(self.my_node_id);
let data = &self.data;
data.peer_report_time
.insert(my_peer_id, std::time::Instant::now());
@@ -134,7 +123,7 @@ impl PeerCenterRpc for PeerCenterServer {
}
data.digest
.store(PeerCenterServer::calc_global_digest(self.my_node_id));
.store(PeerCenterServer::calc_global_digest_data(data));
Ok(ReportPeersResponse::default())
}
@@ -147,7 +136,7 @@ impl PeerCenterRpc for PeerCenterServer {
) -> Result<GetGlobalPeerMapResponse, rpc_types::error::Error> {
let digest = req.digest;
let data = get_global_data(self.my_node_id);
let data = &self.data;
if digest == data.digest.load() && digest != 0 {
return Ok(GetGlobalPeerMapResponse::default());
}
@@ -171,3 +160,80 @@ impl PeerCenterRpc for PeerCenterServer {
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn server_clones_share_instance_data() {
let server = PeerCenterServer::new();
let server_clone = server.clone();
let mut peers = PeerInfoForGlobalMap::default();
peers
.direct_peers
.insert(100, DirectConnectedPeerInfo { latency_ms: 3 });
server
.report_peers(
BaseController::default(),
ReportPeersRequest {
my_peer_id: 99,
peer_infos: Some(peers),
},
)
.await
.unwrap();
let resp = server_clone
.get_global_peer_map(
BaseController::default(),
GetGlobalPeerMapRequest { digest: 0 },
)
.await
.unwrap();
assert_eq!(1, resp.global_peer_map.len());
assert!(resp.global_peer_map[&99].direct_peers.contains_key(&100));
}
#[tokio::test]
async fn independent_server_instances_do_not_share_data() {
let server_a = PeerCenterServer::new();
let server_b = PeerCenterServer::new();
let mut peers = PeerInfoForGlobalMap::default();
peers
.direct_peers
.insert(101, DirectConnectedPeerInfo { latency_ms: 5 });
server_a
.report_peers(
BaseController::default(),
ReportPeersRequest {
my_peer_id: 100,
peer_infos: Some(peers),
},
)
.await
.unwrap();
let resp_a = server_a
.get_global_peer_map(
BaseController::default(),
GetGlobalPeerMapRequest { digest: 0 },
)
.await
.unwrap();
assert_eq!(1, resp_a.global_peer_map.len());
let resp_b = server_b
.get_global_peer_map(
BaseController::default(),
GetGlobalPeerMapRequest { digest: 0 },
)
.await
.unwrap();
assert!(resp_b.global_peer_map.is_empty());
}
}
+113 -12
View File
@@ -94,6 +94,8 @@ impl AclFilter {
/// Preserves connection tracking and rate limiting state across reloads
/// Now lock-free and doesn't require &mut self!
pub fn reload_rules(&self, acl_config: Option<&Acl>) {
self.outbound_allow_records.clear();
let Some(acl_config) = acl_config else {
self.acl_enabled.store(false, Ordering::Relaxed);
return;
@@ -292,13 +294,33 @@ impl AclFilter {
processor.increment_stat(AclStatKey::PacketsTotal);
}
fn classify_chain_type(
is_in: bool,
packet_info: &PacketInfo,
my_ipv4: Option<Ipv4Addr>,
is_local_ipv6: impl Fn(Ipv6Addr) -> bool,
) -> ChainType {
if !is_in {
return ChainType::Outbound;
}
let is_local_dst = packet_info.dst_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED)
|| matches!(packet_info.dst_ip, IpAddr::V6(dst) if is_local_ipv6(dst));
if is_local_dst {
ChainType::Inbound
} else {
ChainType::Forward
}
}
/// Common ACL processing logic
pub fn process_packet_with_acl(
&self,
packet: &ZCPacket,
is_in: bool,
my_ipv4: Option<Ipv4Addr>,
my_ipv6: Option<Ipv6Addr>,
is_local_ipv6: impl Fn(Ipv6Addr) -> bool,
route: &(dyn super::route_trait::Route + Send + Sync + 'static),
) -> bool {
if !self.acl_enabled.load(Ordering::Relaxed) {
@@ -323,17 +345,7 @@ impl AclFilter {
}
};
let chain_type = if is_in {
if packet_info.dst_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED)
|| packet_info.dst_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED)
{
ChainType::Inbound
} else {
ChainType::Forward
}
} else {
ChainType::Outbound
};
let chain_type = Self::classify_chain_type(is_in, &packet_info, my_ipv4, is_local_ipv6);
// Get current processor atomically
let processor = self.get_processor();
@@ -384,3 +396,92 @@ impl AclFilter {
}
}
}
#[cfg(test)]
mod tests {
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
sync::Arc,
time::Instant,
};
use crate::{
common::acl_processor::PacketInfo,
proto::acl::{Acl, ChainType, Protocol},
};
use super::{AclFilter, OutboundAllowRecord};
fn packet_info(dst_ip: IpAddr) -> PacketInfo {
PacketInfo {
src_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
dst_ip,
src_port: Some(1234),
dst_port: Some(80),
protocol: Protocol::Tcp,
packet_size: 64,
src_groups: Arc::new(Vec::new()),
dst_groups: Arc::new(Vec::new()),
}
}
#[test]
fn classify_chain_type_treats_public_ipv6_lease_as_inbound() {
let leased_ipv6 = Ipv6Addr::new(0x2001, 0xdb8, 0x100, 0, 0, 0, 0, 0x123);
let packet_info = packet_info(IpAddr::V6(leased_ipv6));
let chain =
AclFilter::classify_chain_type(true, &packet_info, None, |ip| ip == leased_ipv6);
assert_eq!(chain, ChainType::Inbound);
}
#[test]
fn classify_chain_type_keeps_non_local_ipv6_as_forward() {
let leased_ipv6 = Ipv6Addr::new(0x2001, 0xdb8, 0x100, 0, 0, 0, 0, 0x123);
let packet_info = packet_info(IpAddr::V6(Ipv6Addr::new(
0x2001, 0xdb8, 0xffff, 2, 0, 0, 0, 0x100,
)));
let chain =
AclFilter::classify_chain_type(true, &packet_info, None, |ip| ip == leased_ipv6);
assert_eq!(chain, ChainType::Forward);
}
#[tokio::test]
async fn reload_rules_clears_outbound_allow_records() {
let filter = AclFilter::new();
filter.outbound_allow_records.insert(
OutboundAllowRecord {
src_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
src_port: Some(1234),
dst_port: Some(80),
protocol: Protocol::Tcp,
},
Instant::now(),
);
assert_eq!(filter.outbound_allow_records.len(), 1);
filter.reload_rules(Some(&Acl::default()));
assert_eq!(filter.outbound_allow_records.len(), 0);
filter.outbound_allow_records.insert(
OutboundAllowRecord {
src_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
dst_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
src_port: Some(4321),
dst_port: Some(443),
protocol: Protocol::Tcp,
},
Instant::now(),
);
assert_eq!(filter.outbound_allow_records.len(), 1);
filter.reload_rules(None);
assert_eq!(filter.outbound_allow_records.len(), 0);
}
}
+301 -20
View File
@@ -56,7 +56,7 @@ use super::{
route_trait::NextHopPolicy,
traffic_metrics::{
InstanceLabelKind, LogicalTrafficMetrics, TrafficKind, TrafficMetricRecorder,
route_peer_info_instance_id, traffic_kind,
is_relay_data_packet_type, route_peer_info_instance_id, traffic_kind,
},
};
@@ -69,11 +69,16 @@ pub trait GlobalForeignNetworkAccessor: Send + Sync + 'static {
struct ForeignNetworkEntry {
my_peer_id: PeerId,
// Node-global runtime flags, such as disable_relay_data, live on the parent
// context. The foreign context is scoped to the foreign network's OSPF view.
parent_global_ctx: ArcGlobalCtx,
global_ctx: ArcGlobalCtx,
network: NetworkIdentity,
peer_map: Arc<PeerMap>,
relay_peer_map: Arc<RelayPeerMap>,
peer_session_store: Arc<PeerSessionStore>,
// Static per-network permission from the whitelist check. disable_relay_data
// is the node-wide runtime override layered on top of this value.
relay_data: bool,
pm_packet_sender: Mutex<Option<PacketRecvChan>>,
@@ -82,7 +87,7 @@ struct ForeignNetworkEntry {
packet_recv: Mutex<Option<PacketRecvChanReceiver>>,
bps_limiter: Arc<TokenBucket>,
bps_limiter: Option<Arc<TokenBucket>>,
peer_center: Arc<PeerCenterInstance>,
@@ -186,14 +191,16 @@ impl ForeignNetworkEntry {
);
let relay_bps_limit = global_ctx.config.get_flags().foreign_relay_bps_limit;
let limiter_config = LimiterConfig {
burst_rate: None,
bps: Some(relay_bps_limit),
fill_duration_ms: None,
};
let bps_limiter = global_ctx
.token_bucket_manager()
.get_or_create(&network.network_name, limiter_config.into());
let bps_limiter = (relay_bps_limit != u64::MAX).then(|| {
let limiter_config = LimiterConfig {
burst_rate: None,
bps: Some(relay_bps_limit),
fill_duration_ms: None,
};
global_ctx
.token_bucket_manager()
.get_or_create(&network.network_name, limiter_config.into())
});
let peer_center = Arc::new(PeerCenterInstance::new(Arc::new(
PeerMapWithPeerRpcManager {
@@ -205,6 +212,7 @@ impl ForeignNetworkEntry {
Self {
my_peer_id,
parent_global_ctx: global_ctx.clone(),
global_ctx: foreign_global_ctx,
network,
peer_map,
@@ -231,6 +239,27 @@ impl ForeignNetworkEntry {
}
}
fn desired_avoid_relay_data_feature_flag(
parent_global_ctx: &ArcGlobalCtx,
relay_data: bool,
) -> bool {
!relay_data || parent_global_ctx.get_feature_flags().avoid_relay_data
}
fn sync_parent_relay_data_feature_flag(
parent_global_ctx: &ArcGlobalCtx,
global_ctx: &ArcGlobalCtx,
relay_data: bool,
) -> bool {
let avoid_relay_data =
Self::desired_avoid_relay_data_feature_flag(parent_global_ctx, relay_data);
if global_ctx.get_feature_flags().avoid_relay_data == avoid_relay_data {
return false;
}
global_ctx.set_avoid_relay_data_preference(avoid_relay_data)
}
fn build_foreign_global_ctx(
network: &NetworkIdentity,
global_ctx: ArcGlobalCtx,
@@ -258,10 +287,9 @@ impl ForeignNetworkEntry {
let mut feature_flag = global_ctx.get_feature_flags();
feature_flag.is_public_server = true;
if !relay_data {
feature_flag.avoid_relay_data = true;
}
foreign_global_ctx.set_feature_flags(feature_flag);
feature_flag.avoid_relay_data =
Self::desired_avoid_relay_data_feature_flag(&global_ctx, relay_data);
foreign_global_ctx.set_base_advertised_feature_flags(feature_flag);
for u in global_ctx.get_running_listeners().into_iter() {
foreign_global_ctx.add_running_listener(u);
@@ -412,6 +440,7 @@ impl ForeignNetworkEntry {
let peer_map = self.peer_map.clone();
let relay_peer_map = self.relay_peer_map.clone();
let traffic_metrics = self.traffic_metrics.clone();
let parent_global_ctx = self.parent_global_ctx.clone();
let relay_data = self.relay_data;
let pm_sender = self.pm_packet_sender.lock().await.take().unwrap();
let network_name = self.network.network_name.clone();
@@ -497,14 +526,21 @@ impl ForeignNetworkEntry {
"ignore packet in foreign network"
);
} else {
if packet_type == PacketType::Data as u8
|| packet_type == PacketType::KcpSrc as u8
|| packet_type == PacketType::KcpDst as u8
{
if !relay_data {
if is_relay_data_packet_type(packet_type) {
let disable_relay_data = parent_global_ctx.flags_arc().disable_relay_data;
if !relay_data || disable_relay_data {
tracing::debug!(
?from_peer_id,
?to_peer_id,
packet_type,
disable_relay_data,
"drop foreign network relay data"
);
continue;
}
if !bps_limiter.try_consume(len.into()) {
if let Some(bps_limiter) = bps_limiter.as_ref()
&& !bps_limiter.try_consume(len.into())
{
continue;
}
}
@@ -589,10 +625,31 @@ impl ForeignNetworkEntry {
});
}
async fn run_parent_feature_flag_sync_routine(&self) {
let parent_global_ctx = self.parent_global_ctx.clone();
let global_ctx = self.global_ctx.clone();
let relay_data = self.relay_data;
self.tasks.lock().await.spawn(async move {
let mut parent_events = parent_global_ctx.subscribe();
loop {
ForeignNetworkEntry::sync_parent_relay_data_feature_flag(
&parent_global_ctx,
&global_ctx,
relay_data,
);
if parent_events.recv().await.is_err() {
parent_events = parent_global_ctx.subscribe();
}
}
});
}
async fn prepare(&self, accessor: Box<dyn GlobalForeignNetworkAccessor>) {
self.prepare_route(accessor).await;
self.start_packet_recv().await;
self.run_relay_session_gc_routine().await;
self.run_parent_feature_flag_sync_routine().await;
self.peer_rpc.run();
self.peer_center.init().await;
}
@@ -660,6 +717,7 @@ impl ForeignNetworkManagerData {
fn remove_network(&self, network_name: &String) {
let _l = self.lock.lock().unwrap();
if let Some(old) = self.network_peer_maps.remove(network_name) {
old.1.traffic_metrics.clear_peer_cache();
let to_remove_peers = old.1.peer_map.list_peers();
for p in to_remove_peers {
self.peer_network_map.remove_if(&p, |_, v| {
@@ -669,6 +727,9 @@ impl ForeignNetworkManagerData {
}
}
self.network_peer_last_update.remove(network_name);
shrink_dashmap(&self.peer_network_map, None);
shrink_dashmap(&self.network_peer_maps, None);
shrink_dashmap(&self.network_peer_last_update, None);
}
#[allow(clippy::too_many_arguments)]
@@ -941,12 +1002,14 @@ impl ForeignNetworkManager {
async fn start_event_handler(&self, entry: &ForeignNetworkEntry) {
let data = self.data.clone();
let network_name = entry.network.network_name.clone();
let traffic_metrics = entry.traffic_metrics.clone();
let mut s = entry.global_ctx.subscribe();
self.tasks.lock().unwrap().spawn(async move {
while let Ok(e) = s.recv().await {
match &e {
GlobalCtxEvent::PeerRemoved(peer_id) => {
tracing::info!(?e, "remove peer from foreign network manager");
traffic_metrics.remove_peer(*peer_id);
data.remove_peer(*peer_id, &network_name);
data.network_peer_last_update
.insert(network_name.clone(), SystemTime::now());
@@ -965,6 +1028,7 @@ impl ForeignNetworkManager {
}
// if lagged or recv done just remove the network
tracing::error!("global event handler at foreign network manager exit");
traffic_metrics.clear_peer_cache();
data.remove_network(&network_name);
});
}
@@ -1397,6 +1461,92 @@ pub mod tests {
);
}
#[tokio::test]
async fn disable_relay_data_blocks_foreign_network_transit_data() {
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
let mut flags = pm_center.get_global_ctx().get_flags();
flags.disable_relay_data = true;
pm_center.get_global_ctx().set_flags(flags);
pm_center
.get_global_ctx()
.issue_event(GlobalCtxEvent::ConfigPatched(Default::default()));
let center_peer_id = pm_center
.get_foreign_network_manager()
.get_network_peer_id("net1")
.unwrap();
wait_for_condition(
|| {
let pma_net1 = pma_net1.clone();
async move {
pma_net1.list_routes().await.iter().any(|route| {
route.peer_id == center_peer_id
&& route
.feature_flag
.as_ref()
.map(|flag| flag.avoid_relay_data)
.unwrap_or(false)
})
}
},
Duration::from_secs(5),
)
.await;
let network_labels =
LabelSet::new().with_label_type(LabelType::NetworkName("net1".to_string()));
let forwarded_bytes_before = metric_value(
&pm_center,
MetricName::TrafficBytesForwarded,
network_labels.clone(),
);
let forwarded_packets_before = metric_value(
&pm_center,
MetricName::TrafficPacketsForwarded,
network_labels.clone(),
);
let mut transit_pkt = ZCPacket::new_with_payload(b"foreign-transit-disabled");
transit_pkt.fill_peer_manager_hdr(
pma_net1.my_peer_id(),
pmb_net1.my_peer_id(),
PacketType::Data as u8,
);
pma_net1
.get_foreign_network_client()
.send_msg(transit_pkt, center_peer_id)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
assert_eq!(
metric_value(
&pm_center,
MetricName::TrafficBytesForwarded,
network_labels.clone()
),
forwarded_bytes_before
);
assert_eq!(
metric_value(
&pm_center,
MetricName::TrafficPacketsForwarded,
network_labels
),
forwarded_packets_before
);
}
#[tokio::test]
async fn foreign_network_transit_control_forwarding_records_control_forwarded_metrics() {
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
@@ -1409,6 +1559,10 @@ pub mod tests {
.await
.unwrap();
let mut flags = pm_center.get_global_ctx().get_flags();
flags.disable_relay_data = true;
pm_center.get_global_ctx().set_flags(flags);
let center_peer_id = pm_center
.get_foreign_network_manager()
.get_network_peer_id("net1")
@@ -1461,6 +1615,58 @@ pub mod tests {
.await;
}
#[tokio::test]
async fn foreign_network_peer_removed_clears_traffic_metric_peer_cache() {
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
wait_for_condition(
|| {
let pm_center = pm_center.clone();
async move {
pm_center
.get_foreign_network_manager()
.get_network_peer_id("net1")
.is_some()
}
},
Duration::from_secs(5),
)
.await;
let entry = pm_center
.get_foreign_network_manager()
.data
.get_network_entry("net1")
.unwrap();
entry
.traffic_metrics
.record_rx(pma_net1.my_peer_id(), PacketType::Data as u8, 128)
.await;
assert!(
entry
.traffic_metrics
.contains_peer_cache(pma_net1.my_peer_id())
);
entry
.global_ctx
.issue_event(GlobalCtxEvent::PeerRemoved(pma_net1.my_peer_id()));
wait_for_condition(
|| {
let entry = entry.clone();
let peer_id = pma_net1.my_peer_id();
async move { !entry.traffic_metrics.contains_peer_cache(peer_id) }
},
Duration::from_secs(5),
)
.await;
}
#[tokio::test]
async fn foreign_network_encapsulated_forwarding_records_tx_metrics() {
set_global_var!(OSPF_UPDATE_MY_GLOBAL_FOREIGN_NETWORK_INTERVAL_SEC, 1);
@@ -1657,6 +1863,81 @@ pub mod tests {
));
}
#[tokio::test]
async fn foreign_entry_feature_flag_tracks_parent_disable_relay_data_toggle() {
let global_ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new(
"__access__".to_string(),
"access_secret".to_string(),
)));
let foreign_network = NetworkIdentity::new("net1".to_string(), "net1_secret".to_string());
let (pm_packet_sender, _pm_packet_recv) = create_packet_recv_chan();
let entry = ForeignNetworkEntry::new(
foreign_network,
1,
global_ctx.clone(),
true,
Arc::new(PeerSessionStore::new()),
pm_packet_sender,
);
assert!(!entry.global_ctx.get_feature_flags().avoid_relay_data);
entry.run_parent_feature_flag_sync_routine().await;
let mut flags = global_ctx.get_flags();
flags.disable_relay_data = true;
global_ctx.set_flags(flags);
global_ctx.issue_event(GlobalCtxEvent::ConfigPatched(Default::default()));
wait_for_condition(
|| async { entry.global_ctx.get_feature_flags().avoid_relay_data },
Duration::from_secs(2),
)
.await;
let mut flags = global_ctx.get_flags();
flags.disable_relay_data = false;
global_ctx.set_flags(flags);
global_ctx.issue_event(GlobalCtxEvent::ConfigPatched(Default::default()));
wait_for_condition(
|| async { !entry.global_ctx.get_feature_flags().avoid_relay_data },
Duration::from_secs(2),
)
.await;
}
#[tokio::test]
async fn foreign_entry_without_relay_data_keeps_avoid_feature_flag() {
let global_ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new(
"__access__".to_string(),
"access_secret".to_string(),
)));
let foreign_network = NetworkIdentity::new("net1".to_string(), "net1_secret".to_string());
let (pm_packet_sender, _pm_packet_recv) = create_packet_recv_chan();
let entry = ForeignNetworkEntry::new(
foreign_network,
1,
global_ctx.clone(),
false,
Arc::new(PeerSessionStore::new()),
pm_packet_sender,
);
assert!(entry.global_ctx.get_feature_flags().avoid_relay_data);
let mut flags = global_ctx.get_flags();
flags.disable_relay_data = false;
global_ctx.set_flags(flags);
ForeignNetworkEntry::sync_parent_relay_data_feature_flag(
&global_ctx,
&entry.global_ctx,
entry.relay_data,
);
assert!(entry.global_ctx.get_feature_flags().avoid_relay_data);
}
#[test]
fn credential_trust_path_rejects_admin_identity() {
assert!(ForeignNetworkManager::should_reject_credential_trust_path(
+1
View File
@@ -11,6 +11,7 @@ pub mod peer_ospf_route;
pub mod peer_rpc;
pub mod peer_rpc_service;
pub mod peer_session;
pub(crate) mod public_ipv6;
pub mod relay_peer_map;
pub mod route_trait;
pub mod rpc_service;
+4 -2
View File
@@ -12,6 +12,7 @@ use std::{
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use guarden::guard;
use hmac::Mac;
use prost::Message;
@@ -40,7 +41,6 @@ use crate::{
error::Error,
global_ctx::ArcGlobalCtx,
},
guard,
peers::peer_session::{PeerSessionStore, SessionKey, UpsertResponderSessionReturn},
proto::{
api::instance::{PeerConnInfo, PeerConnStats},
@@ -1352,7 +1352,9 @@ impl PeerConn {
let is_foreign_network = conn_info_for_instrument.network_name
!= self.global_ctx.get_network_identity().network_name;
let recv_limiter = if is_foreign_network {
let recv_limiter = if is_foreign_network
&& self.global_ctx.get_flags().foreign_relay_bps_limit != u64::MAX
{
let relay_network_bps_limit = self.global_ctx.get_flags().foreign_relay_bps_limit;
let limiter_config = LimiterConfig {
burst_rate: None,
+190 -22
View File
@@ -38,7 +38,7 @@ use crate::{
route_trait::{ForeignNetworkRouteInfoMap, MockRoute, NextHopPolicy, RouteInterface},
traffic_metrics::{
InstanceLabelKind, LogicalTrafficMetrics, TrafficKind, TrafficMetricRecorder,
route_peer_info_instance_id, traffic_kind,
is_relay_data_packet_type, route_peer_info_instance_id, traffic_kind,
},
},
proto::{
@@ -263,9 +263,7 @@ impl PeerManager {
.is_err()
{
// if local network is not in whitelist, avoid relay data when exist any other route path
let mut f = global_ctx.get_feature_flags();
f.avoid_relay_data = true;
global_ctx.set_feature_flags(f);
global_ctx.set_avoid_relay_data_preference(true);
}
let is_secure_mode_enabled = global_ctx
@@ -689,6 +687,11 @@ impl PeerManager {
Ok(())
}
fn release_reserved_peer_id(&self, network_name: &str) {
self.reserved_my_peer_id_map.remove(network_name);
shrink_dashmap(&self.reserved_my_peer_id_map, None);
}
#[tracing::instrument(ret)]
pub async fn add_tunnel_as_server(
&self,
@@ -704,7 +707,8 @@ impl PeerManager {
tunnel,
self.peer_session_store.clone(),
);
conn.do_handshake_as_server_ext(|peer, network_name:&str| {
let mut reserved_peer_id_network_name = None;
let handshake_ret = conn.do_handshake_as_server_ext(|peer, network_name:&str| {
if network_name
== self.global_ctx.get_network_identity().network_name
{
@@ -715,6 +719,7 @@ impl PeerManager {
.foreign_network_manager
.get_network_peer_id(network_name);
if peer_id.is_none() {
reserved_peer_id_network_name = Some(network_name.to_string());
peer_id = Some(*self.reserved_my_peer_id_map.entry(network_name.to_string()).or_insert_with(|| {
rand::random::<PeerId>()
}).value());
@@ -730,7 +735,14 @@ impl PeerManager {
Ok(())
})
.await?;
.await;
if let Err(err) = handshake_ret {
if let Some(network_name) = reserved_peer_id_network_name {
self.release_reserved_peer_id(&network_name);
}
return Err(err);
}
let peer_identity = conn.get_network_identity();
let peer_network_name = peer_identity.network_name.clone();
@@ -749,6 +761,7 @@ impl PeerManager {
if !is_local_network && self.global_ctx.get_flags().private_mode && !foreign_network_allowed
{
self.release_reserved_peer_id(&peer_network_name);
return Err(Error::SecretKeyError(
"private mode is turned on, foreign network secret mismatch".to_string(),
));
@@ -756,14 +769,18 @@ impl PeerManager {
conn.set_is_hole_punched(!is_directly_connected);
if is_local_network {
self.add_new_peer_conn(conn).await?;
let add_peer_ret = if is_local_network {
self.add_new_peer_conn(conn).await
} else {
self.foreign_network_manager.add_peer_conn(conn).await?;
self.foreign_network_manager.add_peer_conn(conn).await
};
if let Err(err) = add_peer_ret {
self.release_reserved_peer_id(&peer_network_name);
return Err(err);
}
self.reserved_my_peer_id_map.remove(&peer_network_name);
shrink_dashmap(&self.reserved_my_peer_id_map, None);
self.release_reserved_peer_id(&peer_network_name);
tracing::info!("add tunnel as server done");
Ok(())
@@ -774,6 +791,7 @@ impl PeerManager {
my_peer_id: PeerId,
peer_map: &PeerMap,
foreign_network_mgr: &ForeignNetworkManager,
disable_relay_data: bool,
) -> Result<(), ZCPacket> {
let pm_header = packet.peer_manager_header().unwrap();
if pm_header.packet_type != PacketType::ForeignNetworkPacket as u8 {
@@ -783,6 +801,16 @@ impl PeerManager {
let from_peer_id = pm_header.from_peer_id.get();
let to_peer_id = pm_header.to_peer_id.get();
if disable_relay_data && Self::is_relay_data_zc_packet(&packet) {
tracing::debug!(
?from_peer_id,
?to_peer_id,
inner_packet_type = ?packet.foreign_network_inner_packet_type(),
"drop foreign network relay data while relay data is disabled"
);
return Ok(());
}
let foreign_hdr = packet.foreign_network_hdr().unwrap();
let foreign_network_name = foreign_hdr.get_network_name(packet.payload());
let foreign_peer_id = foreign_hdr.get_dst_peer_id();
@@ -872,6 +900,29 @@ impl PeerManager {
}
}
fn is_relay_data_packet(packet_type: u8) -> bool {
is_relay_data_packet_type(packet_type)
}
fn is_relay_data_zc_packet(packet: &ZCPacket) -> bool {
let Some(hdr) = packet.peer_manager_header() else {
return false;
};
if hdr.packet_type == PacketType::ForeignNetworkPacket as u8 {
let inner_packet_type = packet.foreign_network_inner_packet_type();
if inner_packet_type.is_none() {
tracing::warn!(
?hdr,
"foreign network packet has unparseable inner peer manager header"
);
}
return inner_packet_type.is_none_or(Self::is_relay_data_packet);
}
Self::is_relay_data_packet(hdr.packet_type)
}
async fn start_peer_recv(&self) {
let mut recv = self.packet_recv.lock().await.take().unwrap();
let my_peer_id = self.my_peer_id;
@@ -925,14 +976,21 @@ impl PeerManager {
self.tasks.lock().await.spawn(async move {
tracing::trace!("start_peer_recv");
while let Ok(ret) = recv_packet_from_chan(&mut recv).await {
let Err(mut ret) =
Self::try_handle_foreign_network_packet(ret, my_peer_id, &peers, &foreign_mgr)
.await
let disable_relay_data = global_ctx.flags_arc().disable_relay_data;
let Err(mut ret) = Self::try_handle_foreign_network_packet(
ret,
my_peer_id,
&peers,
&foreign_mgr,
disable_relay_data,
)
.await
else {
continue;
};
let buf_len = ret.buf_len();
let is_relay_data_packet = Self::is_relay_data_zc_packet(&ret);
let Some(hdr) = ret.mut_peer_manager_header() else {
tracing::warn!(?ret, "invalid packet, skip");
continue;
@@ -944,6 +1002,16 @@ impl PeerManager {
let packet_type = hdr.packet_type;
let is_encrypted = hdr.is_encrypted();
if to_peer_id != my_peer_id {
if disable_relay_data && is_relay_data_packet {
tracing::debug!(
?from_peer_id,
?to_peer_id,
packet_type,
"drop forwarded relay data while relay data is disabled"
);
continue;
}
if hdr.forward_counter > 7 {
tracing::warn!(?hdr, "forward counter exceed, drop packet");
continue;
@@ -1062,7 +1130,7 @@ impl PeerManager {
&ret,
true,
global_ctx.get_ipv4().map(|x| x.address()),
global_ctx.get_ipv6().map(|x| x.address()),
|dst| global_ctx.is_ip_local_ipv6(&dst),
&route,
) {
continue;
@@ -1291,6 +1359,18 @@ impl PeerManager {
self.get_route().list_proxy_cidrs_v6().await
}
pub async fn list_public_ipv6_routes(&self) -> BTreeSet<cidr::Ipv6Inet> {
self.get_route().list_public_ipv6_routes().await
}
pub async fn get_my_public_ipv6_addr(&self) -> Option<cidr::Ipv6Inet> {
self.get_route().get_my_public_ipv6_addr().await
}
pub async fn get_local_public_ipv6_info(&self) -> instance::ListPublicIpv6InfoResponse {
self.get_route().get_local_public_ipv6_info().await
}
pub async fn dump_route(&self) -> String {
self.get_route().dump().await
}
@@ -1330,7 +1410,7 @@ impl PeerManager {
data,
false,
None,
None,
|_| false,
&self.get_route(),
) {
return false;
@@ -1532,6 +1612,10 @@ impl PeerManager {
dst_peers.extend(self.peers.list_routes().await.iter().map(|x| *x.key()));
} else if let Some(peer_id) = self.peers.get_peer_id_by_ipv6(ipv6_addr).await {
dst_peers.push(peer_id);
} else if !ipv6_addr.is_unicast_link_local()
&& let Some(peer_id) = self.get_route().get_public_ipv6_gateway_peer_id().await
{
dst_peers.push(peer_id);
} else if !ipv6_addr.is_unicast_link_local() {
// NOTE: never route link local address to exit node.
for exit_node in self.exit_nodes.read().await.iter() {
@@ -1662,7 +1746,7 @@ impl PeerManager {
&& !self.global_ctx.is_ip_local_virtual_ip(&ip_addr)
{
// Keep the loop-prevention flags for proxy-induced self-delivery where
// the destination is not this node's own virtual IP.
// the destination is not this node's own EasyTier-managed IP.
hdr.set_not_send_to_tun(true);
hdr.set_no_proxy(true);
}
@@ -1879,6 +1963,15 @@ impl PeerManager {
version: EASYTIER_VERSION.to_string(),
feature_flag: Some(self.global_ctx.get_feature_flags()),
ip_list: Some(self.global_ctx.get_ip_collector().collect_ip_addrs().await),
public_ipv6_addr: self.get_my_public_ipv6_addr().await.map(Into::into),
ipv6_public_addr_prefix: self
.global_ctx
.get_advertised_ipv6_public_addr_prefix()
.map(|prefix| {
cidr::Ipv6Inet::new(prefix.first_address(), prefix.network_length())
.unwrap()
.into()
}),
}
}
@@ -2055,7 +2148,7 @@ mod tests {
},
},
proto::{
common::{CompressionAlgoPb, NatType, PeerFeatureFlag},
common::{CompressionAlgoPb, NatType},
peer_rpc::SecureAuthLevel,
},
tunnel::{
@@ -2199,6 +2292,84 @@ mod tests {
assert_eq!(signal.version(), initial_version + 2);
}
#[test]
fn disable_relay_data_classifies_data_plane_packets_only() {
for packet_type in [
PacketType::Data,
PacketType::KcpSrc,
PacketType::KcpDst,
PacketType::QuicSrc,
PacketType::QuicDst,
PacketType::DataWithKcpSrcModified,
PacketType::DataWithQuicSrcModified,
PacketType::ForeignNetworkPacket,
] {
assert!(PeerManager::is_relay_data_packet(packet_type as u8));
}
for packet_type in [
PacketType::RpcReq,
PacketType::RpcResp,
PacketType::Ping,
PacketType::Pong,
PacketType::HandShake,
PacketType::NoiseHandshakeMsg1,
PacketType::NoiseHandshakeMsg2,
PacketType::NoiseHandshakeMsg3,
PacketType::RelayHandshake,
PacketType::RelayHandshakeAck,
] {
assert!(!PeerManager::is_relay_data_packet(packet_type as u8));
}
}
#[test]
fn disable_relay_data_inspects_foreign_network_inner_packet_type() {
let network_name = "net1".to_string();
let mut rpc_packet = ZCPacket::new_with_payload(b"rpc");
rpc_packet.fill_peer_manager_hdr(1, 2, PacketType::RpcReq as u8);
let mut foreign_rpc_packet =
ZCPacket::new_for_foreign_network(&network_name, 2, &rpc_packet);
foreign_rpc_packet.fill_peer_manager_hdr(10, 20, PacketType::ForeignNetworkPacket as u8);
assert_eq!(
foreign_rpc_packet.foreign_network_inner_packet_type(),
Some(PacketType::RpcReq as u8)
);
assert!(!PeerManager::is_relay_data_zc_packet(&foreign_rpc_packet));
let mut data_packet = ZCPacket::new_with_payload(b"data");
data_packet.fill_peer_manager_hdr(1, 2, PacketType::Data as u8);
let mut foreign_data_packet =
ZCPacket::new_for_foreign_network(&network_name, 2, &data_packet);
foreign_data_packet.fill_peer_manager_hdr(10, 20, PacketType::ForeignNetworkPacket as u8);
assert_eq!(
foreign_data_packet.foreign_network_inner_packet_type(),
Some(PacketType::Data as u8)
);
assert!(PeerManager::is_relay_data_zc_packet(&foreign_data_packet));
}
#[tokio::test]
async fn non_whitelisted_network_avoid_relay_survives_disable_relay_data_toggle() {
let global_ctx = get_mock_global_ctx();
let mut flags = global_ctx.get_flags();
flags.disable_relay_data = true;
flags.relay_network_whitelist = "other-network".to_string();
global_ctx.set_flags(flags);
let (packet_send, _packet_recv) = create_packet_recv_chan();
let _peer_mgr = PeerManager::new(RouteAlgoType::Ospf, global_ctx.clone(), packet_send);
let mut flags = global_ctx.get_flags();
flags.disable_relay_data = false;
global_ctx.set_flags(flags);
assert!(global_ctx.get_feature_flags().avoid_relay_data);
}
#[tokio::test]
async fn send_msg_internal_does_not_record_tx_metrics_on_failed_delivery() {
let peer_mgr = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
@@ -3096,10 +3267,7 @@ mod tests {
// when b's avoid_relay_data is true, a->c should route through d and e, cost is 3
peer_mgr_b
.get_global_ctx()
.set_feature_flags(PeerFeatureFlag {
avoid_relay_data: true,
..Default::default()
});
.set_avoid_relay_data_preference(true);
tokio::time::sleep(Duration::from_secs(2)).await;
if wait_route_appear_with_cost(peer_mgr_a.clone(), peer_mgr_c.my_peer_id, Some(3))
.await
+316 -10
View File
@@ -10,7 +10,7 @@ use std::{
};
use arc_swap::ArcSwap;
use cidr::{IpCidr, Ipv4Cidr, Ipv6Cidr};
use cidr::{IpCidr, Ipv4Cidr, Ipv6Cidr, Ipv6Inet};
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use ordered_hash_map::OrderedHashMap;
@@ -46,9 +46,10 @@ use crate::{
peer_rpc::{
ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, OspfRouteRpc,
OspfRouteRpcClientFactory, OspfRouteRpcServer, PeerGroupInfo, PeerIdVersion,
PeerIdentityType, RouteForeignNetworkInfos, RouteForeignNetworkSummary, RoutePeerInfo,
RoutePeerInfos, SyncRouteInfoError, SyncRouteInfoRequest, SyncRouteInfoResponse,
TrustedCredentialPubkey, TrustedCredentialPubkeyProof, route_foreign_network_infos,
PeerIdentityType, PublicIpv6AddrRpcServer, RouteForeignNetworkInfos,
RouteForeignNetworkSummary, RoutePeerInfo, RoutePeerInfos, SyncRouteInfoError,
SyncRouteInfoRequest, SyncRouteInfoResponse, TrustedCredentialPubkey,
TrustedCredentialPubkeyProof, route_foreign_network_infos,
route_foreign_network_summary, sync_route_info_request::ConnInfo,
},
rpc_types::{
@@ -63,6 +64,9 @@ use super::{
PeerPacketFilter,
graph_algo::dijkstra_with_first_hop,
peer_rpc::PeerRpcManager,
public_ipv6::{
PublicIpv6PeerRouteInfo, PublicIpv6RouteControl, PublicIpv6Service, PublicIpv6SyncTrigger,
},
route_trait::{
DefaultRouteCostCalculator, ForeignNetworkRouteInfoMap, NextHopPolicy, RouteCostCalculator,
RouteCostCalculatorInterface,
@@ -137,6 +141,10 @@ fn raw_credential_bytes_from_route_info(
.map(|credential| credential.encode_to_vec())
}
fn route_peer_inst_id(info: &RoutePeerInfo) -> Option<uuid::Uuid> {
info.inst_id.map(Into::into)
}
#[derive(Debug, Clone)]
struct AtomicVersion(Arc<AtomicU32>);
@@ -205,6 +213,8 @@ impl RoutePeerInfo {
quic_port: None,
noise_static_pubkey: Vec::new(),
trusted_credential_pubkeys: Vec::new(),
ipv6_public_addr_prefix: None,
ipv6_public_addr_lease: None,
}
}
@@ -221,6 +231,7 @@ impl RoutePeerInfo {
my_peer_id: PeerId,
peer_route_id: u64,
global_ctx: &ArcGlobalCtx,
public_ipv6_addr_lease: Option<Ipv6Inet>,
) -> Self {
let stun_info = global_ctx.get_stun_info_collector().get_stun_info();
let noise_static_pubkey = global_ctx
@@ -259,6 +270,14 @@ impl RoutePeerInfo {
.unwrap_or(24),
ipv6_addr: global_ctx.get_ipv6().map(|x| x.into()),
ipv6_public_addr_prefix: global_ctx.get_advertised_ipv6_public_addr_prefix().map(
|prefix| {
Ipv6Inet::new(prefix.first_address(), prefix.network_length())
.unwrap()
.into()
},
),
ipv6_public_addr_lease: public_ipv6_addr_lease.map(Into::into),
groups: global_ctx.get_acl_groups(my_peer_id),
@@ -349,6 +368,8 @@ impl From<RoutePeerInfo> for crate::proto::api::instance::Route {
path_latency_latency_first: None,
ipv6_addr: val.ipv6_addr,
public_ipv6_addr: val.ipv6_public_addr_lease,
ipv6_public_addr_prefix: val.ipv6_public_addr_prefix,
}
}
}
@@ -964,8 +985,14 @@ impl SyncedRouteInfo {
my_peer_id: PeerId,
my_peer_route_id: u64,
global_ctx: &ArcGlobalCtx,
public_ipv6_addr_lease: Option<Ipv6Inet>,
) -> bool {
let mut new = RoutePeerInfo::new_updated_self(my_peer_id, my_peer_route_id, global_ctx);
let mut new = RoutePeerInfo::new_updated_self(
my_peer_id,
my_peer_route_id,
global_ctx,
public_ipv6_addr_lease,
);
let mut guard = self.peer_infos.upgradable_read();
let old = guard.get(&my_peer_id);
let new_version = old.map(|x| x.version).unwrap_or(0) + 1;
@@ -1201,6 +1228,25 @@ impl SyncedRouteInfo {
Vec<PeerId>,
HashMap<Vec<u8>, crate::common::global_ctx::TrustedKeyMetadata>,
)
where
F: FnMut(PeerId) -> bool,
{
self.verify_and_update_credential_trusts_with_active_peers_protecting(
network_secret,
is_peer_active,
None,
)
}
fn verify_and_update_credential_trusts_with_active_peers_protecting<F>(
&self,
network_secret: Option<&str>,
is_peer_active: F,
protected_peer_id: Option<PeerId>,
) -> (
Vec<PeerId>,
HashMap<Vec<u8>, crate::common::global_ctx::TrustedKeyMetadata>,
)
where
F: FnMut(PeerId) -> bool,
{
@@ -1221,6 +1267,9 @@ impl SyncedRouteInfo {
let mut untrusted_peers =
Self::collect_revoked_credential_peers(&peer_infos, &prev_trusted, &all_trusted);
untrusted_peers.extend(duplicate_untrusted_peers);
if let Some(protected_peer_id) = protected_peer_id {
untrusted_peers.remove(&protected_peer_id);
}
// Remove untrusted peers from peer_infos so they won't appear in route graph
if !untrusted_peers.is_empty() {
@@ -1588,6 +1637,21 @@ impl RouteTable {
.or_insert(peer_id_and_version);
}
if let Some(ipv6_addr) = info
.ipv6_public_addr_lease
.as_ref()
.and_then(|addr| addr.address)
{
self.ipv6_peer_id_map
.entry(ipv6_addr.into())
.and_modify(|v| {
if is_new_peer_better(v) {
*v = peer_id_and_version;
}
})
.or_insert(peer_id_and_version);
}
for cidr in info.proxy_cidrs.iter() {
let Ok(cidr) = cidr.parse::<IpCidr>() else {
tracing::warn!("invalid proxy cidr: {:?}, from peer: {:?}", cidr, peer_id);
@@ -2019,6 +2083,8 @@ struct PeerRouteServiceImpl {
foreign_network_owner_map: DashMap<NetworkIdentity, Vec<PeerId>>,
foreign_network_my_peer_id_map: DashMap<(String, PeerId), PeerId>,
synced_route_info: SyncedRouteInfo,
public_ipv6_service: std::sync::Mutex<Weak<PublicIpv6Service>>,
self_public_ipv6_addr_lease: std::sync::Mutex<Option<Ipv6Inet>>,
cached_local_conn_map: std::sync::Mutex<RouteConnBitmap>,
cached_local_conn_map_version: AtomicVersion,
cached_interface_peer_snapshot: std::sync::Mutex<Arc<InterfacePeerSnapshot>>,
@@ -2081,6 +2147,8 @@ impl PeerRouteServiceImpl {
non_reusable_credential_owners: DashMap::new(),
version: AtomicVersion::new(),
},
public_ipv6_service: std::sync::Mutex::new(Weak::new()),
self_public_ipv6_addr_lease: std::sync::Mutex::new(None),
cached_local_conn_map: std::sync::Mutex::new(RouteConnBitmap::default()),
cached_local_conn_map_version: AtomicVersion::new(),
cached_interface_peer_snapshot: std::sync::Mutex::new(Arc::new(
@@ -2119,6 +2187,20 @@ impl PeerRouteServiceImpl {
.unwrap_or(false)
}
fn set_public_ipv6_service(&self, service: Weak<PublicIpv6Service>) {
*self.public_ipv6_service.lock().unwrap() = service;
}
fn public_ipv6_service(&self) -> Option<Arc<PublicIpv6Service>> {
self.public_ipv6_service.lock().unwrap().upgrade()
}
fn notify_public_ipv6_route_change(&self) -> bool {
self.public_ipv6_service()
.map(|service| service.handle_route_change())
.unwrap_or(false)
}
fn get_or_create_session(&self, dst_peer_id: PeerId) -> Arc<SyncRouteSession> {
self.sessions
.entry(dst_peer_id)
@@ -2230,6 +2312,7 @@ impl PeerRouteServiceImpl {
self.my_peer_id,
self.my_peer_route_id,
&self.global_ctx,
*self.self_public_ipv6_addr_lease.lock().unwrap(),
)
}
@@ -2618,14 +2701,19 @@ impl PeerRouteServiceImpl {
untrusted_changed = self.refresh_credential_trusts_and_disconnect().await;
}
let mut public_ipv6_state_updated = false;
if my_peer_info_updated || my_conn_info_updated || untrusted_changed {
self.update_route_table_and_cached_local_conn_bitmap();
self.update_foreign_network_owner_map();
public_ipv6_state_updated = self.notify_public_ipv6_route_change();
}
if my_peer_info_updated {
self.update_peer_info_last_update();
}
my_peer_info_updated || my_conn_info_updated || my_foreign_network_updated
my_peer_info_updated
|| my_conn_info_updated
|| my_foreign_network_updated
|| public_ipv6_state_updated
}
async fn refresh_acl_groups(&self) -> bool {
@@ -2652,22 +2740,28 @@ impl PeerRouteServiceImpl {
let untrusted = self.refresh_credential_trusts_with_current_topology();
self.disconnect_untrusted_peers(&untrusted).await;
let mut public_ipv6_state_updated = false;
if my_peer_info_updated || !untrusted.is_empty() {
self.update_route_table_and_cached_local_conn_bitmap();
self.update_foreign_network_owner_map();
public_ipv6_state_updated = self.notify_public_ipv6_route_change();
}
if my_peer_info_updated {
self.update_peer_info_last_update();
}
my_peer_info_updated || !untrusted.is_empty()
my_peer_info_updated || !untrusted.is_empty() || public_ipv6_state_updated
}
fn refresh_credential_trusts(&self) -> Vec<PeerId> {
let network_identity = self.global_ctx.get_network_identity();
let (untrusted, global_trusted_keys) = self
.synced_route_info
.verify_and_update_credential_trusts(network_identity.network_secret.as_deref());
.verify_and_update_credential_trusts_with_active_peers_protecting(
network_identity.network_secret.as_deref(),
|_| true,
Some(self.my_peer_id),
);
self.global_ctx
.update_trusted_keys(global_trusted_keys, &network_identity.network_name);
@@ -2683,9 +2777,10 @@ impl PeerRouteServiceImpl {
let (untrusted, global_trusted_keys) = self
.synced_route_info
.verify_and_update_credential_trusts_with_active_peers(
.verify_and_update_credential_trusts_with_active_peers_protecting(
network_identity.network_secret.as_deref(),
|peer_id| self.is_active_non_reusable_credential_peer(peer_id),
Some(self.my_peer_id),
);
self.global_ctx
.update_trusted_keys(global_trusted_keys, &network_identity.network_name);
@@ -2968,7 +3063,6 @@ impl PeerRouteServiceImpl {
session
.update_dst_saved_foreign_network_version(foreign_network, dst_peer_id);
}
session.update_last_sync_succ_timestamp(next_last_sync_succ_timestamp);
}
}
@@ -3493,7 +3587,13 @@ impl RouteSessionManager {
}
if need_update_route_table || foreign_network_changed {
service_impl.update_route_table_and_cached_local_conn_bitmap();
service_impl.update_foreign_network_owner_map();
if need_update_route_table
&& let Some(public_ipv6_service) = service_impl.public_ipv6_service()
{
public_ipv6_service.handle_route_change();
}
}
tracing::debug!(
@@ -3534,12 +3634,86 @@ impl RouteSessionManager {
}
}
struct OspfPublicIpv6RouteHandle {
service_impl: Weak<PeerRouteServiceImpl>,
}
impl PublicIpv6RouteControl for OspfPublicIpv6RouteHandle {
fn my_peer_id(&self) -> PeerId {
self.service_impl
.upgrade()
.map(|service_impl| service_impl.my_peer_id)
.unwrap_or_default()
}
fn peer_route_snapshot(&self) -> Vec<PublicIpv6PeerRouteInfo> {
let Some(service_impl) = self.service_impl.upgrade() else {
return Vec::new();
};
service_impl
.synced_route_info
.peer_infos
.read()
.iter()
.map(|(peer_id, info)| PublicIpv6PeerRouteInfo {
peer_id: *peer_id,
inst_id: route_peer_inst_id(info),
is_provider: info
.feature_flag
.as_ref()
.map(|flags| flags.ipv6_public_addr_provider)
.unwrap_or(false),
prefix: info
.ipv6_public_addr_prefix
.map(Into::into)
.map(|prefix: Ipv6Inet| prefix.network()),
lease: info.ipv6_public_addr_lease.map(Into::into),
reachable: *peer_id == service_impl.my_peer_id
|| service_impl.route_table.peer_reachable(*peer_id),
})
.collect()
}
fn publish_self_public_ipv6_lease(&self, lease: Option<Ipv6Inet>) -> bool {
let Some(service_impl) = self.service_impl.upgrade() else {
return false;
};
let mut current = service_impl.self_public_ipv6_addr_lease.lock().unwrap();
if *current == lease {
return false;
}
*current = lease;
drop(current);
let changed = service_impl.update_my_peer_info();
if changed {
service_impl.update_route_table_and_cached_local_conn_bitmap();
service_impl.update_foreign_network_owner_map();
}
changed
}
}
#[derive(Clone)]
struct OspfPublicIpv6SyncTrigger {
session_mgr: RouteSessionManager,
}
impl PublicIpv6SyncTrigger for OspfPublicIpv6SyncTrigger {
fn sync_now(&self, reason: &str) {
self.session_mgr.sync_now(reason);
}
}
pub struct PeerRoute {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
peer_rpc: Weak<PeerRpcManager>,
service_impl: Arc<PeerRouteServiceImpl>,
public_ipv6_service: Arc<PublicIpv6Service>,
session_mgr: RouteSessionManager,
tasks: std::sync::Mutex<JoinSet<()>>,
@@ -3563,6 +3737,17 @@ impl PeerRoute {
) -> Arc<Self> {
let service_impl = Arc::new(PeerRouteServiceImpl::new(my_peer_id, global_ctx.clone()));
let session_mgr = RouteSessionManager::new(service_impl.clone(), peer_rpc.clone());
let public_ipv6_service = Arc::new(PublicIpv6Service::new(
global_ctx.clone(),
Arc::downgrade(&peer_rpc),
Arc::new(OspfPublicIpv6RouteHandle {
service_impl: Arc::downgrade(&service_impl),
}),
Arc::new(OspfPublicIpv6SyncTrigger {
session_mgr: session_mgr.clone(),
}),
));
service_impl.set_public_ipv6_service(Arc::downgrade(&public_ipv6_service));
Arc::new(PeerRoute {
my_peer_id,
@@ -3570,6 +3755,7 @@ impl PeerRoute {
peer_rpc: Arc::downgrade(&peer_rpc),
service_impl,
public_ipv6_service,
session_mgr,
tasks: std::sync::Mutex::new(JoinSet::new()),
@@ -3607,6 +3793,9 @@ impl PeerRoute {
tracing::debug!("cost_calculator_need_update");
service_impl.synced_route_info.version.inc();
service_impl.update_route_table();
if let Some(public_ipv6_service) = service_impl.public_ipv6_service() {
public_ipv6_service.handle_route_change();
}
}
select! {
@@ -3631,11 +3820,16 @@ impl PeerRoute {
// make sure my_peer_id is in the peer_infos.
self.service_impl.update_my_infos().await;
self.public_ipv6_service.handle_route_change();
peer_rpc.rpc_server().registry().register(
OspfRouteRpcServer::new(self.session_mgr.clone()),
&self.global_ctx.get_network_name(),
);
peer_rpc.rpc_server().registry().register(
PublicIpv6AddrRpcServer::new(self.public_ipv6_service.rpc_server()),
&self.global_ctx.get_network_name(),
);
self.tasks
.lock()
@@ -3657,6 +3851,16 @@ impl PeerRoute {
.lock()
.unwrap()
.spawn(Self::clear_expired_peer(self.service_impl.clone()));
self.tasks
.lock()
.unwrap()
.spawn(self.public_ipv6_service.clone().provider_gc_routine());
self.tasks
.lock()
.unwrap()
.spawn(self.public_ipv6_service.clone().client_routine());
}
}
@@ -3677,6 +3881,10 @@ impl Drop for PeerRoute {
OspfRouteRpcServer::new(self.session_mgr.clone()),
&self.global_ctx.get_network_name(),
);
peer_rpc.rpc_server().registry().unregister(
PublicIpv6AddrRpcServer::new(self.public_ipv6_service.rpc_server()),
&self.global_ctx.get_network_name(),
);
}
}
@@ -3765,6 +3973,51 @@ impl Route for PeerRoute {
.collect()
}
async fn list_public_ipv6_routes(&self) -> BTreeSet<Ipv6Inet> {
self.public_ipv6_service.list_routes()
}
async fn get_my_public_ipv6_addr(&self) -> Option<Ipv6Inet> {
self.public_ipv6_service.my_addr()
}
async fn get_public_ipv6_gateway_peer_id(&self) -> Option<PeerId> {
self.public_ipv6_service.provider_peer_id_for_client()
}
async fn get_local_public_ipv6_info(
&self,
) -> crate::proto::api::instance::ListPublicIpv6InfoResponse {
let Some((provider, leases)) = self.public_ipv6_service.local_provider_state() else {
return crate::proto::api::instance::ListPublicIpv6InfoResponse::default();
};
crate::proto::api::instance::ListPublicIpv6InfoResponse {
provider_prefix: Some(
Ipv6Inet::new(
provider.prefix.first_address(),
provider.prefix.network_length(),
)
.unwrap()
.into(),
),
provider_leases: leases
.into_iter()
.map(|lease| crate::proto::api::instance::PublicIpv6LeaseInfo {
peer_id: lease.peer_id,
inst_id: lease.inst_id.to_string(),
leased_addr: Some(lease.addr.into()),
valid_until_unix_seconds: lease
.valid_until
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64,
reused: lease.reused,
})
.collect(),
}
}
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
let route_table = &self.service_impl.route_table;
if let Some(p) = route_table.ipv4_peer_id_map.get(ipv4_addr) {
@@ -4821,6 +5074,58 @@ mod tests {
);
}
#[tokio::test]
async fn credential_trust_refresh_does_not_remove_self_peer() {
let my_peer_id = 11;
let remote_peer_id = 12;
let credential_key = vec![8; 32];
let service_impl = PeerRouteServiceImpl::new(my_peer_id, get_mock_global_ctx());
let self_info = make_credential_route_peer_info(my_peer_id, &credential_key);
let remote_info = make_credential_route_peer_info(remote_peer_id, &credential_key);
{
let mut guard = service_impl.synced_route_info.peer_infos.write();
guard.insert(self_info.peer_id, self_info);
guard.insert(remote_info.peer_id, remote_info);
}
service_impl
.synced_route_info
.trusted_credential_pubkeys
.insert(
credential_key.clone(),
TrustedCredentialPubkey {
pubkey: credential_key,
expiry_unix: i64::MAX,
..Default::default()
},
);
let (untrusted_peers, _) = service_impl
.synced_route_info
.verify_and_update_credential_trusts_with_active_peers_protecting(
None,
|_| true,
Some(my_peer_id),
);
assert_eq!(untrusted_peers, vec![remote_peer_id]);
assert!(
service_impl
.synced_route_info
.peer_infos
.read()
.contains_key(&my_peer_id)
);
assert!(
!service_impl
.synced_route_info
.peer_infos
.read()
.contains_key(&remote_peer_id)
);
}
#[tokio::test]
async fn credential_refresh_rebuilds_reachability_before_owner_election() {
const NETWORK_SECRET: &str = "sec1";
@@ -5180,6 +5485,7 @@ mod tests {
service_impl.my_peer_id,
service_impl.my_peer_route_id,
&service_impl.global_ctx,
None,
);
let mut self_info = self_info;
self_info.version = 1;
+55 -5
View File
@@ -12,6 +12,22 @@ use crate::{
tunnel::udp,
};
fn remove_easytier_managed_ipv6s(ret: &mut GetIpListResponse, global_ctx: &ArcGlobalCtx) {
ret.interface_ipv6s.retain(|ip| {
let ip = std::net::Ipv6Addr::from(*ip);
!global_ctx.is_ip_easytier_managed_ipv6(&ip)
});
if ret
.public_ipv6
.as_ref()
.map(|ip| std::net::Ipv6Addr::from(*ip))
.is_some_and(|ip| global_ctx.is_ip_easytier_managed_ipv6(&ip))
{
ret.public_ipv6 = None;
}
}
#[derive(Clone)]
pub struct DirectConnectorManagerRpcServer {
// TODO: this only cache for one src peer, should make it global
@@ -36,11 +52,7 @@ impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
.chain(self.global_ctx.get_running_listeners())
.map(Into::into)
.collect();
// remove et ipv6 from the interface ipv6 list
if let Some(et_ipv6) = self.global_ctx.get_ipv6() {
let et_ipv6: crate::proto::common::Ipv6Addr = et_ipv6.address().into();
ret.interface_ipv6s.retain(|x| *x != et_ipv6);
}
remove_easytier_managed_ipv6s(&mut ret, &self.global_ctx);
tracing::trace!(
"get_ip_list: public_ipv4: {:?}, public_ipv6: {:?}, listeners: {:?}",
ret.public_ipv4,
@@ -84,3 +96,41 @@ impl DirectConnectorManagerRpcServer {
Self { global_ctx }
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
peers::peer_rpc_service::remove_easytier_managed_ipv6s, proto::peer_rpc::GetIpListResponse,
};
#[tokio::test]
async fn get_ip_list_sanitizer_removes_managed_ipv6_from_all_sources() {
let global_ctx = get_mock_global_ctx();
let virtual_ipv6 = "fd00::1/64".parse().unwrap();
let public_ipv6 = "2001:db8::2/128".parse().unwrap();
let physical_ipv6: std::net::Ipv6Addr = "2001:db8::3".parse().unwrap();
let routed_ipv6: cidr::Ipv6Inet = "2001:db8::4/128".parse().unwrap();
global_ctx.set_ipv6(Some(virtual_ipv6));
global_ctx.set_public_ipv6_lease(Some(public_ipv6));
global_ctx.set_public_ipv6_routes(BTreeSet::from([routed_ipv6]));
let mut ip_list = GetIpListResponse {
public_ipv6: Some(public_ipv6.address().into()),
interface_ipv6s: vec![
virtual_ipv6.address().into(),
public_ipv6.address().into(),
routed_ipv6.address().into(),
physical_ipv6.into(),
],
..Default::default()
};
remove_easytier_managed_ipv6s(&mut ip_list, &global_ctx);
assert_eq!(ip_list.public_ipv6, None);
assert_eq!(ip_list.interface_ipv6s, vec![physical_ipv6.into()]);
}
}
+5 -1
View File
@@ -7,7 +7,10 @@ use anyhow::anyhow;
use dashmap::DashMap;
use super::secure_datagram::{SecureDatagramDirection, SecureDatagramSession};
use crate::{common::PeerId, tunnel::packet_def::ZCPacket};
use crate::{
common::{PeerId, shrink_dashmap},
tunnel::packet_def::ZCPacket,
};
pub struct UpsertResponderSessionReturn {
pub session: Arc<PeerSession>,
@@ -78,6 +81,7 @@ impl PeerSessionStore {
pub fn evict_unused_sessions(&self) {
self.sessions
.retain(|_key, session| Arc::strong_count(session) > 1);
shrink_dashmap(&self.sessions, None);
}
#[tracing::instrument(skip(self))]
File diff suppressed because it is too large Load Diff
+9 -1
View File
@@ -9,7 +9,7 @@ use tokio::time::{Duration, timeout};
use crate::peers::foreign_network_client::ForeignNetworkClient;
use crate::{
common::error::Error,
common::{PeerId, global_ctx::ArcGlobalCtx},
common::{PeerId, global_ctx::ArcGlobalCtx, shrink_dashmap},
peers::peer_map::PeerMap,
peers::peer_session::{PeerSession, PeerSessionAction, PeerSessionStore, SessionKey},
peers::route_trait::NextHopPolicy,
@@ -652,6 +652,10 @@ impl RelayPeerMap {
self.handshake_locks.remove(&peer_id);
self.pending_packets.remove(&peer_id);
}
shrink_dashmap(&self.states, None);
shrink_dashmap(&self.pending_handshakes, None);
shrink_dashmap(&self.handshake_locks, None);
shrink_dashmap(&self.pending_packets, None);
}
pub fn has_state(&self, peer_id: PeerId) -> bool {
@@ -679,6 +683,10 @@ impl RelayPeerMap {
self.pending_handshakes.remove(&peer_id);
self.handshake_locks.remove(&peer_id);
self.pending_packets.remove(&peer_id);
shrink_dashmap(&self.states, None);
shrink_dashmap(&self.pending_handshakes, None);
shrink_dashmap(&self.handshake_locks, None);
shrink_dashmap(&self.pending_packets, None);
tracing::debug!(?peer_id, "RelayPeerMap removed peer relay state");
}
+31 -3
View File
@@ -1,3 +1,4 @@
use cidr::Ipv6Inet;
use cidr::{Ipv4Cidr, Ipv6Cidr};
use dashmap::DashMap;
use std::{
@@ -8,9 +9,12 @@ use std::{
use crate::{
common::{PeerId, global_ctx::NetworkIdentity},
proto::peer_rpc::{
ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, PeerIdentityType,
RouteForeignNetworkInfos, RouteForeignNetworkSummary, RoutePeerInfo,
proto::{
api::instance::ListPublicIpv6InfoResponse,
peer_rpc::{
ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, PeerIdentityType,
RouteForeignNetworkInfos, RouteForeignNetworkSummary, RoutePeerInfo,
},
},
};
@@ -93,6 +97,22 @@ pub trait Route {
// TODO: rewrite route management, remove this
async fn list_proxy_cidrs_v6(&self) -> BTreeSet<Ipv6Cidr>;
async fn list_public_ipv6_routes(&self) -> BTreeSet<Ipv6Inet> {
BTreeSet::new()
}
async fn get_my_public_ipv6_addr(&self) -> Option<Ipv6Inet> {
None
}
async fn get_public_ipv6_gateway_peer_id(&self) -> Option<PeerId> {
None
}
async fn get_local_public_ipv6_info(&self) -> ListPublicIpv6InfoResponse {
ListPublicIpv6InfoResponse::default()
}
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
None
}
@@ -194,6 +214,14 @@ impl Route for MockRoute {
unimplemented!()
}
async fn list_public_ipv6_routes(&self) -> BTreeSet<Ipv6Inet> {
unimplemented!()
}
async fn get_my_public_ipv6_addr(&self) -> Option<Ipv6Inet> {
panic!("mock route")
}
async fn get_peer_info(&self, _peer_id: PeerId) -> Option<RoutePeerInfo> {
panic!("mock route")
}
+13 -3
View File
@@ -13,9 +13,9 @@ use crate::{
GetWhitelistRequest, GetWhitelistResponse, ListCredentialsRequest,
ListCredentialsResponse, ListForeignNetworkRequest, ListForeignNetworkResponse,
ListGlobalForeignNetworkRequest, ListGlobalForeignNetworkResponse, ListPeerRequest,
ListPeerResponse, ListRouteRequest, ListRouteResponse, PeerInfo, PeerManageRpc,
RevokeCredentialRequest, RevokeCredentialResponse, ShowNodeInfoRequest,
ShowNodeInfoResponse,
ListPeerResponse, ListPublicIpv6InfoRequest, ListPublicIpv6InfoResponse,
ListRouteRequest, ListRouteResponse, PeerInfo, PeerManageRpc, RevokeCredentialRequest,
RevokeCredentialResponse, ShowNodeInfoRequest, ShowNodeInfoResponse,
},
rpc_types::{self, controller::BaseController},
},
@@ -99,6 +99,16 @@ impl PeerManageRpc for PeerManagerRpcService {
Ok(reply)
}
async fn list_public_ipv6_info(
&self,
_: BaseController,
_request: ListPublicIpv6InfoRequest,
) -> Result<ListPublicIpv6InfoResponse, rpc_types::error::Error> {
Ok(weak_upgrade(&self.peer_manager)?
.get_local_public_ipv6_info()
.await)
}
async fn list_route(
&self,
_: BaseController,
+20
View File
@@ -201,6 +201,11 @@ impl LogicalTrafficMetrics {
self.per_peer.len()
}
#[cfg(test)]
fn contains_peer_cache(&self, peer_id: PeerId) -> bool {
self.per_peer.contains_key(&peer_id)
}
fn build_peer_counters(&self, instance_id: &str) -> TrafficCounters {
let instance_label = match self.label_kind {
InstanceLabelKind::To => LabelType::ToInstanceId(instance_id.to_string()),
@@ -241,6 +246,13 @@ pub(crate) fn traffic_kind(packet_type: u8) -> TrafficKind {
}
}
pub(crate) fn is_relay_data_packet_type(packet_type: u8) -> bool {
// Relay handshakes are control-plane setup; payload data is blocked by its
// original packet type after the session exists.
traffic_kind(packet_type) == TrafficKind::Data
|| packet_type == PacketType::ForeignNetworkPacket as u8
}
#[derive(Clone)]
struct TrafficMetricGroup {
data: Arc<LogicalTrafficMetrics>,
@@ -326,6 +338,14 @@ impl TrafficMetricRecorder {
self.rx_metrics.control.clear_peer_cache();
}
#[cfg(test)]
pub(crate) fn contains_peer_cache(&self, peer_id: PeerId) -> bool {
self.tx_metrics.data.contains_peer_cache(peer_id)
|| self.tx_metrics.control.contains_peer_cache(peer_id)
|| self.rx_metrics.data.contains_peer_cache(peer_id)
|| self.rx_metrics.control.contains_peer_cache(peer_id)
}
fn resolve_instance_id(&self, peer_id: PeerId) -> BoxFuture<'static, Option<String>> {
(self.resolve_instance_id)(peer_id)
}
+4
View File
@@ -24,6 +24,10 @@ message InstanceConfigPatch {
repeated ExitNodePatch exit_nodes = 8;
repeated UrlPatch mapped_listeners = 9;
repeated UrlPatch connectors = 10;
optional bool ipv6_public_addr_provider = 11;
optional bool ipv6_public_addr_auto = 12;
optional string ipv6_public_addr_prefix = 13;
optional bool disable_relay_data = 14;
}
message PortForwardPatch {
+21
View File
@@ -81,6 +81,8 @@ message Route {
optional int32 path_latency_latency_first = 14;
common.Ipv6Inet ipv6_addr = 15;
common.Ipv6Inet public_ipv6_addr = 16;
common.Ipv6Inet ipv6_public_addr_prefix = 17;
}
message PeerRoutePair {
@@ -100,12 +102,29 @@ message NodeInfo {
string version = 9;
common.PeerFeatureFlag feature_flag = 10;
peer_rpc.GetIpListResponse ip_list = 11;
common.Ipv6Inet public_ipv6_addr = 12;
common.Ipv6Inet ipv6_public_addr_prefix = 13;
}
message ShowNodeInfoRequest { InstanceIdentifier instance = 1; }
message ShowNodeInfoResponse { NodeInfo node_info = 1; }
message PublicIpv6LeaseInfo {
uint32 peer_id = 1;
string inst_id = 2;
common.Ipv6Inet leased_addr = 3;
int64 valid_until_unix_seconds = 4;
bool reused = 5;
}
message ListPublicIpv6InfoRequest { InstanceIdentifier instance = 1; }
message ListPublicIpv6InfoResponse {
common.Ipv6Inet provider_prefix = 1;
repeated PublicIpv6LeaseInfo provider_leases = 2;
}
message ListRouteRequest { InstanceIdentifier instance = 1; }
message ListRouteResponse { repeated Route routes = 1; }
@@ -167,6 +186,8 @@ message GetForeignNetworkSummaryResponse {
service PeerManageRpc {
rpc ListPeer(ListPeerRequest) returns (ListPeerResponse);
rpc ListPublicIpv6Info(ListPublicIpv6InfoRequest)
returns (ListPublicIpv6InfoResponse);
rpc ListRoute(ListRouteRequest) returns (ListRouteResponse);
rpc DumpRoute(DumpRouteRequest) returns (DumpRouteResponse);
rpc ListForeignNetwork(ListForeignNetworkRequest)
+4
View File
@@ -96,6 +96,10 @@ message NetworkConfig {
optional bool need_p2p = 59;
optional uint64 instance_recv_bps_limit = 60;
optional bool disable_upnp = 61;
optional bool ipv6_public_addr_provider = 62;
optional bool ipv6_public_addr_auto = 63;
optional string ipv6_public_addr_prefix = 64;
optional bool disable_relay_data = 65;
}
message PortForwardConfig {
+3
View File
@@ -75,6 +75,7 @@ message FlagsInConfig {
bool need_p2p = 38;
uint64 instance_recv_bps_limit = 39;
bool disable_upnp = 40;
bool disable_relay_data = 41;
}
message RpcDescriptor {
@@ -104,6 +105,7 @@ enum CompressionAlgoPb {
Invalid = 0;
None = 1;
Zstd = 2;
Lzo = 3;
}
message RpcCompressionInfo {
@@ -225,6 +227,7 @@ message PeerFeatureFlag {
bool is_credential_peer = 8;
bool need_p2p = 9;
bool disable_p2p = 10;
bool ipv6_public_addr_provider = 11;
}
enum SocketType {
+4
View File
@@ -467,6 +467,8 @@ impl TryFrom<CompressionAlgoPb> for CompressorAlgo {
match value {
#[cfg(feature = "zstd")]
CompressionAlgoPb::Zstd => Ok(CompressorAlgo::ZstdDefault),
#[cfg(feature = "lzo")]
CompressionAlgoPb::Lzo => Ok(CompressorAlgo::Lzo),
CompressionAlgoPb::None => Ok(CompressorAlgo::None),
_ => Err(anyhow::anyhow!("Invalid CompressionAlgoPb")),
}
@@ -480,6 +482,8 @@ impl TryFrom<CompressorAlgo> for CompressionAlgoPb {
match value {
#[cfg(feature = "zstd")]
CompressorAlgo::ZstdDefault => Ok(CompressionAlgoPb::Zstd),
#[cfg(feature = "lzo")]
CompressorAlgo::Lzo => Ok(CompressionAlgoPb::Lzo),
CompressorAlgo::None => Ok(CompressionAlgoPb::None),
}
}
+43
View File
@@ -47,6 +47,9 @@ message RoutePeerInfo {
// Trusted credential public keys published by admin nodes (holding network_secret)
repeated TrustedCredentialPubkeyProof trusted_credential_pubkeys = 19;
optional common.Ipv6Inet ipv6_public_addr_prefix = 22;
optional common.Ipv6Inet ipv6_public_addr_lease = 24;
}
message PeerIdVersion {
@@ -133,6 +136,46 @@ service OspfRouteRpc {
rpc SyncRouteInfo(SyncRouteInfoRequest) returns (SyncRouteInfoResponse);
}
message AcquireIpv6PublicAddrLeaseRequest {
uint32 peer_id = 1;
common.UUID inst_id = 2;
}
message RenewIpv6PublicAddrLeaseRequest {
uint32 peer_id = 1;
common.UUID inst_id = 2;
common.Ipv6Inet leased_addr = 3;
}
message ReleaseIpv6PublicAddrLeaseRequest {
uint32 peer_id = 1;
common.UUID inst_id = 2;
}
message GetIpv6PublicAddrLeaseRequest {
uint32 peer_id = 1;
common.UUID inst_id = 2;
}
message Ipv6PublicAddrLeaseReply {
uint32 provider_peer_id = 1;
common.UUID provider_inst_id = 2;
common.Ipv6Inet provider_prefix = 3;
common.Ipv6Inet leased_addr = 4;
google.protobuf.Timestamp valid_until = 5;
bool reused = 6;
optional string error_msg = 7;
}
service PublicIpv6AddrRpc {
rpc AcquireLease(AcquireIpv6PublicAddrLeaseRequest)
returns (Ipv6PublicAddrLeaseReply);
rpc RenewLease(RenewIpv6PublicAddrLeaseRequest)
returns (Ipv6PublicAddrLeaseReply);
rpc ReleaseLease(ReleaseIpv6PublicAddrLeaseRequest) returns (common.Void);
rpc GetLease(GetIpv6PublicAddrLeaseRequest) returns (Ipv6PublicAddrLeaseReply);
}
message GetIpListRequest {}
message GetIpListResponse {
+1 -1
View File
@@ -1,10 +1,10 @@
use std::sync::{Arc, Mutex, atomic::AtomicBool};
use futures::{SinkExt as _, StreamExt};
use guarden::defer;
use tokio::{task::JoinSet, time::timeout};
use crate::{
defer,
proto::rpc_types::error::Error,
tunnel::{Tunnel, packet_def::PacketType, ring::create_ring_tunnel_pair},
};
+2 -3
View File
@@ -4,18 +4,17 @@ use std::sync::{Arc, Mutex};
use bytes::Bytes;
use dashmap::DashMap;
use guarden::defer;
use prost::Message;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use tokio::time::timeout;
use tokio_stream::StreamExt;
use crate::common::shrink_dashmap;
use crate::common::{
PeerId,
PeerId, shrink_dashmap,
stats_manager::{LabelSet, LabelType, MetricName, StatsManager},
};
use crate::defer;
use crate::proto::common::{
CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse,
};
+15 -1
View File
@@ -3,7 +3,10 @@ use std::sync::Arc;
use crate::{
instance_manager::NetworkInstanceManager,
proto::{
api::instance::{self, ListPeerRequest, ListPeerResponse, PeerManageRpc},
api::instance::{
self, ListPeerRequest, ListPeerResponse, ListPublicIpv6InfoRequest,
ListPublicIpv6InfoResponse, PeerManageRpc,
},
rpc_types::controller::BaseController,
},
};
@@ -34,6 +37,17 @@ impl PeerManageRpc for PeerManageRpcService {
.await
}
async fn list_public_ipv6_info(
&self,
ctrl: Self::Controller,
req: ListPublicIpv6InfoRequest,
) -> crate::proto::rpc_types::error::Result<ListPublicIpv6InfoResponse> {
super::get_instance_service(&self.instance_manager, &req.instance)?
.get_peer_manage_service()
.list_public_ipv6_info(ctrl, req)
.await
}
async fn list_route(
&self,
ctrl: Self::Controller,
+330 -4
View File
@@ -14,13 +14,17 @@ use crate::{
},
instance::instance::Instance,
tests::three_node::{generate_secure_mode_config, generate_secure_mode_config_with_key},
tunnel::{common::tests::wait_for_condition, tcp::TcpTunnelConnector},
tunnel::{common::tests::wait_for_condition, tcp::TcpTunnelConnector, udp::UdpTunnelConnector},
};
use super::{add_ns_to_bridge, create_netns, del_netns, drop_insts, ping_test};
use rstest::rstest;
const PUBLIC_SERVER_NETWORK_NAME: &str = "__public_server__";
const PUBLIC_SERVER_SHARED_SECRET: &str = "public-server-shared-secret";
const NEED_P2P_ADMIN_NETWORK_NAME: &str = "need_p2p_credential_test_network";
/// Prepare network namespaces for credential tests
/// Topology:
/// br_a (10.1.1.0/24): ns_adm (10.1.1.1), ns_c1 (10.1.1.2), ns_c2 (10.1.1.3), ns_c3 (10.1.1.4), ns_c4 (10.1.1.5)
@@ -221,6 +225,328 @@ fn create_shared_config(
config
}
fn create_public_server_config() -> TomlConfigLoader {
let config = TomlConfigLoader::default();
config.set_inst_name(PUBLIC_SERVER_NETWORK_NAME.to_string());
config.set_hostname(Some("public-server".to_string()));
config.set_netns(Some("ns_adm".to_string()));
config.set_listeners(vec!["udp://0.0.0.0:11010".parse().unwrap()]);
config.set_network_identity(NetworkIdentity::new(
PUBLIC_SERVER_NETWORK_NAME.to_string(),
PUBLIC_SERVER_SHARED_SECRET.to_string(),
));
config.set_secure_mode(Some(generate_secure_mode_config()));
let mut flags = config.get_flags();
flags.no_tun = true;
flags.private_mode = true;
flags.relay_all_peer_rpc = true;
flags.relay_network_whitelist = "".to_string();
config.set_flags(flags);
config
}
fn create_need_p2p_admin_config(listener_scheme: &str) -> TomlConfigLoader {
let config = TomlConfigLoader::default();
config.set_inst_name(NEED_P2P_ADMIN_NETWORK_NAME.to_string());
config.set_hostname(Some("need-p2p-admin".to_string()));
config.set_netns(Some("ns_c3".to_string()));
config.set_listeners(vec![
format!("{listener_scheme}://0.0.0.0:0").parse().unwrap(),
]);
config.set_network_identity(NetworkIdentity::new(
NEED_P2P_ADMIN_NETWORK_NAME.to_string(),
PUBLIC_SERVER_SHARED_SECRET.to_string(),
));
config.set_secure_mode(Some(generate_secure_mode_config()));
let mut flags = config.get_flags();
flags.no_tun = true;
flags.relay_all_peer_rpc = true;
flags.need_p2p = true;
flags.disable_udp_hole_punching = true;
flags.disable_tcp_hole_punching = true;
flags.disable_sym_hole_punching = true;
config.set_flags(flags);
config
}
#[allow(clippy::too_many_arguments)]
fn create_public_server_credential_config(
credential_secret: &str,
inst_name: &str,
hostname: &str,
ns: &str,
ipv4: &str,
ipv6: &str,
tcp_listener_port: u16,
udp_listener_port: u16,
proxy_cidrs: &[&str],
) -> TomlConfigLoader {
let config = create_credential_config_from_secret(
NEED_P2P_ADMIN_NETWORK_NAME.to_string(),
credential_secret,
inst_name,
Some(ns),
ipv4,
ipv6,
);
config.set_hostname(Some(hostname.to_string()));
config.set_listeners(vec![
format!("tcp://0.0.0.0:{tcp_listener_port}")
.parse()
.unwrap(),
format!("udp://0.0.0.0:{udp_listener_port}")
.parse()
.unwrap(),
]);
for cidr in proxy_cidrs {
config
.add_proxy_cidr((*cidr).parse().unwrap(), None)
.unwrap();
}
let mut flags = config.get_flags();
flags.disable_p2p = true;
config.set_flags(flags);
config
}
async fn wait_direct_peer(inst: &Instance, peer_id: u32, timeout: Duration, label: &str) {
wait_for_condition(
|| async {
let peers = inst.get_peer_manager().get_peer_map().list_peers();
let connected = peers.contains(&peer_id);
println!("{label}: direct peers={:?}, target={}", peers, peer_id);
connected
},
timeout,
)
.await;
}
async fn wait_running_listener(inst: &Instance, scheme: &str, timeout: Duration, label: &str) {
wait_for_condition(
|| async {
let listeners = inst.get_global_ctx().get_running_listeners();
let matched = listeners.iter().any(|listener| {
listener.scheme() == scheme && listener.port().is_some_and(|p| p != 0)
});
println!("{label}: running listeners={:?}", listeners);
matched
},
timeout,
)
.await;
}
async fn wait_route_cost(inst: &Instance, peer_id: u32, cost: i32, timeout: Duration, label: &str) {
wait_for_condition(
|| async {
let routes = inst.get_peer_manager().list_routes().await;
let matched = routes
.iter()
.any(|route| route.peer_id == peer_id && route.cost == cost);
println!(
"{label}: routes={:?}, target={}, cost={}",
routes
.iter()
.map(|route| (route.peer_id, route.cost))
.collect::<Vec<_>>(),
peer_id,
cost
);
matched
},
timeout,
)
.await;
}
async fn wait_foreign_network_count(inst: &Instance, expected: usize, timeout: Duration) {
wait_for_condition(
|| async {
let foreign_networks = inst
.get_peer_manager()
.get_foreign_network_manager()
.list_foreign_networks()
.await
.foreign_networks;
println!("foreign networks: {:?}", foreign_networks);
foreign_networks.len() == expected
},
timeout,
)
.await;
}
/// Regression coverage for a public-server-mediated credential topology:
/// Public server <- admin peer (need_p2p) <- two credential peers.
///
/// Credential peers set `disable_p2p=true`, while the admin peer advertises `need_p2p=true`.
/// The credential peers should still proactively build direct peers with the admin peer through
/// peer RPC forwarded by the public server, even when the admin listener binds an ephemeral port.
#[rstest]
#[case("quic")]
#[case("wss")]
#[case("tcp")]
#[case("udp")]
#[tokio::test]
#[serial_test::serial]
async fn credential_peers_p2p_to_need_p2p_admin_through_public_server(
#[case] admin_listener_scheme: &str,
) {
prepare_credential_network();
let mut public_server_inst = Instance::new(create_public_server_config());
public_server_inst.run().await.unwrap();
let mut admin_inst = Instance::new(create_need_p2p_admin_config(admin_listener_scheme));
admin_inst.run().await.unwrap();
wait_running_listener(
&admin_inst,
admin_listener_scheme,
Duration::from_secs(10),
"admin ephemeral listener",
)
.await;
admin_inst
.get_conn_manager()
.add_connector(UdpTunnelConnector::new(
"udp://10.1.1.1:11010".parse().unwrap(),
));
wait_foreign_network_count(&public_server_inst, 1, Duration::from_secs(10)).await;
let (_credential_a_id, credential_a_secret) = admin_inst
.get_global_ctx()
.get_credential_manager()
.generate_credential_with_options(
vec![],
false,
vec!["10.1.0.0/24".to_string()],
Duration::from_secs(3600),
Some("credential-peer-a".to_string()),
false,
);
let (_credential_b_id, credential_b_secret) = admin_inst
.get_global_ctx()
.get_credential_manager()
.generate_credential_with_options(
vec![],
false,
vec![],
Duration::from_secs(3600),
Some("credential-peer-b".to_string()),
false,
);
admin_inst
.get_global_ctx()
.issue_event(GlobalCtxEvent::CredentialChanged);
wait_foreign_network_count(&public_server_inst, 1, Duration::from_secs(10)).await;
let mut credential_a_inst = Instance::new(create_public_server_credential_config(
&credential_a_secret,
"credential-peer-a",
"credential-a",
"ns_c1",
"10.154.0.1",
"fd00::1/64",
11030,
11031,
&["10.1.0.0/24"],
));
let mut credential_b_inst = Instance::new(create_public_server_credential_config(
&credential_b_secret,
"credential-peer-b",
"credential-b",
"ns_c2",
"10.154.0.2",
"fd00::2/64",
11040,
11041,
&[],
));
credential_a_inst.run().await.unwrap();
credential_b_inst.run().await.unwrap();
credential_a_inst
.get_conn_manager()
.add_connector(UdpTunnelConnector::new(
"udp://10.1.1.1:11010".parse().unwrap(),
));
credential_b_inst
.get_conn_manager()
.add_connector(UdpTunnelConnector::new(
"udp://10.1.1.1:11010".parse().unwrap(),
));
let admin_peer_id = admin_inst.peer_id();
let credential_a_peer_id = credential_a_inst.peer_id();
let credential_b_peer_id = credential_b_inst.peer_id();
println!(
"admin={}, credential_a={}, credential_b={}, admin_listener_scheme={}",
admin_peer_id, credential_a_peer_id, credential_b_peer_id, admin_listener_scheme
);
wait_direct_peer(
&credential_a_inst,
admin_peer_id,
Duration::from_secs(30),
"credential_a -> admin",
)
.await;
wait_direct_peer(
&credential_b_inst,
admin_peer_id,
Duration::from_secs(30),
"credential_b -> admin",
)
.await;
wait_direct_peer(
&admin_inst,
credential_a_peer_id,
Duration::from_secs(10),
"admin -> credential_a",
)
.await;
wait_direct_peer(
&admin_inst,
credential_b_peer_id,
Duration::from_secs(10),
"admin -> credential_b",
)
.await;
wait_route_cost(
&credential_a_inst,
admin_peer_id,
1,
Duration::from_secs(10),
"credential_a route to admin",
)
.await;
wait_route_cost(
&credential_b_inst,
admin_peer_id,
1,
Duration::from_secs(10),
"credential_b route to admin",
)
.await;
drop_insts(vec![
public_server_inst,
admin_inst,
credential_a_inst,
credential_b_inst,
])
.await;
}
fn create_generated_credential_config(
admin_inst: &Instance,
inst_name: &str,
@@ -501,10 +827,10 @@ async fn credential_relay_capability(#[case] allow_relay: bool) {
// Create admin node
let admin_config = create_admin_config("admin", Some("ns_adm"), "10.144.144.1", "fd00::1/64");
let mut admin_inst = Instance::new(admin_config);
let mut ff = admin_inst.get_global_ctx().get_feature_flags();
// if cred c allow relay, we set admin inst avoid relay (if other same-cost path available, admin will not relay data)
ff.avoid_relay_data = allow_relay;
admin_inst.get_global_ctx().set_feature_flags(ff);
admin_inst
.get_global_ctx()
.set_avoid_relay_data_preference(allow_relay);
admin_inst.run().await.unwrap();
let admin_peer_id = admin_inst.peer_id();
+1 -1
View File
@@ -38,7 +38,7 @@ async fn test_route_peer_info_ipv6() {
global_ctx.set_ipv6(Some(ipv6_cidr));
// Create RoutePeerInfo with IPv6 support
let updated_info = RoutePeerInfo::new_updated_self(123, 456, &global_ctx);
let updated_info = RoutePeerInfo::new_updated_self(123, 456, &global_ctx, None);
// Verify IPv6 address is included
assert!(updated_info.ipv6_addr.is_some());
+718 -1
View File
@@ -402,6 +402,528 @@ async fn ping6_test(from_netns: &str, target_ip: &str, payload_size: Option<usiz
code.code().unwrap() == 0
}
fn run_cmd(program: &str, args: &[&str]) {
let output = std::process::Command::new(program)
.args(args)
.output()
.unwrap();
assert!(
output.status.success(),
"{} {:?} failed: stdout={}, stderr={}",
program,
args,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
fn run_cmd_output(program: &str, args: &[&str]) -> String {
let output = std::process::Command::new(program)
.args(args)
.output()
.unwrap();
assert!(
output.status.success(),
"{} {:?} failed: stdout={}, stderr={}",
program,
args,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
String::from_utf8(output.stdout).unwrap()
}
fn run_ip(args: &[&str]) {
run_cmd("ip", args);
}
fn run_ip_in_ns(ns: &str, args: &[&str]) {
let mut cmd = vec!["netns", "exec", ns, "ip"];
cmd.extend_from_slice(args);
run_cmd("ip", &cmd);
}
fn run_ip_in_ns_output(ns: &str, args: &[&str]) -> String {
let mut cmd = vec!["netns", "exec", ns, "ip"];
cmd.extend_from_slice(args);
run_cmd_output("ip", &cmd)
}
fn run_sysctl_in_ns(ns: &str, assignment: &str) {
run_cmd("ip", &["netns", "exec", ns, "sysctl", "-qw", assignment]);
}
fn create_empty_netns(name: &str) {
del_netns(name);
run_ip(&["netns", "add", name]);
run_ip(&["netns", "exec", name, "ip", "link", "set", "lo", "up"]);
}
fn connect_ns_to_bridge(ns: &str, guest_if: &str, host_if: &str, bridge: &str) {
let _ = std::process::Command::new("ip")
.args(["link", "del", host_if])
.status();
run_ip(&[
"link", "add", host_if, "type", "veth", "peer", "name", guest_if,
]);
run_ip(&["link", "set", guest_if, "netns", ns]);
run_ip(&["link", "set", host_if, "up"]);
run_cmd("brctl", &["addif", bridge, host_if]);
run_ip(&["netns", "exec", ns, "ip", "link", "set", guest_if, "up"]);
}
struct PublicIpv6Lab {
extra_namespaces: [&'static str; 2],
extra_bridges: [&'static str; 2],
}
impl PublicIpv6Lab {
const PROVIDER_NS: &'static str = "net_a";
const CLIENT_NS: &'static str = "net_b";
const UPSTREAM_NS: &'static str = "net_pubgw";
const SERVER_NS: &'static str = "net_pubsrv";
const WAN_BRIDGE: &'static str = "br_pubwan";
const SERVER_BRIDGE: &'static str = "br_pubsrv";
const PROVIDER_TUN: &'static str = "etpubv6p";
const CLIENT_TUN: &'static str = "etpubv6c";
const PROVIDER_PREFIX: &'static str = "2001:db8:100::/64";
const PROVIDER_DEFAULT_FROM: &'static str = "2001:db8:100::/64";
const PROVIDER_WAN_ADDR: &'static str = "2001:db8:ffff:1::2/64";
const UPSTREAM_WAN_ADDR: &'static str = "2001:db8:ffff:1::1/64";
const UPSTREAM_SERVER_ADDR: &'static str = "2001:db8:ffff:2::1/64";
const SERVER_ADDR: &'static str = "2001:db8:ffff:2::100/64";
const SERVER_IP: &'static str = "2001:db8:ffff:2::100";
fn setup() -> Self {
prepare_linux_namespaces();
del_netns(Self::UPSTREAM_NS);
del_netns(Self::SERVER_NS);
let _ = std::process::Command::new("ip")
.args(["link", "del", Self::WAN_BRIDGE])
.status();
let _ = std::process::Command::new("ip")
.args(["link", "del", Self::SERVER_BRIDGE])
.status();
let _ = std::process::Command::new("brctl")
.args(["delbr", Self::WAN_BRIDGE])
.status();
let _ = std::process::Command::new("brctl")
.args(["delbr", Self::SERVER_BRIDGE])
.status();
create_empty_netns(Self::UPSTREAM_NS);
create_empty_netns(Self::SERVER_NS);
prepare_bridge(Self::WAN_BRIDGE);
prepare_bridge(Self::SERVER_BRIDGE);
run_ip(&["link", "set", Self::WAN_BRIDGE, "up"]);
run_ip(&["link", "set", Self::SERVER_BRIDGE, "up"]);
connect_ns_to_bridge(
Self::PROVIDER_NS,
"pubwan0",
"veth_pubwan_p",
Self::WAN_BRIDGE,
);
connect_ns_to_bridge(
Self::UPSTREAM_NS,
"upwan0",
"veth_pubwan_u",
Self::WAN_BRIDGE,
);
connect_ns_to_bridge(
Self::UPSTREAM_NS,
"upsrv0",
"veth_pubsrv_u",
Self::SERVER_BRIDGE,
);
connect_ns_to_bridge(
Self::SERVER_NS,
"srv0",
"veth_pubsrv_s",
Self::SERVER_BRIDGE,
);
run_ip_in_ns(
Self::PROVIDER_NS,
&["addr", "add", Self::PROVIDER_WAN_ADDR, "dev", "pubwan0"],
);
run_ip_in_ns(
Self::UPSTREAM_NS,
&["addr", "add", Self::UPSTREAM_WAN_ADDR, "dev", "upwan0"],
);
run_ip_in_ns(
Self::UPSTREAM_NS,
&["addr", "add", Self::UPSTREAM_SERVER_ADDR, "dev", "upsrv0"],
);
run_ip_in_ns(
Self::SERVER_NS,
&["addr", "add", Self::SERVER_ADDR, "dev", "srv0"],
);
run_ip_in_ns(
Self::PROVIDER_NS,
&["link", "add", "pubprefix0", "type", "dummy"],
);
run_ip_in_ns(Self::PROVIDER_NS, &["link", "set", "pubprefix0", "up"]);
run_ip_in_ns(
Self::PROVIDER_NS,
&[
"-6",
"route",
"add",
Self::PROVIDER_PREFIX,
"dev",
"pubprefix0",
],
);
run_ip_in_ns(
Self::PROVIDER_NS,
&[
"-6",
"route",
"add",
"default",
"from",
Self::PROVIDER_DEFAULT_FROM,
"via",
"2001:db8:ffff:1::1",
"dev",
"pubwan0",
],
);
run_ip_in_ns(
Self::SERVER_NS,
&[
"-6",
"route",
"add",
"default",
"via",
"2001:db8:ffff:2::1",
"dev",
"srv0",
],
);
run_ip_in_ns(
Self::UPSTREAM_NS,
&[
"-6",
"route",
"add",
Self::PROVIDER_PREFIX,
"via",
"2001:db8:ffff:1::2",
"dev",
"upwan0",
],
);
run_sysctl_in_ns(Self::PROVIDER_NS, "net.ipv6.conf.all.forwarding=1");
run_sysctl_in_ns(Self::UPSTREAM_NS, "net.ipv6.conf.all.forwarding=1");
Self {
extra_namespaces: [Self::UPSTREAM_NS, Self::SERVER_NS],
extra_bridges: [Self::WAN_BRIDGE, Self::SERVER_BRIDGE],
}
}
}
impl Drop for PublicIpv6Lab {
fn drop(&mut self) {
for ns in self.extra_namespaces {
del_netns(ns);
}
for bridge in self.extra_bridges {
let _ = std::process::Command::new("ip")
.args(["link", "del", bridge])
.status();
let _ = std::process::Command::new("brctl")
.args(["delbr", bridge])
.status();
}
}
}
fn get_public_ipv6_config(
inst_name: &str,
netns: &str,
ipv4: &str,
dev_name: &str,
inst_id: uuid::Uuid,
) -> TomlConfigLoader {
let config = get_inst_config(inst_name, Some(netns), ipv4, "fd00::1/64");
config.set_id(inst_id);
config.set_ipv6(None);
config.set_socks5_portal(None);
config.set_network_identity(NetworkIdentity {
network_name: "public_ipv6_auto_addr_test".to_string(),
network_secret: Some("public_ipv6_auto_addr_secret".to_string()),
network_secret_digest: None,
});
config.set_listeners(vec!["tcp://0.0.0.0:11010".parse().unwrap()]);
let mut flags = config.get_flags();
flags.dev_name = dev_name.to_string();
config.set_flags(flags);
config
}
async fn init_public_ipv6_two_node(
client_inst_id: uuid::Uuid,
) -> (PublicIpv6Lab, Instance, Instance) {
let lab = PublicIpv6Lab::setup();
let provider_cfg = get_public_ipv6_config(
"provider_public_ipv6",
PublicIpv6Lab::PROVIDER_NS,
"10.144.144.1",
PublicIpv6Lab::PROVIDER_TUN,
uuid::Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap(),
);
provider_cfg.set_ipv6_public_addr_provider(true);
let client_cfg = get_public_ipv6_config(
"client_public_ipv6",
PublicIpv6Lab::CLIENT_NS,
"10.144.144.2",
PublicIpv6Lab::CLIENT_TUN,
client_inst_id,
);
client_cfg.set_ipv6_public_addr_auto(true);
let mut provider = Instance::new(provider_cfg);
let mut client = Instance::new(client_cfg);
provider.run().await.unwrap();
client.run().await.unwrap();
provider
.get_conn_manager()
.add_connector(TcpTunnelConnector::new(
"tcp://10.1.1.2:11010".parse().unwrap(),
));
wait_for_condition(
|| async {
provider.get_peer_manager().list_routes().await.len() == 1
&& client.get_peer_manager().list_routes().await.len() == 1
},
Duration::from_secs(8),
)
.await;
(lab, provider, client)
}
async fn wait_for_public_ipv6_addr(inst: &Instance) -> cidr::Ipv6Inet {
wait_for_condition(
|| async {
inst.get_peer_manager()
.get_my_public_ipv6_addr()
.await
.is_some()
},
Duration::from_secs(10),
)
.await;
inst.get_peer_manager()
.get_my_public_ipv6_addr()
.await
.unwrap()
}
async fn wait_for_public_ipv6_route(inst: &Instance, target: cidr::Ipv6Inet) {
wait_for_condition(
|| async {
inst.get_peer_manager()
.list_public_ipv6_routes()
.await
.contains(&target)
},
Duration::from_secs(10),
)
.await;
}
fn route_exists_in_ns(ns: &str, needle: &str) -> bool {
run_ip_in_ns_output(ns, &["-6", "route", "show"])
.lines()
.any(|line| line.contains(needle))
}
fn addr_exists_in_ns(ns: &str, dev: &str, needle: &str) -> bool {
run_ip_in_ns_output(ns, &["-6", "addr", "show", "dev", dev]).contains(needle)
}
#[tokio::test]
#[serial_test::serial]
pub async fn public_ipv6_auto_addr_end_to_end() {
let client_id = uuid::Uuid::parse_str("22222222-2222-2222-2222-222222222222").unwrap();
let (_lab, provider, client) = init_public_ipv6_two_node(client_id).await;
wait_for_condition(
|| async {
provider
.get_global_ctx()
.get_advertised_ipv6_public_addr_prefix()
== Some(PublicIpv6Lab::PROVIDER_PREFIX.parse().unwrap())
},
Duration::from_secs(10),
)
.await;
let leased = wait_for_public_ipv6_addr(&client).await;
wait_for_public_ipv6_route(&provider, leased).await;
assert_eq!(
provider
.get_global_ctx()
.config
.get_ipv6_public_addr_prefix(),
None
);
assert_eq!(
provider
.get_global_ctx()
.get_advertised_ipv6_public_addr_prefix(),
Some(PublicIpv6Lab::PROVIDER_PREFIX.parse().unwrap())
);
let provider_prefix = PublicIpv6Lab::PROVIDER_PREFIX
.parse::<cidr::Ipv6Cidr>()
.unwrap();
assert_eq!(
provider
.get_peer_manager()
.get_my_info()
.await
.ipv6_public_addr_prefix,
Some(
cidr::Ipv6Inet::new(
provider_prefix.first_address(),
provider_prefix.network_length()
)
.unwrap()
.into()
)
);
let provider_info = provider
.get_peer_manager()
.get_local_public_ipv6_info()
.await;
let client_peer_id = client.get_peer_manager().get_my_info().await.peer_id;
assert_eq!(
provider_info.provider_prefix,
Some(
cidr::Ipv6Inet::new(
provider_prefix.first_address(),
provider_prefix.network_length()
)
.unwrap()
.into()
)
);
assert_eq!(provider_info.provider_leases.len(), 1);
assert_eq!(provider_info.provider_leases[0].peer_id, client_peer_id);
assert_eq!(
provider_info.provider_leases[0].inst_id,
client_id.to_string()
);
assert_eq!(
provider_info.provider_leases[0].leased_addr,
Some(leased.into())
);
assert!(
leased.address().segments()[0] & 0xfe00 != 0xfc00,
"leased address should not be unique-local: {leased}"
);
wait_for_condition(
|| async {
addr_exists_in_ns(
PublicIpv6Lab::CLIENT_NS,
PublicIpv6Lab::CLIENT_TUN,
&leased.to_string(),
) && route_exists_in_ns(
PublicIpv6Lab::CLIENT_NS,
&format!("default dev {}", PublicIpv6Lab::CLIENT_TUN),
) && route_exists_in_ns(
PublicIpv6Lab::PROVIDER_NS,
&format!("{} dev {}", leased.address(), PublicIpv6Lab::PROVIDER_TUN),
)
},
Duration::from_secs(10),
)
.await;
wait_for_condition(
|| async { ping6_test(PublicIpv6Lab::CLIENT_NS, PublicIpv6Lab::SERVER_IP, None).await },
Duration::from_secs(10),
)
.await;
wait_for_condition(
|| async {
ping6_test(
PublicIpv6Lab::SERVER_NS,
leased.address().to_string().as_str(),
None,
)
.await
},
Duration::from_secs(10),
)
.await;
drop_insts(vec![provider, client]).await;
}
#[tokio::test]
#[serial_test::serial]
pub async fn public_ipv6_auto_addr_reconnect_reuses_same_address() {
let client_id = uuid::Uuid::parse_str("33333333-3333-3333-3333-333333333333").unwrap();
let (_lab, provider, client) = init_public_ipv6_two_node(client_id).await;
let first = wait_for_public_ipv6_addr(&client).await;
drop_insts(vec![client]).await;
let client_cfg = get_public_ipv6_config(
"client_public_ipv6_reconnect",
PublicIpv6Lab::CLIENT_NS,
"10.144.144.2",
PublicIpv6Lab::CLIENT_TUN,
client_id,
);
client_cfg.set_ipv6_public_addr_auto(true);
let mut client = Instance::new(client_cfg);
client.run().await.unwrap();
provider
.get_conn_manager()
.add_connector(TcpTunnelConnector::new(
"tcp://10.1.1.2:11010".parse().unwrap(),
));
wait_for_condition(
|| async {
provider.get_peer_manager().list_routes().await.len() == 1
&& client.get_peer_manager().list_routes().await.len() == 1
},
Duration::from_secs(8),
)
.await;
let second = wait_for_public_ipv6_addr(&client).await;
assert_eq!(first, second);
wait_for_condition(
|| async { ping6_test(PublicIpv6Lab::CLIENT_NS, PublicIpv6Lab::SERVER_IP, None).await },
Duration::from_secs(10),
)
.await;
drop_insts(vec![provider, client]).await;
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
@@ -3077,7 +3599,15 @@ pub async fn config_patch_test() {
};
use crate::tunnel::common::tests::_tunnel_pingpong_netns_with_timeout;
let insts = init_three_node("udp").await;
let insts = init_three_node_ex(
"udp",
|cfg| {
cfg.set_ipv6(None);
cfg
},
false,
)
.await;
check_route(
"10.144.144.2/24",
@@ -3124,6 +3654,46 @@ pub async fn config_patch_test() {
},
);
// 测试1.1:修改公网 IPv6 provider 相关配置
let public_prefix = "2001:db8:100::/64";
let patch = InstanceConfigPatch {
ipv6_public_addr_provider: Some(true),
ipv6_public_addr_auto: Some(true),
ipv6_public_addr_prefix: Some(public_prefix.to_string()),
..Default::default()
};
insts[1]
.get_config_patcher()
.apply_patch(patch)
.await
.unwrap();
assert!(
insts[1]
.get_global_ctx()
.config
.get_ipv6_public_addr_provider()
);
assert!(insts[1].get_global_ctx().config.get_ipv6_public_addr_auto());
assert_eq!(
insts[1]
.get_global_ctx()
.config
.get_ipv6_public_addr_prefix(),
Some(public_prefix.parse().unwrap())
);
assert!(
insts[1]
.get_global_ctx()
.get_feature_flags()
.ipv6_public_addr_provider
);
assert_eq!(
insts[1]
.get_global_ctx()
.get_advertised_ipv6_public_addr_prefix(),
Some(public_prefix.parse().unwrap())
);
// 测试2: 端口转发
let patch = InstanceConfigPatch {
port_forwards: vec![PortForwardPatch {
@@ -3160,6 +3730,153 @@ pub async fn config_patch_test() {
drop_insts(insts).await;
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
pub async fn config_patch_disable_relay_data_test() {
use crate::proto::api::config::InstanceConfigPatch;
let insts = init_three_node_ex(
"udp",
|cfg| {
cfg.set_ipv6(None);
cfg
},
false,
)
.await;
let relay_peer_id = insts[1].peer_id();
let dst_peer_id = insts[2].peer_id();
assert!(!insts[1].get_global_ctx().get_flags().disable_relay_data);
assert!(
!insts[1]
.get_global_ctx()
.get_feature_flags()
.avoid_relay_data
);
check_route_ex(
insts[0].get_peer_manager().list_routes().await,
dst_peer_id,
|route| {
assert_eq!(route.next_hop_peer_id, relay_peer_id);
true
},
);
wait_for_condition(
|| async { ping_test("net_a", "10.144.144.3", None).await },
Duration::from_secs(5),
)
.await;
insts[1]
.get_config_patcher()
.apply_patch(InstanceConfigPatch {
disable_relay_data: Some(true),
..Default::default()
})
.await
.unwrap();
assert!(insts[1].get_global_ctx().get_flags().disable_relay_data);
assert!(
insts[1]
.get_global_ctx()
.config
.get_flags()
.disable_relay_data
);
assert!(
insts[1]
.get_global_ctx()
.get_feature_flags()
.avoid_relay_data
);
wait_for_condition(
|| {
let peer_mgr = insts[0].get_peer_manager().clone();
async move {
peer_mgr.list_routes().await.iter().any(|route| {
route.peer_id == relay_peer_id
&& route
.feature_flag
.as_ref()
.map(|flag| flag.avoid_relay_data)
.unwrap_or(false)
})
}
},
Duration::from_secs(5),
)
.await;
check_route_ex(
insts[0].get_peer_manager().list_routes().await,
dst_peer_id,
|route| {
assert_eq!(route.next_hop_peer_id, relay_peer_id);
true
},
);
assert!(
!ping_test("net_a", "10.144.144.3", None).await,
"traffic from inst1 to inst3 should be blocked while inst2 relay data is disabled"
);
insts[1]
.get_config_patcher()
.apply_patch(InstanceConfigPatch {
disable_relay_data: Some(false),
..Default::default()
})
.await
.unwrap();
assert!(!insts[1].get_global_ctx().get_flags().disable_relay_data);
assert!(
!insts[1]
.get_global_ctx()
.config
.get_flags()
.disable_relay_data
);
assert!(
!insts[1]
.get_global_ctx()
.get_feature_flags()
.avoid_relay_data
);
wait_for_condition(
|| {
let peer_mgr = insts[0].get_peer_manager().clone();
async move {
peer_mgr.list_routes().await.iter().any(|route| {
route.peer_id == relay_peer_id
&& route
.feature_flag
.as_ref()
.map(|flag| !flag.avoid_relay_data)
.unwrap_or(false)
})
}
},
Duration::from_secs(5),
)
.await;
wait_for_condition(
|| async { ping_test("net_a", "10.144.144.3", None).await },
Duration::from_secs(5),
)
.await;
drop_insts(insts).await;
}
/// Generate SecureModeConfig with specified x25519 private key
pub fn generate_secure_mode_config_with_key(
private_key: &x25519_dalek::StaticSecret,
+10 -1
View File
@@ -281,6 +281,7 @@ impl TunnelListener for FakeTcpTunnelListener {
pub struct FakeTcpTunnelConnector {
addr: url::Url,
ip_to_if_name: IpToIfNameCache,
resolved_addr: Option<SocketAddr>,
}
impl FakeTcpTunnelConnector {
@@ -288,6 +289,7 @@ impl FakeTcpTunnelConnector {
FakeTcpTunnelConnector {
addr,
ip_to_if_name: IpToIfNameCache::new(),
resolved_addr: None,
}
}
}
@@ -314,7 +316,10 @@ fn get_local_ip_for_destination(destination: IpAddr) -> Option<IpAddr> {
#[async_trait::async_trait]
impl TunnelConnector for FakeTcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let remote_addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?;
let remote_addr = match self.resolved_addr {
Some(addr) => addr,
None => SocketAddr::from_url(self.addr.clone(), IpVersion::Both).await?,
};
let local_ip = get_local_ip_for_destination(remote_addr.ip())
.ok_or(TunnelError::InternalError("Failed to get local ip".into()))?;
@@ -390,6 +395,10 @@ impl TunnelConnector for FakeTcpTunnelConnector {
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_resolved_addr(&mut self, addr: SocketAddr) {
self.resolved_addr = Some(addr);
}
}
type RecvFut = Pin<Box<dyn Future<Output = Option<(BytesMut, usize)>> + Send + Sync>>;
@@ -57,21 +57,21 @@ cfg_select! {
pub mod windivert;
pub fn create_tun(
_interface_name: &str,
_src_addr: Option<SocketAddr>,
local_addr: SocketAddr,
interface_name: &str,
src_addr: Option<SocketAddr>,
dst_addr: SocketAddr,
) -> io::Result<Arc<dyn super::stack::Tun>> {
match windivert::WinDivertTun::new(local_addr) {
match windivert::WinDivertTun::new(src_addr, dst_addr) {
Ok(tun) => Ok(Arc::new(tun)),
Err(e) => {
tracing::warn!(
?e,
?local_addr,
?dst_addr,
"WinDivertTun init failed, falling back to PnetTun"
);
Ok(Arc::new(pnet::PnetTun::new(
local_addr.to_string().as_str(),
pnet::create_packet_filter(None, local_addr),
interface_name,
pnet::create_packet_filter(src_addr, dst_addr),
)?))
}
}
@@ -80,15 +80,11 @@ impl Drop for WinDivertTun {
}
impl WinDivertTun {
pub fn new(local_addr: SocketAddr) -> io::Result<Self> {
pub fn new(src_addr: Option<SocketAddr>, dst_addr: SocketAddr) -> io::Result<Self> {
let (tx, rx) = tokio::sync::mpsc::channel(1024);
let ip_filter = match local_addr {
SocketAddr::V4(addr) => format!("ip.DstAddr == {}", addr.ip()),
SocketAddr::V6(addr) => format!("ipv6.DstAddr == {}", addr.ip()),
};
// Filter: DstIP == LocalIP AND TCP.
let filter = format!("{} and tcp", ip_filter);
let filter = build_filter(src_addr, dst_addr)?;
tracing::debug!(%filter, "WinDivertTun created with filter");
// Sniff mode: 1 (WINDIVERT_FLAG_SNIFF)
// Layer: Network (0)
@@ -143,6 +139,46 @@ impl WinDivertTun {
}
}
fn build_filter(src_addr: Option<SocketAddr>, dst_addr: SocketAddr) -> io::Result<String> {
if let Some(src_addr) = src_addr
&& src_addr.is_ipv4() != dst_addr.is_ipv4()
{
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"src/dst addr family mismatch",
));
}
let mut filters = Vec::with_capacity(5);
filters.push("tcp".to_owned());
match dst_addr {
SocketAddr::V4(addr) => {
filters.push(format!("ip.DstAddr == {}", addr.ip()));
filters.push(format!("tcp.DstPort == {}", addr.port()));
}
SocketAddr::V6(addr) => {
filters.push(format!("ipv6.DstAddr == {}", addr.ip()));
filters.push(format!("tcp.DstPort == {}", addr.port()));
}
}
if let Some(src_addr) = src_addr {
match src_addr {
SocketAddr::V4(addr) => {
filters.push(format!("ip.SrcAddr == {}", addr.ip()));
filters.push(format!("tcp.SrcPort == {}", addr.port()));
}
SocketAddr::V6(addr) => {
filters.push(format!("ipv6.SrcAddr == {}", addr.ip()));
filters.push(format!("tcp.SrcPort == {}", addr.port()));
}
}
}
Ok(filters.join(" and "))
}
#[async_trait::async_trait]
impl stack::Tun for WinDivertTun {
async fn recv(&self, packet: &mut BytesMut) -> Result<usize, std::io::Error> {
-1
View File
@@ -128,7 +128,6 @@ pub fn build_tcp_packet(
eth_buf.freeze()
}
#[tracing::instrument(ret)]
pub fn parse_ip_packet(
buf: &Bytes,
) -> Option<(MacAddr, MacAddr, IPPacket<'_>, tcp::TcpPacket<'_>)> {
+4 -1
View File
@@ -517,9 +517,12 @@ impl Stack {
{
trace!(?tcp_packet, "Received SYN packet for port {}, ignoring", tcp_packet.get_destination());
continue;
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
info!("Unknown RST TCP packet from {}, ignoring", remote_addr);
continue;
} else {
trace!("Unknown TCP packet from {}, ignoring", remote_addr);
continue;
}
}
None => {
+25 -1
View File
@@ -141,6 +141,7 @@ pub trait TunnelConnector: Send {
fn remote_url(&self) -> url::Url;
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
fn set_ip_version(&mut self, _ip_version: IpVersion) {}
fn set_resolved_addr(&mut self, _addr: SocketAddr) {}
}
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
@@ -371,9 +372,13 @@ impl TryFrom<&url::Url> for TunnelScheme {
}
}
pub(crate) fn get_scheme_by_url(l: &url::Url) -> Result<TunnelScheme, Error> {
l.try_into()
}
macro_rules! __matches_scheme__ {
($url:expr, $( $pattern:pat_param )|+ ) => {
matches!($crate::tunnel::TunnelScheme::try_from(($url).as_ref()), Ok($( $pattern )|+))
matches!($crate::tunnel::get_scheme_by_url(&$url), Ok($( $pattern )|+))
};
}
@@ -393,3 +398,22 @@ macro_rules! __matches_protocol__ {
}
pub(crate) use __matches_protocol__ as matches_protocol;
#[cfg(test)]
mod tests {
use super::{IpScheme, TunnelScheme, matches_scheme};
#[test]
fn matches_scheme_accepts_owned_url() {
let url: url::Url = "udp://[2001:db8::1]:11010".parse().unwrap();
assert!(matches_scheme!(url, TunnelScheme::Ip(IpScheme::Udp)));
}
#[test]
fn matches_scheme_accepts_borrowed_url() {
let url: url::Url = "udp://[2001:db8::1]:11010".parse().unwrap();
assert!(matches_scheme!(&url, TunnelScheme::Ip(IpScheme::Udp)));
}
}
+15
View File
@@ -309,6 +309,8 @@ pub enum CompressorAlgo {
None = 0,
#[cfg(feature = "zstd")]
ZstdDefault = 1,
#[cfg(feature = "lzo")]
Lzo = 2,
}
#[repr(C, packed)]
@@ -323,6 +325,8 @@ impl CompressorTail {
match self.algo {
#[cfg(feature = "zstd")]
1 => Some(CompressorAlgo::ZstdDefault),
#[cfg(feature = "lzo")]
2 => Some(CompressorAlgo::Lzo),
_ => None,
}
}
@@ -730,6 +734,17 @@ impl ZCPacket {
}
}
pub fn foreign_network_inner_packet_type(&self) -> Option<u8> {
if self.peer_manager_header()?.packet_type != PacketType::ForeignNetworkPacket as u8 {
return None;
}
let payload = self.payload();
let hdr = ForeignNetworkPacketHeader::ref_from_prefix(payload)?;
let inner_packet = payload.get(hdr.get_header_len()..)?;
PeerManagerHeader::ref_from_prefix(inner_packet).map(|hdr| hdr.packet_type)
}
pub fn foreign_network_packet(mut self) -> Self {
let hdr = self.foreign_network_hdr().unwrap();
let foreign_hdr_len = hdr.get_header_len();
+212 -17
View File
@@ -14,8 +14,8 @@ use derivative::Derivative;
use derive_more::{Deref, DerefMut};
use parking_lot::RwLock;
use quinn::{
ClientConfig, Connection, Endpoint, EndpointConfig, ServerConfig, TransportConfig,
congestion::BbrConfig, default_runtime,
ClientConfig, ConnectError, Connection, Endpoint, EndpointConfig, ServerConfig,
TransportConfig, congestion::BbrConfig, default_runtime,
};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::OnceLock;
@@ -135,6 +135,12 @@ impl<Item> RwPool<Item> {
self.resize();
}
fn len(&self) -> usize {
let persistent_len = self.persistent.read().len();
let ephemeral_len = self.ephemeral.read().len();
persistent_len + ephemeral_len
}
/// try to push an item to the ephemeral pool, return the item if full
fn try_push(&self, item: Item) -> Option<Item> {
let mut pool = self.ephemeral.write();
@@ -168,6 +174,49 @@ impl<Item> RwPool<Item> {
f(&mut persistent.iter().chain(ephemeral.iter()))
}
}
impl RwPool<Endpoint> {
fn retain_endpoints<F>(&self, mut keep: F) -> usize
where
F: FnMut(&Endpoint) -> bool,
{
let persistent_removed = {
let mut persistent = self.persistent.write();
let before = persistent.len();
persistent.retain(|endpoint| keep(endpoint));
before - persistent.len()
};
let ephemeral_removed = {
let mut ephemeral = self.ephemeral.write();
let before = ephemeral.len();
ephemeral.retain(|endpoint| keep(endpoint));
before - ephemeral.len()
};
let removed = persistent_removed + ephemeral_removed;
if removed > 0 {
self.resize();
}
removed
}
fn remove_by_local_addr(&self, local_addr: SocketAddr) -> usize {
self.retain_endpoints(|endpoint| endpoint.local_addr().ok() != Some(local_addr))
}
fn contains_local_addr(&self, local_addr: SocketAddr) -> bool {
self.persistent
.read()
.iter()
.any(|endpoint| endpoint.local_addr().ok() == Some(local_addr))
|| self
.ephemeral
.read()
.iter()
.any(|endpoint| endpoint.local_addr().ok() == Some(local_addr))
}
}
//endregion
//region endpoint manager
@@ -262,6 +311,20 @@ impl QuicEndpointManager {
QUIC_ENDPOINT_MANAGER.get().unwrap()
}
fn client_pool(&self, ip_version: IpVersion) -> &RwPool<Endpoint> {
let dual_stack = self.both.is_enabled();
match ip_version {
IpVersion::V4 if !dual_stack => &self.ipv4,
_ => {
if dual_stack {
&self.both
} else {
&self.ipv6
}
}
}
}
/// Get a QUIC endpoint to be used as a server
///
/// # Arguments
@@ -288,14 +351,8 @@ impl QuicEndpointManager {
Ok(endpoint)
}
/// Get a quic endpoint to be used as a client
///
/// # Arguments
/// * `ip_version`: the IP version of the remote address
fn client(global_ctx: &ArcGlobalCtx, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
let mgr = Self::load(global_ctx);
let (pool, endpoint) = mgr.create(|mgr| {
fn client_endpoint(&self, ip_version: IpVersion) -> Result<Endpoint, TunnelError> {
let (pool, endpoint) = self.create(|mgr| {
let dual_stack = mgr.both.is_enabled();
let (pool, addr) = match ip_version {
IpVersion::V4 if !dual_stack => (&mgr.ipv4, (Ipv4Addr::UNSPECIFIED, 0).into()),
@@ -318,6 +375,26 @@ impl QuicEndpointManager {
Ok(pool.with_iter(|iter| iter.min_by_key(|e| e.open_connections()).unwrap().clone()))
}
fn remove_endpoint(&self, endpoint: &Endpoint) -> usize {
let Ok(local_addr) = endpoint.local_addr() else {
return 0;
};
self.remove_endpoint_by_local_addr(local_addr)
}
fn remove_endpoint_by_local_addr(&self, local_addr: SocketAddr) -> usize {
[&self.ipv4, &self.ipv6, &self.both]
.into_iter()
.map(|pool| pool.remove_by_local_addr(local_addr))
.sum()
}
fn contains_local_addr(&self, local_addr: SocketAddr) -> bool {
[&self.ipv4, &self.ipv6, &self.both]
.into_iter()
.any(|pool| pool.contains_local_addr(local_addr))
}
async fn connect(
global_ctx: &ArcGlobalCtx,
addr: SocketAddr,
@@ -327,14 +404,52 @@ impl QuicEndpointManager {
} else {
IpVersion::V6
};
let endpoint = Self::client(global_ctx, ip_version)?;
let connection = endpoint
.connect(addr, "localhost")
.with_context(|| format!("failed to create connection to {}", addr))?
Self::load(global_ctx)
.connect_with_ip_version(addr, ip_version)
.await
.with_context(|| format!("failed to connect to {}", addr))?;
}
Ok((endpoint, connection))
async fn connect_with_ip_version(
&self,
addr: SocketAddr,
ip_version: IpVersion,
) -> Result<(Endpoint, Connection), TunnelError> {
let max_endpoint_stopping_retries = self.client_pool(ip_version).len().saturating_add(1);
let mut endpoint_stopping_retries = 0;
loop {
let endpoint = self.client_endpoint(ip_version)?;
let connecting = match endpoint.connect(addr, "localhost") {
Ok(connecting) => connecting,
Err(ConnectError::EndpointStopping) => {
let local_addr = endpoint.local_addr().ok();
let removed = self.remove_endpoint(&endpoint);
endpoint_stopping_retries += 1;
tracing::warn!(
?addr,
?local_addr,
removed,
"removed stopped quic endpoint and retry connect"
);
if endpoint_stopping_retries > max_endpoint_stopping_retries {
return Err(anyhow::Error::new(ConnectError::EndpointStopping)
.context(format!("failed to create connection to {}", addr))
.into());
}
continue;
}
Err(e) => {
return Err(anyhow::Error::new(e)
.context(format!("failed to create connection to {}", addr))
.into());
}
};
let connection = connecting
.await
.with_context(|| format!("failed to connect to {}", addr))?;
return Ok((endpoint, connection));
}
}
}
//endregion
@@ -398,6 +513,18 @@ impl QuicTunnelListener {
}
}
impl Drop for QuicTunnelListener {
fn drop(&mut self) {
let Some(endpoint) = &self.endpoint else {
return;
};
let Ok(local_addr) = endpoint.local_addr() else {
return;
};
QuicEndpointManager::load(&self.global_ctx).remove_endpoint_by_local_addr(local_addr);
}
}
#[async_trait::async_trait]
impl TunnelListener for QuicTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
@@ -432,6 +559,7 @@ pub struct QuicTunnelConnector {
addr: url::Url,
global_ctx: ArcGlobalCtx,
ip_version: IpVersion,
resolved_addr: Option<SocketAddr>,
}
impl QuicTunnelConnector {
@@ -440,6 +568,7 @@ impl QuicTunnelConnector {
addr,
global_ctx,
ip_version: IpVersion::Both,
resolved_addr: None,
}
}
}
@@ -447,7 +576,10 @@ impl QuicTunnelConnector {
#[async_trait::async_trait]
impl TunnelConnector for QuicTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
let addr = match self.resolved_addr {
Some(addr) => addr,
None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?,
};
let (endpoint, connection) = QuicEndpointManager::connect(&self.global_ctx, addr).await?;
let local_addr = endpoint.local_addr()?;
@@ -484,6 +616,10 @@ impl TunnelConnector for QuicTunnelConnector {
fn set_ip_version(&mut self, ip_version: IpVersion) {
self.ip_version = ip_version;
}
fn set_resolved_addr(&mut self, addr: SocketAddr) {
self.resolved_addr = Some(addr);
}
}
#[cfg(test)]
@@ -507,6 +643,20 @@ mod tests {
get_mock_global_ctx_with_network(Some(identity))
}
fn stopped_client_endpoint() -> (Endpoint, SocketAddr) {
let rt = Builder::new_current_thread().enable_all().build().unwrap();
let endpoint = rt.block_on(async {
QuicEndpointManager::try_create((Ipv4Addr::UNSPECIFIED, 0).into(), false).unwrap()
});
let local_addr = endpoint.local_addr().unwrap();
drop(rt);
assert!(matches!(
endpoint.connect("127.0.0.1:1".parse().unwrap(), "localhost"),
Err(ConnectError::EndpointStopping)
));
(endpoint, local_addr)
}
#[test]
fn quic_pingpong() {
RUNTIME.block_on(quic_pingpong_impl())
@@ -582,6 +732,51 @@ mod tests {
assert!(port > 0);
}
#[test]
fn listener_drop_removes_persistent_endpoint() {
RUNTIME.block_on(listener_drop_removes_persistent_endpoint_impl())
}
async fn listener_drop_removes_persistent_endpoint_impl() {
let global_ctx = global_ctx();
let endpoint_addr = {
let mut listener =
QuicTunnelListener::new("quic://127.0.0.1:0".parse().unwrap(), global_ctx.clone());
listener.listen().await.unwrap();
let endpoint_addr = listener.endpoint.as_ref().unwrap().local_addr().unwrap();
assert!(QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr));
endpoint_addr
};
assert!(!QuicEndpointManager::load(&global_ctx).contains_local_addr(endpoint_addr));
}
#[test]
fn connect_removes_stopped_endpoints_and_retries() {
let (stopped_endpoint_a, stopped_addr_a) = stopped_client_endpoint();
let (stopped_endpoint_b, stopped_addr_b) = stopped_client_endpoint();
RUNTIME.block_on(async move {
let mgr = QuicEndpointManager::new(2);
mgr.both.push(stopped_endpoint_a);
mgr.both.push(stopped_endpoint_b);
assert!(mgr.contains_local_addr(stopped_addr_a));
assert!(mgr.contains_local_addr(stopped_addr_b));
let err = mgr
.connect_with_ip_version("127.0.0.1:0".parse().unwrap(), IpVersion::V4)
.await
.unwrap_err();
let err = format!("{:?}", err);
assert!(
err.contains("invalid remote address"),
"unexpected error: {}",
err
);
assert!(!mgr.contains_local_addr(stopped_addr_a));
assert!(!mgr.contains_local_addr(stopped_addr_b));
});
}
#[test]
fn invalid_peer_addr() {
RUNTIME.block_on(invalid_peer_addr_impl())
+35 -1
View File
@@ -129,6 +129,7 @@ pub struct TcpTunnelConnector {
bind_addrs: Vec<SocketAddr>,
ip_version: IpVersion,
resolved_addr: Option<SocketAddr>,
}
impl TcpTunnelConnector {
@@ -137,6 +138,7 @@ impl TcpTunnelConnector {
addr,
bind_addrs: vec![],
ip_version: IpVersion::Both,
resolved_addr: None,
}
}
@@ -175,7 +177,10 @@ impl TcpTunnelConnector {
#[async_trait]
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
let addr = match self.resolved_addr {
Some(addr) => addr,
None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?,
};
if self.bind_addrs.is_empty() {
self.connect_with_default_bind(addr).await
} else {
@@ -194,6 +199,10 @@ impl super::TunnelConnector for TcpTunnelConnector {
fn set_ip_version(&mut self, ip_version: IpVersion) {
self.ip_version = ip_version;
}
fn set_resolved_addr(&mut self, addr: SocketAddr) {
self.resolved_addr = Some(addr);
}
}
#[cfg(test)]
@@ -294,6 +303,31 @@ mod tests {
);
}
#[tokio::test]
async fn connector_uses_pre_resolved_addr_without_resolving_url() {
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:0".parse().unwrap());
listener.listen().await.unwrap();
let port = listener.local_url().port().unwrap();
let source_url: url::Url = format!("tcp://unresolvable.invalid:{port}")
.parse()
.unwrap();
let resolved_addr: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap();
let mut connector = TcpTunnelConnector::new(source_url.clone());
connector.set_resolved_addr(resolved_addr);
let accept_task = tokio::spawn(async move { listener.accept().await.unwrap() });
let tunnel = connector.connect().await.unwrap();
let _accepted_tunnel = accept_task.await.unwrap();
let info = tunnel.info().unwrap();
assert_eq!(info.remote_addr.unwrap().url, source_url.to_string());
let resolved_remote_addr: url::Url = info.resolved_remote_addr.unwrap().into();
assert_eq!(resolved_remote_addr.host_str(), Some("127.0.0.1"));
assert_eq!(resolved_remote_addr.port(), Some(port));
}
#[tokio::test]
async fn test_alloc_port() {
// v4
+10 -1
View File
@@ -682,6 +682,7 @@ pub struct UdpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
ip_version: IpVersion,
resolved_addr: Option<SocketAddr>,
}
impl UdpTunnelConnector {
@@ -690,6 +691,7 @@ impl UdpTunnelConnector {
addr,
bind_addrs: vec![],
ip_version: IpVersion::Both,
resolved_addr: None,
}
}
@@ -906,7 +908,10 @@ impl UdpTunnelConnector {
#[async_trait]
impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
let addr = match self.resolved_addr {
Some(addr) => addr,
None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?,
};
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
} else {
@@ -925,6 +930,10 @@ impl super::TunnelConnector for UdpTunnelConnector {
fn set_ip_version(&mut self, ip_version: IpVersion) {
self.ip_version = ip_version;
}
fn set_resolved_addr(&mut self, addr: SocketAddr) {
self.resolved_addr = Some(addr);
}
}
#[cfg(test)]
+13 -9
View File
@@ -198,6 +198,7 @@ impl TunnelListener for WsTunnelListener {
pub struct WsTunnelConnector {
addr: url::Url,
ip_version: IpVersion,
resolved_addr: Option<SocketAddr>,
bind_addrs: Vec<SocketAddr>,
}
@@ -207,6 +208,7 @@ impl WsTunnelConnector {
WsTunnelConnector {
addr,
ip_version: IpVersion::Both,
resolved_addr: None,
bind_addrs: vec![],
}
@@ -214,11 +216,10 @@ impl WsTunnelConnector {
async fn connect_with(
addr: url::Url,
ip_version: IpVersion,
socket_addr: SocketAddr,
tcp_socket: TcpSocket,
) -> Result<Box<dyn Tunnel>, TunnelError> {
let is_wss = is_wss(&addr)?;
let socket_addr = SocketAddr::from_url(addr.clone(), ip_version).await?;
let stream = tcp_socket.connect(socket_addr).await?;
if let Err(error) = stream.set_nodelay(true) {
tracing::warn!(?error, "set_nodelay fail in ws connect");
@@ -273,7 +274,7 @@ impl WsTunnelConnector {
} else {
TcpSocket::new_v6()?
};
Self::connect_with(self.addr.clone(), self.ip_version, socket).await
Self::connect_with(self.addr.clone(), addr, socket).await
}
async fn connect_with_custom_bind(
@@ -285,11 +286,7 @@ impl WsTunnelConnector {
for bind_addr in self.bind_addrs.iter() {
tracing::info!(?bind_addr, ?addr, "bind addr");
match bind().addr(*bind_addr).only_v6(true).call() {
Ok(socket) => futures.push(Self::connect_with(
self.addr.clone(),
self.ip_version,
socket,
)),
Ok(socket) => futures.push(Self::connect_with(self.addr.clone(), addr, socket)),
Err(error) => {
tracing::error!(?bind_addr, ?addr, ?error, "bind addr fail");
continue;
@@ -304,7 +301,10 @@ impl WsTunnelConnector {
#[async_trait::async_trait]
impl TunnelConnector for WsTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
let addr = match self.resolved_addr {
Some(addr) => addr,
None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?,
};
if self.bind_addrs.is_empty() || addr.is_ipv6() {
self.connect_with_default_bind(addr).await
} else {
@@ -323,6 +323,10 @@ impl TunnelConnector for WsTunnelConnector {
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
fn set_resolved_addr(&mut self, addr: SocketAddr) {
self.resolved_addr = Some(addr);
}
}
#[cfg(test)]
+10 -1
View File
@@ -598,6 +598,7 @@ pub struct WgTunnelConnector {
bind_addrs: Vec<SocketAddr>,
ip_version: IpVersion,
resolved_addr: Option<SocketAddr>,
}
impl Debug for WgTunnelConnector {
@@ -617,6 +618,7 @@ impl WgTunnelConnector {
udp: None,
bind_addrs: vec![],
ip_version: IpVersion::Both,
resolved_addr: None,
}
}
@@ -702,7 +704,10 @@ impl WgTunnelConnector {
impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version).await?;
let addr = match self.resolved_addr {
Some(addr) => addr,
None => SocketAddr::from_url(self.addr.clone(), self.ip_version).await?,
};
if addr.is_ipv6() {
return self.connect_with_ipv6(addr).await;
@@ -744,6 +749,10 @@ impl super::TunnelConnector for WgTunnelConnector {
fn set_ip_version(&mut self, ip_version: IpVersion) {
self.ip_version = ip_version;
}
fn set_resolved_addr(&mut self, addr: SocketAddr) {
self.resolved_addr = Some(addr);
}
}
#[cfg(test)]
-638
View File
@@ -1,638 +0,0 @@
//! # Guard Module Utilities
//!
//! This module provides mechanisms for scope-based resource management and deferred execution.
//!
//! ### ⚠️ Critical Usage Note: Diverging Expressions
//!
//! Do not use "naked" diverging expressions—such as `panic!`, `todo!`, or `loop {}`—as
//! the sole content of sync guard closure. This prevents the compiler from
//! distinguishing between synchronous (`ASYNC = false`) and asynchronous
//! (`ASYNC = true`) implementations, leading to a type inference error (E0277).
//!
//! ### Technical Context
//!
//! The `!` (Never Type) is a bottom type that can be coerced into any other type.
//! Because it satisfies both the `()` requirement for sync guards and the `Future`
//! requirement for async guards, the compiler encounters an inference deadlock.
//!
//! ### Workaround
//!
//! For macros like `guard!` or `guarded!`, force the closure to resolve to `()`
//! by explicitly setting the guard to `sync`:
//!
//! ```rust
//! let _g = guard!([val] sync {
//! panic!("critical failure");
//! });
//! ```
use crate::utils::task::{DetachableTask, TaskSpawner};
use std::fmt::Debug;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
pub trait CallableGuard<const ASYNC: bool, Context> {
type Output;
fn call(self, context: Context) -> Self::Output;
}
impl<Context, Guard> CallableGuard<false, Context> for Guard
where
Guard: FnOnce(Context),
{
type Output = ();
fn call(self, context: Context) -> Self::Output {
self(context)
}
}
impl<Context, Guard, Task, _R> CallableGuard<true, Context> for Guard
where
Guard: FnOnce(Context) -> Task + Send + 'static,
Task: Future<Output = _R> + Send + 'static,
_R: Send + 'static,
{
type Output = DetachableTask<TaskSpawner<Task>, Task>;
fn call(self, context: Context) -> Self::Output {
DetachableTask::new(self(context))
}
}
pub struct ContextGuard<const ASYNC: bool, Context, Guard: CallableGuard<ASYNC, Context>> {
context: ManuallyDrop<Context>,
guard: ManuallyDrop<Guard>,
}
impl<const ASYNC: bool, Context, Guard: CallableGuard<ASYNC, Context>> Deref
for ContextGuard<ASYNC, Context, Guard>
{
type Target = Context;
fn deref(&self) -> &Self::Target {
&self.context
}
}
impl<const ASYNC: bool, Context, Guard: CallableGuard<ASYNC, Context>> DerefMut
for ContextGuard<ASYNC, Context, Guard>
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.context
}
}
impl<const ASYNC: bool, Context: Debug, Guard: CallableGuard<ASYNC, Context>> Debug
for ContextGuard<ASYNC, Context, Guard>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = if ASYNC {
"ContextGuard::Async"
} else {
"ContextGuard::Sync"
};
f.debug_struct(name)
.field("context", &self.context)
.finish_non_exhaustive()
}
}
impl<const ASYNC: bool, Context, Guard: CallableGuard<ASYNC, Context>>
ContextGuard<ASYNC, Context, Guard>
{
/// Creates a new `ContextGuard`.
///
/// **Note on generics:** The seemingly unused `_R` generic parameter and the
/// `Guard: FnOnce(Context) -> _R` trait bound are intentionally included.
/// They act as a hint to help the compiler infer closure types.
pub fn new<_R>(context: Context, guard: Guard) -> Self
where
Guard: FnOnce(Context) -> _R,
{
ContextGuard {
context: ManuallyDrop::new(context),
guard: ManuallyDrop::new(guard),
}
}
}
impl<const ASYNC: bool, Context, Guard: CallableGuard<ASYNC, Context>>
ContextGuard<ASYNC, Context, Guard>
{
unsafe fn call(&mut self) -> Guard::Output {
unsafe {
let context = ManuallyDrop::take(&mut self.context);
let guard = ManuallyDrop::take(&mut self.guard);
guard.call(context)
}
}
pub fn trigger(self) -> Guard::Output {
let mut this = ManuallyDrop::new(self);
unsafe { this.call() }
}
pub fn defuse(self) -> Context {
let mut this = ManuallyDrop::new(self);
unsafe {
ManuallyDrop::drop(&mut this.guard);
ManuallyDrop::take(&mut this.context)
}
}
}
impl<const ASYNC: bool, Context, Guard: CallableGuard<ASYNC, Context>> Drop
for ContextGuard<ASYNC, Context, Guard>
{
fn drop(&mut self) {
let _: Guard::Output = unsafe { self.call() };
}
}
// region macro
#[doc(hidden)]
#[macro_export]
macro_rules! __guarded {
(@parse@action $guard:ident => $($tt:tt)*) => {
$crate::__guarded! { @parse@async action: [ @stmt $guard ] ; $($tt)* }
};
(@parse@action $($tt:tt)*) => {
$crate::__guarded! { @parse@async action: [ @stmt __guard ] ; $($tt)* }
};
(@parse@async action: [ $($action:tt)* ] ; sync $($tt:tt)*) => {
$crate::__guarded! { @parse@move action: [ $($action)* ] ; async: [ false ] ; $($tt)* }
};
(@parse@async action: [ $($action:tt)* ] ; $($tt:tt)*) => {
$crate::__guarded! { @parse@move action: [ $($action)* ] ; async: [ _ ] ; $($tt)* }
};
(@parse@move action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move $($tt:tt)*) => {
$crate::__guarded! { @parse action: [ $($action)* ] ; async: [ $async ] ; move: [ move ] ; $($tt)* }
};
(@parse@move action: [ $($action:tt)* ] ; async: [ $async:tt ] ; $($tt:tt)*) => {
$crate::__guarded! { @parse action: [ $($action)* ] ; async: [ $async ] ; move: [] ; $($tt)* }
};
(
@parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ;
[ $($args:tt)* ] $body:block
) => {
$crate::__guarded! {
action: [ $($action)* ]
async: [ $async ]
move: [ $($move)? ]
mut: []
rest: [ $($args)* , ]
args: []
vars: []
body: [ $body ]
}
};
(
@parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ;
$body:block
) => {
$crate::__guarded! {
@parse action: [ $($action)* ] ; async: [ $async ] ; move: [ $($move)? ] ;
[] $body
}
};
(
@parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ;
[ $($args:tt)* ] $($body:tt)*
) => {
$crate::__guarded! {
@parse action: [ $($action)* ] ; async: [ $async ] ; move: [ $($move)? ] ;
[ $($args)* ] { $($body)* }
}
};
(
@parse action: [ $($action:tt)* ] ; async: [ $async:tt ] ; move: [ $($move:tt)? ] ;
$($body:tt)*
) => {
$crate::__guarded! {
@parse action: [ $($action)* ] ; async: [ $async ] ; move: [ $($move)? ] ;
[] { $($body)* }
}
};
(
action: [ $($action:tt)* ]
async: [ $async:tt ]
move: [ $($move:tt)? ]
mut: [ $($mut:tt)? ]
rest: [ mut $arg:ident , $($rest:tt)* ]
args: [ $($args:ident)* ]
vars: [ $($vars:tt)* ]
body: [ $body:expr ]
) => {
$crate::__guarded! {
action: [ $($action)* ]
async: [ $async ]
move: [ $($move)? ]
mut: [ mut ]
rest: [ $($rest)* ]
args: [ $($args)* $arg ]
vars: [ $($vars)* [mut $arg] ]
body: [ $body ]
}
};
(
action: [ $($action:tt)* ]
async: [ $async:tt ]
move: [ $($move:tt)? ]
mut: [ $($mut:tt)? ]
rest: [ $arg:ident , $($rest:tt)* ]
args: [ $($args:ident)* ]
vars: [ $($vars:tt)* ]
body: [ $body:expr ]
) => {
$crate::__guarded! {
action: [ $($action)* ]
async: [ $async ]
move: [ $($move)? ]
mut: [ $($mut)? ]
rest: [ $($rest)* ]
args: [ $($args)* $arg ]
vars: [ $($vars)* [$arg] ]
body: [ $body ]
}
};
(
action: [ @stmt $guard:ident ]
async: [ $async:tt ]
move: [ $($move:tt)? ]
mut: [ $($mut:tt)? ]
rest: [ $(,)* ]
args: [ $($args:ident)* ]
vars: [ $([$($vars:tt)*])* ]
body: [ $body:expr ]
) => {
let $($mut)? $guard = $crate::utils::guard::ContextGuard::<$async, _, _>::new(
( $($args),* ),
$($move)? |#[allow(unused_parens, unused_mut)] ( $($($vars)*),* )| $body
);
#[allow(unused_parens, unused_variables, clippy::toplevel_ref_arg)]
let ( $(ref $($vars)*),* ) = *$guard;
};
(
action: [ @expr ]
async: [ $async:tt ]
move: [ $($move:tt)? ]
mut: [ $($mut:tt)? ]
rest: [ $(,)* ]
args: [ $($args:ident)* ]
vars: [ $([$($vars:tt)*])* ]
body: [ $body:expr ]
) => {
$crate::utils::guard::ContextGuard::<$async, _, _>::new(
( $($args),* ),
$($move)? |#[allow(unused_parens)] ( $($($vars)*),* )| $body
)
};
}
/// Creates a [`ContextGuard`] object, binding it to a variable with the specified name (e.g., `_guard`).
/// Context variables specified in the macro invocation are available within and after the guard body.
///
/// **Note:** For usage with `panic!` or `loop`, see the [module-level documentation](self)
/// regarding type inference deadlocks.
#[macro_export]
macro_rules! guarded {
( $($tt:tt)* ) => {
$crate::__guarded! { @parse@action $($tt)* }
};
}
/// Creates a [`ContextGuard`] object, without binding it to a variable.
/// Context variables specified in the macro invocation are available within the guard body.
///
/// **Note:** For usage with `panic!` or `loop`, see the [module-level documentation](self)
/// regarding type inference deadlocks.
#[macro_export]
macro_rules! guard {
( $($tt:tt)* ) => {
$crate::__guarded! { @parse@async action: [ @expr ] ; $($tt)* }
};
}
// endregion
/// Alias for [`guarded!`].
///
/// **Note:** For usage with `panic!` or `loop`, see the [module-level documentation](self)
/// regarding type inference deadlocks.
#[macro_export]
macro_rules! defer {
( $($tt:tt)* ) => {
$crate::guarded! { $($tt)* }
};
}
#[cfg(test)]
mod tests {
use std::panic::catch_unwind;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::oneshot;
#[test]
fn trigger_sync_executes_once() {
let called = Arc::new(AtomicUsize::new(0));
let observed = Arc::new(AtomicUsize::new(0));
let value = 7usize;
let guard = {
let called = called.clone();
let observed = observed.clone();
crate::guard!(move [value] {
called.fetch_add(1, Ordering::SeqCst);
observed.store(value, Ordering::SeqCst);
})
};
guard.trigger();
assert_eq!(called.load(Ordering::SeqCst), 1);
assert_eq!(observed.load(Ordering::SeqCst), 7);
}
#[test]
fn defuse_sync_returns_context_without_running_guard() {
let called = Arc::new(AtomicUsize::new(0));
let value = String::from("hello");
let guard = {
let called = called.clone();
crate::guard!(move [mut value] {
value.push_str(" world");
called.fetch_add(1, Ordering::SeqCst);
})
};
let context = guard.defuse();
assert_eq!(context, "hello");
assert_eq!(called.load(Ordering::SeqCst), 0);
}
#[test]
fn drop_sync_triggers_guard() {
let called = Arc::new(AtomicUsize::new(0));
{
let called = called.clone();
crate::guarded!([called] {
called.fetch_add(1, Ordering::SeqCst);
});
}
assert_eq!(called.load(Ordering::SeqCst), 1);
}
#[test]
fn drop_propagates_guard_panic() {
let dropped = catch_unwind(|| {
guarded! {
sync {
panic!("boom");
}
}
});
assert!(dropped.is_err());
}
#[tokio::test]
async fn trigger_async_returns_runnable_task() {
let called = Arc::new(AtomicUsize::new(0));
let value = 5usize;
let guard = {
let called = called.clone();
crate::guard!(move [value] async move {
called.fetch_add(value, Ordering::SeqCst);
})
};
let task = guard.trigger();
task.await;
assert_eq!(called.load(Ordering::SeqCst), 5);
}
#[tokio::test]
async fn drop_async_detaches_task() {
let (tx, rx) = oneshot::channel();
{
let mut tx = Some(tx);
let value = 9usize;
let _guard = crate::guard!(move [value] {
let tx = tx.take();
async move {
if let Some(tx) = tx {
let _ = tx.send(value);
}
}
});
}
let value = tokio::time::timeout(Duration::from_secs(1), rx)
.await
.expect("detached task should run")
.expect("detached task should send value");
assert_eq!(value, 9);
}
#[tokio::test]
async fn defuse_async_does_not_execute() {
let called = Arc::new(AtomicUsize::new(0));
let value = 11usize;
let guard = {
let called = called.clone();
crate::guard!(move [value] async move {
called.fetch_add(value, Ordering::SeqCst);
})
};
let context = guard.defuse();
assert_eq!(context, 11);
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(called.load(Ordering::SeqCst), 0);
}
#[test]
fn guarded_named_mut_binding_updates_context_before_drop() {
let committed = Arc::new(AtomicUsize::new(0));
{
let value = 1usize;
let step = 2usize;
let committed = committed.clone();
crate::guarded!(scope_guard => [mut value, step] {
committed.store(value + step, Ordering::SeqCst);
});
*value += 10;
assert_eq!(*value, 11);
assert_eq!(*step, 2);
drop(scope_guard);
}
assert_eq!(committed.load(Ordering::SeqCst), 13);
}
#[test]
fn guard_expression_parses_without_braces() {
let observed = Arc::new(AtomicUsize::new(0));
let value = 3usize;
let observed_clone = observed.clone();
let guard = crate::guard!([value] observed_clone.store(value, Ordering::SeqCst));
guard.trigger();
assert_eq!(observed.load(Ordering::SeqCst), 3);
}
#[test]
fn defer_alias_behaves_like_guarded_statement() {
let called = Arc::new(AtomicUsize::new(0));
{
let n = 42usize;
let called = called.clone();
crate::defer!([n] {
called.store(n, Ordering::SeqCst);
});
}
assert_eq!(called.load(Ordering::SeqCst), 42);
}
#[tokio::test]
async fn guard_and_guarded_macro_usage_matrix() {
// 1) guard!: block body + trailing comma args + trigger()
let sink = Arc::new(AtomicUsize::new(0));
let v = 1usize;
let sink_clone = sink.clone();
let g1 = crate::guard!([v,] {
sink_clone.store(v, Ordering::SeqCst);
});
g1.trigger();
assert_eq!(sink.load(Ordering::SeqCst), 1);
// 2) guard!: expression body (no braces)
let sink = Arc::new(AtomicUsize::new(0));
let sink_clone = sink.clone();
let v = 2usize;
let g2 = crate::guard!([v] sink_clone.store(v, Ordering::SeqCst));
g2.trigger();
assert_eq!(sink.load(Ordering::SeqCst), 2);
// 3) guard!: explicit sync + no args form
let sink = Arc::new(AtomicUsize::new(0));
let sink_clone = sink.clone();
let g3 = crate::guard!(sync {
sink_clone.store(3, Ordering::SeqCst);
});
g3.trigger();
assert_eq!(sink.load(Ordering::SeqCst), 3);
// 4) guard!: move capture + defuse() prevents execution
let sink = Arc::new(AtomicUsize::new(0));
let owned = String::from("owned");
let sink_clone = sink.clone();
let g4 = crate::guard!(move [owned] {
if owned == "owned" {
sink_clone.store(4, Ordering::SeqCst);
}
});
let context = g4.defuse();
assert_eq!(context, "owned");
assert_eq!(sink.load(Ordering::SeqCst), 0);
// 5) guard!: async block inference + trigger() returns task
let sink = Arc::new(AtomicUsize::new(0));
let sink_clone = sink.clone();
let n = 5usize;
let g5 = crate::guard!([n] async move {
sink_clone.fetch_add(n, Ordering::SeqCst);
});
g5.trigger().await;
assert_eq!(sink.load(Ordering::SeqCst), 5);
// 6) guarded!: named binding + mut arg visible outside + explicit drop
let sink = Arc::new(AtomicUsize::new(0));
{
let value = 6usize;
let delta = 1usize;
let sink_clone = sink.clone();
crate::guarded!(named => [mut value, delta] {
sink_clone.store(value + delta, Ordering::SeqCst);
});
*value += 10;
assert_eq!(*value, 16);
assert_eq!(*delta, 1);
drop(named);
}
assert_eq!(sink.load(Ordering::SeqCst), 17);
// 7) guarded!: unnamed statement + expression body + implicit drop at scope end
let sink = Arc::new(AtomicUsize::new(0));
{
let n = 7usize;
let sink_clone = sink.clone();
crate::guarded!([n] sink_clone.store(n, Ordering::SeqCst));
}
assert_eq!(sink.load(Ordering::SeqCst), 7);
// 8) guarded!: explicit sync + panic path propagates on drop
let dropped = catch_unwind(|| {
guarded! {
sync {
panic!("matrix-boom");
}
}
});
assert!(dropped.is_err());
// 9) guarded!: async inference on drop detaches and executes
let (tx, rx) = oneshot::channel();
{
let tx = Some(tx);
crate::guarded!([mut tx] {
let tx = tx.take();
async move {
if let Some(tx) = tx {
let _ = tx.send(9usize);
}
}
});
}
let detached = tokio::time::timeout(Duration::from_secs(1), rx)
.await
.expect("detached task should complete")
.expect("detached task should send value");
assert_eq!(detached, 9);
}
}
-1
View File
@@ -1,4 +1,3 @@
pub mod guard;
pub mod panic;
pub mod string;
pub mod task;
-283
View File
@@ -1,7 +1,5 @@
use crate::utils::guard::ContextGuard;
use std::future::Future;
use std::io;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
@@ -80,284 +78,3 @@ impl<Output> Future for CancellableTask<Output> {
}
// endregion
// region DetachableTask
/// A pinned, heap-allocated task.
///
/// **Why Box?** Heap allocation is required because if the task detaches,
/// it outlives the current stack frame. `Pin<Box<_>>` ensures its memory address
/// remains completely stable during and after the transfer.
type BoxTask<Task> = Pin<Box<Task>>;
struct DetachableTaskContext<Spawner, Task> {
spawner: Spawner,
task: Option<BoxTask<Task>>,
}
type DetachableTaskGuardHelper<Context> = ContextGuard<false, Context, fn(Context)>;
type DetachableTaskGuard<Spawner, Task> =
DetachableTaskGuardHelper<DetachableTaskContext<Spawner, Task>>;
/// A task wrapper that executes inline but automatically detaches to a background spawner
/// if the current execution context is interrupted or dropped.
///
/// `DetachableTask` ensures anti-cancellation. If the outer future is dropped (e.g., due to
/// a timeout or a `select!` branch failing), the underlying unfinished task is seamlessly
/// transferred to a background executor via an RAII guard.
///
/// # Advantages over `tokio::spawn` + `.await JoinHandle`
///
/// 1. **Zero Initial Scheduling Overhead**: Prioritizes inline execution. If the task
/// completes before being interrupted, it entirely bypasses the runtime's scheduling queue,
/// eliminating queuing latency and context-switching CPU costs. Spawning is strictly a fallback.
///
/// 2. **Context Locality**: Before detachment, the task is polled directly by the caller's thread.
/// This implicitly preserves the current execution context, including thread-local storage (TLS),
/// Tokio `task_local!` variables, and `tracing` spans, which would otherwise be immediately
/// lost or require explicit propagation across task boundaries.
pub struct DetachableTask<Spawner, Task> {
guard: DetachableTaskGuard<Spawner, Task>,
}
impl<Spawner, Task> DetachableTask<Spawner, Task> {
pub fn detach(self) {
self.guard.trigger()
}
pub fn reclaim(self) -> BoxTask<Task> {
self.guard.defuse().task.unwrap()
}
}
pub type TaskSpawner<Task, R = JoinHandle<<Task as Future>::Output>> = fn(BoxTask<Task>) -> R;
impl DetachableTask<fn(()), ()> {
pub fn with_spawner<Spawner, _R, Task>(
spawner: Spawner,
task: Task,
) -> DetachableTask<Spawner, Task>
where
Spawner: FnOnce(BoxTask<Task>) -> _R,
{
let context = DetachableTaskContext {
spawner,
task: Some(Box::pin(task)),
};
DetachableTask {
guard: crate::guard!([context] if let Some(task) = context.task {
(context.spawner)(task);
}),
}
}
pub fn new<Task>(task: Task) -> DetachableTask<TaskSpawner<Task>, Task>
where
Task: Future + Send + 'static,
<Task as Future>::Output: Send + 'static,
{
Self::with_spawner(|task| tokio::runtime::Handle::current().spawn(task), task)
}
}
impl<Spawner: FnOnce(BoxTask<Task>) -> _R, _R, Task> IntoFuture for DetachableTask<Spawner, Task>
where
Task: Future,
{
type Output = Task::Output;
type IntoFuture = DetachableTaskFuture<Spawner, Task>;
fn into_future(self) -> Self::IntoFuture {
DetachableTaskFuture { guard: self.guard }
}
}
pub struct DetachableTaskFuture<Spawner, Task> {
guard: DetachableTaskGuard<Spawner, Task>,
}
impl<Spawner: FnOnce(BoxTask<Task>) -> _R, _R, Task> Future for DetachableTaskFuture<Spawner, Task>
where
Task: Future,
{
type Output = Task::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// SAFETY:
// 1. We only access the outer struct's unpinned fields.
// 2. The inner task remains securely pinned on the heap via `BoxTask<Task>`.
// 3. We never expose a mutable, unpinned reference to the underlying task.
let this = unsafe { self.get_unchecked_mut() };
let context = this.guard.deref_mut();
let mut task = context.task.take().expect("polled after completion");
let poll = task.as_mut().poll(cx);
if poll.is_pending() {
context.task = Some(task);
}
poll
}
}
// endregion
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
#[tokio::test]
async fn spawn_when_dropped() {
let spawned = Arc::new(AtomicBool::new(false));
{
let spawned = spawned.clone();
let _task = DetachableTask::new(async move {
spawned.store(true, Ordering::SeqCst);
});
}
tokio::time::timeout(Duration::from_secs(1), async {
while !spawned.load(Ordering::SeqCst) {
tokio::task::yield_now().await;
}
})
.await
.expect("task should be spawned on drop");
}
#[tokio::test]
async fn await_completed_task_does_not_detach() {
let spawn_count = Arc::new(AtomicUsize::new(0));
let result = {
let spawn_count = spawn_count.clone();
DetachableTask::with_spawner(
move |_| {
spawn_count.fetch_add(1, Ordering::SeqCst);
},
async { 7usize },
)
.await
};
assert_eq!(result, 7);
assert_eq!(spawn_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn drop_without_await_and_runs_once() {
let spawn_count = Arc::new(AtomicUsize::new(0));
let (done_tx, done_rx) = oneshot::channel();
{
let spawn_count = spawn_count.clone();
let _task = DetachableTask::with_spawner(
move |f| {
spawn_count.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let result = f.await;
let _ = done_tx.send(result);
});
},
async { 42usize },
);
}
let detached_result = tokio::time::timeout(Duration::from_secs(1), done_rx)
.await
.expect("detached task should finish")
.expect("detached task should send result");
assert_eq!(detached_result, 42);
assert_eq!(spawn_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn drop_after_await_still_detaches() {
let spawn_count = Arc::new(AtomicUsize::new(0));
let (value_tx, mut value_rx) = mpsc::channel(4);
let (done_tx, done_rx) = oneshot::channel();
let handle = {
let future = async move {
let mut sum = 0;
while let Some(value) = value_rx.recv().await {
sum += value;
}
sum
};
let spawn_count = spawn_count.clone();
let task = DetachableTask::with_spawner(
move |f| {
spawn_count.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let result = f.await;
let _ = done_tx.send(result);
});
},
future,
);
tokio::spawn(task.into_future())
};
value_tx
.send(10)
.await
.expect("value receiver should still exist");
handle.abort();
value_tx
.send(11)
.await
.expect("value receiver should still exist");
drop(value_tx);
let detached_result = tokio::time::timeout(Duration::from_secs(1), done_rx)
.await
.expect("detached polled task should finish")
.expect("detached polled task should send result");
assert_eq!(detached_result, 21);
assert_eq!(spawn_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn panic_during_inline_poll_does_not_detach_on_drop() {
struct PanicOnPollFuture {
poll_count: Arc<AtomicUsize>,
}
impl Future for PanicOnPollFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
self.poll_count.fetch_add(1, Ordering::SeqCst);
panic!("panic during inline poll")
}
}
let poll_count = Arc::new(AtomicUsize::new(0));
let detach_count = Arc::new(AtomicUsize::new(0));
let task = {
let detach_count = detach_count.clone();
DetachableTask::with_spawner(
move |_| {
detach_count.fetch_add(1, Ordering::SeqCst);
},
PanicOnPollFuture {
poll_count: poll_count.clone(),
},
)
};
let err = tokio::spawn(task.into_future())
.await
.expect_err("inline poll panic should propagate");
assert!(err.is_panic());
assert_eq!(poll_count.load(Ordering::SeqCst), 1);
assert_eq!(detach_count.load(Ordering::SeqCst), 0);
}
}
+63 -5
View File
@@ -2,13 +2,17 @@ use std::sync::Arc;
use crate::{
common::{
config::TomlConfigLoader, global_ctx::GlobalCtx, log, os_info::collect_device_os_info,
set_default_machine_id, stun::MockStunInfoCollector,
config::TomlConfigLoader,
global_ctx::{ArcGlobalCtx, GlobalCtx},
log,
os_info::collect_device_os_info,
set_default_machine_id,
stun::MockStunInfoCollector,
},
connector::create_connector_by_url,
instance_manager::{DaemonGuard, NetworkInstanceManager},
proto::common::NatType,
tunnel::{IpVersion, TunnelConnector},
tunnel::{IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelScheme},
};
use anyhow::{Context as _, Result};
use async_trait::async_trait;
@@ -49,6 +53,30 @@ pub struct WebClient {
connected: Arc<AtomicBool>,
}
struct ConfigServerConnector {
url: Url,
global_ctx: ArcGlobalCtx,
}
#[async_trait]
impl TunnelConnector for ConfigServerConnector {
async fn connect(&mut self) -> std::result::Result<Box<dyn Tunnel>, TunnelError> {
let mut connector =
create_connector_by_url(self.url.as_str(), &self.global_ctx, IpVersion::Both)
.await
.map_err(|err| match err {
crate::common::error::Error::TunnelError(err) => err,
err => TunnelError::Anyhow(err.into()),
})?;
connector.connect().await
}
fn remote_url(&self) -> Url {
self.url.clone()
}
}
impl WebClient {
pub fn new<T: TunnelConnector + 'static, S: ToString, H: ToString>(
connector: T,
@@ -218,6 +246,13 @@ pub async fn run_web_client(
.with_context(|| "failed to parse config server URL")?,
};
TunnelScheme::try_from(&config_server_url).map_err(|_| {
anyhow::anyhow!(
"unsupported config server scheme: {}",
config_server_url.scheme()
)
})?;
let mut c_url = config_server_url.clone();
if !matches!(c_url.scheme(), "ws" | "wss") {
c_url.set_path("");
@@ -243,16 +278,20 @@ pub async fn run_web_client(
let mut flags = global_ctx.get_flags();
flags.bind_device = false;
global_ctx.set_flags(flags);
let hostname = match hostname {
None => gethostname::gethostname().to_string_lossy().to_string(),
Some(hostname) => hostname,
};
Ok(WebClient::new(
create_connector_by_url(c_url.as_str(), &global_ctx, IpVersion::Both).await?,
ConfigServerConnector {
url: c_url,
global_ctx,
},
token.to_string(),
hostname,
secure_mode,
manager.clone(),
manager,
hooks,
))
}
@@ -292,4 +331,23 @@ mod tests {
assert!(sleep_finish.load(std::sync::atomic::Ordering::Relaxed));
println!("Manager stopped.");
}
#[tokio::test]
async fn test_run_web_client_with_unreachable_config_server() {
let manager = Arc::new(NetworkInstanceManager::new());
let client = super::run_web_client(
"udp://config-server.invalid:22020/test",
None,
None,
false,
manager,
None,
)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!client.is_connected());
drop(client);
}
}