Cartesian product in pattern matching with enum

I have an enum storing variant numeric types, lets say

enum MyNum {
    Int(u8),
    BigInt(u64),
    Float(f64),
}

(I have even more variants in my actual code). Then I want to implement PartialOrd for this enum, the only way I can think of is to enumerate all the cases manually:

impl PartialOrd for MyNum {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        match (self, other) {
          (MyNum::Int(l), MyNum::Int(r)) => l.partial_cmp(r),
          (MyNum::Int(l), MyNum::BigInt(r)) => (*l as u64).partial_cmp(r),
          (MyNum::Int(l), MyNum::Float(r)) => (*l as f64).partial_cmp(r),
          (MyNum::BigInt(l), MyNum::Int(r)) => l.partial_cmp(&(*r as u64)),
          (MyNum::BigInt(l), MyNum::BigInt(r)) => l.partial_cmp(r),
          (MyNum::BigInt(l), MyNum::Float(r)) => (*l as f64).partial_cmp(r),
          (MyNum::Float(l), MyNum::Int(r)) => l.partial_cmp(&(*r as f64)),
          (MyNum::Float(l), MyNum::BigInt(r)) => l.partial_cmp(&(*r as f64)),
          (MyNum::Float(l), MyNum::Float(r)) => l.partial_cmp(r)
      }
    }
}

Is there any easier and simpler approach to implement this? I have two possible solutions in mind, but not sure if they are feasible in rust.

The first one is to somehow allow pattern matching with same variant

match (self, other) {
    (MyNum::_(l), MyNum::_(r)) => l.partial_cmp(r),
    (MyNum::Int(l), MyNum::BigInt(r)) => (*l as u64).partial_cmp(r),
    (MyNum::Int(l), MyNum::Float(r)) => (*l as f64).partial_cmp(r),
    (MyNum::BigInt(l), MyNum::Int(r)) => l.partial_cmp(&(*r as u64)),
    (MyNum::BigInt(l), MyNum::Float(r)) => (*l as f64).partial_cmp(r),
    (MyNum::Float(l), MyNum::Int(r)) => l.partial_cmp(&(*r as f64)),
    (MyNum::Float(l), MyNum::BigInt(r)) => l.partial_cmp(&(*r as f64)),
}

This solution saves 3 branches and it's more generalizable, and it's related to this post. The second one is somehow let a macro or function iterates all the variants

enumerate!(self, other, |l, r| (*l as f64).partial_cmp(&(*r as f64)))

I guess these two solutions could be accomplished using macros, but I'm not that experienced with macros. If there are some less-known syntax or crates that could be used, please let me know! Thank you!

One problem with your example is that different types, like u8, u64 and f32, don’t even support direct comparison (via PartialOrd) between each other, so your example code that enumerates all the cases manually doesn’t work.

I’m happy to help with macros, but it would be easier if there was some kind of goal / desired resulting code that actually compiles.

1 Like

Apologize that I didn't check if the example compiles. I updated the post, thanks for your help!

Note that converting to f64 unconditionally would be a bad idea for comparing u64s that are out of the range where f64-precision is high enough to represent every integer without gaps.

E.g.

fn main() {
    dbg!(1_000_000_000_000_000_000_u64 < 1_000_000_000_000_000_001_u64);
    dbg!((1_000_000_000_000_000_000_u64 as f64) < (1_000_000_000_000_000_001_u64 as f64));
}
     Running `target/debug/playground`
[src/main.rs:2] 1_000_000_000_000_000_000_u64 < 1_000_000_000_000_000_001_u64 = true
[src/main.rs:3] (1_000_000_000_000_000_000_u64 as f64) <
    (1_000_000_000_000_000_001_u64 as f64) = false
1 Like

And comparing between u64 and f64 is nontrivial, too, as you’ll need to decide how to handle the f64s. If you want to interpret large f64s as their exact integer value, you might want to achieve behavior where, similar to the example above, comparing BigInt(1_000_000_000_000_000_001_u64) and Float(1_000_000_000_000_000_000_f64) doesn’t return “equal”. This is also necessary if you want to ensure that comparison returns correct transitive results as specified in the documentation of PartialOrd and PartialEq. I.e. if Float(1_000_000_000_000_000_000_f64) == BigInt(1_000_000_000_000_000_001_u64) returned true, as does BigInt(1_000_000_000_000_000_000_u64) == Float(1_000_000_000_000_000_000_f64),but BigInt(1_000_000_000_000_000_001_u64) == BigInt(1_000_000_000_000_000_000_u64) returns false, then a == b and b == c no longer implies a == c.

use std::cmp::Ordering;

enum MyNum {
    Int(u8),
    BigInt(u64),
    Float(f64),
}

impl PartialEq for MyNum {
    fn eq(&self, other: &Self) -> bool {
        self.partial_cmp(other) == Some(Ordering::Equal)
    }
}

impl PartialOrd for MyNum {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        match (self, other) {
          (MyNum::Int(l), MyNum::Int(r)) => l.partial_cmp(r),
          (MyNum::Int(l), MyNum::BigInt(r)) => (*l as u64).partial_cmp(r),
          (MyNum::Int(l), MyNum::Float(r)) => (*l as f64).partial_cmp(r),
          (MyNum::BigInt(l), MyNum::Int(r)) => l.partial_cmp(&(*r as u64)),
          (MyNum::BigInt(l), MyNum::BigInt(r)) => l.partial_cmp(r),
          (MyNum::BigInt(l), MyNum::Float(r)) => (*l as f64).partial_cmp(r),
          (MyNum::Float(l), MyNum::Int(r)) => l.partial_cmp(&(*r as f64)),
          (MyNum::Float(l), MyNum::BigInt(r)) => l.partial_cmp(&(*r as f64)),
          (MyNum::Float(l), MyNum::Float(r)) => l.partial_cmp(r)
      }
    }
}

fn main() {
    use MyNum::*;
    let a = BigInt(1_000_000_000_000_000_000_u64);
    let b = Float(1_000_000_000_000_000_000_f64);
    let c = BigInt(1_000_000_000_000_000_001_u64);
    dbg!(a == b, b == c, a == c);
}
[src/main.rs:36] a == b = true
[src/main.rs:36] b == c = true
[src/main.rs:36] a == c = false

Rust Playground

Thanks for your detailed explanation!

I understand the risk of converting u64 to f64, I use these numeric types just for example. If it's more comfortable for you, we can use u32 for BigInt, or you can just assume I can handle the comparison between integers and floats safely (using some wrappers). What I want to emphasize here is how to simplify the match expression as I have a lot of customized numeric types.

Alright, I’ll write a macro then for a float-less case, for demonstration purposes :slight_smile:

1 Like

Actually, note that as long as the assumption is that there’s some canonical representation that all types can use, then you could just define a helper function to create that common representation. E.g. a function

fn as_u64(&self) -> u64 {
    match self {
        Int(x) => x as u64,
        BigInt(x) => x,
    }
}

would be sufficient to implement PartialOrd then.

Anyways, I’ll assume that there are some use-cases where you can actually benefit from creating a cartesian product of patterns and use syntactically the same expression to handle all of the cases. Then here’s an example macro

macro_rules! enumerate {
    ([$($variant:path),* $(,)*], $($vals:expr,)+$(,)? => |$($vars:ident),+$(,)?| $e:expr) => {
        enumerate! {
            #[$($variant,)*][$(($vals, $vars))+]$e
        }
    };
    (#[$($variant:path,)*][$pair:tt $($pairs:tt)*]$e:expr) => {
        enumerate! {
            ##[$($variant,)*][$($variant,)*]$pair[$($pairs)*]$e
        }
    };
    (##$variants:tt[$($variant:path,)*]($val:expr, $var:ident)$pairs:tt$e:expr) => {
        match $val {
            $(
                $variant($var) => enumerate! {
                    #$variants$pairs$e
                },
            )*
        }
    };
    (#$variants_:tt[]$e:expr) => {
        $e
    };
    (#$variants_:tt[$($__:tt)*]$e:expr) => {
        $e
    };
}

impl PartialOrd for MyNum {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        enumerate!([MyNum::Int, MyNum::BigInt],
            self,
            other,
            => |l, r| (*l as u64).partial_cmp(&(*r as u64))
        )
    }
}

Rust Playground

This expands to nested match expressions, because that’s slightly more straightforward to implement with a macro, and should behave the same.


If you don’t want to list the whole list of enum variants, you can also define those with a macro using callbacks. E.g.

macro_rules! define_variants_list {
    ($name:ident $($t:tt)*) => {
        define_variants_list!{# [$] $name $($t)*}
    };
    (#[$dollar:tt] $name:ident $($t:tt)*) => {
        macro_rules! $name {
            ($callback:path{$dollar($callback_args:tt)*}) => {
                $callback!{
                    $($t)*
                    $dollar($callback_args)*
                }
            }
        }
    };
}

define_variants_list! {
    my_num_variants
    [
        MyNum::Int,
        MyNum::BigInt,
    ]
}

macro_rules! enumerate {
    /* …… */
    ($list:path, $($t:tt)*) => {{
        $list! {
            enumerate {
                , $($t)*
            }
        }
    }};
}

impl PartialOrd for MyNum {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        enumerate!(my_num_variants,
            self,
            other,
            => |l, r| (*l as u64).partial_cmp(&(*r as u64))
        )
    }
}

Rust Playground


For more reliability across modules or crates, you can use absolute paths everywhere, too: Rust Playground

Thank you so much! These macros definitions look magic to me haha.

(Tho I still find it interesting if rust could have some match expressions for enum variant type wildcard)

Admitted, nontrivial macro_rules macros are generally super hard to read. Writing them is not that hard (though debugging is nontrivial, too). My point is, writing macros is often easier than reading them.

The main idea is to expand by taking one of the expressions (self, other) and one of the variables |l, r| and expanding to something like

match EXPR {
    Variant1(VAR) => /* recursive macro call */,
    Variant2(VAR) => /* recursive macro call */,
    Variant3(VAR) => /* recursive macro call */,
}

where the recursive call will have one fewer expression and one fewer variable, e.g. in the concrete example, the recursive call would be – essentially – calling

enumerate!(VARIANTS…, other => |r| (*l as u64).partial_cmp(&(*r as u64)))

The concrete place where this recursive call happens is

    (##$variants:tt[$($variant:path,)*]($val:expr, $var:ident)$pairs:tt$e:expr) => {
        match $val {
            $(
                $variant($var) => enumerate! {
                    #$variants$pairs$e
                },
            )*
        }
    };

The # and ##s define multiple alternative “internal” versions of the macro, kind of like helper functions. There’s some complexity in the macro around how repeating expressions work, e.g. bundling the list of variants into a single $variants, but also expanding it into a list of single $variant values, so that you can couple each individual $variant with a recursive call that also gets a copy of the whole list of $variants again.


The concrete expansion of

impl PartialOrd for MyNum {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        enumerate!(my_num_variants,
            self,
            other,
            => |l, r| (*l as u64).partial_cmp(&(*r as u64))
        )
    }
}

is

impl PartialOrd for MyNum {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        {
            match self {
                MyNum::Int(l) =>
                    match other {
                        MyNum::Int(r) => (*l as u64).partial_cmp(&(*r as u64)),
                        MyNum::BigInt(r) => (*l as u64).partial_cmp(&(*r as u64)),
                    },
                MyNum::BigInt(l) =>
                    match other {
                        MyNum::Int(r) => (*l as u64).partial_cmp(&(*r as u64)),
                        MyNum::BigInt(r) => (*l as u64).partial_cmp(&(*r as u64)),
                    },
            }
        }
    }
}

as can be observed using the “Expand Macros” tool in the playground


The callback techniques used to define an abbreviation/definition of a my_num_variants list are rather advanced and are only used to work around the fact that macros cannot directly expand other macros.

1 Like