Memoization in Rust?

I'm trying to pass a mutable hashmap to a function which calls itself recursively. Any help getting this working would be much appreciated, or if you could just show an example of memoization in Rust, that would help. Thanks!

use std::collections::HashMap;

fn count_collatz(n: usize, h: &mut HashMap) -> usize {
    if n == 1 {
        return 1;
    }
    match h.get(&n) {
        Some(val) => return *val,
        None => {
            let val = 1 + match n % 2 {
                0 => count_collatz(n / 2, h),
                _ => count_collatz(n * 3 + 1, h)
            };
            h.insert(n, val);
            return val
        }
    }
}

fn main() {
    let mut largest = 0;
    let mut largest_seq = 0;
    let mut hash: HashMap = HashMap::new();
    for i in 2..1000001 {
        let seq_length = count_collatz(i, &mut hash);
        if seq_length > largest_seq {
            largest_seq = seq_length;
            largest = i;
        }
    }
    println!("{}", largest);
}

Hmm. This is a case where you have to spend a little effort to get around the borrow rules. Normally, if you see a pattern like this:

match h.get(&k) {
    Some(val) => *val,
    None => {
        let val = ...;
        h.insert(k, val);
        val
    }
}

The answer is to use the entry API to get an entry record, which represent a place in in a hash table where an item could be, and allows you to call or_insert or or_insert_with to insert a value for missing keys. This winds up being quite convenient for many cases where you see this pattern, for instance if you want to initialize with 0 and increment every time, you would do:

*h.entry(k).or_insert(0) += 1;

However, in this case, this doesn't help too much, because you would need to use or_insert_with to compute a value to insert, but that closure would need to re-borrow the hashmap, which it can't because the Entry already has it borrowed:

*h.entry(n).or_insert_with(||
    1 + match n % 2 {
            0 => count_collatz(n / 2, h),
            _ => count_collatz(n * 3 + 1, h)
        }
)

The above fails with the error "error: closure requires unique access to h but *h is already borrowed."

So, I think to do this you're going to need to break apart your borrows into a couple of distinct pieces. One for getting the value out if it is present, one for doing the recursive computation, and one for adding the result of the computation to the hashmap. Luckily if let makes this not too painful; the borrow upon getting the value only lasts within the if let, so afterwards you can continue to use the reference to recursively compute and insert the new value.

fn count_collatz(n: usize, h: &mut HashMap<usize, usize>) -> usize {
    if n == 1 {
        return 1;
    }
    if let Some(val) = h.get(&n) {
        return *val;
    }
    let val = 1 + match n % 2 {
        0 => count_collatz(n / 2, h),
        _ => count_collatz(n * 3 + 1, h)
    };
    h.insert(n, val);
    val
}
3 Likes

Thanks so much for the explanation, Lambda.

Just random additional (probably non-helpful) comment: If you are going to do Collatz conjecture stuff, you might avoid recursion for full credit. :wink: You may end up blowing the stack, and if you disprove the conjecture, you might end up in an infinite loop (just before blowing the stack). Maybe consider managing your own stack (unless that totally isn't the point of the question)?

1 Like

Thanks for your thoughts frankmcsherry. In this case it was just a convenient example to learn how to do memoization in Rust. Interestingly for this example the code without memoization is actually faster than that with.