Handling enums and type flexibility without repeating match arm code?

I have the instinct that I don't fully understand the idioms for working with the type system efficiently and properly.

Consider the below, where I have some number system which defines mathematical operations between f64 and my own constructs, call them n1 and n2. All of the necessary operator overloading allows mixing these types in operations, such as +, -, *, / etc.

These often get wrapped in an enum to allow the user flexibility on input:

enum NorF { F64(f64), N1(n1), N2(n2) }

When dealing with containers of these I also need to define an enum for that:

enum NorFArray{ F64(Array<f64>), N1(Array<n1>), N2(Array<n2>)

(I avoid Array<NorF> becuase it loses the efficiency of BLAS in the f64 case)

What I find, however is that I am constantly writing match statements and within the arms repeatedly writing almost exactly the same code (which might long). For example:

fn mutate_some_array<T>(mut input: ArrayViewMut<T>) -> () {
    // do something with the generic array
}

fn user_function(input: NorF) -> NorFArray {
    match input {
        NorF::F64(_) => {
            let mut arr: Array1<f64> = Array1::ones(2);
            mutate_array(arr.view_mut());
            NorFArray::F64(arr)
        }
        NorF::N1(_) => {
            let mut arr: Array1<n1> = Array1::ones(2);
            mutate_array(arr.view_mut());
            NorFArray::N1(arr)
        }
        NorF::N2(_) => {
            let mut arr: Array1<n2> = Array1::ones(2);
            mutate_array(arr.view_mut());
            NorFArray::N2(arr)
        }
    }
}

The issue is that the different types are handled in the local scope of the arms and the enum allocated at the end of the arm to return correctly.

This toy example is obviously manageable but typically I have more combinations, or the repeated code blocks maybe more complex or longer. How does one with more experience typically handle this?


Playground Example

use ndarray::prelude::*; // 0.15.6
use std::ops::{Mul};
use num_traits::{One, Zero};

#[derive(Debug, Clone)]
struct n1 {
    val: f64
}
#[derive(Debug, Clone)]
struct n2 {
    val2: f64 
}
impl Mul for n1 {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self {
        Self {val: self.val * rhs.val}
    }
}
impl Mul for n2 {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self {
        Self {val2: self.val2 * rhs.val2}
    }
}
impl One for n1 {
    fn one() -> n1 {
        n1 {val: 1.0_f64}
    }
}
impl One for n2 {
    fn one() -> n2 {
        n2 {val2: 1.0_f64}
    }
}
#[derive(Debug)]
enum NorF {F64(f64), N1(n1), N2(n2)}
#[derive(Debug)]
enum NorFArray {F64(Array1<f64>), N1(Array1<n1>), N2(Array1<n2>)}

fn mutate_array<T>(mut arr: ArrayViewMut1<T>) -> () 
where T: One
{
    arr[0] = T::one()
}

fn create_array(input: NorF) -> NorFArray {
    match input {
        NorF::F64(_) => {
            let mut arr: Array1<f64> = Array1::ones(2);
            mutate_array(arr.view_mut());
            NorFArray::F64(arr)
        }
        NorF::N1(_) => {
            let mut arr: Array1<n1> = Array1::ones(2);
            mutate_array(arr.view_mut());
            NorFArray::N1(arr)
        }
        NorF::N2(_) => {
            let mut arr: Array1<n2> = Array1::ones(2);
            mutate_array(arr.view_mut());
            NorFArray::N2(arr)
        }
    }
}
fn main() {
    let norfarray = create_array(NorF::N1(n1{val: 2.5_f64}));
    println!("{:?}", norfarray)
}
1 Like

I have run into similar problems myself and have not seen any truly elegant solutions. The situation could probably be improved by defining separate types implementing common trait(s) and using dynamic dispatch (Box<dyn ...>), but i am assuming we want to avoid that for performance reasons.

Staying with the enum case, we can use macros to avoid repeating the match arms:

macro_rules! match_norf {                                                                                                                                                                                                                     
    ($input:expr, $f:ident) => {                                                                                                                                                                                                              
        match $input {                                                                                                                                                                                                                        
            NorF::F64(_) => $f!(f64, F64),                                                                                                                                                                                                    
            NorF::N1(_) => $f!(n1, N1),                                                                                                                                                                                                       
            NorF::N2(_) => $f!(n2, N2),                                                                                                                                                                                                       
        }                                                                                                                                                                                                                                     
    };                                                                                                                                                                                                                                        
}                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                              
fn create_array(input: NorF) -> NorFArray {                                                                                                                                                                                                   
    macro_rules! f {                                                                                                                                                                                                                          
        ($ty:ty, $var:ident) => {{                                                                                                                                                                                                            
            let mut arr: Array1<$ty> = Array1::ones(2);                                                                                                                                                                                       
            mutate_array(arr.view_mut());                                                                                                                                                                                                     
            NorFArray::$var(arr)                                                                                                                                                                                                              
        }};                                                                                                                                                                                                                                   
    }                                                                                                                                                                                                                                         
    match_norf!(input, f)                                                                                                                                                                                                                     
}

Of course, if the full computation is only ever going to happen on one type of number at a time (since you are using Array1<f64> and not Array1<NorF>), you can make all functions generic over the number type:

fn create_array<T: One + Clone>() -> Array1<T> {
    let mut arr: Array1<T> = Array1::ones(2);
    mutate_array(arr.view_mut());
    arr
}

We can then choose the specific type at the top level (I removed the argument since it is unused):

fn main() {
    let norfarray = create_array::<f64>();
    println!("{:?}", norfarray)
}

With Generic Associated Types we can also make the NorF enum generic over the container type.

trait NorFType {
    type Type<T>;
}

#[derive(Debug)]
struct Number;

impl NorFType for Number {
    type Type<T> = T;
}

#[derive(Debug)]
struct Array;

impl NorFType for Array {
    type Type<T> = Array1<T>;
}

#[allow(dead_code)]
#[derive(Debug)]
enum NorF<T: NorFType> {
    F64(T::Type<f64>),
    N1(T::Type<N1>),
    N2(T::Type<N2>),
}

Which can be used like this to reproduce your original example:

fn create_array_enum(input: NorF<Number>) -> NorF<Array> {
    match input {
        NorF::F64(_) => NorF::F64(create_array()),
        NorF::N1(_) => NorF::N1(create_array()),
        NorF::N2(_) => NorF::N2(create_array()),
    }
}

fn main() {
    let input: NorF<Number> = NorF::N1(N1 { val: 2.5_f64 });
    let norfarray = create_array_enum(input);
    println!("{:?}", norfarray)
}
2 Likes

Thanks this was helpful, especially about associated types.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.