Exploring RustFFT's SIMD Architecture

Exploring RustFFT's SIMD Architecture

After Releasing RustFFT 5.0 yesterday, a few people asked for details on how RustFFT 5.0 achieved its speed improvements over RustFFT 4.0 and its newfound speed advantage over the C FFT library FFTW.

In this post, I'll start by giving a high-level overview about FFT computation, then discuss how RustFFT computes FFTs, and finally dive into exactly how the AVX code is structured to maximize speed.

FFT Overview

The FFT is fundamentally a recursive process. Nearly all FFT algorithms come down to finding a way to express a FFT as some combination of smaller FFTs. At the heart of FFT computation is the Mixed-Radix Algorithm, which depends on the prime factorization of the FFT input size.

If you can factor a FFT size N into integers A * B, you can use that factorization to compute the size-N FFT:

Algorithm 1:

  1. Treat the input array as a 2D array of width A, and height B.
  2. Compute size-B sub-FFTs down each column of our 2D array.
  3. Compute an O(N) processing step, by multiplying the array with precomputed complex numbers called "twiddle factors"
  4. Compute size-A sub-FFTs down the rows of our 2D array.

The most well-known FFT algorithm is probably the Cooley-Tukey Algorithm, and by reading the linked pseudocode, it becomes clear that Cooley-Tukey is just a special case of Algorithm 1 where A=2 and B=N/2, or vice versa.

If the size-A and size-B sub-FFTs are also computed with Algorithm 1, the result is an O(nlogn) divide-and-conquer algorithm with prime-number base cases, corresponding to the original prime factorization of N. These prime number base cases can then be computed using optimized hard-coded functions, or by FFT algorithms that specialize in prime number sizes.

For a simple example of a hardcoded base case, a size-2 FFT is just some additions and subtractions:

output[0] = input[0] + input[1];
output[1] = input[0] - input[1];

RustFFT's Algorithm

RustFFT's FFT computation is based on a specific formulation of the Mixed Radix algorithm, optimized for modern systems which are heavily dependent on CPU cache: the Six-Step FFT. Like Algorithm 1, we're going to factor our FFT size into A * B, and then treat our input array as a 2D array with width A and height B:

Algorithm 2:

  1. Transpose our 2D array
  2. Compute sub-FFTs of size B down the rows of our transposed array
  3. Multiply with precomputed twiddle factors
  4. Transpose again
  5. Compute sub-FFTs of size A down the rows of our transposed array
  6. Transpose one more time.

The MixedRadixSmall struct in the RustFFT source code has an easy-to-read real-world example of Algorithm 2.

The key difference between Algorithm 2 and Algorithm 1 described in the previous section is step 2: Instead of computing FFTs down the columns, which requires strided memory access and is cache-unfriendly, we transpose the data first, so that the strided columns become contiguous rows. The overall result is that this FFT algorithm is inherently cache-friendly.

When the whole problem fits comfortably in cache, other algorithms such as RustFFT's Radix4 algorithm (which is also a special case of the Mixed-Radix Algorithm) are faster, but the larger your FFT size is, the faster the six-step FFT will be compared to anything else.

The six-step FFT appears to be fastest when the two factors A and B are as close to sqrt(N) as possible, with the ideal case being A=B, and RustFFT's planner has code in it to divide factors equally when planning a step of Mixed Radix.

SIMD Challenges

Unfortunately, even though Algorithm 2 is very fast for scalar operations, it's just about the worst possible architecture for SIMD optimization. All SIMD instruction sets are optimized around the idea of loading a handful of blocks of memory into registers, doing parallel pairwise operations until your data has been completely processed, and then writing it out and never touching it again.

As soon as you have to do one of:

  1. Swap data around within a SIMD register
  2. Load and store the same data more than once
  3. Load more than 10-15 chunks of memory at once (Resulting in lot of loads/stores to temporary memory, which violates rule #2)

You're going to see slowdowns. Any nontrivial application will have to do at least some of these, but as a general rule of thumb, if you're writing SIMD code, you should structure your algorithms to minimize them.

The problem is, Algorithm 2 is almost intentionally designed to break these rules. 3 of the 6 steps are transposes, and during a transpose you're doing nothing but violating rule #1. If we're recursively using this same algorithm in our sub-FFTs, then 5 of our 6 steps spend almost of all of their time violating rule #1.

SIMD Twiddle Factors

Step 3 of Algorithm 2, where we apply twiddle factors, is well-suited to SIMD. You load some data from the FFT array, load some twiddle factors, multiply them together, then write the result back to the FFT array. All three rules are satisfied.

When I first began experimenting with SIMD code back in April 2020, switching this step to use SIMD instrinsics was one of the first things I tried. Unfortunately, the overall performance barely changed. The most obvious reason is because the twiddle factor step is such a small percentage of the FFT's overall time. But there's an equally serious problem to overcome:

AVX has 16 registers to work with, letting you store a total of 16 * 8 floats at once - but when you're just doing a pairwise multiplication of 2 arrays, you're only using 3-4 of those registers, maximum, at a time. Most of the CPU's ability to pipeline operations is sitting unused.

There are a few tricks like auto-unrolling that help make pairwise multiplication more pipelined, but that doesn't solve the transpose problem. I quickly realized that a different architecture would be needed, one which satisfied the three SIMD rules above while taking advantage of as much of the CPU as possible.

AVX-Friendly Architecture

One approach would be to take Algorithm 2 and skip steps 1 and 4, giving us the "Four Step FFT" algorithm:

Algorithm 3:

  1. Compute sub-FFTs of size B down the columns of our 2D array
  2. Multiply with precomputed twiddle factors
  3. Compute sub-FFTs of size A down the rows of our 2D array
  4. Transpose

This setup gets us somewhere, because we can start computing the size-B column FFTs via AVX: Use _mm256_loadu_ps() to load 4 complex numbers from row 0, 4 complex numbers from row 1, 4 complex numbers from row 2, etc. Load one AVX register from each row, and we can compute 4 size-B FFTs in parallel.

This doesn't quite work, because we have to be able to handle arbitrary values of B. We only have 16 AVX registers to work with, so if B is more than 16, then we're violating rule #3 from our SIMD rules up above.

Hardcoding a Specific Column FFT Size

The solution is to make B a compile-time constant: Instead of having to worry about handling arbitrary values of B all in one implementation, we can monomorphize our FFT algorithms for various values of B: One FFT function that factors N into A=N/2, B=2 , another FFT function that factors N into A=N/3, B=3, another that factors N into A=N/4, B=4, etc.

If we pick 12 for example as the value for B, that gives us the following FFT algorithm.

Algorithm 4:

  1. Compute sub-FFTs of size 12 down the columns of our 2D array
    1.1. Load a single AVX vector from each of the 12 rows. Each vector can hold 8 floats, so we can load 4 columns' worth of data at a time.
    1.2. Compute 4 parallel FFTs of size 12 using SIMD instructions.
    1.3. Write our 12 AVX vectors back to where we originally loaded them.
  2. Multiply with precomputed twiddle factors
  3. Compute sub-FFTs of size N / 12 down the rows of our 2D array
  4. Transpose

Part of the reason why this setup is so fast is that in step 1.2, we're using 12 of our 16 AVX registers for FFT data, giving us 4 registers for temporary data, which turns out to be just barely enough. So as we compute these size-12 FFTs, we're doing so in a way that gives the compiler and processor the best possible opportunity to pipeline, parallelize, and reorder everything to be as fast as possible.

Merging Twiddle Factors

There is one final problem: We're still violating rule #2 from the SIMD rules earlier: Step 1 processes a bunch of data and writes it back to memory, and step 2 immediately loads that data back to do more operations on it. So the final step is to merge step 2 into step 1, giving us the algorithm found in RustFFT AVX code:

Algorithm 5:

  1. Compute sub-FFTs of size 12 down the columns of our 2D array
    1.1. Load a single AVX vector from each of the 12 rows. Each vector can hold 8 floats, so we can load 4 columns' worth of data at a time.
    1.2. Compute 4 parallel FFTs of size 12 using SIMD instructions.
    1.3. Multiply with precomputed twiddle factors
    1.4. Write our 12 AVX vectors back to where we originally loaded them.
  2. Compute sub-FFTs of size N / 12 down the rows of our 2D array
  3. Transpose

I guess we could call this the "Three Step FFT With Four Substeps" algorithm. Not quite as catchy as Six Step FFT or Four Step FFT. I think the latter two have some satisfying alliteration that this new name doesn't quite capture.

We're still violating Rule #1 in the Step 3 Transpose, and benchmarking shows that we do indeed spend around 40%-50% of our time on the transpose. I suspect that transposing, or bit reversing, or some other nontrivial reordering of the data appears to be a fact of life with FFT computation - but at least this way, it's only a third of the process, rather than most of it.

The MixedRadix8xnAvx , MixedRadix9xnAvx, and MixedRadix12xnAvx structs (and several others in that file) all implement Algorithm 5. "8xn", "9xn", etc are a reference that they factor the FFT size into 8 * N, or 9 * n, etc.

Composing Algorithms Into a FFT Plan

In Algorithm 5, if step 2 was a static function call, then we could only ever process FFTs of size 12^U, or more generally, B^U, for any given radix B. In order to be composable, we need some form of dynamic dispatch for these inner FFTs.

The final special sauce of RustFFT is that the inner FFT call in step 2 is a trait object of the Fft trait. The inner FFT of a MixedRadix12xn can be a MixedRadix5xn. The inner FFT of that MixedRadix5xn can be a hardcoded base case, optimized for a specific size. Any struct that implements the Fft trait can be an inner FFT for any other FFT, letting us compose FFTs however we need.

Once we have this capability, however, there's a new challenge: If we have a FFT size with a prime factorization of 2^10 * 3^8, there are 18 total prime factors that we could arrange in different ways, and the result is a combinatoric explosion of choices. How do we know where to start?

To begin with, it might be best to know which of our FFT algorithms is fastest.

Algorithm Performance

RustFFT's AVX code has implementations for Mixed-Radix 2xn, 3xn, 4xn, 5xn, 6xn, 7xn, 8xn, 9xn, 11x, 12xn, and 16xn.

Of these, 12xn is the fastest in benchmarks, followed closely by 8xn and 9xn. I suspect that they're faster than the others because, like I mentioned before, these algorithms do the best job of using CPU registers to their fullest potential. Also recall something I said way back at the beginning of this article: The Six-Step FFT appears to perform best when N is factored into A and B where A and B are close, and ideally equal. It stands to reason that the larger we can make the constant value in our monomorphized algorithms, the faster the resulting FFT will be.

With that in mind, one might think that 16xn would be the fastest -- but unfortunately, it is just barely too large to fit into AVX registers. We need 16 vectors for just the FFT data, and more for temporaries. The end result is that the CPU has to spill data to temporary memory during the computation, which violates rule #2 of our SIMD rules up above. These problems hamper 16xn just enough to make it lose out.

(Footnote: AVX-512 has 32 SIMD registers instead of 16, but that doesn't offer much hope, because in this architecture, a Bxn algorithm requires both B AVX registers storing FFT data, and B general-purpose registers storing array pointers. Our architecture thus doesn't give us an obvious way to use all 32 of those AVX-512 registers, since AVX-512 still only has 16 general-purpose registers. If you have any ideas here, I'd love to hear them!)

Planning Heuristics

With the benchmarking results in mind, the strategy for our FFT planner becomes clear: Pass the ball to TuckerMixedRadix12xnAvx. If there are no factors of 3 remaining in the FFT size, begrudgingly use 8xn instead. If instead, there are no multiples of 4 left, use 9xn.

In practice, there is an entire other planning challenge around how to choose a recursion base case -- and there's an entire separate architecture for writing hardcoded base cases -- but those seem to have a less meaningful impact on performance than most of the things discussed so far, and this article is already too long, so I'll stop here.

Unanswered Questions

  • FFTW beats RustFFT at a size of 512. 512 is a hardcoded base case in RustFFT, so it doesn't even have anything to do with the artchitecture this article discusses. How is their size-512 FFT so fast?

  • Is there a smarter way to plan which algorithm to use? "Pass the ball to 12xn" gets us to the point where we beat FFTW, but that model of heuristics (IE, run benchmarks, find a pattern, and then hardcode it to follow the pattern) is really labor-intensive, and is subject to human bias in a number of ways.

  • Like I mentioned above, this architecture doesn't extend well to the 32 registers of AVX-512. How can we take advantage of those registers?

  • One thing this article didn't even touch is split-vs-interleaved complex numbers.

    • RustFFT stores its complex numbers interleaved, IE in the memory layout, we have something like [0re, 0im, 1re, 1im, 2re, 2im]. This is very convenient to work with at a high level, but makes multiplications a little more complicated, because we have to shuffle data between lanes. If you remember way back to the SIMD rules in the middle of the article, shuffling data between lanes is a violation of Rule #1.
    • An alternative memory layout is to store the imaginary numbers in separate arrays for the real numbers. This would mean no data reshuffling when multiplying.
    • Is this alternative faster, even considering that you have to run O(N) pre and postprocessing steps to split the data apart and put it back together? Does it make the algorithms harder to understand? In the transpose steps, you'd have to do 8x8 f32 transposes instead of 4x4 complex transposes, which are considerably more complicated. Do the extra costs of those transposes outweigh the benefits?
  • Another issue this article didn't touch is f64 FFTs. RustFFT supports them using the exact same architecture, but there's a unique challenge around how to be generic over floating point type. RustFFT has the AvxVector trait which has common operations on it, so that the MixedRadix8xn structs can call these generic methods instead, which has worked out well.

    • One challenge that I haven't been able to solve is hardcoded base cases: In order to compute a size-12 f32 FFT, you need an array of 8 twiddle factors, represented as [__m256; 2]. In order to compute a size-12 f64 FFT, you need an array of 6 twiddle factors, represented as [__m256d;3]. As far as I know, there's no proposal for specializing the layout of a struct based on a generic parameter, and in general there are countless challenges in the way of making hardcoded base cases generic.
    • Another challenge is that putting your SIMD intrinsics inside trait methods more or less requires you to make every single line of your AVX project unsafe. This is unreasonable, but unfortunately there's no solution for now. I opened a RFC to fix this, but I think that's going to (very understandably) move pretty slowly.

Thanks for reading! If you want to help answer these questions, or just have comments in general, leave a reply here or on reddit, or send me a private message on either of the two.

42 Likes

That's a really interesting write-up, thanks for sharing! I guess if you wanted alliteration, you could call Algorithm 5 a "Free-Step FFT" (sic). :smiley:

2 Likes

Wow! What a comprehensive and fascinating write up. Not that I understand most of it.

Mind you it was a decade or two between my seeing an FFT algorithm for the first time, in BASIC, and understanding how the FFT worked well enough to write my own from scratch. All that bit-reversal stuff seemed to come out of the blue and I could never grasp any explanation of it I found.

My FFT was written in assembler for the 32 bit Parallax Inc. Propeller micro-controller. Had to fit into 512 instructions and use integer only arithmetic.

Versions of that in Rust, C and other languages are here: GitHub - ZiCog/fftbench: A parallel integer only FFT for multi-core micro-controllers like the Parallax Propeller or XMOS devices.. Not sure where the assembler version is just now.

Thanks for the great contribution.

Oh, there is an update to the crates.io page required. The Camel case needs fixing.

Should be using FftPlanner instead of FFTplanner.

Bigger radices can be used to get better cache performance. I don't recall the exact values, but they were something like radix 256 for L1 cache and radix 4K for L2 on some very old machines.

Many years ago I wrote an article for SIAM Review called Bit Reversal on Uniprocessors. In that paper I compared the performance of some 20+ bit reversal algorithms. I was shocked to find so many ways to do such a simple thing.

The term "bit reversal" is misleading. What you're doing is reordering the elements of the one dimensional array so the bits of its indices are reversed. So, if I have a set of indices 0, 1, 2, 3 (00, 01, 10, 11), the bit reversal gives 0, 2, 1, 3 (00, 10, 01, 11). The third element in the original array becomes the second in the bit reversed array.

Once the array is reordered, you can construct an FFT algorithm that doesn't need additional storage. That was a very big deal in the days when a mainframe with a MB of core storage was a rare commodity. There are also FFT algorithms that don't need a bit reversal, but they cannot be done in place.

If anyone is interested, I have an (unpublished) algorithm for doing mixed-radix, multidimensional FFTs at small stride without a transpose. It does need a bit reversal on each row, but the row will fit in cache except for the very largest problems.

1 Like

The way I see it is that the "bit-reversal" thing comes about because in the recursive algorithm we divide and conquer by taking all the even elements together and all the odd elements together and calling ourselves again twice with those two sets. Which happens to end up working on pairs of elements as if they came from bit-reversed array indices.

In the simple 3 nested loop iterative algorithm we do the bit-reversal reordering to achieve the same effect.

Gives me headache thinking about it.

The Parallax Propeller MCU has an instruction to reverse the bits of a word. Can't remember if I used it or not.

You could achieve this with something like RustFFT's MixedRadix struct, with 256 as the height, and N/256 as the width. Easily doable with scalar code, I'm not quite sure how to get it done with AVX code though.

The algorithm vectorized very nicely on the Cray Y-MP and IBM 3090 Vector Unit. (I told you it was a long time ago.)

You can "specialize" the layout of structs by using associated types. Something like:

struct Size12Fft<T: Size12FftInfo> {
    twiddle: T::Twiddle,
    // other fields
}

trait Size12FftInfo {
    type Twiddle;
    // in case these are compile time constants, you could do
    // const TWIDDLE: Self::Twiddle;

    fn apply_twiddle(twiddle: Self::Twiddle, something_else_here: Args) -> Output { ... }
}

impl Size12FftInfo for f32 {
    type Twiddle = [__m256; 2];

    fn apply_twiddle(twiddle: Self::Twiddle, something_else_here: Args) -> Output { ... }
}

impl Size12FftInfo for f64 {
    type Twiddle = [__m256d; 3];

    fn apply_twiddle(twiddle: Self::Twiddle, something_else_here: Args) -> Output { ... }
}

I haven't looked into RustFFT enough to know how exactly to apply this, but this is the general idea.

2 Likes

Where are the benchmarks showing the speed advantage over fftw? I'm curious how much it was, and what system it was conducted on. Would be interesting to see if it holds on M1.

Not published anywhere, but you can clone this and test them yourself: GitHub - ejmahler/fourier: Fast Fourier transforms (FFTs) in Rust

Hmm. I got this running cargo bench in fourier-bench:

rror[E0061]: this function takes 0 arguments but 1 argument was supplied
   --> fourier-bench/benches/fft_bench.rs:40:31
    |
40  |                   let rustfft = rustfft::FftPlanner::<$type>::new(!forward).plan_fft(size);
    |                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -------- supplied 1 argument
    |                                 |
    |                                 expected 0 arguments
...
153 | / create_scenarios! {
154 | |     [pow2, "powers of two", &mut (6..14).map(|x| 2usize.pow(x))]
155 | |     [pow3, "powers of three", &mut (4..8).map(|x| 3usize.pow(x))]
156 | |     [pow5, "powers of five", &mut (2..6).map(|x| 5usize.pow(x))]
...   |
159 | |     [prime, "primes", &mut [191, 439, 1013].iter().map(|x| *x)]
160 | | }
    | |_- in this macro invocation
    |
    = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0061]: this function takes 2 arguments but 1 argument was supplied
   --> fourier-bench/benches/fft_bench.rs:40:75
    |
40  |                   let rustfft = rustfft::FftPlanner::<$type>::new(!forward).plan_fft(size);
    |                                                                             ^^^^^^^^ ---- supplied 1 argument
    |                                                                             |
    |                                                                             expected 2 arguments
...
153 | / create_scenarios! {
154 | |     [pow2, "powers of two", &mut (6..14).map(|x| 2usize.pow(x))]
155 | |     [pow3, "powers of three", &mut (4..8).map(|x| 3usize.pow(x))]
156 | |     [pow5, "powers of five", &mut (2..6).map(|x| 5usize.pow(x))]
...   |
159 | |     [prime, "primes", &mut [191, 439, 1013].iter().map(|x| *x)]
160 | | }
    | |_- in this macro invocation
    |
    = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0599]: no method named `process_inplace_with_scratch` found for struct `Arc<dyn Fft<f32>>` in the current scope
   --> fourier-bench/benches/fft_bench.rs:45:39
    |
45  |                       b.iter(|| rustfft.process_inplace_with_scratch(&mut buffer, &mut scratch))
    |                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: there is an associated function with a similar name: `process_outofplace_with_scratch`
...
153 | / create_scenarios! {
154 | |     [pow2, "powers of two", &mut (6..14).map(|x| 2usize.pow(x))]
155 | |     [pow3, "powers of three", &mut (4..8).map(|x| 3usize.pow(x))]
156 | |     [pow5, "powers of five", &mut (2..6).map(|x| 5usize.pow(x))]
...   |
159 | |     [prime, "primes", &mut [191, 439, 1013].iter().map(|x| *x)]
160 | | }
    | |_- in this macro invocation
    |
    = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0061]: this function takes 0 arguments but 1 argument was supplied
   --> fourier-bench/benches/fft_bench.rs:40:31
    |
40  |                   let rustfft = rustfft::FftPlanner::<$type>::new(!forward).plan_fft(size);
    |                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -------- supplied 1 argument
    |                                 |
    |                                 expected 0 arguments
...
153 | / create_scenarios! {
154 | |     [pow2, "powers of two", &mut (6..14).map(|x| 2usize.pow(x))]
155 | |     [pow3, "powers of three", &mut (4..8).map(|x| 3usize.pow(x))]
156 | |     [pow5, "powers of five", &mut (2..6).map(|x| 5usize.pow(x))]
...   |
159 | |     [prime, "primes", &mut [191, 439, 1013].iter().map(|x| *x)]
160 | | }
    | |_- in this macro invocation
    |
    = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0061]: this function takes 2 arguments but 1 argument was supplied
   --> fourier-bench/benches/fft_bench.rs:40:75
    |
40  |                   let rustfft = rustfft::FftPlanner::<$type>::new(!forward).plan_fft(size);
    |                                                                             ^^^^^^^^ ---- supplied 1 argument
    |                                                                             |
    |                                                                             expected 2 arguments
...
153 | / create_scenarios! {
154 | |     [pow2, "powers of two", &mut (6..14).map(|x| 2usize.pow(x))]
155 | |     [pow3, "powers of three", &mut (4..8).map(|x| 3usize.pow(x))]
156 | |     [pow5, "powers of five", &mut (2..6).map(|x| 5usize.pow(x))]
...   |
159 | |     [prime, "primes", &mut [191, 439, 1013].iter().map(|x| *x)]
160 | | }
    | |_- in this macro invocation
    |
    = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0599]: no method named `process_inplace_with_scratch` found for struct `Arc<dyn Fft<f64>>` in the current scope
   --> fourier-bench/benches/fft_bench.rs:45:39
    |
45  |                       b.iter(|| rustfft.process_inplace_with_scratch(&mut buffer, &mut scratch))
    |                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: there is an associated function with a similar name: `process_outofplace_with_scratch`
...
153 | / create_scenarios! {
154 | |     [pow2, "powers of two", &mut (6..14).map(|x| 2usize.pow(x))]
155 | |     [pow3, "powers of three", &mut (4..8).map(|x| 3usize.pow(x))]
156 | |     [pow5, "powers of five", &mut (2..6).map(|x| 5usize.pow(x))]
...   |
159 | |     [prime, "primes", &mut [191, 439, 1013].iter().map(|x| *x)]
160 | | }
    | |_- in this macro invocation
    |
    = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)

error: aborting due to 6 previous errors