Here's a version with fixed patch indexing, then:
#![feature(test)]
extern crate netcdf;
#[macro_use(s)]
extern crate ndarray;
extern crate rayon;
extern crate test;
use netcdf::variable::Variable;
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, IxDyn, Axis};
use rayon::prelude::*;
use test::Bencher;
fn spawn(mut split: ArrayViewMutD<f32>, var: &Variable, offset: usize, patches: &[[usize;2]]) {
for loc_i in 0..split.len_of(Axis(0)) {
let glob_i = loc_i + offset;
split.slice_mut(s![loc_i, .., .., ..])
.assign(&var.array_at(&[glob_i, 0, patches[glob_i % 30][0], patches[glob_i % 30][1]],
&[9, 1, 64, 64]).unwrap().remove_axis(Axis(1)));
}
}
fn get_batch_pool(var: &Variable, offset: usize, batch_size: usize, patches: &[[usize;2]]) -> ndarray::ArrayD<f32> {
let mut batch = ArrayD::<f32>::zeros(IxDyn(&[batch_size, 9, 64, 64]));
{
// We will parallelize the inner loop using a TBB-like splittable task
struct Splittable<'a> {
view: ArrayViewMutD<'a, f32>,
offset: usize,
}
let split_data = Splittable { view: batch.view_mut(), offset };
// Here is how we split one of our tasks
fn split_fun(s: Splittable) -> (Splittable, Option<Splittable>) {
// We split data along the first axis, if it has >= 2 items
let size = s.view.len_of(Axis(0));
if size < 2 { return (s, None); }
let half_size = size / 2;
let (view1, view2) = s.view.split_at(Axis(0), half_size);
// The offset is preserved across splits for NetCDF interaction
let s1 = Splittable { view: view1, offset: s.offset };
let s2 = Splittable { view: view2, offset: s.offset + half_size };
(s1, Some(s2))
};
// Run a parallel loop using this strategy
rayon::iter::split(split_data, split_fun).for_each(|s| {
spawn(s.view, var, s.offset, patches);
});
}
batch
}
fn main() {
let batch_size = 100;
let patches = [[60, 144], [60, 204], [60, 264], [60, 324], [60, 384],
[60, 444], [60, 504], [60, 564], [60, 624], [120, 144],
[120, 204], [120, 264], [120, 324], [120, 384], [120, 444],
[120, 624], [180, 204], [180, 264], [180, 324], [180, 384],
[180, 444], [180, 504], [180, 564], [180, 624], [240, 324],
[240, 384], [240, 444], [240, 504], [240, 564], [240, 624]];
let file = netcdf::open("original_pre_train_0.nc").unwrap();
let temp = file.root.variables.get("thetao").unwrap();
println!("{} - {}", temp.dimensions[0].name, temp.dimensions[0].len);
println!("{} - {}", temp.dimensions[1].name, temp.dimensions[1].len);
println!("{} - {}", temp.dimensions[2].name, temp.dimensions[2].len);
println!("{} - {}", temp.dimensions[3].name, temp.dimensions[3].len);
let mut batches = Vec::new();
for o in 0..300 {
println!("{}", o);
let batch = get_batch_pool(temp, o * 5, batch_size, &patches[..]);
batches.push(batch);
}
println!("{:?}", batches[0].shape());
}