diff --git a/examples/basic.rs b/examples/basic.rs index abd8575..28c8b61 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -2,7 +2,7 @@ extern crate serde_derive; extern crate bincode; -use bincode::{serialize, deserialize, SizeLimit}; +use bincode::{serialize, deserialize, Infinite}; #[derive(Serialize, Deserialize, PartialEq)] struct Entity { @@ -16,7 +16,7 @@ struct World(Vec); fn main() { let world = World(vec![Entity { x: 0.0, y: 4.0 }, Entity { x: 10.0, y: 20.5 }]); - let encoded: Vec = serialize(&world, SizeLimit::Infinite).unwrap(); + let encoded: Vec = serialize(&world, Infinite).unwrap(); // 8 bytes for the length of the vector, 4 bytes per float. assert_eq!(encoded.len(), 8 + 4 * 4); diff --git a/src/lib.rs b/src/lib.rs index 296ab4c..86fb997 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,12 +17,12 @@ //! //! ```rust //! extern crate bincode; -//! use bincode::{serialize, deserialize}; +//! use bincode::{serialize, deserialize, Bounded}; //! fn main() { //! // The object that we will serialize. //! let target = Some("hello world".to_string()); //! // The maximum size of the encoded message. -//! let limit = bincode::SizeLimit::Bounded(20); +//! let limit = Bounded(20); //! //! let encoded: Vec = serialize(&target, limit).unwrap(); //! let decoded: Option = deserialize(&encoded[..]).unwrap(); @@ -51,7 +51,7 @@ use std::io::{Read, Write}; pub use serde::{ErrorKind, Error, Result, serialized_size, serialized_size_bounded}; -pub type Deserializer = serde::Deserializer; +pub type Deserializer = serde::Deserializer; pub type Serializer = serde::Serializer; /// Deserializes a slice of bytes into an object. @@ -73,11 +73,10 @@ pub fn deserialize(bytes: &[u8]) -> serde::Result /// If this returns an `Error`, assume that the buffer that you passed /// in is in an invalid state, as the error could be returned during any point /// in the reading. -pub fn deserialize_from(reader: &mut R, size_limit: SizeLimit) -> serde::Result - where R: Read, - T: serde_crate::Deserialize, +pub fn deserialize_from(reader: &mut R, size_limit: S) -> serde::Result + where R: Read, T: serde_crate::Deserialize, S: SizeLimit { - serde::deserialize_from::<_, _, byteorder::LittleEndian>(reader, size_limit) + serde::deserialize_from::<_, _, _, byteorder::LittleEndian>(reader, size_limit) } /// Serializes an object directly into a `Writer`. @@ -88,20 +87,20 @@ pub fn deserialize_from(reader: &mut R, size_limit: SizeLimit) -> /// If this returns an `Error` (other than SizeLimit), assume that the /// writer is in an invalid state, as writing could bail out in the middle of /// serializing. -pub fn serialize_into(writer: &mut W, value: &T, size_limit: SizeLimit) -> serde::Result<()> - where W: Write, T: serde_crate::Serialize +pub fn serialize_into(writer: &mut W, value: &T, size_limit: S) -> serde::Result<()> + where W: Write, T: serde_crate::Serialize, S: SizeLimit { - serde::serialize_into::<_, _, byteorder::LittleEndian>(writer, value, size_limit) + serde::serialize_into::<_, _, _, byteorder::LittleEndian>(writer, value, size_limit) } /// Serializes a serializable object into a `Vec` of bytes. /// /// If the serialization would take more bytes than allowed by `size_limit`, /// an error is returned. -pub fn serialize(value: &T, size_limit: SizeLimit) -> serde::Result> - where T: serde_crate::Serialize +pub fn serialize(value: &T, size_limit: S) -> serde::Result> + where T: serde_crate::Serialize, S: SizeLimit { - serde::serialize::<_, byteorder::LittleEndian>(value, size_limit) + serde::serialize::<_, _, byteorder::LittleEndian>(value, size_limit) } /// A limit on the amount of bytes that can be read or written. @@ -122,8 +121,34 @@ pub fn serialize(value: &T, size_limit: SizeLimit) -> serde::Result Result<()>; + fn limit(&self) -> Option; +} + +#[derive(Copy, Clone)] +pub struct Bounded(pub u64); + +#[derive(Copy, Clone)] +pub struct Infinite; + +impl SizeLimit for Bounded { + #[inline(always)] + fn add(&mut self, n: u64) -> Result<()> { + if self.0 >= n { + self.0 -= n; + Ok(()) + } else { + Err(Box::new(ErrorKind::SizeLimit)) + } + } + #[inline(always)] + fn limit(&self) -> Option { Some(self.0) } +} + +impl SizeLimit for Infinite { + #[inline(always)] + fn add(&mut self, _: u64) -> Result<()> { Ok (()) } + #[inline(always)] + fn limit(&self) -> Option { None } } diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 9caeb0a..df78d17 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -109,7 +109,7 @@ impl serde::ser::Error for Error { fn custom(msg: T) -> Self { ErrorKind::Custom(msg.to_string()).into() } -} +} /// Serializes an object directly into a `Writer`. /// @@ -119,15 +119,11 @@ impl serde::ser::Error for Error { /// If this returns an `Error` (other than SizeLimit), assume that the /// writer is in an invalid state, as writing could bail out in the middle of /// serializing. -pub fn serialize_into(writer: &mut W, value: &T, size_limit: SizeLimit) -> Result<()> - where W: Write, T: serde::Serialize, E: ByteOrder +pub fn serialize_into(writer: &mut W, value: &T, size_limit: S) -> Result<()> + where W: Write, T: serde::Serialize, S: SizeLimit, E: ByteOrder { - match size_limit { - SizeLimit::Infinite => { } - SizeLimit::Bounded(x) => { - let mut size_checker = SizeChecker::new(x); - try!(value.serialize(&mut size_checker)) - } + if let Some(limit) = size_limit.limit() { + try!(serialized_size_bounded(value, limit).ok_or(ErrorKind::SizeLimit)); } let mut serializer = Serializer::<_, E>::new(writer); @@ -138,35 +134,59 @@ pub fn serialize_into(writer: &mut W, value: &T, size_l /// /// If the serialization would take more bytes than allowed by `size_limit`, /// an error is returned. -pub fn serialize(value: &T, size_limit: SizeLimit) -> Result> - where T: serde::Serialize +pub fn serialize(value: &T, size_limit: S) -> Result> + where T: serde::Serialize, S: SizeLimit, E: ByteOrder { // Since we are putting values directly into a vector, we can do size // computation out here and pre-allocate a buffer of *exactly* // the right size. - let mut writer = match size_limit { - SizeLimit::Bounded(size_limit) => { + let mut writer = match size_limit.limit() { + Some(size_limit) => { let actual_size = try!(serialized_size_bounded(value, size_limit).ok_or(ErrorKind::SizeLimit)); Vec::with_capacity(actual_size as usize) } - SizeLimit::Infinite => Vec::new() + None => Vec::new() }; - try!(serialize_into::<_, _, E>(&mut writer, value, SizeLimit::Infinite)); + try!(serialize_into::<_, _, _, E>(&mut writer, value, super::Infinite)); Ok(writer) } + +struct CountSize { + total: u64, + limit: Option, +} + +impl SizeLimit for CountSize { + fn add(&mut self, c: u64) -> Result<()> { + self.total += c; + if let Some(limit) = self.limit { + if self.total > limit { + return Err(Box::new(ErrorKind::SizeLimit)) + } + } + Ok(()) + } + + fn limit(&self) -> Option { + unreachable!(); + } +} + /// Returns the size that an object would be if serialized using bincode. /// /// This is used internally as part of the check for encode_into, but it can /// be useful for preallocating buffers if thats your style. -pub fn serialized_size(value: &T) -> u64 +pub fn serialized_size(value: &T) -> u64 where T: serde::Serialize { - use std::u64::MAX; - let mut size_checker = SizeChecker::new(MAX); - value.serialize(&mut size_checker).ok(); - size_checker.written + let mut size_counter = SizeChecker { + size_limit: CountSize { total: 0, limit: None } + }; + + value.serialize(&mut size_counter).ok(); + size_counter.size_limit.total } /// Given a maximum size limit, check how large an object would be if it @@ -177,8 +197,14 @@ pub fn serialized_size(value: &T) -> u64 pub fn serialized_size_bounded(value: &T, max: u64) -> Option where T: serde::Serialize { - let mut size_checker = SizeChecker::new(max); - value.serialize(&mut size_checker).ok().map(|_| size_checker.written) + let mut size_counter = SizeChecker { + size_limit: CountSize { total: 0, limit: Some(max) } + }; + + match value.serialize(&mut size_counter) { + Ok(_) => Some(size_counter.size_limit.total), + Err(_) => None, + } } /// Deserializes an object directly from a `Buffer`ed Reader. @@ -190,11 +216,10 @@ pub fn serialized_size_bounded(value: &T, max: u64) -> Option /// If this returns an `Error`, assume that the buffer that you passed /// in is in an invalid state, as the error could be returned during any point /// in the reading. -pub fn deserialize_from(reader: &mut R, size_limit: SizeLimit) -> Result - where R: Read, - T: serde::Deserialize, +pub fn deserialize_from(reader: &mut R, size_limit: S) -> Result + where R: Read, T: serde::Deserialize, S: SizeLimit, E: ByteOrder { - let mut deserializer = Deserializer::<_, E>::new(reader, size_limit); + let mut deserializer = Deserializer::<_, S, E>::new(reader, size_limit); serde::Deserialize::deserialize(&mut deserializer) } @@ -206,5 +231,5 @@ pub fn deserialize(bytes: &[u8]) -> Result where T: serde::Deserialize, { let mut reader = bytes; - deserialize_from::<_, _, E>(&mut reader, SizeLimit::Infinite) + deserialize_from::<_, _, _, E>(&mut reader, super::Infinite) } diff --git a/src/serde/reader.rs b/src/serde/reader.rs index 145c4c4..ca872cb 100644 --- a/src/serde/reader.rs +++ b/src/serde/reader.rs @@ -21,15 +21,15 @@ const BLOCK_SIZE: usize = 65536; /// serde::Deserialize::deserialize(&mut deserializer); /// let bytes_read = d.bytes_read(); /// ``` -pub struct Deserializer { +pub struct Deserializer { reader: R, - size_limit: SizeLimit, + size_limit: S, read: u64, _phantom: PhantomData, } -impl Deserializer { - pub fn new(r: R, size_limit: SizeLimit) -> Deserializer { +impl Deserializer { + pub fn new(r: R, size_limit: S) -> Deserializer { Deserializer { reader: r, size_limit: size_limit, @@ -44,12 +44,7 @@ impl Deserializer { } fn read_bytes(&mut self, count: u64) -> Result<()> { - self.read += count; - match self.size_limit { - SizeLimit::Infinite => Ok(()), - SizeLimit::Bounded(x) if self.read <= x => Ok(()), - SizeLimit::Bounded(_) => Err(ErrorKind::SizeLimit.into()) - } + self.size_limit.add(count) } fn read_type(&mut self) -> Result<()> { @@ -98,7 +93,8 @@ macro_rules! impl_nums { } } -impl<'a, R: Read, E: ByteOrder> serde::Deserializer for &'a mut Deserializer { +impl<'a, R, S, E> serde::Deserializer for &'a mut Deserializer +where R: Read, S: SizeLimit, E: ByteOrder { type Error = Error; #[inline] @@ -215,7 +211,8 @@ impl<'a, R: Read, E: ByteOrder> serde::Deserializer for &'a mut Deserializer Result where V: serde::de::Visitor, { - impl<'a, R: Read + 'a, E: ByteOrder> serde::de::EnumVisitor for &'a mut Deserializer { + impl<'a, R: 'a, S, E> serde::de::EnumVisitor for &'a mut Deserializer + where R: Read, S: SizeLimit, E: ByteOrder { type Error = Error; type Variant = Self; @@ -230,15 +227,13 @@ impl<'a, R: Read, E: ByteOrder> serde::Deserializer for &'a mut Deserializer(self, - _len: usize, - visitor: V) -> Result + + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result where V: serde::de::Visitor, { - struct TupleVisitor<'a, R: Read + 'a, E: ByteOrder + 'a>(&'a mut Deserializer); + struct TupleVisitor<'a, R: Read + 'a, S: SizeLimit + 'a, E: ByteOrder + 'a>(&'a mut Deserializer); - impl<'a, 'b: 'a, R: Read + 'b, E: ByteOrder> serde::de::SeqVisitor for TupleVisitor<'a, R, E> { + impl<'a, 'b: 'a, R: Read + 'b, S: SizeLimit, E: ByteOrder> serde::de::SeqVisitor for TupleVisitor<'a, R, S, E> { type Error = Error; fn visit_seed(&mut self, seed: T) -> Result> @@ -257,12 +252,12 @@ impl<'a, R: Read, E: ByteOrder> serde::Deserializer for &'a mut Deserializer Result where V: serde::de::Visitor, { - struct SeqVisitor<'a, R: Read + 'a, E: ByteOrder + 'a> { - deserializer: &'a mut Deserializer, + struct SeqVisitor<'a, R: Read + 'a, S: SizeLimit + 'a, E: ByteOrder + 'a> { + deserializer: &'a mut Deserializer, len: usize, } - impl<'a, 'b: 'a, R: Read + 'b, E: ByteOrder> serde::de::SeqVisitor for SeqVisitor<'a, R, E> { + impl<'a, 'b: 'a, R: Read + 'b, S: SizeLimit, E: ByteOrder> serde::de::SeqVisitor for SeqVisitor<'a, R, S, E> { type Error = Error; fn visit_seed(&mut self, seed: T) -> Result> @@ -306,12 +301,12 @@ impl<'a, R: Read, E: ByteOrder> serde::Deserializer for &'a mut Deserializer(self, visitor: V) -> Result where V: serde::de::Visitor, { - struct MapVisitor<'a, R: Read + 'a, E: ByteOrder + 'a> { - deserializer: &'a mut Deserializer, + struct MapVisitor<'a, R: Read + 'a, S: SizeLimit + 'a, E: ByteOrder + 'a> { + deserializer: &'a mut Deserializer, len: usize, } - impl<'a, 'b: 'a, R: Read + 'b, E: ByteOrder> serde::de::MapVisitor for MapVisitor<'a, R, E> { + impl<'a, 'b: 'a, R: Read + 'b, S: SizeLimit, E: ByteOrder> serde::de::MapVisitor for MapVisitor<'a, R, S, E> { type Error = Error; fn visit_key_seed(&mut self, seed: K) -> Result> @@ -390,7 +385,8 @@ impl<'a, R: Read, E: ByteOrder> serde::Deserializer for &'a mut Deserializer serde::de::VariantVisitor for &'a mut Deserializer { +impl<'a, R, S, E> serde::de::VariantVisitor for &'a mut Deserializer +where R: Read, S: SizeLimit, E: ByteOrder { type Error = Error; fn visit_unit(self) -> Result<()> { diff --git a/src/serde/writer.rs b/src/serde/writer.rs index 3bb10fc..203b677 100644 --- a/src/serde/writer.rs +++ b/src/serde/writer.rs @@ -7,6 +7,7 @@ use serde_crate as serde; use byteorder::{WriteBytesExt, ByteOrder}; use super::{Result, Error, ErrorKind}; +use super::super::SizeLimit; /// An Serializer that encodes values directly into a Writer. /// @@ -193,31 +194,24 @@ impl<'a, W: Write, E: ByteOrder> serde::Serializer for &'a mut Serializer } } -pub struct SizeChecker { - pub size_limit: u64, - pub written: u64 +pub struct SizeChecker { + pub size_limit: S, } -impl SizeChecker { - pub fn new(limit: u64) -> SizeChecker { +impl SizeChecker { + pub fn new(size_limit: S) -> SizeChecker { SizeChecker { - size_limit: limit, - written: 0 + size_limit: size_limit } } - fn add_raw(&mut self, size: usize) -> Result<()> { - self.written += size as u64; - if self.written <= self.size_limit { - Ok(()) - } else { - Err(ErrorKind::SizeLimit.into()) - } + fn add_raw(&mut self, size: u64) -> Result<()> { + self.size_limit.add(size) } fn add_value(&mut self, t: T) -> Result<()> { use std::mem::size_of_val; - self.add_raw(size_of_val(&t)) + self.add_raw(size_of_val(&t) as u64) } fn add_enum_tag(&mut self, tag: usize) -> Result<()> { @@ -229,16 +223,16 @@ impl SizeChecker { } } -impl<'a> serde::Serializer for &'a mut SizeChecker { +impl<'a, S: SizeLimit> serde::Serializer for &'a mut SizeChecker { type Ok = (); type Error = Error; - type SerializeSeq = SizeCompound<'a>; - type SerializeTuple = SizeCompound<'a>; - type SerializeTupleStruct = SizeCompound<'a>; - type SerializeTupleVariant = SizeCompound<'a>; - type SerializeMap = SizeCompound<'a>; - type SerializeStruct = SizeCompound<'a>; - type SerializeStructVariant = SizeCompound<'a>; + type SerializeSeq = SizeCompound<'a, S>; + type SerializeTuple = SizeCompound<'a, S>; + type SerializeTupleStruct = SizeCompound<'a, S>; + type SerializeTupleVariant = SizeCompound<'a, S>; + type SerializeMap = SizeCompound<'a, S>; + type SerializeStruct = SizeCompound<'a, S>; + type SerializeStructVariant = SizeCompound<'a, S>; fn serialize_unit(self) -> Result<()> { Ok(()) } @@ -290,16 +284,16 @@ impl<'a> serde::Serializer for &'a mut SizeChecker { fn serialize_str(self, v: &str) -> Result<()> { try!(self.add_value(0 as u64)); - self.add_raw(v.len()) + self.add_raw(v.len() as u64) } fn serialize_char(self, c: char) -> Result<()> { - self.add_raw(encode_utf8(c).as_slice().len()) + self.add_raw(encode_utf8(c).as_slice().len() as u64) } fn serialize_bytes(self, v: &[u8]) -> Result<()> { try!(self.add_value(0 as u64)); - self.add_raw(v.len()) + self.add_raw(v.len() as u64) } fn serialize_none(self) -> Result<()> { @@ -532,11 +526,11 @@ impl<'a, W, E> serde::ser::SerializeStructVariant for Compound<'a, W, E> } #[doc(hidden)] -pub struct SizeCompound<'a> { - ser: &'a mut SizeChecker, +pub struct SizeCompound<'a, S: SizeLimit + 'a> { + ser: &'a mut SizeChecker, } -impl<'a> serde::ser::SerializeSeq for SizeCompound<'a> +impl<'a, S: SizeLimit> serde::ser::SerializeSeq for SizeCompound<'a, S> { type Ok = (); type Error = Error; @@ -554,7 +548,7 @@ impl<'a> serde::ser::SerializeSeq for SizeCompound<'a> } } -impl<'a> serde::ser::SerializeTuple for SizeCompound<'a> +impl<'a, S: SizeLimit> serde::ser::SerializeTuple for SizeCompound<'a, S> { type Ok = (); type Error = Error; @@ -572,7 +566,7 @@ impl<'a> serde::ser::SerializeTuple for SizeCompound<'a> } } -impl<'a> serde::ser::SerializeTupleStruct for SizeCompound<'a> +impl<'a, S: SizeLimit> serde::ser::SerializeTupleStruct for SizeCompound<'a, S> { type Ok = (); type Error = Error; @@ -590,7 +584,7 @@ impl<'a> serde::ser::SerializeTupleStruct for SizeCompound<'a> } } -impl<'a> serde::ser::SerializeTupleVariant for SizeCompound<'a> +impl<'a, S: SizeLimit> serde::ser::SerializeTupleVariant for SizeCompound<'a, S> { type Ok = (); type Error = Error; @@ -608,7 +602,7 @@ impl<'a> serde::ser::SerializeTupleVariant for SizeCompound<'a> } } -impl<'a> serde::ser::SerializeMap for SizeCompound<'a> +impl<'a, S: SizeLimit + 'a> serde::ser::SerializeMap for SizeCompound<'a, S> { type Ok = (); type Error = Error; @@ -633,7 +627,7 @@ impl<'a> serde::ser::SerializeMap for SizeCompound<'a> } } -impl<'a> serde::ser::SerializeStruct for SizeCompound<'a> +impl<'a, S: SizeLimit> serde::ser::SerializeStruct for SizeCompound<'a, S> { type Ok = (); type Error = Error; @@ -651,7 +645,7 @@ impl<'a> serde::ser::SerializeStruct for SizeCompound<'a> } } -impl<'a> serde::ser::SerializeStructVariant for SizeCompound<'a> +impl<'a, S: SizeLimit> serde::ser::SerializeStructVariant for SizeCompound<'a, S> { type Ok = (); type Error = Error; @@ -668,7 +662,6 @@ impl<'a> serde::ser::SerializeStructVariant for SizeCompound<'a> Ok(()) } } - const TAG_CONT: u8 = 0b1000_0000; const TAG_TWO_B: u8 = 0b1100_0000; const TAG_THREE_B: u8 = 0b1110_0000; diff --git a/tests/test.rs b/tests/test.rs index 914013c..9b31d3c 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -11,7 +11,7 @@ use std::ops::Deref; use bincode::refbox::{RefBox, StrBox, SliceBox}; -use bincode::SizeLimit::{Infinite, Bounded}; +use bincode::{Infinite, Bounded}; use bincode::{serialized_size, ErrorKind, Result}; use bincode::endian_choice::{serialize, deserialize}; @@ -34,7 +34,6 @@ fn the_same(element: V) } let size = serialized_size(&element); - { let encoded = serialize_little(&element, Infinite); let encoded = encoded.unwrap(); @@ -47,7 +46,7 @@ fn the_same(element: V) } { - let encoded = serialize::<_, byteorder::BigEndian>(&element, Infinite); + let encoded = serialize::<_, _, byteorder::BigEndian>(&element, Infinite); let encoded = encoded.unwrap(); let decoded = deserialize::<_, byteorder::BigEndian>(&encoded[..]); let decoded = decoded.unwrap(); @@ -235,11 +234,11 @@ fn deserializing_errors() { #[test] fn too_big_deserialize() { let serialized = vec![0,0,0,3]; - let deserialized: Result = deserialize_from_little::<_, _>(&mut &serialized[..], Bounded(3)); + let deserialized: Result = deserialize_from_little::<_, _, _>(&mut &serialized[..], Bounded(3)); assert!(deserialized.is_err()); let serialized = vec![0,0,0,3]; - let deserialized: Result = deserialize_from_little::<_, _>(&mut &serialized[..], Bounded(4)); + let deserialized: Result = deserialize_from_little::<_, _, _>(&mut &serialized[..], Bounded(4)); assert!(deserialized.is_ok()); } @@ -256,7 +255,7 @@ fn char_serialization() { #[test] fn too_big_char_deserialize() { let serialized = vec![0x41]; - let deserialized: Result = deserialize_from_little::<_, _>(&mut &serialized[..], Bounded(1)); + let deserialized: Result = deserialize_from_little::<_, _, _>(&mut &serialized[..], Bounded(1)); assert!(deserialized.is_ok()); assert_eq!(deserialized.unwrap(), 'A'); } @@ -404,6 +403,6 @@ fn bytes() { fn endian_difference() { let x = 10u64; let little = serialize_little(&x, Infinite).unwrap(); - let big = serialize::<_, byteorder::BigEndian>(&x, Infinite).unwrap(); + let big = serialize::<_, _, byteorder::BigEndian>(&x, Infinite).unwrap(); assert_ne!(little, big); }