From 7d2f63eba2d6d95eb0d0ee3d40732731e51fb95b Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 26 Jan 2021 00:41:48 -0800 Subject: [PATCH] add ClientResponse::timeout --- awc/src/response.rs | 114 ++++++++++++++++++++++++++++++++++++--- awc/src/sender.rs | 43 +++++++-------- awc/tests/test_client.rs | 75 +++++++++++++++++++++++++- 3 files changed, 202 insertions(+), 30 deletions(-) diff --git a/awc/src/response.rs b/awc/src/response.rs index c3e7d71ce..847b908ae 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -1,9 +1,11 @@ use std::cell::{Ref, RefMut}; use std::fmt; use std::future::Future; +use std::io; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; use bytes::{Bytes, BytesMut}; use futures_core::{ready, Stream}; @@ -13,6 +15,7 @@ use actix_http::error::{CookieParseError, PayloadError}; use actix_http::http::header::{CONTENT_LENGTH, SET_COOKIE}; use actix_http::http::{HeaderMap, StatusCode, Version}; use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead}; +use actix_rt::time::{sleep, Sleep}; use serde::de::DeserializeOwned; use crate::error::JsonPayloadError; @@ -21,6 +24,41 @@ use crate::error::JsonPayloadError; pub struct ClientResponse { pub(crate) head: ResponseHead, pub(crate) payload: Payload, + pub(crate) timeout: ResponseTimeout, +} + +// a helper enum for response timeout for reusing the boxed sleep. +// It's pass from `SendClientRequest`. +pub(crate) enum ResponseTimeout { + Disabled(Option>>), + Enabled(Option>>), +} + +impl ResponseTimeout { + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), io::Error> { + match *self { + Self::Disabled(_) => Ok(()), + Self::Enabled(Some(ref mut timeout)) => { + if timeout.as_mut().poll(cx).is_ready() { + Ok(()) + } else { + Err(Self::err()) + } + } + Self::Enabled(None) => Err(Self::err()), + } + } + + fn take(&mut self) -> Option>> { + match *self { + Self::Disabled(_) => None, + Self::Enabled(ref mut timeout) => timeout.take(), + } + } + + fn err() -> io::Error { + io::Error::new(io::ErrorKind::TimedOut, "Response Payload IO timed out") + } } impl HttpMessage for ClientResponse { @@ -30,6 +68,10 @@ impl HttpMessage for ClientResponse { &self.head.headers } + fn take_payload(&mut self) -> Payload { + std::mem::replace(&mut self.payload, Payload::None) + } + fn extensions(&self) -> Ref<'_, Extensions> { self.head.extensions() } @@ -38,10 +80,6 @@ impl HttpMessage for ClientResponse { self.head.extensions_mut() } - fn take_payload(&mut self) -> Payload { - std::mem::replace(&mut self.payload, Payload::None) - } - /// Load request cookies. #[inline] fn cookies(&self) -> Result>>, CookieParseError> { @@ -65,7 +103,11 @@ impl HttpMessage for ClientResponse { impl ClientResponse { /// Create new Request instance pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { - ClientResponse { head, payload } + ClientResponse { + head, + payload, + timeout: ResponseTimeout::Disabled(None), + } } #[inline] @@ -101,8 +143,45 @@ impl ClientResponse { ClientResponse { payload, head: self.head, + timeout: self.timeout, } } + + /// Set a timeout duration for [`ClientResponse`](self::ClientResponse). + /// + /// This duration covers the duration of processing the response body stream + /// and would end it as timeout error when deadline met. + /// + /// Disabled by default. + pub fn timeout(self, dur: Duration) -> Self { + let timeout = match self.timeout { + ResponseTimeout::Disabled(Some(mut timeout)) + | ResponseTimeout::Enabled(Some(mut timeout)) => { + match Instant::now().checked_add(dur) { + Some(deadline) => { + timeout.as_mut().reset(deadline.into()); + ResponseTimeout::Enabled(Some(timeout)) + } + None => ResponseTimeout::Enabled(Some(Box::pin(sleep(dur)))), + } + } + _ => ResponseTimeout::Enabled(Some(Box::pin(sleep(dur)))), + }; + + Self { + payload: self.payload, + head: self.head, + timeout, + } + } + + // this method does not enable timeout. It's used to pass the boxed Sleep from + // `SendClientRequest` and reuse it's heap allocation together with it's slot + // in timer wheel. + pub(crate) fn _timeout(mut self, timeout: Option>>) -> Self { + self.timeout = ResponseTimeout::Disabled(timeout); + self + } } impl ClientResponse @@ -136,7 +215,10 @@ where self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - Pin::new(&mut self.get_mut().payload).poll_next(cx) + let this = self.get_mut(); + this.timeout.poll_timeout(cx).map_err(PayloadError::Io)?; + + Pin::new(&mut this.payload).poll_next(cx) } } @@ -155,6 +237,7 @@ impl fmt::Debug for ClientResponse { pub struct MessageBody { length: Option, err: Option, + timeout: Option>>, fut: Option>, } @@ -180,6 +263,7 @@ where MessageBody { length: len, err: None, + timeout: res.timeout.take(), fut: Some(ReadBody::new(res.take_payload(), 262_144)), } } @@ -197,6 +281,7 @@ where fut: None, err: Some(e), length: None, + timeout: None, } } } @@ -220,6 +305,12 @@ where } } + if let Some(ref mut timeout) = this.timeout { + if timeout.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(PayloadError::Io(ResponseTimeout::err()))); + } + } + Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } } @@ -233,6 +324,7 @@ where pub struct JsonBody { length: Option, err: Option, + timeout: Option>>, fut: Option>, _phantom: PhantomData, } @@ -254,6 +346,7 @@ where return JsonBody { length: None, fut: None, + timeout: None, err: Some(JsonPayloadError::ContentType), _phantom: PhantomData, }; @@ -271,6 +364,7 @@ where JsonBody { length: len, err: None, + timeout: req.timeout.take(), fut: Some(ReadBody::new(req.take_payload(), 65536)), _phantom: PhantomData, } @@ -312,6 +406,14 @@ where } } + if let Some(ref mut timeout) = self.timeout { + if timeout.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Io( + ResponseTimeout::err(), + )))); + } + } + let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index 9fb821a0e..c0f89075e 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -5,23 +5,20 @@ use std::rc::Rc; use std::task::{Context, Poll}; use std::time::Duration; -use actix_rt::time::{sleep, Sleep}; -use bytes::Bytes; -use derive_more::From; -use futures_core::Stream; -use serde::Serialize; - use actix_http::body::{Body, BodyStream}; use actix_http::http::header::{self, IntoHeaderValue}; use actix_http::http::{Error as HttpError, HeaderMap, HeaderName}; use actix_http::{Error, RequestHead}; +use actix_rt::time::{sleep, Sleep}; +use bytes::Bytes; +use derive_more::From; +use futures_core::{ready, Stream}; +use serde::Serialize; #[cfg(feature = "compress")] -use actix_http::encoding::Decoder; -#[cfg(feature = "compress")] -use actix_http::http::header::ContentEncoding; -#[cfg(feature = "compress")] -use actix_http::{Payload, PayloadStream}; +use actix_http::{ + encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream, +}; use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use crate::response::ClientResponse; @@ -56,7 +53,6 @@ impl From for SendRequestError { pub enum SendClientRequest { Fut( Pin>>>, - // FIXME: use a pinned Sleep instead of box. Option>>, bool, ), @@ -84,15 +80,14 @@ impl Future for SendClientRequest { match this { SendClientRequest::Fut(send, delay, response_decompress) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { - res.map_body(|head, payload| { + let res = ready!(send.as_mut().poll(cx)).map(|res| { + res._timeout(delay.take()).map_body(|head, payload| { if *response_decompress { Payload::Stream(Decoder::from_headers( payload, @@ -125,13 +120,15 @@ impl Future for SendClientRequest { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, _) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - Pin::new(send).poll(cx) + + send.as_mut() + .poll(cx) + .map_ok(|res| res._timeout(delay.take())) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index 88987e639..71a025c95 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -24,7 +24,7 @@ use actix_web::{ middleware::Compress, test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, }; -use awc::error::SendRequestError; +use awc::error::{JsonPayloadError, PayloadError, SendRequestError}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -160,6 +160,79 @@ async fn test_timeout_override() { } } +#[actix_rt::test] +async fn test_response_timeout() { + use futures_util::stream::{once, StreamExt}; + + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/json") + .streaming(Box::pin(once(async { + actix_rt::time::sleep(Duration::from_millis(200)).await; + Ok::<_, Error>(Bytes::from(STR)) + }))), + ) + }))) + }); + + let client = awc::Client::new(); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(210)) + .body() + .await + .unwrap(); + assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .next() + .await + .unwrap(); + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .body() + .await; + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .json::>() + .await; + match res { + Err(JsonPayloadError::Payload(PayloadError::Io(e))) => { + assert_eq!(e.kind(), std::io::ErrorKind::TimedOut) + } + _ => panic!("Response error type is not matched"), + } +} + #[actix_rt::test] async fn test_connection_reuse() { let num = Arc::new(AtomicUsize::new(0));