diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 7b5b58009..4c726414f 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -35,14 +35,21 @@ pub enum Frame { Pong(String), /// Close message with optional reason Close(Option), + /// First frame of a fragmented text message + BeginText(Option), + /// First frame of a fragmented binary message + BeginBinary(Option), + /// Subsequent frame of a fragmented message + Continue(Option), + /// Final frame of a fragmented message + End(Option), } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] /// WebSockets protocol codec pub struct Codec { max_size: usize, server: bool, - collector: Option, } impl Codec { @@ -51,7 +58,6 @@ impl Codec { Codec { max_size: 65_536, server: true, - collector: None, } } @@ -104,25 +110,29 @@ impl Decoder for Codec { fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match Parser::parse(src, self.server, self.max_size) { Ok(Some((finished, opcode, payload))) => { + if !finished { + return match opcode { + OpCode::Continue => { + Ok(Some(Frame::Continue(payload))) + } + OpCode::Binary => { + Ok(Some(Frame::BeginBinary(payload))) + } + OpCode::Text => { + Ok(Some(Frame::BeginText(payload))) + } + _ => { + Err(ProtocolError::NoContinuation) + } + }; + } + match opcode { OpCode::Continue => { - match self.collector { - Some(ref mut prev) => { - if let Some(ref payload) = payload { - prev.extend_from_slice(payload); - } - } - None => self.collector = payload, - } - - if finished { - Ok(Some(Frame::Binary(self.collector.take()))) - } else { - Ok(None) - } + Ok(Some(Frame::End(payload))) } OpCode::Bad => { - error!("Bad opcode"); + debug!("Bad opcode"); Err(ProtocolError::BadOpCode) } OpCode::Close => { @@ -149,25 +159,10 @@ impl Decoder for Codec { } } OpCode::Binary => { - if finished { - Ok(Some(Frame::Binary(payload))) - } else { - self.collector = payload; - Ok(None) - } + Ok(Some(Frame::Binary(payload))) } OpCode::Text => { - if finished { - Ok(Some(Frame::Text(payload))) - } else { - self.collector = payload; - Ok(None) - } - //let tmp = Vec::from(payload.as_ref()); - //match String::from_utf8(tmp) { - // Ok(s) => Ok(Some(Message::Text(s))), - // Err(_) => Err(ProtocolError::BadEncoding), - //} + Ok(Some(Frame::Text(payload))) } } } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 891d5110d..907080b6f 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -48,8 +48,8 @@ pub enum ProtocolError { #[display(fmt = "Continuation is not supported.")] NoContinuation, /// Bad utf-8 encoding - #[display(fmt = "Bad utf-8 encoding.")] - BadEncoding, + #[display(fmt = "Bad utf-8 encoding: {}", _0)] + BadEncoding(std::str::Utf8Error), /// Io error #[display(fmt = "io error: {}", _0)] Io(io::Error), diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index a99de1f56..e781570b7 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -288,7 +288,7 @@ where inner: ContextParts::new(mb.sender_producer()), messages: VecDeque::new(), }; - ctx.add_stream(WsStream::new(stream, codec.clone())); + ctx.add_stream(WsStream::new(stream, codec)); WebsocketContextFut::new(ctx, actor, mb, codec) } @@ -453,10 +453,35 @@ where } } +enum Collector { + Text(BytesMut), + Binary(BytesMut), + None +} + +impl Collector { + fn take(&mut self) -> Collector { + std::mem::replace(self, Collector::None) + } + + fn is_none(&self) -> bool { + match self { + Collector::None => true, + _ => false, + } + } +} + struct WsStream { + /// Source stream stream: S, + /// WS codec decoder: Codec, + /// Buffer collecting data to be parsed buf: BytesMut, + /// Collector used to concatenate fragmented messages + collector: Collector, + /// Whether or not the stream is closed closed: bool, } @@ -469,6 +494,7 @@ where stream, decoder: codec, buf: BytesMut::new(), + collector: Collector::None, closed: false, } } @@ -514,24 +540,73 @@ where Some(frm) => { let msg = match frm { Frame::Text(data) => { - if let Some(data) = data { - Message::Text( - std::str::from_utf8(&data) - .map_err(|_| ProtocolError::BadEncoding)? - .to_string(), - ) + Some(if let Some(data) = data { + Message::Text(std::str::from_utf8(&data)?.to_string()) } else { Message::Text(String::new()) + }) + } + Frame::Binary(data) => Some(Message::Binary( + data.map(|b| b.freeze()).unwrap_or_else(Bytes::new), + )), + Frame::Ping(s) => Some(Message::Ping(s)), + Frame::Pong(s) => Some(Message::Pong(s)), + Frame::Close(reason) => Some(Message::Close(reason)), + Frame::BeginText(data) => { + let data = data.unwrap_or_else(|| BytesMut::new()); + + if self.collector.is_none() { + // Previous collection was not finalized + return Err(ProtocolError::NoContinuation); + } + + self.collector = Collector::Text(data); + None + } + Frame::BeginBinary(data) => { + let data = data.unwrap_or_else(|| BytesMut::new()); + + if self.collector.is_none() { + // Previous collection was not finalized + return Err(ProtocolError::NoContinuation); + } + + self.collector = Collector::Binary(data); + None + } + Frame::Continue(data) => { + let data = data.as_ref().map(|d| &**d).unwrap_or_else(|| &[]); + + match self.collector { + Collector::Text(ref mut buf) | Collector::Binary(ref mut buf) => { + buf.extend_from_slice(data); + } + // Uninitialized continuation + _ => return Err(ProtocolError::NoContinuation), + } + + None + } + Frame::End(data) => { + let data = data.as_ref().map(|d| &**d).unwrap_or_else(|| &[]); + + match self.collector.take() { + Collector::Text(mut buf) => { + buf.extend_from_slice(data); + Some(Message::Text( + std::str::from_utf8(&buf)?.to_string() + )) + } + Collector::Binary(mut buf) => { + buf.extend_from_slice(data); + Some(Message::Binary(buf.freeze())) + } + // Uninitialized continuation + Collector::None => return Err(ProtocolError::NoContinuation), } } - Frame::Binary(data) => Message::Binary( - data.map(|b| b.freeze()).unwrap_or_else(Bytes::new), - ), - Frame::Ping(s) => Message::Ping(s), - Frame::Pong(s) => Message::Pong(s), - Frame::Close(reason) => Message::Close(reason), }; - Ok(Async::Ready(Some(msg))) + Ok(Async::Ready(msg)) } } }