Ergonomic nightly generators

I have been struggling to make the native generators ergonomic and flexible enough to be used in my library.

These were my requirements:

  • Seamless iteration over generators
  • Easily returning generators from functions
  • Allowing for recursive generator functions
  • Allowing borrows across yields (self-referential generators)
  • Playing nicely with explicit lifetimes
  • Ability to create empty (and typed) generators

After many many tries and hours of research I have come up with the following concise wrapper that has all these features.

I post it here to share my work, but also to ask for feedback. Particularly on the soundness of the minimal unsafe pinning, which is needed to support self-referential generators (with the static keyword). And to confirm if I am using AssertUnmoved properly to prevent any unsafe usage.

use std::ops::{Generator, GeneratorState};
use std::pin::Pin;
use assert_unmoved::AssertUnmoved;

#[macro_export]
macro_rules! gen {
    ($e:expr) => {Gen::new($e)};
}

#[macro_export]
macro_rules! gen_empty {
    () => {Gen::new(|| {
        if 0 == 1 {
            yield ();
        }
    })};
    ($type:ty) => {Gen::new(|| {
        if 0 == 1 {
            let v: $type = Default::default();
            yield v;
        }
    })};
}

#[macro_export]
macro_rules! gen_r {
    ($e:expr) => {Gen::new(Box::pin($e))};
}

#[macro_export]
macro_rules! t_gen {
    ($I:ty) => {Gen<$I, impl Generator<Yield=$I, Return=()>>};
    ($I:ty, $L:lifetime) => {Gen<$I, impl Generator<Yield=$I, Return=()> + $L>};
}

#[macro_export]
macro_rules! t_gen_empty {
    () => {Gen<(), impl Generator<Yield=(), Return=()>>};
    ($I:ty) => {Gen<$I, impl Generator<Yield=$I, Return=()>>};
    ($L:lifetime) => {Gen<(), impl Generator<Yield=(), Return=()> + $L>};
    ($I:ty, $L:lifetime) => {Gen<$I, impl Generator<Yield=$I, Return=()> + $L>};
}

#[macro_export]
macro_rules! t_gen_r {
    ($I:ty) => {Gen<$I, Pin<Box<dyn Generator<Yield=$I, Return=()>>>>};
    ($I:ty, $L:lifetime) => {Gen<$I, Pin<Box<dyn Generator<Yield=usize, Return=()> + $L>>>};
}



pub struct Gen<I, G: Generator<Yield=I, Return=()>> {
    gen: Option<AssertUnmoved<G>>,
}

impl<I, G: Generator<Yield=I, Return=()>> Gen<I, G> {
    pub fn new(gen: G) -> Self {
        Gen { gen: Some(AssertUnmoved::new(gen)) }
    }
}

impl<I, G: Generator<Yield=I, Return=()>> Iterator for Gen<I, G> {
    type Item = I;

    fn next(&mut self) -> Option<Self::Item> {
        match &mut self.gen {
            Some(gen) => {
                let pinned = unsafe { Pin::new_unchecked(gen) };
                match pinned.get_pin_mut().resume(()) {
                    GeneratorState::Yielded(item) => Some(item),
                    _ => None
                }
            }
            None => None
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn range_generator() {
        let generator = gen!(|| {
            for i in 0..10 {
                yield i;
            }
        });

        for (i, x) in generator.enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn return_from_function() {
        fn generator() -> t_gen!(usize) {
            gen!(|| {
                for i in 0..10 {
                    yield i;
                }
            })
        }

        for (i, x) in generator().enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn return_with_lifetime() {
        fn generator<'a>(until: &'a usize) -> t_gen!(usize, 'a) {
            gen!(move || {
                for i in 0..*until {
                    yield i;
                }
            })
        }

        for (i, x) in generator(&10).enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn self_referential() {
        let generator = gen!(static || {
            let x: usize = 1;
            let ptr = &x;
            yield 0;
            yield *ptr;
        });

        for (i, x) in generator.enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn empty() {
        assert_eq!(gen_empty!().collect::<Vec<_>>().len(), 0);
    }

    #[test]
    fn empty_with_type() {
        assert_eq!(gen_empty!(usize).collect::<Vec<_>>().len(), 0);
    }

    #[test]
    fn return_empty() {
        fn generator() -> t_gen_empty!(usize) {
            gen_empty!(usize)
        }

        assert_eq!(generator().collect::<Vec<_>>().len(), 0);
    }

    #[test]
    fn recursive() {
        fn generator(until: usize) -> t_gen_r!(usize) {
            gen_r!(move || {
                yield until;
                if until != 0 {
                    for i in generator(until - 1) {
                        yield i;
                    }
                }
            })
        }

        let until = 10;
        for (i, x) in generator(until).enumerate() {
            assert_eq!(until - i, x);
        }
    }

    #[test]
    fn recursive_self_referential_with_lifetime() {
        fn generator<'a>(until: &'a usize) -> t_gen_r!(usize, 'a) {
            gen_r!(static move || {
                yield *until;
                if *until != 0 {
                    for i in generator(&(*until - 1)) {
                        yield i;
                    }
                }
            })
        }

        let until = 10;
        for (i, x) in generator(&until).enumerate() {
            assert_eq!(until - i, x);
        }
    }
}

Based on feedback from @alice I managed to make it even simpler and with no need for unsafe code. The solution is rather trivial now, but it still quite a convenient ergonomic wrapper.

Let me know if someone can think of a cleaner option for empty generators, this solution is rather hackish.

use std::ops::{Generator, GeneratorState};
use std::pin::Pin;


#[macro_export]
macro_rules! gen_empty {
    () => {Gen::new(|| {
        if 0 == 1 {
            yield ();
        }
    })};
    ($type:ty) => {Gen::new(|| {
        if 0 == 1 {
            let v: $type = Default::default();
            yield v;
        }
    })};
}

pub struct Gen<'a, I: 'a> {
    gen: Pin<Box<dyn Generator<Yield=I, Return=()> + 'a>>,
}

impl<'a, I: 'a> Gen<'a, I> {
    pub fn new(gen: impl Generator<Yield=I, Return=()> + 'a) -> Self {
        Gen { gen: Box::pin(gen) }
    }
}

impl<'a, I: 'a> Iterator for Gen<'a, I> {
    type Item = I;

    fn next(&mut self) -> Option<Self::Item> {
        match self.gen.as_mut().resume(()) {
            GeneratorState::Yielded(item) => Some(item),
            _ => None
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn range_generator() {
        let generator = Gen::new(|| {
            for i in 0..10 {
                yield i;
            }
        });

        for (i, x) in generator.enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn return_from_function() {
        fn generator() -> Gen<'static, usize> {
            Gen::new(|| {
                for i in 0..10 {
                    yield i;
                }
            })
        }

        for (i, x) in generator().enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn return_with_borrow() {
        fn generator(until: &usize) -> Gen<usize> {
            Gen::new(move || {
                for i in 0..*until {
                    yield i;
                }
            })
        }

        for (i, x) in generator(&10).enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn self_referential() {
        let generator = Gen::new(static || {
            let x: usize = 1;
            let ptr = &x;
            yield 0;
            yield *ptr;
        });

        for (i, x) in generator.enumerate() {
            assert_eq!(i, x);
        }
    }

    #[test]
    fn empty() {
        assert_eq!(gen_empty!().collect::<Vec<_>>().len(), 0);
    }

    #[test]
    fn empty_with_type() {
        assert_eq!(gen_empty!(usize).collect::<Vec<_>>().len(), 0);
    }

    #[test]
    fn return_empty() {
        fn generator() -> Gen<'static, usize> {
            gen_empty!(usize)
        }

        assert_eq!(generator().collect::<Vec<_>>().len(), 0);
    }

    #[test]
    fn recursive() {
        fn generator(until: usize) -> Gen<'static, usize> {
            Gen::new(move || {
                yield until;
                if until != 0 {
                    for i in generator(until - 1) {
                        yield i;
                    }
                }
            })
        }

        let until = 10;
        for (i, x) in generator(until).enumerate() {
            assert_eq!(until - i, x);
        }
    }

    #[test]
    fn recursive_self_referential_with_borrow() {
        fn generator(until: &usize) -> Gen<usize> {
            Gen::new(static move || {
                yield *until;
                if *until != 0 {
                    for i in generator(&(*until - 1)) {
                        yield i;
                    }
                }
            })
        }

        let until = 10;
        for (i, x) in generator(&until).enumerate() {
            assert_eq!(until - i, x);
        }
    }
}