`Weak` reference returning `None` in a tree

Hi

I am trying to code a neural network from scratch. My data structure is a tree with each node wrapping an f32.

#[derive(Clone, Debug)]
pub struct Scalar {
    node: Rc<RefCell<Node>>,
}

#[derive(Clone, Debug)]
struct Node {
    value: f32,
    children: Vec<Rc<Scalar>>,
    parent: Weak<Scalar>,
    label: String,
}

These are created with the following code:

impl Scalar {
    pub fn new(value: f32) -> Self {
        Self {
            node: Rc::new(RefCell::new(Node::new(value))),
        }
    }

    pub fn with_label(&self, label: &str) -> Self {
        {
            self.node.borrow_mut().label = label.to_string();
        }
        self.clone()
    }
}

impl Node {
    fn new(value: f32) -> Self {
        Self {
            value,
            children: vec![],
            parent: Weak::new(),
            label: "".to_string(),
        }
    }
}

When performing a basic maths operation such as adding, for example c = a + b, the operands a and b become children of c with strong references, while each of a and b link back to c with weak references. This is so that I can later implement automatic function differentiation for back propagation. I implement Add for the Scalar like this:

impl Add<&Scalar> for &Scalar {
    type Output = Scalar;

    fn add(self, rhs: &Scalar) -> Self::Output {
        let sum = self.node.borrow().value + rhs.node.borrow().value;
        let scalar = Scalar::new(sum);

        // make children
        scalar.node.borrow_mut().children = vec![Rc::new(self.clone()), Rc::new(rhs.clone())];

        // children need parents
       // FIXME: THIS DOES NOT WORK
        self.node.borrow_mut().parent = Rc::downgrade(&Rc::new(scalar.clone()));
        rhs.node.borrow_mut().parent = Rc::downgrade(&Rc::new(scalar.clone()));
        eprintln!(
            "trying to upgrade............. {:?}",
            self.node.borrow().parent
        );

        scalar
    }
}

The weak references back to the parent always end up returning None when I upgrade it. Reading the documentation and the chapter on reference cycles from the Rust book, this happens when the parent goes out of scope. I cannot identify why this happens and I think I am running around in circles trying to find a solution. My test case is the following:

#[cfg(test)]
mod tests {
    use super::*;

    fn float_compare(a: f32, b: f32) -> bool {
        const TOLERANCE: f32 = 1e-6;

        (a - b).abs() < TOLERANCE
    }

    #[test]
    fn add_two_scalars() {
        let a = Scalar::new(3.1).with_label("a");
        let b = Scalar::new(4.2).with_label("b");
        let c = (&a + &b).with_label("c");

        assert!(float_compare(c.node.borrow().value, 7.3));

        let label_child_0 = c.node.borrow().children[0].node.borrow().label.clone();
        let label_child_1 = c.node.borrow().children[1].node.borrow().label.clone();
        assert_eq!(label_child_0, "a");
        assert_eq!(label_child_1, "b");

        assert!(float_compare(a.node.borrow().value, 3.1));

        // TODO: this should not work
        assert!(a.node.borrow().parent.upgrade().is_none());
    }
}

I would like help in identifying what I do wrong with my code. Suggestions for crates that do similar tasks are appreciated, but I would prefer to solve this using as few dependencies as possible because my main goal is to expose myself to the standard library and learn to be a better Rust programmer. This is the first time I am learning about interior mutability (I come from the world of Python) so if there is a better way to be more idiomatic in Rust please feel free to critique my code.

Thanks!

You create a new Rc here and downgrade it immediately. This reduces the strong count from 1 to 0, effectively dropping the value immediately. You need to downgrade a clone of the Rc that stores the parent node, not create a new Rc. This might be fixable just by removing the Rc::new call here, but I haven't read your snippets fully.

4 Likes

I am not sure how to do that because I create scalar in the current scope which is not wrapped in an Rc. Removing Rc::new() gives an error because I am trying to downgrade a Scalar instead.

Yes, I realized after I created a playground from your snippets. Here is a possible approach where we link nodes to each other rather than linking nodes to scalars.

2 Likes

You have a design problem here; Scalar is internally reference counted, and thus you can't get a weak reference to your parent, only a strong reference, because the reference counting is not exposed.

There's two ways out of this:

  1. Move the Rc outside struct Scalar, and implement things for Rc<Scalar> instead of for Scalar, leaving Node unchanged.
  2. Add a WeakScalar struct, with promotion to Scalar, and exploit the fact that Scalar is internally reference counted.

For the second approach, you'd write something like:

#[derive(Clone, Debug)]
pub struct Scalar {
    node: Rc<RefCell<Node>>,
}

impl Scalar {
    pub fn downgrade(this: &Scalar) -> WeakScalar {
        let node = Rc::downgrade(this.node);
        WeakScalar { node }
    }
}

#[derive(Clone, Debug)]
struct WeakScalar {
    node: Weak<RefCell<Node>>,
}

impl WeakScalar {
    pub fn upgrade(&self) -> Option<Scalar> {
        self.node.upgrade().map(|node| Scalar { node })
    }
}

#[derive(Clone, Debug)]
struct Node {
    value: f32,
    children: Vec<Scalar>,
    parent: WeakScalar,
    label: String,
}
3 Likes

Thank you! This works. I never realised that Node needs to be an Rc instead of the Scalar. Is chaining two RefCells common in Rust?

Thank you! I understand what you are saying but I am not sure I understand what the code is doing. I will have to play around with it.

There is no chaining of RefCells going on as we only ever create clones of the outer Rc. All clones of an Rc point to the same RefCell. You can abstract the Rc<RefCell<Node>> type into a type declaration like:

type NodeOuter = Rc<RefCell<Node>>;
type NodeOuterWeak = Weak<RefCell<Node>>;

#[derive(Clone, Debug)]
struct Node {
    value: f32,
    children: Vec<NodeOuter>,
    parent: NodeOuterWeak,
    label: String,
}

Some find that to be nicer to read.

2 Likes

To make it easier to play with, I've put the second option in the Rust Playground - I've changed the assert at the end of the test case, since it was firing saying that a had no parent, which it now does (c is a parent of a in the test case).

Note that I've only considered the reference counting in this code - in a full code review, I'd be highly suspicious of with_label because (a) it returns a value, but modifies the owned value, too, and (b) it takes &str only to convert it unconditionally to String, which is normally an anti-pattern (just take String and let the caller handle the conversion when needed). And I'm not convinced by the parent naming - but I'd need to see the full algorithm to know whether I'm being overly cautious, or if it's a problematic name.

2 Likes

Thanks! Seeing the full code makes sense to me now.

That is an artifact from working with tensorflow where each layer is capable of being assigned a name. Except in my case, I am working with individual nodes. with_label() seemed to me to be a convenience function to give a node a new name while it is created from a mathematical operation. Would this be a better approach?

let c: Scalar = &a + &b;
c.add_label("c");

As for your second point (b), how would you approach labelling a node? Store it as &str?

I agree, the names are too generic. I need parent and children is so I can compute functions going two ways. Take for example g = a + b * c:
intermediate_1 = b * c
intermediate_2 = a + intermediate_1
g = a + intermediate_2
Then, in the backward pass, I compute the gradients:
dg/dg = 1
dg/d intermediate_2 = 1
d intermediate_1/db = c
and so on. Normally all this is vectorised with hundreds if not thousands of nodes per layer (and one deals at the layer abstraction instead of the individual node abstraction) so these effectively become matrix multiplications. I think I digress. The reason I chose those names were because they were in the Rust book and I wanted to get this tree data structure to work.

Thank you for a solution and also for doing a brief code review.

I'd make two changes to with_label, assuming the use in the test case is typical:

First, make it take String instead of &str:

    pub fn with_label(&self, label: String) -> Self {
        {
            self.node.borrow_mut().label = label;
        }
        self.clone()
    }

This means that a caller looks like: c.with_label("c".into()).

Second, I'd change it to be a move of this value, not a clone function:

    pub fn with_label(self, label: String) -> Self {
        self.node.borrow_mut().label = label;
        self
    }

If you want a clone, you'd call a.clone().with_label("a".into()) or similar.

These changes are in this Rust Playground, with all three options for converting &str to String used for you to compare. The point of moving the conversion of &str to String out of the function is to avoid copying labels where you're generating them programmatically - for example, if you label something with format!("{} + {}", a.label(), b.label()), you don't want to create a String then clone it, then deallocate it, when you can create it and move it into place.

With the label changes made, I end up with this Playground

2 Likes

Is this to signify intent that I want to change self? Or is there a more functional difference to taking &self to return self.clone()?

Because you're reference-counted, this is all about making the semantics of the operation clear to a later reader.

Semantically, your uses of with_label are not "take a shared reference to a Scalar that belongs to someone else, do something with it", but instead "take ownership of this Scalar and return a new Scalar that's the same as the owned one, but with a label".

And part of the reason I'm thinking this way is that as I look at your code, I get a strong sense that #[derive(Clone)] is doing the wrong thing for Scalar, because the Rc is an implementation detail needed for automatic differentiation. I suspect (but I'd need to see much more code to be confident of this) that if you have need of a Scalar::clone() in user code, the expectation is going to be that this is a fresh Scalar with the same value, not a reference to the existing Scalar.

By making the changes to remove self.clone() wherever possible, I open up a refactor to hand-implement Clone such that it "disconnects" the copy from the original.

1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.