Miller-Rabin for 64-bit integers

I implemented a deterministic version of Miller-Rabin primality test for 64-bit integers. Please help me improve it (or correct it if there are mistakes).

fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
    a.checked_mul(b).map_or_else(|| (u128::from(a) * u128::from(b) % u128::from(m)) as u64, |p| p % m)
}
fn pow_mod(base: u64, exp: u64, m: u64) -> u64 {
    let mut exp @ 2.. = exp else { return if exp == 1 { base % m } else { 1 % m } };
    let mut base @ 2.. = base % m else { return base % m };
    let mut result = 1;

    while exp > 1 {
        if exp & 1 == 1 {
            result = mul_mod(result, base, m);
        }

        exp >>= 1;
        base = mul_mod(base, base, m);
    }

    mul_mod(result, base, m)    
}
fn is_prime(n: u64) -> bool {
    if let Some(&p) = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37].iter().find(|&x| n % x == 0) {
        return n == p
    } else if n < 41u64.pow(2) {
        return n > 1
    }

    let d = (n - 1) >> (n - 1).trailing_zeros();
    let bases = match n < 4_759_123_141 {
        true => [2, 7, 61].iter(),
        false => [2, 325, 9_375, 28_178, 450_775, 9_780_504, 1_795_265_022].iter(),
    };
  
    for mut x in bases.map(|&b| pow_mod(b, d, n)).filter(|&x| x > 1 && x < n - 1) {
        let mut j = d;

        while x != n - 1 {
            x = mul_mod(x, x, n);
            j *= 2;

            if x == 1 || j == n - 1 {
                return false
            }
        }
    }

    true
}
fn main() {
    let now = std::time::Instant::now();
    println!("{}", (0..=10_000_000).filter(|&x| is_prime(x)).count());
    println!("{:?}", now.elapsed());

    let now = std::time::Instant::now();
    println!("{}", (u64::from(u32::MAX) - 10_000_000..=u64::from(u32::MAX)).filter(|&x| is_prime(x)).count());
    println!("{:?}", now.elapsed());
}

#[cfg(test)]
#[allow(clippy::cognitive_complexity)]
mod tests {
    use super::*;

    #[test]
    fn test_mod_pow() {
        assert_eq!(pow_mod(0, 0, 1), 0);
        assert_eq!(pow_mod(0, 0, 2), 1);
        assert_eq!(pow_mod(0, 1, 1), 0);
        assert_eq!(pow_mod(0, 1, 2), 0);
        assert_eq!(pow_mod(1, 0, 1), 0);
        assert_eq!(pow_mod(1, 0, 2), 1);
        assert_eq!(pow_mod(1, 1, 2), 1);
        assert_eq!(pow_mod(2, 1, 1), 0);
        assert_eq!(pow_mod(2, 1, 2), 0);
        assert_eq!(pow_mod(2, 1, 3), 2);
        assert_eq!(pow_mod(2, 2, 2), 0);
        assert_eq!(pow_mod(2, 2, 1), 0);
        assert_eq!(pow_mod(2, 2, 3), 1);
        assert_eq!(pow_mod(2, 3, 3), 2);
        assert_eq!(pow_mod(2, 6, 5), 4);
        assert_eq!(pow_mod(2, 63, u64::MAX), 2u64.pow(63));
        assert_eq!(pow_mod(2, 64, u64::MAX), 1);
        assert_eq!(pow_mod(3, 1, 2), 1);
        assert_eq!(pow_mod(3, 2, 2), 1);
        assert_eq!(pow_mod(5, 2, 20), 5);
        assert_eq!(pow_mod(5, 2, 19), 6);
        assert_eq!(pow_mod(5, 3, 19), 11);
        assert_eq!(pow_mod(10, 18, u64::MAX), 10u64.pow(18));
        assert_eq!(pow_mod(10, 19, u64::MAX), 10u64.pow(19));
        assert_eq!(pow_mod(10, 20, u64::MAX), 7_766_279_631_452_241_925);
        assert_eq!(pow_mod(u64::MAX, u64::MAX, u64::MAX), 0);
    }

    #[test]
    fn test_is_prime_intervals() {
        let mut total = 0u32;

        for i in 0..=1_000_000 {
            if is_prime(i) {
                total += 1;
            }
            match i {
                100 => assert_eq!(total, 25),
                1_000 => assert_eq!(total, 168),
                10_000 => assert_eq!(total, 1_229),
                100_000 => assert_eq!(total, 9_592),
                1_000_000 => assert_eq!(total, 78_498),
                _ => continue,
            }
        }

        assert_eq!((u64::from(u32::MAX) - 1_000_000..=u64::from(u32::MAX)).filter(|&x| is_prime(x)).count(), 44_872);
        assert_eq!((10u64.pow(12) - 1_000_000..=10u64.pow(12)).filter(|&x| is_prime(x)).count(), 36_400);
        assert_eq!((10u64.pow(15) - 1_000_000..=10u64.pow(15)).filter(|&x| is_prime(x)).count(), 28_910);
        assert_eq!((u64::MAX - 1_000_000..=u64::MAX).filter(|&x| is_prime(x)).count(), 22_475);
    }
    
    #[test]
    fn test_is_prime_true() {
        assert!(is_prime(2));
        assert!(is_prime(3));
        assert!(is_prime(5));
        assert!(is_prime(97));
        assert!(is_prime(103));
        assert!(is_prime(2_147_483_647));
        assert!(is_prime(4_294_967_279));
        assert!(is_prime(4_294_967_291));
        assert!(is_prime(1_000_000_000_100_011));
        assert!(is_prime(1_003_229_774_283_941));
        assert!(is_prime(1_011_001_110_001_111));
        assert!(is_prime(1_311_870_831_664_661));
        assert!(is_prime(2_035_802_523_820_057));
        assert!(is_prime(3_391_382_115_599_173));
        assert!(is_prime(10_000_000_002_065_383));
        assert!(is_prime(37_033_804_397_792_473));
        assert!(is_prime(599_999_999_999_899_999));
        assert!(is_prime(1_000_000_000_000_000_003));
        assert!(is_prime(8_512_677_386_048_191_063));
        assert!(is_prime(9_181_531_581_341_931_811));
        assert!(is_prime(9_876_534_021_204_356_789));
        assert!(is_prime(9_876_543_212_123_456_789));
        assert!(is_prime(9_876_543_218_123_456_789));
        assert!(is_prime(9_988_776_655_443_322_001));
        assert!(is_prime(9_999_999_992_999_999_999));
        assert!(is_prime(9_999_999_997_777_777_333));
        assert!(is_prime(18_446_744_073_709_551_427));
        assert!(is_prime(18_446_744_073_709_551_437));
        assert!(is_prime(18_446_744_073_709_551_521));
        assert!(is_prime(18_446_744_073_709_551_533));
        assert!(is_prime(18_446_744_073_709_551_557));
    }

    #[test]
    fn test_is_prime_false() {
        assert!(!is_prime(0));
        assert!(!is_prime(1));
        assert!(!is_prime(4));
        assert!(!is_prime(93));
        assert!(!is_prime(111));
        assert!(!is_prime(4_294_967_271));
        assert!(!is_prime(4_294_967_273));
        assert!(!is_prime(4_294_967_277));
        assert!(!is_prime(1_000_112_354_597_667));
        assert!(!is_prime(1_158_174_141_556_287));
        assert!(!is_prime(1_483_892_396_791_177));
        assert!(!is_prime(2_225_124_216_112_318));
        assert!(!is_prime(2_695_965_911_118_727));
        assert!(!is_prime(4_391_491_991_635_087));
        assert!(!is_prime(99_999_989u64.pow(2)));
        assert!(!is_prime(11_000_011_101_101_111));
        assert!(!is_prime(12_345_689_798_654_321));
        assert!(!is_prime(13_030_323_000_581_525));
        assert!(!is_prime(111_111_111_111_111_111));
        assert!(!is_prime(99_999_989 * (i32::MAX as u64)));
        assert!(!is_prime(430_442_854_738_298_199));
        assert!(!is_prime(889_091_889_880_616_081));
        assert!(!is_prime(1_034_429_177_995_381_247));
        assert!(!is_prime(1_234_567_888_887_654_321));
        assert!(!is_prime((i32::MAX as u64).pow(2)));
        assert!(!is_prime(7_540_113_804_746_346_429));
        assert!(!is_prime(8_650_415_921_358_664_919));
        assert!(!is_prime(9_876_503_214_123_056_789));
    }
}

Immediately, before even trying to understand your code, I noticed that you do not use rustfmt to format your code in a way that other Rust devs are accustomed to. Following the style guidelines helps readability a lot. So in my opinion, running rustfmt on your code would be a low hanging fruit when it comes to improving it:

fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
    a.checked_mul(b).map_or_else(
        || (u128::from(a) * u128::from(b) % u128::from(m)) as u64,
        |p| p % m,
    )
}

fn pow_mod(base: u64, exp: u64, m: u64) -> u64 {
    let mut exp @ 2.. = exp else {
        return if exp == 1 { base % m } else { 1 % m };
    };
    
    let mut base @ 2.. = base % m else {
        return base % m;
    };
    
    let mut result = 1;

    while exp > 1 {
        if exp & 1 == 1 {
            result = mul_mod(result, base, m);
        }

        exp >>= 1;
        base = mul_mod(base, base, m);
    }

    mul_mod(result, base, m)
}

fn is_prime(n: u64) -> bool {
    if let Some(&p) = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]
        .iter()
        .find(|&x| n % x == 0)
    {
        return n == p;
    } else if n < 41u64.pow(2) {
        return n > 1;
    }

    let d = (n - 1) >> (n - 1).trailing_zeros();
    
    let bases = match n < 4_759_123_141 {
        true => [2, 7, 61].iter(),
        false => [2, 325, 9_375, 28_178, 450_775, 9_780_504, 1_795_265_022].iter(),
    };

    for mut x in bases
        .map(|&b| pow_mod(b, d, n))
        .filter(|&x| x > 1 && x < n - 1)
    {
        let mut j = d;

        while x != n - 1 {
            x = mul_mod(x, x, n);
            j *= 2;

            if x == 1 || j == n - 1 {
                return false;
            }
        }
    }

    true
}

fn main() {
    let now = std::time::Instant::now();
    println!("{}", (0..=10_000_000).filter(|&x| is_prime(x)).count());
    println!("{:?}", now.elapsed());

    let now = std::time::Instant::now();
    println!(
        "{}",
        (u64::from(u32::MAX) - 10_000_000..=u64::from(u32::MAX))
            .filter(|&x| is_prime(x))
            .count()
    );
    println!("{:?}", now.elapsed());
}

#[cfg(test)]
#[allow(clippy::cognitive_complexity)]
mod tests {
    use super::*;

    #[test]
    fn test_mod_pow() {
        assert_eq!(pow_mod(0, 0, 1), 0);
        assert_eq!(pow_mod(0, 0, 2), 1);
        assert_eq!(pow_mod(0, 1, 1), 0);
        assert_eq!(pow_mod(0, 1, 2), 0);
        assert_eq!(pow_mod(1, 0, 1), 0);
        assert_eq!(pow_mod(1, 0, 2), 1);
        assert_eq!(pow_mod(1, 1, 2), 1);
        assert_eq!(pow_mod(2, 1, 1), 0);
        assert_eq!(pow_mod(2, 1, 2), 0);
        assert_eq!(pow_mod(2, 1, 3), 2);
        assert_eq!(pow_mod(2, 2, 2), 0);
        assert_eq!(pow_mod(2, 2, 1), 0);
        assert_eq!(pow_mod(2, 2, 3), 1);
        assert_eq!(pow_mod(2, 3, 3), 2);
        assert_eq!(pow_mod(2, 6, 5), 4);
        assert_eq!(pow_mod(2, 63, u64::MAX), 2u64.pow(63));
        assert_eq!(pow_mod(2, 64, u64::MAX), 1);
        assert_eq!(pow_mod(3, 1, 2), 1);
        assert_eq!(pow_mod(3, 2, 2), 1);
        assert_eq!(pow_mod(5, 2, 20), 5);
        assert_eq!(pow_mod(5, 2, 19), 6);
        assert_eq!(pow_mod(5, 3, 19), 11);
        assert_eq!(pow_mod(10, 18, u64::MAX), 10u64.pow(18));
        assert_eq!(pow_mod(10, 19, u64::MAX), 10u64.pow(19));
        assert_eq!(pow_mod(10, 20, u64::MAX), 7_766_279_631_452_241_925);
        assert_eq!(pow_mod(u64::MAX, u64::MAX, u64::MAX), 0);
    }

    #[test]
    fn test_is_prime_intervals() {
        let mut total = 0u32;

        for i in 0..=1_000_000 {
            if is_prime(i) {
                total += 1;
            }
            match i {
                100 => assert_eq!(total, 25),
                1_000 => assert_eq!(total, 168),
                10_000 => assert_eq!(total, 1_229),
                100_000 => assert_eq!(total, 9_592),
                1_000_000 => assert_eq!(total, 78_498),
                _ => continue,
            }
        }

        assert_eq!(
            (u64::from(u32::MAX) - 1_000_000..=u64::from(u32::MAX))
                .filter(|&x| is_prime(x))
                .count(),
            44_872
        );
        assert_eq!(
            (10u64.pow(12) - 1_000_000..=10u64.pow(12))
                .filter(|&x| is_prime(x))
                .count(),
            36_400
        );
        assert_eq!(
            (10u64.pow(15) - 1_000_000..=10u64.pow(15))
                .filter(|&x| is_prime(x))
                .count(),
            28_910
        );
        assert_eq!(
            (u64::MAX - 1_000_000..=u64::MAX)
                .filter(|&x| is_prime(x))
                .count(),
            22_475
        );
    }

    #[test]
    fn test_is_prime_true() {
        assert!(is_prime(2));
        assert!(is_prime(3));
        assert!(is_prime(5));
        assert!(is_prime(97));
        assert!(is_prime(103));
        assert!(is_prime(2_147_483_647));
        assert!(is_prime(4_294_967_279));
        assert!(is_prime(4_294_967_291));
        assert!(is_prime(1_000_000_000_100_011));
        assert!(is_prime(1_003_229_774_283_941));
        assert!(is_prime(1_011_001_110_001_111));
        assert!(is_prime(1_311_870_831_664_661));
        assert!(is_prime(2_035_802_523_820_057));
        assert!(is_prime(3_391_382_115_599_173));
        assert!(is_prime(10_000_000_002_065_383));
        assert!(is_prime(37_033_804_397_792_473));
        assert!(is_prime(599_999_999_999_899_999));
        assert!(is_prime(1_000_000_000_000_000_003));
        assert!(is_prime(8_512_677_386_048_191_063));
        assert!(is_prime(9_181_531_581_341_931_811));
        assert!(is_prime(9_876_534_021_204_356_789));
        assert!(is_prime(9_876_543_212_123_456_789));
        assert!(is_prime(9_876_543_218_123_456_789));
        assert!(is_prime(9_988_776_655_443_322_001));
        assert!(is_prime(9_999_999_992_999_999_999));
        assert!(is_prime(9_999_999_997_777_777_333));
        assert!(is_prime(18_446_744_073_709_551_427));
        assert!(is_prime(18_446_744_073_709_551_437));
        assert!(is_prime(18_446_744_073_709_551_521));
        assert!(is_prime(18_446_744_073_709_551_533));
        assert!(is_prime(18_446_744_073_709_551_557));
    }

    #[test]
    fn test_is_prime_false() {
        assert!(!is_prime(0));
        assert!(!is_prime(1));
        assert!(!is_prime(4));
        assert!(!is_prime(93));
        assert!(!is_prime(111));
        assert!(!is_prime(4_294_967_271));
        assert!(!is_prime(4_294_967_273));
        assert!(!is_prime(4_294_967_277));
        assert!(!is_prime(1_000_112_354_597_667));
        assert!(!is_prime(1_158_174_141_556_287));
        assert!(!is_prime(1_483_892_396_791_177));
        assert!(!is_prime(2_225_124_216_112_318));
        assert!(!is_prime(2_695_965_911_118_727));
        assert!(!is_prime(4_391_491_991_635_087));
        assert!(!is_prime(99_999_989u64.pow(2)));
        assert!(!is_prime(11_000_011_101_101_111));
        assert!(!is_prime(12_345_689_798_654_321));
        assert!(!is_prime(13_030_323_000_581_525));
        assert!(!is_prime(111_111_111_111_111_111));
        assert!(!is_prime(99_999_989 * (i32::MAX as u64)));
        assert!(!is_prime(430_442_854_738_298_199));
        assert!(!is_prime(889_091_889_880_616_081));
        assert!(!is_prime(1_034_429_177_995_381_247));
        assert!(!is_prime(1_234_567_888_887_654_321));
        assert!(!is_prime((i32::MAX as u64).pow(2)));
        assert!(!is_prime(7_540_113_804_746_346_429));
        assert!(!is_prime(8_650_415_921_358_664_919));
        assert!(!is_prime(9_876_503_214_123_056_789));
    }
}

Playground.

Here is a more readable version of this:

fn pow_mod(mut base: u64, mut exp: u64, m: u64) -> u64 {
    match exp {
        0 => return 1 % m,
        1 => return base % m,
        _ => {}
    }

    base %= m;
    if base < 2 {
        return base;
    }
1 Like