support respond stun request in udp tunnel (#484)

we can use this to help the hole punching. (getting public mapped address stablely)
This commit is contained in:
Sijie.Sun
2024-11-20 23:45:06 +08:00
committed by GitHub
parent 86600c6315
commit aed54f7318
3 changed files with 114 additions and 16 deletions
+75 -6
View File
@@ -3,11 +3,13 @@ use std::{
sync::{Arc, Weak},
};
use anyhow::Context;
use async_trait::async_trait;
use bytes::BytesMut;
use dashmap::DashMap;
use futures::{stream::FuturesUnordered, StreamExt};
use rand::{Rng, SeedableRng};
use zerocopy::AsBytes;
use std::net::SocketAddr;
use tokio::{
@@ -95,7 +97,60 @@ pub fn new_hole_punch_packet(tid: u32, buf_len: u16) -> ZCPacket {
)
}
fn get_zcpacket_from_buf(buf: BytesMut) -> Result<ZCPacket, TunnelError> {
fn is_stun_packet(b: &[u8]) -> bool {
// stun has following pattern:
// 1. first two bits are 0b00
// 2. magic cookie between 32-64 bits: 0x2112A442
b[4..8] == [0x21, 0x12, 0xA4, 0x42] && b[0] & 0xC0 == 0
}
async fn respond_stun_packet(
socket: Arc<UdpSocket>,
addr: SocketAddr,
req_buf: Vec<u8>,
) -> Result<(), anyhow::Error> {
use crate::common::stun_codec_ext::*;
use bytecodec::DecodeExt as _;
use bytecodec::EncodeExt as _;
use stun_codec::rfc5389::attributes::MappedAddress;
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
let mut decoder = MessageDecoder::<Attribute>::new();
let req_msg = decoder
.decode_from_bytes(&req_buf)
.map_err(|e| anyhow::anyhow!("stun decode error: {:?}", e))?
.map_err(|e| anyhow::anyhow!("stun decode broken message error: {:?}", e))?;
let tid = req_msg.transaction_id();
// we only respond easytier stun req, whose tid has 0xdeadbeef prefix
if tid.as_bytes()[0..4] != [0xde, 0xad, 0xbe, 0xef] {
anyhow::bail!("stun req tid not from easytier");
}
let mut resp_msg = Message::<Attribute>::new(
MessageClass::SuccessResponse,
BINDING,
// we discard the prefix, make sure our implementation is not compatible with other stun client
u32_to_tid(tid_to_u32(&tid)),
);
resp_msg.add_attribute(Attribute::MappedAddress(MappedAddress::new(addr.clone())));
let mut encoder = MessageEncoder::new();
let rsp_buf = encoder
.encode_into_bytes(resp_msg.clone())
.map_err(|e| anyhow::anyhow!("stun encode error: {:?}", e))?;
socket
.send_to(&rsp_buf, addr.clone())
.await
.with_context(|| "send stun response error")?;
tracing::debug!(?addr, ?req_msg, "udp respond stun packet done");
Ok(())
}
fn get_zcpacket_from_buf(buf: BytesMut, allow_stun: bool) -> Result<ZCPacket, TunnelError> {
let dg_size = buf.len();
if dg_size < UDP_TUNNEL_HEADER_SIZE {
return Err(TunnelError::InvalidPacket(format!(
@@ -104,6 +159,10 @@ fn get_zcpacket_from_buf(buf: BytesMut) -> Result<ZCPacket, TunnelError> {
)));
}
if allow_stun && is_stun_packet(&buf[..UDP_TUNNEL_HEADER_SIZE]) {
return Ok(ZCPacket::new_from_buf(buf, ZCPacketType::UDP));
}
let zc_packet = ZCPacket::new_from_buf(buf, ZCPacketType::UDP);
let header = zc_packet.udp_tunnel_header().unwrap();
let payload_len = header.len.get() as usize;
@@ -154,7 +213,7 @@ async fn forward_from_ring_to_udp(
}
}
async fn udp_recv_from_socket_forward_task<F>(socket: Arc<UdpSocket>, mut f: F)
async fn udp_recv_from_socket_forward_task<F>(socket: Arc<UdpSocket>, allow_stun: bool, mut f: F)
where
F: FnMut(ZCPacket, SocketAddr) -> (),
{
@@ -175,7 +234,7 @@ where
dg_size
);
let zc_packet = match get_zcpacket_from_buf(buf.split()) {
let zc_packet = match get_zcpacket_from_buf(buf.split(), allow_stun) {
Ok(v) => v,
Err(e) => {
tracing::warn!(?e, "udp get zc packet from buf error");
@@ -337,6 +396,16 @@ impl UdpTunnelListenerData {
let header = zc_packet.udp_tunnel_header().unwrap();
if header.msg_type == UdpPacketType::Syn as u8 {
tokio::spawn(Self::handle_new_connect(self.clone(), addr, zc_packet));
} else if is_stun_packet(header.as_bytes()) {
// ignore stun packet
tracing::debug!("udp forward packet ignore stun packet");
let socket = self.socket.as_ref().unwrap().clone();
tokio::spawn(async move {
let ret = respond_stun_packet(socket, addr, zc_packet.inner().to_vec()).await;
if let Err(e) = ret {
tracing::error!(?e, "udp respond stun packet error");
}
});
} else if header.msg_type != UdpPacketType::HolePunch as u8 {
let Some(mut conn) = self.sock_map.get_mut(&addr) else {
tracing::trace!(?header, "udp forward packet error, connection not found");
@@ -350,7 +419,7 @@ impl UdpTunnelListenerData {
async fn do_forward_task(self: Self) {
let socket = self.socket.as_ref().unwrap().clone();
udp_recv_from_socket_forward_task(socket, |zc_packet, addr| {
udp_recv_from_socket_forward_task(socket, true, |zc_packet, addr| {
self.do_forward_one_packet_to_conn(zc_packet, addr);
})
.await;
@@ -501,7 +570,7 @@ impl UdpTunnelConnector {
socket.recv_buf_from(&mut buf),
)
.await??;
let zc_packet = get_zcpacket_from_buf(buf.split())?;
let zc_packet = get_zcpacket_from_buf(buf.split(), false)?;
if recv_addr != addr {
tracing::warn!(?recv_addr, ?addr, ?usize, "udp wait sack addr not match");
}
@@ -588,7 +657,7 @@ impl UdpTunnelConnector {
tracing::debug!("connector udp close event");
return;
}
_ = udp_recv_from_socket_forward_task(socket_clone, |zc_packet, addr| {
_ = udp_recv_from_socket_forward_task(socket_clone,false, |zc_packet, addr| {
tracing::debug!(?addr, "connector udp forward task done");
if let Err(e) = udp_conn.handle_packet_from_remote(zc_packet) {
tracing::trace!(?e, ?addr, "udp forward packet error");