From 015e1c7b4de2c4e393d94f7b9c06e4b09a5029d3 Mon Sep 17 00:00:00 2001 From: ibraheemdev Date: Thu, 22 Apr 2021 12:08:44 -0400 Subject: [PATCH] static condition middleware future --- src/middleware/condition.rs | 40 ++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index d1ba7ee4d..fd0875014 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -1,11 +1,13 @@ //! For middleware documentation, see [`Condition`]. +use std::future::Future; +use std::pin::Pin; use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use actix_utils::future::Either; -use futures_core::future::LocalBoxFuture; -use futures_util::future::FutureExt as _; + +use futures_core::ready; /// Middleware for conditionally enabling other middleware. /// @@ -48,22 +50,42 @@ where type Error = S::Error; type Transform = ConditionMiddleware; type InitError = T::InitError; - type Future = LocalBoxFuture<'static, Result>; + type Future = ConditionFut<::Future, S>; fn new_transform(&self, service: S) -> Self::Future { if self.enable { let fut = self.transformer.new_transform(service); - async move { - let wrapped_svc = fut.await?; - Ok(ConditionMiddleware::Enable(wrapped_svc)) - } - .boxed_local() + ConditionFut::Enable(fut) } else { - async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local() + ConditionFut::Disable(Some(service)) } } } +#[pin_project::pin_project(project = ConditionFutProj)] +pub enum ConditionFut { + Enable(#[pin] F), + Disable(Option), +} + +impl Future for ConditionFut +where + F: Future>, +{ + type Output = Result, Ie>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let middleware = match self.project() { + ConditionFutProj::Enable(fut) => ConditionMiddleware::Enable(ready!(fut.poll(cx))?), + ConditionFutProj::Disable(service) => { + ConditionMiddleware::Disable(service.take().unwrap()) + } + }; + + Poll::Ready(Ok(middleware)) + } +} + pub enum ConditionMiddleware { Enable(E), Disable(D),