I'm working on an application which uses tensors where the element type is erased at runtime, where the element type might be the various integers and floats, or some other "special" non-primitive types.
What I would like to do is make a method which accepts a single function/closure that'll be executed for any primitive numeric type (i.e. accepting some F
, where F: for<T: Num> FnMut(T)
), however that requires some sort of higher-kinded function/type and I have no idea how I might emulate that in Rust.
For some context, here is a simplified definition:
#[derive(Debug, Copy, Clone, PartialEq)]
enum ElementType {
U8,
F32,
I64,
String,
}
#[derive(Debug, Clone, PartialEq)]
struct Tensor {
element_type: ElementType,
buffer: Box<[u8]>,
}
The idea is that the buffer
contains the underlying bytes from a slice of primitive values[1] and you'll use a match
at runtime to reinterpret buffer
as the correct type. You might construct a Tensor
from primitives like so...
trait Primitive: bytemuck::NoUninit
+ bytemuck::AnyBitPattern
+ num_traits::Num
{
const ELEMENT_TYPE: ElementType;
}
impl Primitive for u8 { const ELEMENT_TYPE: ElementType = ElementType::U8; }
impl Primitive for f32 { const ELEMENT_TYPE: ElementType = ElementType::F32; }
impl Primitive for i64 { const ELEMENT_TYPE: ElementType = ElementType::I64; }
impl Tensor {
fn new<T: Primitive>(items: &[T]) -> Self {
let bytes: &[u8] = bytemuck::cast_slice(items);
Tensor {
element_type: T::ELEMENT_TYPE,
buffer: bytes.to_vec().into_boxed_slice(),
}
}
fn as_primitive<T>(&self) -> Option<&[T]>
where T: Primitive,
{
if self.element_type == T::ELEMENT_TYPE {
Some(bytemuck::cast_slice(&self.buffer))
} else {
None
}
}
}
Now, ideally we'd be able to write a Tensor::for_each_primitive()
method which will try each of the element types in turn and call the correct monomorphisation.
impl Tensor {
fn for_each_primitive<F>(&self, mut callable: F)
where
F: for<T: Primitive> FnMut(T),
{
if let Some(bytes) = self.as_primitive::<u8>() {
bytes.iter().for_each(|b| callable(b));
} else if let Some(longs) = self.as_primitive::<i64>() {
longs.iter().for_each(|l| callable(l));
} else if let Some(floats) = self.as_primitive::<f32>() {
floats.iter().for_each(|f| callable(f));
}
}
}
With the idea being that you could call it like this:
use num_traits::NumCast;
fn main() {
let tensor = Tensor::new(&[1.0, 2.0, 3.0]);
// Hypothetical syntax
let mut sum = 0.0;
tensor.for_each_primitive(|n| sum += <f32 as NumCast>::from(n));
}
I've prepared a playground link with the full (non-working) implementation.
-
... or strings, but they're handled in a very different way and not really relevant at the moment. âŠī¸