Implementation of hybrid partial sorting algorithm

// OPTIMIZE THIS FUNCTION TO RUN AS FAST AS POSSIBLE
// Result must work on play.rust-lang.org

/// Returns the 8 smallest numbers found in the supplied vector
/// in order smallest to largest.
fn least_8(l: &Vec<u32>) -> Vec<u32> {
    let mut ll = l.clone();
    ll.sort();
    ll[0..8].iter().cloned().collect()
}

// DO NOT CHANGE ANYTHING BELOW THIS LINE

fn make_list() -> Vec<u32> {
    const SIZE: usize = 1<<16;
    let mut out = Vec::with_capacity(SIZE);
    let mut num = 998_244_353_u32; // prime
    for i in 0..SIZE {
        out.push(num);
        // rotate and add to produce some pseudorandomness
        num = (((num << 1) | (num >> 31)) as u64 + (i as u64)) as u32;
    }
    return out;
}

fn main() {
    let l = make_list();
    let start = std::time::Instant::now();
    let l8 = least_8(&l);
    let end = std::time::Instant::now();
    assert_eq!(vec![4, 5, 15, 22, 28, 31, 37, 38], l8);
    println!("Took {:?}", end.duration_since(start));
}

(Playground)

You don't need to allocate a vector if you always want 8 elements. You don't need to sort the whole thing either; that's O(n log n), whereas just naïvely finding the top k elements by looping is O(n * k), which can be faster if k is sufficiently small compared to log n. Furthermore, the constant factors in linear search are much smaller compared to sorting, because sorting does a lot of random access, whereas the linear scanning pattern of linear search can easily be predicted/pre-fetched by the CPU/MMU.

All in all, a possible improvement would be

fn least_8(l: &[u32]) -> [u32; 8] {
    let mut arr = [0; 8];
    arr[0] = l.iter().copied().min().unwrap();

    for i in 1..arr.len() {
        arr[i] = l.iter().copied().filter(|&x| x > arr[i - 1]).min().unwrap();
    }

    arr
}

which is 400µs vs. 3ms in release mode, or around 8 times faster.

Can you check it?
This is my own playground.
Execution time 200us

Building on the repeated linear scans from @H2CO3, I've modified the sorting function to do everything in one pass. The idea is to maintain an always-sorted list of the 8 smallest numbers we've seen so far, and insert into it only when we see a number small enough to replace an element in this list of 8.

Execution time of the algorithm is around 40µs on average, which is about 10 times faster than the repeated scans and 75 times faster than the original.

fn least_8(l: &[u32]) -> [u32; 8] {
    // Initialize 8-element return value with first
    // 8 elements from the input vector, and then
    // sort this subset to keep the return vec sorted
    // at all times
    // (assume input vector length >= 8)
    let mut arr = [0; 8];
    arr.clone_from_slice(&l[0..8]);
    arr.sort_unstable();
    // Iterate through remaining elements of input vector
    // and insert any small enough elements into return vec
    for i in 8..(l.len()) {
        // current element should be inserted somewhere into the
        // return vec
        if &l[i] < &arr[7] {
            // find the correct insertion point (p)
            for p in 0..=7 {
                if &l[i] < &arr[p] {
                    // c is the current index of the 8-element
                    // array as we iterate through it.
                    // if c < p, leave this number alone
                    // if c == p, it is replaced with l[i]
                    // if c > p, move element one place to the right
                    // and remove it if it is the largest element
                    // in the list
                    for c in (p + 1..=7).rev() {
                        arr[c] = arr[c - 1];
                    }
                    arr[p] = l[i]; // replace element at p
                    break;
                }
            }
        }
    }
    arr
}
1 Like

Yes, this is mean-heap algorithm.
In the case of k << n, it works well.

Also, don't rely on the Rust Playground for benchmarks, as it's highly inaccurate. Running your code a few times I've got results ranging from 195 µs to 318 µs. Either run the code locally, or even better use criterion.

2 Likes

Might also be useful to compare the performance to the std solution, select_nth_unstable, which is based on an introselect algorithm.

4 Likes

Quick modification of @hax10's code to use that.

// SPDX-License-Identifier: Apache-2.0 OR MIT

/// Note: mutates the input.
fn least_8(list: &mut [u32]) -> [u32; 8] {
    let (head, _, _) = list.select_nth_unstable(8 - 1);  // parameter is index, not count
    head.sort_unstable();

    let mut arr = [0; 8];
    arr.clone_from_slice(&list[0..8]);
    arr
}

Have not yet timed it.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.