Should I use a refcell to share things between all instances of a class?

Hi all! I was wondering if I could get some help on the following issue I've encountered. I'm learning rust and the project I'm working on is writing a library that involves something like the following:


#[derive(Debug, PartialEq)]
pub enum TensorType {
    ScalarT(f64),
    VectorT(Vec<f64>),
    MatrixT(Vec<Vec<f64>>),
}

#[derive(Debug, Clone)]
pub struct Matrix {
    // Forward pass computations
    pub data: TensorType,
    pub shape: (usize, usize),

    // Used to determine if we should apply a pullback or stop going back through the chain here
    pub trainable: bool,
    pub index: usize, // this current matrix's place in the gradient tape
    
    // the parent(s) if any. Source nodes do not have parents and some nodes only have one parent
    // such as in the case of a unary op
    pub parents: [Option<usize>; 2],
    pub pullbacks: [Option<Pullback>; 2]
    pub gradient_tape: Vec<Matrix>
}

This is related to automatic differentiation. Basically, each time we create a matrix we want to "store" it on some chain of data, which allows us to know about a matrix's parents so that we can pass variables back through the chain. The issue I'm running into now is the following:

impl Matrix{
        pub fn from_scalar(v: f64) -> Self {
        let res = Self {
            data: ScalarT(v),
            shape: (1, 1),
            trainable: true,
            index: ??
            parents: [None, None]
            pullbacks: [None, None]
            gradient_tape: ??
        };

        res
    }
}

and the ?? sections highlight my problem. Ideally, all created instances would have access to the shared chain so that I could do something like

index: gradient_tape.len()
gradient_tape: persistent_gradient_tape

but I can't. Could a refcell address my issue? Sorry, I'm still relatively new to all of this. Ideally an end-user would have a workflow that looks something like

let m1 = Matrix::from_scalar(5.0);  
let m2 = Matrix::from_vector(vec![1.0, 2.0, 3.0]);
let m3 = m2 * m1; 

such that all 3 matrices, m1, m2, m3 all have access to the same gradient_tape which is now of length 3.

I'm happy to restructure the program with your insights. Thank you!

RefCell should be used when you need "interior mutability", i.e., mutating shared values (see std::cell - Rust). This isn't directly relevant here.

My first remark is: are you 1000% sure that you want all the created matrices to be part of one same chain? That the user of your library will absolutely never want to have two separate matrix chains? If so, you can store the Vec<Matrix> as a static variable (what is called a "global variable") in other languages. However, in my experience, static/global variables are rarely needed and seldom the right thing, except maybe for certain kinds of caches, or in high-level languages to provide a very simple interface in addition to a more full-fledged interface. If a static/global variable is not appropriate, then you should store the vector separately. Consider creating a pub struct GradientTape { pub contents: Vec<Matrix> } with methods to create matrices linked to your "gradient tape".

By storing a Vec<Matrix> in the Matrix struct, you're necessarily making different matrices use different vectors (because a field of type Vec<Matrix> owns the vector, and a value cannot have more than a single owner). To share ownership, you would use &Vec<Matrix> or Rc<Matrix>, for example. But this still doesn't work because you're trying to build a self-referential data structure. Essentially, if the Vec<Matrix> gets moved, then the references to it inside the Matrix elements it contains would need to be updated too (since the vector's address changes), but Rust does not allow you to customize what happens when moving a value. The right search keywords are "Rust self-referential structure". There are complicated ways around this restriction (std::pin - Rust), but basically you should just avoid this pattern as much as possible.

Bottom line: don't do that — instead, make your users manipulate a GradientTape directly.

Also, instead of [Option<usize>; 2], I would personally use an enum Parents { SourceNode, Unary(usize), Binary(usize) }.

Refcell

RefCell should be used when you need "interior mutability", i.e., mutating shared values (see std::cell - Rust). This isn't directly relevant here

Oh, I guess I misread through the docs then. There's this example in the Rc<T> guide

and I assumed that because I might need to change the shared data that Rc<T> was necessary then.

Global Chain

My first remark is: are you 1000% sure that you want all the created matrices to be part of one same chain?

There are parts that we can optimize out e.g. computations that produce many intermediate matrices and have known easy derivatives can just be done in "raw" computations and the appropriate pullbacks will be applied.

Users manually manipulate the chain

Consider creating a pub struct GradientTape { pub contents: Vec<Matrix> } with methods to create matrices linked to your "gradient tape".
... make your users manipulate a GradientTape directly.

  • So, part of the issue here is that we, the library users, might want to manipulate the chain too, which allows us to provide higher-level abstractions. For example, a commonly used function in ML is the categorical cross entropy and having the user define it is a little clunky since it just "makes sense" to have a function already provided for it. Also, finding the derivative of this isn't necessarily difficult, so it isn't a great example of the possible complexity, but we see how it is composed of many smaller calculations with well-known derivatives that we could just add to the chain and let the system automatically differentiate for us.

Also, instead of [Option<usize>; 2], I would personally use an enum Parents { SourceNode, Unary(usize), Binary(usize) }.

Noted! Thank you


Easy solution here is to just have the function to calculate cross entropy be defined on impl GradientTape instead of impl Vec<Matrix>.

Ahh, yeah, I suppose that this would be the most straightforward thing. Thank you for your insight!

One quick followup actually. If I were to place all the operations in the scope of the GradientTape, the file would grow extremely large. Is there a way for me to separate out the operations? So that I can have something like

losses
primitives
weight_initialization
...

where primitives are things like +, -, /, * and losses, etc. build off of the primitives

Thank you!

Sure, there is no problem with having several impl GradientTape { ... } in different files with different functions. So you can put impl Add for GradientTape { ... } and the like in primitives.rs, impl GradientTape { ... } with some methods in losses.rs, and impl GradientTape { ... } with different methods in weight_initialization.rs.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.