Let me start with saying that I like Rust and its type system, and I've been using it for over 5 years. At the same time, there are definitely still some unfulfilled promises. In particular, Rust promises zero-cost abstractions but in practice it is not so easy to write code where the abstractions are actually zero-cost. I have an example of this and I would like to know if it is possible to write safe code that can be optimized by the compiler like the unsafe version is.
The example:
fn assert_input_valid(input: &[u8]) {
assert_eq!(input.len() % 2, 0, "input length should be a multiple of 2");
}
fn transform(pair: [u8; 2]) -> u8 {
(u16::from_le_bytes(pair) >> 2) as u8
}
fn func_safe(input: &[u8]) -> Vec<u8> {
assert_input_valid(input);
input
.chunks(2)
.map(|chunk| transform(chunk.try_into().unwrap()))
.collect()
}
The function func_safe
interprets an array of bytes as pairs of bytes, does some calculation on each pair, and returns a new array of the transformed pairs. When I benchmark this on my machine I get around 4.4 GiB/s throughput.
Now consider this unsafe implementation:
fn func_unsafe(input: &[u8]) -> Vec<u8> {
assert_input_valid(input);
let pair_count = input.len() / 2;
let mut output: Vec<u8> = Vec::with_capacity(pair_count);
let mut r = input.as_ptr();
let mut w = output.as_mut_ptr();
let w_end = unsafe { w.add(pair_count) };
while w != w_end {
let b0 = unsafe { r.read_then_advance() };
let b1 = unsafe { r.read_then_advance() };
unsafe { w.write_then_advance(transform([b0, b1])) };
}
unsafe {
output.set_len(pair_count);
}
output
}
trait PtrExt {
type Element;
unsafe fn read_then_advance(&mut self) -> Self::Element
where
Self::Element: Sized;
}
impl<T> PtrExt for *const T {
type Element = T;
unsafe fn read_then_advance(&mut self) -> Self::Element
where
Self::Element: Sized,
{
let val = self.read();
*self = self.add(1);
val
}
}
trait PtrMutExt {
type Element;
unsafe fn write_then_advance(&mut self, val: Self::Element)
where
Self::Element: Sized;
}
impl<T> PtrMutExt for *mut T {
type Element = T;
unsafe fn write_then_advance(&mut self, val: Self::Element)
where
Self::Element: Sized,
{
self.write(val);
*self = self.add(1);
}
}
Note: It is possible that I made a mistake in the implementation and that the function is not actually performing the same calculation.
This unsafe implementation gives 63.3 GiB/s throughput on my machine, about 8x more than the safe implementation. The difference can be seen in the assembly. The unsafe version uses SIMD instructions.
For the unsafe implementation, we can get rid of the unsafety with respect to the output by creating vec![0; pair_count]
and iterating over &mut output
. This still allows the SIMD optimization to occur but comes at the cost of writing to each element twice (once for initialization to 0 and once for the final value), reducing the throughput to 51.2 GiB/s.
So here is the challenge, can we write this function without using unsafe and without having to write explicit SIMD code? I suppose using well-known libraries is acceptable, but ideally we only use the std library.
You can fork the benchmark I used if you like.