diff --git a/awc/src/response.rs b/awc/src/response.rs index 3d78b84fb..0cbfd73d5 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -16,6 +16,7 @@ use actix_http::{ use actix_rt::time::{sleep, Sleep}; use bytes::{Bytes, BytesMut}; use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; use serde::de::DeserializeOwned; #[cfg(feature = "cookies")] @@ -316,27 +317,30 @@ where } } -/// Response's payload json parser, it resolves to a deserialized `T` value. -/// -/// Returns error: -/// -/// * content type is not `application/json` -/// * content length is greater than 64k -pub struct JsonBody { - length: Option, - err: Option, - timeout: ResponseTimeout, - fut: Option>, - _phantom: PhantomData, +pin_project! { + /// Response's payload json parser, it resolves to a deserialized `T` value. + /// + /// # Errors + /// `Future` implementation returns error if: + /// * content type is not `application/json` + /// * content length is greater than 64k + pub struct JsonBody { + #[pin] + fut: Option>, + length: Option, + timeout: ResponseTimeout, + err: Option, + _phantom: PhantomData, + } } -impl JsonBody +impl JsonBody where - S: Stream>, - U: DeserializeOwned, + B: Stream>, + T: DeserializeOwned, { /// Create `JsonBody` for request. - pub fn new(res: &mut ClientResponse) -> Self { + pub fn new(res: &mut ClientResponse) -> Self { // check content-type let json = if let Ok(Some(mime)) = res.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) @@ -364,10 +368,10 @@ where } JsonBody { - length: len, - err: None, - timeout: std::mem::take(&mut res.timeout), fut: Some(ReadBody::new(res.take_payload(), 65536)), + length: len, + timeout: std::mem::take(&mut res.timeout), + err: None, _phantom: PhantomData, } } @@ -381,41 +385,37 @@ where } } -impl Unpin for JsonBody +impl Future for JsonBody where - T: Stream> + Unpin, - U: DeserializeOwned, + B: Stream>, + T: DeserializeOwned, { -} + type Output = Result; -impl Future for JsonBody -where - T: Stream> + Unpin, - U: DeserializeOwned, -{ - type Output = Result; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(err) = self.err.take() { + if let Some(err) = this.err.take() { return Poll::Ready(Err(err)); } - if let Some(len) = self.length.take() { - if len > self.fut.as_ref().unwrap().limit { + if let Some(len) = this.length.take() { + let body = Option::as_ref(&this.fut).unwrap(); + if len > body.limit { return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow))); } } - self.timeout + this.timeout .poll_timeout(cx) .map_err(JsonPayloadError::Payload)?; - let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; - Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) + let body = ready!(this.fut.as_pin_mut().unwrap().poll(cx))?; + Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) } } -pin_project_lite::pin_project! { +pin_project! { struct ReadBody { #[pin] stream: Payload, @@ -447,6 +447,7 @@ where if (this.buf.len() + chunk.len()) > *this.limit { return Poll::Ready(Err(PayloadError::Overflow)); } + this.buf.extend_from_slice(&chunk); } @@ -460,7 +461,12 @@ mod tests { use static_assertions::assert_impl_all; use super::*; - use crate::{http::header, test::TestResponse}; + use crate::{any_body::AnyBody, http::header, test::TestResponse}; + + assert_impl_all!(ReadBody<()>: Unpin); + assert_impl_all!(ReadBody: Unpin); + + assert_impl_all!(JsonBody: Unpin); assert_impl_all!(ClientResponse: Unpin);