From 86583049fa59fbc8a2ae77dfa49335166fdd6219 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Wed, 25 Oct 2017 16:25:26 -0700
Subject: [PATCH] Fix disconnection handling

---
 CHANGES.md      |   2 +
 Cargo.toml      |   2 +-
 src/body.rs     |  36 +++++-----
 src/context.rs  |  20 +++++-
 src/resource.rs |   2 +-
 src/server.rs   |  31 +++++++--
 src/task.rs     | 182 ++++++++++++++++++++++++++++++++----------------
 7 files changed, 188 insertions(+), 87 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 54c66adf..76866d2f 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -11,6 +11,8 @@
 
 * Re-use `BinaryBody` for `Frame::Payload`
 
+* Fix disconnection handling.
+
 
 ## 0.1.0 (2017-10-23)
 
diff --git a/Cargo.toml b/Cargo.toml
index c04ff115..bc67ab77 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -49,7 +49,7 @@ tokio-proto = "0.1"
 # h2 = { git = 'https://github.com/carllerche/h2', optional = true }
 
 [dependencies.actix]
-#version = "0.3"
+#version = ">=0.3.1"
 #path = "../actix"
 git = "https://github.com/actix/actix.git"
 default-features = false
diff --git a/src/body.rs b/src/body.rs
index 15571095..9e8b3ece 100644
--- a/src/body.rs
+++ b/src/body.rs
@@ -49,7 +49,7 @@ impl Body {
     }
 
     /// Create body from slice (copy)
-    pub fn from_slice<'a>(s: &'a [u8]) -> Body {
+    pub fn from_slice(s: &[u8]) -> Body {
         Body::Binary(BinaryBody::Bytes(Bytes::from(s)))
     }
 }
@@ -61,19 +61,23 @@ impl<T> From<T> for Body where T: Into<BinaryBody>{
 }
 
 impl BinaryBody {
+    pub fn is_empty(&self) -> bool {
+        self.len() == 0
+    }
+
     pub fn len(&self) -> usize {
-        match self {
-            &BinaryBody::Bytes(ref bytes) => bytes.len(),
-            &BinaryBody::Slice(slice) => slice.len(),
-            &BinaryBody::SharedBytes(ref bytes) => bytes.len(),
-            &BinaryBody::ArcSharedBytes(ref bytes) => bytes.len(),
-            &BinaryBody::SharedString(ref s) => s.len(),
-            &BinaryBody::ArcSharedString(ref s) => s.len(),
+        match *self {
+            BinaryBody::Bytes(ref bytes) => bytes.len(),
+            BinaryBody::Slice(slice) => slice.len(),
+            BinaryBody::SharedBytes(ref bytes) => bytes.len(),
+            BinaryBody::ArcSharedBytes(ref bytes) => bytes.len(),
+            BinaryBody::SharedString(ref s) => s.len(),
+            BinaryBody::ArcSharedString(ref s) => s.len(),
         }
     }
 
     /// Create binary body from slice
-    pub fn from_slice<'a>(s: &'a [u8]) -> BinaryBody {
+    pub fn from_slice(s: &[u8]) -> BinaryBody {
         BinaryBody::Bytes(Bytes::from(s))
     }
 }
@@ -164,13 +168,13 @@ impl<'a> From<&'a Arc<String>> for BinaryBody {
 
 impl AsRef<[u8]> for BinaryBody {
     fn as_ref(&self) -> &[u8] {
-        match self {
-            &BinaryBody::Bytes(ref bytes) => bytes.as_ref(),
-            &BinaryBody::Slice(slice) => slice,
-            &BinaryBody::SharedBytes(ref bytes) => bytes.as_ref(),
-            &BinaryBody::ArcSharedBytes(ref bytes) => bytes.as_ref(),
-            &BinaryBody::SharedString(ref s) => s.as_bytes(),
-            &BinaryBody::ArcSharedString(ref s) => s.as_bytes(),
+        match *self {
+            BinaryBody::Bytes(ref bytes) => bytes.as_ref(),
+            BinaryBody::Slice(slice) => slice,
+            BinaryBody::SharedBytes(ref bytes) => bytes.as_ref(),
+            BinaryBody::ArcSharedBytes(ref bytes) => bytes.as_ref(),
+            BinaryBody::SharedString(ref s) => s.as_bytes(),
+            BinaryBody::ArcSharedString(ref s) => s.as_bytes(),
         }
     }
 }
diff --git a/src/context.rs b/src/context.rs
index 20885206..3606db59 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -10,6 +10,7 @@ use actix::fut::ActorFuture;
 use actix::dev::{AsyncContextApi, ActorAddressCell, ActorItemsCell, ActorWaitCell, SpawnHandle,
                  Envelope, ToEnvelope, RemoteEnvelope};
 
+use task::IoContext;
 use body::BinaryBody;
 use route::{Route, Frame};
 use httpresponse::HttpResponse;
@@ -26,10 +27,20 @@ pub struct HttpContext<A> where A: Actor<Context=HttpContext<A>> + Route,
     stream: VecDeque<Frame>,
     wait: ActorWaitCell<A>,
     app_state: Rc<<A as Route>::State>,
+    disconnected: bool,
 }
 
+impl<A> IoContext for HttpContext<A> where A: Actor<Context=Self> + Route {
 
-impl<A> ActorContext<A> for HttpContext<A> where A: Actor<Context=Self> + Route
+    fn disconnected(&mut self) {
+        self.disconnected = true;
+        if self.state == ActorState::Running {
+            self.state = ActorState::Stopping;
+        }
+    }
+}
+
+impl<A> ActorContext for HttpContext<A> where A: Actor<Context=Self> + Route
 {
     /// Stop actor execution
     fn stop(&mut self) {
@@ -95,6 +106,7 @@ impl<A> HttpContext<A> where A: Actor<Context=Self> + Route {
             wait: ActorWaitCell::default(),
             stream: VecDeque::new(),
             app_state: state,
+            disconnected: false,
         }
     }
 
@@ -124,6 +136,11 @@ impl<A> HttpContext<A> where A: Actor<Context=Self> + Route {
     pub fn write_eof(&mut self) {
         self.stream.push_back(Frame::Payload(None))
     }
+
+    /// Check if connection still open
+    pub fn connected(&self) -> bool {
+        !self.disconnected
+    }
 }
 
 impl<A> HttpContext<A> where A: Actor<Context=Self> + Route {
@@ -157,7 +174,6 @@ impl<A> Stream for HttpContext<A> where A: Actor<Context=Self> + Route
         if self.act.is_none() {
             return Ok(Async::NotReady)
         }
-
         let act: &mut A = unsafe {
             std::mem::transmute(self.act.as_mut().unwrap() as &mut A)
         };
diff --git a/src/resource.rs b/src/resource.rs
index a69e27ca..a030d5ec 100644
--- a/src/resource.rs
+++ b/src/resource.rs
@@ -158,7 +158,7 @@ impl<A> Reply<A> where A: Actor + Route
             },
             ReplyItem::Actor(act) => {
                 ctx.set_actor(act);
-                Task::with_stream(ctx)
+                Task::with_context(ctx)
             }
         }
     }
diff --git a/src/server.rs b/src/server.rs
index 8331b48f..e1e6a228 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -171,11 +171,11 @@ pub struct HttpChannel<T: 'static, A: 'static, H: 'static> {
     keepalive_timer: Option<Timeout>,
 }
 
-/*impl<T: 'static, A: 'static> Drop for HttpChannel<T, A> {
+impl<T: 'static, A: 'static, H: 'static> Drop for HttpChannel<T, A, H> {
     fn drop(&mut self) {
         println!("Drop http channel");
     }
-}*/
+}
 
 impl<T, A, H> Actor for HttpChannel<T, A, H>
     where T: AsyncRead + AsyncWrite + 'static,
@@ -205,6 +205,8 @@ impl<T, A, H> Future for HttpChannel<T, A, H>
         }
 
         loop {
+            let mut not_ready = true;
+
             // check in-flight messages
             let mut idx = 0;
             while idx < self.items.len() {
@@ -218,6 +220,7 @@ impl<T, A, H> Future for HttpChannel<T, A, H>
                     match self.items[idx].task.poll_io(&mut self.stream, req)
                     {
                         Ok(Async::Ready(ready)) => {
+                            not_ready = false;
                             let mut item = self.items.pop_front().unwrap();
 
                             // overide keep-alive state
@@ -247,8 +250,10 @@ impl<T, A, H> Future for HttpChannel<T, A, H>
                 } else if !self.items[idx].finished && !self.items[idx].error {
                     match self.items[idx].task.poll() {
                         Ok(Async::NotReady) => (),
-                        Ok(Async::Ready(_)) =>
-                            self.items[idx].finished = true,
+                        Ok(Async::Ready(_)) => {
+                            not_ready = false;
+                            self.items[idx].finished = true;
+                        },
                         Err(_) =>
                             self.items[idx].error = true,
                     }
@@ -267,8 +272,10 @@ impl<T, A, H> Future for HttpChannel<T, A, H>
                 if !self.inactive[idx].finished && !self.inactive[idx].error {
                     match self.inactive[idx].task.poll() {
                         Ok(Async::NotReady) => (),
-                        Ok(Async::Ready(_)) =>
-                            self.inactive[idx].finished = true,
+                        Ok(Async::Ready(_)) => {
+                            not_ready = false;
+                            self.inactive[idx].finished = true
+                        }
                         Err(_) =>
                             self.inactive[idx].error = true,
                     }
@@ -280,6 +287,8 @@ impl<T, A, H> Future for HttpChannel<T, A, H>
             if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES {
                 match self.reader.parse(&mut self.stream) {
                     Ok(Async::Ready((mut req, payload))) => {
+                        not_ready = false;
+
                         // stop keepalive timer
                         self.keepalive_timer.take();
 
@@ -300,6 +309,12 @@ impl<T, A, H> Future for HttpChannel<T, A, H>
                                    finished: false});
                     }
                     Err(err) => {
+                        // notify all tasks
+                        not_ready = false;
+                        for entry in &mut self.items {
+                            entry.task.disconnected()
+                        }
+
                         // kill keepalive
                         self.keepalive = false;
                         self.keepalive_timer.take();
@@ -344,6 +359,10 @@ impl<T, A, H> Future for HttpChannel<T, A, H>
             if self.items.is_empty() && self.inactive.is_empty() && self.error {
                 return Ok(Async::Ready(()))
             }
+
+            if not_ready {
+                return Ok(Async::NotReady)
+            }
         }
     }
 }
diff --git a/src/task.rs b/src/task.rs
index 7e29cbf5..f5f155e6 100644
--- a/src/task.rs
+++ b/src/task.rs
@@ -1,4 +1,4 @@
-use std::{cmp, io};
+use std::{mem, cmp, io};
 use std::rc::Rc;
 use std::fmt::Write;
 use std::collections::VecDeque;
@@ -47,16 +47,27 @@ impl TaskIOState {
     }
 }
 
+enum TaskStream {
+    None,
+    Stream(Box<FrameStream>),
+    Context(Box<IoContext<Item=Frame, Error=io::Error>>),
+}
+
+pub(crate) trait IoContext: Stream<Item=Frame, Error=io::Error> + 'static {
+    fn disconnected(&mut self);
+}
+
 pub struct Task {
     state: TaskRunningState,
     iostate: TaskIOState,
     frames: VecDeque<Frame>,
-    stream: Option<Box<FrameStream>>,
+    stream: TaskStream,
     encoder: Encoder,
     buffer: BytesMut,
     upgrade: bool,
     keepalive: bool,
     prepared: Option<HttpResponse>,
+    disconnected: bool,
     middlewares: Option<Rc<Vec<Box<Middleware>>>>,
 }
 
@@ -71,12 +82,13 @@ impl Task {
             state: TaskRunningState::Running,
             iostate: TaskIOState::Done,
             frames: frames,
-            stream: None,
+            stream: TaskStream::None,
             encoder: Encoder::length(0),
             buffer: BytesMut::new(),
             upgrade: false,
             keepalive: false,
             prepared: None,
+            disconnected: false,
             middlewares: None,
         }
     }
@@ -88,12 +100,30 @@ impl Task {
             state: TaskRunningState::Running,
             iostate: TaskIOState::ReadingMessage,
             frames: VecDeque::new(),
-            stream: Some(Box::new(stream)),
+            stream: TaskStream::Stream(Box::new(stream)),
             encoder: Encoder::length(0),
             buffer: BytesMut::new(),
             upgrade: false,
             keepalive: false,
             prepared: None,
+            disconnected: false,
+            middlewares: None,
+        }
+    }
+
+    pub(crate) fn with_context<C: IoContext>(ctx: C) -> Self
+    {
+        Task {
+            state: TaskRunningState::Running,
+            iostate: TaskIOState::ReadingMessage,
+            frames: VecDeque::new(),
+            stream: TaskStream::Context(Box::new(ctx)),
+            encoder: Encoder::length(0),
+            buffer: BytesMut::new(),
+            upgrade: false,
+            keepalive: false,
+            prepared: None,
+            disconnected: false,
             middlewares: None,
         }
     }
@@ -106,6 +136,15 @@ impl Task {
         self.middlewares = Some(middlewares);
     }
 
+    pub(crate) fn disconnected(&mut self) {
+        let len = self.buffer.len();
+        self.buffer.split_to(len);
+        self.disconnected = true;
+        if let TaskStream::Context(ref mut ctx) = self.stream {
+            ctx.disconnected();
+        }
+    }
+
     fn prepare(&mut self, req: &mut HttpRequest, msg: HttpResponse)
     {
         trace!("Prepare message status={:?}", msg.status);
@@ -252,20 +291,26 @@ impl Task {
                 trace!("IO Frame: {:?}", frame);
                 match frame {
                     Frame::Message(response) => {
-                        self.prepare(req, response);
+                        if !self.disconnected {
+                            self.prepare(req, response);
+                        }
                     }
                     Frame::Payload(Some(chunk)) => {
-                        if self.prepared.is_some() {
-                            // TODO: add warning, write after EOF
-                            self.encoder.encode(&mut self.buffer, chunk.as_ref());
-                        } else {
-                            // might be response for EXCEPT
-                            self.buffer.extend_from_slice(chunk.as_ref())
+                        if !self.disconnected {
+                            if self.prepared.is_some() {
+                                // TODO: add warning, write after EOF
+                                self.encoder.encode(&mut self.buffer, chunk.as_ref());
+                            } else {
+                                // might be response for EXCEPT
+                                self.buffer.extend_from_slice(chunk.as_ref())
+                            }
                         }
                     },
                     Frame::Payload(None) => {
-                        // TODO: add error "not eof""
-                        if !self.encoder.encode(&mut self.buffer, [].as_ref()) {
+                        if !self.disconnected &&
+                            !self.encoder.encode(&mut self.buffer, [].as_ref())
+                        {
+                            // TODO: add error "not eof""
                             debug!("last payload item, but it is not EOF ");
                             return Err(())
                         }
@@ -276,15 +321,17 @@ impl Task {
         }
 
         // write bytes to TcpStream
-        while !self.buffer.is_empty() {
-            match io.write(self.buffer.as_ref()) {
-                Ok(n) => {
-                    self.buffer.split_to(n);
-                },
-                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
-                    break
+        if !self.disconnected {
+            while !self.buffer.is_empty() {
+                match io.write(self.buffer.as_ref()) {
+                    Ok(n) => {
+                        self.buffer.split_to(n);
+                    },
+                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+                        break
+                    }
+                    Err(_) => return Err(()),
                 }
-                Err(_) => return Err(()),
             }
         }
 
@@ -295,10 +342,13 @@ impl Task {
             } else if self.state == TaskRunningState::Paused {
                 self.state = TaskRunningState::Running;
             }
+        } else {
+            // at this point we wont get any more Frames
+            self.iostate = TaskIOState::Done;
         }
 
         // response is completed
-        if self.buffer.is_empty() && self.iostate.is_done() {
+        if (self.buffer.is_empty() || self.disconnected) && self.iostate.is_done() {
             // run middlewares
             if let Some(ref mut resp) = self.prepared {
                 if let Some(middlewares) = self.middlewares.take() {
@@ -313,6 +363,46 @@ impl Task {
             Ok(Async::NotReady)
         }
     }
+
+    fn poll_stream<S>(&mut self, stream: &mut S) -> Poll<(), ()>
+        where S: Stream<Item=Frame, Error=io::Error>
+    {
+        loop {
+            match stream.poll() {
+                Ok(Async::Ready(Some(frame))) => {
+                    match frame {
+                        Frame::Message(ref msg) => {
+                            if self.iostate != TaskIOState::ReadingMessage {
+                                error!("Non expected frame {:?}", frame);
+                                return Err(())
+                            }
+                            self.upgrade = msg.upgrade();
+                            if self.upgrade || msg.body().has_body() {
+                                self.iostate = TaskIOState::ReadingPayload;
+                            } else {
+                                self.iostate = TaskIOState::Done;
+                            }
+                        },
+                        Frame::Payload(ref chunk) => {
+                            if chunk.is_none() {
+                                self.iostate = TaskIOState::Done;
+                            } else if self.iostate != TaskIOState::ReadingPayload {
+                                error!("Non expected frame {:?}", self.iostate);
+                                return Err(())
+                            }
+                        },
+                    }
+                    self.frames.push_back(frame)
+                },
+                Ok(Async::Ready(None)) =>
+                    return Ok(Async::Ready(())),
+                Ok(Async::NotReady) =>
+                    return Ok(Async::NotReady),
+                Err(_) =>
+                    return Err(())
+            }
+        }
+    }
 }
 
 impl Future for Task {
@@ -320,45 +410,15 @@ impl Future for Task {
     type Error = ();
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        if let Some(ref mut stream) = self.stream {
-            loop {
-                match stream.poll() {
-                    Ok(Async::Ready(Some(frame))) => {
-                        match frame {
-                            Frame::Message(ref msg) => {
-                                if self.iostate != TaskIOState::ReadingMessage {
-                                    error!("Non expected frame {:?}", frame);
-                                    return Err(())
-                                }
-                                self.upgrade = msg.upgrade();
-                                if self.upgrade || msg.body().has_body() {
-                                    self.iostate = TaskIOState::ReadingPayload;
-                                } else {
-                                    self.iostate = TaskIOState::Done;
-                                }
-                            },
-                            Frame::Payload(ref chunk) => {
-                                if chunk.is_none() {
-                                    self.iostate = TaskIOState::Done;
-                                } else if self.iostate != TaskIOState::ReadingPayload {
-                                    error!("Non expected frame {:?}", self.iostate);
-                                    return Err(())
-                                }
-                            },
-                        }
-                        self.frames.push_back(frame)
-                    },
-                    Ok(Async::Ready(None)) =>
-                        return Ok(Async::Ready(())),
-                    Ok(Async::NotReady) =>
-                        return Ok(Async::NotReady),
-                    Err(_) =>
-                        return Err(())
-                }
-            }
-        } else {
-            Ok(Async::Ready(()))
-        }
+        let mut s = mem::replace(&mut self.stream, TaskStream::None);
+
+        let result = match s {
+            TaskStream::None => Ok(Async::Ready(())),
+            TaskStream::Stream(ref mut stream) => self.poll_stream(stream),
+            TaskStream::Context(ref mut context) => self.poll_stream(context),
+        };
+        self.stream = s;
+        result
     }
 }