diff --git a/actix-files/src/service.rs b/actix-files/src/service.rs index f6e1c2e11..057dbe5a3 100644 --- a/actix-files/src/service.rs +++ b/actix-files/src/service.rs @@ -1,8 +1,8 @@ use std::{fmt, io, ops::Deref, path::PathBuf, rc::Rc}; -use actix_service::Service; use actix_web::{ - dev::{ServiceRequest, ServiceResponse}, + body::BoxBody, + dev::{Service, ServiceRequest, ServiceResponse}, error::Error, guard::Guard, http::{header, Method}, @@ -94,7 +94,7 @@ impl fmt::Debug for FilesService { } impl Service for FilesService { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; @@ -103,7 +103,7 @@ impl Service for FilesService { fn call(&self, req: ServiceRequest) -> Self::Future { let is_method_valid = if let Some(guard) = &self.guards { // execute user defined guards - (**guard).check(req.head()) + (**guard).check(&req.guard_ctx()) } else { // default behavior matches!(*req.method(), Method::HEAD | Method::GET) diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index 34213f68a..ecd08fbb3 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, ops, rc::Rc}; use bitflags::bitflags; @@ -49,7 +49,7 @@ impl Message { } } -impl std::ops::Deref for Message { +impl ops::Deref for Message { type Target = T; fn deref(&self) -> &Self::Target { @@ -57,7 +57,7 @@ impl std::ops::Deref for Message { } } -impl std::ops::DerefMut for Message { +impl ops::DerefMut for Message { fn deref_mut(&mut self) -> &mut Self::Target { Rc::get_mut(&mut self.head).expect("Multiple copies exist") } diff --git a/src/app_service.rs b/src/app_service.rs index e0d424390..56b24f0d8 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,14 +1,16 @@ use std::{cell::RefCell, mem, rc::Rc}; -use actix_http::{Extensions, Request}; +use actix_http::Request; use actix_router::{Path, ResourceDef, Router, Url}; use actix_service::{boxed, fn_service, Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; use crate::{ + body::BoxBody, config::{AppConfig, AppService}, data::FnDataFactory, + dev::Extensions, guard::Guard, request::{HttpRequest, HttpRequestPool}, rmap::ResourceMap, @@ -297,7 +299,7 @@ pub struct AppRouting { } impl Service for AppRouting { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; @@ -306,12 +308,15 @@ impl Service for AppRouting { fn call(&self, mut req: ServiceRequest) -> Self::Future { let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { - for f in guards { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in guards { + if !guard.check(&guard_ctx) { return false; } } } + true }); diff --git a/src/dev.rs b/src/dev.rs index 6e1970467..bb1385bde 100644 --- a/src/dev.rs +++ b/src/dev.rs @@ -3,6 +3,16 @@ //! Most users will not have to interact with the types in this module, but it is useful for those //! writing extractors, middleware, libraries, or interacting with the service API directly. +pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead}; +pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; +pub use actix_server::{Server, ServerHandle}; +pub use actix_service::{ + always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform, +}; + +#[cfg(feature = "__compress")] +pub use actix_http::encoding::Decoder as Decompress; + pub use crate::config::{AppConfig, AppService}; #[doc(hidden)] pub use crate::handler::Handler; @@ -14,16 +24,6 @@ pub use crate::types::form::UrlEncoded; pub use crate::types::json::JsonBody; pub use crate::types::readlines::Readlines; -pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead}; -pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; -pub use actix_server::{Server, ServerHandle}; -pub use actix_service::{ - always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform, -}; - -#[cfg(feature = "__compress")] -pub use actix_http::encoding::Decoder as Decompress; - use crate::http::header::ContentEncoding; use actix_router::Patterns; @@ -46,7 +46,6 @@ pub(crate) fn ensure_leading_slash(mut patterns: Patterns) -> Patterns { patterns } -struct Enc(ContentEncoding); /// Helper trait that allows to set specific encoding for response. pub trait BodyEncoding { @@ -70,6 +69,8 @@ impl BodyEncoding for actix_http::ResponseBuilder { } } +struct Enc(ContentEncoding); + impl BodyEncoding for actix_http::Response { fn get_encoding(&self) -> Option { self.extensions().get::().map(|enc| enc.0) diff --git a/src/guard.rs b/src/guard.rs index db7f06987..b541653e7 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,16 +1,14 @@ -//! Route match guards. +//! Route guards. //! -//! Guards are one of the ways how actix-web router chooses a -//! handler service. In essence it is just a function that accepts a -//! reference to a `RequestHead` instance and returns a boolean. -//! It is possible to add guards to *scopes*, *resources* -//! and *routes*. Actix provide several guards by default, like various -//! http methods, header, etc. To become a guard, type must implement `Guard` -//! trait. Simple functions could be guards as well. +//! Guards are one of the ways how actix-web router chooses a handler service. In essence it is just +//! a function that accepts a reference to a `RequestHead` instance and returns a boolean. It is +//! possible to add guards to *scopes*, *resources* and *routes*. Actix provide several guards by +//! default, like various HTTP methods, header, etc. To become a guard, type must implement the +//! `Guard` trait. Simple functions could be guards as well. //! -//! Guards can not modify the request object. But it is possible -//! to store extra attributes on a request by using the `Extensions` container. -//! Extensions containers are available via the `RequestHead::extensions()` method. +//! Guards can not modify the request object. But it is possible to store extra attributes on a +//! request by using the `Extensions` container. Extensions containers are available via the +//! `RequestHead::extensions()` method. //! //! ``` //! use actix_web::{web, http, dev, guard, App, HttpResponse}; @@ -18,31 +16,56 @@ //! App::new().service(web::resource("/index.html").route( //! web::route() //! .guard(guard::Post()) -//! .guard(guard::fn_guard(|head| head.method == http::Method::GET)) +//! .guard(guard::fn_guard(|ctx| ctx.head().method == http::Method::GET)) //! .to(|| HttpResponse::MethodNotAllowed())) //! ); //! ``` -#![allow(non_snake_case)] +use std::{ + cell::{Ref, RefMut}, + convert::TryFrom, + rc::Rc, +}; -use std::rc::Rc; -use std::{convert::TryFrom, ops::Deref}; +use actix_http::{header, uri::Uri, Extensions, Method as HttpMethod, RequestHead}; -use actix_http::{header, uri::Uri, Method as HttpMethod, RequestHead}; +use crate::service::ServiceRequest; + +#[derive(Debug)] +pub struct GuardContext<'a> { + pub(crate) req: &'a ServiceRequest, +} + +impl<'a> GuardContext<'a> { + #[inline] + pub fn head(&self) -> &RequestHead { + self.req.head() + } + + #[inline] + pub fn req_data(&self) -> Ref<'a, Extensions> { + self.req.req_data() + } + + #[inline] + pub fn req_data_mut(&self) -> RefMut<'a, Extensions> { + self.req.req_data_mut() + } +} /// Trait defines resource guards. Guards are used for route selection. /// -/// Guards can not modify the request object. But it is possible -/// to store extra attributes on a request by using the `Extensions` container. -/// Extensions containers are available via the `RequestHead::extensions()` method. +/// Guards can not modify the request object. But it is possible to store extra attributes on a +/// request by using the `Extensions` container. Extensions containers are available via the +/// `RequestHead::extensions()` method. pub trait Guard { /// Check if request matches predicate - fn check(&self, request: &RequestHead) -> bool; + fn check(&self, ctx: &GuardContext<'_>) -> bool; } impl Guard for Rc { - fn check(&self, request: &RequestHead) -> bool { - self.deref().check(request) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (**self).check(ctx) } } @@ -51,39 +74,40 @@ impl Guard for Rc { /// ``` /// use actix_web::{guard, web, App, HttpResponse}; /// -/// App::new().service(web::resource("/index.html").route( -/// web::route() -/// .guard( -/// guard::fn_guard( -/// |req| req.headers() -/// .contains_key("content-type"))) -/// .to(|| HttpResponse::MethodNotAllowed())) -/// ); +/// App::new().service( +/// web::resource("/index.html").route( +/// web::route() +/// .guard(guard::fn_guard(|ctx| { +/// ctx.head().headers().contains_key("content-type") +/// })) +/// .to(|| HttpResponse::MethodNotAllowed()), +/// ), +/// ); /// ``` pub fn fn_guard(f: F) -> impl Guard where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { FnGuard(f) } -struct FnGuard bool>(F); +struct FnGuard) -> bool>(F); impl Guard for FnGuard where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { - fn check(&self, head: &RequestHead) -> bool { - (self.0)(head) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (self.0)(ctx) } } impl Guard for F where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { - fn check(&self, head: &RequestHead) -> bool { - (self)(head) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (self)(ctx) } } @@ -98,28 +122,34 @@ where /// .to(|| HttpResponse::MethodNotAllowed())) /// ); /// ``` +#[allow(non_snake_case)] pub fn Any(guard: F) -> AnyGuard { - AnyGuard(vec![Box::new(guard)]) + AnyGuard { + guards: vec![Box::new(guard)], + } } /// Matches any of supplied guards. -pub struct AnyGuard(Vec>); +pub struct AnyGuard { + guards: Vec>, +} impl AnyGuard { /// Add guard to a list of guards to check pub fn or(mut self, guard: F) -> Self { - self.0.push(Box::new(guard)); + self.guards.push(Box::new(guard)); self } } impl Guard for AnyGuard { - fn check(&self, req: &RequestHead) -> bool { - for p in &self.0 { - if p.check(req) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + for guard in &self.guards { + if guard.check(ctx) { return true; } } + false } } @@ -136,25 +166,30 @@ impl Guard for AnyGuard { /// .to(|| HttpResponse::MethodNotAllowed())) /// ); /// ``` +#[allow(non_snake_case)] pub fn All(guard: F) -> AllGuard { - AllGuard(vec![Box::new(guard)]) + AllGuard { + guards: vec![Box::new(guard)], + } } /// Matches if all of supplied guards. -pub struct AllGuard(Vec>); +pub struct AllGuard { + guards: Vec>, +} impl AllGuard { /// Add new guard to the list of guards to check pub fn and(mut self, guard: F) -> Self { - self.0.push(Box::new(guard)); + self.guards.push(Box::new(guard)); self } } impl Guard for AllGuard { - fn check(&self, request: &RequestHead) -> bool { - for p in &self.0 { - if !p.check(request) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + for guard in &self.guards { + if !guard.check(ctx) { return false; } } @@ -163,82 +198,91 @@ impl Guard for AllGuard { } /// Return guard that matches if supplied guard does not match. -pub fn Not(guard: F) -> NotGuard { +#[allow(non_snake_case)] +pub fn Not(guard: F) -> impl Guard { NotGuard(Box::new(guard)) } -#[doc(hidden)] -pub struct NotGuard(Box); +struct NotGuard(Box); impl Guard for NotGuard { - fn check(&self, request: &RequestHead) -> bool { - !self.0.check(request) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + !self.0.check(ctx) } } /// HTTP method guard. -#[doc(hidden)] -pub struct MethodGuard(HttpMethod); +struct MethodGuard(HttpMethod); impl Guard for MethodGuard { - fn check(&self, request: &RequestHead) -> bool { - request.method == self.0 + fn check(&self, ctx: &GuardContext<'_>) -> bool { + ctx.head().method == self.0 } } /// Guard to match *GET* HTTP method. -pub fn Get() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Get() -> impl Guard { MethodGuard(HttpMethod::GET) } /// Predicate to match *POST* HTTP method. -pub fn Post() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Post() -> impl Guard { MethodGuard(HttpMethod::POST) } /// Predicate to match *PUT* HTTP method. -pub fn Put() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Put() -> impl Guard { MethodGuard(HttpMethod::PUT) } /// Predicate to match *DELETE* HTTP method. -pub fn Delete() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Delete() -> impl Guard { MethodGuard(HttpMethod::DELETE) } /// Predicate to match *HEAD* HTTP method. -pub fn Head() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Head() -> impl Guard { MethodGuard(HttpMethod::HEAD) } /// Predicate to match *OPTIONS* HTTP method. -pub fn Options() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Options() -> impl Guard { MethodGuard(HttpMethod::OPTIONS) } /// Predicate to match *CONNECT* HTTP method. -pub fn Connect() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Connect() -> impl Guard { MethodGuard(HttpMethod::CONNECT) } /// Predicate to match *PATCH* HTTP method. -pub fn Patch() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Patch() -> impl Guard { MethodGuard(HttpMethod::PATCH) } /// Predicate to match *TRACE* HTTP method. -pub fn Trace() -> MethodGuard { +#[allow(non_snake_case)] +pub fn Trace() -> impl Guard { MethodGuard(HttpMethod::TRACE) } /// Predicate to match specified HTTP method. -pub fn Method(method: HttpMethod) -> MethodGuard { +#[allow(non_snake_case)] +pub fn Method(method: HttpMethod) -> impl Guard { MethodGuard(method) } -/// Return predicate that matches if request contains specified header and -/// value. -pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard { +/// Return predicate that matches if request contains specified header and value. +#[allow(non_snake_case)] +pub fn Header(name: &'static str, value: &'static str) -> impl Guard { HeaderGuard( header::HeaderName::try_from(name).unwrap(), header::HeaderValue::from_static(value), @@ -246,13 +290,14 @@ pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard { } #[doc(hidden)] -pub struct HeaderGuard(header::HeaderName, header::HeaderValue); +struct HeaderGuard(header::HeaderName, header::HeaderValue); impl Guard for HeaderGuard { - fn check(&self, req: &RequestHead) -> bool { - if let Some(val) = req.headers.get(&self.0) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + if let Some(val) = ctx.head().headers.get(&self.0) { return val == self.1; } + false } } @@ -268,48 +313,54 @@ impl Guard for HeaderGuard { /// .to(|| HttpResponse::MethodNotAllowed()) /// ); /// ``` +#[allow(non_snake_case)] pub fn Host>(host: H) -> HostGuard { - HostGuard(host.as_ref().to_string(), None) + HostGuard { + host: host.as_ref().to_string(), + scheme: None, + } } fn get_host_uri(req: &RequestHead) -> Option { - use core::str::FromStr; req.headers .get(header::HOST) .and_then(|host_value| host_value.to_str().ok()) .or_else(|| req.uri.host()) - .map(|host: &str| Uri::from_str(host).ok()) + .map(|host| host.parse().ok()) .and_then(|host_success| host_success) } #[doc(hidden)] -pub struct HostGuard(String, Option); +pub struct HostGuard { + host: String, + scheme: Option, +} impl HostGuard { /// Set request scheme to match pub fn scheme>(mut self, scheme: H) -> HostGuard { - self.1 = Some(scheme.as_ref().to_string()); + self.scheme = Some(scheme.as_ref().to_string()); self } } impl Guard for HostGuard { - fn check(&self, req: &RequestHead) -> bool { - let req_host_uri = if let Some(uri) = get_host_uri(req) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + let req_host_uri = if let Some(uri) = get_host_uri(ctx.head()) { uri } else { return false; }; if let Some(uri_host) = req_host_uri.host() { - if self.0 != uri_host { + if self.host != uri_host { return false; } } else { return false; } - if let Some(ref scheme) = self.1 { + if let Some(ref scheme) = self.scheme { if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() { return scheme == req_host_uri_scheme; } @@ -330,16 +381,16 @@ mod tests { fn test_header() { let req = TestRequest::default() .insert_header((header::TRANSFER_ENCODING, "chunked")) - .to_http_request(); + .to_srv_request(); let pred = Header("transfer-encoding", "chunked"); - assert!(pred.check(req.head())); + assert!(pred.check(&req.guard_ctx())); let pred = Header("transfer-encoding", "other"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Header("content-type", "other"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); } #[test] @@ -349,25 +400,25 @@ mod tests { header::HOST, header::HeaderValue::from_static("www.rust-lang.org"), )) - .to_http_request(); + .to_srv_request(); let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); + assert!(pred.check(&req.guard_ctx())); let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); + assert!(pred.check(&req.guard_ctx())); let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("crates.io"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("localhost"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); } #[test] @@ -377,121 +428,117 @@ mod tests { header::HOST, header::HeaderValue::from_static("https://www.rust-lang.org"), )) - .to_http_request(); + .to_srv_request(); let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); + assert!(pred.check(&req.guard_ctx())); let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); + assert!(pred.check(&req.guard_ctx())); let pred = Host("www.rust-lang.org").scheme("http"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("crates.io").scheme("https"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("localhost"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); } #[test] fn test_host_without_header() { let req = TestRequest::default() .uri("www.rust-lang.org") - .to_http_request(); + .to_srv_request(); let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); + assert!(pred.check(&req.guard_ctx())); let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); + assert!(pred.check(&req.guard_ctx())); let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("crates.io"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); let pred = Host("localhost"); - assert!(!pred.check(req.head())); + assert!(!pred.check(&req.guard_ctx())); } #[test] fn test_methods() { - let req = TestRequest::default().to_http_request(); - let req2 = TestRequest::default() - .method(Method::POST) - .to_http_request(); + let req = TestRequest::default().to_srv_request(); + let req2 = TestRequest::default().method(Method::POST).to_srv_request(); - assert!(Get().check(req.head())); - assert!(!Get().check(req2.head())); - assert!(Post().check(req2.head())); - assert!(!Post().check(req.head())); + assert!(Get().check(&req.guard_ctx())); + assert!(!Get().check(&req2.guard_ctx())); + assert!(Post().check(&req2.guard_ctx())); + assert!(!Post().check(&req.guard_ctx())); - let r = TestRequest::default().method(Method::PUT).to_http_request(); - assert!(Put().check(r.head())); - assert!(!Put().check(req.head())); + let r = TestRequest::default().method(Method::PUT).to_srv_request(); + assert!(Put().check(&r.guard_ctx())); + assert!(!Put().check(&req.guard_ctx())); let r = TestRequest::default() .method(Method::DELETE) - .to_http_request(); - assert!(Delete().check(r.head())); - assert!(!Delete().check(req.head())); + .to_srv_request(); + assert!(Delete().check(&r.guard_ctx())); + assert!(!Delete().check(&req.guard_ctx())); - let r = TestRequest::default() - .method(Method::HEAD) - .to_http_request(); - assert!(Head().check(r.head())); - assert!(!Head().check(req.head())); + let r = TestRequest::default().method(Method::HEAD).to_srv_request(); + assert!(Head().check(&r.guard_ctx())); + assert!(!Head().check(&req.guard_ctx())); let r = TestRequest::default() .method(Method::OPTIONS) - .to_http_request(); - assert!(Options().check(r.head())); - assert!(!Options().check(req.head())); + .to_srv_request(); + assert!(Options().check(&r.guard_ctx())); + assert!(!Options().check(&req.guard_ctx())); let r = TestRequest::default() .method(Method::CONNECT) - .to_http_request(); - assert!(Connect().check(r.head())); - assert!(!Connect().check(req.head())); + .to_srv_request(); + assert!(Connect().check(&r.guard_ctx())); + assert!(!Connect().check(&req.guard_ctx())); let r = TestRequest::default() .method(Method::PATCH) - .to_http_request(); - assert!(Patch().check(r.head())); - assert!(!Patch().check(req.head())); + .to_srv_request(); + assert!(Patch().check(&r.guard_ctx())); + assert!(!Patch().check(&req.guard_ctx())); let r = TestRequest::default() .method(Method::TRACE) - .to_http_request(); - assert!(Trace().check(r.head())); - assert!(!Trace().check(req.head())); + .to_srv_request(); + assert!(Trace().check(&r.guard_ctx())); + assert!(!Trace().check(&req.guard_ctx())); } #[test] fn test_preds() { let r = TestRequest::default() .method(Method::TRACE) - .to_http_request(); + .to_srv_request(); - assert!(Not(Get()).check(r.head())); - assert!(!Not(Trace()).check(r.head())); + assert!(Not(Get()).check(&r.guard_ctx())); + assert!(!Not(Trace()).check(&r.guard_ctx())); - assert!(All(Trace()).and(Trace()).check(r.head())); - assert!(!All(Get()).and(Trace()).check(r.head())); + assert!(All(Trace()).and(Trace()).check(&r.guard_ctx())); + assert!(!All(Get()).and(Trace()).check(&r.guard_ctx())); - assert!(Any(Get()).or(Trace()).check(r.head())); - assert!(!Any(Get()).or(Get()).check(r.head())); + assert!(Any(Get()).or(Trace()).check(&r.guard_ctx())); + assert!(!Any(Get()).or(Get()).check(&r.guard_ctx())); } } diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index 18dcaeefa..3ab908481 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -225,7 +225,7 @@ mod tests { .service(web::resource("/v1/something").to(HttpResponse::Ok)) .service( web::resource("/v2/something") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -261,7 +261,7 @@ mod tests { .service(web::resource("/v1/something").to(HttpResponse::Ok)) .service( web::resource("/v2/something") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -294,7 +294,7 @@ mod tests { let app = init_service( App::new().wrap(NormalizePath(TrailingSlash::Trim)).service( web::resource("/") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -318,7 +318,7 @@ mod tests { .service(web::resource("/v1/something/").to(HttpResponse::Ok)) .service( web::resource("/v2/something/") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -353,7 +353,7 @@ mod tests { .wrap(NormalizePath(TrailingSlash::Always)) .service( web::resource("/") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) @@ -378,7 +378,7 @@ mod tests { .service(web::resource("/v1/").to(HttpResponse::Ok)) .service( web::resource("/v2/something") - .guard(fn_guard(|req| req.uri.query() == Some("query=test"))) + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) .to(HttpResponse::Ok), ), ) diff --git a/src/route.rs b/src/route.rs index 6d6fca4b7..6be1431cf 100644 --- a/src/route.rs +++ b/src/route.rs @@ -65,9 +65,12 @@ pub struct RouteService { } impl RouteService { + // TODO: does this need to take &mut ? pub fn check(&self, req: &mut ServiceRequest) -> bool { - for f in self.guards.iter() { - if !f.check(req.head()) { + for guard in self.guards.iter() { + let guard_ctx = req.guard_ctx(); + + if !guard.check(&guard_ctx) { return false; } } diff --git a/src/scope.rs b/src/scope.rs index 176e0d5a0..b4618bb6c 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -538,12 +538,15 @@ impl Service for ScopeService { fn call(&self, mut req: ServiceRequest) -> Self::Future { let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { - for f in guards { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in guards { + if !guard.check(&guard_ctx) { return false; } } } + true }); diff --git a/src/service.rs b/src/service.rs index d5c381fa4..975556197 100644 --- a/src/service.rs +++ b/src/service.rs @@ -21,7 +21,7 @@ use cookie::{Cookie, ParseError as CookieParseError}; use crate::{ config::{AppConfig, AppService}, dev::ensure_leading_slash, - guard::Guard, + guard::{Guard, GuardContext}, info::ConnectionInfo, rmap::ResourceMap, Error, HttpRequest, HttpResponse, @@ -172,7 +172,7 @@ impl ServiceRequest { self.head().uri.path() } - /// Counterpart to [`HttpRequest::query_string`](super::HttpRequest::query_string()). + /// Counterpart to [`HttpRequest::query_string`]. #[inline] pub fn query_string(&self) -> &str { self.req.query_string() @@ -208,13 +208,13 @@ impl ServiceRequest { self.req.match_info() } - /// Counterpart to [`HttpRequest::match_name`](super::HttpRequest::match_name()). + /// Counterpart to [`HttpRequest::match_name`]. #[inline] pub fn match_name(&self) -> Option<&str> { self.req.match_name() } - /// Counterpart to [`HttpRequest::match_pattern`](super::HttpRequest::match_pattern()). + /// Counterpart to [`HttpRequest::match_pattern`]. #[inline] pub fn match_pattern(&self) -> Option { self.req.match_pattern() @@ -238,7 +238,7 @@ impl ServiceRequest { self.req.app_config() } - /// Counterpart to [`HttpRequest::app_data`](super::HttpRequest::app_data()). + /// Counterpart to [`HttpRequest::app_data`]. #[inline] pub fn app_data(&self) -> Option<&T> { for container in self.req.inner.app_data.iter().rev() { @@ -250,19 +250,33 @@ impl ServiceRequest { None } - /// Counterpart to [`HttpRequest::conn_data`](super::HttpRequest::conn_data()). + /// Counterpart to [`HttpRequest::conn_data`]. #[inline] pub fn conn_data(&self) -> Option<&T> { self.req.conn_data() } + /// Counterpart to [`HttpRequest::req_data`]. + #[inline] + pub fn req_data(&self) -> Ref<'_, Extensions> { + self.req.req_data() + } + + /// Counterpart to [`HttpRequest::req_data_mut`]. + #[inline] + pub fn req_data_mut(&self) -> RefMut<'_, Extensions> { + self.req.req_data_mut() + } + #[cfg(feature = "cookies")] + #[inline] pub fn cookies(&self) -> Result>>, CookieParseError> { self.req.cookies() } /// Return request cookie. #[cfg(feature = "cookies")] + #[inline] pub fn cookie(&self, name: &str) -> Option> { self.req.cookie(name) } @@ -283,6 +297,14 @@ impl ServiceRequest { .app_data .push(extensions); } + + /// Creates a context object for use with a [guard](crate::guard). + /// + /// Useful if you are implementing + #[inline] + pub fn guard_ctx(&self) -> GuardContext<'_> { + GuardContext { req: self } + } } impl Resource for ServiceRequest {