Question about generics and compile-time reflection from a beginner

I am trying to integrate the deep learning library Burn in my reinforcement learning codebase, which I have tried to keep generic over States like this boiled down example:

pub struct PointState {
    x: f32,
    y: f32,
}

pub trait State{
    fn ndims() -> usize;
}

impl State for PointState{
    fn ndims() -> usize {
        2
    }
}


use burn::tensor::{backend::Backend, Tensor};

pub struct Batch<B, S>
where
    B: Backend,
    S: State,
{
    // how do I do this?
    data: Tensor<B, S::ndims()>,
    labels: Tensor<B, 1>,
}

And now I find myself a bit in over my head trying to figure out how to make the number of dimensions for the Tensor depend on the number of dimensions that State has.

I have an Agent struct generic over State, encapsulating the training logic as well as a replay buffer. I want Agent to periodically train the Network from the buffer, and I want to be able to initialize it with different Networks and train it on different States (i.e environments).

I've decided to write my thesis in Rust to learn it the hard way, I come from the Python world so please be gentle :slight_smile:

It looks like you want ndims to be an associated const instead of an fn.

1 Like

Unfortunately that won’t make it usable in a type in a generic context, either (without nightly features).

error: generic parameters may not be used in const operations
  --> src/main.rs:LN:23
   |
LN |     data: Tensor<B, { S::NDIMS }>,
   |                       ^^^^^^^^ cannot perform const operation using `S`
   |
   = note: type parameters may not be used in const expressions
   = help: use `#![feature(generic_const_exprs)]` to allow generic const expressions

Ah, right. In this case, it's probably better to work directly with the types themselves:

pub trait State<B> {
    type Data;
}

impl<B> State<B> for PointState {
    type Data = Tensor<B, 2>;
}

pub struct Batch<B, S>
where
    S: State<B>,
{
    data: S::Data,
    labels: Tensor<B, 1>,
}

The generic argument can be moved around and/or replaced with an associated type depending on how it needs to be used. (I can't exactly infer from the context where it needs to go, that's probably use-case-specific.)

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.