How to calculate prime sum and avoid stack overflow

I am trying to achieve the problem here: efficient way to sum all the primes below 10^9

According to the formula in the reference, list the following recursive algorithms:

use num::{integer::Roots, BigUint};
use primal::is_prime;
use std::collections::BTreeMap;

fn prime_sum_s(i: u64, j: u64) -> BigUint {
    fn memoization(v: u64, p: u64, m: &mut BTreeMap<(u64, u64), BigUint>) -> BigUint {
        if v == 1 {
            BigUint::from(0u64)
        }
        else if v == 2 {
            BigUint::from(2u64)
        }
        else if p == 1 {
            BigUint::from((2 + v) * (v - 1) / 2)
        }
        else if p.pow(2) <= v && is_prime(p) {
            match m.get(&(v, p)) {
                Some(s) => s.clone(),
                None => {
                    let a = memoization(v, p - 1, m);
                    let b = memoization(v / p, p - 1, m);
                    let c = memoization(p - 1, p - 1, m);
                    let result = a - (b - c)*p;
                    m.insert((v, p), result.clone());
                    result
                }
            }
        }
        else {
            match m.get(&(v, p)) {
                Some(s) => s.clone(),
                None => {
                    let result = memoization(v, p - 1, m);
                    m.insert((v, p), result.clone());
                    result
                }
            }
        }
    }
    memoization(i, j, &mut BTreeMap::new())
}

#[test]
fn test() {
    println!("{}", prime_sum_s(1_000_000_000, 1_000_000_000.sqrt()));
}

It performs well when calculating one million, but it has overflowed its stack for larger inputs.

It seems that memoization is necessary for this problem, how do I write a memoization non-recursive algorithm?

Recursion can be refactored to use a stack data structure (Vec push/pop)

fn prime_sum_s2(i: u64, j: u64) -> BigUint {
  // The vast majority are clones of others so sticking them into a rc to avoid the allocations
    let mut m: BTreeMap<(u64, u64), Rc<BigUint>> = BTreeMap::new();
    let mut to_do = vec![(i, j)];
   
    while let Some((v, p)) = to_do.pop() {
        let result = if v == 1 {
            Rc::new(BigUint::from(0u64))
        } else if v == 2 {
            Rc::new(BigUint::from(2u64))
        } else if p == 1 {
            Rc::new(BigUint::from((2 + v) * (v - 1) / 2))
        } else if m.contains_key(&(v, p)) {
            continue;
        } else if p.pow(2) <= v && is_prime(p) {
            let a = if let Some(a) = m.get(&(v, p - 1)) {
                a.as_ref()
            } else {
                to_do.push((v, p));
                to_do.push((v, p - 1));
                continue;
            };

            let b = if let Some(b) = m.get(&(v / p, p - 1)) {
                b.as_ref()
            } else {
                to_do.push((v, p));
                to_do.push((v / p, p - 1));
                continue;
            };

            let c = if let Some(c) = m.get(&(p - 1, p - 1)) {
                c.as_ref()
            } else {
                to_do.push((v, p));
                to_do.push((p - 1, p - 1));
                continue;
            };
            Rc::new(a - (b - c) * p)
        } else {
            if let Some(a) = m.get(&(v, p - 1)) {
                Rc::clone(&a)
            } else {
                to_do.push((v, p));
                to_do.push((v, p - 1));
                continue;
            }
        };

        m.insert((v, p), result);
    }

    m.get(&(i, j)).unwrap().as_ref().clone()
}
2 Likes

Also there's a much faster form of the algorithm (also without recursion) linked in the stack overflow. Here's that ported to Rust

fn sum_of_primes_under(n: u64) -> u64 {
    let r = n.sqrt() as u64;
    assert!(r * r <= n && (r + 1).pow(2) > n);
    let mut v: Vec<_> = (1..r + 1).map(|i| n / i).collect();
    v.extend((0..*v.last().unwrap()).rev());
    let mut s: BTreeMap<u64, u64> = v
        .iter()
        .copied()
        .map(|i| (i, i * (i + 1) / 2 - 1))
        .collect();
    for p in 2..r {
        if s[&p] > s[&(p - 1)] {
            // p is prime
            let sp = s[&(p - 1)];
            let p2 = p * p;
            for &ve in &v {
                if ve < p2 {
                    break;
                }
                *s.get_mut(&ve).unwrap() -= p * (s[&(ve / p)] - sp);
            }
        }
    }
    s[&n]
}
2 Likes

Thank you for your help, this is a great improvement!

There is a small problem with the boundary conditions, which caused an error when calculating 1000:

Fixed version:

pub fn sum_of_primes_under(n: u64) -> u64 {
    let r = n.sqrt() as u64;
    assert!(r * r <= n && (r + 1).pow(2) > n);
    let mut v: Vec<_> = (1..=r).map(|i| n / i).collect();
    v.extend((0..*v.last().unwrap()).rev());
    let mut s: BTreeMap<u64, u64> = v.iter().copied().map(|i| (i, ((i + 1) * i / 2).wrapping_sub(1))).collect();
    for p in 2..=r {
        if s[&p] > s[&(p - 1)] {
            // p is prime
            let sp = s[&(p - 1)];
            let p2 = p * p;
            for &ve in &v {
                if ve < p2 {break}
                *s.get_mut(&ve).unwrap() -= p * (s[&(ve / p)] - sp);
            }
        }
    }
    return s[&n];
}
1 Like