mirror of https://git.sr.ht/~stygianentity/bincode
Add enum encode/decode derive
This commit is contained in:
parent
2d0254405b
commit
c83c36333d
|
|
@ -1,19 +1,183 @@
|
|||
use crate::Result;
|
||||
use proc_macro::TokenStream;
|
||||
use syn::Ident;
|
||||
|
||||
pub struct DeriveEnum {}
|
||||
use proc_macro2::TokenStream as TokenStream2;
|
||||
use quote::quote;
|
||||
use quote::ToTokens;
|
||||
use syn::{spanned::Spanned, Field, Fields, Generics, Ident, Index, Variant};
|
||||
pub struct DeriveEnum {
|
||||
name: Ident,
|
||||
generics: Generics,
|
||||
variants: Vec<Variant>,
|
||||
}
|
||||
|
||||
impl DeriveEnum {
|
||||
pub fn parse(_name: Ident, _en: syn::DataEnum) -> Result<Self> {
|
||||
unimplemented!()
|
||||
pub fn parse(name: Ident, generics: Generics, en: syn::DataEnum) -> Result<Self> {
|
||||
let variants = en.variants.into_iter().collect();
|
||||
|
||||
Ok(DeriveEnum {
|
||||
name,
|
||||
generics,
|
||||
variants,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_encodable(self) -> Result<TokenStream> {
|
||||
unimplemented!()
|
||||
let DeriveEnum {
|
||||
name,
|
||||
generics,
|
||||
variants,
|
||||
} = self;
|
||||
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
let match_arms = variants.iter().enumerate().map(|(index, variant)| {
|
||||
let fields_section = fields_to_match_arm(&variant.fields);
|
||||
let encode_statements = field_names_to_encodable(&fields_to_names(&variant.fields));
|
||||
let variant_name = variant.ident.clone();
|
||||
quote! {
|
||||
#name :: #variant_name #fields_section => {
|
||||
encoder.encode_u32(#index as u32)?;
|
||||
#(#encode_statements)*
|
||||
}
|
||||
}
|
||||
});
|
||||
let result = quote! {
|
||||
impl #impl_generics bincode::enc::Encodeable for #name #ty_generics #where_clause {
|
||||
fn encode<E: bincode::enc::Encode>(&self, mut encoder: E) -> Result<(), bincode::error::EncodeError> {
|
||||
match self {
|
||||
#(#match_arms)*
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
Ok(result.into())
|
||||
}
|
||||
|
||||
pub fn to_decodable(self) -> Result<TokenStream> {
|
||||
unimplemented!()
|
||||
let DeriveEnum {
|
||||
name,
|
||||
generics,
|
||||
variants,
|
||||
} = self;
|
||||
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
let max_variant = (variants.len() - 1) as u32;
|
||||
let match_arms = variants.iter().enumerate().map(|(index, variant)| {
|
||||
let index = index as u32;
|
||||
let decode_statements =
|
||||
field_names_to_decodable(&fields_to_constructable_names(&variant.fields));
|
||||
let variant_name = variant.ident.clone();
|
||||
quote! {
|
||||
#index => {
|
||||
#name :: #variant_name {
|
||||
#(#decode_statements)*
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let result = quote! {
|
||||
impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause {
|
||||
fn decode<D: bincode::de::Decode>(mut decoder: D) -> Result<#name #ty_generics, bincode::error::DecodeError> {
|
||||
let i = decoder.decode_u32()?;
|
||||
Ok(match i {
|
||||
#(#match_arms)*
|
||||
variant => return Err(bincode::error::DecodeError::UnexpectedVariant{
|
||||
min: 0,
|
||||
max: #max_variant,
|
||||
found: variant,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
Ok(result.into())
|
||||
}
|
||||
}
|
||||
|
||||
fn fields_to_match_arm(fields: &Fields) -> TokenStream2 {
|
||||
match fields {
|
||||
syn::Fields::Named(fields) => {
|
||||
let fields: Vec<_> = fields
|
||||
.named
|
||||
.iter()
|
||||
.map(|f| f.ident.clone().unwrap().to_token_stream())
|
||||
.collect();
|
||||
quote! {
|
||||
{#(#fields),*}
|
||||
}
|
||||
}
|
||||
syn::Fields::Unnamed(fields) => {
|
||||
let fields: Vec<_> = fields
|
||||
.unnamed
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, f)| Ident::new(&format!("_{}", i), f.span()))
|
||||
.collect();
|
||||
quote! {
|
||||
(#(#fields),*)
|
||||
}
|
||||
}
|
||||
syn::Fields::Unit => quote! {},
|
||||
}
|
||||
}
|
||||
|
||||
fn fields_to_names(fields: &Fields) -> Vec<TokenStream2> {
|
||||
match fields {
|
||||
syn::Fields::Named(fields) => fields
|
||||
.named
|
||||
.iter()
|
||||
.map(|f| f.ident.clone().unwrap().to_token_stream())
|
||||
.collect(),
|
||||
syn::Fields::Unnamed(fields) => fields
|
||||
.unnamed
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, f)| Ident::new(&format!("_{}", i), f.span()).to_token_stream())
|
||||
.collect(),
|
||||
syn::Fields::Unit => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn field_names_to_encodable(names: &[TokenStream2]) -> Vec<TokenStream2> {
|
||||
names
|
||||
.into_iter()
|
||||
.map(|field| {
|
||||
quote! {
|
||||
bincode::enc::Encodeable::encode(#field, &mut encoder)?;
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn fields_to_constructable_names(fields: &Fields) -> Vec<TokenStream2> {
|
||||
match fields {
|
||||
syn::Fields::Named(fields) => fields
|
||||
.named
|
||||
.iter()
|
||||
.map(|f| f.ident.clone().unwrap().to_token_stream())
|
||||
.collect(),
|
||||
syn::Fields::Unnamed(fields) => fields
|
||||
.unnamed
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, f)| Index::from(i).to_token_stream())
|
||||
.collect(),
|
||||
syn::Fields::Unit => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn field_names_to_decodable(names: &[TokenStream2]) -> Vec<TokenStream2> {
|
||||
names
|
||||
.into_iter()
|
||||
.map(|field| {
|
||||
quote! {
|
||||
#field: bincode::de::Decodable::decode(&mut decoder)?,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ fn derive_encodable_inner(input: DeriveInput) -> Result<TokenStream> {
|
|||
.and_then(|str| str.to_encodable())
|
||||
}
|
||||
syn::Data::Enum(enum_definition) => {
|
||||
DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_encodable())
|
||||
DeriveEnum::parse(input.ident, input.generics, enum_definition)
|
||||
.and_then(|str| str.to_encodable())
|
||||
}
|
||||
syn::Data::Union(_) => Err(Error::UnionNotSupported),
|
||||
}
|
||||
|
|
@ -44,7 +45,8 @@ fn derive_decodable_inner(input: DeriveInput) -> Result<TokenStream> {
|
|||
.and_then(|str| str.to_decodable())
|
||||
}
|
||||
syn::Data::Enum(enum_definition) => {
|
||||
DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_decodable())
|
||||
DeriveEnum::parse(input.ident, input.generics, enum_definition)
|
||||
.and_then(|str| str.to_decodable())
|
||||
}
|
||||
syn::Data::Union(_) => Err(Error::UnionNotSupported),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,11 @@ pub enum DecodeError {
|
|||
/// The type that was encoded in the data
|
||||
found: IntegerType,
|
||||
},
|
||||
UnexpectedVariant {
|
||||
min: u32,
|
||||
max: u32,
|
||||
found: u32,
|
||||
},
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
|
|
|
|||
|
|
@ -17,6 +17,13 @@ pub struct Test2<T: Decodable> {
|
|||
#[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)]
|
||||
pub struct TestTupleStruct(u32, u32, u32);
|
||||
|
||||
#[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)]
|
||||
pub enum TestEnum {
|
||||
Foo,
|
||||
Bar { name: u32 },
|
||||
Baz(u32, u32, u32),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodable() {
|
||||
let start = Test {
|
||||
|
|
@ -58,3 +65,54 @@ fn test_decodable_tuple() {
|
|||
let result: TestTupleStruct = bincode::decode(&mut slice).unwrap();
|
||||
assert_eq!(result, start);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodable_enum_struct_variant() {
|
||||
let start = TestEnum::Bar { name: 5u32 };
|
||||
let mut slice = [0u8; 1024];
|
||||
let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap();
|
||||
assert_eq!(bytes_written, 2);
|
||||
assert_eq!(&slice[..bytes_written], &[1, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decodable_enum_struct_variant() {
|
||||
let start = TestEnum::Bar { name: 5u32 };
|
||||
let mut slice = [1, 5];
|
||||
let result: TestEnum = bincode::decode(&mut slice).unwrap();
|
||||
assert_eq!(result, start);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodable_enum_tuple_variant() {
|
||||
let start = TestEnum::Baz(5, 10, 1024);
|
||||
let mut slice = [0u8; 1024];
|
||||
let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap();
|
||||
assert_eq!(bytes_written, 6);
|
||||
assert_eq!(&slice[..bytes_written], &[2, 5, 10, 251, 0, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decodable_enum_unit_variant() {
|
||||
let start = TestEnum::Foo;
|
||||
let mut slice = [0];
|
||||
let result: TestEnum = bincode::decode(&mut slice).unwrap();
|
||||
assert_eq!(result, start);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodable_enum_unit_variant() {
|
||||
let start = TestEnum::Foo;
|
||||
let mut slice = [0u8; 1024];
|
||||
let bytes_written = bincode::encode_into_slice(start, &mut slice).unwrap();
|
||||
assert_eq!(bytes_written, 1);
|
||||
assert_eq!(&slice[..bytes_written], &[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decodable_enum_tuple_variant() {
|
||||
let start = TestEnum::Baz(5, 10, 1024);
|
||||
let mut slice = [2, 5, 10, 251, 0, 4];
|
||||
let result: TestEnum = bincode::decode(&mut slice).unwrap();
|
||||
assert_eq!(result, start);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue