Rayon: Apply par_iter after a take_while

#1

I’m trying out Rust now and am using rayon to try speed up a simple prime numbers counter.

fn get_primes(count: i32) -> Vec<i32> {
    let mut found = vec![];
    for cur in 2..count {
        let max = (cur as f32).sqrt() as i32;
        if !found.iter().take_while(|n| *n < &max).par_iter().any(|n|{
            cur % n == 0
        }){
            found.push(cur);
        }
    }
    found
}

However I understand that Rayon only works on an arraylike type (ie. it cant work on an iterator). How do I apply a parelllel iterator on only part of “found” which is are smaller than max?

#2

However I understand that Rayon only works on an arraylike type (ie. it cant work on an iterator)

It actually can work on iterators using par_bridge. Here’s an example of doing that:

#3

Thanks for your quick response - it was just what I was looking for.
It now compiles:
if found.iter().filter(|n| *n < &max).par_bridge().all(|n|{
cur % n != 0
}){
found.push(cur);
}
However the parallel version is much slower than the serial version (and the par_bridge version which should be doing less work is slower than the fully parallel version)!
Now I need to see how to rewrite this in a way that making it parallel actually improves performance

#4

I’d be very surprised, at least for this example code anyway, if par_bridge is faster than just running the any in-line, since the task scheduling almost certainly takes more time than computing the modulo operation.

A perhaps better formulation to experiment with is something like this: https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=794f1a68ff3765d92e370a094c65b473

#5

Thanks - as_slice with binary search is much better than take_while.

BTW cur >> 1 is the equivalent of divide by 2, not square_root.
There is a crate that implements integer-sqrt, but I opted for the built in one for my starter project.

However you are right about the task scheduling, since my in-line version is still faster.

1 Like
#6

For the sake of comparison, here is the serial version:

fn get_primes(count: i32) -> Vec<i32> {
let mut found = vec![];
for cur in 2..count {
    let max = (cur as f32).sqrt() as i32;
    let mut isprime = true;
    for prev in &found {
        if prev > &max { break; };
        if cur % prev == 0 {
            isprime = false;
            break;
        }
    };
    if isprime {
        found.push(cur);
    }
}
found
}

Done. Found 78498 prime numbers in 44.040482ms

#7

I experimented further and realised that LLVM was optimising away my code! :slight_smile:
I improved my code based on the wikipedia article on primality test - that after 2 and 3 all primes fit in the pattern 6k±1 which means that I only needed to check 1/6 of the numbers.

fn get_primes_serial(count: u32) -> Vec<u32> {
    let mut found = vec![];
    if count > 2 {
        found.push(2);
    }
    if count > 3 {
        found.push(3);
    }
    for cur in (5..count).step_by(6) {
        let max =(cur as f32).sqrt() as u32;
        let is_prime = !found.iter().take_while(|n|*n <= &max).any(|prev| cur % prev == 0);
        if is_prime {
            found.push(cur);
        }
        let cur = cur + 2;
        let max = (cur as f32).sqrt() as u32;
        let is_prime = !found.iter().take_while(|n|*n <= &max).any(|prev| cur % prev == 0);
        if is_prime {
            found.push(cur);
        }
    }
    found
}

However it took exactly the same amount of time as my original naive version!

fn get_primes_serial(count: i32) -> Vec<i32> {
    let mut found = vec![];
    for cur in 2..count {
        let max = (cur as f32).sqrt() as i32;
        let is_prime = !found.iter().take_while(|n|*n <= &max).any(|prev| cur % prev == 0);
        if is_prime {
            found.push(cur);
        }
    }
    found
}
#8

BTW cur >> 1 is the equivalent of divide by 2, not square_root.

My excuse is that I was very tired when I wrote that! :slight_smile:

#9

Forgetting about rust for a moment, let’s talk prime number algorithms.


If you really want to make it fast, a Sieve of Eratosthenes is even faster.

The premise is simple: You have a big boolean array. Each time you find a number that’s not crossed off, you cross off all of its multiples.

  • 2 is not crossed off, so cross off 4, 6, 8, 10, 12, 14, 16, 18, 20…
  • 3 is not crossed off, so cross off 6, 9, 12, 15, 18, 21, 24, 27…
  • 5 is not crossed off, so cross off 10, 15, 20, 25, 30, 35, 40, 45…

Notably, each time you encounter a new prime p, you’ve already crossed off all multiples smaller than p*p, so you can start from that number:

  • 2 is not crossed off, so cross off 4, 6, 8, 10, 12, 14, 16, 18, 20…
  • 3 is not crossed off, so cross off 9, 12, 15, 18, 21, 24, 27…
  • 5 is not crossed off, so cross off 25, 30, 35, 40, 45…

This algorithm performs integer multiplication and addition where yours performed integer modulus, and I believe it has a lower big-O complexity as well.

fn sieve_of_eratosthenes(limit: i64) -> Vec<bool> {
    let mut sieve: Vec<_> = vec![true; limit as usize];
    sieve[0] = false;
    sieve[1] = false;

    for x in 2..limit {
        let first_new_multiple = x*x;
        if first_new_multiple > limit { break }
        if sieve[x as usize] {
            for m in (first_new_multiple..limit).step_by(x as usize) {
                sieve[m as usize] = false;
            }
        }
    }
    sieve
}

This can find all primes up to 10^8 in 1.183s on my machine.


There’s a number of ways you can improve upon this, which I’ll leave as an exercise to the reader:

  • Optimizing for multiples of 2 and maybe 3 so that you can cross off fewer repeat numbers, similar to what you did above. (This is called a wheel optimization).
  • Similar to the above, but compress the representation of the list so that it doesn’t even contain data for multiples of 2 or 3. Return a new type struct Sieve(Vec<bool>) with methods that check the right element for a given number. (or maybe produce a list of primes from the compressed mask and return that)
  • Turn it into a factor sieve. Basically, change the output to a Vec<i64>, where sieve[n] returns the smallest factor of n for n > 1 (and so primes are thus numbers where sieve[n] == n). This is extremely useful if you need to compute the factorization of many numbers.
  • I would not bother trying to parallelize it. The algorithm is inherently serial.

For large enough limits, you’ll find that CPU cache misses begin to almost entirely dominate the cost of the algorithm. After all, every new prime requires it to scan through a significant portion of the list. So to push that time even lower, you’ll need to change the order of iteration to be cache-friendly.

This last upgrade will be far more difficult to implement than the others previously mentioned. It’s called the “segmented sieve,” and you can read about it here: https://primesieve.org/segmented_sieve.html

5 Likes