Password protecting my TCP stream is not working

I am trying to password protect a tcp stream, it is not working, I run handle_stream after spawning the stream and the way it works is i have a total of 4 regular channels, and a read in and read out of the tcp stream, (i have 4 so that when i write it internally its not written back out causing an infinite feedback loop), for my AppState (in a Arc RwLock), I have a enum called AuthenticatedStreamState


#[derive(Default, Debug)]
enum AuthenticatedStreamState {
    Authenticated,
    Terminate, 
    #[default]
    None
}

the idea was, if the first message was a password, it will check it, if its correct it will set the enum to Authenticated, then right after it checks if its authenticated, if not, it will set the enum to terminate, then in the main function (in the tokio thread), it is supposed to read terminate then end the connection, but it needs to end it in a way where it immediately like, the other end knows it, I try breaking from the loops which await for the next message, and use a token to detect this, but my code does not work at all, does anyone know why?

tokio::spawn(async move {
            let listener = TcpListener::bind("0.0.0.0:8086")
                .await
                .expect("Failed to bind");

            //let state = self.state.read().await;
            loop {
                let token = CancellationToken::new();
                let token_read = token.clone();
                let token_write = token.clone();

                let inner_authenticated_stream_state = outer_authenticated_stream_state.clone();
                //let inner_authenticated_stream_state = authenticated_stream_state.clone();
                if let Ok((mut socket, _)) = listener.accept().await {
                    println!("Got a new connection");
                    let (mut read_half, mut write_half) = socket.into_split();
                    let mut buf = vec![0u8; 1024];
                    let mut inner_website_tx = from_website_tx.clone();
                    let mut inner_website_rx = to_website_rx.resubscribe();
                    let writeable_inner_authenticated_stream_state = inner_authenticated_stream_state.clone();
                    tokio::spawn(async move {
                        loop {
                            tokio::select! {
                                _ = token_write.cancelled() => { break; }
                                result = inner_website_rx.recv() => {
                                    match result {
                                        Ok(data) => { let _ = write_half.write(&data).await; }
                                        Err(_) => break,
                                    }
                                }
                            }
                        }
                    });

                    tokio::spawn(async move {
                        let mut buf = vec![0u8; 1024];
                        loop {
                            match read_half.read(&mut buf).await {
                                Ok(0) | Err(_) => break,
                                Ok(n) => {
                                    let _ = inner_website_tx.send(buf[..n].to_vec());
                                    // if let Ok(output) = str::from_utf8(&buf[..n].to_vec()) {
                                    //     println!("{:#?}", output);
                                    // } 
                                    let writable_authenticated_stream_state = inner_authenticated_stream_state.write().await;
                                    //println!("{:#?}", *writable_authenticated_stream_state);
                                    if matches!(*writable_authenticated_stream_state, AuthenticatedStreamState::Terminate) {
                                        token_read.cancel(); 
                                        break;
                                    }
                                }
                            }
                        }
                    });
                }
            }
        });
async fn handle_stream(
    arc_state: Arc<RwLock<AppState>>,
    ctx: Context,
    cache: Arc<Cache>,
    mut rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
    mut tx: tokio::sync::broadcast::Sender<Vec<u8>>,
) {
    let mut state = arc_state.write().await;

    let all_guilds = ctx.cache.guilds();
    let main_guild_option = all_guilds
        .iter()
        .find(|guild_id: &&GuildId| u64::from(**guild_id) == 1479002261811236984 as u64);

    //println!("In handle stream");
    while let Ok(bytes) = rx.recv().await {
        //println!("Got some data");
        match serde_json::from_slice::<AuthenticateRequest>(&bytes) {
            Ok(request) => {
                if request.password == "test".to_string(){
                    println!("Password is correct for stream");
                    let mut authenticated_stream =  state.authenticated_stream.write().await;
                    *authenticated_stream = AuthenticatedStreamState::Authenticated;
                } else {
                    println!("Password is incorrect for stream");
                }
            },
            Err(_) => {},
        }
        let mut authenticated_stream =  state.authenticated_stream.write().await;
        if !matches!(*authenticated_stream, AuthenticatedStreamState::Authenticated) {
            println!("Terminating");
            *authenticated_stream = AuthenticatedStreamState::Terminate;
        }
        drop(authenticated_stream);

        match serde_json::from_slice::<WebsiteUserRequest>(&bytes) {
            Ok(request) => {
                // let request = WebsiteUserRequestFinal { id: intial_request.id.parse().unwrap() };
                // if cache.guilds().iter().any(|guild_id| guild_id == request){
                // }

                if let Some(main_guild) = main_guild_option {
                    // println!("Found the main guild");
                    if let Ok(members) = main_guild.members(&ctx.http, None, None).await {
                        // println!("Get the members list");
                        if let Some(member) = members
                            .iter()
                            .find(|member| u64::from(member.user.id) == request.id.clone())
                        {
                            println!("Found the member");
                            let response = WebsiteUserPresentResponse {
                                id: u64::from(member.user.id),
                            };
                            let bytes: Vec<u8> = serde_json::to_vec((&response).into()).unwrap();
                            let result = tx.send(bytes);
                            //println!("{:#?}", result);
                        } else {
                            // println!("Could not find this member");
                        }
                    } else {
                       // println!("Could not get member list");
                    }
                }
            } 
            Err(_) => {},
        }
        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
    }
}

Full code:

As a suggestion, whenever you phrase things like this, it means that you have put on your plate more than you can handle and you need to take a step back (or two). Try to break your implementation into smaller steps, so that you can identify where things break.

5 Likes

First question: is this for learning/education purposes? If not for learning, throw this out and use mutual TLS.

3 Likes

It would help greatly if you handled all the error cases in your code rather than just ignoring. Then you might see where things are going wrong.

I have not studied your code very hard but as far as I can tell you are sending messages in JSON format. For example getting a bunch of bytes from rx.recv() then trying to deserialise that into a struct that contains a password field.

I strongly suspect that is not going to work. A stream of bytes received over a TCP/IP socket is just that, a stream. There is no guarantee that there is one packet received for every packet sent. Packets can be merged together or split apart as they get from sender to receiver. There is no guarantee then that sending on JSON message will be received as a packet containing a JSON message. You might receive part of a message or multiple messages for every rx.recv().

I see nothing in your code that caters for this problem.

One solution, is to send a form feed character, 0x0C, at the end of every JSON messages sent. Then on the receiving end check every incoming byte for that form feed character so that you know a complete JSON blob as been received. Then you can deserialise that blob.

3 Likes