Iterator with branches?

To avoid the XY problem, my problem is provided below:

Currently, I'm working on a rust plugin for R. One of the functions looks like

#[extendr]
fn batch_likelihood_m2i(m: &[i32]) -> Vec<i32> {
    let mut idx_counter = 0;
    m.iter()
        .flat_map(|&m| {
            let curr = idx_counter;
            idx_counter += 1;
            if m>1 {
                (1..m).flat_map(move |j| (0..j).map(move |_k| curr))
            } else {
                core::iter::once(curr) // failed to compile since the type mismatch
            }
        })
        .collect::<Vec<_>>()
}

But it failed to compile since the return type mismatch.

Is there any technique to yield branches within a single iterator?

I know some dirty technique, for example, .skip(if condition_met {0} else {first_len}).take(if condition_met {first_len} else {second_len}), But is there some beautiful method?

In the general case you can use Box<dyn Iterator<Item = I>>.

            if m>1 {
                Box::new((1..m).flat_map(move |j| (0..j).map(move |_k| curr)))
                    as Box<dyn Iterator<Item = _>>
            } else {
                Box::new(core::iter::once(curr)) as _
            }

For this particular example you could use

            // Could take more care wrt overflows and conversions here...
            let count = if m > 1 { m * (m - 1) / 2 } else { 1 } as usize;
            core::iter::repeat(curr).take(count)
9 Likes

And another option is to use something like the Either type (from the either crate).

use either::Either::{Left, Right};

fn batch_likelihood_m2i(m: &[i32]) -> Vec<i32> {
    let mut idx_counter = 0;
    m.iter()
        .flat_map(|&m| {
            let curr = idx_counter;
            idx_counter += 1;
            if m>1 {
                Left((1..m).flat_map(move |j| (0..j).map(move |_k| curr)))
            } else {
                Right(core::iter::once(curr))
            }
        })
        .collect::<Vec<_>>()
}
10 Likes

Also, since you are mutating state, I wouldn't do that in an iterator chain, it's weird (iterator chains are generally expected by the reader to be pure). I'd do this instead:


pub fn batch_likelihood_m2i(m: &[i32]) -> Vec<i32> {
    let mut result = Vec::new();

    for (&m, curr) in iter::zip(m, 0..) {
        if m > 1 {
            result.extend((1..m).flat_map(|j| iter::repeat(curr).take(j as _)));
        } else {
            result.push(curr);
        }
    }

    result
}
5 Likes

This is because the function could be more complex

#[extendr]
fn batch_likelihood_m2corr(corr: &[f64], m: &[i32]) -> Vec<f64> {
    let mut idx_counter = 0;
    m.iter()
        .flat_map(|&m| {
            let curr = idx_counter;
            idx_counter += m * m;
            if m>1 {
                Left((1..m).flat_map(move |j| (0..j).map(move |k| corr[(curr + j * m + k) as usize])))
            } else {
                Right(core::iter::once(1.))
            }
        })
        .collect::<Vec<_>>()
}

inputs are flattened correlation matrixes and the dim of each correlation matrix.

There might be no unified solution unless writing an ugly scan.

Then use enumerate() instead of a mutable variable.

1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.