How can I test that two traits have the same set of function names?

Hello,

I am developing a sync and async library, with mirror traits for each paradigm.

For example:

trait FooSync {
  pub fn bar(&self) -> u32;
  pub fn baz(&self, delay: &SyncDelay) -> Result<(), ()>;
}
trait FooAsync {
  pub async fn bar(&self) -> u32;
  pub async fn baz(&self, delay: &AsyncDelay) -> Result<(), ()>;
}

The two traits are in different files, and are pretty big in the real code.

How can I write a test to ensure that the same sets of function names are available on both traits? The functions will not be of the same types, as one set is synchronous and the other is not.

I do not care if the implementations differ, there can be other tests for that. I just do not want that a function be added on only one side.

Thanks

Write a macro which generates an implementation of a trait with this shape. Call the macro twice. Then, any discrepancies (except for provided methods in one trait but not the other) will make one or the other implementation fail.

struct SyncDelay;
struct AsyncDelay;

trait FooSync {
    fn bar(&self) -> u32;
    fn baz(&self, delay: &SyncDelay) -> Result<(), ()>;
}
trait FooAsync {
    async fn bar(&self) -> u32;
    async fn baz(&self, delay: &AsyncDelay) -> Result<(), ()>;
}

mod tests {
    use super::*;
    
    struct Dummy;
    
    macro_rules! impl_foo {
        ($($modifier:ident)?, $trait_name:ident, $delay:ident)  => {
            impl $trait_name for Dummy {
              $($modifier)? fn bar(&self) -> u32 {
                  unimplemented!()
              }
              $($modifier)? fn baz(&self, _: &$delay) -> Result<(), ()> {
                  unimplemented!()
              }
            }
        }
    }

    impl_foo!(, FooSync, SyncDelay);
    impl_foo!(async, FooAsync, AsyncDelay);
}

I would like that I do not need to repeat a third time all the functions and parameterize them. Also, I would like to catch default/provided impl mismatches too.

Also, another problem I then have is with struct impl, that do not belong to traits, like:

struct SyncThing {...};
impl SyncThing {
  pub fn init(&mut self, bus: &SyncI2CBus) -> Result<(), ()> {
    bus.write(0x00)
  }
}
struct AsyncThing {...};
impl AsyncThing {
  pub async fn init(&mut self, bus: &AsyncI2CBus) -> Result<(), ()> {
    bus.write(0x00).await
  }
}

As these are not traits, I do not see how I can test their contents either.

I’m not sure what the latest discussions are on how evil it is to rely on global state in proc macros… but for now something like this could work:

[package]

name = "my_macro"

version = "0.1.0"

edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[lib]

proc-macro = true

[dependencies]

itertools = "0.12.0"

proc-macro2 = "1.0.76"

quote = "1.0.35"

syn = { version = "2.0.48", features = ["full"] }
use std::{
    collections::{hash_map::Entry, HashMap, HashSet},
    sync::Mutex,
};

use itertools::Itertools;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemTrait, LitStr, TraitItem};

#[proc_macro_attribute]
pub fn check_consistent_trait_fn_items(args: TokenStream, input: TokenStream) -> TokenStream {
    // Parse the input tokens into a syntax tree
    let original_input = proc_macro2::TokenStream::from(input.clone());
    let input = parse_macro_input!(input as ItemTrait);

    let methods = input
        .items
        .iter()
        .filter_map(|i| {
            let TraitItem::Fn(i) = i else { None? };
            Some(i.sig.ident.to_string())
        })
        .collect::<HashSet<_>>();

    // Parse the macro argument as a string literal
    let key = parse_macro_input!(args as LitStr).value();

    static METHODS_OF_TRAIT: Mutex<Option<HashMap<String, HashSet<String>>>> = Mutex::new(None);

    let mut methods_of_trait_guard = METHODS_OF_TRAIT.lock().unwrap();
    let methods_of_trait = methods_of_trait_guard.get_or_insert_with(HashMap::new);

    // Build the output, possibly using quasi-quotation
    let expanded = match methods_of_trait.entry(key.clone()) {
        Entry::Occupied(expected_methods) if *expected_methods.get() == methods => {
            quote! {
                #original_input
            }
        }
        Entry::Occupied(expected_methods) => {
            let error = format!("\
                Mismatching methods list between the traits marked as `#[check_consistent_trait_fn_items({key:?})]`:\n\n\
                Expected Methods: {}\n  \
                Actual Methods: {}\n\n",
                expected_methods.get().iter().sorted().format_with(", ", |i, f| f(&format_args!("{i}"))),
                methods.iter().sorted().format_with(", ", |i, f| f(&format_args!("{i}")))
            );
            quote! {
                #original_input
                compile_error!{#error}
            }
        }
        Entry::Vacant(e) => {
            e.insert(methods);
            quote! {
                #original_input
            }
        }
    };

    // Hand the output tokens back to the compiler
    TokenStream::from(expanded)
}
use my_macro::check_consistent_trait_fn_items;

fn main() {
    println!("Hello, world!");
}

#[check_consistent_trait_fn_items("Foo")]
trait FooSync {
    fn f();
    fn h();
}

#[check_consistent_trait_fn_items("Foo")]
trait FooAsync {
    fn f();
    fn g();
    fn h();
}
error: Mismatching methods list between the traits marked as `#[check_consistent_trait_fn_items("Foo")]`:
       
       Expected Methods: f, h
         Actual Methods: f, g, h
       
  --> src/main.rs:13:1
   |
13 | #[check_consistent_trait_fn_items("Foo")]
   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   |
   = note: this error originates in the attribute macro `check_consistent_trait_fn_items` (in Nightly builds, run with -Z macro-backtrace for more info)

Edit:

Here’s a version that also checks for equivalent set of default-implemented methods
use std::{
    collections::{hash_map::Entry, BTreeMap, HashMap},
    sync::Mutex,
};

use itertools::Itertools;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemTrait, LitStr, TraitItem};

#[derive(PartialEq, Eq)]
struct MethodInfo {
    name: String,
    provided_default: bool,
}

impl std::fmt::Display for MethodInfo {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if self.provided_default {
            write!(f, "{} (has default)", self.name)
        } else {
            write!(f, "{}", self.name)
        }
    }
}

#[proc_macro_attribute]
pub fn check_consistent_trait_fn_items(args: TokenStream, input: TokenStream) -> TokenStream {
    // Parse the input tokens into a syntax tree
    let original_input = proc_macro2::TokenStream::from(input.clone());
    let input = parse_macro_input!(input as ItemTrait);

    let methods = input
        .items
        .iter()
        .filter_map(|i| {
            let TraitItem::Fn(i) = i else { None? };
            let name = i.sig.ident.to_string();
            Some((
                name.clone(),
                MethodInfo {
                    name,
                    provided_default: i.default.is_some(),
                },
            ))
        })
        .collect::<BTreeMap<_, _>>();

    // Parse the macro argument as a string literal
    let key = parse_macro_input!(args as LitStr).value();

    static METHODS_OF_TRAIT: Mutex<Option<HashMap<String, BTreeMap<String, MethodInfo>>>> =
        Mutex::new(None);

    let mut methods_of_trait_guard = METHODS_OF_TRAIT.lock().unwrap();
    let methods_of_trait = methods_of_trait_guard.get_or_insert_with(HashMap::new);

    // Build the output, possibly using quasi-quotation
    let expanded = match methods_of_trait.entry(key.clone()) {
        Entry::Occupied(expected_methods) if *expected_methods.get() == methods => {
            quote! {
                #original_input
            }
        }
        Entry::Occupied(expected_methods) => {
            let error = format!("\
                Mismatching methods list between the traits marked as `#[check_consistent_trait_fn_items({key:?})]`:\n\n\
                Expected Methods: {}\n  \
                Actual Methods: {}\n\n",
                expected_methods.get().iter().map(|(_, i)| i).format(", "),
                methods.iter().map(|(_, i)| i).format(", "),
            );
            quote! {
                #original_input
                compile_error!{#error}
            }
        }
        Entry::Vacant(e) => {
            e.insert(methods);
            quote! {
                #original_input
            }
        }
    };

    // Hand the output tokens back to the compiler
    TokenStream::from(expanded)
}

Also, if you don’t want this macro to run normally and just be a dev-dependency for testing, not a normal one, you can cfg it out for test purposes like

#[cfg(test)]
use my_macro::check_consistent_trait_fn_items;

#[cfg_attr(test, check_consistent_trait_fn_items("Foo"))]
trait Sync Foo { … }
1 Like

If you need to test (in general) the two APIs, you will need to write normal test code that calls all the methods in each API. You could use this test code to also ensure (roughly) that you have the same methods in each API.

There are various ways to do that. For example, you could have one test method for each API method that tests both the async and non-async version. Of course, you may forget to test an API method entirely. But at least if you do test it, you'll probably remember to test both versions just by following the pattern in the tests.

I realize this is not the automatic, guaranteed sort of approach you are probably hoping for. But it is very practical and simple. There are other types of consistency between the two APIs that probably need testing, such as behavior consistency. So you can't really avoid writing the tests and doing the consistency checks.

Surely there are also non-test implementations of the traits, which will be similarly sized? If one more is a burden, possibly your design needs a rethink.

2 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.