From 6cd40df38769bd046529c62a56660791b7407b93 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Mon, 19 Mar 2018 17:27:03 -0700
Subject: [PATCH] Fix server websockets big payloads support

---
 .travis.yml            |   3 -
 CHANGES.md             |   2 +
 src/client/parser.rs   |   4 +-
 src/pipeline.rs        | 280 +++++++++++++++++++++--------------------
 src/server/h1writer.rs |   6 +-
 src/ws/client.rs       |   4 +-
 src/ws/mod.rs          |   3 +-
 tests/test_ws.rs       |  30 ++++-
 8 files changed, 178 insertions(+), 154 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index dfa93d40e..aa7f0c1e5 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -12,9 +12,6 @@ matrix:
     - rust: stable
     - rust: beta
     - rust: nightly
-  allow_failures:
-    - rust: nightly
-    - rust: beta
 
 #rust:
 #  - 1.21.0
diff --git a/CHANGES.md b/CHANGES.md
index 7ff63f669..ab798f06c 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -8,6 +8,8 @@
 
 * Allow to set client websocket handshake timeout
 
+* Fix server websockets big payloads support
+
 
 ## 0.4.9 (2018-03-16)
 
diff --git a/src/client/parser.rs b/src/client/parser.rs
index 6ffcd76e4..8fe399009 100644
--- a/src/client/parser.rs
+++ b/src/client/parser.rs
@@ -145,9 +145,7 @@ impl HttpResponseParser {
         // convert headers
         let mut hdrs = HeaderMap::new();
         for header in headers[..headers_len].iter() {
-            let n_start = header.name.as_ptr() as usize - bytes_ptr;
-            let n_end = n_start + header.name.len();
-            if let Ok(name) = HeaderName::try_from(slice.slice(n_start, n_end)) {
+            if let Ok(name) = HeaderName::try_from(header.name) {
                 let v_start = header.value.as_ptr() as usize - bytes_ptr;
                 let v_end = v_start + header.value.len();
                 let value = unsafe {
diff --git a/src/pipeline.rs b/src/pipeline.rs
index e92e16f54..b5772e9a3 100644
--- a/src/pipeline.rs
+++ b/src/pipeline.rs
@@ -453,167 +453,171 @@ impl<S: 'static, H> ProcessResponse<S, H> {
     fn poll_io(mut self, io: &mut Writer, info: &mut PipelineInfo<S>)
                -> Result<PipelineState<S, H>, PipelineState<S, H>>
     {
-        if self.drain.is_none() && self.running != RunningState::Paused {
-            // if task is paused, write buffer is probably full
-            'outter: loop {
-                let result = match mem::replace(&mut self.iostate, IOState::Done) {
-                    IOState::Response => {
-                        let encoding = self.resp.content_encoding().unwrap_or(info.encoding);
+        loop {
+            if self.drain.is_none() && self.running != RunningState::Paused {
+                // if task is paused, write buffer is probably full
+                'inner: loop {
+                    let result = match mem::replace(&mut self.iostate, IOState::Done) {
+                        IOState::Response => {
+                            let encoding = self.resp.content_encoding().unwrap_or(info.encoding);
 
-                        let result = match io.start(info.req_mut().get_inner(),
-                                                    &mut self.resp, encoding)
-                        {
-                            Ok(res) => res,
-                            Err(err) => {
-                                info.error = Some(err.into());
-                                return Ok(FinishingMiddlewares::init(info, self.resp))
-                            }
-                        };
-
-                        if let Some(err) = self.resp.error() {
-                            if self.resp.status().is_server_error() {
-                                error!("Error occured during request handling: {}", err);
-                            } else {
-                                warn!("Error occured during request handling: {}", err);
-                            }
-                            if log_enabled!(Debug) {
-                                debug!("{:?}", err);
-                            }
-                        }
-
-                        // always poll stream or actor for the first time
-                        match self.resp.replace_body(Body::Empty) {
-                            Body::Streaming(stream) => {
-                                self.iostate = IOState::Payload(stream);
-                                continue
-                            },
-                            Body::Actor(ctx) => {
-                                self.iostate = IOState::Actor(ctx);
-                                continue
-                            },
-                            _ => (),
-                        }
-
-                        result
-                    },
-                    IOState::Payload(mut body) => {
-                        match body.poll() {
-                            Ok(Async::Ready(None)) => {
-                                if let Err(err) = io.write_eof() {
+                            let result = match io.start(info.req_mut().get_inner(),
+                                                        &mut self.resp, encoding)
+                            {
+                                Ok(res) => res,
+                                Err(err) => {
                                     info.error = Some(err.into());
                                     return Ok(FinishingMiddlewares::init(info, self.resp))
                                 }
-                                break
+                            };
+
+                            if let Some(err) = self.resp.error() {
+                                if self.resp.status().is_server_error() {
+                                    error!("Error occured during request handling: {}", err);
+                                } else {
+                                    warn!("Error occured during request handling: {}", err);
+                                }
+                                if log_enabled!(Debug) {
+                                    debug!("{:?}", err);
+                                }
+                            }
+
+                            // always poll stream or actor for the first time
+                            match self.resp.replace_body(Body::Empty) {
+                                Body::Streaming(stream) => {
+                                    self.iostate = IOState::Payload(stream);
+                                    continue 'inner
+                                },
+                            Body::Actor(ctx) => {
+                                self.iostate = IOState::Actor(ctx);
+                                continue 'inner
                             },
-                            Ok(Async::Ready(Some(chunk))) => {
-                                self.iostate = IOState::Payload(body);
-                                match io.write(chunk.into()) {
-                                    Err(err) => {
+                                _ => (),
+                            }
+
+                            result
+                        },
+                        IOState::Payload(mut body) => {
+                            match body.poll() {
+                                Ok(Async::Ready(None)) => {
+                                    if let Err(err) = io.write_eof() {
                                         info.error = Some(err.into());
                                         return Ok(FinishingMiddlewares::init(info, self.resp))
-                                    },
-                                    Ok(result) => result
-                                }
-                            }
-                            Ok(Async::NotReady) => {
-                                self.iostate = IOState::Payload(body);
-                                break
-                            },
-                            Err(err) => {
-                                info.error = Some(err);
-                                return Ok(FinishingMiddlewares::init(info, self.resp))
-                            }
-                        }
-                    },
-                    IOState::Actor(mut ctx) => {
-                        if info.disconnected.take().is_some() {
-                            ctx.disconnected();
-                        }
-                        match ctx.poll() {
-                            Ok(Async::Ready(Some(vec))) => {
-                                if vec.is_empty() {
-                                    self.iostate = IOState::Actor(ctx);
+                                    }
                                     break
-                                }
-                                let mut res = None;
-                                for frame in vec {
-                                    match frame {
-                                        Frame::Chunk(None) => {
-                                            info.context = Some(ctx);
-                                            if let Err(err) = io.write_eof() {
-                                                info.error = Some(err.into());
-                                                return Ok(
-                                                    FinishingMiddlewares::init(info, self.resp))
-                                            }
-                                            break 'outter
+                                },
+                                Ok(Async::Ready(Some(chunk))) => {
+                                    self.iostate = IOState::Payload(body);
+                                    match io.write(chunk.into()) {
+                                        Err(err) => {
+                                            info.error = Some(err.into());
+                                            return Ok(FinishingMiddlewares::init(info, self.resp))
                                         },
-                                        Frame::Chunk(Some(chunk)) => {
-                                            match io.write(chunk) {
-                                                Err(err) => {
+                                        Ok(result) => result
+                                    }
+                                }
+                                Ok(Async::NotReady) => {
+                                    self.iostate = IOState::Payload(body);
+                                    break
+                                },
+                                Err(err) => {
+                                    info.error = Some(err);
+                                    return Ok(FinishingMiddlewares::init(info, self.resp))
+                                }
+                            }
+                        },
+                        IOState::Actor(mut ctx) => {
+                            if info.disconnected.take().is_some() {
+                                ctx.disconnected();
+                            }
+                            match ctx.poll() {
+                                Ok(Async::Ready(Some(vec))) => {
+                                    if vec.is_empty() {
+                                        self.iostate = IOState::Actor(ctx);
+                                        break
+                                    }
+                                    let mut res = None;
+                                    for frame in vec {
+                                        match frame {
+                                            Frame::Chunk(None) => {
+                                                info.context = Some(ctx);
+                                                if let Err(err) = io.write_eof() {
                                                     info.error = Some(err.into());
                                                     return Ok(
                                                         FinishingMiddlewares::init(info, self.resp))
-                                                },
-                                                Ok(result) => res = Some(result),
-                                            }
-                                        },
-                                        Frame::Drain(fut) => self.drain = Some(fut),
+                                                }
+                                                break 'inner
+                                            },
+                                            Frame::Chunk(Some(chunk)) => {
+                                                match io.write(chunk) {
+                                                    Err(err) => {
+                                                        info.error = Some(err.into());
+                                                        return Ok(
+                                                            FinishingMiddlewares::init(info, self.resp))
+                                                    },
+                                                    Ok(result) => res = Some(result),
+                                                }
+                                            },
+                                            Frame::Drain(fut) => self.drain = Some(fut),
+                                        }
                                     }
+                                    self.iostate = IOState::Actor(ctx);
+                                    if self.drain.is_some() {
+                                        self.running.resume();
+                                        break 'inner
+                                    }
+                                    res.unwrap()
+                                },
+                                Ok(Async::Ready(None)) => {
+                                    break
                                 }
-                                self.iostate = IOState::Actor(ctx);
-                                if self.drain.is_some() {
-                                    self.running.resume();
-                                    break 'outter
+                                Ok(Async::NotReady) => {
+                                    self.iostate = IOState::Actor(ctx);
+                                    break
+                                }
+                                Err(err) => {
+                                    info.error = Some(err);
+                                    return Ok(FinishingMiddlewares::init(info, self.resp))
                                 }
-                                res.unwrap()
-                            },
-                            Ok(Async::Ready(None)) => {
-                                break
-                            }
-                            Ok(Async::NotReady) => {
-                                self.iostate = IOState::Actor(ctx);
-                                break
-                            }
-                            Err(err) => {
-                                info.error = Some(err);
-                                return Ok(FinishingMiddlewares::init(info, self.resp))
                             }
                         }
-                    }
-                    IOState::Done => break,
-                };
+                        IOState::Done => break,
+                    };
 
-                match result {
-                    WriterState::Pause => {
-                        self.running.pause();
-                        break
+                    match result {
+                        WriterState::Pause => {
+                            self.running.pause();
+                            break
+                        }
+                        WriterState::Done => {
+                            self.running.resume()
+                        },
                     }
-                    WriterState::Done => {
-                        self.running.resume()
+                }
+            }
+
+            // flush io but only if we need to
+            if self.running == RunningState::Paused || self.drain.is_some() {
+                match io.poll_completed(false) {
+                    Ok(Async::Ready(_)) => {
+                        self.running.resume();
+
+                        // resolve drain futures
+                        if let Some(tx) = self.drain.take() {
+                            let _ = tx.send(());
+                        }
+                        // restart io processing
+                        continue
                     },
-                }
-            }
-        }
-
-        // flush io but only if we need to
-        if self.running == RunningState::Paused || self.drain.is_some() {
-            match io.poll_completed(false) {
-                Ok(Async::Ready(_)) => {
-                    self.running.resume();
-
-                    // resolve drain futures
-                    if let Some(tx) = self.drain.take() {
-                        let _ = tx.send(());
+                    Ok(Async::NotReady) =>
+                        return Err(PipelineState::Response(self)),
+                    Err(err) => {
+                        info.error = Some(err.into());
+                        return Ok(FinishingMiddlewares::init(info, self.resp))
                     }
-                    // restart io processing
-                    return self.poll_io(io, info);
-                },
-                Ok(Async::NotReady) => return Err(PipelineState::Response(self)),
-                Err(err) => {
-                    info.error = Some(err.into());
-                    return Ok(FinishingMiddlewares::init(info, self.resp))
                 }
             }
+            break
         }
 
         // response is completed
diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs
index 531e3c8d5..c3eb5dc93 100644
--- a/src/server/h1writer.rs
+++ b/src/server/h1writer.rs
@@ -82,7 +82,9 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
                     self.disconnected();
                     return Err(io::Error::new(io::ErrorKind::WriteZero, ""))
                 },
-                Ok(n) => written += n,
+                Ok(n) => {
+                    written += n;
+                },
                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
                     return Ok(written)
                 }
@@ -229,7 +231,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
                     if self.buffer.is_empty() {
                         let pl: &[u8] = payload.as_ref();
                         let n = self.write_data(pl)?;
-                        if pl.len() < n {
+                        if n < pl.len() {
                             self.buffer.extend_from_slice(&pl[n..]);
                             return Ok(WriterState::Done);
                         }
diff --git a/src/ws/client.rs b/src/ws/client.rs
index c5fdcf798..595930989 100644
--- a/src/ws/client.rs
+++ b/src/ws/client.rs
@@ -454,13 +454,13 @@ impl Stream for ClientReader {
         // read
         match Frame::parse(&mut inner.rx, false, max_size) {
             Ok(Async::Ready(Some(frame))) => {
-                let (finished, opcode, payload) = frame.unpack();
+                let (_finished, opcode, payload) = frame.unpack();
 
                 match opcode {
                     // continuation is not supported
                     OpCode::Continue => {
                         inner.closed = true;
-                        return Err(ProtocolError::NoContinuation)
+                        Err(ProtocolError::NoContinuation)
                     },
                     OpCode::Bad => {
                         inner.closed = true;
diff --git a/src/ws/mod.rs b/src/ws/mod.rs
index 12fb4d709..7b41cf253 100644
--- a/src/ws/mod.rs
+++ b/src/ws/mod.rs
@@ -329,7 +329,8 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
                         match String::from_utf8(tmp) {
                             Ok(s) =>
                                 Ok(Async::Ready(Some(Message::Text(s)))),
-                            Err(_) => {
+                            Err(e) => {
+                                println!("ENC: {:?}", e);
                                 self.closed = true;
                                 Err(ProtocolError::BadEncoding)
                             }
diff --git a/tests/test_ws.rs b/tests/test_ws.rs
index 4d3ed4729..a4dd2c230 100644
--- a/tests/test_ws.rs
+++ b/tests/test_ws.rs
@@ -93,7 +93,8 @@ fn test_large_bin() {
 }
 
 struct Ws2 {
-    count: usize
+    count: usize,
+    bin: bool,
 }
 
 impl Actor for Ws2 {
@@ -106,10 +107,14 @@ impl Actor for Ws2 {
 
 impl Ws2 {
     fn send(&mut self, ctx: &mut ws::WebsocketContext<Self>) {
-        ctx.text("0".repeat(65_536));
+        if self.bin {
+            ctx.binary(Vec::from("0".repeat(65_536)));
+        } else {
+            ctx.text("0".repeat(65_536));
+        }
         ctx.drain().and_then(|_, act, ctx| {
             act.count += 1;
-            if act.count != 100 {
+            if act.count != 10_000 {
                 act.send(ctx);
             }
             actix::fut::ok(())
@@ -135,10 +140,25 @@ fn test_server_send_text() {
     let data = Some(ws::Message::Text("0".repeat(65_536)));
 
     let mut srv = test::TestServer::new(
-        |app| app.handler(|req| ws::start(req, Ws2{count:0})));
+        |app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false})));
     let (mut reader, _writer) = srv.ws().unwrap();
 
-    for _ in 0..100 {
+    for _ in 0..10_000 {
+        let (item, r) = srv.execute(reader.into_future()).unwrap();
+        reader = r;
+        assert_eq!(item, data);
+    }
+}
+
+#[test]
+fn test_server_send_bin() {
+    let data = Some(ws::Message::Binary(Binary::from("0".repeat(65_536))));
+
+    let mut srv = test::TestServer::new(
+        |app| app.handler(|req| ws::start(req, Ws2{count:0, bin: true})));
+    let (mut reader, _writer) = srv.ws().unwrap();
+
+    for _ in 0..10_000 {
         let (item, r) = srv.execute(reader.into_future()).unwrap();
         reader = r;
         assert_eq!(item, data);