diff options
author | Dario Nieuwenhuis <dirbaio@dirbaio.net> | 2021-02-03 05:09:37 +0100 |
---|---|---|
committer | Dario Nieuwenhuis <dirbaio@dirbaio.net> | 2021-02-03 05:09:37 +0100 |
commit | cb5931d583d283dda3a1b5ed2014c086bb8f98ae (patch) | |
tree | 19a669238e0d562bf74616fe38485388ec40b02a /embassy-net | |
download | embassy-cb5931d583d283dda3a1b5ed2014c086bb8f98ae.zip |
:rainbow:
Diffstat (limited to 'embassy-net')
-rw-r--r-- | embassy-net/Cargo.toml | 46 | ||||
-rw-r--r-- | embassy-net/src/config/dhcp.rs | 80 | ||||
-rw-r--r-- | embassy-net/src/config/mod.rs | 34 | ||||
-rw-r--r-- | embassy-net/src/config/statik.rs | 26 | ||||
-rw-r--r-- | embassy-net/src/device.rs | 103 | ||||
-rw-r--r-- | embassy-net/src/fmt.rs | 118 | ||||
-rw-r--r-- | embassy-net/src/lib.rs | 31 | ||||
-rw-r--r-- | embassy-net/src/packet_pool.rs | 88 | ||||
-rw-r--r-- | embassy-net/src/pool.rs | 245 | ||||
-rw-r--r-- | embassy-net/src/stack.rs | 212 | ||||
-rw-r--r-- | embassy-net/src/tcp_socket.rs | 178 |
11 files changed, 1161 insertions, 0 deletions
diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml new file mode 100644 index 00000000..aec6b796 --- /dev/null +++ b/embassy-net/Cargo.toml @@ -0,0 +1,46 @@ +[package] +name = "embassy-net" +version = "0.1.0" +authors = ["Dario Nieuwenhuis <dirbaio@dirbaio.net>"] +edition = "2018" + +[features] +std = [] +defmt-trace = [] +defmt-debug = [] +defmt-info = [] +defmt-warn = [] +defmt-error = [] + +[dependencies] + +defmt = { version = "0.1.3", optional = true } +log = { version = "0.4.11", optional = true } + +embassy = { version = "0.1.0" } + +managed = { version = "0.8.0", default-features = false, features = [ "map" ]} +heapless = { version = "0.5.6", default-features = false } +as-slice = { version = "0.1.4" } +generic-array = { version = "0.14.4", default-features = false } +stable_deref_trait = { version = "1.2.0", default-features = false } +futures = { version = "0.3.5", default-features = false, features = [ "async-await" ]} + +[dependencies.smoltcp] +version = "0.6.0" +#git = "https://github.com/akiles/smoltcp" +#rev = "00952e2c5cdf5667a1dfb6142258055f58d3851c" +default-features = false +features = [ + "medium-ethernet", + "medium-ip", + "proto-ipv4", + "proto-dhcpv4", + #"proto-igmp", + #"proto-ipv6", + #"socket-raw", + #"socket-icmp", + #"socket-udp", + "socket-tcp", + "async", +] diff --git a/embassy-net/src/config/dhcp.rs b/embassy-net/src/config/dhcp.rs new file mode 100644 index 00000000..f5d598bd --- /dev/null +++ b/embassy-net/src/config/dhcp.rs @@ -0,0 +1,80 @@ +use embassy::util::Forever; +use heapless::consts::*; +use heapless::Vec; +use smoltcp::dhcp::Dhcpv4Client; +use smoltcp::socket::{RawPacketMetadata, RawSocketBuffer}; +use smoltcp::time::Instant; +use smoltcp::wire::{Ipv4Address, Ipv4Cidr}; + +use super::*; +use crate::{device::LinkState, fmt::*}; +use crate::{Interface, SocketSet}; + +pub struct DhcpResources { + rx_buffer: [u8; 900], + tx_buffer: [u8; 600], + rx_meta: [RawPacketMetadata; 1], + tx_meta: [RawPacketMetadata; 1], +} + +pub struct DhcpConfigurator { + client: Option<Dhcpv4Client>, +} + +impl DhcpConfigurator { + pub fn new() -> Self { + Self { client: None } + } +} + +static DHCP_RESOURCES: Forever<DhcpResources> = Forever::new(); + +impl Configurator for DhcpConfigurator { + fn poll( + &mut self, + iface: &mut Interface, + sockets: &mut SocketSet, + timestamp: Instant, + ) -> Option<Config> { + if self.client.is_none() { + let res = DHCP_RESOURCES.put(DhcpResources { + rx_buffer: [0; 900], + tx_buffer: [0; 600], + rx_meta: [RawPacketMetadata::EMPTY; 1], + tx_meta: [RawPacketMetadata::EMPTY; 1], + }); + let rx_buffer = RawSocketBuffer::new(&mut res.rx_meta[..], &mut res.rx_buffer[..]); + let tx_buffer = RawSocketBuffer::new(&mut res.tx_meta[..], &mut res.tx_buffer[..]); + let dhcp = Dhcpv4Client::new(sockets, rx_buffer, tx_buffer, timestamp); + info!("created dhcp"); + self.client = Some(dhcp) + } + + let client = self.client.as_mut().unwrap(); + + let link_up = iface.device_mut().device.link_state() == LinkState::Up; + if !link_up { + client.reset(timestamp); + return Some(Config::Down); + } + + let config = client.poll(iface, sockets, timestamp).unwrap_or(None)?; + + if config.address.is_none() { + return Some(Config::Down); + } + + let mut dns_servers = Vec::new(); + for s in &config.dns_servers { + if let Some(addr) = s { + dns_servers.push(addr.clone()).unwrap(); + } + } + + return Some(Config::Up(UpConfig { + address: config.address.unwrap(), + gateway: config.router.unwrap_or(Ipv4Address::UNSPECIFIED), + dns_servers, + })); + } +} diff --git a/embassy-net/src/config/mod.rs b/embassy-net/src/config/mod.rs new file mode 100644 index 00000000..596374f9 --- /dev/null +++ b/embassy-net/src/config/mod.rs @@ -0,0 +1,34 @@ +use heapless::consts::*; +use heapless::Vec; +use smoltcp::time::Instant; +use smoltcp::wire::{Ipv4Address, Ipv4Cidr}; + +use crate::fmt::*; +use crate::{Interface, SocketSet}; + +mod dhcp; +mod statik; +pub use dhcp::DhcpConfigurator; +pub use statik::StaticConfigurator; + +#[derive(Debug, Clone)] +pub enum Config { + Down, + Up(UpConfig), +} + +#[derive(Debug, Clone)] +pub struct UpConfig { + pub address: Ipv4Cidr, + pub gateway: Ipv4Address, + pub dns_servers: Vec<Ipv4Address, U3>, +} + +pub trait Configurator { + fn poll( + &mut self, + iface: &mut Interface, + sockets: &mut SocketSet, + timestamp: Instant, + ) -> Option<Config>; +} diff --git a/embassy-net/src/config/statik.rs b/embassy-net/src/config/statik.rs new file mode 100644 index 00000000..52196f48 --- /dev/null +++ b/embassy-net/src/config/statik.rs @@ -0,0 +1,26 @@ +use smoltcp::time::Instant; + +use super::*; +use crate::fmt::*; +use crate::{Interface, SocketSet}; + +pub struct StaticConfigurator { + config: UpConfig, +} + +impl StaticConfigurator { + pub fn new(config: UpConfig) -> Self { + Self { config } + } +} + +impl Configurator for StaticConfigurator { + fn poll( + &mut self, + _iface: &mut Interface, + _sockets: &mut SocketSet, + _timestamp: Instant, + ) -> Option<Config> { + Some(Config::Up(self.config.clone())) + } +} diff --git a/embassy-net/src/device.rs b/embassy-net/src/device.rs new file mode 100644 index 00000000..95a62e79 --- /dev/null +++ b/embassy-net/src/device.rs @@ -0,0 +1,103 @@ +use core::task::{Poll, Waker}; +use smoltcp::phy::Device as SmolDevice; +use smoltcp::phy::DeviceCapabilities; +use smoltcp::time::Instant as SmolInstant; +use smoltcp::Result; + +use crate::fmt::*; +use crate::{Packet, PacketBox, PacketBuf}; + +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum LinkState { + Down, + Up, +} + +pub trait Device { + fn is_transmit_ready(&mut self) -> bool; + fn transmit(&mut self, pkt: PacketBuf); + fn receive(&mut self) -> Option<PacketBuf>; + + fn register_waker(&mut self, waker: &Waker); + fn capabilities(&mut self) -> DeviceCapabilities; + fn link_state(&mut self) -> LinkState; +} + +pub struct DeviceAdapter { + pub device: &'static mut dyn Device, + caps: DeviceCapabilities, +} + +impl DeviceAdapter { + pub(crate) fn new(device: &'static mut dyn Device) -> Self { + Self { + caps: device.capabilities(), + device, + } + } +} + +impl<'a> SmolDevice<'a> for DeviceAdapter { + type RxToken = RxToken; + type TxToken = TxToken<'a>; + + fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { + let rx_pkt = self.device.receive()?; + let tx_pkt = PacketBox::new(Packet::new()).unwrap(); // TODO: not sure about unwrap + let rx_token = RxToken { pkt: rx_pkt }; + let tx_token = TxToken { + device: self.device, + pkt: tx_pkt, + }; + + Some((rx_token, tx_token)) + } + + /// Construct a transmit token. + fn transmit(&'a mut self) -> Option<Self::TxToken> { + if !self.device.is_transmit_ready() { + return None; + } + + let tx_pkt = PacketBox::new(Packet::new())?; + Some(TxToken { + device: self.device, + pkt: tx_pkt, + }) + } + + /// Get a description of device capabilities. + fn capabilities(&self) -> DeviceCapabilities { + self.caps.clone() + } +} + +pub struct RxToken { + pkt: PacketBuf, +} + +impl smoltcp::phy::RxToken for RxToken { + fn consume<R, F>(mut self, _timestamp: SmolInstant, f: F) -> Result<R> + where + F: FnOnce(&mut [u8]) -> Result<R>, + { + f(&mut self.pkt) + } +} + +pub struct TxToken<'a> { + device: &'a mut dyn Device, + pkt: PacketBox, +} + +impl<'a> smoltcp::phy::TxToken for TxToken<'a> { + fn consume<R, F>(mut self, _timestamp: SmolInstant, len: usize, f: F) -> Result<R> + where + F: FnOnce(&mut [u8]) -> Result<R>, + { + let mut buf = self.pkt.slice(0..len); + let r = f(&mut buf)?; + self.device.transmit(buf); + Ok(r) + } +} diff --git a/embassy-net/src/fmt.rs b/embassy-net/src/fmt.rs new file mode 100644 index 00000000..4da69766 --- /dev/null +++ b/embassy-net/src/fmt.rs @@ -0,0 +1,118 @@ +#![macro_use] + +#[cfg(all(feature = "defmt", feature = "log"))] +compile_error!("You may not enable both `defmt` and `log` features."); + +pub use fmt::*; + +#[cfg(feature = "defmt")] +mod fmt { + pub use defmt::{ + assert, assert_eq, assert_ne, debug, debug_assert, debug_assert_eq, debug_assert_ne, error, + info, panic, todo, trace, unreachable, unwrap, warn, + }; +} + +#[cfg(feature = "log")] +mod fmt { + pub use core::{ + assert, assert_eq, assert_ne, debug_assert, debug_assert_eq, debug_assert_ne, panic, todo, + unreachable, + }; + pub use log::{debug, error, info, trace, warn}; +} + +#[cfg(not(any(feature = "defmt", feature = "log")))] +mod fmt { + #![macro_use] + + pub use core::{ + assert, assert_eq, assert_ne, debug_assert, debug_assert_eq, debug_assert_ne, panic, todo, + unreachable, + }; + + #[macro_export] + macro_rules! trace { + ($($msg:expr),+ $(,)?) => { + () + }; + } + + #[macro_export] + macro_rules! debug { + ($($msg:expr),+ $(,)?) => { + () + }; + } + + #[macro_export] + macro_rules! info { + ($($msg:expr),+ $(,)?) => { + () + }; + } + + #[macro_export] + macro_rules! warn { + ($($msg:expr),+ $(,)?) => { + () + }; + } + + #[macro_export] + macro_rules! error { + ($($msg:expr),+ $(,)?) => { + () + }; + } +} + +#[cfg(not(feature = "defmt"))] +#[macro_export] +macro_rules! unwrap { + ($arg:expr) => { + match $crate::fmt::Try::into_result($arg) { + ::core::result::Result::Ok(t) => t, + ::core::result::Result::Err(e) => { + ::core::panic!("unwrap of `{}` failed: {:?}", ::core::stringify!($arg), e); + } + } + }; + ($arg:expr, $($msg:expr),+ $(,)? ) => { + match $crate::fmt::Try::into_result($arg) { + ::core::result::Result::Ok(t) => t, + ::core::result::Result::Err(e) => { + ::core::panic!("unwrap of `{}` failed: {}: {:?}", ::core::stringify!($arg), ::core::format_args!($($msg,)*), e); + } + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NoneError; + +pub trait Try { + type Ok; + type Error; + fn into_result(self) -> Result<Self::Ok, Self::Error>; +} + +impl<T> Try for Option<T> { + type Ok = T; + type Error = NoneError; + + #[inline] + fn into_result(self) -> Result<T, NoneError> { + self.ok_or(NoneError) + } +} + +impl<T, E> Try for Result<T, E> { + type Ok = T; + type Error = E; + + #[inline] + fn into_result(self) -> Self { + self + } +} diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs new file mode 100644 index 00000000..a2a320ad --- /dev/null +++ b/embassy-net/src/lib.rs @@ -0,0 +1,31 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![feature(const_fn)] +#![feature(const_in_array_repeat_expressions)] +#![feature(const_generics)] +#![feature(const_evaluatable_checked)] +#![allow(incomplete_features)] + +// This mod MUST go first, so that the others see its macros. +pub(crate) mod fmt; + +mod pool; // TODO extract to embassy, or to own crate + +mod config; +mod device; +mod packet_pool; +mod stack; +mod tcp_socket; + +pub use config::{Config, Configurator, DhcpConfigurator, StaticConfigurator, UpConfig}; +pub use device::{Device, LinkState}; +pub use packet_pool::{Packet, PacketBox, PacketBuf}; +pub use stack::{init, is_init, run}; +pub use tcp_socket::TcpSocket; + +// smoltcp reexports +pub use smoltcp::phy::{DeviceCapabilities, Medium}; +pub use smoltcp::time::Duration as SmolDuration; +pub use smoltcp::time::Instant as SmolInstant; +pub use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Cidr}; +pub type Interface = smoltcp::iface::Interface<'static, device::DeviceAdapter>; +pub type SocketSet = smoltcp::socket::SocketSet<'static>; diff --git a/embassy-net/src/packet_pool.rs b/embassy-net/src/packet_pool.rs new file mode 100644 index 00000000..24635643 --- /dev/null +++ b/embassy-net/src/packet_pool.rs @@ -0,0 +1,88 @@ +use as_slice::{AsMutSlice, AsSlice}; +use core::ops::{Deref, DerefMut, Range}; + +use super::pool::{BitPool, Box, StaticPool}; + +pub const MTU: usize = 1514; +pub const PACKET_POOL_SIZE: usize = 4; + +pool!(pub PacketPool: [Packet; PACKET_POOL_SIZE]); +pub type PacketBox = Box<PacketPool>; + +pub struct Packet(pub [u8; MTU]); + +impl Packet { + pub const fn new() -> Self { + Self([0; MTU]) + } +} + +impl Box<PacketPool> { + pub fn slice(self, range: Range<usize>) -> PacketBuf { + PacketBuf { + packet: self, + range, + } + } +} + +impl AsSlice for Packet { + type Element = u8; + + fn as_slice(&self) -> &[Self::Element] { + &self.deref()[..] + } +} + +impl AsMutSlice for Packet { + fn as_mut_slice(&mut self) -> &mut [Self::Element] { + &mut self.deref_mut()[..] + } +} + +impl Deref for Packet { + type Target = [u8; MTU]; + + fn deref(&self) -> &[u8; MTU] { + &self.0 + } +} + +impl DerefMut for Packet { + fn deref_mut(&mut self) -> &mut [u8; MTU] { + &mut self.0 + } +} + +pub struct PacketBuf { + packet: PacketBox, + range: Range<usize>, +} + +impl AsSlice for PacketBuf { + type Element = u8; + + fn as_slice(&self) -> &[Self::Element] { + &self.packet[self.range.clone()] + } +} + +impl AsMutSlice for PacketBuf { + fn as_mut_slice(&mut self) -> &mut [Self::Element] { + &mut self.packet[self.range.clone()] + } +} + +impl Deref for PacketBuf { + type Target = [u8]; + + fn deref(&self) -> &[u8] { + &self.packet[self.range.clone()] + } +} + +impl DerefMut for PacketBuf { + fn deref_mut(&mut self) -> &mut [u8] { + &mut self.packet[self.range.clone()] + } +} diff --git a/embassy-net/src/pool.rs b/embassy-net/src/pool.rs new file mode 100644 index 00000000..3ab36e4c --- /dev/null +++ b/embassy-net/src/pool.rs @@ -0,0 +1,245 @@ +#![macro_use] + +use as_slice::{AsMutSlice, AsSlice}; +use core::cmp; +use core::fmt; +use core::hash::{Hash, Hasher}; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::sync::atomic::{AtomicU32, Ordering}; + +use crate::fmt::{assert, *}; + +struct AtomicBitset<const N: usize> +where + [AtomicU32; (N + 31) / 32]: Sized, +{ + used: [AtomicU32; (N + 31) / 32], +} + +impl<const N: usize> AtomicBitset<N> +where + [AtomicU32; (N + 31) / 32]: Sized, +{ + const fn new() -> Self { + const Z: AtomicU32 = AtomicU32::new(0); + Self { + used: [Z; (N + 31) / 32], + } + } + + fn alloc(&self) -> Option<usize> { + for (i, val) in self.used.iter().enumerate() { + let res = val.fetch_update(Ordering::AcqRel, Ordering::Acquire, |val| { + let n = val.trailing_ones() as usize + i * 32; + if n >= N { + None + } else { + Some(val | (1 << n)) + } + }); + if let Ok(val) = res { + let n = val.trailing_ones() as usize + i * 32; + return Some(n); + } + } + None + } + fn free(&self, i: usize) { + assert!(i < N); + self.used[i / 32].fetch_and(!(1 << ((i % 32) as u32)), Ordering::AcqRel); + } +} + +pub trait Pool<T> { + fn alloc(&self) -> Option<*mut T>; + unsafe fn free(&self, p: *mut T); +} + +pub struct BitPool<T, const N: usize> +where + [AtomicU32; (N + 31) / 32]: Sized, +{ + used: AtomicBitset<N>, + data: MaybeUninit<[T; N]>, +} + +impl<T, const N: usize> BitPool<T, N> +where + [AtomicU32; (N + 31) / 32]: Sized, +{ + pub const fn new() -> Self { + Self { + used: AtomicBitset::new(), + data: MaybeUninit::uninit(), + } + } +} + +impl<T, const N: usize> Pool<T> for BitPool<T, N> +where + [AtomicU32; (N + 31) / 32]: Sized, +{ + fn alloc(&self) -> Option<*mut T> { + let n = self.used.alloc()?; + let origin = self.data.as_ptr() as *mut T; + Some(unsafe { origin.add(n) }) + } + + /// safety: p must be a pointer obtained from self.alloc that hasn't been freed yet. + unsafe fn free(&self, p: *mut T) { + let origin = self.data.as_ptr() as *mut T; + let n = p.offset_from(origin); + assert!(n >= 0); + assert!((n as usize) < N); + self.used.free(n as usize); + } +} + +pub trait StaticPool: 'static { + type Item: 'static; + type Pool: Pool<Self::Item>; + fn get() -> &'static Self::Pool; +} + +pub struct Box<P: StaticPool> { + ptr: *mut P::Item, +} + +impl<P: StaticPool> Box<P> { + pub fn new(item: P::Item) -> Option<Self> { + let p = match P::get().alloc() { + Some(p) => p, + None => { + warn!("alloc failed!"); + return None; + } + }; + //trace!("allocated {:u32}", p as u32); + unsafe { p.write(item) }; + Some(Self { ptr: p }) + } +} + +impl<P: StaticPool> Drop for Box<P> { + fn drop(&mut self) { + unsafe { + //trace!("dropping {:u32}", self.ptr as u32); + self.ptr.drop_in_place(); + P::get().free(self.ptr); + }; + } +} + +unsafe impl<P: StaticPool> Send for Box<P> where P::Item: Send {} + +unsafe impl<P: StaticPool> Sync for Box<P> where P::Item: Sync {} + +unsafe impl<P: StaticPool> stable_deref_trait::StableDeref for Box<P> {} + +impl<P: StaticPool> AsSlice for Box<P> +where + P::Item: AsSlice, +{ + type Element = <P::Item as AsSlice>::Element; + + fn as_slice(&self) -> &[Self::Element] { + self.deref().as_slice() + } +} + +impl<P: StaticPool> AsMutSlice for Box<P> +where + P::Item: AsMutSlice, +{ + fn as_mut_slice(&mut self) -> &mut [Self::Element] { + self.deref_mut().as_mut_slice() + } +} + +impl<P: StaticPool> Deref for Box<P> { + type Target = P::Item; + + fn deref(&self) -> &P::Item { + unsafe { &*self.ptr } + } +} + +impl<P: StaticPool> DerefMut for Box<P> { + fn deref_mut(&mut self) -> &mut P::Item { + unsafe { &mut *self.ptr } + } +} + +impl<P: StaticPool> fmt::Debug for Box<P> +where + P::Item: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <P::Item as fmt::Debug>::fmt(self, f) + } +} + +impl<P: StaticPool> fmt::Display for Box<P> +where + P::Item: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <P::Item as fmt::Display>::fmt(self, f) + } +} + +impl<P: StaticPool> PartialEq for Box<P> +where + P::Item: PartialEq, +{ + fn eq(&self, rhs: &Box<P>) -> bool { + <P::Item as PartialEq>::eq(self, rhs) + } +} + +impl<P: StaticPool> Eq for Box<P> where P::Item: Eq {} + +impl<P: StaticPool> PartialOrd for Box<P> +where + P::Item: PartialOrd, +{ + fn partial_cmp(&self, rhs: &Box<P>) -> Option<cmp::Ordering> { + <P::Item as PartialOrd>::partial_cmp(self, rhs) + } +} + +impl<P: StaticPool> Ord for Box<P> +where + P::Item: Ord, +{ + fn cmp(&self, rhs: &Box<P>) -> cmp::Ordering { + <P::Item as Ord>::cmp(self, rhs) + } +} + +impl<P: StaticPool> Hash for Box<P> +where + P::Item: Hash, +{ + fn hash<H>(&self, state: &mut H) + where + H: Hasher, + { + <P::Item as Hash>::hash(self, state) + } +} + +macro_rules! pool { + ($vis:vis $name:ident: [$ty:ty; $size:expr]) => { + $vis struct $name; + impl StaticPool for $name { + type Item = $ty; + type Pool = BitPool<$ty, $size>; + fn get() -> &'static Self::Pool { + static POOL: BitPool<$ty, $size> = BitPool::new(); + &POOL + } + } + }; +} diff --git a/embassy-net/src/stack.rs b/embassy-net/src/stack.rs new file mode 100644 index 00000000..c353f1bb --- /dev/null +++ b/embassy-net/src/stack.rs @@ -0,0 +1,212 @@ +use core::future::Future; +use core::task::Context; +use core::task::Poll; +use core::{cell::RefCell, future}; +use embassy::time::{Instant, Timer}; +use embassy::util::ThreadModeMutex; +use embassy::util::{Forever, WakerRegistration}; +use futures::pin_mut; +use smoltcp::iface::{InterfaceBuilder, Neighbor, NeighborCache, Route, Routes}; +use smoltcp::phy::Device as _; +use smoltcp::phy::Medium; +use smoltcp::socket::SocketSetItem; +use smoltcp::time::Instant as SmolInstant; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address}; + +use crate::device::{Device, DeviceAdapter}; +use crate::fmt::*; +use crate::{ + config::{Config, Configurator}, + device::LinkState, +}; +use crate::{Interface, SocketSet}; + +const ADDRESSES_LEN: usize = 1; +const NEIGHBOR_CACHE_LEN: usize = 8; +const SOCKETS_LEN: usize = 2; +const LOCAL_PORT_MIN: u16 = 1025; +const LOCAL_PORT_MAX: u16 = 65535; + +struct StackResources { + addresses: [IpCidr; ADDRESSES_LEN], + neighbor_cache: [Option<(IpAddress, Neighbor)>; NEIGHBOR_CACHE_LEN], + sockets: [Option<SocketSetItem<'static>>; SOCKETS_LEN], + routes: [Option<(IpCidr, Route)>; 1], +} + +static STACK_RESOURCES: Forever<StackResources> = Forever::new(); +static STACK: ThreadModeMutex<RefCell<Option<Stack>>> = ThreadModeMutex::new(RefCell::new(None)); + +pub(crate) struct Stack { + iface: Interface, + pub sockets: SocketSet, + link_up: bool, + next_local_port: u16, + configurator: &'static mut dyn Configurator, + waker: WakerRegistration, +} + +impl Stack { + pub(crate) fn with<R>(f: impl FnOnce(&mut Stack) -> R) -> R { + let mut stack = STACK.borrow().borrow_mut(); + let stack = stack.as_mut().unwrap(); + f(stack) + } + + pub fn get_local_port(&mut self) -> u16 { + let res = self.next_local_port; + self.next_local_port = if res >= LOCAL_PORT_MAX { + LOCAL_PORT_MIN + } else { + res + 1 + }; + res + } + + pub(crate) fn wake(&mut self) { + self.waker.wake() + } + + fn poll_configurator(&mut self, timestamp: SmolInstant) { + if let Some(config) = self + .configurator + .poll(&mut self.iface, &mut self.sockets, timestamp) + { + let medium = self.iface.device().capabilities().medium; + + let (addr, gateway) = match config { + Config::Up(config) => (config.address.into(), Some(config.gateway)), + Config::Down => (IpCidr::new(Ipv4Address::UNSPECIFIED.into(), 32), None), + }; + + self.iface.update_ip_addrs(|addrs| { + let curr_addr = &mut addrs[0]; + if *curr_addr != addr { + info!("IPv4 address: {:?} -> {:?}", *curr_addr, addr); + *curr_addr = addr; + } + }); + + if medium == Medium::Ethernet { + self.iface.routes_mut().update(|r| { + let cidr = IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0); + let curr_gateway = r.get(&cidr).map(|r| r.via_router); + + if curr_gateway != gateway.map(|a| a.into()) { + info!("IPv4 gateway: {:?} -> {:?}", curr_gateway, gateway); + if let Some(gateway) = gateway { + r.insert(cidr, Route::new_ipv4_gateway(gateway)).unwrap(); + } else { + r.remove(&cidr); + } + } + }); + } + } + } + + fn poll(&mut self, cx: &mut Context<'_>) { + self.iface.device_mut().device.register_waker(cx.waker()); + self.waker.register(cx.waker()); + + let timestamp = instant_to_smoltcp(Instant::now()); + if let Err(e) = self.iface.poll(&mut self.sockets, timestamp) { + // If poll() returns error, it may not be done yet, so poll again later. + cx.waker().wake_by_ref(); + return; + } + + // Update link up + let old_link_up = self.link_up; + self.link_up = self.iface.device_mut().device.link_state() == LinkState::Up; + + // Print when changed + if old_link_up != self.link_up { + if self.link_up { + info!("Link up!"); + } else { + info!("Link down!"); + } + } + + if old_link_up || self.link_up { + self.poll_configurator(timestamp) + } + + if let Some(poll_at) = self.iface.poll_at(&mut self.sockets, timestamp) { + let t = Timer::at(instant_from_smoltcp(poll_at)); + pin_mut!(t); + if t.poll(cx).is_ready() { + cx.waker().wake_by_ref(); + } + } + } +} + +/// Initialize embassy_net. +/// This function must be called from thread mode. +pub fn init(device: &'static mut dyn Device, configurator: &'static mut dyn Configurator) { + let res = STACK_RESOURCES.put(StackResources { + addresses: [IpCidr::new(Ipv4Address::UNSPECIFIED.into(), 32)], + neighbor_cache: [None; NEIGHBOR_CACHE_LEN], + sockets: [None; SOCKETS_LEN], + routes: [None; 1], + }); + + let ethernet_addr = EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02]); + + let medium = device.capabilities().medium; + + let mut b = InterfaceBuilder::new(DeviceAdapter::new(device)); + b = b.ip_addrs(&mut res.addresses[..]); + + if medium == Medium::Ethernet { + b = b.ethernet_addr(ethernet_addr); + b = b.neighbor_cache(NeighborCache::new(&mut res.neighbor_cache[..])); + b = b.routes(Routes::new(&mut res.routes[..])); + } + + let iface = b.finalize(); + + let sockets = SocketSet::new(&mut res.sockets[..]); + + let local_port = loop { + let mut res = [0u8; 2]; + embassy::rand::rand(&mut res); + let port = u16::from_le_bytes(res); + if port >= LOCAL_PORT_MIN && port <= LOCAL_PORT_MAX { + break port; + } + }; + + let stack = Stack { + iface, + sockets, + link_up: false, + configurator, + next_local_port: local_port, + waker: WakerRegistration::new(), + }; + + *STACK.borrow().borrow_mut() = Some(stack); +} + +pub fn is_init() -> bool { + STACK.borrow().borrow().is_some() +} + +pub async fn run() { + futures::future::poll_fn(|cx| { + Stack::with(|stack| stack.poll(cx)); + Poll::<()>::Pending + }) + .await +} + +fn instant_to_smoltcp(instant: Instant) -> SmolInstant { + SmolInstant::from_millis(instant.as_millis() as i64) +} + +fn instant_from_smoltcp(instant: SmolInstant) -> Instant { + Instant::from_millis(instant.total_millis() as u64) +} diff --git a/embassy-net/src/tcp_socket.rs b/embassy-net/src/tcp_socket.rs new file mode 100644 index 00000000..7f4eb014 --- /dev/null +++ b/embassy-net/src/tcp_socket.rs @@ -0,0 +1,178 @@ +use core::marker::PhantomData; +use core::mem; +use core::pin::Pin; +use core::task::{Context, Poll}; +use embassy::io; +use embassy::io::{AsyncBufRead, AsyncWrite}; +use smoltcp::socket::SocketHandle; +use smoltcp::socket::TcpSocket as SyncTcpSocket; +use smoltcp::socket::{TcpSocketBuffer, TcpState}; +use smoltcp::time::Duration; +use smoltcp::wire::IpEndpoint; +use smoltcp::{Error, Result}; + +use super::stack::Stack; +use crate::fmt::*; + +pub struct TcpSocket<'a> { + handle: SocketHandle, + ghost: PhantomData<&'a mut [u8]>, +} + +impl<'a> Unpin for TcpSocket<'a> {} + +impl<'a> TcpSocket<'a> { + pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { + let handle = Stack::with(|stack| { + let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; + let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; + stack.sockets.add(SyncTcpSocket::new( + TcpSocketBuffer::new(rx_buffer), + TcpSocketBuffer::new(tx_buffer), + )) + }); + + Self { + handle, + ghost: PhantomData, + } + } + + pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<()> + where + T: Into<IpEndpoint>, + { + let local_port = Stack::with(|stack| stack.get_local_port()); + self.with(|s| s.connect(remote_endpoint, local_port))?; + + futures::future::poll_fn(|cx| { + self.with(|s| match s.state() { + TcpState::Closed | TcpState::TimeWait => Poll::Ready(Err(Error::Unaddressable)), + TcpState::Listen => Poll::Ready(Err(Error::Illegal)), + TcpState::SynSent | TcpState::SynReceived => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + _ => Poll::Ready(Ok(())), + }) + }) + .await + } + + pub fn set_timeout(&mut self, duration: Option<Duration>) { + self.with(|s| s.set_timeout(duration)) + } + + pub fn set_keep_alive(&mut self, interval: Option<Duration>) { + self.with(|s| s.set_keep_alive(interval)) + } + + pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { + self.with(|s| s.set_hop_limit(hop_limit)) + } + + pub fn local_endpoint(&self) -> IpEndpoint { + self.with(|s| s.local_endpoint()) + } + + pub fn remote_endpoint(&self) -> IpEndpoint { + self.with(|s| s.remote_endpoint()) + } + + pub fn state(&self) -> TcpState { + self.with(|s| s.state()) + } + + pub fn close(&mut self) { + self.with(|s| s.close()) + } + + pub fn abort(&mut self) { + self.with(|s| s.abort()) + } + + pub fn may_send(&self) -> bool { + self.with(|s| s.may_send()) + } + + pub fn may_recv(&self) -> bool { + self.with(|s| s.may_recv()) + } + + fn with<R>(&self, f: impl FnOnce(&mut SyncTcpSocket) -> R) -> R { + Stack::with(|stack| { + let res = { + let mut s = stack.sockets.get::<SyncTcpSocket>(self.handle); + f(&mut *s) + }; + stack.wake(); + res + }) + } +} + +fn to_ioerr(e: Error) -> io::Error { + warn!("smoltcp err: {:?}", e); + // todo + io::Error::Other +} + +impl<'a> Drop for TcpSocket<'a> { + fn drop(&mut self) { + Stack::with(|stack| { + stack.sockets.remove(self.handle); + }) + } +} + +impl<'a> AsyncBufRead for TcpSocket<'a> { + fn poll_fill_buf<'z>( + self: Pin<&'z mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<&'z [u8]>> { + self.with(|socket| match socket.peek(1 << 30) { + // No data ready + Ok(buf) if buf.len() == 0 => { + socket.register_recv_waker(cx.waker()); + Poll::Pending + } + // Data ready! + Ok(buf) => { + // Safety: + // - User can't touch the inner TcpSocket directly at all. + // - The socket itself won't touch these bytes until consume() is called, which + // requires the user to release this borrow. + let buf: &'z [u8] = unsafe { core::mem::transmute(&*buf) }; + Poll::Ready(Ok(buf)) + } + // EOF + Err(Error::Finished) => Poll::Ready(Ok(&[][..])), + // Error + Err(e) => Poll::Ready(Err(to_ioerr(e))), + }) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.with(|s| s.recv(|_| (amt, ()))).unwrap() + } +} + +impl<'a> AsyncWrite for TcpSocket<'a> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.with(|s| match s.send_slice(buf) { + // Not ready to send (no space in the tx buffer) + Ok(0) => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + // Some data sent + Ok(n) => Poll::Ready(Ok(n)), + // Error + Err(e) => Poll::Ready(Err(to_ioerr(e))), + }) + } +} |