From 0c98775b51e4935dfc5e20312dc58fd645794ba6 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Fri, 9 Feb 2018 22:26:48 -0800
Subject: [PATCH] refactor h1 stream polling

---
 src/payload.rs    |  35 +++---
 src/server/h1.rs  | 290 ++++++++++++++++++++++------------------------
 src/ws/context.rs |   4 +-
 src/ws/mod.rs     |   2 +-
 tests/test_ws.rs  |   2 +-
 5 files changed, 164 insertions(+), 169 deletions(-)

diff --git a/src/payload.rs b/src/payload.rs
index c5c63e78..97e59a48 100644
--- a/src/payload.rs
+++ b/src/payload.rs
@@ -40,7 +40,8 @@ impl fmt::Debug for PayloadItem {
 /// Buffered stream of bytes chunks
 ///
 /// Payload stores chunks in a vector. First chunk can be received with `.readany()` method.
-/// Payload stream is not thread safe.
+/// Payload stream is not thread safe. Payload does not notify current task when
+/// new data is available.
 ///
 /// Payload stream can be used as `HttpResponse` body stream.
 #[derive(Debug)]
@@ -148,7 +149,7 @@ impl Stream for Payload {
 
     #[inline]
     fn poll(&mut self) -> Poll<Option<PayloadItem>, PayloadError> {
-        self.inner.borrow_mut().readany()
+        self.inner.borrow_mut().readany(false)
     }
 }
 
@@ -166,7 +167,7 @@ impl Stream for ReadAny {
     type Error = PayloadError;
 
     fn poll(&mut self) -> Poll<Option<Bytes>, Self::Error> {
-        match self.0.borrow_mut().readany()? {
+        match self.0.borrow_mut().readany(false)? {
             Async::Ready(Some(item)) => Ok(Async::Ready(Some(item.0))),
             Async::Ready(None) => Ok(Async::Ready(None)),
             Async::NotReady => Ok(Async::NotReady),
@@ -182,7 +183,7 @@ impl Future for ReadExactly {
     type Error = PayloadError;
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        match self.0.borrow_mut().readexactly(self.1)? {
+        match self.0.borrow_mut().readexactly(self.1, false)? {
             Async::Ready(chunk) => Ok(Async::Ready(chunk)),
             Async::NotReady => Ok(Async::NotReady),
         }
@@ -197,7 +198,7 @@ impl Future for ReadLine {
     type Error = PayloadError;
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        match self.0.borrow_mut().readline()? {
+        match self.0.borrow_mut().readline(false)? {
             Async::Ready(chunk) => Ok(Async::Ready(chunk)),
             Async::NotReady => Ok(Async::NotReady),
         }
@@ -212,7 +213,7 @@ impl Future for ReadUntil {
     type Error = PayloadError;
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        match self.0.borrow_mut().readuntil(&self.1)? {
+        match self.0.borrow_mut().readuntil(&self.1, false)? {
             Async::Ready(chunk) => Ok(Async::Ready(chunk)),
             Async::NotReady => Ok(Async::NotReady),
         }
@@ -324,7 +325,7 @@ impl Inner {
         self.len
     }
 
-    fn readany(&mut self) -> Poll<Option<PayloadItem>, PayloadError> {
+    fn readany(&mut self, notify: bool) -> Poll<Option<PayloadItem>, PayloadError> {
         if let Some(data) = self.items.pop_front() {
             self.len -= data.len();
             Ok(Async::Ready(Some(PayloadItem(data))))
@@ -333,12 +334,14 @@ impl Inner {
         } else if self.eof {
             Ok(Async::Ready(None))
         } else {
-            self.task = Some(current_task());
+            if notify {
+                self.task = Some(current_task());
+            }
             Ok(Async::NotReady)
         }
     }
 
-    fn readexactly(&mut self, size: usize) -> Result<Async<Bytes>, PayloadError> {
+    fn readexactly(&mut self, size: usize, notify: bool) -> Result<Async<Bytes>, PayloadError> {
         if size <= self.len {
             let mut buf = BytesMut::with_capacity(size);
             while buf.len() < size {
@@ -356,12 +359,14 @@ impl Inner {
         if let Some(err) = self.err.take() {
             Err(err)
         } else {
-            self.task = Some(current_task());
+            if notify {
+                self.task = Some(current_task());
+            }
             Ok(Async::NotReady)
         }
     }
 
-    fn readuntil(&mut self, line: &[u8]) -> Result<Async<Bytes>, PayloadError> {
+    fn readuntil(&mut self, line: &[u8], notify: bool) -> Result<Async<Bytes>, PayloadError> {
         let mut idx = 0;
         let mut num = 0;
         let mut offset = 0;
@@ -411,13 +416,15 @@ impl Inner {
         if let Some(err) = self.err.take() {
             Err(err)
         } else {
-            self.task = Some(current_task());
+            if notify {
+                self.task = Some(current_task());
+            }
             Ok(Async::NotReady)
         }
     }
 
-    fn readline(&mut self) -> Result<Async<Bytes>, PayloadError> {
-        self.readuntil(b"\n")
+    fn readline(&mut self, notify: bool) -> Result<Async<Bytes>, PayloadError> {
+        self.readuntil(b"\n", notify)
     }
 
     pub fn readall(&mut self) -> Option<Bytes> {
diff --git a/src/server/h1.rs b/src/server/h1.rs
index f2578b3b..4ce403cb 100644
--- a/src/server/h1.rs
+++ b/src/server/h1.rs
@@ -96,8 +96,6 @@ impl<T, H> Http1<T, H>
         }
     }
 
-    // TODO: refactor
-    #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))]
     pub fn poll(&mut self) -> Poll<(), ()> {
         // keep-alive timer
         if let Some(ref mut timer) = self.keepalive_timer {
@@ -111,99 +109,19 @@ impl<T, H> Http1<T, H>
             }
         }
 
-        loop {
-            let mut not_ready = true;
+        self.poll_io()
+    }
 
-            // check in-flight messages
-            let mut io = false;
-            let mut idx = 0;
-            while idx < self.tasks.len() {
-                let item = &mut self.tasks[idx];
-
-                if !io && !item.flags.contains(EntryFlags::EOF) {
-                    if item.flags.contains(EntryFlags::ERROR) {
-                        // check stream state
-                        if let Ok(Async::NotReady) = self.stream.poll_completed(true) {
-                            return Ok(Async::NotReady)
-                        }
-                        return Err(())
-                    }
-
-                    match item.pipe.poll_io(&mut self.stream) {
-                        Ok(Async::Ready(ready)) => {
-                            not_ready = false;
-
-                            // override keep-alive state
-                            if self.stream.keepalive() {
-                                self.flags.insert(Flags::KEEPALIVE);
-                            } else {
-                                self.flags.remove(Flags::KEEPALIVE);
-                            }
-                            self.stream.reset();
-
-                            item.flags.insert(EntryFlags::EOF);
-                            if ready {
-                                item.flags.insert(EntryFlags::FINISHED);
-                            }
-                        },
-                        // no more IO for this iteration
-                        Ok(Async::NotReady) => io = true,
-                        Err(err) => {
-                            // it is not possible to recover from error
-                            // during pipe handling, so just drop connection
-                            error!("Unhandled error: {}", err);
-                            item.flags.insert(EntryFlags::ERROR);
-
-                            // check stream state, we still can have valid data in buffer
-                            if let Ok(Async::NotReady) = self.stream.poll_completed(true) {
-                                return Ok(Async::NotReady)
-                            }
-                            return Err(())
-                        }
-                    }
-                } else if !item.flags.contains(EntryFlags::FINISHED) {
-                    match item.pipe.poll() {
-                        Ok(Async::NotReady) => (),
-                        Ok(Async::Ready(_)) => {
-                            not_ready = false;
-                            item.flags.insert(EntryFlags::FINISHED);
-                        },
-                        Err(err) => {
-                            item.flags.insert(EntryFlags::ERROR);
-                            error!("Unhandled error: {}", err);
-                        }
-                    }
-                }
-                idx += 1;
-            }
-
-            // cleanup finished tasks
-            while !self.tasks.is_empty() {
-                if self.tasks[0].flags.contains(EntryFlags::EOF) &&
-                    self.tasks[0].flags.contains(EntryFlags::FINISHED)
-                {
-                    self.tasks.pop_front();
-                } else {
-                    break
-                }
-            }
-
-            // no keep-alive
-            if !self.flags.contains(Flags::KEEPALIVE) && self.tasks.is_empty() {
-                // check stream state
-                if !self.poll_completed(true)? {
-                    return Ok(Async::NotReady)
-                }
-                return Ok(Async::Ready(()))
-            }
-
-            // read incoming data
-            while !self.flags.contains(Flags::ERROR) && self.tasks.len() < MAX_PIPELINED_MESSAGES {
+    // TODO: refactor
+    pub fn poll_io(&mut self) -> Poll<(), ()> {
+        // read incoming data
+        let need_read =
+            if !self.flags.contains(Flags::ERROR) && self.tasks.len() < MAX_PIPELINED_MESSAGES
+        {
+            'outer: loop {
                 match self.reader.parse(self.stream.get_mut(),
                                         &mut self.read_buf, &self.settings) {
                     Ok(Async::Ready(mut req)) => {
-                        not_ready = false;
-
                         // set remote addr
                         req.set_peer_addr(self.addr);
 
@@ -211,58 +129,24 @@ impl<T, H> Http1<T, H>
                         self.keepalive_timer.take();
 
                         // start request processing
-                        let mut pipe = None;
                         for h in self.settings.handlers().iter_mut() {
                             req = match h.handle(req) {
                                 Ok(t) => {
-                                    pipe = Some(t);
-                                    break
+                                    self.tasks.push_back(
+                                        Entry {pipe: t, flags: EntryFlags::empty()});
+                                    continue 'outer
                                 },
                                 Err(req) => req,
                             }
                         }
 
                         self.tasks.push_back(
-                            Entry {pipe: pipe.unwrap_or_else(|| Pipeline::error(HTTPNotFound)),
+                            Entry {pipe: Pipeline::error(HTTPNotFound),
                                    flags: EntryFlags::empty()});
+                        continue
                     },
-                    Ok(Async::NotReady) => {
-                        // start keep-alive timer, this also is slow request timeout
-                        if self.tasks.is_empty() {
-                            if self.settings.keep_alive_enabled() {
-                                let keep_alive = self.settings.keep_alive();
-                                if keep_alive > 0 && self.flags.contains(Flags::KEEPALIVE) {
-                                    if self.keepalive_timer.is_none() {
-                                        trace!("Start keep-alive timer");
-                                        let mut to = Timeout::new(
-                                            Duration::new(keep_alive, 0),
-                                            Arbiter::handle()).unwrap();
-                                        // register timeout
-                                        let _ = to.poll();
-                                        self.keepalive_timer = Some(to);
-                                    }
-                                } else {
-                                    // check stream state
-                                    if !self.poll_completed(true)? {
-                                        return Ok(Async::NotReady)
-                                    }
-                                    // keep-alive disable, drop connection
-                                    return Ok(Async::Ready(()))
-                                }
-                            } else if !self.poll_completed(false)? ||
-                                self.flags.contains(Flags::KEEPALIVE)
-                            {
-                                // check stream state or
-                                // if keep-alive unset, rely on operating system
-                                return Ok(Async::NotReady)
-                            } else {
-                                return Ok(Async::Ready(()))
-                            }
-                        }
-                        break
-                    },
+                    Ok(Async::NotReady) => (),
                     Err(ReaderError::Disconnect) => {
-                        not_ready = false;
                         self.flags.insert(Flags::ERROR);
                         self.stream.disconnected();
                         for entry in &mut self.tasks {
@@ -271,7 +155,6 @@ impl<T, H> Http1<T, H>
                     },
                     Err(err) => {
                         // notify all tasks
-                        not_ready = false;
                         self.stream.disconnected();
                         for entry in &mut self.tasks {
                             entry.pipe.disconnected()
@@ -293,20 +176,132 @@ impl<T, H> Http1<T, H>
                         }
                     },
                 }
+                break
+            }
+            false
+        } else {
+            true
+        };
+
+        loop {
+            // check in-flight messages
+            let mut io = false;
+            let mut idx = 0;
+            while idx < self.tasks.len() {
+                let item = &mut self.tasks[idx];
+
+                if !io && !item.flags.contains(EntryFlags::EOF) {
+                    // io is corrupted, send buffer
+                    if item.flags.contains(EntryFlags::ERROR) {
+                        if let Ok(Async::NotReady) = self.stream.poll_completed(true) {
+                            return Ok(Async::NotReady)
+                        }
+                        return Err(())
+                    }
+
+                    match item.pipe.poll_io(&mut self.stream) {
+                        Ok(Async::Ready(ready)) => {
+                            // override keep-alive state
+                            if self.stream.keepalive() {
+                                self.flags.insert(Flags::KEEPALIVE);
+                            } else {
+                                self.flags.remove(Flags::KEEPALIVE);
+                            }
+                            // prepare stream for next response
+                            self.stream.reset();
+
+                            if ready {
+                                item.flags.insert(EntryFlags::EOF | EntryFlags::FINISHED);
+                            } else {
+                                item.flags.insert(EntryFlags::FINISHED);
+                            }
+                        },
+                        // no more IO for this iteration
+                        Ok(Async::NotReady) => io = true,
+                        Err(err) => {
+                            // it is not possible to recover from error
+                            // during pipe handling, so just drop connection
+                            error!("Unhandled error: {}", err);
+                            item.flags.insert(EntryFlags::ERROR);
+
+                            // check stream state, we still can have valid data in buffer
+                            if let Ok(Async::NotReady) = self.stream.poll_completed(true) {
+                                return Ok(Async::NotReady)
+                            }
+                            return Err(())
+                        }
+                    }
+                } else if !item.flags.contains(EntryFlags::FINISHED) {
+                    match item.pipe.poll() {
+                        Ok(Async::NotReady) => (),
+                        Ok(Async::Ready(_)) => item.flags.insert(EntryFlags::FINISHED),
+                        Err(err) => {
+                            item.flags.insert(EntryFlags::ERROR);
+                            error!("Unhandled error: {}", err);
+                        }
+                    }
+                }
+                idx += 1;
             }
 
-            // check for parse error
-            if self.tasks.is_empty() {
+            // cleanup finished tasks
+            let mut popped = false;
+            while !self.tasks.is_empty() {
+                if self.tasks[0].flags.contains(EntryFlags::EOF | EntryFlags::FINISHED) {
+                    popped = true;
+                    self.tasks.pop_front();
+                } else {
+                    break
+                }
+            }
+            if need_read && popped {
+                return self.poll_io()
+            }
+
+            // no keep-alive
+            if !self.flags.contains(Flags::KEEPALIVE) && self.tasks.is_empty() {
                 // check stream state
                 if !self.poll_completed(true)? {
                     return Ok(Async::NotReady)
                 }
-                if self.flags.contains(Flags::ERROR) || self.keepalive_timer.is_none() {
-                    return Ok(Async::Ready(()))
-                }
+                return Ok(Async::Ready(()))
             }
 
-            if not_ready {
+            // start keep-alive timer, this also is slow request timeout
+            if self.tasks.is_empty() {
+                // check stream state
+                if self.flags.contains(Flags::ERROR) {
+                    return Ok(Async::Ready(()))
+                }
+
+                if self.settings.keep_alive_enabled() {
+                    let keep_alive = self.settings.keep_alive();
+                    if keep_alive > 0 && self.flags.contains(Flags::KEEPALIVE) {
+                        if self.keepalive_timer.is_none() {
+                            trace!("Start keep-alive timer");
+                            let mut to = Timeout::new(
+                                Duration::new(keep_alive, 0), Arbiter::handle()).unwrap();
+                            // register timeout
+                            let _ = to.poll();
+                            self.keepalive_timer = Some(to);
+                        }
+                    } else {
+                        // check stream state
+                        if !self.poll_completed(true)? {
+                            return Ok(Async::NotReady)
+                        }
+                        // keep-alive is disabled, drop connection
+                        return Ok(Async::Ready(()))
+                    }
+                } else if !self.poll_completed(false)? ||
+                    self.flags.contains(Flags::KEEPALIVE) {
+                        // check stream state or
+                        // if keep-alive unset, rely on operating system
+                        return Ok(Async::NotReady)
+                    } else {
+                        return Ok(Async::Ready(()))
+                    }
+            } else {
                 self.poll_completed(false)?;
                 return Ok(Async::NotReady)
             }
@@ -344,7 +339,7 @@ impl Reader {
 
     #[inline]
     fn decode(&mut self, buf: &mut BytesMut, payload: &mut PayloadInfo)
-              -> std::result::Result<Decoding, ReaderError>
+              -> Result<Decoding, ReaderError>
     {
         loop {
             match payload.decoder.decode(buf) {
@@ -416,15 +411,10 @@ impl Reader {
         // if buf is empty parse_message will always return NotReady, let's avoid that
         let read = if buf.is_empty() {
             match utils::read_from_io(io, buf) {
-                Ok(Async::Ready(0)) => {
-                    // debug!("Ignored premature client disconnection");
-                    return Err(ReaderError::Disconnect);
-                },
+                Ok(Async::Ready(0)) => return Err(ReaderError::Disconnect),
                 Ok(Async::Ready(_)) => (),
-                Ok(Async::NotReady) =>
-                    return Ok(Async::NotReady),
-                Err(err) =>
-                    return Err(ReaderError::Error(err.into()))
+                Ok(Async::NotReady) => return Ok(Async::NotReady),
+                Err(err) => return Err(ReaderError::Error(err.into()))
             }
             false
         } else {
@@ -455,10 +445,8 @@ impl Reader {
                                 return Err(ReaderError::Disconnect);
                             },
                             Ok(Async::Ready(_)) => (),
-                            Ok(Async::NotReady) =>
-                                return Ok(Async::NotReady),
-                            Err(err) =>
-                                return Err(ReaderError::Error(err.into()))
+                            Ok(Async::NotReady) => return Ok(Async::NotReady),
+                            Err(err) => return Err(ReaderError::Error(err.into())),
                         }
                     } else {
                         return Ok(Async::NotReady)
diff --git a/src/ws/context.rs b/src/ws/context.rs
index c74410aa..a903a890 100644
--- a/src/ws/context.rs
+++ b/src/ws/context.rs
@@ -139,8 +139,8 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
 
     /// Send text frame
     #[inline]
-    pub fn text(&mut self, text: &str) {
-        self.write(Frame::message(Vec::from(text), OpCode::Text, true).generate(false));
+    pub fn text<T: Into<String>>(&mut self, text: T) {
+        self.write(Frame::message(text.into(), OpCode::Text, true).generate(false));
     }
 
     /// Send binary frame
diff --git a/src/ws/mod.rs b/src/ws/mod.rs
index 17501a7d..07b845cc 100644
--- a/src/ws/mod.rs
+++ b/src/ws/mod.rs
@@ -30,7 +30,7 @@
 //!     fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
 //!         match msg {
 //!             ws::Message::Ping(msg) => ctx.pong(&msg),
-//!             ws::Message::Text(text) => ctx.text(&text),
+//!             ws::Message::Text(text) => ctx.text(text),
 //!             ws::Message::Binary(bin) => ctx.binary(bin),
 //!             _ => (),
 //!         }
diff --git a/tests/test_ws.rs b/tests/test_ws.rs
index 29db0b9e..cb5a7426 100644
--- a/tests/test_ws.rs
+++ b/tests/test_ws.rs
@@ -22,7 +22,7 @@ impl Handler<ws::Message> for Ws {
     fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
         match msg {
             ws::Message::Ping(msg) => ctx.pong(&msg),
-            ws::Message::Text(text) => ctx.text(&text),
+            ws::Message::Text(text) => ctx.text(text),
             ws::Message::Binary(bin) => ctx.binary(bin),
             _ => (),
         }