From 9c7fb85e0ef29033170d43a66220dad9702a23e5 Mon Sep 17 00:00:00 2001 From: Victor Koenders Date: Wed, 22 Sep 2021 12:02:42 +0200 Subject: [PATCH] Added support for slices, str, fixed size arrays. Added lifetime to Decode trait --- derive/Cargo.toml | 2 +- derive/src/derive_enum.rs | 37 ++++++++++++++-- derive/src/derive_struct.rs | 37 +++++++++++++--- docs/spec.md | 37 ++++++++++++++++ src/de/decoder.rs | 6 +-- src/de/impls.rs | 88 +++++++++++++++++++++++-------------- src/de/mod.rs | 8 ++-- src/enc/encoder.rs | 5 +++ src/enc/impls.rs | 15 +++++++ src/enc/mod.rs | 1 + src/error.rs | 2 + src/lib.rs | 8 ++-- tests/derive.rs | 5 ++- tests/test.rs | 50 ++++++++++++++++++++- 14 files changed, 245 insertions(+), 56 deletions(-) diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 34d6a61..b4379a8 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -13,4 +13,4 @@ proc-macro2 = "1.0" [dependencies.syn] version = "1.0.74" default-features = false -features = ["parsing", "derive", "proc-macro", "printing"] +features = ["parsing", "derive", "proc-macro", "printing", "clone-impls"] diff --git a/derive/src/derive_enum.rs b/derive/src/derive_enum.rs index 9e14519..68ddda9 100644 --- a/derive/src/derive_enum.rs +++ b/derive/src/derive_enum.rs @@ -1,8 +1,12 @@ use crate::Result; use proc_macro::TokenStream; +use proc_macro2::Span; use proc_macro2::TokenStream as TokenStream2; use quote::quote; use quote::ToTokens; +use syn::GenericParam; +use syn::Lifetime; +use syn::LifetimeDef; use syn::{spanned::Spanned, Fields, Generics, Ident, Index, Variant}; pub struct DeriveEnum { name: Ident, @@ -63,7 +67,34 @@ impl DeriveEnum { variants, } = self; - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let (mut impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // check if we don't already have a '__de lifetime + let mut should_insert_lifetime = true; + + for param in &generics.params { + if let GenericParam::Lifetime(lt) = param { + if lt.lifetime.ident == "__de" { + should_insert_lifetime = false; + break; + } + } + } + + // if we don't have a '__de lifetime, insert it + let mut generics_with_decode_lifetime; + if should_insert_lifetime { + generics_with_decode_lifetime = generics.clone(); + generics_with_decode_lifetime + .params + .push(GenericParam::Lifetime(LifetimeDef::new(Lifetime::new( + "'__de", + Span::call_site(), + )))); + + impl_generics = generics_with_decode_lifetime.split_for_impl().0; + } + let max_variant = (variants.len() - 1) as u32; let match_arms = variants.iter().enumerate().map(|(index, variant)| { let index = index as u32; @@ -79,8 +110,8 @@ impl DeriveEnum { } }); let result = quote! { - impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause { - fn decode(mut decoder: D) -> Result<#name #ty_generics, bincode::error::DecodeError> { + impl #impl_generics bincode::de::Decodable<'__de> for #name #ty_generics #where_clause { + fn decode>(mut decoder: D) -> Result<#name #ty_generics, bincode::error::DecodeError> { let i = decoder.decode_u32()?; Ok(match i { #(#match_arms)* diff --git a/derive/src/derive_struct.rs b/derive/src/derive_struct.rs index ff63e74..1c43796 100644 --- a/derive/src/derive_struct.rs +++ b/derive/src/derive_struct.rs @@ -1,8 +1,8 @@ use crate::Result; use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; -use syn::{Generics, Ident, Index}; +use syn::{GenericParam, Generics, Ident, Index, Lifetime, LifetimeDef}; pub struct DeriveStruct { name: Ident, @@ -70,7 +70,33 @@ impl DeriveStruct { fields, } = self; - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let (mut impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // check if we don't already have a '__de lifetime + let mut should_insert_lifetime = true; + + for param in &generics.params { + if let GenericParam::Lifetime(lt) = param { + if lt.lifetime.ident == "__de" { + should_insert_lifetime = false; + break; + } + } + } + + // if we don't have a '__de lifetime, insert it + let mut generics_with_decode_lifetime; + if should_insert_lifetime { + generics_with_decode_lifetime = generics.clone(); + generics_with_decode_lifetime + .params + .push(GenericParam::Lifetime(LifetimeDef::new(Lifetime::new( + "'__de", + Span::call_site(), + )))); + + impl_generics = generics_with_decode_lifetime.split_for_impl().0; + } let fields = fields .into_iter() @@ -82,8 +108,8 @@ impl DeriveStruct { .collect::>(); let result = quote! { - impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause { - fn decode(mut decoder: D) -> Result<#name #ty_generics, bincode::error::DecodeError> { + impl #impl_generics bincode::de::Decodable< '__de > for #name #ty_generics #where_clause { + fn decode>(mut decoder: D) -> Result { Ok(#name { #(#fields)* }) @@ -91,6 +117,7 @@ impl DeriveStruct { } }; + Ok(result.into()) } } diff --git a/docs/spec.md b/docs/spec.md index fb24e24..8313c47 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -90,6 +90,8 @@ assert_eq!(encoded.as_slice(), &[ Collections are encoded with their length value first, following by each entry of the collection. The length value is based on your `IntEncoding`. +**note**: fixed array length do not have their `len` encoded. See [Arrays](#arrays) + ```rs let list = vec![ 0u8, @@ -121,3 +123,38 @@ assert_eq!(encoded.as_slice(), &[ b'H', b'e', b'l', b'l', b'o' ]); ``` + +# Arrays + +Arrays are encoded *without* a length. + +```rs +let arr: [u8; 5] = [10, 20, 30, 40, 50]; +let encoded = bincode::encode_to_vec(&list).unwrap(); +assert_eq!(encoded.as_slice(), &[10, 20, 30, 40 50]); +``` + +This applies to any type `T` that implements `Encodabl`/`Decodabl` + +```rs +#[derive(bincode::Encodabl)] +struct Foo { + first: u8, + second: u8 +}; + +let arr: [Foo; 2] = [ + Foo { + first: 10, + second: 20, + }, + Foo { + first: 30, + second: 40, + }, +]; + +let encoded = bincode::encode_to_vec(&list).unwrap(); +assert_eq!(encoded.as_slice(), &[10, 20, 30, 40]); +``` + diff --git a/src/de/decoder.rs b/src/de/decoder.rs index 7d04156..2819ec7 100644 --- a/src/de/decoder.rs +++ b/src/de/decoder.rs @@ -23,7 +23,7 @@ impl<'de, R: Reader<'de>, C: Config> Decoder { } } -impl<'a, 'de, R: Reader<'de>, C: Config> Decode for &'a mut Decoder { +impl<'a, 'de, R: Reader<'de>, C: Config> Decode<'de> for &'a mut Decoder { fn decode_u8(&mut self) -> Result { let mut bytes = [0u8; 1]; self.reader.read(&mut bytes)?; @@ -198,8 +198,8 @@ impl<'a, 'de, R: Reader<'de>, C: Config> Decode for &'a mut Decoder { }) } - fn decode_slice(&mut self, slice: &mut [u8]) -> Result<(), DecodeError> { - self.reader.read(slice) + fn decode_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError> { + self.reader.forward_read(len, |s| s) } fn decode_array(&mut self) -> Result<[u8; N], DecodeError> { diff --git a/src/de/impls.rs b/src/de/impls.rs index ce5a08c..9218939 100644 --- a/src/de/impls.rs +++ b/src/de/impls.rs @@ -1,99 +1,119 @@ use super::{Decodable, Decode}; use crate::error::DecodeError; -impl Decodable for u8 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for u8 { + fn decode>(mut decoder: D) -> Result { decoder.decode_u8() } } -impl Decodable for u16 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for u16 { + fn decode>(mut decoder: D) -> Result { decoder.decode_u16() } } -impl Decodable for u32 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for u32 { + fn decode>(mut decoder: D) -> Result { decoder.decode_u32() } } -impl Decodable for u64 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for u64 { + fn decode>(mut decoder: D) -> Result { decoder.decode_u64() } } -impl Decodable for u128 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for u128 { + fn decode>(mut decoder: D) -> Result { decoder.decode_u128() } } -impl Decodable for usize { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for usize { + fn decode>(mut decoder: D) -> Result { decoder.decode_usize() } } -impl Decodable for i8 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for i8 { + fn decode>(mut decoder: D) -> Result { decoder.decode_i8() } } -impl Decodable for i16 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for i16 { + fn decode>(mut decoder: D) -> Result { decoder.decode_i16() } } -impl Decodable for i32 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for i32 { + fn decode>(mut decoder: D) -> Result { decoder.decode_i32() } } -impl Decodable for i64 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for i64 { + fn decode>(mut decoder: D) -> Result { decoder.decode_i64() } } -impl Decodable for i128 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for i128 { + fn decode>(mut decoder: D) -> Result { decoder.decode_i128() } } -impl Decodable for isize { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for isize { + fn decode>(mut decoder: D) -> Result { decoder.decode_isize() } } -impl Decodable for f32 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for f32 { + fn decode>(mut decoder: D) -> Result { decoder.decode_f32() } } -impl Decodable for f64 { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for f64 { + fn decode>(mut decoder: D) -> Result { decoder.decode_f64() } } -impl Decodable for [u8; N] { - fn decode(mut decoder: D) -> Result { +impl<'de> Decodable<'de> for &'de [u8] { + fn decode>(mut decoder: D) -> Result { + let len = usize::decode(&mut decoder)?; + decoder.decode_slice(len) + } +} + +impl<'de> Decodable<'de> for &'de str { + fn decode>(decoder: D) -> Result { + let slice: &[u8] = Decodable::decode(decoder)?; + core::str::from_utf8(slice).map_err(DecodeError::Utf8) + } +} + +impl<'de, const N: usize> Decodable<'de> for [u8; N] { + fn decode>(mut decoder: D) -> Result { decoder.decode_array() } } -impl<'a, T> Decode for &'a mut T +impl<'de, T> Decodable<'de> for core::marker::PhantomData { + fn decode>(_: D) -> Result { + Ok(core::marker::PhantomData) + } +} + +impl<'a, 'de, T> Decode<'de> for &'a mut T where - T: Decode, + T: Decode<'de>, { fn decode_u8(&mut self) -> Result { T::decode_u8(self) @@ -151,8 +171,8 @@ where T::decode_f64(self) } - fn decode_slice(&mut self, slice: &mut [u8]) -> Result<(), DecodeError> { - T::decode_slice(self, slice) + fn decode_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError> { + T::decode_slice(self, len) } fn decode_array(&mut self) -> Result<[u8; N], DecodeError> { diff --git a/src/de/mod.rs b/src/de/mod.rs index dca66a2..31f3b38 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -6,11 +6,11 @@ mod impls; pub mod read; pub use self::decoder::Decoder; -pub trait Decodable: Sized { - fn decode(decoder: D) -> Result; +pub trait Decodable<'de>: Sized { + fn decode>(decoder: D) -> Result; } -pub trait Decode { +pub trait Decode<'de> { fn decode_u8(&mut self) -> Result; fn decode_u16(&mut self) -> Result; fn decode_u32(&mut self) -> Result; @@ -27,6 +27,6 @@ pub trait Decode { fn decode_f32(&mut self) -> Result; fn decode_f64(&mut self) -> Result; - fn decode_slice(&mut self, slice: &mut [u8]) -> Result<(), DecodeError>; + fn decode_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError>; fn decode_array(&mut self) -> Result<[u8; N], DecodeError>; } diff --git a/src/enc/encoder.rs b/src/enc/encoder.rs index c01ba62..5ffdada 100644 --- a/src/enc/encoder.rs +++ b/src/enc/encoder.rs @@ -168,6 +168,11 @@ impl<'a, W: Writer, C: Config> Encode for &'a mut Encoder { } fn encode_slice(&mut self, val: &[u8]) -> Result<(), EncodeError> { + self.encode_usize(val.len())?; self.writer.write(val) } + + fn encode_array(&mut self, val: [u8; N]) -> Result<(), EncodeError> { + self.writer.write(&val) + } } diff --git a/src/enc/impls.rs b/src/enc/impls.rs index 35c9a69..5d199e0 100644 --- a/src/enc/impls.rs +++ b/src/enc/impls.rs @@ -91,6 +91,18 @@ impl Encodeable for &'_ [u8] { } } +impl Encodeable for &'_ str { + fn encode(&self, mut encoder: E) -> Result<(), EncodeError> { + encoder.encode_slice(self.as_bytes()) + } +} + +impl Encodeable for [u8; N] { + fn encode(&self, mut encoder: E) -> Result<(), EncodeError> { + encoder.encode_array(*self) + } +} + impl<'a, T> Encode for &'a mut T where T: Encode, @@ -142,4 +154,7 @@ where fn encode_slice(&mut self, val: &[u8]) -> Result<(), EncodeError> { T::encode_slice(self, val) } + fn encode_array(&mut self, val: [u8; N]) -> Result<(), EncodeError> { + T::encode_array(self, val) + } } diff --git a/src/enc/mod.rs b/src/enc/mod.rs index ca5ebfe..ad997a3 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -29,4 +29,5 @@ pub trait Encode { fn encode_f32(&mut self, val: f32) -> Result<(), EncodeError>; fn encode_f64(&mut self, val: f64) -> Result<(), EncodeError>; fn encode_slice(&mut self, val: &[u8]) -> Result<(), EncodeError>; + fn encode_array(&mut self, val: [u8; N]) -> Result<(), EncodeError>; } diff --git a/src/error.rs b/src/error.rs index 47340c7..04db286 100644 --- a/src/error.rs +++ b/src/error.rs @@ -21,6 +21,8 @@ pub enum DecodeError { max: u32, found: u32, }, + + Utf8(core::str::Utf8Error), } #[non_exhaustive] diff --git a/src/lib.rs b/src/lib.rs index 38c0f95..3857395 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,12 +43,14 @@ pub fn encode_into_slice_with_config( Ok(encoder.into_writer().bytes_written()) } -pub fn decode(src: &mut [u8]) -> Result { +pub fn decode<'__de, D: de::Decodable<'__de>>( + src: &'__de mut [u8], +) -> Result { decode_with_config(src, config::Default) } -pub fn decode_with_config( - src: &mut [u8], +pub fn decode_with_config<'__de, D: de::Decodable<'__de>, C: Config>( + src: &'__de mut [u8], _config: C, ) -> Result { let reader = de::read::SliceReader::new(src); diff --git a/tests/derive.rs b/tests/derive.rs index e699aee..a33c2a2 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -1,4 +1,5 @@ use bincode::{de::Decodable, enc::Encodeable}; +use core::marker::PhantomData; #[derive(bincode::Encodable, PartialEq, Debug)] pub struct Test { @@ -8,10 +9,11 @@ pub struct Test { } #[derive(bincode::Decodable, PartialEq, Debug, Eq)] -pub struct Test2 { +pub struct Test2<'__de, T: Decodable<'__de>> { a: T, b: u32, c: u32, + pd: PhantomData<&'__de ()>, } #[derive(bincode::Encodable, bincode::Decodable, PartialEq, Debug, Eq)] @@ -43,6 +45,7 @@ fn test_decodable() { a: 5u32, b: 10u32, c: 1024u32, + pd: PhantomData, }; let mut slice = [5, 10, 251, 0, 4]; let result: Test2 = bincode::decode(&mut slice).unwrap(); diff --git a/tests/test.rs b/tests/test.rs index 2664ec1..2850355 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -5,7 +5,12 @@ use core::fmt::Debug; fn the_same_with_config(element: V, config: C) where - V: bincode::enc::Encodeable + bincode::de::Decodable + PartialEq + Debug + Clone + 'static, + V: bincode::enc::Encodeable + + for<'de> bincode::de::Decodable<'de> + + PartialEq + + Debug + + Clone + + 'static, C: Config, { let mut buffer = [0u8; 32]; @@ -16,7 +21,12 @@ where } fn the_same(element: V) where - V: bincode::enc::Encodeable + bincode::de::Decodable + PartialEq + Debug + Clone + 'static, + V: bincode::enc::Encodeable + + for<'de> bincode::de::Decodable<'de> + + PartialEq + + Debug + + Clone + + 'static, { the_same_with_config( element.clone(), @@ -61,3 +71,39 @@ fn test_numbers() { the_same(5.0f32); the_same(5.0f64); } + +#[test] +fn test_slice() { + let mut buffer = [0u8; 32]; + let input: &[u8] = &[1, 2, 3, 4, 5, 6, 7]; + bincode::encode_into_slice(input, &mut buffer).unwrap(); + assert_eq!(&buffer[..8], &[7, 1, 2, 3, 4, 5, 6, 7]); + + let output: &[u8] = bincode::decode(&mut buffer[..8]).unwrap(); + assert_eq!(input, output); +} + +#[test] +fn test_str() { + let mut buffer = [0u8; 32]; + let input: &str = "Hello world"; + bincode::encode_into_slice(input, &mut buffer).unwrap(); + assert_eq!( + &buffer[..12], + &[11, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100] + ); + + let output: &str = bincode::decode(&mut buffer[..12]).unwrap(); + assert_eq!(input, output); +} + +#[test] +fn test_array() { + let mut buffer = [0u8; 32]; + let input: [u8; 10] = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]; + bincode::encode_into_slice(input, &mut buffer).unwrap(); + assert_eq!(&buffer[..10], &[10, 20, 30, 40, 50, 60, 70, 80, 90, 100]); + + let output: [u8; 10] = bincode::decode(&mut buffer[..10]).unwrap(); + assert_eq!(input, output); +}