Beginner: Code review of simple sockets server

Tl;DR Can someone look at my stupid code below and tell me how a more experienced Rust developer would make it nicer? Thanks!

Hi, I started learning rust again today (I tried it once one or two years ago) and decided to build a sockets server with it for practice. I've been reading through the tutorial pages and the alternative approach to OOP in rust and trying to wrap my head around it. Anyways, I built some code that works, but I feel like it's a touchy beast balancing on its toes where any small change causes an outbreak of compiler errors that can only be resolved by a series of StackOverflow searches and increasing layers of indirection. So I was wondering if someone who's been working with Rust for a while could give any suggestions on how the code design could be improved.

Thanks!

Project goal:
A program that uses a TcpListener to process messages from each client in separate threads. The number of threads that are alive is kept track of. The specific format of the messages: first 4 bytes contain the message length in little endian and the following bytes contain the message.

My implementation:

  • Client has a run method which will be executed in a thread
  • Client struct contains each stream and a way to detect if it is alive
  • The client has a thread safe reference counted variable (Arc<bool> where bool could be anything). This allows getting a weak reference to it to detect when the thread finishes and the last reference to the struct member is lost
  • Store these weak references in a vector and retain the ones that still point to something

The ugliest part I see in my implementation:
I feel like the whole Arc<bool> thing is hacky, but I couldn't figure out how to tame the mess that ensued when I tried keeping the client objects in a vector and implementing an is_alive() method because the client instance called in the closure passed to thread::spawn would outlive the scope of the variable.

Code:

use std::net::{TcpListener, TcpStream};
use std::io;
use std::io::Read;
use std::str;
use std::thread;
use std::sync::Arc;
use std::sync::Weak;
use std::error::Error;

type BoxResult<T> = Result<T, Box<Error>>;

/// Interpret 4 bytes in little endian as a u32 number
fn bytes_to_u32(array: [u8; 4]) -> u32 {
    u32::from(array[0]) +
        (u32::from(array[1]) << 8) +
        (u32::from(array[2]) << 16) +
        (u32::from(array[3]) << 24)
}

/// TCP socket to a client using our protocol
struct Client {
    stream: TcpStream,
    on_message: fn(String),
    alive_handle: Arc<bool>  // Reference for detecting when object leaves scope
}

impl Client {
    fn new(stream: TcpStream, on_message: fn(String)) -> Client {
        Client {
            stream,
            on_message,
            alive_handle: Arc::new(false),
        }
    }

    fn run(&mut self) {
        loop {
            let op_size = self.read_message_size();
            if op_size.is_err() {
                break;
            }
            let size = op_size.unwrap();
            let next_msg = self.get_next_message(size);
            match next_msg {
                Ok(next_msg) => (self.on_message)(next_msg),
                Err(e) => println!("Error getting message: {}", e)
            }
        }
    }

    fn read_message_size(&mut self) -> BoxResult<usize> {
        let mut bytes = [0; 4];
        self.stream.read_exact(&mut bytes)?;
        Ok(bytes_to_u32(bytes) as usize)
    }

    fn get_next_message(&mut self, size: usize) -> BoxResult<String> {
        let mut data = vec![0; size];
        self.stream.read_exact(&mut data)?;
        Ok(str::from_utf8(&data)?.to_owned())
    }
}

fn main() -> io::Result<()> {
    let mut alive_flags: Vec<Weak<bool>> = vec![];
    for op_stream in TcpListener::bind("127.0.0.1:8000")?.incoming() {
        let stream = match op_stream {
            Ok(stream) => stream,
            Err(_) => continue
        };
        let mut client = Client::new(stream, |x| println!("Message = {}", x));
        let flag = Arc::downgrade(&client.alive_handle);
        thread::spawn(move || client.run());
        alive_flags.push(flag);
        alive_flags.retain(|x| x.upgrade().is_some());
        println!("Number of threads: {}", alive_flags.len());
    }
    Ok(())
}

For reference, here's the Python code for simulating clients connecting and disconnecting.

Updates:

  • While I was hesitant at first, it seems like implementing a new method is pretty common. So I should probably have a Client::new.
  • Exceptions should be properly handled. I've just updated the code with this
  • Of course usize is platform dependent so I've replaced it with the more applicable u32
  • Pass small arrays by value instead of reference
  • Redundant keys when initializing variables of the same name in a struct

In Rust, we don't use return 3; as last statement, instead we just write 3.

fn read_message_size(&mut self) -> usize {
        let mut bytes = [0; 4];
        self.stream.read(&mut bytes).unwrap();
        bytes_to_usize(&mut bytes)
    }

    fn get_next_message(&mut self, size: usize) -> String {
        let mut data = vec![0; size];
        let _ = self.stream.read(&mut data);
        str::from_utf8(&data).unwrap().to_owned()
    }

Why is using return as the last statement in a function considered bad style? (on SO).


self.stream.read(&mut bytes).unwrap();
// ...
let _ = self.stream.read(&mut data);

is incorrect, because read does not necessarily read the right amount of bytes. That's why it returns a Result (which you should not ignore, like you do right now!). Instead use read_exact.

All in all, you should familiarise yourself with the error handling concept in Rust, espcially Result and Option. See chapter 9 of the book.


let mut bytes = [0; 4];
// ...
bytes_to_usize(&mut bytes)

I don't think that you understood what usize is. It's an platform dependent type, which size is how many bytes it takes to reference any location in memory. E.g. 8 bytes on a typical x86_64 platform on desktops. Therefore your 4 byte array is wrong here. You either want to expand that to 8 bytes or use u32 instead.


After looking more closely to your code, you should read the book, because it seems that you haven't actually looked in it. There are some style decisions which are very un-rusty and considered bad. Please read it :slight_smile:

Thanks for the suggestions! However, exception handling and incorrect type usages don't directly address the core of what I'm asking as they don't affect much of the design of the code. In any case, I've made the suggestions you've mentioned. I'll continue reading the book to see if I can figure out some of these un-rusty style decisions. If you wouldn't mind, could you point out one or two of the biggest style improvements that could be made?

I'd recommend installing clippy - it'll help you find common issues with code style/correctness.

I ran it on your code via the playground (it's in the tools menu), and some of the suggestions made were:

  • When you're initializing a struct, instead of doing stream: stream, you can just say stream as a shorthand.
  • [u8; 4] implements Copy, and it's small enough that passing it by reference to bytes_to_u32 is probably less performant than just passing it by value.
  • The as operator will sometimes do lossy conversions, which can cause bugs if you change the type of a variable. You can use functions like u32::from to do lossless conversions.
  • You're still not handling the returned amount of bytes read in read_message_size - Clippy actually catches this and throws up a big honking error message to make sure you know it's a no-no :slight_smile:
    • If you're not familiar with why - calling read sometimes will only partially fill the array, meaning that there could still be stale data from the previous read left in there. You either need to handle this explicitly, or say 'keep reading until the buffer is full' by calling read_exact.

Here's a version of your code with these things fixed - there may be other changes that could be made (don't have time to dig too deep right now), but this should get some of the low-hanging fruit out of the way.

3 Likes

Thanks! Just updated with all those suggestions; clippy is really nice (wonder if the name pays tribute to Microsoft Office :P). Funny enough, I was actually using that exact same playground but didn't notice those tools.

So is the only reason for all of the alive_flags fun so that you can keep track of the total number of threads? To be honest it took me a little while to figure out what was going on there :-).

If all you're trying to do is keep track of the thread count, here's an implementation that uses Arc<Mutex<i32>> (Arc + Mutex go together a lot) and the Drop trait:

use std::error::Error;
use std::io;
use std::io::Read;
use std::net::{TcpListener, TcpStream};
use std::str;
use std::sync::{Arc, Mutex};
use std::thread;

type BoxResult<T> = Result<T, Box<Error>>;

/// Interpret 4 bytes in little endian as a u32 number
fn bytes_to_u32(array: [u8; 4]) -> u32 {
    u32::from(array[0])
        + (u32::from(array[1]) << 8)
        + (u32::from(array[2]) << 16)
        + (u32::from(array[3]) << 24)
}

/// TCP socket to a client using our protocol
struct Client {
    stream: TcpStream,
    on_message: fn(String),
    total_count: Arc<Mutex<i32>>,
}

impl Client {
    fn new(stream: TcpStream, on_message: fn(String), total_count: Arc<Mutex<i32>>) -> Client {
        Client {
            stream,
            on_message,
            total_count,
        }
    }

    fn run(&mut self) {
        {
            let mut guard = self.total_count.lock().unwrap();
            *guard += 1;
        }
        loop {
            let op_size = self.read_message_size();
            if op_size.is_err() {
                break;
            }
            let size = op_size.unwrap();
            let next_msg = self.get_next_message(size);
            match next_msg {
                Ok(next_msg) => (self.on_message)(next_msg),
                Err(e) => println!("Error getting message: {}", e),
            }
        }
    }

    fn read_message_size(&mut self) -> BoxResult<usize> {
        let mut bytes = [0; 4];
        self.stream.read_exact(&mut bytes)?;
        Ok(bytes_to_u32(bytes) as usize)
    }

    fn get_next_message(&mut self, size: usize) -> BoxResult<String> {
        let mut data = vec![0; size];
        self.stream.read_exact(&mut data)?;
        Ok(str::from_utf8(&data)?.to_owned())
    }
}

impl Drop for Client {
    fn drop(&mut self) {
        let mut guard = self.total_count.lock().unwrap();
        *guard -= 1;
    }
}

fn main() -> io::Result<()> {
    let thread_count = Arc::new(Mutex::new(0));
    for op_stream in TcpListener::bind("127.0.0.1:8000")?.incoming() {
        let stream = match op_stream {
            Ok(stream) => stream,
            Err(_) => continue,
        };
        let mut client = Client::new(
            stream,
            |x| println!("Message = {}", x),
            thread_count.clone(),
        );
        thread::spawn(move || client.run());
        {
            let guard = thread_count.lock().unwrap();
            println!("Number of threads: {}", *guard);
        }
    }
    Ok(())
}
1 Like

Although note that you might no want to call .unwrap() like I did :slight_smile:

Why not AtomicUsize

There is u32::from_le_bytes

2 Likes

Because I've almost only ever seen/used Arc<Mutex<>> with structs and completely forgot that AtomicUsize existed :slightly_smiling_face: Thanks for the tip!

Yup, that part was just copy-pasted from the original post. I just added my Arc<Mutex<>> stuff.

1 Like

In case anyone else was curious, here's a version with @hellow's suggested changes:

use std::error::Error;
use std::io;
use std::io::Read;
use std::net::{TcpListener, TcpStream};
use std::str;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;

type BoxResult<T> = Result<T, Box<Error>>;

/// TCP socket to a client using our protocol
struct Client {
    stream: TcpStream,
    on_message: fn(String),
    total_count: Arc<AtomicUsize>,
}

impl Client {
    fn new(stream: TcpStream, on_message: fn(String), total_count: Arc<AtomicUsize>) -> Client {
        Client {
            stream,
            on_message,
            total_count,
        }
    }

    fn run(&mut self) {
        self.total_count.fetch_add(1, Ordering::SeqCst);

        loop {
            let op_size = self.read_message_size();
            if op_size.is_err() {
                break;
            }
            let size = op_size.unwrap();
            let next_msg = self.get_next_message(size);
            match next_msg {
                Ok(next_msg) => (self.on_message)(next_msg),
                Err(e) => println!("Error getting message: {}", e),
            }
        }
    }

    fn read_message_size(&mut self) -> BoxResult<usize> {
        let mut bytes = [0; 4];
        self.stream.read_exact(&mut bytes)?;
        Ok(u32::from_le_bytes(bytes) as usize)
    }

    fn get_next_message(&mut self, size: usize) -> BoxResult<String> {
        let mut data = vec![0; size];
        self.stream.read_exact(&mut data)?;
        Ok(str::from_utf8(&data)?.to_owned())
    }
}

impl Drop for Client {
    fn drop(&mut self) {
        self.total_count.fetch_sub(1, Ordering::SeqCst);
    }
}

fn main() -> io::Result<()> {
    let thread_count = Arc::new(AtomicUsize::new(0));

    for op_stream in TcpListener::bind("127.0.0.1:8000")?.incoming() {
        let stream = match op_stream {
            Ok(stream) => stream,
            Err(_) => continue,
        };
        let mut client = Client::new(
            stream,
            |x| println!("Message = {}", x),
            Arc::clone(&thread_count),
        );
        thread::spawn(move || client.run());
        println!("Number of threads: {}", thread_count.load(Ordering::SeqCst));
    }
    Ok(())
}

You might be able to relax some of the orderings, but I have no idea what i'm doing when it comes to multithreading, so I played it safe :slight_smile:

3 Likes

Thanks everyone! That looks a lot less hacky using Arc<AtomicUsize> and implementing Drop. @17cupsofcoffee As a small adjustment, would it make sense to put the total_count.fetch_add(1, Ordering::SeqCst) in the ::new() function to prevent { Client::new(...); } from breaking total_count?

Yeah, that might be a good idea, as otherwise if you dropped your Client before running it, you'd end up with the counter out of sync. Good spot, hadn't considered that :slight_smile:

This topic was automatically closed 90 days after the last reply. New replies are no longer allowed.