diff --git a/derive/src/derive_enum.rs b/derive/src/derive_enum.rs index 9a870bc..b6c0508 100644 --- a/derive/src/derive_enum.rs +++ b/derive/src/derive_enum.rs @@ -1,4 +1,4 @@ -use crate::generate::{FnSelfArg, Generator}; +use crate::generate::{FnSelfArg, Generator, StreamBuilder}; use crate::parse::{EnumVariant, Fields}; use crate::prelude::*; use crate::Result; @@ -10,9 +10,15 @@ pub struct DeriveEnum { } impl DeriveEnum { - pub fn generate_encode(self, generator: &mut Generator) -> Result<()> { - let DeriveEnum { variants } = self; + fn iter_fields(&self) -> EnumVariantIterator { + EnumVariantIterator { + idx: 0, + last_val: None, + variants: &self.variants, + } + } + pub fn generate_encode(self, generator: &mut Generator) -> Result<()> { generator .impl_for("bincode::enc::Encode") .unwrap() @@ -25,7 +31,7 @@ impl DeriveEnum { fn_body.ident_str("match"); fn_body.ident_str("self"); fn_body.group(Delimiter::Brace, |match_body| { - for (variant_index, variant) in variants.into_iter().enumerate() { + for (variant_index, variant) in self.iter_fields() { // Self::Variant match_body.ident_str("Self"); match_body.puncts("::"); @@ -62,11 +68,16 @@ impl DeriveEnum { // } match_body.group(Delimiter::Brace, |body| { // variant index - body.push_parsed(format!( - "::encode(&{}, &mut encoder)?;", - variant_index - )) - .unwrap(); + body.push_parsed("::encode") + .unwrap(); + body.group(Delimiter::Parenthesis, |args| { + args.punct('&'); + args.group(Delimiter::Parenthesis, |num| num.extend(variant_index)); + args.punct(','); + args.push_parsed("&mut encoder").unwrap(); + }); + body.punct('?'); + body.punct(';'); // If we have any fields, encode them all one by one for field_name in variant.fields.names() { body.push_parsed(format!( @@ -85,8 +96,73 @@ impl DeriveEnum { Ok(()) } + /// Build the catch-all case for an int-to-enum decode implementation + fn invalid_variant_case(&self, enum_name: &str, result: &mut StreamBuilder) { + // we'll be generating: + // variant => Err( + // bincode::error::DecodeError::UnexpectedVariant { + // found: variant, + // type_name: + // allowed: ..., + // } + // ) + // + // Where allowed is either: + // - bincode::error::AllowedEnumVariants::Range { min: 0, max: } + // if we have no fixed value variants + // - bincode::error::AllowedEnumVariants::Allowed(&[, , ...]) + // if we have fixed value variants + result.ident_str("variant"); + result.puncts("=>"); + result.ident_str("Err"); + result.group(Delimiter::Parenthesis, |err_inner| { + err_inner + .push_parsed("bincode::error::DecodeError::UnexpectedVariant") + .unwrap(); + err_inner.group(Delimiter::Brace, |variant_inner| { + variant_inner.ident_str("found"); + variant_inner.punct(':'); + variant_inner.ident_str("variant"); + variant_inner.punct(','); + + variant_inner.ident_str("type_name"); + variant_inner.punct(':'); + variant_inner.lit_str(enum_name); + variant_inner.punct(','); + + variant_inner.ident_str("allowed"); + variant_inner.punct(':'); + + if self.variants.iter().any(|i| i.has_fixed_value()) { + // we have fixed values, implement AllowedEnumVariants::Allowed + variant_inner + .push_parsed("bincode::error::AllowedEnumVariants::Allowed") + .unwrap(); + variant_inner.group(Delimiter::Parenthesis, |allowed_inner| { + allowed_inner.punct('&'); + allowed_inner.group(Delimiter::Bracket, |allowed_slice| { + for (idx, (ident, _)) in self.iter_fields().enumerate() { + if idx != 0 { + allowed_slice.punct(','); + } + allowed_slice.extend(ident); + } + }); + }); + } else { + // no fixed values, implement a range + variant_inner + .push_parsed(format!( + "bincode::error::AllowedEnumVariants::Range {{ min: 0, max: {} }}", + self.variants.len() - 1 + )) + .unwrap(); + } + }) + }); + } + pub fn generate_decode(self, generator: &mut Generator) -> Result<()> { - let DeriveEnum { variants } = self; let enum_name = generator.target_name().to_string(); if generator.has_lifetimes() { @@ -103,43 +179,44 @@ impl DeriveEnum { .push_parsed("let variant_index = ::decode(&mut decoder)?;").unwrap(); fn_builder.push_parsed("match variant_index").unwrap(); fn_builder.group(Delimiter::Brace, |variant_case| { - for (idx, variant) in variants.iter().enumerate() { - // idx => Ok(..) - variant_case.lit_u32(idx as u32); - variant_case.puncts("=>"); - variant_case.ident_str("Ok"); - variant_case.group(Delimiter::Parenthesis, |variant_case_body| { - // Self::Variant { } - // Self::Variant { 0: ..., 1: ... 2: ... }, - // Self::Variant { a: ..., b: ... c: ... }, - variant_case_body.ident_str("Self"); - variant_case_body.puncts("::"); - variant_case_body.ident(variant.name.clone()); + for (mut variant_index, variant) in self.iter_fields() { + // idx => Ok(..) + if variant_index.len() > 1 { + variant_case.push_parsed("x if x == ").unwrap(); + variant_case.extend(variant_index); + } else { + variant_case.push(variant_index.remove(0)); + } + variant_case.puncts("=>"); + variant_case.ident_str("Ok"); + variant_case.group(Delimiter::Parenthesis, |variant_case_body| { + // Self::Variant { } + // Self::Variant { 0: ..., 1: ... 2: ... }, + // Self::Variant { a: ..., b: ... c: ... }, + variant_case_body.ident_str("Self"); + variant_case_body.puncts("::"); + variant_case_body.ident(variant.name.clone()); - variant_case_body.group(Delimiter::Brace, |variant_body| { - let is_tuple = matches!(variant.fields, Fields::Tuple(_)); - for (idx, field) in variant.fields.names().into_iter().enumerate() { - if is_tuple { - variant_body.lit_usize(idx); - } else { - variant_body.ident(field.unwrap_ident().clone()); + variant_case_body.group(Delimiter::Brace, |variant_body| { + let is_tuple = matches!(variant.fields, Fields::Tuple(_)); + for (idx, field) in variant.fields.names().into_iter().enumerate() { + if is_tuple { + variant_body.lit_usize(idx); + } else { + variant_body.ident(field.unwrap_ident().clone()); + } + variant_body.punct(':'); + variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap(); } - variant_body.punct(':'); - variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap(); - } + }); }); - }); - variant_case.punct(','); - } + variant_case.punct(','); + } - // invalid idx - variant_case.push_parsed(format!( - "variant => return Err(bincode::error::DecodeError::UnexpectedVariant {{ min: 0, max: {}, found: variant, type_name: {:?} }})", - variants.len() - 1, - enum_name.to_string() - )).unwrap(); - }); - }).unwrap(); + // invalid idx + self.invalid_variant_case(&enum_name, variant_case); + }); + }).unwrap(); } else { // enum has no lifetimes, implement Decode generator.impl_for("bincode::de::Decode") @@ -153,45 +230,79 @@ impl DeriveEnum { .push_parsed("let variant_index = ::decode(&mut decoder)?;").unwrap(); fn_builder.push_parsed("match variant_index").unwrap(); fn_builder.group(Delimiter::Brace, |variant_case| { - for (idx, variant) in variants.iter().enumerate() { - // idx => Ok(..) - variant_case.lit_u32(idx as u32); - variant_case.puncts("=>"); - variant_case.ident_str("Ok"); - variant_case.group(Delimiter::Parenthesis, |variant_case_body| { - // Self::Variant { } - // Self::Variant { 0: ..., 1: ... 2: ... }, - // Self::Variant { a: ..., b: ... c: ... }, - variant_case_body.ident_str("Self"); - variant_case_body.puncts("::"); - variant_case_body.ident(variant.name.clone()); + for (mut variant_index, variant) in self.iter_fields() { + // idx => Ok(..) + if variant_index.len() > 1 { + variant_case.push_parsed("x if x == ").unwrap(); + variant_case.extend(variant_index); + } else { + variant_case.push(variant_index.remove(0)); + } + variant_case.puncts("=>"); + variant_case.ident_str("Ok"); + variant_case.group(Delimiter::Parenthesis, |variant_case_body| { + // Self::Variant { } + // Self::Variant { 0: ..., 1: ... 2: ... }, + // Self::Variant { a: ..., b: ... c: ... }, + variant_case_body.ident_str("Self"); + variant_case_body.puncts("::"); + variant_case_body.ident(variant.name.clone()); - variant_case_body.group(Delimiter::Brace, |variant_body| { - let is_tuple = matches!(variant.fields, Fields::Tuple(_)); - for (idx, field) in variant.fields.names().into_iter().enumerate() { - if is_tuple { - variant_body.lit_usize(idx); - } else { - variant_body.ident(field.unwrap_ident().clone()); + variant_case_body.group(Delimiter::Brace, |variant_body| { + let is_tuple = matches!(variant.fields, Fields::Tuple(_)); + for (idx, field) in variant.fields.names().into_iter().enumerate() { + if is_tuple { + variant_body.lit_usize(idx); + } else { + variant_body.ident(field.unwrap_ident().clone()); + } + variant_body.punct(':'); + variant_body.push_parsed("bincode::de::Decode::decode(&mut decoder)?,").unwrap(); } - variant_body.punct(':'); - variant_body.push_parsed("bincode::de::Decode::decode(&mut decoder)?,").unwrap(); - } + }); }); - }); - variant_case.punct(','); - } + variant_case.punct(','); + } - // invalid idx - variant_case.push_parsed(format!( - "variant => return Err(bincode::error::DecodeError::UnexpectedVariant {{ min: 0, max: {}, found: variant, type_name: {:?} }})", - variants.len() - 1, - enum_name.to_string() - )).unwrap(); - }); - }).unwrap(); + // invalid idx + self.invalid_variant_case(&enum_name, variant_case); + }); + }).unwrap(); } Ok(()) } } + +struct EnumVariantIterator<'a> { + variants: &'a [EnumVariant], + idx: usize, + last_val: Option<(Literal, u32)>, +} + +impl<'a> Iterator for EnumVariantIterator<'a> { + type Item = (Vec, &'a EnumVariant); + + fn next(&mut self) -> Option { + let idx = self.idx; + let variant = self.variants.get(self.idx)?; + self.idx += 1; + + let tokens = if let Fields::Integer(lit) = &variant.fields { + let tree = TokenTree::Literal(lit.clone()); + self.last_val = Some((lit.clone(), 0)); + vec![tree] + } else if let Some((lit, add)) = self.last_val.as_mut() { + *add += 1; + vec![ + TokenTree::Literal(lit.clone()), + TokenTree::Punct(Punct::new('+', Spacing::Alone)), + TokenTree::Literal(Literal::u32_suffixed(*add)), + ] + } else { + vec![TokenTree::Literal(Literal::u32_suffixed(idx as u32))] + }; + + Some((tokens, variant)) + } +} diff --git a/derive/src/error.rs b/derive/src/error.rs index b44dcda..e16ea83 100644 --- a/derive/src/error.rs +++ b/derive/src/error.rs @@ -4,10 +4,19 @@ use std::fmt; #[derive(Debug)] pub enum Error { UnknownDataType(Span), - InvalidRustSyntax(Span), + InvalidRustSyntax { span: Span, expected: String }, ExpectedIdent(Span), } +impl Error { + pub fn wrong_token(token: Option<&TokenTree>, expected: &'static str) -> Result { + Err(Self::InvalidRustSyntax { + span: token.map(|t| t.span()).unwrap_or_else(Span::call_site), + expected: format!("{}, got {:?}", expected, token), + }) + } +} + // helper functions for the unit tests #[cfg(test)] impl Error { @@ -16,7 +25,7 @@ impl Error { } pub fn is_invalid_rust_syntax(&self) -> bool { - matches!(self, Error::InvalidRustSyntax(_)) + matches!(self, Error::InvalidRustSyntax { .. }) } } @@ -26,7 +35,9 @@ impl fmt::Display for Error { Self::UnknownDataType(_) => { write!(fmt, "Unknown data type, only enum and struct are supported") } - Self::InvalidRustSyntax(_) => write!(fmt, "Invalid rust syntax"), + Self::InvalidRustSyntax { expected, .. } => { + write!(fmt, "Invalid rust syntax, expected {}", expected) + } Self::ExpectedIdent(_) => write!(fmt, "Expected ident"), } } @@ -37,7 +48,7 @@ impl Error { let maybe_span = match &self { Error::UnknownDataType(span) | Error::ExpectedIdent(span) - | Error::InvalidRustSyntax(span) => Some(*span), + | Error::InvalidRustSyntax { span, .. } => Some(*span), }; self.throw_with_span(maybe_span.unwrap_or_else(Span::call_site)) } diff --git a/derive/src/generate/stream_builder.rs b/derive/src/generate/stream_builder.rs index 3c44647..ca383fe 100644 --- a/derive/src/generate/stream_builder.rs +++ b/derive/src/generate/stream_builder.rs @@ -125,12 +125,6 @@ impl StreamBuilder { .extend([TokenTree::Literal(Literal::string(str.as_ref()))]); } - /// Add an `u32` value to the stream. - pub fn lit_u32(&mut self, val: u32) { - self.stream - .extend([TokenTree::Literal(Literal::u32_unsuffixed(val))]); - } - /// Add an `usize` value to the stream. pub fn lit_usize(&mut self, val: usize) { self.stream diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 12c09cb..1560208 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -31,7 +31,7 @@ pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream fn derive_encode_inner(input: TokenStream) -> Result { let source = &mut input.into_iter().peekable(); - let _attributes = parse::Attributes::try_take(source)?; + let _attributes = parse::Attribute::try_take(source)?; let _visibility = parse::Visibility::try_take(source)?; let (datatype, name) = parse::DataType::take(source)?; let generics = parse::Generics::try_take(source)?; @@ -72,7 +72,7 @@ pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream fn derive_decode_inner(input: TokenStream) -> Result { let source = &mut input.into_iter().peekable(); - let _attributes = parse::Attributes::try_take(source)?; + let _attributes = parse::Attribute::try_take(source)?; let _visibility = parse::Visibility::try_take(source)?; let (datatype, name) = parse::DataType::take(source)?; let generics = parse::Generics::try_take(source)?; diff --git a/derive/src/parse/attributes.rs b/derive/src/parse/attributes.rs index 74d194f..7f6dfc9 100644 --- a/derive/src/parse/attributes.rs +++ b/derive/src/parse/attributes.rs @@ -1,38 +1,48 @@ -use super::assume_group; +use super::{assume_group, assume_punct}; use crate::parse::consume_punct_if; use crate::prelude::{Delimiter, Group, Punct, TokenTree}; use crate::{Error, Result}; use std::iter::Peekable; #[derive(Debug)] -pub struct Attributes { +pub struct Attribute { // we don't use these fields yet #[allow(dead_code)] punct: Punct, #[allow(dead_code)] - tokens: Group, + tokens: Option, } -impl Attributes { - pub fn try_take(input: &mut Peekable>) -> Result> { - if let Some(punct) = consume_punct_if(input, '#') { - // found attributes, next token should be a [] group - if let Some(TokenTree::Group(g)) = input.peek() { - if g.delimiter() != Delimiter::Bracket { - return Err(Error::InvalidRustSyntax(g.span())); +impl Attribute { + pub fn try_take(input: &mut Peekable>) -> Result> { + let mut result = Vec::new(); + + while let Some(punct) = consume_punct_if(input, '#') { + match input.peek() { + Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { + result.push(Attribute { + punct, + tokens: Some(assume_group(input.next())), + }); } - return Ok(Some(Attributes { - punct, - tokens: assume_group(input.next()), - })); + Some(TokenTree::Group(g)) => { + return Err(Error::InvalidRustSyntax { + span: g.span(), + expected: format!("[] bracket, got {:?}", g.delimiter()), + }); + } + Some(TokenTree::Punct(p)) if p.as_char() == '#' => { + // sometimes with empty lines of doc comments, we get two #'s in a row + // add an empty attributes and continue to the next loop + result.push(Attribute { + punct: assume_punct(input.next(), '#'), + tokens: None, + }) + } + token => return Error::wrong_token(token, "[] group or next # attribute"), } - // expected [] group, found something else - return Err(Error::InvalidRustSyntax(match input.peek() { - Some(next_token) => next_token.span(), - None => punct.span(), - })); } - Ok(None) + Ok(result) } } @@ -41,14 +51,14 @@ fn test_attributes_try_take() { use crate::token_stream; let stream = &mut token_stream("struct Foo;"); - assert!(Attributes::try_take(stream).unwrap().is_none()); + assert!(Attribute::try_take(stream).unwrap().is_empty()); match stream.next().unwrap() { TokenTree::Ident(i) => assert_eq!(i, "struct"), x => panic!("Expected ident, found {:?}", x), } let stream = &mut token_stream("#[cfg(test)] struct Foo;"); - assert!(Attributes::try_take(stream).unwrap().is_some()); + assert!(!Attribute::try_take(stream).unwrap().is_empty()); match stream.next().unwrap() { TokenTree::Ident(i) => assert_eq!(i, "struct"), x => panic!("Expected ident, found {:?}", x), diff --git a/derive/src/parse/body.rs b/derive/src/parse/body.rs index ff9689d..f6f7b29 100644 --- a/derive/src/parse/body.rs +++ b/derive/src/parse/body.rs @@ -1,6 +1,8 @@ -use super::{assume_group, assume_ident, read_tokens_until_punct, Attributes, Visibility}; +use super::{ + assume_group, assume_ident, assume_punct, read_tokens_until_punct, Attribute, Visibility, +}; use crate::parse::consume_punct_if; -use crate::prelude::{Delimiter, Ident, Span, TokenTree}; +use crate::prelude::{Delimiter, Ident, Literal, Span, TokenTree}; use crate::{Error, Result}; use std::iter::Peekable; @@ -18,19 +20,19 @@ impl StructBody { fields: Fields::Unit, }) } - Some(t) => { - return Err(Error::InvalidRustSyntax(t.span())); - } - _ => { - return Err(Error::InvalidRustSyntax(Span::call_site())); - } + token => return Error::wrong_token(token, "group or punct"), } let group = assume_group(input.next()); let mut stream = group.stream().into_iter().peekable(); let fields = match group.delimiter() { Delimiter::Brace => Fields::Struct(UnnamedField::parse_with_name(&mut stream)?), Delimiter::Parenthesis => Fields::Tuple(UnnamedField::parse(&mut stream)?), - _ => return Err(Error::InvalidRustSyntax(group.span())), + found => { + return Err(Error::InvalidRustSyntax { + span: group.span(), + expected: format!("brace or parenthesis, found {:?}", found), + }) + } }; Ok(StructBody { fields }) } @@ -124,37 +126,57 @@ impl EnumBody { variants: Vec::new(), }) } - Some(t) => { - return Err(Error::InvalidRustSyntax(t.span())); - } - _ => { - return Err(Error::InvalidRustSyntax(Span::call_site())); - } + token => return Error::wrong_token(token, "group or ;"), } let group = assume_group(input.next()); let mut variants = Vec::new(); let stream = &mut group.stream().into_iter().peekable(); while stream.peek().is_some() { - let attributes = Attributes::try_take(stream)?; + let attributes = Attribute::try_take(stream)?; let ident = match stream.peek() { Some(TokenTree::Ident(_)) => assume_ident(stream.next()), - Some(x) => return Err(Error::InvalidRustSyntax(x.span())), - None => return Err(Error::InvalidRustSyntax(Span::call_site())), + token => return Error::wrong_token(token, "ident"), }; let mut fields = Fields::Unit; - if let Some(TokenTree::Group(_)) = stream.peek() { - let group = assume_group(stream.next()); - let stream = &mut group.stream().into_iter().peekable(); - match group.delimiter() { - Delimiter::Brace => { - fields = Fields::Struct(UnnamedField::parse_with_name(stream)?) + match stream.peek() { + Some(TokenTree::Group(_)) => { + let group = assume_group(stream.next()); + let stream = &mut group.stream().into_iter().peekable(); + match group.delimiter() { + Delimiter::Brace => { + fields = Fields::Struct(UnnamedField::parse_with_name(stream)?) + } + Delimiter::Parenthesis => { + fields = Fields::Tuple(UnnamedField::parse(stream)?) + } + delim => { + return Err(Error::InvalidRustSyntax { + span: group.span(), + expected: format!("Brace or parenthesis, found {:?}", delim), + }) + } } - Delimiter::Parenthesis => fields = Fields::Tuple(UnnamedField::parse(stream)?), - _ => return Err(Error::InvalidRustSyntax(group.span())), } + Some(TokenTree::Punct(p)) if p.as_char() == '=' => { + assume_punct(stream.next(), '='); + match stream.next() { + Some(TokenTree::Literal(lit)) => { + fields = Fields::Integer(lit); + } + token => return Error::wrong_token(token.as_ref(), "literal"), + } + } + Some(TokenTree::Punct(p)) if p.as_char() == ',' => { + // next field + } + None => { + // group done + } + token => return Error::wrong_token(token, "group, comma or ="), } + consume_punct_if(stream, ','); variants.push(EnumVariant { @@ -209,7 +231,13 @@ fn test_enum_body_take() { pub struct EnumVariant { pub name: Ident, pub fields: Fields, - pub attributes: Option, + pub attributes: Vec, +} + +impl EnumVariant { + pub fn has_fixed_value(&self) -> bool { + matches!(&self.fields, Fields::Integer(_)) + } } #[derive(Debug)] @@ -223,6 +251,14 @@ pub enum Fields { /// ``` Unit, + /// Variant with an integer value. + /// ```rs + /// enum Foo { + /// Baz = 5, + /// } + /// ``` + Integer(Literal), + /// Tuple-like variant /// ```rs /// enum Foo { @@ -258,7 +294,7 @@ impl Fields { .iter() .map(|(ident, _)| IdentOrIndex::Ident(ident)) .collect(), - Self::Unit => Vec::new(), + Self::Unit | Self::Integer(_) => Vec::new(), } } @@ -266,7 +302,7 @@ impl Fields { match self { Self::Tuple(_) => Some(Delimiter::Parenthesis), Self::Struct(_) => Some(Delimiter::Brace), - Self::Unit => None, + Self::Unit | Self::Integer(_) => None, } } } @@ -282,6 +318,7 @@ impl Fields { Self::Tuple(fields) => fields.len(), Self::Struct(fields) => fields.len(), Self::Unit => 0, + Self::Integer(_) => 0, } } @@ -290,6 +327,7 @@ impl Fields { Self::Tuple(fields) => fields.get(index).map(|f| (None, f)), Self::Struct(fields) => fields.get(index).map(|(ident, field)| (Some(ident), field)), Self::Unit => None, + Self::Integer(_) => None, } } } @@ -298,7 +336,7 @@ impl Fields { pub struct UnnamedField { pub vis: Visibility, pub r#type: Vec, - pub attributes: Option, + pub attributes: Vec, } impl UnnamedField { @@ -307,20 +345,24 @@ impl UnnamedField { ) -> Result> { let mut result = Vec::new(); loop { - let attributes = Attributes::try_take(input)?; + let attributes = Attribute::try_take(input)?; let vis = Visibility::try_take(input)?; let ident = match input.peek() { Some(TokenTree::Ident(_)) => assume_ident(input.next()), - Some(x) => return Err(Error::InvalidRustSyntax(x.span())), + Some(x) => { + return Err(Error::InvalidRustSyntax { + span: x.span(), + expected: format!("ident or end of group, got {:?}", x), + }) + } None => break, }; match input.peek() { Some(TokenTree::Punct(p)) if p.as_char() == ':' => { input.next(); } - Some(x) => return Err(Error::InvalidRustSyntax(x.span())), - None => return Err(Error::InvalidRustSyntax(Span::call_site())), + token => return Error::wrong_token(token, ":"), } let r#type = read_tokens_until_punct(input, &[','])?; consume_punct_if(input, ','); @@ -339,7 +381,7 @@ impl UnnamedField { pub fn parse(input: &mut Peekable>) -> Result> { let mut result = Vec::new(); while input.peek().is_some() { - let attributes = Attributes::try_take(input)?; + let attributes = Attribute::try_take(input)?; let vis = Visibility::try_take(input)?; let r#type = read_tokens_until_punct(input, &[','])?; diff --git a/derive/src/parse/data_type.rs b/derive/src/parse/data_type.rs index 64d0760..ae7b19b 100644 --- a/derive/src/parse/data_type.rs +++ b/derive/src/parse/data_type.rs @@ -1,4 +1,4 @@ -use crate::prelude::{Ident, Span, TokenTree}; +use crate::prelude::{Ident, TokenTree}; use crate::{Error, Result}; use std::iter::Peekable; @@ -10,24 +10,19 @@ pub enum DataType { impl DataType { pub fn take(input: &mut Peekable>) -> Result<(Self, Ident)> { - if let Some(TokenTree::Ident(ident)) = input.peek() { + if let Some(TokenTree::Ident(_)) = input.peek() { + let ident = super::assume_ident(input.next()); let result = match ident.to_string().as_str() { "struct" => DataType::Struct, "enum" => DataType::Enum, _ => return Err(Error::UnknownDataType(ident.span())), }; - let ident = super::assume_ident(input.next()); return match input.next() { Some(TokenTree::Ident(ident)) => Ok((result, ident)), - Some(t) => Err(Error::InvalidRustSyntax(t.span())), - None => Err(Error::InvalidRustSyntax(ident.span())), + token => Error::wrong_token(token.as_ref(), "ident"), }; } - let span = input - .peek() - .map(|t| t.span()) - .unwrap_or_else(Span::call_site); - Err(Error::InvalidRustSyntax(span)) + Error::wrong_token(input.peek(), "ident") } } diff --git a/derive/src/parse/generics.rs b/derive/src/parse/generics.rs index 9495c4c..b574a00 100644 --- a/derive/src/parse/generics.rs +++ b/derive/src/parse/generics.rs @@ -7,7 +7,7 @@ use std::iter::Peekable; #[derive(Debug)] pub struct Generics { - lifetimes_and_generics: Vec, + generics: Vec, } impl Generics { @@ -17,30 +17,31 @@ impl Generics { if punct.as_char() == '<' { let punct = super::assume_punct(input.next(), '<'); let mut result = Generics { - lifetimes_and_generics: Vec::new(), + generics: Vec::new(), }; loop { match input.peek() { Some(TokenTree::Punct(punct)) if punct.as_char() == '\'' => { - result - .lifetimes_and_generics - .push(Lifetime::take(input)?.into()); + result.generics.push(Lifetime::take(input)?.into()); super::consume_punct_if(input, ','); } Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => { assume_punct(input.next(), '>'); break; } + Some(TokenTree::Ident(ident)) if ident_eq(ident, "const") => { + result.generics.push(ConstGeneric::take(input)?.into()); + super::consume_punct_if(input, ','); + } Some(TokenTree::Ident(_)) => { - result - .lifetimes_and_generics - .push(Generic::take(input)?.into()); + result.generics.push(SimpleGeneric::take(input)?.into()); super::consume_punct_if(input, ','); } x => { - return Err(Error::InvalidRustSyntax( - x.map(|x| x.span()).unwrap_or_else(|| punct.span()), - )); + return Err(Error::InvalidRustSyntax { + span: x.map(|x| x.span()).unwrap_or_else(|| punct.span()), + expected: format!("', > or an ident, got {:?}", x), + }); } } } @@ -51,30 +52,19 @@ impl Generics { } pub fn has_lifetime(&self) -> bool { - self.lifetimes_and_generics - .iter() - .any(|lt| lt.is_lifetime()) + self.generics.iter().any(|lt| lt.is_lifetime()) } pub fn impl_generics(&self) -> StreamBuilder { let mut result = StreamBuilder::new(); result.punct('<'); - for (idx, generic) in self.lifetimes_and_generics.iter().enumerate() { + for (idx, generic) in self.generics.iter().enumerate() { if idx > 0 { result.punct(','); } - if generic.is_lifetime() { - result.lifetime(generic.ident()); - } else { - result.ident(generic.ident()); - } - - if generic.has_constraints() { - result.punct(':'); - result.extend(generic.constraints()); - } + generic.append_to_result_with_constraints(&mut result); } result.punct('>'); @@ -91,7 +81,7 @@ impl Generics { if self.has_lifetime() { for (idx, lt) in self - .lifetimes_and_generics + .generics .iter() .filter_map(|lt| lt.as_lifetime()) .enumerate() @@ -101,19 +91,9 @@ impl Generics { } } - for generic in &self.lifetimes_and_generics { + for generic in &self.generics { result.punct(','); - - if generic.is_lifetime() { - result.lifetime(generic.ident()); - } else { - result.ident(generic.ident()); - } - - if generic.has_constraints() { - result.punct(':'); - result.extend(generic.constraints()); - } + generic.append_to_result_with_constraints(&mut result); } result.punct('>'); @@ -125,7 +105,7 @@ impl Generics { let mut result = StreamBuilder::new(); result.punct('<'); - for (idx, generic) in self.lifetimes_and_generics.iter().enumerate() { + for (idx, generic) in self.generics.iter().enumerate() { if idx > 0 { result.punct(','); } @@ -142,27 +122,29 @@ impl Generics { } #[derive(Debug)] -enum LifetimeOrGeneric { +enum Generic { Lifetime(Lifetime), - Generic(Generic), + Generic(SimpleGeneric), + Const(ConstGeneric), } -impl LifetimeOrGeneric { +impl Generic { fn is_lifetime(&self) -> bool { - matches!(self, LifetimeOrGeneric::Lifetime(_)) + matches!(self, Generic::Lifetime(_)) } fn ident(&self) -> Ident { match self { Self::Lifetime(lt) => lt.ident.clone(), Self::Generic(gen) => gen.ident.clone(), + Self::Const(gen) => gen.ident.clone(), } } fn as_lifetime(&self) -> Option<&Lifetime> { match self { Self::Lifetime(lt) => Some(lt), - Self::Generic(_) => None, + _ => None, } } @@ -170,6 +152,7 @@ impl LifetimeOrGeneric { match self { Self::Lifetime(lt) => !lt.constraint.is_empty(), Self::Generic(gen) => !gen.constraints.is_empty(), + Self::Const(_) => true, // const generics always have a constraint } } @@ -177,22 +160,46 @@ impl LifetimeOrGeneric { match self { Self::Lifetime(lt) => lt.constraint.clone(), Self::Generic(gen) => gen.constraints.clone(), + Self::Const(gen) => gen.constraints.clone(), + } + } + + fn append_to_result_with_constraints(&self, builder: &mut StreamBuilder) { + match self { + Self::Lifetime(lt) => builder.lifetime(lt.ident.clone()), + Self::Generic(gen) => { + builder.ident(gen.ident.clone()); + } + Self::Const(gen) => { + builder.ident(gen.const_token.clone()); + builder.ident(gen.ident.clone()); + } + } + if self.has_constraints() { + builder.punct(':'); + builder.extend(self.constraints()); } } } -impl From for LifetimeOrGeneric { +impl From for Generic { fn from(lt: Lifetime) -> Self { Self::Lifetime(lt) } } -impl From for LifetimeOrGeneric { - fn from(gen: Generic) -> Self { +impl From for Generic { + fn from(gen: SimpleGeneric) -> Self { Self::Generic(gen) } } +impl From for Generic { + fn from(gen: ConstGeneric) -> Self { + Self::Const(gen) + } +} + #[test] fn test_generics_try_take() { use crate::token_stream; @@ -210,18 +217,18 @@ fn test_generics_try_take() { assert_eq!(data_type, super::DataType::Struct); assert_eq!(ident, "Foo"); let generics = Generics::try_take(stream).unwrap().unwrap(); - assert_eq!(generics.lifetimes_and_generics.len(), 2); - assert_eq!(generics.lifetimes_and_generics[0].ident(), "a"); - assert_eq!(generics.lifetimes_and_generics[1].ident(), "T"); + assert_eq!(generics.generics.len(), 2); + assert_eq!(generics.generics[0].ident(), "a"); + assert_eq!(generics.generics[1].ident(), "T"); let stream = &mut token_stream("struct Foo()"); let (data_type, ident) = super::DataType::take(stream).unwrap(); assert_eq!(data_type, super::DataType::Struct); assert_eq!(ident, "Foo"); let generics = Generics::try_take(stream).unwrap().unwrap(); - assert_eq!(generics.lifetimes_and_generics.len(), 2); - assert_eq!(generics.lifetimes_and_generics[0].ident(), "A"); - assert_eq!(generics.lifetimes_and_generics[1].ident(), "B"); + assert_eq!(generics.generics.len(), 2); + assert_eq!(generics.generics[0].ident(), "A"); + assert_eq!(generics.generics[1].ident(), "B"); let stream = &mut token_stream("struct Foo<'a, T: Display>()"); let (data_type, ident) = super::DataType::take(stream).unwrap(); @@ -229,18 +236,18 @@ fn test_generics_try_take() { assert_eq!(ident, "Foo"); let generics = Generics::try_take(stream).unwrap().unwrap(); dbg!(&generics); - assert_eq!(generics.lifetimes_and_generics.len(), 2); - assert_eq!(generics.lifetimes_and_generics[0].ident(), "a"); - assert_eq!(generics.lifetimes_and_generics[1].ident(), "T"); + assert_eq!(generics.generics.len(), 2); + assert_eq!(generics.generics[0].ident(), "a"); + assert_eq!(generics.generics[1].ident(), "T"); let stream = &mut token_stream("struct Foo<'a, T: for<'a> Bar<'a> + 'static>()"); let (data_type, ident) = super::DataType::take(stream).unwrap(); assert_eq!(data_type, super::DataType::Struct); assert_eq!(ident, "Foo"); dbg!(&generics); - assert_eq!(generics.lifetimes_and_generics.len(), 2); - assert_eq!(generics.lifetimes_and_generics[0].ident(), "a"); - assert_eq!(generics.lifetimes_and_generics[1].ident(), "T"); + assert_eq!(generics.generics.len(), 2); + assert_eq!(generics.generics[0].ident(), "a"); + assert_eq!(generics.generics[1].ident(), "T"); let stream = &mut token_stream( "struct Baz Bar<'a, for<'b> Bar<'b, for<'c> Bar<'c, u32>>>> {}", @@ -250,8 +257,8 @@ fn test_generics_try_take() { assert_eq!(ident, "Baz"); let generics = Generics::try_take(stream).unwrap().unwrap(); dbg!(&generics); - assert_eq!(generics.lifetimes_and_generics.len(), 1); - assert_eq!(generics.lifetimes_and_generics[0].ident(), "T"); + assert_eq!(generics.generics.len(), 1); + assert_eq!(generics.generics[0].ident(), "T"); let stream = &mut token_stream("struct Baz<()> {}"); let (data_type, ident) = super::DataType::take(stream).unwrap(); @@ -267,9 +274,9 @@ fn test_generics_try_take() { assert_eq!(ident, "Bar"); let generics = Generics::try_take(stream).unwrap().unwrap(); dbg!(&generics); - assert_eq!(generics.lifetimes_and_generics.len(), 2); - assert_eq!(generics.lifetimes_and_generics[0].ident(), "A"); - assert_eq!(generics.lifetimes_and_generics[1].ident(), "B"); + assert_eq!(generics.generics.len(), 2); + assert_eq!(generics.generics[0].ident(), "A"); + assert_eq!(generics.generics[1].ident(), "B"); } #[derive(Debug)] @@ -325,12 +332,12 @@ fn test_lifetime_take() { } #[derive(Debug)] -pub struct Generic { +pub struct SimpleGeneric { ident: Ident, constraints: Vec, } -impl Generic { +impl SimpleGeneric { pub fn take(input: &mut Peekable>) -> Result { let ident = super::assume_ident(input.next()); let mut constraints = Vec::new(); @@ -340,7 +347,33 @@ impl Generic { constraints = super::read_tokens_until_punct(input, &['>', ','])?; } } - Ok(Generic { ident, constraints }) + Ok(Self { ident, constraints }) + } +} + +#[derive(Debug)] +pub struct ConstGeneric { + const_token: Ident, + ident: Ident, + constraints: Vec, +} + +impl ConstGeneric { + pub fn take(input: &mut Peekable>) -> Result { + let const_token = super::assume_ident(input.next()); + let ident = super::assume_ident(input.next()); + let mut constraints = Vec::new(); + if let Some(TokenTree::Punct(punct)) = input.peek() { + if punct.as_char() == ':' { + super::assume_punct(input.next(), ':'); + constraints = super::read_tokens_until_punct(input, &['>', ','])?; + } + } + Ok(Self { + const_token, + ident, + constraints, + }) } } @@ -410,8 +443,8 @@ fn test_generic_constraints_try_take() { assert_eq!(data_type, DataType::Struct); assert_eq!(ident, "Test"); let constraints = Generics::try_take(stream).unwrap().unwrap(); - assert_eq!(constraints.lifetimes_and_generics.len(), 1); - assert_eq!(constraints.lifetimes_and_generics[0].ident(), "T"); + assert_eq!(constraints.generics.len(), 1); + assert_eq!(constraints.generics[0].ident(), "T"); let body = StructBody::take(stream).unwrap(); assert_eq!(body.fields.len(), 0); } diff --git a/derive/src/parse/mod.rs b/derive/src/parse/mod.rs index e888abb..e480a45 100644 --- a/derive/src/parse/mod.rs +++ b/derive/src/parse/mod.rs @@ -8,10 +8,10 @@ mod data_type; mod generics; mod visibility; -pub use self::attributes::Attributes; +pub use self::attributes::Attribute; pub use self::body::{EnumBody, EnumVariant, Fields, StructBody, UnnamedField}; pub use self::data_type::DataType; -pub use self::generics::{Generic, GenericConstraints, Generics, Lifetime}; +pub use self::generics::{GenericConstraints, Generics, Lifetime, SimpleGeneric}; pub use self::visibility::Visibility; pub(self) fn assume_group(t: Option) -> Group { @@ -103,7 +103,14 @@ pub(self) fn read_tokens_until_punct( if expected_puncts.contains(&punct.as_char()) { break; } - return Err(Error::InvalidRustSyntax(punct.span())); + return Err(Error::InvalidRustSyntax { + span: punct.span(), + expected: format!( + "one of {:?}, got '{}'", + expected_puncts, + punct.as_char() + ), + }); } }; let expected = OPEN_BRACKETS[index]; diff --git a/src/de/impls.rs b/src/de/impls.rs index 4efff04..dc2143e 100644 --- a/src/de/impls.rs +++ b/src/de/impls.rs @@ -445,8 +445,7 @@ where } x => Err(DecodeError::UnexpectedVariant { found: x as u32, - max: 1, - min: 0, + allowed: crate::error::AllowedEnumVariants::Range { max: 1, min: 0 }, type_name: core::any::type_name::>(), }), } @@ -471,8 +470,7 @@ where } x => Err(DecodeError::UnexpectedVariant { found: x as u32, - max: 1, - min: 0, + allowed: crate::error::AllowedEnumVariants::Range { max: 1, min: 0 }, type_name: core::any::type_name::>(), }), } @@ -539,8 +537,7 @@ where 1 => Ok(Bound::Included(T::decode(decoder)?)), 2 => Ok(Bound::Excluded(T::decode(decoder)?)), x => Err(DecodeError::UnexpectedVariant { - min: 0, - max: 2, + allowed: crate::error::AllowedEnumVariants::Range { max: 2, min: 0 }, found: x, type_name: core::any::type_name::>(), }), diff --git a/src/error.rs b/src/error.rs index 856dc8a..ec97dff 100644 --- a/src/error.rs +++ b/src/error.rs @@ -50,7 +50,7 @@ pub enum EncodeError { /// Errors that can be encounted by decoding a type #[non_exhaustive] -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum DecodeError { /// The reader reached its end but more bytes were expected. UnexpectedEnd, @@ -74,11 +74,8 @@ pub enum DecodeError { /// The type name that was being decoded. type_name: &'static str, - /// The min index of the enum. Usually this is `0`. - min: u32, - - /// the max index of the enum. - max: u32, + /// The variants that are allowed + allowed: AllowedEnumVariants, /// The index of the enum that the decoder encountered found: u32, @@ -126,9 +123,20 @@ impl DecodeError { } } +/// Indicates which enum variants are allowed +#[non_exhaustive] +#[derive(Debug, PartialEq)] +pub enum AllowedEnumVariants { + /// All values between `min` and `max` (inclusive) are allowed + #[allow(missing_docs)] + Range { min: u32, max: u32 }, + /// Each one of these values is allowed + Allowed(&'static [u32]), +} + /// Integer types. Used by [DecodeError]. These types have no purpose other than being shown in errors. #[non_exhaustive] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] #[allow(missing_docs)] pub enum IntegerType { U8, diff --git a/src/features/impl_std.rs b/src/features/impl_std.rs index 649c27c..9d41889 100644 --- a/src/features/impl_std.rs +++ b/src/features/impl_std.rs @@ -252,8 +252,7 @@ impl Decode for IpAddr { 0 => Ok(IpAddr::V4(Ipv4Addr::decode(decoder)?)), 1 => Ok(IpAddr::V6(Ipv6Addr::decode(decoder)?)), found => Err(DecodeError::UnexpectedVariant { - min: 0, - max: 1, + allowed: crate::error::AllowedEnumVariants::Range { min: 0, max: 1 }, found, type_name: core::any::type_name::(), }), @@ -306,8 +305,7 @@ impl Decode for SocketAddr { 0 => Ok(SocketAddr::V4(SocketAddrV4::decode(decoder)?)), 1 => Ok(SocketAddr::V6(SocketAddrV6::decode(decoder)?)), found => Err(DecodeError::UnexpectedVariant { - min: 0, - max: 1, + allowed: crate::error::AllowedEnumVariants::Range { min: 0, max: 1 }, found, type_name: core::any::type_name::(), }), diff --git a/tests/derive.rs b/tests/derive.rs index 0ec0e49..6501235 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -144,3 +144,50 @@ fn test_decode_enum_tuple_variant() { bincode::decode_from_slice(&mut slice, Configuration::standard()).unwrap(); assert_eq!(result, start); } + +#[derive(bincode::Decode, bincode::Encode, PartialEq, Eq, Debug)] +enum CStyleEnum { + A = 1, + B = 2, + C, + D = 5, + E, +} + +#[test] +fn test_c_style_enum() { + fn ser(e: CStyleEnum) -> u8 { + let mut slice = [0u8; 10]; + let bytes_written = + bincode::encode_into_slice(e, &mut slice, Configuration::standard()).unwrap(); + assert_eq!(bytes_written, 1); + slice[0] + } + + assert_eq!(ser(CStyleEnum::A), 1); + assert_eq!(ser(CStyleEnum::B), 2); + assert_eq!(ser(CStyleEnum::C), 3); + assert_eq!(ser(CStyleEnum::D), 5); + assert_eq!(ser(CStyleEnum::E), 6); + + fn de(num: u8) -> Result { + bincode::decode_from_slice(&[num], Configuration::standard()) + } + + fn expected_err(idx: u32) -> Result { + Err(bincode::error::DecodeError::UnexpectedVariant { + type_name: "CStyleEnum", + allowed: bincode::error::AllowedEnumVariants::Allowed(&[1, 2, 3, 5, 6]), + found: idx, + }) + } + + assert_eq!(de(0), expected_err(0)); + assert_eq!(de(1).unwrap(), CStyleEnum::A); + assert_eq!(de(2).unwrap(), CStyleEnum::B); + assert_eq!(de(3).unwrap(), CStyleEnum::C); + assert_eq!(de(4), expected_err(4)); + assert_eq!(de(5).unwrap(), CStyleEnum::D); + assert_eq!(de(6).unwrap(), CStyleEnum::E); + assert_eq!(de(7), expected_err(7)); +} diff --git a/tests/issues.rs b/tests/issues.rs new file mode 100644 index 0000000..b7fa4b3 --- /dev/null +++ b/tests/issues.rs @@ -0,0 +1,4 @@ +#![no_std] + +#[path = "issues/issue_427.rs"] +mod issue_427; diff --git a/tests/issues/issue_427.rs b/tests/issues/issue_427.rs new file mode 100644 index 0000000..5d63899 --- /dev/null +++ b/tests/issues/issue_427.rs @@ -0,0 +1,69 @@ +#![cfg(feature = "derive")] + +/// HID-IO Packet Buffer Struct +/// +/// # Remarks +/// Used to store HID-IO data chunks. Will be chunked into individual packets on transmission. +#[repr(C)] +#[derive(PartialEq, Clone, Debug, bincode::Encode)] +pub struct HidIoPacketBuffer { + /// Type of packet (Continued is automatically set if needed) + pub ptype: u32, + /// Packet Id + pub id: u32, + /// Packet length for serialization (in bytes) + pub max_len: u32, + /// Payload data, chunking is done automatically by serializer + pub data: [u8; H], + /// Set False if buffer is not complete, True if it is + pub done: bool, +} + +#[repr(u32)] +#[derive(PartialEq, Clone, Copy, Debug, bincode::Encode)] +#[allow(dead_code)] +/// Requests for to perform a specific action +pub enum HidIoCommandId { + SupportedIds = 0x00, + GetInfo = 0x01, + TestPacket = 0x02, + ResetHidIo = 0x03, + Reserved = 0x04, // ... 0x0F + + GetProperties = 0x10, + KeyState = 0x11, + KeyboardLayout = 0x12, + KeyLayout = 0x13, + KeyShapes = 0x14, + LedLayout = 0x15, + FlashMode = 0x16, + UnicodeText = 0x17, + UnicodeState = 0x18, + HostMacro = 0x19, + SleepMode = 0x1A, + + KllState = 0x20, + PixelSetting = 0x21, + PixelSet1c8b = 0x22, + PixelSet3c8b = 0x23, + PixelSet1c16b = 0x24, + PixelSet3c16b = 0x25, + + OpenUrl = 0x30, + TerminalCmd = 0x31, + GetInputLayout = 0x32, + SetInputLayout = 0x33, + TerminalOut = 0x34, + + HidKeyboard = 0x40, + HidKeyboardLed = 0x41, + HidMouse = 0x42, + HidJoystick = 0x43, + HidSystemCtrl = 0x44, + HidConsumerCtrl = 0x45, + + ManufacturingTest = 0x50, + ManufacturingResult = 0x51, + + Unused = 0xFFFF, +}