I've had some success, here is a standalone implementation of a channel that can take arbitrary type and I hope it helps someone:
[dependencies]
derive-error = "0.0.4"
futures = { version = "0.3" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "0.2", features = ["full"] }
tokio-serde = { version = "0.6", features = ["bincode"] }
tokio-util = { version = "0.2", features = ["codec"] }
use derive_error::Error;
use futures::prelude::*;
use serde::{Deserialize, Serialize};
use std::net::Ipv4Addr;
use tokio::net::{TcpListener, TcpStream};
use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio_serde::SymmetricallyFramed;
use tokio_serde::formats::*;
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
#[derive(Debug, Error)]
pub enum Error {
IO(std::io::Error),
ReaderError,
}
type SymmetricalReader<'a, T> = SymmetricallyFramed<
FramedRead<ReadHalf<'a>, LengthDelimitedCodec>,
T,
SymmetricalBincode<T>>;
type SymmetricalWriter<'a, T> = SymmetricallyFramed<
FramedWrite<WriteHalf<'a>, LengthDelimitedCodec>,
T,
SymmetricalBincode<T>>;
pub struct Receiver<'a, T> {
pub reader: SymmetricalReader<'a, T>,
}
impl<'a, T> Receiver<'a, T> where
SymmetricalReader<'a, T> : TryStream<Ok=T> + Unpin,
{
pub async fn recv(&mut self) -> Result<Option<T>, Error> {
// TODO: ReaderError should capture the error returned by `try_next`.
if let Ok(msg) = self.reader.try_next().await {
Ok(msg)
} else {
Err(Error::ReaderError)
}
}
}
pub struct Sender<'a, T> {
pub writer: SymmetricalWriter<'a, T>,
}
impl<T> Sender<'_, T> where
T: for<'a> Deserialize<'a> + Serialize + Unpin
{
pub async fn send(&mut self, item: T) -> Result<(), Error> {
Ok(self.writer.send(item).await.map_err(Error::IO)?)
}
}
pub struct Channel<T> {
socket: TcpStream,
ghost: std::marker::PhantomData<T>,
}
impl<T> Channel<T> where
T: for<'a> Deserialize<'a> + Serialize,
{
pub async fn connect(address: &Ipv4Addr, port: u16) ->
Result<Channel<T>, Error>
{
let address = format!("{}:{}", address, port);
let socket = TcpStream::connect(&address).await.map_err(Error::IO)?;
println!("connection established: {:?}", address);
Ok(Channel{ socket, ghost: Default::default() })
}
pub async fn accept(address: &Ipv4Addr, port: u16) ->
Result<Channel<T>, Error>
{
let address = format!("{}:{}", address, port);
let mut listener = TcpListener::bind(&address).await.map_err(Error::IO)?;
let (socket, address) = listener.accept().await.map_err(Error::IO)?;
println!("connection accepted: {:?}", address);
Ok(Channel{ socket, ghost: Default::default() })
}
pub fn split(&mut self) ->
(Sender<'_, T>, Receiver<'_, T>)
{
let (reader, writer) = self.socket.split();
let reader: FramedRead<
ReadHalf,
LengthDelimitedCodec,
> = FramedRead::new(reader, LengthDelimitedCodec::new());
let reader: SymmetricalReader<T> = SymmetricallyFramed::new(
reader, SymmetricalBincode::default());
let writer: FramedWrite<
WriteHalf,
LengthDelimitedCodec,
> = FramedWrite::new(writer, LengthDelimitedCodec::new());
let writer: SymmetricalWriter<T> = SymmetricallyFramed::new(
writer, SymmetricalBincode::default());
(Sender{ writer }, Receiver{ reader })
}
}
#[cfg(test)]
mod tests {
use crate::channel::*;
use std::str::FromStr;
#[test]
fn send_recv() {
tokio::runtime::Runtime::new().expect("failed to create Tokio runtime").block_on(async {
let handle_1 = tokio::spawn(async {
let address = Ipv4Addr::from_str("127.0.0.1")
.expect("failed to construct address");
let mut channel: Channel<String> = Channel::accept(&address, 21000)
.await
.expect("failed to accept connection");
let (mut sender, mut receiver) = channel.split();
// Send message:
sender.send(String::from("123")).await.unwrap();
// Receive message:
let msg = receiver.recv().await.unwrap();
assert_eq!(msg, Some(String::from("321")));
// Send message:
sender.send(String::from("456")).await.unwrap();
// Receive message:
let msg = receiver.recv().await.unwrap();
assert_eq!(msg, Some(String::from("654")));
});
let handle_2 = tokio::spawn(async {
let address = Ipv4Addr::from_str("127.0.0.1")
.expect("failed to construct address");
let mut channel: Channel<String> = Channel::connect(&address, 21000)
.await
.expect("failed to accept connection");
let (mut sender, mut receiver) = channel.split();
// Receive message:
let msg = receiver.recv().await.unwrap();
assert_eq!(msg, Some(String::from("123")));
// Send message:
sender.send(String::from("321")).await.unwrap();
// Receive message:
let msg = receiver.recv().await.unwrap();
assert_eq!(msg, Some(String::from("456")));
// Send message:
sender.send(String::from("654")).await.unwrap();
});
handle_2.await.unwrap();
handle_1.await.unwrap();
});
}
}
The unit test is passing and I'm happy with this implementation. The only thing I'm stuck at is how can I capture/wrap the error returned by self.reader.try_next().await. I'm getting this compiler error:
error[E0277]: `?` couldn't convert the error to `channel::Error`
--> common/src/channel.rs:35:40
|
35 | Ok(self.reader.try_next().await?)
| ^ the trait `std::convert::From<<tokio_serde::Framed<tokio_util::codec::framed_read::FramedRead<tokio::net::tcp::split::ReadHalf<'a>, tokio_util::codec::length_delimited::LengthDelimitedCodec>, T, T, tokio_serde::formats::bincode::Bincode<T, T>> as futures_core::stream::TryStream>::Error>` is not implemented for `channel::Error`
|
= note: the question mark operation (`?`) implicitly performs a conversion on the error value using the `From` trait
= help: the following implementations were found:
<channel::Error as std::convert::From<std::io::Error>>
= note: required by `std::convert::From::from`
I'd appreciate advice/suggestion. Thank you!