Rust equivalant of numpy.where()

What is the rust equivalent of numpy.where in the following code? I want to find indices for 4278.0<Xall<=4286.0. Here Xall is NumPy array. But in my rust code, I am using a vector instead.
I tried it using a handwritten logical code using for loop, but it is not fast enough. So does anybody know what is the rust equivalent of numpy.where in the following code?

left=numpy.where(np.logical_and(Xall>=4278.0, Xall<=4286.0))

I am not fluent enoguh in numpy to know what numpy.where exactly returns, but what about something like:

vec.iter().filter(|x| x>= 100 && x<=120).collect::<Vec<_>>()

?

Also what is the performance of the Python code, how did you measure it, did you run 'cargo run' or 'cargo run --release' (the former is only runs a debug build of rust which is usually magnitudes slower than the release version), how did you time the rust code?

2 Likes

I assume you’re working with 1d data if you say you’re using a vector. It might help if you also shared your previous attempt and how you tell that it’s too slow. It might also help if you link to some documentation of the numpy method in question yourself and provide a type-signature of what you’d like to do in Rust. Anyways, looking through the docs

When only condition is provided, this function is a shorthand for np.asarray(condition).nonzero().

numpy.nonzero(a)

Return the indices of the elements that are non-zero.

to fully understand what you’re after, it seems like this should be straightforward with iterators. E.g.

fn indices_between_4278_and_4286(x: &[i32]) -> Vec<usize> {
    x.iter()
        .enumerate()
        .filter_map(|(index, &value)| (4278 <= value && value <= 4286).then(|| index))
        .collect()
}
6 Likes

Not exactly what numpy does, but if OP doesn't need a vector we could also drop the collect and return an iterator :slight_smile:

1 Like

Same as what steffahn's code but developed to using iterators/iterables (because I thought it was fun)

fn indices_where<T>(input: impl IntoIterator<Item = T>, mut where_: impl FnMut(T) -> bool)
    -> impl Iterator<Item = usize>
{
    input.into_iter().enumerate()
        .filter_map(move |(index, elt)| where_(elt).then(|| index))
}

fn main() {
    let data = vec![1000, 4282, 4293, 4281, 4271];
   
    for x in indices_where(&data, |&value| 4278 <= value && value <= 4286) {
        println!("{}", x);
    }
}
3 Likes

Thanks. This is the kind of approach I am looking for. I replaced integers with floats. At the position of numbers in the expressions, it says error[E0308]: mismatched types. expected &&f64, found floating-point number.

The filter method borrows whatever type the iterator has, if the iterator already is over &f64 the filter will expect an function that takes an &&f64.
You can now either dereference the parameter (by using **x<=...) or reference the constant twice (&&100.0).

It creates an object of type Vec<&f64>. But I want to get Vec. I am using cargo run --release

You can call cloned in the iterator to get from &f64 to ``f64

I am doing this for a large amount of data in parallel. So I am already getting very high swap usage and low CPU usage. It is i/o bottleneck. So am trying to avoid unnecessary cloning data. Is it possible to just know the limits of indices so I can reference them from vector, like below (but faster using something like numpy.where() in python).

for j in 0..Xall.len() {
        if Xall[j] >= 4278.0 && Xall[j - 1] < 4278.0 {
            Xlow_left = Xall[j];
            Xlow_left_index = j;
        }
        if Xall[j] > 4286.0 && Xall[j - 1] <= 4286.0 {
            Xhigh_left = Xall[j];
            Xhigh_left_index = j;
        }
    }
let Xdata_left = &Xall[Xlow_left_index..=Xhigh_left_index];

otherwise, can you write the code for me for this suggestion.

@exoplanet_hunter It looks like you can save some time by breaking the loop once you find Xhigh_left_index. It looks like a search in sorted data, is it? It might also be useful to try a binary search in that case, as long as the vector is long enough (I'd bother with it if the Xall sequence was longer than 100 elements).

It is longer than 100 elements and sorted. I will try the binary search.

I'd possibly use partition_point instead of the usual binary search method (both do similar jobs) - it makes it easier to both express the predicate and work with floats in that case. It looks like you want the partition point of elt <= 4286.0 for the high bound?

fwiw, the equivalent of the formula you used in rust would be something like

fn lt_eq(a: &[f64], b: f64) -> Vec<bool>;
fn gt_eq(a: &[f64], b: f64) -> Vec<bool>;
fn and(a: &[bool], b: &[bool]) -> Vec<bool>;
fn where_(a: &[f64], b: &[bool]) -> Vec<f64>;

let a: Vec<f64> = ...;

let result = where_(&a, &and(&lt_eq(&a, 4286.0f64), &gt_eq(&a, 4278.0f64)));

i.e. numpy allocates a new array for every operation (numpy arrays are usually immutable if you do not use the out parameter).

A semantically equivalent Rust implementation is

use std::time::Instant;

fn where_(lhs: &[f64], min: f64, max: f64) -> Vec<f64> {
    lhs.iter()
        .copied()
        .filter(|x| *x > min && *x < max)
        .collect()
}

fn main() {
    let size = 2usize.pow(20);
    let lhs = (0..size).map(|x| x as f64).collect::<Vec<_>>();

    let now = Instant::now();
    let result = where_(&lhs, 4278.0f64, 4286.0f64);
    println!("{} us", now.elapsed().as_micros());
    println!("{:?}", result);
}

on my computer with --release that is ~3x faster than numpy's equivalent semantics (1519 us vs 4644.87 us)

import time

import numpy as np

Xall = np.array(list(range(2 ** 20)))

start = time.time()
data = np.where(np.logical_and(Xall >= 4278.0, Xall <= 4286.0))
print(f"{(time.time() - start) * 1000 * 1000} us")

you could use par_iter from rayon to multi-thread the operation. What do you mean with IO-bounded? The expression you posted is pure CPU.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.