Passing a reference into a procedural macro in Rust

Let's take just as an example, calculating the median of a sequence of integers at compile time.

I have implemented it both in C++ and in Rust.

C++: Calculating median at compile time and writing a C++ function that returns different types. · GitHub
Rust: https://github.com/jymchng/median-proc-macro/blob/master/crates/median-pm-core/src/lib.rs#L9

In C++, I am able to first instantiate an array, then pass the array as a reference into the overloaded median function and it works. Like so:

std::array<int, 5> arr1 = {1, 2, 3, 4, 5};
auto median_seven = median<5>(arr1);

However, in Rust, I cannot think of anyway to do that. Like so:

let my_vec = vec![1,2,3,4,5];
let median = median!(&my_vec);

When I tried doing that, I get a syn::Ident with value "my_vec".

Generally, what is the mechanism/idiom to pass reference/value into the procedural macro and let the macro treats them as reference/value respectively?

You want a function, not a macro.

1 Like

To address this point specifically: this is impossible. The compiler simply doesn't have that information at the point macros are expanded. The macro gets the tokens you pass to it, and that's it. You can't check types or look at values because none of it has been determined yet.

2 Likes

A macro can't compute the median of a sequence, because it has no access to the sequence. The sequence doesn't even exist at the point in compilation where the macro is expanded - it's working with source code, whereas the Vec created by the vec! macro will only exist at runtime. A macro could, at best, expand to code that will evaluate the median when run, but it's not really obvious why that would be an improvement over a function.

The way I'd do this kind of overloaded function in Rust would be to define a function generic in <T>, that receives a Vec<T>, where T is constrained to types which implement Ord or PartialOrd. That function can be free-standing (median(&my_vec)), or defined via a trait as a method on vectors (my_vec.median()). Something like:

fn median<T>(vec: &[T]) -> Option<T>
where T: PartialOrd
{
  // left as an exercise
}

Why is it that you want to use a macro instead of a function for this?

4 Likes

Rust doesn't have a close equivalent to C++ metaprogramming. On the flip side, C++ doesn't have a close equivalent to Rust macro metaprogramming. There's some overlap in capability, but this particular case doesn't hit that overlap unless you reformulate it:

let median = median![1, 2, 3, 4, 5];
1 Like

You may be looking for this function if you want access to the length:

fn median<T, const LEN>(ary: &[T; LEN]) -> Option<&T>
  where T: PartialOrd,
        [T; LEN]: Sized /* actually a wf constraint but i don't want to be confusing */
{
}

The reason why I am using a macro and not a const fn is because in C++, I am able to do constexpr std::conditional<(N % 2 == 0), float, int> such that the return type of the median C++ function is different depending on the number of the arguments passed into it.

My Rust macro is able to return i32 or f64 depending on the number of arguments passed, but seems like it is unable to evaluate if the argument passed in is an array or a vector. But well, certainly I am aware of const fn calculating median at compile time.

In short, I am trying to achieve two things:

  1. Computation of median at compile time. Both C++ and Rust are able to do it.

  2. Return different types based on the number of arguments passed in, i.e. if Even, return float/f64 else int/i32. Both C++ and Rust are able to do it.

BUT, C++ can do both, Rust cannot.

I believe the prevailing view among the Rust community is that attempting this leads quickly to madness.

Rust type-level functions take the form of traits and associated types. The trait is the "function", the type it's implemented on and the generic parameters are the "inputs", and the associated types are the "outputs".

I could start writing that up, but there's a problem: we don't really want this function to be returning i32 or f64. We want the function to be returning {integer} or {float} (the unresolved type of integer & float constants in the source code before they're resolved into a concrete type). To do that, macros are mandatory.

Also you can't write completeness proofs for const generic arithmetic yet.

trait MedianOutput {
  type Output;
}
impl<T> MedianOutput for [T; 0]
  where T: num_traits::Num
{
  type Output = T;
}
impl<T> MedianOutput for [T; 1]
  where T: num_traits::Num,
  f64: From<T>
{
  type Output = f64;
}
impl<T> MedianOutput for [T; 2]
  where T: num_traits::Num
{
  type Output = T;
}
impl<T> MedianOutput for [T; 3]
  where T: num_traits::Num,
  f64: From<T>
{
  type Output = f64;
}
// repeat until you reach 12

fn main() {
    let _x: <[i32; 3] as MedianOutput>::Output;
}
2 Likes

Oh, I didn't know you can do

f64: From<T>

What does this mean? Usually I see trait bounds on generic types and not on concrete types like f64.

And I suppose the trait MedianOutput could then have a method call median that does the calculation of the median?

That's a terrible idea. You'll surprise users of your code, in a bad way.


That said, it is doable with some type-level boolean-inversion-based parity checking.

It means exactly what it says. f64 must implement From<T> for the given T. This still constrains T, not f64.

Bounds don't constrain the LHS – they constrain the appearing type variables. Just like the inequality 1 < x doesn't constrain the number 1, it constrains the variable x.

trait ArrayLengthIsEven<const N: usize> {
    fn calc_median(arr: &mut [usize; N]) -> f64 {
        arr.sort();
        (arr[N/2] as f64 + arr[N/2 - 1] as f64) / 2.0
    }
}

trait ArrayLengthIsOdd<const N: usize> {
    fn calc_median(arr: &mut [usize; N]) -> usize {
        arr.sort();
        arr[N/2]
    }
}

impl ArrayLengthIsEven<2> for &mut [usize; 2] {}
impl ArrayLengthIsEven<4> for &mut [usize; 4] {}
impl ArrayLengthIsEven<6> for &mut [usize; 6] {}
impl ArrayLengthIsEven<8> for &mut [usize; 8] {}
impl ArrayLengthIsEven<10> for &mut [usize; 10] {}
impl ArrayLengthIsOdd<1> for &mut [usize; 1] {}
impl ArrayLengthIsOdd<3> for &mut [usize; 3] {}
impl ArrayLengthIsOdd<5> for &mut [usize; 5] {}
impl ArrayLengthIsOdd<7> for &mut [usize; 7] {}
impl ArrayLengthIsOdd<9> for &mut [usize; 9] {}

fn main() {
    const arr: [usize; 3] = [1, 2, 3];
    const arr_two: [usize; 4] = [1, 2, 3, 4];
    let median_one = <&mut [usize; 3] as ArrayLengthIsOdd::<{arr.len()}>>::calc_median(&mut arr);
    let median_two = <&mut [usize; 4] as ArrayLengthIsEven::<{arr_two.len()}>>::calc_median(&mut arr_two);
    println!("{median_one}, {median_two}");
}

Managed to write something like so: Rust Playground

It does give different types but it doesn't compute it at compile time.

Did you look at my Playground above?

Yes, I saw it here Rust Playground

Your solution is definitely way more elegant than mine. It seems way simpler too and without the use of proc-macro.

1 Like

Would you mind explaing why on Line 47 you are implementing HasMedian trait for (i64, T)?

Thank you.

The "array" of numbers is represented as a type-level linked list of the form (i64, (i64, (…, (…, i64)))). That step recursively implements the trait for every such list.

That's the whole point of the entire implementation, in fact. The definition of its associated type:

type Median = <T::Median as Invertible>::Inverse;

ensures that a list of length N+1 has a median of type f64 if the median of the list of length N has a median of type i64, and vice versa.

1 Like

Yeah, I noted that is the genius part of your solution! It flips the associated type of Invertible with every increment of the length of the sequence passed into the declarative macro list!(...). For i64, its Invertible::Inverse is i64, which serves as the base case, and each subsequent increment flips the associated type once, hence for a sequence of length N, it flips N-1 times. If N is even, N-1 will be odd and hence the flips end up in f64 and vice versa.

Sir, would you mind sharing your thought process on arriving at this solution? Thank you.