Merge branch 'master' into tokio0.3

This commit is contained in:
fakeshadow 2020-12-21 00:55:47 +08:00 committed by GitHub
commit b4a8c6e1f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 377 additions and 308 deletions

View File

@ -1,3 +1,5 @@
comment: false
coverage:
status:
project:

View File

@ -4,7 +4,8 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use actix_http::error::Error;
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready};
use futures_util::future::{ready, Ready};
use futures_util::ready;
use crate::dev::Payload;
use crate::request::HttpRequest;
@ -95,21 +96,41 @@ where
T: FromRequest,
T::Future: 'static,
{
type Config = T::Config;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Option<T>, Error>>;
type Future = FromRequestOptFuture<T::Future>;
type Config = T::Config;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
T::from_request(req, payload)
.then(|r| match r {
Ok(v) => ok(Some(v)),
Err(e) => {
log::debug!("Error for Option<T> extractor: {}", e.into());
ok(None)
}
})
.boxed_local()
FromRequestOptFuture {
fut: T::from_request(req, payload),
}
}
}
#[pin_project::pin_project]
pub struct FromRequestOptFuture<Fut> {
#[pin]
fut: Fut,
}
impl<Fut, T, E> Future for FromRequestOptFuture<Fut>
where
Fut: Future<Output = Result<T, E>>,
E: Into<Error>,
{
type Output = Result<Option<T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let res = ready!(this.fut.poll(cx));
match res {
Ok(t) => Poll::Ready(Ok(Some(t))),
Err(e) => {
log::debug!("Error for Option<T> extractor: {}", e.into());
Poll::Ready(Ok(None))
}
}
}
}
@ -165,29 +186,45 @@ where
T::Error: 'static,
T::Future: 'static,
{
type Config = T::Config;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Result<T, T::Error>, Error>>;
type Future = FromRequestResFuture<T::Future>;
type Config = T::Config;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
T::from_request(req, payload)
.then(|res| match res {
Ok(v) => ok(Ok(v)),
Err(e) => ok(Err(e)),
})
.boxed_local()
FromRequestResFuture {
fut: T::from_request(req, payload),
}
}
}
#[pin_project::pin_project]
pub struct FromRequestResFuture<Fut> {
#[pin]
fut: Fut,
}
impl<Fut, T, E> Future for FromRequestResFuture<Fut>
where
Fut: Future<Output = Result<T, E>>,
{
type Output = Result<Result<T, E>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let res = ready!(this.fut.poll(cx));
Poll::Ready(Ok(res))
}
}
#[doc(hidden)]
impl FromRequest for () {
type Config = ();
type Error = Error;
type Future = Ready<Result<(), Error>>;
type Config = ();
fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future {
ok(())
ready(Ok(()))
}
}

View File

@ -90,26 +90,20 @@ where
}
fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future {
HandlerServiceResponse {
fut: self.hnd.call(param),
fut2: None,
req: Some(req),
}
let fut = self.hnd.call(param);
HandlerServiceResponse::Future(fut, Some(req))
}
}
#[doc(hidden)]
#[pin_project]
pub struct HandlerServiceResponse<T, R>
#[pin_project(project = HandlerProj)]
pub enum HandlerServiceResponse<T, R>
where
T: Future<Output = R>,
R: Responder,
{
#[pin]
fut: T,
#[pin]
fut2: Option<R::Future>,
req: Option<HttpRequest>,
Future(#[pin] T, Option<HttpRequest>),
Responder(#[pin] R::Future, Option<HttpRequest>),
}
impl<T, R> Future for HandlerServiceResponse<T, R>
@ -120,28 +114,26 @@ where
type Output = Result<ServiceResponse, Infallible>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();
if let Some(fut) = this.fut2.as_pin_mut() {
return match fut.poll(cx) {
Poll::Ready(Ok(res)) => {
Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res)))
loop {
match self.as_mut().project() {
HandlerProj::Future(fut, req) => {
let res = ready!(fut.poll(cx));
let fut = res.respond_to(req.as_ref().unwrap());
let state = HandlerServiceResponse::Responder(fut, req.take());
self.as_mut().set(state);
}
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res)))
HandlerProj::Responder(fut, req) => {
let res = ready!(fut.poll(cx));
let req = req.take().unwrap();
return match res {
Ok(res) => Poll::Ready(Ok(ServiceResponse::new(req, res))),
Err(e) => {
let res: Response = e.into().into();
Poll::Ready(Ok(ServiceResponse::new(req, res)))
}
};
}
};
}
match this.fut.poll(cx) {
Poll::Ready(res) => {
let fut = res.respond_to(this.req.as_ref().unwrap());
self.as_mut().project().fut2.set(Some(fut));
self.poll(cx)
}
Poll::Pending => Poll::Pending,
}
}
}
@ -169,12 +161,12 @@ where
Error = Infallible,
> + Clone,
{
type Config = ();
type Request = ServiceRequest;
type Response = ServiceResponse;
type Error = (Error, ServiceRequest);
type InitError = ();
type Error = Error;
type Config = ();
type Service = ExtractService<T, S>;
type InitError = ();
type Future = Ready<Result<Self::Service, ()>>;
fn new_service(&self, _: ()) -> Self::Future {
@ -200,7 +192,7 @@ where
{
type Request = ServiceRequest;
type Response = ServiceResponse;
type Error = (Error, ServiceRequest);
type Error = Error;
type Future = ExtractResponse<T, S>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -210,24 +202,14 @@ where
fn call(&mut self, req: ServiceRequest) -> Self::Future {
let (req, mut payload) = req.into_parts();
let fut = T::from_request(&req, &mut payload);
ExtractResponse {
fut,
req,
fut_s: None,
service: self.service.clone(),
}
ExtractResponse::Future(fut, Some(req), self.service.clone())
}
}
#[pin_project]
pub struct ExtractResponse<T: FromRequest, S: Service> {
req: HttpRequest,
service: S,
#[pin]
fut: T::Future,
#[pin]
fut_s: Option<S::Future>,
#[pin_project(project = ExtractProj)]
pub enum ExtractResponse<T: FromRequest, S: Service> {
Future(#[pin] T::Future, Option<HttpRequest>, S),
Response(#[pin] S::Future),
}
impl<T: FromRequest, S> Future for ExtractResponse<T, S>
@ -238,24 +220,26 @@ where
Error = Infallible,
>,
{
type Output = Result<ServiceResponse, (Error, ServiceRequest)>;
type Output = Result<ServiceResponse, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();
if let Some(fut) = this.fut_s.as_pin_mut() {
return fut.poll(cx).map_err(|_| panic!());
}
match ready!(this.fut.poll(cx)) {
Err(e) => {
let req = ServiceRequest::new(this.req.clone());
Poll::Ready(Err((e.into(), req)))
}
Ok(item) => {
let fut = Some(this.service.call((item, this.req.clone())));
self.as_mut().project().fut_s.set(fut);
self.poll(cx)
loop {
match self.as_mut().project() {
ExtractProj::Future(fut, req, srv) => {
let res = ready!(fut.poll(cx));
let req = req.take().unwrap();
match res {
Err(e) => {
let req = ServiceRequest::new(req);
return Poll::Ready(Ok(req.error_response(e.into())));
}
Ok(item) => {
let fut = srv.call((item, req));
self.as_mut().set(ExtractResponse::Response(fut));
}
}
}
ExtractProj::Response(fut) => return fut.poll(cx).map_err(|_| panic!()),
}
}
}

View File

@ -1,10 +1,14 @@
//! Middleware for setting default response headers
use std::convert::TryFrom;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform};
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready};
use futures_util::future::{ready, Ready};
use futures_util::ready;
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use crate::http::{Error as HttpError, HeaderMap};
@ -97,15 +101,15 @@ where
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = DefaultHeadersMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(DefaultHeadersMiddleware {
ready(Ok(DefaultHeadersMiddleware {
service,
inner: self.inner.clone(),
})
}))
}
}
@ -122,36 +126,56 @@ where
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
type Future = DefaultHeaderFuture<S, B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
#[allow(clippy::borrow_interior_mutable_const)]
fn call(&mut self, req: ServiceRequest) -> Self::Future {
let inner = self.inner.clone();
let fut = self.service.call(req);
async move {
let mut res = fut.await?;
// set response headers
for (key, value) in inner.headers.iter() {
if !res.headers().contains_key(key) {
res.headers_mut().insert(key.clone(), value.clone());
}
}
// default content-type
if inner.ct && !res.headers().contains_key(&CONTENT_TYPE) {
res.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
}
Ok(res)
DefaultHeaderFuture {
fut,
inner,
_body: PhantomData,
}
.boxed_local()
}
}
#[pin_project::pin_project]
pub struct DefaultHeaderFuture<S: Service, B> {
#[pin]
fut: S::Future,
inner: Rc<Inner>,
_body: PhantomData<B>,
}
impl<S, B> Future for DefaultHeaderFuture<S, B>
where
S: Service<Response = ServiceResponse<B>, Error = Error>,
{
type Output = <S::Future as Future>::Output;
#[allow(clippy::borrow_interior_mutable_const)]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut res = ready!(this.fut.poll(cx))?;
// set response headers
for (key, value) in this.inner.headers.iter() {
if !res.headers().contains_key(key) {
res.headers_mut().insert(key.clone(), value.clone());
}
}
// default content-type
if this.inner.ct && !res.headers().contains_key(&CONTENT_TYPE) {
res.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
}
Poll::Ready(Ok(res))
}
}

View File

@ -234,7 +234,7 @@ impl Route {
struct RouteNewService<T>
where
T: ServiceFactory<Request = ServiceRequest, Error = (Error, ServiceRequest)>,
T: ServiceFactory<Request = ServiceRequest, Error = Error>,
{
service: T,
}
@ -245,7 +245,7 @@ where
Config = (),
Request = ServiceRequest,
Response = ServiceResponse,
Error = (Error, ServiceRequest),
Error = Error,
>,
T::Future: 'static,
T::Service: 'static,
@ -262,7 +262,7 @@ where
Config = (),
Request = ServiceRequest,
Response = ServiceResponse,
Error = (Error, ServiceRequest),
Error = Error,
>,
T::Future: 'static,
T::Service: 'static,
@ -297,11 +297,7 @@ struct RouteServiceWrapper<T: Service> {
impl<T> Service for RouteServiceWrapper<T>
where
T::Future: 'static,
T: Service<
Request = ServiceRequest,
Response = ServiceResponse,
Error = (Error, ServiceRequest),
>,
T: Service<Request = ServiceRequest, Response = ServiceResponse, Error = Error>,
{
type Request = ServiceRequest;
type Response = ServiceResponse;
@ -309,27 +305,11 @@ where
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx).map_err(|(e, _)| e)
self.service.poll_ready(cx)
}
fn call(&mut self, req: ServiceRequest) -> Self::Future {
// let mut fut = self.service.call(req);
self.service
.call(req)
.map(|res| match res {
Ok(res) => Ok(res),
Err((err, req)) => Ok(req.error_response(err)),
})
.boxed_local()
// match fut.poll() {
// Poll::Ready(Ok(res)) => Either::Left(ok(res)),
// Poll::Ready(Err((e, req))) => Either::Left(ok(req.error_response(e))),
// Poll::Pending => Either::Right(Box::new(fut.then(|res| match res {
// Ok(res) => Ok(res),
// Err((err, req)) => Ok(req.error_response(err)),
// }))),
// }
Box::pin(self.service.call(req))
}
}

View File

@ -1,14 +1,16 @@
//! Json extractor/responder
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{fmt, ops};
use bytes::BytesMut;
use futures_util::future::{err, ok, FutureExt, LocalBoxFuture, Ready};
use futures_util::StreamExt;
use futures_util::future::{ready, Ready};
use futures_util::ready;
use futures_util::stream::Stream;
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -127,12 +129,12 @@ impl<T: Serialize> Responder for Json<T> {
fn respond_to(self, _: &HttpRequest) -> Self::Future {
let body = match serde_json::to_string(&self.0) {
Ok(body) => body,
Err(e) => return err(e.into()),
Err(e) => return ready(Err(e.into())),
};
ok(Response::build(StatusCode::OK)
ready(Ok(Response::build(StatusCode::OK)
.content_type("application/json")
.body(body))
.body(body)))
}
}
@ -173,37 +175,64 @@ where
T: DeserializeOwned + 'static,
{
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self, Error>>;
type Future = JsonExtractFut<T>;
type Config = JsonConfig;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
let req2 = req.clone();
let config = JsonConfig::from_req(req);
let limit = config.limit;
let ctype = config.content_type.clone();
let err_handler = config.err_handler.clone();
JsonBody::new(req, payload, ctype)
.limit(limit)
.map(move |res| match res {
Err(e) => {
log::debug!(
"Failed to deserialize Json from payload. \
Request path: {}",
req2.path()
);
JsonExtractFut {
req: Some(req.clone()),
fut: JsonBody::new(req, payload, ctype).limit(limit),
err_handler,
}
}
}
if let Some(err) = err_handler {
Err((*err)(e, &req2))
} else {
Err(e.into())
}
type JsonErrorHandler =
Option<Arc<dyn Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync>>;
pub struct JsonExtractFut<T> {
req: Option<HttpRequest>,
fut: JsonBody<T>,
err_handler: JsonErrorHandler,
}
impl<T> Future for JsonExtractFut<T>
where
T: DeserializeOwned + 'static,
{
type Output = Result<Json<T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let res = ready!(Pin::new(&mut this.fut).poll(cx));
let res = match res {
Err(e) => {
let req = this.req.take().unwrap();
log::debug!(
"Failed to deserialize Json from payload. \
Request path: {}",
req.path()
);
if let Some(err) = this.err_handler.as_ref() {
Err((*err)(e, &req))
} else {
Err(e.into())
}
Ok(data) => Ok(Json(data)),
})
.boxed_local()
}
Ok(data) => Ok(Json(data)),
};
Poll::Ready(res)
}
}
@ -248,8 +277,7 @@ where
#[derive(Clone)]
pub struct JsonConfig {
limit: usize,
err_handler:
Option<Arc<dyn Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync>>,
err_handler: JsonErrorHandler,
content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>,
}
@ -308,17 +336,22 @@ impl Default for JsonConfig {
/// * content type is not `application/json`
/// (unless specified in [`JsonConfig`])
/// * content length is greater than 256k
pub struct JsonBody<U> {
limit: usize,
length: Option<usize>,
#[cfg(feature = "compress")]
stream: Option<Decompress<Payload>>,
#[cfg(not(feature = "compress"))]
stream: Option<Payload>,
err: Option<JsonPayloadError>,
fut: Option<LocalBoxFuture<'static, Result<U, JsonPayloadError>>>,
pub enum JsonBody<U> {
Error(Option<JsonPayloadError>),
Body {
limit: usize,
length: Option<usize>,
#[cfg(feature = "compress")]
payload: Decompress<Payload>,
#[cfg(not(feature = "compress"))]
payload: Payload,
buf: BytesMut,
_res: PhantomData<U>,
},
}
impl<U> Unpin for JsonBody<U> {}
impl<U> JsonBody<U>
where
U: DeserializeOwned + 'static,
@ -340,39 +373,58 @@ where
};
if !json {
return JsonBody {
limit: 262_144,
length: None,
stream: None,
fut: None,
err: Some(JsonPayloadError::ContentType),
};
return JsonBody::Error(Some(JsonPayloadError::ContentType));
}
let len = req
let length = req
.headers()
.get(&CONTENT_LENGTH)
.and_then(|l| l.to_str().ok())
.and_then(|s| s.parse::<usize>().ok());
// Notice the content_length is not checked against limit of json config here.
// As the internal usage always call JsonBody::limit after JsonBody::new.
// And limit check to return an error variant of JsonBody happens there.
#[cfg(feature = "compress")]
let payload = Decompress::from_headers(payload.take(), req.headers());
#[cfg(not(feature = "compress"))]
let payload = payload.take();
JsonBody {
JsonBody::Body {
limit: 262_144,
length: len,
stream: Some(payload),
fut: None,
err: None,
length,
payload,
buf: BytesMut::with_capacity(8192),
_res: PhantomData,
}
}
/// Change max size of payload. By default max size is 256Kb
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
pub fn limit(self, limit: usize) -> Self {
match self {
JsonBody::Body {
length,
payload,
buf,
..
} => {
if let Some(len) = length {
if len > limit {
return JsonBody::Error(Some(JsonPayloadError::Overflow));
}
}
JsonBody::Body {
limit,
length,
payload,
buf,
_res: PhantomData,
}
}
JsonBody::Error(e) => JsonBody::Error(e),
}
}
}
@ -382,41 +434,34 @@ where
{
type Output = Result<U, JsonPayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut {
return Pin::new(fut).poll(cx);
}
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if let Some(err) = self.err.take() {
return Poll::Ready(Err(err));
}
let limit = self.limit;
if let Some(len) = self.length.take() {
if len > limit {
return Poll::Ready(Err(JsonPayloadError::Overflow));
}
}
let mut stream = self.stream.take().unwrap();
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if (body.len() + chunk.len()) > limit {
return Err(JsonPayloadError::Overflow);
} else {
body.extend_from_slice(&chunk);
match this {
JsonBody::Body {
limit,
buf,
payload,
..
} => loop {
let res = ready!(Pin::new(&mut *payload).poll_next(cx));
match res {
Some(chunk) => {
let chunk = chunk?;
if (buf.len() + chunk.len()) > *limit {
return Poll::Ready(Err(JsonPayloadError::Overflow));
} else {
buf.extend_from_slice(&chunk);
}
}
None => {
let json = serde_json::from_slice::<U>(&buf)?;
return Poll::Ready(Ok(json));
}
}
Ok(serde_json::from_slice::<U>(&body)?)
}
.boxed_local(),
);
self.poll(cx)
},
JsonBody::Error(e) => Poll::Ready(Err(e.take().unwrap())),
}
}
}

View File

@ -7,10 +7,12 @@ use std::task::{Context, Poll};
use actix_http::error::{Error, ErrorBadRequest, PayloadError};
use actix_http::HttpMessage;
use bytes::{Bytes, BytesMut};
use encoding_rs::UTF_8;
use encoding_rs::{Encoding, UTF_8};
use futures_core::stream::Stream;
use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures_util::StreamExt;
use futures_util::{
future::{err, ok, Either, ErrInto, Ready, TryFutureExt as _},
ready,
};
use mime::Mime;
use crate::extract::FromRequest;
@ -135,10 +137,7 @@ impl FromRequest for Payload {
impl FromRequest for Bytes {
type Config = PayloadConfig;
type Error = Error;
type Future = Either<
LocalBoxFuture<'static, Result<Bytes, Error>>,
Ready<Result<Bytes, Error>>,
>;
type Future = Either<ErrInto<HttpMessageBody, Error>, Ready<Result<Bytes, Error>>>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
@ -151,7 +150,7 @@ impl FromRequest for Bytes {
let limit = cfg.limit;
let fut = HttpMessageBody::new(req, payload).limit(limit);
Either::Left(async move { Ok(fut.await?) }.boxed_local())
Either::Left(fut.err_into())
}
}
@ -185,10 +184,7 @@ impl FromRequest for Bytes {
impl FromRequest for String {
type Config = PayloadConfig;
type Error = Error;
type Future = Either<
LocalBoxFuture<'static, Result<String, Error>>,
Ready<Result<String, Error>>,
>;
type Future = Either<StringExtractFut, Ready<Result<String, Error>>>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
@ -205,25 +201,40 @@ impl FromRequest for String {
Err(e) => return Either::Right(err(e.into())),
};
let limit = cfg.limit;
let fut = HttpMessageBody::new(req, payload).limit(limit);
let body_fut = HttpMessageBody::new(req, payload).limit(limit);
Either::Left(
async move {
let body = fut.await?;
Either::Left(StringExtractFut { body_fut, encoding })
}
}
if encoding == UTF_8 {
Ok(str::from_utf8(body.as_ref())
.map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned())
} else {
Ok(encoding
.decode_without_bom_handling_and_without_replacement(&body)
.map(|s| s.into_owned())
.ok_or_else(|| ErrorBadRequest("Can not decode body"))?)
}
}
.boxed_local(),
)
pub struct StringExtractFut {
body_fut: HttpMessageBody,
encoding: &'static Encoding,
}
impl<'a> Future for StringExtractFut {
type Output = Result<String, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let encoding = self.encoding;
Pin::new(&mut self.body_fut).poll(cx).map(|out| {
let body = out?;
bytes_to_string(body, encoding)
})
}
}
fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result<String, Error> {
if encoding == UTF_8 {
Ok(str::from_utf8(body.as_ref())
.map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned())
} else {
Ok(encoding
.decode_without_bom_handling_and_without_replacement(&body)
.map(|s| s.into_owned())
.ok_or_else(|| ErrorBadRequest("Can not decode body"))?)
}
}
@ -291,10 +302,12 @@ impl PayloadConfig {
// Allow shared refs to default.
const DEFAULT_CONFIG: PayloadConfig = PayloadConfig {
limit: 262_144, // 2^18 bytes (~256kB)
limit: DEFAULT_CONFIG_LIMIT,
mimetype: None,
};
const DEFAULT_CONFIG_LIMIT: usize = 262_144; // 2^18 bytes (~256kB)
impl Default for PayloadConfig {
fn default() -> Self {
DEFAULT_CONFIG.clone()
@ -312,99 +325,83 @@ pub struct HttpMessageBody {
limit: usize,
length: Option<usize>,
#[cfg(feature = "compress")]
stream: Option<dev::Decompress<dev::Payload>>,
stream: dev::Decompress<dev::Payload>,
#[cfg(not(feature = "compress"))]
stream: Option<dev::Payload>,
stream: dev::Payload,
buf: BytesMut,
err: Option<PayloadError>,
fut: Option<LocalBoxFuture<'static, Result<Bytes, PayloadError>>>,
}
impl HttpMessageBody {
/// Create `MessageBody` for request.
#[allow(clippy::borrow_interior_mutable_const)]
pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> HttpMessageBody {
let mut len = None;
let mut length = None;
let mut err = None;
if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) {
if let Ok(s) = l.to_str() {
if let Ok(l) = s.parse::<usize>() {
len = Some(l)
} else {
return Self::err(PayloadError::UnknownLength);
}
} else {
return Self::err(PayloadError::UnknownLength);
match l.to_str() {
Ok(s) => match s.parse::<usize>() {
Ok(l) if l > DEFAULT_CONFIG_LIMIT => {
err = Some(PayloadError::Overflow)
}
Ok(l) => length = Some(l),
Err(_) => err = Some(PayloadError::UnknownLength),
},
Err(_) => err = Some(PayloadError::UnknownLength),
}
}
#[cfg(feature = "compress")]
let stream = Some(dev::Decompress::from_headers(payload.take(), req.headers()));
let stream = dev::Decompress::from_headers(payload.take(), req.headers());
#[cfg(not(feature = "compress"))]
let stream = Some(payload.take());
let stream = payload.take();
HttpMessageBody {
stream,
limit: 262_144,
length: len,
fut: None,
err: None,
limit: DEFAULT_CONFIG_LIMIT,
length,
buf: BytesMut::with_capacity(8192),
err,
}
}
/// Change max size of payload. By default max size is 256Kb
pub fn limit(mut self, limit: usize) -> Self {
if let Some(l) = self.length {
if l > limit {
self.err = Some(PayloadError::Overflow);
}
}
self.limit = limit;
self
}
fn err(e: PayloadError) -> Self {
HttpMessageBody {
stream: None,
limit: 262_144,
fut: None,
err: Some(e),
length: None,
}
}
}
impl Future for HttpMessageBody {
type Output = Result<Bytes, PayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut {
return Pin::new(fut).poll(cx);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if let Some(e) = this.err.take() {
return Poll::Ready(Err(e));
}
if let Some(err) = self.err.take() {
return Poll::Ready(Err(err));
}
if let Some(len) = self.length.take() {
if len > self.limit {
return Poll::Ready(Err(PayloadError::Overflow));
}
}
// future
let limit = self.limit;
let mut stream = self.stream.take().unwrap();
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if body.len() + chunk.len() > limit {
return Err(PayloadError::Overflow);
loop {
let res = ready!(Pin::new(&mut this.stream).poll_next(cx));
match res {
Some(chunk) => {
let chunk = chunk?;
if this.buf.len() + chunk.len() > this.limit {
return Poll::Ready(Err(PayloadError::Overflow));
} else {
body.extend_from_slice(&chunk);
this.buf.extend_from_slice(&chunk);
}
}
Ok(body.freeze())
None => return Poll::Ready(Ok(this.buf.split().freeze())),
}
.boxed_local(),
);
self.poll(cx)
}
}
}