diff --git a/examples/basic.rs b/examples/basic.rs index a08d3a2..1b5df8d 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -20,9 +20,11 @@ fn main() { }; let encoded: Vec = bincode::encode(&world, SizeLimit::Infinite).unwrap(); + // 8 bytes for the length of the vector, 4 bytes per float. assert_eq!(encoded.len(), 8 + 4 * 4); - let decoded: World = bincode::decode(encoded, SizeLimit::Infinite).unwrap(); + + let decoded: World = bincode::decode(encoded.as_slice()).unwrap(); assert!(world == decoded); } diff --git a/src/lib.rs b/src/lib.rs index 352c627..7aa1b45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,18 +6,16 @@ extern crate "rustc-serialize" as rustc_serialize; -use std::io::Buffer; -use std::io::MemWriter; -use std::io::MemReader; -use std::io::IoResult; -use rustc_serialize::Encodable; -use rustc_serialize::Decodable; +use std::io::{Buffer, MemWriter}; +use rustc_serialize::{Encodable, Decodable}; -pub use writer::EncoderWriter; +pub use writer::{EncoderWriter, EncodingResult, EncodingError}; pub use reader::{DecoderReader, DecodingResult, DecodingError}; +use writer::SizeChecker; mod writer; mod reader; +#[cfg(test)] mod test; #[derive(Clone, Copy)] pub enum SizeLimit { @@ -25,7 +23,11 @@ pub enum SizeLimit { UpperBound(u64) } -pub fn encode(t: &T, size_limit: SizeLimit) -> IoResult> { +/// Encodes an encodable object into a `Vec` of bytes. +/// +/// If the encoding would take more bytes than allowed by `size_limit`, +/// an error is returned. +pub fn encode(t: &T, size_limit: SizeLimit) -> EncodingResult> { let mut w = MemWriter::new(); match encode_into(t, &mut w, size_limit) { Ok(()) => Ok(w.into_inner()), @@ -33,17 +35,34 @@ pub fn encode(t: &T, size_limit: SizeLimit) -> IoResult> { } } -pub fn decode(b: Vec, size_limit: SizeLimit) -> DecodingResult { - decode_from(&mut MemReader::new(b), size_limit) +/// Decodes a slice of bytes into an object. +pub fn decode(b: &[u8]) -> DecodingResult { + let mut b = b; + decode_from(&mut b, SizeLimit::Infinite) } -pub fn encode_into(t: &T, w: &mut W, size_limit: SizeLimit) -> IoResult<()> { +/// Encodes an object directly into a `Writer`. +/// +/// If the encoding would take more bytes than allowed by `size_limit`, an error +/// is returned and *no bytes* will be written into the `Writer`. +pub fn encode_into(t: &T, w: &mut W, size_limit: SizeLimit) -> EncodingResult<()> { + try!(match size_limit { + SizeLimit::Infinite => Ok(()), + SizeLimit::UpperBound(x) => { + let mut size_checker = SizeChecker::new(x); + t.encode(&mut size_checker) + } + }); + t.encode(&mut writer::EncoderWriter::new(w, size_limit)) } -pub fn decode_from(r: &mut R, size_limit: SizeLimit) -> DecodingResult { +/// Decoes an object directly from a Buffered Reader. +/// +/// If the provided `SizeLimit` is reached, the decode will bail immediately. +/// A SizeLimit can help prevent an attacker from flooding your server with +/// a neverending stream of values that runs your server out of memory. +pub fn decode_from(r: &mut R, size_limit: SizeLimit) -> +DecodingResult { Decodable::decode(&mut reader::DecoderReader::new(r, size_limit)) } - -#[cfg(test)] -mod test; diff --git a/src/test.rs b/src/test.rs index e6b872a..691fe19 100644 --- a/src/test.rs +++ b/src/test.rs @@ -13,13 +13,14 @@ use rustc_serialize::{ use super::{ encode, decode, + decode_from, DecodingError, DecodingResult }; use super::SizeLimit::{Infinite, UpperBound}; fn the_same<'a, V>(element: V) where V: Encodable, V: Decodable, V: PartialEq, V: Show { - assert!(element == decode(encode(&element, Infinite).unwrap(), Infinite).unwrap()); + assert!(element == decode(encode(&element, Infinite).unwrap().as_slice()).unwrap()); } #[test] @@ -178,32 +179,44 @@ fn is_invalid_bytes(res: DecodingResult) { #[test] fn decoding_errors() { - is_invalid_bytes(decode::(vec![0xA], Infinite)); - is_invalid_bytes(decode::(vec![0, 0, 0, 0, 0, 0, 0, 1, 0xFF], Infinite)); + is_invalid_bytes(decode::(vec![0xA].as_slice())); + is_invalid_bytes(decode::(vec![0, 0, 0, 0, 0, 0, 0, 1, 0xFF].as_slice())); // Out-of-bounds variant #[derive(RustcEncodable, RustcDecodable)] enum Test { One, Two, }; - is_invalid_bytes(decode::(vec![0, 0, 0, 5], Infinite)); - is_invalid_bytes(decode::>(vec![5, 0], Infinite)); + is_invalid_bytes(decode::(vec![0, 0, 0, 5].as_slice())); + is_invalid_bytes(decode::>(vec![5, 0].as_slice())); } #[test] fn too_big_decode() { let encoded = vec![0,0,0,3]; - let decoded: Result = decode(encoded, UpperBound(3)); + let mut encoded_ref = encoded.as_slice(); + let decoded: Result = decode_from(&mut encoded_ref, UpperBound(3)); assert!(decoded.is_err()); let encoded = vec![0,0,0,3]; - let decoded: Result = decode(encoded, UpperBound(4)); + let mut encoded_ref = encoded.as_slice(); + let decoded: Result = decode_from(&mut encoded_ref, UpperBound(4)); assert!(decoded.is_ok()); } #[test] -fn too_big_char() { +fn too_big_char_decode() { let encoded = vec![0x41]; - let decoded: Result = decode(encoded, UpperBound(1)); + let mut encoded_ref = encoded.as_slice(); + let decoded: Result = decode_from(&mut encoded_ref, UpperBound(1)); assert_eq!(decoded, Ok('A')); } + +#[test] +fn too_big_encode() { + assert!(encode(&0u32, UpperBound(3)).is_err()); + assert!(encode(&0u32, UpperBound(4)).is_ok()); + + assert!(encode(&"abcde", UpperBound(4)).is_err()); + assert!(encode(&"abcde", UpperBound(5)).is_ok()); +} diff --git a/src/writer.rs b/src/writer.rs index 73339f8..1ac30a0 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -1,17 +1,49 @@ -use std::io::{Writer, IoError, IoResult}; +use std::io::{Writer, IoError}; +use std::error::Error; use std::num::Int; use rustc_serialize::Encoder; use super::SizeLimit; -type EwResult = IoResult<()>; +pub type EncodingResult = Result; + +#[derive(Show)] +pub enum EncodingError { + IoError(IoError), + SizeLimit +} pub struct EncoderWriter<'a, W: 'a> { writer: &'a mut W, _size_limit: SizeLimit } +pub struct SizeChecker { + size_limit: u64, + written: u64 +} + +fn wrap_io(err: IoError) -> EncodingError { + EncodingError::IoError(err) +} + +impl Error for EncodingError { + fn description(&self) -> &str { + match *self { + EncodingError::IoError(ref err) => err.description(), + EncodingError::SizeLimit => "the size limit for decoding has been reached" + } + } + + fn detail(&self) -> Option { + match *self { + EncodingError::IoError(ref err) => err.detail(), + EncodingError::SizeLimit => None + } + } +} + impl <'a, W: Writer> EncoderWriter<'a, W> { pub fn new(w: &'a mut W, size_limit: SizeLimit) -> EncoderWriter<'a, W> { EncoderWriter { @@ -21,65 +53,88 @@ impl <'a, W: Writer> EncoderWriter<'a, W> { } } -impl<'a, W: Writer> Encoder for EncoderWriter<'a, W> { - type Error = IoError; +impl SizeChecker { + pub fn new(limit: u64) -> SizeChecker { + SizeChecker { + size_limit: limit, + written: 0 + } + } - fn emit_nil(&mut self) -> EwResult { Ok(()) } - fn emit_usize(&mut self, v: usize) -> EwResult { + fn add_raw(&mut self, size: usize) -> EncodingResult<()> { + self.written += size as u64; + if self.written <= self.size_limit { + Ok(()) + } else { + Err(EncodingError::SizeLimit) + } + } + + fn add_value(&mut self, t: T) -> EncodingResult<()> { + use std::mem::size_of_val; + self.add_raw(size_of_val(&t)) + } +} + +impl<'a, W: Writer> Encoder for EncoderWriter<'a, W> { + type Error = EncodingError; + + fn emit_nil(&mut self) -> EncodingResult<()> { Ok(()) } + fn emit_usize(&mut self, v: usize) -> EncodingResult<()> { self.emit_u64(v as u64) } - fn emit_u64(&mut self, v: u64) -> EwResult { - self.writer.write_be_u64(v) + fn emit_u64(&mut self, v: u64) -> EncodingResult<()> { + self.writer.write_be_u64(v).map_err(wrap_io) } - fn emit_u32(&mut self, v: u32) -> EwResult { - self.writer.write_be_u32(v) + fn emit_u32(&mut self, v: u32) -> EncodingResult<()> { + self.writer.write_be_u32(v).map_err(wrap_io) } - fn emit_u16(&mut self, v: u16) -> EwResult { - self.writer.write_be_u16(v) + fn emit_u16(&mut self, v: u16) -> EncodingResult<()> { + self.writer.write_be_u16(v).map_err(wrap_io) } - fn emit_u8(&mut self, v: u8) -> EwResult { - self.writer.write_u8(v) + fn emit_u8(&mut self, v: u8) -> EncodingResult<()> { + self.writer.write_u8(v).map_err(wrap_io) } - fn emit_isize(&mut self, v: isize) -> EwResult { + fn emit_isize(&mut self, v: isize) -> EncodingResult<()> { self.emit_i64(v as i64) } - fn emit_i64(&mut self, v: i64) -> EwResult { - self.writer.write_be_i64(v) + fn emit_i64(&mut self, v: i64) -> EncodingResult<()> { + self.writer.write_be_i64(v).map_err(wrap_io) } - fn emit_i32(&mut self, v: i32) -> EwResult { - self.writer.write_be_i32(v) + fn emit_i32(&mut self, v: i32) -> EncodingResult<()> { + self.writer.write_be_i32(v).map_err(wrap_io) } - fn emit_i16(&mut self, v: i16) -> EwResult { - self.writer.write_be_i16(v) + fn emit_i16(&mut self, v: i16) -> EncodingResult<()> { + self.writer.write_be_i16(v).map_err(wrap_io) } - fn emit_i8(&mut self, v: i8) -> EwResult { - self.writer.write_i8(v) + fn emit_i8(&mut self, v: i8) -> EncodingResult<()> { + self.writer.write_i8(v).map_err(wrap_io) } - fn emit_bool(&mut self, v: bool) -> EwResult { - self.writer.write_u8(if v {1} else {0}) + fn emit_bool(&mut self, v: bool) -> EncodingResult<()> { + self.writer.write_u8(if v {1} else {0}).map_err(wrap_io) } - fn emit_f64(&mut self, v: f64) -> EwResult { - self.writer.write_be_f64(v) + fn emit_f64(&mut self, v: f64) -> EncodingResult<()> { + self.writer.write_be_f64(v).map_err(wrap_io) } - fn emit_f32(&mut self, v: f32) -> EwResult { - self.writer.write_be_f32(v) + fn emit_f32(&mut self, v: f32) -> EncodingResult<()> { + self.writer.write_be_f32(v).map_err(wrap_io) } - fn emit_char(&mut self, v: char) -> EwResult { - self.writer.write_char(v) + fn emit_char(&mut self, v: char) -> EncodingResult<()> { + self.writer.write_char(v).map_err(wrap_io) } - fn emit_str(&mut self, v: &str) -> EwResult { + fn emit_str(&mut self, v: &str) -> EncodingResult<()> { try!(self.emit_usize(v.len())); - self.writer.write_str(v) + self.writer.write_str(v).map_err(wrap_io) } - fn emit_enum(&mut self, __: &str, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_enum(&mut self, __: &str, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } fn emit_enum_variant(&mut self, _: &str, v_id: usize, _: usize, - f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { let max_u32: u32 = Int::max_value(); if v_id > (max_u32 as usize) { panic!("Variant tag doesn't fit in a u32") @@ -87,80 +142,221 @@ impl<'a, W: Writer> Encoder for EncoderWriter<'a, W> { try!(self.emit_u32(v_id as u32)); f(self) } - fn emit_enum_variant_arg(&mut self, _: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_enum_variant_arg(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } fn emit_enum_struct_variant(&mut self, _: &str, _: usize, _: usize, - f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } fn emit_enum_struct_variant_field(&mut self, _: &str, _: usize, - f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_struct(&mut self, _: &str, _: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_struct(&mut self, _: &str, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_struct_field(&mut self, _: &str, _: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_struct_field(&mut self, _: &str, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_tuple(&mut self, _: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_tuple_arg(&mut self, _: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple_arg(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_tuple_struct(&mut self, _: &str, len: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple_struct(&mut self, _: &str, len: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { self.emit_tuple(len, f) } - fn emit_tuple_struct_arg(&mut self, f_idx: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple_struct_arg(&mut self, f_idx: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { self.emit_tuple_arg(f_idx, f) } - fn emit_option(&mut self, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_option(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_option_none(&mut self) -> EwResult { - self.writer.write_u8(0) + fn emit_option_none(&mut self) -> EncodingResult<()> { + self.writer.write_u8(0).map_err(wrap_io) } - fn emit_option_some(&mut self, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { - try!(self.writer.write_u8(1)); + fn emit_option_some(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { + try!(self.writer.write_u8(1).map_err(wrap_io)); f(self) } - fn emit_seq(&mut self, len: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_seq(&mut self, len: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { try!(self.emit_usize(len)); f(self) } - fn emit_seq_elt(&mut self, _: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_seq_elt(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_map(&mut self, len: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_map(&mut self, len: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { try!(self.emit_usize(len)); f(self) } - fn emit_map_elt_key(&mut self, _: usize, mut f: F) -> EwResult where - F: FnMut(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_map_elt_key(&mut self, _: usize, mut f: F) -> EncodingResult<()> where + F: FnMut(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_map_elt_val(&mut self, _: usize, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_map_elt_val(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } + +} + +impl Encoder for SizeChecker { + type Error = EncodingError; + + fn emit_nil(&mut self) -> EncodingResult<()> { Ok(()) } + fn emit_usize(&mut self, v: usize) -> EncodingResult<()> { + self.add_value(v as u64) + } + fn emit_u64(&mut self, v: u64) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_u32(&mut self, v: u32) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_u16(&mut self, v: u16) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_u8(&mut self, v: u8) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_isize(&mut self, v: isize) -> EncodingResult<()> { + self.add_value(v as i64) + } + fn emit_i64(&mut self, v: i64) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_i32(&mut self, v: i32) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_i16(&mut self, v: i16) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_i8(&mut self, v: i8) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_bool(&mut self, _: bool) -> EncodingResult<()> { + self.add_value(0 as u8) + } + fn emit_f64(&mut self, v: f64) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_f32(&mut self, v: f32) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_char(&mut self, v: char) -> EncodingResult<()> { + self.add_raw(v.len_utf8()) + } + fn emit_str(&mut self, v: &str) -> EncodingResult<()> { + self.add_raw(v.len()) + } + fn emit_enum(&mut self, __: &str, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_enum_variant(&mut self, _: &str, + v_id: usize, + _: usize, + f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.add_value(v_id as u32)); + f(self) + } + fn emit_enum_variant_arg(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_enum_struct_variant(&mut self, _: &str, + _: usize, + _: usize, + f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_enum_struct_variant_field(&mut self, + _: &str, + _: usize, + f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_struct(&mut self, _: &str, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_struct_field(&mut self, _: &str, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_tuple(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_tuple_arg(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_tuple_struct(&mut self, _: &str, len: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + self.emit_tuple(len, f) + } + fn emit_tuple_struct_arg(&mut self, f_idx: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + self.emit_tuple_arg(f_idx, f) + } + fn emit_option(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_option_none(&mut self) -> EncodingResult<()> { + self.add_value(0 as u8) + } + fn emit_option_some(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.add_value(1 as u8)); + f(self) + } + fn emit_seq(&mut self, len: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.emit_usize(len)); + f(self) + } + fn emit_seq_elt(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_map(&mut self, len: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.emit_usize(len)); + f(self) + } + fn emit_map_elt_key(&mut self, _: usize, mut f: F) -> EncodingResult<()> where + F: FnMut(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_map_elt_val(&mut self, _: usize, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + }