summaryrefslogtreecommitdiff
path: root/embassy-usb
diff options
context:
space:
mode:
authoralexmoon <alex.r.moon@gmail.com>2022-03-30 14:17:15 -0400
committerDario Nieuwenhuis <dirbaio@dirbaio.net>2022-04-06 05:38:11 +0200
commitf5ba022257ccd9ddd371f1dcd10c0775cc5a3110 (patch)
treece48804d72e8a936a026a3479c49fede8a8352f9 /embassy-usb
parent77e0aca03b89ebc5f1e93b6c64b6c91ca10cedd1 (diff)
downloadembassy-f5ba022257ccd9ddd371f1dcd10c0775cc5a3110.zip
Refactor ControlPipe to use the typestate pattern for safety
Diffstat (limited to 'embassy-usb')
-rw-r--r--embassy-usb/src/control.rs121
-rw-r--r--embassy-usb/src/lib.rs158
2 files changed, 157 insertions, 122 deletions
diff --git a/embassy-usb/src/control.rs b/embassy-usb/src/control.rs
index b5077c73..9f1115ff 100644
--- a/embassy-usb/src/control.rs
+++ b/embassy-usb/src/control.rs
@@ -1,5 +1,7 @@
use core::mem;
+use crate::descriptor::DescriptorWriter;
+use crate::driver::{self, ReadError};
use crate::DEFAULT_ALTERNATE_SETTING;
use super::types::*;
@@ -191,3 +193,122 @@ pub trait ControlHandler {
InResponse::Accepted(&buf[0..2])
}
}
+
+/// Typestate representing a ControlPipe in the DATA IN stage
+#[derive(Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub(crate) struct DataInStage {
+ length: usize,
+}
+
+/// Typestate representing a ControlPipe in the DATA OUT stage
+#[derive(Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub(crate) struct DataOutStage {
+ length: usize,
+}
+
+/// Typestate representing a ControlPipe in the STATUS stage
+#[derive(Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub(crate) struct StatusStage {}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub(crate) enum Setup {
+ DataIn(Request, DataInStage),
+ DataOut(Request, DataOutStage),
+}
+
+pub(crate) struct ControlPipe<C: driver::ControlPipe> {
+ control: C,
+}
+
+impl<C: driver::ControlPipe> ControlPipe<C> {
+ pub(crate) fn new(control: C) -> Self {
+ ControlPipe { control }
+ }
+
+ pub(crate) async fn setup(&mut self) -> Setup {
+ let req = self.control.setup().await;
+ match (req.direction, req.length) {
+ (UsbDirection::Out, n) => Setup::DataOut(
+ req,
+ DataOutStage {
+ length: usize::from(n),
+ },
+ ),
+ (UsbDirection::In, n) => Setup::DataIn(
+ req,
+ DataInStage {
+ length: usize::from(n),
+ },
+ ),
+ }
+ }
+
+ pub(crate) async fn data_out<'a>(
+ &mut self,
+ buf: &'a mut [u8],
+ stage: DataOutStage,
+ ) -> Result<(&'a [u8], StatusStage), ReadError> {
+ if stage.length == 0 {
+ Ok((&[], StatusStage {}))
+ } else {
+ let req_length = stage.length;
+ let max_packet_size = self.control.max_packet_size();
+ let mut total = 0;
+
+ for chunk in buf.chunks_mut(max_packet_size) {
+ let size = self.control.data_out(chunk).await?;
+ total += size;
+ if size < max_packet_size || total == req_length {
+ break;
+ }
+ }
+
+ Ok((&buf[0..total], StatusStage {}))
+ }
+ }
+
+ pub(crate) async fn accept_in(&mut self, buf: &[u8], stage: DataInStage) {
+ #[cfg(feature = "defmt")]
+ debug!("control in accept {:x}", buf);
+ #[cfg(not(feature = "defmt"))]
+ debug!("control in accept {:x?}", buf);
+
+ let req_len = stage.length;
+ let len = buf.len().min(req_len);
+ let max_packet_size = self.control.max_packet_size();
+ let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0;
+
+ let mut chunks = buf[0..len]
+ .chunks(max_packet_size)
+ .chain(need_zlp.then(|| -> &[u8] { &[] }));
+
+ while let Some(chunk) = chunks.next() {
+ self.control.data_in(chunk, chunks.size_hint().0 == 0).await;
+ }
+ }
+
+ pub(crate) async fn accept_in_writer(
+ &mut self,
+ req: Request,
+ stage: DataInStage,
+ f: impl FnOnce(&mut DescriptorWriter),
+ ) {
+ let mut buf = [0; 256];
+ let mut w = DescriptorWriter::new(&mut buf);
+ f(&mut w);
+ let pos = w.position().min(usize::from(req.length));
+ self.accept_in(&buf[..pos], stage).await
+ }
+
+ pub(crate) fn accept(&mut self, _: StatusStage) {
+ self.control.accept();
+ }
+
+ pub(crate) fn reject(&mut self) {
+ self.control.reject();
+ }
+}
diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs
index 77a9c33b..067b5b07 100644
--- a/embassy-usb/src/lib.rs
+++ b/embassy-usb/src/lib.rs
@@ -16,7 +16,7 @@ use heapless::Vec;
use self::control::*;
use self::descriptor::*;
-use self::driver::*;
+use self::driver::{Bus, Driver, Event};
use self::types::*;
use self::util::*;
@@ -92,10 +92,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
Self {
bus: driver,
config,
- control: ControlPipe {
- control,
- request: None,
- },
+ control: ControlPipe::new(control),
device_descriptor,
config_descriptor,
bos_descriptor,
@@ -134,57 +131,50 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
Either::Right(req) => {
debug!("control request: {:x}", req);
- match req.direction {
- UsbDirection::In => self.handle_control_in(req).await,
- UsbDirection::Out => self.handle_control_out(req).await,
+ match req {
+ Setup::DataIn(req, stage) => self.handle_control_in(req, stage).await,
+ Setup::DataOut(req, stage) => self.handle_control_out(req, stage).await,
}
}
}
}
}
- async fn handle_control_out(&mut self, req: Request) {
+ async fn handle_control_out(&mut self, req: Request, stage: DataOutStage) {
const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16;
const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16;
- // If the request has a data state, we must read it.
- let data = if req.length > 0 {
- match self.control.data_out(self.control_buf).await {
- Ok(data) => data,
- Err(_) => {
- warn!("usb: failed to read CONTROL OUT data stage.");
- return;
- }
+ let (data, stage) = match self.control.data_out(self.control_buf, stage).await {
+ Ok(data) => data,
+ Err(_) => {
+ warn!("usb: failed to read CONTROL OUT data stage.");
+ return;
}
- } else {
- &[]
};
match (req.request_type, req.recipient) {
(RequestType::Standard, Recipient::Device) => match (req.request, req.value) {
(Request::CLEAR_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => {
self.remote_wakeup_enabled = false;
- self.control.accept();
+ self.control.accept(stage)
}
(Request::SET_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => {
self.remote_wakeup_enabled = true;
- self.control.accept();
+ self.control.accept(stage)
}
(Request::SET_ADDRESS, 1..=127) => {
self.pending_address = req.value as u8;
- self.control.accept();
+ self.control.accept(stage)
}
(Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => {
self.device_state = UsbDeviceState::Configured;
- self.control.accept();
+ self.control.accept(stage)
}
(Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => match self.device_state {
- UsbDeviceState::Default => {
- self.control.accept();
- }
+ UsbDeviceState::Default => self.control.accept(stage),
_ => {
self.device_state = UsbDeviceState::Addressed;
- self.control.accept();
+ self.control.accept(stage)
}
},
_ => self.control.reject(),
@@ -193,12 +183,12 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
(Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => {
let ep_addr = ((req.index as u8) & 0x8f).into();
self.bus.set_stalled(ep_addr, true);
- self.control.accept();
+ self.control.accept(stage)
}
(Request::CLEAR_FEATURE, Request::FEATURE_ENDPOINT_HALT) => {
let ep_addr = ((req.index as u8) & 0x8f).into();
self.bus.set_stalled(ep_addr, false);
- self.control.accept();
+ self.control.accept(stage)
}
_ => self.control.reject(),
},
@@ -218,7 +208,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
_ => handler.control_out(req, data),
};
match response {
- OutResponse::Accepted => self.control.accept(),
+ OutResponse::Accepted => self.control.accept(stage),
OutResponse::Rejected => self.control.reject(),
}
}
@@ -229,7 +219,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
}
}
- async fn handle_control_in(&mut self, req: Request) {
+ async fn handle_control_in(&mut self, req: Request, stage: DataInStage) {
match (req.request_type, req.recipient) {
(RequestType::Standard, Recipient::Device) => match req.request {
Request::GET_STATUS => {
@@ -240,17 +230,15 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
if self.remote_wakeup_enabled {
status |= 0x0002;
}
- self.control.accept_in(&status.to_le_bytes()).await;
- }
- Request::GET_DESCRIPTOR => {
- self.handle_get_descriptor(req).await;
+ self.control.accept_in(&status.to_le_bytes(), stage).await
}
+ Request::GET_DESCRIPTOR => self.handle_get_descriptor(req, stage).await,
Request::GET_CONFIGURATION => {
let status = match self.device_state {
UsbDeviceState::Configured => CONFIGURATION_VALUE,
_ => CONFIGURATION_NONE,
};
- self.control.accept_in(&status.to_le_bytes()).await;
+ self.control.accept_in(&status.to_le_bytes(), stage).await
}
_ => self.control.reject(),
},
@@ -261,7 +249,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
if self.bus.is_stalled(ep_addr) {
status |= 0x0001;
}
- self.control.accept_in(&status.to_le_bytes()).await;
+ self.control.accept_in(&status.to_le_bytes(), stage).await
}
_ => self.control.reject(),
},
@@ -285,7 +273,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
};
match response {
- InResponse::Accepted(data) => self.control.accept_in(data).await,
+ InResponse::Accepted(data) => self.control.accept_in(data, stage).await,
InResponse::Rejected => self.control.reject(),
}
}
@@ -296,17 +284,19 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
}
}
- async fn handle_get_descriptor(&mut self, req: Request) {
+ async fn handle_get_descriptor(&mut self, req: Request, stage: DataInStage) {
let (dtype, index) = req.descriptor_type_index();
match dtype {
- descriptor_type::BOS => self.control.accept_in(self.bos_descriptor).await,
- descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor).await,
- descriptor_type::CONFIGURATION => self.control.accept_in(self.config_descriptor).await,
+ descriptor_type::BOS => self.control.accept_in(self.bos_descriptor, stage).await,
+ descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor, stage).await,
+ descriptor_type::CONFIGURATION => {
+ self.control.accept_in(self.config_descriptor, stage).await
+ }
descriptor_type::STRING => {
if index == 0 {
self.control
- .accept_in_writer(req, |w| {
+ .accept_in_writer(req, stage, |w| {
w.write(descriptor_type::STRING, &lang_id::ENGLISH_US.to_le_bytes());
})
.await
@@ -324,7 +314,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
};
if let Some(s) = s {
- self.control.accept_in_writer(req, |w| w.string(s)).await;
+ self.control
+ .accept_in_writer(req, stage, |w| w.string(s))
+ .await
} else {
self.control.reject()
}
@@ -334,81 +326,3 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
}
}
}
-
-struct ControlPipe<C: driver::ControlPipe> {
- control: C,
- request: Option<Request>,
-}
-
-impl<C: driver::ControlPipe> ControlPipe<C> {
- async fn setup(&mut self) -> Request {
- assert!(self.request.is_none());
- let req = self.control.setup().await;
- self.request = Some(req);
- req
- }
-
- async fn data_out<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], ReadError> {
- let req = self.request.unwrap();
- assert_eq!(req.direction, UsbDirection::Out);
- assert!(req.length > 0);
- let req_length = usize::from(req.length);
-
- let max_packet_size = self.control.max_packet_size();
- let mut total = 0;
-
- for chunk in buf.chunks_mut(max_packet_size) {
- let size = self.control.data_out(chunk).await?;
- total += size;
- if size < max_packet_size || total == req_length {
- break;
- }
- }
-
- Ok(&buf[0..total])
- }
-
- async fn accept_in(&mut self, buf: &[u8]) -> () {
- #[cfg(feature = "defmt")]
- debug!("control in accept {:x}", buf);
- #[cfg(not(feature = "defmt"))]
- debug!("control in accept {:x?}", buf);
- let req = unwrap!(self.request);
- assert!(req.direction == UsbDirection::In);
-
- let req_len = usize::from(req.length);
- let len = buf.len().min(req_len);
- let max_packet_size = self.control.max_packet_size();
- let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0;
-
- let mut chunks = buf[0..len]
- .chunks(max_packet_size)
- .chain(need_zlp.then(|| -> &[u8] { &[] }));
-
- while let Some(chunk) = chunks.next() {
- self.control.data_in(chunk, chunks.size_hint().0 == 0).await;
- }
-
- self.request = None;
- }
-
- async fn accept_in_writer(&mut self, req: Request, f: impl FnOnce(&mut DescriptorWriter)) {
- let mut buf = [0; 256];
- let mut w = DescriptorWriter::new(&mut buf);
- f(&mut w);
- let pos = w.position().min(usize::from(req.length));
- self.accept_in(&buf[..pos]).await;
- }
-
- fn accept(&mut self) {
- assert!(self.request.is_some());
- self.control.accept();
- self.request = None;
- }
-
- fn reject(&mut self) {
- assert!(self.request.is_some());
- self.control.reject();
- self.request = None;
- }
-}