I am trying to integrate the deep learning library Burn in my reinforcement learning codebase, which I have tried to keep generic over States like this boiled down example:
pub struct PointState {
x: f32,
y: f32,
}
pub trait State{
fn ndims() -> usize;
}
impl State for PointState{
fn ndims() -> usize {
2
}
}
use burn::tensor::{backend::Backend, Tensor};
pub struct Batch<B, S>
where
B: Backend,
S: State,
{
// how do I do this?
data: Tensor<B, S::ndims()>,
labels: Tensor<B, 1>,
}
And now I find myself a bit in over my head trying to figure out how to make the number of dimensions for the Tensor depend on the number of dimensions that State has.
I have an Agent struct generic over State, encapsulating the training logic as well as a replay buffer. I want Agent to periodically train the Network from the buffer, and I want to be able to initialize it with different Networks and train it on different States (i.e environments).
I've decided to write my thesis in Rust to learn it the hard way, I come from the Python world so please be gentle
Unfortunately that won’t make it usable in a type in a generic context, either (without nightly features).
error: generic parameters may not be used in const operations
--> src/main.rs:LN:23
|
LN | data: Tensor<B, { S::NDIMS }>,
| ^^^^^^^^ cannot perform const operation using `S`
|
= note: type parameters may not be used in const expressions
= help: use `#![feature(generic_const_exprs)]` to allow generic const expressions
Ah, right. In this case, it's probably better to work directly with the types themselves:
pub trait State<B> {
type Data;
}
impl<B> State<B> for PointState {
type Data = Tensor<B, 2>;
}
pub struct Batch<B, S>
where
S: State<B>,
{
data: S::Data,
labels: Tensor<B, 1>,
}
The generic argument can be moved around and/or replaced with an associated type depending on how it needs to be used. (I can't exactly infer from the context where it needs to go, that's probably use-case-specific.)