@quinedot Thanks for your help, It really helped me a lot. I think the origin error I mentioned on the top of this page is caused by self-join because the main thread would close immediately and the worker pool would still work, and my job contains the reference count of the worker pool itself, which caused the worker pool to be drop in one of the worker thread. so i added a channel for waiting for all workers to finish job, and it worked without error.
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::thread::{self, JoinHandle};
pub type Job = Box<dyn FnOnce() + Send + 'static>;
pub struct Woker {
pub id: usize,
pub thread: Option<JoinHandle<()>>,
}
impl Woker {
pub fn new(
id: usize,
job_receiver: Arc<Mutex<Receiver<Job>>>,
finish_sender: Arc<Mutex<Sender<Result<(), ()>>>>,
counter: Arc<Mutex<usize>>,
size: usize,
) -> Self {
Self {
id,
thread: Some(thread::spawn(move || {
loop {
let result = job_receiver.lock().unwrap().recv();
match result {
Ok(job) => {
job();
let mut counter_garud = counter.lock().unwrap();
*counter_garud +=1;
if *counter_garud == size {
finish_sender.lock().unwrap().send(Ok(())).unwrap();
}
drop(counter_garud);
}
Err(_) => {
break;
}
}
}
}))
}
}
}
pub struct WokerPool {
pub counter: Arc<Mutex<usize>>,
workers: Vec<Woker>,
job_sender: Option<Sender<Job>>,
}
pub type ThreadSafeWorkPool = Arc<WokerPool>;
impl WokerPool {
pub fn new(size: usize, finish_sender: Sender<Result<(), ()>>) -> Self {
let (job_sender, job_receiver) = channel::<Job>();
let thread_safe_job_receiver = Arc::new(Mutex::new(job_receiver));
let thread_safe_finish_sender = Arc::new(Mutex::new(finish_sender));
let mut workers = Vec::with_capacity(size);
let thread_saft_counter = Arc::new(Mutex::new(size));
for i in 0..size {
workers.push(Woker::new(
i,
Arc::clone(&thread_safe_job_receiver),
Arc::clone(&thread_safe_finish_sender),
Arc::clone(&thread_saft_counter),
size,
)
)
}
Self {
job_sender: Some(job_sender),
workers,
counter: thread_saft_counter,
}
}
pub fn execute<F>(& self, f: F)
where
F: FnOnce() + Send + 'static
{
self.job_sender.as_ref().unwrap().send(Box::new(f)).unwrap();
}
}
impl Drop for WokerPool {
fn drop(&mut self) {
drop(self.job_sender.take());
for worker in &mut self.workers {
// println!("Work id {} join", worker.id);
if let Some(thread) = worker.thread.take() {
match thread.join() {
Ok(_) => {},
Err(_) => { println!("Error when join {}", worker.id) }
}
}
}
}
}
But there is one more problem occur, my parallel visit is much slower than sequential, ( parallel is 2.4ms, recursive is 390us, benchmark by criterion )
pub type ThreadSafeWorkPool = Arc<WokerPool>;
pub fn parallel_visit(root: &ThreadSafeNode, worker_pool: &ThreadSafeWorkPool) {
let tsn_guard = root.lock().unwrap();
//println!("index : {}", tsn_guard.value);
let Node { left, right, .. } = &*tsn_guard;
match (left.as_ref(), right.as_ref()) {
(None, None) => {},
(Some(single), None) | (None, Some(single)) => {
parallel_visit(single, worker_pool)
}
(Some(left), Some(right)) => {
let mut counter_guard = worker_pool.counter.lock().unwrap();
if *counter_guard == 0 {
drop(counter_guard);
parallel_visit(left, worker_pool);
parallel_visit(right, worker_pool);
}else {
*counter_guard -= 1;
drop(counter_guard);
let wp_arc = Arc::clone(worker_pool);
let left_arc = Arc::clone(left);
worker_pool.execute(move || {
parallel_visit(&left_arc, &wp_arc);
});
parallel_visit(right, worker_pool)
}
}
}
}
but when i merge part of parallel visit function use recursive function(sequential), is faster than just sequential ( parallel is 280us, recustrive is 403us ).
pub type ThreadSafeWorkPool = Arc<WokerPool>;
pub fn parallel_visit(root: &ThreadSafeNode, worker_pool: &ThreadSafeWorkPool) {
let tsn_guard = root.lock().unwrap();
//println!("index : {}", tsn_guard.value);
let Node { left, right, .. } = &*tsn_guard;
match (left.as_ref(), right.as_ref()) {
(None, None) => {},
(Some(single), None) | (None, Some(single)) => {
parallel_visit(single, worker_pool)
}
(Some(left), Some(right)) => {
let mut counter_guard = worker_pool.counter.lock().unwrap();
if *counter_guard == 0 {
drop(counter_guard);
parallel_visit(left, worker_pool);
parallel_visit(right, worker_pool);
}else {
*counter_guard -= 1;
drop(counter_guard);
// let wp_arc = Arc::clone(worker_pool);
let left_arc = Arc::clone(left);
worker_pool.execute(move || {
//parallel_visit(&left_arc, &wp_arc);
recursive_visit(&left_arc);
});
parallel_visit(right, worker_pool)
}
}
}
}
pub fn recursive_visit(root: &ThreadSafeNode) {
let tsn_guard = root.lock().unwrap();
// println!("index: {}", tsn_guard.value);
let Node { left, right, .. } = &*tsn_guard;
match (left.as_ref(), right.as_ref()) {
(None, None) => {},
(Some(single), None) | (None, Some(single)) => {
recursive_visit(single)
}
(Some(left), Some(right)) => {
recursive_visit(left);
recursive_visit(right);
}
}
}
I think that the overhead of the lock counter and send message cause is different. I would like to hear your advice ~ thanks ~
By the way, I add Instant to measure the time cost of each job in worker pool
let now = Instant::now();
job();
println!("Cost of job: {} (Work id {:?})", now.elapsed().as_nanos(), id);
i found that parallel visit function which combine with recursive function's worker would take more time then purely parallel visit, maybe because may many time send job and lock and waiting counter cause purely parallel would slower.