Checking whether four 2-bit uints are unique: How to optimize?

The input is a u8 which is four 2-bit uints packed together.

I want to check whether the four bit fields are exactly 0,1,2,3 in any order. Inputs that meet this requirement are considered valid.

(Additionally, if the highest-order bit field is 0, and the other three are exactly 0,1,2 in any order, the input is also considered valid)

Here's a naive implementation which gets the correct result (Rust Playground):

fn is_valid(packed: u8) -> bool {
    let u0 = packed & 0b11;
    let u1 = (packed >> 2) & 0b11;
    let u2 = (packed >> 4) & 0b11;
    let u3 = (packed >> 6) & 0b11;
    u0 != u1 && u0 != u2 && u1 != u2
        && (u0 != u3 && u1 != u3 && u2 != u3 || u3 == 0 && u0 != 3 && u1 != 3 && u2 != 3)
}

The inputs would mostly (>99%) be valid, so I want to optimize for the happy path. The merit of short-circuiting in sad paths is practically worthless.

And here's how I benchmarked it using Criterion:

use criterion::{black_box, criterion_group, criterion_main, Criterion};
// Add imports here

fn criterion_benchmark(c: &mut Criterion) {
    let valid_values: Vec<_> = (0..=255).filter(|&u| is_valid(u)).collect();
    let mut samples = vec![0];
    for _ in 0..3 {
        samples.extend_from_slice(&valid_values);
    }
    c.bench_function("is_valid", |b| {
        b.iter(|| {
            samples[0] += 1;
            for &u in &samples {
                black_box(is_valid(u));
            }
        })
    });
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

By the way, I've tried using enums with num_derive or num_enum and here's the brief result (the first one is the naive implementation):


is_valid                time:   [176.17 ns 178.18 ns 180.44 ns]
                        change: [-4.0889% -2.6412% -1.2874%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 4 outliers among 100 measurements (4.00%)
  1 (1.00%) low mild
  1 (1.00%) high mild
  2 (2.00%) high severe

is_valid_enum_num_enum  time:   [196.51 ns 198.01 ns 199.60 ns]
                        change: [-2.2483% -1.0571% +0.1369%] (p = 0.09 > 0.05)
                        No change in performance detected.
Found 4 outliers among 100 measurements (4.00%)
  4 (4.00%) high mild

is_valid_enum_num_derive
                        time:   [346.85 ns 351.05 ns 356.00 ns]
                        change: [-1.0824% +0.4368% +1.9960%] (p = 0.58 > 0.05)
                        No change in performance detected.
Found 5 outliers among 100 measurements (5.00%)
  5 (5.00%) high mild

Store the valid patterns in a 256-bit bitvector and index it using the u8 input.

7 Likes

Good point. However I have one more question.

I've checked the asm generated by is_valid_enum_num_enum based on the TryFrom impl derived with num_enum, and it seems to use a jump table.

Why is it so much slower than directly using an array then?

I was curious so I experimented with the bitvector approach. I noticed that runtime performance is slightly slower then using an array of bool. I also experimented with using different array types (u16, u32, u64, u128) for the bitvector. u32 performed the best.

#[inline]
pub const fn is_valid_bitvec_u32(packed: u8) -> bool {
    const BITVEC: [u32; 8] = const {
        let mut table = [0u32; 8];
        let mut i = 0;
        while i < 256 {
            let outer = (i as u8) >> 5;
            let inner = (i as u8) & 0b00011111;
            table[outer as usize] |= (is_valid_inner(i as u8) as u32) << inner;
            i += 1;
        }
        table
    };
    (BITVEC[(packed >> 5) as usize] >> (packed & 0b00011111)) & 1 == 1
}

#[inline]
pub const fn is_valid_bool_array(value: u8) -> bool {
    const VALID: [bool; 256] = const {
        let mut table = [false; 256];
        let mut i = 0;
        while i < 256 {
            table[i] = is_valid_inner(i as u8);
            i += 1;
        }
        table
    };
    VALID[value as usize]
}

pub const fn is_valid_inner(packed: u8) -> bool {
    let u0 = packed & 0b11;
    let u1 = (packed >> 2) & 0b11;
    let u2 = (packed >> 4) & 0b11;
    let u3 = (packed >> 6) & 0b11;
    u0 != u1
        && u0 != u2
        && u1 != u2
        && (u0 != u3 && u1 != u3 && u2 != u3 || u3 == 0 && u0 != 3 && u1 != 3 && u2 != 3)
}

#[test]
fn all_valid() {
    for i in 0..=255 {
        assert_eq!(is_valid_inner(i), is_valid_bitvec_u32(i), "{i}");
        assert_eq!(is_valid_inner(i), is_valid_bool_array(i), "{i}");
    }
}
     Running benches/bench.rs (target/release/deps/bench-5a85348521db05f5)
Gnuplot not found, using plotters backend
is_valid_inner          time:   [1.3831 µs 1.3919 µs 1.4015 µs]
                        change: [-6.5454% -1.0176% +4.6803%] (p = 0.73 > 0.05)
                        No change in performance detected.
Found 2 outliers among 100 measurements (2.00%)
  1 (1.00%) high mild
  1 (1.00%) high severe

is_valid_bool_array     time:   [381.54 ns 381.76 ns 381.99 ns]
                        change: [-3.2278% -2.2605% -1.4711%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
  5 (5.00%) high mild
  6 (6.00%) high severe

is_valid_bitvec_u32     time:   [464.61 ns 465.00 ns 465.51 ns]
                        change: [-3.3932% -2.8046% -2.3223%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 15 outliers among 100 measurements (15.00%)
  7 (7.00%) low mild
  4 (4.00%) high mild
  4 (4.00%) high severe
use criterion::{black_box, criterion_group, criterion_main, Criterion};
// Add imports here

fn criterion_benchmark(c: &mut Criterion) {
    let mut samples = vec![0];
    samples.extend(0..=255);
    samples.extend(0..=255);
    samples.extend(0..=255);

    c.bench_function("is_valid_inner", |b| {
        b.iter(|| {
            samples[0] += 1;
            for &u in &samples {
                black_box(is_valid_inner(u));
            }
        })
    });
    c.bench_function("is_valid_bool_array", |b| {
        b.iter(|| {
            samples[0] += 1;
            for &u in &samples {
                black_box(is_valid_bool_array(u));
            }
        })
    });
    c.bench_function("is_valid_bitvec_u32", |b| {
        b.iter(|| {
            samples[0] += 1;
            for &u in &samples {
                black_box(is_valid_bitvec_u32(u));
            }
        })
    });

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
1 Like

Note I did modify the benchmark a bit to remove the part where you filter by only "valid" values.

I deliberately used 90 valid samples + 1 random sample which is 30/256 valid。

Because that's what I would get: most of the time valid data, corrupted data is a rare case.

I just don't want to go the unsafe fn new_unchecked way.

This may reverse if the table doesn't always sit in the L1 cache due to other things going on in the program.

6 Likes

Is your data in a [u8]? If so, you should have your function take a whole &[u8] so that you can try things like checking large chunks of bytes, then if there's an invalid byte, going back one-by-one to see which one it was. It also lets you exploit patterns in the input (if there are any) and write SIMD-friendly code. It's not guaranteed that the fastest function on one byte is also the fastest function when checking many bytes in a row.

Thanks for the idea. Unfortunately, the packed u8 values are scattered around before getting filtered. And after validity check it would be embedded into a much larger struct.

While I can collect them into a [u8] and try what you said before mapping them back, the collecting and mapping procedure introduces non-trivial overhead, plus it would break encapsulation too badly which is what I'm trying to avoid.

1 Like

This is branchless(no branches in the valid case) and maybe faster than the lookup table

pub fn is_valid(packed: u8) -> bool {
    let u0 = packed & 0b11;
    let u1 = (packed >> 2) & 0b11;
    let u2 = (packed >> 4) & 0b11;
    let u3 = (packed >> 6) & 0b11;
    (1 << u0) | (1 << u1) | (1 << u2) | (1 << u3) == 0b1111
}

Than you can add a u64 based group processing like this

pub const fn are_valid_u64(x: u64) -> bool {
    const MASK: u64 = u64::from_ne_bytes([0b11; 8]);
    const PARITY_MASK: u64 = u64::from_ne_bytes([1; 8]);
    const PARITY_SUM: u64 = u64::from_ne_bytes([2; 8]);
    const SUM: u64 = u64::from_ne_bytes([1 + 2 + 3; 8]);
    // written this way so fn is const
    let masked = [
        (x >> 0) & MASK,
        (x >> 2) & MASK,
        (x >> 4) & MASK,
        (x >> 6) & MASK,
    ];
    let parity = [
        (masked[0] ^ (masked[0] >> 1)) & PARITY_MASK,
        (masked[1] ^ (masked[1] >> 1)) & PARITY_MASK,
        (masked[2] ^ (masked[2] >> 1)) & PARITY_MASK,
        (masked[3] ^ (masked[3] >> 1)) & PARITY_MASK,
    ];
    let sum = masked[0] + masked[1] + masked[2] + masked[3];
    let parity_sum = parity[0] + parity[1] + parity[2] + parity[3];
    (sum == SUM) & (parity_sum == PARITY_SUM)
}

and combine them for fast processing of slices until an invalid input is found

pub fn until_invalid(mut x: &[u8]) -> &[u8] {
    while let Some((chunk, rest)) = x.split_first_chunk() {
        if are_valid_u64(u64::from_ne_bytes(*chunk)) {
            x = rest;
        } else {
            break
        }
    }
    while let Some((first, rest)) = x.split_first() {
        if is_valid(*first) {
            x = rest;
        } else {
            break
        }
    }
    x
}
2 Likes

How does the final destination look like? What happens if invalid values are found? Is this actually a bottleneck? Your problem description sounds like the gathering of the values themselves will be because accessing scattered values is typically not cache friendly.

A Vec of structs. The packed u8 (wrapped into a newtype after validity check) is one of the struct's fields.

The construction of the struct fails. We can choose to skip the corrupted entry or to fail fast when we fill the Vec.

Honestly, no. I'm just trying to learn some optimization techniques by asking.

It's json deserialization. The json data itself is a continuous Bytes buffer (Deref<Target = [u8]>), but the (packed) numeric fields are scattered within it.

The json data provider is out of our control.

Without group processing, this is even slower than the naive approach:

pub const fn is_valid_branchless(packed: u8) -> bool {
    let packed = packed as u64;

    const MASK: u64 = 0b11;
    const PARITY_MASK: u64 = 1;
    const PARITY_SUM: u64 = 2;
    const SUM: u64 = 1 + 2 + 3;
    // written this way so fn is const    
    let masked = [
        packed & MASK,
        (packed >> 2) & MASK,
        (packed >> 4) & MASK,
        (packed >> 6) & MASK,
    ];
    let parity = [
        (masked[0] ^ (masked[0] >> 1)) & PARITY_MASK,
        (masked[1] ^ (masked[1] >> 1)) & PARITY_MASK,
        (masked[2] ^ (masked[2] >> 1)) & PARITY_MASK,
        (masked[3] ^ (masked[3] >> 1)) & PARITY_MASK,
    ];
    let sum = masked[0] + masked[1] + masked[2] + masked[3];
    let parity_sum = parity[0] + parity[1] + parity[2] + parity[3];
    (sum == SUM) & (parity_sum == PARITY_SUM)
}

I wonder how we can parallelize those checkes for a single u8.

Which is slower? The is_valid with the bitset? This suprises me because the pairwise comparison variant has three branches in the sucess case: Compiler Explorer

It might be that your CPU's branch predictor learns the order of valid values or better the frequency of taken branches in the benchmark. From your problem description I would assume that you cannot meausure a consistent difference in the actual application. I would go for the most readable solution.

Generally, from a performance perspective you have to take into consideration that the CPU recources you use for this problem like ALUs, branch prediction buffer and cache are not available for the other tasks you need to perform "at the same time". This is the largest short coming of microbenchmarking subtasks.

3 Likes

If you're willing to use nightly, there's a couple of tricks you can use with llvm-mca and comments to get it to show you not just the assembly, but also llvm-mca information for how much CPU resource each function takes.

The comments delimit assembly blocks to analyze; the llvm-mca documentation tells you what it can do, and what the output shows you where it's not obvious (e.g. in this case, is_valid_bitset takes 3.82 clock cycles per iteration, while is_valid_compare takes 9.10 clock cycles per iteration).

2 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.