diff --git a/derive/src/derive_enum.rs b/derive/src/derive_enum.rs index b6c0508..c21052d 100644 --- a/derive/src/derive_enum.rs +++ b/derive/src/derive_enum.rs @@ -162,114 +162,124 @@ impl DeriveEnum { }); } - pub fn generate_decode(self, generator: &mut Generator) -> Result<()> { + pub fn generate_decode(&self, generator: &mut Generator) -> Result<()> { + // Remember to keep this mostly in sync with generate_borrow_decode + let enum_name = generator.target_name().to_string(); - if generator.has_lifetimes() { - // enum has a lifetime, implement BorrowDecode - - generator.impl_for_with_de_lifetime("bincode::de::BorrowDecode<'__de>") - .unwrap() - .generate_fn("borrow_decode") - .with_generic("D", ["bincode::de::BorrowDecoder<'__de>"]) - .with_arg("mut decoder", "D") - .with_return_type("core::result::Result") - .body(|fn_builder| { - fn_builder - .push_parsed("let variant_index = ::decode(&mut decoder)?;").unwrap(); - fn_builder.push_parsed("match variant_index").unwrap(); - fn_builder.group(Delimiter::Brace, |variant_case| { - for (mut variant_index, variant) in self.iter_fields() { - // idx => Ok(..) - if variant_index.len() > 1 { - variant_case.push_parsed("x if x == ").unwrap(); - variant_case.extend(variant_index); - } else { - variant_case.push(variant_index.remove(0)); - } - variant_case.puncts("=>"); - variant_case.ident_str("Ok"); - variant_case.group(Delimiter::Parenthesis, |variant_case_body| { - // Self::Variant { } - // Self::Variant { 0: ..., 1: ... 2: ... }, - // Self::Variant { a: ..., b: ... c: ... }, - variant_case_body.ident_str("Self"); - variant_case_body.puncts("::"); - variant_case_body.ident(variant.name.clone()); - - variant_case_body.group(Delimiter::Brace, |variant_body| { - let is_tuple = matches!(variant.fields, Fields::Tuple(_)); - for (idx, field) in variant.fields.names().into_iter().enumerate() { - if is_tuple { - variant_body.lit_usize(idx); - } else { - variant_body.ident(field.unwrap_ident().clone()); - } - variant_body.punct(':'); - variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap(); - } - }); - }); - variant_case.punct(','); - } - - // invalid idx - self.invalid_variant_case(&enum_name, variant_case); - }); - }).unwrap(); - } else { - // enum has no lifetimes, implement Decode - generator.impl_for("bincode::de::Decode") + generator + .impl_for("bincode::Decode") .unwrap() - .generate_fn("decode") - .with_generic("D", ["bincode::de::Decoder"]) - .with_arg("mut decoder", "D") - .with_return_type("core::result::Result") - .body(|fn_builder| { - fn_builder - .push_parsed("let variant_index = ::decode(&mut decoder)?;").unwrap(); - fn_builder.push_parsed("match variant_index").unwrap(); - fn_builder.group(Delimiter::Brace, |variant_case| { - for (mut variant_index, variant) in self.iter_fields() { - // idx => Ok(..) - if variant_index.len() > 1 { - variant_case.push_parsed("x if x == ").unwrap(); - variant_case.extend(variant_index); - } else { - variant_case.push(variant_index.remove(0)); - } - variant_case.puncts("=>"); - variant_case.ident_str("Ok"); - variant_case.group(Delimiter::Parenthesis, |variant_case_body| { - // Self::Variant { } - // Self::Variant { 0: ..., 1: ... 2: ... }, - // Self::Variant { a: ..., b: ... c: ... }, - variant_case_body.ident_str("Self"); - variant_case_body.puncts("::"); - variant_case_body.ident(variant.name.clone()); - - variant_case_body.group(Delimiter::Brace, |variant_body| { - let is_tuple = matches!(variant.fields, Fields::Tuple(_)); - for (idx, field) in variant.fields.names().into_iter().enumerate() { - if is_tuple { - variant_body.lit_usize(idx); - } else { - variant_body.ident(field.unwrap_ident().clone()); - } - variant_body.punct(':'); - variant_body.push_parsed("bincode::de::Decode::decode(&mut decoder)?,").unwrap(); - } - }); - }); - variant_case.punct(','); + .generate_fn("decode") + .with_generic("D", ["bincode::de::Decoder"]) + .with_arg("mut decoder", "D") + .with_return_type("core::result::Result") + .body(|fn_builder| { + fn_builder + .push_parsed( + "let variant_index = ::decode(&mut decoder)?;", + ) + .unwrap(); + fn_builder.push_parsed("match variant_index").unwrap(); + fn_builder.group(Delimiter::Brace, |variant_case| { + for (mut variant_index, variant) in self.iter_fields() { + // idx => Ok(..) + if variant_index.len() > 1 { + variant_case.push_parsed("x if x == ").unwrap(); + variant_case.extend(variant_index); + } else { + variant_case.push(variant_index.remove(0)); } + variant_case.puncts("=>"); + variant_case.ident_str("Ok"); + variant_case.group(Delimiter::Parenthesis, |variant_case_body| { + // Self::Variant { } + // Self::Variant { 0: ..., 1: ... 2: ... }, + // Self::Variant { a: ..., b: ... c: ... }, + variant_case_body.ident_str("Self"); + variant_case_body.puncts("::"); + variant_case_body.ident(variant.name.clone()); - // invalid idx - self.invalid_variant_case(&enum_name, variant_case); - }); - }).unwrap(); - } + variant_case_body.group(Delimiter::Brace, |variant_body| { + let is_tuple = matches!(variant.fields, Fields::Tuple(_)); + for (idx, field) in variant.fields.names().into_iter().enumerate() { + if is_tuple { + variant_body.lit_usize(idx); + } else { + variant_body.ident(field.unwrap_ident().clone()); + } + variant_body.punct(':'); + variant_body + .push_parsed("bincode::Decode::decode(&mut decoder)?,") + .unwrap(); + } + }); + }); + variant_case.punct(','); + } + // invalid idx + self.invalid_variant_case(&enum_name, variant_case); + }); + }) + .unwrap(); + Ok(()) + } + + pub fn generate_borrow_decode(self, generator: &mut Generator) -> Result<()> { + // Remember to keep this mostly in sync with generate_decode + + let enum_name = generator.target_name().to_string(); + + generator.impl_for_with_de_lifetime("bincode::de::BorrowDecode<'__de>") + .unwrap() + .generate_fn("borrow_decode") + .with_generic("D", ["bincode::de::BorrowDecoder<'__de>"]) + .with_arg("mut decoder", "D") + .with_return_type("core::result::Result") + .body(|fn_builder| { + fn_builder + .push_parsed("let variant_index = ::decode(&mut decoder)?;").unwrap(); + fn_builder.push_parsed("match variant_index").unwrap(); + fn_builder.group(Delimiter::Brace, |variant_case| { + for (mut variant_index, variant) in self.iter_fields() { + // idx => Ok(..) + if variant_index.len() > 1 { + variant_case.push_parsed("x if x == ").unwrap(); + variant_case.extend(variant_index); + } else { + variant_case.push(variant_index.remove(0)); + } + variant_case.puncts("=>"); + variant_case.ident_str("Ok"); + variant_case.group(Delimiter::Parenthesis, |variant_case_body| { + // Self::Variant { } + // Self::Variant { 0: ..., 1: ... 2: ... }, + // Self::Variant { a: ..., b: ... c: ... }, + variant_case_body.ident_str("Self"); + variant_case_body.puncts("::"); + variant_case_body.ident(variant.name.clone()); + + variant_case_body.group(Delimiter::Brace, |variant_body| { + let is_tuple = matches!(variant.fields, Fields::Tuple(_)); + for (idx, field) in variant.fields.names().into_iter().enumerate() { + if is_tuple { + variant_body.lit_usize(idx); + } else { + variant_body.ident(field.unwrap_ident().clone()); + } + variant_body.punct(':'); + variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap(); + } + }); + }); + variant_case.punct(','); + } + + // invalid idx + self.invalid_variant_case(&enum_name, variant_case); + }); + }).unwrap(); Ok(()) } } diff --git a/derive/src/derive_struct.rs b/derive/src/derive_struct.rs index d341e63..617106f 100644 --- a/derive/src/derive_struct.rs +++ b/derive/src/derive_struct.rs @@ -36,74 +36,74 @@ impl DeriveStruct { } pub fn generate_decode(self, generator: &mut Generator) -> Result<()> { + // Remember to keep this mostly in sync with generate_borrow_decode let DeriveStruct { fields } = self; - if generator.has_lifetimes() { - // struct has a lifetime, implement BorrowDecode + generator + .impl_for("bincode::Decode") + .unwrap() + .generate_fn("decode") + .with_generic("D", ["bincode::de::Decoder"]) + .with_arg("mut decoder", "D") + .with_return_type("core::result::Result") + .body(|fn_body| { + // Ok(Self { + fn_body.ident_str("Ok"); + fn_body.group(Delimiter::Parenthesis, |ok_group| { + ok_group.ident_str("Self"); + ok_group.group(Delimiter::Brace, |struct_body| { + // Fields + // { + // a: bincode::Decode::decode(&mut decoder)?, + // b: bincode::Decode::decode(&mut decoder)?, + // ... + // } + for field in fields.names() { + struct_body + .push_parsed(format!( + "{}: bincode::Decode::decode(&mut decoder)?,", + field.to_string() + )) + .unwrap(); + } + }); + }); + }) + .unwrap(); - generator - .impl_for_with_de_lifetime("bincode::de::BorrowDecode<'__de>") - .unwrap() - .generate_fn("borrow_decode") - .with_generic("D", ["bincode::de::BorrowDecoder<'__de>"]) - .with_arg("mut decoder", "D") - .with_return_type("core::result::Result") - .body(|fn_body| { - // Ok(Self { - fn_body.ident_str("Ok"); - fn_body.group(Delimiter::Parenthesis, |ok_group| { - ok_group.ident_str("Self"); - ok_group.group(Delimiter::Brace, |struct_body| { - for field in fields.names() { - struct_body - .push_parsed(format!( + Ok(()) + } + + pub fn generate_borrow_decode(self, generator: &mut Generator) -> Result<()> { + // Remember to keep this mostly in sync with generate_decode + let DeriveStruct { fields } = self; + + generator + .impl_for_with_de_lifetime("bincode::de::BorrowDecode<'__de>") + .unwrap() + .generate_fn("borrow_decode") + .with_generic("D", ["bincode::de::BorrowDecoder<'__de>"]) + .with_arg("mut decoder", "D") + .with_return_type("core::result::Result") + .body(|fn_body| { + // Ok(Self { + fn_body.ident_str("Ok"); + fn_body.group(Delimiter::Parenthesis, |ok_group| { + ok_group.ident_str("Self"); + ok_group.group(Delimiter::Brace, |struct_body| { + for field in fields.names() { + struct_body + .push_parsed(format!( "{}: bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,", field.to_string() )) - .unwrap(); - } - }); + .unwrap(); + } }); - }) - .unwrap(); + }); + }) + .unwrap(); - Ok(()) - } else { - // struct has no lifetimes, implement Decode - - generator - .impl_for("bincode::de::Decode") - .unwrap() - .generate_fn("decode") - .with_generic("D", ["bincode::de::Decoder"]) - .with_arg("mut decoder", "D") - .with_return_type("core::result::Result") - .body(|fn_body| { - // Ok(Self { - fn_body.ident_str("Ok"); - fn_body.group(Delimiter::Parenthesis, |ok_group| { - ok_group.ident_str("Self"); - ok_group.group(Delimiter::Brace, |struct_body| { - // Fields - // { - // a: bincode::de::Decode::decode(&mut decoder)?, - // b: bincode::de::Decode::decode(&mut decoder)?, - // ... - // } - for field in fields.names() { - struct_body - .push_parsed(format!( - "{}: bincode::de::Decode::decode(&mut decoder)?,", - field.to_string() - )) - .unwrap(); - } - }); - }); - }) - .unwrap(); - - Ok(()) - } + Ok(()) } } diff --git a/derive/src/generate/generator.rs b/derive/src/generate/generator.rs index 7cc0bc0..957ba9a 100644 --- a/derive/src/generate/generator.rs +++ b/derive/src/generate/generator.rs @@ -43,14 +43,6 @@ impl Generator { ImplFor::new_with_de_lifetime(self, trait_name) } - /// Returns `true` if the struct or enum has lifetimes. - pub fn has_lifetimes(&self) -> bool { - self.generics - .as_ref() - .map(|g| g.has_lifetime()) - .unwrap_or(false) - } - /// Consume the contents of this generator. This *must* be called, or else the generator will panic on drop. pub fn take_stream(mut self) -> TokenStream { std::mem::take(&mut self.stream).stream diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 1560208..b0b1cd7 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -102,6 +102,47 @@ fn derive_decode_inner(input: TokenStream) -> Result { Ok(stream) } +#[proc_macro_derive(BorrowDecode)] +pub fn derive_brrow_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + #[allow(clippy::useless_conversion)] + derive_borrow_decode_inner(input.into()) + .unwrap_or_else(|e| e.into_token_stream()) + .into() +} + +fn derive_borrow_decode_inner(input: TokenStream) -> Result { + let source = &mut input.into_iter().peekable(); + + let _attributes = parse::Attribute::try_take(source)?; + let _visibility = parse::Visibility::try_take(source)?; + let (datatype, name) = parse::DataType::take(source)?; + let generics = parse::Generics::try_take(source)?; + let generic_constraints = parse::GenericConstraints::try_take(source)?; + + let mut generator = generate::Generator::new(name.clone(), generics, generic_constraints); + + match datatype { + parse::DataType::Struct => { + let body = parse::StructBody::take(source)?; + derive_struct::DeriveStruct { + fields: body.fields, + } + .generate_borrow_decode(&mut generator)?; + } + parse::DataType::Enum => { + let body = parse::EnumBody::take(source)?; + derive_enum::DeriveEnum { + variants: body.variants, + } + .generate_borrow_decode(&mut generator)?; + } + } + + let stream = generator.take_stream(); + dump_output(name, "BorrowDecode", &stream); + Ok(stream) +} + fn dump_output(name: crate::prelude::Ident, derive: &str, stream: &crate::prelude::TokenStream) { use std::io::Write; diff --git a/src/features/derive.rs b/src/features/derive.rs index 57b3a58..1d07ba1 100644 --- a/src/features/derive.rs +++ b/src/features/derive.rs @@ -1,2 +1,2 @@ #[cfg_attr(docsrs, doc(cfg(feature = "derive")))] -pub use bincode_derive::{Decode, Encode}; +pub use bincode_derive::{BorrowDecode, Decode, Encode}; diff --git a/src/features/impl_alloc.rs b/src/features/impl_alloc.rs index f33a786..b916310 100644 --- a/src/features/impl_alloc.rs +++ b/src/features/impl_alloc.rs @@ -6,7 +6,14 @@ use crate::{ }; #[cfg(feature = "atomic")] use alloc::sync::Arc; -use alloc::{borrow::Cow, boxed::Box, collections::*, rc::Rc, string::String, vec::Vec}; +use alloc::{ + borrow::{Cow, ToOwned}, + boxed::Box, + collections::*, + rc::Rc, + string::String, + vec::Vec, +}; #[derive(Default)] struct VecWriter { @@ -229,12 +236,27 @@ where } } +// BlockedTODO: https://github.com/rust-lang/rust/issues/31844 +// Cow should be able to decode a borrowed value +// Currently this conflicts with the owned `Decode` implementation below + +// impl<'cow, T> BorrowDecode<'cow> for Cow<'cow, T> +// where +// T: BorrowDecode<'cow>, +// { +// fn borrow_decode>(decoder: D) -> Result { +// let t = T::borrow_decode(decoder)?; +// Ok(Cow::Borrowed(t)) +// } +// } + impl<'cow, T> Decode for Cow<'cow, T> where - T: Decode + Clone, + T: ToOwned, + ::Owned: Decode, { fn decode(decoder: D) -> Result { - let t = T::decode(decoder)?; + let t = ::Owned::decode(decoder)?; Ok(Cow::Owned(t)) } } diff --git a/src/lib.rs b/src/lib.rs index 512f8a1..476df17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ //! |std | Yes ||`decode_from_reader` and `encode_into_writer`| //! |alloc | Yes |All common containers in alloc, like `Vec`, `String`, `Box`|`encode_to_vec`| //! |atomic| Yes |All `Atomic*` integer types, e.g. `AtomicUsize`, and `AtomicBool`|| -//! |derive| Yes |||Enables the `Encode` and `Decode` derive macro| +//! |derive| Yes |||Enables the `BorrowDecode`, `Decode` and `Encode` derive macros| //! |serde | No |TODO|TODO|TODO| //! //! # Example @@ -77,6 +77,9 @@ pub mod de; pub mod enc; pub mod error; +pub use de::{BorrowDecode, Decode}; +pub use enc::Encode; + use config::Config; /// Encode the given value into the given slice. Returns the amount of bytes that have been written. diff --git a/tests/alloc.rs b/tests/alloc.rs index baad469..1a2c457 100644 --- a/tests/alloc.rs +++ b/tests/alloc.rs @@ -28,13 +28,13 @@ impl bincode::enc::Encode for Foo { } } -impl bincode::de::Decode for Foo { +impl bincode::Decode for Foo { fn decode( mut decoder: D, ) -> Result { Ok(Self { - a: bincode::de::Decode::decode(&mut decoder)?, - b: bincode::de::Decode::decode(&mut decoder)?, + a: bincode::Decode::decode(&mut decoder)?, + b: bincode::Decode::decode(&mut decoder)?, }) } } diff --git a/tests/derive.rs b/tests/derive.rs index 6501235..e06c80e 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -17,7 +17,7 @@ pub struct Test2 { c: u32, } -#[derive(bincode::Decode, PartialEq, Debug, Eq)] +#[derive(bincode::BorrowDecode, PartialEq, Debug, Eq)] pub struct Test3<'a> { a: &'a str, b: u32, @@ -34,7 +34,7 @@ pub enum TestEnum { Baz(u32, u32, u32), } -#[derive(bincode::Encode, bincode::Decode, PartialEq, Debug, Eq)] +#[derive(bincode::Encode, bincode::BorrowDecode, PartialEq, Debug, Eq)] pub enum TestEnum2<'a> { Foo, Bar { name: &'a str }, diff --git a/tests/issues.rs b/tests/issues.rs index b7fa4b3..51a8ebd 100644 --- a/tests/issues.rs +++ b/tests/issues.rs @@ -1,4 +1,7 @@ #![no_std] +#[path = "issues/issue_431.rs"] +mod issue_431; + #[path = "issues/issue_427.rs"] mod issue_427; diff --git a/tests/issues/issue_431.rs b/tests/issues/issue_431.rs new file mode 100644 index 0000000..cd4537f --- /dev/null +++ b/tests/issues/issue_431.rs @@ -0,0 +1,32 @@ +#![cfg(all(feature = "std", feature = "derive"))] + +extern crate std; + +use bincode::{config::Configuration, Decode, Encode}; +use std::borrow::Cow; +use std::string::String; + +#[derive(Encode, Decode, PartialEq, Debug)] +struct T<'a, A: Clone + Encode + Decode> { + t: Cow<'a, U<'a, A>>, +} + +#[derive(Clone, Encode, Decode, PartialEq, Debug)] +struct U<'a, A: Clone + Encode + Decode> { + u: Cow<'a, A>, +} + +#[test] +fn test() { + let u = U { + u: Cow::Owned(String::from("Hello world")), + }; + let t = T { + t: Cow::Borrowed(&u), + }; + let vec = bincode::encode_to_vec(&t, Configuration::standard()).unwrap(); + + let decoded: T = bincode::decode_from_slice(&vec, Configuration::standard()).unwrap(); + + assert_eq!(t, decoded); +} diff --git a/tests/std.rs b/tests/std.rs index 66b4e7b..7a669a5 100644 --- a/tests/std.rs +++ b/tests/std.rs @@ -30,13 +30,13 @@ impl bincode::enc::Encode for Foo { } } -impl bincode::de::Decode for Foo { +impl bincode::Decode for Foo { fn decode( mut decoder: D, ) -> Result { Ok(Self { - a: bincode::de::Decode::decode(&mut decoder)?, - b: bincode::de::Decode::decode(&mut decoder)?, + a: bincode::Decode::decode(&mut decoder)?, + b: bincode::Decode::decode(&mut decoder)?, }) } } diff --git a/tests/utils.rs b/tests/utils.rs index 015ab1b..7a64ad7 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -3,7 +3,7 @@ use core::fmt::Debug; fn the_same_with_config(element: &V, config: C, cmp: CMP) where - V: bincode::enc::Encode + bincode::de::Decode + Debug + 'static, + V: bincode::enc::Encode + bincode::Decode + Debug + 'static, C: Config, CMP: Fn(&V, &V) -> bool, { @@ -28,7 +28,7 @@ where pub fn the_same_with_comparer(element: V, cmp: CMP) where - V: bincode::enc::Encode + bincode::de::Decode + Debug + 'static, + V: bincode::enc::Encode + bincode::Decode + Debug + 'static, CMP: Fn(&V, &V) -> bool, { // A matrix of each different config option possible @@ -101,7 +101,7 @@ where #[allow(dead_code)] // This is not used in every test pub fn the_same(element: V) where - V: bincode::enc::Encode + bincode::de::Decode + PartialEq + Debug + 'static, + V: bincode::enc::Encode + bincode::Decode + PartialEq + Debug + 'static, { the_same_with_comparer(element, |a, b| a == b); }