Memoizing fixed-point combinator

A little coding challenge. The task is to state a generic memoizing fixed-point combinator. As an application, calculate the Fibonacci numbers with it, i.e.

def fix(m,f):
    def g(x):
        if x not in m: m[x] = f(g,x)
        return m[x]
    return g

fib = fix({0: 1, 1: 1}, lambda f,n: f(n-1) + f(n-2))
a = [fib(n) for n in range(0,10)]
print(a)
Spoiler: Ugly solution
use std::collections::HashMap;
use std::hash::Hash;
use std::cell::RefCell;

macro_rules! hashmap {
    ( $( $key:expr => $value:expr ),* ) => {{
        let mut _m = HashMap::new();
        $(_m.insert($key,$value);)*
        _m
    }}
}

struct Fun<X,Y> {
    m: RefCell<HashMap<X,Y>>,
    f: Box<dyn Fn(&Fun<X,Y>, X) -> Y>
}
impl<X,Y> Fun<X,Y> where X: Clone+Eq+Hash, Y: Clone {
    fn call(&self, x: X) -> Y {
        if let Some(value) = self.m.borrow().get(&x) {
            return value.clone();
        }
        let value = (self.f)(&self,x.clone());
        self.m.borrow_mut().insert(x,value.clone());
        value.clone()
    }
}

fn fix<X: Copy+Eq+Hash, Y: Clone>(
    m: HashMap<X,Y>,
    f: impl Fn(&Fun<X,Y>,X) -> Y + 'static
) -> impl Fn(X) -> Y
{
    let f = Fun{m: RefCell::new(m), f: Box::new(f)};
    return move |x| f.call(x);
}

fn main() {
    let fib = fix(hashmap!{0 => 1, 1 => 1},
        |f, n| f.call(n-1) + f.call(n-2));
    let a: Vec<u32> = (0..10).map(fib).collect();
    println!("{:?}", a);
}

Here’s what I came up with. Letting the function refer to its own memoized version proved trickier than I originally anticipated.

Edit: my cyclical recursion check isn’t quite right; it can spuriously fire for two concurrent threads askingfor the same value.

use std::{
    collections::{hash_map::Entry, HashMap},
    hash::Hash,
    sync::{Mutex, Once},
};

fn memoize<I: Clone + Hash + Eq, O: Clone, F: Fn(I) -> O>(f: F) -> impl Fn(I) -> O {
    let memo = Mutex::new(HashMap::<I, Option<O>>::new());
    move |i: I| {
        match memo.lock().unwrap().entry(i.clone()) {
            Entry::Occupied(e) => return e.get().clone().expect("Cyclical recursion!"),
            Entry::Vacant(e) => {
                e.insert(None);
            }
        }
        let result = f(i.clone());
        memo.lock().unwrap().insert(i, Some(result.clone()));
        result
    }
}

macro_rules! memoized_fn_rec {
    ($viz:vis $name:ident($arg:ident : $ty_in:ty)->$ty_out:ty $body:block) => {
        $viz fn $name($arg:$ty_in)->$ty_out {
            static mut MEMOIZED:Option<Box<dyn Fn($ty_in)->$ty_out>>=None;
            static INIT:Once = Once::new();
            INIT.call_once(|| {
                let inner = memoize(|$arg| {
                    let $name:&dyn Fn($ty_in)->$ty_out = unsafe { &*(&MEMOIZED).as_ref().unwrap() };
                    $body
                });
                unsafe { MEMOIZED=Some(Box::new(inner)); }
            });
            unsafe { (&MEMOIZED.as_ref().unwrap())($arg) }
        }
    }
}

memoized_fn_rec! {
    fib(x:usize)->usize {
        match x {
            0 => 1,
            1 => 1,
            _ => fib(x-1) + fib(x-2),
        } 
    }
}

fn main() {
    for x in 0..10 {
        println!("{}", fib(x));
    }
}

(Playground)

1 Like