How to wrap a non-object-safe Trait in an object-safe one?

I originally encountered this issue when attempting to create a Rayon ParallelIterator that a flat_map operation was applied to a dynamic number of times see this post on the Rayon issue tracker.

For standard, serialized Iterators the following works:

fn fancy<Input, Output, Item>(input: Input, depth: usize) -> Output
where
    Input : IntoIterator<Item = Item>,
    Output : std::iter::FromIterator<Item>,
    Item : Copy + std::ops::Add<Output = Item>,
{
    let mut iter: Box<dyn Iterator<Item = Item>> = Box::new(input.into_iter());
    for _ in 0..depth {
        iter = Box::new(iter.flat_map(|x| vec![x, x + x].into_iter()));
    }
    iter.collect()
}

When working with Rayon's ParallelIterators, however, things get iffy: ParallelIterator requires types that implement it to be Sized, making them non-object safe and thus impossible to pass around as trait object. Thus the following straight translation does not work:

use rayon::prelude::*;
fn fancy2<Input, Output, Item>(input: Input, depth: usize) -> Output
where
    Input : IntoParallelIterator<Item = Item>,
    Output : FromParallelIterator<Item>,
    Item : Copy + std::ops::Add<Output = Item> + Send,
{
    let mut iter: Box<dyn rayon::iter::ParallelIterator<Item = Item>> = Box::new(input.into_par_iter());
    for _ in 0..depth {
        iter = Box::new(iter.flat_map(|x| vec![x, x + x].into_par_iter()));
    }
    iter.collect()
}

My question is: How can we wrap a non-object-safe trait like ParallelIterator in a new trait that is object-safe?

1 Like

I don't think it's generally possible.

If the possible types are known at compile time, you can use an enum wrapper. Rayon already implements its traits for the Either type, with Left and Right variants. If you need more than two, you can create your own enum for this without much trouble.

That won't work for your 0..depth dynamic nesting though. You might be able to implement a custom ParallelIterator type that approximates this, with an UnindexedProducer that splits your x and x + x parts. There might be other ways to refactor the code to do what you want too.

1 Like

Is there a way to make a wrapper struct/enum that can be wrapped in a Box<dyn ...>, or will it effectively get 'tainted' by the Sized requirement of its field that is an instance of ParallelIterator?

This is equivalent:

fn make_vec<I>(x: I, out: &mut [I])
where
    I: Copy + std::ops::Add<Output = I>,
{
    if out.len() == 1 {
        out[0] = x;
    } else {
        let (left, right) = out.split_at_mut(out.len() / 2);
        make_vec(x, left);
        make_vec(x + x, right)
    }
}
fn fancy2<Input, Output, Item>(input: Input, depth: usize) -> Output
where
    Input : IntoParallelIterator<Item = Item>,
    Output : FromParallelIterator<Item>,
    Item : Copy + std::ops::Add<Output = Item> + Send,
{
    input.into_par_iter()
        .flat_map(|x| {
            let mut vec = vec![x; 1 << depth];
            make_vec(x, &mut vec);
            vec
        })
        .collect()
}

except that it isn't as parallel as you might want.

2 Likes

You can also create a struct and manually implement ParallelIterator on it, for example see this. This will be completely parallel. Though I'm not sure it'd be faster..

5 Likes

You could use rayon::join on the pair of recursive make_vec calls to squeeze out more parallelism. I like your custom ParallelIterator too.

However, I hesitate to focus on this particular example. I bet the CPU time will be dominated by allocation and data movement, rather than computation, which isn't a great use of parallelism in the first place. More realistic scenarios might be helped by seeing your solutions, but they'll probably need something custom.

@alice Thank you! That solution, where you create a new concrete struct that implements ParallelIterator and because it is a single type we do not need any dynamic dispatch, seems very clean. Hereby marking it as a solution.

@cuviper The real application I am working on does L-system expansion, which mostly means that the resulting vector in the flat_map will be much longer most of the time.

As for speed, we'll probably only be able to know for sure by benchmarking :sweat_smile:.

To increase the efficiency, you could probably change the if to check if depth is less than, say, 10 and use the recursive function i posted above to take care of the shorter cases without doing it in parallel, as very small tasks are usually faster to do in one thread.