On Concrete Type Monomorphization

A layered software architecture, where the internals of each layer is kept hidden - only exposed via dyn traits. We then realize that we have a hot loop, that executes dynamic dispatch on methods of the dynamic trait, which kills performance. We need to refactor the code such that logic from one layer is pushed down - in order to remove all the vtable lookups.

But wait, I came up with two small procedural macros ‘wrappable’ and ‘wrapping’ that allows for library authors to make the vtable boundary flexible. Basically, developers using the library are able to create blanket supertraits, enabling the returned dynamic traits to have extra methods - backed by concrete type monomorphization.

Basically, a function defined on the upper layer gets compiled and added to the vtable of each of the internal types from the underlying layer.

This is a pattern that can be applied with just a few lines of code (using these macros).

Before authoring this, I was having a hard time trying to find a solution to this. Any suggestions for an alternative way to add blanket implementations that get concrete type monomorphized (maybe a Nightly feature I need to know about)?

I don't completely understand the use case, or the architecture.

If the underlying layers expose some trait Trait, then the upper layers can still define new methods on any type implementing Trait (including dyn Trait). E.g., an extension trait (and a blanket impl) could accomplish this.

And why is hiding internals via dyn Trait better than exposing one or more concrete PublicTraitImpl types in underlying layers, and using generic parameters and/or those concrete types in upper layers? Concrete types can still hide internal state, after all, and sounds as though you've just shifted the concrete type from PublicTraitImpl to Box<dyn Trait>. If the library / underlying layer is the only thing which will provide Trait impls, then the concrete option works just fine; you can use enums as necessary to merge all the internal impls into one (see also enum_dispatch).

Personally, I love that the default option in Rust avoids vtable overhead, and I only use dyn if strictly necessary (or if performance is entirely unimportant compared to the cost in binary size of monomorphization). I recently reimplemented Google's C++ leveldb library in Rust, and I can't help but think of how they have an abstract base class for iterators which uses virtual functions (complete with workarounds to reduce the cost of using vtables in hot loops), even though the full set of iterators is statically known; my Rust equivalent is something like enum IterToMerge { Foo(FooIter), Bar(BarIter), Baz(BazIter) } and it works just fine.

It could be my background in object oriented programming that introduces complications that isn’t there when using the appropriate Rust constructs.

One of the use-cases is an implementation of Elias-Fano I am working on, which is basically a VERY compressed representation of a sorted array of numbers. We specifically need this for real-time processing (nano-second lookups), so performance is extremely important. Depending on the number of items in the array, the memory layout and thus the executing code changes. You can think of it as a type with a static const generic. I guess we could have something like an enumerator with the about 60 different implementations, but instead we chose to use a dynamic trait.

Since the library internally determines the optimal implementation and returns for instance a Box, then the set of methods in the vtable is (as far I understand) limited to the methods known by the library. In other words, we cannot just add a method that gets automatically compiled for all the types.

But with the proposed pattern, the supplied blanket methods automatically gets compiled for each of the underlying (and private) implementations behind the scenes.

If you never need to statically dispatch on the array size (requiring a branch on the array size on each operation), then something like struct Array(Box<[u8]>) is sufficient.

Assuming that you would like to be able to statically choose the array size, and that most of the different array sizes have a lot of code in common, my first try would be something like

// Underlying layer
pub struct Array<const N: usize>
// `()` is just a random zero-sized type; could define your own
where (): ArrayRepr<N>
{
    /* private fields */
}

impl<const N: usize> Array<N> where (): ArrayRepr<N> {
    // public methods here, which get monomorphized
}

// Assuming that only *some* `N` work as array sizes,
// and some of the code for different `N` is very different,
// I think something like this would be needed. I could be wrong about this.
trait ArrayRepr<const N: usize> {
     // Size-specific stuff; keep this as small as possible to avoid needless repetition in impls

     // Probably something like `[u8; M]` for some `M`, whatever you need.
     type Repr;
}

// Maybe define a `macro_rules!` macro to reduce repetition, as needed
impl ArrayRepr<1> for () { ... }
impl ArrayRepr<2> for () { ... }
impl ArrayRepr<3> for () { ... }
...

// Upper layer
// Optionally, use a `const N: usize` parameter
// if the additional methods need to know it.
trait ArrayExt {
    // additional methods
}

impl<const N: usize> ArrayExt for Array<N> where (): ArrayRepr<N> {
    // impls, which get monomorphized
}

If you need to dynamically choose the size on top of the above code, large enums may be necessary, and the cost of a branch is brought back.

Mmm, I still don't like this solution. It feels like too much boilerplate.

I forget where I saw a quote about this, but the general advice is: write Rust code as concretely as possible, only introduce abstraction where necessary.

Quickly looking at existing crates for reference, this one uses a lot of generics: EliasFano in sux::dict::elias_fano - Rust

These two use entirely concrete types: EliasFano in elias_fano_rust - Rust, EliasFano in elias_fano - Rust

On a different note, I did a variation of LevelDb in C#, that has been powering that company for about 15 years now (with hundreds of billions of records).

I added the possibility of including a custom aggregation function, which would automatically aggregate two values into one when the key was the same. This allows for quickly counting and sorting keys without having to have them in memory (or smarter aggregation such as hyperloglog for estimating number of unique visits to items without having to store them all).

If you in leveldb have a heap (binary tree merge) of iterators of various types, the possibility of compiler optimizations is limited. Performance will probably be approximately the same if using dynamic dispatch or enum dispatch.

Requiring all types to be directly visible everywhere is also one of the reasons some complain about the compilation times.

I will make a few experiments with enum_dispatch to explore strengths and weaknesses. Thanks, that seems very related.

Those crates are great starting points. We are currently talking billions of collections with jointly 250 GB of memory backed by huge pages with platform-specific assembly instructions for various sizes and operations (e.g. AVX512 and SVE2). The platform requirements are a little extraordinary on this one. Looking at ARM assembly intrinsics id not the fun part.

I did a quick test.

This is the library code - lib.rs

FlexibleArray

use dynamic_wrapping::{wrappable, wrapping};
use enum_dispatch::enum_dispatch;

#[wrappable] // macro from dynamic_wrapping
pub trait FlexibleArray {
    fn get(&self, index: u32) -> u32;
    fn len(&self) -> u32;
}

EmptyArray

pub struct EmptyArray;

impl FlexibleArray for EmptyArray {
    fn get(&self, _index: u32) -> u32 {
        panic!("EmptyArray has no elements");
    }

    fn len(&self) -> u32 {
        0
    }
}

VectorArray

pub struct VectorArray {
    data: Vec<u32>,
}

impl VectorArray {
    pub fn new(data: Vec<u32>) -> Self {
        VectorArray { data }
    }
}

impl FlexibleArray for VectorArray {
    fn get(&self, index: u32) -> u32 {
        self.data[index as usize]
    }

    fn len(&self) -> u32 {
        self.data.len() as u32
    }
}

Array16

pub struct Array16 {
    data: [u32; 16],
}

impl Array16 {
    pub fn new(data: [u32; 16]) -> Self {
        Array16 { data }
    }
}

impl FlexibleArray for Array16 {
    fn get(&self, index: u32) -> u32 {
        self.data[index as usize]
    }

    fn len(&self) -> u32 {
        16
    }
}

Dynamic Wrapping

#[wrapping(
    FlexibleArray => Box<dyn FlexibleArray + 'a>, Box::new
)]
pub struct BoxDynWrapping;

BoxDynExtendedWrapping

Quite easy to declare:

#[enum_dispatch(FlexibleArray)]
pub enum FlexibleArrayEnum {
    EmptyArray(EmptyArray),
    VectorArray(VectorArray),
    Array16(Array16),
}

VectorStorage

pub struct VectorStorage;

impl VectorStorage {
    pub fn open_with<'a, W: FlexibleArrayWrapper<'a>>(code: u8) -> W::Wrapped {
        match code % 3 {
            0 => W::wrap(EmptyArray),
            1 => W::wrap(VectorArray::new(create_vector_array_data(code))),
            _ => W::wrap(Array16::new(create_array16(code))),
        }
    }

    pub fn open_dyn(code: u8) -> Box<dyn FlexibleArray> {
        Self::open_with::<BoxDynWrapping>(code)
    }

    pub fn open_enum(code: u8) -> FlexibleArrayEnum {
        match code % 3 {
            0 => FlexibleArrayEnum::EmptyArray(EmptyArray),
            1 => FlexibleArrayEnum::VectorArray(VectorArray::new(create_vector_array_data(code))),
            _ => FlexibleArrayEnum::Array16(Array16::new(create_array16(code))),
        }
    }
}

fn create_vector_array_data(code: u8) -> Vec<u32> {
    (0..15).map(|i| (code as u32) + 13 * i).collect()
}

fn create_array16(code: u8) -> [u32; 16] {
    std::array::from_fn(|i| (code as u32) + 13 * i as u32)
}

Main program - main.rs

In my main program I find out that I need a sum function that computes the sum of the elements in an array. In this scenario, I am not allowed to make changes in lib.rs, because this algorithm is specific to my program.

Naïve Sum

So I add such a function to my main program:

#[inline(never)]
fn compute_sum(array: &dyn FlexibleArray) -> u32 {
    let mut sum = 0u32;
    for i in 0..array.len() {
        sum += array.get(i);
    }
    sum
}

This generates a single assembly method printed as 53 lines loading the addesses of len and get into two registers and repeatedly calling the second in a loop (dynamic dispatch).

Dynamic Wrapper

With added support for Dynamic Wrapper, so we can extend the trait.

pub trait ExtendedFlexibleArray: FlexibleArray {
    fn get_sum(&self) -> u32;
}

impl<T: FlexibleArray> ExtendedFlexibleArray for T {
    fn get_sum(&self) -> u32 {
        let mut sum = 0u32;
        for i in 0..self.len() {
            sum += self.get(i);
        }
        sum
    }
}

#[wrapping(FlexibleArray => Box<dyn ExtendedFlexibleArray + 'a>, Box::new)]
pub struct BoxDynExtendedWrapping;

When supplying BoxDynExtendedWrapping to the open_with method, we can now call get_sum on the returned array.

let extended_array = VectorStorage::open_with::<BoxDynExtendedWrapping>(code);

This effectively generates some new methods:

 1 "<test_enum_dispatch::Array16 as test_enum_dispatch::ExtendedFlexibleArray>::get_sum" [18]
 4 "<test_enum_dispatch::EmptyArray as test_enum_dispatch::ExtendedFlexibleArray>::get_sum" [7]
 5 "<test_enum_dispatch::VectorArray as test_enum_dispatch::ExtendedFlexibleArray>::get_sum" [135]

Most importantly, we see that the EmptyArray code is:

        xor eax, eax
        ret

and that the Array16 code is doing optimized AVX prefix sum:

        vmovdqu ymm0, ymmword ptr [rdi + 32]
        vpaddd zmm0, zmm0, zmmword ptr [rdi]
        vextracti128 xmm1, ymm0, 1
        vpaddd xmm0, xmm0, xmm1
        vpshufd xmm1, xmm0, 238
        vpaddd xmm0, xmm0, xmm1
        vpshufd xmm1, xmm0, 85
        vpaddd xmm0, xmm0, xmm1
        vmovd eax, xmm0
        vzeroupper
        ret

This proves that we are getting highly specialized implementations of these methods (static dispatch). It doesn't matter that this method wasn't originally in the library.

FlexibleArrayEnum

So we create a method that computes the sum from an instance of FlexibleArrayEnum. We are not in the library crate, so the internal enum details are hidden from us.

#[inline(never)]
fn compute_sum_enum(array: &FlexibleArrayEnum) -> u32 {
    let mut sum = 0u32;
    for i in 0..array.len() {
        sum += array.get(i);
    }
    sum
}

We now see a method where the assembly is about 166 lines. It looks like this is similar to what we get with a blanket implementation on a supertrait. Promising.

.section .text.test_enum_dispatch::compute_sum_enum,"ax",@progbits
        .p2align        4
.type   test_enum_dispatch::compute_sum_enum,@function
test_enum_dispatch::compute_sum_enum:
                // src/main.rs:33
                fn compute_sum_enum(array: &FlexibleArrayEnum) -> u32 {
        .cfi_startproc
        push rax
        .cfi_def_cfa_offset 16
                // src/lib.rs:70
                #[enum_dispatch(FlexibleArray)]
        mov eax, dword ptr [rdi]
        test eax, eax
        je .LBB8_12
        cmp eax, 1
        jne .LBB8_11
        mov rax, rdi
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/alloc/src/vec/mod.rs:3023
                let len = self.len;
        mov rdi, qword ptr [rdi + 24]
        mov ecx, 4294967295
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/cmp.rs:1915
                fn lt(&self, other: &Self) -> bool { *self <  *other }
        and rcx, rdi
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/iter/range.rs:900
                if self.start < self.end {
        je .LBB8_3
        mov rsi, qword ptr [rax + 24]
        mov rdx, qword ptr [rax + 16]
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/slice/index.rs:272
                &(*slice)[self]
        lea r8, [rcx - 1]
        cmp rsi, r8
        cmovb r8, rsi
        inc r8
        cmp r8, 65
        jae .LBB8_6
        xor r8d, r8d
        xor eax, eax
        jmp .LBB8_9
.LBB8_11:
                // src/main.rs:36
                sum += array.get(i);
        vmovdqu xmm0, xmmword ptr [rdi + 4]
        mov rdx, qword ptr [rdi + 16]
        mov rax, qword ptr [rdi + 24]
                // src/lib.rs:57
                self.data[index as usize]
        mov r8d, dword ptr [rdi + 56]
                // src/main.rs:36
                sum += array.get(i);
        add r8d, dword ptr [rdi + 60]
        vpinsrd xmm0, xmm0, dword ptr [rdi + 12], 1
        vpunpcklqdq xmm0, xmm0, xmmword ptr [rdi + 32]
        mov ecx, edx
        shr rdx, 32
        add edx, dword ptr [rdi + 64]
        mov esi, eax
        shr rax, 32
        vpaddd xmm0, xmm0, xmmword ptr [rdi + 40]
        add eax, esi
        add eax, edx
        add eax, r8d
        add eax, ecx
        vpshufd xmm1, xmm0, 238
        vpaddd xmm0, xmm0, xmm1
        vpshufd xmm1, xmm0, 85
        vpaddd xmm0, xmm0, xmm1
        vmovd r9d, xmm0
        add r9d, dword ptr [rdi + 8]
        add eax, r9d
                // src/main.rs:39
                }
        pop rcx
        .cfi_def_cfa_offset 8
        ret
.LBB8_3:
        .cfi_def_cfa_offset 16
        xor eax, eax
        pop rcx
        .cfi_def_cfa_offset 8
        ret
.LBB8_6:
        .cfi_def_cfa_offset 16
        mov eax, r8d
        and eax, 63
        mov r9d, 64
        vpxor xmm0, xmm0, xmm0
        vpxor xmm1, xmm1, xmm1
        vpxor xmm2, xmm2, xmm2
        vpxor xmm3, xmm3, xmm3
        cmovne r9, rax
        xor eax, eax
        sub r8, r9
        .p2align        4
.LBB8_7:
                // src/main.rs:36
                sum += array.get(i);
        vpaddd zmm0, zmm0, zmmword ptr [rdx + 4*rax]
        vpaddd zmm1, zmm1, zmmword ptr [rdx + 4*rax + 64]
        vpaddd zmm2, zmm2, zmmword ptr [rdx + 4*rax + 128]
        vpaddd zmm3, zmm3, zmmword ptr [rdx + 4*rax + 192]
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/num/uint_macros.rs:985
                intrinsics::unchecked_add(self, rhs)
        add rax, 64
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/iter/range.rs:900
                if self.start < self.end {
        cmp r8, rax
        jne .LBB8_7
        vpaddd zmm0, zmm1, zmm0
        vpaddd zmm2, zmm3, zmm2
        vpaddd zmm0, zmm2, zmm0
        vextracti64x4 ymm1, zmm0, 1
        vpaddd zmm0, zmm0, zmm1
        vextracti128 xmm1, ymm0, 1
        vpaddd xmm0, xmm0, xmm1
        vpshufd xmm1, xmm0, 238
        vpaddd xmm0, xmm0, xmm1
        vpshufd xmm1, xmm0, 85
        vpaddd xmm0, xmm0, xmm1
        vmovd eax, xmm0
        .p2align        4
.LBB8_9:
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/slice/index.rs:272
                &(*slice)[self]
        cmp rsi, r8
        je .LBB8_13
                // src/main.rs:36
                sum += array.get(i);
        add eax, dword ptr [rdx + 4*r8]
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/num/uint_macros.rs:985
                intrinsics::unchecked_add(self, rhs)
        inc r8
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/cmp.rs:1915
                fn lt(&self, other: &Self) -> bool { *self <  *other }
        cmp rcx, r8
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/iter/range.rs:900
                if self.start < self.end {
        jne .LBB8_9
.LBB8_12:
                // src/main.rs:39
                }
        pop rcx
        .cfi_def_cfa_offset 8
        vzeroupper
        ret
.LBB8_13:
        .cfi_def_cfa_offset 16
                // ~/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/slice/index.rs:272
                &(*slice)[self]
        dec edi
        lea rdx, [rip + .Lanon.9e60ad0f5f82fa416d6bdfb18d32e3b4.10]
        cmp rsi, rdi
        cmovb rdi, rsi
        vzeroupper
        call qword ptr [rip + core::panicking::panic_bounds_check@GOTPCREL]

Summary

It looks like enum_dispatch is a valid alternative to my wrapper approach.

I can even access my super trait here and call get_sum:

#[inline(never)]
fn compute_sum_enum(array: &FlexibleArrayEnum) -> u32 {
    match array {
        FlexibleArrayEnum::EmptyArray(_) => 0,
        FlexibleArrayEnum::VectorArray(v) => v.get_sum(),
        FlexibleArrayEnum::Array16(a) => a.get_sum(),
    }
}