From 27f3d8f51736d91c95016ec691fd6b6d4745896c Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Sat, 25 Dec 2021 01:11:33 +0000 Subject: [PATCH] split awc response module --- awc/CHANGES.md | 7 +- awc/src/connect.rs | 2 +- awc/src/lib.rs | 4 +- awc/src/response.rs | 568 ----------------------------- awc/src/responses/json_body.rs | 196 ++++++++++ awc/src/responses/mod.rs | 262 +++++++++++++ awc/src/responses/read_body.rs | 61 ++++ awc/src/responses/response_body.rs | 128 +++++++ awc/src/ws.rs | 3 +- 9 files changed, 656 insertions(+), 575 deletions(-) delete mode 100644 awc/src/response.rs create mode 100644 awc/src/responses/json_body.rs create mode 100644 awc/src/responses/mod.rs create mode 100644 awc/src/responses/read_body.rs create mode 100644 awc/src/responses/response_body.rs diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 9dcca804e..2864c9112 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -3,10 +3,13 @@ ## Unreleased - 2021-xx-xx - Rename `Connector::{ssl => openssl}`. [#2503] - Improve `Client` instantiation efficiency when using `openssl` by only building connectors once. [#2503] -- `ClientRequest::send_body` now takes an `impl MessageBody`. [#????] +- `ClientRequest::send_body` now takes an `impl MessageBody`. [#2546] +- `impl Future` for `JsonBody` no longer requires the body type be `Unpin`. [#2546] +- `impl Stream` for `ClientResponse` no longer requires the body type be `Unpin`. [#2546] +- `ClientResponse` is no longer `Unpin`. [#2546] [#2503]: https://github.com/actix/actix-web/pull/2503 -[#????]: https://github.com/actix/actix-web/pull/???? +[#2546]: https://github.com/actix/actix-web/pull/2546 ## 3.0.0-beta.14 - 2021-12-17 diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 19870b069..f93014a67 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -16,7 +16,7 @@ use crate::{ client::{ Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, }, - response::ClientResponse, + ClientResponse, }; pub type BoxConnectorService = Rc< diff --git a/awc/src/lib.rs b/awc/src/lib.rs index 00c559406..6c13bc129 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -113,7 +113,7 @@ pub mod error; mod frozen; pub mod middleware; mod request; -mod response; +mod responses; mod sender; pub mod test; pub mod ws; @@ -128,7 +128,7 @@ pub use self::client::Connector; pub use self::connect::{BoxConnectorService, BoxedSocket, ConnectRequest, ConnectResponse}; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; -pub use self::response::{ClientResponse, JsonBody, MessageBody}; +pub use self::responses::{ClientResponse, JsonBody, ResponseBody}; pub use self::sender::SendClientRequest; use std::{convert::TryFrom, rc::Rc, time::Duration}; diff --git a/awc/src/response.rs b/awc/src/response.rs deleted file mode 100644 index 0cbfd73d5..000000000 --- a/awc/src/response.rs +++ /dev/null @@ -1,568 +0,0 @@ -use std::{ - cell::{Ref, RefMut}, - fmt, - future::Future, - io, - marker::PhantomData, - pin::Pin, - task::{Context, Poll}, - time::{Duration, Instant}, -}; - -use actix_http::{ - error::PayloadError, header, header::HeaderMap, BoxedPayloadStream, Extensions, - HttpMessage, Payload, ResponseHead, StatusCode, Version, -}; -use actix_rt::time::{sleep, Sleep}; -use bytes::{Bytes, BytesMut}; -use futures_core::{ready, Stream}; -use pin_project_lite::pin_project; -use serde::de::DeserializeOwned; - -#[cfg(feature = "cookies")] -use crate::cookie::{Cookie, ParseError as CookieParseError}; -use crate::error::JsonPayloadError; - -/// Client Response -pub struct ClientResponse { - pub(crate) head: ResponseHead, - pub(crate) payload: Payload, - pub(crate) timeout: ResponseTimeout, -} - -/// helper enum with reusable sleep passed from `SendClientResponse`. -/// See `ClientResponse::_timeout` for reason. -pub(crate) enum ResponseTimeout { - Disabled(Option>>), - Enabled(Pin>), -} - -impl Default for ResponseTimeout { - fn default() -> Self { - Self::Disabled(None) - } -} - -impl ResponseTimeout { - fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { - match *self { - Self::Enabled(ref mut timeout) => { - if timeout.as_mut().poll(cx).is_ready() { - Err(PayloadError::Io(io::Error::new( - io::ErrorKind::TimedOut, - "Response Payload IO timed out", - ))) - } else { - Ok(()) - } - } - Self::Disabled(_) => Ok(()), - } - } -} - -impl HttpMessage for ClientResponse { - type Stream = S; - - fn headers(&self) -> &HeaderMap { - &self.head.headers - } - - fn take_payload(&mut self) -> Payload { - std::mem::replace(&mut self.payload, Payload::None) - } - - fn extensions(&self) -> Ref<'_, Extensions> { - self.head.extensions() - } - - fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.head.extensions_mut() - } -} - -impl ClientResponse { - /// Create new Request instance - pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { - ClientResponse { - head, - payload, - timeout: ResponseTimeout::default(), - } - } - - #[inline] - pub(crate) fn head(&self) -> &ResponseHead { - &self.head - } - - /// Read the Request Version. - #[inline] - pub fn version(&self) -> Version { - self.head().version - } - - /// Get the status from the server. - #[inline] - pub fn status(&self) -> StatusCode { - self.head().status - } - - #[inline] - /// Returns request's headers. - pub fn headers(&self) -> &HeaderMap { - &self.head().headers - } - - /// Set a body and return previous body value - pub fn map_body(mut self, f: F) -> ClientResponse - where - F: FnOnce(&mut ResponseHead, Payload) -> Payload, - { - let payload = f(&mut self.head, self.payload); - - ClientResponse { - payload, - head: self.head, - timeout: self.timeout, - } - } - - /// Set a timeout duration for [`ClientResponse`](self::ClientResponse). - /// - /// This duration covers the duration of processing the response body stream - /// and would end it as timeout error when deadline met. - /// - /// Disabled by default. - pub fn timeout(self, dur: Duration) -> Self { - let timeout = match self.timeout { - ResponseTimeout::Disabled(Some(mut timeout)) - | ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) { - Some(deadline) => { - timeout.as_mut().reset(deadline.into()); - ResponseTimeout::Enabled(timeout) - } - None => ResponseTimeout::Enabled(Box::pin(sleep(dur))), - }, - _ => ResponseTimeout::Enabled(Box::pin(sleep(dur))), - }; - - Self { - payload: self.payload, - head: self.head, - timeout, - } - } - - /// This method does not enable timeout. It's used to pass the boxed `Sleep` from - /// `SendClientRequest` and reuse it's heap allocation together with it's slot in - /// timer wheel. - pub(crate) fn _timeout(mut self, timeout: Option>>) -> Self { - self.timeout = ResponseTimeout::Disabled(timeout); - self - } - - /// Load request cookies. - #[cfg(feature = "cookies")] - pub fn cookies(&self) -> Result>>, CookieParseError> { - struct Cookies(Vec>); - - if self.extensions().get::().is_none() { - let mut cookies = Vec::new(); - for hdr in self.headers().get_all(&header::SET_COOKIE) { - let s = std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; - cookies.push(Cookie::parse_encoded(s)?.into_owned()); - } - self.extensions_mut().insert(Cookies(cookies)); - } - Ok(Ref::map(self.extensions(), |ext| { - &ext.get::().unwrap().0 - })) - } - - /// Return request cookie. - #[cfg(feature = "cookies")] - pub fn cookie(&self, name: &str) -> Option> { - if let Ok(cookies) = self.cookies() { - for cookie in cookies.iter() { - if cookie.name() == name { - return Some(cookie.to_owned()); - } - } - } - None - } -} - -impl ClientResponse -where - S: Stream>, -{ - /// Loads HTTP response's body. - pub fn body(&mut self) -> MessageBody { - MessageBody::new(self) - } - - /// Loads and parse `application/json` encoded body. - /// Return `JsonBody` future. It resolves to a `T` value. - /// - /// Returns error: - /// - /// * content type is not `application/json` - /// * content length is greater than 256k - pub fn json(&mut self) -> JsonBody { - JsonBody::new(self) - } -} - -impl Stream for ClientResponse -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - this.timeout.poll_timeout(cx)?; - - Pin::new(&mut this.payload).poll_next(cx) - } -} - -impl fmt::Debug for ClientResponse { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; - writeln!(f, " headers:")?; - for (key, val) in self.headers().iter() { - writeln!(f, " {:?}: {:?}", key, val)?; - } - Ok(()) - } -} - -const DEFAULT_BODY_LIMIT: usize = 2 * 1024 * 1024; - -/// Future that resolves to a complete HTTP message body. -pub struct MessageBody { - length: Option, - timeout: ResponseTimeout, - body: Result, Option>, -} - -impl MessageBody -where - S: Stream>, -{ - /// Create `MessageBody` for request. - pub fn new(res: &mut ClientResponse) -> MessageBody { - let length = match res.headers().get(&header::CONTENT_LENGTH) { - Some(value) => { - let len = value.to_str().ok().and_then(|s| s.parse::().ok()); - - match len { - None => return Self::err(PayloadError::UnknownLength), - len => len, - } - } - None => None, - }; - - MessageBody { - length, - timeout: std::mem::take(&mut res.timeout), - body: Ok(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)), - } - } - - /// Change max size of payload. By default max size is 2048kB - pub fn limit(mut self, limit: usize) -> Self { - if let Ok(ref mut body) = self.body { - body.limit = limit; - } - self - } - - fn err(e: PayloadError) -> Self { - MessageBody { - length: None, - timeout: ResponseTimeout::default(), - body: Err(Some(e)), - } - } -} - -impl Future for MessageBody -where - S: Stream> + Unpin, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - match this.body { - Err(ref mut err) => Poll::Ready(Err(err.take().unwrap())), - Ok(ref mut body) => { - if let Some(len) = this.length.take() { - if len > body.limit { - return Poll::Ready(Err(PayloadError::Overflow)); - } - } - - this.timeout.poll_timeout(cx)?; - - Pin::new(body).poll(cx) - } - } - } -} - -pin_project! { - /// Response's payload json parser, it resolves to a deserialized `T` value. - /// - /// # Errors - /// `Future` implementation returns error if: - /// * content type is not `application/json` - /// * content length is greater than 64k - pub struct JsonBody { - #[pin] - fut: Option>, - length: Option, - timeout: ResponseTimeout, - err: Option, - _phantom: PhantomData, - } -} - -impl JsonBody -where - B: Stream>, - T: DeserializeOwned, -{ - /// Create `JsonBody` for request. - pub fn new(res: &mut ClientResponse) -> Self { - // check content-type - let json = if let Ok(Some(mime)) = res.mime_type() { - mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) - } else { - false - }; - if !json { - return JsonBody { - length: None, - fut: None, - timeout: ResponseTimeout::default(), - err: Some(JsonPayloadError::ContentType), - _phantom: PhantomData, - }; - } - - let mut len = None; - - if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { - if let Ok(s) = l.to_str() { - if let Ok(l) = s.parse::() { - len = Some(l) - } - } - } - - JsonBody { - fut: Some(ReadBody::new(res.take_payload(), 65536)), - length: len, - timeout: std::mem::take(&mut res.timeout), - err: None, - _phantom: PhantomData, - } - } - - /// Change max size of payload. By default max size is 64kB - pub fn limit(mut self, limit: usize) -> Self { - if let Some(ref mut fut) = self.fut { - fut.limit = limit; - } - self - } -} - -impl Future for JsonBody -where - B: Stream>, - T: DeserializeOwned, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - if let Some(err) = this.err.take() { - return Poll::Ready(Err(err)); - } - - if let Some(len) = this.length.take() { - let body = Option::as_ref(&this.fut).unwrap(); - if len > body.limit { - return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow))); - } - } - - this.timeout - .poll_timeout(cx) - .map_err(JsonPayloadError::Payload)?; - - let body = ready!(this.fut.as_pin_mut().unwrap().poll(cx))?; - Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) - } -} - -pin_project! { - struct ReadBody { - #[pin] - stream: Payload, - buf: BytesMut, - limit: usize, - } -} - -impl ReadBody { - fn new(stream: Payload, limit: usize) -> Self { - Self { - stream, - buf: BytesMut::new(), - limit, - } - } -} - -impl Future for ReadBody -where - S: Stream>, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - while let Some(chunk) = ready!(this.stream.as_mut().poll_next(cx)?) { - if (this.buf.len() + chunk.len()) > *this.limit { - return Poll::Ready(Err(PayloadError::Overflow)); - } - - this.buf.extend_from_slice(&chunk); - } - - Poll::Ready(Ok(this.buf.split().freeze())) - } -} - -#[cfg(test)] -mod tests { - use serde::{Deserialize, Serialize}; - use static_assertions::assert_impl_all; - - use super::*; - use crate::{any_body::AnyBody, http::header, test::TestResponse}; - - assert_impl_all!(ReadBody<()>: Unpin); - assert_impl_all!(ReadBody: Unpin); - - assert_impl_all!(JsonBody: Unpin); - - assert_impl_all!(ClientResponse: Unpin); - - #[actix_rt::test] - async fn test_body() { - let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "xxxx")).finish(); - match req.body().await.err().unwrap() { - PayloadError::UnknownLength => {} - _ => unreachable!("error"), - } - - let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "10000000")).finish(); - match req.body().await.err().unwrap() { - PayloadError::Overflow => {} - _ => unreachable!("error"), - } - - let mut req = TestResponse::default() - .set_payload(Bytes::from_static(b"test")) - .finish(); - assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); - - let mut req = TestResponse::default() - .set_payload(Bytes::from_static(b"11111111111111")) - .finish(); - match req.body().limit(5).await.err().unwrap() { - PayloadError::Overflow => {} - _ => unreachable!("error"), - } - } - - #[derive(Serialize, Deserialize, PartialEq, Debug)] - struct MyObject { - name: String, - } - - fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { - match err { - JsonPayloadError::Payload(PayloadError::Overflow) => { - matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) - } - JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), - _ => false, - } - } - - #[actix_rt::test] - async fn test_json_body() { - let mut req = TestResponse::default().finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).await; - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - - let mut req = TestResponse::default() - .insert_header(( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/text"), - )) - .finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).await; - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - - let mut req = TestResponse::default() - .insert_header(( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - )) - .insert_header(( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("10000"), - )) - .finish(); - - let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; - assert!(json_eq( - json.err().unwrap(), - JsonPayloadError::Payload(PayloadError::Overflow) - )); - - let mut req = TestResponse::default() - .insert_header(( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - )) - .insert_header(( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - )) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .finish(); - - let json = JsonBody::<_, MyObject>::new(&mut req).await; - assert_eq!( - json.ok().unwrap(), - MyObject { - name: "test".to_owned() - } - ); - } -} diff --git a/awc/src/responses/json_body.rs b/awc/src/responses/json_body.rs new file mode 100644 index 000000000..1d81dd2fc --- /dev/null +++ b/awc/src/responses/json_body.rs @@ -0,0 +1,196 @@ +use std::{ + future::Future, + marker::PhantomData, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{error::PayloadError, header, HttpMessage}; +use bytes::Bytes; +use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; +use serde::de::DeserializeOwned; + +use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT}; +use crate::{error::JsonPayloadError, ClientResponse}; + +pin_project! { + /// Consumes body stream and parses JSON, resolving to a deserialized `T` value. + /// + /// # Errors + /// `Future` implementation returns error if: + /// - content type is not `application/json`; + /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB). + pub struct JsonBody { + #[pin] + body: Option>, + length: Option, + timeout: ResponseTimeout, + err: Option, + _phantom: PhantomData, + } +} + +impl JsonBody +where + B: Stream>, + T: DeserializeOwned, +{ + /// Create `JsonBody` for request. + pub fn new(res: &mut ClientResponse) -> Self { + // check content-type + let json = if let Ok(Some(mime)) = res.mime_type() { + mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) + } else { + false + }; + + if !json { + return JsonBody { + length: None, + body: None, + timeout: ResponseTimeout::default(), + err: Some(JsonPayloadError::ContentType), + _phantom: PhantomData, + }; + } + + let mut len = None; + + if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { + if let Ok(s) = l.to_str() { + if let Ok(l) = s.parse::() { + len = Some(l) + } + } + } + + JsonBody { + body: Some(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)), + length: len, + timeout: mem::take(&mut res.timeout), + err: None, + _phantom: PhantomData, + } + } + + /// Change max size of payload. Default limit is 2 MiB. + pub fn limit(mut self, limit: usize) -> Self { + if let Some(ref mut fut) = self.body { + fut.limit = limit; + } + + self + } +} + +impl Future for JsonBody +where + B: Stream>, + T: DeserializeOwned, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Err(err)); + } + + if let Some(len) = this.length.take() { + let body = Option::as_ref(&this.body).unwrap(); + if len > body.limit { + return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow))); + } + } + + this.timeout + .poll_timeout(cx) + .map_err(JsonPayloadError::Payload)?; + + let body = ready!(this.body.as_pin_mut().unwrap().poll(cx))?; + Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) + } +} + +#[cfg(test)] +mod tests { + use actix_http::BoxedPayloadStream; + use serde::{Deserialize, Serialize}; + use static_assertions::assert_impl_all; + + use super::*; + use crate::{http::header, test::TestResponse}; + + assert_impl_all!(JsonBody: Unpin); + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct MyObject { + name: String, + } + + fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { + match err { + JsonPayloadError::Payload(PayloadError::Overflow) => { + matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) + } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), + _ => false, + } + } + + #[actix_rt::test] + async fn test_json_body() { + let mut req = TestResponse::default().finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestResponse::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/text"), + )) + .finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestResponse::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + )) + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("10000"), + )) + .finish(); + + let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; + assert!(json_eq( + json.err().unwrap(), + JsonPayloadError::Payload(PayloadError::Overflow) + )); + + let mut req = TestResponse::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + )) + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + )) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .finish(); + + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert_eq!( + json.ok().unwrap(), + MyObject { + name: "test".to_owned() + } + ); + } +} diff --git a/awc/src/responses/mod.rs b/awc/src/responses/mod.rs new file mode 100644 index 000000000..82d8ada5c --- /dev/null +++ b/awc/src/responses/mod.rs @@ -0,0 +1,262 @@ +use std::{ + cell::{Ref, RefMut}, + fmt, + future::Future, + io, mem, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use actix_http::{ + error::PayloadError, header, header::HeaderMap, BoxedPayloadStream, Extensions, + HttpMessage, Payload, ResponseHead, StatusCode, Version, +}; +use actix_rt::time::{sleep, Sleep}; +use bytes::Bytes; +use futures_core::Stream; +use serde::de::DeserializeOwned; + +#[cfg(feature = "cookies")] +use crate::cookie::{Cookie, ParseError as CookieParseError}; + +mod json_body; +mod read_body; +mod response_body; + +pub use self::json_body::JsonBody; +pub use self::response_body::ResponseBody; + +/// Client Response +pub struct ClientResponse { + pub(crate) head: ResponseHead, + pub(crate) payload: Payload, + pub(crate) timeout: ResponseTimeout, +} + +/// helper enum with reusable sleep passed from `SendClientResponse`. +/// See `ClientResponse::_timeout` for reason. +pub(crate) enum ResponseTimeout { + Disabled(Option>>), + Enabled(Pin>), +} + +impl Default for ResponseTimeout { + fn default() -> Self { + Self::Disabled(None) + } +} + +impl ResponseTimeout { + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { + match *self { + Self::Enabled(ref mut timeout) => { + if timeout.as_mut().poll(cx).is_ready() { + Err(PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Response Payload IO timed out", + ))) + } else { + Ok(()) + } + } + Self::Disabled(_) => Ok(()), + } + } +} + +impl HttpMessage for ClientResponse { + type Stream = S; + + fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + fn take_payload(&mut self) -> Payload { + mem::replace(&mut self.payload, Payload::None) + } + + fn extensions(&self) -> Ref<'_, Extensions> { + self.head.extensions() + } + + fn extensions_mut(&self) -> RefMut<'_, Extensions> { + self.head.extensions_mut() + } +} + +impl ClientResponse { + /// Create new Request instance + pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { + ClientResponse { + head, + payload, + timeout: ResponseTimeout::default(), + } + } + + #[inline] + pub(crate) fn head(&self) -> &ResponseHead { + &self.head + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + /// Get the status from the server. + #[inline] + pub fn status(&self) -> StatusCode { + self.head().status + } + + #[inline] + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// Set a body and return previous body value + pub fn map_body(mut self, f: F) -> ClientResponse + where + F: FnOnce(&mut ResponseHead, Payload) -> Payload, + { + let payload = f(&mut self.head, self.payload); + + ClientResponse { + payload, + head: self.head, + timeout: self.timeout, + } + } + + /// Set a timeout duration for [`ClientResponse`](self::ClientResponse). + /// + /// This duration covers the duration of processing the response body stream + /// and would end it as timeout error when deadline met. + /// + /// Disabled by default. + pub fn timeout(self, dur: Duration) -> Self { + let timeout = match self.timeout { + ResponseTimeout::Disabled(Some(mut timeout)) + | ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) { + Some(deadline) => { + timeout.as_mut().reset(deadline.into()); + ResponseTimeout::Enabled(timeout) + } + None => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }, + _ => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }; + + Self { + payload: self.payload, + head: self.head, + timeout, + } + } + + /// This method does not enable timeout. It's used to pass the boxed `Sleep` from + /// `SendClientRequest` and reuse it's heap allocation together with it's slot in + /// timer wheel. + pub(crate) fn _timeout(mut self, timeout: Option>>) -> Self { + self.timeout = ResponseTimeout::Disabled(timeout); + self + } + + /// Load request cookies. + #[cfg(feature = "cookies")] + pub fn cookies(&self) -> Result>>, CookieParseError> { + struct Cookies(Vec>); + + if self.extensions().get::().is_none() { + let mut cookies = Vec::new(); + for hdr in self.headers().get_all(&header::SET_COOKIE) { + let s = std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; + cookies.push(Cookie::parse_encoded(s)?.into_owned()); + } + self.extensions_mut().insert(Cookies(cookies)); + } + Ok(Ref::map(self.extensions(), |ext| { + &ext.get::().unwrap().0 + })) + } + + /// Return request cookie. + #[cfg(feature = "cookies")] + pub fn cookie(&self, name: &str) -> Option> { + if let Ok(cookies) = self.cookies() { + for cookie in cookies.iter() { + if cookie.name() == name { + return Some(cookie.to_owned()); + } + } + } + None + } +} + +impl ClientResponse +where + S: Stream>, +{ + /// Loads HTTP response's body. + /// Consumes body stream and parses JSON, resolving to a deserialized `T` value. + /// + /// # Errors + /// `Future` implementation returns error if: + /// - content type is not `application/json` + /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB) + pub fn body(&mut self) -> ResponseBody { + ResponseBody::new(self) + } + + /// Consumes body stream and parses JSON, resolving to a deserialized `T` value. + /// + /// # Errors + /// Future returns error if: + /// - content type is not `application/json`; + /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB). + pub fn json(&mut self) -> JsonBody { + JsonBody::new(self) + } +} + +impl Stream for ClientResponse +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + this.timeout.poll_timeout(cx)?; + + Pin::new(&mut this.payload).poll_next(cx) + } +} + +impl fmt::Debug for ClientResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +/// Default body size limit: 2MiB +const DEFAULT_BODY_LIMIT: usize = 2 * 1024 * 1024; + +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + + assert_impl_all!(ClientResponse: Unpin); +} diff --git a/awc/src/responses/read_body.rs b/awc/src/responses/read_body.rs new file mode 100644 index 000000000..a32bbb984 --- /dev/null +++ b/awc/src/responses/read_body.rs @@ -0,0 +1,61 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{error::PayloadError, Payload}; +use bytes::{Bytes, BytesMut}; +use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; + +pin_project! { + pub(crate) struct ReadBody { + #[pin] + pub(crate) stream: Payload, + pub(crate) buf: BytesMut, + pub(crate) limit: usize, + } +} + +impl ReadBody { + pub(crate) fn new(stream: Payload, limit: usize) -> Self { + Self { + stream, + buf: BytesMut::new(), + limit, + } + } +} + +impl Future for ReadBody +where + S: Stream>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + while let Some(chunk) = ready!(this.stream.as_mut().poll_next(cx)?) { + if (this.buf.len() + chunk.len()) > *this.limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + + this.buf.extend_from_slice(&chunk); + } + + Poll::Ready(Ok(this.buf.split().freeze())) + } +} + +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + use crate::any_body::AnyBody; + + assert_impl_all!(ReadBody<()>: Unpin); + assert_impl_all!(ReadBody: Unpin); +} diff --git a/awc/src/responses/response_body.rs b/awc/src/responses/response_body.rs new file mode 100644 index 000000000..e5437760e --- /dev/null +++ b/awc/src/responses/response_body.rs @@ -0,0 +1,128 @@ +use std::{ + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{error::PayloadError, header, HttpMessage}; +use bytes::Bytes; +use futures_core::Stream; + +use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT}; +use crate::ClientResponse; + +/// Future that resolves to a complete response body. +pub struct ResponseBody { + body: Result, Option>, + length: Option, + timeout: ResponseTimeout, +} + +impl ResponseBody +where + S: Stream>, +{ + /// Create `MessageBody` for request. + pub fn new(res: &mut ClientResponse) -> ResponseBody { + let length = match res.headers().get(&header::CONTENT_LENGTH) { + Some(value) => { + let len = value.to_str().ok().and_then(|s| s.parse::().ok()); + + match len { + None => return Self::err(PayloadError::UnknownLength), + len => len, + } + } + None => None, + }; + + ResponseBody { + length, + timeout: mem::take(&mut res.timeout), + body: Ok(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)), + } + } + + /// Change max size limit of payload. + /// + /// The default limit is 2 MiB. + pub fn limit(mut self, limit: usize) -> Self { + if let Ok(ref mut body) = self.body { + body.limit = limit; + } + self + } + + fn err(e: PayloadError) -> Self { + ResponseBody { + length: None, + timeout: ResponseTimeout::default(), + body: Err(Some(e)), + } + } +} + +impl Future for ResponseBody +where + S: Stream> + Unpin, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match this.body { + Err(ref mut err) => Poll::Ready(Err(err.take().unwrap())), + Ok(ref mut body) => { + if let Some(len) = this.length.take() { + if len > body.limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + } + + this.timeout.poll_timeout(cx)?; + + Pin::new(body).poll(cx) + } + } + } +} + +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + use crate::{http::header, test::TestResponse}; + + assert_impl_all!(ResponseBody<()>: Unpin); + + #[actix_rt::test] + async fn test_body() { + let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "xxxx")).finish(); + match req.body().await.err().unwrap() { + PayloadError::UnknownLength => {} + _ => unreachable!("error"), + } + + let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "10000000")).finish(); + match req.body().await.err().unwrap() { + PayloadError::Overflow => {} + _ => unreachable!("error"), + } + + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"test")) + .finish(); + assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); + + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"11111111111111")) + .finish(); + match req.body().limit(5).await.err().unwrap() { + PayloadError::Overflow => {} + _ => unreachable!("error"), + } + } +} diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 06d54aadb..c63e22969 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -42,8 +42,7 @@ use crate::{ header::{self, HeaderName, HeaderValue, TryIntoHeaderValue, AUTHORIZATION}, ConnectionType, Method, StatusCode, Uri, Version, }, - response::ClientResponse, - ClientConfig, + ClientConfig, ClientResponse, }; #[cfg(feature = "cookies")]