diff --git a/src/ws/client.rs b/src/ws/client.rs index 18789fef8..c79451c92 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -273,10 +273,21 @@ impl Client { } } +enum ContinuationOpCode { + Binary, + Text +} + +struct Continuation { + opcode: ContinuationOpCode, + buffer: Vec, +} + struct Inner { tx: UnboundedSender, rx: PayloadBuffer>, closed: bool, + continuation: Option, } /// Future that implementes client websocket handshake process. @@ -433,6 +444,7 @@ impl Future for ClientHandshake { tx: self.tx.take().unwrap(), rx: PayloadBuffer::new(resp.payload()), closed: false, + continuation: None, }; let inner = Rc::new(RefCell::new(inner)); @@ -475,13 +487,46 @@ impl Stream for ClientReader { // read match Frame::parse(&mut inner.rx, no_masking, max_size) { Ok(Async::Ready(Some(frame))) => { - let (_finished, opcode, payload) = frame.unpack(); + let (finished, opcode, payload) = frame.unpack(); match opcode { - // continuation is not supported OpCode::Continue => { - inner.closed = true; - Err(ProtocolError::NoContinuation) + if !finished { + let inner = &mut *inner; + match inner.continuation { + Some(ref mut continuation) => { + continuation.buffer.append(&mut Vec::from(payload.as_ref())); + Ok(Async::NotReady) + } + None => { + inner.closed = true; + Err(ProtocolError::BadContinuation) + } + } + } else { + match inner.continuation.take() { + Some(Continuation {opcode, mut buffer}) => { + buffer.append(&mut Vec::from(payload.as_ref())); + match opcode { + ContinuationOpCode::Binary => + Ok(Async::Ready(Some(Message::Binary(Binary::from(buffer))))), + ContinuationOpCode::Text => { + match String::from_utf8(buffer) { + Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), + Err(_) => { + inner.closed = true; + Err(ProtocolError::BadEncoding) + } + } + } + } + } + None => { + inner.closed = true; + Err(ProtocolError::BadContinuation) + } + } + } } OpCode::Bad => { inner.closed = true; @@ -498,15 +543,33 @@ impl Stream for ClientReader { OpCode::Pong => Ok(Async::Ready(Some(Message::Pong( String::from_utf8_lossy(payload.as_ref()).into(), )))), - OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))), + OpCode::Binary => { + if finished { + Ok(Async::Ready(Some(Message::Binary(payload)))) + } else { + inner.continuation = Some(Continuation { + opcode: ContinuationOpCode::Binary, + buffer: Vec::from(payload.as_ref()) + }); + Ok(Async::NotReady) + } + } OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => { - inner.closed = true; - Err(ProtocolError::BadEncoding) + if finished { + let tmp = Vec::from(payload.as_ref()); + match String::from_utf8(tmp) { + Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), + Err(_) => { + inner.closed = true; + Err(ProtocolError::BadEncoding) + } } + } else { + inner.continuation = Some(Continuation { + opcode: ContinuationOpCode::Text, + buffer: Vec::from(payload.as_ref()) + }); + Ok(Async::NotReady) } } } diff --git a/src/ws/mod.rs b/src/ws/mod.rs index c16f8d6d2..0a65175c5 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -88,9 +88,9 @@ pub enum ProtocolError { /// A payload reached size limit. #[fail(display = "A payload reached size limit.")] Overflow, - /// Continuation is not supported - #[fail(display = "Continuation is not supported.")] - NoContinuation, + /// Bad continuation frame sequence. + #[fail(display = "Bad continuation frame sequence.")] + BadContinuation, /// Bad utf-8 encoding #[fail(display = "Bad utf-8 encoding.")] BadEncoding, @@ -250,11 +250,22 @@ pub fn handshake( .take()) } +enum ContinuationOpCode { + Binary, + Text +} + +struct Continuation { + opcode: ContinuationOpCode, + buffer: Vec, +} + /// Maps `Payload` stream into stream of `ws::Message` items pub struct WsStream { rx: PayloadBuffer, closed: bool, max_size: usize, + continuation: Option, } impl WsStream @@ -267,6 +278,7 @@ where rx: PayloadBuffer::new(stream), closed: false, max_size: 65_536, + continuation: None } } @@ -279,6 +291,8 @@ where } } + + impl Stream for WsStream where S: Stream, @@ -295,14 +309,44 @@ where Ok(Async::Ready(Some(frame))) => { let (finished, opcode, payload) = frame.unpack(); - // continuation is not supported - if !finished { - self.closed = true; - return Err(ProtocolError::NoContinuation); - } - match opcode { - OpCode::Continue => Err(ProtocolError::NoContinuation), + OpCode::Continue => { + if !finished { + match self.continuation { + Some(ref mut continuation) => { + continuation.buffer.append(&mut Vec::from(payload.as_ref())); + Ok(Async::NotReady) + } + None => { + self.closed = true; + Err(ProtocolError::BadContinuation) + } + } + } else { + match self.continuation.take() { + Some(Continuation {opcode, mut buffer}) => { + buffer.append(&mut Vec::from(payload.as_ref())); + match opcode { + ContinuationOpCode::Binary => + Ok(Async::Ready(Some(Message::Binary(Binary::from(buffer))))), + ContinuationOpCode::Text => { + match String::from_utf8(buffer) { + Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), + Err(_) => { + self.closed = true; + Err(ProtocolError::BadEncoding) + } + } + } + } + } + None => { + self.closed = true; + Err(ProtocolError::BadContinuation) + } + } + } + } OpCode::Bad => { self.closed = true; Err(ProtocolError::BadOpCode) @@ -318,15 +362,33 @@ where OpCode::Pong => Ok(Async::Ready(Some(Message::Pong( String::from_utf8_lossy(payload.as_ref()).into(), )))), - OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))), + OpCode::Binary => { + if finished { + Ok(Async::Ready(Some(Message::Binary(payload)))) + } else { + self.continuation = Some(Continuation { + opcode: ContinuationOpCode::Binary, + buffer: Vec::from(payload.as_ref()) + }); + Ok(Async::NotReady) + } + } OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => { - self.closed = true; - Err(ProtocolError::BadEncoding) + if finished { + let tmp = Vec::from(payload.as_ref()); + match String::from_utf8(tmp) { + Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), + Err(_) => { + self.closed = true; + Err(ProtocolError::BadEncoding) + } } + } else { + self.continuation = Some(Continuation { + opcode: ContinuationOpCode::Text, + buffer: Vec::from(payload.as_ref()) + }); + Ok(Async::NotReady) } } }