I've managed to improve the Rust entry for this benchmark:
Related to:
This code is faster, shorter (and without crates dependencies of JSON libs):
// Compile with: -C opt-level=3
#![feature(time2, hashmap_hasher)]
use std::collections::HashMap;
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::hash_state::DefaultState;
use std::default::Default;
use std::hash::{Hash, Hasher};
use std::mem::transmute;
use std::ops::{Add, Sub};
#[derive(PartialEq, Copy, Clone)]
struct Point(f64, f64);
fn sq(x: f64) -> f64 { x * x }
impl Point {
fn norm(self: &Point) -> f64 {
(sq(self.0) + sq(self.1)).sqrt()
}
}
impl Add for Point {
type Output = Point;
fn add(self, other: Point) -> Point {
Point(self.0 + other.0, self.1 + other.1)
}
}
impl Sub for Point {
type Output = Point;
fn sub(self, other: Point) -> Point {
Point(self.0 - other.0, self.1 - other.1)
}
}
impl Eq for Point {}
impl Hash for Point {
fn hash<H: Hasher>(&self, state: &mut H) {
// Perform a bitwise transform, relying on the fact that we
// are never Infinity or NaN
let Point(x, y) = *self;
let x: u64 = unsafe { transmute(x) };
let y: u64 = unsafe { transmute(y) };
x.hash(state);
y.hash(state);
}
}
type FnvHashMap<K, V> = HashMap<K, V, DefaultState<FnvHasher>>;
#[derive(Copy, Clone, Debug)]
pub struct FnvHasher(u64);
impl Default for FnvHasher {
fn default() -> FnvHasher { FnvHasher(0xcbf29ce484222325) }
}
impl Hasher for FnvHasher {
fn write(&mut self, bytes: &[u8]) {
let FnvHasher(mut hash) = *self;
for byte in bytes {
hash = hash ^ (*byte as u64);
hash = hash.wrapping_mul(0x100000001b3);
}
*self = FnvHasher(hash);
}
fn finish(&self) -> u64 { self.0 }
}
fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }
fn avg(points: &[Point]) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;
Point(x / k, y / k)
}
fn closest(x: Point, ys: &[Point]) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0), |(m, p), &q| {
let d = dist(q, x);
if d < m { (d, q) } else { (m, p) }
});
y
}
fn clusters(xs: &[Point], centroids: &[Point]) -> Vec<Vec<Point>> {
let mut groups: FnvHashMap<Point, Vec<Point>> = Default::default();
for &x in xs.iter() {
let y = closest(x, centroids);
// Notable change: avoid double hash lookups.
match groups.entry(y) {
Occupied(entry) => { entry.into_mut().push(x); },
Vacant(entry) => { entry.insert(vec![x]); }
}
}
groups.into_iter().map(|(_, v)| v).collect::<Vec<Vec<Point>>>()
}
fn run(points: &[Point], n: u32, iters: u32) -> Vec<Vec<Point>> {
let mut centroids: Vec<Point> = points.iter().take(n as usize).cloned().collect();
for _ in 0 .. iters {
centroids = clusters(points, ¢roids).iter().map(|g| avg(&g)).collect();
}
clusters(points, ¢roids)
}
//-----------------------------
fn benchmark(points: &[Point], times: i32) -> f64 {
use std::time::Instant;
let start = Instant::now();
for _ in 0 .. times {
run(points, 10, 15);
}
let end = Instant::now();
let diff = end.duration_from_earlier(start);
let t = diff.as_secs() as f64 + (diff.subsec_nanos() as f64 / 1_000_000_000.0);
t / times as f64
}
fn main() {
use std::fs::File;
use std::io::Read;
let mut data = String::new();
File::open("points.json").unwrap().read_to_string(&mut data).unwrap();
let points = data
.trim_matches(&['[', ']'][..])
.split("],[")
.map(|r| r.split(",").map(|p| p.parse::<f64>().unwrap()))
.map(|mut r| Point(r.next().unwrap(), r.next().unwrap()))
.collect::<Vec<Point>>();
let iterations = 100;
println!("The average time is {} ms", benchmark(&points, iterations) * 1000.0);
}
But it gives me two warnings (using a recent nightly rustc):
...\kmeans_rs3.rs:7:5: 7:47 warning: use of deprecated item: support moved to std::hash, #[warn(deprecated)] on by default
...\kmeans_rs3.rs:7 use std::collections::hash_state::DefaultState;
^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...\kmeans_rs3.rs:54:39: 54:62 warning: use of deprecated item: support moved to std::hash, #[warn(deprecated)] on by default
...\kmeans_rs3.rs:54 type FnvHashMap<K, V> = HashMap<K, V, DefaultState>;
^~~~~~~~~~~~~~~~~~~~~~~
How can I fix this code? (And other suggestions are welcome, but keep in mind that most of this code wasn't written by me).
By the way, in general Rust hashing ergonomy is terrible, it needs to be improved a lot. The above code is full of magic incantations and boilerplate code to hash some Points. And not being able to hash f64 on default isn't acceptable, some simpler solution is needed.
Once the two warnings are fixed, I think we can submit this code as pull request for andreaferretti/kmeans.