From 8241e6c656495448c61007ac19c41376924d6773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lena=20Hellstr=C3=B6m?= Date: Mon, 20 Sep 2021 15:58:16 +0200 Subject: [PATCH] Add generic bound support to derive --- derive/src/derive_struct.rs | 23 +++++++++++++++++------ derive/src/lib.rs | 6 ++++-- tests/derive.rs | 6 ++++-- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/derive/src/derive_struct.rs b/derive/src/derive_struct.rs index 7f922e9..1668657 100644 --- a/derive/src/derive_struct.rs +++ b/derive/src/derive_struct.rs @@ -1,15 +1,16 @@ use crate::Result; use proc_macro::TokenStream; -use quote::quote; -use syn::{spanned::Spanned, Ident}; +use quote::{quote, quote_spanned}; +use syn::{spanned::Spanned, Generics, Ident}; pub struct DeriveStruct { name: Ident, + generics: Generics, fields: Vec, } impl DeriveStruct { - pub fn parse(name: Ident, str: syn::DataStruct) -> Result { + pub fn parse(name: Ident, generics: Generics, str: syn::DataStruct) -> Result { let fields = match str.fields { syn::Fields::Named(fields) => fields .named @@ -24,11 +25,21 @@ impl DeriveStruct { .collect(), syn::Fields::Unit => Vec::new(), }; - Ok(Self { name, fields }) + Ok(Self { + name, + generics, + fields, + }) } pub fn to_encodable(self) -> Result { - let DeriveStruct { name, fields } = self; + let DeriveStruct { + name, + generics, + fields, + } = self; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let fields = fields .into_iter() @@ -40,7 +51,7 @@ impl DeriveStruct { .collect::>(); let result = quote! { - impl bincode::enc::Encodeable for #name { + impl #impl_generics bincode::enc::Encodeable for #name #ty_generics #where_clause { fn encode(&self, mut encoder: E) -> Result<(), bincode::error::EncodeError> { #(#fields)* Ok(()) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 82be10a..c9cce74 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -21,7 +21,8 @@ pub fn derive_encodable(input: TokenStream) -> TokenStream { fn derive_encodable_inner(input: DeriveInput) -> Result { match input.data { syn::Data::Struct(struct_definition) => { - DeriveStruct::parse(input.ident, struct_definition).and_then(|str| str.to_encodable()) + DeriveStruct::parse(input.ident, input.generics, struct_definition) + .and_then(|str| str.to_encodable()) } syn::Data::Enum(enum_definition) => { DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_encodable()) @@ -39,7 +40,8 @@ pub fn derive_decodable(input: TokenStream) -> TokenStream { fn derive_decodable_inner(input: DeriveInput) -> Result { match input.data { syn::Data::Struct(struct_definition) => { - DeriveStruct::parse(input.ident, struct_definition).and_then(|str| str.to_decodable()) + DeriveStruct::parse(input.ident, input.generics, struct_definition) + .and_then(|str| str.to_decodable()) } syn::Data::Enum(enum_definition) => { DeriveEnum::parse(input.ident, enum_definition).and_then(|str| str.to_decodable()) diff --git a/tests/derive.rs b/tests/derive.rs index 4aa0156..155925b 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -1,6 +1,8 @@ +use bincode::enc::Encodeable; + #[derive(bincode::Encodable, PartialEq, Debug)] -pub struct Test { - a: i32, +pub struct Test { + a: T, b: u32, c: u8, }