From 08e726cb10fe61cb568242a473108f4cc027faa9 Mon Sep 17 00:00:00 2001 From: Ty Overby Date: Thu, 8 Jan 2015 22:32:59 -0800 Subject: [PATCH] Added size checking on serialization. Specifying a bound on serialization is optional, so when specified, it will run a pass over the object and make sure that the serialized object will fit into the required amount of bytes. --- src/lib.rs | 22 ++-- src/test.rs | 11 +- src/writer.rs | 340 +++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 291 insertions(+), 82 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 460d9ff..3903d3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,15 +4,12 @@ extern crate "rustc-serialize" as rustc_serialize; -use std::io::Buffer; -use std::io::MemWriter; -use std::io::MemReader; -use std::io::IoResult; -use rustc_serialize::Encodable; -use rustc_serialize::Decodable; +use std::io::{Buffer, MemWriter, MemReader}; +use rustc_serialize::{Encodable, Decodable}; -pub use writer::EncoderWriter; +pub use writer::{EncoderWriter, EncodingResult, EncodingError}; pub use reader::{DecoderReader, DecodingResult, DecodingError}; +use writer::SizeChecker; mod writer; mod reader; @@ -23,7 +20,7 @@ pub enum SizeLimit { UpperBound(u64) } -pub fn encode(t: &T, size_limit: SizeLimit) -> IoResult> { +pub fn encode(t: &T, size_limit: SizeLimit) -> EncodingResult> { let mut w = MemWriter::new(); match encode_into(t, &mut w, size_limit) { Ok(()) => Ok(w.into_inner()), @@ -35,7 +32,14 @@ pub fn decode(b: Vec, size_limit: SizeLimit) -> DecodingResult decode_from(&mut MemReader::new(b), size_limit) } -pub fn encode_into(t: &T, w: &mut W, size_limit: SizeLimit) -> IoResult<()> { +pub fn encode_into(t: &T, w: &mut W, size_limit: SizeLimit) -> EncodingResult<()> { + try!(match size_limit { + SizeLimit::Infinite => Ok(()), + SizeLimit::UpperBound(x) => { + let mut size_checker = SizeChecker::new(x); + t.encode(&mut size_checker) + } + }); t.encode(&mut writer::EncoderWriter::new(w, size_limit)) } diff --git a/src/test.rs b/src/test.rs index 62cca5c..96024f9 100644 --- a/src/test.rs +++ b/src/test.rs @@ -202,8 +202,17 @@ fn too_big_decode() { } #[test] -fn too_big_char() { +fn too_big_char_decode() { let encoded = vec![0x41]; let decoded: Result = decode(encoded, UpperBound(1)); assert_eq!(decoded, Ok('A')); } + +#[test] +fn too_big_encode() { + assert!(encode(&0u32, UpperBound(3)).is_err()); + assert!(encode(&0u32, UpperBound(4)).is_ok()); + + assert!(encode(&"abcde", UpperBound(4)).is_err()); + assert!(encode(&"abcde", UpperBound(5)).is_ok()); +} diff --git a/src/writer.rs b/src/writer.rs index 6b5bcce..30317a4 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -1,17 +1,49 @@ -use std::io::{Writer, IoError, IoResult}; +use std::io::{Writer, IoError}; +use std::error::Error; use std::num::Int; use rustc_serialize::Encoder; use super::SizeLimit; -type EwResult = IoResult<()>; +pub type EncodingResult = Result; + +#[derive(Show)] +pub enum EncodingError { + IoError(IoError), + SizeLimit +} pub struct EncoderWriter<'a, W: 'a> { writer: &'a mut W, _size_limit: SizeLimit } +pub struct SizeChecker { + size_limit: u64, + written: u64 +} + +fn wrap_io(err: IoError) -> EncodingError { + EncodingError::IoError(err) +} + +impl Error for EncodingError { + fn description(&self) -> &str { + match *self { + EncodingError::IoError(ref err) => err.description(), + EncodingError::SizeLimit => "the size limit for decoding has been reached" + } + } + + fn detail(&self) -> Option { + match *self { + EncodingError::IoError(ref err) => err.detail(), + EncodingError::SizeLimit => None + } + } +} + impl <'a, W: Writer> EncoderWriter<'a, W> { pub fn new(w: &'a mut W, size_limit: SizeLimit) -> EncoderWriter<'a, W> { EncoderWriter { @@ -21,65 +53,88 @@ impl <'a, W: Writer> EncoderWriter<'a, W> { } } -impl<'a, W: Writer> Encoder for EncoderWriter<'a, W> { - type Error = IoError; +impl SizeChecker { + pub fn new(limit: u64) -> SizeChecker { + SizeChecker { + size_limit: limit, + written: 0 + } + } - fn emit_nil(&mut self) -> EwResult { Ok(()) } - fn emit_uint(&mut self, v: uint) -> EwResult { + fn add_raw(&mut self, size: uint) -> EncodingResult<()> { + self.written += size as u64; + if self.written <= self.size_limit { + Ok(()) + } else { + Err(EncodingError::SizeLimit) + } + } + + fn add_value(&mut self, t: T) -> EncodingResult<()> { + use std::mem::size_of_val; + self.add_raw(size_of_val(&t)) + } +} + +impl<'a, W: Writer> Encoder for EncoderWriter<'a, W> { + type Error = EncodingError; + + fn emit_nil(&mut self) -> EncodingResult<()> { Ok(()) } + fn emit_uint(&mut self, v: uint) -> EncodingResult<()> { self.emit_u64(v as u64) } - fn emit_u64(&mut self, v: u64) -> EwResult { - self.writer.write_be_u64(v) + fn emit_u64(&mut self, v: u64) -> EncodingResult<()> { + self.writer.write_be_u64(v).map_err(wrap_io) } - fn emit_u32(&mut self, v: u32) -> EwResult { - self.writer.write_be_u32(v) + fn emit_u32(&mut self, v: u32) -> EncodingResult<()> { + self.writer.write_be_u32(v).map_err(wrap_io) } - fn emit_u16(&mut self, v: u16) -> EwResult { - self.writer.write_be_u16(v) + fn emit_u16(&mut self, v: u16) -> EncodingResult<()> { + self.writer.write_be_u16(v).map_err(wrap_io) } - fn emit_u8(&mut self, v: u8) -> EwResult { - self.writer.write_u8(v) + fn emit_u8(&mut self, v: u8) -> EncodingResult<()> { + self.writer.write_u8(v).map_err(wrap_io) } - fn emit_int(&mut self, v: int) -> EwResult { + fn emit_int(&mut self, v: int) -> EncodingResult<()> { self.emit_i64(v as i64) } - fn emit_i64(&mut self, v: i64) -> EwResult { - self.writer.write_be_i64(v) + fn emit_i64(&mut self, v: i64) -> EncodingResult<()> { + self.writer.write_be_i64(v).map_err(wrap_io) } - fn emit_i32(&mut self, v: i32) -> EwResult { - self.writer.write_be_i32(v) + fn emit_i32(&mut self, v: i32) -> EncodingResult<()> { + self.writer.write_be_i32(v).map_err(wrap_io) } - fn emit_i16(&mut self, v: i16) -> EwResult { - self.writer.write_be_i16(v) + fn emit_i16(&mut self, v: i16) -> EncodingResult<()> { + self.writer.write_be_i16(v).map_err(wrap_io) } - fn emit_i8(&mut self, v: i8) -> EwResult { - self.writer.write_i8(v) + fn emit_i8(&mut self, v: i8) -> EncodingResult<()> { + self.writer.write_i8(v).map_err(wrap_io) } - fn emit_bool(&mut self, v: bool) -> EwResult { - self.writer.write_u8(if v {1} else {0}) + fn emit_bool(&mut self, v: bool) -> EncodingResult<()> { + self.writer.write_u8(if v {1} else {0}).map_err(wrap_io) } - fn emit_f64(&mut self, v: f64) -> EwResult { - self.writer.write_be_f64(v) + fn emit_f64(&mut self, v: f64) -> EncodingResult<()> { + self.writer.write_be_f64(v).map_err(wrap_io) } - fn emit_f32(&mut self, v: f32) -> EwResult { - self.writer.write_be_f32(v) + fn emit_f32(&mut self, v: f32) -> EncodingResult<()> { + self.writer.write_be_f32(v).map_err(wrap_io) } - fn emit_char(&mut self, v: char) -> EwResult { - self.writer.write_char(v) + fn emit_char(&mut self, v: char) -> EncodingResult<()> { + self.writer.write_char(v).map_err(wrap_io) } - fn emit_str(&mut self, v: &str) -> EwResult { + fn emit_str(&mut self, v: &str) -> EncodingResult<()> { try!(self.emit_uint(v.len())); - self.writer.write_str(v) + self.writer.write_str(v).map_err(wrap_io) } - fn emit_enum(&mut self, __: &str, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_enum(&mut self, __: &str, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } fn emit_enum_variant(&mut self, _: &str, v_id: uint, _: uint, - f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { let max_u32: u32 = Int::max_value(); if v_id > (max_u32 as uint) { panic!("Variant tag doesn't fit in a u32") @@ -87,80 +142,221 @@ impl<'a, W: Writer> Encoder for EncoderWriter<'a, W> { try!(self.emit_u32(v_id as u32)); f(self) } - fn emit_enum_variant_arg(&mut self, _: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_enum_variant_arg(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } fn emit_enum_struct_variant(&mut self, _: &str, _: uint, _: uint, - f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } fn emit_enum_struct_variant_field(&mut self, _: &str, _: uint, - f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_struct(&mut self, _: &str, _: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_struct(&mut self, _: &str, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_struct_field(&mut self, _: &str, _: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_struct_field(&mut self, _: &str, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_tuple(&mut self, _: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_tuple_arg(&mut self, _: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple_arg(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_tuple_struct(&mut self, _: &str, len: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple_struct(&mut self, _: &str, len: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { self.emit_tuple(len, f) } - fn emit_tuple_struct_arg(&mut self, f_idx: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_tuple_struct_arg(&mut self, f_idx: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { self.emit_tuple_arg(f_idx, f) } - fn emit_option(&mut self, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_option(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_option_none(&mut self) -> EwResult { - self.writer.write_u8(0) + fn emit_option_none(&mut self) -> EncodingResult<()> { + self.writer.write_u8(0).map_err(wrap_io) } - fn emit_option_some(&mut self, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { - try!(self.writer.write_u8(1)); + fn emit_option_some(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { + try!(self.writer.write_u8(1).map_err(wrap_io)); f(self) } - fn emit_seq(&mut self, len: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_seq(&mut self, len: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { try!(self.emit_uint(len)); f(self) } - fn emit_seq_elt(&mut self, _: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_seq_elt(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_map(&mut self, len: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_map(&mut self, len: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { try!(self.emit_uint(len)); f(self) } - fn emit_map_elt_key(&mut self, _: uint, mut f: F) -> EwResult where - F: FnMut(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_map_elt_key(&mut self, _: uint, mut f: F) -> EncodingResult<()> where + F: FnMut(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } - fn emit_map_elt_val(&mut self, _: uint, f: F) -> EwResult where - F: FnOnce(&mut EncoderWriter<'a, W>) -> EwResult { + fn emit_map_elt_val(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut EncoderWriter<'a, W>) -> EncodingResult<()> { f(self) } + +} + +impl Encoder for SizeChecker { + type Error = EncodingError; + + fn emit_nil(&mut self) -> EncodingResult<()> { Ok(()) } + fn emit_uint(&mut self, v: uint) -> EncodingResult<()> { + self.add_value(v as u64) + } + fn emit_u64(&mut self, v: u64) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_u32(&mut self, v: u32) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_u16(&mut self, v: u16) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_u8(&mut self, v: u8) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_int(&mut self, v: int) -> EncodingResult<()> { + self.add_value(v as i64) + } + fn emit_i64(&mut self, v: i64) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_i32(&mut self, v: i32) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_i16(&mut self, v: i16) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_i8(&mut self, v: i8) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_bool(&mut self, _: bool) -> EncodingResult<()> { + self.add_value(0 as u8) + } + fn emit_f64(&mut self, v: f64) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_f32(&mut self, v: f32) -> EncodingResult<()> { + self.add_value(v) + } + fn emit_char(&mut self, v: char) -> EncodingResult<()> { + self.add_raw(v.len_utf8()) + } + fn emit_str(&mut self, v: &str) -> EncodingResult<()> { + self.add_raw(v.len()) + } + fn emit_enum(&mut self, __: &str, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_enum_variant(&mut self, _: &str, + v_id: uint, + _: uint, + f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.add_value(v_id as u32)); + f(self) + } + fn emit_enum_variant_arg(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_enum_struct_variant(&mut self, _: &str, + _: uint, + _: uint, + f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_enum_struct_variant_field(&mut self, + _: &str, + _: uint, + f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_struct(&mut self, _: &str, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_struct_field(&mut self, _: &str, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_tuple(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_tuple_arg(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_tuple_struct(&mut self, _: &str, len: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + self.emit_tuple(len, f) + } + fn emit_tuple_struct_arg(&mut self, f_idx: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + self.emit_tuple_arg(f_idx, f) + } + fn emit_option(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_option_none(&mut self) -> EncodingResult<()> { + self.add_value(0 as u8) + } + fn emit_option_some(&mut self, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.add_value(1 as u8)); + f(self) + } + fn emit_seq(&mut self, len: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.emit_uint(len)); + f(self) + } + fn emit_seq_elt(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_map(&mut self, len: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + try!(self.emit_uint(len)); + f(self) + } + fn emit_map_elt_key(&mut self, _: uint, mut f: F) -> EncodingResult<()> where + F: FnMut(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + fn emit_map_elt_val(&mut self, _: uint, f: F) -> EncodingResult<()> where + F: FnOnce(&mut SizeChecker) -> EncodingResult<()> { + f(self) + } + }