Merge branch 'master' into tokio0.3

This commit is contained in:
fakeshadow 2020-12-21 00:55:47 +08:00 committed by GitHub
commit b4a8c6e1f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 377 additions and 308 deletions

View File

@ -1,3 +1,5 @@
comment: false
coverage: coverage:
status: status:
project: project:

View File

@ -4,7 +4,8 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_http::error::Error; use actix_http::error::Error;
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures_util::future::{ready, Ready};
use futures_util::ready;
use crate::dev::Payload; use crate::dev::Payload;
use crate::request::HttpRequest; use crate::request::HttpRequest;
@ -95,21 +96,41 @@ where
T: FromRequest, T: FromRequest,
T::Future: 'static, T::Future: 'static,
{ {
type Config = T::Config;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Option<T>, Error>>; type Future = FromRequestOptFuture<T::Future>;
type Config = T::Config;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
T::from_request(req, payload) FromRequestOptFuture {
.then(|r| match r { fut: T::from_request(req, payload),
Ok(v) => ok(Some(v)), }
Err(e) => { }
log::debug!("Error for Option<T> extractor: {}", e.into()); }
ok(None)
} #[pin_project::pin_project]
}) pub struct FromRequestOptFuture<Fut> {
.boxed_local() #[pin]
fut: Fut,
}
impl<Fut, T, E> Future for FromRequestOptFuture<Fut>
where
Fut: Future<Output = Result<T, E>>,
E: Into<Error>,
{
type Output = Result<Option<T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let res = ready!(this.fut.poll(cx));
match res {
Ok(t) => Poll::Ready(Ok(Some(t))),
Err(e) => {
log::debug!("Error for Option<T> extractor: {}", e.into());
Poll::Ready(Ok(None))
}
}
} }
} }
@ -165,29 +186,45 @@ where
T::Error: 'static, T::Error: 'static,
T::Future: 'static, T::Future: 'static,
{ {
type Config = T::Config;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Result<T, T::Error>, Error>>; type Future = FromRequestResFuture<T::Future>;
type Config = T::Config;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
T::from_request(req, payload) FromRequestResFuture {
.then(|res| match res { fut: T::from_request(req, payload),
Ok(v) => ok(Ok(v)), }
Err(e) => ok(Err(e)), }
}) }
.boxed_local()
#[pin_project::pin_project]
pub struct FromRequestResFuture<Fut> {
#[pin]
fut: Fut,
}
impl<Fut, T, E> Future for FromRequestResFuture<Fut>
where
Fut: Future<Output = Result<T, E>>,
{
type Output = Result<Result<T, E>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let res = ready!(this.fut.poll(cx));
Poll::Ready(Ok(res))
} }
} }
#[doc(hidden)] #[doc(hidden)]
impl FromRequest for () { impl FromRequest for () {
type Config = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<(), Error>>; type Future = Ready<Result<(), Error>>;
type Config = ();
fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future {
ok(()) ready(Ok(()))
} }
} }

View File

@ -90,26 +90,20 @@ where
} }
fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future {
HandlerServiceResponse { let fut = self.hnd.call(param);
fut: self.hnd.call(param), HandlerServiceResponse::Future(fut, Some(req))
fut2: None,
req: Some(req),
}
} }
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project] #[pin_project(project = HandlerProj)]
pub struct HandlerServiceResponse<T, R> pub enum HandlerServiceResponse<T, R>
where where
T: Future<Output = R>, T: Future<Output = R>,
R: Responder, R: Responder,
{ {
#[pin] Future(#[pin] T, Option<HttpRequest>),
fut: T, Responder(#[pin] R::Future, Option<HttpRequest>),
#[pin]
fut2: Option<R::Future>,
req: Option<HttpRequest>,
} }
impl<T, R> Future for HandlerServiceResponse<T, R> impl<T, R> Future for HandlerServiceResponse<T, R>
@ -120,28 +114,26 @@ where
type Output = Result<ServiceResponse, Infallible>; type Output = Result<ServiceResponse, Infallible>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project(); loop {
match self.as_mut().project() {
if let Some(fut) = this.fut2.as_pin_mut() { HandlerProj::Future(fut, req) => {
return match fut.poll(cx) { let res = ready!(fut.poll(cx));
Poll::Ready(Ok(res)) => { let fut = res.respond_to(req.as_ref().unwrap());
Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) let state = HandlerServiceResponse::Responder(fut, req.take());
self.as_mut().set(state);
} }
Poll::Pending => Poll::Pending, HandlerProj::Responder(fut, req) => {
Poll::Ready(Err(e)) => { let res = ready!(fut.poll(cx));
let res: Response = e.into().into(); let req = req.take().unwrap();
Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) return match res {
Ok(res) => Poll::Ready(Ok(ServiceResponse::new(req, res))),
Err(e) => {
let res: Response = e.into().into();
Poll::Ready(Ok(ServiceResponse::new(req, res)))
}
};
} }
};
}
match this.fut.poll(cx) {
Poll::Ready(res) => {
let fut = res.respond_to(this.req.as_ref().unwrap());
self.as_mut().project().fut2.set(Some(fut));
self.poll(cx)
} }
Poll::Pending => Poll::Pending,
} }
} }
} }
@ -169,12 +161,12 @@ where
Error = Infallible, Error = Infallible,
> + Clone, > + Clone,
{ {
type Config = ();
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse; type Response = ServiceResponse;
type Error = (Error, ServiceRequest); type Error = Error;
type InitError = (); type Config = ();
type Service = ExtractService<T, S>; type Service = ExtractService<T, S>;
type InitError = ();
type Future = Ready<Result<Self::Service, ()>>; type Future = Ready<Result<Self::Service, ()>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
@ -200,7 +192,7 @@ where
{ {
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse; type Response = ServiceResponse;
type Error = (Error, ServiceRequest); type Error = Error;
type Future = ExtractResponse<T, S>; type Future = ExtractResponse<T, S>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -210,24 +202,14 @@ where
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
let (req, mut payload) = req.into_parts(); let (req, mut payload) = req.into_parts();
let fut = T::from_request(&req, &mut payload); let fut = T::from_request(&req, &mut payload);
ExtractResponse::Future(fut, Some(req), self.service.clone())
ExtractResponse {
fut,
req,
fut_s: None,
service: self.service.clone(),
}
} }
} }
#[pin_project] #[pin_project(project = ExtractProj)]
pub struct ExtractResponse<T: FromRequest, S: Service> { pub enum ExtractResponse<T: FromRequest, S: Service> {
req: HttpRequest, Future(#[pin] T::Future, Option<HttpRequest>, S),
service: S, Response(#[pin] S::Future),
#[pin]
fut: T::Future,
#[pin]
fut_s: Option<S::Future>,
} }
impl<T: FromRequest, S> Future for ExtractResponse<T, S> impl<T: FromRequest, S> Future for ExtractResponse<T, S>
@ -238,24 +220,26 @@ where
Error = Infallible, Error = Infallible,
>, >,
{ {
type Output = Result<ServiceResponse, (Error, ServiceRequest)>; type Output = Result<ServiceResponse, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project(); loop {
match self.as_mut().project() {
if let Some(fut) = this.fut_s.as_pin_mut() { ExtractProj::Future(fut, req, srv) => {
return fut.poll(cx).map_err(|_| panic!()); let res = ready!(fut.poll(cx));
} let req = req.take().unwrap();
match res {
match ready!(this.fut.poll(cx)) { Err(e) => {
Err(e) => { let req = ServiceRequest::new(req);
let req = ServiceRequest::new(this.req.clone()); return Poll::Ready(Ok(req.error_response(e.into())));
Poll::Ready(Err((e.into(), req))) }
} Ok(item) => {
Ok(item) => { let fut = srv.call((item, req));
let fut = Some(this.service.call((item, this.req.clone()))); self.as_mut().set(ExtractResponse::Response(fut));
self.as_mut().project().fut_s.set(fut); }
self.poll(cx) }
}
ExtractProj::Response(fut) => return fut.poll(cx).map_err(|_| panic!()),
} }
} }
} }

View File

@ -1,10 +1,14 @@
//! Middleware for setting default response headers //! Middleware for setting default response headers
use std::convert::TryFrom; use std::convert::TryFrom;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures_util::future::{ready, Ready};
use futures_util::ready;
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use crate::http::{Error as HttpError, HeaderMap}; use crate::http::{Error as HttpError, HeaderMap};
@ -97,15 +101,15 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type InitError = ();
type Transform = DefaultHeadersMiddleware<S>; type Transform = DefaultHeadersMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(DefaultHeadersMiddleware { ready(Ok(DefaultHeadersMiddleware {
service, service,
inner: self.inner.clone(), inner: self.inner.clone(),
}) }))
} }
} }
@ -122,36 +126,56 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = DefaultHeaderFuture<S, B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx) self.service.poll_ready(cx)
} }
#[allow(clippy::borrow_interior_mutable_const)]
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();
let fut = self.service.call(req); let fut = self.service.call(req);
async move { DefaultHeaderFuture {
let mut res = fut.await?; fut,
inner,
// set response headers _body: PhantomData,
for (key, value) in inner.headers.iter() {
if !res.headers().contains_key(key) {
res.headers_mut().insert(key.clone(), value.clone());
}
}
// default content-type
if inner.ct && !res.headers().contains_key(&CONTENT_TYPE) {
res.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
}
Ok(res)
} }
.boxed_local() }
}
#[pin_project::pin_project]
pub struct DefaultHeaderFuture<S: Service, B> {
#[pin]
fut: S::Future,
inner: Rc<Inner>,
_body: PhantomData<B>,
}
impl<S, B> Future for DefaultHeaderFuture<S, B>
where
S: Service<Response = ServiceResponse<B>, Error = Error>,
{
type Output = <S::Future as Future>::Output;
#[allow(clippy::borrow_interior_mutable_const)]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut res = ready!(this.fut.poll(cx))?;
// set response headers
for (key, value) in this.inner.headers.iter() {
if !res.headers().contains_key(key) {
res.headers_mut().insert(key.clone(), value.clone());
}
}
// default content-type
if this.inner.ct && !res.headers().contains_key(&CONTENT_TYPE) {
res.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
}
Poll::Ready(Ok(res))
} }
} }

View File

@ -234,7 +234,7 @@ impl Route {
struct RouteNewService<T> struct RouteNewService<T>
where where
T: ServiceFactory<Request = ServiceRequest, Error = (Error, ServiceRequest)>, T: ServiceFactory<Request = ServiceRequest, Error = Error>,
{ {
service: T, service: T,
} }
@ -245,7 +245,7 @@ where
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse, Response = ServiceResponse,
Error = (Error, ServiceRequest), Error = Error,
>, >,
T::Future: 'static, T::Future: 'static,
T::Service: 'static, T::Service: 'static,
@ -262,7 +262,7 @@ where
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse, Response = ServiceResponse,
Error = (Error, ServiceRequest), Error = Error,
>, >,
T::Future: 'static, T::Future: 'static,
T::Service: 'static, T::Service: 'static,
@ -297,11 +297,7 @@ struct RouteServiceWrapper<T: Service> {
impl<T> Service for RouteServiceWrapper<T> impl<T> Service for RouteServiceWrapper<T>
where where
T::Future: 'static, T::Future: 'static,
T: Service< T: Service<Request = ServiceRequest, Response = ServiceResponse, Error = Error>,
Request = ServiceRequest,
Response = ServiceResponse,
Error = (Error, ServiceRequest),
>,
{ {
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse; type Response = ServiceResponse;
@ -309,27 +305,11 @@ where
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx).map_err(|(e, _)| e) self.service.poll_ready(cx)
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
// let mut fut = self.service.call(req); Box::pin(self.service.call(req))
self.service
.call(req)
.map(|res| match res {
Ok(res) => Ok(res),
Err((err, req)) => Ok(req.error_response(err)),
})
.boxed_local()
// match fut.poll() {
// Poll::Ready(Ok(res)) => Either::Left(ok(res)),
// Poll::Ready(Err((e, req))) => Either::Left(ok(req.error_response(e))),
// Poll::Pending => Either::Right(Box::new(fut.then(|res| match res {
// Ok(res) => Ok(res),
// Err((err, req)) => Ok(req.error_response(err)),
// }))),
// }
} }
} }

View File

@ -1,14 +1,16 @@
//! Json extractor/responder //! Json extractor/responder
use std::future::Future; use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{fmt, ops}; use std::{fmt, ops};
use bytes::BytesMut; use bytes::BytesMut;
use futures_util::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; use futures_util::future::{ready, Ready};
use futures_util::StreamExt; use futures_util::ready;
use futures_util::stream::Stream;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
@ -127,12 +129,12 @@ impl<T: Serialize> Responder for Json<T> {
fn respond_to(self, _: &HttpRequest) -> Self::Future { fn respond_to(self, _: &HttpRequest) -> Self::Future {
let body = match serde_json::to_string(&self.0) { let body = match serde_json::to_string(&self.0) {
Ok(body) => body, Ok(body) => body,
Err(e) => return err(e.into()), Err(e) => return ready(Err(e.into())),
}; };
ok(Response::build(StatusCode::OK) ready(Ok(Response::build(StatusCode::OK)
.content_type("application/json") .content_type("application/json")
.body(body)) .body(body)))
} }
} }
@ -173,37 +175,64 @@ where
T: DeserializeOwned + 'static, T: DeserializeOwned + 'static,
{ {
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self, Error>>; type Future = JsonExtractFut<T>;
type Config = JsonConfig; type Config = JsonConfig;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
let req2 = req.clone();
let config = JsonConfig::from_req(req); let config = JsonConfig::from_req(req);
let limit = config.limit; let limit = config.limit;
let ctype = config.content_type.clone(); let ctype = config.content_type.clone();
let err_handler = config.err_handler.clone(); let err_handler = config.err_handler.clone();
JsonBody::new(req, payload, ctype) JsonExtractFut {
.limit(limit) req: Some(req.clone()),
.map(move |res| match res { fut: JsonBody::new(req, payload, ctype).limit(limit),
Err(e) => { err_handler,
log::debug!( }
"Failed to deserialize Json from payload. \ }
Request path: {}", }
req2.path()
);
if let Some(err) = err_handler { type JsonErrorHandler =
Err((*err)(e, &req2)) Option<Arc<dyn Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync>>;
} else {
Err(e.into()) pub struct JsonExtractFut<T> {
} req: Option<HttpRequest>,
fut: JsonBody<T>,
err_handler: JsonErrorHandler,
}
impl<T> Future for JsonExtractFut<T>
where
T: DeserializeOwned + 'static,
{
type Output = Result<Json<T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let res = ready!(Pin::new(&mut this.fut).poll(cx));
let res = match res {
Err(e) => {
let req = this.req.take().unwrap();
log::debug!(
"Failed to deserialize Json from payload. \
Request path: {}",
req.path()
);
if let Some(err) = this.err_handler.as_ref() {
Err((*err)(e, &req))
} else {
Err(e.into())
} }
Ok(data) => Ok(Json(data)), }
}) Ok(data) => Ok(Json(data)),
.boxed_local() };
Poll::Ready(res)
} }
} }
@ -248,8 +277,7 @@ where
#[derive(Clone)] #[derive(Clone)]
pub struct JsonConfig { pub struct JsonConfig {
limit: usize, limit: usize,
err_handler: err_handler: JsonErrorHandler,
Option<Arc<dyn Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync>>,
content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>, content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>,
} }
@ -308,17 +336,22 @@ impl Default for JsonConfig {
/// * content type is not `application/json` /// * content type is not `application/json`
/// (unless specified in [`JsonConfig`]) /// (unless specified in [`JsonConfig`])
/// * content length is greater than 256k /// * content length is greater than 256k
pub struct JsonBody<U> { pub enum JsonBody<U> {
limit: usize, Error(Option<JsonPayloadError>),
length: Option<usize>, Body {
#[cfg(feature = "compress")] limit: usize,
stream: Option<Decompress<Payload>>, length: Option<usize>,
#[cfg(not(feature = "compress"))] #[cfg(feature = "compress")]
stream: Option<Payload>, payload: Decompress<Payload>,
err: Option<JsonPayloadError>, #[cfg(not(feature = "compress"))]
fut: Option<LocalBoxFuture<'static, Result<U, JsonPayloadError>>>, payload: Payload,
buf: BytesMut,
_res: PhantomData<U>,
},
} }
impl<U> Unpin for JsonBody<U> {}
impl<U> JsonBody<U> impl<U> JsonBody<U>
where where
U: DeserializeOwned + 'static, U: DeserializeOwned + 'static,
@ -340,39 +373,58 @@ where
}; };
if !json { if !json {
return JsonBody { return JsonBody::Error(Some(JsonPayloadError::ContentType));
limit: 262_144,
length: None,
stream: None,
fut: None,
err: Some(JsonPayloadError::ContentType),
};
} }
let len = req let length = req
.headers() .headers()
.get(&CONTENT_LENGTH) .get(&CONTENT_LENGTH)
.and_then(|l| l.to_str().ok()) .and_then(|l| l.to_str().ok())
.and_then(|s| s.parse::<usize>().ok()); .and_then(|s| s.parse::<usize>().ok());
// Notice the content_length is not checked against limit of json config here.
// As the internal usage always call JsonBody::limit after JsonBody::new.
// And limit check to return an error variant of JsonBody happens there.
#[cfg(feature = "compress")] #[cfg(feature = "compress")]
let payload = Decompress::from_headers(payload.take(), req.headers()); let payload = Decompress::from_headers(payload.take(), req.headers());
#[cfg(not(feature = "compress"))] #[cfg(not(feature = "compress"))]
let payload = payload.take(); let payload = payload.take();
JsonBody { JsonBody::Body {
limit: 262_144, limit: 262_144,
length: len, length,
stream: Some(payload), payload,
fut: None, buf: BytesMut::with_capacity(8192),
err: None, _res: PhantomData,
} }
} }
/// Change max size of payload. By default max size is 256Kb /// Change max size of payload. By default max size is 256Kb
pub fn limit(mut self, limit: usize) -> Self { pub fn limit(self, limit: usize) -> Self {
self.limit = limit; match self {
self JsonBody::Body {
length,
payload,
buf,
..
} => {
if let Some(len) = length {
if len > limit {
return JsonBody::Error(Some(JsonPayloadError::Overflow));
}
}
JsonBody::Body {
limit,
length,
payload,
buf,
_res: PhantomData,
}
}
JsonBody::Error(e) => JsonBody::Error(e),
}
} }
} }
@ -382,41 +434,34 @@ where
{ {
type Output = Result<U, JsonPayloadError>; type Output = Result<U, JsonPayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut { let this = self.get_mut();
return Pin::new(fut).poll(cx);
}
if let Some(err) = self.err.take() { match this {
return Poll::Ready(Err(err)); JsonBody::Body {
} limit,
buf,
let limit = self.limit; payload,
if let Some(len) = self.length.take() { ..
if len > limit { } => loop {
return Poll::Ready(Err(JsonPayloadError::Overflow)); let res = ready!(Pin::new(&mut *payload).poll_next(cx));
} match res {
} Some(chunk) => {
let mut stream = self.stream.take().unwrap(); let chunk = chunk?;
if (buf.len() + chunk.len()) > *limit {
self.fut = Some( return Poll::Ready(Err(JsonPayloadError::Overflow));
async move { } else {
let mut body = BytesMut::with_capacity(8192); buf.extend_from_slice(&chunk);
}
while let Some(item) = stream.next().await { }
let chunk = item?; None => {
if (body.len() + chunk.len()) > limit { let json = serde_json::from_slice::<U>(&buf)?;
return Err(JsonPayloadError::Overflow); return Poll::Ready(Ok(json));
} else {
body.extend_from_slice(&chunk);
} }
} }
Ok(serde_json::from_slice::<U>(&body)?) },
} JsonBody::Error(e) => Poll::Ready(Err(e.take().unwrap())),
.boxed_local(), }
);
self.poll(cx)
} }
} }

View File

@ -7,10 +7,12 @@ use std::task::{Context, Poll};
use actix_http::error::{Error, ErrorBadRequest, PayloadError}; use actix_http::error::{Error, ErrorBadRequest, PayloadError};
use actix_http::HttpMessage; use actix_http::HttpMessage;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use encoding_rs::UTF_8; use encoding_rs::{Encoding, UTF_8};
use futures_core::stream::Stream; use futures_core::stream::Stream;
use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; use futures_util::{
use futures_util::StreamExt; future::{err, ok, Either, ErrInto, Ready, TryFutureExt as _},
ready,
};
use mime::Mime; use mime::Mime;
use crate::extract::FromRequest; use crate::extract::FromRequest;
@ -135,10 +137,7 @@ impl FromRequest for Payload {
impl FromRequest for Bytes { impl FromRequest for Bytes {
type Config = PayloadConfig; type Config = PayloadConfig;
type Error = Error; type Error = Error;
type Future = Either< type Future = Either<ErrInto<HttpMessageBody, Error>, Ready<Result<Bytes, Error>>>;
LocalBoxFuture<'static, Result<Bytes, Error>>,
Ready<Result<Bytes, Error>>,
>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
@ -151,7 +150,7 @@ impl FromRequest for Bytes {
let limit = cfg.limit; let limit = cfg.limit;
let fut = HttpMessageBody::new(req, payload).limit(limit); let fut = HttpMessageBody::new(req, payload).limit(limit);
Either::Left(async move { Ok(fut.await?) }.boxed_local()) Either::Left(fut.err_into())
} }
} }
@ -185,10 +184,7 @@ impl FromRequest for Bytes {
impl FromRequest for String { impl FromRequest for String {
type Config = PayloadConfig; type Config = PayloadConfig;
type Error = Error; type Error = Error;
type Future = Either< type Future = Either<StringExtractFut, Ready<Result<String, Error>>>;
LocalBoxFuture<'static, Result<String, Error>>,
Ready<Result<String, Error>>,
>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
@ -205,25 +201,40 @@ impl FromRequest for String {
Err(e) => return Either::Right(err(e.into())), Err(e) => return Either::Right(err(e.into())),
}; };
let limit = cfg.limit; let limit = cfg.limit;
let fut = HttpMessageBody::new(req, payload).limit(limit); let body_fut = HttpMessageBody::new(req, payload).limit(limit);
Either::Left( Either::Left(StringExtractFut { body_fut, encoding })
async move { }
let body = fut.await?; }
if encoding == UTF_8 { pub struct StringExtractFut {
Ok(str::from_utf8(body.as_ref()) body_fut: HttpMessageBody,
.map_err(|_| ErrorBadRequest("Can not decode body"))? encoding: &'static Encoding,
.to_owned()) }
} else {
Ok(encoding impl<'a> Future for StringExtractFut {
.decode_without_bom_handling_and_without_replacement(&body) type Output = Result<String, Error>;
.map(|s| s.into_owned())
.ok_or_else(|| ErrorBadRequest("Can not decode body"))?) fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
} let encoding = self.encoding;
}
.boxed_local(), Pin::new(&mut self.body_fut).poll(cx).map(|out| {
) let body = out?;
bytes_to_string(body, encoding)
})
}
}
fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result<String, Error> {
if encoding == UTF_8 {
Ok(str::from_utf8(body.as_ref())
.map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned())
} else {
Ok(encoding
.decode_without_bom_handling_and_without_replacement(&body)
.map(|s| s.into_owned())
.ok_or_else(|| ErrorBadRequest("Can not decode body"))?)
} }
} }
@ -291,10 +302,12 @@ impl PayloadConfig {
// Allow shared refs to default. // Allow shared refs to default.
const DEFAULT_CONFIG: PayloadConfig = PayloadConfig { const DEFAULT_CONFIG: PayloadConfig = PayloadConfig {
limit: 262_144, // 2^18 bytes (~256kB) limit: DEFAULT_CONFIG_LIMIT,
mimetype: None, mimetype: None,
}; };
const DEFAULT_CONFIG_LIMIT: usize = 262_144; // 2^18 bytes (~256kB)
impl Default for PayloadConfig { impl Default for PayloadConfig {
fn default() -> Self { fn default() -> Self {
DEFAULT_CONFIG.clone() DEFAULT_CONFIG.clone()
@ -312,99 +325,83 @@ pub struct HttpMessageBody {
limit: usize, limit: usize,
length: Option<usize>, length: Option<usize>,
#[cfg(feature = "compress")] #[cfg(feature = "compress")]
stream: Option<dev::Decompress<dev::Payload>>, stream: dev::Decompress<dev::Payload>,
#[cfg(not(feature = "compress"))] #[cfg(not(feature = "compress"))]
stream: Option<dev::Payload>, stream: dev::Payload,
buf: BytesMut,
err: Option<PayloadError>, err: Option<PayloadError>,
fut: Option<LocalBoxFuture<'static, Result<Bytes, PayloadError>>>,
} }
impl HttpMessageBody { impl HttpMessageBody {
/// Create `MessageBody` for request. /// Create `MessageBody` for request.
#[allow(clippy::borrow_interior_mutable_const)] #[allow(clippy::borrow_interior_mutable_const)]
pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> HttpMessageBody { pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> HttpMessageBody {
let mut len = None; let mut length = None;
let mut err = None;
if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) {
if let Ok(s) = l.to_str() { match l.to_str() {
if let Ok(l) = s.parse::<usize>() { Ok(s) => match s.parse::<usize>() {
len = Some(l) Ok(l) if l > DEFAULT_CONFIG_LIMIT => {
} else { err = Some(PayloadError::Overflow)
return Self::err(PayloadError::UnknownLength); }
} Ok(l) => length = Some(l),
} else { Err(_) => err = Some(PayloadError::UnknownLength),
return Self::err(PayloadError::UnknownLength); },
Err(_) => err = Some(PayloadError::UnknownLength),
} }
} }
#[cfg(feature = "compress")] #[cfg(feature = "compress")]
let stream = Some(dev::Decompress::from_headers(payload.take(), req.headers())); let stream = dev::Decompress::from_headers(payload.take(), req.headers());
#[cfg(not(feature = "compress"))] #[cfg(not(feature = "compress"))]
let stream = Some(payload.take()); let stream = payload.take();
HttpMessageBody { HttpMessageBody {
stream, stream,
limit: 262_144, limit: DEFAULT_CONFIG_LIMIT,
length: len, length,
fut: None, buf: BytesMut::with_capacity(8192),
err: None, err,
} }
} }
/// Change max size of payload. By default max size is 256Kb /// Change max size of payload. By default max size is 256Kb
pub fn limit(mut self, limit: usize) -> Self { pub fn limit(mut self, limit: usize) -> Self {
if let Some(l) = self.length {
if l > limit {
self.err = Some(PayloadError::Overflow);
}
}
self.limit = limit; self.limit = limit;
self self
} }
fn err(e: PayloadError) -> Self {
HttpMessageBody {
stream: None,
limit: 262_144,
fut: None,
err: Some(e),
length: None,
}
}
} }
impl Future for HttpMessageBody { impl Future for HttpMessageBody {
type Output = Result<Bytes, PayloadError>; type Output = Result<Bytes, PayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut { let this = self.get_mut();
return Pin::new(fut).poll(cx);
if let Some(e) = this.err.take() {
return Poll::Ready(Err(e));
} }
if let Some(err) = self.err.take() { loop {
return Poll::Ready(Err(err)); let res = ready!(Pin::new(&mut this.stream).poll_next(cx));
} match res {
Some(chunk) => {
if let Some(len) = self.length.take() { let chunk = chunk?;
if len > self.limit { if this.buf.len() + chunk.len() > this.limit {
return Poll::Ready(Err(PayloadError::Overflow)); return Poll::Ready(Err(PayloadError::Overflow));
}
}
// future
let limit = self.limit;
let mut stream = self.stream.take().unwrap();
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if body.len() + chunk.len() > limit {
return Err(PayloadError::Overflow);
} else { } else {
body.extend_from_slice(&chunk); this.buf.extend_from_slice(&chunk);
} }
} }
Ok(body.freeze()) None => return Poll::Ready(Ok(this.buf.split().freeze())),
} }
.boxed_local(), }
);
self.poll(cx)
} }
} }