No_alloc MPMC Queue

Hello,
I've written a MPMC-Queue with variable sized receiver queues that is no_std and no_alloc.
This is my first time diving into unsafe territory, so please correct me on my use of unsafe.
I'm especially unsure if I need to use pin anywhere within this implementation.

The queue uses a circular double linked list, which each sender and receiver is an element of. The double linked list items are all located on the stack and remove themselves from the list when dropped.
When sending a message, a sender walks the list until it reaches itself. The sender calls the ListItemBehaviour::handle function on each element.
Receivers add the message to their own internal queue where it can be picked up later.

use core::cell::RefCell;
use core::future::{poll_fn, Future};
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::ptr::NonNull;
use core::task::{Poll, Waker};

use critical_section::Mutex;
use defmt::Format;
use heapless::Deque;

#[derive(PartialEq, Format, Debug)]
pub enum Error {
    Overflowed(usize),
}

pub struct UninitSender<T>(MaybeUninit<Sender<T>>);
pub struct SenderRef<'sender, T>(&'sender mut Sender<T>);
struct Sender<T> {
    item: ListItem<T>,
    inner: SenderInner<T>,
}

pub struct UninitReceiver<T, const N: usize>(MaybeUninit<Receiver<T, N>>);
pub struct ReceiverRef<'receiver, T, const N: usize>(&'receiver mut Receiver<T, N>);
struct Receiver<T, const N: usize> {
    item: ListItem<T>,
    inner: Mutex<RefCell<ReceiverInner<T, N>>>,
}

struct ListItem<T> {
    inner: NonNull<dyn ListItemBehaviour<T>>,
    nav: Mutex<RefCell<ListItemNav<T>>>,
}
struct ListItemNav<T> {
    next: NonNull<ListItem<T>>,
    prev: NonNull<ListItem<T>>,
}

trait ListItemBehaviour<T> {
    fn handle(&self, val: &T);
}

struct SenderInner<T>(PhantomData<T>);
struct ReceiverInner<T, const N: usize> {
    queue: Deque<T, N>,
    waker: Option<Waker>,
    overflowed: usize,
}

impl<T> Default for UninitSender<T> {
    fn default() -> Self {
        Self(MaybeUninit::zeroed())
    }
}

impl<T> UninitSender<T> {
    /// create a new channel with only one sender
    pub fn new_channel(&mut self) -> SenderRef<'_, T> {
        //SAEFTY UninitSender<T> can only be initialized with zeroed MaybeUninit
        let sender = unsafe { Sender::<T>::init(&mut self.0) };
        SenderRef(sender)
    }
}

impl<T> Sender<T> {
    /// SAEFTY: MaybeUninit needs to be initialized / zeroed
    unsafe fn init(this: &mut MaybeUninit<Self>) -> &mut Self {
        //SAEFTY children are initialized seperately below, to get the references correct
        let sender = unsafe { &mut *this.as_mut_ptr() };
        sender.inner = SenderInner(PhantomData);

        //SAEFTY MaybeUninit<T> and T have the same memory layout if MaybeUninit<T> is initialized
        //or zeroed
        let item: &mut MaybeUninit<ListItem<T>> = unsafe { core::mem::transmute(&mut sender.item) };
        ListItem::<T>::init(item, &sender.inner as &dyn ListItemBehaviour<T>);

        sender
    }
}

impl<T> ListItemBehaviour<T> for SenderInner<T> {
    fn handle(&self, _val: &T) {
        //a sender doesn't receive anything
    }
}

impl<'a, T> SenderRef<'a, T> {
    pub fn send(&self, value: T) {
        critical_section::with(|cs| {
            let load_next = |item: &ListItem<T>| item.nav.borrow_ref(cs).next;
            let mut next = load_next(&self.0.item);
            while next.as_ptr() != &self.0.item as *const _ as *mut _ {
                let current = unsafe { next.as_ref() };
                let inner = unsafe { current.inner.as_ref() };
                inner.handle(&value);
                next = load_next(current);
            }
        })
    }
    pub fn sender<'b>(&self, sender: &'b mut UninitSender<T>) -> SenderRef<'b, T> {
        //SAEFTY UninitSender<T> can only be initialized with zeroed MaybeUninit
        let sender = unsafe { Sender::<T>::init(&mut sender.0) };

        sender.item.insert_into(&self.0.item);

        SenderRef(sender)
    }
}

impl<'a, T: Clone> SenderRef<'a, T> {
    pub fn receiver<'b, const N: usize>(
        &self,
        receiver: &'b mut UninitReceiver<T, N>,
    ) -> ReceiverRef<'b, T, N> {
        //SAEFTY UninitReceiver<T> can only be initialized with zeroed MaybeUninit
        let receiver = unsafe { Receiver::<T, N>::init(&mut receiver.0) };

        receiver.item.insert_into(&self.0.item);

        ReceiverRef(receiver)
    }
}
impl<'a, T> Drop for SenderRef<'a, T> {
    fn drop(&mut self) {
        self.0.item.pop_self();
    }
}

impl<T, const N: usize> Default for UninitReceiver<T, N> {
    fn default() -> Self {
        Self(MaybeUninit::zeroed())
    }
}
impl<T: Clone, const N: usize> UninitReceiver<T, N> {
    pub fn new_channel(&mut self) -> ReceiverRef<T, N> {
        //SAEFTY UninitReceiver<T> can only be initialized with zeroed MaybeUninit
        let receiver = unsafe { Receiver::<T, N>::init(&mut self.0) };
        ReceiverRef(receiver)
    }
}

impl<T: Clone, const N: usize> Receiver<T, N> {
    /// SAEFTY: MaybeUninit needs to be initialized / zeroed
    unsafe fn init(this: &mut MaybeUninit<Self>) -> &mut Self {
        //SAEFTY: children are initialized seperately below, to get the references correct
        let receiver = unsafe { &mut *this.as_mut_ptr() };
        receiver.inner = Mutex::new(RefCell::new(ReceiverInner {
            queue: Deque::new(),
            waker: None,
            overflowed: 0,
        }));

        //SAEFTY MaybeUninit<T> and T have the same memory layout if MaybeUninit<T> is initialized
        //or zeroed
        let item: &mut MaybeUninit<ListItem<T>> =
            unsafe { core::mem::transmute(&mut receiver.item) };
        ListItem::<T>::init(item, &receiver.inner as &dyn ListItemBehaviour<T>);

        receiver
    }
}

impl<T: Clone, const N: usize> ListItemBehaviour<T> for Mutex<RefCell<ReceiverInner<T, N>>> {
    fn handle(&self, val: &T) {
        let waker = critical_section::with(|cs| {
            let mut inner = self.borrow_ref_mut(cs);

            if inner.overflowed > 0 || inner.queue.push_front(val.clone()).is_err() {
                inner.overflowed += 1;
            }

            inner.waker.take()
        });

        if let Some(waker) = waker {
            waker.wake();
        }
    }
}

impl<'a, T, const N: usize> ReceiverRef<'a, T, N> {
    pub fn recv(&mut self) -> impl Future<Output = Result<T, Error>> + '_ {
        poll_fn(|cx| {
            let (ousted_waker, item) = critical_section::with(|cs| {
                let mut inner = self.0.inner.borrow_ref_mut(cs);

                if inner.overflowed > 0 {
                    return (None, Err(Error::Overflowed(inner.overflowed)));
                }

                let mut waker = None;
                if inner.queue.is_empty()
                    && !inner
                        .waker
                        .as_ref()
                        .map(|x| x.will_wake(cx.waker()))
                        .unwrap_or_default()
                {
                    waker = Some(cx.waker().clone());
                    core::mem::swap(&mut waker, &mut inner.waker);
                }

                let item = match (inner.queue.pop_back(), inner.overflowed > 0) {
                    (Some(x), _) => Ok(Some(x)),
                    (None, false) => Ok(None),
                    (None, true) => Err(Error::Overflowed(inner.overflowed)),
                };
                (waker, item)
            });
            if let Some(waker) = ousted_waker {
                waker.wake();
            }

            match item {
                Ok(Some(item)) => Poll::Ready(Ok(item)),
                Err(e) => Poll::Ready(Err(e)),
                Ok(None) => Poll::Pending,
            }
        })
    }

    pub fn try_recv(&mut self) -> Result<Option<T>, Error> {
        critical_section::with(|cs| {
            let mut inner = self.0.inner.borrow_ref_mut(cs);

            match (inner.queue.pop_back(), inner.overflowed > 0) {
                (Some(x), _) => Ok(Some(x)),
                (None, false) => Ok(None),
                (None, true) => Err(Error::Overflowed(inner.overflowed)),
            }
        })
    }

    /// clears queue and resets overflow counter
    pub fn resolve_overflow(&mut self) {
        critical_section::with(|cs| {
            let mut inner = self.0.inner.borrow_ref_mut(cs);
            inner.queue.clear();
            inner.overflowed = 0;
        })
    }

    pub fn sender<'b>(&self, sender: &'b mut UninitSender<T>) -> SenderRef<'b, T> {
        //SAEFTY UninitSender<T> can only be initialized with zeroed MaybeUninit
        let sender = unsafe { Sender::<T>::init(&mut sender.0) };

        sender.item.insert_into(&self.0.item);

        SenderRef(sender)
    }
}

impl<'a, T: Clone, const N: usize> ReceiverRef<'a, T, N> {
    pub fn receiver<'b, const M: usize>(
        &self,
        receiver: &'b mut UninitReceiver<T, M>,
    ) -> ReceiverRef<'b, T, M> {
        //SAEFTY UninitReceiver<T> can only be initialized with zeroed MaybeUninit
        let receiver = unsafe { Receiver::<T, M>::init(&mut receiver.0) };

        receiver.item.insert_into(&self.0.item);

        ReceiverRef(receiver)
    }
}

impl<'a, T, const N: usize> Drop for ReceiverRef<'a, T, N> {
    fn drop(&mut self) {
        self.0.item.pop_self();
    }
}

impl<T> ListItem<T> {
    fn init<'a>(
        this: &'a mut MaybeUninit<Self>,
        inner: &'a dyn ListItemBehaviour<T>,
    ) -> &'a mut Self {
        //SAEFTY reference is never null
        let inner = unsafe { NonNull::new_unchecked(inner as *const _ as *mut _) };

        //SAEFTY reference is never null
        let this_ptr = unsafe { NonNull::new_unchecked(this.as_ptr() as *const _ as *mut _) };

        this.write(Self {
            inner,
            nav: Mutex::new(RefCell::new(ListItemNav {
                prev: this_ptr,
                next: this_ptr,
            })),
        })
    }

    fn insert_into(&self, other: &Self) {
        //SAEFTY reference is never null
        let other_ptr = unsafe { NonNull::new_unchecked(other as *const _ as *mut _) };

        //SAEFTY reference is never null
        let own_ptr = unsafe { NonNull::new_unchecked(self as *const _ as *mut _) };

        critical_section::with(|cs| {
            let next_ptr = {
                let mut other_nav = other.nav.borrow_ref_mut(cs);
                let next_ptr = other_nav.next;

                other_nav.next = own_ptr;

                next_ptr
            };
            {
                let next = unsafe { next_ptr.as_ref() };
                let mut next_nav = next.nav.borrow_ref_mut(cs);
                next_nav.prev = own_ptr;
            }

            {
                let mut own_nav = self.nav.borrow_ref_mut(cs);
                own_nav.next = next_ptr;
                own_nav.prev = other_ptr;
            }
        });
    }

    fn pop_self(&self) {
        //SAEFTY reference is never null
        let own_ptr = unsafe { NonNull::new_unchecked(self as *const _ as *mut _) };
        critical_section::with(|cs| {
            let own_nav = self.nav.borrow_ref(cs);
            if !(own_ptr == own_nav.next && own_ptr == own_nav.prev) {
                let next_ptr = own_nav.next;
                let next = unsafe { next_ptr.as_ref() };

                let prev_ptr = own_nav.prev;
                let prev = unsafe { prev_ptr.as_ref() };

                {
                    let mut next_nav = next.nav.borrow_ref_mut(cs);
                    next_nav.prev = prev_ptr;
                }
                {
                    let mut prev_nav = prev.nav.borrow_ref_mut(cs);
                    prev_nav.next = next_ptr;
                }
            }
        })
    }
}

unsafe impl<T> Send for ListItem<T> {}

#[cfg(test)]
mod test {

    use super::{Error, UninitReceiver, UninitSender};

    #[test]
    fn create() {
        let mut sender = UninitSender::<u32>::default();
        let mut receiver = UninitReceiver::default();

        let sender = sender.new_channel();
        let _receiver = sender.receiver::<10>(&mut receiver);
    }

    #[test]
    fn oneshot() {
        let mut sender = UninitSender::default();
        let mut receiver = UninitReceiver::default();

        let sender = sender.new_channel();
        let mut receiver = sender.receiver::<10>(&mut receiver);

        assert_eq!(receiver.try_recv(), Ok(None));
        sender.send(0);
        assert_eq!(receiver.try_recv(), Ok(Some(0)));
    }

    #[test]
    fn overflow() {
        let mut sender = UninitSender::default();
        let mut receiver = UninitReceiver::default();

        let sender = sender.new_channel();
        let mut receiver = sender.receiver::<10>(&mut receiver);

        assert_eq!(receiver.try_recv(), Ok(None));
        for i in 0..11 {
            sender.send(i);
        }
        for i in 0..10 {
            assert_eq!(receiver.try_recv(), Ok(Some(i)));
        }
        assert_eq!(receiver.try_recv(), Err(Error::Overflowed(1)));
        receiver.resolve_overflow();
        assert_eq!(receiver.try_recv(), Ok(None));
    }

    #[test]
    fn multi() {
        let mut sender1 = UninitSender::default();
        let mut sender2 = UninitSender::default();
        let mut receiver1 = UninitReceiver::default();
        let mut receiver2 = UninitReceiver::default();

        let sender1 = sender1.new_channel();
        let sender2 = sender1.sender(&mut sender2);
        let mut receiver1 = sender2.receiver::<10>(&mut receiver1);
        let mut receiver2 = receiver1.receiver::<10>(&mut receiver2);

        assert_eq!(receiver1.try_recv(), Ok(None));
        assert_eq!(receiver2.try_recv(), Ok(None));
        sender1.send(0);
        sender1.send(1);
        assert_eq!(receiver1.try_recv(), Ok(Some(0)));
        assert_eq!(receiver1.try_recv(), Ok(Some(1)));
        assert_eq!(receiver2.try_recv(), Ok(Some(0)));
        assert_eq!(receiver2.try_recv(), Ok(Some(1)));

        drop(sender1);

        assert_eq!(receiver1.try_recv(), Ok(None));
        assert_eq!(receiver2.try_recv(), Ok(None));
        sender2.send(2);
        sender2.send(3);
        assert_eq!(receiver1.try_recv(), Ok(Some(2)));
        assert_eq!(receiver1.try_recv(), Ok(Some(3)));
        assert_eq!(receiver2.try_recv(), Ok(Some(2)));
        assert_eq!(receiver2.try_recv(), Ok(Some(3)));
    }
}
1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.