Using coroutines to rewrite recursive functions into iterating ones

Hello everyone!

Currently, Rust has an unstable feature called coroutines. A coroutine is a special type of closure that can either return completely (dropping all its local variables and without a possibility to resume it), or return temporarily (yield), in that case all local variables are preserved and a coroutine can be resumed later. It's similar to asynchronous function futures, with the difference that async functions yield when the job cannot be finished now (but may be able to finish later) because it requires some external event (e.g. an arrival of data from server, a network connection) that has not happened yet, while coroutines yield either when the job is partially done, or some new data is requested. A coroutine interface (core::ops::Coroutine) is similar to the future interface (except there is no "context object" like in core::future::Future, a yielding coroutine may provide some value and some value of an argument type may be expected on every resuming) and to the iterator interface (except that a coroutine may be non-Unpin and therefore Pin<&mut Self> is required instead of &mut Self, a value may be requested on every resume and a value may be returned when "iteration" completes).

If a project allows using unstable features (or when coroutines are stabilized), it can use coroutines to avoid recursion. The main reason why a project may want to avoid recursion is that a recursive function may allocate a hard-to-predict amount of memory on the stack and cause a stack overflow. To avoid this, a recursive function may be rewritten to an iterating one, which means all the necessary memory will be allocated on the heap. One example of a recursive algorithm is finding all permutations of a given sequence. A recursive algorithm may look like:

Recursive version
use std::collections::VecDeque;

fn permutations(list: &mut [i32]) -> Vec<VecDeque<i32>>
{
      let mut result: Vec<VecDeque<i32>> = Vec::new();
      if list.len() == 1
      {
         result.push(VecDeque::from(list.to_owned()));
         return result;
      }
      for i in 0..list.len()
      {
         if i > 0
         {
            list.swap(0, i);
         }
         for j in permutations(&mut list[1..])
         {
            let mut j = j;
            j.push_front(list[0]);
            result.push(j);
         }
         if i > 0
         {
            list.swap(0, i);
         }
      }
      result
}

fn main()
{
   let mut args = std::env::args();
   let _name = args.next().unwrap();
   let mut values: Vec<i32> = Vec::new();
   for i in args
   {
      let num = match i32::from_str_radix(&i, 10)
      {
         Ok(v) => v,
         Err(_) => { break; },
      };
      values.push(num);
   }
   let result = permutations(&mut values);
   for i in result
   {
      for j in i
      {
         print!("{j} ");
      }
      println!();
   }
}

An iterating version may look like:

Iterating version
#![feature(coroutine_trait,coroutines)]
use std::collections::VecDeque;
use core::ops::{Coroutine, CoroutineState};
use core::pin::Pin;

enum FnArg
{
   Initial(VecDeque<i32>),
   Resume(Vec<VecDeque<i32>>),
}

fn permutations() -> Pin<Box<dyn Coroutine<FnArg,Yield=VecDeque<i32>,Return=Vec<VecDeque<i32>>>>>
{
   Box::pin(#[coroutine] |list: FnArg| {
      let mut list = match list
      {
         FnArg::Initial(value) => value,
         FnArg::Resume(_) => panic!("permutations() coroutine is started with a resume value!"),
      };
      let mut result: Vec<VecDeque<i32>> = Vec::new();
      if list.len() == 1
      {
         result.push(VecDeque::from(list.to_owned()));
         return result;
      }
      for i in 0..list.len()
      {
         if i > 0
         {
            list.swap(0, i);
         }
         let mut part = list.clone();
         let _ = part.pop_front();
         let nextstep = match (yield part)
         {
            FnArg::Initial(_) => panic!("permutations() coroutine is resumed with a start value!"),
            FnArg::Resume(value) => value,
         };
         for j in nextstep
         {
            let mut j = j;
            j.push_front(list[0]);
            result.push(j);
         }
         if i > 0
         {
            list.swap(0, i);
         }
      }
      result
   })
}

enum StackFrameState
{
   Incomplete(Pin<Box<dyn Coroutine<FnArg,Yield=VecDeque<i32>,Return=Vec<VecDeque<i32>>>>>),
   Complete(Vec<VecDeque<i32>>),
}

#[inline]
fn unwrap_incomplete(v: &mut StackFrameState) -> &mut Pin<Box<dyn Coroutine<FnArg,Yield=VecDeque<i32>,Return=Vec<VecDeque<i32>>>>>
{
   match v
   {
      StackFrameState::Incomplete(value) => value,
      StackFrameState::Complete(_) => panic!("impossible state: two complete stack frames on a virtual stack"),
   }
}

struct StackFrame
{
   state: StackFrameState,
   data: VecDeque<i32>,
}

fn main()
{
   let mut args = std::env::args();
   let _name = args.next().unwrap();
   let mut values: VecDeque<i32> = VecDeque::new();
   for i in args
   {
      let num = match i32::from_str_radix(&i, 10)
      {
         Ok(v) => v,
         Err(_) => { break; },
      };
      values.push_back(num);
   }
   let mut result: Vec<VecDeque<i32>> = Vec::new();
   let mut permutator = permutations();
   let mut virtualstack: Vec<StackFrame> = Vec::new();
   let mut arg = FnArg::Initial(values);
   let mut data: VecDeque<i32> = match permutator.as_mut().resume(arg)
   {
      CoroutineState::Complete(value) => {
         for i in value
         {
            for j in i
            {
               print!("{j} ");
            }
            println!();
         }
         return;
      },
      CoroutineState::Yielded(value) => value,
   };
   virtualstack.push(StackFrame { state: StackFrameState::Incomplete(permutator), data: data });
   while let Some(mut frame) = virtualstack.pop()
   {
      match frame.state
      {
         StackFrameState::Complete(value) => {
            let oldframe = match virtualstack.last_mut()
            {
               Some(f) => f,
               None => {
                  result = value;
                  break;
               },
            };
            match unwrap_incomplete(&mut oldframe.state).as_mut().resume(FnArg::Resume(value))
            {
               CoroutineState::Complete(retval) => {
                  oldframe.state = StackFrameState::Complete(retval);
               },
               CoroutineState::Yielded(yvalue) => {
                  oldframe.data = yvalue;
               },
            }
         },
         StackFrameState::Incomplete(ref mut fnptr) => {
            let yvalue = core::mem::take(&mut frame.data);
            virtualstack.push(frame);
            let mut perm = permutations();
            match perm.as_mut().resume(FnArg::Initial(yvalue))
            {
               CoroutineState::Complete(value) => {
                  virtualstack.push(StackFrame { state: StackFrameState::Complete(value), data: VecDeque::new() });
               },
               CoroutineState::Yielded(value) => {
                  virtualstack.push(StackFrame { state: StackFrameState::Incomplete(perm), data: value });
               },
            }
         },
      }
   }
   for i in result
   {
      for j in i
      {
         print!("{j} ");
      }
      println!();
   }
}

In the iterating version, a caller creates a "virtual stack" on the heap. Each time a permutation algorithm function would otherwise call itself, it yields. A caller then handles invocation of the permutation coroutine with a necessary argument and resumes the previous coroutine with the value that would be returned by invocation of the recursive function.

Looks convoluted. As I understand it the algorithm is still recursive, you’re just using Vec<_> instead of the program stack. And inside the coroutine, yield (args) basically means recurse(args).

There are probably better ways to protect from stack overflows but sticking with coroutines, I think there is some room for improvement. For example you don’t need StackFrameState::Complete: there can only be at most one on the stack anyway, and handling it is a chore. Similarly you don’t need StackFrame::data as you only use the topmost value, to pass to the next call. Removing that, the iteration boils down to this:

let mut arg = FnArg::Initial(values);
let mut stack = Vec::new();
stack.push(permutations());
while let Some(permutator) = stack.last_mut() {
	match permutator.as_mut().resume(arg) {
		CoroutineState::Yielded(recurse_arg) => {
			arg = FnArg::Initial(recurse_arg);
			stack.push(permutations());
		}
		CoroutineState::Complete(ret) => {
			arg = FnArg::Resume(ret);
			stack.pop();
		}
	}
}
let FnArg::Resume(result) = arg else { panic!() };

This can probably be further improved but I’d rethink naming. For example Initial and Resume could be Call and Return because that’s what they represent in the algorithm.

1 Like