Shrinking bitset with compile time known length

C++ has std::bitset<N> which will store N bits directly (not in a separate heap allocation). Unfortunately, it always takes a minimum of sizeof(size_t) bytes because it's implemented as an array of size_t. So on x86-64 sizeof(std::bitset<16>) == 8. So for my own C++ projects I instead use a type I created that has specializations for different ranges of N so that it uses the narrowest integer type that makes sense. How do I do that in Rust?

AFAICT specialization in Rust isn't going to happen. I also can't define overlapping impls. Is there a better way than using macros to generate impls of all sizes up to some maximum size I care about?

Are you on a platform with fairly-efficient unaligned reads? You could just use a [u8; N/8] and have a fast path that uses larger reads for larger N.

1 Like

You can do this with typenum but it's somewhat painful to implement:

#![allow(type_alias_bounds)]

use core::{
    marker::PhantomData as P,
    ops::{Add, Range, Sub},
};

use num_traits::PrimInt;
use typenum::{
    consts::{U0, U128, U16, U32, U64, U8},
    Diff, False, Gr, IsGreater, IsLessOrEqual, LeEq, Sum, True, Unsigned,
};

trait Compute { type Result: ?Sized; }
struct Ternary<Cond, IfTrue: ?Sized, Else: ?Sized>(P<(Cond, P<IfTrue>, Else)>);
impl<T: ?Sized, F: ?Sized> Compute for Ternary<True, T, F> {
    type Result = T;
}
impl<T: ?Sized, F: ?Sized> Compute for Ternary<False, T, F> {
    type Result = F;
}
type Cond<C, T, F> = <Ternary<C, T, F> as Compute>::Result;

trait GetStorage { type Storage: PrimInt; }
struct TypenumTyToStorage<N: Unsigned>(P<N>);
type StorageForLen<N: Unsigned> = <TypenumTyToStorage<N> as GetStorage>::Storage;

type CmpU64<N> = Cond<LeEq<N, U64>, u64, u128>;
type CmpU32<N> = Cond<LeEq<N, U32>, u32, CmpU64<N>>;
type CmpU16<N> = Cond<LeEq<N, U16>, u16, CmpU32<N>>;
type CmpU8<N> = Cond<LeEq<N, U8>, u8, CmpU16<N>>;

impl<N: Unsigned> GetStorage for TypenumTyToStorage<N>
where
    // N: IsLessOrEqual<U128, Output = True>, // it's up to the user of this trait to enforce this; we return 128 if N > 64
    N: IsGreater<U0, Output = True>,

    N: IsLessOrEqual<U64>,
    Ternary<LeEq<N, U64>, u64, u128>: Compute,
    CmpU64<N>: PrimInt,

    N: IsLessOrEqual<U32>,
    Ternary<LeEq<N, U32>, u32, CmpU64<N>>: Compute,
    CmpU32<N>: PrimInt,

    N: IsLessOrEqual<U16>,
    Ternary<LeEq<N, U16>, u16, CmpU32<N>>: Compute,
    CmpU16<N>: PrimInt,

    N: IsLessOrEqual<U8>,
    Ternary<LeEq<N, U8>, u8, CmpU16<N>>: Compute,
    CmpU8<N>: PrimInt,
{
    type Storage = CmpU8<N>;
}

trait BitSetStorageAccess {
    fn try_get(&self, bit: usize) -> Result<bool, ()>;
    fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()>;

    fn get(&self, bit: usize) -> bool { self.try_get(bit).unwrap() }
    fn set(&mut self, bit: usize, val: bool) { self.try_set(bit, val).unwrap() }
}

#[derive(Debug, Default, Clone, Copy)]
struct Empty;
type Sentinel<Len: Unsigned> = BitSetStorageNode<Empty, Len, U0, ()>;

impl BitSetStorageAccess for () {
    fn try_get(&self, _: usize) -> Result<bool, ()> { unreachable!() }
    fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> { unreachable!() }
    fn get(&self, _: usize) -> bool { unreachable!() }
    fn set(&mut self, _: usize, _: bool) { unreachable!() }
}

impl<L: Unsigned> BitSetStorageAccess for Sentinel<L> {
    fn try_get(&self, _: usize) -> Result<bool, ()> { Err(()) }
    fn try_set(&mut self, _: usize, _: bool) -> Result<(), ()> { Err(()) }

    fn get(&self, bit: usize) -> bool {
        panic!("out of bounds: attempted to get index {bit} in a {} element bitset", L::USIZE)
    }
    fn set(&mut self, bit: usize, _: bool) {
        panic!("out of bounds: attempted to set index {bit} in a {} element bitset", L::USIZE)
    }
}

struct BitSetStorageNode<Storage, WidthOffset: Unsigned, Len: Unsigned, Rest: BitSetStorageAccess> {
    inner: Storage,
    rest: Rest,
    _width: P<(WidthOffset, Len)>,
}

impl<S, O: Unsigned, L: Unsigned, R: BitSetStorageAccess> BitSetStorageNode<S, O, L, R> {
    #[inline(always)]
    const fn range() -> Range<usize> {
        O::USIZE..(O::USIZE + L::USIZE)
    }

    #[inline(always)]
    fn relative(bit: usize) -> Option<usize> {
        if Self::range().contains(&bit) {
            Some(bit - O::USIZE)
        } else {
            None
        }
    }
}

impl<Storage, Offs, Len, Rest> BitSetStorageAccess for BitSetStorageNode<Storage, Offs, Len, Rest>
where
    Storage: PrimInt,
    Offs: Unsigned,
    Len: Unsigned,
    Rest: BitSetStorageAccess,
{
    #[inline(always)]
    fn try_get(&self, bit: usize) -> Result<bool, ()> {
        if let Some(bit_idx) = Self::relative(bit) {
            let zero = Storage::zero();
            let mask = Storage::one() << bit_idx;
            Ok((self.inner & mask) != zero)
        } else {
            self.rest.try_get(bit)
        }
    }

    fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()> {
        if let Some(bit_idx) = Self::relative(bit) {
            let val = if val { Storage::one() } else { Storage::zero() } << bit_idx;
            self.inner = self.inner | val;

            Ok(())
        } else {
            self.rest.try_set(bit, val)
        }
    }
}

trait GetStorageNodes {
    type Top: BitSetStorageAccess + Default;
}

// We prefer having fewer storage nodes over using as little space as possible
// here; i.e. we represent 80 bit bitsets as 1 `u128` instead of as a `u64` and
// a `u16`.
//
// It's possible to swap out `LenToRootStorageNodeFewestNodes`
// for your own type that chains together storage nodes as you wish;
// I've picked the simple strategy here for brevity.
struct LenToRootStorageNodeFewestNodes<Len: Unsigned, Offset: Unsigned = U0>((Len, Offset));
type FewestNodes<L: Unsigned, Offs = U0> =
    <LenToRootStorageNodeFewestNodes<L, Offs> as GetStorageNodes>::Top;

struct LenToRootStorageNodeFewestNodesRecurse<Len: Unsigned, Offset: Unsigned, GreaterThan128>(
    (Len, Offset, GreaterThan128),
);
impl<L: Unsigned, O: Unsigned> GetStorageNodes
    for LenToRootStorageNodeFewestNodesRecurse<L, O, False>
{
    type Top = ();
}
impl<L: Unsigned, O: Unsigned> GetStorageNodes
    for LenToRootStorageNodeFewestNodesRecurse<L, O, True>
where
    O: Add<U128>,
    Sum<O, U128>: Unsigned,

    L: IsGreater<U128, Output = True>,
    L: Sub<U128>,
    Diff<L, U128>: Unsigned,

    LenToRootStorageNodeFewestNodes<Diff<L, U128>, Sum<O, U128>>: GetStorageNodes,
{
    type Top = FewestNodes<Diff<L, U128>, Sum<O, U128>>;
}

// if <= 128 bits, this is the final storage node:
type LastNode<O, L> = BitSetStorageNode<StorageForLen<L>, O, L, Sentinel<Sum<O, L>>>;

// if > 128 bits, add a 128 bit node and then recurse (with 128 subtracted from the length):
type RecurseNode<O, L> =
    <LenToRootStorageNodeFewestNodesRecurse<L, O, Gr<L, U128>> as GetStorageNodes>::Top;

type StorageTop<O, L> =
    Cond<Gr<L, U128>, BitSetStorageNode<u128, O, U128, RecurseNode<O, L>>, LastNode<O, L>>;

impl<L: Unsigned, O: Unsigned> GetStorageNodes for LenToRootStorageNodeFewestNodes<L, O>
where
    L: IsGreater<U128>,

    O: Add<L>,
    Sum<O, L>: Unsigned,

    LenToRootStorageNodeFewestNodesRecurse<L, O, Gr<L, U128>>: GetStorageNodes,

    TypenumTyToStorage<L>: GetStorage,
    LastNode<O, L>: BitSetStorageAccess,

    Ternary<Gr<L, U128>, BitSetStorageNode<u128, O, U128, RecurseNode<O, L>>, LastNode<O, L>>:
        Compute,
    StorageTop<O, L>: BitSetStorageAccess + Sized + Default,
{
    type Top = StorageTop<O, L>;
}

struct BitSet<Len: Unsigned, Storage: BitSetStorageAccess = FewestNodes<Len, U0>> {
    inner: Storage,
    _len: P<Len>,
}

impl<L: Unsigned, S: BitSetStorageAccess + Default> Default for BitSet<L, S> {
    fn default() -> Self {
        Self { inner: Default::default(), _len: Default::default() }
    }
}

impl<L: Unsigned, S: BitSetStorageAccess> BitSetStorageAccess for BitSet<L, S> {
    fn try_get(&self, bit: usize) -> Result<bool, ()> { self.inner.try_get(bit) }
    fn try_set(&mut self, bit: usize, val: bool) -> Result<(), ()> { self.inner.try_set(bit, val) }
}

impl BitSet<U0, ()> {
    pub fn new<L: Unsigned>() -> BitSet<L, FewestNodes<L>>
    where
        LenToRootStorageNodeFewestNodes<L, U0>: GetStorageNodes,
    {
        BitSet {
            inner: Default::default(),
            _len: Default::default(),
        }
    }
}

A more complete example here (playground); codegen ends up being okay as well.


If you're willing to use nightly, feature(generic_const_exprs) lets you avoid much of the ceremony in the above but realistically (if @scottmcm's suggestion doesn't work for your use case) macros are probably the most reasonable solution here.

1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.