diff --git a/bbqtest/Cargo.toml b/bbqtest/Cargo.toml index c3801ab..411d6ad 100644 --- a/bbqtest/Cargo.toml +++ b/bbqtest/Cargo.toml @@ -22,6 +22,7 @@ crossbeam-utils = "0.7" crossbeam = "0.7" heapless = "0.5" cfg-if = "0.1" +futures = "0.3" [[bench]] name = "benches" diff --git a/bbqtest/src/async_usage.rs b/bbqtest/src/async_usage.rs new file mode 100644 index 0000000..71d0cfa --- /dev/null +++ b/bbqtest/src/async_usage.rs @@ -0,0 +1,47 @@ +#[cfg(test)] +mod tests { + use bbqueue::{consts::*, BBBuffer}; + use futures::executor::block_on; + + #[test] + fn test_read() { + let bb: BBBuffer = BBBuffer::new(); + let (mut prod, mut cons) = bb.try_split().unwrap(); + + { + let mut grant = prod.grant_exact(4).unwrap(); + let buf = grant.buf(); + buf[0] = 0xDE; + buf[1] = 0xAD; + buf[2] = 0xC0; + buf[3] = 0xDE; + grant.commit(4); + } + + let mut rx_buf = [0; 4]; + let result = block_on(cons.read_async(&mut rx_buf)); + + assert_eq!(4, result.unwrap()); + assert_eq!(rx_buf[0], 0xDE); + assert_eq!(rx_buf[1], 0xAD); + assert_eq!(rx_buf[2], 0xC0); + assert_eq!(rx_buf[3], 0xDE); + } + + #[test] + fn test_write() { + let bb: BBBuffer = BBBuffer::new(); + let (mut prod, mut cons) = bb.try_split().unwrap(); + + let result = block_on(prod.write_async(&[0xDE, 0xAD, 0xC0, 0xDE])); + assert_eq!(4, result.unwrap()); + + let grant = cons.read().unwrap(); + let rx_buf = grant.buf(); + assert_eq!(4, rx_buf.len()); + assert_eq!(rx_buf[0], 0xDE); + assert_eq!(rx_buf[1], 0xAD); + assert_eq!(rx_buf[2], 0xC0); + assert_eq!(rx_buf[3], 0xDE); + } +} diff --git a/bbqtest/src/lib.rs b/bbqtest/src/lib.rs index 63786aa..c529974 100644 --- a/bbqtest/src/lib.rs +++ b/bbqtest/src/lib.rs @@ -1,6 +1,7 @@ //! NOTE: this crate is really just a shim for testing //! the other no-std crate. +mod async_usage; mod framed; mod multi_thread; mod ring_around_the_senders; diff --git a/core/src/bbbuffer.rs b/core/src/bbbuffer.rs index 6a8fa29..ef3c1d1 100644 --- a/core/src/bbbuffer.rs +++ b/core/src/bbbuffer.rs @@ -1,13 +1,16 @@ use crate::{ framed::{FrameConsumer, FrameProducer}, + signal::Signal, Error, Result, }; use core::{ cell::UnsafeCell, cmp::min, + future::Future, marker::PhantomData, mem::{forget, transmute, MaybeUninit}, ops::{Deref, DerefMut}, + pin::Pin, ptr::NonNull, result::Result as CoreResult, slice::from_raw_parts_mut, @@ -15,6 +18,7 @@ use core::{ AtomicBool, AtomicUsize, Ordering::{AcqRel, Acquire, Release}, }, + task::{Context, Poll}, }; pub use generic_array::typenum::consts; use generic_array::{ArrayLength, GenericArray}; @@ -239,6 +243,12 @@ pub struct ConstBBBuffer { /// Have we already split? already_split: AtomicBool, + + /// Waker for async producer. + producer_waker: Signal<()>, + + /// Waker for async consumer. + consumer_waker: Signal<()>, } impl ConstBBBuffer { @@ -293,6 +303,12 @@ impl ConstBBBuffer { /// We haven't split at the start already_split: AtomicBool::new(false), + + /// Consumer waker + consumer_waker: Signal::new(), + + /// Producer waker + producer_waker: Signal::new(), } } } @@ -331,6 +347,17 @@ where unsafe impl<'a, N> Send for Producer<'a, N> where N: ArrayLength {} +/// TODO: Documentation for struct +pub struct AsyncWrite<'a, N> +where + N: ArrayLength, +{ + producer: &'a mut Producer<'a, N>, + buffer: &'a [u8], + remaining: usize, + cancelled: bool, +} + impl<'a, N> Producer<'a, N> where N: ArrayLength, @@ -529,6 +556,90 @@ where to_commit: 0, }) } + + /// Asynchronously write data and complete when write is finished. The buffer must outlive the future + /// that is returned. + pub fn write_async(&'a mut self, buffer: &'a [u8]) -> AsyncWrite<'a, N> { + AsyncWrite::new(self, buffer) + } + + fn poll(&self, cx: &mut Context<'_>) -> Poll<()> { + let inner = unsafe { &self.bbq.as_ref().0 }; + inner.producer_waker.poll_wait(cx) + } + + fn notify_consumer(&self) { + let inner = unsafe { &self.bbq.as_ref().0 }; + inner.consumer_waker.signal(()) + } +} + +impl<'a, N> AsyncWrite<'a, N> +where + N: ArrayLength, +{ + fn new(producer: &'a mut Producer<'a, N>, buffer: &'a [u8]) -> AsyncWrite<'a, N> { + Self { + producer, + remaining: buffer.len(), + buffer, + cancelled: false, + } + } + + /// Signal that this future is cancelled and should no longer poll data + pub fn cancel(&mut self) { + self.cancelled = true; + } +} + +impl<'a, N> Future for AsyncWrite<'a, N> +where + N: ArrayLength, +{ + type Output = Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.cancelled { + Poll::Ready(Ok(self.buffer.len() - self.remaining)) + } else { + loop { + let remaining = self.remaining; + match self.producer.grant_max_remaining(remaining) { + Ok(mut grant) => { + let buf = grant.buf(); + + let wp = self.buffer.len() - self.remaining; + let to_copy = core::cmp::min(self.remaining, buf.len()); + + buf[..to_copy].copy_from_slice(&self.buffer[wp..wp + to_copy]); + + self.remaining -= to_copy; + grant.commit(to_copy); + + self.producer.notify_consumer(); + + if self.remaining == 0 { + return Poll::Ready(Ok(self.buffer.len())); + } else { + match self.producer.poll(cx) { + Poll::Pending => { + return Poll::Pending; + } + _ => {} + } + } + } + Err(Error::InsufficientSize) => match self.producer.poll(cx) { + Poll::Pending => { + return Poll::Pending; + } + _ => {} + }, + Err(e) => return Poll::Ready(Err(e)), + } + } + } + } } /// `Consumer` is the primary interface for reading data from a `BBBuffer`. @@ -542,6 +653,83 @@ where unsafe impl<'a, N> Send for Consumer<'a, N> where N: ArrayLength {} +/// TODO: Documentation for struct +pub struct AsyncRead<'a, N> +where + N: ArrayLength, +{ + consumer: &'a mut Consumer<'a, N>, + buffer: &'a mut [u8], + remaining: usize, + cancelled: bool, +} + +impl<'a, N> AsyncRead<'a, N> +where + N: ArrayLength, +{ + fn new(consumer: &'a mut Consumer<'a, N>, buffer: &'a mut [u8]) -> Self { + Self { + consumer, + cancelled: false, + remaining: buffer.len(), + buffer, + } + } + + /// Signal that this future is cancelled and should no longer poll data + pub fn cancel(&mut self) { + self.cancelled = true; + } +} + +impl<'a, N> Future for AsyncRead<'a, N> +where + N: ArrayLength, +{ + type Output = Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.cancelled { + Poll::Ready(Ok(self.buffer.len() - self.remaining)) + } else { + loop { + match self.consumer.read() { + Ok(grant) => { + let buf = grant.buf(); + let rp = self.buffer.len() - self.remaining; + let to_copy = core::cmp::min(self.remaining, buf.len()); + + self.buffer[rp..rp + to_copy].copy_from_slice(&buf[..to_copy]); + self.remaining -= to_copy; + grant.release(to_copy); + + self.consumer.notify_producer(); + + if self.remaining == 0 { + return Poll::Ready(Ok(rp + to_copy)); + } else { + match self.consumer.poll(cx) { + Poll::Pending => { + return Poll::Pending; + } + _ => {} + } + } + } + // If there was no data available, but we got signaled in the meantime, try again + Err(Error::InsufficientSize) => match self.consumer.poll(cx) { + Poll::Pending => { + return Poll::Pending; + } + _ => {} + }, + Err(e) => return Poll::Ready(Err(e)), + } + } + } + } +} + impl<'a, N> Consumer<'a, N> where N: ArrayLength, @@ -680,6 +868,22 @@ where to_release: 0, }) } + + /// Asynchronously read data into a provided buffer until the buffer is filled. The buffer must outlive the future + /// that is returned. + pub fn read_async(&'a mut self, rx_buffer: &'a mut [u8]) -> AsyncRead<'a, N> { + AsyncRead::new(self, rx_buffer) + } + + fn poll(&self, cx: &mut Context<'_>) -> Poll<()> { + let inner = unsafe { &self.bbq.as_ref().0 }; + inner.consumer_waker.poll_wait(cx) + } + + fn notify_producer(&self) { + let inner = unsafe { &self.bbq.as_ref().0 }; + inner.producer_waker.signal(()) + } } impl BBBuffer diff --git a/core/src/lib.rs b/core/src/lib.rs index ed7048e..85808d3 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -106,6 +106,7 @@ #![deny(warnings)] mod bbbuffer; +mod signal; pub use bbbuffer::*; /// There are no longer separate `atomic` and `cm_mutex` modules. You can just use the types at the diff --git a/core/src/signal.rs b/core/src/signal.rs new file mode 100644 index 0000000..dfc2852 --- /dev/null +++ b/core/src/signal.rs @@ -0,0 +1,81 @@ +// Copyright The Embassy Project (https://github.com/akiles/embassy). Licensed under the Apache 2.0 +// license + +use core::cell::UnsafeCell; +use core::mem; +use core::task::{Context, Poll, Waker}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))] +fn with_critical_section(f: F) -> R +where + F: FnOnce() -> R, +{ + use std::sync::Once; + static INIT: Once = Once::new(); + static mut BKL: Option> = None; + + INIT.call_once(|| unsafe { + BKL.replace(std::sync::Mutex::new(())); + }); + + let _guard = unsafe { BKL.as_ref().unwrap().lock() }; + f() +} + +#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))] +fn with_critical_section(&self, f: F) -> R +where + F: FnOnce() -> R, +{ + cortex_m::interrupt::free(|_| f()) +} + +pub struct Signal { + state: UnsafeCell>, +} + +enum State { + None, + Waiting(Waker), + Signaled(T), +} + +unsafe impl Send for Signal {} +unsafe impl Sync for Signal {} + +impl Signal { + pub const fn new() -> Self { + Self { + state: UnsafeCell::new(State::None), + } + } + + #[allow(clippy::single_match)] + pub fn signal(&self, val: T) { + with_critical_section(|| unsafe { + let state = &mut *self.state.get(); + match mem::replace(state, State::Signaled(val)) { + State::Waiting(waker) => waker.wake(), + _ => {} + } + }) + } + + pub fn poll_wait(&self, cx: &mut Context<'_>) -> Poll { + with_critical_section(|| unsafe { + let state = &mut *self.state.get(); + match state { + State::None => { + *state = State::Waiting(cx.waker().clone()); + Poll::Pending + } + State::Waiting(w) if w.will_wake(cx.waker()) => Poll::Pending, + State::Waiting(_) => Poll::Pending, + State::Signaled(_) => match mem::replace(state, State::None) { + State::Signaled(res) => Poll::Ready(res), + _ => Poll::Pending, + }, + } + }) + } +}