How to emulate higher-kinded functions?

I'm working on an application which uses tensors where the element type is erased at runtime, where the element type might be the various integers and floats, or some other "special" non-primitive types.

What I would like to do is make a method which accepts a single function/closure that'll be executed for any primitive numeric type (i.e. accepting some F, where F: for<T: Num> FnMut(T)), however that requires some sort of higher-kinded function/type and I have no idea how I might emulate that in Rust.

For some context, here is a simplified definition:

#[derive(Debug, Copy, Clone, PartialEq)]
enum ElementType {
    U8,
    F32,
    I64,
    String,
}

#[derive(Debug, Clone, PartialEq)]
struct Tensor {
    element_type: ElementType,
    buffer: Box<[u8]>,
}

The idea is that the buffer contains the underlying bytes from a slice of primitive values[1] and you'll use a match at runtime to reinterpret buffer as the correct type. You might construct a Tensor from primitives like so...

trait Primitive: bytemuck::NoUninit 
    + bytemuck::AnyBitPattern 
    + num_traits::Num 
{
    const ELEMENT_TYPE: ElementType;
}

impl Primitive for u8 { const ELEMENT_TYPE: ElementType = ElementType::U8; }
impl Primitive for f32 { const ELEMENT_TYPE: ElementType = ElementType::F32; }
impl Primitive for i64 { const ELEMENT_TYPE: ElementType = ElementType::I64; }

impl Tensor {
    fn new<T: Primitive>(items: &[T]) -> Self {
        let bytes: &[u8] = bytemuck::cast_slice(items);
        
        Tensor {
            element_type: T::ELEMENT_TYPE,
            buffer: bytes.to_vec().into_boxed_slice(),
        }
    }
    
    fn as_primitive<T>(&self) -> Option<&[T]>
    where T: Primitive,
    {
        if self.element_type == T::ELEMENT_TYPE {
            Some(bytemuck::cast_slice(&self.buffer))
        } else {
            None
        }
    }
}

Now, ideally we'd be able to write a Tensor::for_each_primitive() method which will try each of the element types in turn and call the correct monomorphisation.

impl Tensor {
    fn for_each_primitive<F>(&self, mut callable: F)
    where
        F: for<T: Primitive> FnMut(T),
    {
        if let Some(bytes) = self.as_primitive::<u8>() {
            bytes.iter().for_each(|b| callable(b));
        } else if let Some(longs) = self.as_primitive::<i64>() {
            longs.iter().for_each(|l| callable(l));
        } else if let Some(floats) = self.as_primitive::<f32>() {
            floats.iter().for_each(|f| callable(f));
        }
    }
}

With the idea being that you could call it like this:

use num_traits::NumCast;

fn main() {
    let tensor = Tensor::new(&[1.0, 2.0, 3.0]);
    
    // Hypothetical syntax
    let mut sum = 0.0;
    tensor.for_each_primitive(|n| sum += <f32 as NumCast>::from(n));
}

I've prepared a playground link with the full (non-working) implementation.


  1. ... or strings, but they're handled in a very different way and not really relevant at the moment. ↩ī¸Ž

Rust doesn't currently have generic closures so I would just use a regular trait instead. Something like:

trait ForEachPrimitiveFn {
    fn call<T: Primitive + ToPrimitive>(&mut self, value: T);
}
impl Tensor {
    fn for_each_primitive<F>(&self, mut callable: F)
    where
        F: ForEachPrimitiveFn,
    {
        if let Some(bytes) = self.as_primitive::<u8>() {
            bytes.iter().for_each(|&b| callable.call(b));
        } else if let Some(longs) = self.as_primitive::<i64>() {
            longs.iter().for_each(|&l| callable.call(l));
        } else if let Some(floats) = self.as_primitive::<f32>() {
            floats.iter().for_each(|&f| callable.call(f));
        }
    }
}

Playground link with working implementation.

2 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.