Multithreaded directory traversal & hashing with a synchronized job stack

I'm experimenting with writing various implementations for computing a Merkle tree hash of a directory of files, and the following is the least straightforward implementation so far (yet also the fastest). It's inspired by this Python code for multithreaded directory traversal; besides translating it into Rust, I've made the job management more object-oriented (or at least more encapsulated) and also added in hashing of files within the threads.

Directories to traverse and files to hash are stored in a job stack that tracks how many unfinished jobs there are, where a job is only considered "finished" after the value is both retrieved from the stack and dropped; in between, new jobs can be added to the stack based on the current job. When fetching from an empty stack, the thread blocks until either a new job is added or until all jobs are marked finished, the latter event indicating that there will be no more jobs. Popping items from the job stack is done via an iterator that produces wrapper structs that decrement the job count on drop.

If an error occurs in a thread, all threads then (should) finish cleanly as soon as possible. This is accomplished by "shutting down" the job stack, clearing out its set of pending jobs and dropping any new jobs that would be added after that point.

Besides general code review stuff, the main things I'm wondering about are:

  • Is the error handling safe & clean?

  • Is there a way to make JobStack itself an iterator (instead of having a separate JobStackIterator) while still allowing the items it produces to manipulate JobStack on drop? My attempts at accomplishing this so far have run up against the fact that generic associated types are unstable, so this may not be possible, but there may also be another way to do this that I'm overlooking.

use log::{trace, warn};
use std::ops::Deref;
use std::path::Path;
use std::sync::mpsc::channel;
use std::sync::{Arc, Condvar, Mutex};
use std::thread;

struct JobStack<T> {
    data: Mutex<JobStackData<T>>,
    cond: Condvar,
}

struct JobStackData<T> {
    queue: Vec<T>,
    jobs: usize,
    shutdown: bool,
}

impl<T> JobStack<T> {
    fn new<I: IntoIterator<Item = T>>(items: I) -> Self {
        let queue: Vec<T> = items.into_iter().collect();
        let jobs = queue.len();
        JobStack {
            data: Mutex::new(JobStackData {
                queue,
                jobs,
                shutdown: false,
            }),
            cond: Condvar::new(),
        }
    }

    // We can't impl Extend, as that requires the receiver to be mut
    fn extend<I: IntoIterator<Item = T>>(&self, iter: I) {
        let mut data = self.data.lock().unwrap();
        if !data.shutdown {
            let prelen = data.queue.len();
            data.queue.extend(iter);
            data.jobs += data.queue.len() - prelen;
            trace!("Job count incremented to {}", data.jobs);
            self.cond.notify_all();
        }
    }

    fn shutdown(&self) {
        let mut data = self.data.lock().unwrap();
        if !data.shutdown {
            trace!("Shutting down stack");
            data.jobs -= data.queue.len();
            data.queue.clear();
            data.shutdown = true;
            self.cond.notify_all();
        }
    }

    fn is_shutdown(&self) -> bool {
        self.data.lock().unwrap().shutdown
    }

    fn iter(&self) -> JobStackIterator<'_, T> {
        JobStackIterator { stack: self }
    }
}

struct JobStackIterator<'a, T> {
    stack: &'a JobStack<T>,
}

impl<'a, T> Iterator for JobStackIterator<'a, T> {
    type Item = JobStackItem<'a, T>;

    fn next(&mut self) -> Option<Self::Item> {
        let mut data = self.stack.data.lock().unwrap();
        loop {
            trace!("Looping through JobStackIterator::next");
            if data.jobs == 0 || data.shutdown {
                trace!("[JobStackIterator::next] no jobs; returning None");
                return None;
            }
            match data.queue.pop() {
                Some(value) => {
                    return Some(JobStackItem {
                        value,
                        stack: self.stack,
                    })
                }
                None => {
                    trace!("[JobStackIterator::next] queue is empty; waiting");
                    data = self.stack.cond.wait(data).unwrap();
                }
            }
        }
    }
}

struct JobStackItem<'a, T> {
    value: T,
    stack: &'a JobStack<T>,
}

impl<T> Deref for JobStackItem<'_, T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &self.value
    }
}

impl<T> Drop for JobStackItem<'_, T> {
    fn drop(&mut self) {
        let mut data = self.stack.data.lock().unwrap();
        data.jobs -= 1;
        trace!("Job count decremented to {}", data.jobs);
        if data.jobs == 0 {
            self.stack.cond.notify_all();
        }
    }
}

pub fn fastio_checksum<P: AsRef<Path>>(dirpath: P, threads: usize) -> Result<String, WalkError> {
    let dirpath = dirpath.as_ref();
    let stack = Arc::new(JobStack::new([DirEntry {
        path: dirpath.to_path_buf(),
        is_dir: true,
    }]));
    let (sender, receiver) = channel();
    for i in 0..threads {
        let basepath = dirpath.to_path_buf();
        let stack = Arc::clone(&stack);
        let sender = sender.clone();
        thread::spawn(move || {
            trace!("[{i}] Starting thread");
            for entry in stack.iter() {
                trace!("[{i}] Popped {:?} from stack", *entry);
                let output = if entry.is_dir {
                    match listdir(&entry.path) {
                        Ok(entries) => {
                            stack.extend(
                                entries
                                    .into_iter()
                                    .inspect(|n| trace!("[{i}] Pushing {n:?} onto stack")),
                            );
                            None
                        }
                        Err(e) => Some(Err(e)),
                    }
                } else {
                    Some(FileInfo::for_file(&entry.path, &basepath))
                };
                if let Some(v) = output {
                    // If we've shut down, don't send anything except Errs
                    if v.is_err() || !stack.is_shutdown() {
                        if v.is_err() {
                            stack.shutdown();
                        }
                        trace!("[{i}] Sending {v:?} to output");
                        match sender.send(v) {
                            Ok(_) => (),
                            Err(_) => {
                                warn!("[{i}] Failed to send; exiting");
                                stack.shutdown();
                                return;
                            }
                        }
                    }
                }
            }
            trace!("[{i}] Ending thread");
        });
    }
    drop(sender);
    // Force the receiver to receive everything (rather than breaking out early
    // on an Err) in order to ensure that all threads run to completion
    let mut infos = Vec::new();
    let mut err = None;
    for v in receiver {
        match v {
            Ok(i) => {
                infos.push(i);
            }
            Err(e) => {
                err.get_or_insert(e);
            }
        }
    }
    match err {
        Some(e) => Err(e),
        None => Ok(compile_checksum(infos)),
    }
}

// The following items are support code that aren't part of what I want
// reviewed, so I'm just going to include stubbed-out declarations.

#[derive(Debug)]
pub(crate) struct DirEntry {
    pub(crate) path: PathBuf,
    pub(crate) is_dir: bool,
}

pub(crate) fn listdir<P: AsRef<Path>>(dirpath: P) -> Result<Vec<DirEntry>, WalkError> {
    ...
}

enum WalkError {
    ...
}

pub fn compile_checksum<I: IntoIterator<Item = FileInfo>>(seq: I) -> String {
    ...
}

pub struct FileInfo {
    ...
}

impl FileInfo {
    pub fn for_file<P, Q>(path: P, basepath: Q) -> Result<FileInfo, WalkError>
    where
        P: AsRef<Path>,
        Q: AsRef<Path>,
    { ... }
}

The trick for implementing an iterator that returns references without GATs is to implement the trait for a reference to your type

impl<'a, T> Iterator for &'a JobStack {
    type Item = JobStackItem<'a, T>;
    // etc... 
}

That's effectively what you were doing anyway, there's just a thin wrapper struct in between.

Generally a "collection" type should implement IntoIterator rather than Iterator but since your iterator is consuming values from the collection rather than providing a view into the collection it does make sense to skip the intermediary type I think

1 Like

Neat! That works, but then it turns out I have to write for entry in &*stack, which doesn't seem very appealing.

Oh right because there's a smart pointer involved. You'll need to do something to get from the Arc to a reference. The method call did that for you automatically which is why you didn't need to do any deref gymnastics the old way

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.