Translating C socket code to rust

I'm trying to translate some C code to rust. It sets up a socket, sends and recieves any number of replies (they are paired), then closes the socket. The sent messages are 16 * n + 2 bytes: a length byte (units in 16 byte blocks), the contents, then a trailing null. The replies are 16 * (n + 1) + 2 bytes: a null byte, a block of unknown contents, the reply body (same length as the request body), a trailing null byte. Conceptually this seems simple, but I don't know anything about sockets. I am using socket, #![feature(generic_const_exprs)] to statically verify that request and reply lengths correspond, and 1.60.0-nightly (5d8767cb2 2022-02-12) (from rustup show).

#include "oracle.h"
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>

#define MAX_BLOCK_LEN 4
// Block-length(1) | Message(16*block_len) | Null-terminator(1)
#define MAX_REQUEST_LEN 1 + MAX_BLOCK_LEN * 16 + 1
// Validity flag(1) | IV(16) | Ciphertext(16*block_len) | Null-terminator(1)
#define MAX_RESPONSE_LEN 1 + 16 + MAX_BLOCK_LEN * 16 + 1
#define NOFLAGS 0
#define BLOCK_LENGTH 16

oracle Oracle_Connect(const char *const host, int port) {
  struct sockaddr_in servaddr;
  oracle oracle;

  oracle.sockfd = socket(AF_INET, SOCK_STREAM, 0);

  bzero(&servaddr, sizeof(servaddr));
  servaddr.sin_family = AF_INET;
  servaddr.sin_addr.s_addr = inet_addr(host);
  servaddr.sin_port = htons(port);

  if (!connect(oracle.sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr))) {
    printf("Connected to server successfully.\n");
    oracle.sockfd = 0;
  } else {
    perror("Failed to connect to oracle");
    oracle.sockfd = -1;
  }
  return oracle;
}

int Oracle_Disconnect(oracle o) {
  if (!close(o.sockfd)) {
    printf("Connection closed successfully.\n");
    return 0;
  } else {
    perror("[WARNING]: You haven't connected to the server yet");
    return -1;
  }
}

// Packet Structure: < Block-length(1) | Plaintext(block-length*16) |
// Null-terminator(1) > Fills recvblock with returned ctext
int Oracle_Send(oracle oracle, const unsigned char *const message,
                char block_length, unsigned char *const ctext) {
  unsigned char request[MAX_REQUEST_LEN];
  unsigned char response[MAX_RESPONSE_LEN];

  request[0] = block_length;
  memcpy(request + 1, message, block_length * 16);
  request[block_length * 16 + 1] = '\0';

  if (!send(oracle.sockfd, request, 1 + block_length * 16 + 1, NOFLAGS)) {
    perror("[WARNING]: You haven't connected to the server yet");
    return -1;
  }

  if (!recv(oracle.sockfd, response, 1 + 16 + block_length * 16 + 1, NOFLAGS)) {
    perror("[ERROR]: Recv failed");
    return -1;
  }
  if (response[0]) {
    fprintf(stderr, "[ERROR]: Invalid request");
    return -1;
  }

  memcpy(ctext, response + 1, 16 + block_length * 16);

  return 0;
}

mod oracle {
    use anyhow::bail;
    use anyhow::Context;
    use anyhow::Result;
    use socket2::{Domain, Socket, Type};
    use std::io::Read;
    use std::io::Write;
    use std::net::SocketAddr;

    pub struct Oracle {
        socket: Socket,
    }

    impl Oracle {
        const BLOCK_SIZE: usize = 16;

        pub fn new(addr: SocketAddr) -> Result<Self> {
            Socket::new(Domain::IPV4, Type::STREAM, None)
                .and_then(|socket| {
                    socket.connect(&addr.into())?;
                    Ok(Self { socket })
                })
                .context("Failed to connect")
        }

        pub fn encrypt<const N: usize>(
            &mut self,
            bytes: &[u8; N * Self::BLOCK_SIZE],
        ) -> Result<[u8; (N + 1) * Self::BLOCK_SIZE]>
        where
            [(); Self::BLOCK_SIZE * N + 2]: Sized,
            [(); 1 + Self::BLOCK_SIZE + N * Self::BLOCK_SIZE + 1]: Sized,
        {
            let mut send = [0u8; Self::BLOCK_SIZE * N + 2];
            send[0] = N.try_into()?;
            send[1..N * Self::BLOCK_SIZE + 1].copy_from_slice(bytes);
            send[send.len() - 1] = 0;
            self.socket.write_all(&send)?;
            self.socket.flush()?;
            let mut recv = [0; 1 + Self::BLOCK_SIZE + N * Self::BLOCK_SIZE + 1];
            self.socket.read_exact(&mut recv)?;
            if dbg!(recv[0]) != 0 {
                bail!("invalid request")
            }
            let mut out = [0; (N + 1) * Self::BLOCK_SIZE];
            out.copy_from_slice(&recv[1..1 + (N + 1) * Self::BLOCK_SIZE]);
            Ok(out)
        }
    }

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn test_connect() -> Result<()> {
            let addr = _;
            let mut oracle = Oracle::new(addr)?;
            let msg = b"0123456789abcdef0123456789abcdef";
            let c = oracle.encrypt::<2>(msg)?;
            dbg!(c);
            Ok(())
        }
    }
}

I was sendings messages of incorrect length. the client code works with arbitrary length, but the server only accepts N = 1.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.