Help -- Fastest Prime Sieve, in Rust

Can you share your last code here? It seems it contains some unnecessary operations. For example, &primes.to_vec() take reference of primes as a non-modifiable slice, allocates heap memory and memcpy all the slice's element into new buffer, and again take its reference as non-modifiable slice. In short, it semantically does nothing but allocation, memcpy, and freeing it after its lifetime.

@Hyeonu please see previous comment I made about &primes.to_vec(). On my system it's faster.

I did a slight modification to nextp_init, which makes it a tad faster (and consistent w/Nim|D version).

Here's the gist link again, with current updated code, including @ExpHP's counter, et al, which makes this implementation functionally equivalent to the Nim|D versions (except Rust computes PG parameters at run time vs compile time).

And here's the source code too.

// Can compile as: $ cargo build --release
// Single val: $ echo val1 | ./twinprimes_ssoz
// Range vals: $ echo val1 val2 | ./twinprimes_ssoz
// Last update: 2019/8/11

extern crate rayon;
extern crate num_cpus;
use rayon::prelude::*;
use std::time::SystemTime;
use std::sync::atomic::{self, AtomicUsize};

// A counter implemented using relaxed (unsynchronized) atomic operations.
struct RelaxedCounter(AtomicUsize);

impl RelaxedCounter {
    fn new() -> Self {
        RelaxedCounter(AtomicUsize::new(0))
    }
    /// Increment and get the new value.
    fn increment(&self) -> usize {
        self.0.fetch_add(1, atomic::Ordering::Relaxed) + 1
    }
}

fn print_time(title: &str, time: SystemTime) {
    print!("{} = ", title);
    println!("{} secs", {
        match time.elapsed() {
            Ok(e) => {
                e.as_secs() as f64 +
                e.subsec_nanos() as f64 / 1_000_000_000f64},
            Err(e) => {panic!("Timer error {:?}", e)},
        }
    });
}

// Global array used to count number of primes in each 'seg' byte.
// Each value is number of '0' bits (primes) for values 0..255.
static PBITS: [usize; 256] = [
   8,7,7,6,7,6,6,5,7,6,6,5,6,5,5,4,7,6,6,5,6,5,5,4,6,5,5,4,5,4,4,3
  ,7,6,6,5,6,5,5,4,6,5,5,4,5,4,4,3,6,5,5,4,5,4,4,3,5,4,4,3,4,3,3,2
  ,7,6,6,5,6,5,5,4,6,5,5,4,5,4,4,3,6,5,5,4,5,4,4,3,5,4,4,3,4,3,3,2
  ,6,5,5,4,5,4,4,3,5,4,4,3,4,3,3,2,5,4,4,3,4,3,3,2,4,3,3,2,3,2,2,1
  ,7,6,6,5,6,5,5,4,6,5,5,4,5,4,4,3,6,5,5,4,5,4,4,3,5,4,4,3,4,3,3,2
  ,6,5,5,4,5,4,4,3,5,4,4,3,4,3,3,2,5,4,4,3,4,3,3,2,4,3,3,2,3,2,2,1
  ,6,5,5,4,5,4,4,3,5,4,4,3,4,3,3,2,5,4,4,3,4,3,3,2,4,3,3,2,3,2,2,1
  ,5,4,4,3,4,3,3,2,4,3,3,2,3,2,2,1,4,3,3,2,3,2,2,1,3,2,2,1,2,1,1,0];

fn gcd(mut m: usize, mut n: usize) -> usize {
   while m != 0 { let t = m; m = n % m; n = t }
   n
}

fn modinv(a0: usize, m0: usize) -> usize {
    if m0 == 1 { return 1 }
    let (mut a, mut m) = (a0 as isize, m0 as isize);
    let (mut x0, mut inv) = (0, 1);
    while a > 1 {
        inv -= (a / m) * x0;
        a = a % m;
        std::mem::swap(&mut a, &mut m);
        std::mem::swap(&mut x0, &mut inv);
    }
    if inv < 0 { inv += m0 as isize }
    inv as usize
}

fn gen_pg_parameters(prime: usize) -> (usize, usize, usize, Vec<usize>, Vec<usize>, Vec<usize>) {
  // Create constant parameters for given PG
  println!("generating parameters for P{}", prime);
  let primes: Vec<usize> = vec![2, 3, 5, 7, 11, 13, 17, 19, 23];
  let mut modpg = 1usize;
  for prm in primes { modpg *= prm; if prm == prime { break } } // PG's modulus

  let mut residues: Vec<usize> = vec![];    // generate PG's residue values here
  let (mut pc, mut inc) = (5usize, 2usize); // use P3's PGS to reduce pcs to check
  while pc < modpg / 2 {                    // find a residue, then modular complement
    if gcd(modpg, pc) == 1 { residues.push(pc); residues.push(modpg - pc) }
    pc += inc; inc ^= 0b110;
  }
  residues.sort(); 
  residues.push(modpg - 1);
  residues.push(modpg + 1);

  let mut restwins: Vec<usize> = vec![];    // extract upper twin pair residues here
  let mut j = 0usize;
  while j < residues.len() - 1 {
    if residues[j] + 2 == residues[j + 1] { j += 1; restwins.push(residues[j]); }
    j += 1;
  }
  let mut inverses: Vec<usize> = vec![];    // create PG's residues inverses here
  for res in &residues { inverses.push(modinv(*res, modpg)) }

  let mut pos = vec![0; modpg];             // convert res val -> res indx array here
  for i in 0..residues.len() { pos[residues[i] - 2] = i };
  
  (modpg, residues[0], restwins.len(), restwins, inverses, pos)
}

fn select_pg(start_range: usize, end_range: usize) -> (usize, usize) {
  // Select at runtime best PG and segment size factor to use for input value.
  // These are good estimates derived from PG data profiling. Can be improved.
  let range = end_range - start_range;
  let bn: usize; let pg: usize;
  if end_range < 49 { 
    bn = 1; pg = 3;
  } else if range < 10_000_000 { 
    bn = 16; pg = 5;
  } else if range <  1_100_000_000 {
    bn = 32; pg = 7;
  } else if range < 35_500_000_000 {
    bn = 64; pg = 11;
  } else if range < 15_000_000_000_000 {
    pg = 13;
    if      range > 7_000_000_000_000 { bn = 384; }
    else if range > 2_500_000_000_000 { bn = 320; }
    else if range >   250_000_000_000 { bn = 196; }
    else { bn = 128; }
  } else {
    bn = 384; pg = 17;
  }
  return (bn, pg)
}

fn sozpg(val: usize, res_0: usize) -> Vec<usize> {
  // Compute the primes r1..sqrt(input_num) and store in 'primes' array.
  // Any algorithm (fast|small) is usable. Here the SoZ for P5 is used.
  let md = 30;                      // P5's modulus value
  let rscnt = 8;                    // P5's residue count
  static RES: [usize; 8] = [7,11,13,17,19,23,29,31];
  static POSN: [usize; 30] = [0,0,0,0,0,0,0,0,0,1,0,2,0,0,0,3,0,4,0,0,0,5,0,0,0,0,0,6,0,7];

  let kmax = (val - 7) / md + 1;    // number of resgroups upto input value
  let mut prms = vec![0u8; kmax];   // byte array of prime candidates, init '0'
  let sqrt_n = (val as f64).sqrt() as usize; // compute integer sqrt of val
  let (mut modk, mut r, mut k) = (0, 0, 0 ); // initialize residue parameters

  // mark the multiples of the primes r1..sqrtN in 'prms'
  loop {
    if r == rscnt { r = 0; modk += md; k += 1 }
    if (prms[k] & (1 << r)) != 0 { r += 1; continue } // skip pc if not prime
    let prm_r = RES[r];             // if prime save its residue value
    let prime = modk + prm_r;       // numerate the prime value
    if  prime > sqrt_n { break }    // we're finished when it's > sqrtN
    for ri in &RES {                // mark prime's multiples in prms
      let prod = prm_r * ri - 2;    // compute cross-product for prm_r|ri pair
      let bit_r = 1 << POSN[prod % md];           // bit mask for prod's residue
      let mut kpm = k * (prime + ri) + prod / md; // 1st resgroup for prime mult
      while kpm < kmax { prms[kpm] |= bit_r; kpm += prime };
    }
    r += 1;
  }
  // prms now contains the nonprime positions for the prime candidates r1..N
  // extract primes into global var 'primes'
  let mut primes = vec![];          // create empty dynamic array for primes
  for k in 0..kmax {                // for each resgroup
    for r in 0..rscnt {             // numerate|store primes from pcs list
      if (prms[k] & (1 << r)) == 0 { primes.push(md * k + RES[r]) }
    }
  }
  while primes[0] < res_0 { primes.remove(0); }
  while primes[primes.len() - 1] > val { primes.pop(); }
  primes
}

fn nextp_init(rhi: usize, kmin: usize, modpg: usize, start_num: usize,  
    primes: &[usize], resinvrs: &[usize], pos: &[usize]) -> Vec<usize> {
  // Initialize 'nextp' array for twin pair residues at 'indx' in 'restwins'.
  // Compute 1st prime multiple resgroups for each prime r1..sqrt(N) and store
  // consecutively as lo_tp|hi_tp pairs for their restracks.
  let mut nextp = vec![0usize; primes.len() * 2];// 1st mults array for this twin pair
  let (r_hi, r_lo) = (rhi, rhi - 2);     // upper|lower twin pair residue values
  let mut kmin = kmin - 1;               // resgroup index for start_num
  if kmin * modpg + r_lo < start_num { kmin += 1 }; // ensure r_lo in range
  for (j, prime) in primes.iter().enumerate() { // for each prime r1..sqrt(N)
    let k = (prime - 2) / modpg;         // find the resgroup it's in
    let r = (prime - 2) % modpg + 2;     // and its residue value
    let r_inv = resinvrs[pos[r - 2]];    // and residue inverse
    let mut ri = (r_lo * r_inv - 2) % modpg + 2;  // compute r's ri for r_lo
    let mut ki = k * (prime + ri) + (r * ri - 2) / modpg; // and 1st mult
    if ki < kmin { ki = (kmin - ki) % prime; if ki > 0 { ki = prime - ki } } 
    else { ki = ki - kmin };
    nextp[2 * j] = ki;      // prime's 1st mult resgroup val in range for lo_tp
    ri = (r_hi * r_inv - 2) % modpg + 2;          // compute r's ri for r_hi
    ki = k * (prime + ri) + (r * ri - 2) / modpg; // and 1st mult resgroup
    if ki < kmin { ki = (kmin - ki) % prime; if ki > 0 { ki = prime - ki } } 
    else { ki = ki - kmin };
    nextp[2 * j | 1] = ki;  // prime's 1st mult resgroup val in range for hi_tp
  }
  nextp
}

fn twins_sieve(indx: usize, kmin: usize, kmax: usize, kb: usize, start_num: usize, end_num: usize, modpg: usize,
   primes: &[usize], restwins: &[usize], resinvrs: &[usize], pos: &[usize]) -> (usize, usize) {
   // Perform in a thread, the ssoz for a given twinpair, for Kmax resgroups.
   // First create|init 'nextp' array of 1st prime mults for given twin pair,
   // (stored consequtively in 'nextp') and init seg byte array for KB resgroups.
   // For sieve, mark resgroup bits to '1' if either twinpair restrack is nonprime,
   // for primes mults resgroups, and update 'nextp' restrack slices acccordingly.
   // Find last twin prime|sum for range, store in their arrays for this twinpair.
   let (mut sum, mut ki, mut kn) = (0usize, kmin - 1, kb);
   let (mut hi_tp, mut k_max) = (0usize, kmax); // max twin prime|resgroup val
   let r_hi = restwins[indx];                   // twin prime hi residue value
   let mut seg = vec![0u8; ((kb-1) >> 3) + 1];  // seg byte array for KB resgroups
   let mut nextp = nextp_init(r_hi, kmin, modpg, start_num, primes, resinvrs, pos);
   if ((ki * modpg) + r_hi - 2) < start_num  { ki += 1; }    // ensure lo tps in range
   if ((k_max - 1) * modpg + r_hi) > end_num { k_max -= 1; } // ensure hi tps in range
   while ki < k_max {                         // for Kn resgroup size slices upto Kmax
     if kb > (k_max - ki) { kn = k_max - ki } // set last segment slice resgroup size
     for b in 0..=((kn-1) >> 3) as usize { seg[b] = 0 } // set all seg byte bits to prime
     for (j, prime) in primes.iter().enumerate() {      // for each prime r1..sqrt(N)
                                        // for lower twin pair residue track
       let mut k = nextp[2 * j];        // starting from this resgroup in seg
       while k < kn  {                  // mark primenth resgrouup bits prime mults
         seg[k >> 3] |= 1 << (k & 7);   // mark byte resgroup bit as not a twin prime
         k += prime; }                  // resgroup for prime's next multiple
       nextp[2 * j] = k - kn;           // save 1st resgroup in next eligible seg
                                        // for upper twin pair residue track
       k = nextp[2 * j | 1];            // starting from this resgroup in seg
       while k < kn  {                  // mark primenth resgroup bits prime mults
         seg[k >> 3] |= 1 << (k & 7);   // mark byte resgroup bit as not a twin prime
         k += prime; }                  // resgroup for prime's next multiple
       nextp[2 * j | 1] = k - kn;       // save 1st resgroup in next eligible seg
     }
                                        // need to set as nonprime unused bits in last
                                        // byte of last seg; so fast, do for every seg
     seg[(kn-1) >> 3] |= !((2 << ((kn-1) % 8)) - 1) as u8;
     let mut cnt = 0usize;              // initialize segment twin primes count
     for k in 0..=((kn-1) >> 3) { cnt += PBITS[seg[k] as usize] } // sum segment twin primes
     if cnt > 0 {                       // if segment has twin primes
       sum += cnt;                      // add the segment tp count to total count
       let mut upk = kn - 1;            // from end of seg, count backwards to largest tp
       while seg[upk >> 3] & (1 << (upk & 7)) != 0 { upk -= 1 } 
       hi_tp = ki + upk;                // numerate its full resgroup value
     }
     ki += kb;                          // set 1st resgroup val of next seg slice
   }                                    // when done, numerate largest twin prime in seg
   hi_tp = if r_hi > end_num { 0 } else { hi_tp * modpg + r_hi };
   if sum == 0 { hi_tp = 1;}            // for small ranges w/o twins, set largest to 1
   (hi_tp, sum)
}

fn main() {
  // Accept command line input for 1 or 2 range values < 2**64
  let mut val = String::new();
  std::io::stdin().read_line (&mut val).expect("Failed to read line");
  let mut substr_iter = val.split_whitespace();
  let mut next_or_default = |def| -> usize {
      substr_iter.next().unwrap_or(def).parse().expect("Input is not a number")
  };
  let mut end_num = std::cmp::max(next_or_default("3"), 3);
  let mut start_num = std::cmp::max(next_or_default("3"), 3);
  if start_num > end_num { std::mem::swap(&mut end_num, &mut start_num) } 

  println!("threads = {}", num_cpus::get());
  let ts = SystemTime::now();      // start timing sieve setup execution

  start_num |= 1;                  // if start_num even increase by 1
  end_num = (end_num - 1) | 1;     // if end_num even decrease by 1
                                   // select PG and seg factor Bn for range
  let (bn, pg) = select_pg(start_num, end_num);
  let (modpg, res_0, pairscnt, restwins, resinvrs, pos) = gen_pg_parameters(pg);

  let kmin = (start_num-2) / modpg + 1; // number of resgroups to start_num
  let kmax = (end_num - 2) / modpg + 1; // number of resgroups to end_num
  let range = kmax - kmin + 1;          // number of range resgroups, at least 1
  let n = if range < 37_500_000_000_000 { 4 } else if range < 975_000_000_000_000 { 6 } else { 8 };
  let b = bn * 1024 * n;                // set seg size to optimize for selected PG
  let kb = if range < b { range } else { b }; // segments resgroups size

  println!("each thread segment is [1 x {}] bytes array", ((kb-1) >> 3) + 1);

  // This is not necessary for running the program but provides information
  // to determine the 'efficiency' of the used PG: (num of primes)/(num of pcs)
  let maxpairs = range * pairscnt;  // maximum number of twinprime pcs
  println!("twinprime candidates = {}; resgroups = {}", maxpairs, range);
  // End of non-essential code.

  let primes: Vec<usize> = if end_num < 49 { vec![5] }         // generate
      else { sozpg((end_num as f64).sqrt() as usize, res_0) }; // sieving primes

  println!("each of {} threads has nextp[2 x {}] array", pairscnt, primes.len());
  print_time("setup time", ts);    // sieve setup time

  let mut twinscnt = 0usize;       // count of twin primes in range
  let lo_range = restwins[0] - 3;  // lo_range = lo_tp - 1
  for tp in &[3, 5, 11, 17] {      // excluded low tp values PGs used
    if end_num == 3 { break };     // if 3 end of range, no twin primes
    if tp >= &start_num && tp <= &lo_range { twinscnt += 1 };
  }
  println!("perform twinprimes ssoz sieve");

  let t1 = SystemTime::now();      // start timing ssoz sieve execution
                                   // sieve each twin pair in parallel
  let (lastwins, cnts): (Vec<_>, Vec<_>) = {
    let counter = RelaxedCounter::new();
    (0..pairscnt).into_par_iter()
       .map( |index| {
          let out = twins_sieve(index, kmin, kmax, kb, start_num, end_num, modpg,
          &primes.to_vec(), &restwins, &resinvrs, &pos);
          print!("\r{} of {} threads done", counter.increment(), pairscnt);
          out
       }).unzip()
  };
  let mut last_twin = 0usize;      // find largest twin prime|sum in range
  for i in 0..pairscnt {
     twinscnt += cnts[i];
     if last_twin < lastwins[i] { last_twin = lastwins[i]; }
  }
  if end_num == 5 && twinscnt == 1 { last_twin = 5; } 
  let mut kn = range % kb;         // set number of resgroups in last slice
  if kn == 0 { kn = kb };          // if multiple of seg size set to seg size

  print_time("\nsieve time", t1);  // ssoz sieve time
  print_time("total time", ts);    // setup + sieve time
  println!("last segment = {} resgroups; segment slices = {}", kn, (range - 1)/kb + 1);
  println!("total twins = {}; last twin = {}+/-1", twinscnt, last_twin - 1);
}

EDIT: This is slightly refactored, and cleaned up version (now in gist), from original posted code.

This topic was automatically closed 90 days after the last reply. New replies are no longer allowed.