From 63f46079925b24fdc922f3db76688b7417407241 Mon Sep 17 00:00:00 2001 From: Trangar Date: Sat, 11 Dec 2021 15:44:43 +0100 Subject: [PATCH] Feature/config limit (#439) * Added Limit and NoLimit to the configuration * Added a limit check to Decoder and DecoderImpl * Added test cases, added a helper function specialized for containers * Added a test to see if inlining makes the limit config faster, added inlining to the decoder --- Cargo.toml | 4 ++ benches/inline.rs | 18 +++++++++ src/config.rs | 66 ++++++++++++++++++++++++++++----- src/de/decoder.rs | 37 ++++++++++++++++++- src/de/impls.rs | 31 ++++++++++++++-- src/de/mod.rs | 75 +++++++++++++++++++++++++++++++++++++- src/error.rs | 3 ++ src/features/impl_alloc.rs | 25 +++++++++++++ src/features/impl_std.rs | 5 +++ tests/alloc.rs | 38 +++++++++++++++++++ 10 files changed, 287 insertions(+), 15 deletions(-) create mode 100644 benches/inline.rs diff --git a/Cargo.toml b/Cargo.toml index 2f5ae5d..6a98f8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,10 @@ rand = "0.8" name = "varint" harness = false +[[bench]] +name = "inline" +harness = false + [profile.bench] codegen-units = 1 debug = 1 diff --git a/benches/inline.rs b/benches/inline.rs new file mode 100644 index 0000000..dbb5fec --- /dev/null +++ b/benches/inline.rs @@ -0,0 +1,18 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use bincode::config::Configuration; + +fn inline_decoder_claim_bytes_read(c: &mut Criterion) { + let config = Configuration::standard().with_limit::<100000>(); + let slice = bincode::encode_to_vec(vec![String::from("Hello world"); 1000], config).unwrap(); + + c.bench_function("inline_decoder_claim_bytes_read", |b| { + b.iter(|| { + let _: Vec = + black_box(bincode::decode_from_slice(black_box(&slice), config).unwrap()); + }) + }); +} + +criterion_group!(benches, inline_decoder_claim_bytes_read); +criterion_main!(benches); diff --git a/src/config.rs b/src/config.rs index 7ff60a9..aec6a07 100644 --- a/src/config.rs +++ b/src/config.rs @@ -40,10 +40,11 @@ use core::marker::PhantomData; /// [skip_fixed_array_length]: #method.skip_fixed_array_length /// [write_fixed_array_length]: #method.write_fixed_array_length #[derive(Copy, Clone)] -pub struct Configuration { +pub struct Configuration { _e: PhantomData, _i: PhantomData, _a: PhantomData, + _l: PhantomData, } impl Configuration { @@ -59,17 +60,18 @@ impl Configuration { /// - Little endian /// - Fixed int length encoding /// - Write array lengths - pub fn legacy() -> Configuration { + pub fn legacy() -> Configuration { Self::generate() } } -impl Configuration { - fn generate<_E, _I, _A>() -> Configuration<_E, _I, _A> { +impl Configuration { + fn generate<_E, _I, _A, _L>() -> Configuration<_E, _I, _A, _L> { Configuration { _e: PhantomData, _i: PhantomData, _a: PhantomData, + _l: PhantomData, } } @@ -162,16 +164,36 @@ impl Configuration { pub fn write_fixed_array_length(self) -> Configuration { Self::generate() } + + /// Sets the byte limit to `limit`. + pub fn with_limit(self) -> Configuration> { + Self::generate() + } + + /// Clear the byte limit. + pub fn with_no_limit(self) -> Configuration { + Self::generate() + } } /// Indicates a type is valid for controlling the bincode configuration pub trait Config: - InternalEndianConfig + InternalArrayLengthConfig + InternalIntEncodingConfig + Copy + Clone + InternalEndianConfig + + InternalArrayLengthConfig + + InternalIntEncodingConfig + + InternalLimitConfig + + Copy + + Clone { } impl Config for T where - T: InternalEndianConfig + InternalArrayLengthConfig + InternalIntEncodingConfig + Copy + Clone + T: InternalEndianConfig + + InternalArrayLengthConfig + + InternalIntEncodingConfig + + InternalLimitConfig + + Copy + + Clone { } @@ -223,6 +245,20 @@ impl InternalArrayLengthConfig for WriteFixedArrayLength { const SKIP_FIXED_ARRAY_LENGTH: bool = false; } +#[doc(hidden)] +#[derive(Copy, Clone)] +pub struct NoLimit {} +impl InternalLimitConfig for NoLimit { + const LIMIT: Option = None; +} + +#[doc(hidden)] +#[derive(Copy, Clone)] +pub struct Limit {} +impl InternalLimitConfig for Limit { + const LIMIT: Option = Some(N); +} + mod internal { use super::Configuration; @@ -230,7 +266,7 @@ mod internal { const ENDIAN: Endian; } - impl InternalEndianConfig for Configuration { + impl InternalEndianConfig for Configuration { const ENDIAN: Endian = E::ENDIAN; } @@ -244,7 +280,9 @@ mod internal { const INT_ENCODING: IntEncoding; } - impl InternalIntEncodingConfig for Configuration { + impl InternalIntEncodingConfig + for Configuration + { const INT_ENCODING: IntEncoding = I::INT_ENCODING; } @@ -258,7 +296,17 @@ mod internal { const SKIP_FIXED_ARRAY_LENGTH: bool; } - impl InternalArrayLengthConfig for Configuration { + impl InternalArrayLengthConfig + for Configuration + { const SKIP_FIXED_ARRAY_LENGTH: bool = A::SKIP_FIXED_ARRAY_LENGTH; } + + pub trait InternalLimitConfig { + const LIMIT: Option; + } + + impl InternalLimitConfig for Configuration { + const LIMIT: Option = L::LIMIT; + } } diff --git a/src/de/decoder.rs b/src/de/decoder.rs index df2cab4..d68ad7c 100644 --- a/src/de/decoder.rs +++ b/src/de/decoder.rs @@ -2,7 +2,7 @@ use super::{ read::{BorrowReader, Reader}, BorrowDecoder, Decoder, }; -use crate::{config::Config, utils::Sealed}; +use crate::{config::Config, error::DecodeError, utils::Sealed}; /// A Decoder that reads bytes from a given reader `R`. /// @@ -24,12 +24,17 @@ use crate::{config::Config, utils::Sealed}; pub struct DecoderImpl { reader: R, config: C, + bytes_read: usize, } impl DecoderImpl { /// Construct a new Decoder pub fn new(reader: R, config: C) -> DecoderImpl { - DecoderImpl { reader, config } + DecoderImpl { + reader, + config, + bytes_read: 0, + } } } @@ -55,4 +60,32 @@ impl Decoder for DecoderImpl { fn config(&self) -> &Self::C { &self.config } + + #[inline] + fn claim_bytes_read(&mut self, n: usize) -> Result<(), DecodeError> { + // C::LIMIT is a const so this check should get compiled away + if let Some(limit) = C::LIMIT { + // Make sure we don't accidentally overflow `bytes_read` + self.bytes_read = self + .bytes_read + .checked_add(n) + .ok_or(DecodeError::LimitExceeded)?; + if self.bytes_read > limit { + Err(DecodeError::LimitExceeded) + } else { + Ok(()) + } + } else { + Ok(()) + } + } + + #[inline] + fn unclaim_bytes_read(&mut self, n: usize) { + // C::LIMIT is a const so this check should get compiled away + if C::LIMIT.is_some() { + // We should always be claiming more than we unclaim, so this should never underflow + self.bytes_read -= n; + } + } } diff --git a/src/de/impls.rs b/src/de/impls.rs index e0f2605..4bd5b64 100644 --- a/src/de/impls.rs +++ b/src/de/impls.rs @@ -33,6 +33,7 @@ impl Decode for bool { impl Decode for u8 { #[inline] fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(1)?; if let Some(buf) = decoder.reader().peek_read(1) { let byte = buf[0]; decoder.reader().consume(1); @@ -55,6 +56,7 @@ impl Decode for NonZeroU8 { impl Decode for u16 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(2)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_u16(decoder.reader(), D::C::ENDIAN) @@ -81,6 +83,7 @@ impl Decode for NonZeroU16 { impl Decode for u32 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(4)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_u32(decoder.reader(), D::C::ENDIAN) @@ -107,6 +110,7 @@ impl Decode for NonZeroU32 { impl Decode for u64 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(8)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_u64(decoder.reader(), D::C::ENDIAN) @@ -133,6 +137,7 @@ impl Decode for NonZeroU64 { impl Decode for u128 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(16)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_u128(decoder.reader(), D::C::ENDIAN) @@ -159,6 +164,7 @@ impl Decode for NonZeroU128 { impl Decode for usize { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(8)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_usize(decoder.reader(), D::C::ENDIAN) @@ -185,6 +191,7 @@ impl Decode for NonZeroUsize { impl Decode for i8 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(1)?; let mut bytes = [0u8; 1]; decoder.reader().read(&mut bytes)?; Ok(bytes[0] as i8) @@ -201,6 +208,7 @@ impl Decode for NonZeroI8 { impl Decode for i16 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(2)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_i16(decoder.reader(), D::C::ENDIAN) @@ -227,6 +235,7 @@ impl Decode for NonZeroI16 { impl Decode for i32 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(4)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_i32(decoder.reader(), D::C::ENDIAN) @@ -253,6 +262,7 @@ impl Decode for NonZeroI32 { impl Decode for i64 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(8)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_i64(decoder.reader(), D::C::ENDIAN) @@ -279,6 +289,7 @@ impl Decode for NonZeroI64 { impl Decode for i128 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(16)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_i128(decoder.reader(), D::C::ENDIAN) @@ -305,6 +316,7 @@ impl Decode for NonZeroI128 { impl Decode for isize { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(8)?; match D::C::INT_ENCODING { IntEncoding::Variable => { crate::varint::varint_decode_isize(decoder.reader(), D::C::ENDIAN) @@ -331,6 +343,7 @@ impl Decode for NonZeroIsize { impl Decode for f32 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(4)?; let mut bytes = [0u8; 4]; decoder.reader().read(&mut bytes)?; Ok(match D::C::ENDIAN { @@ -342,6 +355,7 @@ impl Decode for f32 { impl Decode for f64 { fn decode(mut decoder: D) -> Result { + decoder.claim_bytes_read(8)?; let mut bytes = [0u8; 8]; decoder.reader().read(&mut bytes)?; Ok(match D::C::ENDIAN { @@ -362,6 +376,10 @@ impl Decode for char { if width == 0 { return Err(DecodeError::InvalidCharEncoding(array)); } + // Normally we have to `.claim_bytes_read` before reading, however in this + // case the amount of bytes read from `char` can vary wildly, and it should + // only read up to 4 bytes too much. + decoder.claim_bytes_read(width)?; if width == 1 { return Ok(array[0] as char); } @@ -379,6 +397,7 @@ impl Decode for char { impl<'a, 'de: 'a> BorrowDecode<'de> for &'a [u8] { fn borrow_decode>(mut decoder: D) -> Result { let len = super::decode_slice_len(&mut decoder)?; + decoder.claim_bytes_read(len)?; decoder.borrow_reader().take_bytes(len) } } @@ -397,7 +416,7 @@ impl<'a, 'de: 'a> BorrowDecode<'de> for Option<&'a [u8]> { impl<'a, 'de: 'a> BorrowDecode<'de> for &'a str { fn borrow_decode>(decoder: D) -> Result { - let slice: &[u8] = BorrowDecode::borrow_decode(decoder)?; + let slice = <&[u8]>::borrow_decode(decoder)?; core::str::from_utf8(slice).map_err(DecodeError::Utf8) } } @@ -429,6 +448,9 @@ where } } + decoder.claim_bytes_read(core::mem::size_of::<[T; N]>())?; + + // Optimize for `[u8; N]` if TypeId::of::() == TypeId::of::() { let mut buf = [0u8; N]; decoder.reader().read(&mut buf)?; @@ -439,8 +461,11 @@ where let res = unsafe { ptr.read() }; Ok(res) } else { - let result = - super::impl_core::collect_into_array(&mut (0..N).map(|_| T::decode(&mut decoder))); + let result = super::impl_core::collect_into_array(&mut (0..N).map(|_| { + // See the documentation on `unclaim_bytes_read` as to why we're doing this here + decoder.unclaim_bytes_read(core::mem::size_of::()); + T::decode(&mut decoder) + })); // result is only None if N does not match the values of `(0..N)`, which it always should // So this unwrap should never occur diff --git a/src/de/mod.rs b/src/de/mod.rs index 057699e..ad3f032 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -6,7 +6,11 @@ mod impl_tuples; mod impls; use self::read::{BorrowReader, Reader}; -use crate::{config::Config, error::DecodeError, utils::Sealed}; +use crate::{ + config::{Config, InternalLimitConfig}, + error::DecodeError, + utils::Sealed, +}; pub mod read; @@ -53,6 +57,65 @@ pub trait Decoder: Sealed { /// Returns a mutable reference to the config fn config(&self) -> &Self::C; + + /// Claim that `n` bytes are going to be read from the decoder. + /// This can be used to validate `Configuration::Limit()`. + fn claim_bytes_read(&mut self, n: usize) -> Result<(), DecodeError>; + + /// Claim that we're going to read a container which contains `len` entries of `T`. + /// This will correctly handle overflowing if `len * size_of::() > usize::max_value` + fn claim_container_read(&mut self, len: usize) -> Result<(), DecodeError> { + if ::LIMIT.is_some() { + match len.checked_mul(core::mem::size_of::()) { + Some(val) => self.claim_bytes_read(val), + None => Err(DecodeError::LimitExceeded), + } + } else { + Ok(()) + } + } + + /// Notify the decoder that `n` bytes are being reclaimed. + /// + /// When decoding container types, a typical implementation would claim to read `len * size_of::()` bytes. + /// This is to ensure that bincode won't allocate several GB of memory while constructing the container. + /// + /// Because the implementation claims `len * size_of::()`, but then has to decode each `T`, this would be marked + /// as double. This function allows us to un-claim each `T` that gets decoded. + /// + /// We cannot check if `len * size_of::()` is valid without claiming it, because this would mean that if you have + /// a nested container (e.g. `Vec>`), it does not know how much memory is already claimed, and could easily + /// allocate much more than the user intends. + /// ``` + /// # use bincode::de::{Decode, Decoder}; + /// # use bincode::error::DecodeError; + /// # struct Container(Vec); + /// # impl Container { + /// # fn with_capacity(cap: usize) -> Self { + /// # Self(Vec::with_capacity(cap)) + /// # } + /// # + /// # fn push(&mut self, t: T) { + /// # self.0.push(t); + /// # } + /// # } + /// impl Decode for Container { + /// fn decode(mut decoder: D) -> Result { + /// let len = u64::decode(&mut decoder)? as usize; + /// // Make sure we don't allocate too much memory + /// decoder.claim_bytes_read(len * core::mem::size_of::()); + /// + /// let mut result = Container::with_capacity(len); + /// for _ in 0..len { + /// // un-claim the memory + /// decoder.unclaim_bytes_read(core::mem::size_of::()); + /// result.push(T::decode(&mut decoder)?) + /// } + /// Ok(result) + /// } + /// } + /// ``` + fn unclaim_bytes_read(&mut self, n: usize); } /// Any source that can decode basic types. This type is most notably implemented for [Decoder]. @@ -81,6 +144,16 @@ where fn config(&self) -> &Self::C { T::config(self) } + + #[inline] + fn claim_bytes_read(&mut self, n: usize) -> Result<(), DecodeError> { + T::claim_bytes_read(self, n) + } + + #[inline] + fn unclaim_bytes_read(&mut self, n: usize) { + T::unclaim_bytes_read(self, n) + } } impl<'a, 'de, T> BorrowDecoder<'de> for &'a mut T diff --git a/src/error.rs b/src/error.rs index b9c6a3b..600a4ee 100644 --- a/src/error.rs +++ b/src/error.rs @@ -70,6 +70,9 @@ pub enum DecodeError { /// The reader reached its end but more bytes were expected. UnexpectedEnd, + /// The given configuration limit was exceeded + LimitExceeded, + /// Invalid type was found. The decoder tried to read type `expected`, but found type `found` instead. InvalidIntegerType { /// The type that was being read from the reader diff --git a/src/features/impl_alloc.rs b/src/features/impl_alloc.rs index 366aa73..e8316f1 100644 --- a/src/features/impl_alloc.rs +++ b/src/features/impl_alloc.rs @@ -52,8 +52,13 @@ where { fn decode(mut decoder: D) -> Result { let len = crate::de::decode_slice_len(&mut decoder)?; + decoder.claim_container_read::(len)?; + let mut map = BinaryHeap::with_capacity(len); for _ in 0..len { + // See the documentation on `unclaim_bytes_read` as to why we're doing this here + decoder.unclaim_bytes_read(core::mem::size_of::()); + let key = T::decode(&mut decoder)?; map.push(key); } @@ -81,8 +86,13 @@ where { fn decode(mut decoder: D) -> Result { let len = crate::de::decode_slice_len(&mut decoder)?; + decoder.claim_container_read::<(K, V)>(len)?; + let mut map = BTreeMap::new(); for _ in 0..len { + // See the documentation on `unclaim_bytes_read` as to why we're doing this here + decoder.unclaim_bytes_read(core::mem::size_of::<(K, V)>()); + let key = K::decode(&mut decoder)?; let value = V::decode(&mut decoder)?; map.insert(key, value); @@ -112,8 +122,13 @@ where { fn decode(mut decoder: D) -> Result { let len = crate::de::decode_slice_len(&mut decoder)?; + decoder.claim_container_read::(len)?; + let mut map = BTreeSet::new(); for _ in 0..len { + // See the documentation on `unclaim_bytes_read` as to why we're doing this here + decoder.unclaim_bytes_read(core::mem::size_of::()); + let key = T::decode(&mut decoder)?; map.insert(key); } @@ -140,8 +155,13 @@ where { fn decode(mut decoder: D) -> Result { let len = crate::de::decode_slice_len(&mut decoder)?; + decoder.claim_container_read::(len)?; + let mut map = VecDeque::with_capacity(len); for _ in 0..len { + // See the documentation on `unclaim_bytes_read` as to why we're doing this here + decoder.unclaim_bytes_read(core::mem::size_of::()); + let key = T::decode(&mut decoder)?; map.push_back(key); } @@ -168,8 +188,13 @@ where { fn decode(mut decoder: D) -> Result { let len = crate::de::decode_slice_len(&mut decoder)?; + decoder.claim_container_read::(len)?; + let mut vec = Vec::with_capacity(len); for _ in 0..len { + // See the documentation on `unclaim_bytes_read` as to why we're doing this here + decoder.unclaim_bytes_read(core::mem::size_of::()); + vec.push(T::decode(&mut decoder)?); } Ok(vec) diff --git a/src/features/impl_std.rs b/src/features/impl_std.rs index fe53cda..9f4e103 100644 --- a/src/features/impl_std.rs +++ b/src/features/impl_std.rs @@ -369,8 +369,13 @@ where { fn decode(mut decoder: D) -> Result { let len = crate::de::decode_slice_len(&mut decoder)?; + decoder.claim_container_read::<(K, V)>(len)?; + let mut map = HashMap::with_capacity(len); for _ in 0..len { + // See the documentation on `unclaim_bytes_read` as to why we're doing this here + decoder.unclaim_bytes_read(core::mem::size_of::<(K, V)>()); + let k = K::decode(&mut decoder)?; let v = V::decode(&mut decoder)?; map.insert(k, v); diff --git a/tests/alloc.rs b/tests/alloc.rs index 1a2c457..6f326e0 100644 --- a/tests/alloc.rs +++ b/tests/alloc.rs @@ -89,3 +89,41 @@ fn test_alloc_commons() { set }); } + +#[test] +fn test_container_limits() { + use bincode::{error::DecodeError, Decode}; + + const DECODE_LIMIT: usize = 100_000; + + // for this test we'll create a malformed package of a lot of bytes + let test_cases = &[ + // u64::max_value(), should overflow + bincode::encode_to_vec(u64::max_value(), Configuration::standard()).unwrap(), + // A high value which doesn't overflow, but exceeds the decode limit + bincode::encode_to_vec(DECODE_LIMIT as u64, Configuration::standard()).unwrap(), + ]; + + fn validate_fail(slice: &[u8]) { + let result = bincode::decode_from_slice::( + slice, + Configuration::standard().with_limit::(), + ); + + assert_eq!(result.unwrap_err(), DecodeError::LimitExceeded); + } + + for slice in test_cases { + validate_fail::>(slice); + validate_fail::>(slice); + validate_fail::>(slice); + validate_fail::>(slice); + validate_fail::>(slice); + validate_fail::(slice); + validate_fail::>(slice); + #[cfg(feature = "std")] + { + validate_fail::>(slice); + } + } +}