From 39c39028180511a6a135023bf54d9e4bf5f64318 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Fri, 14 Sep 2018 13:07:38 -0700
Subject: [PATCH] rename Connections to more generic Counter and export it

---
 src/counter.rs       | 76 ++++++++++++++++++++++++++++++++++++
 src/lib.rs           |  1 +
 src/server/mod.rs    |  3 --
 src/server/worker.rs | 91 ++++++--------------------------------------
 src/ssl/mod.rs       |  4 +-
 src/ssl/openssl.rs   |  6 +--
 6 files changed, 93 insertions(+), 88 deletions(-)
 create mode 100644 src/counter.rs

diff --git a/src/counter.rs b/src/counter.rs
new file mode 100644
index 00000000..d2bd0226
--- /dev/null
+++ b/src/counter.rs
@@ -0,0 +1,76 @@
+use std::cell::Cell;
+use std::rc::Rc;
+
+use futures::task::AtomicTask;
+
+#[derive(Clone)]
+/// Simple counter with ability to notify task on reaching specific number
+///
+/// Counter could be cloned, total ncount is shared across all clones.
+pub struct Counter(Rc<CounterInner>);
+
+struct CounterInner {
+    count: Cell<usize>,
+    max: usize,
+    task: AtomicTask,
+}
+
+impl Counter {
+    /// Create `Counter` instance and set max value.
+    pub fn new(max: usize) -> Self {
+        Counter(Rc::new(CounterInner {
+            max,
+            count: Cell::new(0),
+            task: AtomicTask::new(),
+        }))
+    }
+
+    pub fn get(&self) -> CounterGuard {
+        CounterGuard::new(self.0.clone())
+    }
+
+    pub fn check(&self) -> bool {
+        self.0.check()
+    }
+
+    pub fn total(&self) -> usize {
+        self.0.count.get()
+    }
+}
+
+pub struct CounterGuard(Rc<CounterInner>);
+
+impl CounterGuard {
+    fn new(inner: Rc<CounterInner>) -> Self {
+        inner.inc();
+        CounterGuard(inner)
+    }
+}
+
+impl Drop for CounterGuard {
+    fn drop(&mut self) {
+        self.0.dec();
+    }
+}
+
+impl CounterInner {
+    fn inc(&self) {
+        let num = self.count.get() + 1;
+        self.count.set(num);
+        if num == self.max {
+            self.task.register();
+        }
+    }
+
+    fn dec(&self) {
+        let num = self.count.get();
+        self.count.set(num - 1);
+        if num == self.max {
+            self.task.notify();
+        }
+    }
+
+    fn check(&self) -> bool {
+        self.count.get() < self.max
+    }
+}
diff --git a/src/lib.rs b/src/lib.rs
index ca4a6573..446222f7 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -56,6 +56,7 @@ extern crate webpki;
 extern crate webpki_roots;
 
 pub mod connector;
+pub mod counter;
 pub mod framed;
 pub mod resolver;
 pub mod server;
diff --git a/src/server/mod.rs b/src/server/mod.rs
index b080ffe5..0f8eb685 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -10,9 +10,6 @@ mod worker;
 pub use self::server::Server;
 pub use self::services::ServerServiceFactory;
 
-#[allow(unused_imports)]
-pub(crate) use self::worker::{Connections, ConnectionsGuard};
-
 /// Pause accepting incoming connections
 ///
 /// If socket contains some pending connection, they might be dropped.
diff --git a/src/server/worker.rs b/src/server/worker.rs
index eb8f7e48..a741c35d 100644
--- a/src/server/worker.rs
+++ b/src/server/worker.rs
@@ -1,12 +1,9 @@
-use std::cell::Cell;
-use std::rc::Rc;
 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
 use std::sync::Arc;
 use std::{mem, net, time};
 
 use futures::sync::mpsc::{UnboundedReceiver, UnboundedSender};
 use futures::sync::oneshot;
-use futures::task::AtomicTask;
 use futures::{future, Async, Future, Poll, Stream};
 use tokio_current_thread::spawn;
 use tokio_timer::{sleep, Delay};
@@ -17,6 +14,7 @@ use actix::{Arbiter, Message};
 use super::accept::AcceptNotify;
 use super::services::{BoxedServerService, InternalServerServiceFactory, ServerMessage};
 use super::Token;
+use counter::Counter;
 
 pub(crate) enum WorkerCommand {
     Message(Conn),
@@ -50,8 +48,8 @@ pub(crate) fn num_connections() -> usize {
 }
 
 thread_local! {
-    static MAX_CONNS_COUNTER: Connections =
-        Connections::new(MAX_CONNS.load(Ordering::Relaxed));
+    static MAX_CONNS_COUNTER: Counter =
+        Counter::new(MAX_CONNS.load(Ordering::Relaxed));
 }
 
 #[derive(Clone)]
@@ -122,7 +120,7 @@ pub(crate) struct Worker {
     rx: UnboundedReceiver<WorkerCommand>,
     services: Vec<BoxedServerService>,
     availability: WorkerAvailability,
-    conns: Connections,
+    conns: Counter,
     factories: Vec<Box<InternalServerServiceFactory>>,
     state: WorkerState,
 }
@@ -308,8 +306,7 @@ impl Future for Worker {
                     match self.rx.poll() {
                         // handle incoming tcp stream
                         Ok(Async::Ready(Some(WorkerCommand::Message(msg)))) => {
-                            match self.check_readiness()
-                            {
+                            match self.check_readiness() {
                                 Ok(true) => {
                                     let guard = self.conns.get();
                                     spawn(
@@ -320,7 +317,7 @@ impl Future for Worker {
                                                 val
                                             }),
                                     );
-                                    continue
+                                    continue;
                                 }
                                 Ok(false) => {
                                     trace!("Serveice is unsavailable");
@@ -330,12 +327,14 @@ impl Future for Worker {
                                 Err(idx) => {
                                     trace!("Serveice failed, restarting");
                                     self.availability.set(false);
-                                    self.state =
-                                        WorkerState::Restarting(idx, self.factories[idx].create());
+                                    self.state = WorkerState::Restarting(
+                                        idx,
+                                        self.factories[idx].create(),
+                                    );
                                 }
                             }
                             return self.poll();
-                        },
+                        }
                         // `StopWorker` message handler
                         Ok(Async::Ready(Some(WorkerCommand::Stop(graceful, tx)))) => {
                             self.availability.set(false);
@@ -379,71 +378,3 @@ impl Future for Worker {
         Ok(Async::NotReady)
     }
 }
-
-#[derive(Clone)]
-pub(crate) struct Connections(Rc<ConnectionsInner>);
-
-struct ConnectionsInner {
-    count: Cell<usize>,
-    maxconn: usize,
-    task: AtomicTask,
-}
-
-impl Connections {
-    pub fn new(maxconn: usize) -> Self {
-        Connections(Rc::new(ConnectionsInner {
-            maxconn,
-            count: Cell::new(0),
-            task: AtomicTask::new(),
-        }))
-    }
-
-    pub fn get(&self) -> ConnectionsGuard {
-        ConnectionsGuard::new(self.0.clone())
-    }
-
-    pub fn check(&self) -> bool {
-        self.0.check()
-    }
-
-    pub fn total(&self) -> usize {
-        self.0.count.get()
-    }
-}
-
-pub(crate) struct ConnectionsGuard(Rc<ConnectionsInner>);
-
-impl ConnectionsGuard {
-    fn new(inner: Rc<ConnectionsInner>) -> Self {
-        inner.inc();
-        ConnectionsGuard(inner)
-    }
-}
-
-impl Drop for ConnectionsGuard {
-    fn drop(&mut self) {
-        self.0.dec();
-    }
-}
-
-impl ConnectionsInner {
-    fn inc(&self) {
-        let num = self.count.get() + 1;
-        self.count.set(num);
-        if num == self.maxconn {
-            self.task.register();
-        }
-    }
-
-    fn dec(&self) {
-        let num = self.count.get();
-        self.count.set(num - 1);
-        if num == self.maxconn {
-            self.task.notify();
-        }
-    }
-
-    fn check(&self) -> bool {
-        self.count.get() < self.maxconn
-    }
-}
diff --git a/src/ssl/mod.rs b/src/ssl/mod.rs
index 8d56a891..f512ab29 100644
--- a/src/ssl/mod.rs
+++ b/src/ssl/mod.rs
@@ -1,7 +1,7 @@
 //! SSL Services
 use std::sync::atomic::{AtomicUsize, Ordering};
 
-use super::server::Connections;
+use super::counter::Counter;
 
 #[cfg(feature = "ssl")]
 mod openssl;
@@ -21,7 +21,7 @@ pub fn max_concurrent_ssl_connect(num: usize) {
 }
 
 thread_local! {
-    static MAX_CONN_COUNTER: Connections = Connections::new(MAX_CONN.load(Ordering::Relaxed));
+    static MAX_CONN_COUNTER: Counter = Counter::new(MAX_CONN.load(Ordering::Relaxed));
 }
 
 // #[cfg(feature = "tls")]
diff --git a/src/ssl/openssl.rs b/src/ssl/openssl.rs
index 8a84062d..17083c58 100644
--- a/src/ssl/openssl.rs
+++ b/src/ssl/openssl.rs
@@ -7,7 +7,7 @@ use tokio_openssl::{AcceptAsync, ConnectAsync, SslAcceptorExt, SslConnectorExt,
 
 use super::MAX_CONN_COUNTER;
 use connector::ConnectionInfo;
-use server::{Connections, ConnectionsGuard};
+use counter::{Counter, CounterGuard};
 use service::{NewService, Service};
 
 /// Support `SSL` connections via openssl package
@@ -59,7 +59,7 @@ impl<T: AsyncRead + AsyncWrite> NewService for OpensslAcceptor<T> {
 pub struct OpensslAcceptorService<T> {
     acceptor: SslAcceptor,
     io: PhantomData<T>,
-    conns: Connections,
+    conns: Counter,
 }
 
 impl<T: AsyncRead + AsyncWrite> Service for OpensslAcceptorService<T> {
@@ -89,7 +89,7 @@ where
     T: AsyncRead + AsyncWrite,
 {
     fut: AcceptAsync<T>,
-    _guard: ConnectionsGuard,
+    _guard: CounterGuard,
 }
 
 impl<T: AsyncRead + AsyncWrite> Future for OpensslAcceptorServiceFut<T> {