Tokio_tungstenite: WebSocket protocol error: Handshake not finished

I have extended this code to also have a endpoint to connect with a websocket.

The function that is executed is this:

    async fn websocket(&self, req: Http, stream: &mut TcpStream) -> Result<HttpResponse, String> {

        let mut upgraded = match tokio_tungstenite::accept_async(stream).await {
            Ok(conn) => conn,
            Err(e) => {
                eprintln!("Err {}", e);
                return Err("".into());
            }
        };



        unimplemented!()
    }

The problem is that when using Postman to connect to ws://127.0.0.1:8080/ws, Postman keeps stuck on "Connecting...". Using a basic Python3 client gives the runtime error in Rust:


Err WebSocket protocol error: Handshake not finished
Error:
HTTP/1.1 500 Switching Protocols
Content-Type: text/plain

I thought this was maybe because I did not send the response to the client to upgrade the connection. So I tried this:

    async fn websocket(&self, req: Http, stream: &mut TcpStream) -> Result<HttpResponse, String> {

        // Get Websocket key from header
        let ws_key = match req.headers.get("Sec-WebSocket-Key") {
            Some(key) => key,
            None => {
                return Ok(HttpResponse {
                    status: 400,
                    headers: vec![
                        ("Content-Type".into(), "text/plain".into()),
                    ],
                    body: "".to_string(),
                });
            }
        };
       
        // Generate Upgrade response
        let response = HttpResponse {
            status: 101,
            headers: vec![
                ("Upgrade".into(), "websocket".into()),
                ("Connection".into(), "Upgrade".into()),
                ("Sec-WebSocket-Accept".into(), ws_key.into()),
            ],
            body: "".into(),
        };

        println!("{}", response.to_string());
        
        // Send response back to client to finish handshake
        stream.write_all(format!("{}\n", response).as_bytes()).await.unwrap();
        stream.flush().await.unwrap();
        
        // Now use upgrade connection to init websocket connection
        let mut upgraded = match tokio_tungstenite::accept_async(stream).await {
            Ok(conn) => conn,
            Err(e) => {
                eprintln!("Err {}", e);
                return Err("".into());
            }
        };



        unimplemented!()
    }

Response:

HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: dWiZ5+ofK9wu7ZZVK1IXfQ==

Rust gives the same runtime error and the Python client yields Error did not receive a valid HTTP response.

Code:

// Tokio
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
// Futures
use futures_util::{future, pin_mut, StreamExt, SinkExt};
use futures_channel::mpsc::{unbounded, UnboundedSender};
use futures::FutureExt;
use futures::future::BoxFuture;
// String manipulation (JSON, serialization, regexes,...)
use serde::{Deserialize, Serialize};
use regex::Regex;
use serde_json;
// std
use std::sync::Arc;
use tokio::sync::Mutex;
use std::error::Error;
// Other
use async_trait::async_trait;
use reqwest;
use url;

mod api;
use api::ExchangeAPI;

// Define type for tuple containing routes. (HTTP-method, path, fn pointer to route function).
//type Route<'a> = (&'static str, &'static str, fn(&'a Routes, Http, &'a mut TcpStream) -> BoxFuture<'a, Result<HttpResponse, String>>);
type Route = (
    &'static str,
    &'static str,
    for<'a> fn(&'a Routes, Http, &'a mut TcpStream)
        -> BoxFuture<'a, Result<HttpResponse, String>>
);

// Macros
//
// This macro improves the readability of the code. Instead of
// ("GET", "/", |r, h| Routes::index(r, h).boxed()) the programmer can just
// type route!("GET", "/", Routes::index).
macro_rules! route {
    ($m:expr, $p:expr, $fn:path) => {
        ($m, $p, |r, h, s| $fn(r, h, s).boxed())
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    
    // Start server.
    let server = TcpListener::bind("127.0.0.1:8080").await.expect("Failed to bind to 127.0.0.1:8080");
    println!("Listening on 127.0.0.1:8080...");

    // Define all your routes here.
    let routes: Arc<std::vec::Vec<Route>> = Arc::new(vec![
    /*    route!("GET", "/", Routes::index),
        route!("GET", "/welcome/{name}", Routes::welcome),
        route!("GET", "/welcome/{name}/{age}", Routes::welcome_age),
        route!("POST", "/print", Routes::print_name),
        route!("GET", "/html-page", Routes::html_page),*/
        route!("GET", "/ws", Routes::websocket),
    ]);

    // Spawn process to handle incoming streams.
    loop {
        let cloned_routes = routes.clone();
        let (mut stream, _) = server.accept().await.unwrap();
        tokio::spawn(async move {
            if let Err(e) = process(&mut stream, &cloned_routes).await {
                eprintln!("Error: {}", e);
            }
        });
    }

}

// Process the incoming streams.
async fn process(stream: &mut TcpStream, routes: &std::vec::Vec<Route>) -> Result<(), Box<dyn Error>> {
 

    // Buffer holding the stream in bytes.
    let mut buffer = [0; 1024];
    if let Err(e) = stream.read(&mut buffer).await {
        eprintln!("Error: {}", e);
        
        let response = Routes::internal_server_error().await;
        stream.write_all(response.to_string().as_bytes()).await.unwrap();
        stream.flush().await.unwrap();

        return Ok(());
    }
   
    // Create a string from the contents of the buffer.
    let stream_string = String::from_utf8_lossy(&buffer[..]);

    println!("{}", stream_string);

    // Turn the raw stream as string into a workable Http-object.
    let mut http = Http::from_str(&stream_string);
    
    // Check if there is a route matching the HTTP-method and path.
    let response = match Routes::goto(&mut http, stream, routes).await {
        Ok(r) => r,
        Err(e) => {
            eprintln!("Error: {}", e);
            Routes::internal_server_error().await
        }
    };

    println!("{}", response);

    // Write response to stream as answer.
    stream.write_all(format!("{}\n", response).as_bytes()).await.unwrap();
    stream.flush().await.unwrap();

    Ok(())
}

// Struct for incoming HTTP requests.
#[derive(Clone, Debug)]
struct Http {
    method: String,
    path: String,
    params: std::collections::HashMap<String, String>,
    headers: std::collections::HashMap<String, String>,
    body: String,
}

impl Http {
    // Generate Http-object from raw stream string.
    // Note: This function assumes that all incoming requests are well formed in the expected
    // format.
    fn from_str(str: &str) -> Self {
        let lines : std::vec::Vec<&str> = str.lines().collect();
        let request_line : std::vec::Vec<&str> = lines[0].split_whitespace().collect();
        
        let method = request_line[0];
        let path = request_line[1];
     
        let mut headers : std::collections::HashMap<String, String> = std::collections::HashMap::new();
        for line in &lines[1..] {
            if line.is_empty() {
                break;
            }

            let parts : std::vec::Vec<&str> = line.splitn(2, ":").collect();
            if parts.len() == 2 {
                headers.insert(
                    parts[0].trim().into(),
                    parts[1].trim().into()
                );
            }
        }

        let mut body = "".to_string();
        let mut sep_found = false;
        for line in lines {
            if line == "" {
                sep_found = true;
                continue;
            }

            if sep_found {
                body = format!("{}\n{}", body, line);
            }
        }

        // Remove trailing zero-bytes.
        body = body.chars().filter(|&c| c != '\0').collect();

        Http {
            method: method.into(),
            path: path.into(),
            headers: headers,
            // Params will be set in Route::goto where we find the matching route.
            params: std::collections::HashMap::new(),
            body: body,
        }
    }
}

// Struct for a HTTP Response.
// This way we can easily define a response and use the string formatter to write
// the response to the stream as string.
struct HttpResponse {
    status: u16,
    headers: std::vec::Vec<(String, String)>,
    body: String,
}

// Convert HttpResponse to string.
impl std::fmt::Display for HttpResponse {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        let mut headers : String = "".into();
        for (key, value) in self.headers.iter() {
            headers = format!("{}{}: {}\n", headers, key, value);
        }
        headers.pop();

        write!(
            f, 
            "HTTP/1.1 {status} Switching Protocols\n{headers}",
            status = self.status,
            headers = headers,
          //  body = self.body
        )
    }
}

// Struct Routes holds the following:
// - A method `goto` to find and execute a the fn pointer to a route for the matching Http-method
// and path.
// - Defined custom routes.
// - Predefined standard routes (e.g 404, 500)
struct Routes;
impl Routes {

    // Find fn pointer to execute the route given the Http-method and path from the incoming
    // request.
    // We pass the TcpStream in case a WebSocket-endpoint is called.
    async fn goto(req: &mut Http, stream: &mut TcpStream, routes: &std::vec::Vec<Route>) -> Result<HttpResponse, String> {
        
        if let Some(r) = routes.iter().find(|&&(m, p, _)| m == req.method && check_path(p, &*req.path)) {
            let (_, p, ro) = *r;
            req.params = extract_params(p, &*req.path);
            return ro(&Routes, req.clone(), stream).await;
        } else {
            return Ok(Routes::not_found().await);
        }

        // A route can contain custom parameters e.g /user/{name}. By formatting
        // the defined routes into a regex string we can check if the defined path
        // matches the given path incoming from the request.
        //
        // :param p: Path we predefined in our routes.
        // :param path: Path extracted from the incoming request.
        fn check_path(p: &str, path: &str) -> bool {
            
            let re_custom_param = Regex::new(r"\{(\w+)\}").unwrap();
            let formatted_path = re_custom_param.replace_all(p, "(\\w+)");
            let regex_str_path = format!("^{}$", formatted_path);

            let re = Regex::new(&regex_str_path).unwrap();
            re.is_match(path)
        }

        // Extract the parameters from the path given in the request and make it
        // workable in a HashMap.
        // E.g:
        // p = /user/{firstname}/{lastname}, path = /user/foo/bar, becomes:
        // {"firstname": "foo", "lastname": "bar"}
        //
        // :param p: Path we predefined in our routes.
        // :param path: Path extracted from the incoming request.
        fn extract_params(p: &str, path: &str) -> std::collections::HashMap<String, String> {
            let mut params : std::collections::HashMap<String, String> = std::collections::HashMap::new();

            // Split both paths in pieces.
            let p_pieces = p.split("/");
            let path_pieces = path.split("/").collect::<std::vec::Vec<&str>>();

            // Loop over the pieces of our predefined route path. If a piece
            // contains {...} AKA is a custom parameter:
            // Add the string between the brackets as key and add the value of the
            // path from the request as value to the hashmap.
            let mut i = 0;
            for p_piece in p_pieces {
                let re = Regex::new(r"\{(\w+)\}").unwrap();
                if re.is_match(p_piece) {
                    let param_name = p_piece.replace("{", "").replace("}", "");
                    params.insert(param_name, path_pieces[i].into());
                }

                i = i + 1;
            }

            params
        }
    }

    // Custom routes.
    //
    /*
    async fn index(&self, req: Http, stream: TcpStream) -> Result<HttpResponse, String> {
        Ok(HttpResponse {
            status: 200,
            headers: vec![
                ("Content-Type".into(), "text/plain".into()),
            ],
            body: "This is the homepage.".into(),
        })
    }
   
    async fn welcome(&self, req: Http, st) -> Result<HttpResponse, String> {
        let name = req.params.get("name").cloned().unwrap_or_else(|| "".to_string());

        Ok(HttpResponse {
            status: 200,
            headers: vec![
                    ("Content-Type".into(), "text/plain".into()),
            ],
            body: format!("Welcome {}", name.to_string()),
        })
    }
    
    async fn welcome_age(&self, req: Http) -> Result<HttpResponse, String> {
        let name = req.params.get("name").cloned().unwrap_or_else(|| "".to_string());
        let age = req.params.get("age").cloned().unwrap_or_else(|| "".to_string());

        Ok(HttpResponse {
            status: 200,
            headers: vec![
                ("Content-Type".into(), "text/plain".into()),
            ],
            body: format!("Welcome {}, your age is {}", name.to_string(), age.to_string()),
        })
    }
    
    async fn print_name(&self, req: Http) -> Result<HttpResponse, String> {
        // Define the expected data from the POST-request.
        #[derive(Serialize, Deserialize)]
        struct Data {
            name: String,
        };

        // Make Data object from POST request body. Return error 400
        // if sent data is malformed.
        let data : Data = match serde_json::from_str(&*req.body) {
            Ok(d) => d,
            Err(e) => {
                return Ok(HttpResponse {
                    status: 400,
                    headers: vec![
                        ("Content-Type".into(), "text/plain".into()),
                    ],
                    body: "".to_string(),
                });
            }
        };

        // Print out the name. Normally we would do something advanced
        // like a database query or smth.
        println!("Name {}", data.name);

        Ok(HttpResponse {
            status: 200,
            headers: vec![
                ("Content-Type".into(), "text/plain".into()),
            ],
            body: "".to_string(),
        })
    }

    async fn html_page(&self, req: Http) -> Result<HttpResponse, String> {
        let html = 
        r#"
            <h1>Hello</h1>
            <p>This is html.</p>
        "#;

        Ok(HttpResponse {
            status: 200,
            headers: vec![
                ("Content-Type".into(), "text/plain".into()),
            ],
            body: html.into(),
        })
    }*/

    async fn websocket(&self, req: Http, stream: &mut TcpStream) -> Result<HttpResponse, String> {

        // Get Websocket key from header
        let ws_key = match req.headers.get("Sec-WebSocket-Key") {
            Some(key) => key,
            None => {
                return Ok(HttpResponse {
                    status: 400,
                    headers: vec![
                        ("Content-Type".into(), "text/plain".into()),
                    ],
                    body: "".to_string(),
                });
            }
        };
       
        // Generate Upgrade response
        let response = HttpResponse {
            status: 101,
            headers: vec![
                ("Upgrade".into(), "websocket".into()),
                ("Connection".into(), "Upgrade".into()),
                ("Sec-WebSocket-Accept".into(), ws_key.into()),
            ],
            body: "".into(),
        };

        println!("{}", response.to_string());
        
        // Send response back to client to finish handshake
        stream.write_all(format!("{}\n", response).as_bytes()).await.unwrap();
        stream.flush().await.unwrap();
        
        // Now use upgrade connection to init websocket connection
        let mut upgraded = match tokio_tungstenite::accept_async(stream).await {
            Ok(conn) => conn,
            Err(e) => {
                eprintln!("Err {}", e);
                return Err("".into());
            }
        };



        unimplemented!()
    }

    // Standard routes.
    //
    async fn not_found() -> HttpResponse {
        HttpResponse {
            status: 404,
            headers: vec![
                ("Content-Type".into(), "text/plain".into()),
            ],
            body: "Not found.".into(),
        }
    }

    async fn internal_server_error() -> HttpResponse {
        HttpResponse {
            status: 500,
            headers: vec![
                ("Content-Type".into(), "text/plain".into()),
            ],
            body: "Internal server error.".into(),
        }
    }
}

(Click here for Cargo.toml)

Python3 client:

import websockets
import asyncio
import aioconsole

async def send_listener(ws):
    while True:
        msg = await ainput("> ")

async def recv_listener(ws):
    while True:
        req = await ws.recv()
        print(req)


async def conn():
    try:
        connection = websockets.connect("ws://127.0.0.1:8080/ws")

        async with connection as ws:
            send_task = asyncio.ensure_future(send_listener(ws))
            recv_task = asyncio.ensure_future(recv_listener(ws))

            await asyncio.gather(send_task, recv_task)
    except Exception as e:
        print(f"Error {e}")
        exit()

if __name__ == "__main__":
    asyncio.get_event_loop().run_until_complete(conn())


Edit: I also read in the docs

This function will internally call server::accept to create a handshake representation and returns a future representing the resolution of the WebSocket handshake. The returned future will resolve to either WebSocketStream<S> or Error depending if it’s successful or not.

So I shouldn't have to manually make the response?

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.