Serde: custom deserialize untagged enum with tagged variant

Note: The source code in this post is a complete example. To compile it:

cargo add indoc serde serde_yaml serde_derive

I have the following config file:

#[macro_use]
extern crate serde_derive;

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Settings {
    potential: Potential,
}

#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(untagged)]
pub enum Potential {
    Single(PotentialKind),
    Sum(Vec<PotentialKind>),
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum PotentialKind {
    #[serde(rename = "airebo")] Airebo(PotentialAirebo),
    #[serde(rename = "test-function-zero")] Zero,
    // many many more...
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PotentialAirebo {
    pub lj_sigma: f64,
}

I got rid of the #[derive(Deserialize)] on Potential because it swallows errors. After some browsing around serde::de, I came up with the following manual implementation:

use std::fmt;
use serde::de::{self, IntoDeserializer};

impl<'de> de::Deserialize<'de> for Potential {
    fn deserialize<D: de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
        struct MyVisitor;
        impl<'de> de::Visitor<'de> for MyVisitor {
            type Value = Potential;

            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                write!(formatter, "a potential or array of potentials")
            }

            fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
                let mut vec = vec![];
                while let Some(pot) = seq.next_element()? {
                    vec.push(pot);
                }
                Ok(Potential::Sum(vec))
            }

            fn visit_str<E: de::Error>(self, s: &str) -> Result<Self::Value, E> {
                de::Deserialize::deserialize(s.into_deserializer())
                    .map(Potential::Single)
            }

            fn visit_map<A: de::MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
                de::Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
                    .map(Potential::Single)
            }
        }

        deserializer.deserialize_any(MyVisitor)
    }
}

This works great for string potentials and arrays, and brings back all of the amazing error messages that were lost when this untagged enum was first introduced. But it doesn’t work for singleton map enum variants (i.e. PotentialKind::Airebo):

#[test]
fn singleton() {
    assert_eq!(
        serde_yaml::from_str::<Settings>(indoc!(r"
            potential:
              airebo:
                lj-sigma: 14
        ")).unwrap(),
        Settings {
            potential: Potential::Single({
                PotentialKind::Airebo(PotentialAirebo { lj_sigma: 14.0 })
            }),
        },
    );

    // Check that errors contain location info. This rules out hacks such as
    // deserializing to an intermediary like serde_yaml::Value.
    let loc = serde_yaml::from_str::<Settings>(indoc!(r"
        potential:
          airebo:
            lj-sigma: true
    ")).unwrap_err().location().expect("missing location information");

    assert_eq!(loc.line(), 3);
    assert_eq!(loc.column(), 15);
}
---- singleton stdout ----
thread 'singleton' panicked at 'called `Result::unwrap()` on an `Err` value:
Message("invalid type: map, expected enum PotentialKind", Some(Pos {
marker: Marker { index: 19, line: 2, col: 8 }, path: "potential" }))',
src/libcore/result.rs:997:5

It seems the problem is as follows:

  • To support arrays and strings, my custom Deserialize impl is forced to use deserialize_any.
  • deserialize_any only calls visit_map and not visit_enum. (hence I didn’t bother overriding the latter)
  • Deserializing MapAccessDeserializer almost certainly deserializes it as a map.
  • Tagged enums like PotentialKind do not implement visit_map, because they expect to receive a call to visit_enum instead.

So I’m not sure what to do. How do I deserialize this singleton map as an enum variant?

Your code is fine – I released serde 1.0.91 with the fix.

3 Likes

Wow! A fix already online not three hours after I’ve posted. You’re a wizard, David!