Create AST with Std::Op::Add/Mul without global state

Hi. New user of Rust here trying to implement my first Rust program.

I want to be able to create an abstract syntax tree (AST) by using operator overloading with Std::Op::Add etc to be able to write code like let f = x * w + b * 3.0; and then be able to mutate the values in the tree. I want to store the AST nodes in a stack so that it will be topologically sorted during creation.

My current solution works, but I have to have a static mut STATIC_NS to be able to access my stack from Std::Op::Add. I tried to change the struct in the code below to:

struct NodePtr {
    i: usize,
    ns: Rc<RefCell<NodeStack>>,
}

To be able to create the nodes with Std::Op:Add without a global state. But then I can't figure out how to be able to mutate the nodes in the tree. I also want to keep the syntax clean when creating the expressions. Any ideas how to do it? Is using a Rc<RefCell<> the right way to go?

This is a running example of my current code for creating the AST.

#[derive(Debug, Clone, Default)]
struct NodeData {
    value: f32,
    label: String
}

#[derive(Debug, Clone)]
enum Op {
    Add,
    Mul,
}
#[derive(Debug, Clone)]
enum UnaryOp {
    Log,
}

#[derive(Debug, Clone)]
enum NodeType {
    Value,
    BinaryOp { op: Op, left: usize, right: usize },
    UnaryOp { op: UnaryOp, left: usize }
}

#[derive(Debug, Clone)]
struct Node {
    data: NodeData,
    node: NodeType
}

#[derive(Debug)]
struct NodeStack {
    stack: Vec<Node>,
}

#[derive(Debug, Clone, Copy)]
struct NodePtr { // Pointer to an index in the NodeStack
    i: usize,
}

static mut STATIC_NS: NodeStack = NodeStack { stack: Vec::new() };
impl NodeStack {
    fn get_static() -> &'static mut NodeStack {
        unsafe { &mut STATIC_NS }
    }

    fn get(&self, node: &NodePtr) -> &Node {
        &self.stack[node.i]
    }

    fn get_mut(&mut self, node: &NodePtr) -> &mut Node {
        &mut self.stack[node.i]
    }

    fn val(&mut self, data: f32) -> NodePtr {
        let node = Node {
            data: NodeData { value: data, ..NodeData::default() },
            node: NodeType::Value
        };
        self.stack.push(node);
        NodePtr { i: self.stack.len() - 1 }
    }

    fn set_label(&mut self, node: &NodePtr, label: &str) {
        self.stack[node.i].data.label = String::from(label);
    }

    fn create_op(&mut self, op: Op, left: &NodePtr, right: &NodePtr) -> NodePtr {
        let node = Node {
            data: NodeData::default(),
            node: NodeType::BinaryOp { op, left: left.i, right: right.i }
        };
        self.stack.push(node);
        let node_ptr = NodePtr { i: self.stack.len() - 1 };
        node_ptr
    }
}

macro_rules! val {
    ($ns:expr, $var:ident, $value:expr) => {
        let mut $var = $ns.val($value);
        $ns.set_label(&$var, stringify!($var));
    };
}

macro_rules! binary_operator_overload {
    ($op:ident, $f:ident) => {
        impl std::ops::$op<NodePtr> for NodePtr {
            type Output = NodePtr;

            fn $f(self, other: NodePtr) -> NodePtr {
                unsafe { STATIC_NS.create_op(Op::$op, &self, &other) }
            }
        }
        impl std::ops::$op<f32> for NodePtr {
            type Output = NodePtr;

            fn $f(self, other: f32) -> NodePtr {
                unsafe { STATIC_NS.create_op(Op::$op, &self, &STATIC_NS.val(other)) }
            }
        }
    }
}
binary_operator_overload!(Add, add);
binary_operator_overload!(Mul, mul);

fn main() {
    let ns = NodeStack::get_static();
    val!(ns, w, 1.);
    val!(ns, b, 1.);
    val!(ns, x, 1.);

    let f = x * w + b * 3.;
    ns.get_mut(&b).data.value = 2.;
    println!("b: {:?}", ns.get(&b));
    println!("f = w * x + b: {:?}.", ns.get(&f));
}

static mut is pretty much impossible to use soundly (experts get it wrong), and will probably be deprecated once the common use cases have reasonable defaults in std. If I run your code against Miri in the playground (under Tools, upper-right), it detects undefined behavior.

Here's a pretty mechanical translation of that approach. I ended up adding a bunch of operator variations and you have to use references (or clone) in the expression now (or give up ownership of the variables).


How about using one of those static mut replacements I mentioned? Here's another mostly mechanical translation using OnceSync and a Mutex. The first version I had which compiled deadlocked, which is why I added:

    let mut ns = NodeStack::get_static();
    val!(ns, w, 1.);
    val!(ns, b, 1.);
    val!(ns, x, 1.);
    drop(ns);
//  ^^^^^^^^^

...because your operators lock the Mutex (and trying to obtain more than one lock on the same thread results in a deadlock or panic or similar). You'd have to take care to never hold on to the lock while you're performing operations (generally you don't want to hold on to the lock much at all). Moving that into your val macro helps some.

But it's still foot-gunny enough I'd probably prefer the other attempt. Or a re-entrant mutex (but there isn't one in std).


I suspect I haven't found the best solution to your question, but that's where I got to.

2 Likes

Thanks, that was an amazing reply, I learned a lot!

In my own attempt using Rc<RefCell<NodeStack>> I didn't think of overloading &NodePtr for the Add and Mul operations, so I got an issue with:

    let f = x * w + b * 3.;
    let g = &f * 10.;

Since f gets moved when I created g. I changed val! to return a &NodePtr to be able to skip some of the & when creating nodes.

I think it's a bit interesting that RustRover cannot handle ns.get_mut(&b).data.value = 2.; and shows an error for that line of code.

Thanks a lot!

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.