From 044942891f41c2a3dbe47ea6ece0ee15b246ac4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lena=20Hellstr=C3=B6m?= Date: Wed, 22 Sep 2021 22:15:35 +0200 Subject: [PATCH] Clean up the borrow crimes --- derive/src/derive_enum.rs | 87 +++++++++++++++++++++++++------------ derive/src/derive_struct.rs | 64 ++++++++++++++++++--------- src/de/decoder.rs | 2 +- src/de/impls.rs | 77 +++++++++++++++++--------------- src/de/mod.rs | 12 ++--- src/features/impl_std.rs | 4 +- tests/derive.rs | 23 +++++++--- tests/test.rs | 4 +- 8 files changed, 173 insertions(+), 100 deletions(-) diff --git a/derive/src/derive_enum.rs b/derive/src/derive_enum.rs index d3b0e6c..e08c63c 100644 --- a/derive/src/derive_enum.rs +++ b/derive/src/derive_enum.rs @@ -66,15 +66,13 @@ impl DeriveEnum { let (mut impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - // check if we don't already have a '__de lifetime - let mut should_insert_lifetime = true; + // check if the type has lifetimes + let mut should_insert_lifetime = false; for param in &generics.params { - if let GenericParam::Lifetime(lt) = param { - if lt.lifetime.ident == "__de" { - should_insert_lifetime = false; - break; - } + if let GenericParam::Lifetime(_) = param { + should_insert_lifetime = true; + break; } } @@ -82,12 +80,18 @@ impl DeriveEnum { let mut generics_with_decode_lifetime; if should_insert_lifetime { generics_with_decode_lifetime = generics.clone(); + + let mut new_lifetime = LifetimeDef::new(Lifetime::new("'__de", Span::call_site())); + + for param in &generics.params { + if let GenericParam::Lifetime(lt) = param { + new_lifetime.bounds.push(lt.lifetime.clone()) + } + } + generics_with_decode_lifetime .params - .push(GenericParam::Lifetime(LifetimeDef::new(Lifetime::new( - "'__de", - Span::call_site(), - )))); + .push(GenericParam::Lifetime(new_lifetime)); impl_generics = generics_with_decode_lifetime.split_for_impl().0; } @@ -95,8 +99,10 @@ impl DeriveEnum { let max_variant = (variants.len() - 1) as u32; let match_arms = variants.iter().enumerate().map(|(index, variant)| { let index = index as u32; - let decode_statements = - field_names_to_decodable(&fields_to_constructable_names(&variant.fields)); + let decode_statements = field_names_to_decodable( + &fields_to_constructable_names(&variant.fields), + should_insert_lifetime, + ); let variant_name = variant.ident.clone(); quote! { #index => { @@ -106,20 +112,39 @@ impl DeriveEnum { } } }); - let result = quote! { - impl #impl_generics bincode::de::Decodable<'__de> for #name #ty_generics #where_clause { - fn decode>(mut decoder: D) -> Result<#name #ty_generics, bincode::error::DecodeError> { - let i = decoder.decode_u32()?; - Ok(match i { - #(#match_arms)* - variant => return Err(bincode::error::DecodeError::UnexpectedVariant{ - min: 0, - max: #max_variant, - found: variant, + let result = if should_insert_lifetime { + quote! { + impl #impl_generics bincode::de::BorrowDecodable<'__de> for #name #ty_generics #where_clause { + fn borrow_decode>(mut decoder: D) -> Result { + let i = decoder.decode_u32()?; + Ok(match i { + #(#match_arms)* + variant => return Err(bincode::error::DecodeError::UnexpectedVariant{ + min: 0, + max: #max_variant, + found: variant, + }) }) - }) - } + } + } + } + } else { + quote! { + impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause { + fn decode(mut decoder: D) -> Result { + let i = decoder.decode_u32()?; + Ok(match i { + #(#match_arms)* + variant => return Err(bincode::error::DecodeError::UnexpectedVariant{ + min: 0, + max: #max_variant, + found: variant, + }) + }) + } + + } } }; @@ -199,12 +224,18 @@ fn fields_to_constructable_names(fields: &Fields) -> Vec { } } -fn field_names_to_decodable(names: &[TokenStream2]) -> Vec { +fn field_names_to_decodable(names: &[TokenStream2], borrowed: bool) -> Vec { names .iter() .map(|field| { - quote! { - #field: bincode::de::Decodable::decode(&mut decoder)?, + if borrowed { + quote! { + #field: bincode::de::BorrowDecodable::borrow_decode(&mut decoder)?, + } + } else { + quote! { + #field: bincode::de::Decodable::decode(&mut decoder)?, + } } }) .collect::>() diff --git a/derive/src/derive_struct.rs b/derive/src/derive_struct.rs index 1c43796..881545f 100644 --- a/derive/src/derive_struct.rs +++ b/derive/src/derive_struct.rs @@ -26,6 +26,7 @@ impl DeriveStruct { .collect(), syn::Fields::Unit => Vec::new(), }; + Ok(Self { name, generics, @@ -72,28 +73,30 @@ impl DeriveStruct { let (mut impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - // check if we don't already have a '__de lifetime - let mut should_insert_lifetime = true; + // check if the type has lifetimes + let mut should_insert_lifetime = false; for param in &generics.params { - if let GenericParam::Lifetime(lt) = param { - if lt.lifetime.ident == "__de" { - should_insert_lifetime = false; - break; - } + if let GenericParam::Lifetime(_) = param { + should_insert_lifetime = true; + break; } } - // if we don't have a '__de lifetime, insert it + // if the type has lifetimes, insert '__de and bound it to the lifetimes let mut generics_with_decode_lifetime; if should_insert_lifetime { generics_with_decode_lifetime = generics.clone(); + let mut new_lifetime = LifetimeDef::new(Lifetime::new("'__de", Span::call_site())); + + for param in &generics.params { + if let GenericParam::Lifetime(lt) = param { + new_lifetime.bounds.push(lt.lifetime.clone()) + } + } generics_with_decode_lifetime .params - .push(GenericParam::Lifetime(LifetimeDef::new(Lifetime::new( - "'__de", - Span::call_site(), - )))); + .push(GenericParam::Lifetime(new_lifetime)); impl_generics = generics_with_decode_lifetime.split_for_impl().0; } @@ -101,20 +104,39 @@ impl DeriveStruct { let fields = fields .into_iter() .map(|field| { - quote! { - #field: bincode::de::Decodable::decode(&mut decoder)?, + if should_insert_lifetime { + quote! { + #field: bincode::de::BorrowDecodable::borrow_decode(&mut decoder)?, + } + } else { + quote! { + #field: bincode::de::Decodable::decode(&mut decoder)?, + } } }) .collect::>(); - let result = quote! { - impl #impl_generics bincode::de::Decodable< '__de > for #name #ty_generics #where_clause { - fn decode>(mut decoder: D) -> Result { - Ok(#name { - #(#fields)* - }) - } + let result = if should_insert_lifetime { + quote! { + impl #impl_generics bincode::de::BorrowDecodable<'__de> for #name #ty_generics #where_clause { + fn borrow_decode>(mut decoder: D) -> Result { + Ok(#name { + #(#fields)* + }) + } + } + } + } else { + quote! { + impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause { + fn decode(mut decoder: D) -> Result { + Ok(#name { + #(#fields)* + }) + } + + } } }; diff --git a/src/de/decoder.rs b/src/de/decoder.rs index 2599112..ff10d37 100644 --- a/src/de/decoder.rs +++ b/src/de/decoder.rs @@ -32,7 +32,7 @@ impl<'a, 'de, R: BorrowReader<'de>, C: Config> BorrowDecode<'de> for &'a mut Dec } } -impl<'a, 'de, R: Reader<'de>, C: Config> Decode<'de> for &'a mut Decoder { +impl<'a, 'de, R: Reader<'de>, C: Config> Decode for &'a mut Decoder { fn decode_u8(&mut self) -> Result { let mut bytes = [0u8; 1]; self.reader.read(&mut bytes)?; diff --git a/src/de/impls.rs b/src/de/impls.rs index afbb55f..21164be 100644 --- a/src/de/impls.rs +++ b/src/de/impls.rs @@ -1,86 +1,86 @@ use super::{BorrowDecodable, BorrowDecode, Decodable, Decode}; use crate::error::DecodeError; -impl<'de> Decodable<'de> for u8 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for u8 { + fn decode(mut decoder: D) -> Result { decoder.decode_u8() } } -impl<'de> Decodable<'de> for u16 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for u16 { + fn decode(mut decoder: D) -> Result { decoder.decode_u16() } } -impl<'de> Decodable<'de> for u32 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for u32 { + fn decode(mut decoder: D) -> Result { decoder.decode_u32() } } -impl<'de> Decodable<'de> for u64 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for u64 { + fn decode(mut decoder: D) -> Result { decoder.decode_u64() } } -impl<'de> Decodable<'de> for u128 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for u128 { + fn decode(mut decoder: D) -> Result { decoder.decode_u128() } } -impl<'de> Decodable<'de> for usize { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for usize { + fn decode(mut decoder: D) -> Result { decoder.decode_usize() } } -impl<'de> Decodable<'de> for i8 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for i8 { + fn decode(mut decoder: D) -> Result { decoder.decode_i8() } } -impl<'de> Decodable<'de> for i16 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for i16 { + fn decode(mut decoder: D) -> Result { decoder.decode_i16() } } -impl<'de> Decodable<'de> for i32 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for i32 { + fn decode(mut decoder: D) -> Result { decoder.decode_i32() } } -impl<'de> Decodable<'de> for i64 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for i64 { + fn decode(mut decoder: D) -> Result { decoder.decode_i64() } } -impl<'de> Decodable<'de> for i128 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for i128 { + fn decode(mut decoder: D) -> Result { decoder.decode_i128() } } -impl<'de> Decodable<'de> for isize { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for isize { + fn decode(mut decoder: D) -> Result { decoder.decode_isize() } } -impl<'de> Decodable<'de> for f32 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for f32 { + fn decode(mut decoder: D) -> Result { decoder.decode_f32() } } -impl<'de> Decodable<'de> for f64 { - fn decode>(mut decoder: D) -> Result { +impl<'de> Decodable for f64 { + fn decode(mut decoder: D) -> Result { decoder.decode_f64() } } @@ -99,21 +99,21 @@ impl<'a, 'de: 'a> BorrowDecodable<'de> for &'a str { } } -impl<'de, const N: usize> Decodable<'de> for [u8; N] { - fn decode>(mut decoder: D) -> Result { +impl<'de, const N: usize> Decodable for [u8; N] { + fn decode(mut decoder: D) -> Result { decoder.decode_array() } } -impl<'de, T> Decodable<'de> for core::marker::PhantomData { - fn decode>(_: D) -> Result { +impl<'de, T> Decodable for core::marker::PhantomData { + fn decode(_: D) -> Result { Ok(core::marker::PhantomData) } } -impl<'a, 'de, T> Decode<'de> for &'a mut T +impl<'a, 'de, T> Decode for &'a mut T where - T: Decode<'de>, + T: Decode, { fn decode_u8(&mut self) -> Result { T::decode_u8(self) @@ -175,3 +175,12 @@ where T::decode_array::(self) } } + +impl<'a, 'de, T> BorrowDecode<'de> for &'a mut T +where + T: BorrowDecode<'de>, +{ + fn decode_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError> { + T::decode_slice(self, len) + } +} diff --git a/src/de/mod.rs b/src/de/mod.rs index d0dc730..199889e 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -6,21 +6,21 @@ mod impls; pub mod read; pub use self::decoder::Decoder; -pub trait Decodable<'de>: Sized + BorrowDecodable<'de> { - fn decode>(decoder: D) -> Result; +pub trait Decodable: for<'de> BorrowDecodable<'de> { + fn decode(decoder: D) -> Result; } pub trait BorrowDecodable<'de>: Sized { fn borrow_decode>(decoder: D) -> Result; } -impl<'de, T: Decodable<'de>> BorrowDecodable<'de> for T { - fn borrow_decode>(decoder: D) -> Result { +impl<'de, T: Decodable> BorrowDecodable<'de> for T { + fn borrow_decode(decoder: D) -> Result { Decodable::decode(decoder) } } -pub trait Decode<'de> { +pub trait Decode { fn decode_u8(&mut self) -> Result; fn decode_u16(&mut self) -> Result; fn decode_u32(&mut self) -> Result; @@ -40,6 +40,6 @@ pub trait Decode<'de> { fn decode_array(&mut self) -> Result<[u8; N], DecodeError>; } -pub trait BorrowDecode<'de>: Decode<'de> { +pub trait BorrowDecode<'de>: Decode { fn decode_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError>; } diff --git a/src/features/impl_std.rs b/src/features/impl_std.rs index 9cb3463..22b0ac1 100644 --- a/src/features/impl_std.rs +++ b/src/features/impl_std.rs @@ -4,13 +4,13 @@ use crate::{ error::DecodeError, }; -pub fn decode_from<'__de, D: Decodable<'__de>, R: std::io::Read>( +pub fn decode_from<'__de, D: Decodable, R: std::io::Read>( src: &'__de mut R, ) -> Result { decode_from_with_config(src, config::Default) } -pub fn decode_from_with_config<'__de, D: Decodable<'__de>, C: Config, R: std::io::Read>( +pub fn decode_from_with_config<'__de, D: Decodable, C: Config, R: std::io::Read>( src: &'__de mut R, _config: C, ) -> Result { diff --git a/tests/derive.rs b/tests/derive.rs index 2bf0606..c32701b 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -1,7 +1,6 @@ #![cfg(feature = "derive")] use bincode::{de::Decodable, enc::Encodeable}; -use core::marker::PhantomData; #[derive(bincode::Encodable, PartialEq, Debug)] pub struct Test { @@ -11,11 +10,17 @@ pub struct Test { } #[derive(bincode::Decodable, PartialEq, Debug, Eq)] -pub struct Test2<'__de, T: Decodable<'__de>> { +pub struct Test2 { a: T, b: u32, c: u32, - pd: PhantomData<&'__de ()>, +} + +#[derive(bincode::Decodable, PartialEq, Debug, Eq)] +pub struct Test3<'a> { + a: &'a str, + b: u32, + c: u32, } #[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)] @@ -28,6 +33,13 @@ pub enum TestEnum { Baz(u32, u32, u32), } +#[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)] +pub enum TestEnum2<'a> { + Foo, + Bar { name: &'a str }, + Baz(u32, u32, u32), +} + #[test] fn test_encodable() { let start = Test { @@ -47,10 +59,9 @@ fn test_decodable() { a: 5u32, b: 10u32, c: 1024u32, - pd: PhantomData, }; - let mut slice = [5, 10, 251, 0, 4]; - let result: Test2 = bincode::decode(&mut slice).unwrap(); + let slice = [5, 10, 251, 0, 4]; + let result: Test2 = bincode::decode_from(&mut slice.as_ref()).unwrap(); assert_eq!(result, start); } diff --git a/tests/test.rs b/tests/test.rs index f09625c..b0ca166 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -4,7 +4,7 @@ use core::fmt::Debug; fn the_same_with_config(element: V, config: C) where V: bincode::enc::Encodeable - + for<'de> bincode::de::Decodable<'de> + + for<'de> bincode::de::Decodable + PartialEq + Debug + Clone @@ -20,7 +20,7 @@ where fn the_same(element: V) where V: bincode::enc::Encodeable - + for<'de> bincode::de::Decodable<'de> + + for<'de> bincode::de::Decodable + PartialEq + Debug + Clone