Story
I just created an account to reply to Elegantly flatten Vec<Result<Vec<T>, E>> into Result<Vec<T>, E> but then I found that it was closed. This is why I post it here
The two methods that were suggested are the following:
fn flatten_fold<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
outer.into_iter().try_fold(vec![], |mut unrolled, result| {
unrolled.extend(result?);
Ok(unrolled)
})
}
fn flatten_itertools<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
itertools::process_results(outer, |i| i.flatten().collect())
}
I didn't dive deep into how itertools::process_results
and only show its benchmark results.
But the problems of flatten_fold
are the following:
- No pre-allocation
- Fails too late if the location of the first error in the nested vectors is too far
Therefore, I pulled out my magic box unsafe
and wrote a better function which is ~300x faster in the worst case which is when the last element in the outer vector is an error:
fn flatten_unsafe<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
let mut len = 0;
let mut err = false;
for result in &outer {
match result {
Ok(inner) => len += inner.len(),
Err(_) => {
err = true;
break;
}
}
}
if err {
for result in outer {
result?;
}
// Safety: Can't be reached since we found an error in the vector `outer`.
unsafe { unreachable_unchecked() };
}
let mut flat_v = Vec::with_capacity(len);
let mut ptr = flat_v.as_mut_ptr();
for result in outer {
// Safety: If any element of `outer` is an error, we would have found it above.
for value in unsafe { result.unwrap_unchecked() } {
// Safety: The capacity is set as the total length.
unsafe { ptr::write(ptr, value) };
ptr = unsafe { ptr.add(1) };
}
}
// Safety: The calculated total length.
unsafe { flat_v.set_len(len) };
Ok(flat_v)
}
BTW: Do you know of a better way to return the error in the first loop if it doesn't implement Clone
?
Benchmarks
I use divan for the benchmarks.
Let's start with the mentioned worst case:
use divan::{bench, Bencher};
use flatten::{flatten_fold, flatten_itertools, flatten_unsafe};
use std::hint::black_box;
const LENS: &[usize] = &[4, 16, 64, 246, 1024, 4096];
fn main() {
divan::main()
}
fn bench_v(len: usize) -> Vec<Result<Vec<u64>, ()>> {
let mut v = vec![Ok(vec![0; 1024]); len];
v[len - 1] = Err(());
v
}
#[bench(consts = LENS)]
fn bench_flatten_fold<const N: usize>(bencher: Bencher) {
bencher
.with_inputs(|| bench_v(N))
.bench_values(|v| black_box(flatten_fold(black_box(v))))
}
#[bench(consts = LENS)]
fn bench_flatten_itertools<const N: usize>(bencher: Bencher) {
bencher
.with_inputs(|| bench_v(N))
.bench_values(|v| black_box(flatten_itertools(black_box(v))))
}
#[bench(consts = LENS)]
fn bench_flatten_unsafe<const N: usize>(bencher: Bencher) {
bencher
.with_inputs(|| bench_v(N))
.bench_values(|v| black_box(flatten_unsafe(black_box(v))))
}
Results:
bench fastest β slowest β median β mean β samples β iters
ββ bench_flatten_fold β β β β β
β ββ 4 2.604 Β΅s β 13.27 Β΅s β 2.665 Β΅s β 2.833 Β΅s β 100 β 100
β ββ 16 52.08 Β΅s β 123.7 Β΅s β 56.5 Β΅s β 58.46 Β΅s β 100 β 100
β ββ 64 14.14 Β΅s β 331.8 Β΅s β 19.82 Β΅s β 24.43 Β΅s β 100 β 100
β ββ 246 858.1 Β΅s β 1.34 ms β 866.5 Β΅s β 879 Β΅s β 100 β 100
β ββ 1024 4.103 ms β 6.132 ms β 4.189 ms β 4.248 ms β 100 β 100
β β°β 4096 21.67 ms β 23.48 ms β 21.86 ms β 21.92 ms β 100 β 100
ββ bench_flatten_itertools β β β β β
β ββ 4 9.497 Β΅s β 17.16 Β΅s β 9.517 Β΅s β 9.954 Β΅s β 100 β 100
β ββ 16 47.77 Β΅s β 93.53 Β΅s β 48.28 Β΅s β 50.79 Β΅s β 100 β 100
β ββ 64 201.3 Β΅s β 408.3 Β΅s β 203.8 Β΅s β 208.7 Β΅s β 100 β 100
β ββ 246 746 Β΅s β 1.573 ms β 754.4 Β΅s β 776.1 Β΅s β 100 β 100
β ββ 1024 6.674 ms β 7.772 ms β 6.704 ms β 6.778 ms β 100 β 100
β β°β 4096 31.29 ms β 33.21 ms β 31.52 ms β 31.69 ms β 100 β 100
β°β bench_flatten_unsafe β β β β β
ββ 4 63 ns β 145.9 ns β 66.13 ns β 67.1 ns β 100 β 3200
ββ 16 212.7 ns β 683.5 ns β 218.8 ns β 254.5 ns β 100 β 800
ββ 64 833.7 ns β 2.349 Β΅s β 856.3 ns β 983.3 ns β 100 β 400
ββ 246 3.034 Β΅s β 6.201 Β΅s β 3.135 Β΅s β 3.257 Β΅s β 100 β 100
ββ 1024 12.29 Β΅s β 19.69 Β΅s β 12.79 Β΅s β 13.35 Β΅s β 100 β 100
β°β 4096 1.988 ms β 3.045 ms β 2.027 ms β 2.052 ms β 100 β 100
Are you looking for the click bait? Compare the mean values for N=1024: 4248/13.35 = 318.2
I know, this is obviously the worst case. How about the best case? (no errors)
Well, just comment out the line that sets the last element to Err(())
in the bench_v
function:
bench fastest β slowest β median β mean β samples β iters
ββ bench_flatten_fold β β β β β
β ββ 4 2.644 Β΅s β 14.16 Β΅s β 2.704 Β΅s β 2.968 Β΅s β 100 β 100
β ββ 16 43.48 Β΅s β 101.3 Β΅s β 48.13 Β΅s β 48.81 Β΅s β 100 β 100
β ββ 64 15.04 Β΅s β 339.5 Β΅s β 21.68 Β΅s β 27.39 Β΅s β 100 β 100
β ββ 246 702.1 Β΅s β 1.321 ms β 714.7 Β΅s β 735.7 Β΅s β 100 β 100
β ββ 1024 3.321 ms β 6.055 ms β 3.372 ms β 3.44 ms β 100 β 100
β β°β 4096 20.09 ms β 21.88 ms β 20.25 ms β 20.37 ms β 100 β 100
ββ bench_flatten_itertools β β β β β
β ββ 4 11.62 Β΅s β 12.28 Β΅s β 12.19 Β΅s β 11.95 Β΅s β 100 β 100
β ββ 16 48.55 Β΅s β 63.11 Β΅s β 48.88 Β΅s β 49.28 Β΅s β 100 β 100
β ββ 64 195.7 Β΅s β 410 Β΅s β 197.2 Β΅s β 203.7 Β΅s β 100 β 100
β ββ 246 733.1 Β΅s β 1.52 ms β 738.4 Β΅s β 751 Β΅s β 100 β 100
β ββ 1024 5.764 ms β 6.672 ms β 5.794 ms β 5.842 ms β 100 β 100
β β°β 4096 29.5 ms β 31.7 ms β 29.64 ms β 29.76 ms β 100 β 100
β°β bench_flatten_unsafe β β β β β
ββ 4 615.7 ns β 1.863 Β΅s β 638.5 ns β 655.9 ns β 100 β 200
ββ 16 2.504 Β΅s β 18.14 Β΅s β 2.534 Β΅s β 3.191 Β΅s β 100 β 100
ββ 64 10.83 Β΅s β 267.2 Β΅s β 11.29 Β΅s β 15.07 Β΅s β 100 β 100
ββ 246 40.17 Β΅s β 848.5 Β΅s β 42.15 Β΅s β 61.47 Β΅s β 100 β 100
ββ 1024 3.133 ms β 3.825 ms β 3.158 ms β 3.208 ms β 100 β 100
β°β 4096 15.59 ms β 17.26 ms β 15.78 ms β 15.87 ms β 100 β 100
The implementation with unsafe is always faster (up to 15x)!
Actions?
Should we implement something similar in the std
or itertools
?
Especially because itertools
is slower in every benchmarkβ¦
Pinging some people from the last thread in case they are interested: @eee @cuviper