Any idea following not work?
#![feature(generic_const_exprs)]
#![feature(negative_impls, with_negative_coherence)]
enum Assert<const c: bool> {}
trait IsTrue {}
impl IsTrue for Assert<true> {}
trait IsFalse {}
impl IsFalse for Assert<false> {}
impl<T: ?Sized + IsTrue> !IsFalse for T {}
impl<T: ?Sized + IsFalse> !IsTrue for T {}
#[derive(Debug, PartialEq, Eq)]
struct Tensor<const R: usize>;
trait Module<const RHS: usize, const OUT: usize> {
fn forward(&self, x: Tensor<RHS>) -> Tensor<OUT>;
}
struct Layer<const R: usize> {
weight: Tensor<R>,
}
impl<const LHS: usize, const RHS: usize> Module<RHS, LHS> for Layer<LHS>
where
Assert<{ LHS >= RHS }>: IsTrue,
{
fn forward(&self, x: Tensor<RHS>) -> Tensor<LHS> {
println!("impl 1: {} >= {}", LHS, RHS);
Tensor::<LHS>
}
}
impl<const LHS: usize, const RHS: usize> Module<RHS, RHS> for Layer<LHS>
where
Assert<{ LHS >= RHS }>: IsFalse,
{
fn forward(&self, x: Tensor<RHS>) -> Tensor<RHS> {
println!("impl 2: {} !>= {}", LHS, RHS);
Tensor::<RHS>
}
}
fn main() {
let l = Layer {
weight: Tensor::<2>,
};
let x1 = Tensor::<1>;
assert_eq!(Tensor::<2>, l.forward(x1));
let x2 = Tensor::<3>;
assert_eq!(Tensor::<3>, l.forward(x2));
}
while this works playground