add bounds checking to DecoderReader

This commit is contained in:
Ty Overby 2015-01-07 22:04:54 -08:00
parent b4023f5281
commit f6cbf9a70d
2 changed files with 52 additions and 5 deletions

View File

@ -1,4 +1,5 @@
use std::io::{Buffer, Reader, IoError, IoResult, OtherIoError}; use std::io::{Buffer, Reader, IoError, IoResult, OtherIoError};
use std::num::{cast, NumCast};
use std::error::Error; use std::error::Error;
use rustc_serialize::Decoder; use rustc_serialize::Decoder;
@ -7,18 +8,40 @@ use super::SizeLimit;
pub struct DecoderReader<'a, R: 'a> { pub struct DecoderReader<'a, R: 'a> {
reader: &'a mut R, reader: &'a mut R,
size_limit: SizeLimit size_limit: SizeLimit,
read: u64
} }
impl<'a, R: Reader+Buffer> DecoderReader<'a, R> { impl<'a, R: Reader+Buffer> DecoderReader<'a, R> {
pub fn new(r: &'a mut R, size_limit: SizeLimit) -> DecoderReader<'a, R> { pub fn new(r: &'a mut R, size_limit: SizeLimit) -> DecoderReader<'a, R> {
DecoderReader { DecoderReader {
reader: r, reader: r,
size_limit: size_limit size_limit: size_limit,
read: 0
} }
} }
} }
impl <'a, A> DecoderReader<'a, A> {
fn read_bytes<I>(&mut self, count: I) -> Result<(), IoError>
where I: NumCast {
self.read += cast(count).unwrap();
match self.size_limit {
SizeLimit::Infinite => Ok(()),
SizeLimit::UpperBound(x) if self.read <= x => Ok(()),
SizeLimit::UpperBound(_) => Err(IoError{
kind: OtherIoError,
desc: "The max number of bytes has been read from this reader.",
detail: None
})
}
}
fn read_type<T>(&mut self) -> Result<(), IoError> {
use std::intrinsics::size_of;
unsafe{ self.read_bytes(size_of::<T>()) }
}
}
impl<'a, R: Reader+Buffer> Decoder for DecoderReader<'a, R> { impl<'a, R: Reader+Buffer> Decoder for DecoderReader<'a, R> {
type Error = IoError; type Error = IoError;
@ -29,49 +52,62 @@ impl<'a, R: Reader+Buffer> Decoder for DecoderReader<'a, R> {
self.read_u64().map(|x| x as uint) self.read_u64().map(|x| x as uint)
} }
fn read_u64(&mut self) -> IoResult<u64> { fn read_u64(&mut self) -> IoResult<u64> {
try!(self.read_type::<u64>());
self.reader.read_be_u64() self.reader.read_be_u64()
} }
fn read_u32(&mut self) -> IoResult<u32> { fn read_u32(&mut self) -> IoResult<u32> {
try!(self.read_type::<u32>());
self.reader.read_be_u32() self.reader.read_be_u32()
} }
fn read_u16(&mut self) -> IoResult<u16> { fn read_u16(&mut self) -> IoResult<u16> {
try!(self.read_type::<u16>());
self.reader.read_be_u16() self.reader.read_be_u16()
} }
fn read_u8(&mut self) -> IoResult<u8> { fn read_u8(&mut self) -> IoResult<u8> {
try!(self.read_type::<u8>());
self.reader.read_u8() self.reader.read_u8()
} }
fn read_int(&mut self) -> IoResult<int> { fn read_int(&mut self) -> IoResult<int> {
self.read_i64().map(|x| x as int) self.read_i64().map(|x| x as int)
} }
fn read_i64(&mut self) -> IoResult<i64> { fn read_i64(&mut self) -> IoResult<i64> {
try!(self.read_type::<i64>());
self.reader.read_be_i64() self.reader.read_be_i64()
} }
fn read_i32(&mut self) -> IoResult<i32> { fn read_i32(&mut self) -> IoResult<i32> {
try!(self.read_type::<i32>());
self.reader.read_be_i32() self.reader.read_be_i32()
} }
fn read_i16(&mut self) -> IoResult<i16> { fn read_i16(&mut self) -> IoResult<i16> {
try!(self.read_type::<i16>());
self.reader.read_be_i16() self.reader.read_be_i16()
} }
fn read_i8(&mut self) -> IoResult<i8> { fn read_i8(&mut self) -> IoResult<i8> {
try!(self.read_type::<i8>());
self.reader.read_i8() self.reader.read_i8()
} }
fn read_bool(&mut self) -> IoResult<bool> { fn read_bool(&mut self) -> IoResult<bool> {
match try!(self.reader.read_i8()) { match try!(self.read_i8()) {
1 => Ok(true), 1 => Ok(true),
_ => Ok(false) _ => Ok(false)
} }
} }
fn read_f64(&mut self) -> IoResult<f64> { fn read_f64(&mut self) -> IoResult<f64> {
try!(self.read_type::<f64>());
self.reader.read_be_f64() self.reader.read_be_f64()
} }
fn read_f32(&mut self) -> IoResult<f32> { fn read_f32(&mut self) -> IoResult<f32> {
try!(self.read_type::<f32>());
self.reader.read_be_f32() self.reader.read_be_f32()
} }
fn read_char(&mut self) -> IoResult<char> { fn read_char(&mut self) -> IoResult<char> {
try!(self.read_type::<char>());
self.reader.read_char() self.reader.read_char()
} }
fn read_str(&mut self) -> IoResult<String> { fn read_str(&mut self) -> IoResult<String> {
let len = try!(self.read_uint()); let len = try!(self.read_uint());
try!(self.read_bytes(len));
let vector = try!(self.reader.read_exact(len)); let vector = try!(self.reader.read_exact(len));
String::from_utf8(vector).map_err(|e| IoError { String::from_utf8(vector).map_err(|e| IoError {
kind: OtherIoError, kind: OtherIoError,
@ -143,7 +179,7 @@ impl<'a, R: Reader+Buffer> Decoder for DecoderReader<'a, R> {
} }
fn read_option<T, F>(&mut self, mut f: F) -> IoResult<T> where fn read_option<T, F>(&mut self, mut f: F) -> IoResult<T> where
F: FnMut(&mut DecoderReader<'a, R>, bool) -> IoResult<T> { F: FnMut(&mut DecoderReader<'a, R>, bool) -> IoResult<T> {
match try!(self.reader.read_u8()) { match try!(self.read_u8()) {
1 => f(self, true), 1 => f(self, true),
_ => f(self, false) _ => f(self, false)
} }

View File

@ -14,7 +14,7 @@ use super::{
encode, encode,
decode, decode,
}; };
use super::SizeLimit::Infinite; use super::SizeLimit::{Infinite, UpperBound};
fn the_same<'a, V>(element: V) where V: Encodable, V: Decodable, V: PartialEq, V: Show { fn the_same<'a, V>(element: V) where V: Encodable, V: Decodable, V: PartialEq, V: Show {
assert!(element == decode(encode(&element, Infinite).unwrap(), Infinite).unwrap()); assert!(element == decode(encode(&element, Infinite).unwrap(), Infinite).unwrap());
@ -175,3 +175,14 @@ fn bad_unicode() {
assert!(decoded.is_err()); assert!(decoded.is_err());
} }
#[test]
fn too_big_decode() {
let encoded = vec![0,0,0,3];
let decoded: Result<u32, _> = decode(encoded, UpperBound(3));
assert!(decoded.is_err());
let encoded = vec![0,0,0,3];
let decoded: Result<u32, _> = decode(encoded, UpperBound(4));
assert!(decoded.is_ok());
}