Building a thread safe (lazy) cache with invalidation/reset

For a project, I'll need a thread-safe way of caching the result of a calculation, which can be reset/invalidated due to new elements being added. In which case, it needs to be recalculated (lazily) once the calculation is accessed again.

In pseudo-code, something along the lines of:

struct Validator<T, R> {
  elements: Vec<T>,
  cache: LazyCache<R>,
}

impl<T, R> Validator<T, R> {
  fn add(&self, val: T) {
    self.elements.push(val);
    self.cache.invalidate();  
  }
  
  fn long_calculation(&self) -> R {
    if !self.cache.is_valid() {
        let ret = do_calculation(&self.elements);
        self.cache.store(ret);
    }

    self.cache.get()
  }
}

Of course, a naive implementation of this wouldn't be thread safe, or if the cache is naively guarded with a Mutex, it wouldn't allow for many concurrent reads.

The main issue here is that I need exclusive access to the cache during recalculations (it's mutated), but I'll need to check before every read, whether it's valid, and we don't want to run the calculations multiple times.

Therefore, inspired by double checked locking, I came up with the following RwLock based implementation.

pub struct Context<T> {
    lock: RwLock<Option<T>>,
}

impl<T> Context<T> where T: Copy {
    pub fn new() -> Self {
        Self {
            lock: RwLock::new(None),
        }
    }
    pub fn reset(&self) {
        let mut writer = self.lock.write();
        *writer = None;
    }

    pub fn validate<F>(&self, f: F) -> T
    where
        F: FnOnce() -> T,
    {
        let reader = self.lock.read(); // <-- (1)
        if let Some(value) = &*reader {
            *value
        } else {
            drop(reader); // <-- (2)
            let mut writer = self.lock.write(); // <-- (3)
            if writer.is_some() {
                return *writer.as_ref().unwrap(); // <-- (4)
            }
            *writer = Some(f());
            *writer.as_ref().unwrap()
        }
    }

The "pattern" of double checking can be seen after acquiring the write() lock, because multiple threads could see a None value when acquiring the read lock (at (1)), then drop it (at (2)) and wait for the write lock to be acquired (at (3)). One of the (potentially) many threads will acquire it, update the string in the critical section, before dropping the lock again. At which point, another thread will grab the write lock, and sees that there is a value already in the cache and return it directly (at (4)).

miri believes this is sound, and my preliminary tests also confirmed that.

Is that true? Is there a better way to do this?

Fully working test case can be found in this playground.

1 Like

Am I reading this correctly that the value inside the lock can will only ever be None or Some("done")?

In this example, yes. In reality, the String replaced by the result of a calculation of a graph algorithm (detection of cycles in a directed graph).

I don't want to run this on every access, so I only recalculate it when the graph changes; therefore, I need to cache the result.

1 Like

I don't see anything wrong here -- I think you have correctly handled the race condition. BTW, I believe Miri would not detect an error here, because you have no unsafe code. Miri only detects data races and other UB, and those can occur only with unsafe code.

There are non-blocking approaches such as using arc-swap. But I don't think they give you the semantics you're after. I think you want invalidating threads to block while another thread is replacing the value, so no invalidations are missed.

2 Likes

How about something like this?

pub struct Context {
    lock: RwLock<OnceLock<String>>,
}

impl Context {
    pub fn new() -> Self {
        Self {
            lock: RwLock::new(OnceLock::new()),
        }
    }

    pub fn reset(&self) {
        let pre = std::time::Instant::now();
        let _ = self.lock.write().take();
        eprintln!(
            "{:?}: reset {:?}",
            std::thread::current().id(),
            pre.elapsed()
        );
    }

    pub fn validate<F>(&self, f: F) -> bool
    where
        F: FnOnce() -> String,
    {
        let pre = std::time::Instant::now();
        let result = (self.lock.read().get_or_init(f) == "done");

        eprintln!(
            "{:?}: validate {:?}",
            std::thread::current().id(),
            pre.elapsed()
        );
        result
    }
}
8 Likes

That is such a great use of OnceLock.

Yes, I was looking into arc-swap as well, however, as you noted correctly, it doesn't provide the necessary semantics that I was after.

I couldn't find this pattern anywhere, is this so out of the ordinary what I'm doing here?

1 Like

Very elegant solution!

I ran some very simple benchmarking, timing the average time it takes to do the reset, and it turns out the solution with the OnceLock is ~200ns on average faster.

Benchmark code below, ran with cargo run --release on an Intel(R) Xeon(R) Platinum 8488C.

Code
#![allow(unused, clippy::new_without_default)]
use std::borrow::BorrowMut;
use std::cell::RefCell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::thread::sleep;
use std::time::Duration;

use parking_lot::{Mutex, RwLock};
use simple_moving_average::{SingleSumSMA, SMA};

thread_local! {
    pub static SUM_RESET: RefCell<SingleSumSMA::<Duration, u32, 10>> = RefCell::new(SingleSumSMA::from_zero(Duration::ZERO));
}

pub struct Context {
    lock: RwLock<Option<String>>,
    lock1: RwLock<OnceLock<String>>,
}

impl Context {
    pub fn new() -> Self {
        Self {
            lock: RwLock::new(None),
            lock1: RwLock::new(OnceLock::new()),
        }
    }

    pub fn reset1(&self) {
        let pre = std::time::Instant::now();
        let _ = self.lock1.write().take();
        SUM_RESET.with_borrow_mut(move |sum| sum.add_sample(pre.elapsed()));
        eprintln!(
            "{:?}: reset {:?}",
            std::thread::current().id(),
            pre.elapsed()
        );
    }

    pub fn validate1<F>(&self, f: F) -> bool
    where
        F: FnOnce() -> String,
    {
        let result = (self.lock1.read().get_or_init(f) == "done");
        result
    }

    pub fn reset(&self) {
        let pre = std::time::Instant::now();
        let mut writer = self.lock.write();
        *writer = None;
        drop(writer);
        SUM_RESET.with_borrow_mut(move |sum| sum.add_sample(pre.elapsed()));
        eprintln!(
            "{:?}: reset {:?}",
            std::thread::current().id(),
            pre.elapsed()
        );
    }

    pub fn validate<F>(&self, f: F) -> bool
    where
        F: FnOnce() -> String,
    {
        // let pre = std::time::Instant::now();
        let reader = self.lock.read();
        if let Some(value) = &*reader {
            // println!(
            //     "{:?}: cached {:?}",
            //     std::thread::current().id(),
            //     pre.elapsed()
            // );
            value == "done"
        } else {
            drop(reader);
            let pre = std::time::Instant::now();
            let mut writer = self.lock.write();
            if writer.is_some() {
                return *writer.as_ref().unwrap() == "done";
            }
            eprintln!(
                "{:?}: init {:?}",
                std::thread::current().id(),
                pre.elapsed()
            );
            *writer = Some(f());
            writer.as_ref().unwrap() == "done"
        }
    }
}

fn main() {
    let context = Arc::new(Context::new());
    let terminate = Arc::new(AtomicBool::new(false));
    let sum_all = Arc::new(Mutex::new(SingleSumSMA::<Duration, u32, 10>::from_zero(
        Duration::ZERO,
    )));

    let handle0 = {
        let context = Arc::clone(&context);
        let terminate = Arc::clone(&terminate);
        let sum_all = Arc::clone(&sum_all);
        std::thread::spawn(move || {
            let mut count = 0_u32;
            while !terminate.load(Ordering::Relaxed) {
                let v = context.validate1(|| String::from("done"));
                assert!(v);
                count += 1;
                if count == 3_000_000 {
                    count = 0;
                    context.reset1();
                }
            }

            {
                SUM_RESET.with_borrow(move |sum| {
                    let mut lock = sum_all.lock();
                    lock.add_sample(sum.get_average());
                })
            }
        })
    };

    let handle1 = {
        let context = Arc::clone(&context);
        let terminate = Arc::clone(&terminate);
        let sum_all = Arc::clone(&sum_all);
        std::thread::spawn(move || {
            // context.reset1();
            let mut count = 0_u32;
            while !terminate.load(Ordering::Relaxed) {
                let v = context.validate1(|| String::from("done"));
                assert!(v);
                count += 1;
                if count == 1_000_000 {
                    // if count == 10_000 {
                    count = 0;
                    context.reset1();
                }
            }

            {
                SUM_RESET.with_borrow(move |sum| {
                    let mut lock = sum_all.lock();
                    lock.add_sample(sum.get_average());
                })
            }
        })
    };

    let handle2 = {
        let context = Arc::clone(&context);
        let terminate = Arc::clone(&terminate);
        let sum_all = Arc::clone(&sum_all);
        std::thread::spawn(move || {
            let mut count = 0_u32;
            while !terminate.load(Ordering::Relaxed) {
                let v = context.validate1(|| String::from("done"));
                assert!(v);
                count += 1;
                if count == 5_000_000 {
                    count = 0;
                    context.reset1();
                }
            }

            {
                SUM_RESET.with_borrow(move |sum| {
                    let mut lock = sum_all.lock();
                    lock.add_sample(sum.get_average());
                })
            }
        })
    };

    sleep(Duration::from_secs(10));
    terminate.store(true, Ordering::Relaxed);
    handle0.join().unwrap();
    handle1.join().unwrap();
    handle2.join().unwrap();

    println!("average reset took: {:?}", sum_all.lock().get_average());
}

1 Like

When you get around to integrating this solution into your main program, strongly consider protecting the element list with the same RwLock as the cache to ensure consistency. Going back to your initial pseudocode, you could translate it like this:

// Can be used as any of:
//   - Owned `Validator<…>` for single-threaded applications
//   - `Arc<RwLock<Validator<…>>` for multi-threaded applications 
//   - `static VALIDATOR: RwLock<Validator<…>>` as a global cache
struct Validator<T, R> {
  elements: Vec<T>,
  cache: OnceLock<R>,
}

impl<T, R> Validator<T, R> {
  const fn new()->Self {
    Validator { elements: Vec:.new(), cache: OnceLock::new() }
  }

  pub fn add(&mut self, val: T) {
    self.elements.push(val);
    let _ = self.cache.take();  
  }
  
  pub fn long_calculation(&self) -> R {
    self.cache.get_or_init(|| do_calculation(self.elements))
  }
}
1 Like

I do that, sort of. The code is open source on GitHub.

The Validator is the cycle detection mechanism for a dependency injection framework I'm working on.

The validate function is here.

The list of elements is a list of visitors, visiting each dependency once and registering them as a node in a (directed) graph, and adding all dependencies as edges.

When calculating the new cached results, I'm grabbing a lock on the visitors (non-exclusive, which means no new visitors can be added while it's visiting all dependencies) here.

When adding new dependencies, the cache is reset and a new visitor is added (e.g., here).

The potential issue here is that there’s a moment between these two blocks where neither lock is held and context contains a value inconsistent with visitors. If the new visitors create a cycle and another thread checks at that moment, it can receive a false negative which was cached before the new visitors were added.

I believe this can be fixed by holding the visitor write lock until after the context reset is complete, but you always need to be careful of deadlocks if you’re holding multiple locks at the same time. Guarding both with the same lock is less flexible but easier to reason about.

1 Like

That's a good point. Fortunately, easy to fix it can be fixed, although, not easily. :wink:

Thank you!

Actually, surprisingly difficult to move into a single lock because we borrow the visitor list during visitation, and we need a mutable context. Therefore, I'd need an Option<(Visitors, Context)>, which I .take() during visitation, and set it back before relinquishing the lock.

Not sure if that's making things easier. Unfortunately, holding the visitor lock longer also dead locks.

Back to the drawing board.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.