From dfa795ff9da9e7064239b9bb4ea55a90d60bfd11 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 17 Feb 2021 03:18:31 -0800 Subject: [PATCH 1/2] return poll in poll_flush (#2005) --- actix-http/src/h1/dispatcher.rs | 37 +++++++++++++-------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index f7d7f32c3..839f75402 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -13,6 +13,7 @@ use actix_rt::time::{sleep_until, Instant, Sleep}; use actix_service::Service; use bitflags::bitflags; use bytes::{Buf, BytesMut}; +use futures_core::ready; use log::{error, trace}; use pin_project::pin_project; @@ -233,14 +234,10 @@ where } } - /// Flush stream - /// - /// true - got WouldBlock - /// false - didn't get WouldBlock fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Result { + ) -> Poll> { let InnerDispatcherProj { io, write_buf, .. } = self.project(); let mut io = Pin::new(io.as_mut().unwrap()); @@ -248,19 +245,18 @@ where let mut written = 0; while written < len { - match io.as_mut().poll_write(cx, &write_buf[written..]) { - Poll::Ready(Ok(0)) => { - return Err(DispatchError::Io(io::Error::new( + match io.as_mut().poll_write(cx, &write_buf[written..])? { + Poll::Ready(0) => { + return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, "", ))) } - Poll::Ready(Ok(n)) => written += n, + Poll::Ready(n) => written += n, Poll::Pending => { write_buf.advance(written); - return Ok(true); + return Poll::Pending; } - Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), } } @@ -268,9 +264,7 @@ where write_buf.clear(); // flush the io and check if get blocked. - let blocked = io.poll_flush(cx)?.is_pending(); - - Ok(blocked) + io.poll_flush(cx) } fn send_response( @@ -841,14 +835,11 @@ where if inner.flags.contains(Flags::WRITE_DISCONNECT) { Poll::Ready(Ok(())) } else { - // flush buffer and wait on block. - if inner.as_mut().poll_flush(cx)? { - Poll::Pending - } else { - Pin::new(inner.project().io.as_mut().unwrap()) - .poll_shutdown(cx) - .map_err(DispatchError::from) - } + // flush buffer and wait on blocked. + ready!(inner.as_mut().poll_flush(cx))?; + Pin::new(inner.project().io.as_mut().unwrap()) + .poll_shutdown(cx) + .map_err(DispatchError::from) } } else { // read from io stream and fill read buffer. @@ -888,7 +879,7 @@ where // // TODO: what? is WouldBlock good or bad? // want to find a reference for this macOS behavior - if inner.as_mut().poll_flush(cx)? || !drain { + if inner.as_mut().poll_flush(cx)?.is_pending() || !drain { break; } } From 5efea652e342479d6d0a20afc58bd880a691a712 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 17 Feb 2021 03:55:11 -0800 Subject: [PATCH 2/2] add ClientResponse::timeout (#1931) --- awc/CHANGES.md | 4 ++ awc/src/response.rs | 131 ++++++++++++++++++++++++++++++++------- awc/src/sender.rs | 32 +++++----- awc/tests/test_client.rs | 75 +++++++++++++++++++++- 4 files changed, 202 insertions(+), 40 deletions(-) diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 9224f414d..c67f65560 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,9 +1,13 @@ # Changes ## Unreleased - 2021-xx-xx +### Added +* `ClientResponse::timeout` for set the timeout of collecting response body. [#1931] + ### Changed * Feature `cookies` is now optional and enabled by default. [#1981] +[#1931]: https://github.com/actix/actix-web/pull/1931 [#1981]: https://github.com/actix/actix-web/pull/1981 diff --git a/awc/src/response.rs b/awc/src/response.rs index cf687329d..514b8a90b 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -1,20 +1,22 @@ -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; use std::{ cell::{Ref, RefMut}, - mem, + fmt, + future::Future, + io, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, }; +use actix_http::{ + error::PayloadError, + http::{header, HeaderMap, StatusCode, Version}, + Extensions, HttpMessage, Payload, PayloadStream, ResponseHead, +}; +use actix_rt::time::{sleep, Sleep}; use bytes::{Bytes, BytesMut}; use futures_core::{ready, Stream}; - -use actix_http::error::PayloadError; -use actix_http::http::header; -use actix_http::http::{HeaderMap, StatusCode, Version}; -use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead}; use serde::de::DeserializeOwned; #[cfg(feature = "cookies")] @@ -26,6 +28,38 @@ use crate::error::JsonPayloadError; pub struct ClientResponse { pub(crate) head: ResponseHead, pub(crate) payload: Payload, + pub(crate) timeout: ResponseTimeout, +} + +/// helper enum with reusable sleep passed from `SendClientResponse`. +/// See `ClientResponse::_timeout` for reason. +pub(crate) enum ResponseTimeout { + Disabled(Option>>), + Enabled(Pin>), +} + +impl Default for ResponseTimeout { + fn default() -> Self { + Self::Disabled(None) + } +} + +impl ResponseTimeout { + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { + match *self { + Self::Enabled(ref mut timeout) => { + if timeout.as_mut().poll(cx).is_ready() { + Err(PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Response Payload IO timed out", + ))) + } else { + Ok(()) + } + } + Self::Disabled(_) => Ok(()), + } + } } impl HttpMessage for ClientResponse { @@ -35,6 +69,10 @@ impl HttpMessage for ClientResponse { &self.head.headers } + fn take_payload(&mut self) -> Payload { + std::mem::replace(&mut self.payload, Payload::None) + } + fn extensions(&self) -> Ref<'_, Extensions> { self.head.extensions() } @@ -43,10 +81,6 @@ impl HttpMessage for ClientResponse { self.head.extensions_mut() } - fn take_payload(&mut self) -> Payload { - mem::replace(&mut self.payload, Payload::None) - } - /// Load request cookies. #[cfg(feature = "cookies")] fn cookies(&self) -> Result>>, CookieParseError> { @@ -69,7 +103,11 @@ impl HttpMessage for ClientResponse { impl ClientResponse { /// Create new Request instance pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { - ClientResponse { head, payload } + ClientResponse { + head, + payload, + timeout: ResponseTimeout::default(), + } } #[inline] @@ -105,8 +143,43 @@ impl ClientResponse { ClientResponse { payload, head: self.head, + timeout: self.timeout, } } + + /// Set a timeout duration for [`ClientResponse`](self::ClientResponse). + /// + /// This duration covers the duration of processing the response body stream + /// and would end it as timeout error when deadline met. + /// + /// Disabled by default. + pub fn timeout(self, dur: Duration) -> Self { + let timeout = match self.timeout { + ResponseTimeout::Disabled(Some(mut timeout)) + | ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) { + Some(deadline) => { + timeout.as_mut().reset(deadline.into()); + ResponseTimeout::Enabled(timeout) + } + None => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }, + _ => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }; + + Self { + payload: self.payload, + head: self.head, + timeout, + } + } + + /// This method does not enable timeout. It's used to pass the boxed `Sleep` from + /// `SendClientRequest` and reuse it's heap allocation together with it's slot in + /// timer wheel. + pub(crate) fn _timeout(mut self, timeout: Option>>) -> Self { + self.timeout = ResponseTimeout::Disabled(timeout); + self + } } impl ClientResponse @@ -137,7 +210,10 @@ where type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().payload).poll_next(cx) + let this = self.get_mut(); + this.timeout.poll_timeout(cx)?; + + Pin::new(&mut this.payload).poll_next(cx) } } @@ -156,6 +232,7 @@ impl fmt::Debug for ClientResponse { pub struct MessageBody { length: Option, err: Option, + timeout: ResponseTimeout, fut: Option>, } @@ -181,6 +258,7 @@ where MessageBody { length: len, err: None, + timeout: std::mem::take(&mut res.timeout), fut: Some(ReadBody::new(res.take_payload(), 262_144)), } } @@ -198,6 +276,7 @@ where fut: None, err: Some(e), length: None, + timeout: ResponseTimeout::default(), } } } @@ -221,6 +300,8 @@ where } } + this.timeout.poll_timeout(cx)?; + Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } } @@ -234,6 +315,7 @@ where pub struct JsonBody { length: Option, err: Option, + timeout: ResponseTimeout, fut: Option>, _phantom: PhantomData, } @@ -244,9 +326,9 @@ where U: DeserializeOwned, { /// Create `JsonBody` for request. - pub fn new(req: &mut ClientResponse) -> Self { + pub fn new(res: &mut ClientResponse) -> Self { // check content-type - let json = if let Ok(Some(mime)) = req.mime_type() { + let json = if let Ok(Some(mime)) = res.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) } else { false @@ -255,13 +337,15 @@ where return JsonBody { length: None, fut: None, + timeout: ResponseTimeout::default(), err: Some(JsonPayloadError::ContentType), _phantom: PhantomData, }; } let mut len = None; - if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { + + if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { if let Ok(s) = l.to_str() { if let Ok(l) = s.parse::() { len = Some(l) @@ -272,7 +356,8 @@ where JsonBody { length: len, err: None, - fut: Some(ReadBody::new(req.take_payload(), 65536)), + timeout: std::mem::take(&mut res.timeout), + fut: Some(ReadBody::new(res.take_payload(), 65536)), _phantom: PhantomData, } } @@ -311,6 +396,10 @@ where } } + self.timeout + .poll_timeout(cx) + .map_err(JsonPayloadError::Payload)?; + let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index a72b129f8..6bac401c5 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -18,15 +18,11 @@ use actix_http::{ use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; use derive_more::From; -use futures_core::Stream; +use futures_core::{ready, Stream}; use serde::Serialize; #[cfg(feature = "compress")] -use actix_http::encoding::Decoder; -#[cfg(feature = "compress")] -use actix_http::http::header::ContentEncoding; -#[cfg(feature = "compress")] -use actix_http::{Payload, PayloadStream}; +use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream}; use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use crate::response::ClientResponse; @@ -61,7 +57,6 @@ impl From for SendRequestError { pub enum SendClientRequest { Fut( Pin>>>, - // FIXME: use a pinned Sleep instead of box. Option>>, bool, ), @@ -88,15 +83,14 @@ impl Future for SendClientRequest { match this { SendClientRequest::Fut(send, delay, response_decompress) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { - res.map_body(|head, payload| { + let res = ready!(send.as_mut().poll(cx)).map(|res| { + res._timeout(delay.take()).map_body(|head, payload| { if *response_decompress { Payload::Stream(Decoder::from_headers(payload, &head.headers)) } else { @@ -123,13 +117,15 @@ impl Future for SendClientRequest { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, _) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - Pin::new(send).poll(cx) + + send.as_mut() + .poll(cx) + .map_ok(|res| res._timeout(delay.take())) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index 7e74d226e..bcbaf3f41 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -24,7 +24,7 @@ use actix_web::{ middleware::Compress, test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, }; -use awc::error::SendRequestError; +use awc::error::{JsonPayloadError, PayloadError, SendRequestError}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -157,6 +157,79 @@ async fn test_timeout_override() { } } +#[actix_rt::test] +async fn test_response_timeout() { + use futures_util::stream::{once, StreamExt}; + + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/json") + .streaming(Box::pin(once(async { + actix_rt::time::sleep(Duration::from_millis(200)).await; + Ok::<_, Error>(Bytes::from(STR)) + }))), + ) + }))) + }); + + let client = awc::Client::new(); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(500)) + .body() + .await + .unwrap(); + assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .next() + .await + .unwrap(); + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .body() + .await; + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .json::>() + .await; + match res { + Err(JsonPayloadError::Payload(PayloadError::Io(e))) => { + assert_eq!(e.kind(), std::io::ErrorKind::TimedOut) + } + _ => panic!("Response error type is not matched"), + } +} + #[actix_rt::test] async fn test_connection_reuse() { let num = Arc::new(AtomicUsize::new(0));