From e81b148201ad0f1fda58e2b327394c7ba0ca556b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kat=20March=C3=A1n?= Date: Sat, 30 May 2026 15:07:41 -0700 Subject: [PATCH] feat(serde): Add support for flags and #rest (#163) Fixes: https://github.com/kdl-org/kdl-rs/issues/161 --- Cargo.toml | 2 +- src/de.rs | 331 +++++++++++++++++++++++++++++++++++++++------- src/entry.rs | 14 +- src/identifier.rs | 2 +- src/node.rs | 12 +- src/v2_parser.rs | 71 +++++----- 6 files changed, 340 insertions(+), 92 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6cbe7fe..c754b05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ homepage = "https://kdl.dev" repository = "https://github.com/kdl-org/kdl-rs" keywords = ["kdl", "document", "serialization", "config"] rust-version = "1.95" -edition = "2021" +edition = "2024" [features] default = ["span", "serde"] diff --git a/src/de.rs b/src/de.rs index 32b6ce7..83d7d86 100644 --- a/src/de.rs +++ b/src/de.rs @@ -18,6 +18,8 @@ //! //! - A **KDL document** is treated as a **map** (or struct), where each top-level //! node name is a key and the node's content is the value. +//! - A **KDL node with no arguments** and no properties/children is treated as +//! a **boolean flag**. //! - A **KDL node with only a single argument** and no properties/children is //! treated as a **scalar value** (the argument itself). //! - A **KDL node with multiple arguments** is treated as a **sequence** of those arguments. @@ -36,20 +38,22 @@ //! struct Config { //! name: String, //! port: u16, +//! active: bool //! } //! //! let kdl = r#" //! name my-app //! port 8080 +//! active //! "#; //! //! let config: Config = kdl::de::from_str(kdl).unwrap(); -//! assert_eq!(config, Config { name: "my-app".into(), port: 8080 }); +//! assert_eq!(config, Config { name: "my-app".into(), port: 8080, active: true }); //! ``` //! //! ## Arguments mapping //! -//! This library supports two kinds of special renaming for arguments - `#{n}` and `#args`: +//! This library supports three kinds of special renaming for arguments - `#{n}`, `#args`, and `#rest`: //! //! - `#{n}`: This is used when you want to access argument at a specific //! position. The field name should start with `#`, then follow by the position @@ -103,6 +107,33 @@ //! assert_eq!(config, Config { server: Server { info: Vec::from(["my-app".into(), "https://example.com".into()]) }}); //! ``` //! +//! - `#rest`: This is used when you want to collect everything that wasn't +//! referenced by an `#{n}` rename. +//! +//! ```rust +//! use serde::Deserialize; +//! +//! #[derive(Deserialize, Debug, PartialEq)] +//! struct Config { +//! command: Command, +//! } +//! +//! #[derive(Deserialize, Debug, PartialEq)] +//! struct Command { +//! #[serde(rename = "#0")] +//! cmd: String, +//! #[serde(rename = "#rest")] +//! args: Vec, +//! } +//! +//! let kdl = r#" +//! command frob "https://example.com" x y +//! "#; +//! +//! let config: Config = kdl::de::from_str(kdl).unwrap(); +//! assert_eq!(config, Config { command: Command { cmd: "frob".into(), args: vec!["https://example.com".into(), "x".into(), "y".into()]}}); +//! ``` +//! //! ## Properties mapping //! //! You can use `#@field-name` on a field that will be used for a property. @@ -130,15 +161,52 @@ //! let config: Config = kdl::de::from_str(kdl).unwrap(); //! assert_eq!(config, Config { server: Server { name: "my-app".into(), port: 8080 }}); //! ``` +//! +//! ## Types mapping +//! +//! You can extract type annotations for nodes, arguments and properties by +//! using `#type`: +//! +//! ```rust +//! use serde::Deserialize; +//! +//! #[derive(Deserialize, Debug, PartialEq)] +//! struct Config { +//! server: Server, +//! } +//! +//! #[derive(Deserialize, Debug, PartialEq)] +//! #[serde(rename_all = "kebab-case")] +//! enum ServerKind { +//! Big, +//! Small, +//! } +//! #[derive(Deserialize, Debug, PartialEq)] +//! struct Server { +//! #[serde(rename = "#type")] +//! server_kind: ServerKind, +//! #[serde(rename = "#0#type")] +//! app_type: String, +//! #[serde(rename = "#@port#type")] +//! port_type: String, +//! } +//! +//! let kdl = r#" +//! (big)server (cool)my-app port=(internal)8080 +//! "#; +//! +//! let config: Config = kdl::de::from_str(kdl).unwrap(); +//! assert_eq!(config, Config { server: Server { server_kind: ServerKind::Big, app_type: "cool".into(), port_type: "internal".into() }}); +//! ``` use std::borrow::Cow; use std::sync::Arc; use std::{fmt, iter}; use miette::{Diagnostic, LabeledSpan, Severity, SourceCode, SourceSpan}; +use serde::Deserialize; use serde::de::value::StringDeserializer; use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor}; -use serde::Deserialize; use crate::{KdlDiagnostic, KdlDocument, KdlEntry, KdlIdentifier, KdlNode, KdlValue}; @@ -698,8 +766,8 @@ impl<'de, 'a> de::Deserializer<'de> for NodeDeserializer<'a> { fn deserialize_any>(self, visitor: V) -> Result { if self.is_empty() { - // Node with no data → null/unit - return visitor.visit_unit(); + // Node with no data → bool flag + return visitor.visit_bool(true); } if self.is_scalar() { // Single argument → deserialize as scalar @@ -726,6 +794,8 @@ impl<'de, 'a> de::Deserializer<'de> for NodeDeserializer<'a> { fn deserialize_bool>(self, visitor: V) -> Result { if self.is_scalar() { ValueDeserializer::new(self.args()[0], self.input).deserialize_any(visitor) + } else if self.is_empty() { + visitor.visit_bool(true) } else { Err(Error::new("expected a boolean value")).map_err(|err| { err.with_span( @@ -866,16 +936,16 @@ impl<'de, 'a> de::Deserializer<'de> for NodeDeserializer<'a> { } fn deserialize_map>(self, visitor: V) -> Result { - visitor.visit_map(NodeMapAccess::new(self.node, self.input)) + visitor.visit_map(NodeMapAccess::new(self.node, self.input, None)) } fn deserialize_struct>( self, _name: &'static str, - _fields: &'static [&'static str], + fields: &'static [&'static str], visitor: V, ) -> Result { - self.deserialize_map(visitor) + visitor.visit_map(NodeMapAccess::new(self.node, self.input, Some(fields))) } fn deserialize_enum>( @@ -885,18 +955,18 @@ impl<'de, 'a> de::Deserializer<'de> for NodeDeserializer<'a> { visitor: V, ) -> Result { // If the node is a scalar string, treat it as a unit variant name. - if self.is_scalar() { - if let KdlValue::String(s) = self.args()[0].value() { - return visitor - .visit_enum(str_deserializer(s.as_str())) - .map_err(|err| { - err.with_span( - self.input, - entry_span(self.args()[0]), - "while deserializing this enum variant", - ) - }); - } + if self.is_scalar() + && let KdlValue::String(s) = self.args()[0].value() + { + return visitor + .visit_enum(str_deserializer(s.as_str())) + .map_err(|err| { + err.with_span( + self.input, + entry_span(self.args()[0]), + "while deserializing this enum variant", + ) + }); } // Otherwise, try treating the node as an externally-tagged enum: @@ -1042,6 +1112,7 @@ impl<'de, 'a> SeqAccess<'de> for NodeGroupSeqAccess<'a> { } } +#[derive(Debug)] struct NodeMapAccess<'a> { /// Properties and children combined as (key, value_source). entries: Vec<(Cow<'a, str>, NodeMapValue<'a>)>, @@ -1049,51 +1120,86 @@ struct NodeMapAccess<'a> { input: &'a Arc, } +#[derive(Debug)] enum NodeMapValue<'a> { Ident(&'a KdlIdentifier), Arg(&'a KdlEntry), Args(Vec<&'a KdlEntry>), SingleNode(&'a KdlNode), MultiNode(Vec<&'a KdlNode>), + Omitted, } impl<'a> NodeMapAccess<'a> { - fn new(node: &'a KdlNode, input: &'a Arc) -> Self { + fn new( + node: &'a KdlNode, + input: &'a Arc, + fields: Option<&'static [&'static str]>, + ) -> Self { let mut entries: Vec<(Cow<'a, str>, NodeMapValue<'a>)> = Vec::new(); - entries.push(("#name".into(), NodeMapValue::Ident(node.name()))); - if let Some(ty) = node.ty() { - entries.push(("#type".into(), NodeMapValue::Ident(ty))) + if let Some(fields) = fields { + if fields.contains(&"#name") { + entries.push(("#name".into(), NodeMapValue::Ident(node.name()))); + } + if fields.contains(&"#type") + && let Some(ty) = node.ty() + { + entries.push(("#type".into(), NodeMapValue::Ident(ty))); + } } - let args: Vec<_> = node - .entries() - .iter() - .filter(|e| e.name().is_none()) - .collect(); - let props: Vec<_> = node - .entries() - .iter() - .filter(|e| e.name().is_some()) - .collect(); - - for (i, arg) in args.iter().enumerate() { - entries.push((format!("#{}", i).into(), NodeMapValue::Arg(arg))); - if let Some(ty) = arg.ty() { - entries.push((format!("#{}#type", i).into(), NodeMapValue::Ident(ty))); + if let Some(fields) = fields + && let collect_all = fields.contains(&"#args") + && let collect_rest = fields.contains(&"#rest") + { + let mut args = Vec::new(); + let mut rest = Vec::new(); + for (i, arg) in node + .entries() + .iter() + .filter(|e| e.name().is_none()) + .enumerate() + { + if collect_all { + args.push(arg); + } + let idx_name = format!("#{i}"); + let ty_name = format!("{idx_name}#type"); + if let Some(ty) = arg.ty() { + entries.push((ty_name.into(), NodeMapValue::Ident(ty))); + } + if fields.contains(&idx_name.as_ref()) { + entries.push((idx_name.into(), NodeMapValue::Arg(arg))); + } else if collect_rest { + rest.push(arg); + } + } + if collect_all { + entries.push(("#args".into(), NodeMapValue::Args(args))); + } + if collect_rest { + entries.push(("#rest".into(), NodeMapValue::Args(rest))); } } - if !args.is_empty() { - entries.push(("#args".into(), NodeMapValue::Args(args.clone()))); - } - - for prop in &props { + let mut used_names = Vec::new(); + for prop in node.entries().iter().filter(|e| e.name().is_some()) { + // SAFETY: we just filtered for things with names. let name = prop.name().unwrap().value(); - entries.push((format!("#@{}", name).into(), NodeMapValue::Arg(prop))); - if let Some(ty) = prop.ty() { - entries.push((format!("#@{}#type", name).into(), NodeMapValue::Ident(ty))); + if let Some(fields) = fields { + let hash_name = format!("#@{name}"); + let ty_name = format!("{hash_name}#type"); + if fields.contains(&hash_name.as_ref()) { + entries.push((hash_name.into(), NodeMapValue::Arg(prop))); + } + if let Some(ty) = prop.ty() + && fields.contains(&ty_name.as_ref()) + { + entries.push((ty_name.into(), NodeMapValue::Ident(ty))); + } } entries.push((name.into(), NodeMapValue::Arg(prop))); + used_names.push(name); } // Add children (grouped by node name) @@ -1103,6 +1209,7 @@ impl<'a> NodeMapAccess<'a> { std::collections::HashMap::new(); for child in children.nodes() { let name = child.name().value(); + used_names.push(name); if !child_groups.contains_key(name) { child_keys.push(name); } @@ -1117,7 +1224,15 @@ impl<'a> NodeMapAccess<'a> { } } } + if let Some(fields) = fields { + for field in fields { + if !field.starts_with('#') && !used_names.contains(field) { + entries.push((String::from(*field).into(), NodeMapValue::Omitted)); + } + } + } + dbg!(&entries); NodeMapAccess { entries, idx: 0, @@ -1166,6 +1281,7 @@ impl<'de, 'a> MapAccess<'de> for NodeMapAccess<'a> { nodes, input: self.input, }), + NodeMapValue::Omitted => seed.deserialize(OmittedChildDeserializer), } } } @@ -1215,6 +1331,36 @@ impl<'de, 'a> Deserializer<'de> for ArgsSeqDeserializer<'a> { } } +struct OmittedChildDeserializer; + +impl<'de> Deserializer<'de> for OmittedChildDeserializer { + type Error = Error; + + fn deserialize_any>(self, visitor: V) -> Result { + visitor.visit_bool(false) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_bool(false) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_none() + } + + serde::forward_to_deserialize_any! { + i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf unit unit_struct tuple newtype_struct seq + tuple_struct map struct enum identifier ignored_any + } +} + struct NodeEnumAccess<'a> { node: &'a KdlNode, input: &'a Arc, @@ -1885,6 +2031,34 @@ h 8 ); } + #[test] + fn rename_rest_args() { + #[derive(Deserialize, Debug, PartialEq)] + struct Command { + #[serde(rename = "#0")] + name: String, + #[serde(rename = "#rest")] + args: Vec, + } + + #[derive(Deserialize, Debug, PartialEq)] + struct Config { + command: Command, + } + + let kdl = "command run --verbose --output result.txt"; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + command: Command { + name: "run".into(), + args: vec!["--verbose".into(), "--output".into(), "result.txt".into(),], + }, + } + ); + } + #[test] fn rename_children_args() { #[derive(Deserialize, Debug, PartialEq)] @@ -1913,6 +2087,69 @@ h 8 ); } + #[test] + fn bool_flag_children() { + #[derive(Deserialize, Debug, PartialEq)] + #[serde(rename_all = "kebab-case")] + struct Server { + host: String, + active: bool, + } + + #[derive(Deserialize, Debug, PartialEq)] + struct Config { + server: Server, + } + + let kdl = "server host=localhost { active }"; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + server: Server { + host: "localhost".into(), + active: true, + }, + } + ); + + let kdl = "server host=localhost active=#true"; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + server: Server { + host: "localhost".into(), + active: true, + }, + } + ); + + let kdl = "server host=localhost"; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + server: Server { + host: "localhost".into(), + active: false, + }, + } + ); + + let kdl = "server host=localhost { }"; + let config: Config = from_str(kdl).unwrap(); + assert_eq!( + config, + Config { + server: Server { + host: "localhost".into(), + active: false, + }, + } + ); + } + // TODO(@zkat): Can't figure out how to do internally tagged stuff just yet... // #[test] // fn internal_tag_rename() { diff --git a/src/entry.rs b/src/entry.rs index 94f87b8..93fcef4 100644 --- a/src/entry.rs +++ b/src/entry.rs @@ -2,7 +2,7 @@ use miette::SourceSpan; use std::{fmt::Display, str::FromStr}; -use crate::{v2_parser, KdlError, KdlIdentifier, KdlValue}; +use crate::{KdlError, KdlIdentifier, KdlValue, v2_parser}; /// KDL Entries are the "arguments" to KDL nodes: either a (positional) /// [`Argument`](https://github.com/kdl-org/kdl/blob/main/SPEC.md#argument) or @@ -230,7 +230,8 @@ impl KdlEntry { let s = x.value_repr.trim(); // convert raw strings to new format let s = s.strip_prefix('r').unwrap_or(s); - let s = if crate::value::is_plain_ident(val) { + + if crate::value::is_plain_ident(val) { val.into() } else if s .find(|c| v2_parser::NEWLINES.iter().any(|nl| nl.contains(c))) @@ -258,8 +259,7 @@ impl KdlEntry { } else { // We're all good! Let's move on. s.to_string() - }; - s + } } // These have `#` prefixes now. The regular Display impl will // take care of that. @@ -301,7 +301,8 @@ impl KdlEntry { } else { s.to_string() }; - let s = if crate::value::is_plain_ident(val) + + if crate::value::is_plain_ident(val) && !s.starts_with('\"') && !s.starts_with("r#") { @@ -340,8 +341,7 @@ impl KdlEntry { } else { // We're all good! Let's move on. s.to_string() - }; - s + } } // No more # prefix for these KdlValue::Bool(b) => b.to_string(), diff --git a/src/identifier.rs b/src/identifier.rs index 1522cc8..49762f6 100644 --- a/src/identifier.rs +++ b/src/identifier.rs @@ -2,7 +2,7 @@ use miette::SourceSpan; use std::{fmt::Display, str::FromStr}; -use crate::{v2_parser, KdlError, KdlValue}; +use crate::{KdlError, KdlValue, v2_parser}; /// Represents a KDL /// [Identifier](https://github.com/kdl-org/kdl/blob/main/SPEC.md#identifier). diff --git a/src/node.rs b/src/node.rs index af9914b..6b3c75f 100644 --- a/src/node.rs +++ b/src/node.rs @@ -10,8 +10,8 @@ use std::{ use miette::SourceSpan; use crate::{ - v2_parser, FormatConfig, KdlDocument, KdlDocumentFormat, KdlEntry, KdlError, KdlIdentifier, - KdlValue, + FormatConfig, KdlDocument, KdlDocumentFormat, KdlEntry, KdlError, KdlIdentifier, KdlValue, + v2_parser, }; /// Represents an individual KDL @@ -291,10 +291,10 @@ impl KdlNode { if !terminator.starts_with('\n') { *terminator = "\n".into(); } - if let Some(c) = trailing.chars().next() { - if !c.is_whitespace() { - trailing.insert(0, ' '); - } + if let Some(c) = trailing.chars().next() + && !c.is_whitespace() + { + trailing.insert(0, ' '); } *before_children = " ".into(); diff --git a/src/v2_parser.rs b/src/v2_parser.rs index 281061b..d1fb22e 100644 --- a/src/v2_parser.rs +++ b/src/v2_parser.rs @@ -7,7 +7,8 @@ use miette::{Severity, SourceSpan}; use num_traits::CheckedMul; use winnow::{ - ascii::{digit1, hex_digit1, oct_digit1, Caseless}, + LocatingSlice, + ascii::{Caseless, digit1, hex_digit1, oct_digit1}, combinator::{ alt, cut_err, empty, eof, fail, not, opt, peek, preceded, repeat, repeat_till, separated, terminated, trace, @@ -16,7 +17,6 @@ use winnow::{ prelude::*, stream::{AsChar, Location, Recover, Recoverable, Stream}, token::{any, none_of, one_of, take_while}, - LocatingSlice, }; use crate::{ @@ -270,10 +270,10 @@ pub(crate) fn document(input: &mut Input<'_>) -> PResult { if badend { document.parse_next(input)?; } - if let Some(bom) = bom { - if let Some(fmt) = doc.format_mut() { - fmt.leading = format!("{bom}{}", fmt.leading); - } + if let Some(bom) = bom + && let Some(fmt) = doc.format_mut() + { + fmt.leading = format!("{bom}{}", fmt.leading); } Ok(doc) } @@ -300,11 +300,11 @@ fn nodes(input: &mut Input<'_>) -> PResult { // If there is a node, let it have the leading format // This gives more consistent behavior - if let Some(first_node) = ns.get_mut(0) { - if let Some(first_node_format) = first_node.format_mut() { - first_node_format.leading = leading.into(); - leading = ""; - } + if let Some(first_node) = ns.get_mut(0) + && let Some(first_node_format) = first_node.format_mut() + { + first_node_format.leading = leading.into(); + leading = ""; } Ok(KdlDocument { @@ -1467,9 +1467,11 @@ mod string_tests { Some(KdlValue::String("\"\"\"".into())) ); - assert!(string - .parse(new_input("\"\"\"\nfoo\n bar\n baz\n \"\"\"")) - .is_err()); + assert!( + string + .parse(new_input("\"\"\"\nfoo\n bar\n baz\n \"\"\"")) + .is_err() + ); } #[test] @@ -1507,9 +1509,11 @@ mod string_tests { .unwrap(), Some(KdlValue::String("foo\n \\nbar\n baz".into())) ); - assert!(string - .parse(new_input("#\"\"\"\nfoo\n bar\n baz\n \"\"\"#")) - .is_err()); + assert!( + string + .parse(new_input("#\"\"\"\nfoo\n bar\n baz\n \"\"\"#")) + .is_err() + ); assert!(string.parse(new_input("#\"\nfoo\nbar\nbaz\n\"#")).is_err()); assert!(string.parse(new_input("\"\nfoo\nbar\nbaz\n\"")).is_err()); @@ -1699,9 +1703,11 @@ fn multi_line_comment_test() { assert!(multi_line_comment.parse(new_input("/*\nfoo*/")).is_ok()); assert!(multi_line_comment.parse(new_input("/*foo\n*/")).is_ok()); assert!(multi_line_comment.parse(new_input("/* foo\n*/")).is_ok()); - assert!(multi_line_comment - .parse(new_input("/* /*bar*/ foo\n*/")) - .is_ok()); + assert!( + multi_line_comment + .parse(new_input("/* /*bar*/ foo\n*/")) + .is_ok() + ); } /// slashdash := '/-' (node-space | line-space)* @@ -1724,15 +1730,18 @@ fn slashdash_tests() { assert!(node_entry.parse(new_input("/-commented tada")).is_ok()); assert!(node.parse(new_input("foo /- { }")).is_ok()); assert!(node.parse(new_input("foo /- { bar }")).is_ok()); - assert!(node - .parse(new_input("/- foo bar\nnode /-1 2 { x }")) - .is_ok()); - assert!(node - .parse(new_input("/- foo bar\nnode 2 /-3 { x }")) - .is_ok()); - assert!(node - .parse(new_input("/- foo bar\nnode /-1 2 /-3 { x }")) - .is_ok()); + assert!( + node.parse(new_input("/- foo bar\nnode /-1 2 { x }")) + .is_ok() + ); + assert!( + node.parse(new_input("/- foo bar\nnode 2 /-3 { x }")) + .is_ok() + ); + assert!( + node.parse(new_input("/- foo bar\nnode /-1 2 /-3 { x }")) + .is_ok() + ); } /// `number := keyword-number | hex | octal | binary | decimal` @@ -2031,7 +2040,9 @@ macro_rules! impl_from_str_radix { }; } -impl_from_str_radix!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize); +impl_from_str_radix!( + i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize +); trait MaybeNegatable: CheckedMul { fn negated(&self) -> Option;