support no tun mode (#141)

This commit is contained in:
Sijie.Sun
2024-06-10 10:27:24 +08:00
committed by GitHub
parent fede35cca4
commit 8aa57ebc22
21 changed files with 1722 additions and 170 deletions
@@ -0,0 +1,75 @@
use futures::{Sink, Stream};
use smoltcp::phy::DeviceCapabilities;
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_util::sync::{PollSendError, PollSender};
use super::device::AsyncDevice;
/// A device that send and receive packets using a channel.
pub struct ChannelDevice {
recv: Receiver<io::Result<Vec<u8>>>,
send: PollSender<Vec<u8>>,
caps: DeviceCapabilities,
}
impl ChannelDevice {
/// Make a new `ChannelDevice` with the given `recv` and `send` channels.
///
/// The `caps` is used to determine the device capabilities. `DeviceCapabilities::max_transmission_unit` must be set.
pub fn new(caps: DeviceCapabilities) -> (Self, Sender<io::Result<Vec<u8>>>, Receiver<Vec<u8>>) {
let (tx1, rx1) = channel(1000);
let (tx2, rx2) = channel(1000);
(
ChannelDevice {
send: PollSender::new(tx1),
recv: rx2,
caps,
},
tx2,
rx1,
)
}
}
impl Stream for ChannelDevice {
type Item = io::Result<Vec<u8>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.recv.poll_recv(cx)
}
}
fn map_err(e: PollSendError<Vec<u8>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
impl Sink<Vec<u8>> for ChannelDevice {
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.send.poll_reserve(cx).map_err(map_err)
}
fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
self.send.send_item(item).map_err(map_err)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.send.poll_reserve(cx).map_err(map_err)
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl AsyncDevice for ChannelDevice {
fn capabilities(&self) -> &DeviceCapabilities {
&self.caps
}
}
@@ -0,0 +1,122 @@
use futures::{Sink, Stream};
pub use smoltcp::phy::DeviceCapabilities;
use smoltcp::{
phy::{Device, RxToken, TxToken},
time::Instant,
};
use std::{collections::VecDeque, io};
/// Default value of `max_burst_size`.
pub const DEFAULT_MAX_BURST_SIZE: usize = 100;
/// A packet used in `AsyncDevice`.
pub type Packet = Vec<u8>;
/// A device that send and receive packets asynchronously.
pub trait AsyncDevice:
Stream<Item = io::Result<Packet>> + Sink<Packet, Error = io::Error> + Send + Unpin
{
/// Returns the device capabilities.
fn capabilities(&self) -> &DeviceCapabilities;
}
impl<T> AsyncDevice for Box<T>
where
T: AsyncDevice,
{
fn capabilities(&self) -> &DeviceCapabilities {
(**self).capabilities()
}
}
/// A device that send and receive packets synchronously.
pub struct BufferDevice {
caps: DeviceCapabilities,
max_burst_size: usize,
recv_queue: VecDeque<Packet>,
send_queue: VecDeque<Packet>,
}
/// RxToken for `BufferDevice`.
pub struct BufferRxToken(Packet);
impl RxToken for BufferRxToken {
fn consume<R, F>(mut self, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let p = &mut self.0;
let result = f(p);
result
}
}
/// TxToken for `BufferDevice`.
pub struct BufferTxToken<'a>(&'a mut BufferDevice);
impl<'d> TxToken for BufferTxToken<'d> {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut buffer = vec![0u8; len];
let result = f(&mut buffer);
self.0.send_queue.push_back(buffer);
result
}
}
impl Device for BufferDevice {
type RxToken<'a> = BufferRxToken
where Self:'a;
type TxToken<'a> = BufferTxToken<'a>
where Self:'a;
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
match self.recv_queue.pop_front() {
Some(p) => Some((BufferRxToken(p), BufferTxToken(self))),
None => None,
}
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
if self.send_queue.len() < self.max_burst_size {
Some(BufferTxToken(self))
} else {
None
}
}
fn capabilities(&self) -> DeviceCapabilities {
self.caps.clone()
}
}
impl BufferDevice {
pub(crate) fn new(caps: DeviceCapabilities) -> BufferDevice {
let max_burst_size = caps.max_burst_size.unwrap_or(DEFAULT_MAX_BURST_SIZE);
BufferDevice {
caps,
max_burst_size,
recv_queue: VecDeque::with_capacity(max_burst_size),
send_queue: VecDeque::with_capacity(max_burst_size),
}
}
pub(crate) fn take_send_queue(&mut self) -> VecDeque<Packet> {
std::mem::replace(
&mut self.send_queue,
VecDeque::with_capacity(self.max_burst_size),
)
}
pub(crate) fn push_recv_queue(&mut self, p: impl Iterator<Item = Packet>) {
self.recv_queue.extend(p.take(self.avaliable_recv_queue()));
}
pub(crate) fn avaliable_recv_queue(&self) -> usize {
self.max_burst_size - self.recv_queue.len()
}
pub(crate) fn need_wait(&self) -> bool {
self.recv_queue.is_empty()
}
}
+220
View File
@@ -0,0 +1,220 @@
// most code is copied from https://github.com/spacemeowx2/tokio-smoltcp
//! An asynchronous wrapper for smoltcp.
use std::{
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
sync::{
atomic::{AtomicU16, Ordering},
Arc,
},
};
use device::BufferDevice;
use futures::Future;
use reactor::Reactor;
pub use smoltcp;
use smoltcp::{
iface::{Config, Interface, Routes},
time::{Duration, Instant},
wire::{HardwareAddress, IpAddress, IpCidr, IpProtocol, IpVersion},
};
pub use socket::{RawSocket, TcpListener, TcpStream, UdpSocket};
pub use socket_allocator::BufferSize;
use tokio::sync::Notify;
/// The async devices.
pub mod channel_device;
pub mod device;
mod reactor;
mod socket;
mod socket_allocator;
/// Can be used to create a forever timestamp in neighbor.
// The 60_000 is the same as NeighborCache::ENTRY_LIFETIME.
pub const FOREVER: Instant =
Instant::from_micros_const(i64::max_value() - Duration::from_millis(60_000).micros() as i64);
pub struct Neighbor {
pub protocol_addr: IpAddress,
pub hardware_addr: HardwareAddress,
pub timestamp: Instant,
}
/// A config for a `Net`.
///
/// This is used to configure the `Net`.
#[non_exhaustive]
pub struct NetConfig {
pub interface_config: Config,
pub ip_addr: IpCidr,
pub gateway: Vec<IpAddress>,
pub buffer_size: BufferSize,
}
impl NetConfig {
pub fn new(interface_config: Config, ip_addr: IpCidr, gateway: Vec<IpAddress>) -> Self {
Self {
interface_config,
ip_addr,
gateway,
buffer_size: Default::default(),
}
}
}
/// `Net` is the main interface to the network stack.
/// Socket creation and configuration is done through the `Net` interface.
///
/// When `Net` is dropped, all sockets are closed and the network stack is stopped.
pub struct Net {
reactor: Arc<Reactor>,
ip_addr: IpCidr,
from_port: AtomicU16,
stopper: Arc<Notify>,
}
impl std::fmt::Debug for Net {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Net")
.field("ip_addr", &self.ip_addr)
.field("from_port", &self.from_port)
.finish()
}
}
impl Net {
/// Creates a new `Net` instance. It panics if the medium is not supported.
pub fn new<D: device::AsyncDevice + 'static>(device: D, config: NetConfig) -> Net {
let (net, fut) = Self::new2(device, config);
tokio::spawn(fut);
net
}
fn new2<D: device::AsyncDevice + 'static>(
device: D,
config: NetConfig,
) -> (Net, impl Future<Output = io::Result<()>> + Send) {
let mut buffer_device = BufferDevice::new(device.capabilities().clone());
let mut iface = Interface::new(config.interface_config, &mut buffer_device, Instant::now());
let ip_addr = config.ip_addr;
iface.update_ip_addrs(|ip_addrs| {
ip_addrs.push(ip_addr).unwrap();
});
for gateway in config.gateway {
match gateway {
IpAddress::Ipv4(v4) => {
iface.routes_mut().add_default_ipv4_route(v4).unwrap();
}
IpAddress::Ipv6(v6) => {
iface.routes_mut().add_default_ipv6_route(v6).unwrap();
}
#[allow(unreachable_patterns)]
_ => panic!("Unsupported address"),
};
}
let stopper = Arc::new(Notify::new());
let (reactor, fut) = Reactor::new(
device,
iface,
buffer_device,
config.buffer_size,
stopper.clone(),
);
(
Net {
reactor: Arc::new(reactor),
ip_addr: config.ip_addr,
from_port: AtomicU16::new(10001),
stopper,
},
fut,
)
}
fn get_port(&self) -> u16 {
self.from_port
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| {
Some(if x > 60000 { 10000 } else { x + 1 })
})
.unwrap()
}
/// Creates a new TcpListener, which will be bound to the specified address.
pub async fn tcp_bind(&self, addr: SocketAddr) -> io::Result<TcpListener> {
let addr = self.set_address(addr);
TcpListener::new(self.reactor.clone(), addr.into()).await
}
/// Opens a TCP connection to a remote host.
pub async fn tcp_connect(&self, addr: SocketAddr) -> io::Result<TcpStream> {
TcpStream::connect(
self.reactor.clone(),
(self.ip_addr.address(), self.get_port()).into(),
addr.into(),
)
.await
}
/// This function will create a new UDP socket and attempt to bind it to the `addr` provided.
pub async fn udp_bind(&self, addr: SocketAddr) -> io::Result<UdpSocket> {
let addr = self.set_address(addr);
UdpSocket::new(self.reactor.clone(), addr.into()).await
}
/// Creates a new raw socket.
pub async fn raw_socket(
&self,
ip_version: IpVersion,
ip_protocol: IpProtocol,
) -> io::Result<RawSocket> {
RawSocket::new(self.reactor.clone(), ip_version, ip_protocol).await
}
fn set_address(&self, mut addr: SocketAddr) -> SocketAddr {
if addr.ip().is_unspecified() {
addr.set_ip(match self.ip_addr.address() {
IpAddress::Ipv4(ip) => Ipv4Addr::from(ip).into(),
IpAddress::Ipv6(ip) => Ipv6Addr::from(ip).into(),
#[allow(unreachable_patterns)]
_ => panic!("address must not be unspecified"),
});
}
if addr.port() == 0 {
addr.set_port(self.get_port());
}
addr
}
/// Enable or disable the AnyIP capability.
pub fn set_any_ip(&self, any_ip: bool) {
let iface = self.reactor.iface().clone();
let mut iface: parking_lot::lock_api::MutexGuard<'_, parking_lot::RawMutex, Interface> =
iface.lock();
iface.set_any_ip(any_ip);
}
/// Get whether AnyIP is enabled.
pub fn any_ip(&self) -> bool {
let iface = self.reactor.iface().clone();
let iface = iface.lock();
iface.any_ip()
}
pub fn routes<F: FnOnce(&Routes)>(&self, f: F) {
let iface = self.reactor.iface().clone();
let iface = iface.lock();
let routes = iface.routes();
f(routes)
}
pub fn routes_mut<F: FnOnce(&mut Routes)>(&self, f: F) {
let iface = self.reactor.iface().clone();
let mut iface = iface.lock();
let routes = iface.routes_mut();
f(routes)
}
}
impl Drop for Net {
fn drop(&mut self) {
self.stopper.notify_waiters()
}
}
@@ -0,0 +1,163 @@
use super::{
device::{BufferDevice, Packet},
socket_allocator::{BufferSize, SocketAlloctor},
};
use futures::{stream::iter, FutureExt, SinkExt, StreamExt};
use parking_lot::{MappedMutexGuard, Mutex, MutexGuard};
use smoltcp::{
iface::{Context, Interface, SocketHandle},
socket::{AnySocket, Socket},
time::{Duration, Instant},
};
use std::{collections::VecDeque, future::Future, io, sync::Arc};
use tokio::{pin, select, sync::Notify, time::sleep};
pub(crate) type BufferInterface = Arc<Mutex<Interface>>;
const MAX_BURST_SIZE: usize = 100;
pub(crate) struct Reactor {
notify: Arc<Notify>,
iface: BufferInterface,
socket_allocator: SocketAlloctor,
}
async fn receive(
async_iface: &mut impl super::device::AsyncDevice,
recv_buf: &mut VecDeque<Packet>,
) -> io::Result<()> {
if let Some(packet) = async_iface.next().await {
recv_buf.push_back(packet?);
}
Ok(())
}
async fn run(
mut async_iface: impl super::device::AsyncDevice,
iface: BufferInterface,
mut device: BufferDevice,
socket_allocator: SocketAlloctor,
notify: Arc<Notify>,
stopper: Arc<Notify>,
) -> io::Result<()> {
let default_timeout = Duration::from_secs(60);
let timer = sleep(default_timeout.into());
let max_burst_size = async_iface
.capabilities()
.max_burst_size
.unwrap_or(MAX_BURST_SIZE);
let mut recv_buf = VecDeque::with_capacity(max_burst_size);
pin!(timer);
loop {
let packets = device.take_send_queue();
async_iface
.send_all(&mut iter(packets).map(|p| Ok(p)))
.await?;
if recv_buf.is_empty() && device.need_wait() {
let start = Instant::now();
let deadline = {
iface
.lock()
.poll_delay(start, &socket_allocator.sockets().lock())
.unwrap_or(default_timeout)
};
timer
.as_mut()
.reset(tokio::time::Instant::now() + deadline.into());
select! {
_ = &mut timer => {},
_ = receive(&mut async_iface,&mut recv_buf) => {}
_ = notify.notified() => {}
_ = stopper.notified() => break,
};
while let (true, Some(Ok(p))) = (
recv_buf.len() < max_burst_size,
async_iface.next().now_or_never().flatten(),
) {
recv_buf.push_back(p);
}
}
let mut iface = iface.lock();
device.push_recv_queue(recv_buf.drain(..device.avaliable_recv_queue().min(recv_buf.len())));
iface.poll(
Instant::now(),
&mut device,
&mut socket_allocator.sockets().lock(),
);
}
Ok(())
}
impl Reactor {
pub fn new(
async_device: impl super::device::AsyncDevice,
iface: Interface,
device: BufferDevice,
buffer_size: BufferSize,
stopper: Arc<Notify>,
) -> (Self, impl Future<Output = io::Result<()>> + Send) {
let iface = Arc::new(Mutex::new(iface));
let notify = Arc::new(Notify::new());
let socket_allocator = SocketAlloctor::new(buffer_size);
let fut = run(
async_device,
iface.clone(),
device,
socket_allocator.clone(),
notify.clone(),
stopper,
);
(
Reactor {
notify,
iface: iface.clone(),
socket_allocator,
},
fut,
)
}
pub fn get_socket<T: AnySocket<'static>>(
&self,
handle: SocketHandle,
) -> MappedMutexGuard<'_, T> {
MutexGuard::map(
self.socket_allocator.sockets().lock(),
|sockets: &mut smoltcp::iface::SocketSet<'_>| sockets.get_mut::<T>(handle),
)
}
pub fn context(&self) -> MappedMutexGuard<'_, Context> {
MutexGuard::map(self.iface.lock(), |iface| iface.context())
}
pub fn socket_allocator(&self) -> &SocketAlloctor {
&self.socket_allocator
}
pub fn notify(&self) {
self.notify.notify_waiters();
}
pub fn iface(&self) -> &BufferInterface {
&self.iface
}
}
impl Drop for Reactor {
fn drop(&mut self) {
for (_, socket) in self.socket_allocator.sockets().lock().iter_mut() {
match socket {
Socket::Tcp(tcp) => tcp.close(),
Socket::Raw(_) => {}
Socket::Udp(udp) => udp.close(),
#[allow(unreachable_patterns)]
_ => {}
}
}
}
}
@@ -0,0 +1,377 @@
use super::{reactor::Reactor, socket_allocator::SocketHandle};
use futures::future::{self, poll_fn};
use futures::{ready, Stream};
pub use smoltcp::socket::{raw, tcp, udp};
use smoltcp::wire::{IpAddress, IpEndpoint, IpProtocol, IpVersion};
use std::mem::replace;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::{
io,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// A TCP socket server, listening for connections.
///
/// You can accept a new connection by using the accept method.
pub struct TcpListener {
handle: SocketHandle,
reactor: Arc<Reactor>,
local_addr: SocketAddr,
}
fn map_err<E: std::error::Error>(e: E) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}
impl TcpListener {
pub(super) async fn new(
reactor: Arc<Reactor>,
local_endpoint: IpEndpoint,
) -> io::Result<TcpListener> {
let handle = reactor.socket_allocator().new_tcp_socket();
{
let mut socket = reactor.get_socket::<tcp::Socket>(*handle);
socket.listen(local_endpoint).map_err(map_err)?;
}
let local_addr = ep2sa(&local_endpoint);
Ok(TcpListener {
handle,
reactor,
local_addr,
})
}
pub fn poll_accept(
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
if socket.state() == tcp::State::Established {
drop(socket);
return Poll::Ready(Ok(TcpStream::accept(self)?));
}
socket.register_send_waker(cx.waker());
Poll::Pending
}
pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
poll_fn(|cx| self.poll_accept(cx)).await
}
pub fn incoming(self) -> Incoming {
Incoming(self)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
}
pub struct Incoming(TcpListener);
impl Stream for Incoming {
type Item = io::Result<TcpStream>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (tcp, _) = ready!(self.0.poll_accept(cx))?;
Poll::Ready(Some(Ok(tcp)))
}
}
fn ep2sa(ep: &IpEndpoint) -> SocketAddr {
match ep.addr {
IpAddress::Ipv4(v4) => SocketAddr::new(IpAddr::V4(Ipv4Addr::from(v4)), ep.port),
IpAddress::Ipv6(v6) => SocketAddr::new(IpAddr::V6(Ipv6Addr::from(v6)), ep.port),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
/// A TCP stream between a local and a remote socket.
pub struct TcpStream {
handle: SocketHandle,
reactor: Arc<Reactor>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
}
impl TcpStream {
pub(super) async fn connect(
reactor: Arc<Reactor>,
local_endpoint: IpEndpoint,
remote_endpoint: IpEndpoint,
) -> io::Result<TcpStream> {
let handle = reactor.socket_allocator().new_tcp_socket();
reactor
.get_socket::<tcp::Socket>(*handle)
.connect(&mut reactor.context(), remote_endpoint, local_endpoint)
.map_err(map_err)?;
let local_addr = ep2sa(&local_endpoint);
let peer_addr = ep2sa(&remote_endpoint);
let tcp = TcpStream {
handle,
reactor,
local_addr,
peer_addr,
};
tcp.reactor.notify();
future::poll_fn(|cx| tcp.poll_connected(cx)).await?;
Ok(tcp)
}
fn accept(listener: &mut TcpListener) -> io::Result<(TcpStream, SocketAddr)> {
let reactor = listener.reactor.clone();
let new_handle = reactor.socket_allocator().new_tcp_socket();
{
let mut new_socket = reactor.get_socket::<tcp::Socket>(*new_handle);
new_socket.listen(listener.local_addr).map_err(map_err)?;
}
let (peer_addr, local_addr) = {
let socket = reactor.get_socket::<tcp::Socket>(*listener.handle);
(
// should be Some, because the state is Established
ep2sa(&socket.remote_endpoint().unwrap()),
ep2sa(&socket.local_endpoint().unwrap()),
)
};
Ok((
TcpStream {
handle: replace(&mut listener.handle, new_handle),
reactor: reactor.clone(),
local_addr,
peer_addr,
},
peer_addr,
))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(self.peer_addr)
}
pub fn poll_connected(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
if socket.state() == tcp::State::Established {
return Poll::Ready(Ok(()));
}
socket.register_send_waker(cx.waker());
Poll::Pending
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
if !socket.may_recv() {
return Poll::Ready(Ok(()));
}
if socket.can_recv() {
let read = socket
.recv_slice(buf.initialize_unfilled())
.map_err(map_err)?;
self.reactor.notify();
buf.advance(read);
return Poll::Ready(Ok(()));
}
socket.register_recv_waker(cx.waker());
Poll::Pending
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
if !socket.may_send() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
if socket.can_send() {
let r = socket.send_slice(buf).map_err(map_err)?;
self.reactor.notify();
return Poll::Ready(Ok(r));
}
socket.register_send_waker(cx.waker());
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
if socket.send_queue() == 0 {
return Poll::Ready(Ok(()));
}
socket.register_send_waker(cx.waker());
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let mut socket = self.reactor.get_socket::<tcp::Socket>(*self.handle);
if socket.is_open() {
socket.close();
self.reactor.notify();
}
if socket.state() == tcp::State::Closed {
return Poll::Ready(Ok(()));
}
socket.register_send_waker(cx.waker());
Poll::Pending
}
}
/// A UDP socket.
pub struct UdpSocket {
handle: SocketHandle,
reactor: Arc<Reactor>,
local_addr: SocketAddr,
}
impl UdpSocket {
pub(super) async fn new(
reactor: Arc<Reactor>,
local_endpoint: IpEndpoint,
) -> io::Result<UdpSocket> {
let handle = reactor.socket_allocator().new_udp_socket();
{
let mut socket = reactor.get_socket::<udp::Socket>(*handle);
socket.bind(local_endpoint).map_err(map_err)?;
}
let local_addr = ep2sa(&local_endpoint);
Ok(UdpSocket {
handle,
reactor,
local_addr,
})
}
/// Note that on multiple calls to a poll_* method in the send direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup.
pub fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: SocketAddr,
) -> Poll<io::Result<usize>> {
let mut socket = self.reactor.get_socket::<udp::Socket>(*self.handle);
let target_ip: IpEndpoint = target.into();
match socket.send_slice(buf, target_ip) {
// the buffer is full
Err(udp::SendError::BufferFull) => {}
r => {
r.map_err(map_err)?;
self.reactor.notify();
return Poll::Ready(Ok(buf.len()));
}
}
socket.register_send_waker(cx.waker());
Poll::Pending
}
/// See note on `poll_send_to`
pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
}
/// Note that on multiple calls to a poll_* method in the recv direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup.
pub fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
let mut socket = self.reactor.get_socket::<udp::Socket>(*self.handle);
match socket.recv_slice(buf) {
// the buffer is empty
Err(udp::RecvError::Exhausted) => {}
r => {
let (size, metadata) = r.map_err(map_err)?;
self.reactor.notify();
return Poll::Ready(Ok((size, ep2sa(&metadata.endpoint))));
}
}
socket.register_recv_waker(cx.waker());
Poll::Pending
}
/// See note on `poll_recv_from`
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
poll_fn(|cx| self.poll_recv_from(cx, buf)).await
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
}
/// A raw socket.
pub struct RawSocket {
handle: SocketHandle,
reactor: Arc<Reactor>,
}
impl RawSocket {
pub(super) async fn new(
reactor: Arc<Reactor>,
ip_version: IpVersion,
ip_protocol: IpProtocol,
) -> io::Result<RawSocket> {
let handle = reactor
.socket_allocator()
.new_raw_socket(ip_version, ip_protocol);
Ok(RawSocket { handle, reactor })
}
/// Note that on multiple calls to a poll_* method in the send direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup.
pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut socket = self.reactor.get_socket::<raw::Socket>(*self.handle);
match socket.send_slice(buf) {
// the buffer is full
Err(raw::SendError::BufferFull) => {}
r => {
r.map_err(map_err)?;
self.reactor.notify();
return Poll::Ready(Ok(buf.len()));
}
}
socket.register_send_waker(cx.waker());
Poll::Pending
}
/// See note on `poll_send`
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
poll_fn(|cx| self.poll_send(cx, buf)).await
}
/// Note that on multiple calls to a poll_* method in the recv direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup.
pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let mut socket = self.reactor.get_socket::<raw::Socket>(*self.handle);
match socket.recv_slice(buf) {
// the buffer is empty
Err(raw::RecvError::Exhausted) => {}
r => {
let size = r.map_err(map_err)?;
return Poll::Ready(Ok(size));
}
}
socket.register_recv_waker(cx.waker());
Poll::Pending
}
/// See note on `poll_recv`
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| self.poll_recv(cx, buf)).await
}
}
@@ -0,0 +1,145 @@
use parking_lot::Mutex;
use smoltcp::{
iface::{SocketHandle as InnerSocketHandle, SocketSet},
socket::{raw, tcp, udp},
wire::{IpProtocol, IpVersion},
};
use std::{
ops::{Deref, DerefMut},
sync::Arc,
};
/// `BufferSize` is used to configure the size of the socket buffer.
#[derive(Debug, Clone, Copy)]
pub struct BufferSize {
pub tcp_rx_size: usize,
pub tcp_tx_size: usize,
pub udp_rx_size: usize,
pub udp_tx_size: usize,
pub udp_rx_meta_size: usize,
pub udp_tx_meta_size: usize,
pub raw_rx_size: usize,
pub raw_tx_size: usize,
pub raw_rx_meta_size: usize,
pub raw_tx_meta_size: usize,
}
impl Default for BufferSize {
fn default() -> Self {
BufferSize {
tcp_rx_size: 8192,
tcp_tx_size: 8192,
udp_rx_size: 8192,
udp_tx_size: 8192,
udp_rx_meta_size: 32,
udp_tx_meta_size: 32,
raw_rx_size: 8192,
raw_tx_size: 8192,
raw_rx_meta_size: 32,
raw_tx_meta_size: 32,
}
}
}
type SharedSocketSet = Arc<Mutex<SocketSet<'static>>>;
#[derive(Clone)]
pub struct SocketAlloctor {
sockets: SharedSocketSet,
buffer_size: BufferSize,
}
impl SocketAlloctor {
pub(crate) fn new(buffer_size: BufferSize) -> SocketAlloctor {
let sockets = Arc::new(Mutex::new(SocketSet::new(Vec::new())));
SocketAlloctor {
sockets,
buffer_size,
}
}
pub(crate) fn sockets(&self) -> &SharedSocketSet {
&self.sockets
}
pub fn new_tcp_socket(&self) -> SocketHandle {
let mut set = self.sockets.lock();
let handle = set.add(self.alloc_tcp_socket());
SocketHandle::new(handle, self.sockets.clone())
}
pub fn new_udp_socket(&self) -> SocketHandle {
let mut set = self.sockets.lock();
let handle = set.add(self.alloc_udp_socket());
SocketHandle::new(handle, self.sockets.clone())
}
pub fn new_raw_socket(&self, ip_version: IpVersion, ip_protocol: IpProtocol) -> SocketHandle {
let mut set = self.sockets.lock();
let handle = set.add(self.alloc_raw_socket(ip_version, ip_protocol));
SocketHandle::new(handle, self.sockets.clone())
}
fn alloc_tcp_socket(&self) -> tcp::Socket<'static> {
let rx_buffer = tcp::SocketBuffer::new(vec![0; self.buffer_size.tcp_rx_size]);
let tx_buffer = tcp::SocketBuffer::new(vec![0; self.buffer_size.tcp_tx_size]);
let mut tcp = tcp::Socket::new(rx_buffer, tx_buffer);
tcp.set_nagle_enabled(false);
tcp
}
fn alloc_udp_socket(&self) -> udp::Socket<'static> {
let rx_buffer = udp::PacketBuffer::new(
vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_rx_meta_size],
vec![0; self.buffer_size.udp_rx_size],
);
let tx_buffer = udp::PacketBuffer::new(
vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_tx_meta_size],
vec![0; self.buffer_size.udp_tx_size],
);
let udp = udp::Socket::new(rx_buffer, tx_buffer);
udp
}
fn alloc_raw_socket(
&self,
ip_version: IpVersion,
ip_protocol: IpProtocol,
) -> raw::Socket<'static> {
let rx_buffer = raw::PacketBuffer::new(
vec![raw::PacketMetadata::EMPTY; self.buffer_size.raw_rx_meta_size],
vec![0; self.buffer_size.raw_rx_size],
);
let tx_buffer = raw::PacketBuffer::new(
vec![raw::PacketMetadata::EMPTY; self.buffer_size.raw_tx_meta_size],
vec![0; self.buffer_size.raw_tx_size],
);
let raw = raw::Socket::new(ip_version, ip_protocol, rx_buffer, tx_buffer);
raw
}
}
pub struct SocketHandle(InnerSocketHandle, SharedSocketSet);
impl SocketHandle {
fn new(inner: InnerSocketHandle, set: SharedSocketSet) -> SocketHandle {
SocketHandle(inner, set)
}
}
impl Drop for SocketHandle {
fn drop(&mut self) {
let mut iface = self.1.lock();
iface.remove(self.0);
}
}
impl Deref for SocketHandle {
type Target = InnerSocketHandle;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for SocketHandle {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}