Map with threads

I have a vector of things and I need to run some function on each one of them collecting the results in the same order. I could use map to do this, but I wanted to run the processes in parallel, in threads. So I wrote an implementation of map_thread. It was fun to write, but:

  • I am quite sure there is already a solution for this I could not find. I'd appreciate a pointer.
  • I'd be happy to get suggestions for improving this code so I can learn from that.
use std::sync::mpsc;
use std::thread;
use std::marker::Send;


fn main() {
    let numbers: Vec<i32> = (1..=10).collect();
    println!("{:?}", numbers);

    let doubles = numbers.iter().map(double_function).collect::<Vec<i32>>();
    println!("{:?}", doubles);

    let doubles = map_thread(&numbers, double_function, 3);
    println!("{:?}", doubles);
}

fn double_function(num: &i32) -> i32 {
    2 * num
}

fn map_thread<Tin:Send+Copy+'static, Tout:Ord+Send+Copy+'static>(params: &[Tin], func: fn(&Tin) -> Tout, max_threads: i32) -> Vec<Tout> {
    let (tx, rx) = mpsc::channel();
    let mut thread_count = 0;
    let mut started = 0;
    let mut finished = 0;
    let mut results: Vec<(i32, Tout)> = vec![];

    for paramx in params.iter() {
        let number = *paramx;
        started += 1;
        let mytx = tx.clone();
        thread::Builder::new().name(format!("{}", started)).spawn(move || {
            let id: i32 = thread::current().name().unwrap().parse().unwrap();
            let res = func(&number);
            mytx.send((id, res)).unwrap();
        }).unwrap();

        thread_count += 1;
        if thread_count >= max_threads {
            let received = rx.recv().unwrap();
            results.push(received);
            finished += 1;
        }
    }

    for received in rx {
        finished += 1;
        results.push(received);
        if finished >= started {
            break;
        }
    }

    results.sort();
    results.iter().map(|item| item.1).collect::<Vec<Tout>>()

}
1 Like

You are looking for rayon. No code example, as I'm on my phone.

6 Likes

Your map_thread is doing more work than it needs to:

  • id is passed via string parsing instead of being captured in the closure
  • Sorting the results afterwards can be avoided, as you know exactly where each calculated result should go in the order

If I were writing this from scratch, I'd do something like this:

fn map_thread<Tin:Sync, Tout:Send>(params: &[Tin], func: fn(&Tin) -> Tout, max_threads: i32) -> Vec<Tout> {
    let thread_count = Mutex::new(0);
    let notify_finish = Condvar::new();
    
    // Use scoped threads so that we can pass references around
    thread::scope(|scope| {
        let mut handles = vec![];
        for xref in params {
            // Wait for earlier threads to finish, if necessary
            let mut guard = thread_count.lock().unwrap();
            while *guard >= max_threads {
                guard = notify_finish.wait(guard).unwrap();
            }
            
            // Increment running thread count
            *guard += 1;
            
            handles.push(scope.spawn(|| {
                // Run calculation
                let res = func(xref);
                
                // Report success
                *(thread_count.lock().unwrap()) -= 1;
                notify_finish.notify_one();
                
                // Return calculation result
                res
            }));
        }
        
        // Read the return values of each thread, in order
        handles.into_iter().map(|h| h.join().unwrap()).collect()
    })
}
2 Likes

Alternatively, you can also use more of a work-stealing style:

fn map_thread<Tin:Sync, Tout:Send+Default+Clone>(params: &[Tin], func: fn(&Tin) -> Tout, max_threads: i32) -> Vec<Tout> {
    let mut result = vec![Tout::default(); params.len()];
    let jobs = Mutex::new(params.into_iter().zip(result.iter_mut()));
    
    thread::scope(|scope| {
        for _ in 0..max_threads {
            scope.spawn(|| {
                // Closure used to prevent mutex being held during loop body
                let next = || jobs.lock().unwrap().next();
                
                while let Some((i,o)) = next() {
                    *o = func(i);
                }
            });
        }
    });
    
    result
}

Edit: The previous work-stealing version was accidentally serialized due to drop timing rules :frowning:

1 Like