I have own Matrix implementation and it takes a lot of memory for .dot() operation:
use std::fmt::Debug;
use std::io::{Error, ErrorKind};
use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
use rand::{Rng, thread_rng};
use rand::distributions::uniform::SampleUniform;
trait Numeric:
Div<Output = Self>
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Default
+ Debug
+ PartialOrd
+ SampleUniform
+ Clone
+ Copy
+ Neg<Output = Self>
+ AddAssign
+ SubAssign {}
impl<T> Numeric for T where T:
Div<Output = Self>
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Default
+ Debug
+ PartialOrd
+ SampleUniform
+ Clone
+ Copy
+ Neg<Output = Self>
+ AddAssign
+ SubAssign {}
macro_rules! impl_matrix_ops_for_scalar {
($($T:ty),*) => {
$(
impl Add<Matrix<$T>> for $T {
type Output = Matrix<$T>;
fn add(self, mut matrix: Matrix<$T>) -> Self::Output {
matrix.assign(matrix.data.iter().map(|&item| item + self).collect()).unwrap();
matrix
}
}
impl Sub<Matrix<$T>> for $T {
type Output = Matrix<$T>;
fn sub(self, mut matrix: Matrix<$T>) -> Self::Output {
matrix.assign(matrix.data.iter().map(|&item| self - item).collect()).unwrap();
matrix
}
}
impl<'a> Sub<&'a mut Matrix<$T>> for $T {
type Output = Matrix<$T>;
fn sub(self, mut matrix: &'a mut Matrix<$T>) -> Self::Output {
matrix.assign(matrix.data.iter().map(|&item| self - item).collect()).unwrap();
matrix.clone()
}
}
impl Mul<Matrix<$T>> for $T {
type Output = Matrix<$T>;
fn mul(self, mut matrix: Matrix<$T>) -> Self::Output {
matrix.assign(matrix.data.iter().map(|&item| item * self).collect()).unwrap();
matrix
}
}
impl Div<Matrix<$T>> for $T {
type Output = Matrix<$T>;
fn div(self, mut matrix: Matrix<$T>) -> Self::Output {
matrix.assign(matrix.data.iter().map(|&item| self / item).collect()).unwrap();
matrix
}
}
)*
};
}
impl_matrix_ops_for_scalar!(f64, f32, i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, usize);
#[derive(Debug, Clone, Default)]
pub struct Shape {
pub rows: usize,
pub cols: usize,
}
#[derive(Debug, Clone, Default)]
pub struct Matrix<N> {
pub shape: Shape,
data: Vec<N>,
}
impl<N> Matrix<N> where N: Default + Copy + Clone + Debug {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
shape: Shape { rows, cols },
data: (0..rows * cols).map(|_| N::default()).collect(),
}
}
pub fn new_empty(rows: usize, cols: usize) -> Self {
Self {
shape: Shape { rows, cols },
data: vec![],
}
}
pub fn assign(&mut self, data: Vec<N>) -> Result<(), Error> {
if data.len() != self.shape.rows * self.shape.cols {
return Err(Error::new(ErrorKind::Other, "Dataset size is different from matrix size"));
}
self.data = data;
Ok(())
}
pub fn get_item(&mut self, row: usize, col: usize) -> Option<&mut N> {
self.data.get_mut(row * self.shape.cols + col)
}
pub fn get_row(&mut self, row: usize) -> &mut [N] {
&mut self.data[row * self.shape.cols .. row * self.shape.cols + self.shape.cols]
}
pub fn transpose(&mut self) -> Self {
let mut output = Matrix::new(self.shape.cols, self.shape.rows);
for i in 0..self.shape.rows {
for j in 0..self.shape.cols {
*output.get_item(j, i).unwrap() = self.get_item(i, j).unwrap().clone()
}
}
output
}
}
impl<N> Matrix<N> where N: Numeric {
pub fn randomize(mut self, min: N, max: N) -> Self {
let mut rng = thread_rng();
self.data = (0..self.shape.rows * self.shape.cols).map(|_| rng.gen_range(min..max)).collect();
self
}
pub fn dot(&mut self, other: &mut Self) -> Self {
assert_eq!(
self.shape.cols, other.shape.rows,
"Cannot multiply matrices A ({:?}) and B ({:?}), \
please check first matrix cols amount equals to second matrix rows amount",
self.shape, other.shape
);
let mut output = Matrix::new_empty(self.shape.rows, other.shape.cols);
for i in 0..self.shape.rows {
for j in 0..output.shape.cols {
let mut item = N::default();
for k in 0..other.shape.rows {
item += self.get_item(i, k).unwrap().clone() * other.get_item(k, j).unwrap().clone();
}
output.data.push(item);
}
}
output
}
}
I implemented minimal example to reproduce.
The error I got is next:
memory allocation of 42949672960 bytes failed
To provide more context: I use this matrix in my own neural network implementation.
Could somebody help me to get what I did wrong ?
P.S. I know I can use ndarray
and so on, but I am interested how it works under the hood (and also I am interested how to fix memory leak issues).