From 18575ee1ee26bfc54bea891d452b0a51f4b42b73 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Wed, 9 May 2018 16:27:31 -0700
Subject: [PATCH] Add Router::with_async() method for async handler
 registration

---
 CHANGES.md             |   5 ++
 build.rs               |  11 +++-
 src/resource.rs        |  21 +++++++
 src/route.rs           |  72 ++++++++++++++++++++-
 src/with.rs            | 139 +++++++++++++++++++++++++++++++++++++++++
 tests/test_handlers.rs |  78 +++++++++++++++++++++++
 6 files changed, 324 insertions(+), 2 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 51267c76..cf4df3e3 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,5 +1,10 @@
 # Changes
 
+## 0.6.3 (2018-05-xx)
+
+* Add `Router::with_async()` method for async handler registration.
+
+
 ## 0.6.2 (2018-05-09)
 
 * WsWriter trait is optional.
diff --git a/build.rs b/build.rs
index 3b3001f9..7cb25c73 100644
--- a/build.rs
+++ b/build.rs
@@ -1,8 +1,17 @@
 extern crate version_check;
 
 fn main() {
+    let mut has_impl_trait = true;
+
+    match version_check::is_min_version("1.26.0") {
+        Some((true, _)) => println!("cargo:rustc-cfg=actix_impl_trait"),
+        _ => (),
+    };
     match version_check::is_nightly() {
-        Some(true) => println!("cargo:rustc-cfg=actix_nightly"),
+        Some(true) => {
+            println!("cargo:rustc-cfg=actix_nightly");
+            println!("cargo:rustc-cfg=actix_impl_trait");
+        }
         Some(false) => (),
         None => (),
     };
diff --git a/src/resource.rs b/src/resource.rs
index fb08afd9..e52760f4 100644
--- a/src/resource.rs
+++ b/src/resource.rs
@@ -1,9 +1,11 @@
 use std::marker::PhantomData;
 use std::rc::Rc;
 
+use futures::Future;
 use http::{Method, StatusCode};
 use smallvec::SmallVec;
 
+use error::Error;
 use handler::{AsyncResult, FromRequest, Handler, Responder};
 use httprequest::HttpRequest;
 use httpresponse::HttpResponse;
@@ -183,6 +185,25 @@ impl<S: 'static> ResourceHandler<S> {
         self.routes.last_mut().unwrap().with(handler);
     }
 
+    /// Register a new route and add async handler.
+    ///
+    /// This is shortcut for:
+    ///
+    /// ```rust,ignore
+    /// Application::resource("/", |r| r.route().with_async(index)
+    /// ```
+    pub fn with_async<T, F, R, I, E>(&mut self, handler: F)
+    where
+        F: Fn(T) -> R + 'static,
+        R: Future<Item = I, Error = E> + 'static,
+        I: Responder + 'static,
+        E: Into<Error> + 'static,
+        T: FromRequest<S> + 'static,
+    {
+        self.routes.push(Route::default());
+        self.routes.last_mut().unwrap().with_async(handler);
+    }
+
     /// Register a resource middleware
     ///
     /// This is similar to `App's` middlewares, but
diff --git a/src/route.rs b/src/route.rs
index 215a7f22..4ff3279e 100644
--- a/src/route.rs
+++ b/src/route.rs
@@ -13,7 +13,7 @@ use httpresponse::HttpResponse;
 use middleware::{Finished as MiddlewareFinished, Middleware,
                  Response as MiddlewareResponse, Started as MiddlewareStarted};
 use pred::Predicate;
-use with::{ExtractorConfig, With, With2, With3};
+use with::{ExtractorConfig, With, With2, With3, WithAsync};
 
 /// Resource route definition
 ///
@@ -129,6 +129,34 @@ impl<S: 'static> Route<S> {
     ///        |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor
     /// }
     /// ```
+    ///
+    /// It is possible to use tuples for specifing multiple extractors for one
+    /// handler function.
+    ///
+    /// ```rust
+    /// # extern crate bytes;
+    /// # extern crate actix_web;
+    /// # extern crate futures;
+    /// #[macro_use] extern crate serde_derive;
+    /// # use std::collections::HashMap;
+    /// use actix_web::{http, App, Query, Path, Result, Json};
+    ///
+    /// #[derive(Deserialize)]
+    /// struct Info {
+    ///     username: String,
+    /// }
+    ///
+    /// /// extract path info using serde
+    /// fn index(info: (Path<Info>, Query<HashMap<String, String>>, Json<Info>)) -> Result<String> {
+    ///     Ok(format!("Welcome {}!", info.0.username))
+    /// }
+    ///
+    /// fn main() {
+    ///     let app = App::new().resource(
+    ///        "/{username}/index.html",                     // <- define path parameters
+    ///        |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor
+    /// }
+    /// ```
     pub fn with<T, F, R>(&mut self, handler: F) -> ExtractorConfig<S, T>
     where
         F: Fn(T) -> R + 'static,
@@ -140,6 +168,47 @@ impl<S: 'static> Route<S> {
         cfg
     }
 
+    /// Set async handler function, use request extractor for parameters.
+    ///
+    /// ```rust
+    /// # extern crate bytes;
+    /// # extern crate actix_web;
+    /// # extern crate futures;
+    /// #[macro_use] extern crate serde_derive;
+    /// use actix_web::{App, Path, Error, http};
+    /// use futures::Future;
+    ///
+    /// #[derive(Deserialize)]
+    /// struct Info {
+    ///     username: String,
+    /// }
+    ///
+    /// /// extract path info using serde
+    /// fn index(info: Path<Info>) -> Box<Future<Item=&'static str, Error=Error>> {
+    ///     unimplemented!()
+    /// }
+    ///
+    /// fn main() {
+    ///     let app = App::new().resource(
+    ///        "/{username}/index.html",           // <- define path parameters
+    ///        |r| r.method(http::Method::GET)
+    ///               .with_async(index));         // <- use `with` extractor
+    /// }
+    /// ```
+    pub fn with_async<T, F, R, I, E>(&mut self, handler: F) -> ExtractorConfig<S, T>
+    where
+        F: Fn(T) -> R + 'static,
+        R: Future<Item = I, Error = E> + 'static,
+        I: Responder + 'static,
+        E: Into<Error> + 'static,
+        T: FromRequest<S> + 'static,
+    {
+        let cfg = ExtractorConfig::default();
+        self.h(WithAsync::new(handler, Clone::clone(&cfg)));
+        cfg
+    }
+
+    #[doc(hidden)]
     /// Set handler function, use request extractor for both parameters.
     ///
     /// ```rust
@@ -189,6 +258,7 @@ impl<S: 'static> Route<S> {
         (cfg1, cfg2)
     }
 
+    #[doc(hidden)]
     /// Set handler function, use request extractor for all parameters.
     pub fn with3<T1, T2, T3, F, R>(
         &mut self, handler: F,
diff --git a/src/with.rs b/src/with.rs
index fa3d7d80..dca600bb 100644
--- a/src/with.rs
+++ b/src/with.rs
@@ -167,6 +167,145 @@ where
     }
 }
 
+pub struct WithAsync<T, S, F, R, I, E>
+where
+    F: Fn(T) -> R,
+    R: Future<Item = I, Error = E>,
+    I: Responder,
+    E: Into<E>,
+    T: FromRequest<S>,
+    S: 'static,
+{
+    hnd: Rc<UnsafeCell<F>>,
+    cfg: ExtractorConfig<S, T>,
+    _s: PhantomData<S>,
+}
+
+impl<T, S, F, R, I, E> WithAsync<T, S, F, R, I, E>
+where
+    F: Fn(T) -> R,
+    R: Future<Item = I, Error = E>,
+    I: Responder,
+    E: Into<Error>,
+    T: FromRequest<S>,
+    S: 'static,
+{
+    pub fn new(f: F, cfg: ExtractorConfig<S, T>) -> Self {
+        WithAsync {
+            cfg,
+            hnd: Rc::new(UnsafeCell::new(f)),
+            _s: PhantomData,
+        }
+    }
+}
+
+impl<T, S, F, R, I, E> Handler<S> for WithAsync<T, S, F, R, I, E>
+where
+    F: Fn(T) -> R + 'static,
+    R: Future<Item = I, Error = E> + 'static,
+    I: Responder + 'static,
+    E: Into<Error> + 'static,
+    T: FromRequest<S> + 'static,
+    S: 'static,
+{
+    type Result = AsyncResult<HttpResponse>;
+
+    fn handle(&mut self, req: HttpRequest<S>) -> Self::Result {
+        let mut fut = WithAsyncHandlerFut {
+            req,
+            started: false,
+            hnd: Rc::clone(&self.hnd),
+            cfg: self.cfg.clone(),
+            fut1: None,
+            fut2: None,
+            fut3: None,
+        };
+
+        match fut.poll() {
+            Ok(Async::Ready(resp)) => AsyncResult::ok(resp),
+            Ok(Async::NotReady) => AsyncResult::async(Box::new(fut)),
+            Err(e) => AsyncResult::err(e),
+        }
+    }
+}
+
+struct WithAsyncHandlerFut<T, S, F, R, I, E>
+where
+    F: Fn(T) -> R,
+    R: Future<Item = I, Error = E> + 'static,
+    I: Responder + 'static,
+    E: Into<Error> + 'static,
+    T: FromRequest<S> + 'static,
+    S: 'static,
+{
+    started: bool,
+    hnd: Rc<UnsafeCell<F>>,
+    cfg: ExtractorConfig<S, T>,
+    req: HttpRequest<S>,
+    fut1: Option<Box<Future<Item = T, Error = Error>>>,
+    fut2: Option<R>,
+    fut3: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
+}
+
+impl<T, S, F, R, I, E> Future for WithAsyncHandlerFut<T, S, F, R, I, E>
+where
+    F: Fn(T) -> R,
+    R: Future<Item = I, Error = E> + 'static,
+    I: Responder + 'static,
+    E: Into<Error> + 'static,
+    T: FromRequest<S> + 'static,
+    S: 'static,
+{
+    type Item = HttpResponse;
+    type Error = Error;
+
+    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
+        if let Some(ref mut fut) = self.fut3 {
+            return fut.poll();
+        }
+
+        if self.fut2.is_some() {
+            return match self.fut2.as_mut().unwrap().poll() {
+                Ok(Async::NotReady) => Ok(Async::NotReady),
+                Ok(Async::Ready(r)) => match r.respond_to(&self.req) {
+                    Ok(r) => match r.into().into() {
+                        AsyncResultItem::Err(err) => Err(err),
+                        AsyncResultItem::Ok(resp) => Ok(Async::Ready(resp)),
+                        AsyncResultItem::Future(fut) => {
+                            self.fut3 = Some(fut);
+                            self.poll()
+                        }
+                    },
+                    Err(e) => Err(e.into()),
+                },
+                Err(e) => Err(e.into()),
+            };
+        }
+
+        let item = if !self.started {
+            self.started = true;
+            let reply = T::from_request(&self.req, self.cfg.as_ref()).into();
+            match reply.into() {
+                AsyncResultItem::Err(err) => return Err(err),
+                AsyncResultItem::Ok(msg) => msg,
+                AsyncResultItem::Future(fut) => {
+                    self.fut1 = Some(fut);
+                    return self.poll();
+                }
+            }
+        } else {
+            match self.fut1.as_mut().unwrap().poll()? {
+                Async::Ready(item) => item,
+                Async::NotReady => return Ok(Async::NotReady),
+            }
+        };
+
+        let hnd: &mut F = unsafe { &mut *self.hnd.get() };
+        self.fut2 = Some((*hnd)(item));
+        self.poll()
+    }
+}
+
 pub struct With2<T1, T2, S, F, R>
 where
     F: Fn(T1, T2) -> R,
diff --git a/tests/test_handlers.rs b/tests/test_handlers.rs
index 8aea34d0..42a9f3ac 100644
--- a/tests/test_handlers.rs
+++ b/tests/test_handlers.rs
@@ -9,6 +9,7 @@ extern crate tokio_core;
 extern crate serde_derive;
 extern crate serde_json;
 
+use std::io;
 use std::time::Duration;
 
 use actix::*;
@@ -377,6 +378,83 @@ fn test_path_and_query_extractor2_async4() {
     assert_eq!(response.status(), StatusCode::BAD_REQUEST);
 }
 
+#[cfg(actix_impl_trait)]
+fn test_impl_trait(
+    data: (Json<Value>, Path<PParam>, Query<PParam>),
+) -> impl Future<Item = String, Error = io::Error> {
+    Timeout::new(Duration::from_millis(10), &Arbiter::handle())
+        .unwrap()
+        .and_then(move |_| {
+            Ok(format!(
+                "Welcome {} - {}!",
+                data.1.username,
+                (data.0).0
+            ))
+        })
+}
+
+#[cfg(actix_impl_trait)]
+fn test_impl_trait_err(
+    _data: (Json<Value>, Path<PParam>, Query<PParam>),
+) -> impl Future<Item = String, Error = io::Error> {
+    Timeout::new(Duration::from_millis(10), &Arbiter::handle())
+        .unwrap()
+        .and_then(move |_| Err(io::Error::new(io::ErrorKind::Other, "other")))
+}
+
+#[cfg(actix_impl_trait)]
+#[test]
+fn test_path_and_query_extractor2_async4_impl_trait() {
+    let mut srv = test::TestServer::new(|app| {
+        app.resource("/{username}/index.html", |r| {
+            r.route().with_async(test_impl_trait)
+        });
+    });
+
+    // client request
+    let request = srv.post()
+        .uri(srv.url("/test1/index.html?username=test2"))
+        .header("content-type", "application/json")
+        .body("{\"test\": 1}")
+        .unwrap();
+    let response = srv.execute(request.send()).unwrap();
+    assert!(response.status().is_success());
+
+    // read response
+    let bytes = srv.execute(response.body()).unwrap();
+    assert_eq!(
+        bytes,
+        Bytes::from_static(b"Welcome test1 - {\"test\":1}!")
+    );
+
+    // client request
+    let request = srv.get()
+        .uri(srv.url("/test1/index.html"))
+        .finish()
+        .unwrap();
+    let response = srv.execute(request.send()).unwrap();
+    assert_eq!(response.status(), StatusCode::BAD_REQUEST);
+}
+
+#[cfg(actix_impl_trait)]
+#[test]
+fn test_path_and_query_extractor2_async4_impl_trait_err() {
+    let mut srv = test::TestServer::new(|app| {
+        app.resource("/{username}/index.html", |r| {
+            r.route().with_async(test_impl_trait_err)
+        });
+    });
+
+    // client request
+    let request = srv.post()
+        .uri(srv.url("/test1/index.html?username=test2"))
+        .header("content-type", "application/json")
+        .body("{\"test\": 1}")
+        .unwrap();
+    let response = srv.execute(request.send()).unwrap();
+    assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
+}
+
 #[test]
 fn test_non_ascii_route() {
     let mut srv = test::TestServer::new(|app| {