diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index ee84c5e0..b77dc014 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -37,7 +37,7 @@ tracing-subscriber = { version = "0.3", features = [ "time", ] } derivative = "2.2.0" -derive_more = {version = "2.1.1", features = ["full"]} +derive_more = { version = "2.1.1", features = ["full"] } console-subscriber = { version = "0.4.1", optional = true } indoc = "2.0.7" regex = "1.8" @@ -79,12 +79,12 @@ quinn = { version = "0.11.8", optional = true, features = ["ring"] } quinn-plaintext = { version = "0.3.0", optional = true } rustls = { version = "0.23.0", features = [ - "ring","tls12" + "ring", "tls12" ], default-features = false, optional = true } rcgen = { version = "0.12.1", optional = true } # for websocket -tokio-websockets = { version = "0.8", optional = true, features = [ +tokio-websockets = { version = "0.13.2", optional = true, features = [ "rustls-webpki-roots", "client", "server", @@ -94,6 +94,7 @@ tokio-websockets = { version = "0.8", optional = true, features = [ http = { version = "1", default-features = false, features = [ "std", ], optional = true } +forwarded-header-value = { version = "0.1.1", optional = true } tokio-rustls = { version = "0.26", default-features = false, optional = true } # for tap device @@ -387,6 +388,7 @@ tun = ["dep:tun"] websocket = [ "dep:tokio-websockets", "dep:http", + "dep:forwarded-header-value", "dep:tokio-rustls", "dep:rustls", "dep:rcgen", diff --git a/easytier/src/tunnel/websocket.rs b/easytier/src/tunnel/websocket.rs index bf644210..71b9d01e 100644 --- a/easytier/src/tunnel/websocket.rs +++ b/easytier/src/tunnel/websocket.rs @@ -1,14 +1,17 @@ -use std::{net::SocketAddr, sync::Arc, time::Duration}; - use anyhow::Context; use bytes::BytesMut; +use forwarded_header_value::ForwardedHeaderValue; use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; +use pnet::ipnetwork::IpNetwork; +use std::sync::LazyLock; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use tokio::{ net::{TcpListener, TcpSocket, TcpStream}, time::timeout, }; use tokio_rustls::TlsAcceptor; -use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message}; +use tokio_util::either::Either; +use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message, ServerBuilder}; use zerocopy::AsBytes; use super::TunnelInfo; @@ -59,6 +62,20 @@ async fn map_from_ws_message( ))) } +static TRUSTED_PROXIES: LazyLock> = LazyLock::new(|| { + [ + "127.0.0.0/8", // IPV4 Loopback + "10.0.0.0/8", // IPV4 Private Networks + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", // IPV6 Loopback + "fc00::/7", // IPV6 Private network + ] + .into_iter() + .map(|s| s.parse().unwrap()) + .collect() +}); + #[derive(Debug)] pub struct WSTunnelListener { addr: url::Url, @@ -74,47 +91,69 @@ impl WSTunnelListener { } async fn try_accept(&self, stream: TcpStream) -> Result, TunnelError> { - let info = TunnelInfo { - tunnel_type: self.addr.scheme().to_owned(), - local_addr: Some(self.local_url().into()), - remote_addr: Some( - super::build_url_from_socket_addr( - &stream.peer_addr()?.to_string(), - self.addr.scheme().to_string().as_str(), - ) - .into(), - ), - }; + let mut remote_addr = stream.peer_addr()?; - let server_bulder = tokio_websockets::ServerBuilder::new().limits(Limits::unlimited()); - - let ret: Box = if is_wss(&self.addr)? { + let stream = if is_wss(&self.addr)? { init_crypto_provider(); let (certs, key) = get_insecure_tls_cert(); let config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key) .with_context(|| "Failed to create server config")?; - let acceptor = TlsAcceptor::from(Arc::new(config)); - let stream = acceptor.accept(stream).await?; - let (write, read) = server_bulder.accept(stream).await?.split(); - - Box::new(TunnelWrapper::new( - read.filter_map(map_from_ws_message), - write.with(sink_from_zc_packet), - Some(info), - )) + let stream = TlsAcceptor::from(Arc::new(config)).accept(stream).await?; + Either::Left(stream) } else { - let (write, read) = server_bulder.accept(stream).await?.split(); - Box::new(TunnelWrapper::new( - read.filter_map(map_from_ws_message), - write.with(sink_from_zc_packet), - Some(info), - )) + Either::Right(stream) }; - Ok(ret) + let (request, stream) = ServerBuilder::new() + .limits(Limits::unlimited()) + .accept(stream) + .await?; + + if TRUSTED_PROXIES + .iter() + .any(|net| net.contains(remote_addr.ip())) + { + if let Some(forwarded) = request + .headers() + .get("Forwarded") + .and_then(|f| f.to_str().ok()) + .and_then(|f| ForwardedHeaderValue::from_forwarded(f).ok()) + .or_else(|| { + request + .headers() + .get("X-Forwarded-For") + .and_then(|f| f.to_str().ok()) + .and_then(|f| ForwardedHeaderValue::from_x_forwarded_for(f).ok()) + }) + { + if let Some(ip) = forwarded.remotest_forwarded_for_ip() { + remote_addr = SocketAddr::new(ip, 0); + } + } + } + + let (write, read) = stream.split(); + + let info = TunnelInfo { + tunnel_type: self.addr.scheme().to_owned(), + local_addr: Some(self.local_url().into()), + remote_addr: Some( + super::build_url_from_socket_addr( + &remote_addr.to_string(), + self.addr.scheme().to_string().as_str(), + ) + .into(), + ), + }; + + Ok(Box::new(TunnelWrapper::new( + read.filter_map(map_from_ws_message), + write.with(sink_from_zc_packet), + Some(info), + ))) } } @@ -292,9 +331,9 @@ impl TunnelConnector for WSTunnelConnector { #[cfg(test)] pub mod tests { + use super::*; use crate::tunnel::common::tests::_tunnel_pingpong; - use crate::tunnel::websocket::{WSTunnelConnector, WSTunnelListener}; - use crate::tunnel::{TunnelConnector, TunnelListener}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[rstest::rstest] #[tokio::test] @@ -345,4 +384,48 @@ pub mod tests { j.abort(); } + + #[tokio::test] + async fn ws_forwarded() { + let mut listener = WSTunnelListener::new("ws://127.0.0.1:25559".parse().unwrap()); + listener.listen().await.unwrap(); + + let server_task = tokio::spawn(async move { + let tunnel = listener.accept().await.unwrap(); + + let remote_addr = tunnel + .info() + .unwrap() + .remote_addr + .unwrap() + .url + .parse::() + .unwrap(); + + assert_eq!(remote_addr.host_str().unwrap(), "203.0.113.5"); + + tunnel + }); + + let mut stream = TcpStream::connect("127.0.0.1:25559").await.unwrap(); + + let handshake = "GET / HTTP/1.1\r\n\ + Host: 127.0.0.1:25559\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\ + X-Forwarded-For: 203.0.113.5, 192.168.1.1\r\n\ + \r\n"; + + stream.write_all(handshake.as_bytes()).await.unwrap(); + + let mut buf = [0u8; 1024]; + let bytes_read = stream.read(&mut buf).await.unwrap(); + let response = String::from_utf8_lossy(&buf[..bytes_read]); + + assert!(response.contains("101 Switching Protocols")); + + let _tunnel = server_task.await.unwrap(); + } }