I have a vector of things and I need to run some function on each one of them collecting the results in the same order. I could use map
to do this, but I wanted to run the processes in parallel, in threads. So I wrote an implementation of map_thread
. It was fun to write, but:
- I am quite sure there is already a solution for this I could not find. I'd appreciate a pointer.
- I'd be happy to get suggestions for improving this code so I can learn from that.
use std::sync::mpsc;
use std::thread;
use std::marker::Send;
fn main() {
let numbers: Vec<i32> = (1..=10).collect();
println!("{:?}", numbers);
let doubles = numbers.iter().map(double_function).collect::<Vec<i32>>();
println!("{:?}", doubles);
let doubles = map_thread(&numbers, double_function, 3);
println!("{:?}", doubles);
}
fn double_function(num: &i32) -> i32 {
2 * num
}
fn map_thread<Tin:Send+Copy+'static, Tout:Ord+Send+Copy+'static>(params: &[Tin], func: fn(&Tin) -> Tout, max_threads: i32) -> Vec<Tout> {
let (tx, rx) = mpsc::channel();
let mut thread_count = 0;
let mut started = 0;
let mut finished = 0;
let mut results: Vec<(i32, Tout)> = vec![];
for paramx in params.iter() {
let number = *paramx;
started += 1;
let mytx = tx.clone();
thread::Builder::new().name(format!("{}", started)).spawn(move || {
let id: i32 = thread::current().name().unwrap().parse().unwrap();
let res = func(&number);
mytx.send((id, res)).unwrap();
}).unwrap();
thread_count += 1;
if thread_count >= max_threads {
let received = rx.recv().unwrap();
results.push(received);
finished += 1;
}
}
for received in rx {
finished += 1;
results.push(received);
if finished >= started {
break;
}
}
results.sort();
results.iter().map(|item| item.1).collect::<Vec<Tout>>()
}