fix the return type in poll_timeout. return payload error directly

This commit is contained in:
fakeshadow 2021-01-26 02:10:38 -08:00
parent 7d2f63eba2
commit 000e80aafe
2 changed files with 32 additions and 38 deletions

View File

@ -34,30 +34,32 @@ pub(crate) enum ResponseTimeout {
Enabled(Option<Pin<Box<Sleep>>>), Enabled(Option<Pin<Box<Sleep>>>),
} }
impl Default for ResponseTimeout {
fn default() -> Self {
Self::Disabled(None)
}
}
impl ResponseTimeout { impl ResponseTimeout {
fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), io::Error> { fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> {
match *self { match *self {
Self::Disabled(_) => Ok(()), Self::Disabled(_) => Ok(()),
Self::Enabled(Some(ref mut timeout)) => { Self::Enabled(Some(ref mut timeout)) => {
if timeout.as_mut().poll(cx).is_ready() { if timeout.as_mut().poll(cx).is_ready() {
Ok(())
} else {
Err(Self::err()) Err(Self::err())
} else {
Ok(())
} }
} }
Self::Enabled(None) => Err(Self::err()), Self::Enabled(None) => Err(Self::err()),
} }
} }
fn take(&mut self) -> Option<Pin<Box<Sleep>>> { fn err() -> PayloadError {
match *self { PayloadError::Io(io::Error::new(
Self::Disabled(_) => None, io::ErrorKind::TimedOut,
Self::Enabled(ref mut timeout) => timeout.take(), "Response Payload IO timed out",
} ))
}
fn err() -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, "Response Payload IO timed out")
} }
} }
@ -106,7 +108,7 @@ impl<S> ClientResponse<S> {
ClientResponse { ClientResponse {
head, head,
payload, payload,
timeout: ResponseTimeout::Disabled(None), timeout: ResponseTimeout::default(),
} }
} }
@ -216,7 +218,7 @@ where
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> { ) -> Poll<Option<Self::Item>> {
let this = self.get_mut(); let this = self.get_mut();
this.timeout.poll_timeout(cx).map_err(PayloadError::Io)?; this.timeout.poll_timeout(cx)?;
Pin::new(&mut this.payload).poll_next(cx) Pin::new(&mut this.payload).poll_next(cx)
} }
@ -237,7 +239,7 @@ impl<S> fmt::Debug for ClientResponse<S> {
pub struct MessageBody<S> { pub struct MessageBody<S> {
length: Option<usize>, length: Option<usize>,
err: Option<PayloadError>, err: Option<PayloadError>,
timeout: Option<Pin<Box<Sleep>>>, timeout: ResponseTimeout,
fut: Option<ReadBody<S>>, fut: Option<ReadBody<S>>,
} }
@ -263,7 +265,7 @@ where
MessageBody { MessageBody {
length: len, length: len,
err: None, err: None,
timeout: res.timeout.take(), timeout: std::mem::take(&mut res.timeout),
fut: Some(ReadBody::new(res.take_payload(), 262_144)), fut: Some(ReadBody::new(res.take_payload(), 262_144)),
} }
} }
@ -281,7 +283,7 @@ where
fut: None, fut: None,
err: Some(e), err: Some(e),
length: None, length: None,
timeout: None, timeout: ResponseTimeout::default(),
} }
} }
} }
@ -305,11 +307,7 @@ where
} }
} }
if let Some(ref mut timeout) = this.timeout { this.timeout.poll_timeout(cx)?;
if timeout.as_mut().poll(cx).is_ready() {
return Poll::Ready(Err(PayloadError::Io(ResponseTimeout::err())));
}
}
Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
} }
@ -324,7 +322,7 @@ where
pub struct JsonBody<S, U> { pub struct JsonBody<S, U> {
length: Option<usize>, length: Option<usize>,
err: Option<JsonPayloadError>, err: Option<JsonPayloadError>,
timeout: Option<Pin<Box<Sleep>>>, timeout: ResponseTimeout,
fut: Option<ReadBody<S>>, fut: Option<ReadBody<S>>,
_phantom: PhantomData<U>, _phantom: PhantomData<U>,
} }
@ -335,9 +333,9 @@ where
U: DeserializeOwned, U: DeserializeOwned,
{ {
/// Create `JsonBody` for request. /// Create `JsonBody` for request.
pub fn new(req: &mut ClientResponse<S>) -> Self { pub fn new(res: &mut ClientResponse<S>) -> Self {
// check content-type // check content-type
let json = if let Ok(Some(mime)) = req.mime_type() { let json = if let Ok(Some(mime)) = res.mime_type() {
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
} else { } else {
false false
@ -346,14 +344,14 @@ where
return JsonBody { return JsonBody {
length: None, length: None,
fut: None, fut: None,
timeout: None, timeout: ResponseTimeout::default(),
err: Some(JsonPayloadError::ContentType), err: Some(JsonPayloadError::ContentType),
_phantom: PhantomData, _phantom: PhantomData,
}; };
} }
let mut len = None; let mut len = None;
if let Some(l) = req.headers().get(&CONTENT_LENGTH) { if let Some(l) = res.headers().get(&CONTENT_LENGTH) {
if let Ok(s) = l.to_str() { if let Ok(s) = l.to_str() {
if let Ok(l) = s.parse::<usize>() { if let Ok(l) = s.parse::<usize>() {
len = Some(l) len = Some(l)
@ -364,8 +362,8 @@ where
JsonBody { JsonBody {
length: len, length: len,
err: None, err: None,
timeout: req.timeout.take(), timeout: std::mem::take(&mut res.timeout),
fut: Some(ReadBody::new(req.take_payload(), 65536)), fut: Some(ReadBody::new(res.take_payload(), 65536)),
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@ -406,13 +404,9 @@ where
} }
} }
if let Some(ref mut timeout) = self.timeout { self.timeout
if timeout.as_mut().poll(cx).is_ready() { .poll_timeout(cx)
return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Io( .map_err(JsonPayloadError::Payload)?;
ResponseTimeout::err(),
))));
}
}
let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?;
Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from)) Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))

View File

@ -795,9 +795,9 @@ async fn test_client_cookie_handling() {
async fn client_unread_response() { async fn client_unread_response() {
let addr = test::unused_addr(); let addr = test::unused_addr();
std::thread::spawn(move || {
let lst = std::net::TcpListener::bind(addr).unwrap(); let lst = std::net::TcpListener::bind(addr).unwrap();
std::thread::spawn(move || {
for stream in lst.incoming() { for stream in lst.incoming() {
let mut stream = stream.unwrap(); let mut stream = stream.unwrap();
let mut b = [0; 1000]; let mut b = [0; 1000];