Add WireGuard Client to Readme (#44)

* Add README for Wireguard Client

* add default protocol flag

* wireguard connector support bind device
This commit is contained in:
Sijie.Sun
2024-03-31 21:10:59 +08:00
committed by GitHub
parent 05cabb2651
commit 25a7603990
14 changed files with 281 additions and 46 deletions
+26
View File
@@ -45,6 +45,9 @@ pub trait ConfigLoader: Send + Sync {
fn get_vpn_portal_config(&self) -> Option<VpnPortalConfig>;
fn set_vpn_portal_config(&self, config: VpnPortalConfig);
fn get_flags(&self) -> Flags;
fn set_flags(&self, flags: Flags);
fn dump(&self) -> String;
}
@@ -96,6 +99,14 @@ pub struct VpnPortalConfig {
pub wireguard_listen: SocketAddr,
}
// Flags is used to control the behavior of the program
#[derive(derivative::Derivative, Deserialize, Serialize)]
#[derivative(Debug, Clone, PartialEq, Default)]
pub struct Flags {
#[derivative(Default(value = "\"tcp\".to_string()"))]
pub default_protocol: String,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct Config {
netns: Option<String>,
@@ -114,6 +125,8 @@ struct Config {
rpc_portal: Option<SocketAddr>,
vpn_portal_config: Option<VpnPortalConfig>,
flags: Option<Flags>,
}
#[derive(Debug, Clone)]
@@ -332,6 +345,19 @@ impl ConfigLoader for TomlConfigLoader {
self.config.lock().unwrap().vpn_portal_config = Some(config);
}
fn get_flags(&self) -> Flags {
self.config
.lock()
.unwrap()
.flags
.clone()
.unwrap_or_default()
}
fn set_flags(&self, flags: Flags) {
self.config.lock().unwrap().flags = Some(flags);
}
fn dump(&self) -> String {
toml::to_string_pretty(&*self.config.lock().unwrap()).unwrap()
}
+5 -1
View File
@@ -4,7 +4,7 @@ use crate::rpc::PeerConnInfo;
use crossbeam::atomic::AtomicCell;
use super::{
config::ConfigLoader,
config::{ConfigLoader, Flags},
netns::NetNS,
network::IPCollector,
stun::{StunInfoCollector, StunInfoCollectorTrait},
@@ -199,6 +199,10 @@ impl GlobalCtx {
pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> {
self.config.get_vpn_portal_config().map(|x| x.client_cidr)
}
pub fn get_flags(&self) -> Flags {
self.config.get_flags()
}
}
#[cfg(test)]
+7 -1
View File
@@ -246,10 +246,16 @@ impl DirectConnectorManager {
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
.collect::<Vec<_>>();
let listener = available_listeners
let mut listener = available_listeners
.get(0)
.ok_or(anyhow::anyhow!("peer {} have no listener", dst_peer_id))?;
// if have default listener, use it first
listener = available_listeners
.iter()
.find(|l| l.scheme() == data.global_ctx.get_flags().default_protocol)
.unwrap_or(listener);
let mut tasks = JoinSet::new();
ip_list.interface_ipv4s.iter().for_each(|ip| {
let addr = format!(
+9 -2
View File
@@ -78,11 +78,18 @@ pub async fn create_connector_by_url(
return Ok(Box::new(connector));
}
"wg" => {
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
let nid = global_ctx.get_network_identity();
let wg_config =
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
let connector = WgTunnelConnector::new(url, wg_config);
let mut connector = WgTunnelConnector::new(url, wg_config);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector));
}
_ => {
+9
View File
@@ -114,6 +114,9 @@ example: wg://0.0.0.0:11010/10.14.14.0/24, means the vpn portal is a wireguard s
and the vpn client is in network of 10.14.14.0/24"
)]
vpn_portal: Option<String>,
#[arg(long, help = "default protocol to use when connecting to peers")]
default_protocol: Option<String>,
}
impl From<Cli> for TomlConfigLoader {
@@ -238,6 +241,12 @@ impl From<Cli> for TomlConfigLoader {
});
}
if cli.default_protocol.is_some() {
let mut f = cfg.get_flags();
f.default_protocol = cli.default_protocol.as_ref().unwrap().clone();
cfg.set_flags(f);
}
cfg
}
}
+1 -1
View File
@@ -204,7 +204,7 @@ impl UdpNatEntry {
else {
break;
};
ip_id += 1;
ip_id = ip_id.wrapping_add(1);
}
self.stop();
+11
View File
@@ -8,6 +8,7 @@ use std::{
time::{Duration, SystemTime},
};
use crossbeam::atomic::AtomicCell;
use futures::Future;
use tokio::{
sync::{Mutex, RwLock},
@@ -37,6 +38,7 @@ static SERVICE_ID: u32 = 5;
struct PeridicJobCtx<T> {
peer_mgr: Arc<PeerManager>,
center_peer: AtomicCell<PeerId>,
job_ctx: T,
}
@@ -81,6 +83,7 @@ impl PeerCenterBase {
async move {
let ctx = Arc::new(PeridicJobCtx {
peer_mgr: peer_mgr.clone(),
center_peer: AtomicCell::new(PeerId::default()),
job_ctx,
});
loop {
@@ -89,6 +92,7 @@ impl PeerCenterBase {
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
};
ctx.center_peer.store(center_peer.clone());
tracing::trace!(?center_peer, "run periodic job");
let rpc_mgr = peer_mgr.get_peer_rpc_mgr();
let _g = lock.lock().await;
@@ -226,11 +230,13 @@ impl PeerCenterInstance {
service: PeerManagerRpcService,
need_send_peers: AtomicBool,
last_report_peers: Mutex<PeerInfoForGlobalMap>,
last_center_peer: AtomicCell<PeerId>,
}
let ctx = Arc::new(Ctx {
service: PeerManagerRpcService::new(self.peer_mgr.clone()),
need_send_peers: AtomicBool::new(true),
last_report_peers: Mutex::new(PeerInfoForGlobalMap::default()),
last_center_peer: AtomicCell::new(PeerId::default()),
});
self.client
@@ -241,6 +247,10 @@ impl PeerCenterInstance {
let mut peers = PeerInfoForGlobalMap::default();
for _ in 1..10 {
peers = ctx.job_ctx.service.list_peers().await.into();
if ctx.center_peer.load() != ctx.job_ctx.last_center_peer.load() {
// if center peer changed, report peers immediately
break;
}
if peers == *ctx.job_ctx.last_report_peers.lock().await {
return Ok(3000);
}
@@ -276,6 +286,7 @@ impl PeerCenterInstance {
return Ok(500);
}
ctx.job_ctx.last_center_peer.store(ctx.center_peer.load());
ctx.job_ctx.need_send_peers.store(false, Ordering::Relaxed);
Ok(3000)
})
+1 -1
View File
@@ -231,7 +231,7 @@ impl PeerConnPinger {
};
});
req_seq += 1;
req_seq = req_seq.wrapping_add(1);
tokio::time::sleep(Duration::from_millis(1000)).await;
}
});
+7 -10
View File
@@ -551,16 +551,13 @@ impl UdpTunnelConnector {
let mut futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
let socket = UdpSocket::bind(*bind_addr).await?;
// linux does not use interface of bind_addr to send packet, so we need to bind device
// mac can handle this with bind correctly
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket.bind_device(Some(dev_name.as_bytes()))?;
}
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(*bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &bind_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
futures.push(self.try_connect_with_socket(socket));
}
+81 -28
View File
@@ -16,7 +16,7 @@ use boringtun::{
x25519::{PublicKey, StaticSecret},
};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
@@ -559,6 +559,8 @@ pub struct WgTunnelConnector {
addr: url::Url,
config: WgConfig,
udp: Option<Arc<UdpSocket>>,
bind_addrs: Vec<SocketAddr>,
}
impl Debug for WgTunnelConnector {
@@ -576,6 +578,7 @@ impl WgTunnelConnector {
addr,
config,
udp: None,
bind_addrs: vec![],
}
}
@@ -609,20 +612,20 @@ impl WgTunnelConnector {
keepalive.into()
}
}
#[async_trait]
impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg")?;
tracing::warn!("wg connect: {:?}", self.addr);
let udp = UdpSocket::bind("0.0.0.0:0").await?;
#[tracing::instrument(skip(config))]
async fn connect_with_socket(
addr_url: url::Url,
config: WgConfig,
udp: UdpSocket,
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&addr_url, "wg")?;
tracing::warn!("wg connect: {:?}", addr);
let local_addr = udp.local_addr().unwrap().to_string();
let mut my_tun = Tunn::new(
self.config.my_secret_key.clone(),
self.config.peer_public_key.clone(),
config.my_secret_key.clone(),
config.peer_public_key.clone(),
None,
None,
rand::thread_rng().next_u32(),
@@ -638,7 +641,7 @@ impl super::TunnelConnector for WgTunnelConnector {
let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]);
udp.send_to(&keepalive, addr).await?;
let mut wg_peer = WgPeer::new(Arc::new(udp), self.config.clone(), addr);
let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr);
let tunnel = wg_peer.start_and_get_tunnel();
let data = wg_peer.data.as_ref().unwrap().clone();
@@ -659,16 +662,57 @@ impl super::TunnelConnector for WgTunnelConnector {
info: TunnelInfo {
tunnel_type: "wg".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(),
remote_addr: self.remote_url().into(),
remote_addr: addr_url.to_string(),
},
});
Ok(ret)
}
}
#[async_trait]
impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let bind_addrs = if self.bind_addrs.is_empty() {
vec!["0.0.0.0:0".parse().unwrap()]
} else {
self.bind_addrs.clone()
};
let mut futures = FuturesUnordered::new();
for bind_addr in bind_addrs.into_iter() {
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &bind_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
futures.push(Self::connect_with_socket(
self.addr.clone(),
self.config.clone(),
socket,
));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return ret;
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
@@ -676,19 +720,7 @@ pub mod tests {
use boringtun::*;
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use crate::tunnels::wireguard::*;
pub fn enable_log() {
let filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into())
.from_env()
.unwrap()
.add_directive("tarpc=error".parse().unwrap());
tracing_subscriber::fmt::fmt()
.pretty()
.with_env_filter(filter)
.init();
}
use crate::tunnels::{wireguard::*, TunnelConnector};
pub fn create_wg_config() -> (WgConfig, WgConfig) {
let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
@@ -717,7 +749,7 @@ pub mod tests {
}
#[tokio::test]
async fn test_wg() {
async fn wg_pingpong() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg);
let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg);
@@ -725,10 +757,31 @@ pub mod tests {
}
#[tokio::test]
async fn udp_bench() {
async fn wg_bench() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg);
let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg);
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn wg_bench_with_bind() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg);
let mut connector =
WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg);
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn wg_bench_with_bind_fail() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg);
let mut connector =
WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg);
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
}
+7
View File
@@ -218,7 +218,14 @@ impl VpnPortal for WireGuard {
}
async fn dump_client_config(&self, peer_mgr: Arc<PeerManager>) -> String {
if self.inner.is_none() {
return "ERROR: Wireguard VPN Portal Not Started".to_string();
}
let global_ctx = self.inner.as_ref().unwrap().global_ctx.clone();
if global_ctx.config.get_vpn_portal_config().is_none() {
return "ERROR: VPN Portal Config Not Set".to_string();
}
let routes = peer_mgr.list_routes().await;
let mut allow_ips = routes
.iter()