Reading tokio TcpStream without mutating it

So, the problem is that my server needs to handle connections of two types: basic HTTP requests or WebSocket handshakes. I'm using tokio-tungstenite library for performing WS handshakes. It takes ownership of the entire TcpStream and converts it to WebSocketStream. But I also need the ability to handle normal HTTP request and in order to do that I want to read just HTTP header from TcpStream and if it looks like the start of WS handshake then just pass TcpStream (without any mutation) to tungstenite library. Otherwise, handle HTTP requests as normal.

Just a sample of code to give some clarity:

async fn handle_connection(mut stream: TcpStream, state: Arc<State>) {
  
  // TODO: read without mutation somehow
  let mut buffer = /* */;
  if let Ok(size) = stream.read(&mut buffer).await {
    /* handle */
  }

  // TODO: only in case when it looks like WebSocket handshake
  let mut ws = match accept_async(stream).await {
    Ok(websocket) => /* handle WS connection */,
    Err(error) => {
      Logger::warning(format!("Failed to accept WebSocket connection: {}", error));
      return
    }
  };

Is it possible? Or I need to find a way to pass just buffer with header to the WS library somehow?

No, instead you need to use the magic function from_raw_socket. Here is part of some working code I have that uses a hyper web server.

async fn listen(shared: Arc<Shared>) {
    let addr = shared.settings.listen;

    let make_service = make_service_fn(move |client: &AddrStream| {
        // This closure is called per connection.
        let ip = client.remote_addr();
        let shared = shared.clone();
        async move {
            Ok::<_, Infallible>(service_fn(move |req| {
                // This closure is called per request on the connection.
                let shared = shared.clone();
                handle_request(ip, req, shared)
            }))
        }
    });
    let server = Server::bind(&addr).serve(make_service);

    if let Err(e) = server.await {
        eprintln!("server error: {}", e);
    }
}

async fn handle_request(
    ip: SocketAddr,
    req: Request<Body>,
    shared: Arc<Shared>,
) -> Result<Response<Body>, BoxError> {
    match req.headers().get(SEC_WEBSOCKET_KEY).map(|key| key.clone()) {
        Some(key) => start_websocket(ip, req, shared, key).await,
        None => ordinary_request(ip, req, shared).await,
    }
}

async fn start_websocket(
    ip: SocketAddr,
    req: Request<Body>,
    shared: Arc<Shared>,
    key: HeaderValue,
) -> Result<Response<Body>, BoxError> {
    tokio::spawn(async move {
        let upgraded = req.into_body().on_upgrade().await?;
        let ws_stream = WebSocketStream::from_raw_socket(
            upgraded,
            tokio_tungstenite::tungstenite::protocol::Role::Server,
            None,
        ).await;
        let id = shared.id_count.fetch_add(1, atomic::Ordering::Relaxed);
        let (sender, receiver) = channel(32);
        let _ = shared.sender_command.send(Command::NewConnection {
            channel: sender,
            id,
        });
        websocket_handle_events(ip, shared, ws_stream, id, receiver).await
    });

    use hyper::header::{UPGRADE, CONNECTION, SEC_WEBSOCKET_ACCEPT};

    let mut upgrade_response = Response::builder()
        .status(StatusCode::SWITCHING_PROTOCOLS)
        .body(Body::empty())
        .unwrap();

    let headers = upgrade_response.headers_mut();
    headers.insert(UPGRADE, HeaderValue::from_static("WebSocket"));
    headers.insert(CONNECTION, HeaderValue::from_static("Upgrade"));
    headers.insert(SEC_WEBSOCKET_ACCEPT, key);

    Ok(upgrade_response)
}

async fn ordinary_request(
    ip: SocketAddr,
    req: Request<Body>,
    _shared: Arc<Shared>,
) -> Result<Response<Body>, BoxError> {
    ...
}

Thanks!

Ah, here's the import list.

use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{Arc, atomic};
use std::path::{PathBuf, Path};
use tokio::runtime::Runtime;

use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Server, Request, Response, Body, StatusCode};
use hyper::header::{HeaderValue, SEC_WEBSOCKET_KEY, CONTENT_TYPE, CONTENT_LENGTH};
use tokio_tungstenite::WebSocketStream;

use crate::websocket::websocket_handle_events;
use crate::{BoxError, Settings, Shared};
use crate::state::Command;

use tokio::sync::mpsc::channel;
use std::sync::atomic::AtomicU64;

You might find that useful too.

1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.