Fold version slower than for version

The two functions below do the same thing. But the for version is significantly faster than the fold version.

I know this isn't a proper benchmark but I ran this program multiple times and the difference was always significant, especially if you make the value of limit larger.

I ran the code in the playground using stable, beta and nightly (all using release) and the results are the same.

I think it has to do with recursion since my other mini programs containing fold and recursion show the same slower performance compared to their for counterparts. Those that don't have recursion don't show such behavior.

fn count_of_50_smooths_for(limit: u64, cur_smooth: u64, primes: &[u64]) -> u64 {
    let mut acc = 0;
    
    for (i, &p) in primes.iter().enumerate().take_while(|(_, &p)| p <= limit / cur_smooth) {
        acc += 1 + count_of_50_smooths_for(limit, cur_smooth * p, &primes[i..]);
    }

    acc
}
fn count_of_50_smooths_fold(limit: u64, cur_smooth: u64, primes: &[u64]) -> u64 {
    primes.iter().enumerate().take_while(|(_, &p)| p <= limit / cur_smooth).fold(0, |acc, (i, &p)| {
        acc + 1 + count_of_50_smooths_fold(limit, cur_smooth * p, &primes[i..])
    })
}
fn main() {
    let timer = std::time::Instant::now();
    
    println!("using for loop:");
    println!("    result:  {}", count_of_50_smooths_for(2u64.pow(50), 1, &[2, 3, 5, 7, 11, 13, 17, 19, 23, 27, 31, 37, 41, 43, 47]));//, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]));
    println!("    time:    {:?}", timer.elapsed());

    let timer = std::time::Instant::now();
    
    println!("using fold:");
    println!("    result:  {}", count_of_50_smooths_fold(2u64.pow(50), 1, &[2, 3, 5, 7, 11, 13, 17, 19, 23, 27, 31, 37, 41, 43, 47]));//, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]));
    println!("    time:    {:?}", timer.elapsed());
}

1 Like

probably the optimizer gave up because the iterator type is two complex. assembly output shows that the TakeWhile<I, P> wrapper was not inlined.

however, if I change the iterator like this, it gives similar runtimes in both cases:

 fn count_of_50_smooths_fold(limit: u64, cur_smooth: u64, primes: &[u64]) -> u64 {
-    primes.iter().enumerate().take_while(|(_, &p)| p <= limit / cur_smooth).fold(0, |acc, (i, &p)| {
+    primes.iter().take_while(|p| **p <= limit / cur_smooth).enumerate().fold(0, |acc, (i, &p)| {
         acc + 1 + count_of_50_smooths_fold(limit, cur_smooth * p, &primes[i..])
     })
 }
4 Likes

Nice trick! And it makes more sense since the index isn't needed by take_while. I wish there's a clippy lint for using enumerate and zip too early.

Unfortunately, it seems the problem also exists with filter and the trick won't be applicable.

Here's an example of for vs fold but with filter instead of take_while. for was noticeably faster, though the difference was not as large as with take_while. I ran it in termux (android 13, snapdragon 695, arm v8, 6/128). I'm not sure if the result will be the same on x86.

fn mults_count_primes_fold(primes: &[u64], limit: u64, sum: u64, val: u64, odd: bool) -> u64 {
    primes.iter().enumerate().filter(|&(_, &p)| p <= limit / val).fold(sum, |acc, (i, &p)| { 
        let new_val = val * p;
        let temp_acc = if odd { acc + limit / new_val } else { acc - limit / new_val };

        mults_count_primes_fold(&primes[(i + 1)..], limit, temp_acc, new_val, !odd)
    })
}
fn mults_count_primes_for(primes: &[u64], limit: u64, sum: u64, val: u64, odd: bool) -> u64 {
    let mut acc = sum;

    for (i, &p) in primes.iter().enumerate().filter(|&(_, &p)| p <= limit / val) {
        let new_val = val * p;
        let temp_acc = if odd { acc + limit / new_val } else { acc - limit / new_val };

        acc = mults_count_primes_for(&primes[(i + 1)..], limit, temp_acc, new_val, !odd);
    }
    
    acc
}
fn main() {
    let timer = std::time::Instant::now();
    println!("using fold (ascending primes):");
    println!("    result:  {:?}", mults_count_primes_fold(&[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 105, 107, 113], u64::MAX, 0, 1, true));
    println!("    time:    {:?}", timer.elapsed());

    let timer = std::time::Instant::now();
    println!("using fold (descending primes):");
    println!("    result:  {:?}", mults_count_primes_fold(&[113, 109, 107, 103, 101, 97, 89, 83, 79, 73, 71, 67, 61, 59, 53, 47, 43, 41, 37, 31, 29, 23, 19, 17, 13, 11, 7, 5, 3, 2], u64::MAX, 0, 1, true));
    println!("    time:    {:?}", timer.elapsed());

    let timer = std::time::Instant::now();
    println!("using for (ascending primes):");
    println!("    result:  {:?}", mults_count_primes_for(&[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 105, 107, 113], u64::MAX, 0, 1, true));
    println!("    time:    {:?}", timer.elapsed());

    let timer = std::time::Instant::now();
    println!("using for (descending primes):");
    println!("    result:  {:?}", mults_count_primes_for(&[113, 109, 107, 103, 101, 97, 89, 83, 79, 73, 71, 67, 61, 59, 53, 47, 43, 41, 37, 31, 29, 23, 19, 17, 13, 11, 7, 5, 3, 2], u64::MAX, 0, 1, true));
    println!("    time:    {:?}", timer.elapsed());
}

even if this code is not in the hot path, I don't think fold() is really better in terms of readability in this example. of course maybe it's a different story for the real code, but in the real application, you don't want to bet on the optimization for performance critical code anyway.

so, I'd just use a regular for loop and stop worrying about it.

3 Likes