add ability for condition middleware to consume an option and future

This commit is contained in:
Akos Vandra 2021-09-29 17:06:40 +02:00
parent a3806cde19
commit 6620fa5c4e
4 changed files with 194 additions and 31 deletions

View File

@ -3,10 +3,12 @@
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
### Added ### Added
* Option to allow `Json` extractor to work without a `Content-Type` header present. [#2362] * Option to allow `Json` extractor to work without a `Content-Type` header present. [#2362]
* New construction possiblity for the `Condition` middleware using `middleware::conditionally`, `middleware::optionally`, `middleware::optionally_fut`, `middleware::futurally`
### Changed ### Changed
* Associated type `FromRequest::Config` was removed. [#2233] * Associated type `FromRequest::Config` was removed. [#2233]
* Inner field made private on `web::Payload`. [#2384] * Inner field made private on `web::Payload`. [#2384]
* `middleware::Condition::new` was removed in favour of `middleware::conditionally`
[#2233]: https://github.com/actix/actix-web/pull/2233 [#2233]: https://github.com/actix/actix-web/pull/2233
[#2362]: https://github.com/actix/actix-web/pull/2362 [#2362]: https://github.com/actix/actix-web/pull/2362

View File

@ -139,7 +139,7 @@ mod tests {
use crate::dev::ServiceRequest; use crate::dev::ServiceRequest;
use crate::http::StatusCode; use crate::http::StatusCode;
use crate::middleware::{self, Condition, Logger}; use crate::middleware::{self, conditionally, Condition, Logger};
use crate::test::{call_service, init_service, TestRequest}; use crate::test::{call_service, init_service, TestRequest};
use crate::{web, App, HttpResponse}; use crate::{web, App, HttpResponse};
@ -199,7 +199,7 @@ mod tests {
let logger = Logger::default(); let logger = Logger::default();
let mw = Condition::new(true, Compat::new(logger)) let mw = conditionally(true, Compat::new(logger))
.new_transform(srv.into_service()) .new_transform(srv.into_service())
.await .await
.unwrap(); .unwrap();

View File

@ -4,8 +4,11 @@ use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use actix_utils::future::Either; use actix_utils::future::Either;
use futures_core::future::LocalBoxFuture; use std::future::{ready, Future, Ready};
use futures_util::future::FutureExt as _; use std::ops::DerefMut;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Mutex;
/// Middleware for conditionally enabling other middleware. /// Middleware for conditionally enabling other middleware.
/// ///
@ -15,52 +18,104 @@ use futures_util::future::FutureExt as _;
/// ///
/// # Examples /// # Examples
/// ``` /// ```
/// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::middleware::{Condition, NormalizePath, TrailingSlash, conditionally, optionally, optionally_fut};
/// use actix_web::App; /// use actix_web::App;
/// use std::future::ready;
/// ///
/// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok(); /// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok();
/// let config_opt = Some(TrailingSlash::Trim);
/// let config_opt_future = ready(Some(TrailingSlash::Always));
/// let future = ready(Some(NormalizePath::new(TrailingSlash::MergeOnly)));
///
/// let app = App::new() /// let app = App::new()
/// .wrap(Condition::new(enable_normalize, NormalizePath::default())); /// .wrap(conditionally(enable_normalize, NormalizePath::default()))
/// .wrap(optionally(config_opt, |mode| NormalizePath::new(mode)))
/// .wrap(optionally_fut(config_opt_future, |mode| NormalizePath::new(mode)))
/// .wrap(Condition::new(future));
/// ``` /// ```
pub struct Condition<T> {
transformer: T, pub struct Condition<T, F>(Rc<Mutex<F>>)
enable: bool, where
F: Future<Output = Option<T>> + Unpin + 'static;
pub fn futurally<T, F>(transformer: F) -> Condition<T, F>
where
F: Future<Output = Option<T>> + Unpin + 'static,
{
Condition(Rc::new(Mutex::new(transformer)))
} }
impl<T> Condition<T> { pub fn conditionally<T>(enable: bool, transformer: T) -> Condition<T, Ready<Option<T>>> {
pub fn new(enable: bool, transformer: T) -> Self { if enable {
Self { Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(Some(transformer)))))
transformer, } else {
enable, Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(None))))
}
} }
} }
impl<S, T, Req> Transform<S, Req> for Condition<T> pub fn optionally<T, A, FACTORY>(
condition: Option<A>,
transformer: FACTORY,
) -> Condition<T, impl Future<Output = Option<T>>>
where
FACTORY: FnOnce(A) -> T,
{
match condition {
Some(v) => {
Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(Some(transformer(v))))))
}
None => Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(None)))),
}
}
pub fn optionally_fut<T, A, F2, FACTORY>(
condition: F2,
transformer: FACTORY,
) -> Condition<T, impl Future<Output = Option<T>> + Unpin + 'static>
where
F2: Future<Output = Option<A>> + Unpin + 'static,
FACTORY: FnOnce(A) -> T,
{
Condition(Rc::new(Mutex::new(Box::pin(async move {
match condition.await {
Some(v) => Some(transformer(v)),
None => None,
}
}))))
}
impl<S, T, Req, F> Transform<S, Req> for Condition<T, F>
where where
S: Service<Req> + 'static, S: Service<Req> + 'static,
T: Transform<S, Req, Response = S::Response, Error = S::Error>, T: Transform<S, Req, Response = S::Response, Error = S::Error>,
T::Future: 'static, T::Future: 'static,
T::InitError: 'static, T::InitError: 'static,
T::Transform: 'static, T::Transform: 'static,
F: Future<Output = Option<T>> + Unpin + 'static,
{ {
type Response = S::Response; type Response = S::Response;
type Error = S::Error; type Error = S::Error;
type Transform = ConditionMiddleware<T::Transform, S>; type Transform = ConditionMiddleware<T::Transform, S>;
type InitError = T::InitError; type InitError = T::InitError;
type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Transform, Self::InitError>>>>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
if self.enable { let mutex = self.0.clone();
let fut = self.transformer.new_transform(service);
async move { Box::pin(async move {
let mut lock = mutex.lock().unwrap();
let transformer = lock.deref_mut().await;
match transformer {
Some(transformer) => {
let fut = transformer.new_transform(service);
let wrapped_svc = fut.await?; let wrapped_svc = fut.await?;
Ok(ConditionMiddleware::Enable(wrapped_svc)) Ok(ConditionMiddleware::Enable(wrapped_svc))
} }
.boxed_local() None => Ok(ConditionMiddleware::Disable(service)),
} else {
async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local()
} }
})
} }
} }
@ -124,7 +179,7 @@ mod tests {
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mw = Condition::new(true, mw) let mw = conditionally(true, mw)
.new_transform(srv.into_service()) .new_transform(srv.into_service())
.await .await
.unwrap(); .unwrap();
@ -140,7 +195,108 @@ mod tests {
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mw = Condition::new(false, mw) let mw = conditionally(false, mw)
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE), None);
}
#[actix_rt::test]
async fn test_handler_optional_some() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally(Some(StatusCode::INTERNAL_SERVER_ERROR), |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn test_handler_optional_none() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally(None, |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE), None);
}
#[actix_rt::test]
async fn test_handler_optional_future_some() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally_fut(ready(Some(StatusCode::INTERNAL_SERVER_ERROR)), |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn test_handler_optional_future_none() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally_fut(ready(None), |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE), None);
}
#[actix_rt::test]
async fn test_handler_futurally_enabled() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = futurally(ready(Some(
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
)))
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn test_handler_futurally_disabled() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let none: Option<ErrorHandlers<_>> = None;
let mw = futurally(ready(none))
.new_transform(srv.into_service()) .new_transform(srv.into_service())
.await .await
.unwrap(); .unwrap();

View File

@ -8,6 +8,10 @@ mod logger;
mod normalize; mod normalize;
pub use self::compat::Compat; pub use self::compat::Compat;
pub use self::condition::conditionally;
pub use self::condition::futurally;
pub use self::condition::optionally;
pub use self::condition::optionally_fut;
pub use self::condition::Condition; pub use self::condition::Condition;
pub use self::default_headers::DefaultHeaders; pub use self::default_headers::DefaultHeaders;
pub use self::err_handlers::{ErrorHandlerResponse, ErrorHandlers}; pub use self::err_handlers::{ErrorHandlerResponse, ErrorHandlers};
@ -25,6 +29,7 @@ mod tests {
use crate::{http::StatusCode, App}; use crate::{http::StatusCode, App};
use super::*; use super::*;
use crate::middleware::conditionally;
#[test] #[test]
fn common_combinations() { fn common_combinations() {
@ -32,7 +37,7 @@ mod tests {
let _ = App::new() let _ = App::new()
.wrap(Compat::new(Logger::default())) .wrap(Compat::new(Logger::default()))
.wrap(Condition::new(true, DefaultHeaders::new())) .wrap(conditionally(true, DefaultHeaders::new()))
.wrap(DefaultHeaders::new().header("X-Test2", "X-Value2")) .wrap(DefaultHeaders::new().header("X-Test2", "X-Value2"))
.wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| {
Ok(ErrorHandlerResponse::Response(res)) Ok(ErrorHandlerResponse::Response(res))
@ -47,7 +52,7 @@ mod tests {
Ok(ErrorHandlerResponse::Response(res)) Ok(ErrorHandlerResponse::Response(res))
})) }))
.wrap(DefaultHeaders::new().header("X-Test2", "X-Value2")) .wrap(DefaultHeaders::new().header("X-Test2", "X-Value2"))
.wrap(Condition::new(true, DefaultHeaders::new())) .wrap(conditionally(true, DefaultHeaders::new()))
.wrap(Compat::new(Logger::default())); .wrap(Compat::new(Logger::default()));
#[cfg(feature = "__compress")] #[cfg(feature = "__compress")]
@ -55,7 +60,7 @@ mod tests {
let _ = App::new().wrap(Compress::default()).wrap(Logger::default()); let _ = App::new().wrap(Compress::default()).wrap(Logger::default());
let _ = App::new().wrap(Logger::default()).wrap(Compress::default()); let _ = App::new().wrap(Logger::default()).wrap(Compress::default());
let _ = App::new().wrap(Compat::new(Compress::default())); let _ = App::new().wrap(Compat::new(Compress::default()));
let _ = App::new().wrap(Condition::new(true, Compat::new(Compress::default()))); let _ = App::new().wrap(conditionally(true, Compat::new(Compress::default())));
} }
} }
} }