From d7ce6484457be229c4c9b4fcd3581d599b7ff9f1 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Thu, 17 Dec 2020 02:34:10 +0800 Subject: [PATCH 1/8] remove boxed future for Option and Result extract type (#1829) * remove boxed future for Option and Result extract type * use ready macro * fix fmt --- src/extract.rs | 81 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 22 deletions(-) diff --git a/src/extract.rs b/src/extract.rs index df9c34cb3..5916b1bc5 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -4,7 +4,8 @@ use std::pin::Pin; use std::task::{Context, Poll}; use actix_http::error::Error; -use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use futures_util::future::{ready, Ready}; +use futures_util::ready; use crate::dev::Payload; use crate::request::HttpRequest; @@ -95,21 +96,41 @@ where T: FromRequest, T::Future: 'static, { - type Config = T::Config; type Error = Error; - type Future = LocalBoxFuture<'static, Result, Error>>; + type Future = FromRequestOptFuture; + type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - T::from_request(req, payload) - .then(|r| match r { - Ok(v) => ok(Some(v)), - Err(e) => { - log::debug!("Error for Option extractor: {}", e.into()); - ok(None) - } - }) - .boxed_local() + FromRequestOptFuture { + fut: T::from_request(req, payload), + } + } +} + +#[pin_project::pin_project] +pub struct FromRequestOptFuture { + #[pin] + fut: Fut, +} + +impl Future for FromRequestOptFuture +where + Fut: Future>, + E: Into, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let res = ready!(this.fut.poll(cx)); + match res { + Ok(t) => Poll::Ready(Ok(Some(t))), + Err(e) => { + log::debug!("Error for Option extractor: {}", e.into()); + Poll::Ready(Ok(None)) + } + } } } @@ -165,29 +186,45 @@ where T::Error: 'static, T::Future: 'static, { - type Config = T::Config; type Error = Error; - type Future = LocalBoxFuture<'static, Result, Error>>; + type Future = FromRequestResFuture; + type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - T::from_request(req, payload) - .then(|res| match res { - Ok(v) => ok(Ok(v)), - Err(e) => ok(Err(e)), - }) - .boxed_local() + FromRequestResFuture { + fut: T::from_request(req, payload), + } + } +} + +#[pin_project::pin_project] +pub struct FromRequestResFuture { + #[pin] + fut: Fut, +} + +impl Future for FromRequestResFuture +where + Fut: Future>, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let res = ready!(this.fut.poll(cx)); + Poll::Ready(Ok(res)) } } #[doc(hidden)] impl FromRequest for () { - type Config = (); type Error = Error; type Future = Ready>; + type Config = (); fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { - ok(()) + ready(Ok(())) } } From 1a361273e7bf15fb43b8f7a10334a0edae4fde0a Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Wed, 16 Dec 2020 22:40:26 +0000 Subject: [PATCH 2/8] optimize bytes and string payload extractors (#1831) --- src/types/payload.rs | 72 ++++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/src/types/payload.rs b/src/types/payload.rs index acb8b9a82..fd4d3e945 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -7,10 +7,15 @@ use std::task::{Context, Poll}; use actix_http::error::{Error, ErrorBadRequest, PayloadError}; use actix_http::HttpMessage; use bytes::{Bytes, BytesMut}; -use encoding_rs::UTF_8; +use encoding_rs::{Encoding, UTF_8}; use futures_core::stream::Stream; -use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; -use futures_util::StreamExt; +use futures_util::{ + future::{ + err, ok, Either, ErrInto, FutureExt as _, LocalBoxFuture, Ready, + TryFutureExt as _, + }, + stream::StreamExt as _, +}; use mime::Mime; use crate::extract::FromRequest; @@ -135,10 +140,7 @@ impl FromRequest for Payload { impl FromRequest for Bytes { type Config = PayloadConfig; type Error = Error; - type Future = Either< - LocalBoxFuture<'static, Result>, - Ready>, - >; + type Future = Either, Ready>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -151,7 +153,7 @@ impl FromRequest for Bytes { let limit = cfg.limit; let fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left(async move { Ok(fut.await?) }.boxed_local()) + Either::Left(fut.err_into()) } } @@ -185,10 +187,7 @@ impl FromRequest for Bytes { impl FromRequest for String { type Config = PayloadConfig; type Error = Error; - type Future = Either< - LocalBoxFuture<'static, Result>, - Ready>, - >; + type Future = Either>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -205,25 +204,40 @@ impl FromRequest for String { Err(e) => return Either::Right(err(e.into())), }; let limit = cfg.limit; - let fut = HttpMessageBody::new(req, payload).limit(limit); + let body_fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left( - async move { - let body = fut.await?; + Either::Left(StringExtractFut { body_fut, encoding }) + } +} - if encoding == UTF_8 { - Ok(str::from_utf8(body.as_ref()) - .map_err(|_| ErrorBadRequest("Can not decode body"))? - .to_owned()) - } else { - Ok(encoding - .decode_without_bom_handling_and_without_replacement(&body) - .map(|s| s.into_owned()) - .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) - } - } - .boxed_local(), - ) +pub struct StringExtractFut { + body_fut: HttpMessageBody, + encoding: &'static Encoding, +} + +impl<'a> Future for StringExtractFut { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let encoding = self.encoding; + + Pin::new(&mut self.body_fut).poll(cx).map(|out| { + let body = out?; + bytes_to_string(body, encoding) + }) + } +} + +fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result { + if encoding == UTF_8 { + Ok(str::from_utf8(body.as_ref()) + .map_err(|_| ErrorBadRequest("Can not decode body"))? + .to_owned()) + } else { + Ok(encoding + .decode_without_bom_handling_and_without_replacement(&body) + .map(|s| s.into_owned()) + .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) } } From 97f615c245a9ad86a129a94e50a600f2006da306 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Thu, 17 Dec 2020 07:34:33 +0800 Subject: [PATCH 3/8] remove boxed futures on Json extract type (#1832) --- src/types/json.rs | 211 ++++++++++++++++++++++++++++------------------ 1 file changed, 128 insertions(+), 83 deletions(-) diff --git a/src/types/json.rs b/src/types/json.rs index 83c9f21b0..95613a0ce 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -1,14 +1,16 @@ //! Json extractor/responder use std::future::Future; +use std::marker::PhantomData; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::{fmt, ops}; use bytes::BytesMut; -use futures_util::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; -use futures_util::StreamExt; +use futures_util::future::{ready, Ready}; +use futures_util::ready; +use futures_util::stream::Stream; use serde::de::DeserializeOwned; use serde::Serialize; @@ -127,12 +129,12 @@ impl Responder for Json { fn respond_to(self, _: &HttpRequest) -> Self::Future { let body = match serde_json::to_string(&self.0) { Ok(body) => body, - Err(e) => return err(e.into()), + Err(e) => return ready(Err(e.into())), }; - ok(Response::build(StatusCode::OK) + ready(Ok(Response::build(StatusCode::OK) .content_type("application/json") - .body(body)) + .body(body))) } } @@ -173,37 +175,64 @@ where T: DeserializeOwned + 'static, { type Error = Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = JsonExtractFut; type Config = JsonConfig; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - let req2 = req.clone(); let config = JsonConfig::from_req(req); let limit = config.limit; let ctype = config.content_type.clone(); let err_handler = config.err_handler.clone(); - JsonBody::new(req, payload, ctype) - .limit(limit) - .map(move |res| match res { - Err(e) => { - log::debug!( - "Failed to deserialize Json from payload. \ - Request path: {}", - req2.path() - ); + JsonExtractFut { + req: Some(req.clone()), + fut: JsonBody::new(req, payload, ctype).limit(limit), + err_handler, + } + } +} - if let Some(err) = err_handler { - Err((*err)(e, &req2)) - } else { - Err(e.into()) - } +type JsonErrorHandler = + Option Error + Send + Sync>>; + +pub struct JsonExtractFut { + req: Option, + fut: JsonBody, + err_handler: JsonErrorHandler, +} + +impl Future for JsonExtractFut +where + T: DeserializeOwned + 'static, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let res = ready!(Pin::new(&mut this.fut).poll(cx)); + + let res = match res { + Err(e) => { + let req = this.req.take().unwrap(); + log::debug!( + "Failed to deserialize Json from payload. \ + Request path: {}", + req.path() + ); + + if let Some(err) = this.err_handler.as_ref() { + Err((*err)(e, &req)) + } else { + Err(e.into()) } - Ok(data) => Ok(Json(data)), - }) - .boxed_local() + } + Ok(data) => Ok(Json(data)), + }; + + Poll::Ready(res) } } @@ -248,8 +277,7 @@ where #[derive(Clone)] pub struct JsonConfig { limit: usize, - err_handler: - Option Error + Send + Sync>>, + err_handler: JsonErrorHandler, content_type: Option bool + Send + Sync>>, } @@ -308,17 +336,22 @@ impl Default for JsonConfig { /// * content type is not `application/json` /// (unless specified in [`JsonConfig`]) /// * content length is greater than 256k -pub struct JsonBody { - limit: usize, - length: Option, - #[cfg(feature = "compress")] - stream: Option>, - #[cfg(not(feature = "compress"))] - stream: Option, - err: Option, - fut: Option>>, +pub enum JsonBody { + Error(Option), + Body { + limit: usize, + length: Option, + #[cfg(feature = "compress")] + payload: Decompress, + #[cfg(not(feature = "compress"))] + payload: Payload, + buf: BytesMut, + _res: PhantomData, + }, } +impl Unpin for JsonBody {} + impl JsonBody where U: DeserializeOwned + 'static, @@ -340,39 +373,58 @@ where }; if !json { - return JsonBody { - limit: 262_144, - length: None, - stream: None, - fut: None, - err: Some(JsonPayloadError::ContentType), - }; + return JsonBody::Error(Some(JsonPayloadError::ContentType)); } - let len = req + let length = req .headers() .get(&CONTENT_LENGTH) .and_then(|l| l.to_str().ok()) .and_then(|s| s.parse::().ok()); + // Notice the content_length is not checked against limit of json config here. + // As the internal usage always call JsonBody::limit after JsonBody::new. + // And limit check to return an error variant of JsonBody happens there. + #[cfg(feature = "compress")] let payload = Decompress::from_headers(payload.take(), req.headers()); #[cfg(not(feature = "compress"))] let payload = payload.take(); - JsonBody { + JsonBody::Body { limit: 262_144, - length: len, - stream: Some(payload), - fut: None, - err: None, + length, + payload, + buf: BytesMut::with_capacity(8192), + _res: PhantomData, } } /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self + pub fn limit(self, limit: usize) -> Self { + match self { + JsonBody::Body { + length, + payload, + buf, + .. + } => { + if let Some(len) = length { + if len > limit { + return JsonBody::Error(Some(JsonPayloadError::Overflow)); + } + } + + JsonBody::Body { + limit, + length, + payload, + buf, + _res: PhantomData, + } + } + JsonBody::Error(e) => JsonBody::Error(e), + } } } @@ -382,41 +434,34 @@ where { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(ref mut fut) = self.fut { - return Pin::new(fut).poll(cx); - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); - if let Some(err) = self.err.take() { - return Poll::Ready(Err(err)); - } - - let limit = self.limit; - if let Some(len) = self.length.take() { - if len > limit { - return Poll::Ready(Err(JsonPayloadError::Overflow)); - } - } - let mut stream = self.stream.take().unwrap(); - - self.fut = Some( - async move { - let mut body = BytesMut::with_capacity(8192); - - while let Some(item) = stream.next().await { - let chunk = item?; - if (body.len() + chunk.len()) > limit { - return Err(JsonPayloadError::Overflow); - } else { - body.extend_from_slice(&chunk); + match this { + JsonBody::Body { + limit, + buf, + payload, + .. + } => loop { + let res = ready!(Pin::new(&mut *payload).poll_next(cx)); + match res { + Some(chunk) => { + let chunk = chunk?; + if (buf.len() + chunk.len()) > *limit { + return Poll::Ready(Err(JsonPayloadError::Overflow)); + } else { + buf.extend_from_slice(&chunk); + } + } + None => { + let json = serde_json::from_slice::(&buf)?; + return Poll::Ready(Ok(json)); } } - Ok(serde_json::from_slice::(&body)?) - } - .boxed_local(), - ); - - self.poll(cx) + }, + JsonBody::Error(e) => Poll::Ready(Err(e.take().unwrap())), + } } } From 2a5215c1d6cce12ff3f4bdc4e7ac73190d4aa9e0 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Thu, 17 Dec 2020 19:40:49 +0800 Subject: [PATCH 4/8] Remove boxed future from HttpMessage (#1834) --- src/types/payload.rs | 109 ++++++++++++++++++------------------------- 1 file changed, 46 insertions(+), 63 deletions(-) diff --git a/src/types/payload.rs b/src/types/payload.rs index fd4d3e945..9228b37aa 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -10,11 +10,8 @@ use bytes::{Bytes, BytesMut}; use encoding_rs::{Encoding, UTF_8}; use futures_core::stream::Stream; use futures_util::{ - future::{ - err, ok, Either, ErrInto, FutureExt as _, LocalBoxFuture, Ready, - TryFutureExt as _, - }, - stream::StreamExt as _, + future::{err, ok, Either, ErrInto, Ready, TryFutureExt as _}, + ready, }; use mime::Mime; @@ -305,10 +302,12 @@ impl PayloadConfig { // Allow shared refs to default. const DEFAULT_CONFIG: PayloadConfig = PayloadConfig { - limit: 262_144, // 2^18 bytes (~256kB) + limit: DEFAULT_CONFIG_LIMIT, mimetype: None, }; +const DEFAULT_CONFIG_LIMIT: usize = 262_144; // 2^18 bytes (~256kB) + impl Default for PayloadConfig { fn default() -> Self { DEFAULT_CONFIG.clone() @@ -326,99 +325,83 @@ pub struct HttpMessageBody { limit: usize, length: Option, #[cfg(feature = "compress")] - stream: Option>, + stream: dev::Decompress, #[cfg(not(feature = "compress"))] - stream: Option, + stream: dev::Payload, + buf: BytesMut, err: Option, - fut: Option>>, } impl HttpMessageBody { /// Create `MessageBody` for request. #[allow(clippy::borrow_interior_mutable_const)] pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> HttpMessageBody { - let mut len = None; + let mut length = None; + let mut err = None; + if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { - if let Ok(s) = l.to_str() { - if let Ok(l) = s.parse::() { - len = Some(l) - } else { - return Self::err(PayloadError::UnknownLength); - } - } else { - return Self::err(PayloadError::UnknownLength); + match l.to_str() { + Ok(s) => match s.parse::() { + Ok(l) if l > DEFAULT_CONFIG_LIMIT => { + err = Some(PayloadError::Overflow) + } + Ok(l) => length = Some(l), + Err(_) => err = Some(PayloadError::UnknownLength), + }, + Err(_) => err = Some(PayloadError::UnknownLength), } } #[cfg(feature = "compress")] - let stream = Some(dev::Decompress::from_headers(payload.take(), req.headers())); + let stream = dev::Decompress::from_headers(payload.take(), req.headers()); #[cfg(not(feature = "compress"))] - let stream = Some(payload.take()); + let stream = payload.take(); HttpMessageBody { stream, - limit: 262_144, - length: len, - fut: None, - err: None, + limit: DEFAULT_CONFIG_LIMIT, + length, + buf: BytesMut::with_capacity(8192), + err, } } /// Change max size of payload. By default max size is 256Kb pub fn limit(mut self, limit: usize) -> Self { + if let Some(l) = self.length { + if l > limit { + self.err = Some(PayloadError::Overflow); + } + } self.limit = limit; self } - - fn err(e: PayloadError) -> Self { - HttpMessageBody { - stream: None, - limit: 262_144, - fut: None, - err: Some(e), - length: None, - } - } } impl Future for HttpMessageBody { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(ref mut fut) = self.fut { - return Pin::new(fut).poll(cx); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if let Some(e) = this.err.take() { + return Poll::Ready(Err(e)); } - if let Some(err) = self.err.take() { - return Poll::Ready(Err(err)); - } - - if let Some(len) = self.length.take() { - if len > self.limit { - return Poll::Ready(Err(PayloadError::Overflow)); - } - } - - // future - let limit = self.limit; - let mut stream = self.stream.take().unwrap(); - self.fut = Some( - async move { - let mut body = BytesMut::with_capacity(8192); - - while let Some(item) = stream.next().await { - let chunk = item?; - if body.len() + chunk.len() > limit { - return Err(PayloadError::Overflow); + loop { + let res = ready!(Pin::new(&mut this.stream).poll_next(cx)); + match res { + Some(chunk) => { + let chunk = chunk?; + if this.buf.len() + chunk.len() > this.limit { + return Poll::Ready(Err(PayloadError::Overflow)); } else { - body.extend_from_slice(&chunk); + this.buf.extend_from_slice(&chunk); } } - Ok(body.freeze()) + None => return Poll::Ready(Ok(this.buf.split().freeze())), } - .boxed_local(), - ); - self.poll(cx) + } } } From c7b4c6edfa0a3e7e6c6618584ec189ed5b3d99bd Mon Sep 17 00:00:00 2001 From: Yuki Okushi Date: Thu, 17 Dec 2020 21:38:52 +0900 Subject: [PATCH 5/8] Disable PR comment from codecov --- codecov.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codecov.yml b/codecov.yml index 102e8969d..e6bc40203 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,3 +1,5 @@ +comment: false + coverage: status: project: From a4dbaa8ed11651b33125d39d8cee0f90a3bfed61 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Sat, 19 Dec 2020 07:08:59 +0800 Subject: [PATCH 6/8] remove boxed future in DefaultHeaders middleware (#1838) --- src/middleware/defaultheaders.rs | 72 +++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index 6d43aba95..a6f1a4336 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -1,10 +1,14 @@ //! Middleware for setting default response headers use std::convert::TryFrom; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use futures_util::future::{ready, Ready}; +use futures_util::ready; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::{Error as HttpError, HeaderMap}; @@ -97,15 +101,15 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type InitError = (); type Transform = DefaultHeadersMiddleware; + type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ok(DefaultHeadersMiddleware { + ready(Ok(DefaultHeadersMiddleware { service, inner: self.inner.clone(), - }) + })) } } @@ -122,36 +126,56 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = DefaultHeaderFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - #[allow(clippy::borrow_interior_mutable_const)] fn call(&mut self, req: ServiceRequest) -> Self::Future { let inner = self.inner.clone(); let fut = self.service.call(req); - async move { - let mut res = fut.await?; - - // set response headers - for (key, value) in inner.headers.iter() { - if !res.headers().contains_key(key) { - res.headers_mut().insert(key.clone(), value.clone()); - } - } - // default content-type - if inner.ct && !res.headers().contains_key(&CONTENT_TYPE) { - res.headers_mut().insert( - CONTENT_TYPE, - HeaderValue::from_static("application/octet-stream"), - ); - } - Ok(res) + DefaultHeaderFuture { + fut, + inner, + _body: PhantomData, } - .boxed_local() + } +} + +#[pin_project::pin_project] +pub struct DefaultHeaderFuture { + #[pin] + fut: S::Future, + inner: Rc, + _body: PhantomData, +} + +impl Future for DefaultHeaderFuture +where + S: Service, Error = Error>, +{ + type Output = ::Output; + + #[allow(clippy::borrow_interior_mutable_const)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let mut res = ready!(this.fut.poll(cx))?; + // set response headers + for (key, value) in this.inner.headers.iter() { + if !res.headers().contains_key(key) { + res.headers_mut().insert(key.clone(), value.clone()); + } + } + // default content-type + if this.inner.ct && !res.headers().contains_key(&CONTENT_TYPE) { + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + } + Poll::Ready(Ok(res)) } } From 79de04d8625deff301fe76fdf8649543f2976ae5 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Sun, 20 Dec 2020 00:33:34 +0800 Subject: [PATCH 7/8] optimise Extract service (#1841) --- src/handler.rs | 110 +++++++++++++++++++++---------------------------- 1 file changed, 47 insertions(+), 63 deletions(-) diff --git a/src/handler.rs b/src/handler.rs index 669512ab3..db6c5ce0a 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -90,26 +90,20 @@ where } fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { - HandlerServiceResponse { - fut: self.hnd.call(param), - fut2: None, - req: Some(req), - } + let fut = self.hnd.call(param); + HandlerServiceResponse::Future(fut, Some(req)) } } #[doc(hidden)] -#[pin_project] -pub struct HandlerServiceResponse +#[pin_project(project = HandlerProj)] +pub enum HandlerServiceResponse where T: Future, R: Responder, { - #[pin] - fut: T, - #[pin] - fut2: Option, - req: Option, + Future(#[pin] T, Option), + Responder(#[pin] R::Future, Option), } impl Future for HandlerServiceResponse @@ -120,28 +114,26 @@ where type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - if let Some(fut) = this.fut2.as_pin_mut() { - return match fut.poll(cx) { - Poll::Ready(Ok(res)) => { - Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + loop { + match self.as_mut().project() { + HandlerProj::Future(fut, req) => { + let res = ready!(fut.poll(cx)); + let fut = res.respond_to(req.as_ref().unwrap()); + let state = HandlerServiceResponse::Responder(fut, req.take()); + self.as_mut().set(state); } - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); - Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + HandlerProj::Responder(fut, req) => { + let res = ready!(fut.poll(cx)); + let req = req.take().unwrap(); + return match res { + Ok(res) => Poll::Ready(Ok(ServiceResponse::new(req, res))), + Err(e) => { + let res: Response = e.into().into(); + Poll::Ready(Ok(ServiceResponse::new(req, res))) + } + }; } - }; - } - - match this.fut.poll(cx) { - Poll::Ready(res) => { - let fut = res.respond_to(this.req.as_ref().unwrap()); - self.as_mut().project().fut2.set(Some(fut)); - self.poll(cx) } - Poll::Pending => Poll::Pending, } } } @@ -169,12 +161,12 @@ where Error = Infallible, > + Clone, { - type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; type Error = (Error, ServiceRequest); - type InitError = (); + type Config = (); type Service = ExtractService; + type InitError = (); type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { @@ -210,24 +202,14 @@ where fn call(&mut self, req: ServiceRequest) -> Self::Future { let (req, mut payload) = req.into_parts(); let fut = T::from_request(&req, &mut payload); - - ExtractResponse { - fut, - req, - fut_s: None, - service: self.service.clone(), - } + ExtractResponse::Future(fut, Some(req), self.service.clone()) } } -#[pin_project] -pub struct ExtractResponse { - req: HttpRequest, - service: S, - #[pin] - fut: T::Future, - #[pin] - fut_s: Option, +#[pin_project(project = ExtractProj)] +pub enum ExtractResponse { + Future(#[pin] T::Future, Option, S), + Response(#[pin] S::Future), } impl Future for ExtractResponse @@ -241,21 +223,23 @@ where type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - if let Some(fut) = this.fut_s.as_pin_mut() { - return fut.poll(cx).map_err(|_| panic!()); - } - - match ready!(this.fut.poll(cx)) { - Err(e) => { - let req = ServiceRequest::new(this.req.clone()); - Poll::Ready(Err((e.into(), req))) - } - Ok(item) => { - let fut = Some(this.service.call((item, this.req.clone()))); - self.as_mut().project().fut_s.set(fut); - self.poll(cx) + loop { + match self.as_mut().project() { + ExtractProj::Future(fut, req, srv) => { + let res = ready!(fut.poll(cx)); + let req = req.take().unwrap(); + match res { + Err(e) => { + let req = ServiceRequest::new(req); + return Poll::Ready(Err((e.into(), req))); + } + Ok(item) => { + let fut = srv.call((item, req)); + self.as_mut().set(ExtractResponse::Response(fut)); + } + } + } + ExtractProj::Response(fut) => return fut.poll(cx).map_err(|_| panic!()), } } } From 6cbf27508af7e8c38d2fe174fdb26062b35340ed Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Sun, 20 Dec 2020 10:20:29 +0800 Subject: [PATCH 8/8] simplify ExtractService's return type (#1842) --- src/handler.rs | 8 ++++---- src/route.rs | 32 ++++++-------------------------- 2 files changed, 10 insertions(+), 30 deletions(-) diff --git a/src/handler.rs b/src/handler.rs index db6c5ce0a..0dc06b3ce 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -163,7 +163,7 @@ where { type Request = ServiceRequest; type Response = ServiceResponse; - type Error = (Error, ServiceRequest); + type Error = Error; type Config = (); type Service = ExtractService; type InitError = (); @@ -192,7 +192,7 @@ where { type Request = ServiceRequest; type Response = ServiceResponse; - type Error = (Error, ServiceRequest); + type Error = Error; type Future = ExtractResponse; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { @@ -220,7 +220,7 @@ where Error = Infallible, >, { - type Output = Result; + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { @@ -231,7 +231,7 @@ where match res { Err(e) => { let req = ServiceRequest::new(req); - return Poll::Ready(Err((e.into(), req))); + return Poll::Ready(Ok(req.error_response(e.into()))); } Ok(item) => { let fut = srv.call((item, req)); diff --git a/src/route.rs b/src/route.rs index 45efd9e3c..f8ef458f9 100644 --- a/src/route.rs +++ b/src/route.rs @@ -234,7 +234,7 @@ impl Route { struct RouteNewService where - T: ServiceFactory, + T: ServiceFactory, { service: T, } @@ -245,7 +245,7 @@ where Config = (), Request = ServiceRequest, Response = ServiceResponse, - Error = (Error, ServiceRequest), + Error = Error, >, T::Future: 'static, T::Service: 'static, @@ -262,7 +262,7 @@ where Config = (), Request = ServiceRequest, Response = ServiceResponse, - Error = (Error, ServiceRequest), + Error = Error, >, T::Future: 'static, T::Service: 'static, @@ -297,11 +297,7 @@ struct RouteServiceWrapper { impl Service for RouteServiceWrapper where T::Future: 'static, - T: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = (Error, ServiceRequest), - >, + T: Service, { type Request = ServiceRequest; type Response = ServiceResponse; @@ -309,27 +305,11 @@ where type Future = LocalBoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx).map_err(|(e, _)| e) + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - // let mut fut = self.service.call(req); - self.service - .call(req) - .map(|res| match res { - Ok(res) => Ok(res), - Err((err, req)) => Ok(req.error_response(err)), - }) - .boxed_local() - - // match fut.poll() { - // Poll::Ready(Ok(res)) => Either::Left(ok(res)), - // Poll::Ready(Err((e, req))) => Either::Left(ok(req.error_response(e))), - // Poll::Pending => Either::Right(Box::new(fut.then(|res| match res { - // Ok(res) => Ok(res), - // Err((err, req)) => Ok(req.error_response(err)), - // }))), - // } + Box::pin(self.service.call(req)) } }