From 000e80aafebaa06e538a547dbf04a2343ca83bb5 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 26 Jan 2021 02:10:38 -0800 Subject: [PATCH] fix the return type in poll_timeout. return payload error directly --- awc/src/response.rs | 66 ++++++++++++++++++---------------------- awc/tests/test_client.rs | 4 +-- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/awc/src/response.rs b/awc/src/response.rs index 847b908ae..9762f770b 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -34,30 +34,32 @@ pub(crate) enum ResponseTimeout { Enabled(Option>>), } +impl Default for ResponseTimeout { + fn default() -> Self { + Self::Disabled(None) + } +} + impl ResponseTimeout { - fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), io::Error> { + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { match *self { Self::Disabled(_) => Ok(()), Self::Enabled(Some(ref mut timeout)) => { if timeout.as_mut().poll(cx).is_ready() { - Ok(()) - } else { Err(Self::err()) + } else { + Ok(()) } } Self::Enabled(None) => Err(Self::err()), } } - fn take(&mut self) -> Option>> { - match *self { - Self::Disabled(_) => None, - Self::Enabled(ref mut timeout) => timeout.take(), - } - } - - fn err() -> io::Error { - io::Error::new(io::ErrorKind::TimedOut, "Response Payload IO timed out") + fn err() -> PayloadError { + PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Response Payload IO timed out", + )) } } @@ -106,7 +108,7 @@ impl ClientResponse { ClientResponse { head, payload, - timeout: ResponseTimeout::Disabled(None), + timeout: ResponseTimeout::default(), } } @@ -216,7 +218,7 @@ where cx: &mut Context<'_>, ) -> Poll> { let this = self.get_mut(); - this.timeout.poll_timeout(cx).map_err(PayloadError::Io)?; + this.timeout.poll_timeout(cx)?; Pin::new(&mut this.payload).poll_next(cx) } @@ -237,7 +239,7 @@ impl fmt::Debug for ClientResponse { pub struct MessageBody { length: Option, err: Option, - timeout: Option>>, + timeout: ResponseTimeout, fut: Option>, } @@ -263,7 +265,7 @@ where MessageBody { length: len, err: None, - timeout: res.timeout.take(), + timeout: std::mem::take(&mut res.timeout), fut: Some(ReadBody::new(res.take_payload(), 262_144)), } } @@ -281,7 +283,7 @@ where fut: None, err: Some(e), length: None, - timeout: None, + timeout: ResponseTimeout::default(), } } } @@ -305,11 +307,7 @@ where } } - if let Some(ref mut timeout) = this.timeout { - if timeout.as_mut().poll(cx).is_ready() { - return Poll::Ready(Err(PayloadError::Io(ResponseTimeout::err()))); - } - } + this.timeout.poll_timeout(cx)?; Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } @@ -324,7 +322,7 @@ where pub struct JsonBody { length: Option, err: Option, - timeout: Option>>, + timeout: ResponseTimeout, fut: Option>, _phantom: PhantomData, } @@ -335,9 +333,9 @@ where U: DeserializeOwned, { /// Create `JsonBody` for request. - pub fn new(req: &mut ClientResponse) -> Self { + pub fn new(res: &mut ClientResponse) -> Self { // check content-type - let json = if let Ok(Some(mime)) = req.mime_type() { + let json = if let Ok(Some(mime)) = res.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) } else { false @@ -346,14 +344,14 @@ where return JsonBody { length: None, fut: None, - timeout: None, + timeout: ResponseTimeout::default(), err: Some(JsonPayloadError::ContentType), _phantom: PhantomData, }; } let mut len = None; - if let Some(l) = req.headers().get(&CONTENT_LENGTH) { + if let Some(l) = res.headers().get(&CONTENT_LENGTH) { if let Ok(s) = l.to_str() { if let Ok(l) = s.parse::() { len = Some(l) @@ -364,8 +362,8 @@ where JsonBody { length: len, err: None, - timeout: req.timeout.take(), - fut: Some(ReadBody::new(req.take_payload(), 65536)), + timeout: std::mem::take(&mut res.timeout), + fut: Some(ReadBody::new(res.take_payload(), 65536)), _phantom: PhantomData, } } @@ -406,13 +404,9 @@ where } } - if let Some(ref mut timeout) = self.timeout { - if timeout.as_mut().poll(cx).is_ready() { - return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Io( - ResponseTimeout::err(), - )))); - } - } + self.timeout + .poll_timeout(cx) + .map_err(JsonPayloadError::Payload)?; let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index 71a025c95..82ef7fe92 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -795,9 +795,9 @@ async fn test_client_cookie_handling() { async fn client_unread_response() { let addr = test::unused_addr(); - std::thread::spawn(move || { - let lst = std::net::TcpListener::bind(addr).unwrap(); + let lst = std::net::TcpListener::bind(addr).unwrap(); + std::thread::spawn(move || { for stream in lst.incoming() { let mut stream = stream.unwrap(); let mut b = [0; 1000];