From 6620fa5c4e6d72daf369564f364fde5b68abe373 Mon Sep 17 00:00:00 2001 From: Akos Vandra Date: Wed, 29 Sep 2021 17:06:40 +0200 Subject: [PATCH] add ability for condition middleware to consume an option and future --- CHANGES.md | 2 + src/middleware/compat.rs | 4 +- src/middleware/condition.rs | 208 +++++++++++++++++++++++++++++++----- src/middleware/mod.rs | 11 +- 4 files changed, 194 insertions(+), 31 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 05055a517..b2cc1402e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,10 +3,12 @@ ## Unreleased - 2021-xx-xx ### Added * Option to allow `Json` extractor to work without a `Content-Type` header present. [#2362] +* New construction possiblity for the `Condition` middleware using `middleware::conditionally`, `middleware::optionally`, `middleware::optionally_fut`, `middleware::futurally` ### Changed * Associated type `FromRequest::Config` was removed. [#2233] * Inner field made private on `web::Payload`. [#2384] +* `middleware::Condition::new` was removed in favour of `middleware::conditionally` [#2233]: https://github.com/actix/actix-web/pull/2233 [#2362]: https://github.com/actix/actix-web/pull/2362 diff --git a/src/middleware/compat.rs b/src/middleware/compat.rs index 0a6256fe2..85166e485 100644 --- a/src/middleware/compat.rs +++ b/src/middleware/compat.rs @@ -139,7 +139,7 @@ mod tests { use crate::dev::ServiceRequest; use crate::http::StatusCode; - use crate::middleware::{self, Condition, Logger}; + use crate::middleware::{self, conditionally, Condition, Logger}; use crate::test::{call_service, init_service, TestRequest}; use crate::{web, App, HttpResponse}; @@ -199,7 +199,7 @@ mod tests { let logger = Logger::default(); - let mw = Condition::new(true, Compat::new(logger)) + let mw = conditionally(true, Compat::new(logger)) .new_transform(srv.into_service()) .await .unwrap(); diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index d1ba7ee4d..490385169 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -4,8 +4,11 @@ use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use actix_utils::future::Either; -use futures_core::future::LocalBoxFuture; -use futures_util::future::FutureExt as _; +use std::future::{ready, Future, Ready}; +use std::ops::DerefMut; +use std::pin::Pin; +use std::rc::Rc; +use std::sync::Mutex; /// Middleware for conditionally enabling other middleware. /// @@ -15,52 +18,104 @@ use futures_util::future::FutureExt as _; /// /// # Examples /// ``` -/// use actix_web::middleware::{Condition, NormalizePath}; +/// use actix_web::middleware::{Condition, NormalizePath, TrailingSlash, conditionally, optionally, optionally_fut}; /// use actix_web::App; +/// use std::future::ready; /// /// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok(); +/// let config_opt = Some(TrailingSlash::Trim); +/// let config_opt_future = ready(Some(TrailingSlash::Always)); +/// let future = ready(Some(NormalizePath::new(TrailingSlash::MergeOnly))); +/// /// let app = App::new() -/// .wrap(Condition::new(enable_normalize, NormalizePath::default())); +/// .wrap(conditionally(enable_normalize, NormalizePath::default())) +/// .wrap(optionally(config_opt, |mode| NormalizePath::new(mode))) +/// .wrap(optionally_fut(config_opt_future, |mode| NormalizePath::new(mode))) +/// .wrap(Condition::new(future)); /// ``` -pub struct Condition { - transformer: T, - enable: bool, + +pub struct Condition(Rc>) +where + F: Future> + Unpin + 'static; + +pub fn futurally(transformer: F) -> Condition +where + F: Future> + Unpin + 'static, +{ + Condition(Rc::new(Mutex::new(transformer))) } -impl Condition { - pub fn new(enable: bool, transformer: T) -> Self { - Self { - transformer, - enable, - } +pub fn conditionally(enable: bool, transformer: T) -> Condition>> { + if enable { + Condition::>>(Rc::new(Mutex::new(ready(Some(transformer))))) + } else { + Condition::>>(Rc::new(Mutex::new(ready(None)))) } } -impl Transform for Condition +pub fn optionally( + condition: Option, + transformer: FACTORY, +) -> Condition>> +where + FACTORY: FnOnce(A) -> T, +{ + match condition { + Some(v) => { + Condition::>>(Rc::new(Mutex::new(ready(Some(transformer(v)))))) + } + None => Condition::>>(Rc::new(Mutex::new(ready(None)))), + } +} + +pub fn optionally_fut( + condition: F2, + transformer: FACTORY, +) -> Condition> + Unpin + 'static> +where + F2: Future> + Unpin + 'static, + FACTORY: FnOnce(A) -> T, +{ + Condition(Rc::new(Mutex::new(Box::pin(async move { + match condition.await { + Some(v) => Some(transformer(v)), + None => None, + } + })))) +} + +impl Transform for Condition where S: Service + 'static, T: Transform, T::Future: 'static, T::InitError: 'static, T::Transform: 'static, + F: Future> + Unpin + 'static, { type Response = S::Response; type Error = S::Error; type Transform = ConditionMiddleware; type InitError = T::InitError; - type Future = LocalBoxFuture<'static, Result>; + type Future = Pin>>>; fn new_transform(&self, service: S) -> Self::Future { - if self.enable { - let fut = self.transformer.new_transform(service); - async move { - let wrapped_svc = fut.await?; - Ok(ConditionMiddleware::Enable(wrapped_svc)) + let mutex = self.0.clone(); + + Box::pin(async move { + let mut lock = mutex.lock().unwrap(); + + let transformer = lock.deref_mut().await; + + match transformer { + Some(transformer) => { + let fut = transformer.new_transform(service); + let wrapped_svc = fut.await?; + Ok(ConditionMiddleware::Enable(wrapped_svc)) + } + None => Ok(ConditionMiddleware::Disable(service)), } - .boxed_local() - } else { - async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local() - } + }) } } @@ -124,7 +179,7 @@ mod tests { let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mw = Condition::new(true, mw) + let mw = conditionally(true, mw) .new_transform(srv.into_service()) .await .unwrap(); @@ -140,7 +195,108 @@ mod tests { let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mw = Condition::new(false, mw) + let mw = conditionally(false, mw) + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE), None); + } + + #[actix_rt::test] + async fn test_handler_optional_some() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let mw = optionally(Some(StatusCode::INTERNAL_SERVER_ERROR), |status| { + ErrorHandlers::new().handler(status, render_500) + }) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } + + #[actix_rt::test] + async fn test_handler_optional_none() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let mw = optionally(None, |status| { + ErrorHandlers::new().handler(status, render_500) + }) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE), None); + } + + #[actix_rt::test] + async fn test_handler_optional_future_some() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let mw = optionally_fut(ready(Some(StatusCode::INTERNAL_SERVER_ERROR)), |status| { + ErrorHandlers::new().handler(status, render_500) + }) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } + + #[actix_rt::test] + async fn test_handler_optional_future_none() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let mw = optionally_fut(ready(None), |status| { + ErrorHandlers::new().handler(status, render_500) + }) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE), None); + } + + #[actix_rt::test] + async fn test_handler_futurally_enabled() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let mw = futurally(ready(Some( + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), + ))) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } + + #[actix_rt::test] + async fn test_handler_futurally_disabled() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let none: Option> = None; + + let mw = futurally(ready(none)) .new_transform(srv.into_service()) .await .unwrap(); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index d19cb64e9..cc9d59372 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -8,6 +8,10 @@ mod logger; mod normalize; pub use self::compat::Compat; +pub use self::condition::conditionally; +pub use self::condition::futurally; +pub use self::condition::optionally; +pub use self::condition::optionally_fut; pub use self::condition::Condition; pub use self::default_headers::DefaultHeaders; pub use self::err_handlers::{ErrorHandlerResponse, ErrorHandlers}; @@ -25,6 +29,7 @@ mod tests { use crate::{http::StatusCode, App}; use super::*; + use crate::middleware::conditionally; #[test] fn common_combinations() { @@ -32,7 +37,7 @@ mod tests { let _ = App::new() .wrap(Compat::new(Logger::default())) - .wrap(Condition::new(true, DefaultHeaders::new())) + .wrap(conditionally(true, DefaultHeaders::new())) .wrap(DefaultHeaders::new().header("X-Test2", "X-Value2")) .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { Ok(ErrorHandlerResponse::Response(res)) @@ -47,7 +52,7 @@ mod tests { Ok(ErrorHandlerResponse::Response(res)) })) .wrap(DefaultHeaders::new().header("X-Test2", "X-Value2")) - .wrap(Condition::new(true, DefaultHeaders::new())) + .wrap(conditionally(true, DefaultHeaders::new())) .wrap(Compat::new(Logger::default())); #[cfg(feature = "__compress")] @@ -55,7 +60,7 @@ mod tests { let _ = App::new().wrap(Compress::default()).wrap(Logger::default()); let _ = App::new().wrap(Logger::default()).wrap(Compress::default()); let _ = App::new().wrap(Compat::new(Compress::default())); - let _ = App::new().wrap(Condition::new(true, Compat::new(Compress::default()))); + let _ = App::new().wrap(conditionally(true, Compat::new(Compress::default()))); } } }