Write an Iterator returning multiple mutables references (unsafe)

I'm trying to provide mutable access to a datastructure (it's basically the opposite of zip()).

An iterator shall return accessors to that slice, which need mutable access to the underlying data structure. Application logic will make sure that two accessors cannot target the same memory location.

struct Container<'a> {
    slice: &'a mut [()],
}

impl<'a> Iterator for Container<'a> {
    type Item = &'a mut [()];

    fn next(&mut self) -> Option<Self::Item> {
        // try to duplicate the reference
        let mut_clone = unsafe { &mut *self.slice };
        // No worries! The real application will return a wrapper ensuring all access will be disjoint.
        Some(mut_clone)
    }
}

I borrowed the unsafe pointer cloning from the chunks-method, but it's still complaining that I somehow messed up the lifetimes as if I didn't use unsafe.

Can someone explain me how to handle that?

Lifetimes of regular references don't disappear in unsafe. Markedly, unsafe doesn't mean "YOLO, just pretend we are in C". The usage of an unsafe block only grants you a very narrow set of additional capabilities, such as dereferencing raw pointers, calling unsafe functions, and implementing unsafe traits (and off the top of my head, that's basically all of it). Code that relies of safe Rust types and functions still gets fully type- and borrow-checked, exactly as in safe code.

So what you can do here is create a raw pointer and dereference it (in which case you'll get a free lifetime parameter, which you can have inferred to be whatever). But I doubt that's actually what you want. If you can tell us more about how your iterator will be used, we can likely suggest a better, likely safe, solution.

6 Likes

Thanks for your explanation. What I'm trying to do is the following:

I have a buffer of audio samples which are store in an interleaved fashion: [L, R, L, R, L, R, …]. The containing struct looks like this (simplified):

struct AudioBuffer {
  samples: &mut [f32],
  channel_count: usize, // 2 in my example above
}

One usual access pattern is to iterate through this buffer channel by channel:

for channel_mut in AudioBuffer.channels_mut() {
  // channel 0 would return all `L`s
  // channel 1 would return all `R`s
  for sample_mut in channel_mut.samples_mut() {
    *sample_mut = some_value;
  }
}

I'm struggling to find a non-unsafe solution for the implementation of channels_mut() and I guess it's impossible: each resulting value would need access to (almost) the entire range of the buffer which is something the compiler cannot allow.

One possible solution would be to not use the Iterator interface, so I could return shorter lifetimes (making it impossible to have multiple channels at the same time). Another Option would be to implement for_each_channel(…). But in both cases ergonimics would suffer.

Creating aliasing mutable references is UB anyway, even if you use them in a disjoint way. What you can do is to make your wrappers contain a raw pointer (and a PhantomData, for the lifetime) and then use unsafe to access that raw pointer only where you know it's safe. Also, don't store a mutable reference inside Container as it would invalidate the pointers in the wrappers, instead store a raw pointer and a PhantomData there too.

1 Like

You don't need to scrap Iterator for this. This works:

struct AudioBuffer<'a> {
    samples: &'a mut [f32],
    channel_count: usize,
}

impl AudioBuffer<'_> {
    fn channel_mut(&mut self, index: usize) -> ChannelIterator<'_> {
        ChannelIterator {
            iter: self.samples.chunks_mut(self.channel_count),
            index,
        }
    }
}

struct ChannelIterator<'a> {
    iter: core::slice::ChunksMut<'a, f32>,
    index: usize,
}

impl<'a> Iterator for ChannelIterator<'a> {
    type Item = &'a mut f32;
    
    fn next(&mut self) -> Option<Self::Item> {
        self.iter.next().map(|frame| &mut frame[self.index])
    }
}
2 Likes

If you don't need multiple channel iterators to coexist, then go with @H2CO3's solution. If you do need them to coexist, you'll need to write some unsafe code yourself¹. One way would look something like this:

(passes MIRI, but not properly audited)

struct AudioBuffer<'a> {
    samples: &'a mut [f32],
    channel_count: usize, // 2 in my example above
}

impl AudioBuffer<'_> {
    pub fn channels_mut(&mut self)->Vec<ChannelIter<'_>> {
        // NB: Required for soundness; do not remove.
        assert!(self.samples.len() > 0);
        assert!(self.samples.len() % self.channel_count == 0);
        
        let range = self.samples.as_mut_ptr_range();
        let start = range.start;
        let end = range.end;
        (0..self.channel_count).map(|ch| ChannelIter {
            stride: self.channel_count,
            // Safety: see asserts above
            start: unsafe { start.offset(ch as isize) },
            end: end,
            lt: std::marker::PhantomData
        }).collect()
    }
}

struct ChannelIter<'a> {
    stride: usize,
    start: *mut f32,
    end: *mut f32,
    lt: std::marker::PhantomData<&'a mut f32>
}

impl<'a> Iterator for ChannelIter<'a> {
    type Item = &'a mut f32;
    fn next(&mut self)->Option<&'a mut f32> {
        if self.start.is_null() { return None; }
        Some(unsafe {
            // Safety: Pointers are valid.
            // No two iterators created from a single call to 
            // AudioBuffer::channels_mut will ever access the same
            // element, so they are free to coexist.  More iterators
            // to the same buffer are prevented by the lifetime 'a, which
            // holds an exclusive lock on the buffer.
            let val = &mut *self.start;
            if self.end.offset_from(self.start) > self.stride as isize {
                self.start = self.start.offset(self.stride as isize);
            } else {
                self.start = std::ptr::null_mut();
            }
            val
        })
    }
}

¹ You could also do something safe with samples.iter_mut().collect::<Vec<_>>(), but that will involve more heap traffic than you probably want in an audio application.

4 Likes

I decided to write out a version of this as well:

use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::Rc;

use std::iter::{Zip,Cycle};
use std::ops::Range;
use std::slice;

struct AudioBuffer<'a> {
    samples: &'a mut [f32],
    channel_count: usize,
}

impl AudioBuffer<'_> {
    fn channels_mut(&mut self) -> impl Iterator<Item = ChannelIter<'_>> {
        let channels = Rc::new(RefCell::new(Channels {
            queues: (0..self.channel_count).map(|_| VecDeque::new()).collect(),
            iter: (0..self.channel_count).cycle().zip(self.samples.iter_mut()),
        }));

        (0..self.channel_count).map(move |idx| ChannelIter {
            channels: Rc::clone(&channels),
            idx,
        })
    }
}

struct Channels<'a> {
    queues: Vec<VecDeque<&'a mut f32>>,
    iter: Zip<Cycle<Range<usize>>, slice::IterMut<'a, f32>>,
}

impl<'a> Channels<'a> {
    fn pump(&mut self) -> Option<()> {
        let (ch, x) = self.iter.next()?;
        self.queues[ch].push_back(x);
        Some(())
    }

    pub fn next(&mut self, ch: usize) -> Option<&'a mut f32> {
        while self.queues[ch].len() == 0 {
            self.pump()?
        }
        self.queues[ch].pop_front()
    }
}

struct ChannelIter<'a> {
    channels: Rc<RefCell<Channels<'a>>>,
    idx: usize,
}

impl<'a> Iterator for ChannelIter<'a> {
    type Item = &'a mut f32;
    fn next(&mut self) -> Option<&'a mut f32> {
        self.channels.borrow_mut().next(self.idx)
    }
}
3 Likes

Thanks but your solution scrapped the outer Iterator - which was exactly the workaround I had in mind :wink:
But I think this is still the best compromise.

@2e71828 Thanks a lot for your efforts! I rarely work with unsafe code and there's a lot for me to digest in your code.

  • The PhantomData introduces kind of a custom lifetime because the raw pointers don't have it any more, right?
  • Is there a special reason why your channels_mut returns a Vec rather than an Iterator?

I don't believe that simultaneous write access to multiple channels will be needed in practice (there's another API for accessing the samples frame-wise). My main intention to use an Iterator was to offer a convenient (and symmetric) API. But seeing how hard this is to get right I will likely switch to a single-channel-at-a-time solution.

Oh, that's interesting. So as a rule of thumb references and raw pointers to the same location should never co-exist? Otherwise UB, I guess?

Not really; just development path dependence. This also appears to work:

impl AudioBuffer<'_> {
    pub fn channels_mut(& mut self)->impl Iterator<Item=ChannelIter<'_>> {
        // NB: Required for soundness; do not remove.
        assert!(self.samples.len() > 0);
        assert!(self.samples.len() % self.channel_count == 0);
        
        let range = self.samples.as_mut_ptr_range();
        let start = range.start;
        let end = range.end;
        let stride = self.channel_count;
        (0..self.channel_count).map(move |ch| ChannelIter {
            stride,
            // Safety: see asserts above
            start: unsafe { start.offset(ch as isize) },
            end: end,
            lt: std::marker::PhantomData
        })
    }
}

Yes; I’ve effectively split the reference into two parts:

  • The raw pointer, which has weaker aliasing restrictions than the original reference, and
  • The borrow/lifetime (represented by the PhantomData), which ensures that the original reference’s exclusive-access region extends beyond the last use of these values.

When I yield an &mut f32, the two parts get put back together.

1 Like

Looking at that code, I can't find any soundness issues with it. Since you start with a safe slice and check the channel count, there's nowhere for overflow issues to creep in. The careful offset_from() math prevents creating out-of-bounds pointers. And the borrow stack is conceptually pretty simple, just the disjoint Unique pointers created from the SRW as_mut_ptr_range(). Optionally, ChannelIter can be marked Send and Sync, since it requires a unique borrow to use.

1 Like

The exact semantics are still not stable, but the Stacked Borrows experiment aims to define exactly how raw pointers and references interact. Basically if you create a raw pointer from a reference, the raw pointer is "valid" for as long as the original reference isn't used. If you use the reference again the raw pointers are "invalidated".

3 Likes

Nope, that's not right. That would make it impossible to ever use APIs like <[T]>::as_mut_ptr() correctly, or even primitive-coerce a mutable reference into a raw pointer without instantly causing UB.

The problem is not the existence of an aliasing raw pointer. Raw pointers are interior mutability primitives; even *mut raw pointers are allowed to co-exist with each other or with a (unique) mutable reference. The problem arises when you re-create a mutable reference from a raw pointer, since two mutable references to the same value aren't allowed to co-exist.

(Of course, this rough first approximation is not even slightly complicated by the right to temporarily reborrow a mutable borrow in a subregion of the original borrow.)

1 Like

I threw together a concrete version too.

2 Likes

I realized that I was doing a lot of things manually that we have iterator adapters for. Unfortunately, Range<*mut ...> doesn't implement Iterator, so I had to write that part myself. Here's the cleaned-up version:

impl AudioBuffer<'_> {
    pub fn channels_mut(&mut self)
        ->impl Iterator<Item=impl Iterator<Item=&'_ mut f32>>
    {
        let pointers = PtrIter::from(&mut *self.samples);
        let stride = self.channel_count;
        (0..self.channel_count).map(move |ch| 
            pointers
                .clone()
                .skip(ch).step_by(stride)
                .map(|p| unsafe {&mut *p})
        )
    }
}

#[derive(Clone)]
/// Necessary b/c Range<*mut T> isn't iterable :(
struct PtrIter<T>(std::ops::Range<*mut T>);

impl<T> Iterator for PtrIter<T> {
    type Item = *mut T;
    fn next(&mut self) -> Option<*mut T> {
        if self.0.is_empty() { None }
        else {
            let out = self.0.start;
            self.0.start = unsafe { out.offset(1) };
            Some(out)
        }
    }
}

impl<'a, T> From<&'a mut [T]> for PtrIter<T> {
    fn from(slice:&'a mut [T])->Self {
        PtrIter(slice.as_mut_ptr_range())
    }
}
3 Likes

Awesome :heart: Thank you all for your help!

I just found a second use case where the workaround (not using an outer Iterator at all) feels cumbersome: There's another scenario where the channels are in separate buffers and I want to iterate over frames (zipping all channels and provide mutable access). So the unsafe solution is still on the table.

Thanks to your explanations I think I can come up with a good solution.

My first go was (at attempt at) using Range<*mut _> as an iterator too, heh

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.