From 5efea652e342479d6d0a20afc58bd880a691a712 Mon Sep 17 00:00:00 2001
From: fakeshadow <24548779@qq.com>
Date: Wed, 17 Feb 2021 03:55:11 -0800
Subject: [PATCH] add ClientResponse::timeout (#1931)

---
 awc/CHANGES.md           |   4 ++
 awc/src/response.rs      | 131 ++++++++++++++++++++++++++++++++-------
 awc/src/sender.rs        |  32 +++++-----
 awc/tests/test_client.rs |  75 +++++++++++++++++++++-
 4 files changed, 202 insertions(+), 40 deletions(-)

diff --git a/awc/CHANGES.md b/awc/CHANGES.md
index 9224f414..c67f6556 100644
--- a/awc/CHANGES.md
+++ b/awc/CHANGES.md
@@ -1,9 +1,13 @@
 # Changes
 
 ## Unreleased - 2021-xx-xx
+### Added
+* `ClientResponse::timeout` for set the timeout of collecting response body. [#1931]
+
 ### Changed
 * Feature `cookies` is now optional and enabled by default. [#1981]
 
+[#1931]: https://github.com/actix/actix-web/pull/1931
 [#1981]: https://github.com/actix/actix-web/pull/1981
 
 
diff --git a/awc/src/response.rs b/awc/src/response.rs
index cf687329..514b8a90 100644
--- a/awc/src/response.rs
+++ b/awc/src/response.rs
@@ -1,20 +1,22 @@
-use std::fmt;
-use std::future::Future;
-use std::marker::PhantomData;
-use std::pin::Pin;
-use std::task::{Context, Poll};
 use std::{
     cell::{Ref, RefMut},
-    mem,
+    fmt,
+    future::Future,
+    io,
+    marker::PhantomData,
+    pin::Pin,
+    task::{Context, Poll},
+    time::{Duration, Instant},
 };
 
+use actix_http::{
+    error::PayloadError,
+    http::{header, HeaderMap, StatusCode, Version},
+    Extensions, HttpMessage, Payload, PayloadStream, ResponseHead,
+};
+use actix_rt::time::{sleep, Sleep};
 use bytes::{Bytes, BytesMut};
 use futures_core::{ready, Stream};
-
-use actix_http::error::PayloadError;
-use actix_http::http::header;
-use actix_http::http::{HeaderMap, StatusCode, Version};
-use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead};
 use serde::de::DeserializeOwned;
 
 #[cfg(feature = "cookies")]
@@ -26,6 +28,38 @@ use crate::error::JsonPayloadError;
 pub struct ClientResponse<S = PayloadStream> {
     pub(crate) head: ResponseHead,
     pub(crate) payload: Payload<S>,
+    pub(crate) timeout: ResponseTimeout,
+}
+
+/// helper enum with reusable sleep passed from `SendClientResponse`.
+/// See `ClientResponse::_timeout` for reason.
+pub(crate) enum ResponseTimeout {
+    Disabled(Option<Pin<Box<Sleep>>>),
+    Enabled(Pin<Box<Sleep>>),
+}
+
+impl Default for ResponseTimeout {
+    fn default() -> Self {
+        Self::Disabled(None)
+    }
+}
+
+impl ResponseTimeout {
+    fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> {
+        match *self {
+            Self::Enabled(ref mut timeout) => {
+                if timeout.as_mut().poll(cx).is_ready() {
+                    Err(PayloadError::Io(io::Error::new(
+                        io::ErrorKind::TimedOut,
+                        "Response Payload IO timed out",
+                    )))
+                } else {
+                    Ok(())
+                }
+            }
+            Self::Disabled(_) => Ok(()),
+        }
+    }
 }
 
 impl<S> HttpMessage for ClientResponse<S> {
@@ -35,6 +69,10 @@ impl<S> HttpMessage for ClientResponse<S> {
         &self.head.headers
     }
 
+    fn take_payload(&mut self) -> Payload<S> {
+        std::mem::replace(&mut self.payload, Payload::None)
+    }
+
     fn extensions(&self) -> Ref<'_, Extensions> {
         self.head.extensions()
     }
@@ -43,10 +81,6 @@ impl<S> HttpMessage for ClientResponse<S> {
         self.head.extensions_mut()
     }
 
-    fn take_payload(&mut self) -> Payload<S> {
-        mem::replace(&mut self.payload, Payload::None)
-    }
-
     /// Load request cookies.
     #[cfg(feature = "cookies")]
     fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
@@ -69,7 +103,11 @@ impl<S> HttpMessage for ClientResponse<S> {
 impl<S> ClientResponse<S> {
     /// Create new Request instance
     pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self {
-        ClientResponse { head, payload }
+        ClientResponse {
+            head,
+            payload,
+            timeout: ResponseTimeout::default(),
+        }
     }
 
     #[inline]
@@ -105,8 +143,43 @@ impl<S> ClientResponse<S> {
         ClientResponse {
             payload,
             head: self.head,
+            timeout: self.timeout,
         }
     }
+
+    /// Set a timeout duration for [`ClientResponse`](self::ClientResponse).
+    ///
+    /// This duration covers the duration of processing the response body stream
+    /// and would end it as timeout error when deadline met.
+    ///
+    /// Disabled by default.
+    pub fn timeout(self, dur: Duration) -> Self {
+        let timeout = match self.timeout {
+            ResponseTimeout::Disabled(Some(mut timeout))
+            | ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) {
+                Some(deadline) => {
+                    timeout.as_mut().reset(deadline.into());
+                    ResponseTimeout::Enabled(timeout)
+                }
+                None => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
+            },
+            _ => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
+        };
+
+        Self {
+            payload: self.payload,
+            head: self.head,
+            timeout,
+        }
+    }
+
+    /// This method does not enable timeout. It's used to pass the boxed `Sleep` from
+    /// `SendClientRequest` and reuse it's heap allocation together with it's slot in
+    /// timer wheel.
+    pub(crate) fn _timeout(mut self, timeout: Option<Pin<Box<Sleep>>>) -> Self {
+        self.timeout = ResponseTimeout::Disabled(timeout);
+        self
+    }
 }
 
 impl<S> ClientResponse<S>
@@ -137,7 +210,10 @@ where
     type Item = Result<Bytes, PayloadError>;
 
     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        Pin::new(&mut self.get_mut().payload).poll_next(cx)
+        let this = self.get_mut();
+        this.timeout.poll_timeout(cx)?;
+
+        Pin::new(&mut this.payload).poll_next(cx)
     }
 }
 
@@ -156,6 +232,7 @@ impl<S> fmt::Debug for ClientResponse<S> {
 pub struct MessageBody<S> {
     length: Option<usize>,
     err: Option<PayloadError>,
+    timeout: ResponseTimeout,
     fut: Option<ReadBody<S>>,
 }
 
@@ -181,6 +258,7 @@ where
         MessageBody {
             length: len,
             err: None,
+            timeout: std::mem::take(&mut res.timeout),
             fut: Some(ReadBody::new(res.take_payload(), 262_144)),
         }
     }
@@ -198,6 +276,7 @@ where
             fut: None,
             err: Some(e),
             length: None,
+            timeout: ResponseTimeout::default(),
         }
     }
 }
@@ -221,6 +300,8 @@ where
             }
         }
 
+        this.timeout.poll_timeout(cx)?;
+
         Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
     }
 }
@@ -234,6 +315,7 @@ where
 pub struct JsonBody<S, U> {
     length: Option<usize>,
     err: Option<JsonPayloadError>,
+    timeout: ResponseTimeout,
     fut: Option<ReadBody<S>>,
     _phantom: PhantomData<U>,
 }
@@ -244,9 +326,9 @@ where
     U: DeserializeOwned,
 {
     /// Create `JsonBody` for request.
-    pub fn new(req: &mut ClientResponse<S>) -> Self {
+    pub fn new(res: &mut ClientResponse<S>) -> Self {
         // check content-type
-        let json = if let Ok(Some(mime)) = req.mime_type() {
+        let json = if let Ok(Some(mime)) = res.mime_type() {
             mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
         } else {
             false
@@ -255,13 +337,15 @@ where
             return JsonBody {
                 length: None,
                 fut: None,
+                timeout: ResponseTimeout::default(),
                 err: Some(JsonPayloadError::ContentType),
                 _phantom: PhantomData,
             };
         }
 
         let mut len = None;
-        if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) {
+
+        if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) {
             if let Ok(s) = l.to_str() {
                 if let Ok(l) = s.parse::<usize>() {
                     len = Some(l)
@@ -272,7 +356,8 @@ where
         JsonBody {
             length: len,
             err: None,
-            fut: Some(ReadBody::new(req.take_payload(), 65536)),
+            timeout: std::mem::take(&mut res.timeout),
+            fut: Some(ReadBody::new(res.take_payload(), 65536)),
             _phantom: PhantomData,
         }
     }
@@ -311,6 +396,10 @@ where
             }
         }
 
+        self.timeout
+            .poll_timeout(cx)
+            .map_err(JsonPayloadError::Payload)?;
+
         let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?;
         Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
     }
diff --git a/awc/src/sender.rs b/awc/src/sender.rs
index a72b129f..6bac401c 100644
--- a/awc/src/sender.rs
+++ b/awc/src/sender.rs
@@ -18,15 +18,11 @@ use actix_http::{
 use actix_rt::time::{sleep, Sleep};
 use bytes::Bytes;
 use derive_more::From;
-use futures_core::Stream;
+use futures_core::{ready, Stream};
 use serde::Serialize;
 
 #[cfg(feature = "compress")]
-use actix_http::encoding::Decoder;
-#[cfg(feature = "compress")]
-use actix_http::http::header::ContentEncoding;
-#[cfg(feature = "compress")]
-use actix_http::{Payload, PayloadStream};
+use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream};
 
 use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError};
 use crate::response::ClientResponse;
@@ -61,7 +57,6 @@ impl From<PrepForSendingError> for SendRequestError {
 pub enum SendClientRequest {
     Fut(
         Pin<Box<dyn Future<Output = Result<ClientResponse, SendRequestError>>>>,
-        // FIXME: use a pinned Sleep instead of box.
         Option<Pin<Box<Sleep>>>,
         bool,
     ),
@@ -88,15 +83,14 @@ impl Future for SendClientRequest {
 
         match this {
             SendClientRequest::Fut(send, delay, response_decompress) => {
-                if delay.is_some() {
-                    match Pin::new(delay.as_mut().unwrap()).poll(cx) {
-                        Poll::Pending => {}
-                        _ => return Poll::Ready(Err(SendRequestError::Timeout)),
+                if let Some(delay) = delay {
+                    if delay.as_mut().poll(cx).is_ready() {
+                        return Poll::Ready(Err(SendRequestError::Timeout));
                     }
                 }
 
-                let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| {
-                    res.map_body(|head, payload| {
+                let res = ready!(send.as_mut().poll(cx)).map(|res| {
+                    res._timeout(delay.take()).map_body(|head, payload| {
                         if *response_decompress {
                             Payload::Stream(Decoder::from_headers(payload, &head.headers))
                         } else {
@@ -123,13 +117,15 @@ impl Future for SendClientRequest {
         let this = self.get_mut();
         match this {
             SendClientRequest::Fut(send, delay, _) => {
-                if delay.is_some() {
-                    match Pin::new(delay.as_mut().unwrap()).poll(cx) {
-                        Poll::Pending => {}
-                        _ => return Poll::Ready(Err(SendRequestError::Timeout)),
+                if let Some(delay) = delay {
+                    if delay.as_mut().poll(cx).is_ready() {
+                        return Poll::Ready(Err(SendRequestError::Timeout));
                     }
                 }
-                Pin::new(send).poll(cx)
+
+                send.as_mut()
+                    .poll(cx)
+                    .map_ok(|res| res._timeout(delay.take()))
             }
             SendClientRequest::Err(ref mut e) => match e.take() {
                 Some(e) => Poll::Ready(Err(e)),
diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs
index 7e74d226..bcbaf3f4 100644
--- a/awc/tests/test_client.rs
+++ b/awc/tests/test_client.rs
@@ -24,7 +24,7 @@ use actix_web::{
     middleware::Compress,
     test, web, App, Error, HttpMessage, HttpRequest, HttpResponse,
 };
-use awc::error::SendRequestError;
+use awc::error::{JsonPayloadError, PayloadError, SendRequestError};
 
 const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
                    Hello World Hello World Hello World Hello World Hello World \
@@ -157,6 +157,79 @@ async fn test_timeout_override() {
     }
 }
 
+#[actix_rt::test]
+async fn test_response_timeout() {
+    use futures_util::stream::{once, StreamExt};
+
+    let srv = test::start(|| {
+        App::new().service(web::resource("/").route(web::to(|| async {
+            Ok::<_, Error>(
+                HttpResponse::Ok()
+                    .content_type("application/json")
+                    .streaming(Box::pin(once(async {
+                        actix_rt::time::sleep(Duration::from_millis(200)).await;
+                        Ok::<_, Error>(Bytes::from(STR))
+                    }))),
+            )
+        })))
+    });
+
+    let client = awc::Client::new();
+
+    let res = client
+        .get(srv.url("/"))
+        .send()
+        .await
+        .unwrap()
+        .timeout(Duration::from_millis(500))
+        .body()
+        .await
+        .unwrap();
+    assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR);
+
+    let res = client
+        .get(srv.url("/"))
+        .send()
+        .await
+        .unwrap()
+        .timeout(Duration::from_millis(100))
+        .next()
+        .await
+        .unwrap();
+    match res {
+        Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut),
+        _ => panic!("Response error type is not matched"),
+    }
+
+    let res = client
+        .get(srv.url("/"))
+        .send()
+        .await
+        .unwrap()
+        .timeout(Duration::from_millis(100))
+        .body()
+        .await;
+    match res {
+        Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut),
+        _ => panic!("Response error type is not matched"),
+    }
+
+    let res = client
+        .get(srv.url("/"))
+        .send()
+        .await
+        .unwrap()
+        .timeout(Duration::from_millis(100))
+        .json::<HashMap<String, String>>()
+        .await;
+    match res {
+        Err(JsonPayloadError::Payload(PayloadError::Io(e))) => {
+            assert_eq!(e.kind(), std::io::ErrorKind::TimedOut)
+        }
+        _ => panic!("Response error type is not matched"),
+    }
+}
+
 #[actix_rt::test]
 async fn test_connection_reuse() {
     let num = Arc::new(AtomicUsize::new(0));