Add enum encode/decode derive

This commit is contained in:
Lena Hellström 2021-09-20 23:06:25 +02:00
parent 2d0254405b
commit c83c36333d
4 changed files with 238 additions and 9 deletions

View File

@ -1,19 +1,183 @@
use crate::Result;
use proc_macro::TokenStream;
use syn::Ident;
pub struct DeriveEnum {}
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use quote::ToTokens;
use syn::{spanned::Spanned, Field, Fields, Generics, Ident, Index, Variant};
pub struct DeriveEnum {
name: Ident,
generics: Generics,
variants: Vec<Variant>,
}
impl DeriveEnum {
pub fn parse(_name: Ident, _en: syn::DataEnum) -> Result<Self> {
unimplemented!()
pub fn parse(name: Ident, generics: Generics, en: syn::DataEnum) -> Result<Self> {
let variants = en.variants.into_iter().collect();
Ok(DeriveEnum {
name,
generics,
variants,
})
}
pub fn to_encodable(self) -> Result<TokenStream> {
unimplemented!()
let DeriveEnum {
name,
generics,
variants,
} = self;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let match_arms = variants.iter().enumerate().map(|(index, variant)| {
let fields_section = fields_to_match_arm(&variant.fields);
let encode_statements = field_names_to_encodable(&fields_to_names(&variant.fields));
let variant_name = variant.ident.clone();
quote! {
#name :: #variant_name #fields_section => {
encoder.encode_u32(#index as u32)?;
#(#encode_statements)*
}
}
});
let result = quote! {
impl #impl_generics bincode::enc::Encodeable for #name #ty_generics #where_clause {
fn encode<E: bincode::enc::Encode>(&self, mut encoder: E) -> Result<(), bincode::error::EncodeError> {
match self {
#(#match_arms)*
}
Ok(())
}
}
};
Ok(result.into())
}
pub fn to_decodable(self) -> Result<TokenStream> {
unimplemented!()
let DeriveEnum {
name,
generics,
variants,
} = self;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let max_variant = (variants.len() - 1) as u32;
let match_arms = variants.iter().enumerate().map(|(index, variant)| {
let index = index as u32;
let decode_statements =
field_names_to_decodable(&fields_to_constructable_names(&variant.fields));
let variant_name = variant.ident.clone();
quote! {
#index => {
#name :: #variant_name {
#(#decode_statements)*
}
}
}
});
let result = quote! {
impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause {
fn decode<D: bincode::de::Decode>(mut decoder: D) -> Result<#name #ty_generics, bincode::error::DecodeError> {
let i = decoder.decode_u32()?;
Ok(match i {
#(#match_arms)*
variant => return Err(bincode::error::DecodeError::UnexpectedVariant{
min: 0,
max: #max_variant,
found: variant,
})
})
}
}
};
Ok(result.into())
}
}
fn fields_to_match_arm(fields: &Fields) -> TokenStream2 {
match fields {
syn::Fields::Named(fields) => {
let fields: Vec<_> = fields
.named
.iter()
.map(|f| f.ident.clone().unwrap().to_token_stream())
.collect();
quote! {
{#(#fields),*}
}
}
syn::Fields::Unnamed(fields) => {
let fields: Vec<_> = fields
.unnamed
.iter()
.enumerate()
.map(|(i, f)| Ident::new(&format!("_{}", i), f.span()))
.collect();
quote! {
(#(#fields),*)
}
}
syn::Fields::Unit => quote! {},
}
}
fn fields_to_names(fields: &Fields) -> Vec<TokenStream2> {
match fields {
syn::Fields::Named(fields) => fields
.named
.iter()
.map(|f| f.ident.clone().unwrap().to_token_stream())
.collect(),
syn::Fields::Unnamed(fields) => fields
.unnamed
.iter()
.enumerate()
.map(|(i, f)| Ident::new(&format!("_{}", i), f.span()).to_token_stream())
.collect(),
syn::Fields::Unit => Vec::new(),
}
}
fn field_names_to_encodable(names: &[TokenStream2]) -> Vec<TokenStream2> {
names
.into_iter()
.map(|field| {
quote! {
bincode::enc::Encodeable::encode(#field, &mut encoder)?;
}
})
.collect::<Vec<_>>()
}
fn fields_to_constructable_names(fields: &Fields) -> Vec<TokenStream2> {
match fields {
syn::Fields::Named(fields) => fields
.named
.iter()
.map(|f| f.ident.clone().unwrap().to_token_stream())
.collect(),
syn::Fields::Unnamed(fields) => fields
.unnamed
.iter()
.enumerate()
.map(|(i, f)| Index::from(i).to_token_stream())
.collect(),
syn::Fields::Unit => Vec::new(),
}
}
fn field_names_to_decodable(names: &[TokenStream2]) -> Vec<TokenStream2> {
names
.into_iter()
.map(|field| {
quote! {
#field: bincode::de::Decodable::decode(&mut decoder)?,
}
})
.collect::<Vec<_>>()
}

View File

@ -25,7 +25,8 @@ fn derive_encodable_inner(input: DeriveInput) -> Result<TokenStream> {
.and_then(|str| str.to_encodable())
}
syn::Data::Enum(enum_definition) => {
DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_encodable())
DeriveEnum::parse(input.ident, input.generics, enum_definition)
.and_then(|str| str.to_encodable())
}
syn::Data::Union(_) => Err(Error::UnionNotSupported),
}
@ -44,7 +45,8 @@ fn derive_decodable_inner(input: DeriveInput) -> Result<TokenStream> {
.and_then(|str| str.to_decodable())
}
syn::Data::Enum(enum_definition) => {
DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_decodable())
DeriveEnum::parse(input.ident, input.generics, enum_definition)
.and_then(|str| str.to_decodable())
}
syn::Data::Union(_) => Err(Error::UnionNotSupported),
}

View File

@ -16,6 +16,11 @@ pub enum DecodeError {
/// The type that was encoded in the data
found: IntegerType,
},
UnexpectedVariant {
min: u32,
max: u32,
found: u32,
},
}
#[non_exhaustive]

View File

@ -17,6 +17,13 @@ pub struct Test2<T: Decodable> {
#[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)]
pub struct TestTupleStruct(u32, u32, u32);
#[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)]
pub enum TestEnum {
Foo,
Bar { name: u32 },
Baz(u32, u32, u32),
}
#[test]
fn test_encodable() {
let start = Test {
@ -58,3 +65,54 @@ fn test_decodable_tuple() {
let result: TestTupleStruct = bincode::decode(&mut slice).unwrap();
assert_eq!(result, start);
}
#[test]
fn test_encodable_enum_struct_variant() {
let start = TestEnum::Bar { name: 5u32 };
let mut slice = [0u8; 1024];
let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap();
assert_eq!(bytes_written, 2);
assert_eq!(&slice[..bytes_written], &[1, 5]);
}
#[test]
fn test_decodable_enum_struct_variant() {
let start = TestEnum::Bar { name: 5u32 };
let mut slice = [1, 5];
let result: TestEnum = bincode::decode(&mut slice).unwrap();
assert_eq!(result, start);
}
#[test]
fn test_encodable_enum_tuple_variant() {
let start = TestEnum::Baz(5, 10, 1024);
let mut slice = [0u8; 1024];
let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap();
assert_eq!(bytes_written, 6);
assert_eq!(&slice[..bytes_written], &[2, 5, 10, 251, 0, 4]);
}
#[test]
fn test_decodable_enum_unit_variant() {
let start = TestEnum::Foo;
let mut slice = [0];
let result: TestEnum = bincode::decode(&mut slice).unwrap();
assert_eq!(result, start);
}
#[test]
fn test_encodable_enum_unit_variant() {
let start = TestEnum::Foo;
let mut slice = [0u8; 1024];
let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap();
assert_eq!(bytes_written, 1);
assert_eq!(&slice[..bytes_written], &[0]);
}
#[test]
fn test_decodable_enum_tuple_variant() {
let start = TestEnum::Baz(5, 10, 1024);
let mut slice = [2, 5, 10, 251, 0, 4];
let result: TestEnum = bincode::decode(&mut slice).unwrap();
assert_eq!(result, start);
}