From c83c36333d9ecc1d2763049190bd4f5bf5ae79a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lena=20Hellstr=C3=B6m?= Date: Mon, 20 Sep 2021 23:06:25 +0200 Subject: [PATCH] Add enum encode/decode derive --- derive/src/derive_enum.rs | 178 ++++++++++++++++++++++++++++++++++++-- derive/src/lib.rs | 6 +- src/error.rs | 5 ++ tests/derive.rs | 58 +++++++++++++ 4 files changed, 238 insertions(+), 9 deletions(-) diff --git a/derive/src/derive_enum.rs b/derive/src/derive_enum.rs index 73c0b08..cb1d563 100644 --- a/derive/src/derive_enum.rs +++ b/derive/src/derive_enum.rs @@ -1,19 +1,183 @@ use crate::Result; use proc_macro::TokenStream; -use syn::Ident; - -pub struct DeriveEnum {} +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use quote::ToTokens; +use syn::{spanned::Spanned, Field, Fields, Generics, Ident, Index, Variant}; +pub struct DeriveEnum { + name: Ident, + generics: Generics, + variants: Vec, +} impl DeriveEnum { - pub fn parse(_name: Ident, _en: syn::DataEnum) -> Result { - unimplemented!() + pub fn parse(name: Ident, generics: Generics, en: syn::DataEnum) -> Result { + let variants = en.variants.into_iter().collect(); + + Ok(DeriveEnum { + name, + generics, + variants, + }) } pub fn to_encodable(self) -> Result { - unimplemented!() + let DeriveEnum { + name, + generics, + variants, + } = self; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let match_arms = variants.iter().enumerate().map(|(index, variant)| { + let fields_section = fields_to_match_arm(&variant.fields); + let encode_statements = field_names_to_encodable(&fields_to_names(&variant.fields)); + let variant_name = variant.ident.clone(); + quote! { + #name :: #variant_name #fields_section => { + encoder.encode_u32(#index as u32)?; + #(#encode_statements)* + } + } + }); + let result = quote! { + impl #impl_generics bincode::enc::Encodeable for #name #ty_generics #where_clause { + fn encode(&self, mut encoder: E) -> Result<(), bincode::error::EncodeError> { + match self { + #(#match_arms)* + } + Ok(()) + } + + } + }; + + Ok(result.into()) } pub fn to_decodable(self) -> Result { - unimplemented!() + let DeriveEnum { + name, + generics, + variants, + } = self; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let max_variant = (variants.len() - 1) as u32; + let match_arms = variants.iter().enumerate().map(|(index, variant)| { + let index = index as u32; + let decode_statements = + field_names_to_decodable(&fields_to_constructable_names(&variant.fields)); + let variant_name = variant.ident.clone(); + quote! { + #index => { + #name :: #variant_name { + #(#decode_statements)* + } + } + } + }); + let result = quote! { + impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause { + fn decode(mut decoder: D) -> Result<#name #ty_generics, bincode::error::DecodeError> { + let i = decoder.decode_u32()?; + Ok(match i { + #(#match_arms)* + variant => return Err(bincode::error::DecodeError::UnexpectedVariant{ + min: 0, + max: #max_variant, + found: variant, + }) + }) + } + + } + }; + + Ok(result.into()) } } + +fn fields_to_match_arm(fields: &Fields) -> TokenStream2 { + match fields { + syn::Fields::Named(fields) => { + let fields: Vec<_> = fields + .named + .iter() + .map(|f| f.ident.clone().unwrap().to_token_stream()) + .collect(); + quote! { + {#(#fields),*} + } + } + syn::Fields::Unnamed(fields) => { + let fields: Vec<_> = fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| Ident::new(&format!("_{}", i), f.span())) + .collect(); + quote! { + (#(#fields),*) + } + } + syn::Fields::Unit => quote! {}, + } +} + +fn fields_to_names(fields: &Fields) -> Vec { + match fields { + syn::Fields::Named(fields) => fields + .named + .iter() + .map(|f| f.ident.clone().unwrap().to_token_stream()) + .collect(), + syn::Fields::Unnamed(fields) => fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| Ident::new(&format!("_{}", i), f.span()).to_token_stream()) + .collect(), + syn::Fields::Unit => Vec::new(), + } +} + +fn field_names_to_encodable(names: &[TokenStream2]) -> Vec { + names + .into_iter() + .map(|field| { + quote! { + bincode::enc::Encodeable::encode(#field, &mut encoder)?; + } + }) + .collect::>() +} + +fn fields_to_constructable_names(fields: &Fields) -> Vec { + match fields { + syn::Fields::Named(fields) => fields + .named + .iter() + .map(|f| f.ident.clone().unwrap().to_token_stream()) + .collect(), + syn::Fields::Unnamed(fields) => fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| Index::from(i).to_token_stream()) + .collect(), + syn::Fields::Unit => Vec::new(), + } +} + +fn field_names_to_decodable(names: &[TokenStream2]) -> Vec { + names + .into_iter() + .map(|field| { + quote! { + #field: bincode::de::Decodable::decode(&mut decoder)?, + } + }) + .collect::>() +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index c9cce74..2925525 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -25,7 +25,8 @@ fn derive_encodable_inner(input: DeriveInput) -> Result { .and_then(|str| str.to_encodable()) } syn::Data::Enum(enum_definition) => { - DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_encodable()) + DeriveEnum::parse(input.ident, input.generics, enum_definition) + .and_then(|str| str.to_encodable()) } syn::Data::Union(_) => Err(Error::UnionNotSupported), } @@ -44,7 +45,8 @@ fn derive_decodable_inner(input: DeriveInput) -> Result { .and_then(|str| str.to_decodable()) } syn::Data::Enum(enum_definition) => { - DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_decodable()) + DeriveEnum::parse(input.ident, input.generics, enum_definition) + .and_then(|str| str.to_decodable()) } syn::Data::Union(_) => Err(Error::UnionNotSupported), } diff --git a/src/error.rs b/src/error.rs index 7ade7fc..47340c7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,6 +16,11 @@ pub enum DecodeError { /// The type that was encoded in the data found: IntegerType, }, + UnexpectedVariant { + min: u32, + max: u32, + found: u32, + }, } #[non_exhaustive] diff --git a/tests/derive.rs b/tests/derive.rs index e407837..e699aee 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -17,6 +17,13 @@ pub struct Test2 { #[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)] pub struct TestTupleStruct(u32, u32, u32); +#[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)] +pub enum TestEnum { + Foo, + Bar { name: u32 }, + Baz(u32, u32, u32), +} + #[test] fn test_encodable() { let start = Test { @@ -58,3 +65,54 @@ fn test_decodable_tuple() { let result: TestTupleStruct = bincode::decode(&mut slice).unwrap(); assert_eq!(result, start); } + +#[test] +fn test_encodable_enum_struct_variant() { + let start = TestEnum::Bar { name: 5u32 }; + let mut slice = [0u8; 1024]; + let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap(); + assert_eq!(bytes_written, 2); + assert_eq!(&slice[..bytes_written], &[1, 5]); +} + +#[test] +fn test_decodable_enum_struct_variant() { + let start = TestEnum::Bar { name: 5u32 }; + let mut slice = [1, 5]; + let result: TestEnum = bincode::decode(&mut slice).unwrap(); + assert_eq!(result, start); +} + +#[test] +fn test_encodable_enum_tuple_variant() { + let start = TestEnum::Baz(5, 10, 1024); + let mut slice = [0u8; 1024]; + let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap(); + assert_eq!(bytes_written, 6); + assert_eq!(&slice[..bytes_written], &[2, 5, 10, 251, 0, 4]); +} + +#[test] +fn test_decodable_enum_unit_variant() { + let start = TestEnum::Foo; + let mut slice = [0]; + let result: TestEnum = bincode::decode(&mut slice).unwrap(); + assert_eq!(result, start); +} + +#[test] +fn test_encodable_enum_unit_variant() { + let start = TestEnum::Foo; + let mut slice = [0u8; 1024]; + let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap(); + assert_eq!(bytes_written, 1); + assert_eq!(&slice[..bytes_written], &[0]); +} + +#[test] +fn test_decodable_enum_tuple_variant() { + let start = TestEnum::Baz(5, 10, 1024); + let mut slice = [2, 5, 10, 251, 0, 4]; + let result: TestEnum = bincode::decode(&mut slice).unwrap(); + assert_eq!(result, start); +}