From 477e1d69533a515419978377e48428779964c36f Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Thu, 20 Sep 2018 11:16:12 -0700
Subject: [PATCH] add keep-alive service

---
 src/keepalive.rs                 | 123 +++++++++++++++++++++++++++++++
 src/lib.rs                       |   6 +-
 src/{lowrestimer.rs => timer.rs} |  67 +++++++++--------
 3 files changed, 165 insertions(+), 31 deletions(-)
 create mode 100644 src/keepalive.rs
 rename src/{lowrestimer.rs => timer.rs} (67%)

diff --git a/src/keepalive.rs b/src/keepalive.rs
new file mode 100644
index 00000000..fc330f18
--- /dev/null
+++ b/src/keepalive.rs
@@ -0,0 +1,123 @@
+use std::marker::PhantomData;
+use std::time::{Duration, Instant};
+
+use futures::future::{ok, FutureResult};
+use futures::{Async, Future, Poll};
+use tokio_timer::Delay;
+
+use super::service::{NewService, Service};
+use super::timer::{LowResTimer, LowResTimerService};
+use super::Never;
+
+pub struct KeepAlive<R, E, F> {
+    f: F,
+    ka: Duration,
+    timer: LowResTimer,
+    _t: PhantomData<(R, E)>,
+}
+
+impl<R, E, F> KeepAlive<R, E, F>
+where
+    F: Fn() -> E + Clone,
+{
+    pub fn new(ka: Duration, timer: LowResTimer, f: F) -> Self {
+        KeepAlive {
+            f,
+            ka,
+            timer,
+            _t: PhantomData,
+        }
+    }
+}
+
+impl<R, E, F> Clone for KeepAlive<R, E, F>
+where
+    F: Fn() -> E + Clone,
+{
+    fn clone(&self) -> Self {
+        KeepAlive {
+            f: self.f.clone(),
+            ka: self.ka,
+            timer: self.timer.clone(),
+            _t: PhantomData,
+        }
+    }
+}
+
+impl<R, E, F> NewService for KeepAlive<R, E, F>
+where
+    F: Fn() -> E + Clone,
+{
+    type Request = R;
+    type Response = R;
+    type Error = E;
+    type InitError = Never;
+    type Service = KeepAliveService<R, E, F>;
+    type Future = FutureResult<Self::Service, Self::InitError>;
+
+    fn new_service(&self) -> Self::Future {
+        ok(KeepAliveService::new(
+            self.ka,
+            self.timer.timer(),
+            self.f.clone(),
+        ))
+    }
+}
+
+pub struct KeepAliveService<R, E, F> {
+    f: F,
+    ka: Duration,
+    timer: LowResTimerService,
+    delay: Delay,
+    expire: Instant,
+    _t: PhantomData<(R, E)>,
+}
+
+impl<R, E, F> KeepAliveService<R, E, F>
+where
+    F: Fn() -> E,
+{
+    pub fn new(ka: Duration, mut timer: LowResTimerService, f: F) -> Self {
+        let expire = timer.now() + ka;
+        KeepAliveService {
+            f,
+            ka,
+            timer,
+            delay: Delay::new(expire),
+            expire,
+            _t: PhantomData,
+        }
+    }
+}
+
+impl<R, E, F> Service for KeepAliveService<R, E, F>
+where
+    F: Fn() -> E,
+{
+    type Request = R;
+    type Response = R;
+    type Error = E;
+    type Future = FutureResult<R, E>;
+
+    fn poll_ready(&mut self) -> Poll<(), Self::Error> {
+        match self.delay.poll() {
+            Ok(Async::Ready(_)) => {
+                let now = self.timer.now();
+                if self.expire <= now {
+                    Err((self.f)())
+                } else {
+                    self.delay = Delay::new(self.expire);
+                    let _ = self.delay.poll();
+                    Ok(Async::Ready(()))
+                }
+            }
+            Ok(Async::NotReady) => Ok(Async::Ready(())),
+            Err(_) => panic!(),
+        }
+    }
+
+    fn call(&mut self, req: Self::Request) -> Self::Future {
+        self.expire = self.timer.now() + self.ka;
+        ok(req)
+    }
+}
diff --git a/src/lib.rs b/src/lib.rs
index ae6f3d9c..a1204730 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -60,9 +60,13 @@ pub mod connector;
 pub mod counter;
 pub mod framed;
 pub mod inflight;
+pub mod keepalive;
 pub mod resolver;
 pub mod server;
 pub mod service;
 pub mod ssl;
 pub mod stream;
-pub mod lowrestimer;
+pub mod timer;
+
+#[derive(Copy, Clone, Debug)]
+pub enum Never {}
diff --git a/src/lowrestimer.rs b/src/timer.rs
similarity index 67%
rename from src/lowrestimer.rs
rename to src/timer.rs
index 32087165..df8f83a6 100644
--- a/src/lowrestimer.rs
+++ b/src/timer.rs
@@ -2,12 +2,13 @@ use std::cell::RefCell;
 use std::rc::Rc;
 use std::time::{Duration, Instant};
 
-use futures::{Future, Poll, Async};
 use futures::future::{ok, FutureResult};
+use futures::{Async, Future, Poll};
 use tokio_current_thread::spawn;
 use tokio_timer::sleep;
 
-use super::service::{Service, NewService};
+use super::service::{NewService, Service};
+use super::Never;
 
 #[derive(Clone, Debug)]
 pub struct LowResTimer(Rc<RefCell<Inner>>);
@@ -22,7 +23,7 @@ impl Inner {
     fn new(interval: Duration) -> Self {
         Inner {
             interval,
-            current: None
+            current: None,
         }
     }
 }
@@ -31,6 +32,10 @@ impl LowResTimer {
     pub fn with_interval(interval: Duration) -> LowResTimer {
         LowResTimer(Rc::new(RefCell::new(Inner::new(interval))))
     }
+
+    pub fn timer(&self) -> LowResTimerService {
+        LowResTimerService(self.0.clone())
+    }
 }
 
 impl Default for LowResTimer {
@@ -42,40 +47,30 @@ impl Default for LowResTimer {
 impl NewService for LowResTimer {
     type Request = ();
     type Response = Instant;
-    type Error = ();
-    type InitError = ();
+    type Error = Never;
+    type InitError = Never;
     type Service = LowResTimerService;
     type Future = FutureResult<Self::Service, Self::InitError>;
 
     fn new_service(&self) -> Self::Future {
-        ok(LowResTimerService(self.0.clone()))
+        ok(self.timer())
     }
 }
 
-
 #[derive(Clone, Debug)]
 pub struct LowResTimerService(Rc<RefCell<Inner>>);
 
 impl LowResTimerService {
-    pub fn with_interval(interval: Duration) -> LowResTimerService {
-        LowResTimerService(Rc::new(RefCell::new(Inner::new(interval))))
-    }
-}
-
-impl Service for LowResTimerService {
-    type Request = ();
-    type Response = Instant;
-    type Error = ();
-    type Future = FutureResult<Self::Response, Self::Error>;
-
-    fn poll_ready(&mut self) -> Poll<(), Self::Error> {
-        Ok(Async::Ready(()))
+    pub fn with_resolution(resolution: Duration) -> LowResTimerService {
+        LowResTimerService(Rc::new(RefCell::new(Inner::new(resolution))))
     }
 
-    fn call(&mut self, _: ()) -> Self::Future {
+    /// Get current time. This function has to be called from
+    /// future's poll method, otherwise it panics.
+    pub fn now(&mut self) -> Instant {
         let cur = self.0.borrow().current.clone();
         if let Some(cur) = cur {
-            ok(cur)
+            cur
         } else {
             let now = Instant::now();
             let inner = self.0.clone();
@@ -85,14 +80,26 @@ impl Service for LowResTimerService {
                 b.interval
             };
 
-            spawn(
-                sleep(interval)
-                    .map_err(|_| panic!())
-                    .and_then(move|_| {
-                        inner.borrow_mut().current.take();
-                        Ok(())
-                    }));
-            ok(now)
+            spawn(sleep(interval).map_err(|_| panic!()).and_then(move |_| {
+                inner.borrow_mut().current.take();
+                Ok(())
+            }));
+            now
         }
     }
 }
+
+impl Service for LowResTimerService {
+    type Request = ();
+    type Response = Instant;
+    type Error = Never;
+    type Future = FutureResult<Self::Response, Self::Error>;
+
+    fn poll_ready(&mut self) -> Poll<(), Self::Error> {
+        Ok(Async::Ready(()))
+    }
+
+    fn call(&mut self, _: ()) -> Self::Future {
+        ok(self.now())
+    }
+}