diff --git a/actix-web/src/middleware/condition.rs b/actix-web/src/middleware/condition.rs index 659f88bc9..acb4c96f2 100644 --- a/actix-web/src/middleware/condition.rs +++ b/actix-web/src/middleware/condition.rs @@ -1,18 +1,20 @@ //! For middleware documentation, see [`Condition`]. -use std::task::{Context, Poll}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use actix_service::{Service, Transform}; -use actix_utils::future::Either; -use futures_core::future::LocalBoxFuture; +use futures_core::{future::LocalBoxFuture, ready}; use futures_util::future::FutureExt as _; +use pin_project_lite::pin_project; + +use crate::{body::EitherBody, dev::ServiceResponse}; /// Middleware for conditionally enabling other middleware. /// -/// The controlled middleware must not change the `Service` interfaces. This means you cannot -/// control such middlewares like `Logger` or `Compress` directly. See the [`Compat`](super::Compat) -/// middleware for a workaround. -/// /// # Examples /// ``` /// use actix_web::middleware::{Condition, NormalizePath}; @@ -36,16 +38,16 @@ impl Condition { } } -impl Transform for Condition +impl Transform for Condition where - S: Service + 'static, - T: Transform, + S: Service, Error = Err> + 'static, + T: Transform, Error = Err>, T::Future: 'static, T::InitError: 'static, T::Transform: 'static, { - type Response = S::Response; - type Error = S::Error; + type Response = ServiceResponse>; + type Error = Err; type Transform = ConditionMiddleware; type InitError = T::InitError; type Future = LocalBoxFuture<'static, Result>; @@ -69,14 +71,14 @@ pub enum ConditionMiddleware { Disable(D), } -impl Service for ConditionMiddleware +impl Service for ConditionMiddleware where - E: Service, - D: Service, + E: Service, Error = Err>, + D: Service, Error = Err>, { - type Response = E::Response; - type Error = E::Error; - type Future = Either; + type Response = ServiceResponse>; + type Error = Err; + type Future = ConditionMiddlewareFuture; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { match self { @@ -87,16 +89,45 @@ where fn call(&self, req: Req) -> Self::Future { match self { - ConditionMiddleware::Enable(service) => Either::left(service.call(req)), - ConditionMiddleware::Disable(service) => Either::right(service.call(req)), + ConditionMiddleware::Enable(service) => ConditionMiddlewareFuture::Left { + fut: service.call(req), + }, + ConditionMiddleware::Disable(service) => ConditionMiddlewareFuture::Right { + fut: service.call(req), + }, } } } +pin_project! { + #[project = EitherProj] + pub enum ConditionMiddlewareFuture { + Left { #[pin] fut: L, }, + Right { #[pin] fut: R, }, + } +} + +impl Future for ConditionMiddlewareFuture +where + L: Future, Err>>, + R: Future, Err>>, +{ + type Output = Result>, Err>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = match self.project() { + EitherProj::Left { fut } => ready!(fut.poll(cx))?.map_into_left_body(), + EitherProj::Right { fut } => ready!(fut.poll(cx))?.map_into_right_body(), + }; + + Poll::Ready(Ok(res)) + } +} + #[cfg(test)] mod tests { use actix_service::IntoService; - use actix_utils::future::ok; use super::*; use crate::{ @@ -106,11 +137,13 @@ mod tests { header::{HeaderValue, CONTENT_TYPE}, StatusCode, }, - middleware::{err_handlers::*, Compat}, + middleware::err_handlers::*, test::{self, TestRequest}, HttpResponse, }; + fn assert_type(_: &T) {} + #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() @@ -122,31 +155,31 @@ mod tests { #[actix_rt::test] async fn test_handler_enabled() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) + let srv = |req: ServiceRequest| async move { + let resp = HttpResponse::InternalServerError().message_body(String::new())?; + Ok(req.into_response(resp)) }; - let mw = Compat::new( - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), - ); + let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(true, mw) .new_transform(srv.into_service()) .await .unwrap(); + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_type::, String>>>(&resp); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } #[actix_rt::test] async fn test_handler_disabled() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) + let srv = |req: ServiceRequest| async move { + let resp = HttpResponse::InternalServerError().message_body(String::new())?; + Ok(req.into_response(resp)) }; - let mw = Compat::new( - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), - ); + let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(false, mw) .new_transform(srv.into_service()) @@ -154,6 +187,7 @@ mod tests { .unwrap(); let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_type::, String>>>(&resp); assert_eq!(resp.headers().get(CONTENT_TYPE), None); } }