U32/integer square root

I am studying rust and I am writing a code that checks if a number is prime, but during the check I am testing if the integer provided is divisible by some value in the interval 2..num, i wanted to optimize this and check the interval 2..sqrt(num), but i can't do this because num is a u32 value and not a float.

is there a proper way to get this sqrt without using other libraries?

use std::io;

fn is_prime(i:u32) -> bool{
    for x in 2..i {
        if i%x == 0{
            return false;
        }
    }
    return true;
}


fn main() {
    let mut input_text = String::new();
    let mut i:u32 = 0;

    println!("Type an integer:");
    io::stdin().read_line(&mut input_text).expect("Failure to read integer");

    let trimmed = input_text.trim();
    match trimmed.parse::<u32>(){
        Ok(num) => i = num,
        Err(..) => println!("The number is not an integer {}", trimmed),
    };


    if is_prime(i){
        println!("The number is prime!");
    }
    else{
        println!("The number is NOT prime!");
    }

}

Rather than trial division up to sqrt(n), it's generally better to use a sieve of some sort (e.g. of Eratosthenes), perhaps with a wheel for small primes, to avoid trialling division by composite numbers that you already know can't divide your candidate number.

std doesn't provide any built-in way to get an integer square root. If this is something you actually want, your best bet is to either to just use floating point math ((n as f64).sqrt() as u32) and cap your allowed input at 9×1015, the largest integer that fits in f64 without skipping (which is larger than u32::MAX at around 4×109), or pull in a crate implementation of isqrt (or implement it yourself).

The num crate, which is practically standard, includes flooring square root. The algorithm isn't that complicated if you want to implement it yourself. You can use Newton's Method. There is also a binary search method described in the book Hacker's Delight that only requires addition in the inner loop.

3 Likes

The standard trick is to flip the comparison and instead of asking "is x smaller than the square root of i?", ask the equivalent question of "is x * x smaller than i?". This comes up all the time, e.g. comparing distances in computer graphics is also done like that, usually.

5 Likes

H2CO3's answer is the best one for your actual use case.

But if you ever need integer square root in the future, here's a previous thread about it: Integer square root algorithm - #5 by scottmcm

IIRC the easiest is f64::from(x).sqrt() as u32, because f64 has enough mantissa precision for this. But rounding modes are complicated, so you'd want to double-check that. (There's only 4 billion, so you can just check that isqrt(x).pow(2) <= x for all of them without trouble.)

2 Likes

The cost of the square root is amortized over the entire loop, but testing the square each time requires a multiply (or accumulating 2*i + 1 where i is the induction variable). It doesn't matter so much in this case because there are much better algorithms for primality testing, but it's worth keeping in mind.

If the floor of the floating-point square root in this range is incorrect, that is a bug. The worst case is 2^32 - 1 and its square root is 2^16 - error where the error is about 2^-17, which comfortably fits in a f64.

You need at most 6 iterations of Newton's method to get the exact answer:

fn sqrt(a: u32) -> u32 {
    if a == 0 {
        return 0;
    }
    let mut x = 1 << (33 - a.leading_zeros()) / 2;
    loop {
        let y = (x + a / x) / 2;
        if y >= x {
            return x;
        }
        x = y;
    }
}
1 Like

This is one of those cases where you can use a trick to make a mathematical function MUCH faster.

If you iterate over the integers from 1 to 10_000 and take their square roots, 99% of the time the answer will be the same number as last time; it seems silly to start from scratch each time. And for those 100 out of 10_000 cases where the answer really does change, you don't actually need to calculate the square root from scratch -- you just have to add one to the previous answer.

Here is a function that makes a very fast integer square root function for such cases.

pub fn mk_isqrt(init: u16) -> impl FnMut(u32) -> u16
{
    let mut sqrt = init; // the current square root
    let s = init as u32;
    let mut hi = s * (s + 2);   // (s + 1)^2 - 1 without overflowing
    move |n: u32| {
        // If the current sqrt doesn't work for this n, increment it.
        if n > hi {
            sqrt += 1;
            let s = sqrt as u32;
            hi += s + s + 1;
        }
        sqrt
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn test() {
        let isqrt = super::mk_isqrt(0);
        let result: Vec<_> = (0u32..17)
            .map(isqrt)
            .collect();
        let expected: Vec<u16> = vec![
        //  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16
            0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4
        ];
        assert_eq!(result, expected);
    }
}

You can of course make this more general (e.g. ascending/descending/both, and not assuming that the sqrt will only change by 1), and in fact I have started working on a crate that would provide macros to allow you to generate an integer square root function like this with the exact amount of generality you need. If anyone knows of such a facility -- in any language -- please let me know. Or perhaps you know of a name for this kind of technique; I have been calling it "input-local math" since it takes advantage of the fact that the inputs change only gradually. You could use the same trick for lots of functions.

And here's a slightly faster version of that:

pub fn mk_isqrt(init: u16) -> impl FnMut(u32) -> u16
{
    let mut sqrt = init; // the current square root
    let s = init as u32;
    let mut hi = s * (s + 2);   // (s + 1)^2 - 1 without overflowing
    let mut inc = s + s + 1;
    move |n: u32| {
        // If the current sqrt doesn't work for this n, increment it.
        if n > hi {
            sqrt += 1;
            inc += 2;
            hi += inc;
        }
        sqrt
    }
}

In case it's not clear what is happening, consider that

(s + 1)^2 = s^2 + 2s + 1

So for inputs from s^2 to s^2 + 2s, we know the isqrt is s. And when the input exceeds s^2 + 2s, we just tee up the next range.