From f6cbf9a70d4484c023f8f68a51f22fc556b9e047 Mon Sep 17 00:00:00 2001 From: Ty Overby Date: Wed, 7 Jan 2015 22:04:54 -0800 Subject: [PATCH] add bounds checking to DecoderReader --- src/reader.rs | 44 ++++++++++++++++++++++++++++++++++++++++---- src/test.rs | 13 ++++++++++++- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/reader.rs b/src/reader.rs index 9ff116a..953adba 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,4 +1,5 @@ use std::io::{Buffer, Reader, IoError, IoResult, OtherIoError}; +use std::num::{cast, NumCast}; use std::error::Error; use rustc_serialize::Decoder; @@ -7,18 +8,40 @@ use super::SizeLimit; pub struct DecoderReader<'a, R: 'a> { reader: &'a mut R, - size_limit: SizeLimit + size_limit: SizeLimit, + read: u64 } impl<'a, R: Reader+Buffer> DecoderReader<'a, R> { pub fn new(r: &'a mut R, size_limit: SizeLimit) -> DecoderReader<'a, R> { DecoderReader { reader: r, - size_limit: size_limit + size_limit: size_limit, + read: 0 } } } +impl <'a, A> DecoderReader<'a, A> { + fn read_bytes(&mut self, count: I) -> Result<(), IoError> + where I: NumCast { + self.read += cast(count).unwrap(); + match self.size_limit { + SizeLimit::Infinite => Ok(()), + SizeLimit::UpperBound(x) if self.read <= x => Ok(()), + SizeLimit::UpperBound(_) => Err(IoError{ + kind: OtherIoError, + desc: "The max number of bytes has been read from this reader.", + detail: None + }) + } + } + fn read_type(&mut self) -> Result<(), IoError> { + use std::intrinsics::size_of; + unsafe{ self.read_bytes(size_of::()) } + } +} + impl<'a, R: Reader+Buffer> Decoder for DecoderReader<'a, R> { type Error = IoError; @@ -29,49 +52,62 @@ impl<'a, R: Reader+Buffer> Decoder for DecoderReader<'a, R> { self.read_u64().map(|x| x as uint) } fn read_u64(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_u64() } fn read_u32(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_u32() } fn read_u16(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_u16() } fn read_u8(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_u8() } fn read_int(&mut self) -> IoResult { self.read_i64().map(|x| x as int) } fn read_i64(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_i64() } fn read_i32(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_i32() } fn read_i16(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_i16() } fn read_i8(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_i8() } fn read_bool(&mut self) -> IoResult { - match try!(self.reader.read_i8()) { + match try!(self.read_i8()) { 1 => Ok(true), _ => Ok(false) } } fn read_f64(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_f64() } fn read_f32(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_be_f32() } fn read_char(&mut self) -> IoResult { + try!(self.read_type::()); self.reader.read_char() } fn read_str(&mut self) -> IoResult { let len = try!(self.read_uint()); + + try!(self.read_bytes(len)); let vector = try!(self.reader.read_exact(len)); String::from_utf8(vector).map_err(|e| IoError { kind: OtherIoError, @@ -143,7 +179,7 @@ impl<'a, R: Reader+Buffer> Decoder for DecoderReader<'a, R> { } fn read_option(&mut self, mut f: F) -> IoResult where F: FnMut(&mut DecoderReader<'a, R>, bool) -> IoResult { - match try!(self.reader.read_u8()) { + match try!(self.read_u8()) { 1 => f(self, true), _ => f(self, false) } diff --git a/src/test.rs b/src/test.rs index ff41263..017287c 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,7 +14,7 @@ use super::{ encode, decode, }; -use super::SizeLimit::Infinite; +use super::SizeLimit::{Infinite, UpperBound}; fn the_same<'a, V>(element: V) where V: Encodable, V: Decodable, V: PartialEq, V: Show { assert!(element == decode(encode(&element, Infinite).unwrap(), Infinite).unwrap()); @@ -175,3 +175,14 @@ fn bad_unicode() { assert!(decoded.is_err()); } + +#[test] +fn too_big_decode() { + let encoded = vec![0,0,0,3]; + let decoded: Result = decode(encoded, UpperBound(3)); + assert!(decoded.is_err()); + + let encoded = vec![0,0,0,3]; + let decoded: Result = decode(encoded, UpperBound(4)); + assert!(decoded.is_ok()); +}