merge master

This commit is contained in:
fakeshadow 2021-02-17 04:05:35 -08:00
commit 837f012b7e
6 changed files with 230 additions and 76 deletions

View File

@ -13,6 +13,7 @@ use actix_rt::time::{sleep_until, Instant, Sleep};
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Buf, BytesMut}; use bytes::{Buf, BytesMut};
use futures_core::ready;
use log::{error, trace}; use log::{error, trace};
use pin_project::pin_project; use pin_project::pin_project;
@ -233,14 +234,10 @@ where
} }
} }
/// Flush stream
///
/// true - got WouldBlock
/// false - didn't get WouldBlock
fn poll_flush( fn poll_flush(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Result<bool, DispatchError> { ) -> Poll<Result<(), io::Error>> {
let InnerDispatcherProj { io, write_buf, .. } = self.project(); let InnerDispatcherProj { io, write_buf, .. } = self.project();
let mut io = Pin::new(io.as_mut().unwrap()); let mut io = Pin::new(io.as_mut().unwrap());
@ -248,19 +245,18 @@ where
let mut written = 0; let mut written = 0;
while written < len { while written < len {
match io.as_mut().poll_write(cx, &write_buf[written..]) { match io.as_mut().poll_write(cx, &write_buf[written..])? {
Poll::Ready(Ok(0)) => { Poll::Ready(0) => {
return Err(DispatchError::Io(io::Error::new( return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
"", "",
))) )))
} }
Poll::Ready(Ok(n)) => written += n, Poll::Ready(n) => written += n,
Poll::Pending => { Poll::Pending => {
write_buf.advance(written); write_buf.advance(written);
return Ok(true); return Poll::Pending;
} }
Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)),
} }
} }
@ -268,9 +264,7 @@ where
write_buf.clear(); write_buf.clear();
// flush the io and check if get blocked. // flush the io and check if get blocked.
let blocked = io.poll_flush(cx)?.is_pending(); io.poll_flush(cx)
Ok(blocked)
} }
fn send_response( fn send_response(
@ -841,15 +835,12 @@ where
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} else { } else {
// flush buffer and wait on block. // flush buffer and wait on blocked.
if inner.as_mut().poll_flush(cx)? { ready!(inner.as_mut().poll_flush(cx))?;
Poll::Pending
} else {
Pin::new(inner.project().io.as_mut().unwrap()) Pin::new(inner.project().io.as_mut().unwrap())
.poll_shutdown(cx) .poll_shutdown(cx)
.map_err(DispatchError::from) .map_err(DispatchError::from)
} }
}
} else { } else {
// read from io stream and fill read buffer. // read from io stream and fill read buffer.
let should_disconnect = inner.as_mut().read_available(cx)?; let should_disconnect = inner.as_mut().read_available(cx)?;
@ -888,7 +879,7 @@ where
// //
// TODO: what? is WouldBlock good or bad? // TODO: what? is WouldBlock good or bad?
// want to find a reference for this macOS behavior // want to find a reference for this macOS behavior
if inner.as_mut().poll_flush(cx)? || !drain { if inner.as_mut().poll_flush(cx)?.is_pending() || !drain {
break; break;
} }
} }

View File

@ -1,9 +1,13 @@
# Changes # Changes
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
### Added
* `ClientResponse::timeout` for set the timeout of collecting response body. [#1931]
### Changed ### Changed
* Feature `cookies` is now optional and enabled by default. [#1981] * Feature `cookies` is now optional and enabled by default. [#1981]
[#1931]: https://github.com/actix/actix-web/pull/1931
[#1981]: https://github.com/actix/actix-web/pull/1981 [#1981]: https://github.com/actix/actix-web/pull/1981

View File

@ -41,14 +41,14 @@ pub enum ConnectRequest {
} }
pub enum ConnectResponse { pub enum ConnectResponse {
ClientResponse(ClientResponse), Client(ClientResponse),
TunnelResponse(ResponseHead, Framed<BoxedSocket, ClientCodec>), Tunnel(ResponseHead, Framed<BoxedSocket, ClientCodec>),
} }
impl ConnectResponse { impl ConnectResponse {
pub fn into_client_response(self) -> ClientResponse { pub fn into_client_response(self) -> ClientResponse {
match self { match self {
ConnectResponse::ClientResponse(res) => res, ConnectResponse::Client(res) => res,
_ => panic!( _ => panic!(
"ClientResponse only reachable with ConnectResponse::ClientResponse variant" "ClientResponse only reachable with ConnectResponse::ClientResponse variant"
), ),
@ -57,7 +57,7 @@ impl ConnectResponse {
pub fn into_tunnel_response(self) -> (ResponseHead, Framed<BoxedSocket, ClientCodec>) { pub fn into_tunnel_response(self) -> (ResponseHead, Framed<BoxedSocket, ClientCodec>) {
match self { match self {
ConnectResponse::TunnelResponse(head, framed) => (head, framed), ConnectResponse::Tunnel(head, framed) => (head, framed),
_ => panic!( _ => panic!(
"TunnelResponse only reachable with ConnectResponse::TunnelResponse variant" "TunnelResponse only reachable with ConnectResponse::TunnelResponse variant"
), ),
@ -99,9 +99,7 @@ where
// send request // send request
let (head, payload) = connection.send_request(head, body).await?; let (head, payload) = connection.send_request(head, body).await?;
Ok(ConnectResponse::ClientResponse(ClientResponse::new( Ok(ConnectResponse::Client(ClientResponse::new(head, payload)))
head, payload,
)))
} }
ConnectRequest::Tunnel(head, ..) => { ConnectRequest::Tunnel(head, ..) => {
// send request // send request
@ -109,7 +107,7 @@ where
connection.open_tunnel(RequestHeadType::from(head)).await?; connection.open_tunnel(RequestHeadType::from(head)).await?;
let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io))));
Ok(ConnectResponse::TunnelResponse(head, framed)) Ok(ConnectResponse::Tunnel(head, framed))
} }
} }
}) })

View File

@ -1,20 +1,22 @@
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{ use std::{
cell::{Ref, RefMut}, cell::{Ref, RefMut},
mem, fmt,
future::Future,
io,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
time::{Duration, Instant},
}; };
use actix_http::{
error::PayloadError,
http::{header, HeaderMap, StatusCode, Version},
Extensions, HttpMessage, Payload, PayloadStream, ResponseHead,
};
use actix_rt::time::{sleep, Sleep};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_core::{ready, Stream}; use futures_core::{ready, Stream};
use actix_http::error::PayloadError;
use actix_http::http::header;
use actix_http::http::{HeaderMap, StatusCode, Version};
use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
@ -26,6 +28,38 @@ use crate::error::JsonPayloadError;
pub struct ClientResponse<S = PayloadStream> { pub struct ClientResponse<S = PayloadStream> {
pub(crate) head: ResponseHead, pub(crate) head: ResponseHead,
pub(crate) payload: Payload<S>, pub(crate) payload: Payload<S>,
pub(crate) timeout: ResponseTimeout,
}
/// helper enum with reusable sleep passed from `SendClientResponse`.
/// See `ClientResponse::_timeout` for reason.
pub(crate) enum ResponseTimeout {
Disabled(Option<Pin<Box<Sleep>>>),
Enabled(Pin<Box<Sleep>>),
}
impl Default for ResponseTimeout {
fn default() -> Self {
Self::Disabled(None)
}
}
impl ResponseTimeout {
fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> {
match *self {
Self::Enabled(ref mut timeout) => {
if timeout.as_mut().poll(cx).is_ready() {
Err(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Response Payload IO timed out",
)))
} else {
Ok(())
}
}
Self::Disabled(_) => Ok(()),
}
}
} }
impl<S> HttpMessage for ClientResponse<S> { impl<S> HttpMessage for ClientResponse<S> {
@ -35,6 +69,10 @@ impl<S> HttpMessage for ClientResponse<S> {
&self.head.headers &self.head.headers
} }
fn take_payload(&mut self) -> Payload<S> {
std::mem::replace(&mut self.payload, Payload::None)
}
fn extensions(&self) -> Ref<'_, Extensions> { fn extensions(&self) -> Ref<'_, Extensions> {
self.head.extensions() self.head.extensions()
} }
@ -43,10 +81,6 @@ impl<S> HttpMessage for ClientResponse<S> {
self.head.extensions_mut() self.head.extensions_mut()
} }
fn take_payload(&mut self) -> Payload<S> {
mem::replace(&mut self.payload, Payload::None)
}
/// Load request cookies. /// Load request cookies.
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> { fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
@ -69,7 +103,11 @@ impl<S> HttpMessage for ClientResponse<S> {
impl<S> ClientResponse<S> { impl<S> ClientResponse<S> {
/// Create new Request instance /// Create new Request instance
pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self { pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self {
ClientResponse { head, payload } ClientResponse {
head,
payload,
timeout: ResponseTimeout::default(),
}
} }
#[inline] #[inline]
@ -105,8 +143,43 @@ impl<S> ClientResponse<S> {
ClientResponse { ClientResponse {
payload, payload,
head: self.head, head: self.head,
timeout: self.timeout,
} }
} }
/// Set a timeout duration for [`ClientResponse`](self::ClientResponse).
///
/// This duration covers the duration of processing the response body stream
/// and would end it as timeout error when deadline met.
///
/// Disabled by default.
pub fn timeout(self, dur: Duration) -> Self {
let timeout = match self.timeout {
ResponseTimeout::Disabled(Some(mut timeout))
| ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) {
Some(deadline) => {
timeout.as_mut().reset(deadline.into());
ResponseTimeout::Enabled(timeout)
}
None => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
},
_ => ResponseTimeout::Enabled(Box::pin(sleep(dur))),
};
Self {
payload: self.payload,
head: self.head,
timeout,
}
}
/// This method does not enable timeout. It's used to pass the boxed `Sleep` from
/// `SendClientRequest` and reuse it's heap allocation together with it's slot in
/// timer wheel.
pub(crate) fn _timeout(mut self, timeout: Option<Pin<Box<Sleep>>>) -> Self {
self.timeout = ResponseTimeout::Disabled(timeout);
self
}
} }
impl<S> ClientResponse<S> impl<S> ClientResponse<S>
@ -137,7 +210,10 @@ where
type Item = Result<Bytes, PayloadError>; type Item = Result<Bytes, PayloadError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.get_mut().payload).poll_next(cx) let this = self.get_mut();
this.timeout.poll_timeout(cx)?;
Pin::new(&mut this.payload).poll_next(cx)
} }
} }
@ -156,6 +232,7 @@ impl<S> fmt::Debug for ClientResponse<S> {
pub struct MessageBody<S> { pub struct MessageBody<S> {
length: Option<usize>, length: Option<usize>,
err: Option<PayloadError>, err: Option<PayloadError>,
timeout: ResponseTimeout,
fut: Option<ReadBody<S>>, fut: Option<ReadBody<S>>,
} }
@ -181,6 +258,7 @@ where
MessageBody { MessageBody {
length: len, length: len,
err: None, err: None,
timeout: std::mem::take(&mut res.timeout),
fut: Some(ReadBody::new(res.take_payload(), 262_144)), fut: Some(ReadBody::new(res.take_payload(), 262_144)),
} }
} }
@ -198,6 +276,7 @@ where
fut: None, fut: None,
err: Some(e), err: Some(e),
length: None, length: None,
timeout: ResponseTimeout::default(),
} }
} }
} }
@ -221,6 +300,8 @@ where
} }
} }
this.timeout.poll_timeout(cx)?;
Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
} }
} }
@ -234,6 +315,7 @@ where
pub struct JsonBody<S, U> { pub struct JsonBody<S, U> {
length: Option<usize>, length: Option<usize>,
err: Option<JsonPayloadError>, err: Option<JsonPayloadError>,
timeout: ResponseTimeout,
fut: Option<ReadBody<S>>, fut: Option<ReadBody<S>>,
_phantom: PhantomData<U>, _phantom: PhantomData<U>,
} }
@ -244,9 +326,9 @@ where
U: DeserializeOwned, U: DeserializeOwned,
{ {
/// Create `JsonBody` for request. /// Create `JsonBody` for request.
pub fn new(req: &mut ClientResponse<S>) -> Self { pub fn new(res: &mut ClientResponse<S>) -> Self {
// check content-type // check content-type
let json = if let Ok(Some(mime)) = req.mime_type() { let json = if let Ok(Some(mime)) = res.mime_type() {
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
} else { } else {
false false
@ -255,13 +337,15 @@ where
return JsonBody { return JsonBody {
length: None, length: None,
fut: None, fut: None,
timeout: ResponseTimeout::default(),
err: Some(JsonPayloadError::ContentType), err: Some(JsonPayloadError::ContentType),
_phantom: PhantomData, _phantom: PhantomData,
}; };
} }
let mut len = None; let mut len = None;
if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) {
if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) {
if let Ok(s) = l.to_str() { if let Ok(s) = l.to_str() {
if let Ok(l) = s.parse::<usize>() { if let Ok(l) = s.parse::<usize>() {
len = Some(l) len = Some(l)
@ -272,7 +356,8 @@ where
JsonBody { JsonBody {
length: len, length: len,
err: None, err: None,
fut: Some(ReadBody::new(req.take_payload(), 65536)), timeout: std::mem::take(&mut res.timeout),
fut: Some(ReadBody::new(res.take_payload(), 65536)),
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@ -311,6 +396,10 @@ where
} }
} }
self.timeout
.poll_timeout(cx)
.map_err(JsonPayloadError::Payload)?;
let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?;
Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from)) Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
} }

View File

@ -22,11 +22,7 @@ use futures_core::Stream;
use serde::Serialize; use serde::Serialize;
#[cfg(feature = "compress")] #[cfg(feature = "compress")]
use actix_http::encoding::Decoder; use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream};
#[cfg(feature = "compress")]
use actix_http::http::header::ContentEncoding;
#[cfg(feature = "compress")]
use actix_http::{Payload, PayloadStream};
use crate::connect::{ConnectRequest, ConnectResponse}; use crate::connect::{ConnectRequest, ConnectResponse};
use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError};
@ -89,21 +85,25 @@ impl Future for SendClientRequest {
match this { match this {
SendClientRequest::Fut(send, delay, response_decompress) => { SendClientRequest::Fut(send, delay, response_decompress) => {
if delay.is_some() { if let Some(delay) = delay {
match Pin::new(delay.as_mut().unwrap()).poll(cx) { if delay.as_mut().poll(cx).is_ready() {
Poll::Pending => {} return Poll::Ready(Err(SendRequestError::Timeout));
_ => return Poll::Ready(Err(SendRequestError::Timeout)),
} }
} }
let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { let res = futures_core::ready!(send.as_mut().poll(cx)).map(|res| {
res.into_client_response().map_body(|head, payload| { res.into_client_response()._timeout(delay.take()).map_body(
|head, payload| {
if *response_decompress { if *response_decompress {
Payload::Stream(Decoder::from_headers(payload, &head.headers)) Payload::Stream(Decoder::from_headers(payload, &head.headers))
} else { } else {
Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) Payload::Stream(Decoder::new(
payload,
ContentEncoding::Identity,
))
} }
}) },
)
}); });
Poll::Ready(res) Poll::Ready(res)
@ -124,15 +124,14 @@ impl Future for SendClientRequest {
let this = self.get_mut(); let this = self.get_mut();
match this { match this {
SendClientRequest::Fut(send, delay, _) => { SendClientRequest::Fut(send, delay, _) => {
if delay.is_some() { if let Some(delay) = delay {
match Pin::new(delay.as_mut().unwrap()).poll(cx) { if delay.as_mut().poll(cx).is_ready() {
Poll::Pending => {} return Poll::Ready(Err(SendRequestError::Timeout));
_ => return Poll::Ready(Err(SendRequestError::Timeout)),
} }
} }
Pin::new(send) send.as_mut()
.poll(cx) .poll(cx)
.map_ok(|res| res.into_client_response()) .map_ok(|res| res.into_client_response()._timeout(delay.take()))
} }
SendClientRequest::Err(ref mut e) => match e.take() { SendClientRequest::Err(ref mut e) => match e.take() {
Some(e) => Poll::Ready(Err(e)), Some(e) => Poll::Ready(Err(e)),

View File

@ -24,7 +24,7 @@ use actix_web::{
middleware::Compress, middleware::Compress,
test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, test, web, App, Error, HttpMessage, HttpRequest, HttpResponse,
}; };
use awc::error::SendRequestError; use awc::error::{JsonPayloadError, PayloadError, SendRequestError};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
@ -157,6 +157,79 @@ async fn test_timeout_override() {
} }
} }
#[actix_rt::test]
async fn test_response_timeout() {
use futures_util::stream::{once, StreamExt};
let srv = test::start(|| {
App::new().service(web::resource("/").route(web::to(|| async {
Ok::<_, Error>(
HttpResponse::Ok()
.content_type("application/json")
.streaming(Box::pin(once(async {
actix_rt::time::sleep(Duration::from_millis(200)).await;
Ok::<_, Error>(Bytes::from(STR))
}))),
)
})))
});
let client = awc::Client::new();
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(500))
.body()
.await
.unwrap();
assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR);
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(100))
.next()
.await
.unwrap();
match res {
Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut),
_ => panic!("Response error type is not matched"),
}
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(100))
.body()
.await;
match res {
Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut),
_ => panic!("Response error type is not matched"),
}
let res = client
.get(srv.url("/"))
.send()
.await
.unwrap()
.timeout(Duration::from_millis(100))
.json::<HashMap<String, String>>()
.await;
match res {
Err(JsonPayloadError::Payload(PayloadError::Io(e))) => {
assert_eq!(e.kind(), std::io::ErrorKind::TimedOut)
}
_ => panic!("Response error type is not matched"),
}
}
#[actix_rt::test] #[actix_rt::test]
async fn test_connection_reuse() { async fn test_connection_reuse() {
let num = Arc::new(AtomicUsize::new(0)); let num = Arc::new(AtomicUsize::new(0));