How to deserialize untagged enums fast?

I'm currently implementing an existing network protocol in rust. Enums seemed like a perfect fit to represent the possible packet types. To correctly serialize them, serde has to be told to serialize them untagged.

So far so good, serializing and deserializing works even with other implementations, the actual protocol requires cbor encoding but for debugging purposes I can easily switch to json, thanks to serde this works with only very minimal code changes.

While benchmarking my implementation I realized that deserialization was painfully slow. After some fiddling with my source I learned that the order within the enum is responsible for the slowdown. Hence, I sorted them by probability to speed up everything 6x in the average cases.
This still doesn't feel right and probably is not very idiomatic.

A simple rust program that behaves similarly is the following:

use serde::{Deserialize, Serialize};
use std::io::stdout;
use std::io::Write;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)] // Order of probable occurence, serde tries decoding in untagged enums in this order
enum PacketVariants {
    Hello(u32, u8, String),
    Bye(u32, String),
    Bla(String),
    Blub(u32),
}

fn to_cbor(p: &PacketVariants) -> Vec<u8> {
    serde_cbor::to_vec(p).expect("Error serializing packet as cbor.")
}
fn to_json(p: &PacketVariants) -> String {
    serde_json::to_string(p).expect("Error serializing packet as json.")
}

fn deserialize_packet_cbor(runs: i64, buf: Vec<u8>) {
    print!("Loading {} packets: \t", runs);
    stdout().flush().unwrap();

    use std::time::Instant;
    let bench_now = Instant::now();

    for _x in 0..runs {
        let p: PacketVariants = serde_cbor::from_slice(&buf).expect("Decoding packet failed");
    }
    let elapsed = bench_now.elapsed();
    let sec = (elapsed.as_secs() as f64) + (f64::from(elapsed.subsec_nanos()) / 1_000_000_000.0);
    println!("{} packets/second", (runs as f64 / sec) as i64);
}
fn deserialize_packet_json(runs: i64, buf: String) {
    print!("Loading {} packets: \t", runs);
    stdout().flush().unwrap();

    use std::time::Instant;
    let bench_now = Instant::now();

    for _x in 0..runs {
        let p: PacketVariants = serde_json::from_str(&buf).expect("Decoding packet failed");
    }
    let elapsed = bench_now.elapsed();
    let sec = (elapsed.as_secs() as f64) + (f64::from(elapsed.subsec_nanos()) / 1_000_000_000.0);
    println!("{} packets/second", (runs as f64 / sec) as i64);
}
fn main() {
    let hello = PacketVariants::Hello(1, 3, "AAA".into());
    let bye = PacketVariants::Bye(2, "AAA".into());
    let bla = PacketVariants::Bla("AAA".into());
    let blub = PacketVariants::Blub(3);

    // CBOR tests
    let p1 = to_cbor(&hello);
    let p2 = to_cbor(&bye);
    let p3 = to_cbor(&bla);
    let p4 = to_cbor(&blub);
    println!("{:02x?}", p1);
    println!("{:02x?}", p2);
    println!("{:02x?}", p3);
    println!("{:02x?}", p4);

    deserialize_packet_cbor(100_000, p1);
    deserialize_packet_cbor(100_000, p2);
    deserialize_packet_cbor(100_000, p3);
    deserialize_packet_cbor(100_000, p4);

    // JSON tests
    let p1 = to_json(&hello);
    let p2 = to_json(&bye);
    let p3 = to_json(&bla);
    let p4 = to_json(&blub);
    println!("{:?}", p1);
    println!("{:?}", p2);
    println!("{:?}", p3);
    println!("{:?}", p4);

    deserialize_packet_json(100_000, p1);
    deserialize_packet_json(100_000, p2);
    deserialize_packet_json(100_000, p3);
    deserialize_packet_json(100_000, p4);
}

These dependencies are required:

serde = { version = "1.0", features = ["derive"] }
serde_derive = "1"
serde_cbor = "0.9"
serde_json = "1.0"

And here the cbor benchmark output:

Loading 100000 packets: 	3192034 packets/second
Loading 100000 packets: 	1094580 packets/second
Loading 100000 packets: 	693023 packets/second
Loading 100000 packets: 	594142 packets/second

Each line corresponds to one of the packet variants.

Is my only alternative writing a manual packet parser and not using serde? Is there a way to make serde parsing more intelligent so that it looks ahead if another field is coming or if the type fits one of the expected ones? I guess the biggest problem for serde is that, for example, json doesn't provide any useful information regarding array lengths or type information whereas cbor or msgpack usually provide vague type information or the length of an array. Should solutions for this be implemented in say serde_cbor?
On the other hand, maybe I am just too new to rust and missing something really obvious to speed up my program...

1 Like

You could manually implement Deserialize for your type, I had a go at doing so and at least on the playground for decoding from JSON it seems to give better speed: Rust Playground (serde_cbor is not available on the playground so I couldn't test that).

Serde is mainly used for self-descriving formats like JSON, YAML and msgpack. If you want to deal with traditional binary protocol, you may write your own parser/writer im the end, which is actually not that hard in Rust. There's plenty of awesome crates to help you build your own parser, like "nom".

Anyway you can still #[derive(Serialize, Deserialize)] your protocol's in-memory representation to print it as JSON, which can be helpful for your purpose.

Here is what I would write for performance. This is in the direction of @Nemo157's implementation.

The performance is 2x faster at PacketVariants::Hello, 5x faster at PacketVariants::Bye, 19x faster at PacketVariants::Bla, and 4x faster at PacketVariants::Blub.

3 Likes

Thanks for the great examples!

I always was a bit afraid of writing visitor code but with these examples I at least got an idea how it should look for my use case. Will probably try to do something similar to @Nemo157's solution with one visitor because the packets have between 8 and 12 fields and splitting these up in many visitors as @dtolnay has done seems even more tedious. Also it is a bit similar to writing a manual parser. Most of the logic should be in a large visit_seq function.
Maybe in the future macros could be used to autogenerate this code as there is no real magic involved but writing all this code still costs you quite some time and doing it manually is error prone.

Writing the parser from hand would be straight forward but I'd rather not reimplement CBOR decoding, the most stable and maintained cbor rust version is for serde.

Interestingly enough, my not very idiomatic rust version is still twice as fast as our golang implementation. I love zero-cost abstractions and whatever other magic is responsible for this :slight_smile:

1 Like

This topic was automatically closed 90 days after the last reply. New replies are no longer allowed.