How safely can this communication channel pattern be implemented in Rust?

Problem

Consider an iterative parallel algorithm that works on two structs Sender and Receiver. Every iteration, one thread performs some computatio onSender, and in parallel another thread performs some computation on Receiver.

Every iteration, Sender writes a struct Data, and every iteration, Receiver reads the same struct Data that was produced by Sender in the tje previous iteration. In other words, the struct Data written by sender in iteration x gets read by Receiver in iteration x+1.

Approach using raw pointers

To transmit Data from Sender to Receiver, I want to use a struct Channel. It has the following structure:

struct Channel {
    current: Data,
    previous: Data,
}

Now, Sender has a pointer to current, and Receiver has a pointer to previous. This way, they share no data and can work concurrently without any synchronisation mechanism. After each iteration, Channel swaps the two Datas. This way, the data can be transmitted.

I can guarantee that the swap operation runs independently of the computations on Sender and Receiver. Hence it also does not require any synchronisation mechanisms. All synchronisation is done "externally" in the main loop of the algorithm, that alternates between performing computations on Sender and Receiver and performing the swap operation on Channel.

Avoiding raw pointers

Using pointers implies using unsafe which I want to avoid. Ideally I would like to build this whole construct using safe Rust, but that seems impossible. My construction requires that there are two sets of exclusive references to Channel:

  1. Sender has exclusive access to current and Receiver has exclusive access to previous.
  2. The main loop has exclusive access to the whole Channel.

This is undefined behaviour, as far as I know. I can avoid this by using shared references and wrapping the Datas into mutexes, but that seems wasteful since there are no concurrent accesses happening.

So the next idea would be to wrap the raw pointers into special types that keep the unsafe code contained. So there would be a pointer type for current, previous and Channel. The former two only allow acces to the respective field, and the latter only allows to swap the two Datas. Let's call the pointer types CurrentPointer, PreviousPointer, and SwapPointer.

These types could contain all the unsafe code, but that feels dangerous. In contrast to e.g. a Box, that also contains unsafe code, my special pointer types can be used wrongly, resulting in undefined behaviour. A box cannot be used wrongly to produce undefined behaviour. However my pointer types can easily be used wrongly by e.g. accidentally accessing current while swapping the Datas in a different thread. Their contract can only be expressed in documentation, and not in code.

Question

Is it possible to somehow add more restrictions to the pointer types to avoid undefined behaviour, while not using any locking, and permitting my use-case? Something like SwapPointer can only be used in a region of code in which the other two pointers are "dormant", i.e. are not touched or borrowed from.

Is there a reason the main loop of the program can't just own the two copies of the Data and give mutable references to them to the tasks when they are executing?

There are multiple iterations which are to be run in parallel. So, the swap has to communicate with other threads. Here are three safe ways to do it:

  • Use scoped threads. Cost: spawning new threads per iteration.
  • Use rayon tasks. Cost: synchronization done by the rayon task system, and a new dependency.
  • Use two threads which exchange their buffers using a pair of channels (mpsc or similar). Cost: more elaborate synchronization ops than in principle necessary.

I'm working on some example code for these.

I don't think that there's a way that's both safe and completely free of overhead; to use existing threads, something has to tell the threads to go (because either Sender or Receiver will finish an iteration first and have to wait).

2 Likes

Sample code showing all three strategies I described above. You can run it on the Rust Playground.

use std::{mem, thread};

struct Data {}
struct Sender {}
impl Sender {
    fn process(&mut self, data: &mut Data) {
        // TODO: do some actual work...
    }
}

struct Receiver {}
impl Receiver {
    fn process(&mut self, data: &mut Data) {
        // TODO: do some actual work...
    }
}

#[test]
pub fn use_threads() {
    let mut s_data = Data {};
    let mut r_data = Data {};
    let mut sender = Sender {};
    let mut receiver = Receiver {};
    for _ in 1..100 {
        thread::scope(|scope| {
            scope.spawn(|| sender.process(&mut s_data));
            scope.spawn(|| receiver.process(&mut r_data));
            // At the end of scope, `thread::scope` waits for all threads to
            // finish.
        });
        mem::swap(&mut s_data, &mut r_data);
    }
}

#[test]
pub fn use_rayon() {
    let mut s_data = Data {};
    let mut r_data = Data {};
    let mut sender = Sender {};
    let mut receiver = Receiver {};
    for _ in 1..100 {
        // This is like the scoped thread option but uses a global thread pool
        // (which can be reconfigured).
        rayon::join(
            || sender.process(&mut s_data),
            || receiver.process(&mut r_data),
        );
        mem::swap(&mut s_data, &mut r_data);
    }
}

#[test]
pub fn use_channels() {
    use std::sync::mpsc;
    let (data_tx, data_rx) = mpsc::sync_channel(1);
    let (circulate_tx, circulate_rx) = mpsc::sync_channel(2);
    
    // Allocate 2 buffers for use, and stuff them into the channel for unused
    // buffers.
    for _ in 0..2 {
        circulate_tx.send(Box::new(Data {})).unwrap();
    }
    
    // Dedicated Sender thread
    thread::spawn(move || {
        let mut sender = Sender {};
        for _ in 1..100 {
            let mut data = circulate_rx.recv().unwrap();
            sender.process(&mut data);
            data_tx.send(data).unwrap();
        }
    });

    // Dedicated Receiver thread
    let rx_thread = thread::spawn(move || {
        let mut receiver = Receiver {};
        loop {
            match data_rx.recv() {
                Ok(mut data) => {
                    receiver.process(&mut data);
                    // Recirculate the buffer, but ignore errors in case
                    // the sender is already done.
                    let _ = circulate_tx.send(data);
                }
                Err(mpsc::RecvError) => {
                    // All data has been delivered to us and processed by us,
                    // so exit.
                    break;
                }
            }
        }
    });
    
    // Wait till the receiver has processed all items
    rx_thread.join().unwrap();
}
3 Likes

I wrote up that one already (you could use the main thread for one of them, but I think this is easier to understand):

#[derive(Debug)]
struct Data(u64);

impl Data {
    fn mutate(&mut self) {
        self.0 += 1;
    }
}

fn main() {
    let (mut d1, mut d2) = (Data(0), Data(10));
    let (s1, r1) = std::sync::mpsc::sync_channel(0);
    let (s2, r2) = std::sync::mpsc::sync_channel(0);

    let t1 = std::thread::spawn(move || loop {
        if d1.0 >= 1000 {
            break d1;
        }
        d1.mutate();

        let new = r1.recv().unwrap();
        s2.send(d1).unwrap();
        d1 = new;
    });

    let t2 = std::thread::spawn(move || loop {
        d2.mutate();

        if let Err(std::sync::mpsc::SendError(d2)) = s1.send(d2) {
            break d2;
        }
        d2 = r2.recv().unwrap();
    });

    let (d1, d2) = (t1.join().unwrap(), t2.join().unwrap());

    println!("{d1:?} {d2:?}");
}

There's only two things that require unsafe: threads and sending data between them. The standard library has threads covered.

The three options are different ways to send data:

  • Spawning a new thread for each iteration uses the thread spawning machinery for sending a closure containing your data.
  • Rayon is able to do the same, but can send the closure to existing threads.
  • Channels allow you to send only the data. This has the advantage of not needing to reenter the closure every time.

But they said:

So there is already some sort of synchronization going on, my question is why not use that?

So there is already some sort of synchronization going on, my question is why not use that?

The question is not how to do the communication but how to do it safely. All safe mechanisms will involve some synchronization — but that synchronization can, implicitly or explicitly, manage the data and the control flow, rather than . In my examples use_threads() and use_rayon(), the synchronization is waiting for the threads/tasks to complete, and the data flow is statically checked by the borrow checker to agree with the control flow; and in use_channels(), the control flow is driven by waiting for the arrival of Box<Data>s in the channels.

@isibboi didn't say what synchronization they already have, but presumably it is equivalent to waiting for the task to complete. So my proposal is not to add more synchronization, but to replace the existing synchronization with something that is aware of the data flow. But perhaps there's some detail that makes that more complex. @isibboi, if you share the code of your main loop as it currently stands — what synchronization you have already got — then we can give more specific advice.

That is technically possible, however in my application it is a trade-off between more code or unsafer code. The constraints I gave were on purpose to avoid the "more code" approach.

In practice, I have not just one sender and one receiver, but a set of tens of different types. Each type acts as sender or receiver for multiple channels, with different data each. And there are many instances of each type. And, between iterations, new instances of types can be added or removed.

So it would take another big complicated data structure to keep track of which instance consumes which data from which specific other instance. This is what I call the "more code" approach that I am trying to avoid here.

By using the "unsafe code" approach with pointers, I don't need to separately store which data goes where, but I would just have a vector with pointers to all channels, and simply update the channels through that vector after each iteration.


I'll type out some sample code for the main loop that others have requested once I have access to a PC hopefully later today.

In practice, I have not just one sender and one receiver, but a set of tens of different types. Each type acts as sender or receiver for multiple channels, with different data each. And there are many instances of each type. And, between iterations, new instances of types can be added or removed.

Then using explicit channels as demonstrated in my use_channels() function above sounds like a great plan. You explicitly wire up the data-flow relationships in whatever shape you want, and once the threads are spawned they each do their work as soon as it is possible without needing any other coordination.

You seem to ignore the fact that adding compile-time constraints to your type doesn't automagically make it usable as before while the compiler checks they are satisfied. Instead most of the time you'll have to explicitly prove those constraints in some way, which ultimately in your case will require embedding that proof in the synchronization framework, since that's what really guarantees the constraints are satisfied. In its simpliest form this proof is just an exclusive reference to the data you want them to have access to.

That is for compile time checks, but you can also have runtime checks with the simpliest one being locking, which you however want to avoid for some reason. As others mentioned there are other kind of runtime checks (e.g. channels), but personally I don't see a reason to use them given the locks will be uncontented (hence just a single CAS operation, you can't get much cheaper than that) if your claims are correct. Have you measured your code with those locks to reach the conclusion you want to avoid them?

I think the solution using channels is very nice, as it can be applied directly and requires no unsafe in my own code. However it does come with unnecessary synchronisation, where unnecessary means "unnecessary assuming the program is correct". I am wondering if there are other approaches that make it harder to introduce bugs than using raw pointers, but are still zero cost.

The main loop code looks something like this, if I were to use raw pointers:

// Structs that perform some computations via the `Computation` trait.
let mut computations: Vec<*dyn mut Computation> = ...; 
// The channels.
// I am using `dyn` notation here to abstract over the fact that
// the data may be different for each channel.
let mut channels: Vec<*dyn mut Channel> = ...;

for _ in 0..iterations {
    // Perform the computations in parallel.
    computations.par_iter_mut(|computation| computation.compute());
    // Swap the current and previous data in each channel in parallel.
    channels.par_iter_mut(|channel| channel.swap());
}

After some thinking I had the idea of using designated "key" types. There is one key type for the computations and one for the swaps. By having them borrow exclusively from a common master key, I can make sure that at most one exists, based on the master key. And the master key itself can ensure that it exists only once.

Since the master key cares for existing only once, and the key types are restricted by borrowing from the master key, the key types can be zero sized, and hence be passed around without overhead.

I am thinking of something like this:

mod channel_keys {
    use std::sync::atomic::AtomicBool;
    use std::sync::atomic::Ordering;
    use core::marker::PhantomData;

    static MASTER_KEY_EXISTS: AtomicBool = AtomicBool::new(false);

    pub struct MasterKey {
        /// A meaningless private field to ensure that
        /// this struct can only be constructed via its constructor.
        dummy: (),
    }

    impl MasterKey {
        pub fn new() -> Self {
            // Assert that the master key does not exist
            // and set it as existing.
            assert!(!MASTER_KEY_EXISTS.swap(true, Ordering::Relaxed));

            // Return a new master key.
            Self { dummy: () }
        }

        pub fn get_data_key(&mut self) -> DataKey<'_> {
            DataKey {
                scope: Default::default(),
            }
        }

        pub fn get_channel_key(&mut self) -> ChannelKey<'_> {
            ChannelKey {
                scope: Default::default(),
            }
        }
    }

    impl Drop for MasterKey {
        fn drop(&mut self) {
            // Assert that the master key exists
            // and set it as not existing.
            assert!(MASTER_KEY_EXISTS.swap(false, Ordering::Relaxed));
        }
    }

    pub struct DataKey<'master_key> {
        scope: PhantomData<&'master_key mut MasterKey>,
    }

    pub struct ChannelKey<'master_key> {
        scope: PhantomData<&'master_key mut MasterKey>,
    }
}

use channel_keys::{ChannelKey, DataKey, MasterKey};

fn use_data_key(key: &DataKey) {
    println!("doing work with data key");
}

fn use_channel_key(key: &ChannelKey) {
    println!("doing work with channel key");
}

fn main() {
    let mut master_key = MasterKey::new();
    
    // Fails because the master key exists already.
    // let mut master_key2 = MasterKey::new();

    for _ in 0..3 {
        let data_key = master_key.get_data_key();
        assert_eq!(std::mem::size_of::<DataKey>(), 0);
        use_data_key(&data_key);

        let channel_key = master_key.get_channel_key();
        assert_eq!(std::mem::size_of::<ChannelKey>(), 0);
        use_channel_key(&channel_key);
        
        // Fails because it implies two simultanous mutable borrows.
        // use_data_key(&data_key);
    }
}

(Playground)

The keys can then be passed to some custom pointer types to ensure that the pointers can only be dereferenced at the right time.

I started implementing this in a crate. Maybe if I have time I will benchmark it against a mutex-based solution.

However it does come with unnecessary synchronisation

While I don't mean to discourage you from using the strategy you've come up with, do note that in your code,

for _ in 0..iterations {
    // Perform the computations in parallel.
    computations.par_iter_mut(|computation| computation.compute());
    // Swap the current and previous data in each channel in parallel.
    channels.par_iter_mut(|channel| channel.swap());
}

there are two synchronization points between the iteration and all tasks per iteration of the loop: each par_iter_mut() must wait for all tasks it creates before it can return. This is, at least abstractly, just as much synchronization as all the waiting for channels in my channel-based solution.

And, because the main loop controls all tasks, it means that each iteration cannot start until all parts of the previous one finish, which may reduce parallelism (because there can't be overlap between finishing one and starting the next). On the other hand, if that's not a significant factor, you get the benfits of Rayon's thread pool (assuming that's Rayon's par_iter_mut(), and you have more computations than cores).

I'd recommend benchmarking both. You may be surprised — or not.

1 Like

Thanks for your replies! I now have an architecture I am satisfied with.

About proving things to the compiler

I am using the borrow mechanics to prove that my accesses to the shared data are "sound".

About performance in a real setting

Since I plan to implement this on tens of types, with different computations and varying number of connected channels, it is hard to benchmark this in the small scale and draw conclusions that would apply in the large scale. In the large scale, implementing different variants takes time, unless I introduce another layer of abstraction. I don't have the time, but the abstraction with the two key types and special pointer types I proposed above seems to allow also using Mutexes as a "verification". So if I wanted to, I could still benchmark this later when performance provably becomes an issue.

About the synchronisation required by par_iter_mut

I am aiming at the case where there is a low number of threads but a high number of computations. I then plan to use chunking for the computations, such that rayon does not need that much synchronisation for the par_iter_mut. Then I actually safe synchronisation points compared to using the Mutex solution inside the Channel, because each computation needs to access some channel at least once.