How do I read the requested domain from a TCP stream in tokio?

I have a case where I need to load dynamically the appropriate TLS certificate for each domain name in one webserver.

How do I do this in Tokio?
The code is something like:

extern crate tokio;
extern crate tokio_rustls;

use std::fs::File;
use std::io::BufReader;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use tokio::io;
use tokio::net::{TcpStream,TcpListener};
use tokio::prelude::{Future, Stream};
use tokio_rustls::{
    rustls::{
        internal::pemfile::{certs, pkcs8_private_keys},
        Certificate, NoClientAuth, PrivateKey, ServerConfig,
    },
    TlsAcceptor,
};


fn main() {
    let addr = "127.0.0.1:8081".to_socket_addrs().unwrap().next().unwrap();

    let socket = TcpListener::bind(&addr).unwrap();
    let done = socket.incoming().for_each(move |stream| {
        //TODO: read the domain here
        let domain = somehow_read_the_requested_domain_from(&stream);

        let cert_file = format!("/home/user/{}/cert1.pem", domain);
        let key_file = format!("/home/user/{}/privkey1.pem", domain);

        let mut config = ServerConfig::new(NoClientAuth::new());
        config
            .set_single_cert(load_certs(&cert_file), load_keys(&key_file).remove(0))
            .expect("invalid key or certificate");

        let config = TlsAcceptor::from(Arc::new(config));
        let addr = stream.peer_addr().ok();
        let done = config
            .accept(stream)
            .and_then(move |stream| {
                io::write_all(
                    stream,
                    format!(
                        "HTTP/1.0 200 ok\r\n\
                         Connection: close\r\n\
                         Content-length: 12\r\n\
                         \r\n\
                         Hello world {}!",
                        domain
                    ),
                )
            })
            .and_then(|(stream, _)| io::flush(stream))
            .map(move |_| println!("Accept: {:?}", addr))
            .map_err(move |err| println!("Error: {:?} - {:?}", err, addr));
        tokio::spawn(done);

        Ok(())
    });

    tokio::run(done.map_err(drop));
}

fn load_certs(path: &str) -> Vec<Certificate> {
    certs(&mut BufReader::new(File::open(path).unwrap())).unwrap()
}

fn load_keys(path: &str) -> Vec<PrivateKey> {
    pkcs8_private_keys(&mut BufReader::new(File::open(path).unwrap())).unwrap()
}

// ACTUALLY READ THE DOMAIN NAME FROM TCP STREAM OR TCP LISTENER
fn somehow_read_the_requested_domain_from(stream: &TcpStream) -> String {
    "domain.name".to_string()
}

Take a look at https://docs.rs/rustls/0.16.0/rustls/struct.ResolvesServerCertUsingSNI.html

1 Like

This still doesn't let me read the requested domain, but this does the job. I think, I don't really need to know the domain name. It's up for the resolver to do that for me.