Hi All,
I am new to Rust and am in the process of learning the language. I program mostly in Golang these days and was exploring Rust as an alternative for writing networking applications.
I have a packet processing program that transports packets over a QUIC connection from client to server where the server forwards it to the destination with appropriate NAT'ing applied. The program is functional though the performance of the Rust based program is lower than a comparative Go program and I was hoping this would be higher. I think I might be missing something in my implementation which may improve performance even though I have tried to use async based constructs.
Any advice on how to make this better would be greatly appreciated.
The relevant part of the code is copy-pasted below.
#[macro_use]
extern crate log;
use std::{net::SocketAddr, error::Error, fs::File, io::BufReader};
use std::sync::Arc;
use quinn::{Endpoint, ServerConfig, Connection};
use ctrlc;
use tokio::io::{WriteHalf, ReadHalf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use std::net::Ipv4Addr;
use std::path::Path;
use std::fs;
use libc;
use iptables;
use pnet::packet::{ipv4::{MutableIpv4Packet, checksum}, ip::IpNextHeaderProtocols, tcp::{MutableTcpPacket}, tcp, udp, Packet, udp::MutableUdpPacket};
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
env_logger::init();
debug!("Starting server!");
// start server loop
server().await?;
Ok(())
}
fn server_addr() -> SocketAddr {
"0.0.0.0:443".parse::<SocketAddr>().unwrap()
}
async fn server() -> Result<(), Box<dyn Error>> {
const ALPN_QUIC_TUNNEL: &[&[u8]] = &[b"quic-tunnel"];
let (certs, key) = read_certs_from_file().unwrap();
let mut server_crypto = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?;
server_crypto.alpn_protocols = ALPN_QUIC_TUNNEL.iter().map(|&x| x.into()).collect();
let server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
let endpoint = Endpoint::server(server_config, server_addr())?;
// Start iterating over incoming connections.
while let Some(conn) = endpoint.accept().await {
let connection = conn.await?;
debug!("Received connection!");
let tunnel_subnet = String::from("172.0.0.10");
let tunnel_ip = tunnel_subnet.parse::<Ipv4Addr>().unwrap();
let iface: tokio_tun::Tun = tokio_tun::TunBuilder::new()
.name("tun0")
.tap(false)
.packet_info(false)
.mtu(1100)
.up()
.address(tunnel_ip)
.netmask(Ipv4Addr::new(255, 255, 255, 254))
.try_build().unwrap();
let (reader, writer) = tokio::io::split(iface);
let connection2 = connection.clone();
let receive_fut = receive_datagram(connection, writer);
tokio::spawn(async move {
receive_fut.await;
});
let send_fut = send_datagram(connection2, reader);
tokio::spawn(async move {
send_fut.await;
});
}
Ok(())
}
async fn receive_datagram(connection: Connection, mut tun_writer: WriteHalf<tokio_tun::Tun>) {
let srcip = "172.0.0.11".parse::<Ipv4Addr>().unwrap();
while let Ok(received_bytes) = connection.read_datagram().await {
let mut data = received_bytes.to_vec();
// decode packet and replace src ip
let mut ipv4 = MutableIpv4Packet::new(&mut data).unwrap();
ipv4.set_source(srcip);
match ipv4.get_next_level_protocol() {
IpNextHeaderProtocols::Tcp => {
let destip = ipv4.get_destination();
let mut tcp = MutableTcpPacket::owned(ipv4.payload().to_owned()).unwrap();
tcp.set_checksum(tcp::ipv4_checksum(&tcp.to_immutable(), &srcip, &destip));
ipv4.set_payload(tcp.packet());
}
IpNextHeaderProtocols::Udp => {
let destip = ipv4.get_destination();
let mut udp = MutableUdpPacket::owned(ipv4.payload().to_owned()).unwrap();
udp.set_checksum(udp::ipv4_checksum(&udp.to_immutable(), &srcip, &destip));
ipv4.set_payload(udp.packet());
},
IpNextHeaderProtocols::Icmp => (),
_ => {
debug!("Unknown packet type!");
continue;
},
}
ipv4.set_checksum(checksum(&ipv4.to_immutable()));
let _num_written = tun_writer.write(&ipv4.packet()).await;
}
}
async fn send_datagram(connection: Connection, mut tun_reader: ReadHalf<tokio_tun::Tun>) {
let destip = "10.0.0.2".parse::<Ipv4Addr>().unwrap();
let mut buffer = vec![0; 1100];
loop {
let num_read = tun_reader.read(&mut buffer).await.unwrap();
let mut data = buffer[..num_read].to_owned();
// decode packet and replace src ip
let mut ipv4 = MutableIpv4Packet::new(&mut data).unwrap();
ipv4.set_destination(destip);
match ipv4.get_next_level_protocol() {
IpNextHeaderProtocols::Tcp => {
let srcip = ipv4.get_source();
let mut tcp = MutableTcpPacket::owned(ipv4.payload().to_owned()).unwrap();
tcp.set_checksum(tcp::ipv4_checksum(&tcp.to_immutable(), &srcip, &destip));
ipv4.set_payload(tcp.packet());
}
IpNextHeaderProtocols::Udp => {
let srcip = ipv4.get_source();
let mut udp = MutableUdpPacket::owned(ipv4.payload().to_owned()).unwrap();
udp.set_checksum(udp::ipv4_checksum(&udp.to_immutable(), &srcip, &destip));
ipv4.set_payload(udp.packet());
},
IpNextHeaderProtocols::Icmp => (),
_ => {
debug!("Unknown packet type!");
continue
},
}
ipv4.set_checksum(checksum(&ipv4.to_immutable()));
let out_data = ipv4.packet().to_owned();
if let Err(e) = connection.send_datagram(out_data.into()) {
error!("Error sending datagram back to client: {}", e.to_string());
}
}
}