From 6f5f7943b4729527e56fb5e17515d3ce97c9d0f2 Mon Sep 17 00:00:00 2001 From: Horu <73709188+HigherOrderLogic@users.noreply.github.com> Date: Tue, 19 May 2026 17:43:04 +0000 Subject: [PATCH] feat(serde): allow defining data shape (#157) --- src/de.rs | 201 +++++++++++++++++++++++++++++--- src/se.rs | 334 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 513 insertions(+), 22 deletions(-) diff --git a/src/de.rs b/src/de.rs index b8fc307..99934bb 100644 --- a/src/de.rs +++ b/src/de.rs @@ -37,9 +37,11 @@ //! assert_eq!(config, Config { name: "my-app".into(), port: 8080 }); //! ``` +use std::borrow::Cow; use std::fmt; -use serde::de::{self, DeserializeSeed, Deserializer as _, MapAccess, SeqAccess, Visitor}; +use serde::de::value::StringDeserializer; +use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor}; use serde::Deserialize; use crate::{KdlDocument, KdlEntry, KdlNode, KdlValue}; @@ -691,22 +693,21 @@ impl<'de, 'a> SeqAccess<'de> for NodeGroupSeqAccess<'a> { struct NodeMapAccess<'a> { /// Properties and children combined as (key, value_source). - entries: Vec<(&'a str, NodeMapValue<'a>)>, + entries: Vec<(Cow<'a, str>, NodeMapValue<'a>)>, idx: usize, } enum NodeMapValue<'a> { Arg(&'a KdlValue), + Args(Vec<&'a KdlEntry>), SingleNode(&'a KdlNode), MultiNode(Vec<&'a KdlNode>), } impl<'a> NodeMapAccess<'a> { fn new(node: &'a KdlNode) -> Self { - let mut entries: Vec<(&'a str, NodeMapValue<'a>)> = Vec::new(); + let mut entries: Vec<(Cow<'a, str>, NodeMapValue<'a>)> = Vec::new(); - // Add the first argument as "-" if there are args and also props/children - // (for mixed nodes). Otherwise arguments are not included in map mode. let args: Vec<_> = node .entries() .iter() @@ -718,18 +719,21 @@ impl<'a> NodeMapAccess<'a> { .filter(|e| e.name().is_some()) .collect(); - // If there are only args and no props/children, we shouldn't be in map mode. - // But if we are (struct requested), put args as "-" entries. - if !args.is_empty() && (props.is_empty() && node.children().is_none()) { - for arg in args.iter() { - entries.push(("-", NodeMapValue::Arg(arg.value()))); - } + for (i, arg) in args.iter().enumerate() { + entries.push(( + format!("$argument{}", i + 1).into(), + NodeMapValue::Arg(arg.value()), + )); + } + + if !args.is_empty() { + entries.push(("$arguments".into(), NodeMapValue::Args(args.clone()))); } - // Add properties for prop in &props { let name = prop.name().unwrap().value(); - entries.push((name, NodeMapValue::Arg(prop.value()))); + entries.push((format!("@{}", name).into(), NodeMapValue::Arg(prop.value()))); + entries.push((name.into(), NodeMapValue::Arg(prop.value()))); } // Add children (grouped by node name) @@ -747,9 +751,9 @@ impl<'a> NodeMapAccess<'a> { for key in child_keys { let group = child_groups.remove(key).unwrap(); if group.len() == 1 { - entries.push((key, NodeMapValue::SingleNode(group[0]))); + entries.push((key.into(), NodeMapValue::SingleNode(group[0]))); } else { - entries.push((key, NodeMapValue::MultiNode(group))); + entries.push((key.into(), NodeMapValue::MultiNode(group))); } } } @@ -768,8 +772,11 @@ impl<'de, 'a> MapAccess<'de> for NodeMapAccess<'a> { if self.idx >= self.entries.len() { return Ok(None); } - let key = self.entries[self.idx].0; - seed.deserialize(str_deserializer(key)).map(Some) + match &self.entries[self.idx].0 { + Cow::Borrowed(s) => seed.deserialize(str_deserializer(s)), + Cow::Owned(s) => seed.deserialize(StringDeserializer::new(s.clone())), + } + .map(Some) } fn next_value_seed>( @@ -780,12 +787,57 @@ impl<'de, 'a> MapAccess<'de> for NodeMapAccess<'a> { self.idx += 1; match value { NodeMapValue::Arg(v) => seed.deserialize(ValueDeserializer { value: v }), + NodeMapValue::Args(entries) => seed.deserialize(ArgsSeqDeserializer { + entries: entries.clone(), + }), NodeMapValue::SingleNode(node) => seed.deserialize(NodeDeserializer { node }), NodeMapValue::MultiNode(nodes) => seed.deserialize(NodeGroupDeserializer { nodes }), } } } +struct ArgsSeqDeserializer<'a> { + entries: Vec<&'a KdlEntry>, +} + +impl<'de, 'a> Deserializer<'de> for ArgsSeqDeserializer<'a> { + type Error = Error; + + fn deserialize_any>(self, visitor: V) -> Result { + visitor.visit_seq(ArgSeqAccess { + iter: self.entries.into_iter(), + }) + } + + fn deserialize_seq>(self, visitor: V) -> Result { + visitor.visit_seq(ArgSeqAccess { + iter: self.entries.into_iter(), + }) + } + + fn deserialize_option>(self, visitor: V) -> Result { + if self.entries.is_empty() { + visitor.visit_none() + } else { + visitor.visit_some(self) + } + } + + fn deserialize_newtype_struct>( + self, + _: &'static str, + visitor: V, + ) -> Result { + visitor.visit_newtype_struct(self) + } + + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf unit unit_struct tuple + tuple_struct map struct enum identifier ignored_any + } +} + struct NodeEnumAccess<'a> { node: &'a KdlNode, } @@ -1289,4 +1341,119 @@ h 8 } ); } + + #[test] + fn rename_props() { + #[derive(Deserialize, Debug, PartialEq)] + struct Server { + #[serde(rename = "@host")] + host: String, + #[serde(rename = "@port")] + port: u16, + } + + #[derive(Deserialize, Debug, PartialEq)] + struct Config { + server: Server, + } + + let kdl = r#"server host="localhost" port=8080"#; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + server: Server { + host: "localhost".into(), + port: 8080, + }, + } + ); + } + + #[test] + fn rename_args() { + #[derive(Deserialize, Debug, PartialEq)] + struct Server { + #[serde(rename = "$argument1")] + host: String, + #[serde(rename = "@port")] + port: u16, + #[serde(rename = "$arguments")] + all: Vec, + } + + #[derive(Deserialize, Debug, PartialEq)] + struct Config { + server: Server, + } + + let kdl = r#"server "localhost" port=8080 "remote""#; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + server: Server { + host: "localhost".into(), + port: 8080, + all: Vec::from(["localhost".into(), "remote".into()]) + }, + } + ); + } + + #[test] + fn rename_all_args() { + #[derive(Deserialize, Debug, PartialEq)] + struct Command { + #[serde(rename = "@name")] + name: String, + #[serde(rename = "$arguments")] + args: Vec, + } + + #[derive(Deserialize, Debug, PartialEq)] + struct Config { + command: Command, + } + + let kdl = r#"command name="run" "--verbose" "--output" "result.txt""#; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + command: Command { + name: "run".into(), + args: vec!["--verbose".into(), "--output".into(), "result.txt".into(),], + }, + } + ); + } + + #[test] + fn rename_children_args() { + #[derive(Deserialize, Debug, PartialEq)] + struct Server { + #[serde(rename = "@host")] + host: String, + #[serde(rename = "$arguments")] + ports: Vec, + } + + #[derive(Deserialize, Debug, PartialEq)] + struct Config { + server: Server, + } + + let kdl = r#"server host="localhost" 8080 8081 8082"#; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + server: Server { + host: "localhost".into(), + ports: vec![8080, 8081, 8082], + }, + } + ); + } } diff --git a/src/se.rs b/src/se.rs index a08e60b..d16bf14 100644 --- a/src/se.rs +++ b/src/se.rs @@ -761,11 +761,24 @@ impl<'a> ser::SerializeStruct for NodeChildMapSerializer<'a> { key: &'static str, value: &T, ) -> Result<(), Self::Error> { - let mut child = KdlNode::new(key); - let mut child_ser = NodeValueSerializer { node: &mut child }; - value.serialize(&mut child_ser)?; - let children = self.node.ensure_children(); - children.nodes_mut().push(child); + if let Some(attr_name) = key.strip_prefix('@') { + let kdl_val = to_kdl_value(value)?; + self.node + .entries_mut() + .push(KdlEntry::new_prop(attr_name, kdl_val)); + } else if key == "$arguments" { + let mut ser = ArgsSerializer { node: self.node }; + value.serialize(&mut ser)?; + } else if key.starts_with("$argument") { + let kdl_val = to_kdl_value(value)?; + self.node.entries_mut().push(KdlEntry::new(kdl_val)); + } else { + let mut child = KdlNode::new(key); + let mut child_ser = NodeValueSerializer { node: &mut child }; + value.serialize(&mut child_ser)?; + let children = self.node.ensure_children(); + children.nodes_mut().push(child); + } Ok(()) } @@ -1170,6 +1183,235 @@ impl ser::Serializer for KdlValueSerializer { } } +struct ArgsSerializer<'a> { + node: &'a mut KdlNode, +} + +impl<'a> ser::Serializer for &'a mut ArgsSerializer<'a> { + type Ok = (); + type Error = Error; + + type SerializeSeq = ArgsSeqSerializer<'a>; + type SerializeTuple = ArgsSeqSerializer<'a>; + type SerializeTupleStruct = ArgsSeqSerializer<'a>; + type SerializeTupleVariant = ser::Impossible<(), Error>; + type SerializeMap = ser::Impossible<(), Error>; + type SerializeStruct = ser::Impossible<(), Error>; + type SerializeStructVariant = ser::Impossible<(), Error>; + + fn serialize_seq(self, _: Option) -> Result { + Ok(ArgsSeqSerializer { node: self.node }) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_struct( + self, + _: &'static str, + len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_bool(self, v: bool) -> Result { + self.node.entries_mut().push(KdlEntry::new(v)); + Ok(()) + } + + fn serialize_i8(self, v: i8) -> Result { + self.serialize_i64(v as i64) + } + fn serialize_i16(self, v: i16) -> Result { + self.serialize_i64(v as i64) + } + fn serialize_i32(self, v: i32) -> Result { + self.serialize_i64(v as i64) + } + fn serialize_i64(self, v: i64) -> Result { + self.node + .entries_mut() + .push(KdlEntry::new(KdlValue::Integer(v as i128))); + Ok(()) + } + + fn serialize_u8(self, v: u8) -> Result { + self.serialize_u64(v as u64) + } + fn serialize_u16(self, v: u16) -> Result { + self.serialize_u64(v as u64) + } + fn serialize_u32(self, v: u32) -> Result { + self.serialize_u64(v as u64) + } + fn serialize_u64(self, v: u64) -> Result { + self.node + .entries_mut() + .push(KdlEntry::new(KdlValue::Integer(v as i128))); + Ok(()) + } + + fn serialize_f32(self, v: f32) -> Result { + self.serialize_f64(v as f64) + } + fn serialize_f64(self, v: f64) -> Result { + self.node + .entries_mut() + .push(KdlEntry::new(KdlValue::Float(v))); + Ok(()) + } + + fn serialize_char(self, v: char) -> Result { + self.serialize_str(&v.to_string()) + } + + fn serialize_str(self, v: &str) -> Result { + self.node + .entries_mut() + .push(KdlEntry::new(KdlValue::String(v.to_string()))); + Ok(()) + } + + fn serialize_bytes(self, _: &[u8]) -> Result { + Err(ser::Error::custom( + "bytes cannot be represented as KDL arguments", + )) + } + + fn serialize_none(self) -> Result { + self.node.entries_mut().push(KdlEntry::new(KdlValue::Null)); + Ok(()) + } + + fn serialize_some(self, value: &T) -> Result { + value.serialize(self) + } + + fn serialize_unit(self) -> Result { + self.node.entries_mut().push(KdlEntry::new(KdlValue::Null)); + Ok(()) + } + + fn serialize_unit_struct(self, _: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _: &'static str, + _: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + fn serialize_newtype_struct( + self, + _: &'static str, + value: &T, + ) -> Result { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: &T, + ) -> Result { + Err(ser::Error::custom( + "newtype variants cannot be represented as KDL arguments", + )) + } + + fn serialize_map(self, _: Option) -> Result { + Err(ser::Error::custom( + "maps are cannot be represented as KDL arguments", + )) + } + + fn serialize_struct( + self, + _: &'static str, + _: usize, + ) -> Result { + Err(ser::Error::custom( + "structs are cannot be represented as KDL arguments", + )) + } + + fn serialize_struct_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: usize, + ) -> Result { + Err(ser::Error::custom( + "struct variants cannot be represented as KDL arguments", + )) + } + + fn serialize_tuple_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: usize, + ) -> Result { + Err(ser::Error::custom( + "tuple variants cannot be represented as KDL arguments", + )) + } +} + +struct ArgsSeqSerializer<'a> { + node: &'a mut KdlNode, +} + +impl<'a> ser::SerializeSeq for ArgsSeqSerializer<'a> { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { + let kdl_val = to_kdl_value(value)?; + self.node.entries_mut().push(KdlEntry::new(kdl_val)); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a> ser::SerializeTuple for ArgsSeqSerializer<'a> { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { + ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + ser::SerializeSeq::end(self) + } +} + +impl<'a> ser::SerializeTupleStruct for ArgsSeqSerializer<'a> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> { + ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + ser::SerializeSeq::end(self) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1323,4 +1565,86 @@ mod tests { assert_eq!(doc.get("name").unwrap().get(0), Some(&"test".into())); assert_eq!(doc.get("port").unwrap().get(0), Some(&3000.into())); } + + #[test] + fn rename_props() { + #[derive(Serialize)] + struct Server { + #[serde(rename = "@host")] + host: String, + #[serde(rename = "@port")] + port: u16, + } + + #[derive(Serialize)] + struct Config { + server: Server, + } + + let config = Config { + server: Server { + host: "localhost".into(), + port: 8080, + }, + }; + let kdl = to_string(&config).unwrap(); + assert!(kdl.contains("server")); + assert!(kdl.contains("host=localhost")); + assert!(kdl.contains("port=8080")); + } + + #[test] + fn rename_args() { + #[derive(Serialize)] + struct Server { + #[serde(rename = "$argument1")] + host: String, + #[serde(rename = "@port")] + port: u16, + } + + #[derive(Serialize)] + struct Config { + server: Server, + } + + let config = Config { + server: Server { + host: "localhost".into(), + port: 8080, + }, + }; + let kdl = to_string(&config).unwrap(); + assert!(kdl.contains("server")); + assert!(kdl.contains("localhost")); + assert!(kdl.contains("port=8080")); + } + + #[test] + fn rename_all_args() { + #[derive(Serialize)] + struct Command { + #[serde(rename = "@name")] + name: String, + #[serde(rename = "$arguments")] + args: Vec, + } + + #[derive(Serialize)] + struct Config { + command: Command, + } + + let config = Config { + command: Command { + name: "run".into(), + args: vec!["--verbose".into(), "--output".into()], + }, + }; + let kdl = to_string(&config).unwrap(); + assert!(kdl.contains("command")); + assert!(kdl.contains("name=run")); + assert!(kdl.contains("--verbose")); + assert!(kdl.contains("--output")); + } }