diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index a37208a2b..7233105b1 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -12,8 +12,6 @@ pub enum Message { Text(String), /// Binary message Binary(Bytes), - /// Continuation - Continuation(Item), /// Ping message Ping(Bytes), /// Pong message @@ -92,52 +90,52 @@ impl Codec { } impl Encoder for Codec { - type Item = Message; + type Item = Frame; type Error = ProtocolError; - fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - Message::Text(txt) => Parser::write_message( + Frame::Text(txt) => Parser::write_frame( dst, txt, OpCode::Text, true, !self.flags.contains(Flags::SERVER), ), - Message::Binary(bin) => Parser::write_message( + Frame::Binary(bin) => Parser::write_frame( dst, bin, OpCode::Binary, true, !self.flags.contains(Flags::SERVER), ), - Message::Ping(txt) => Parser::write_message( + Frame::Ping(txt) => Parser::write_frame( dst, txt, OpCode::Ping, true, !self.flags.contains(Flags::SERVER), ), - Message::Pong(txt) => Parser::write_message( + Frame::Pong(txt) => Parser::write_frame( dst, txt, OpCode::Pong, true, !self.flags.contains(Flags::SERVER), ), - Message::Close(reason) => { + Frame::Close(reason) => { Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER)) } - Message::Continuation(cont) => match cont { + Frame::Continuation(cont) => match cont { Item::FirstText(data) => { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { self.flags.insert(Flags::W_CONTINUATION); - Parser::write_message( + Parser::write_frame( dst, &data[..], - OpCode::Binary, + OpCode::Text, false, !self.flags.contains(Flags::SERVER), ) @@ -148,10 +146,10 @@ impl Encoder for Codec { return Err(ProtocolError::ContinuationStarted); } else { self.flags.insert(Flags::W_CONTINUATION); - Parser::write_message( + Parser::write_frame( dst, &data[..], - OpCode::Text, + OpCode::Binary, false, !self.flags.contains(Flags::SERVER), ) @@ -159,7 +157,7 @@ impl Encoder for Codec { } Item::Continue(data) => { if self.flags.contains(Flags::W_CONTINUATION) { - Parser::write_message( + Parser::write_frame( dst, &data[..], OpCode::Continue, @@ -173,7 +171,7 @@ impl Encoder for Codec { Item::Last(data) => { if self.flags.contains(Flags::W_CONTINUATION) { self.flags.remove(Flags::W_CONTINUATION); - Parser::write_message( + Parser::write_frame( dst, &data[..], OpCode::Continue, @@ -185,7 +183,6 @@ impl Encoder for Codec { } } }, - Message::Nop => (), } Ok(()) } diff --git a/actix-http/src/ws/dispatcher.rs b/actix-http/src/ws/dispatcher.rs index 7a6b11b18..e125dba48 100644 --- a/actix-http/src/ws/dispatcher.rs +++ b/actix-http/src/ws/dispatcher.rs @@ -6,11 +6,11 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::{IntoService, Service}; use actix_utils::framed; -use super::{Codec, Frame, Message}; +use super::{Codec, Frame}; pub struct Dispatcher where - S: Service + 'static, + S: Service + 'static, T: AsyncRead + AsyncWrite, { inner: framed::Dispatcher, @@ -19,7 +19,7 @@ where impl Dispatcher where T: AsyncRead + AsyncWrite, - S: Service, + S: Service, S::Future: 'static, S::Error: 'static, { @@ -39,7 +39,7 @@ where impl Future for Dispatcher where T: AsyncRead + AsyncWrite, - S: Service, + S: Service, S::Future: 'static, S::Error: 'static, { diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 8f7004f18..457592897 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -153,7 +153,7 @@ impl Parser { } /// Generate binary representation - pub fn write_message>( + pub fn write_frame>( dst: &mut BytesMut, pl: B, op: OpCode, @@ -211,7 +211,7 @@ impl Parser { } }; - Parser::write_message(dst, payload, OpCode::Close, true, mask) + Parser::write_frame(dst, payload, OpCode::Close, true, mask) } } @@ -346,7 +346,7 @@ mod tests { #[test] fn test_ping_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + Parser::write_frame(&mut buf, Vec::from("data"), OpCode::Ping, true, false); let mut v = vec![137u8, 4u8]; v.extend(b"data"); @@ -356,7 +356,7 @@ mod tests { #[test] fn test_pong_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + Parser::write_frame(&mut buf, Vec::from("data"), OpCode::Pong, true, false); let mut v = vec![138u8, 4u8]; v.extend(b"data"); diff --git a/actix-http/src/ws/frame_iters.rs b/actix-http/src/ws/frame_iters.rs new file mode 100644 index 000000000..05b1c42bd --- /dev/null +++ b/actix-http/src/ws/frame_iters.rs @@ -0,0 +1,221 @@ +use super::{Frame, Item}; + +use bytes::Bytes; +use std::str::Chars; + +/// Convert binary message content into Frame::Continuation types +/// +/// This struct is an iterator over websocket frames +/// with a configurable maximum content size. +/// Original messages that are already within the size +/// limit will be rendered as an iterator over one single +/// binary frame. +/// Original messages that are larger than the size threshold +/// will be converted into an iterator over continuation +/// messages, where the first is a FirstBinary message. +pub struct ContinuationBins<'a> { + original: &'a [u8], + step: usize, + bs_i: usize, + bs_tot: usize, + max_frame_content_bytes: usize, +} + +impl<'a> ContinuationBins<'a> { + pub fn new(original: &'a [u8], max_frame_content_bytes: usize) -> Self { + let bs_tot = original.len(); + + Self { + original, + step: 0, + bs_i: 0, + bs_tot, + max_frame_content_bytes, + } + } +} + +impl<'a> Iterator for ContinuationBins<'a> { + type Item = Frame; + + fn next(&mut self) -> Option { + if self.bs_i >= self.bs_tot { + None + } else if self.bs_tot - self.bs_i <= self.max_frame_content_bytes { + if self.step == 0 { + // if there are fewer than max bytes remaining to send and + // we haven't sent anything yet, no continuation frame needed + self.bs_i += self.max_frame_content_bytes; + Some(Frame::Binary(Bytes::copy_from_slice(&self.original))) + } else { + // otherwise if there are fewer than max bytes remaining to send and + // we've already sent something, we send a final frame + let here = self.bs_i; + self.bs_i += self.max_frame_content_bytes; + Some(Frame::Continuation(Item::Last(Bytes::copy_from_slice( + &self.original[here..self.bs_tot], + )))) + } + } else { + let item = if self.step == 0 { + Item::FirstBinary(Bytes::copy_from_slice( + &self.original[self.bs_i..self.bs_i + self.max_frame_content_bytes], + )) + } else { + Item::Continue(Bytes::copy_from_slice( + &self.original[self.bs_i..self.bs_i + self.max_frame_content_bytes], + )) + }; + self.step += 1; + self.bs_i += self.max_frame_content_bytes; + + Some(Frame::Continuation(item)) + } + } +} + +/// Convert text message content into Frame::Continuation types +/// +/// This struct is an iterator over websocket frames +/// with a configurable maximum content size. +/// Original messages that are already within the size +/// limit will be rendered as an iterator over one single +/// text frame. +/// Original messages that are larger than the size threshold +/// will be converted into an iterator over continuation +/// frames, where the first is a FirstText message. +/// Note that for text frames, the maximum content size is +/// fuzzy -- the actual content size may exceed the configured +/// maximum content size by up to 7 bytes, depending on UTF-8 +/// encoding of the text string. +pub struct ContinuationTexts<'a> { + original: Chars<'a>, + step: usize, + bs_i: usize, + bs_tot: usize, + max_frame_content_bytes: usize, +} + +impl<'a> ContinuationTexts<'a> { + pub fn new(original: Chars<'a>, max_frame_content_bytes: usize) -> Self { + let bs_tot = original.as_str().len(); + + Self { + original, + step: 0, + bs_i: 0, + bs_tot, + max_frame_content_bytes, + } + } +} + +impl<'a> Iterator for ContinuationTexts<'a> { + type Item = Frame; + + fn next(&mut self) -> Option { + if self.bs_i >= self.bs_tot { + None + } else if self.bs_tot - self.bs_i <= self.max_frame_content_bytes { + let bs = Bytes::copy_from_slice(self.original.as_str().as_bytes()); + self.bs_i += self.max_frame_content_bytes; + let frm = if self.step == 0 { + // if there are fewer than max bytes remaining to send and + // we haven't sent anything yet, no continuation frame needed + Frame::Text(bs) + } else { + // otherwise if there are fewer than max bytes remaining to send and + // we've already sent something, we send a final frame + Frame::Continuation(Item::Last(bs)) + }; + Some(frm) + } else { + let mut s = String::new(); + let mut temp_i: usize = 0; + while temp_i < self.max_frame_content_bytes { + let c = self.original.next(); + if let Some(c) = c { + temp_i += c.len_utf8(); + self.bs_i += c.len_utf8(); + s.push(c); + } else { + self.bs_i = self.bs_tot; + break; + } + } + let item = if self.step == 0 { + Item::FirstText(Bytes::copy_from_slice(s.as_bytes())) + } else { + Item::Continue(Bytes::copy_from_slice(s.as_bytes())) + }; + self.step += 1; + + Some(Frame::Continuation(item)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_continuation_bins() { + // render a single Frame::Binary when max size is greater than the payload len + let mut bins = ContinuationBins::new(b"one two three", 100); + assert_eq!(bins.next(), Some(Frame::Binary("one two three".into()))); + + let mut bins = ContinuationBins::new(b"one two three", 4); + assert_eq!( + bins.next(), + Some(Frame::Continuation(Item::FirstBinary("one ".into()))) + ); + assert_eq!( + bins.next(), + Some(Frame::Continuation(Item::Continue("two ".into()))) + ); + assert_eq!( + bins.next(), + Some(Frame::Continuation(Item::Continue("thre".into()))) + ); + assert_eq!( + bins.next(), + Some(Frame::Continuation(Item::Last("e".into()))) + ); + } + + #[test] + fn test_continuation_texts() { + // render a single Frame::Binary when max size is greater than the payload len + let mut texts = ContinuationTexts::new("one two three".chars(), 100); + assert_eq!(texts.next(), Some(Frame::Text("one two three".into()))); + + let mut texts = ContinuationTexts::new("one two three".chars(), 4); + assert_eq!( + texts.next(), + Some(Frame::Continuation(Item::FirstText("one ".into()))) + ); + assert_eq!( + texts.next(), + Some(Frame::Continuation(Item::Continue("two ".into()))) + ); + assert_eq!( + texts.next(), + Some(Frame::Continuation(Item::Continue("thre".into()))) + ); + assert_eq!( + texts.next(), + Some(Frame::Continuation(Item::Last("e".into()))) + ); + + let mut snowmen = ContinuationTexts::new("⛄⛄⛄".chars(), 5); + assert_eq!( + snowmen.next(), + Some(Frame::Continuation(Item::FirstText("⛄⛄".into()))) + ); + assert_eq!( + snowmen.next(), + Some(Frame::Continuation(Item::Last("⛄".into()))) + ); + } +} diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 3d83943c7..ec12d5ff3 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -15,12 +15,14 @@ use crate::response::{Response, ResponseBuilder}; mod codec; mod dispatcher; mod frame; +mod frame_iters; mod mask; mod proto; pub use self::codec::{Codec, Frame, Item, Message}; pub use self::dispatcher::Dispatcher; pub use self::frame::Parser; +pub use self::frame_iters::{ContinuationBins, ContinuationTexts}; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; /// Websocket protocol errors diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 4b4b8f089..bde1190b3 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -67,16 +67,14 @@ where } } -async fn service(msg: ws::Frame) -> Result { +async fn service(msg: ws::Frame) -> Result { let msg = match msg { - ws::Frame::Ping(msg) => ws::Message::Pong(msg), - ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8_lossy(&text).to_string()) - } - ws::Frame::Binary(bin) => ws::Message::Binary(bin), - ws::Frame::Continuation(item) => ws::Message::Continuation(item), - ws::Frame::Close(reason) => ws::Message::Close(reason), - _ => panic!(), + ws::Frame::Ping(msg) => ws::Frame::Pong(msg), + ws::Frame::Text(bs) => ws::Frame::Text(bs), + ws::Frame::Binary(bin) => ws::Frame::Binary(bin), + ws::Frame::Continuation(item) => ws::Frame::Continuation(item), + ws::Frame::Close(reason) => ws::Frame::Close(reason), + ws::Frame::Pong(_) => panic!(), }; Ok(msg) } @@ -98,27 +96,21 @@ async fn test_simple() { // client service let mut framed = srv.ws().await.unwrap(); - framed - .send(ws::Message::Text("text".to_string())) - .await - .unwrap(); + framed.send(ws::Frame::Text("text".into())).await.unwrap(); let (item, mut framed) = framed.into_future().await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Text(Bytes::from_static(b"text")) ); - framed - .send(ws::Message::Binary("text".into())) - .await - .unwrap(); + framed.send(ws::Frame::Binary("text".into())).await.unwrap(); let (item, mut framed) = framed.into_future().await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Binary(Bytes::from_static(&b"text"[..])) ); - framed.send(ws::Message::Ping("text".into())).await.unwrap(); + framed.send(ws::Frame::Ping("text".into())).await.unwrap(); let (item, mut framed) = framed.into_future().await; assert_eq!( item.unwrap().unwrap(), @@ -126,9 +118,7 @@ async fn test_simple() { ); framed - .send(ws::Message::Continuation(ws::Item::FirstText( - "text".into(), - ))) + .send(ws::Frame::Continuation(ws::Item::FirstText("text".into()))) .await .unwrap(); let (item, mut framed) = framed.into_future().await; @@ -138,20 +128,18 @@ async fn test_simple() { ); assert!(framed - .send(ws::Message::Continuation(ws::Item::FirstText( - "text".into() - ))) + .send(ws::Frame::Continuation(ws::Item::FirstText("text".into()))) .await .is_err()); assert!(framed - .send(ws::Message::Continuation(ws::Item::FirstBinary( + .send(ws::Frame::Continuation(ws::Item::FirstBinary( "text".into() ))) .await .is_err()); framed - .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .send(ws::Frame::Continuation(ws::Item::Continue("text".into()))) .await .unwrap(); let (item, mut framed) = framed.into_future().await; @@ -161,7 +149,7 @@ async fn test_simple() { ); framed - .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .send(ws::Frame::Continuation(ws::Item::Last("text".into()))) .await .unwrap(); let (item, mut framed) = framed.into_future().await; @@ -171,17 +159,17 @@ async fn test_simple() { ); assert!(framed - .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .send(ws::Frame::Continuation(ws::Item::Continue("text".into()))) .await .is_err()); assert!(framed - .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .send(ws::Frame::Continuation(ws::Item::Last("text".into()))) .await .is_err()); framed - .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .send(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) .await .unwrap(); diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index b28aeade4..5e7a545dc 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -1,4 +1,6 @@ //! Websocket integration +//use super::{ContinuationBins, ContinuationTexts}; + use std::collections::VecDeque; use std::io; use std::pin::Pin; @@ -16,7 +18,8 @@ use actix::{ use actix_codec::{Decoder, Encoder}; use actix_http::ws::{hash_key, Codec}; pub use actix_http::ws::{ - CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError, + CloseCode, CloseReason, ContinuationBins, ContinuationTexts, Frame, HandshakeError, + Item, Message, ProtocolError, }; use actix_web::dev::HttpResponseBuilder; use actix_web::error::{Error, PayloadError}; @@ -26,6 +29,8 @@ use bytes::{Bytes, BytesMut}; use futures::channel::oneshot::Sender; use futures::{Future, Stream}; +const DEFAULT_MAX_FRAME_SIZE: usize = 64_000; + /// Do websocket handshake and start ws actor. pub fn start(actor: A, req: &HttpRequest, stream: T) -> Result where @@ -180,7 +185,8 @@ where A: Actor>, { inner: ContextParts, - messages: VecDeque>, + frames: VecDeque>, + max_frame_content_bytes: usize, } impl ActorContext for WebsocketContext @@ -268,7 +274,8 @@ where let mb = Mailbox::default(); let mut ctx = WebsocketContext { inner: ContextParts::new(mb.sender_producer()), - messages: VecDeque::new(), + frames: VecDeque::new(), + max_frame_content_bytes: DEFAULT_MAX_FRAME_SIZE, }; ctx.add_stream(WsStream::new(stream, Codec::new())); @@ -291,7 +298,8 @@ where let mb = Mailbox::default(); let mut ctx = WebsocketContext { inner: ContextParts::new(mb.sender_producer()), - messages: VecDeque::new(), + frames: VecDeque::new(), + max_frame_content_bytes: DEFAULT_MAX_FRAME_SIZE, }; ctx.add_stream(WsStream::new(stream, codec)); @@ -311,7 +319,8 @@ where let mb = Mailbox::default(); let mut ctx = WebsocketContext { inner: ContextParts::new(mb.sender_producer()), - messages: VecDeque::new(), + frames: VecDeque::new(), + max_frame_content_bytes: DEFAULT_MAX_FRAME_SIZE, }; ctx.add_stream(WsStream::new(stream, Codec::new())); @@ -332,38 +341,46 @@ where /// data you should prefer the `text()` or `binary()` convenience functions /// that handle the framing for you. #[inline] - pub fn write_raw(&mut self, msg: Message) { - self.messages.push_back(Some(msg)); + pub fn write_raw(&mut self, frm: Frame) { + self.frames.push_back(Some(frm)); } /// Send text frame #[inline] pub fn text>(&mut self, text: T) { - self.write_raw(Message::Text(text.into())); + for frm in + ContinuationTexts::new(text.into().chars(), self.max_frame_content_bytes) + { + self.write_raw(frm); + } } /// Send binary frame #[inline] pub fn binary>(&mut self, data: B) { - self.write_raw(Message::Binary(data.into())); + for frm in + ContinuationBins::new(data.into().as_ref(), self.max_frame_content_bytes) + { + self.write_raw(frm); + } } /// Send ping frame #[inline] pub fn ping(&mut self, message: &[u8]) { - self.write_raw(Message::Ping(Bytes::copy_from_slice(message))); + self.write_raw(Frame::Ping(Bytes::copy_from_slice(message))); } /// Send pong frame #[inline] pub fn pong(&mut self, message: &[u8]) { - self.write_raw(Message::Pong(Bytes::copy_from_slice(message))); + self.write_raw(Frame::Pong(Bytes::copy_from_slice(message))); } /// Send close frame #[inline] pub fn close(&mut self, reason: Option) { - self.write_raw(Message::Close(reason)); + self.write_raw(Frame::Close(reason)); } /// Handle of the running future @@ -431,10 +448,10 @@ where let _ = Pin::new(&mut this.fut).poll(cx); } - // encode messages - while let Some(item) = this.fut.ctx().messages.pop_front() { - if let Some(msg) = item { - this.encoder.encode(msg, &mut this.buf)?; + // encode frames + while let Some(item) = this.fut.ctx().frames.pop_front() { + if let Some(frm) = item { + this.encoder.encode(frm, &mut this.buf)?; } else { this.closed = true; break; @@ -462,6 +479,12 @@ where } } +enum ContnBufType { + Unknown, + Text, + Binary, +} + #[pin_project::pin_project] struct WsStream { #[pin] @@ -469,6 +492,8 @@ struct WsStream { decoder: Codec, buf: BytesMut, closed: bool, + continuation_buf: BytesMut, + continuation_buf_type: ContnBufType, } impl WsStream @@ -481,6 +506,8 @@ where decoder: codec, buf: BytesMut::new(), closed: false, + continuation_buf: BytesMut::new(), + continuation_buf_type: ContnBufType::Unknown, } } } @@ -518,7 +545,7 @@ where } } - match this.decoder.decode(this.buf)? { + match this.decoder.decode(&mut this.buf)? { None => { if *this.closed { Poll::Ready(None) @@ -526,26 +553,69 @@ where Poll::Pending } } - Some(frm) => { - let msg = match frm { - Frame::Text(data) => Message::Text( - std::str::from_utf8(&data) - .map_err(|e| { + Some(frm) => match frm { + Frame::Text(data) => { + let txt = std::str::from_utf8(&data) + .map_err(|e| { + ProtocolError::Io(io::Error::new( + io::ErrorKind::Other, + format!("{}", e), + )) + })? + .to_string(); + Poll::Ready(Some(Ok(Message::Text(txt)))) + } + Frame::Binary(data) => Poll::Ready(Some(Ok(Message::Binary(data)))), + Frame::Ping(s) => Poll::Ready(Some(Ok(Message::Ping(s)))), + Frame::Pong(s) => Poll::Ready(Some(Ok(Message::Pong(s)))), + Frame::Close(reason) => Poll::Ready(Some(Ok(Message::Close(reason)))), + Frame::Continuation(item) => match item { + Item::FirstText(bs) => { + this.continuation_buf.clear(); + this.continuation_buf.extend_from_slice(&bs[..]); + *this.continuation_buf_type = ContnBufType::Text; + Poll::Pending + } + Item::FirstBinary(bs) => { + this.continuation_buf.clear(); + this.continuation_buf.extend_from_slice(&bs[..]); + *this.continuation_buf_type = ContnBufType::Binary; + Poll::Pending + } + Item::Continue(bs) => { + this.continuation_buf.extend_from_slice(&bs[..]); + Poll::Pending + } + Item::Last(bs) => { + this.continuation_buf.extend_from_slice(&bs[..]); + match this.continuation_buf_type { + ContnBufType::Text => { + let txt = + std::str::from_utf8(&this.continuation_buf[..]) + .map_err(|e| { + ProtocolError::Io(io::Error::new( + io::ErrorKind::Other, + format!("{}", e), + )) + })? + .to_string(); + Poll::Ready(Some(Ok(Message::Text(txt)))) + } + ContnBufType::Binary => { + let bts = + Bytes::copy_from_slice(&this.continuation_buf[..]); + Poll::Ready(Some(Ok(Message::Binary(bts)))) + } + ContnBufType::Unknown => Poll::Ready(Some(Err( ProtocolError::Io(io::Error::new( io::ErrorKind::Other, - format!("{}", e), - )) - })? - .to_string(), - ), - Frame::Binary(data) => Message::Binary(data), - Frame::Ping(s) => Message::Ping(s), - Frame::Pong(s) => Message::Pong(s), - Frame::Close(reason) => Message::Close(reason), - Frame::Continuation(item) => Message::Continuation(item), - }; - Poll::Ready(Some(Ok(msg))) - } + "Invalid decoder state".to_string(), + )), + ))), + } + } + }, + }, } } } diff --git a/actix-web-actors/tests/test_ws.rs b/actix-web-actors/tests/test_ws.rs index 076e375d3..4e7f33779 100644 --- a/actix-web-actors/tests/test_ws.rs +++ b/actix-web-actors/tests/test_ws.rs @@ -30,8 +30,42 @@ impl StreamHandler> for Ws { async fn test_simple() { let mut srv = test::start(|| { App::new().service(web::resource("/").to( - |req: HttpRequest, stream: web::Payload| { - async move { ws::start(Ws, &req, stream) } + |req: HttpRequest, stream: web::Payload| async move { + ws::start(Ws, &req, stream) + }, + )) + }); + + // client service + let mut framed = srv.ws().await.unwrap(); + framed.send(ws::Frame::Text("text".into())).await.unwrap(); + + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); + + framed.send(ws::Frame::Binary("text".into())).await.unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text").into())); + + framed.send(ws::Frame::Ping("text".into())).await.unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Pong(Bytes::copy_from_slice(b"text"))); + + framed + .send(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); +} + +#[actix_rt::test] +async fn test_continuation_frames() { + let mut srv = test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::start(Ws, &req, stream) }, )) }); @@ -39,29 +73,68 @@ async fn test_simple() { // client service let mut framed = srv.ws().await.unwrap(); framed - .send(ws::Message::Text("text".to_string())) + .send(ws::Frame::Continuation(ws::Item::FirstText("first".into()))) + .await + .unwrap(); + framed + .send(ws::Frame::Continuation(ws::Item::Last(" text".into()))) .await .unwrap(); - let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); + assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"first text"))); framed - .send(ws::Message::Binary("text".into())) + .send(ws::Frame::Continuation(ws::Item::FirstBinary( + "first".into(), + ))) + .await + .unwrap(); + framed + .send(ws::Frame::Continuation(ws::Item::Last(" binary".into()))) .await .unwrap(); let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text").into())); - - framed.send(ws::Message::Ping("text".into())).await.unwrap(); - let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Pong(Bytes::copy_from_slice(b"text"))); + assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"first binary"))); framed - .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .send(ws::Frame::Continuation(ws::Item::FirstText("first".into()))) + .await + .unwrap(); + framed + .send(ws::Frame::Continuation(ws::Item::Continue( + " continuation".into(), + ))) + .await + .unwrap(); + framed + .send(ws::Frame::Continuation(ws::Item::Last(" text".into()))) .await .unwrap(); - let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); + assert_eq!( + item, + ws::Frame::Text(Bytes::from_static(b"first continuation text")) + ); + + framed + .send(ws::Frame::Continuation(ws::Item::FirstBinary( + "first".into(), + ))) + .await + .unwrap(); + framed + .send(ws::Frame::Continuation(ws::Item::Continue( + " continuation".into(), + ))) + .await + .unwrap(); + framed + .send(ws::Frame::Continuation(ws::Item::Last(" binary".into()))) + .await + .unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!( + item, + ws::Frame::Binary(Bytes::from_static(b"first continuation binary")) + ); }