Idiomatic way to populate vector from branched loop

I am trying to optimize the function below by avoiding:

  • the pre-allocation with zeros
  • bound checks

is there an idiomatic way of achieving this without too much unsafe code?

/// `indices` becomes `[a0, a1, a2, ..., b0, b1, b2]` where `a` denotes 
/// indices unset and `b` indices set, in `is_set`. `set_count` has been previously computed
fn split_indices(is_set: &[bool], set_count: usize) -> Vec<usize> {
    debug_assert_eq!(is_set.iter().filter(|x|**x).count(), set_count);
    
    let unset_count = is_set.len() - set_count;
    
    let mut valids = 0;
    let mut nulls = 0;
    let mut indices = vec![0; is_set.len()];
    is_set
        .iter()
        .enumerate()
        .for_each(|(index, is_valid)| {
            if *is_valid {
                indices[unset_count + valids] = index;
                valids += 1;
            } else {
                indices[nulls] = index;
                nulls += 1;
            }
        });
    indices
}
1 Like

I'd probably write something like this. It iterates over the input vector twice, but doesn't require precalculating set_count:

fn split_indices(is_set: &[bool]) -> Vec<usize> {
    let mut out = Vec::with_capacity(is_set.len());
 
    out.extend(is_set.iter().enumerate().filter(|(_,&b)|!b).map(|(i,_)| i));
    out.extend(is_set.iter().enumerate().filter(|(_,&b)| b).map(|(i,_)| i));

    out
}
3 Likes

If you need to iterate over the input list only once, you can do something like this (untested) on nightly. I wouldn't exactly call it "idiomatic," though-- in almost all circumstances, the performance savings of unsafe tricks like this is tiny when compared to the extra maintenance burden.

#![feature(vec_spare_capacity, option_result_unwrap_unchecked,maybe_uninit_extra)]

/// `indices` becomes `[a0, a1, a2, ..., b0, b1, b2]` where `a` denotes 
/// indices unset and `b` indices set, in `is_set`. `set_count` has been previously computed
fn split_indices(is_set: &[bool]) -> Vec<usize> {
    let mut out = Vec::with_capacity(is_set.len());
    
    let mut dest_iter = out.spare_capacity_mut().iter_mut();
    
    for (i, &b) in is_set.iter().enumerate() {
        let dest = match b {
            true => dest_iter.next_back(),
            false => dest_iter.next(),
        };
        debug_assert!(dest.is_some());
        unsafe { dest.unwrap_unchecked() }.write(i);
    }
    
    debug_assert!(dest_iter.next().is_none());
    unsafe { out.set_len(is_set.len()) }
 
    out
}

NB: The set indices here will appear in reverse order

1 Like

Similar to the last one, but on stable. However, I'd go with the short and 100% safe version.

1 Like

You can replicate unwrap_unchecked on stable like this:

let dest = match dest {
    Some(uninit) => uninit.as_mut_ptr(),
    None => unsafe { std::hint::unreachable_unchecked() }
};

If you do, though, the whole function needs to be marked unsafe: Providing an incorrect value of set_count is UB.

1 Like

My concern with this is that .extend will not specialize to TrustedLen since filter is not TrustedLen, thereby introducing a third check: that we have enough space for the new item.

fn split_indices1(is_set: &[bool], set_count: usize) -> Vec<usize> {
    debug_assert_eq!(is_set.iter().filter(|x|**x).count(), set_count);

    let unset_count = (is_set.len() - set_count) as i32;

    let mut valids: i32 = 0;
    let mut indices = Vec::with_capacity(is_set.len());
    unsafe {indices.set_len(is_set.len())};
    let mut vec = indices.as_mut_slice();
    is_set
        .iter()
        .enumerate()
        .for_each(|(index, is_valid)| {
            let set_bool = *is_valid;
            let set = set_bool as i32;
            let unset = (!set_bool) as i32;
            /// nulls = index - valids
            /// v = set ? valid : nulls  is equivalent to branchless conditional assignment: v = valid - unset * (valid - nulls)
            let v = valids - (unset * (valids * 2 - index as i32));
            let j = v + (set * unset_count);

            /// indices[j as usize] = index;
            unsafe {
                *(vec.get_unchecked_mut(j as usize)) = index;
            }
            valids += set;
        });
    indices
}

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.