How do I read the requested domain from a TCP stream in tokio?

I have a case where I need to load dynamically the appropriate TLS certificate for each domain name in one webserver.

How do I do this in Tokio?
The code is something like:

extern crate tokio;
extern crate tokio_rustls;

use std::fs::File;
use std::io::BufReader;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use tokio::io;
use tokio::net::{TcpStream,TcpListener};
use tokio::prelude::{Future, Stream};
use tokio_rustls::{
    rustls::{
        internal::pemfile::{certs, pkcs8_private_keys},
        Certificate, NoClientAuth, PrivateKey, ServerConfig,
    },
    TlsAcceptor,
};


fn main() {
    let addr = "127.0.0.1:8081".to_socket_addrs().unwrap().next().unwrap();

    let socket = TcpListener::bind(&addr).unwrap();
    let done = socket.incoming().for_each(move |stream| {
        //TODO: read the domain here
        let domain = somehow_read_the_requested_domain_from(&stream);

        let cert_file = format!("/home/user/{}/cert1.pem", domain);
        let key_file = format!("/home/user/{}/privkey1.pem", domain);

        let mut config = ServerConfig::new(NoClientAuth::new());
        config
            .set_single_cert(load_certs(&cert_file), load_keys(&key_file).remove(0))
            .expect("invalid key or certificate");

        let config = TlsAcceptor::from(Arc::new(config));
        let addr = stream.peer_addr().ok();
        let done = config
            .accept(stream)
            .and_then(move |stream| {
                io::write_all(
                    stream,
                    format!(
                        "HTTP/1.0 200 ok\r\n\
                         Connection: close\r\n\
                         Content-length: 12\r\n\
                         \r\n\
                         Hello world {}!",
                        domain
                    ),
                )
            })
            .and_then(|(stream, _)| io::flush(stream))
            .map(move |_| println!("Accept: {:?}", addr))
            .map_err(move |err| println!("Error: {:?} - {:?}", err, addr));
        tokio::spawn(done);

        Ok(())
    });

    tokio::run(done.map_err(drop));
}

fn load_certs(path: &str) -> Vec<Certificate> {
    certs(&mut BufReader::new(File::open(path).unwrap())).unwrap()
}

fn load_keys(path: &str) -> Vec<PrivateKey> {
    pkcs8_private_keys(&mut BufReader::new(File::open(path).unwrap())).unwrap()
}

// ACTUALLY READ THE DOMAIN NAME FROM TCP STREAM OR TCP LISTENER
fn somehow_read_the_requested_domain_from(stream: &TcpStream) -> String {
    "domain.name".to_string()
}

Take a look at rustls::ResolvesServerCertUsingSNI - Rust

1 Like

This still doesn't let me read the requested domain, but this does the job. I think, I don't really need to know the domain name. It's up for the resolver to do that for me.

This topic was automatically closed 90 days after the last reply. New replies are no longer allowed.