From 38c667b0d0296f1664f34345ba93b2cf4aac80f1 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Fri, 1 Jan 2021 12:11:57 +0800 Subject: [PATCH] add scoped middleware --- src/middleware/condition.rs | 4 +- src/middleware/mod.rs | 2 + src/middleware/scoped.rs | 202 ++++++++++++++++++++++++++++++++++++ 3 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 src/middleware/scoped.rs diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index 9061c7458..4e6b2a943 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -6,7 +6,9 @@ use futures_util::future::{ok, Either, FutureExt, LocalBoxFuture}; /// `Middleware` for conditionally enables another middleware. /// The controlled middleware must not change the `Service` interfaces. -/// This means you cannot control such middlewares like `Logger` or `Compress`. +/// +/// This means you cannot control such middlewares like `Logger` or `Compress` directly. +/// *. See `Scoped` middleware for alternative. /// /// ## Usage /// diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 12c12a98c..c7aa026a9 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -10,8 +10,10 @@ mod defaultheaders; pub mod errhandlers; mod logger; pub mod normalize; +mod scoped; pub use self::condition::Condition; pub use self::defaultheaders::DefaultHeaders; pub use self::logger::Logger; pub use self::normalize::NormalizePath; +pub use self::scoped::Scoped; diff --git a/src/middleware/scoped.rs b/src/middleware/scoped.rs new file mode 100644 index 000000000..54c5b49bb --- /dev/null +++ b/src/middleware/scoped.rs @@ -0,0 +1,202 @@ +//! `Middleware` for enabling any middleware to be used in `Resource`, `Scope` and `Condition`. +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_http::body::Body; +use actix_http::body::{MessageBody, ResponseBody}; +use actix_service::{Service, Transform}; +use futures_core::future::LocalBoxFuture; +use futures_core::ready; + +use crate::error::Error; +use crate::service::ServiceResponse; + +/// `Middleware` for enabling any middleware to be used in `Resource`, `Scope` and `Condition`. +/// +/// +/// ## Usage +/// +/// ```rust +/// use actix_web::middleware::{Logger, Scoped}; +/// use actix_web::{App, web}; +/// +/// # fn main() { +/// let logger = Logger::default(); +/// +/// // this would not compile +/// // let app = App::new().service(web::scope("scoped").wrap(logger)); +/// +/// // by using scoped middleware we can use logger in scope. +/// let app = App::new().service(web::scope("scoped").wrap(Scoped::new(logger))); +/// # } +/// ``` +pub struct Scoped { + transform: T, +} + +impl Scoped { + pub fn new(transform: T) -> Self { + Self { transform } + } +} + +impl Transform for Scoped +where + S: Service, + T: Transform, + T::Future: 'static, + T::Response: MapServiceResponseBody, + Error: From, +{ + type Request = T::Request; + type Response = ServiceResponse; + type Error = Error; + type Transform = ScopedMiddleware; + type InitError = T::InitError; + type Future = LocalBoxFuture<'static, Result>; + + fn new_transform(&self, service: S) -> Self::Future { + let fut = self.transform.new_transform(service); + Box::pin(async move { + let service = fut.await?; + Ok(ScopedMiddleware { service }) + }) + } +} + +pub struct ScopedMiddleware { + service: S, +} + +impl Service for ScopedMiddleware +where + S: Service, + S::Response: MapServiceResponseBody, + Error: From, +{ + type Request = S::Request; + type Response = ServiceResponse; + type Error = Error; + type Future = ScopedFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx).map_err(From::from) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let fut = self.service.call(req); + ScopedFuture { fut } + } +} + +#[doc(hidden)] +#[pin_project::pin_project] +pub struct ScopedFuture +where + S: Service, +{ + #[pin] + fut: S::Future, +} + +impl Future for ScopedFuture +where + S: Service, + S::Response: MapServiceResponseBody, + Error: From, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = ready!(self.project().fut.poll(cx))?; + Poll::Ready(Ok(res.map_body())) + } +} + +// private trait for convert ServiceResponse's ResponseBody type +// to ResponseBody::Other(Body::Message) +#[doc(hidden)] +pub trait MapServiceResponseBody { + fn map_body(self) -> ServiceResponse; +} + +impl MapServiceResponseBody for ServiceResponse { + fn map_body(self) -> ServiceResponse { + self.map_body(|_, body| ResponseBody::Other(Body::from_message(body))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use actix_service::IntoService; + + use crate::dev::ServiceRequest; + use crate::http::StatusCode; + use crate::middleware::{Compress, Condition, Logger}; + use crate::test::{call_service, init_service, TestRequest}; + use crate::App; + use crate::{web, HttpResponse}; + + #[actix_rt::test] + async fn test_scope_middleware() { + let logger = Logger::default(); + let compress = Compress::default(); + + let mut srv = init_service( + App::new().service( + web::scope("app") + .wrap(Scoped::new(logger)) + .wrap(Scoped::new(compress)) + .service( + web::resource("/test").route(web::get().to(HttpResponse::Ok)), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_resource_scope_middleware() { + let logger = Logger::default(); + let compress = Compress::default(); + + let mut srv = init_service( + App::new().service( + web::resource("app/test") + .wrap(Scoped::new(logger)) + .wrap(Scoped::new(compress)) + .route(web::get().to(HttpResponse::Ok)), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_condition_scope_middleware() { + let srv = |req: ServiceRequest| { + Box::pin(async move { + Ok(req.into_response(HttpResponse::InternalServerError().finish())) + }) + }; + + let logger = Logger::default(); + + let mut mw = Condition::new(true, Scoped::new(logger)) + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = call_service(&mut mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } +}