Generic operator overloading

Hi, I have a wrapper around Rust's fixed-size array,

pub struct Array<T, const N: usize>(pub [T; N]);

and I want to support some arithmetic operations like Mul, Add, etc..

By implementing

// 1
impl<T,  const N: usize> Mul for Array<T, N>

I achieved component wise multiplication between two Arrays. Then I implemented

// 2
impl<T, const N: usize> Mul<T> for Array<T, N>

to support multiplication between Array and its elements (it can be a scalar types or nested Array).

So far so good, but when I was trying to implement multiplication between Array and a Scalar like this

// 3
impl<T, U, const N: usize> Mul<U> for Array<T, N>

the compiler complains conflicting implementations between 2 and 3. If I'm not mistaken, it's because that implementation 2 and 3 overlaps when the Array is not nested (Array<f32, 4>). I tried to add a Num bound to U, but it's not working. So, what was wrong here?

Here is the full code: Rust Playground.

Mul<i32> for Array<i32, 5> is implemented by both #2 and #3.

In the complete code, I tried to constrain T to not be scalar types (trait NdArray is only implemented by Array):

implT: Copy + NdArray, U: Num, const N: size > MulU> for ArrayT, N>

It still gives me the same error.

It's a limitation of the type system. It's trying to be robust to the possibility that you may add impl NdArray for i32 or impl Num for Array.

Do you need #2 at all? #3 seems to cover the case of #2.

I thought #2 would handle the case like Array<Array<T, 3>, 3> * Array<T, 3>. It supports scalar multiplication when the array is not nested like Array<i32, 128> * i32. For Array<Array<i32, 3>, 3> * i32, this would be not implemented.

If I remove all the bounds in #3, like this

impl<T, U, const N: usize> Mul<U> for Array<T, N>

then, it will conflict with #1, cause U can also be an Array.

As a high level design, I think you're trying to do too many things with the same operator. In particular, since it supports both point-wise multiplication and all-vs-one multiplication, it's not clear what Array<Array<T, 3>, 3> * Array<T, 3> is supposed to mean: it could either mean multiply each element of the first one by each element of the second one, or multiply each element of the first one by the whole second array.

2 Likes

Actually, I want to implement broadcasting when two arrays have different shapes.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.