Optimizing linear algebra code

I have benchmark, sort of, coming from university problems (I'm no programmer, I do computational physics). The code solves equations for matrices with special structures where the structure is hard coded into functions to save memory. The code was ported from Matlab to Julia to C++ to Rust by hand so it is not very idiomatic anyway.

The benchmarks for Rust are rather disappointing probably because Rust cannot optimize away the bounds checks and therefore the the loops cannot aggressively be optimized for SIMD instructions:

index

So there is almost a factor of 3 to C++ (the comparison is unfair because C++ was also compiled with the fast math options but even then there is still a factor of 2). I also use target-cpu=native for Rust but it did not change the runtimes.

I am aware that there are unsafe functions that do unchecked access on a field but it would be better for readability if I could still use the [] operator. In C++ the [] operator does not do bounds checks and in Julia bounds checks can be turned off globally or whole functions after the code has been developed and tested.

A full excerpt of the code is to follow. The code consists of for loops over ranges of the following form:

    for i in l.n-l.n_x+2..=l.n-1 {
        r2 += square(-b[i-1] + x[i-1]*l.diag as f64 + x[i+1-1]*l.tri_diag as f64 + x[i-1-1]*l.tri_diag as f64 + x[i-l.n_x-1]*l.side_diag as f64);
    }

By construction the code is always inbounds but the compiler does not seem to be aware of that. My question is how one would formulate this code in a safe way. I am aware of .iter().enumerate() to iterate over elements as well as indices but it is not clear to me how this might be used for a more complex example.

The standard benchmark of Rust in Nightly benchmarks this function with 283 ns/iter (+/- 51). In Julia it is 108 ns and in C++ 151 ns although this comparison needs to be taken with a grain of salt since it was done with three different benchmark libraries but it reflects the general result from above.

pub fn square<T>(i: T) -> T 
where 
	T:std::ops::Mul + std::ops::Mul<Output = T> + Copy {
	i*i
}

#[derive(Debug)]
pub struct Laplace2dMatrix {
    pub n_x       : usize ,
    pub n_y       : usize ,
    pub n         : usize ,
    pub diag      : i64   ,
    pub tri_diag  : i64   ,
    pub side_diag : i64   ,
}

impl Laplace2dMatrix {
    pub fn rectangular(n_x: usize, n_y: usize) -> Laplace2dMatrix {
        Laplace2dMatrix {
                  n_x: n_x                                  ,
                  n_y: n_y                                  ,
                    n: n_x*n_y                              ,
                 diag: -2*( square(n_x+1) + square(n_y+1) ) as i64 ,
             tri_diag: square(n_x+1) as i64                 ,
            side_diag: square(n_y+1) as i64                 ,
        }
    }

    pub fn quadratic(n_xy: usize) -> Laplace2dMatrix {
        Laplace2dMatrix::rectangular(n_xy, n_xy)
    }
}

/// calculates the residual r^2 = ||l*x - b||_2^2 
pub fn calculate_residual_squared(l: &Laplace2dMatrix, x: &[f64], b: &[f64]) -> f64 {
    // Assumptions that must hold for l
    // Laplace2dMatrix represents a block diagonal matrix where the block matrix
    // is the same for all blocks and the block matrix is a band matrix
    // n_x and n_y are the size of the block matrix and l is a n x n matrix 
    // where n = n_x*n_y 
    assert!(l.n == x.len());
    assert!(l.n == b.len());
    assert!(l.n == l.n_x*l.n_y);

    let mut r2 = 0.0;

    r2 += square(-b[1-1]  + x[1-1]*l.diag as f64 + x[2-1]   * l.tri_diag as f64 + x[1+l.n_x-1] * l.side_diag as f64);
    for i in 2..=l.n_x-1 {
        r2 += square(-b[i-1] + x[i-1]*l.diag as f64 + x[i+1-1]*l.tri_diag as f64 + x[i-1-1]*l.tri_diag as f64 + x[i+l.n_x-1]*l.side_diag as f64);
    }
    r2 += square(-b[l.n_x-1] + x[l.n_x-1]*l.diag as f64 + x[l.n_x-1-1]*l.tri_diag as f64 + x[l.n_x+l.n_x-1]*l.side_diag as f64);

    for outer_base in (l.n_x..=l.n_x*(l.n_y-2)).step_by(l.n_x) {
        r2 += square(-b[outer_base+1 -1] + x[outer_base+1 -1]*l.diag as f64 + x[outer_base+2   -1]*l.tri_diag as f64 + x[outer_base+1-l.n_x -1]*l.side_diag as f64 + x[outer_base+1+l.n_x-1]*l.side_diag as f64);
        for i in 2..=l.n_x-1 {
            r2 += square(-b[outer_base+i-1] + x[outer_base+i-1]*l.diag as f64 + x[outer_base+i+1-1]*l.tri_diag as f64 + x[outer_base+i-1-1]*l.tri_diag as f64 + x[outer_base+i+l.n_x-1]*l.side_diag as f64 + x[outer_base+i-l.n_x-1]*l.side_diag as f64);
        }
        r2 += square(-b[outer_base+l.n_x-1] + x[outer_base+l.n_x-1]*l.diag as f64 + x[outer_base+l.n_x-1-1]*l.tri_diag as f64 + x[outer_base+l.n_x+l.n_x-1]*l.side_diag as f64 + x[outer_base     -1]*l.side_diag as f64);
    }

    r2 += square(-b[l.n-l.n_x+1-1] + x[l.n-l.n_x+1-1]*l.diag as f64 + x[l.n-l.n_x+2-1]*l.tri_diag as f64 + x[l.n-l.n_x-l.n_x+1-1]*l.side_diag as f64);
    for i in l.n-l.n_x+2..=l.n-1 {
        r2 += square(-b[i-1] + x[i-1]*l.diag as f64 + x[i+1-1]*l.tri_diag as f64 + x[i-1-1]*l.tri_diag as f64 + x[i-l.n_x-1]*l.side_diag as f64);
    }
    r2 += square(-b[l.n     -1] + x[l.n     -1]*l.diag as f64 + x[l.n-1   -1]*l.tri_diag as f64 + x[l.n-l.n_x     -1]*l.side_diag as f64);


    r2 /= l.n as f64;
    return r2;
}

fn main() {
	let l = Laplace2dMatrix::quadratic(10);
	let x = vec![1.2; 100];
	let b = vec![1.2; 100];

	println!("squared residual: {:?}", calculate_residual_squared(&l, &x, &b));
}
1 Like

[] is slow. Avoid it. It adds a bounds check. That check itself isn't very slow, but for numeric code it has very slow side effect: because it can panic, it prevents auto-vectorization and code-reordering optimizations.

Because of that for i in n is the slowest loop in Rust. You should avoid it as much as you can.

For fast loops you have some options:

  1. Use iterators instead. They automatically eliminate redundant bounds checks and optimize very well.

  2. Use get_unchecked() instead of []. get_unchecked is the direct equivalent of [] in C/C++. Rust's [] is equivalent of C++ .at().

  3. Make loop conditions and indexing very very simple and obvious to the optimizer, so that it can eliminate bounds checks. For simple cases it means making a slice of expected length:

    let tmp = &arr[0..len];
    for i in 0..len {
       // fast if the optimizer can see relationship between both tmp and i
       tmp[i] 
    }
    

    rust.godbolt.org is nice for checking if bounds are optimized out or not (just remember to add opt-level flag)

9 Likes

You can join #rust-sci on oftc.net (IRC) for a more interactive discussion or #rust-sci:matrix.org (Matrix) they are bridged.
I will try to see what I can find out about this…

2 Likes

Removing bounds checks gets it from 261ns to 192ns on my machine.

How did you do it? Did you use unsafe get_unchecked() ?

WARNING: THE FOLLOWING CODE IS UNSAFE AND FOR EXPERIMENTAL PURPOSES ONLY

struct Arr<'a> (&'a [f64]);
impl<'a> Index<usize> for Arr<'a> {
    type Output = f64;
    fn index(&self, idx: usize) -> &Self::Output {
        unsafe {
            self.0.get_unchecked(idx)
        }
    }
}

and then I wrapped the arrays:

    let x = Arr(x);
    let b = Arr(b);
1 Like

It's kinda unsatisfying because it is unsafe
but it is way more readable at least

I will try what happens if I throw this on the whole program suite

To be honest, I was too lazy to replace all the indices…

I am now at 165ns using https://gitlab.com/kornelski/ffast-math/-/blob/master/src/lib.rs (converted to f64)

2 Likes

Just to be clear in case anybody sees this and thinks they should apply it to their own code - this shouldn't be used as a general solution. I think the rule around using unsafe is that usage of the word unsafe should be very close lexically to the code that is guaranteeing your code is actually safe. This does not do that.

At the very least, you should make sure that instantiating Arr requires unsafe (e.g. unsafe fn new(...) -> Arr), and only instantiate Arr in the same routine as the code that indexes the array. Then write lots of comments explaining how you can be positive that you never read past the end of the array.

6 Likes

Absolutely! It is unsafe and a terrible idea. But it works to find out how it affects performance.

2 Likes

Rust bounds checking is the HPC killer.

2 Likes

that is interesting
needing to use a special container instead of a switch is not ergonomic though

It is actually not that terrible. Plus it allows you control where you want speed vs. accuracy.

However the biggest improvement will be to vectorize your code… manually. LLVM has clearly given up on any attempts at doing so.

2 Likes

I seem to recall that there's been a discussion of a switch before, and I think it was decided against because of how easy it is to abuse. The core benefit of Rust is that if you never utter the word unsafe, you're statically guaranteed that your code is free of a whole class of bugs that lead to security vulnerabilities and general application unreliability. If you tell people that there's a "go fast" flag that disables this core feature, you start to lose that guarantee.

I see the point but for Scientific Computing and other fields that is bad. And it is sad, too. I really like Rust for the clarity of the programs

1 Like

Someday, somebody will find a way to write computational kernels in Rust that guarantee via the type system that an index is bounded. But such a technology does not yet exist, I think...

1 Like

that's a developer time vs computer time thing. If I can get more than 30% speed up by a switch I'm gonna take it but if I need to invest days to manually vectorize something and sacrifice portability on the way most of the time I will just let the machine run longer

You are discounting the fact that writing unsafe without putting in the effort to validate it just defers most of that effort to debugging the undefined behavior that is inevitably produced. "Save 30% of my time now to waste 60% of my time later" sounds like a poor economical tradeoff.

2 Likes

The approach I was talking about is to develop and test the code with bounds checks and for smaller loads and after that turn off the bounds checks for real world problems

The reason for runtime bounds checks at all is because the compiler cannot statically guarantee that the memory access is in bounds. For cases where it can make this guarantee, the checks are optimized away. This is the point that Kornel was making early on in the thread. Turning off check is not the same as sufficiently proving that they are unnecessary.