From 9559f6a1759c3605ef1243a510ff36ec15018586 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Wed, 3 Jan 2018 23:41:55 -0800
Subject: [PATCH] introduce IoStream trait for low level stream operations

---
 src/channel.rs | 186 ++++++++++++++++++++++++++++++++++++++++---------
 src/h1.rs      |  47 +++++++++----
 src/server.rs  |  20 +++---
 src/worker.rs  |   4 +-
 4 files changed, 201 insertions(+), 56 deletions(-)

diff --git a/src/channel.rs b/src/channel.rs
index 963ef106..baae32fa 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -1,8 +1,8 @@
-use std::{ptr, mem, time};
+use std::{ptr, mem, time, io};
 use std::rc::Rc;
 use std::net::{SocketAddr, Shutdown};
 
-use bytes::Bytes;
+use bytes::{Bytes, Buf, BufMut};
 use futures::{Future, Poll, Async};
 use tokio_io::{AsyncRead, AsyncWrite};
 use tokio_core::net::TcpStream;
@@ -48,8 +48,7 @@ impl<T: HttpHandler> IntoHttpHandler for T {
     }
 }
 
-enum HttpProtocol<T, H>
-    where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static
+enum HttpProtocol<T: IoStream, H: 'static>
 {
     H1(h1::Http1<T, H>),
     H2(h2::Http2<T, H>),
@@ -57,22 +56,14 @@ enum HttpProtocol<T, H>
 
 #[doc(hidden)]
 pub struct HttpChannel<T, H>
-    where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static
+    where T: IoStream, H: HttpHandler + 'static
 {
     proto: Option<HttpProtocol<T, H>>,
     node: Option<Node<HttpChannel<T, H>>>,
 }
 
-impl<T, H> Drop for HttpChannel<T, H>
-    where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static
-{
-    fn drop(&mut self) {
-        self.shutdown()
-    }
-}
-
 impl<T, H> HttpChannel<T, H>
-    where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static
+    where T: IoStream, H: HttpHandler + 'static
 {
     pub(crate) fn new(h: Rc<WorkerSettings<H>>,
                       io: T, peer: Option<SocketAddr>, http2: bool) -> HttpChannel<T, H>
@@ -91,19 +82,12 @@ impl<T, H> HttpChannel<T, H>
         }
     }
 
-    fn io(&mut self) -> Option<&mut T> {
-        match self.proto {
-            Some(HttpProtocol::H1(ref mut h1)) => {
-                Some(h1.io())
-            }
-            _ => None,
-        }
-    }
-
     fn shutdown(&mut self) {
         match self.proto {
             Some(HttpProtocol::H1(ref mut h1)) => {
-                let _ = h1.io().shutdown();
+                let io = h1.io();
+                let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0)));
+                let _ = IoStream::shutdown(io, Shutdown::Both);
             }
             Some(HttpProtocol::H2(ref mut h2)) => {
                 h2.shutdown()
@@ -122,7 +106,7 @@ impl<T, H> HttpChannel<T, H>
 }*/
 
 impl<T, H> Future for HttpChannel<T, H>
-    where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static
+    where T: IoStream, H: HttpHandler + 'static
 {
     type Item = ();
     type Error = ();
@@ -242,7 +226,7 @@ impl Node<()> {
         }
     }
 
-    pub(crate) fn traverse<H>(&self) where H: HttpHandler + 'static {
+    pub(crate) fn traverse<T, H>(&self) where T: IoStream, H: HttpHandler + 'static {
         let mut next = self.next.as_ref();
         loop {
             if let Some(n) = next {
@@ -251,13 +235,8 @@ impl Node<()> {
                     next = n.next.as_ref();
 
                     if !n.element.is_null() {
-                        let ch: &mut HttpChannel<TcpStream, H> = mem::transmute(
+                        let ch: &mut HttpChannel<T, H> = mem::transmute(
                             &mut *(n.element as *mut _));
-                        if let Some(io) = ch.io() {
-                            let _ = TcpStream::set_linger(io, Some(time::Duration::new(0, 0)));
-                            let _ = TcpStream::shutdown(io, Shutdown::Both);
-                            continue;
-                        }
                         ch.shutdown();
                     }
                 }
@@ -267,3 +246,146 @@ impl Node<()> {
         }
     }
 }
+
+
+pub trait IoStream: AsyncRead + AsyncWrite + 'static {
+    fn shutdown(&mut self, how: Shutdown) -> io::Result<()>;
+
+    fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()>;
+
+    fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()>;
+}
+
+impl IoStream for TcpStream {
+    #[inline]
+    fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
+        TcpStream::shutdown(self, how)
+    }
+
+    #[inline]
+    fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
+        TcpStream::set_nodelay(self, nodelay)
+    }
+
+    #[inline]
+    fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()> {
+        TcpStream::set_linger(self, dur)
+    }
+}
+
+
+pub(crate) struct WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
+   io: T,
+}
+
+impl<T> WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static
+{
+    pub fn new(io: T) -> Self {
+        WrapperStream{io: io}
+    }
+}
+
+impl<T> IoStream for WrapperStream<T>
+    where T: AsyncRead + AsyncWrite + 'static
+{
+    #[inline]
+    fn shutdown(&mut self, _: Shutdown) -> io::Result<()> {
+        Ok(())
+    }
+
+    #[inline]
+    fn set_nodelay(&mut self, _: bool) -> io::Result<()> {
+        Ok(())
+    }
+
+    #[inline]
+    fn set_linger(&mut self, _: Option<time::Duration>) -> io::Result<()> {
+        Ok(())
+    }
+}
+
+impl<T> io::Read for WrapperStream<T>
+    where T: AsyncRead + AsyncWrite + 'static
+{
+    #[inline]
+    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+        self.io.read(buf)
+    }
+}
+
+impl<T> io::Write for WrapperStream<T>
+    where T: AsyncRead + AsyncWrite + 'static
+{
+    #[inline]
+    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+        self.io.write(buf)
+    }
+    #[inline]
+    fn flush(&mut self) -> io::Result<()> {
+        self.io.flush()
+    }
+}
+
+impl<T> AsyncRead for WrapperStream<T>
+    where T: AsyncRead + AsyncWrite + 'static
+{
+    fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
+        self.io.read_buf(buf)
+    }
+}
+
+impl<T> AsyncWrite for WrapperStream<T>
+    where T: AsyncRead + AsyncWrite + 'static
+{
+    fn shutdown(&mut self) -> Poll<(), io::Error> {
+        self.io.shutdown()
+    }
+    fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
+        self.io.write_buf(buf)
+    }
+}
+
+
+#[cfg(feature="alpn")]
+use tokio_openssl::SslStream;
+
+#[cfg(feature="alpn")]
+impl IoStream for SslStream<TcpStream> {
+    #[inline]
+    fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> {
+        let _ = self.get_mut().shutdown();
+        Ok(())
+    }
+
+    #[inline]
+    fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
+        self.get_mut().get_mut().set_nodelay(nodelay)
+    }
+
+    #[inline]
+    fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()> {
+        self.get_mut().get_mut().set_linger(dur)
+    }
+}
+
+#[cfg(feature="tls")]
+use tokio_tls::TlsStream;
+
+#[cfg(feature="tls")]
+impl IoStream for TlsStream<TcpStream> {
+    #[inline]
+    fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> {
+        let _ = self.get_mut().shutdown();
+        Ok(())
+    }
+
+    #[inline]
+    fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
+        self.get_mut().get_mut().set_nodelay(nodelay)
+    }
+
+    #[inline]
+    fn set_linger(&mut self, dur: Option<time::Duration>) -> io::Result<()> {
+        self.get_mut().get_mut().set_linger(dur)
+    }
+}
diff --git a/src/h1.rs b/src/h1.rs
index e0358a15..e5f592dd 100644
--- a/src/h1.rs
+++ b/src/h1.rs
@@ -10,12 +10,11 @@ use http::{Uri, Method, Version, HttpTryFrom, HeaderMap};
 use http::header::{self, HeaderName, HeaderValue};
 use bytes::{Bytes, BytesMut, BufMut};
 use futures::{Future, Poll, Async};
-use tokio_io::{AsyncRead, AsyncWrite};
 use tokio_core::reactor::Timeout;
 
 use pipeline::Pipeline;
 use encoding::PayloadType;
-use channel::{HttpHandler, HttpHandlerTask};
+use channel::{HttpHandler, HttpHandlerTask, IoStream};
 use h1writer::{Writer, H1Writer};
 use worker::WorkerSettings;
 use httpcodes::HTTPNotFound;
@@ -57,7 +56,7 @@ enum Item {
     Http2,
 }
 
-pub(crate) struct Http1<T: AsyncWrite + 'static, H: 'static> {
+pub(crate) struct Http1<T: IoStream, H: 'static> {
     flags: Flags,
     settings: Rc<WorkerSettings<H>>,
     addr: Option<SocketAddr>,
@@ -74,8 +73,7 @@ struct Entry {
 }
 
 impl<T, H> Http1<T, H>
-    where T: AsyncRead + AsyncWrite + 'static,
-          H: HttpHandler + 'static
+    where T: IoStream, H: HttpHandler + 'static
 {
     pub fn new(h: Rc<WorkerSettings<H>>, stream: T, addr: Option<SocketAddr>) -> Self {
         let bytes = h.get_shared_bytes();
@@ -417,7 +415,7 @@ impl Reader {
     pub fn parse<T, H>(&mut self, io: &mut T,
                        buf: &mut BytesMut,
                        settings: &WorkerSettings<H>) -> Poll<Item, ReaderError>
-        where T: AsyncRead
+        where T: IoStream
     {
         // read payload
         if self.payload.is_some() {
@@ -507,8 +505,8 @@ impl Reader {
         }
     }
 
-    fn read_from_io<T: AsyncRead>(&mut self, io: &mut T, buf: &mut BytesMut)
-                                  -> Poll<usize, io::Error>
+    fn read_from_io<T: IoStream>(&mut self, io: &mut T, buf: &mut BytesMut)
+                                 -> Poll<usize, io::Error>
     {
         unsafe {
             if buf.remaining_mut() < LW_BUFFER_SIZE {
@@ -894,14 +892,17 @@ impl ChunkedState {
 
 #[cfg(test)]
 mod tests {
-    use std::{io, cmp};
-    use bytes::{Bytes, BytesMut};
-    use futures::{Async};
-    use tokio_io::AsyncRead;
+    use std::{io, cmp, time};
+    use std::net::Shutdown;
+    use bytes::{Bytes, BytesMut, Buf};
+    use futures::Async;
+    use tokio_io::{AsyncRead, AsyncWrite};
     use http::{Version, Method};
+
     use super::*;
     use application::HttpApplication;
     use worker::WorkerSettings;
+    use channel::IoStream;
 
     struct Buffer {
         buf: Bytes,
@@ -940,6 +941,28 @@ mod tests {
         }
     }
 
+    impl IoStream for Buffer {
+        fn shutdown(&self, _: Shutdown) -> io::Result<()> {
+            Ok(())
+        }
+        fn set_nodelay(&self, _: bool) -> io::Result<()> {
+            Ok(())
+        }
+        fn set_linger(&self, _: Option<time::Duration>) -> io::Result<()> {
+            Ok(())
+        }
+    }
+    impl io::Write for Buffer {
+        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {Ok(buf.len())}
+        fn flush(&mut self) -> io::Result<()> {Ok(())}
+    }
+    impl AsyncWrite for Buffer {
+        fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(Async::Ready(())) }
+        fn write_buf<B: Buf>(&mut self, _: &mut B) -> Poll<usize, io::Error> {
+            Ok(Async::NotReady)
+        }
+    }
+
     macro_rules! not_ready {
         ($e:expr) => (match $e {
             Ok(Async::NotReady) => (),
diff --git a/src/server.rs b/src/server.rs
index 1833e8ae..d602e276 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -31,7 +31,7 @@ use tokio_openssl::SslStream;
 use actix::actors::signal;
 
 use helpers;
-use channel::{HttpChannel, HttpHandler, IntoHttpHandler};
+use channel::{HttpChannel, HttpHandler, IntoHttpHandler, IoStream, WrapperStream};
 use worker::{Conn, Worker, WorkerSettings, StreamHandlerType, StopWorker};
 
 /// Various server settings
@@ -131,7 +131,7 @@ impl<T: 'static, A: 'static, H: HttpHandler + 'static, U: 'static>  HttpServer<T
 
 impl<T, A, H, U, V> HttpServer<T, A, H, U>
     where A: 'static,
-          T: AsyncRead + AsyncWrite + 'static,
+          T: IoStream,
           H: HttpHandler,
           U: IntoIterator<Item=V> + 'static,
           V: IntoHttpHandler<Handler=H>,
@@ -450,7 +450,7 @@ impl<H: HttpHandler, U, V> HttpServer<SslStream<TcpStream>, net::SocketAddr, H,
     }
 }
 
-impl<T, A, H, U, V> HttpServer<T, A, H, U>
+impl<T, A, H, U, V> HttpServer<WrapperStream<T>, A, H, U>
     where A: 'static,
           T: AsyncRead + AsyncWrite + 'static,
           H: HttpHandler,
@@ -488,7 +488,7 @@ impl<T, A, H, U, V> HttpServer<T, A, H, U>
         // start server
         HttpServer::create(move |ctx| {
             ctx.add_stream(stream.map(
-                move |(t, _)| Conn{io: t, peer: None, http2: false}));
+                move |(t, _)| Conn{io: WrapperStream::new(t), peer: None, http2: false}));
             self
         })
     }
@@ -499,7 +499,7 @@ impl<T, A, H, U, V> HttpServer<T, A, H, U>
 /// Handle `SIGINT`, `SIGTERM`, `SIGQUIT` signals and send `SystemExit(0)`
 /// message to `System` actor.
 impl<T, A, H, U> Handler<signal::Signal> for HttpServer<T, A, H, U>
-    where T: AsyncRead + AsyncWrite + 'static,
+    where T: IoStream,
           H: HttpHandler + 'static,
           U: 'static,
           A: 'static,
@@ -530,13 +530,13 @@ impl<T, A, H, U> Handler<signal::Signal> for HttpServer<T, A, H, U>
 }
 
 impl<T, A, H, U> StreamHandler<Conn<T>, io::Error> for HttpServer<T, A, H, U>
-    where T: AsyncRead + AsyncWrite + 'static,
+    where T: IoStream,
           H: HttpHandler + 'static,
           U: 'static,
           A: 'static {}
 
 impl<T, A, H, U> Handler<Conn<T>, io::Error> for HttpServer<T, A, H, U>
-    where T: AsyncRead + AsyncWrite + 'static,
+    where T: IoStream,
           H: HttpHandler + 'static,
           U: 'static,
           A: 'static,
@@ -573,7 +573,7 @@ pub struct StopServer {
 }
 
 impl<T, A, H, U> Handler<PauseServer> for HttpServer<T, A, H, U>
-    where T: AsyncRead + AsyncWrite + 'static,
+    where T: IoStream,
           H: HttpHandler + 'static,
           U: 'static,
           A: 'static,
@@ -589,7 +589,7 @@ impl<T, A, H, U> Handler<PauseServer> for HttpServer<T, A, H, U>
 }
 
 impl<T, A, H, U> Handler<ResumeServer> for HttpServer<T, A, H, U>
-    where T: AsyncRead + AsyncWrite + 'static,
+    where T: IoStream,
           H: HttpHandler + 'static,
           U: 'static,
           A: 'static,
@@ -605,7 +605,7 @@ impl<T, A, H, U> Handler<ResumeServer> for HttpServer<T, A, H, U>
 }
 
 impl<T, A, H, U> Handler<StopServer> for HttpServer<T, A, H, U>
-    where T: AsyncRead + AsyncWrite + 'static,
+    where T: IoStream,
           H: HttpHandler + 'static,
           U: 'static,
           A: 'static,
diff --git a/src/worker.rs b/src/worker.rs
index c6127d2a..d0f73f63 100644
--- a/src/worker.rs
+++ b/src/worker.rs
@@ -135,7 +135,7 @@ impl<H: HttpHandler + 'static> Worker<H> {
                 slf.shutdown_timeout(ctx, tx, d);
             } else {
                 info!("Force shutdown http worker, {} connections", num);
-                slf.settings.head().traverse::<H>();
+                slf.settings.head().traverse::<TcpStream, H>();
                 let _ = tx.send(false);
                 Arbiter::arbiter().send(StopArbiter(0));
             }
@@ -187,7 +187,7 @@ impl<H> Handler<StopWorker> for Worker<H>
             Self::async_reply(rx.map_err(|_| ()).actfuture())
         } else {
             info!("Force shutdown http worker, {} connections", num);
-            self.settings.head().traverse::<H>();
+            self.settings.head().traverse::<TcpStream, H>();
             Self::reply(false)
         }
     }