Rust equivalant of numpy.where()

What is the rust equivalent of numpy.where in the following code? I want to find indices for 4278.0<Xall<=4286.0. Here Xall is NumPy array. But in my rust code, I am using a vector instead.
I tried it using a handwritten logical code using for loop, but it is not fast enough. So does anybody know what is the rust equivalent of numpy.where in the following code?

left=numpy.where(np.logical_and(Xall>=4278.0, Xall<=4286.0))

I am not fluent enoguh in numpy to know what numpy.where exactly returns, but what about something like:

vec.iter().filter(|x| x>= 100 && x<=120).collect::<Vec<_>>()

?

Also what is the performance of the Python code, how did you measure it, did you run 'cargo run' or 'cargo run --release' (the former is only runs a debug build of rust which is usually magnitudes slower than the release version), how did you time the rust code?

1 Like

I assume you’re working with 1d data if you say you’re using a vector. It might help if you also shared your previous attempt and how you tell that it’s too slow. It might also help if you link to some documentation of the numpy method in question yourself and provide a type-signature of what you’d like to do in Rust. Anyways, looking through the docs

When only condition is provided, this function is a shorthand for np.asarray(condition).nonzero().

numpy.nonzero(a)

Return the indices of the elements that are non-zero.

to fully understand what you’re after, it seems like this should be straightforward with iterators. E.g.

fn indices_between_4278_and_4286(x: &[i32]) -> Vec<usize> {
    x.iter()
        .enumerate()
        .filter_map(|(index, &value)| (4278 <= value && value <= 4286).then(|| index))
        .collect()
}
4 Likes

Not exactly what numpy does, but if OP doesn't need a vector we could also drop the collect and return an iterator :slight_smile:

1 Like

Same as what steffahn's code but developed to using iterators/iterables (because I thought it was fun)

fn indices_where<T>(input: impl IntoIterator<Item = T>, mut where_: impl FnMut(T) -> bool)
    -> impl Iterator<Item = usize>
{
    input.into_iter().enumerate()
        .filter_map(move |(index, elt)| where_(elt).then(|| index))
}

fn main() {
    let data = vec![1000, 4282, 4293, 4281, 4271];
   
    for x in indices_where(&data, |&value| 4278 <= value && value <= 4286) {
        println!("{}", x);
    }
}
3 Likes

Thanks. This is the kind of approach I am looking for. I replaced integers with floats. At the position of numbers in the expressions, it says error[E0308]: mismatched types. expected &&f64, found floating-point number.

The filter method borrows whatever type the iterator has, if the iterator already is over &f64 the filter will expect an function that takes an &&f64.
You can now either dereference the parameter (by using **x<=...) or reference the constant twice (&&100.0).