Challenge: write safe code that can be SIMD optimized

Let me start with saying that I like Rust and its type system, and I've been using it for over 5 years. At the same time, there are definitely still some unfulfilled promises. In particular, Rust promises zero-cost abstractions but in practice it is not so easy to write code where the abstractions are actually zero-cost. I have an example of this and I would like to know if it is possible to write safe code that can be optimized by the compiler like the unsafe version is.

The example:


fn assert_input_valid(input: &[u8]) {
    assert_eq!(input.len() % 2, 0, "input length should be a multiple of 2");
}

fn transform(pair: [u8; 2]) -> u8 {
    (u16::from_le_bytes(pair) >> 2) as u8
}

fn func_safe(input: &[u8]) -> Vec<u8> {
    assert_input_valid(input);
    input
        .chunks(2)
        .map(|chunk| transform(chunk.try_into().unwrap()))
        .collect()
}

The function func_safe interprets an array of bytes as pairs of bytes, does some calculation on each pair, and returns a new array of the transformed pairs. When I benchmark this on my machine I get around 4.4 GiB/s throughput.

Now consider this unsafe implementation:


fn func_unsafe(input: &[u8]) -> Vec<u8> {
    assert_input_valid(input);
    let pair_count = input.len() / 2;
    let mut output: Vec<u8> = Vec::with_capacity(pair_count);

    let mut r = input.as_ptr();
    let mut w = output.as_mut_ptr();
    let w_end = unsafe { w.add(pair_count) };
    while w != w_end {
        let b0 = unsafe { r.read_then_advance() };
        let b1 = unsafe { r.read_then_advance() };
        unsafe { w.write_then_advance(transform([b0, b1])) };
    }

    unsafe {
        output.set_len(pair_count);
    }
    output
}

trait PtrExt {
    type Element;

    unsafe fn read_then_advance(&mut self) -> Self::Element
    where
        Self::Element: Sized;
}

impl<T> PtrExt for *const T {
    type Element = T;

    unsafe fn read_then_advance(&mut self) -> Self::Element
    where
        Self::Element: Sized,
    {
        let val = self.read();
        *self = self.add(1);
        val
    }
}

trait PtrMutExt {
    type Element;

    unsafe fn write_then_advance(&mut self, val: Self::Element)
    where
        Self::Element: Sized;
}

impl<T> PtrMutExt for *mut T {
    type Element = T;

    unsafe fn write_then_advance(&mut self, val: Self::Element)
    where
        Self::Element: Sized,
    {
        self.write(val);
        *self = self.add(1);
    }
}

Note: It is possible that I made a mistake in the implementation and that the function is not actually performing the same calculation.

This unsafe implementation gives 63.3 GiB/s throughput on my machine, about 8x more than the safe implementation. The difference can be seen in the assembly. The unsafe version uses SIMD instructions.

For the unsafe implementation, we can get rid of the unsafety with respect to the output by creating vec![0; pair_count] and iterating over &mut output. This still allows the SIMD optimization to occur but comes at the cost of writing to each element twice (once for initialization to 0 and once for the final value), reducing the throughput to 51.2 GiB/s.

So here is the challenge, can we write this function without using unsafe and without having to write explicit SIMD code? I suppose using well-known libraries is acceptable, but ideally we only use the std library.

You can fork the benchmark I used if you like.

1 Like

Try chunks_exact instead of chunks.

You might try reserve_exact and extend instead of collect. When computing capacity you need to use checked arithmetic, because otherwise LLVM will suspect it can overflow.

Also maybe unrolling would help? chunks_exact(16) and either flat_map or manually unroll each chunk.

8 Likes

Wow, chunks_exact did the trick.

fn func_safe(input: &[u8]) -> Vec<u8> {
    let chunks = input.chunks_exact(2);
    assert!(
        chunks.remainder().is_empty(),
        "input length should be a multiple of 2"
    );
    chunks.map(|chunk| transform(chunk.try_into().unwrap())).collect()
}

It seems like the reserve_exact and extend does not make a difference in this case. The chunks_exact iter implements ExactSizeIterator which collect might use to do the right allocation, or the cost is just neglegible.

4 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.