diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 99edfd773..032db31ea 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -35,13 +35,17 @@ pub enum Frame { Pong(Bytes), /// Close message with optional reason Close(Option), + /// Active continuation + Continue, } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] /// WebSockets protocol codec pub struct Codec { max_size: usize, server: bool, + cont_code: Option, + buf: Vec, } impl Codec { @@ -50,6 +54,8 @@ impl Codec { Codec { max_size: 65_536, server: true, + buf: vec![], + cont_code: None, } } @@ -68,6 +74,27 @@ impl Codec { self.server = false; self } + + fn combine_payload(&mut self, payload: Option) -> Option { + let mut size: usize = if let Some(ref pl) = payload { + pl.len() + } else { + 0 + }; + size += self.buf.iter().map(|pl| pl.len()).sum::(); + if size > 0 { + let mut res = BytesMut::with_capacity(size); + for pl in self.buf.drain(..) { + res.extend_from_slice(&pl) + } + if let Some(pl) = payload { + res.extend_from_slice(&pl) + } + Some(res) + } else { + None + } + } } impl Encoder for Codec { @@ -101,12 +128,7 @@ impl Decoder for Codec { fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match Parser::parse(src, self.server, self.max_size) { - Ok(Some((finished, rsv, opcode, payload))) => { - // continuation is not supported - if !finished { - return Err(ProtocolError::NoContinuation); - } - + Ok(Some((finished, rsv, mut opcode, mut payload))) => { // Since this is the default codec we have no extension // and should fail if rsv is set. // In an async context this will cause a NON-STRICT @@ -116,8 +138,49 @@ impl Decoder for Codec { return Err(ProtocolError::RSVSet); } + if !finished { + if (opcode == OpCode::Text || opcode == OpCode::Binary) + && self.cont_code.is_none() + { + // We are starting a new continuation + self.cont_code = Some(opcode); + if let Some(pl) = payload { + self.buf.push(pl); + } + return Ok(Some(Frame::Continue)); + } else if opcode == OpCode::Continue && self.cont_code.is_some() { + // We continue a continuation + if let Some(pl) = payload { + self.buf.push(pl); + }; + return Ok(Some(Frame::Continue)); + } else { + return Err(ProtocolError::NoContinuation); + } + } else if opcode == OpCode::Continue { + // We finish a continuation + if let Some(orig_opcode) = self.cont_code { + // reset saved opcode + self.cont_code = None; + // put cached code into current opciode + opcode = orig_opcode; + // Collect the payload + payload = self.combine_payload(payload) + } else { + // We have a continuation finish op code but nothing to continue, + // this is an error + return Err(ProtocolError::NoContinuation); + } + } else if self.cont_code.is_some() + && (opcode == OpCode::Binary || opcode == OpCode::Text) + { + // We are finished but this isn't a continuation and + // we still have a started continuation + return Err(ProtocolError::NoContinuation); + } + match opcode { - OpCode::Continue => Err(ProtocolError::NoContinuation), + OpCode::Continue => unreachable!(), OpCode::Bad => Err(ProtocolError::BadOpCode), OpCode::Close => { if let Some(ref pl) = payload { diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index b6b627962..e0b7877f7 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -288,7 +288,7 @@ where inner: ContextParts::new(mb.sender_producer()), messages: VecDeque::new(), }; - ctx.add_stream(WsStream::new(stream, codec)); + ctx.add_stream(WsStream::new(stream, codec.clone())); WebsocketContextFut::new(ctx, actor, mb, codec) } @@ -530,6 +530,7 @@ where Frame::Ping(s) => Message::Ping(s), Frame::Pong(s) => Message::Pong(s), Frame::Close(reason) => Message::Close(reason), + Frame::Continue => Message::Nop, }; Ok(Async::Ready(Some(msg))) }