~300x more efficient flattening of Vec<Result<Vec<T>, E>> in the worst case

Story

I just created an account to reply to Elegantly flatten Vec<Result<Vec<T>, E>> into Result<Vec<T>, E> but then I found that it was closed. This is why I post it here :smiley:

The two methods that were suggested are the following:

fn flatten_fold<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
    outer.into_iter().try_fold(vec![], |mut unrolled, result| {
        unrolled.extend(result?);
        Ok(unrolled)
    })
}

fn flatten_itertools<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
    itertools::process_results(outer, |i| i.flatten().collect())
}

I didn't dive deep into how itertools::process_results and only show its benchmark results.

But the problems of flatten_fold are the following:

  1. No pre-allocation
  2. Fails too late if the location of the first error in the nested vectors is too far

Therefore, I pulled out my magic box unsafe and wrote a better function which is ~300x faster in the worst case which is when the last element in the outer vector is an error:

fn flatten_unsafe<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
    let mut len = 0;
    let mut err = false;

    for result in &outer {
        match result {
            Ok(inner) => len += inner.len(),
            Err(_) => {
                err = true;
                break;
            }
        }
    }

    if err {
        for result in outer {
            result?;
        }

        // Safety: Can't be reached since we found an error in the vector `outer`.
        unsafe { unreachable_unchecked() };
    }

    let mut flat_v = Vec::with_capacity(len);

    let mut ptr = flat_v.as_mut_ptr();
    for result in outer {
        // Safety: If any element of `outer` is an error, we would have found it above.
        for value in unsafe { result.unwrap_unchecked() } {
            // Safety: The capacity is set as the total length.
            unsafe { ptr::write(ptr, value) };
            ptr = unsafe { ptr.add(1) };
        }
    }

    // Safety: The calculated total length.
    unsafe { flat_v.set_len(len) };

    Ok(flat_v)
}

BTW: Do you know of a better way to return the error in the first loop if it doesn't implement Clone?

Benchmarks

I use divan for the benchmarks.

Let's start with the mentioned worst case:

use divan::{bench, Bencher};
use flatten::{flatten_fold, flatten_itertools, flatten_unsafe};
use std::hint::black_box;

const LENS: &[usize] = &[4, 16, 64, 246, 1024, 4096];

fn main() {
    divan::main()
}

fn bench_v(len: usize) -> Vec<Result<Vec<u64>, ()>> {
    let mut v = vec![Ok(vec![0; 1024]); len];
    v[len - 1] = Err(());

    v
}

#[bench(consts = LENS)]
fn bench_flatten_fold<const N: usize>(bencher: Bencher) {
    bencher
        .with_inputs(|| bench_v(N))
        .bench_values(|v| black_box(flatten_fold(black_box(v))))
}

#[bench(consts = LENS)]
fn bench_flatten_itertools<const N: usize>(bencher: Bencher) {
    bencher
        .with_inputs(|| bench_v(N))
        .bench_values(|v| black_box(flatten_itertools(black_box(v))))
}

#[bench(consts = LENS)]
fn bench_flatten_unsafe<const N: usize>(bencher: Bencher) {
    bencher
        .with_inputs(|| bench_v(N))
        .bench_values(|v| black_box(flatten_unsafe(black_box(v))))
}

Results:

bench                       fastest       β”‚ slowest       β”‚ median        β”‚ mean          β”‚ samples β”‚ iters
β”œβ”€ bench_flatten_fold                     β”‚               β”‚               β”‚               β”‚         β”‚
β”‚  β”œβ”€ 4                     2.604 Β΅s      β”‚ 13.27 Β΅s      β”‚ 2.665 Β΅s      β”‚ 2.833 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 16                    52.08 Β΅s      β”‚ 123.7 Β΅s      β”‚ 56.5 Β΅s       β”‚ 58.46 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 64                    14.14 Β΅s      β”‚ 331.8 Β΅s      β”‚ 19.82 Β΅s      β”‚ 24.43 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 246                   858.1 Β΅s      β”‚ 1.34 ms       β”‚ 866.5 Β΅s      β”‚ 879 Β΅s        β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 1024                  4.103 ms      β”‚ 6.132 ms      β”‚ 4.189 ms      β”‚ 4.248 ms      β”‚ 100     β”‚ 100
β”‚  ╰─ 4096                  21.67 ms      β”‚ 23.48 ms      β”‚ 21.86 ms      β”‚ 21.92 ms      β”‚ 100     β”‚ 100
β”œβ”€ bench_flatten_itertools                β”‚               β”‚               β”‚               β”‚         β”‚
β”‚  β”œβ”€ 4                     9.497 Β΅s      β”‚ 17.16 Β΅s      β”‚ 9.517 Β΅s      β”‚ 9.954 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 16                    47.77 Β΅s      β”‚ 93.53 Β΅s      β”‚ 48.28 Β΅s      β”‚ 50.79 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 64                    201.3 Β΅s      β”‚ 408.3 Β΅s      β”‚ 203.8 Β΅s      β”‚ 208.7 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 246                   746 Β΅s        β”‚ 1.573 ms      β”‚ 754.4 Β΅s      β”‚ 776.1 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 1024                  6.674 ms      β”‚ 7.772 ms      β”‚ 6.704 ms      β”‚ 6.778 ms      β”‚ 100     β”‚ 100
β”‚  ╰─ 4096                  31.29 ms      β”‚ 33.21 ms      β”‚ 31.52 ms      β”‚ 31.69 ms      β”‚ 100     β”‚ 100
╰─ bench_flatten_unsafe                   β”‚               β”‚               β”‚               β”‚         β”‚
   β”œβ”€ 4                     63 ns         β”‚ 145.9 ns      β”‚ 66.13 ns      β”‚ 67.1 ns       β”‚ 100     β”‚ 3200
   β”œβ”€ 16                    212.7 ns      β”‚ 683.5 ns      β”‚ 218.8 ns      β”‚ 254.5 ns      β”‚ 100     β”‚ 800
   β”œβ”€ 64                    833.7 ns      β”‚ 2.349 Β΅s      β”‚ 856.3 ns      β”‚ 983.3 ns      β”‚ 100     β”‚ 400
   β”œβ”€ 246                   3.034 Β΅s      β”‚ 6.201 Β΅s      β”‚ 3.135 Β΅s      β”‚ 3.257 Β΅s      β”‚ 100     β”‚ 100
   β”œβ”€ 1024                  12.29 Β΅s      β”‚ 19.69 Β΅s      β”‚ 12.79 Β΅s      β”‚ 13.35 Β΅s      β”‚ 100     β”‚ 100
   ╰─ 4096                  1.988 ms      β”‚ 3.045 ms      β”‚ 2.027 ms      β”‚ 2.052 ms      β”‚ 100     β”‚ 100

Are you looking for the click bait? Compare the mean values for N=1024: 4248/13.35 = 318.2

I know, this is obviously the worst case. How about the best case? (no errors)

Well, just comment out the line that sets the last element to Err(()) in the bench_v function:

bench                       fastest       β”‚ slowest       β”‚ median        β”‚ mean          β”‚ samples β”‚ iters
β”œβ”€ bench_flatten_fold                     β”‚               β”‚               β”‚               β”‚         β”‚
β”‚  β”œβ”€ 4                     2.644 Β΅s      β”‚ 14.16 Β΅s      β”‚ 2.704 Β΅s      β”‚ 2.968 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 16                    43.48 Β΅s      β”‚ 101.3 Β΅s      β”‚ 48.13 Β΅s      β”‚ 48.81 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 64                    15.04 Β΅s      β”‚ 339.5 Β΅s      β”‚ 21.68 Β΅s      β”‚ 27.39 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 246                   702.1 Β΅s      β”‚ 1.321 ms      β”‚ 714.7 Β΅s      β”‚ 735.7 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 1024                  3.321 ms      β”‚ 6.055 ms      β”‚ 3.372 ms      β”‚ 3.44 ms       β”‚ 100     β”‚ 100
β”‚  ╰─ 4096                  20.09 ms      β”‚ 21.88 ms      β”‚ 20.25 ms      β”‚ 20.37 ms      β”‚ 100     β”‚ 100
β”œβ”€ bench_flatten_itertools                β”‚               β”‚               β”‚               β”‚         β”‚
β”‚  β”œβ”€ 4                     11.62 Β΅s      β”‚ 12.28 Β΅s      β”‚ 12.19 Β΅s      β”‚ 11.95 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 16                    48.55 Β΅s      β”‚ 63.11 Β΅s      β”‚ 48.88 Β΅s      β”‚ 49.28 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 64                    195.7 Β΅s      β”‚ 410 Β΅s        β”‚ 197.2 Β΅s      β”‚ 203.7 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 246                   733.1 Β΅s      β”‚ 1.52 ms       β”‚ 738.4 Β΅s      β”‚ 751 Β΅s        β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 1024                  5.764 ms      β”‚ 6.672 ms      β”‚ 5.794 ms      β”‚ 5.842 ms      β”‚ 100     β”‚ 100
β”‚  ╰─ 4096                  29.5 ms       β”‚ 31.7 ms       β”‚ 29.64 ms      β”‚ 29.76 ms      β”‚ 100     β”‚ 100
╰─ bench_flatten_unsafe                   β”‚               β”‚               β”‚               β”‚         β”‚
   β”œβ”€ 4                     615.7 ns      β”‚ 1.863 Β΅s      β”‚ 638.5 ns      β”‚ 655.9 ns      β”‚ 100     β”‚ 200
   β”œβ”€ 16                    2.504 Β΅s      β”‚ 18.14 Β΅s      β”‚ 2.534 Β΅s      β”‚ 3.191 Β΅s      β”‚ 100     β”‚ 100
   β”œβ”€ 64                    10.83 Β΅s      β”‚ 267.2 Β΅s      β”‚ 11.29 Β΅s      β”‚ 15.07 Β΅s      β”‚ 100     β”‚ 100
   β”œβ”€ 246                   40.17 Β΅s      β”‚ 848.5 Β΅s      β”‚ 42.15 Β΅s      β”‚ 61.47 Β΅s      β”‚ 100     β”‚ 100
   β”œβ”€ 1024                  3.133 ms      β”‚ 3.825 ms      β”‚ 3.158 ms      β”‚ 3.208 ms      β”‚ 100     β”‚ 100
   ╰─ 4096                  15.59 ms      β”‚ 17.26 ms      β”‚ 15.78 ms      β”‚ 15.87 ms      β”‚ 100     β”‚ 100

The implementation with unsafe is always faster (up to 15x)!

Actions?

Should we implement something similar in the std or itertools?

Especially because itertools is slower in every benchmark…

Pinging some people from the last thread in case they are interested: @eee @cuviper

1 Like

Two clear problems with this one:

  1. It should use with_capacity for a fair comparison with the alternative below
  2. Since you don't need ownership of the accumulator, you can ΞΌoptimize it by using try_for_each instead (this helps because LLVM doesn't always manage to realize that wrapping up the Vec into a Result and pulling it out again isn't actually doing anything, and thus by just not doing that it's more obvious what's happening in the loop)

So for safe versions, try something like

fn flatten_fold_preallocate<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
    use std::ops::ControlFlow::*;

    let (Continue(cap) | Break(cap)) = outer.iter().try_fold(0, |s, r| match r {
        Ok(v) => Continue(s + v.len()),
        Err(_) => Break(s),
    });

    let mut unrolled = Vec::with_capacity(cap);
    outer.into_iter().try_for_each(|result| {
        unrolled.extend(result?);
        Ok(())
    })?;
    Ok(unrolled)
}

which is optimized for assuming that there aren't going to be any errors.

Or if the "don't allocate and copy if there's an error" is worth keeping, I think a safe version would look something like this:

fn flatten_fold_precheck<T, E>(mut outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
    let r = outer.iter().enumerate().try_fold(0, |s, (i, r)| match r {
        Ok(v) => Ok(s + v.len()),
        Err(_) => Err(i),
    });
    let mut unrolled;
    match r {
        Err(i) => return outer.swap_remove(i).map(|_| unreachable!()),
        Ok(cap) => unrolled = Vec::with_capacity(cap),        
    }
    outer.into_iter().for_each(|result| {
        unrolled.extend(result.unwrap_or_else(|_| unreachable!()));
    });
    Ok(unrolled)
}

Up-levelling a second, the reason that itertools isn't going to help as much here is that if you have concrete types and can multiple-iterate, then you can do better than if all you have is an iterator.

That's the difference between https://doc.rust-lang.org/nightly/std/primitive.slice.html#method.join and https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.join, for example.

1 Like

On first glance, my only concern (in terms of correctness) is the (lack of) handling of overflows in

len += inner.len()

though maybe that can only happen for zero-sized T? (I’m not 100% sure if there’s a guarantee that the sum of all allocation sizes never exceeds usize::MAX.) For the important cases, like non-zero-sized T on 64bit architecture, this should be irrelevant though, so the benchmark shouldn’t suffer from any appropriately minimal fixes.

Also, the code structure around err confuses me, as in principle the Err(_) branch could be made to directly return the error, and I’d be somewhat surprised (but also didn’ test it) if that made performance worse, compared to the err = true; break approach. Oh, it’s about the Clone thing… got it! I didn’t spot the remark.

Edit: If the enumerate doesn’t end up affecting performance, then

    for (i, result) in outer.iter().enumerate() {
        match result {
            Ok(inner) => len += inner.len(),
            Err(_) => {
                return outer.swap_remove(i);
            }
        }
    }

could be reasonable.

1 Like

FWIW, a safe version of your code has nearly the same performance:

fn flatten_safe<T, E>(outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
    let mut len = 0;
    let mut err = false;

    for result in &outer {
        match result {
            Ok(inner) => len += inner.len(),
            Err(_) => {
                err = true;
                break;
            }
        }
    }

    if err {
        for result in outer {
            result?;
        }
        unreachable!();
    }

    let mut flat_v = Vec::with_capacity(len);
    for result in outer {
        flat_v.append(&mut result?);
    }
    Ok(flat_v)
}
3 Likes

@scottmcm @steffahn swap_remove is really smart! I forgot about it :sweat_smile:

@steffahn You are right, I should use saturating_add!

Here is the performance with the two new functions without errors:

bench                           fastest       β”‚ slowest       β”‚ median        β”‚ mean          β”‚ samples β”‚ iters
β”œβ”€ bench_flatten_fold_precheck                β”‚               β”‚               β”‚               β”‚         β”‚
β”‚  β”œβ”€ 4                         2.634 Β΅s      β”‚ 17.7 Β΅s       β”‚ 2.694 Β΅s      β”‚ 3.177 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 16                        45.15 Β΅s      β”‚ 70.85 Β΅s      β”‚ 47.73 Β΅s      β”‚ 48.82 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 64                        175.4 Β΅s      β”‚ 253.4 Β΅s      β”‚ 186.1 Β΅s      β”‚ 186.6 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 256                       692 Β΅s        β”‚ 1.048 ms      β”‚ 704.2 Β΅s      β”‚ 713.2 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 1024                      3.268 ms      β”‚ 4.966 ms      β”‚ 3.29 ms       β”‚ 3.349 ms      β”‚ 100     β”‚ 100
β”‚  ╰─ 4096                      15.74 ms      β”‚ 16.72 ms      β”‚ 15.89 ms      β”‚ 15.94 ms      β”‚ 100     β”‚ 100
β”œβ”€ bench_flatten_safe                         β”‚               β”‚               β”‚               β”‚         β”‚
β”‚  β”œβ”€ 4                         2.764 Β΅s      β”‚ 3.535 Β΅s      β”‚ 2.804 Β΅s      β”‚ 2.809 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 16                        4.738 Β΅s      β”‚ 21.2 Β΅s       β”‚ 4.809 Β΅s      β”‚ 5.052 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 64                        14.23 Β΅s      β”‚ 292 Β΅s        β”‚ 14.62 Β΅s      β”‚ 18.42 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 256                       48.92 Β΅s      β”‚ 1.046 ms      β”‚ 50.31 Β΅s      β”‚ 71.23 Β΅s      β”‚ 100     β”‚ 100
β”‚  β”œβ”€ 1024                      3.27 ms       β”‚ 4.089 ms      β”‚ 3.3 ms        β”‚ 3.344 ms      β”‚ 100     β”‚ 100
β”‚  ╰─ 4096                      15.82 ms      β”‚ 17.47 ms      β”‚ 15.94 ms      β”‚ 16.02 ms      β”‚ 100     β”‚ 100
╰─ bench_flatten_unsafe                       β”‚               β”‚               β”‚               β”‚         β”‚
   β”œβ”€ 4                         580.6 ns      β”‚ 1.928 Β΅s      β”‚ 590.6 ns      β”‚ 604.5 ns      β”‚ 100     β”‚ 200
   β”œβ”€ 16                        2.303 Β΅s      β”‚ 18.29 Β΅s      β”‚ 2.434 Β΅s      β”‚ 3.023 Β΅s      β”‚ 100     β”‚ 100
   β”œβ”€ 64                        10.73 Β΅s      β”‚ 215.5 Β΅s      β”‚ 11.21 Β΅s      β”‚ 14.77 Β΅s      β”‚ 100     β”‚ 100
   β”œβ”€ 256                       40.9 Β΅s       β”‚ 911.6 Β΅s      β”‚ 42.21 Β΅s      β”‚ 64.65 Β΅s      β”‚ 100     β”‚ 100
   β”œβ”€ 1024                      3.073 ms      β”‚ 3.729 ms      β”‚ 3.098 ms      β”‚ 3.149 ms      β”‚ 100     β”‚ 100
   ╰─ 4096                      15.23 ms      β”‚ 16.31 ms      β”‚ 15.38 ms      β”‚ 15.45 ms      β”‚ 100     β”‚ 100

The unsafe implementation is still faster than flatten_safe.

The performance of the worst case (only one error at the end) is almost the same for all three.

The weird thing though is that I get the same performance as flatten_safe when I use copy_nonoverlapping which is used internally by append which flatten_safe uses. Isn't it weird that copy_nonoverlapping is slower? Is something wrong here?

For other people interested in an efficient and safe implementation, here is the solution by @cuviper slightly modified to use swap_remove:

pub fn flatten_safe<T, E>(mut outer: Vec<Result<Vec<T>, E>>) -> Result<Vec<T>, E> {
    let mut len = 0;
    for (ind, result) in outer.iter().enumerate() {
        match result {
            Ok(inner) => len += inner.len(),
            Err(_) => {
                return outer.swap_remove(ind);
            }
        }
    }

    let mut flat_v = Vec::with_capacity(len);
    for result in outer {
        flat_v.append(&mut result?);
    }

    Ok(flat_v)
}

The difference between this and the unsafe implementation is not significant.

9 Likes