Inspired by the code from Knuth Morris Pratt algorithm with C++ and Rust [Solved], I'd like to ask the following question: what is the fastest way to find the first difference between two byte-slices? That is, can the following be made faster?
Compiler explorer shows that this code generates a loop that does a single byte comparison at a time: Compiler Explorer. I expect this could be done quite a bit faster if using simd or just comparing an u64 at a time?
I am surprised that there isn't a function for this already. There's memcmp, but it doesn't reveal the location of mismatch (although it needs to have this info?).
Still don't know good answer to this question, but got annoyed enough to quickly code up some simd solution:
#[inline(never)]
pub fn mismatch(xs: &[u8], ys: &[u8]) -> usize {
xs.iter().zip(ys).take_while(|(x, y)| x == y).count()
}
#[inline(never)]
pub fn mismatch_simd(xs: &[u8], ys: &[u8]) -> usize {
let l = xs.len().min(ys.len());
let mut xs = &xs[..l];
let mut ys = &ys[..l];
let mut off = 0;
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
while xs.len() >= 32 {
let x = _mm256_loadu_si256(xs.as_ptr() as _);
let y = _mm256_loadu_si256(ys.as_ptr() as _);
let r = _mm256_xor_si256(x, y);
let r = _mm256_cmpeq_epi8(r, zero);
let r = _mm256_movemask_epi8(r);
if r.trailing_ones() < 32 {
return off + r.trailing_ones() as usize;
}
xs = &xs[32..];
ys = &ys[32..];
off += 32;
}
}
off + mismatch(xs, ys)
}
#[test]
fn smoke() {
let xs = "all the work and no play made jack a dull boy"
.repeat(3)
.into_bytes();
let ys = [xs.clone(), xs.clone()].concat();
for i in 0..ys.len() {
let mut ys = ys.clone();
ys[i] = 0;
assert_eq!(mismatch(&xs, &ys), mismatch_simd(&xs, &ys))
}
}
#[test]
fn bench() {
let n = 100_000;
let m = 100;
let mut xs = "Hello, world".repeat(n).into_bytes();
let mut ys = xs.clone();
xs.push(b'x');
ys.extend(b"ijk");
let t = std::time::Instant::now();
let mut res1 = 0;
for _ in 0..m {
res1 += mismatch(&xs, &ys);
}
eprintln!("naive {:0.2?}", t.elapsed());
let t = std::time::Instant::now();
let mut res2 = 0;
for _ in 0..m {
res2 += mismatch_simd(&xs, &ys);
}
eprintln!("simd {:0.2?}", t.elapsed());
assert_eq!(res1, res2);
}
λ RUSTFLAGS="-Ctarget-cpu=native" cargo t -r --lib -- bench --nocapture
Compiling m v0.1.0 (/home/matklad/tmp/m)
Finished release [optimized] target(s) in 0.26s
Running unittests src/lib.rs (target/release/deps/m-6997e6fd8122178e)
running 1 test
naive 26.98ms
simd 3.59ms
test bench ... ok
Note: I think that's my first SIMD program, definitely not production ready
λ cargo t -r -- bench --nocapture
naive 662.42ms
simd 741.04ms
chunk 83.34ms
test bench ... ok
λ RUSTFLAGS="-Ctarget-cpu=native" cargo t -r -- bench --nocapture
naive 654.31ms
simd 80.46ms
chunk 81.20ms
test bench ... ok
use std::iter;
#[inline(never)]
pub fn mismatch(xs: &[u8], ys: &[u8]) -> usize {
xs.iter().zip(ys).take_while(|(x, y)| x == y).count()
}
#[inline(never)]
pub fn mismatch_simd(xs: &[u8], ys: &[u8]) -> usize {
let l = xs.len().min(ys.len());
let mut xs = &xs[..l];
let mut ys = &ys[..l];
let mut off = 0;
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
while xs.len() >= 32 {
let x = _mm256_loadu_si256(xs.as_ptr() as _);
let y = _mm256_loadu_si256(ys.as_ptr() as _);
let r = _mm256_xor_si256(x, y);
let r = _mm256_cmpeq_epi8(r, zero);
let r = _mm256_movemask_epi8(r);
if r.trailing_ones() < 32 {
return off + r.trailing_ones() as usize;
}
xs = &xs[32..];
ys = &ys[32..];
off += 32;
}
}
off + mismatch(xs, ys)
}
#[inline(never)]
pub fn mismatch_chunked(xs: &[u8], ys: &[u8]) -> usize {
fn inner<const N: usize>(xs: &[u8], ys: &[u8]) -> usize {
let off = iter::zip(xs.chunks_exact(N), ys.chunks_exact(N))
.take_while(|(x, y)| x == y)
.count()
* N;
off + iter::zip(&xs[off..], &ys[off..])
.take_while(|(x, y)| x == y)
.count()
}
inner::<128>(xs, ys)
}
#[test]
fn bench() {
fn bench_mismatch(name: &str, f: fn(&[u8], &[u8]) -> usize) {
let n = 500_000;
let m = 500;
let mut xs = "Hello, world".repeat(n).into_bytes();
let mut ys = xs.clone();
xs.push(b'x');
ys.extend(b"ijk");
let t = std::time::Instant::now();
let mut res = 0;
for _ in 0..m {
res += f(&xs, &ys);
}
eprintln!("{name:10} {:0.2?}", t.elapsed());
assert_eq!(res, 3000000000);
}
bench_mismatch("naive", mismatch);
bench_mismatch("simd ", mismatch_simd);
bench_mismatch("chunk ", mismatch_chunked);
}