I have been struggling to make the native generators ergonomic and flexible enough to be used in my library.
These were my requirements:
- Seamless iteration over generators
- Easily returning generators from functions
- Allowing for recursive generator functions
- Allowing borrows across yields (self-referential generators)
- Playing nicely with explicit lifetimes
- Ability to create empty (and typed) generators
After many many tries and hours of research I have come up with the following concise wrapper that has all these features.
I post it here to share my work, but also to ask for feedback. Particularly on the soundness of the minimal unsafe pinning, which is needed to support self-referential generators (with the static keyword). And to confirm if I am using AssertUnmoved properly to prevent any unsafe usage.
use std::ops::{Generator, GeneratorState};
use std::pin::Pin;
use assert_unmoved::AssertUnmoved;
#[macro_export]
macro_rules! gen {
($e:expr) => {Gen::new($e)};
}
#[macro_export]
macro_rules! gen_empty {
() => {Gen::new(|| {
if 0 == 1 {
yield ();
}
})};
($type:ty) => {Gen::new(|| {
if 0 == 1 {
let v: $type = Default::default();
yield v;
}
})};
}
#[macro_export]
macro_rules! gen_r {
($e:expr) => {Gen::new(Box::pin($e))};
}
#[macro_export]
macro_rules! t_gen {
($I:ty) => {Gen<$I, impl Generator<Yield=$I, Return=()>>};
($I:ty, $L:lifetime) => {Gen<$I, impl Generator<Yield=$I, Return=()> + $L>};
}
#[macro_export]
macro_rules! t_gen_empty {
() => {Gen<(), impl Generator<Yield=(), Return=()>>};
($I:ty) => {Gen<$I, impl Generator<Yield=$I, Return=()>>};
($L:lifetime) => {Gen<(), impl Generator<Yield=(), Return=()> + $L>};
($I:ty, $L:lifetime) => {Gen<$I, impl Generator<Yield=$I, Return=()> + $L>};
}
#[macro_export]
macro_rules! t_gen_r {
($I:ty) => {Gen<$I, Pin<Box<dyn Generator<Yield=$I, Return=()>>>>};
($I:ty, $L:lifetime) => {Gen<$I, Pin<Box<dyn Generator<Yield=usize, Return=()> + $L>>>};
}
pub struct Gen<I, G: Generator<Yield=I, Return=()>> {
gen: Option<AssertUnmoved<G>>,
}
impl<I, G: Generator<Yield=I, Return=()>> Gen<I, G> {
pub fn new(gen: G) -> Self {
Gen { gen: Some(AssertUnmoved::new(gen)) }
}
}
impl<I, G: Generator<Yield=I, Return=()>> Iterator for Gen<I, G> {
type Item = I;
fn next(&mut self) -> Option<Self::Item> {
match &mut self.gen {
Some(gen) => {
let pinned = unsafe { Pin::new_unchecked(gen) };
match pinned.get_pin_mut().resume(()) {
GeneratorState::Yielded(item) => Some(item),
_ => None
}
}
None => None
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn range_generator() {
let generator = gen!(|| {
for i in 0..10 {
yield i;
}
});
for (i, x) in generator.enumerate() {
assert_eq!(i, x);
}
}
#[test]
fn return_from_function() {
fn generator() -> t_gen!(usize) {
gen!(|| {
for i in 0..10 {
yield i;
}
})
}
for (i, x) in generator().enumerate() {
assert_eq!(i, x);
}
}
#[test]
fn return_with_lifetime() {
fn generator<'a>(until: &'a usize) -> t_gen!(usize, 'a) {
gen!(move || {
for i in 0..*until {
yield i;
}
})
}
for (i, x) in generator(&10).enumerate() {
assert_eq!(i, x);
}
}
#[test]
fn self_referential() {
let generator = gen!(static || {
let x: usize = 1;
let ptr = &x;
yield 0;
yield *ptr;
});
for (i, x) in generator.enumerate() {
assert_eq!(i, x);
}
}
#[test]
fn empty() {
assert_eq!(gen_empty!().collect::<Vec<_>>().len(), 0);
}
#[test]
fn empty_with_type() {
assert_eq!(gen_empty!(usize).collect::<Vec<_>>().len(), 0);
}
#[test]
fn return_empty() {
fn generator() -> t_gen_empty!(usize) {
gen_empty!(usize)
}
assert_eq!(generator().collect::<Vec<_>>().len(), 0);
}
#[test]
fn recursive() {
fn generator(until: usize) -> t_gen_r!(usize) {
gen_r!(move || {
yield until;
if until != 0 {
for i in generator(until - 1) {
yield i;
}
}
})
}
let until = 10;
for (i, x) in generator(until).enumerate() {
assert_eq!(until - i, x);
}
}
#[test]
fn recursive_self_referential_with_lifetime() {
fn generator<'a>(until: &'a usize) -> t_gen_r!(usize, 'a) {
gen_r!(static move || {
yield *until;
if *until != 0 {
for i in generator(&(*until - 1)) {
yield i;
}
}
})
}
let until = 10;
for (i, x) in generator(&until).enumerate() {
assert_eq!(until - i, x);
}
}
}