add ability for condition middleware to consume an option and future

This commit is contained in:
Akos Vandra 2021-09-29 17:06:40 +02:00
parent a3806cde19
commit 6620fa5c4e
4 changed files with 194 additions and 31 deletions

View File

@ -3,10 +3,12 @@
## Unreleased - 2021-xx-xx
### Added
* Option to allow `Json` extractor to work without a `Content-Type` header present. [#2362]
* New construction possiblity for the `Condition` middleware using `middleware::conditionally`, `middleware::optionally`, `middleware::optionally_fut`, `middleware::futurally`
### Changed
* Associated type `FromRequest::Config` was removed. [#2233]
* Inner field made private on `web::Payload`. [#2384]
* `middleware::Condition::new` was removed in favour of `middleware::conditionally`
[#2233]: https://github.com/actix/actix-web/pull/2233
[#2362]: https://github.com/actix/actix-web/pull/2362

View File

@ -139,7 +139,7 @@ mod tests {
use crate::dev::ServiceRequest;
use crate::http::StatusCode;
use crate::middleware::{self, Condition, Logger};
use crate::middleware::{self, conditionally, Condition, Logger};
use crate::test::{call_service, init_service, TestRequest};
use crate::{web, App, HttpResponse};
@ -199,7 +199,7 @@ mod tests {
let logger = Logger::default();
let mw = Condition::new(true, Compat::new(logger))
let mw = conditionally(true, Compat::new(logger))
.new_transform(srv.into_service())
.await
.unwrap();

View File

@ -4,8 +4,11 @@ use std::task::{Context, Poll};
use actix_service::{Service, Transform};
use actix_utils::future::Either;
use futures_core::future::LocalBoxFuture;
use futures_util::future::FutureExt as _;
use std::future::{ready, Future, Ready};
use std::ops::DerefMut;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Mutex;
/// Middleware for conditionally enabling other middleware.
///
@ -15,52 +18,104 @@ use futures_util::future::FutureExt as _;
///
/// # Examples
/// ```
/// use actix_web::middleware::{Condition, NormalizePath};
/// use actix_web::middleware::{Condition, NormalizePath, TrailingSlash, conditionally, optionally, optionally_fut};
/// use actix_web::App;
/// use std::future::ready;
///
/// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok();
/// let config_opt = Some(TrailingSlash::Trim);
/// let config_opt_future = ready(Some(TrailingSlash::Always));
/// let future = ready(Some(NormalizePath::new(TrailingSlash::MergeOnly)));
///
/// let app = App::new()
/// .wrap(Condition::new(enable_normalize, NormalizePath::default()));
/// .wrap(conditionally(enable_normalize, NormalizePath::default()))
/// .wrap(optionally(config_opt, |mode| NormalizePath::new(mode)))
/// .wrap(optionally_fut(config_opt_future, |mode| NormalizePath::new(mode)))
/// .wrap(Condition::new(future));
/// ```
pub struct Condition<T> {
transformer: T,
enable: bool,
pub struct Condition<T, F>(Rc<Mutex<F>>)
where
F: Future<Output = Option<T>> + Unpin + 'static;
pub fn futurally<T, F>(transformer: F) -> Condition<T, F>
where
F: Future<Output = Option<T>> + Unpin + 'static,
{
Condition(Rc::new(Mutex::new(transformer)))
}
impl<T> Condition<T> {
pub fn new(enable: bool, transformer: T) -> Self {
Self {
transformer,
enable,
}
pub fn conditionally<T>(enable: bool, transformer: T) -> Condition<T, Ready<Option<T>>> {
if enable {
Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(Some(transformer)))))
} else {
Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(None))))
}
}
impl<S, T, Req> Transform<S, Req> for Condition<T>
pub fn optionally<T, A, FACTORY>(
condition: Option<A>,
transformer: FACTORY,
) -> Condition<T, impl Future<Output = Option<T>>>
where
FACTORY: FnOnce(A) -> T,
{
match condition {
Some(v) => {
Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(Some(transformer(v))))))
}
None => Condition::<T, Ready<Option<T>>>(Rc::new(Mutex::new(ready(None)))),
}
}
pub fn optionally_fut<T, A, F2, FACTORY>(
condition: F2,
transformer: FACTORY,
) -> Condition<T, impl Future<Output = Option<T>> + Unpin + 'static>
where
F2: Future<Output = Option<A>> + Unpin + 'static,
FACTORY: FnOnce(A) -> T,
{
Condition(Rc::new(Mutex::new(Box::pin(async move {
match condition.await {
Some(v) => Some(transformer(v)),
None => None,
}
}))))
}
impl<S, T, Req, F> Transform<S, Req> for Condition<T, F>
where
S: Service<Req> + 'static,
T: Transform<S, Req, Response = S::Response, Error = S::Error>,
T::Future: 'static,
T::InitError: 'static,
T::Transform: 'static,
F: Future<Output = Option<T>> + Unpin + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Transform = ConditionMiddleware<T::Transform, S>;
type InitError = T::InitError;
type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Transform, Self::InitError>>>>;
fn new_transform(&self, service: S) -> Self::Future {
if self.enable {
let fut = self.transformer.new_transform(service);
async move {
let mutex = self.0.clone();
Box::pin(async move {
let mut lock = mutex.lock().unwrap();
let transformer = lock.deref_mut().await;
match transformer {
Some(transformer) => {
let fut = transformer.new_transform(service);
let wrapped_svc = fut.await?;
Ok(ConditionMiddleware::Enable(wrapped_svc))
}
.boxed_local()
} else {
async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local()
None => Ok(ConditionMiddleware::Disable(service)),
}
})
}
}
@ -124,7 +179,7 @@ mod tests {
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mw = Condition::new(true, mw)
let mw = conditionally(true, mw)
.new_transform(srv.into_service())
.await
.unwrap();
@ -140,7 +195,108 @@ mod tests {
let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mw = Condition::new(false, mw)
let mw = conditionally(false, mw)
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE), None);
}
#[actix_rt::test]
async fn test_handler_optional_some() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally(Some(StatusCode::INTERNAL_SERVER_ERROR), |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn test_handler_optional_none() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally(None, |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE), None);
}
#[actix_rt::test]
async fn test_handler_optional_future_some() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally_fut(ready(Some(StatusCode::INTERNAL_SERVER_ERROR)), |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn test_handler_optional_future_none() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = optionally_fut(ready(None), |status| {
ErrorHandlers::new().handler(status, render_500)
})
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE), None);
}
#[actix_rt::test]
async fn test_handler_futurally_enabled() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mw = futurally(ready(Some(
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
)))
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn test_handler_futurally_disabled() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let none: Option<ErrorHandlers<_>> = None;
let mw = futurally(ready(none))
.new_transform(srv.into_service())
.await
.unwrap();

View File

@ -8,6 +8,10 @@ mod logger;
mod normalize;
pub use self::compat::Compat;
pub use self::condition::conditionally;
pub use self::condition::futurally;
pub use self::condition::optionally;
pub use self::condition::optionally_fut;
pub use self::condition::Condition;
pub use self::default_headers::DefaultHeaders;
pub use self::err_handlers::{ErrorHandlerResponse, ErrorHandlers};
@ -25,6 +29,7 @@ mod tests {
use crate::{http::StatusCode, App};
use super::*;
use crate::middleware::conditionally;
#[test]
fn common_combinations() {
@ -32,7 +37,7 @@ mod tests {
let _ = App::new()
.wrap(Compat::new(Logger::default()))
.wrap(Condition::new(true, DefaultHeaders::new()))
.wrap(conditionally(true, DefaultHeaders::new()))
.wrap(DefaultHeaders::new().header("X-Test2", "X-Value2"))
.wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| {
Ok(ErrorHandlerResponse::Response(res))
@ -47,7 +52,7 @@ mod tests {
Ok(ErrorHandlerResponse::Response(res))
}))
.wrap(DefaultHeaders::new().header("X-Test2", "X-Value2"))
.wrap(Condition::new(true, DefaultHeaders::new()))
.wrap(conditionally(true, DefaultHeaders::new()))
.wrap(Compat::new(Logger::default()));
#[cfg(feature = "__compress")]
@ -55,7 +60,7 @@ mod tests {
let _ = App::new().wrap(Compress::default()).wrap(Logger::default());
let _ = App::new().wrap(Logger::default()).wrap(Compress::default());
let _ = App::new().wrap(Compat::new(Compress::default()));
let _ = App::new().wrap(Condition::new(true, Compat::new(Compress::default())));
let _ = App::new().wrap(conditionally(true, Compat::new(Compress::default())));
}
}
}