diff options
Diffstat (limited to 'embassy-sync')
-rw-r--r-- | embassy-sync/Cargo.toml | 34 | ||||
-rw-r--r-- | embassy-sync/build.rs | 29 | ||||
-rw-r--r-- | embassy-sync/src/blocking_mutex/mod.rs | 189 | ||||
-rw-r--r-- | embassy-sync/src/blocking_mutex/raw.rs | 149 | ||||
-rw-r--r-- | embassy-sync/src/channel/mod.rs | 5 | ||||
-rw-r--r-- | embassy-sync/src/channel/mpmc.rs | 596 | ||||
-rw-r--r-- | embassy-sync/src/channel/pubsub/mod.rs | 542 | ||||
-rw-r--r-- | embassy-sync/src/channel/pubsub/publisher.rs | 182 | ||||
-rw-r--r-- | embassy-sync/src/channel/pubsub/subscriber.rs | 152 | ||||
-rw-r--r-- | embassy-sync/src/channel/signal.rs | 100 | ||||
-rw-r--r-- | embassy-sync/src/fmt.rs | 228 | ||||
-rw-r--r-- | embassy-sync/src/lib.rs | 17 | ||||
-rw-r--r-- | embassy-sync/src/mutex.rs | 167 | ||||
-rw-r--r-- | embassy-sync/src/pipe.rs | 551 | ||||
-rw-r--r-- | embassy-sync/src/ring_buffer.rs | 146 | ||||
-rw-r--r-- | embassy-sync/src/waitqueue/mod.rs | 7 | ||||
-rw-r--r-- | embassy-sync/src/waitqueue/multi_waker.rs | 33 | ||||
-rw-r--r-- | embassy-sync/src/waitqueue/waker.rs | 92 |
18 files changed, 3219 insertions, 0 deletions
diff --git a/embassy-sync/Cargo.toml b/embassy-sync/Cargo.toml new file mode 100644 index 00000000..0d14bba5 --- /dev/null +++ b/embassy-sync/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "embassy-sync" +version = "0.1.0" +edition = "2021" + +[package.metadata.embassy_docs] +src_base = "https://github.com/embassy-rs/embassy/blob/embassy-sync-v$VERSION/embassy-sync/src/" +src_base_git = "https://github.com/embassy-rs/embassy/blob/$COMMIT/embassy-sync/src/" +features = ["nightly"] +target = "thumbv7em-none-eabi" + +[features] +nightly = ["embedded-io/async"] + +[dependencies] +defmt = { version = "0.3", optional = true } +log = { version = "0.4.14", optional = true } + +futures-util = { version = "0.3.17", default-features = false } +atomic-polyfill = "1.0.1" +critical-section = "1.1" +heapless = "0.7.5" +cfg-if = "1.0.0" +embedded-io = "0.3.0" + +[dev-dependencies] +futures-executor = { version = "0.3.17", features = [ "thread-pool" ] } +futures-test = "0.3.17" +futures-timer = "3.0.2" +futures-util = { version = "0.3.17", features = [ "channel" ] } + +# Enable critical-section implementation for std, for tests +critical-section = { version = "1.1", features = ["std"] } +static_cell = "1.0" diff --git a/embassy-sync/build.rs b/embassy-sync/build.rs new file mode 100644 index 00000000..6fe82b44 --- /dev/null +++ b/embassy-sync/build.rs @@ -0,0 +1,29 @@ +use std::env; + +fn main() { + let target = env::var("TARGET").unwrap(); + + if target.starts_with("thumbv6m-") { + println!("cargo:rustc-cfg=cortex_m"); + println!("cargo:rustc-cfg=armv6m"); + } else if target.starts_with("thumbv7m-") { + println!("cargo:rustc-cfg=cortex_m"); + println!("cargo:rustc-cfg=armv7m"); + } else if target.starts_with("thumbv7em-") { + println!("cargo:rustc-cfg=cortex_m"); + println!("cargo:rustc-cfg=armv7m"); + println!("cargo:rustc-cfg=armv7em"); // (not currently used) + } else if target.starts_with("thumbv8m.base") { + println!("cargo:rustc-cfg=cortex_m"); + println!("cargo:rustc-cfg=armv8m"); + println!("cargo:rustc-cfg=armv8m_base"); + } else if target.starts_with("thumbv8m.main") { + println!("cargo:rustc-cfg=cortex_m"); + println!("cargo:rustc-cfg=armv8m"); + println!("cargo:rustc-cfg=armv8m_main"); + } + + if target.ends_with("-eabihf") { + println!("cargo:rustc-cfg=has_fpu"); + } +} diff --git a/embassy-sync/src/blocking_mutex/mod.rs b/embassy-sync/src/blocking_mutex/mod.rs new file mode 100644 index 00000000..8a4a4c64 --- /dev/null +++ b/embassy-sync/src/blocking_mutex/mod.rs @@ -0,0 +1,189 @@ +//! Blocking mutex. +//! +//! This module provides a blocking mutex that can be used to synchronize data. +pub mod raw; + +use core::cell::UnsafeCell; + +use self::raw::RawMutex; + +/// Blocking mutex (not async) +/// +/// Provides a blocking mutual exclusion primitive backed by an implementation of [`raw::RawMutex`]. +/// +/// Which implementation you select depends on the context in which you're using the mutex, and you can choose which kind +/// of interior mutability fits your use case. +/// +/// Use [`CriticalSectionMutex`] when data can be shared between threads and interrupts. +/// +/// Use [`NoopMutex`] when data is only shared between tasks running on the same executor. +/// +/// Use [`ThreadModeMutex`] when data is shared between tasks running on the same executor but you want a global singleton. +/// +/// In all cases, the blocking mutex is intended to be short lived and not held across await points. +/// Use the async [`Mutex`](crate::mutex::Mutex) if you need a lock that is held across await points. +pub struct Mutex<R, T: ?Sized> { + // NOTE: `raw` must be FIRST, so when using ThreadModeMutex the "can't drop in non-thread-mode" gets + // to run BEFORE dropping `data`. + raw: R, + data: UnsafeCell<T>, +} + +unsafe impl<R: RawMutex + Send, T: ?Sized + Send> Send for Mutex<R, T> {} +unsafe impl<R: RawMutex + Sync, T: ?Sized + Send> Sync for Mutex<R, T> {} + +impl<R: RawMutex, T> Mutex<R, T> { + /// Creates a new mutex in an unlocked state ready for use. + #[inline] + pub const fn new(val: T) -> Mutex<R, T> { + Mutex { + raw: R::INIT, + data: UnsafeCell::new(val), + } + } + + /// Creates a critical section and grants temporary access to the protected data. + pub fn lock<U>(&self, f: impl FnOnce(&T) -> U) -> U { + self.raw.lock(|| { + let ptr = self.data.get() as *const T; + let inner = unsafe { &*ptr }; + f(inner) + }) + } +} + +impl<R, T> Mutex<R, T> { + /// Creates a new mutex based on a pre-existing raw mutex. + /// + /// This allows creating a mutex in a constant context on stable Rust. + #[inline] + pub const fn const_new(raw_mutex: R, val: T) -> Mutex<R, T> { + Mutex { + raw: raw_mutex, + data: UnsafeCell::new(val), + } + } + + /// Consumes this mutex, returning the underlying data. + #[inline] + pub fn into_inner(self) -> T { + self.data.into_inner() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place---the mutable borrow statically guarantees no locks exist. + #[inline] + pub fn get_mut(&mut self) -> &mut T { + unsafe { &mut *self.data.get() } + } +} + +/// A mutex that allows borrowing data across executors and interrupts. +/// +/// # Safety +/// +/// This mutex is safe to share between different executors and interrupts. +pub type CriticalSectionMutex<T> = Mutex<raw::CriticalSectionRawMutex, T>; + +/// A mutex that allows borrowing data in the context of a single executor. +/// +/// # Safety +/// +/// **This Mutex is only safe within a single executor.** +pub type NoopMutex<T> = Mutex<raw::NoopRawMutex, T>; + +impl<T> Mutex<raw::CriticalSectionRawMutex, T> { + /// Borrows the data for the duration of the critical section + pub fn borrow<'cs>(&'cs self, _cs: critical_section::CriticalSection<'cs>) -> &'cs T { + let ptr = self.data.get() as *const T; + unsafe { &*ptr } + } +} + +impl<T> Mutex<raw::NoopRawMutex, T> { + /// Borrows the data + pub fn borrow(&self) -> &T { + let ptr = self.data.get() as *const T; + unsafe { &*ptr } + } +} + +// ThreadModeMutex does NOT use the generic mutex from above because it's special: +// it's Send+Sync even if T: !Send. There's no way to do that without specialization (I think?). +// +// There's still a ThreadModeRawMutex for use with the generic Mutex (handy with Channel, for example), +// but that will require T: Send even though it shouldn't be needed. + +#[cfg(any(cortex_m, feature = "std"))] +pub use thread_mode_mutex::*; +#[cfg(any(cortex_m, feature = "std"))] +mod thread_mode_mutex { + use super::*; + + /// A "mutex" that only allows borrowing from thread mode. + /// + /// # Safety + /// + /// **This Mutex is only safe on single-core systems.** + /// + /// On multi-core systems, a `ThreadModeMutex` **is not sufficient** to ensure exclusive access. + pub struct ThreadModeMutex<T: ?Sized> { + inner: UnsafeCell<T>, + } + + // NOTE: ThreadModeMutex only allows borrowing from one execution context ever: thread mode. + // Therefore it cannot be used to send non-sendable stuff between execution contexts, so it can + // be Send+Sync even if T is not Send (unlike CriticalSectionMutex) + unsafe impl<T: ?Sized> Sync for ThreadModeMutex<T> {} + unsafe impl<T: ?Sized> Send for ThreadModeMutex<T> {} + + impl<T> ThreadModeMutex<T> { + /// Creates a new mutex + pub const fn new(value: T) -> Self { + ThreadModeMutex { + inner: UnsafeCell::new(value), + } + } + } + + impl<T: ?Sized> ThreadModeMutex<T> { + /// Lock the `ThreadModeMutex`, granting access to the data. + /// + /// # Panics + /// + /// This will panic if not currently running in thread mode. + pub fn lock<R>(&self, f: impl FnOnce(&T) -> R) -> R { + f(self.borrow()) + } + + /// Borrows the data + /// + /// # Panics + /// + /// This will panic if not currently running in thread mode. + pub fn borrow(&self) -> &T { + assert!( + raw::in_thread_mode(), + "ThreadModeMutex can only be borrowed from thread mode." + ); + unsafe { &*self.inner.get() } + } + } + + impl<T: ?Sized> Drop for ThreadModeMutex<T> { + fn drop(&mut self) { + // Only allow dropping from thread mode. Dropping calls drop on the inner `T`, so + // `drop` needs the same guarantees as `lock`. `ThreadModeMutex<T>` is Send even if + // T isn't, so without this check a user could create a ThreadModeMutex in thread mode, + // send it to interrupt context and drop it there, which would "send" a T even if T is not Send. + assert!( + raw::in_thread_mode(), + "ThreadModeMutex can only be dropped from thread mode." + ); + + // Drop of the inner `T` happens after this. + } + } +} diff --git a/embassy-sync/src/blocking_mutex/raw.rs b/embassy-sync/src/blocking_mutex/raw.rs new file mode 100644 index 00000000..15796f1b --- /dev/null +++ b/embassy-sync/src/blocking_mutex/raw.rs @@ -0,0 +1,149 @@ +//! Mutex primitives. +//! +//! This module provides a trait for mutexes that can be used in different contexts. +use core::marker::PhantomData; + +/// Raw mutex trait. +/// +/// This mutex is "raw", which means it does not actually contain the protected data, it +/// just implements the mutex mechanism. For most uses you should use [`super::Mutex`] instead, +/// which is generic over a RawMutex and contains the protected data. +/// +/// Note that, unlike other mutexes, implementations only guarantee no +/// concurrent access from other threads: concurrent access from the current +/// thread is allwed. For example, it's possible to lock the same mutex multiple times reentrantly. +/// +/// Therefore, locking a `RawMutex` is only enough to guarantee safe shared (`&`) access +/// to the data, it is not enough to guarantee exclusive (`&mut`) access. +/// +/// # Safety +/// +/// RawMutex implementations must ensure that, while locked, no other thread can lock +/// the RawMutex concurrently. +/// +/// Unsafe code is allowed to rely on this fact, so incorrect implementations will cause undefined behavior. +pub unsafe trait RawMutex { + /// Create a new `RawMutex` instance. + /// + /// This is a const instead of a method to allow creating instances in const context. + const INIT: Self; + + /// Lock this `RawMutex`. + fn lock<R>(&self, f: impl FnOnce() -> R) -> R; +} + +/// A mutex that allows borrowing data across executors and interrupts. +/// +/// # Safety +/// +/// This mutex is safe to share between different executors and interrupts. +pub struct CriticalSectionRawMutex { + _phantom: PhantomData<()>, +} +unsafe impl Send for CriticalSectionRawMutex {} +unsafe impl Sync for CriticalSectionRawMutex {} + +impl CriticalSectionRawMutex { + /// Create a new `CriticalSectionRawMutex`. + pub const fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +unsafe impl RawMutex for CriticalSectionRawMutex { + const INIT: Self = Self::new(); + + fn lock<R>(&self, f: impl FnOnce() -> R) -> R { + critical_section::with(|_| f()) + } +} + +// ================ + +/// A mutex that allows borrowing data in the context of a single executor. +/// +/// # Safety +/// +/// **This Mutex is only safe within a single executor.** +pub struct NoopRawMutex { + _phantom: PhantomData<*mut ()>, +} + +unsafe impl Send for NoopRawMutex {} + +impl NoopRawMutex { + /// Create a new `NoopRawMutex`. + pub const fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +unsafe impl RawMutex for NoopRawMutex { + const INIT: Self = Self::new(); + fn lock<R>(&self, f: impl FnOnce() -> R) -> R { + f() + } +} + +// ================ + +#[cfg(any(cortex_m, feature = "std"))] +mod thread_mode { + use super::*; + + /// A "mutex" that only allows borrowing from thread mode. + /// + /// # Safety + /// + /// **This Mutex is only safe on single-core systems.** + /// + /// On multi-core systems, a `ThreadModeRawMutex` **is not sufficient** to ensure exclusive access. + pub struct ThreadModeRawMutex { + _phantom: PhantomData<()>, + } + + unsafe impl Send for ThreadModeRawMutex {} + unsafe impl Sync for ThreadModeRawMutex {} + + impl ThreadModeRawMutex { + /// Create a new `ThreadModeRawMutex`. + pub const fn new() -> Self { + Self { _phantom: PhantomData } + } + } + + unsafe impl RawMutex for ThreadModeRawMutex { + const INIT: Self = Self::new(); + fn lock<R>(&self, f: impl FnOnce() -> R) -> R { + assert!(in_thread_mode(), "ThreadModeMutex can only be locked from thread mode."); + + f() + } + } + + impl Drop for ThreadModeRawMutex { + fn drop(&mut self) { + // Only allow dropping from thread mode. Dropping calls drop on the inner `T`, so + // `drop` needs the same guarantees as `lock`. `ThreadModeMutex<T>` is Send even if + // T isn't, so without this check a user could create a ThreadModeMutex in thread mode, + // send it to interrupt context and drop it there, which would "send" a T even if T is not Send. + assert!( + in_thread_mode(), + "ThreadModeMutex can only be dropped from thread mode." + ); + + // Drop of the inner `T` happens after this. + } + } + + pub(crate) fn in_thread_mode() -> bool { + #[cfg(feature = "std")] + return Some("main") == std::thread::current().name(); + + #[cfg(not(feature = "std"))] + // ICSR.VECTACTIVE == 0 + return unsafe { (0xE000ED04 as *const u32).read_volatile() } & 0x1FF == 0; + } +} +#[cfg(any(cortex_m, feature = "std"))] +pub use thread_mode::*; diff --git a/embassy-sync/src/channel/mod.rs b/embassy-sync/src/channel/mod.rs new file mode 100644 index 00000000..5df1f5c5 --- /dev/null +++ b/embassy-sync/src/channel/mod.rs @@ -0,0 +1,5 @@ +//! Async channels + +pub mod mpmc; +pub mod pubsub; +pub mod signal; diff --git a/embassy-sync/src/channel/mpmc.rs b/embassy-sync/src/channel/mpmc.rs new file mode 100644 index 00000000..7bebd341 --- /dev/null +++ b/embassy-sync/src/channel/mpmc.rs @@ -0,0 +1,596 @@ +//! A queue for sending values between asynchronous tasks. +//! +//! It can be used concurrently by multiple producers (senders) and multiple +//! consumers (receivers), i.e. it is an "MPMC channel". +//! +//! Receivers are competing for messages. So a message that is received by +//! one receiver is not received by any other. +//! +//! This queue takes a Mutex type so that various +//! targets can be attained. For example, a ThreadModeMutex can be used +//! for single-core Cortex-M targets where messages are only passed +//! between tasks running in thread mode. Similarly, a CriticalSectionMutex +//! can also be used for single-core targets where messages are to be +//! passed from exception mode e.g. out of an interrupt handler. +//! +//! This module provides a bounded channel that has a limit on the number of +//! messages that it can store, and if this limit is reached, trying to send +//! another message will result in an error being returned. +//! + +use core::cell::RefCell; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use heapless::Deque; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex; +use crate::waitqueue::WakerRegistration; + +/// Send-only access to a [`Channel`]. +#[derive(Copy)] +pub struct Sender<'ch, M, T, const N: usize> +where + M: RawMutex, +{ + channel: &'ch Channel<M, T, N>, +} + +impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> +where + M: RawMutex, +{ + fn clone(&self) -> Self { + Sender { channel: self.channel } + } +} + +impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> +where + M: RawMutex, +{ + /// Sends a value. + /// + /// See [`Channel::send()`] + pub fn send(&self, message: T) -> SendFuture<'ch, M, T, N> { + self.channel.send(message) + } + + /// Attempt to immediately send a message. + /// + /// See [`Channel::send()`] + pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { + self.channel.try_send(message) + } +} + +/// Send-only access to a [`Channel`] without knowing channel size. +#[derive(Copy)] +pub struct DynamicSender<'ch, T> { + channel: &'ch dyn DynamicChannel<T>, +} + +impl<'ch, T> Clone for DynamicSender<'ch, T> { + fn clone(&self) -> Self { + DynamicSender { channel: self.channel } + } +} + +impl<'ch, M, T, const N: usize> From<Sender<'ch, M, T, N>> for DynamicSender<'ch, T> +where + M: RawMutex, +{ + fn from(s: Sender<'ch, M, T, N>) -> Self { + Self { channel: s.channel } + } +} + +impl<'ch, T> DynamicSender<'ch, T> { + /// Sends a value. + /// + /// See [`Channel::send()`] + pub fn send(&self, message: T) -> DynamicSendFuture<'ch, T> { + DynamicSendFuture { + channel: self.channel, + message: Some(message), + } + } + + /// Attempt to immediately send a message. + /// + /// See [`Channel::send()`] + pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { + self.channel.try_send_with_context(message, None) + } +} + +/// Receive-only access to a [`Channel`]. +#[derive(Copy)] +pub struct Receiver<'ch, M, T, const N: usize> +where + M: RawMutex, +{ + channel: &'ch Channel<M, T, N>, +} + +impl<'ch, M, T, const N: usize> Clone for Receiver<'ch, M, T, N> +where + M: RawMutex, +{ + fn clone(&self) -> Self { + Receiver { channel: self.channel } + } +} + +impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> +where + M: RawMutex, +{ + /// Receive the next value. + /// + /// See [`Channel::recv()`]. + pub fn recv(&self) -> RecvFuture<'_, M, T, N> { + self.channel.recv() + } + + /// Attempt to immediately receive the next value. + /// + /// See [`Channel::try_recv()`] + pub fn try_recv(&self) -> Result<T, TryRecvError> { + self.channel.try_recv() + } +} + +/// Receive-only access to a [`Channel`] without knowing channel size. +#[derive(Copy)] +pub struct DynamicReceiver<'ch, T> { + channel: &'ch dyn DynamicChannel<T>, +} + +impl<'ch, T> Clone for DynamicReceiver<'ch, T> { + fn clone(&self) -> Self { + DynamicReceiver { channel: self.channel } + } +} + +impl<'ch, T> DynamicReceiver<'ch, T> { + /// Receive the next value. + /// + /// See [`Channel::recv()`]. + pub fn recv(&self) -> DynamicRecvFuture<'_, T> { + DynamicRecvFuture { channel: self.channel } + } + + /// Attempt to immediately receive the next value. + /// + /// See [`Channel::try_recv()`] + pub fn try_recv(&self) -> Result<T, TryRecvError> { + self.channel.try_recv_with_context(None) + } +} + +impl<'ch, M, T, const N: usize> From<Receiver<'ch, M, T, N>> for DynamicReceiver<'ch, T> +where + M: RawMutex, +{ + fn from(s: Receiver<'ch, M, T, N>) -> Self { + Self { channel: s.channel } + } +} + +/// Future returned by [`Channel::recv`] and [`Receiver::recv`]. +pub struct RecvFuture<'ch, M, T, const N: usize> +where + M: RawMutex, +{ + channel: &'ch Channel<M, T, N>, +} + +impl<'ch, M, T, const N: usize> Future for RecvFuture<'ch, M, T, N> +where + M: RawMutex, +{ + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { + match self.channel.try_recv_with_context(Some(cx)) { + Ok(v) => Poll::Ready(v), + Err(TryRecvError::Empty) => Poll::Pending, + } + } +} + +/// Future returned by [`DynamicReceiver::recv`]. +pub struct DynamicRecvFuture<'ch, T> { + channel: &'ch dyn DynamicChannel<T>, +} + +impl<'ch, T> Future for DynamicRecvFuture<'ch, T> { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { + match self.channel.try_recv_with_context(Some(cx)) { + Ok(v) => Poll::Ready(v), + Err(TryRecvError::Empty) => Poll::Pending, + } + } +} + +/// Future returned by [`Channel::send`] and [`Sender::send`]. +pub struct SendFuture<'ch, M, T, const N: usize> +where + M: RawMutex, +{ + channel: &'ch Channel<M, T, N>, + message: Option<T>, +} + +impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> +where + M: RawMutex, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.message.take() { + Some(m) => match self.channel.try_send_with_context(m, Some(cx)) { + Ok(..) => Poll::Ready(()), + Err(TrySendError::Full(m)) => { + self.message = Some(m); + Poll::Pending + } + }, + None => panic!("Message cannot be None"), + } + } +} + +impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: RawMutex {} + +/// Future returned by [`DynamicSender::send`]. +pub struct DynamicSendFuture<'ch, T> { + channel: &'ch dyn DynamicChannel<T>, + message: Option<T>, +} + +impl<'ch, T> Future for DynamicSendFuture<'ch, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.message.take() { + Some(m) => match self.channel.try_send_with_context(m, Some(cx)) { + Ok(..) => Poll::Ready(()), + Err(TrySendError::Full(m)) => { + self.message = Some(m); + Poll::Pending + } + }, + None => panic!("Message cannot be None"), + } + } +} + +impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {} + +trait DynamicChannel<T> { + fn try_send_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>>; + + fn try_recv_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError>; +} + +/// Error returned by [`try_recv`](Channel::try_recv). +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum TryRecvError { + /// A message could not be received because the channel is empty. + Empty, +} + +/// Error returned by [`try_send`](Channel::try_send). +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum TrySendError<T> { + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + Full(T), +} + +struct ChannelState<T, const N: usize> { + queue: Deque<T, N>, + receiver_waker: WakerRegistration, + senders_waker: WakerRegistration, +} + +impl<T, const N: usize> ChannelState<T, N> { + const fn new() -> Self { + ChannelState { + queue: Deque::new(), + receiver_waker: WakerRegistration::new(), + senders_waker: WakerRegistration::new(), + } + } + + fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.try_recv_with_context(None) + } + + fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> { + if self.queue.is_full() { + self.senders_waker.wake(); + } + + if let Some(message) = self.queue.pop_front() { + Ok(message) + } else { + if let Some(cx) = cx { + self.receiver_waker.register(cx.waker()); + } + Err(TryRecvError::Empty) + } + } + + fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { + self.try_send_with_context(message, None) + } + + fn try_send_with_context(&mut self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> { + match self.queue.push_back(message) { + Ok(()) => { + self.receiver_waker.wake(); + Ok(()) + } + Err(message) => { + if let Some(cx) = cx { + self.senders_waker.register(cx.waker()); + } + Err(TrySendError::Full(message)) + } + } + } +} + +/// A bounded channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to `send` new messages will wait until a message is +/// received from the channel. +/// +/// All data sent will become available in the same order as it was sent. +pub struct Channel<M, T, const N: usize> +where + M: RawMutex, +{ + inner: Mutex<M, RefCell<ChannelState<T, N>>>, +} + +impl<M, T, const N: usize> Channel<M, T, N> +where + M: RawMutex, +{ + /// Establish a new bounded channel. For example, to create one with a NoopMutex: + /// + /// ``` + /// use embassy_sync::channel::mpmc::Channel; + /// use embassy_sync::blocking_mutex::raw::NoopRawMutex; + /// + /// // Declare a bounded channel of 3 u32s. + /// let mut channel = Channel::<NoopRawMutex, u32, 3>::new(); + /// ``` + pub const fn new() -> Self { + Self { + inner: Mutex::new(RefCell::new(ChannelState::new())), + } + } + + fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, N>) -> R) -> R { + self.inner.lock(|rc| f(&mut *rc.borrow_mut())) + } + + fn try_recv_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> { + self.lock(|c| c.try_recv_with_context(cx)) + } + + fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> { + self.lock(|c| c.try_send_with_context(m, cx)) + } + + /// Get a sender for this channel. + pub fn sender(&self) -> Sender<'_, M, T, N> { + Sender { channel: self } + } + + /// Get a receiver for this channel. + pub fn receiver(&self) -> Receiver<'_, M, T, N> { + Receiver { channel: self } + } + + /// Send a value, waiting until there is capacity. + /// + /// Sending completes when the value has been pushed to the channel's queue. + /// This doesn't mean the value has been received yet. + pub fn send(&self, message: T) -> SendFuture<'_, M, T, N> { + SendFuture { + channel: self, + message: Some(message), + } + } + + /// Attempt to immediately send a message. + /// + /// This method differs from [`send`](Channel::send) by returning immediately if the channel's + /// buffer is full, instead of waiting. + /// + /// # Errors + /// + /// If the channel capacity has been reached, i.e., the channel has `n` + /// buffered values where `n` is the argument passed to [`Channel`], then an + /// error is returned. + pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { + self.lock(|c| c.try_send(message)) + } + + /// Receive the next value. + /// + /// If there are no messages in the channel's buffer, this method will + /// wait until a message is sent. + pub fn recv(&self) -> RecvFuture<'_, M, T, N> { + RecvFuture { channel: self } + } + + /// Attempt to immediately receive a message. + /// + /// This method will either receive a message from the channel immediately or return an error + /// if the channel is empty. + pub fn try_recv(&self) -> Result<T, TryRecvError> { + self.lock(|c| c.try_recv()) + } +} + +/// Implements the DynamicChannel to allow creating types that are unaware of the queue size with the +/// tradeoff cost of dynamic dispatch. +impl<M, T, const N: usize> DynamicChannel<T> for Channel<M, T, N> +where + M: RawMutex, +{ + fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> { + Channel::try_send_with_context(self, m, cx) + } + + fn try_recv_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> { + Channel::try_recv_with_context(self, cx) + } +} + +#[cfg(test)] +mod tests { + use core::time::Duration; + + use futures_executor::ThreadPool; + use futures_timer::Delay; + use futures_util::task::SpawnExt; + use static_cell::StaticCell; + + use super::*; + use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex}; + + fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize { + c.queue.capacity() - c.queue.len() + } + + #[test] + fn sending_once() { + let mut c = ChannelState::<u32, 3>::new(); + assert!(c.try_send(1).is_ok()); + assert_eq!(capacity(&c), 2); + } + + #[test] + fn sending_when_full() { + let mut c = ChannelState::<u32, 3>::new(); + let _ = c.try_send(1); + let _ = c.try_send(1); + let _ = c.try_send(1); + match c.try_send(2) { + Err(TrySendError::Full(2)) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 0); + } + + #[test] + fn receiving_once_with_one_send() { + let mut c = ChannelState::<u32, 3>::new(); + assert!(c.try_send(1).is_ok()); + assert_eq!(c.try_recv().unwrap(), 1); + assert_eq!(capacity(&c), 3); + } + + #[test] + fn receiving_when_empty() { + let mut c = ChannelState::<u32, 3>::new(); + match c.try_recv() { + Err(TryRecvError::Empty) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 3); + } + + #[test] + fn simple_send_and_receive() { + let c = Channel::<NoopRawMutex, u32, 3>::new(); + assert!(c.try_send(1).is_ok()); + assert_eq!(c.try_recv().unwrap(), 1); + } + + #[test] + fn cloning() { + let c = Channel::<NoopRawMutex, u32, 3>::new(); + let r1 = c.receiver(); + let s1 = c.sender(); + + let _ = r1.clone(); + let _ = s1.clone(); + } + + #[test] + fn dynamic_dispatch() { + let c = Channel::<NoopRawMutex, u32, 3>::new(); + let s: DynamicSender<'_, u32> = c.sender().into(); + let r: DynamicReceiver<'_, u32> = c.receiver().into(); + + assert!(s.try_send(1).is_ok()); + assert_eq!(r.try_recv().unwrap(), 1); + } + + #[futures_test::test] + async fn receiver_receives_given_try_send_async() { + let executor = ThreadPool::new().unwrap(); + + static CHANNEL: StaticCell<Channel<CriticalSectionRawMutex, u32, 3>> = StaticCell::new(); + let c = &*CHANNEL.init(Channel::new()); + let c2 = c; + assert!(executor + .spawn(async move { + assert!(c2.try_send(1).is_ok()); + }) + .is_ok()); + assert_eq!(c.recv().await, 1); + } + + #[futures_test::test] + async fn sender_send_completes_if_capacity() { + let c = Channel::<CriticalSectionRawMutex, u32, 1>::new(); + c.send(1).await; + assert_eq!(c.recv().await, 1); + } + + #[futures_test::test] + async fn senders_sends_wait_until_capacity() { + let executor = ThreadPool::new().unwrap(); + + static CHANNEL: StaticCell<Channel<CriticalSectionRawMutex, u32, 1>> = StaticCell::new(); + let c = &*CHANNEL.init(Channel::new()); + assert!(c.try_send(1).is_ok()); + + let c2 = c; + let send_task_1 = executor.spawn_with_handle(async move { c2.send(2).await }); + let c2 = c; + let send_task_2 = executor.spawn_with_handle(async move { c2.send(3).await }); + // Wish I could think of a means of determining that the async send is waiting instead. + // However, I've used the debugger to observe that the send does indeed wait. + Delay::new(Duration::from_millis(500)).await; + assert_eq!(c.recv().await, 1); + assert!(executor + .spawn(async move { + loop { + c.recv().await; + } + }) + .is_ok()); + send_task_1.unwrap().await; + send_task_2.unwrap().await; + } +} diff --git a/embassy-sync/src/channel/pubsub/mod.rs b/embassy-sync/src/channel/pubsub/mod.rs new file mode 100644 index 00000000..f62b4d11 --- /dev/null +++ b/embassy-sync/src/channel/pubsub/mod.rs @@ -0,0 +1,542 @@ +//! Implementation of [PubSubChannel], a queue where published messages get received by all subscribers. + +#![deny(missing_docs)] + +use core::cell::RefCell; +use core::fmt::Debug; +use core::task::{Context, Poll, Waker}; + +use heapless::Deque; + +use self::publisher::{ImmediatePub, Pub}; +use self::subscriber::Sub; +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex; +use crate::waitqueue::MultiWakerRegistration; + +pub mod publisher; +pub mod subscriber; + +pub use publisher::{DynImmediatePublisher, DynPublisher, ImmediatePublisher, Publisher}; +pub use subscriber::{DynSubscriber, Subscriber}; + +/// A broadcast channel implementation where multiple publishers can send messages to multiple subscribers +/// +/// Any published message can be read by all subscribers. +/// A publisher can choose how it sends its message. +/// +/// - With [Pub::publish()] the publisher has to wait until there is space in the internal message queue. +/// - With [Pub::publish_immediate()] the publisher doesn't await and instead lets the oldest message +/// in the queue drop if necessary. This will cause any [Subscriber] that missed the message to receive +/// an error to indicate that it has lagged. +/// +/// ## Example +/// +/// ``` +/// # use embassy_sync::blocking_mutex::raw::NoopRawMutex; +/// # use embassy_sync::channel::pubsub::WaitResult; +/// # use embassy_sync::channel::pubsub::PubSubChannel; +/// # use futures_executor::block_on; +/// # let test = async { +/// // Create the channel. This can be static as well +/// let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); +/// +/// // This is a generic subscriber with a direct reference to the channel +/// let mut sub0 = channel.subscriber().unwrap(); +/// // This is a dynamic subscriber with a dynamic (trait object) reference to the channel +/// let mut sub1 = channel.dyn_subscriber().unwrap(); +/// +/// let pub0 = channel.publisher().unwrap(); +/// +/// // Publish a message, but wait if the queue is full +/// pub0.publish(42).await; +/// +/// // Publish a message, but if the queue is full, just kick out the oldest message. +/// // This may cause some subscribers to miss a message +/// pub0.publish_immediate(43); +/// +/// // Wait for a new message. If the subscriber missed a message, the WaitResult will be a Lag result +/// assert_eq!(sub0.next_message().await, WaitResult::Message(42)); +/// assert_eq!(sub1.next_message().await, WaitResult::Message(42)); +/// +/// // Wait again, but this time ignore any Lag results +/// assert_eq!(sub0.next_message_pure().await, 43); +/// assert_eq!(sub1.next_message_pure().await, 43); +/// +/// // There's also a polling interface +/// assert_eq!(sub0.try_next_message(), None); +/// assert_eq!(sub1.try_next_message(), None); +/// # }; +/// # +/// # block_on(test); +/// ``` +/// +pub struct PubSubChannel<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> { + inner: Mutex<M, RefCell<PubSubState<T, CAP, SUBS, PUBS>>>, +} + +impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> + PubSubChannel<M, T, CAP, SUBS, PUBS> +{ + /// Create a new channel + pub const fn new() -> Self { + Self { + inner: Mutex::const_new(M::INIT, RefCell::new(PubSubState::new())), + } + } + + /// Create a new subscriber. It will only receive messages that are published after its creation. + /// + /// If there are no subscriber slots left, an error will be returned. + pub fn subscriber(&self) -> Result<Subscriber<M, T, CAP, SUBS, PUBS>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.subscriber_count >= SUBS { + Err(Error::MaximumSubscribersReached) + } else { + s.subscriber_count += 1; + Ok(Subscriber(Sub::new(s.next_message_id, self))) + } + }) + } + + /// Create a new subscriber. It will only receive messages that are published after its creation. + /// + /// If there are no subscriber slots left, an error will be returned. + pub fn dyn_subscriber(&self) -> Result<DynSubscriber<'_, T>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.subscriber_count >= SUBS { + Err(Error::MaximumSubscribersReached) + } else { + s.subscriber_count += 1; + Ok(DynSubscriber(Sub::new(s.next_message_id, self))) + } + }) + } + + /// Create a new publisher + /// + /// If there are no publisher slots left, an error will be returned. + pub fn publisher(&self) -> Result<Publisher<M, T, CAP, SUBS, PUBS>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.publisher_count >= PUBS { + Err(Error::MaximumPublishersReached) + } else { + s.publisher_count += 1; + Ok(Publisher(Pub::new(self))) + } + }) + } + + /// Create a new publisher + /// + /// If there are no publisher slots left, an error will be returned. + pub fn dyn_publisher(&self) -> Result<DynPublisher<'_, T>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.publisher_count >= PUBS { + Err(Error::MaximumPublishersReached) + } else { + s.publisher_count += 1; + Ok(DynPublisher(Pub::new(self))) + } + }) + } + + /// Create a new publisher that can only send immediate messages. + /// This kind of publisher does not take up a publisher slot. + pub fn immediate_publisher(&self) -> ImmediatePublisher<M, T, CAP, SUBS, PUBS> { + ImmediatePublisher(ImmediatePub::new(self)) + } + + /// Create a new publisher that can only send immediate messages. + /// This kind of publisher does not take up a publisher slot. + pub fn dyn_immediate_publisher(&self) -> DynImmediatePublisher<T> { + DynImmediatePublisher(ImmediatePub::new(self)) + } +} + +impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubBehavior<T> + for PubSubChannel<M, T, CAP, SUBS, PUBS> +{ + fn get_message_with_context(&self, next_message_id: &mut u64, cx: Option<&mut Context<'_>>) -> Poll<WaitResult<T>> { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + + // Check if we can read a message + match s.get_message(*next_message_id) { + // Yes, so we are done polling + Some(WaitResult::Message(message)) => { + *next_message_id += 1; + Poll::Ready(WaitResult::Message(message)) + } + // No, so we need to reregister our waker and sleep again + None => { + if let Some(cx) = cx { + s.register_subscriber_waker(cx.waker()); + } + Poll::Pending + } + // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged + Some(WaitResult::Lagged(amount)) => { + *next_message_id += amount; + Poll::Ready(WaitResult::Lagged(amount)) + } + } + }) + } + + fn publish_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), T> { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + // Try to publish the message + match s.try_publish(message) { + // We did it, we are ready + Ok(()) => Ok(()), + // The queue is full, so we need to reregister our waker and go to sleep + Err(message) => { + if let Some(cx) = cx { + s.register_publisher_waker(cx.waker()); + } + Err(message) + } + } + }) + } + + fn publish_immediate(&self, message: T) { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.publish_immediate(message) + }) + } + + fn unregister_subscriber(&self, subscriber_next_message_id: u64) { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.unregister_subscriber(subscriber_next_message_id) + }) + } + + fn unregister_publisher(&self) { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.unregister_publisher() + }) + } +} + +/// Internal state for the PubSub channel +struct PubSubState<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> { + /// The queue contains the last messages that have been published and a countdown of how many subscribers are yet to read it + queue: Deque<(T, usize), CAP>, + /// Every message has an id. + /// Don't worry, we won't run out. + /// If a million messages were published every second, then the ID's would run out in about 584942 years. + next_message_id: u64, + /// Collection of wakers for Subscribers that are waiting. + subscriber_wakers: MultiWakerRegistration<SUBS>, + /// Collection of wakers for Publishers that are waiting. + publisher_wakers: MultiWakerRegistration<PUBS>, + /// The amount of subscribers that are active + subscriber_count: usize, + /// The amount of publishers that are active + publisher_count: usize, +} + +impl<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubState<T, CAP, SUBS, PUBS> { + /// Create a new internal channel state + const fn new() -> Self { + Self { + queue: Deque::new(), + next_message_id: 0, + subscriber_wakers: MultiWakerRegistration::new(), + publisher_wakers: MultiWakerRegistration::new(), + subscriber_count: 0, + publisher_count: 0, + } + } + + fn try_publish(&mut self, message: T) -> Result<(), T> { + if self.subscriber_count == 0 { + // We don't need to publish anything because there is no one to receive it + return Ok(()); + } + + if self.queue.is_full() { + return Err(message); + } + // We just did a check for this + self.queue.push_back((message, self.subscriber_count)).ok().unwrap(); + + self.next_message_id += 1; + + // Wake all of the subscribers + self.subscriber_wakers.wake(); + + Ok(()) + } + + fn publish_immediate(&mut self, message: T) { + // Make space in the queue if required + if self.queue.is_full() { + self.queue.pop_front(); + } + + // This will succeed because we made sure there is space + self.try_publish(message).ok().unwrap(); + } + + fn get_message(&mut self, message_id: u64) -> Option<WaitResult<T>> { + let start_id = self.next_message_id - self.queue.len() as u64; + + if message_id < start_id { + return Some(WaitResult::Lagged(start_id - message_id)); + } + + let current_message_index = (message_id - start_id) as usize; + + if current_message_index >= self.queue.len() { + return None; + } + + // We've checked that the index is valid + let queue_item = self.queue.iter_mut().nth(current_message_index).unwrap(); + + // We're reading this item, so decrement the counter + queue_item.1 -= 1; + let message = queue_item.0.clone(); + + if current_message_index == 0 && queue_item.1 == 0 { + self.queue.pop_front(); + self.publisher_wakers.wake(); + } + + Some(WaitResult::Message(message)) + } + + fn register_subscriber_waker(&mut self, waker: &Waker) { + match self.subscriber_wakers.register(waker) { + Ok(()) => {} + Err(_) => { + // All waker slots were full. This can only happen when there was a subscriber that now has dropped. + // We need to throw it away. It's a bit inefficient, but we can wake everything. + // Any future that is still active will simply reregister. + // This won't happen a lot, so it's ok. + self.subscriber_wakers.wake(); + self.subscriber_wakers.register(waker).unwrap(); + } + } + } + + fn register_publisher_waker(&mut self, waker: &Waker) { + match self.publisher_wakers.register(waker) { + Ok(()) => {} + Err(_) => { + // All waker slots were full. This can only happen when there was a publisher that now has dropped. + // We need to throw it away. It's a bit inefficient, but we can wake everything. + // Any future that is still active will simply reregister. + // This won't happen a lot, so it's ok. + self.publisher_wakers.wake(); + self.publisher_wakers.register(waker).unwrap(); + } + } + } + + fn unregister_subscriber(&mut self, subscriber_next_message_id: u64) { + self.subscriber_count -= 1; + + // All messages that haven't been read yet by this subscriber must have their counter decremented + let start_id = self.next_message_id - self.queue.len() as u64; + if subscriber_next_message_id >= start_id { + let current_message_index = (subscriber_next_message_id - start_id) as usize; + self.queue + .iter_mut() + .skip(current_message_index) + .for_each(|(_, counter)| *counter -= 1); + } + } + + fn unregister_publisher(&mut self) { + self.publisher_count -= 1; + } +} + +/// Error type for the [PubSubChannel] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + /// All subscriber slots are used. To add another subscriber, first another subscriber must be dropped or + /// the capacity of the channels must be increased. + MaximumSubscribersReached, + /// All publisher slots are used. To add another publisher, first another publisher must be dropped or + /// the capacity of the channels must be increased. + MaximumPublishersReached, +} + +/// 'Middle level' behaviour of the pubsub channel. +/// This trait is used so that Sub and Pub can be generic over the channel. +pub trait PubSubBehavior<T> { + /// Try to get a message from the queue with the given message id. + /// + /// If the message is not yet present and a context is given, then its waker is registered in the subsriber wakers. + fn get_message_with_context(&self, next_message_id: &mut u64, cx: Option<&mut Context<'_>>) -> Poll<WaitResult<T>>; + + /// Try to publish a message to the queue. + /// + /// If the queue is full and a context is given, then its waker is registered in the publisher wakers. + fn publish_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), T>; + + /// Publish a message immediately + fn publish_immediate(&self, message: T); + + /// Let the channel know that a subscriber has dropped + fn unregister_subscriber(&self, subscriber_next_message_id: u64); + + /// Let the channel know that a publisher has dropped + fn unregister_publisher(&self); +} + +/// The result of the subscriber wait procedure +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum WaitResult<T> { + /// The subscriber did not receive all messages and lagged by the given amount of messages. + /// (This is the amount of messages that were missed) + Lagged(u64), + /// A message was received + Message(T), +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::blocking_mutex::raw::NoopRawMutex; + + #[futures_test::test] + async fn dyn_pub_sub_works() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let mut sub0 = channel.dyn_subscriber().unwrap(); + let mut sub1 = channel.dyn_subscriber().unwrap(); + let pub0 = channel.dyn_publisher().unwrap(); + + pub0.publish(42).await; + + assert_eq!(sub0.next_message().await, WaitResult::Message(42)); + assert_eq!(sub1.next_message().await, WaitResult::Message(42)); + + assert_eq!(sub0.try_next_message(), None); + assert_eq!(sub1.try_next_message(), None); + } + + #[futures_test::test] + async fn all_subscribers_receive() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let mut sub0 = channel.subscriber().unwrap(); + let mut sub1 = channel.subscriber().unwrap(); + let pub0 = channel.publisher().unwrap(); + + pub0.publish(42).await; + + assert_eq!(sub0.next_message().await, WaitResult::Message(42)); + assert_eq!(sub1.next_message().await, WaitResult::Message(42)); + + assert_eq!(sub0.try_next_message(), None); + assert_eq!(sub1.try_next_message(), None); + } + + #[futures_test::test] + async fn lag_when_queue_full_on_immediate_publish() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let mut sub0 = channel.subscriber().unwrap(); + let pub0 = channel.publisher().unwrap(); + + pub0.publish_immediate(42); + pub0.publish_immediate(43); + pub0.publish_immediate(44); + pub0.publish_immediate(45); + pub0.publish_immediate(46); + pub0.publish_immediate(47); + + assert_eq!(sub0.try_next_message(), Some(WaitResult::Lagged(2))); + assert_eq!(sub0.next_message().await, WaitResult::Message(44)); + assert_eq!(sub0.next_message().await, WaitResult::Message(45)); + assert_eq!(sub0.next_message().await, WaitResult::Message(46)); + assert_eq!(sub0.next_message().await, WaitResult::Message(47)); + assert_eq!(sub0.try_next_message(), None); + } + + #[test] + fn limited_subs_and_pubs() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let sub0 = channel.subscriber(); + let sub1 = channel.subscriber(); + let sub2 = channel.subscriber(); + let sub3 = channel.subscriber(); + let sub4 = channel.subscriber(); + + assert!(sub0.is_ok()); + assert!(sub1.is_ok()); + assert!(sub2.is_ok()); + assert!(sub3.is_ok()); + assert_eq!(sub4.err().unwrap(), Error::MaximumSubscribersReached); + + drop(sub0); + + let sub5 = channel.subscriber(); + assert!(sub5.is_ok()); + + // publishers + + let pub0 = channel.publisher(); + let pub1 = channel.publisher(); + let pub2 = channel.publisher(); + let pub3 = channel.publisher(); + let pub4 = channel.publisher(); + + assert!(pub0.is_ok()); + assert!(pub1.is_ok()); + assert!(pub2.is_ok()); + assert!(pub3.is_ok()); + assert_eq!(pub4.err().unwrap(), Error::MaximumPublishersReached); + + drop(pub0); + + let pub5 = channel.publisher(); + assert!(pub5.is_ok()); + } + + #[test] + fn publisher_wait_on_full_queue() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let pub0 = channel.publisher().unwrap(); + + // There are no subscribers, so the queue will never be full + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + + let sub0 = channel.subscriber().unwrap(); + + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Err(0)); + + drop(sub0); + } +} diff --git a/embassy-sync/src/channel/pubsub/publisher.rs b/embassy-sync/src/channel/pubsub/publisher.rs new file mode 100644 index 00000000..705797f6 --- /dev/null +++ b/embassy-sync/src/channel/pubsub/publisher.rs @@ -0,0 +1,182 @@ +//! Implementation of anything directly publisher related + +use core::future::Future; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use super::{PubSubBehavior, PubSubChannel}; +use crate::blocking_mutex::raw::RawMutex; + +/// A publisher to a channel +pub struct Pub<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The channel we are a publisher for + channel: &'a PSB, + _phantom: PhantomData<T>, +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Pub<'a, PSB, T> { + pub(super) fn new(channel: &'a PSB) -> Self { + Self { + channel, + _phantom: Default::default(), + } + } + + /// Publish a message right now even when the queue is full. + /// This may cause a subscriber to miss an older message. + pub fn publish_immediate(&self, message: T) { + self.channel.publish_immediate(message) + } + + /// Publish a message. But if the message queue is full, wait for all subscribers to have read the last message + pub fn publish<'s>(&'s self, message: T) -> PublisherWaitFuture<'s, 'a, PSB, T> { + PublisherWaitFuture { + message: Some(message), + publisher: self, + } + } + + /// Publish a message if there is space in the message queue + pub fn try_publish(&self, message: T) -> Result<(), T> { + self.channel.publish_with_context(message, None) + } +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Drop for Pub<'a, PSB, T> { + fn drop(&mut self) { + self.channel.unregister_publisher() + } +} + +/// A publisher that holds a dynamic reference to the channel +pub struct DynPublisher<'a, T: Clone>(pub(super) Pub<'a, dyn PubSubBehavior<T> + 'a, T>); + +impl<'a, T: Clone> Deref for DynPublisher<'a, T> { + type Target = Pub<'a, dyn PubSubBehavior<T> + 'a, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T: Clone> DerefMut for DynPublisher<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// A publisher that holds a generic reference to the channel +pub struct Publisher<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize>( + pub(super) Pub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>, +); + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> Deref + for Publisher<'a, M, T, CAP, SUBS, PUBS> +{ + type Target = Pub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> DerefMut + for Publisher<'a, M, T, CAP, SUBS, PUBS> +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// A publisher that can only use the `publish_immediate` function, but it doesn't have to be registered with the channel. +/// (So an infinite amount is possible) +pub struct ImmediatePub<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The channel we are a publisher for + channel: &'a PSB, + _phantom: PhantomData<T>, +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> ImmediatePub<'a, PSB, T> { + pub(super) fn new(channel: &'a PSB) -> Self { + Self { + channel, + _phantom: Default::default(), + } + } + /// Publish the message right now even when the queue is full. + /// This may cause a subscriber to miss an older message. + pub fn publish_immediate(&self, message: T) { + self.channel.publish_immediate(message) + } + + /// Publish a message if there is space in the message queue + pub fn try_publish(&self, message: T) -> Result<(), T> { + self.channel.publish_with_context(message, None) + } +} + +/// An immediate publisher that holds a dynamic reference to the channel +pub struct DynImmediatePublisher<'a, T: Clone>(pub(super) ImmediatePub<'a, dyn PubSubBehavior<T> + 'a, T>); + +impl<'a, T: Clone> Deref for DynImmediatePublisher<'a, T> { + type Target = ImmediatePub<'a, dyn PubSubBehavior<T> + 'a, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T: Clone> DerefMut for DynImmediatePublisher<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// An immediate publisher that holds a generic reference to the channel +pub struct ImmediatePublisher<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize>( + pub(super) ImmediatePub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>, +); + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> Deref + for ImmediatePublisher<'a, M, T, CAP, SUBS, PUBS> +{ + type Target = ImmediatePub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> DerefMut + for ImmediatePublisher<'a, M, T, CAP, SUBS, PUBS> +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// Future for the publisher wait action +pub struct PublisherWaitFuture<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The message we need to publish + message: Option<T>, + publisher: &'s Pub<'a, PSB, T>, +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Future for PublisherWaitFuture<'s, 'a, PSB, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let message = self.message.take().unwrap(); + match self.publisher.channel.publish_with_context(message, Some(cx)) { + Ok(()) => Poll::Ready(()), + Err(message) => { + self.message = Some(message); + Poll::Pending + } + } + } +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Unpin for PublisherWaitFuture<'s, 'a, PSB, T> {} diff --git a/embassy-sync/src/channel/pubsub/subscriber.rs b/embassy-sync/src/channel/pubsub/subscriber.rs new file mode 100644 index 00000000..b9a2cbe1 --- /dev/null +++ b/embassy-sync/src/channel/pubsub/subscriber.rs @@ -0,0 +1,152 @@ +//! Implementation of anything directly subscriber related + +use core::future::Future; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use super::{PubSubBehavior, PubSubChannel, WaitResult}; +use crate::blocking_mutex::raw::RawMutex; + +/// A subscriber to a channel +pub struct Sub<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The message id of the next message we are yet to receive + next_message_id: u64, + /// The channel we are a subscriber to + channel: &'a PSB, + _phantom: PhantomData<T>, +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Sub<'a, PSB, T> { + pub(super) fn new(next_message_id: u64, channel: &'a PSB) -> Self { + Self { + next_message_id, + channel, + _phantom: Default::default(), + } + } + + /// Wait for a published message + pub fn next_message<'s>(&'s mut self) -> SubscriberWaitFuture<'s, 'a, PSB, T> { + SubscriberWaitFuture { subscriber: self } + } + + /// Wait for a published message (ignoring lag results) + pub async fn next_message_pure(&mut self) -> T { + loop { + match self.next_message().await { + WaitResult::Lagged(_) => continue, + WaitResult::Message(message) => break message, + } + } + } + + /// Try to see if there's a published message we haven't received yet. + /// + /// This function does not peek. The message is received if there is one. + pub fn try_next_message(&mut self) -> Option<WaitResult<T>> { + match self.channel.get_message_with_context(&mut self.next_message_id, None) { + Poll::Ready(result) => Some(result), + Poll::Pending => None, + } + } + + /// Try to see if there's a published message we haven't received yet (ignoring lag results). + /// + /// This function does not peek. The message is received if there is one. + pub fn try_next_message_pure(&mut self) -> Option<T> { + loop { + match self.try_next_message() { + Some(WaitResult::Lagged(_)) => continue, + Some(WaitResult::Message(message)) => break Some(message), + None => break None, + } + } + } +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Drop for Sub<'a, PSB, T> { + fn drop(&mut self) { + self.channel.unregister_subscriber(self.next_message_id) + } +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Unpin for Sub<'a, PSB, T> {} + +/// Warning: The stream implementation ignores lag results and returns all messages. +/// This might miss some messages without you knowing it. +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> futures_util::Stream for Sub<'a, PSB, T> { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match self + .channel + .get_message_with_context(&mut self.next_message_id, Some(cx)) + { + Poll::Ready(WaitResult::Message(message)) => Poll::Ready(Some(message)), + Poll::Ready(WaitResult::Lagged(_)) => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => Poll::Pending, + } + } +} + +/// A subscriber that holds a dynamic reference to the channel +pub struct DynSubscriber<'a, T: Clone>(pub(super) Sub<'a, dyn PubSubBehavior<T> + 'a, T>); + +impl<'a, T: Clone> Deref for DynSubscriber<'a, T> { + type Target = Sub<'a, dyn PubSubBehavior<T> + 'a, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T: Clone> DerefMut for DynSubscriber<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// A subscriber that holds a generic reference to the channel +pub struct Subscriber<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize>( + pub(super) Sub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>, +); + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> Deref + for Subscriber<'a, M, T, CAP, SUBS, PUBS> +{ + type Target = Sub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> DerefMut + for Subscriber<'a, M, T, CAP, SUBS, PUBS> +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// Future for the subscriber wait action +pub struct SubscriberWaitFuture<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + subscriber: &'s mut Sub<'a, PSB, T>, +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Future for SubscriberWaitFuture<'s, 'a, PSB, T> { + type Output = WaitResult<T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.subscriber + .channel + .get_message_with_context(&mut self.subscriber.next_message_id, Some(cx)) + } +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Unpin for SubscriberWaitFuture<'s, 'a, PSB, T> {} diff --git a/embassy-sync/src/channel/signal.rs b/embassy-sync/src/channel/signal.rs new file mode 100644 index 00000000..9279266c --- /dev/null +++ b/embassy-sync/src/channel/signal.rs @@ -0,0 +1,100 @@ +//! A synchronization primitive for passing the latest value to a task. +use core::cell::UnsafeCell; +use core::future::Future; +use core::mem; +use core::task::{Context, Poll, Waker}; + +/// Single-slot signaling primitive. +/// +/// This is similar to a [`Channel`](crate::channel::mpmc::Channel) with a buffer size of 1, except +/// "sending" to it (calling [`Signal::signal`]) when full will overwrite the previous value instead +/// of waiting for the receiver to pop the previous value. +/// +/// It is useful for sending data between tasks when the receiver only cares about +/// the latest data, and therefore it's fine to "lose" messages. This is often the case for "state" +/// updates. +/// +/// For more advanced use cases, you might want to use [`Channel`](crate::channel::mpmc::Channel) instead. +/// +/// Signals are generally declared as `static`s and then borrowed as required. +/// +/// ``` +/// use embassy_sync::channel::signal::Signal; +/// +/// enum SomeCommand { +/// On, +/// Off, +/// } +/// +/// static SOME_SIGNAL: Signal<SomeCommand> = Signal::new(); +/// ``` +pub struct Signal<T> { + state: UnsafeCell<State<T>>, +} + +enum State<T> { + None, + Waiting(Waker), + Signaled(T), +} + +unsafe impl<T: Send> Send for Signal<T> {} +unsafe impl<T: Send> Sync for Signal<T> {} + +impl<T> Signal<T> { + /// Create a new `Signal`. + pub const fn new() -> Self { + Self { + state: UnsafeCell::new(State::None), + } + } +} + +impl<T: Send> Signal<T> { + /// Mark this Signal as signaled. + pub fn signal(&self, val: T) { + critical_section::with(|_| unsafe { + let state = &mut *self.state.get(); + if let State::Waiting(waker) = mem::replace(state, State::Signaled(val)) { + waker.wake(); + } + }) + } + + /// Remove the queued value in this `Signal`, if any. + pub fn reset(&self) { + critical_section::with(|_| unsafe { + let state = &mut *self.state.get(); + *state = State::None + }) + } + + /// Manually poll the Signal future. + pub fn poll_wait(&self, cx: &mut Context<'_>) -> Poll<T> { + critical_section::with(|_| unsafe { + let state = &mut *self.state.get(); + match state { + State::None => { + *state = State::Waiting(cx.waker().clone()); + Poll::Pending + } + State::Waiting(w) if w.will_wake(cx.waker()) => Poll::Pending, + State::Waiting(_) => panic!("waker overflow"), + State::Signaled(_) => match mem::replace(state, State::None) { + State::Signaled(res) => Poll::Ready(res), + _ => unreachable!(), + }, + } + }) + } + + /// Future that completes when this Signal has been signaled. + pub fn wait(&self) -> impl Future<Output = T> + '_ { + futures_util::future::poll_fn(move |cx| self.poll_wait(cx)) + } + + /// non-blocking method to check whether this signal has been signaled. + pub fn signaled(&self) -> bool { + critical_section::with(|_| matches!(unsafe { &*self.state.get() }, State::Signaled(_))) + } +} diff --git a/embassy-sync/src/fmt.rs b/embassy-sync/src/fmt.rs new file mode 100644 index 00000000..f8bb0a03 --- /dev/null +++ b/embassy-sync/src/fmt.rs @@ -0,0 +1,228 @@ +#![macro_use] +#![allow(unused_macros)] + +#[cfg(all(feature = "defmt", feature = "log"))] +compile_error!("You may not enable both `defmt` and `log` features."); + +macro_rules! assert { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::assert!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::assert!($($x)*); + } + }; +} + +macro_rules! assert_eq { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::assert_eq!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::assert_eq!($($x)*); + } + }; +} + +macro_rules! assert_ne { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::assert_ne!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::assert_ne!($($x)*); + } + }; +} + +macro_rules! debug_assert { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::debug_assert!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::debug_assert!($($x)*); + } + }; +} + +macro_rules! debug_assert_eq { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::debug_assert_eq!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::debug_assert_eq!($($x)*); + } + }; +} + +macro_rules! debug_assert_ne { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::debug_assert_ne!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::debug_assert_ne!($($x)*); + } + }; +} + +macro_rules! todo { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::todo!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::todo!($($x)*); + } + }; +} + +macro_rules! unreachable { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::unreachable!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::unreachable!($($x)*); + } + }; +} + +macro_rules! panic { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::panic!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::panic!($($x)*); + } + }; +} + +macro_rules! trace { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::trace!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::trace!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! debug { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::debug!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::debug!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! info { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::info!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::info!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! warn { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::warn!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::warn!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! error { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::error!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::error!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +#[cfg(feature = "defmt")] +macro_rules! unwrap { + ($($x:tt)*) => { + ::defmt::unwrap!($($x)*) + }; +} + +#[cfg(not(feature = "defmt"))] +macro_rules! unwrap { + ($arg:expr) => { + match $crate::fmt::Try::into_result($arg) { + ::core::result::Result::Ok(t) => t, + ::core::result::Result::Err(e) => { + ::core::panic!("unwrap of `{}` failed: {:?}", ::core::stringify!($arg), e); + } + } + }; + ($arg:expr, $($msg:expr),+ $(,)? ) => { + match $crate::fmt::Try::into_result($arg) { + ::core::result::Result::Ok(t) => t, + ::core::result::Result::Err(e) => { + ::core::panic!("unwrap of `{}` failed: {}: {:?}", ::core::stringify!($arg), ::core::format_args!($($msg,)*), e); + } + } + } +} + +#[cfg(feature = "defmt-timestamp-uptime")] +defmt::timestamp! {"{=u64:us}", crate::time::Instant::now().as_micros() } + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NoneError; + +pub trait Try { + type Ok; + type Error; + fn into_result(self) -> Result<Self::Ok, Self::Error>; +} + +impl<T> Try for Option<T> { + type Ok = T; + type Error = NoneError; + + #[inline] + fn into_result(self) -> Result<T, NoneError> { + self.ok_or(NoneError) + } +} + +impl<T, E> Try for Result<T, E> { + type Ok = T; + type Error = E; + + #[inline] + fn into_result(self) -> Self { + self + } +} diff --git a/embassy-sync/src/lib.rs b/embassy-sync/src/lib.rs new file mode 100644 index 00000000..7d881590 --- /dev/null +++ b/embassy-sync/src/lib.rs @@ -0,0 +1,17 @@ +#![cfg_attr(not(any(feature = "std", feature = "wasm")), no_std)] +#![cfg_attr(feature = "nightly", feature(generic_associated_types, type_alias_impl_trait))] +#![allow(clippy::new_without_default)] +#![doc = include_str!("../../README.md")] +#![warn(missing_docs)] + +// This mod MUST go first, so that the others see its macros. +pub(crate) mod fmt; + +// internal use +mod ring_buffer; + +pub mod blocking_mutex; +pub mod channel; +pub mod mutex; +pub mod pipe; +pub mod waitqueue; diff --git a/embassy-sync/src/mutex.rs b/embassy-sync/src/mutex.rs new file mode 100644 index 00000000..75a6e8dd --- /dev/null +++ b/embassy-sync/src/mutex.rs @@ -0,0 +1,167 @@ +//! Async mutex. +//! +//! This module provides a mutex that can be used to synchronize data between asynchronous tasks. +use core::cell::{RefCell, UnsafeCell}; +use core::ops::{Deref, DerefMut}; +use core::task::Poll; + +use futures_util::future::poll_fn; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex as BlockingMutex; +use crate::waitqueue::WakerRegistration; + +/// Error returned by [`Mutex::try_lock`] +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct TryLockError; + +struct State { + locked: bool, + waker: WakerRegistration, +} + +/// Async mutex. +/// +/// The mutex is generic over a blocking [`RawMutex`](crate::blocking_mutex::raw::RawMutex). +/// The raw mutex is used to guard access to the internal "is locked" flag. It +/// is held for very short periods only, while locking and unlocking. It is *not* held +/// for the entire time the async Mutex is locked. +/// +/// Which implementation you select depends on the context in which you're using the mutex. +/// +/// Use [`CriticalSectionRawMutex`](crate::blocking_mutex::raw::CriticalSectionRawMutex) when data can be shared between threads and interrupts. +/// +/// Use [`NoopRawMutex`](crate::blocking_mutex::raw::NoopRawMutex) when data is only shared between tasks running on the same executor. +/// +/// Use [`ThreadModeRawMutex`](crate::blocking_mutex::raw::ThreadModeRawMutex) when data is shared between tasks running on the same executor but you want a singleton. +/// +pub struct Mutex<M, T> +where + M: RawMutex, + T: ?Sized, +{ + state: BlockingMutex<M, RefCell<State>>, + inner: UnsafeCell<T>, +} + +unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for Mutex<M, T> {} +unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for Mutex<M, T> {} + +/// Async mutex. +impl<M, T> Mutex<M, T> +where + M: RawMutex, +{ + /// Create a new mutex with the given value. + pub const fn new(value: T) -> Self { + Self { + inner: UnsafeCell::new(value), + state: BlockingMutex::new(RefCell::new(State { + locked: false, + waker: WakerRegistration::new(), + })), + } + } +} + +impl<M, T> Mutex<M, T> +where + M: RawMutex, + T: ?Sized, +{ + /// Lock the mutex. + /// + /// This will wait for the mutex to be unlocked if it's already locked. + pub async fn lock(&self) -> MutexGuard<'_, M, T> { + poll_fn(|cx| { + let ready = self.state.lock(|s| { + let mut s = s.borrow_mut(); + if s.locked { + s.waker.register(cx.waker()); + false + } else { + s.locked = true; + true + } + }); + + if ready { + Poll::Ready(MutexGuard { mutex: self }) + } else { + Poll::Pending + } + }) + .await + } + + /// Attempt to immediately lock the mutex. + /// + /// If the mutex is already locked, this will return an error instead of waiting. + pub fn try_lock(&self) -> Result<MutexGuard<'_, M, T>, TryLockError> { + self.state.lock(|s| { + let mut s = s.borrow_mut(); + if s.locked { + Err(TryLockError) + } else { + s.locked = true; + Ok(()) + } + })?; + + Ok(MutexGuard { mutex: self }) + } +} + +/// Async mutex guard. +/// +/// Owning an instance of this type indicates having +/// successfully locked the mutex, and grants access to the contents. +/// +/// Dropping it unlocks the mutex. +pub struct MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + mutex: &'a Mutex<M, T>, +} + +impl<'a, M, T> Drop for MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn drop(&mut self) { + self.mutex.state.lock(|s| { + let mut s = s.borrow_mut(); + s.locked = false; + s.waker.wake(); + }) + } +} + +impl<'a, M, T> Deref for MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + type Target = T; + fn deref(&self) -> &Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &*(self.mutex.inner.get() as *const T) } + } +} + +impl<'a, M, T> DerefMut for MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &mut *(self.mutex.inner.get()) } + } +} diff --git a/embassy-sync/src/pipe.rs b/embassy-sync/src/pipe.rs new file mode 100644 index 00000000..7d64b648 --- /dev/null +++ b/embassy-sync/src/pipe.rs @@ -0,0 +1,551 @@ +//! Async byte stream pipe. + +use core::cell::RefCell; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex; +use crate::ring_buffer::RingBuffer; +use crate::waitqueue::WakerRegistration; + +/// Write-only access to a [`Pipe`]. +#[derive(Copy)] +pub struct Writer<'p, M, const N: usize> +where + M: RawMutex, +{ + pipe: &'p Pipe<M, N>, +} + +impl<'p, M, const N: usize> Clone for Writer<'p, M, N> +where + M: RawMutex, +{ + fn clone(&self) -> Self { + Writer { pipe: self.pipe } + } +} + +impl<'p, M, const N: usize> Writer<'p, M, N> +where + M: RawMutex, +{ + /// Writes a value. + /// + /// See [`Pipe::write()`] + pub fn write<'a>(&'a self, buf: &'a [u8]) -> WriteFuture<'a, M, N> { + self.pipe.write(buf) + } + + /// Attempt to immediately write a message. + /// + /// See [`Pipe::write()`] + pub fn try_write(&self, buf: &[u8]) -> Result<usize, TryWriteError> { + self.pipe.try_write(buf) + } +} + +/// Future returned by [`Pipe::write`] and [`Writer::write`]. +pub struct WriteFuture<'p, M, const N: usize> +where + M: RawMutex, +{ + pipe: &'p Pipe<M, N>, + buf: &'p [u8], +} + +impl<'p, M, const N: usize> Future for WriteFuture<'p, M, N> +where + M: RawMutex, +{ + type Output = usize; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.pipe.try_write_with_context(Some(cx), self.buf) { + Ok(n) => Poll::Ready(n), + Err(TryWriteError::Full) => Poll::Pending, + } + } +} + +impl<'p, M, const N: usize> Unpin for WriteFuture<'p, M, N> where M: RawMutex {} + +/// Read-only access to a [`Pipe`]. +#[derive(Copy)] +pub struct Reader<'p, M, const N: usize> +where + M: RawMutex, +{ + pipe: &'p Pipe<M, N>, +} + +impl<'p, M, const N: usize> Clone for Reader<'p, M, N> +where + M: RawMutex, +{ + fn clone(&self) -> Self { + Reader { pipe: self.pipe } + } +} + +impl<'p, M, const N: usize> Reader<'p, M, N> +where + M: RawMutex, +{ + /// Reads a value. + /// + /// See [`Pipe::read()`] + pub fn read<'a>(&'a self, buf: &'a mut [u8]) -> ReadFuture<'a, M, N> { + self.pipe.read(buf) + } + + /// Attempt to immediately read a message. + /// + /// See [`Pipe::read()`] + pub fn try_read(&self, buf: &mut [u8]) -> Result<usize, TryReadError> { + self.pipe.try_read(buf) + } +} + +/// Future returned by [`Pipe::read`] and [`Reader::read`]. +pub struct ReadFuture<'p, M, const N: usize> +where + M: RawMutex, +{ + pipe: &'p Pipe<M, N>, + buf: &'p mut [u8], +} + +impl<'p, M, const N: usize> Future for ReadFuture<'p, M, N> +where + M: RawMutex, +{ + type Output = usize; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.pipe.try_read_with_context(Some(cx), self.buf) { + Ok(n) => Poll::Ready(n), + Err(TryReadError::Empty) => Poll::Pending, + } + } +} + +impl<'p, M, const N: usize> Unpin for ReadFuture<'p, M, N> where M: RawMutex {} + +/// Error returned by [`try_read`](Pipe::try_read). +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum TryReadError { + /// No data could be read from the pipe because it is currently + /// empty, and reading would require blocking. + Empty, +} + +/// Error returned by [`try_write`](Pipe::try_write). +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum TryWriteError { + /// No data could be written to the pipe because it is + /// currently full, and writing would require blocking. + Full, +} + +struct PipeState<const N: usize> { + buffer: RingBuffer<N>, + read_waker: WakerRegistration, + write_waker: WakerRegistration, +} + +impl<const N: usize> PipeState<N> { + const fn new() -> Self { + PipeState { + buffer: RingBuffer::new(), + read_waker: WakerRegistration::new(), + write_waker: WakerRegistration::new(), + } + } + + fn clear(&mut self) { + self.buffer.clear(); + self.write_waker.wake(); + } + + fn try_read(&mut self, buf: &mut [u8]) -> Result<usize, TryReadError> { + self.try_read_with_context(None, buf) + } + + fn try_read_with_context(&mut self, cx: Option<&mut Context<'_>>, buf: &mut [u8]) -> Result<usize, TryReadError> { + if self.buffer.is_full() { + self.write_waker.wake(); + } + + let available = self.buffer.pop_buf(); + if available.is_empty() { + if let Some(cx) = cx { + self.read_waker.register(cx.waker()); + } + return Err(TryReadError::Empty); + } + + let n = available.len().min(buf.len()); + buf[..n].copy_from_slice(&available[..n]); + self.buffer.pop(n); + Ok(n) + } + + fn try_write(&mut self, buf: &[u8]) -> Result<usize, TryWriteError> { + self.try_write_with_context(None, buf) + } + + fn try_write_with_context(&mut self, cx: Option<&mut Context<'_>>, buf: &[u8]) -> Result<usize, TryWriteError> { + if self.buffer.is_empty() { + self.read_waker.wake(); + } + + let available = self.buffer.push_buf(); + if available.is_empty() { + if let Some(cx) = cx { + self.write_waker.register(cx.waker()); + } + return Err(TryWriteError::Full); + } + + let n = available.len().min(buf.len()); + available[..n].copy_from_slice(&buf[..n]); + self.buffer.push(n); + Ok(n) + } +} + +/// A bounded pipe for communicating between asynchronous tasks +/// with backpressure. +/// +/// The pipe will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to `write` new messages will wait until a message is +/// read from the pipe. +/// +/// All data written will become available in the same order as it was written. +pub struct Pipe<M, const N: usize> +where + M: RawMutex, +{ + inner: Mutex<M, RefCell<PipeState<N>>>, +} + +impl<M, const N: usize> Pipe<M, N> +where + M: RawMutex, +{ + /// Establish a new bounded pipe. For example, to create one with a NoopMutex: + /// + /// ``` + /// use embassy_sync::pipe::Pipe; + /// use embassy_sync::blocking_mutex::raw::NoopRawMutex; + /// + /// // Declare a bounded pipe, with a buffer of 256 bytes. + /// let mut pipe = Pipe::<NoopRawMutex, 256>::new(); + /// ``` + pub const fn new() -> Self { + Self { + inner: Mutex::new(RefCell::new(PipeState::new())), + } + } + + fn lock<R>(&self, f: impl FnOnce(&mut PipeState<N>) -> R) -> R { + self.inner.lock(|rc| f(&mut *rc.borrow_mut())) + } + + fn try_read_with_context(&self, cx: Option<&mut Context<'_>>, buf: &mut [u8]) -> Result<usize, TryReadError> { + self.lock(|c| c.try_read_with_context(cx, buf)) + } + + fn try_write_with_context(&self, cx: Option<&mut Context<'_>>, buf: &[u8]) -> Result<usize, TryWriteError> { + self.lock(|c| c.try_write_with_context(cx, buf)) + } + + /// Get a writer for this pipe. + pub fn writer(&self) -> Writer<'_, M, N> { + Writer { pipe: self } + } + + /// Get a reader for this pipe. + pub fn reader(&self) -> Reader<'_, M, N> { + Reader { pipe: self } + } + + /// Write a value, waiting until there is capacity. + /// + /// Writeing completes when the value has been pushed to the pipe's queue. + /// This doesn't mean the value has been read yet. + pub fn write<'a>(&'a self, buf: &'a [u8]) -> WriteFuture<'a, M, N> { + WriteFuture { pipe: self, buf } + } + + /// Attempt to immediately write a message. + /// + /// This method differs from [`write`](Pipe::write) by returning immediately if the pipe's + /// buffer is full, instead of waiting. + /// + /// # Errors + /// + /// If the pipe capacity has been reached, i.e., the pipe has `n` + /// buffered values where `n` is the argument passed to [`Pipe`], then an + /// error is returned. + pub fn try_write(&self, buf: &[u8]) -> Result<usize, TryWriteError> { + self.lock(|c| c.try_write(buf)) + } + + /// Receive the next value. + /// + /// If there are no messages in the pipe's buffer, this method will + /// wait until a message is written. + pub fn read<'a>(&'a self, buf: &'a mut [u8]) -> ReadFuture<'a, M, N> { + ReadFuture { pipe: self, buf } + } + + /// Attempt to immediately read a message. + /// + /// This method will either read a message from the pipe immediately or return an error + /// if the pipe is empty. + pub fn try_read(&self, buf: &mut [u8]) -> Result<usize, TryReadError> { + self.lock(|c| c.try_read(buf)) + } + + /// Clear the data in the pipe's buffer. + pub fn clear(&self) { + self.lock(|c| c.clear()) + } + + /// Return whether the pipe is full (no free space in the buffer) + pub fn is_full(&self) -> bool { + self.len() == N + } + + /// Return whether the pipe is empty (no data buffered) + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Total byte capacity. + /// + /// This is the same as the `N` generic param. + pub fn capacity(&self) -> usize { + N + } + + /// Used byte capacity. + pub fn len(&self) -> usize { + self.lock(|c| c.buffer.len()) + } + + /// Free byte capacity. + /// + /// This is equivalent to `capacity() - len()` + pub fn free_capacity(&self) -> usize { + N - self.len() + } +} + +#[cfg(feature = "nightly")] +mod io_impls { + use core::convert::Infallible; + + use futures_util::FutureExt; + + use super::*; + + impl<M: RawMutex, const N: usize> embedded_io::Io for Pipe<M, N> { + type Error = Infallible; + } + + impl<M: RawMutex, const N: usize> embedded_io::asynch::Read for Pipe<M, N> { + type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> + where + Self: 'a; + + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + Pipe::read(self, buf).map(Ok) + } + } + + impl<M: RawMutex, const N: usize> embedded_io::asynch::Write for Pipe<M, N> { + type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> + where + Self: 'a; + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { + Pipe::write(self, buf).map(Ok) + } + + type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> + where + Self: 'a; + + fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { + futures_util::future::ready(Ok(())) + } + } + + impl<M: RawMutex, const N: usize> embedded_io::Io for &Pipe<M, N> { + type Error = Infallible; + } + + impl<M: RawMutex, const N: usize> embedded_io::asynch::Read for &Pipe<M, N> { + type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> + where + Self: 'a; + + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + Pipe::read(self, buf).map(Ok) + } + } + + impl<M: RawMutex, const N: usize> embedded_io::asynch::Write for &Pipe<M, N> { + type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> + where + Self: 'a; + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { + Pipe::write(self, buf).map(Ok) + } + + type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> + where + Self: 'a; + + fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { + futures_util::future::ready(Ok(())) + } + } + + impl<M: RawMutex, const N: usize> embedded_io::Io for Reader<'_, M, N> { + type Error = Infallible; + } + + impl<M: RawMutex, const N: usize> embedded_io::asynch::Read for Reader<'_, M, N> { + type ReadFuture<'a> = impl Future<Output = Result<usize, Self::Error>> + where + Self: 'a; + + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + Reader::read(self, buf).map(Ok) + } + } + + impl<M: RawMutex, const N: usize> embedded_io::Io for Writer<'_, M, N> { + type Error = Infallible; + } + + impl<M: RawMutex, const N: usize> embedded_io::asynch::Write for Writer<'_, M, N> { + type WriteFuture<'a> = impl Future<Output = Result<usize, Self::Error>> + where + Self: 'a; + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { + Writer::write(self, buf).map(Ok) + } + + type FlushFuture<'a> = impl Future<Output = Result<(), Self::Error>> + where + Self: 'a; + + fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { + futures_util::future::ready(Ok(())) + } + } +} + +#[cfg(test)] +mod tests { + use futures_executor::ThreadPool; + use futures_util::task::SpawnExt; + use static_cell::StaticCell; + + use super::*; + use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex}; + + fn capacity<const N: usize>(c: &PipeState<N>) -> usize { + N - c.buffer.len() + } + + #[test] + fn writing_once() { + let mut c = PipeState::<3>::new(); + assert!(c.try_write(&[1]).is_ok()); + assert_eq!(capacity(&c), 2); + } + + #[test] + fn writing_when_full() { + let mut c = PipeState::<3>::new(); + assert_eq!(c.try_write(&[42]), Ok(1)); + assert_eq!(c.try_write(&[43]), Ok(1)); + assert_eq!(c.try_write(&[44]), Ok(1)); + assert_eq!(c.try_write(&[45]), Err(TryWriteError::Full)); + assert_eq!(capacity(&c), 0); + } + + #[test] + fn receiving_once_with_one_send() { + let mut c = PipeState::<3>::new(); + assert!(c.try_write(&[42]).is_ok()); + let mut buf = [0; 16]; + assert_eq!(c.try_read(&mut buf), Ok(1)); + assert_eq!(buf[0], 42); + assert_eq!(capacity(&c), 3); + } + + #[test] + fn receiving_when_empty() { + let mut c = PipeState::<3>::new(); + let mut buf = [0; 16]; + assert_eq!(c.try_read(&mut buf), Err(TryReadError::Empty)); + assert_eq!(capacity(&c), 3); + } + + #[test] + fn simple_send_and_receive() { + let c = Pipe::<NoopRawMutex, 3>::new(); + assert!(c.try_write(&[42]).is_ok()); + let mut buf = [0; 16]; + assert_eq!(c.try_read(&mut buf), Ok(1)); + assert_eq!(buf[0], 42); + } + + #[test] + fn cloning() { + let c = Pipe::<NoopRawMutex, 3>::new(); + let r1 = c.reader(); + let w1 = c.writer(); + + let _ = r1.clone(); + let _ = w1.clone(); + } + + #[futures_test::test] + async fn receiver_receives_given_try_write_async() { + let executor = ThreadPool::new().unwrap(); + + static CHANNEL: StaticCell<Pipe<CriticalSectionRawMutex, 3>> = StaticCell::new(); + let c = &*CHANNEL.init(Pipe::new()); + let c2 = c; + let f = async move { + assert_eq!(c2.try_write(&[42]), Ok(1)); + }; + executor.spawn(f).unwrap(); + let mut buf = [0; 16]; + assert_eq!(c.read(&mut buf).await, 1); + assert_eq!(buf[0], 42); + } + + #[futures_test::test] + async fn sender_send_completes_if_capacity() { + let c = Pipe::<CriticalSectionRawMutex, 1>::new(); + c.write(&[42]).await; + let mut buf = [0; 16]; + assert_eq!(c.read(&mut buf).await, 1); + assert_eq!(buf[0], 42); + } +} diff --git a/embassy-sync/src/ring_buffer.rs b/embassy-sync/src/ring_buffer.rs new file mode 100644 index 00000000..52108402 --- /dev/null +++ b/embassy-sync/src/ring_buffer.rs @@ -0,0 +1,146 @@ +pub struct RingBuffer<const N: usize> { + buf: [u8; N], + start: usize, + end: usize, + empty: bool, +} + +impl<const N: usize> RingBuffer<N> { + pub const fn new() -> Self { + Self { + buf: [0; N], + start: 0, + end: 0, + empty: true, + } + } + + pub fn push_buf(&mut self) -> &mut [u8] { + if self.start == self.end && !self.empty { + trace!(" ringbuf: push_buf empty"); + return &mut self.buf[..0]; + } + + let n = if self.start <= self.end { + self.buf.len() - self.end + } else { + self.start - self.end + }; + + trace!(" ringbuf: push_buf {:?}..{:?}", self.end, self.end + n); + &mut self.buf[self.end..self.end + n] + } + + pub fn push(&mut self, n: usize) { + trace!(" ringbuf: push {:?}", n); + if n == 0 { + return; + } + + self.end = self.wrap(self.end + n); + self.empty = false; + } + + pub fn pop_buf(&mut self) -> &mut [u8] { + if self.empty { + trace!(" ringbuf: pop_buf empty"); + return &mut self.buf[..0]; + } + + let n = if self.end <= self.start { + self.buf.len() - self.start + } else { + self.end - self.start + }; + + trace!(" ringbuf: pop_buf {:?}..{:?}", self.start, self.start + n); + &mut self.buf[self.start..self.start + n] + } + + pub fn pop(&mut self, n: usize) { + trace!(" ringbuf: pop {:?}", n); + if n == 0 { + return; + } + + self.start = self.wrap(self.start + n); + self.empty = self.start == self.end; + } + + pub fn is_full(&self) -> bool { + self.start == self.end && !self.empty + } + + pub fn is_empty(&self) -> bool { + self.empty + } + + #[allow(unused)] + pub fn len(&self) -> usize { + if self.empty { + 0 + } else if self.start < self.end { + self.end - self.start + } else { + N + self.end - self.start + } + } + + pub fn clear(&mut self) { + self.start = 0; + self.end = 0; + self.empty = true; + } + + fn wrap(&self, n: usize) -> usize { + assert!(n <= self.buf.len()); + if n == self.buf.len() { + 0 + } else { + n + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn push_pop() { + let mut rb: RingBuffer<4> = RingBuffer::new(); + let buf = rb.push_buf(); + assert_eq!(4, buf.len()); + buf[0] = 1; + buf[1] = 2; + buf[2] = 3; + buf[3] = 4; + rb.push(4); + + let buf = rb.pop_buf(); + assert_eq!(4, buf.len()); + assert_eq!(1, buf[0]); + rb.pop(1); + + let buf = rb.pop_buf(); + assert_eq!(3, buf.len()); + assert_eq!(2, buf[0]); + rb.pop(1); + + let buf = rb.pop_buf(); + assert_eq!(2, buf.len()); + assert_eq!(3, buf[0]); + rb.pop(1); + + let buf = rb.pop_buf(); + assert_eq!(1, buf.len()); + assert_eq!(4, buf[0]); + rb.pop(1); + + let buf = rb.pop_buf(); + assert_eq!(0, buf.len()); + + let buf = rb.push_buf(); + assert_eq!(4, buf.len()); + } +} diff --git a/embassy-sync/src/waitqueue/mod.rs b/embassy-sync/src/waitqueue/mod.rs new file mode 100644 index 00000000..6661a6b6 --- /dev/null +++ b/embassy-sync/src/waitqueue/mod.rs @@ -0,0 +1,7 @@ +//! Async low-level wait queues + +mod waker; +pub use waker::*; + +mod multi_waker; +pub use multi_waker::*; diff --git a/embassy-sync/src/waitqueue/multi_waker.rs b/embassy-sync/src/waitqueue/multi_waker.rs new file mode 100644 index 00000000..325d2cb3 --- /dev/null +++ b/embassy-sync/src/waitqueue/multi_waker.rs @@ -0,0 +1,33 @@ +use core::task::Waker; + +use super::WakerRegistration; + +/// Utility struct to register and wake multiple wakers. +pub struct MultiWakerRegistration<const N: usize> { + wakers: [WakerRegistration; N], +} + +impl<const N: usize> MultiWakerRegistration<N> { + /// Create a new empty instance + pub const fn new() -> Self { + const WAKER: WakerRegistration = WakerRegistration::new(); + Self { wakers: [WAKER; N] } + } + + /// Register a waker. If the buffer is full the function returns it in the error + pub fn register<'a>(&mut self, w: &'a Waker) -> Result<(), &'a Waker> { + if let Some(waker_slot) = self.wakers.iter_mut().find(|waker_slot| !waker_slot.occupied()) { + waker_slot.register(w); + Ok(()) + } else { + Err(w) + } + } + + /// Wake all registered wakers. This clears the buffer + pub fn wake(&mut self) { + for waker_slot in self.wakers.iter_mut() { + waker_slot.wake() + } + } +} diff --git a/embassy-sync/src/waitqueue/waker.rs b/embassy-sync/src/waitqueue/waker.rs new file mode 100644 index 00000000..64e300eb --- /dev/null +++ b/embassy-sync/src/waitqueue/waker.rs @@ -0,0 +1,92 @@ +use core::cell::Cell; +use core::mem; +use core::task::Waker; + +use crate::blocking_mutex::raw::CriticalSectionRawMutex; +use crate::blocking_mutex::Mutex; + +/// Utility struct to register and wake a waker. +#[derive(Debug)] +pub struct WakerRegistration { + waker: Option<Waker>, +} + +impl WakerRegistration { + /// Create a new `WakerRegistration`. + pub const fn new() -> Self { + Self { waker: None } + } + + /// Register a waker. Overwrites the previous waker, if any. + pub fn register(&mut self, w: &Waker) { + match self.waker { + // Optimization: If both the old and new Wakers wake the same task, we can simply + // keep the old waker, skipping the clone. (In most executor implementations, + // cloning a waker is somewhat expensive, comparable to cloning an Arc). + Some(ref w2) if (w2.will_wake(w)) => {} + _ => { + // clone the new waker and store it + if let Some(old_waker) = mem::replace(&mut self.waker, Some(w.clone())) { + // We had a waker registered for another task. Wake it, so the other task can + // reregister itself if it's still interested. + // + // If two tasks are waiting on the same thing concurrently, this will cause them + // to wake each other in a loop fighting over this WakerRegistration. This wastes + // CPU but things will still work. + // + // If the user wants to have two tasks waiting on the same thing they should use + // a more appropriate primitive that can store multiple wakers. + old_waker.wake() + } + } + } + } + + /// Wake the registered waker, if any. + pub fn wake(&mut self) { + if let Some(w) = self.waker.take() { + w.wake() + } + } + + /// Returns true if a waker is currently registered + pub fn occupied(&self) -> bool { + self.waker.is_some() + } +} + +/// Utility struct to register and wake a waker. +pub struct AtomicWaker { + waker: Mutex<CriticalSectionRawMutex, Cell<Option<Waker>>>, +} + +impl AtomicWaker { + /// Create a new `AtomicWaker`. + pub const fn new() -> Self { + Self { + waker: Mutex::const_new(CriticalSectionRawMutex::new(), Cell::new(None)), + } + } + + /// Register a waker. Overwrites the previous waker, if any. + pub fn register(&self, w: &Waker) { + critical_section::with(|cs| { + let cell = self.waker.borrow(cs); + cell.set(match cell.replace(None) { + Some(w2) if (w2.will_wake(w)) => Some(w2), + _ => Some(w.clone()), + }) + }) + } + + /// Wake the registered waker, if any. + pub fn wake(&self) { + critical_section::with(|cs| { + let cell = self.waker.borrow(cs); + if let Some(w) = cell.replace(None) { + w.wake_by_ref(); + cell.set(Some(w)); + } + }) + } +} |