As a private project for fun, I chose to implement a fully-connected neural network in Rust. To be precise, I'm implementing a multilayer perceptron, which is one of the most basic neural networks available. My goal is to create for a given network (and the features of my computer), the fastest possible feed-forward function. It's my first time using SIMD and I don't know a whole lot about it. I think I created a fairly fast function, but would like to know, if there's anything that I can improve performance-wise.
Aside from performance, making my code less dependent on my specific computer's features while keeping the same performance is my secondary goal. I don't mind using third-party packages for that purpose.
My third goal is to create more ergonomic Rust code. For example, the usage of Cell
is more or less a band-aid, right now, because Rust doesn't have a slice::windows_mut
function.
Thanks in advance for any tips!
#![no_implicit_prelude]
use ::std::arch::x86_64::__m256;
use ::std::arch::x86_64::_mm256_add_ps;
use ::std::arch::x86_64::_mm256_mul_ps;
use ::std::arch::x86_64::_mm256_rcp_ps;
use ::std::arch::x86_64::_mm256_set1_ps;
use ::std::arch::x86_64::_mm256_setzero_ps;
use ::std::boxed::Box;
use ::std::cell::Cell;
use ::std::eprintln;
use ::std::iter::IntoIterator;
use ::std::iter::Iterator;
use ::std::time::Instant;
use ::std::vec;
fn main() {
unsafe {
let mut edges: Box<[Box<[EdgeVector]>]> = vec![
vec![
EdgeVector {
weight: _mm256_set1_ps(1.0)
};
4
]
.into_boxed_slice(),
vec![
EdgeVector {
weight: _mm256_set1_ps(1.0)
};
2
]
.into_boxed_slice(),
]
.into_boxed_slice();
let mut vertices: Box<[Box<[VertexVector]>]> = vec![
vec![
VertexVector {
value: Cell::new(_mm256_set1_ps(1.0))
};
2
]
.into_boxed_slice(),
vec![
VertexVector {
value: Cell::new(_mm256_setzero_ps())
};
2
]
.into_boxed_slice(),
vec![
VertexVector {
value: Cell::new(_mm256_setzero_ps())
};
1
]
.into_boxed_slice(),
]
.into_boxed_slice();
let instant = Instant::now();
feed_forward(&mut vertices, &mut edges);
eprintln!("feed_forward: {} ns", instant.elapsed().as_nanos());
}
}
unsafe fn feed_forward(vertices: &mut [Box<[VertexVector]>], edges: &mut [Box<[EdgeVector]>]) {
for (layers, weight_matrix) in vertices.windows(2).zip(edges.into_iter()) {
let layer = layers.get_unchecked(0);
let next_layer = layers.get_unchecked(1);
let next_layer_len = next_layer.len();
let mut start = 0;
for source_vertex_vec in layer.into_iter() {
for (edge_vec, target_vertex_vec) in (&weight_matrix[start..start + next_layer_len])
.into_iter()
.zip(next_layer.into_iter())
{
target_vertex_vec.value.set(_mm256_add_ps(
target_vertex_vec.value.get(),
_mm256_mul_ps(source_vertex_vec.value.get(), edge_vec.weight),
));
}
start += next_layer_len;
}
for target_vertex_vec in next_layer.into_iter() {
target_vertex_vec
.value
.set(_mm256_lgc_ps(target_vertex_vec.value.get()));
}
}
}
#[derive(Clone)]
struct EdgeVector {
weight: __m256,
}
#[derive(Clone)]
struct VertexVector {
value: Cell<__m256>,
}
/// Compute the logistic value of packed single-precision (32-bit) floating-point elements in a and store the results in dst.///
/// <b>Operation</b>
/// ```
/// FOR j := 0 to 7
/// i := j*32
/// dst[i+31:i] := lgc(a[i+31:i])
/// ENDFOR
/// dst[MAX:256] := 0
/// ```
#[inline]
unsafe fn _mm256_lgc_ps(a: __m256) -> __m256 {
_mm256_rcp_ps(_mm256_add_ps(
_mm256_exp_ps(_mm256_mul_ps(a, _mm256_set1_ps(-1.0))),
_mm256_set1_ps(1.0),
))
}
/// Compute the exponential value of e raised to the power of packed single-precision (32-bit) floating-point elements in a, and store the results in dst.
///
/// <b>Operation</b>
/// ```
/// FOR j := 0 to 7
/// i := j*32
/// dst[i+31:i] := e^(a[i+31:i])
/// ENDFOR
/// dst[MAX:256] := 0
/// ```
///
/// <b>References</b>
/// [IntelĀ® Intrinsics Guide](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_exp_ps&expand=2273)
#[inline]
unsafe fn _mm256_exp_ps(a: __m256) -> __m256 {
let mut scalar = F32x8 { simd: a }.scalar;
for float in &mut scalar {
*float = float.exp();
}
F32x8 { scalar }.simd
}
#[repr(C)]
union F32x8 {
simd: __m256,
scalar: [f32; 8],
}