How to find common prefix of two byte slices effectively?

Inspired by the code from Knuth Morris Pratt algorithm with C++ and Rust [Solved], I'd like to ask the following question: what is the fastest way to find the first difference between two byte-slices? That is, can the following be made faster?

pub fn prefix(xs: &[u8], ys: &[u8]) -> usize {
    xs.iter().zip(ys)
      .take_while(|(x, y)| x == y)
      .count()
}

Compiler explorer shows that this code generates a loop that does a single byte comparison at a time: Compiler Explorer. I expect this could be done quite a bit faster if using simd or just comparing an u64 at a time?

I am surprised that there isn't a function for this already. There's memcmp, but it doesn't reveal the location of mismatch (although it needs to have this info?).

2 Likes

If you use array references instead of slices you will not see a loop at all.

Still don't know good answer to this question, but got annoyed enough to quickly code up some simd solution:

#[inline(never)]
pub fn mismatch(xs: &[u8], ys: &[u8]) -> usize {
    xs.iter().zip(ys).take_while(|(x, y)| x == y).count()
}

#[inline(never)]
pub fn mismatch_simd(xs: &[u8], ys: &[u8]) -> usize {
    let l = xs.len().min(ys.len());
    let mut xs = &xs[..l];
    let mut ys = &ys[..l];
    let mut off = 0;

    unsafe {
        use std::arch::x86_64::*;

        let zero = _mm256_setzero_si256();
        while xs.len() >= 32 {
            let x = _mm256_loadu_si256(xs.as_ptr() as _);
            let y = _mm256_loadu_si256(ys.as_ptr() as _);

            let r = _mm256_xor_si256(x, y);
            let r = _mm256_cmpeq_epi8(r, zero);
            let r = _mm256_movemask_epi8(r);
            if r.trailing_ones() < 32 {
                return off + r.trailing_ones() as usize;
            }

            xs = &xs[32..];
            ys = &ys[32..];
            off += 32;
        }
    }
    off + mismatch(xs, ys)
}

#[test]
fn smoke() {
    let xs = "all the work and no play made jack a dull boy"
        .repeat(3)
        .into_bytes();
    let ys = [xs.clone(), xs.clone()].concat();
    for i in 0..ys.len() {
        let mut ys = ys.clone();
        ys[i] = 0;
        assert_eq!(mismatch(&xs, &ys), mismatch_simd(&xs, &ys))
    }
}

#[test]
fn bench() {
    let n = 100_000;
    let m = 100;
    let mut xs = "Hello, world".repeat(n).into_bytes();
    let mut ys = xs.clone();
    xs.push(b'x');
    ys.extend(b"ijk");

    let t = std::time::Instant::now();
    let mut res1 = 0;
    for _ in 0..m {
        res1 += mismatch(&xs, &ys);
    }
    eprintln!("naive {:0.2?}", t.elapsed());

    let t = std::time::Instant::now();
    let mut res2 = 0;
    for _ in 0..m {
        res2 += mismatch_simd(&xs, &ys);
    }
    eprintln!("simd  {:0.2?}", t.elapsed());
    assert_eq!(res1, res2);
}
λ RUSTFLAGS="-Ctarget-cpu=native" cargo t -r --lib -- bench --nocapture
   Compiling m v0.1.0 (/home/matklad/tmp/m)
    Finished release [optimized] target(s) in 0.26s
     Running unittests src/lib.rs (target/release/deps/m-6997e6fd8122178e)

running 1 test
naive 26.98ms
simd  3.59ms
test bench ... ok

Note: I think that's my first SIMD program, definitely not production ready

6 Likes

nkkarpov (Nikolai Karpov) · GitHub found a much better solution:

pub fn mismatch(xs: &[u8], ys: &[u8]) -> usize {
    mismatch_chunks::<128>(xs, ys)
}

fn mismatch_chunks<const N: usize>(xs: &[u8], ys: &[u8]) -> usize {
    let off = iter::zip(xs.chunks_exact(N), ys.chunks_exact(N))
        .take_while(|(x, y)| x == y)
        .count()
        * N;
    off + iter::zip(&xs[off..], &ys[off..])
        .take_while(|(x, y)| x == y)
        .count()
}

By just manually chunking the iterator, we give enough hint to the compiler to auto-vectorise this.

10 Likes

I would be interested to see the same benchmark run on that 3rd solution.
Makes it easier to compare to the naive and simd solutions.

As an aside: where can I read up on actually using SIMD in Rust? It would be useful for me to get SIMD up and running with wasm for example.

λ cargo t -r -- bench --nocapture
naive      662.42ms
simd       741.04ms
chunk      83.34ms
test bench ... ok

λ RUSTFLAGS="-Ctarget-cpu=native" cargo t -r -- bench --nocapture
naive      654.31ms
simd       80.46ms
chunk      81.20ms
test bench ... ok
use std::iter;

#[inline(never)]
pub fn mismatch(xs: &[u8], ys: &[u8]) -> usize {
    xs.iter().zip(ys).take_while(|(x, y)| x == y).count()
}

#[inline(never)]
pub fn mismatch_simd(xs: &[u8], ys: &[u8]) -> usize {
    let l = xs.len().min(ys.len());
    let mut xs = &xs[..l];
    let mut ys = &ys[..l];
    let mut off = 0;

    unsafe {
        use std::arch::x86_64::*;

        let zero = _mm256_setzero_si256();
        while xs.len() >= 32 {
            let x = _mm256_loadu_si256(xs.as_ptr() as _);
            let y = _mm256_loadu_si256(ys.as_ptr() as _);

            let r = _mm256_xor_si256(x, y);
            let r = _mm256_cmpeq_epi8(r, zero);
            let r = _mm256_movemask_epi8(r);
            if r.trailing_ones() < 32 {
                return off + r.trailing_ones() as usize;
            }

            xs = &xs[32..];
            ys = &ys[32..];
            off += 32;
        }
    }
    off + mismatch(xs, ys)
}

#[inline(never)]
pub fn mismatch_chunked(xs: &[u8], ys: &[u8]) -> usize {
    fn inner<const N: usize>(xs: &[u8], ys: &[u8]) -> usize {
        let off = iter::zip(xs.chunks_exact(N), ys.chunks_exact(N))
            .take_while(|(x, y)| x == y)
            .count()
            * N;
        off + iter::zip(&xs[off..], &ys[off..])
            .take_while(|(x, y)| x == y)
            .count()
    }

    inner::<128>(xs, ys)
}

#[test]
fn bench() {
    fn bench_mismatch(name: &str, f: fn(&[u8], &[u8]) -> usize) {
        let n = 500_000;
        let m = 500;
        let mut xs = "Hello, world".repeat(n).into_bytes();
        let mut ys = xs.clone();
        xs.push(b'x');
        ys.extend(b"ijk");

        let t = std::time::Instant::now();
        let mut res = 0;
        for _ in 0..m {
            res += f(&xs, &ys);
        }
        eprintln!("{name:10} {:0.2?}", t.elapsed());
        assert_eq!(res, 3000000000);
    }

    bench_mismatch("naive", mismatch);
    bench_mismatch("simd ", mismatch_simd);
    bench_mismatch("chunk ", mismatch_chunked);
}

7 Likes