Help needed with Rust Threads

Hi All,

I am looking to write a multi thread program with many threads waiting on a queue using a mutex and condition variable combination.

  • Program first creates a thread pool, lets each thread wait on its own dedicated queue.
  • A job scheduler finds the data to process via user input, network etc and takes a lock to the specific worker thread in the threadpool. It then pushes data to the thread pool.
  • Once pushed, the job scheduler wakes up the thread(s) by calling a signal/ notify.
  • Upon wakeup, the specific thread takes the lock of the queue, dequeues the work, unlocks and executes the work given.

In C++, this is pretty much quickly written with std::thread, std::mutex, std::condition_variable and std::queue.

I am finding it difficult in rust as to how to approach with the thread function. Most of the examples use closures and it is hard to look at the code and understand it when using them in largescale code when they span beyond 50 lines.

I would like to ask, if there is an example(s) that is similar to that of C++? or a link that provides some educational information about Rust multi threading, locking and data sharing mechanisms. Much appreciated.

Here's a video on Rust threads and sync primitives:

There are multiple Rust libraries that provide thread pools. Most of them use closures as their interface because it’s the most flexible way to transfer arbitrary work. You can also create your own set of threads and use a MPMC[1] channel library (like flume or crossbeam-channel) to distribute work to them and thus turn them into a thread pool. I would not recommend writing a thread pool and job queue from scratch unless your goal is specifically to have the experience of doing so. Is your goal to write a thread pool or is it to write a program that uses a thread pool?


  1. multiple producer, multiple consumer ↩︎

2 Likes

My approach to that is to put the main work of a thread into a function (or a method) and call that from the closure that is spawned as per most example one sees.

Typically I'm using Rust channels or crossbeam::channel crossbeam::channel - Rust for queueing things.

Rust is no different, only a few dozens lines are necessary for a basic thread pool:

use std::collections::VecDeque;
use std::panic::{catch_unwind, UnwindSafe};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};

type Job = Box<dyn FnOnce() + Send + UnwindSafe>;

struct Queue {
    jobs: VecDeque<Job>,
    shutdown: bool,
}

struct Inner {
    cvar: Condvar,
    queue: Mutex<Queue>,
}

impl Inner {
    fn new() -> Self {
        let queue = Queue {
            jobs: VecDeque::new(),
            shutdown: false,
        };

        Inner {
            cvar: Condvar::new(),
            queue: Mutex::new(queue),
        }
    }

    fn shutdown(&self) {
        let mut queue = self.queue.lock().unwrap();
        queue.shutdown = true;
        self.cvar.notify_all();
    }

    fn push(&self, job: Job) {
        let mut queue = self.queue.lock().unwrap();
        queue.jobs.push_back(job);
        self.cvar.notify_one();
    }

    fn pop(&self) -> Option<Job> {
        let mut queue = self.queue.lock().unwrap();
        loop {
            if let Some(job) = queue.jobs.pop_front() {
                return Some(job);
            } else if queue.shutdown {
                return None;
            } else {
                queue = self.cvar.wait(queue).unwrap();
            }
        }
    }

    fn run(&self) {
        while let Some(job) = self.pop() {
            let _ignore_panic = catch_unwind(job);
        }
    }
}

pub struct ThreadPool {
    inner: Arc<Inner>,
    workers: Vec<JoinHandle<()>>,
}

impl ThreadPool {
    pub fn new(size: usize) -> Self {
        let inner = Arc::new(Inner::new());
        let mut workers = Vec::with_capacity(size);

        for _ in 0..size {
            let inner = Arc::clone(&inner);
            let worker = thread::spawn(move || inner.run());
            workers.push(worker);
        }

        ThreadPool { inner, workers }
    }

    pub fn push(&self, job: impl FnOnce() + Send + UnwindSafe + 'static) {
        self.inner.push(Box::new(job));
    }

    pub fn join(mut self) {
        self.inner.shutdown();
        for worker in self.workers.drain(..) {
            worker.join().unwrap();
        }
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        self.inner.shutdown();
    }
}

fn example_job(i: u64) {
    let thread_id = thread::current().id();

    if i == 5 {
        println!("{thread_id:?}: Processing task #{i} and panicking...");
        panic!("example panic");
    } else {
        println!("{thread_id:?}: Processing task #{i}...",);
        thread::sleep(std::time::Duration::from_millis(i * 250));
    }
}

fn main() {
    let pool = ThreadPool::new(4);

    for i in 1..=10 {
        pool.push(move || example_job(i));
    }

    pool.join();

    println!("All tasks completed");
}

3 Likes