I'm trying to translate some C code to rust. It sets up a socket, sends and recieves any number of replies (they are paired), then closes the socket. The sent messages are 16 * n + 2 bytes: a length byte (units in 16 byte blocks), the contents, then a trailing null. The replies are 16 * (n + 1) + 2 bytes: a null byte, a block of unknown contents, the reply body (same length as the request body), a trailing null byte. Conceptually this seems simple, but I don't know anything about sockets. I am using socket
, #![feature(generic_const_exprs)]
to statically verify that request and reply lengths correspond, and 1.60.0-nightly (5d8767cb2 2022-02-12) (from rustup show
).
#include "oracle.h"
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>
#define MAX_BLOCK_LEN 4
// Block-length(1) | Message(16*block_len) | Null-terminator(1)
#define MAX_REQUEST_LEN 1 + MAX_BLOCK_LEN * 16 + 1
// Validity flag(1) | IV(16) | Ciphertext(16*block_len) | Null-terminator(1)
#define MAX_RESPONSE_LEN 1 + 16 + MAX_BLOCK_LEN * 16 + 1
#define NOFLAGS 0
#define BLOCK_LENGTH 16
oracle Oracle_Connect(const char *const host, int port) {
struct sockaddr_in servaddr;
oracle oracle;
oracle.sockfd = socket(AF_INET, SOCK_STREAM, 0);
bzero(&servaddr, sizeof(servaddr));
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = inet_addr(host);
servaddr.sin_port = htons(port);
if (!connect(oracle.sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr))) {
printf("Connected to server successfully.\n");
oracle.sockfd = 0;
} else {
perror("Failed to connect to oracle");
oracle.sockfd = -1;
}
return oracle;
}
int Oracle_Disconnect(oracle o) {
if (!close(o.sockfd)) {
printf("Connection closed successfully.\n");
return 0;
} else {
perror("[WARNING]: You haven't connected to the server yet");
return -1;
}
}
// Packet Structure: < Block-length(1) | Plaintext(block-length*16) |
// Null-terminator(1) > Fills recvblock with returned ctext
int Oracle_Send(oracle oracle, const unsigned char *const message,
char block_length, unsigned char *const ctext) {
unsigned char request[MAX_REQUEST_LEN];
unsigned char response[MAX_RESPONSE_LEN];
request[0] = block_length;
memcpy(request + 1, message, block_length * 16);
request[block_length * 16 + 1] = '\0';
if (!send(oracle.sockfd, request, 1 + block_length * 16 + 1, NOFLAGS)) {
perror("[WARNING]: You haven't connected to the server yet");
return -1;
}
if (!recv(oracle.sockfd, response, 1 + 16 + block_length * 16 + 1, NOFLAGS)) {
perror("[ERROR]: Recv failed");
return -1;
}
if (response[0]) {
fprintf(stderr, "[ERROR]: Invalid request");
return -1;
}
memcpy(ctext, response + 1, 16 + block_length * 16);
return 0;
}
mod oracle {
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use socket2::{Domain, Socket, Type};
use std::io::Read;
use std::io::Write;
use std::net::SocketAddr;
pub struct Oracle {
socket: Socket,
}
impl Oracle {
const BLOCK_SIZE: usize = 16;
pub fn new(addr: SocketAddr) -> Result<Self> {
Socket::new(Domain::IPV4, Type::STREAM, None)
.and_then(|socket| {
socket.connect(&addr.into())?;
Ok(Self { socket })
})
.context("Failed to connect")
}
pub fn encrypt<const N: usize>(
&mut self,
bytes: &[u8; N * Self::BLOCK_SIZE],
) -> Result<[u8; (N + 1) * Self::BLOCK_SIZE]>
where
[(); Self::BLOCK_SIZE * N + 2]: Sized,
[(); 1 + Self::BLOCK_SIZE + N * Self::BLOCK_SIZE + 1]: Sized,
{
let mut send = [0u8; Self::BLOCK_SIZE * N + 2];
send[0] = N.try_into()?;
send[1..N * Self::BLOCK_SIZE + 1].copy_from_slice(bytes);
send[send.len() - 1] = 0;
self.socket.write_all(&send)?;
self.socket.flush()?;
let mut recv = [0; 1 + Self::BLOCK_SIZE + N * Self::BLOCK_SIZE + 1];
self.socket.read_exact(&mut recv)?;
if dbg!(recv[0]) != 0 {
bail!("invalid request")
}
let mut out = [0; (N + 1) * Self::BLOCK_SIZE];
out.copy_from_slice(&recv[1..1 + (N + 1) * Self::BLOCK_SIZE]);
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connect() -> Result<()> {
let addr = _;
let mut oracle = Oracle::new(addr)?;
let msg = b"0123456789abcdef0123456789abcdef";
let c = oracle.encrypt::<2>(msg)?;
dbg!(c);
Ok(())
}
}
}