Abstraction causes overhead

Dear all,

While experimenting with the IntoIterator trait, I find that introducing this abstraction can result in a resulting binary that is quite inefficient.

The full example can be found here, but what I am essentially doing is the following:

  1. Create a variant of an enum based on some input
  2. Convert the value into an iterator (two variations in the example)
  3. Consume the items in the iterator, summing the elements

One would hope that the usage of the iterator (the abstracted implementation) is just as fast as the "direct" implementation (see the function goal).

All the functions already get inlined; that is not the problem. Is there anything that can be done to optimize this code? Is there a better way to write the EnumIntoIterator or its iterator implementation? Why can't the compiler figure this out?

You gotta implement fold as well

impl Iterator for EnumIntoIter {
    type Item = u32;

    fn next(&mut self) -> Option<u32> {
        match self {
            EnumIntoIter::Empty(iter) => iter.next(),
            EnumIntoIter::Single(iter) => iter.next(),
            EnumIntoIter::Multi(iter) => iter.next(),
        }
    }

    fn fold<B, F>(self, init: B, f: F) -> B
    where
        F: FnMut(B, Self::Item) -> B,
    {
        match self {
            EnumIntoIter::Empty(iter) => iter.fold(init, f),
            EnumIntoIter::Single(iter) => iter.fold(init, f),
            EnumIntoIter::Multi(iter) => iter.fold(init, f),
        }
    }
}

Compiler Explorer


std::iter – Rust ⇒ Implementing Iterator

Also note that Iterator provides a default implementation of methods such as nth and fold which call next internally. However, it is also possible to write a custom implementation of methods like nth and fold if an iterator can compute them more efficiently without calling next .

2 Likes

That seems like a reasonable solution.

So, basically, implementors should consider whether any of the default implementations of the Iterator trait might be inefficient (even after optimization)?

So I implemented fold, and swapped out the Vec iter for an array iter (this was giving the "goal" an unfair advantage as it had a statically known length while "abstracted" had a dynamic length).

I also tried "cheating" a bit by pulling out a closure in abstracted, allowing the sum to get inlined separately for all three branches. This had the largest effect on code size.

These changes were still not enough, unfortunately.

Looks like switching to an array actually results in the compiler not finding the optimized version from @steffahn (despite the fold implementation)... Compiler Explorer

The fundamental with any enum of multiple iterator types that implements Iterator is: If you only implement next, then every call to next will have to do the match again in order to figure out what enum variant you’re having. [Also in case any of the underlying iterators has an optimized fold implementation, that one wouldn’t be used if you didn’t overload fold yourself; but none of the three iterators vec:IntoIter, iter::Once and iter::Empty has a special fold implementation, so I guess that doesn’t apply here.]

By the way in case that wasn’t clear: .sum() is implemented in terms of fold (and .map() has an efficient fold implementation in terms of the fold of the underlying iterator as well).

Are you using --release or (with rustc) -C opt-level=3?

Yes they are, see the flags in the compiler explorer.

Would it be desirable to provide a default implementation for all enum types?

I was confused, since the tab with rustc 1.54 didn't have the flag but the one with nightly did.

I don’t quite understand the question.

Something like #[derive(Iterator<Item = u32>)]?

Tested Chain<Once<u32>, Once<u32>>, too, that one optimizes well

Compiler Explorer

The array::IntoIter isn’t even fully optimized if you remove the 0 => and 1 => branches completely… Anyways, array::IntoIter is still pretty new. Maybe some improvements to it will be able to allow more optimization in the future. Perhaps the code examples from this thread are even worth opening an issue on GitHub, it seems a bit weird that the compiler is better at removing a full allocation (Vec) and its iterator than an array. Or that a Chain<Once<T>, Once<T>> behaves better than array::IntoIter<T, 2>.

1 Like

I don't see a way in which such an implementation would currently be possible; I'm just thinking that (almost?) any time you implement iterator on an enum, you will have to branch on the variant in next. It seems a shame that this is not optimized out.

Of course it is a bit tricky, as I guess that the variant might change during iteration.

Ah, maybe I wasn't clear enough in my eplanation above

This only describes the problem, not how overloading the fold implementation addresses it.

The default implementation of fold will repeatedly call next in a loop to do the folding, each call to next will need to figure out the enum variant again. A custom fold implementation can test the enum variant once and call fold in the underlying iterator of that variant once. Of course there is still a branch on the enum variant, but it doesn't happen repeatedly, only once at the start of the iteration.

Admitted, even with the default implementation of fold (doing repeated next calls) there is a chance that the optimizer will figure out on its own that branching is necessary only once and then the enum variant never changes, but it's a non trivial optimization so it might not always happen or preclude other optimization or whatever.

1 Like

EDIT: Per @steffahn below, this is UB due to a silly mistake; leaving it here as a reminder of how easy it is to mess up unsafe code.

I was curious what I could make happen with some unsafe and came up with this; it codegens a bit shorter than @ExpHP's version.

struct Dispatch {
    get: fn(_:&mut EnumIntoIter)->&mut dyn Iterator<Item=u32>,
    iter: EnumIntoIter
}

impl Iterator for Dispatch {
    type Item = u32;
    
    #[inline(always)]
    fn next(&mut self) -> Option<u32> {
        (self.get)(&mut self.iter).next()
    }

    #[inline(always)]
    fn fold<B, F>(mut self, init: B, f: F) -> B
    where
        F: FnMut(B, Self::Item) -> B,
    {
        (self.get)(&mut self.iter).fold(init, f)
    }
}

impl IntoIterator for Enum {
    type Item = u32;
    type IntoIter = Dispatch;
  
    #[inline(always)]
    fn into_iter(self) -> Dispatch {
        use EnumIntoIter::*;
        match self {
            Enum::Empty =>
                Dispatch {
                    get: |i| {
                        if let Empty(iter) = i { iter }
                        else { unsafe { unreachable_unchecked() } }
                    },
                    iter: EnumIntoIter::Empty(std::iter::empty()),
                },
            Enum::Single(single) =>
                Dispatch {
                    get: |i| {
                        if let Single(iter) = i { iter }
                        else { unsafe { unreachable_unchecked() } }
                    },
                    iter: EnumIntoIter::Single(std::iter::once(single))
                },
            Enum::Multi(pair) =>
                Dispatch {
                    get: |i| {
                        if let Single(iter) = i { iter }
                        else { unsafe { unreachable_unchecked() } }
                    },
                    iter: EnumIntoIter::Multi(std::array::IntoIter::new(pair))
                },
        }
    }
}

Your code has UB. You're matching the Multi against Single.

2 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.