I'm trying to basically use const generics to compile time check tensor shapes. I have a pretty simple example where I need to custom specify different impls of Mul for tensors of 2 dims times 2 dims, 1 x 1 dim, 2 x 1 dim, and 1 x 2 dim. So I wrote this:
struct Tensor<const D1: usize, const D2: usize> {}
// 2 x 2
impl <const D1: usize, const D2: usize>Mul<&Tensor<D1, D2>> for Tensor<D1, D2> {
default type Output = Tensor<D1, D2>;
default fn mul(self, rhs: &Tensor<D1, D2>) -> Self::Output {
todo!()
}
}
// 2 x 1
impl <const D1: usize, const D2: usize>Mul<&Tensor<D1, 1>> for Tensor<D1, D2> {
type Output = Tensor<D1, 1>;
fn mul(self, rhs: &Tensor<D1, 1>) -> Self::Output {
todo!()
}
}
but I get an error on the 2x1 impl: conflicting implementations of trait std::ops::Mul<&Tensor<{_: usize}, 1_usize>> for type Tensor<{_: usize}, 1_usize> conflicting implementation for Tensor<{_: usize}, 1_usize>
I also tried to mark the 2x2 impl with default:
default impl <const D1: usize, const D2: usize>Mul<&Tensor<D1, D2>> for Tensor<D1, D2> {
default type Output = Tensor<D1, D2>;
default fn mul(self, rhs: &Tensor<D1, D2>) -> Self::Output {
todo!()
}
}
but this is recursing a bunch so I get this error:
overflow evaluating the requirement Tensor<{_: usize}, 1_usize>: std::ops::Mul<&Tensor<{_: usize}, 1_usize>>
I'm not sure why this is conflicting since one implementation is more specific than the other (D2 is set to 1 in the 2x1 impl)
Any help is much appreciated!