How to properly use channel for coroutine communication

Hi I'm new to rust async / .await. While playing with it, I tried to port the concurrent-prime-sieve example on golang homepage to rust but could not manage to build a solution that runs as fast.

I first tried channel / sync_channel in std::sync::mpsc, but they do not seem to provide async api and were slow for the task. Then I tried async_std::channel, it works, but not as fast as the golang program, appreciate if anyone can help point out where I'm obviously wrong and improve the code, thanks!

The program creates another coroutine for every prime number found, it's mainly about testing coroutine perf instead of building the fastest prime-finding program.

My result shows below rust program is ~2x slower when input n is 5000.

Disclaimer: I'm not here to claim which one is faster but only to hear guidance on how to write it properly in rust, thanks!

use async_std::{
    channel,
    channel::{Receiver, Sender},
    task,
    task::Task,
};

fn main() {
    let n = std::env::args_os()
        .nth(1)
        .and_then(|s| s.into_string().ok())
        .and_then(|s| s.parse().ok())
        .unwrap_or(100);

    task::block_on(async_main(n));
}

async fn async_main(n: usize) -> anyhow::Result<(), anyhow::Error> {
    let (sender, mut receiver) = channel::bounded::<usize>(2);
    task::spawn(generate(sender));
    for _i in 0..n {
        let prime = receiver.recv().await?;
        println!("{}", prime);
        let (sender_next, receiver_next) = channel::bounded::<usize>(2);
        task::spawn(filter(receiver, sender_next, prime));
        receiver = receiver_next;
    }
    Ok(())
}

async fn generate(sender: Sender<usize>) -> anyhow::Result<(), anyhow::Error> {
    let mut i = 2;
    loop {
        sender.send(i).await?;
        i += 1;
    }
}

async fn filter(
    receiver: Receiver<usize>,
    sender: Sender<usize>,
    prime: usize,
) -> anyhow::Result<(), anyhow::Error> {
    loop {
        let i = receiver.recv().await?;
        if i % prime != 0 {
            sender.send(i).await?;
        }
    }
    Ok(())
}

original go version:

// A concurrent prime sieve from go officical example

package main

import (
	"fmt"
	"os"
	"strconv"
)

// Send the sequence 2, 3, 4, ... to channel 'ch'.
func Generate(ch chan<- int) {
	for i := 2; ; i++ {
		ch <- i // Send 'i' to channel 'ch'.
	}
}

// Copy the values from channel 'in' to channel 'out',
// removing those divisible by 'prime'.
func Filter(in <-chan int, out chan<- int, prime int) {
	for {
		i := <-in // Receive value from 'in'.
		if i%prime != 0 {
			out <- i // Send 'i' to 'out'.
		}
	}
}

// The prime sieve: Daisy-chain Filter processes.
func main() {
	n := 1000
	if len(os.Args) > 1 {
		if _n, err := strconv.Atoi(os.Args[1]); err == nil {
			n = _n
		}
	}
	ch := make(chan int) // Create a new channel.
	go Generate(ch)      // Launch Generate goroutine.
	for i := 0; i < n; i++ {
		prime := <-ch
		fmt.Println(prime)
		ch1 := make(chan int)
		go Filter(ch, ch1, prime)
		ch = ch1
	}
}

I rewrote it to use Tokio instead of async-std, and it became twice as fast on my laptop.

$ time ./async-std 5000 > /dev/null
./async-std 5000 > /dev/null  9.03s user 0.26s system 760% cpu 1.221 total
$ time ./tokio 5000 > /dev/null
./tokio 5000 > /dev/null  4.66s user 0.61s system 775% cpu 0.680 total
use tokio::sync::mpsc::{self, Sender, Receiver};

fn main() {
    let n = std::env::args_os()
        .nth(1)
        .and_then(|s| s.into_string().ok())
        .and_then(|s| s.parse().ok())
        .unwrap_or(100);

    async_main(n).unwrap();
}

#[tokio::main]
async fn async_main(n: usize) -> anyhow::Result<(), anyhow::Error> {
    let (sender, mut receiver) = mpsc::channel::<usize>(2);
    tokio::spawn(generate(sender));
    for _i in 0..n {
        let prime = receiver.recv().await.unwrap();
        println!("{}", prime);
        let (sender_next, receiver_next) = mpsc::channel::<usize>(2);
        tokio::spawn(filter(receiver, sender_next, prime));
        receiver = receiver_next;
    }
    Ok(())
}

async fn generate(sender: Sender<usize>) -> anyhow::Result<(), anyhow::Error> {
    let mut i = 2;
    while sender.send(i).await.is_ok() {
        i += 1;
    }
    Ok(())
}

async fn filter(
    mut receiver: Receiver<usize>,
    sender: Sender<usize>,
    prime: usize,
) -> anyhow::Result<(), anyhow::Error> {
    while let Some(i) = receiver.recv().await {
        if i % prime != 0 {
            if sender.send(i).await.is_err() {
                return Ok(());
            }
        }
    }
    Ok(())
}
3 Likes

Thanks! I also get ~2x perf boost with the change u suggested.