summaryrefslogtreecommitdiff
path: root/embassy-net
diff options
context:
space:
mode:
authorArtur Kowalski <artur.kowalski@3mdeb.com>2022-07-28 10:25:47 +0200
committerArtur Kowalski <arturkow2000@gmail.com>2022-08-10 19:40:35 +0200
commitd5ab0d3ebb119c7ffd95da4b67325f75cae05b7e (patch)
treeda0c94e370118c9df5e3a1e9582aa684035063fa /embassy-net
parent0e524247fa4adc524c546b0d073e7061ad6c1b83 (diff)
downloadembassy-d5ab0d3ebb119c7ffd95da4b67325f75cae05b7e.zip
Add UDP socket support
Diffstat (limited to 'embassy-net')
-rw-r--r--embassy-net/Cargo.toml1
-rw-r--r--embassy-net/src/lib.rs5
-rw-r--r--embassy-net/src/udp.rs227
3 files changed, 233 insertions, 0 deletions
diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml
index fface207..e4d8c2c2 100644
--- a/embassy-net/Cargo.toml
+++ b/embassy-net/Cargo.toml
@@ -18,6 +18,7 @@ std = []
defmt = ["dep:defmt", "smoltcp/defmt"]
+udp = ["smoltcp/socket-udp"]
tcp = ["smoltcp/socket-tcp"]
dns = ["smoltcp/socket-dns"]
dhcpv4 = ["medium-ethernet", "smoltcp/socket-dhcpv4"]
diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs
index 1c5ba103..83d36471 100644
--- a/embassy-net/src/lib.rs
+++ b/embassy-net/src/lib.rs
@@ -16,6 +16,9 @@ pub use stack::{Config, ConfigStrategy, Stack, StackResources};
#[cfg(feature = "tcp")]
pub mod tcp;
+#[cfg(feature = "udp")]
+pub mod udp;
+
// smoltcp reexports
pub use smoltcp::phy::{DeviceCapabilities, Medium};
pub use smoltcp::time::{Duration as SmolDuration, Instant as SmolInstant};
@@ -24,3 +27,5 @@ pub use smoltcp::wire::{EthernetAddress, HardwareAddress};
pub use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Cidr};
#[cfg(feature = "proto-ipv6")]
pub use smoltcp::wire::{Ipv6Address, Ipv6Cidr};
+#[cfg(feature = "udp")]
+pub use smoltcp::{socket::udp::PacketMetadata, wire::IpListenEndpoint};
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs
new file mode 100644
index 00000000..6b15805c
--- /dev/null
+++ b/embassy-net/src/udp.rs
@@ -0,0 +1,227 @@
+use core::cell::UnsafeCell;
+use core::mem;
+use core::task::Poll;
+
+use futures::future::poll_fn;
+use smoltcp::iface::{Interface, SocketHandle};
+use smoltcp::socket::udp::{self, PacketMetadata};
+use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
+
+use super::stack::SocketStack;
+use crate::{Device, Stack};
+
+#[derive(PartialEq, Eq, Clone, Copy, Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum BindError {
+ /// The socket was already open.
+ InvalidState,
+ /// No route to host.
+ NoRoute,
+}
+
+#[derive(PartialEq, Eq, Clone, Copy, Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum Error {
+ /// No route to host.
+ NoRoute,
+}
+
+pub struct UdpSocket<'a> {
+ io: UdpIo<'a>,
+}
+
+pub struct UdpReader<'a> {
+ io: UdpIo<'a>,
+}
+
+pub struct UdpWriter<'a> {
+ io: UdpIo<'a>,
+}
+
+impl<'a> UdpSocket<'a> {
+ pub fn new<D: Device>(
+ stack: &'a Stack<D>,
+ rx_meta: &'a mut [PacketMetadata],
+ rx_buffer: &'a mut [u8],
+ tx_meta: &'a mut [PacketMetadata],
+ tx_buffer: &'a mut [u8],
+ ) -> Self {
+ // safety: not accessed reentrantly.
+ let s = unsafe { &mut *stack.socket.get() };
+
+ let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) };
+ let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
+ let tx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(tx_meta) };
+ let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
+ let handle = s.sockets.add(udp::Socket::new(
+ udp::PacketBuffer::new(rx_meta, rx_buffer),
+ udp::PacketBuffer::new(tx_meta, tx_buffer),
+ ));
+
+ Self {
+ io: UdpIo {
+ stack: &stack.socket,
+ handle,
+ },
+ }
+ }
+
+ pub fn split(&mut self) -> (UdpReader<'_>, UdpWriter<'_>) {
+ (UdpReader { io: self.io }, UdpWriter { io: self.io })
+ }
+
+ pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError>
+ where
+ T: Into<IpListenEndpoint>,
+ {
+ let mut endpoint = endpoint.into();
+
+ // safety: not accessed reentrantly.
+ if endpoint.port == 0 {
+ // If user didn't specify port allocate a dynamic port.
+ endpoint.port = unsafe { &mut *self.io.stack.get() }.get_local_port();
+ }
+
+ // safety: not accessed reentrantly.
+ match unsafe { self.io.with_mut(|s, _| s.bind(endpoint)) } {
+ Ok(()) => Ok(()),
+ Err(udp::BindError::InvalidState) => Err(BindError::InvalidState),
+ Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute),
+ }
+ }
+
+ pub async fn send_to<T>(&mut self, buf: &[u8], remote_endpoint: T) -> Result<(), Error>
+ where
+ T: Into<IpEndpoint>,
+ {
+ self.io.write(buf, remote_endpoint.into()).await
+ }
+
+ pub async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> {
+ self.io.read(buf).await
+ }
+
+ pub async fn flush(&mut self) -> Result<(), Error> {
+ self.io.flush().await
+ }
+
+ pub fn endpoint(&self) -> IpListenEndpoint {
+ unsafe { self.io.with(|s, _| s.endpoint()) }
+ }
+
+ pub fn is_open(&self) -> bool {
+ unsafe { self.io.with(|s, _| s.is_open()) }
+ }
+
+ pub fn close(&mut self) {
+ unsafe { self.io.with_mut(|s, _| s.close()) }
+ }
+
+ pub fn may_send(&self) -> bool {
+ unsafe { self.io.with(|s, _| s.can_send()) }
+ }
+
+ pub fn may_recv(&self) -> bool {
+ unsafe { self.io.with(|s, _| s.can_recv()) }
+ }
+}
+
+impl Drop for UdpSocket<'_> {
+ fn drop(&mut self) {
+ // safety: not accessed reentrantly.
+ let s = unsafe { &mut *self.io.stack.get() };
+ s.sockets.remove(self.io.handle);
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct UdpIo<'a> {
+ stack: &'a UnsafeCell<SocketStack>,
+ handle: SocketHandle,
+}
+
+impl UdpIo<'_> {
+ /// SAFETY: must not call reentrantly.
+ unsafe fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
+ let s = &*self.stack.get();
+ let socket = s.sockets.get::<udp::Socket>(self.handle);
+ f(socket, &s.iface)
+ }
+
+ /// SAFETY: must not call reentrantly.
+ unsafe fn with_mut<R>(&mut self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
+ let s = &mut *self.stack.get();
+ let socket = s.sockets.get_mut::<udp::Socket>(self.handle);
+ let res = f(socket, &mut s.iface);
+ s.waker.wake();
+ res
+ }
+
+ async fn read(&mut self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> {
+ poll_fn(move |cx| unsafe {
+ self.with_mut(|s, _| match s.recv_slice(buf) {
+ Ok(x) => Poll::Ready(Ok(x)),
+ // No data ready
+ Err(udp::RecvError::Exhausted) => {
+ //s.register_recv_waker(cx.waker());
+ cx.waker().wake_by_ref();
+ Poll::Pending
+ }
+ })
+ })
+ .await
+ }
+
+ async fn write(&mut self, buf: &[u8], ep: IpEndpoint) -> Result<(), Error> {
+ poll_fn(move |cx| unsafe {
+ self.with_mut(|s, _| match s.send_slice(buf, ep) {
+ // Entire datagram has been sent
+ Ok(()) => Poll::Ready(Ok(())),
+ Err(udp::SendError::BufferFull) => {
+ s.register_send_waker(cx.waker());
+ Poll::Pending
+ }
+ Err(udp::SendError::Unaddressable) => Poll::Ready(Err(Error::NoRoute)),
+ })
+ })
+ .await
+ }
+
+ async fn flush(&mut self) -> Result<(), Error> {
+ poll_fn(move |_| {
+ Poll::Ready(Ok(())) // TODO: Is there a better implementation for this?
+ })
+ .await
+ }
+}
+
+impl UdpReader<'_> {
+ pub async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> {
+ self.io.read(buf).await
+ }
+}
+
+impl UdpWriter<'_> {
+ pub async fn send_to<T>(&mut self, buf: &[u8], remote_endpoint: T) -> Result<(), Error>
+ where
+ T: Into<IpEndpoint>,
+ {
+ self.io.write(buf, remote_endpoint.into()).await
+ }
+
+ pub async fn flush(&mut self) -> Result<(), Error> {
+ self.io.flush().await
+ }
+}
+
+impl embedded_io::Error for BindError {
+ fn kind(&self) -> embedded_io::ErrorKind {
+ embedded_io::ErrorKind::Other
+ }
+}
+
+impl embedded_io::Error for Error {
+ fn kind(&self) -> embedded_io::ErrorKind {
+ embedded_io::ErrorKind::Other
+ }
+}