Say I'm building a library and have the following
// lib.rs
static TAPE: Mutex<GradientTape> = Mutex::new(GradientTape::new());
Creating a static variable in my library is a workaround, and for the moment, I'm just trying to get something out the door.
// gradient_tape.rs
pub struct GradientTape{
pub nodes: Vec<Matrix>
}
where the gradient tape stores a bunch of matrices
//matrix.rs
#[derive(Debug, Clone)]
pub struct Matrix {
...
pub pullbacks: [Option<Pullback>; 2],
}
The reason I'm using static TAPE
is that every time I'm creating a new matrix (e.g from the dot-product or subtraction, etc.) of current matrices, I want it to be added to the running tape. If I don't do it this way, the code gets very clunky. (See bottom of post for more on this)
Anyways, on to the issue: my pullbacks have the following defined for when I need to clone them
pub(crate) type _PullbackNoArg = Box<dyn ClonableFnNoArg<Output=Matrix>>;
pub(crate) enum Pullback {
PullbackSingleArg(_Pullback),
PullbackNoArg(_PullbackNoArg),
}
pub(crate) trait ClonableFnNoArg: FnOnce() -> Matrix {
fn clone_box(&self) -> Box<dyn ClonableFnNoArg<Output=Matrix>>;
}
impl Clone for Box<dyn ClonableFnNoArg<Output=Matrix>> {
fn clone(&self) -> Self {
<dyn ClonableFnNoArg<Output=Matrix>>::clone_box(&**self)
}
}
impl<T> ClonableFnNoArg for T
where
T: 'static + FnOnce() -> Matrix + Clone,
{
fn clone_box(&self) -> Box<dyn ClonableFnNoArg<Output=Matrix>> {
Box::new(T::clone(self))
}
}
but I'm encountering the issue of
(dyn ClonableFnNoArg<Output = Matrix> + 'static)
cannot be sent between threads safely [E0277] Help: the trait Send
is not implemented for (dyn ClonableFnNoArg<Output = Matrix> + 'static)
so I try to define Send
unsafe impl Send for Box<dyn ClonableFnNoArg<Output = Matrix> + 'static>{}
but I then receive cross-crate traits with a default impl, like
Send, can only be implemented for a struct/enum type defined in the current crate [E0321] can't implement cross-crate trait for type in another crate
I don't really understand the issue- I'm not trying to implement a cross-crate trait?
Clunky code
I know it's possible to define some of the operations within the scope of GradientTape
such that I can have the user still add information to it, but it doesn't apply in all cases.
One case where it can apply:
// matrix.rs
impl Matrix{
...
pub fn dot(lhs: Matrix, rhs: Matrix, index: usize) -> Matrix {...}
...
}
// gradient_tape.rs
impl GradientTape{
...
pub fn dot(&self, lhs: Matrix, rhs: Matrix) -> Matrix{
let res = Matrix::dot(lhs, rhs, self.nodes.len());
self.nodes.push(res);
res
}
...
}
but there are some cases where I don't think it applies e.g. in the case of where I define Add
impl ops::Add<Self> for Matrix {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
assert!(Self::is_broadcastable(&self, &rhs), "Could not broadcast addition between: {self:?} and {rhs:?}");
Self::_broadcast_op(&self, &rhs, Op::Add)
}
}
Or maybe I'm missing something, and I can achieve this through GradientTape