Simple conditional loop

I've been testing simple loops in Rust because I'd like to know more about writing performant ones, here are some very simple examples:

#[inline]
fn rust_cheating(high: isize) -> isize {
    let mut total = 0;
    let mut i = 4;
    let high = high * 2;
    while i <= high {
        total += i;
        i += 4;
    }
    total
}

fn rust_iter1(n: isize) -> isize {
    (1..=n).fold(0, |sum, item| {
        if item%2 == 0 {
            sum + item*2
        } else {
            sum
        }
    })
}

fn rust_iter2(n: isize) -> isize {
    (1..=n)
    .filter(|item| item%2==0)
    .map(|item| item*2)
    .fold(0, |sum, item| sum + item)
}

#[inline]
fn rust_iter3(n: isize) -> isize {
    (4..=n*2).step_by(4).fold(0, |sum, item|  sum + item)
}

Using criterion, this is the benchmark

rust_cheating           time:   [179.32 us 179.44 us 179.61 us]                          
Found 13 outliers among 100 measurements (13.00%)
  5 (5.00%) high mild
  8 (8.00%) high severe

rust_iter1              time:   [270.06 us 270.73 us 271.62 us]                       
Found 14 outliers among 100 measurements (14.00%)
  4 (4.00%) high mild
  10 (10.00%) high severe

rust_iter2              time:   [270.07 us 270.76 us 271.62 us]                       
Found 11 outliers among 100 measurements (11.00%)
  4 (4.00%) high mild
  7 (7.00%) high severe

rust_iter3              time:   [496.71 us 498.06 us 499.66 us]                       
Found 9 outliers among 100 measurements (9.00%)
  8 (8.00%) high mild
  1 (1.00%) high severe

Any tip on how I can improve? I am using n = 1_000_000 .

I'm not entirely sure why the first one is "cheating" but if iterators are necessary then here's my attempt:

fn rust_iter4(n: isize) -> isize {
    (1..=n/2).fold(0, |sum, item| sum + (item*4) )
}

Which should be more or less the same as this for loop:

fn rust_for(n: isize) -> isize {
    let mut total = 0;
    for i in 1..=n/2 {
        total += i * 4;
    }
    total
}

I don't know how it compares performance wise.

2 Likes

That name comes from this article.

Ah. Yeah, when dealing with integers you're not going to beat a while loop no matter how clever you are.

Interesting results:

rust_for                time:   [478.60 us 493.62 us 509.58 us] 
rust_iter4              time:   [2.1201 ns 2.2015 ns 2.2839 ns]

The last one is very impressive

Huh, I wonder why the for loop is so much slower.

That ns means that the compiler optimized away at compile time (performed the computation), right?

Correct.

It was probably able to see that the loop had no side-effects or all inputs were known and evaluatable at compile time.

Whatever framework you use for benchmarking should provide some sort of "black box" for inhibiting those sorts of optimisations. Possibly as part of parameterising the benchmark.

1 Like

Ha, I missed the units. I guess the literal fastest way would be to solve the equation and not loop at all.

Please never use a sum as your examples for this stuff. LLVM has special closed-form optimizations for sums, so really you're just testing the very nuanced difference between whether that manages to happen or not.

As a quick demo, see that there's no loop in the output of this:

1 Like

IMO almost all these functions can't be properly optimized by LLVM because of the additional complexity from the branch, the step_by or RangeInclusive (i.e. ..=).

rust_iter4 is the only one which is simple enough to be optimized, however in the end this happens only because of the implementation of RangeInclusive::fold (which is still a bit complex and leaves a bunch of branches). rust_for uses a for which in turn calls next which is more general than fold and can't be optimized in the codebase.

However if you remove the overhead from RangeInclusive they both are optimized even more, leaving only one branch:

fn rust_iter4_exclusive(n: isize) -> isize {
    (1..n/2+1).fold(0, |sum, item| sum + (item*4) )
}

fn rust_for_exclusive(n: isize) -> isize {
    let mut total = 0;
    for i in 1..n/2+1 {
        total += i * 4;
    }
    total
}

They're pretty much the slightly less efficient version of the following code:

fn rust_manual(n: isize) -> isize {
    if n/2 >= 1 {
        (n/2) * (n/2 + 1) * 2
    } else {
        0
    }
}

This last version doesn't contain any branch, however little changes in the code (for example changing that * 2 to a * 4 / 2) may not result in the same optimizations. If you want to be sure there's no branch then integrate the condition in the math: (n/2>=1) as isize * (n/2) * (n/2 + 1) * 2

So as a conclusion, the lessons learned are:

  • Avoid RangeInclusive as much as you can
  • Prefer Iterator's terminating methods like for_each/fold/try_fold/collect/sum ecc ecc instead of plain for loops
2 Likes

Thank you for the explanation :smiley: