Fixes for 427 (#428)

* Made bincode_derive handle empty lines of docs correctly
* Made bincode_derive properly support const generics
* Added support for enums with variants with fixed values
This commit is contained in:
Trangar 2021-11-07 10:31:15 +01:00 committed by GitHub
parent 7174f6422d
commit b4c46a789a
15 changed files with 570 additions and 244 deletions

View File

@ -1,4 +1,4 @@
use crate::generate::{FnSelfArg, Generator};
use crate::generate::{FnSelfArg, Generator, StreamBuilder};
use crate::parse::{EnumVariant, Fields};
use crate::prelude::*;
use crate::Result;
@ -10,9 +10,15 @@ pub struct DeriveEnum {
}
impl DeriveEnum {
pub fn generate_encode(self, generator: &mut Generator) -> Result<()> {
let DeriveEnum { variants } = self;
fn iter_fields(&self) -> EnumVariantIterator {
EnumVariantIterator {
idx: 0,
last_val: None,
variants: &self.variants,
}
}
pub fn generate_encode(self, generator: &mut Generator) -> Result<()> {
generator
.impl_for("bincode::enc::Encode")
.unwrap()
@ -25,7 +31,7 @@ impl DeriveEnum {
fn_body.ident_str("match");
fn_body.ident_str("self");
fn_body.group(Delimiter::Brace, |match_body| {
for (variant_index, variant) in variants.into_iter().enumerate() {
for (variant_index, variant) in self.iter_fields() {
// Self::Variant
match_body.ident_str("Self");
match_body.puncts("::");
@ -62,11 +68,16 @@ impl DeriveEnum {
// }
match_body.group(Delimiter::Brace, |body| {
// variant index
body.push_parsed(format!(
"<u32 as bincode::enc::Encode>::encode(&{}, &mut encoder)?;",
variant_index
))
.unwrap();
body.push_parsed("<u32 as bincode::enc::Encode>::encode")
.unwrap();
body.group(Delimiter::Parenthesis, |args| {
args.punct('&');
args.group(Delimiter::Parenthesis, |num| num.extend(variant_index));
args.punct(',');
args.push_parsed("&mut encoder").unwrap();
});
body.punct('?');
body.punct(';');
// If we have any fields, encode them all one by one
for field_name in variant.fields.names() {
body.push_parsed(format!(
@ -85,8 +96,73 @@ impl DeriveEnum {
Ok(())
}
/// Build the catch-all case for an int-to-enum decode implementation
fn invalid_variant_case(&self, enum_name: &str, result: &mut StreamBuilder) {
// we'll be generating:
// variant => Err(
// bincode::error::DecodeError::UnexpectedVariant {
// found: variant,
// type_name: <enum_name>
// allowed: ...,
// }
// )
//
// Where allowed is either:
// - bincode::error::AllowedEnumVariants::Range { min: 0, max: <max> }
// if we have no fixed value variants
// - bincode::error::AllowedEnumVariants::Allowed(&[<variant1>, <variant2>, ...])
// if we have fixed value variants
result.ident_str("variant");
result.puncts("=>");
result.ident_str("Err");
result.group(Delimiter::Parenthesis, |err_inner| {
err_inner
.push_parsed("bincode::error::DecodeError::UnexpectedVariant")
.unwrap();
err_inner.group(Delimiter::Brace, |variant_inner| {
variant_inner.ident_str("found");
variant_inner.punct(':');
variant_inner.ident_str("variant");
variant_inner.punct(',');
variant_inner.ident_str("type_name");
variant_inner.punct(':');
variant_inner.lit_str(enum_name);
variant_inner.punct(',');
variant_inner.ident_str("allowed");
variant_inner.punct(':');
if self.variants.iter().any(|i| i.has_fixed_value()) {
// we have fixed values, implement AllowedEnumVariants::Allowed
variant_inner
.push_parsed("bincode::error::AllowedEnumVariants::Allowed")
.unwrap();
variant_inner.group(Delimiter::Parenthesis, |allowed_inner| {
allowed_inner.punct('&');
allowed_inner.group(Delimiter::Bracket, |allowed_slice| {
for (idx, (ident, _)) in self.iter_fields().enumerate() {
if idx != 0 {
allowed_slice.punct(',');
}
allowed_slice.extend(ident);
}
});
});
} else {
// no fixed values, implement a range
variant_inner
.push_parsed(format!(
"bincode::error::AllowedEnumVariants::Range {{ min: 0, max: {} }}",
self.variants.len() - 1
))
.unwrap();
}
})
});
}
pub fn generate_decode(self, generator: &mut Generator) -> Result<()> {
let DeriveEnum { variants } = self;
let enum_name = generator.target_name().to_string();
if generator.has_lifetimes() {
@ -103,43 +179,44 @@ impl DeriveEnum {
.push_parsed("let variant_index = <u32 as bincode::de::Decode>::decode(&mut decoder)?;").unwrap();
fn_builder.push_parsed("match variant_index").unwrap();
fn_builder.group(Delimiter::Brace, |variant_case| {
for (idx, variant) in variants.iter().enumerate() {
// idx => Ok(..)
variant_case.lit_u32(idx as u32);
variant_case.puncts("=>");
variant_case.ident_str("Ok");
variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
// Self::Variant { }
// Self::Variant { 0: ..., 1: ... 2: ... },
// Self::Variant { a: ..., b: ... c: ... },
variant_case_body.ident_str("Self");
variant_case_body.puncts("::");
variant_case_body.ident(variant.name.clone());
for (mut variant_index, variant) in self.iter_fields() {
// idx => Ok(..)
if variant_index.len() > 1 {
variant_case.push_parsed("x if x == ").unwrap();
variant_case.extend(variant_index);
} else {
variant_case.push(variant_index.remove(0));
}
variant_case.puncts("=>");
variant_case.ident_str("Ok");
variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
// Self::Variant { }
// Self::Variant { 0: ..., 1: ... 2: ... },
// Self::Variant { a: ..., b: ... c: ... },
variant_case_body.ident_str("Self");
variant_case_body.puncts("::");
variant_case_body.ident(variant.name.clone());
variant_case_body.group(Delimiter::Brace, |variant_body| {
let is_tuple = matches!(variant.fields, Fields::Tuple(_));
for (idx, field) in variant.fields.names().into_iter().enumerate() {
if is_tuple {
variant_body.lit_usize(idx);
} else {
variant_body.ident(field.unwrap_ident().clone());
variant_case_body.group(Delimiter::Brace, |variant_body| {
let is_tuple = matches!(variant.fields, Fields::Tuple(_));
for (idx, field) in variant.fields.names().into_iter().enumerate() {
if is_tuple {
variant_body.lit_usize(idx);
} else {
variant_body.ident(field.unwrap_ident().clone());
}
variant_body.punct(':');
variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap();
}
variant_body.punct(':');
variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap();
}
});
});
});
variant_case.punct(',');
}
variant_case.punct(',');
}
// invalid idx
variant_case.push_parsed(format!(
"variant => return Err(bincode::error::DecodeError::UnexpectedVariant {{ min: 0, max: {}, found: variant, type_name: {:?} }})",
variants.len() - 1,
enum_name.to_string()
)).unwrap();
});
}).unwrap();
// invalid idx
self.invalid_variant_case(&enum_name, variant_case);
});
}).unwrap();
} else {
// enum has no lifetimes, implement Decode
generator.impl_for("bincode::de::Decode")
@ -153,45 +230,79 @@ impl DeriveEnum {
.push_parsed("let variant_index = <u32 as bincode::de::Decode>::decode(&mut decoder)?;").unwrap();
fn_builder.push_parsed("match variant_index").unwrap();
fn_builder.group(Delimiter::Brace, |variant_case| {
for (idx, variant) in variants.iter().enumerate() {
// idx => Ok(..)
variant_case.lit_u32(idx as u32);
variant_case.puncts("=>");
variant_case.ident_str("Ok");
variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
// Self::Variant { }
// Self::Variant { 0: ..., 1: ... 2: ... },
// Self::Variant { a: ..., b: ... c: ... },
variant_case_body.ident_str("Self");
variant_case_body.puncts("::");
variant_case_body.ident(variant.name.clone());
for (mut variant_index, variant) in self.iter_fields() {
// idx => Ok(..)
if variant_index.len() > 1 {
variant_case.push_parsed("x if x == ").unwrap();
variant_case.extend(variant_index);
} else {
variant_case.push(variant_index.remove(0));
}
variant_case.puncts("=>");
variant_case.ident_str("Ok");
variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
// Self::Variant { }
// Self::Variant { 0: ..., 1: ... 2: ... },
// Self::Variant { a: ..., b: ... c: ... },
variant_case_body.ident_str("Self");
variant_case_body.puncts("::");
variant_case_body.ident(variant.name.clone());
variant_case_body.group(Delimiter::Brace, |variant_body| {
let is_tuple = matches!(variant.fields, Fields::Tuple(_));
for (idx, field) in variant.fields.names().into_iter().enumerate() {
if is_tuple {
variant_body.lit_usize(idx);
} else {
variant_body.ident(field.unwrap_ident().clone());
variant_case_body.group(Delimiter::Brace, |variant_body| {
let is_tuple = matches!(variant.fields, Fields::Tuple(_));
for (idx, field) in variant.fields.names().into_iter().enumerate() {
if is_tuple {
variant_body.lit_usize(idx);
} else {
variant_body.ident(field.unwrap_ident().clone());
}
variant_body.punct(':');
variant_body.push_parsed("bincode::de::Decode::decode(&mut decoder)?,").unwrap();
}
variant_body.punct(':');
variant_body.push_parsed("bincode::de::Decode::decode(&mut decoder)?,").unwrap();
}
});
});
});
variant_case.punct(',');
}
variant_case.punct(',');
}
// invalid idx
variant_case.push_parsed(format!(
"variant => return Err(bincode::error::DecodeError::UnexpectedVariant {{ min: 0, max: {}, found: variant, type_name: {:?} }})",
variants.len() - 1,
enum_name.to_string()
)).unwrap();
});
}).unwrap();
// invalid idx
self.invalid_variant_case(&enum_name, variant_case);
});
}).unwrap();
}
Ok(())
}
}
struct EnumVariantIterator<'a> {
variants: &'a [EnumVariant],
idx: usize,
last_val: Option<(Literal, u32)>,
}
impl<'a> Iterator for EnumVariantIterator<'a> {
type Item = (Vec<TokenTree>, &'a EnumVariant);
fn next(&mut self) -> Option<Self::Item> {
let idx = self.idx;
let variant = self.variants.get(self.idx)?;
self.idx += 1;
let tokens = if let Fields::Integer(lit) = &variant.fields {
let tree = TokenTree::Literal(lit.clone());
self.last_val = Some((lit.clone(), 0));
vec![tree]
} else if let Some((lit, add)) = self.last_val.as_mut() {
*add += 1;
vec![
TokenTree::Literal(lit.clone()),
TokenTree::Punct(Punct::new('+', Spacing::Alone)),
TokenTree::Literal(Literal::u32_suffixed(*add)),
]
} else {
vec![TokenTree::Literal(Literal::u32_suffixed(idx as u32))]
};
Some((tokens, variant))
}
}

View File

@ -4,10 +4,19 @@ use std::fmt;
#[derive(Debug)]
pub enum Error {
UnknownDataType(Span),
InvalidRustSyntax(Span),
InvalidRustSyntax { span: Span, expected: String },
ExpectedIdent(Span),
}
impl Error {
pub fn wrong_token<T>(token: Option<&TokenTree>, expected: &'static str) -> Result<T, Self> {
Err(Self::InvalidRustSyntax {
span: token.map(|t| t.span()).unwrap_or_else(Span::call_site),
expected: format!("{}, got {:?}", expected, token),
})
}
}
// helper functions for the unit tests
#[cfg(test)]
impl Error {
@ -16,7 +25,7 @@ impl Error {
}
pub fn is_invalid_rust_syntax(&self) -> bool {
matches!(self, Error::InvalidRustSyntax(_))
matches!(self, Error::InvalidRustSyntax { .. })
}
}
@ -26,7 +35,9 @@ impl fmt::Display for Error {
Self::UnknownDataType(_) => {
write!(fmt, "Unknown data type, only enum and struct are supported")
}
Self::InvalidRustSyntax(_) => write!(fmt, "Invalid rust syntax"),
Self::InvalidRustSyntax { expected, .. } => {
write!(fmt, "Invalid rust syntax, expected {}", expected)
}
Self::ExpectedIdent(_) => write!(fmt, "Expected ident"),
}
}
@ -37,7 +48,7 @@ impl Error {
let maybe_span = match &self {
Error::UnknownDataType(span)
| Error::ExpectedIdent(span)
| Error::InvalidRustSyntax(span) => Some(*span),
| Error::InvalidRustSyntax { span, .. } => Some(*span),
};
self.throw_with_span(maybe_span.unwrap_or_else(Span::call_site))
}

View File

@ -125,12 +125,6 @@ impl StreamBuilder {
.extend([TokenTree::Literal(Literal::string(str.as_ref()))]);
}
/// Add an `u32` value to the stream.
pub fn lit_u32(&mut self, val: u32) {
self.stream
.extend([TokenTree::Literal(Literal::u32_unsuffixed(val))]);
}
/// Add an `usize` value to the stream.
pub fn lit_usize(&mut self, val: usize) {
self.stream

View File

@ -31,7 +31,7 @@ pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream
fn derive_encode_inner(input: TokenStream) -> Result<TokenStream> {
let source = &mut input.into_iter().peekable();
let _attributes = parse::Attributes::try_take(source)?;
let _attributes = parse::Attribute::try_take(source)?;
let _visibility = parse::Visibility::try_take(source)?;
let (datatype, name) = parse::DataType::take(source)?;
let generics = parse::Generics::try_take(source)?;
@ -72,7 +72,7 @@ pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream
fn derive_decode_inner(input: TokenStream) -> Result<TokenStream> {
let source = &mut input.into_iter().peekable();
let _attributes = parse::Attributes::try_take(source)?;
let _attributes = parse::Attribute::try_take(source)?;
let _visibility = parse::Visibility::try_take(source)?;
let (datatype, name) = parse::DataType::take(source)?;
let generics = parse::Generics::try_take(source)?;

View File

@ -1,38 +1,48 @@
use super::assume_group;
use super::{assume_group, assume_punct};
use crate::parse::consume_punct_if;
use crate::prelude::{Delimiter, Group, Punct, TokenTree};
use crate::{Error, Result};
use std::iter::Peekable;
#[derive(Debug)]
pub struct Attributes {
pub struct Attribute {
// we don't use these fields yet
#[allow(dead_code)]
punct: Punct,
#[allow(dead_code)]
tokens: Group,
tokens: Option<Group>,
}
impl Attributes {
pub fn try_take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Option<Self>> {
if let Some(punct) = consume_punct_if(input, '#') {
// found attributes, next token should be a [] group
if let Some(TokenTree::Group(g)) = input.peek() {
if g.delimiter() != Delimiter::Bracket {
return Err(Error::InvalidRustSyntax(g.span()));
impl Attribute {
pub fn try_take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Vec<Self>> {
let mut result = Vec::new();
while let Some(punct) = consume_punct_if(input, '#') {
match input.peek() {
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
result.push(Attribute {
punct,
tokens: Some(assume_group(input.next())),
});
}
return Ok(Some(Attributes {
punct,
tokens: assume_group(input.next()),
}));
Some(TokenTree::Group(g)) => {
return Err(Error::InvalidRustSyntax {
span: g.span(),
expected: format!("[] bracket, got {:?}", g.delimiter()),
});
}
Some(TokenTree::Punct(p)) if p.as_char() == '#' => {
// sometimes with empty lines of doc comments, we get two #'s in a row
// add an empty attributes and continue to the next loop
result.push(Attribute {
punct: assume_punct(input.next(), '#'),
tokens: None,
})
}
token => return Error::wrong_token(token, "[] group or next # attribute"),
}
// expected [] group, found something else
return Err(Error::InvalidRustSyntax(match input.peek() {
Some(next_token) => next_token.span(),
None => punct.span(),
}));
}
Ok(None)
Ok(result)
}
}
@ -41,14 +51,14 @@ fn test_attributes_try_take() {
use crate::token_stream;
let stream = &mut token_stream("struct Foo;");
assert!(Attributes::try_take(stream).unwrap().is_none());
assert!(Attribute::try_take(stream).unwrap().is_empty());
match stream.next().unwrap() {
TokenTree::Ident(i) => assert_eq!(i, "struct"),
x => panic!("Expected ident, found {:?}", x),
}
let stream = &mut token_stream("#[cfg(test)] struct Foo;");
assert!(Attributes::try_take(stream).unwrap().is_some());
assert!(!Attribute::try_take(stream).unwrap().is_empty());
match stream.next().unwrap() {
TokenTree::Ident(i) => assert_eq!(i, "struct"),
x => panic!("Expected ident, found {:?}", x),

View File

@ -1,6 +1,8 @@
use super::{assume_group, assume_ident, read_tokens_until_punct, Attributes, Visibility};
use super::{
assume_group, assume_ident, assume_punct, read_tokens_until_punct, Attribute, Visibility,
};
use crate::parse::consume_punct_if;
use crate::prelude::{Delimiter, Ident, Span, TokenTree};
use crate::prelude::{Delimiter, Ident, Literal, Span, TokenTree};
use crate::{Error, Result};
use std::iter::Peekable;
@ -18,19 +20,19 @@ impl StructBody {
fields: Fields::Unit,
})
}
Some(t) => {
return Err(Error::InvalidRustSyntax(t.span()));
}
_ => {
return Err(Error::InvalidRustSyntax(Span::call_site()));
}
token => return Error::wrong_token(token, "group or punct"),
}
let group = assume_group(input.next());
let mut stream = group.stream().into_iter().peekable();
let fields = match group.delimiter() {
Delimiter::Brace => Fields::Struct(UnnamedField::parse_with_name(&mut stream)?),
Delimiter::Parenthesis => Fields::Tuple(UnnamedField::parse(&mut stream)?),
_ => return Err(Error::InvalidRustSyntax(group.span())),
found => {
return Err(Error::InvalidRustSyntax {
span: group.span(),
expected: format!("brace or parenthesis, found {:?}", found),
})
}
};
Ok(StructBody { fields })
}
@ -124,37 +126,57 @@ impl EnumBody {
variants: Vec::new(),
})
}
Some(t) => {
return Err(Error::InvalidRustSyntax(t.span()));
}
_ => {
return Err(Error::InvalidRustSyntax(Span::call_site()));
}
token => return Error::wrong_token(token, "group or ;"),
}
let group = assume_group(input.next());
let mut variants = Vec::new();
let stream = &mut group.stream().into_iter().peekable();
while stream.peek().is_some() {
let attributes = Attributes::try_take(stream)?;
let attributes = Attribute::try_take(stream)?;
let ident = match stream.peek() {
Some(TokenTree::Ident(_)) => assume_ident(stream.next()),
Some(x) => return Err(Error::InvalidRustSyntax(x.span())),
None => return Err(Error::InvalidRustSyntax(Span::call_site())),
token => return Error::wrong_token(token, "ident"),
};
let mut fields = Fields::Unit;
if let Some(TokenTree::Group(_)) = stream.peek() {
let group = assume_group(stream.next());
let stream = &mut group.stream().into_iter().peekable();
match group.delimiter() {
Delimiter::Brace => {
fields = Fields::Struct(UnnamedField::parse_with_name(stream)?)
match stream.peek() {
Some(TokenTree::Group(_)) => {
let group = assume_group(stream.next());
let stream = &mut group.stream().into_iter().peekable();
match group.delimiter() {
Delimiter::Brace => {
fields = Fields::Struct(UnnamedField::parse_with_name(stream)?)
}
Delimiter::Parenthesis => {
fields = Fields::Tuple(UnnamedField::parse(stream)?)
}
delim => {
return Err(Error::InvalidRustSyntax {
span: group.span(),
expected: format!("Brace or parenthesis, found {:?}", delim),
})
}
}
Delimiter::Parenthesis => fields = Fields::Tuple(UnnamedField::parse(stream)?),
_ => return Err(Error::InvalidRustSyntax(group.span())),
}
Some(TokenTree::Punct(p)) if p.as_char() == '=' => {
assume_punct(stream.next(), '=');
match stream.next() {
Some(TokenTree::Literal(lit)) => {
fields = Fields::Integer(lit);
}
token => return Error::wrong_token(token.as_ref(), "literal"),
}
}
Some(TokenTree::Punct(p)) if p.as_char() == ',' => {
// next field
}
None => {
// group done
}
token => return Error::wrong_token(token, "group, comma or ="),
}
consume_punct_if(stream, ',');
variants.push(EnumVariant {
@ -209,7 +231,13 @@ fn test_enum_body_take() {
pub struct EnumVariant {
pub name: Ident,
pub fields: Fields,
pub attributes: Option<Attributes>,
pub attributes: Vec<Attribute>,
}
impl EnumVariant {
pub fn has_fixed_value(&self) -> bool {
matches!(&self.fields, Fields::Integer(_))
}
}
#[derive(Debug)]
@ -223,6 +251,14 @@ pub enum Fields {
/// ```
Unit,
/// Variant with an integer value.
/// ```rs
/// enum Foo {
/// Baz = 5,
/// }
/// ```
Integer(Literal),
/// Tuple-like variant
/// ```rs
/// enum Foo {
@ -258,7 +294,7 @@ impl Fields {
.iter()
.map(|(ident, _)| IdentOrIndex::Ident(ident))
.collect(),
Self::Unit => Vec::new(),
Self::Unit | Self::Integer(_) => Vec::new(),
}
}
@ -266,7 +302,7 @@ impl Fields {
match self {
Self::Tuple(_) => Some(Delimiter::Parenthesis),
Self::Struct(_) => Some(Delimiter::Brace),
Self::Unit => None,
Self::Unit | Self::Integer(_) => None,
}
}
}
@ -282,6 +318,7 @@ impl Fields {
Self::Tuple(fields) => fields.len(),
Self::Struct(fields) => fields.len(),
Self::Unit => 0,
Self::Integer(_) => 0,
}
}
@ -290,6 +327,7 @@ impl Fields {
Self::Tuple(fields) => fields.get(index).map(|f| (None, f)),
Self::Struct(fields) => fields.get(index).map(|(ident, field)| (Some(ident), field)),
Self::Unit => None,
Self::Integer(_) => None,
}
}
}
@ -298,7 +336,7 @@ impl Fields {
pub struct UnnamedField {
pub vis: Visibility,
pub r#type: Vec<TokenTree>,
pub attributes: Option<Attributes>,
pub attributes: Vec<Attribute>,
}
impl UnnamedField {
@ -307,20 +345,24 @@ impl UnnamedField {
) -> Result<Vec<(Ident, Self)>> {
let mut result = Vec::new();
loop {
let attributes = Attributes::try_take(input)?;
let attributes = Attribute::try_take(input)?;
let vis = Visibility::try_take(input)?;
let ident = match input.peek() {
Some(TokenTree::Ident(_)) => assume_ident(input.next()),
Some(x) => return Err(Error::InvalidRustSyntax(x.span())),
Some(x) => {
return Err(Error::InvalidRustSyntax {
span: x.span(),
expected: format!("ident or end of group, got {:?}", x),
})
}
None => break,
};
match input.peek() {
Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
input.next();
}
Some(x) => return Err(Error::InvalidRustSyntax(x.span())),
None => return Err(Error::InvalidRustSyntax(Span::call_site())),
token => return Error::wrong_token(token, ":"),
}
let r#type = read_tokens_until_punct(input, &[','])?;
consume_punct_if(input, ',');
@ -339,7 +381,7 @@ impl UnnamedField {
pub fn parse(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Vec<Self>> {
let mut result = Vec::new();
while input.peek().is_some() {
let attributes = Attributes::try_take(input)?;
let attributes = Attribute::try_take(input)?;
let vis = Visibility::try_take(input)?;
let r#type = read_tokens_until_punct(input, &[','])?;

View File

@ -1,4 +1,4 @@
use crate::prelude::{Ident, Span, TokenTree};
use crate::prelude::{Ident, TokenTree};
use crate::{Error, Result};
use std::iter::Peekable;
@ -10,24 +10,19 @@ pub enum DataType {
impl DataType {
pub fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<(Self, Ident)> {
if let Some(TokenTree::Ident(ident)) = input.peek() {
if let Some(TokenTree::Ident(_)) = input.peek() {
let ident = super::assume_ident(input.next());
let result = match ident.to_string().as_str() {
"struct" => DataType::Struct,
"enum" => DataType::Enum,
_ => return Err(Error::UnknownDataType(ident.span())),
};
let ident = super::assume_ident(input.next());
return match input.next() {
Some(TokenTree::Ident(ident)) => Ok((result, ident)),
Some(t) => Err(Error::InvalidRustSyntax(t.span())),
None => Err(Error::InvalidRustSyntax(ident.span())),
token => Error::wrong_token(token.as_ref(), "ident"),
};
}
let span = input
.peek()
.map(|t| t.span())
.unwrap_or_else(Span::call_site);
Err(Error::InvalidRustSyntax(span))
Error::wrong_token(input.peek(), "ident")
}
}

View File

@ -7,7 +7,7 @@ use std::iter::Peekable;
#[derive(Debug)]
pub struct Generics {
lifetimes_and_generics: Vec<LifetimeOrGeneric>,
generics: Vec<Generic>,
}
impl Generics {
@ -17,30 +17,31 @@ impl Generics {
if punct.as_char() == '<' {
let punct = super::assume_punct(input.next(), '<');
let mut result = Generics {
lifetimes_and_generics: Vec::new(),
generics: Vec::new(),
};
loop {
match input.peek() {
Some(TokenTree::Punct(punct)) if punct.as_char() == '\'' => {
result
.lifetimes_and_generics
.push(Lifetime::take(input)?.into());
result.generics.push(Lifetime::take(input)?.into());
super::consume_punct_if(input, ',');
}
Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
assume_punct(input.next(), '>');
break;
}
Some(TokenTree::Ident(ident)) if ident_eq(ident, "const") => {
result.generics.push(ConstGeneric::take(input)?.into());
super::consume_punct_if(input, ',');
}
Some(TokenTree::Ident(_)) => {
result
.lifetimes_and_generics
.push(Generic::take(input)?.into());
result.generics.push(SimpleGeneric::take(input)?.into());
super::consume_punct_if(input, ',');
}
x => {
return Err(Error::InvalidRustSyntax(
x.map(|x| x.span()).unwrap_or_else(|| punct.span()),
));
return Err(Error::InvalidRustSyntax {
span: x.map(|x| x.span()).unwrap_or_else(|| punct.span()),
expected: format!("', > or an ident, got {:?}", x),
});
}
}
}
@ -51,30 +52,19 @@ impl Generics {
}
pub fn has_lifetime(&self) -> bool {
self.lifetimes_and_generics
.iter()
.any(|lt| lt.is_lifetime())
self.generics.iter().any(|lt| lt.is_lifetime())
}
pub fn impl_generics(&self) -> StreamBuilder {
let mut result = StreamBuilder::new();
result.punct('<');
for (idx, generic) in self.lifetimes_and_generics.iter().enumerate() {
for (idx, generic) in self.generics.iter().enumerate() {
if idx > 0 {
result.punct(',');
}
if generic.is_lifetime() {
result.lifetime(generic.ident());
} else {
result.ident(generic.ident());
}
if generic.has_constraints() {
result.punct(':');
result.extend(generic.constraints());
}
generic.append_to_result_with_constraints(&mut result);
}
result.punct('>');
@ -91,7 +81,7 @@ impl Generics {
if self.has_lifetime() {
for (idx, lt) in self
.lifetimes_and_generics
.generics
.iter()
.filter_map(|lt| lt.as_lifetime())
.enumerate()
@ -101,19 +91,9 @@ impl Generics {
}
}
for generic in &self.lifetimes_and_generics {
for generic in &self.generics {
result.punct(',');
if generic.is_lifetime() {
result.lifetime(generic.ident());
} else {
result.ident(generic.ident());
}
if generic.has_constraints() {
result.punct(':');
result.extend(generic.constraints());
}
generic.append_to_result_with_constraints(&mut result);
}
result.punct('>');
@ -125,7 +105,7 @@ impl Generics {
let mut result = StreamBuilder::new();
result.punct('<');
for (idx, generic) in self.lifetimes_and_generics.iter().enumerate() {
for (idx, generic) in self.generics.iter().enumerate() {
if idx > 0 {
result.punct(',');
}
@ -142,27 +122,29 @@ impl Generics {
}
#[derive(Debug)]
enum LifetimeOrGeneric {
enum Generic {
Lifetime(Lifetime),
Generic(Generic),
Generic(SimpleGeneric),
Const(ConstGeneric),
}
impl LifetimeOrGeneric {
impl Generic {
fn is_lifetime(&self) -> bool {
matches!(self, LifetimeOrGeneric::Lifetime(_))
matches!(self, Generic::Lifetime(_))
}
fn ident(&self) -> Ident {
match self {
Self::Lifetime(lt) => lt.ident.clone(),
Self::Generic(gen) => gen.ident.clone(),
Self::Const(gen) => gen.ident.clone(),
}
}
fn as_lifetime(&self) -> Option<&Lifetime> {
match self {
Self::Lifetime(lt) => Some(lt),
Self::Generic(_) => None,
_ => None,
}
}
@ -170,6 +152,7 @@ impl LifetimeOrGeneric {
match self {
Self::Lifetime(lt) => !lt.constraint.is_empty(),
Self::Generic(gen) => !gen.constraints.is_empty(),
Self::Const(_) => true, // const generics always have a constraint
}
}
@ -177,22 +160,46 @@ impl LifetimeOrGeneric {
match self {
Self::Lifetime(lt) => lt.constraint.clone(),
Self::Generic(gen) => gen.constraints.clone(),
Self::Const(gen) => gen.constraints.clone(),
}
}
fn append_to_result_with_constraints(&self, builder: &mut StreamBuilder) {
match self {
Self::Lifetime(lt) => builder.lifetime(lt.ident.clone()),
Self::Generic(gen) => {
builder.ident(gen.ident.clone());
}
Self::Const(gen) => {
builder.ident(gen.const_token.clone());
builder.ident(gen.ident.clone());
}
}
if self.has_constraints() {
builder.punct(':');
builder.extend(self.constraints());
}
}
}
impl From<Lifetime> for LifetimeOrGeneric {
impl From<Lifetime> for Generic {
fn from(lt: Lifetime) -> Self {
Self::Lifetime(lt)
}
}
impl From<Generic> for LifetimeOrGeneric {
fn from(gen: Generic) -> Self {
impl From<SimpleGeneric> for Generic {
fn from(gen: SimpleGeneric) -> Self {
Self::Generic(gen)
}
}
impl From<ConstGeneric> for Generic {
fn from(gen: ConstGeneric) -> Self {
Self::Const(gen)
}
}
#[test]
fn test_generics_try_take() {
use crate::token_stream;
@ -210,18 +217,18 @@ fn test_generics_try_take() {
assert_eq!(data_type, super::DataType::Struct);
assert_eq!(ident, "Foo");
let generics = Generics::try_take(stream).unwrap().unwrap();
assert_eq!(generics.lifetimes_and_generics.len(), 2);
assert_eq!(generics.lifetimes_and_generics[0].ident(), "a");
assert_eq!(generics.lifetimes_and_generics[1].ident(), "T");
assert_eq!(generics.generics.len(), 2);
assert_eq!(generics.generics[0].ident(), "a");
assert_eq!(generics.generics[1].ident(), "T");
let stream = &mut token_stream("struct Foo<A, B>()");
let (data_type, ident) = super::DataType::take(stream).unwrap();
assert_eq!(data_type, super::DataType::Struct);
assert_eq!(ident, "Foo");
let generics = Generics::try_take(stream).unwrap().unwrap();
assert_eq!(generics.lifetimes_and_generics.len(), 2);
assert_eq!(generics.lifetimes_and_generics[0].ident(), "A");
assert_eq!(generics.lifetimes_and_generics[1].ident(), "B");
assert_eq!(generics.generics.len(), 2);
assert_eq!(generics.generics[0].ident(), "A");
assert_eq!(generics.generics[1].ident(), "B");
let stream = &mut token_stream("struct Foo<'a, T: Display>()");
let (data_type, ident) = super::DataType::take(stream).unwrap();
@ -229,18 +236,18 @@ fn test_generics_try_take() {
assert_eq!(ident, "Foo");
let generics = Generics::try_take(stream).unwrap().unwrap();
dbg!(&generics);
assert_eq!(generics.lifetimes_and_generics.len(), 2);
assert_eq!(generics.lifetimes_and_generics[0].ident(), "a");
assert_eq!(generics.lifetimes_and_generics[1].ident(), "T");
assert_eq!(generics.generics.len(), 2);
assert_eq!(generics.generics[0].ident(), "a");
assert_eq!(generics.generics[1].ident(), "T");
let stream = &mut token_stream("struct Foo<'a, T: for<'a> Bar<'a> + 'static>()");
let (data_type, ident) = super::DataType::take(stream).unwrap();
assert_eq!(data_type, super::DataType::Struct);
assert_eq!(ident, "Foo");
dbg!(&generics);
assert_eq!(generics.lifetimes_and_generics.len(), 2);
assert_eq!(generics.lifetimes_and_generics[0].ident(), "a");
assert_eq!(generics.lifetimes_and_generics[1].ident(), "T");
assert_eq!(generics.generics.len(), 2);
assert_eq!(generics.generics[0].ident(), "a");
assert_eq!(generics.generics[1].ident(), "T");
let stream = &mut token_stream(
"struct Baz<T: for<'a> Bar<'a, for<'b> Bar<'b, for<'c> Bar<'c, u32>>>> {}",
@ -250,8 +257,8 @@ fn test_generics_try_take() {
assert_eq!(ident, "Baz");
let generics = Generics::try_take(stream).unwrap().unwrap();
dbg!(&generics);
assert_eq!(generics.lifetimes_and_generics.len(), 1);
assert_eq!(generics.lifetimes_and_generics[0].ident(), "T");
assert_eq!(generics.generics.len(), 1);
assert_eq!(generics.generics[0].ident(), "T");
let stream = &mut token_stream("struct Baz<()> {}");
let (data_type, ident) = super::DataType::take(stream).unwrap();
@ -267,9 +274,9 @@ fn test_generics_try_take() {
assert_eq!(ident, "Bar");
let generics = Generics::try_take(stream).unwrap().unwrap();
dbg!(&generics);
assert_eq!(generics.lifetimes_and_generics.len(), 2);
assert_eq!(generics.lifetimes_and_generics[0].ident(), "A");
assert_eq!(generics.lifetimes_and_generics[1].ident(), "B");
assert_eq!(generics.generics.len(), 2);
assert_eq!(generics.generics[0].ident(), "A");
assert_eq!(generics.generics[1].ident(), "B");
}
#[derive(Debug)]
@ -325,12 +332,12 @@ fn test_lifetime_take() {
}
#[derive(Debug)]
pub struct Generic {
pub struct SimpleGeneric {
ident: Ident,
constraints: Vec<TokenTree>,
}
impl Generic {
impl SimpleGeneric {
pub fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
let ident = super::assume_ident(input.next());
let mut constraints = Vec::new();
@ -340,7 +347,33 @@ impl Generic {
constraints = super::read_tokens_until_punct(input, &['>', ','])?;
}
}
Ok(Generic { ident, constraints })
Ok(Self { ident, constraints })
}
}
#[derive(Debug)]
pub struct ConstGeneric {
const_token: Ident,
ident: Ident,
constraints: Vec<TokenTree>,
}
impl ConstGeneric {
pub fn take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Self> {
let const_token = super::assume_ident(input.next());
let ident = super::assume_ident(input.next());
let mut constraints = Vec::new();
if let Some(TokenTree::Punct(punct)) = input.peek() {
if punct.as_char() == ':' {
super::assume_punct(input.next(), ':');
constraints = super::read_tokens_until_punct(input, &['>', ','])?;
}
}
Ok(Self {
const_token,
ident,
constraints,
})
}
}
@ -410,8 +443,8 @@ fn test_generic_constraints_try_take() {
assert_eq!(data_type, DataType::Struct);
assert_eq!(ident, "Test");
let constraints = Generics::try_take(stream).unwrap().unwrap();
assert_eq!(constraints.lifetimes_and_generics.len(), 1);
assert_eq!(constraints.lifetimes_and_generics[0].ident(), "T");
assert_eq!(constraints.generics.len(), 1);
assert_eq!(constraints.generics[0].ident(), "T");
let body = StructBody::take(stream).unwrap();
assert_eq!(body.fields.len(), 0);
}

View File

@ -8,10 +8,10 @@ mod data_type;
mod generics;
mod visibility;
pub use self::attributes::Attributes;
pub use self::attributes::Attribute;
pub use self::body::{EnumBody, EnumVariant, Fields, StructBody, UnnamedField};
pub use self::data_type::DataType;
pub use self::generics::{Generic, GenericConstraints, Generics, Lifetime};
pub use self::generics::{GenericConstraints, Generics, Lifetime, SimpleGeneric};
pub use self::visibility::Visibility;
pub(self) fn assume_group(t: Option<TokenTree>) -> Group {
@ -103,7 +103,14 @@ pub(self) fn read_tokens_until_punct(
if expected_puncts.contains(&punct.as_char()) {
break;
}
return Err(Error::InvalidRustSyntax(punct.span()));
return Err(Error::InvalidRustSyntax {
span: punct.span(),
expected: format!(
"one of {:?}, got '{}'",
expected_puncts,
punct.as_char()
),
});
}
};
let expected = OPEN_BRACKETS[index];

View File

@ -445,8 +445,7 @@ where
}
x => Err(DecodeError::UnexpectedVariant {
found: x as u32,
max: 1,
min: 0,
allowed: crate::error::AllowedEnumVariants::Range { max: 1, min: 0 },
type_name: core::any::type_name::<Option<T>>(),
}),
}
@ -471,8 +470,7 @@ where
}
x => Err(DecodeError::UnexpectedVariant {
found: x as u32,
max: 1,
min: 0,
allowed: crate::error::AllowedEnumVariants::Range { max: 1, min: 0 },
type_name: core::any::type_name::<Result<T, U>>(),
}),
}
@ -539,8 +537,7 @@ where
1 => Ok(Bound::Included(T::decode(decoder)?)),
2 => Ok(Bound::Excluded(T::decode(decoder)?)),
x => Err(DecodeError::UnexpectedVariant {
min: 0,
max: 2,
allowed: crate::error::AllowedEnumVariants::Range { max: 2, min: 0 },
found: x,
type_name: core::any::type_name::<Bound<T>>(),
}),

View File

@ -50,7 +50,7 @@ pub enum EncodeError {
/// Errors that can be encounted by decoding a type
#[non_exhaustive]
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum DecodeError {
/// The reader reached its end but more bytes were expected.
UnexpectedEnd,
@ -74,11 +74,8 @@ pub enum DecodeError {
/// The type name that was being decoded.
type_name: &'static str,
/// The min index of the enum. Usually this is `0`.
min: u32,
/// the max index of the enum.
max: u32,
/// The variants that are allowed
allowed: AllowedEnumVariants,
/// The index of the enum that the decoder encountered
found: u32,
@ -126,9 +123,20 @@ impl DecodeError {
}
}
/// Indicates which enum variants are allowed
#[non_exhaustive]
#[derive(Debug, PartialEq)]
pub enum AllowedEnumVariants {
/// All values between `min` and `max` (inclusive) are allowed
#[allow(missing_docs)]
Range { min: u32, max: u32 },
/// Each one of these values is allowed
Allowed(&'static [u32]),
}
/// Integer types. Used by [DecodeError]. These types have no purpose other than being shown in errors.
#[non_exhaustive]
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum IntegerType {
U8,

View File

@ -252,8 +252,7 @@ impl Decode for IpAddr {
0 => Ok(IpAddr::V4(Ipv4Addr::decode(decoder)?)),
1 => Ok(IpAddr::V6(Ipv6Addr::decode(decoder)?)),
found => Err(DecodeError::UnexpectedVariant {
min: 0,
max: 1,
allowed: crate::error::AllowedEnumVariants::Range { min: 0, max: 1 },
found,
type_name: core::any::type_name::<IpAddr>(),
}),
@ -306,8 +305,7 @@ impl Decode for SocketAddr {
0 => Ok(SocketAddr::V4(SocketAddrV4::decode(decoder)?)),
1 => Ok(SocketAddr::V6(SocketAddrV6::decode(decoder)?)),
found => Err(DecodeError::UnexpectedVariant {
min: 0,
max: 1,
allowed: crate::error::AllowedEnumVariants::Range { min: 0, max: 1 },
found,
type_name: core::any::type_name::<SocketAddr>(),
}),

View File

@ -144,3 +144,50 @@ fn test_decode_enum_tuple_variant() {
bincode::decode_from_slice(&mut slice, Configuration::standard()).unwrap();
assert_eq!(result, start);
}
#[derive(bincode::Decode, bincode::Encode, PartialEq, Eq, Debug)]
enum CStyleEnum {
A = 1,
B = 2,
C,
D = 5,
E,
}
#[test]
fn test_c_style_enum() {
fn ser(e: CStyleEnum) -> u8 {
let mut slice = [0u8; 10];
let bytes_written =
bincode::encode_into_slice(e, &mut slice, Configuration::standard()).unwrap();
assert_eq!(bytes_written, 1);
slice[0]
}
assert_eq!(ser(CStyleEnum::A), 1);
assert_eq!(ser(CStyleEnum::B), 2);
assert_eq!(ser(CStyleEnum::C), 3);
assert_eq!(ser(CStyleEnum::D), 5);
assert_eq!(ser(CStyleEnum::E), 6);
fn de(num: u8) -> Result<CStyleEnum, bincode::error::DecodeError> {
bincode::decode_from_slice(&[num], Configuration::standard())
}
fn expected_err(idx: u32) -> Result<CStyleEnum, bincode::error::DecodeError> {
Err(bincode::error::DecodeError::UnexpectedVariant {
type_name: "CStyleEnum",
allowed: bincode::error::AllowedEnumVariants::Allowed(&[1, 2, 3, 5, 6]),
found: idx,
})
}
assert_eq!(de(0), expected_err(0));
assert_eq!(de(1).unwrap(), CStyleEnum::A);
assert_eq!(de(2).unwrap(), CStyleEnum::B);
assert_eq!(de(3).unwrap(), CStyleEnum::C);
assert_eq!(de(4), expected_err(4));
assert_eq!(de(5).unwrap(), CStyleEnum::D);
assert_eq!(de(6).unwrap(), CStyleEnum::E);
assert_eq!(de(7), expected_err(7));
}

4
tests/issues.rs Normal file
View File

@ -0,0 +1,4 @@
#![no_std]
#[path = "issues/issue_427.rs"]
mod issue_427;

69
tests/issues/issue_427.rs Normal file
View File

@ -0,0 +1,69 @@
#![cfg(feature = "derive")]
/// HID-IO Packet Buffer Struct
///
/// # Remarks
/// Used to store HID-IO data chunks. Will be chunked into individual packets on transmission.
#[repr(C)]
#[derive(PartialEq, Clone, Debug, bincode::Encode)]
pub struct HidIoPacketBuffer<const H: usize> {
/// Type of packet (Continued is automatically set if needed)
pub ptype: u32,
/// Packet Id
pub id: u32,
/// Packet length for serialization (in bytes)
pub max_len: u32,
/// Payload data, chunking is done automatically by serializer
pub data: [u8; H],
/// Set False if buffer is not complete, True if it is
pub done: bool,
}
#[repr(u32)]
#[derive(PartialEq, Clone, Copy, Debug, bincode::Encode)]
#[allow(dead_code)]
/// Requests for to perform a specific action
pub enum HidIoCommandId {
SupportedIds = 0x00,
GetInfo = 0x01,
TestPacket = 0x02,
ResetHidIo = 0x03,
Reserved = 0x04, // ... 0x0F
GetProperties = 0x10,
KeyState = 0x11,
KeyboardLayout = 0x12,
KeyLayout = 0x13,
KeyShapes = 0x14,
LedLayout = 0x15,
FlashMode = 0x16,
UnicodeText = 0x17,
UnicodeState = 0x18,
HostMacro = 0x19,
SleepMode = 0x1A,
KllState = 0x20,
PixelSetting = 0x21,
PixelSet1c8b = 0x22,
PixelSet3c8b = 0x23,
PixelSet1c16b = 0x24,
PixelSet3c16b = 0x25,
OpenUrl = 0x30,
TerminalCmd = 0x31,
GetInputLayout = 0x32,
SetInputLayout = 0x33,
TerminalOut = 0x34,
HidKeyboard = 0x40,
HidKeyboardLed = 0x41,
HidMouse = 0x42,
HidJoystick = 0x43,
HidSystemCtrl = 0x44,
HidConsumerCtrl = 0x45,
ManufacturingTest = 0x50,
ManufacturingResult = 0x51,
Unused = 0xFFFF,
}