Types not constrained in associated types

I have an implementation of a trait with some types:

impl <const BATCH: u16, const IN: u16, const OUT: u16>Module for Linear<IN, OUT> {
    type Input = Tensor<BATCH, IN>;
    type Output = Tensor<BATCH, OUT>;

    fn train(&mut self) {}

    fn eval(&mut self) {}

    fn forward(&mut self, input: Self::Input) -> Self::Output {
        input.matmul(&self.ws.tr()) + &self.bs
    }
}

for some reason this doesn't work because BATCH is not constrained. How can I constrain BATCH?

What is the full error message?

The const parameter must be part a parameter of the implementing type, or of the trait, or an associated item derivable from otherwise constrained parameters and the trait and the implementing type. I.e. an implementation must be identifiable from the implementing type and the trait alone. (RFC 0447.)

For this particular example, if it compiled, there would be 65,536 implementations of Module for each Linear<IN, OUT> (one for every possible BATCH). Modulo specialization (which is not stable and wouldn't cover this particular case anyway), there can be at most one implementation of a trait per type.

More generally, note that associated types are considered the "outputs" of an implementation -- you can only have one for a given implementation.

Perhaps BATCH should be a parameter ("input") of Module instead. (Or of fn forward.)

error[E0207]: the const parameter BATCH is not constrained by the impl trait, self type, or predicates
--> src/modules/linear.rs:41:7
|
41 | impl <const BATCH: u16, const IN: u16, const OUT: u16>Module for Linear<IN, OUT> {
| ^^^^^^^^^^^^^^^^ unconstrained const parameter
|
= note: expressions using a const parameter must map each value to a distinct output value
= note: proving the result of expressions other than the parameter are unique is not supported

For more information about this error, try rustc --explain E0207.

Is there any way to use it as a type parameter of input / output rather than as the parameter itself?

You can on nightly, using GATs (Generic Associated Types).

That's close to what I want, but I need the BATCH in both the input and output to be the same. If someone passes in a tensor with a batch of 10, the output needs a batch of 10. I don't think this method enforces that.

The method enforces that through its signature:

fn forward<const BATCH: u16>(&mut self, input: Self::Input<BATCH>) -> Self::Output<BATCH>;

It takes a single BATCH parameter, then forwards that parameter to both Self::Input and Self::Output. This way, both the input and output batches must be identical for the method to resolve.

As a side note, one alternative with stable Rust would be to add the BATCH parameter to Module itself (Rust Playground):

struct Tensor<const BATCH: u16, const OTHER: u16> {}
struct Linear<const IN: u16, const OUT: u16> {}

trait Module<const BATCH: u16> {
    type Input;
    type Output;
    fn train(&mut self);
    fn eval(&mut self);
    fn forward(&mut self, input: Self::Input) -> Self::Output;
}

impl<const BATCH: u16, const IN: u16, const OUT: u16> Module<BATCH> for Linear<IN, OUT> {
    type Input = Tensor<BATCH, IN>;
    type Output = Tensor<BATCH, OUT>;

    fn train(&mut self) {}
    fn eval(&mut self) {}
    fn forward(&mut self, _input: Self::Input) -> Self::Output {
        todo!()
    }
}

But this might not be flexible enough, depending on how you are using the trait.

2 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.