Using Macros to modify AST to modify and add line of codes in function

Hi Everyone,

I would like to write a macro that when used on a function or trait, it would modify the function and generate a line of code (for example incremental a global counter). So, assuming my understanding of Rust macro is correct, it is writing a macro that modify the AST and then return out TokenStream as well.

Yes that's about right. You'll probably want to write a proc macro which takes in a TokenStream and returns a new TokenStream. Often proc macros are coupled with syn which converts the input TokenStream to a syntax tree. And quote is commonly used to help turn rust code in the macro definition into an output TokenStream.

So would it be correct to say, in order to implement it as such:

#[proc_macro]
pub fn macro_demo(item: TokenStream) -> TokenStream {
	let item_fn = syn::parse::<ItemFn>(item).expect("Failed to Parse"); // to generate AST to modify?

	if (some condition) {
		quote!( // insert the line of code I want here?)
	}

	let token = quote!(item_fn); // any other expression that does not match said condition
        TokenStream::from(token)
}

Then when used it should like this?

#[macro_example]
fn target_function(...) {
   // does something
}

So the type of procedural macro I should use is a derive macro?

So if you want to manipulate the function body, you'll want an attribute macro. A derive macro is specifically if you want to be used #[derive(XXX)] on an item which will only generate additional code based on the item's definition but cannot manipulate the original code.

Your code is in the ball park I think of what you want though if you want to insert code inside a function you'll need to do a little bit of work to manipulate the item_fn so that your code ends up inside the function definition.

Here's a fairly simple example of inserting some code into a function. I'm not in any way an expert in writing macros, so there might be a cleaner way to do this, but hopefully it gives you a starting point to work from.

counter-macro/Cargo.toml

[package]
name = "counter-macro"
version = "0.1.0"
authors = ["drewkett"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[lib]
proc-macro = true

[dependencies]
syn = { version = "1", features = ["full"] }
quote = "1"

counter-macro/src/lib.rs

use proc_macro::*;
use quote::quote;

#[proc_macro_attribute]
pub fn counter(_args: TokenStream, input: TokenStream) -> TokenStream {
    let mut item: syn::Item = syn::parse(input).unwrap();
    let fn_item = match &mut item {
        syn::Item::Fn(fn_item) => fn_item,
        _ => panic!("expected fn")
    };
    fn_item.block.stmts.insert(0,syn::parse(quote!(println!("count me in");).into()).unwrap());

    use quote::ToTokens;
    item.into_token_stream().into()
}

counter/Cargo.toml

[package]
name = "counter"
version = "0.1.0"
authors = ["drewkett"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
counter-macro = { path="../counter-macro" }

counter/src/main.rs

#[counter_macro::counter]
fn dummy() -> i32 {
    42
}
fn main() {
    dummy();
}

outputs

> cargo run
count me in
3 Likes

Wow! Thank you so much @drewkett! Appreciate it! Yeah this certainly makes it clearer and help to really see the type of macro that I really needed and also how to manipulate the AST better as well. Yes I believe this would help be a great starting point. Let me try this out and if there is any questions I will come back to this thread again!:slight_smile:

Hi Drew, just got a question. I tried to emulate what you did on the match implementation and also read up and tried to do a select this particular case and do nothing with the rest. Here is a snippet of what I am trying to do. Now I am trying to go one node lower in the AST, so I tried matching to obtain statements, then from statements to obtain the expression. My particular interest is found in the unsafe_token found in the syn::ExprUnsafe. Here is an example:

for i in 0..fn_item.block.stmts.len() {
                let expr = match &mut fn_item.block.stmts[i] {
                        syn::Stmt::Expr(expr) => expr,
                        _ => {},
                };

                let target = match expr {
                        syn::Expr::Unsafe(target) => target,
                        _ => {},
                };

Am I iterating through this correctly?

The trouble with those match statements is that each arm doesn't return the same value. The first one returns expr for Expr(..) but returns () if it doesn't match. The compiler doesn't like that since let expr can't be assigned a single type which is required. One way to fix that is to change _ => {} to _ => continue which will make the loop continue to the next value (so that let expr never gets reached in that case) but the following is the more idiomatic way of accomplishing this (including some other tweaks)

for stmt in &mut fn_item.block.stmts {
    if let syn::Stmt::Expr(syn::Expr::Unsafe(target)) = &mut stmt {
        // use target here
    }
};

The first change was to not iterate using indexes. You almost never want to do that in rust. Indices are both more likely to be used incorrectly and they are usually less performant in rust. Rust has powerful iterators that should be used for iterating over collections.

The second change was to collapse your two match statements into a single if let binding. if let works the same as match as far as matching is concerned, except that you can only check for one condition but in this case that seems fine. Additionally the mechanism of matching two levels of structures in a single statement is called structural matching and its very helpful for working with two levels of data at once. (You often see it used for Option<Result<..>> or Result<Option<..>> types).

Final note: if you're looking for unsafe anywhere in the function definition, you'll probably need to look inside ExprBlock's among other things for statements that don't directly fall under that top level. This will make this code more complicated, but I'll stop at what I wrote for now.

1 Like

Thank you!

I guess iterating through index is something that I picked up from the use of Python. Been trying to understand Rust coming from a background with little C experience. Understand why Rust didn't have much documentation of using indexes in their rust documentation hand book already.

I will probably read up more about this structural matching cause I have not heard this before.

Can I also clarify because you mentioned that I have to look into ExprBlock, that this will probably need me use some form of recursion to fully traverse through the code base to iterate through every single expression in the code base? Because while I was looking at the documentation I noticed that even for an ExprUnsafe, it has a field Block which means I can look into this block and find more stmts which are also expressions that can contain other blocks. So in a sense the illustration of the AST is something like this if I comprehend it correctly:

Item -> ItemFn -> Block -> Stmts -> Expr -> Block -> Stmts ...
                                                  -> Item -> ItemFn -> ...

As I am writing this, it feels like this problem I am looking at is a tree search program.

Yes, for that syn offers a bunch of convenience abstractions that let you "map" (Fold) a specific element, or "mutate" (VisitMut) it (the difference is a matter of style and having to deal with &mut access vs owned access; in practice and from experience I find VisitMut way more convenient than Fold (e.g., imagine wanting to change a field within an enum variant, in the Fold case, you need to reconstruct the whole enum).


Here is an example replacing all occurrences of "Answer to the Ultimate Question of Life, the Universe, and Everything" with 42:

let mut input: ItemFn = parse_macro_input!(input);

struct Visitor {
    /* state here */
}

impl VisitMut for Visitor {
    fn visit_expr_lit_mut (&mut self, node: &mut ExprLit)
    {
        // Sub-recurse (not really needed here since there aren't
        // sub-expressions within an ExprLit):
        visit_mut::visit_expr_lit_mut(self, node);

        if matches!(
            *dbg!(&node.lit),
            Lit::Str(ref s)
            if s.value() == "\
                Answer to the Ultimate Question of Life, the Universe, and Everything\
            "
        )
        {
            *node = parse_quote!( 42 );
        }
    }
}

let mut visitor = Visitor { /* initial state */ };

visitor.visit_item_fn_mut(&mut input);

input.into_token_stream()
4 Likes

Oh wow. I did not know this. Was already attempting to try to split the code into functions and then recursively call visit_fn_item and visit_expr accordingly. Will take a look at the documentation! Just wanted to clarify, what do you mean by state? Does it refer to the start and end point of the Visitor? Also, I noticed in the code you left in the Rust Playground, there were use of proc_macro2? is there a difference or is this like something Rust is planning to shift from a proc_macro to proc_macro2 and deprecate the former?

Sometimes you want to also accumulate some state; e.g., a counter that increments each time 42 is seen (see the Visit module for better examples of that). In this instance, I don't think you will need to carry an actual state: an empty struct will do the trick just fine.

For technical reasons related to the implementation of proc-macros (they are, after all, a form of compiler plugin), ::proc_macro functions are currently only usable within a proc-macro = true crate. This mean that neither integration tests, nor classic crates such as the Playground environment can call them.

::proc-macro2 is thus a wrapper with a fallback polyfill when the environment that uses its types and functions is not a proc-macro = true one. That's why your code will be more flexible if it uses ::proc_macro2 types such as its TokenStream (often renamed to TokenStream2), and only performs the necessary "trivial" .into() conversions to/from the ::proc_macro::TokenStream type within the #[proc_macro]-annotated functions :slightly_smiling_face:

2 Likes

Ohhh got it! thank you so much @Yandros ! Appreciate your add-on to this. Been tinkering with it and I think it is working for me! The state you mentioned for the Struct Visitor is something I will definitely need considering on my use case.

I see. I was slightly confused and a bit concerned when I saw the proc-macro2. Thought I will be working on deprecated functions hahah :sweat_smile:

Now, I need to work on differentiating the branches, but I definitely learnt a lot more about working with the AST in Rust and macros too with this too! Thank you so much guys!:slight_smile: really appreciate it!

2 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.