diff --git a/src/config/legacy.rs b/src/config/legacy.rs index 2d4e51a..a9ac065 100644 --- a/src/config/legacy.rs +++ b/src/config/legacy.rs @@ -49,6 +49,7 @@ macro_rules! config_map { (Unlimited, Little) => { let $opts = DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .with_no_limit() .with_little_endian(); $call @@ -56,6 +57,7 @@ macro_rules! config_map { (Unlimited, Big) => { let $opts = DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .with_no_limit() .with_big_endian(); $call @@ -63,6 +65,7 @@ macro_rules! config_map { (Unlimited, Native) => { let $opts = DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .with_no_limit() .with_native_endian(); $call @@ -71,6 +74,7 @@ macro_rules! config_map { (Limited(l), Little) => { let $opts = DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .with_limit(l) .with_little_endian(); $call @@ -78,6 +82,7 @@ macro_rules! config_map { (Limited(l), Big) => { let $opts = DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .with_limit(l) .with_big_endian(); $call @@ -85,6 +90,7 @@ macro_rules! config_map { (Limited(l), Native) => { let $opts = DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .with_limit(l) .with_native_endian(); $call diff --git a/src/config/mod.rs b/src/config/mod.rs index 3520ccf..9fdc154 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -8,16 +8,19 @@ pub(crate) use self::endian::BincodeByteOrder; pub(crate) use self::int::IntEncoding; pub(crate) use self::internal::*; pub(crate) use self::limit::SizeLimit; +pub(crate) use self::trailing::TrailingBytes; pub use self::endian::{BigEndian, LittleEndian, NativeEndian}; pub use self::int::{FixintEncoding, VarintEncoding}; pub use self::legacy::*; pub use self::limit::{Bounded, Infinite}; +pub use self::trailing::{AllowTrailing, RejectTrailing}; mod endian; mod int; mod legacy; mod limit; +mod trailing; /// The default options for bincode serialization/deserialization. /// @@ -50,6 +53,7 @@ impl InternalOptions for DefaultOptions { type Limit = Infinite; type Endian = LittleEndian; type IntEncoding = VarintEncoding; + type Trailing = RejectTrailing; #[inline(always)] fn limit(&mut self) -> &mut Infinite { @@ -111,6 +115,16 @@ pub trait Options: InternalOptions + Sized { WithOtherIntEncoding::new(self) } + /// Sets the deserializer to reject trailing bytes + fn reject_trailing_bytes(self) -> WithOtherTrailing { + WithOtherTrailing::new(self) + } + + /// Sets the deserializer to allow trailing bytes + fn allow_trailing_bytes(self) -> WithOtherTrailing { + WithOtherTrailing::new(self) + } + /// Serializes a serializable object into a `Vec` of bytes using this configuration #[inline(always)] fn serialize(self, t: &S) -> Result> { @@ -229,6 +243,12 @@ pub struct WithOtherIntEncoding { _length: PhantomData, } +/// A configuration struct with a user-specified trailing bytes behavior. +pub struct WithOtherTrailing { + options: O, + _trailing: PhantomData, +} + impl WithOtherLimit { #[inline(always)] pub(crate) fn new(options: O, limit: L) -> WithOtherLimit { @@ -259,11 +279,21 @@ impl WithOtherIntEncoding { } } +impl WithOtherTrailing { + #[inline(always)] + pub(crate) fn new(options: O) -> WithOtherTrailing { + WithOtherTrailing { + options, + _trailing: PhantomData, + } + } +} + impl InternalOptions for WithOtherEndian { type Limit = O::Limit; type Endian = E; type IntEncoding = O::IntEncoding; - + type Trailing = O::Trailing; #[inline(always)] fn limit(&mut self) -> &mut O::Limit { self.options.limit() @@ -274,7 +304,7 @@ impl InternalOptions for WithOtherLimit &mut L { &mut self.new_limit } @@ -284,6 +314,18 @@ impl InternalOptions for WithOtherIntEncod type Limit = O::Limit; type Endian = O::Endian; type IntEncoding = I; + type Trailing = O::Trailing; + + fn limit(&mut self) -> &mut O::Limit { + self.options.limit() + } +} + +impl InternalOptions for WithOtherTrailing { + type Limit = O::Limit; + type Endian = O::Endian; + type IntEncoding = O::IntEncoding; + type Trailing = T; fn limit(&mut self) -> &mut O::Limit { self.options.limit() @@ -297,6 +339,7 @@ mod internal { type Limit: SizeLimit + 'static; type Endian: BincodeByteOrder + 'static; type IntEncoding: IntEncoding + 'static; + type Trailing: TrailingBytes + 'static; fn limit(&mut self) -> &mut Self::Limit; } @@ -305,6 +348,7 @@ mod internal { type Limit = O::Limit; type Endian = O::Endian; type IntEncoding = O::IntEncoding; + type Trailing = O::Trailing; #[inline(always)] fn limit(&mut self) -> &mut Self::Limit { diff --git a/src/config/trailing.rs b/src/config/trailing.rs new file mode 100644 index 0000000..413bd08 --- /dev/null +++ b/src/config/trailing.rs @@ -0,0 +1,37 @@ +use de::read::SliceReader; +use {ErrorKind, Result}; + +/// A trait for erroring deserialization if not all bytes were read. +pub trait TrailingBytes { + /// Checks a given slice reader to determine if deserialization used all bytes in the slice. + fn check_end(reader: &SliceReader) -> Result<()>; +} + +/// A TrailingBytes config that will allow trailing bytes in slices after deserialization. +#[derive(Copy, Clone)] +pub struct AllowTrailing; + +/// A TrailingBytes config that will cause bincode to produce an error if bytes are left over in the slice when deserialization is complete. + +#[derive(Copy, Clone)] +pub struct RejectTrailing; + +impl TrailingBytes for AllowTrailing { + #[inline(always)] + fn check_end(_reader: &SliceReader) -> Result<()> { + Ok(()) + } +} + +impl TrailingBytes for RejectTrailing { + #[inline(always)] + fn check_end(reader: &SliceReader) -> Result<()> { + if reader.is_finished() { + Ok(()) + } else { + Err(Box::new(ErrorKind::Custom( + "Slice had bytes remaining after deserialization".to_string(), + ))) + } + } +} diff --git a/src/de/mod.rs b/src/de/mod.rs index 080ea24..e1179ea 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -26,7 +26,7 @@ pub mod read; /// let bytes_read = d.bytes_read(); /// ``` pub struct Deserializer { - reader: R, + pub(crate) reader: R, options: O, } diff --git a/src/de/read.rs b/src/de/read.rs index fdb6a1d..4179ad8 100644 --- a/src/de/read.rs +++ b/src/de/read.rs @@ -52,6 +52,10 @@ impl<'storage> SliceReader<'storage> { self.slice = remaining; Ok(read_slice) } + + pub(crate) fn is_finished(&self) -> bool { + self.slice.is_empty() + } } impl IoReader { diff --git a/src/internal.rs b/src/internal.rs index be13d31..ac7ee55 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -2,7 +2,7 @@ use serde; use std::io::{Read, Write}; use std::marker::PhantomData; -use config::{Infinite, InternalOptions, Options, SizeLimit}; +use config::{Infinite, InternalOptions, Options, SizeLimit, TrailingBytes}; use de::read::BincodeRead; use Result; @@ -111,7 +111,14 @@ where T: serde::de::DeserializeSeed<'a>, O: InternalOptions, { - let reader = ::de::read::SliceReader::new(bytes); let options = ::config::WithOtherLimit::new(options, Infinite); - deserialize_from_custom_seed(seed, reader, options) + + let reader = ::de::read::SliceReader::new(bytes); + let mut deserializer = ::de::Deserializer::with_bincode_read(reader, options); + let val = seed.deserialize(&mut deserializer)?; + + match O::Trailing::check_end(&deserializer.reader) { + Ok(_) => Ok(val), + Err(err) => Err(err), + } } diff --git a/src/lib.rs b/src/lib.rs index 6241811..f8c1968 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,7 @@ where { DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .serialize(value) } @@ -107,6 +108,7 @@ where { DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .deserialize_from(reader) } @@ -122,6 +124,7 @@ where { DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .deserialize_from_custom(reader) } @@ -136,6 +139,7 @@ where { DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .deserialize_in_place(reader, place) } @@ -146,6 +150,7 @@ where { DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .deserialize(bytes) } @@ -156,5 +161,6 @@ where { DefaultOptions::new() .with_fixint_encoding() + .allow_trailing_bytes() .serialized_size(value) } diff --git a/tests/test.rs b/tests/test.rs index 4debc73..fd6c829 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -277,6 +277,17 @@ fn deserializing_errors() { } } +#[test] +fn trailing_bytes() { + match DefaultOptions::new() + .deserialize::(b"1x") + .map_err(|e| *e) + { + Err(ErrorKind::Custom(_)) => {} + other => panic!("Expecting TrailingBytes, got {:?}", other), + } +} + #[test] fn too_big_deserialize() { let serialized = vec![0, 0, 0, 3];