In order to allow easy interchangeability between using TcpListener
and TlsAcceptor
, I needed to add a wrapper around TlsAcceptor
that allows it to be polled like a stream. In my tests, the wrapper appears to work, but I wanted the critique of the async experts here (you know who you are!) in case there's a better way of doing this. Specifically, visit the Stream
implementation for TlsListener
pub struct TlsListener {
inner: TcpListener,
tls_acceptor: Arc<TlsAcceptor>,
queue: Vec<Pin<Box<dyn Future<Output=Result<TlsStream<TcpStream>, tokio_native_tls::native_tls::Error>>>>>
}
impl TlsListener {
pub fn new(inner: TcpListener, identity: Identity) -> std::io::Result<Self> {
Ok(Self { inner, tls_acceptor: Arc::new(TlsAcceptor::from(tokio_native_tls::native_tls::TlsAcceptor::new(identity).map_err(|err| std::io::Error::new(std::io::ErrorKind::ConnectionRefused, err))?)), queue: Vec::new()})
}
pub fn new_pkcs<P: AsRef<Path>, T: AsRef<str>>(path: P, password: T, inner: TcpListener) -> std::io::Result<Self> {
let identity = Self::load_tls_pkcs(path, password).map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err.into_string()))?;
Self::new(inner, identity)
}
/// Given a path and password, returns the asymmetric crypto identity
pub fn load_tls_pkcs<P: AsRef<Path>, T: AsRef<str>>(path: P, password: T) -> Result<Identity, NetworkError> {
let bytes = std::fs::read(path).map_err(|err| NetworkError::Generic(err.to_string()))?;
Identity::from_pkcs12(&bytes, password.as_ref()).map_err(|err| NetworkError::Generic(err.to_string()))
}
fn poll_future(future: &mut Pin<Box<dyn Future<Output=Result<TlsStream<TcpStream>, tokio_native_tls::native_tls::Error>>>>, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
future.as_mut().poll(cx).map(|r| Some(r.map(|stream| {
let peer_addr = stream.get_ref().get_ref().get_ref().peer_addr().unwrap();
(stream, peer_addr)
}).map_err(|err| std::io::Error::new(std::io::ErrorKind::ConnectionRefused, err))))
}
}
impl Stream for TlsListener {
type Item = std::io::Result<(TlsStream<TcpStream>, SocketAddr)>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Self {
inner,
tls_acceptor,
queue
} = &mut *self;
// on first poll, there won't be any futures loaded. We skip this immediately on first-poll
for (idx, future) in queue.iter_mut().enumerate() {
return match futures::ready!(Self::poll_future(future, cx)) {
Some(Ok((stream, peer_addr))) => {
let _ = queue.remove(idx);
Poll::Ready(Some(Ok((stream, peer_addr))))
}
Some(Err(err)) => {
let _ = queue.remove(idx);
Poll::Ready(Some(Err(err)))
}
None => {
let _ = queue.remove(idx);
log::error!("TlsListener: Polled none");
Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "Polled none"))))
}
}
}
// Now, poll the TcpListener
match futures::ready!(Pin::new(inner).poll_accept(cx)) {
Ok((stream, _peer_addr)) => {
// We received a TcpStream. Upgrade it to a TlsStream by storing a pollable future
let tls_acceptor = tls_acceptor.clone();
let future = async move {
tls_acceptor.accept(stream).await
};
let mut future = Box::pin(future) as Pin<Box<dyn Future<Output=Result<TlsStream<TcpStream>, tokio_native_tls::native_tls::Error>>>>;
// poll the future once to register any internal wakers
let poll_res = Self::poll_future(&mut future, cx);
queue.push(future);
poll_res
}
Err(err) => {
log::error!("TLS Listener error: {:?}", err);
Poll::Ready(None)
}
}
}
}