rework Guard trait

This commit is contained in:
Rob Ede 2021-12-27 22:03:28 +00:00
parent 36193b0a50
commit 53ae8ec1ba
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
9 changed files with 267 additions and 186 deletions

View File

@ -1,8 +1,8 @@
use std::{fmt, io, ops::Deref, path::PathBuf, rc::Rc}; use std::{fmt, io, ops::Deref, path::PathBuf, rc::Rc};
use actix_service::Service;
use actix_web::{ use actix_web::{
dev::{ServiceRequest, ServiceResponse}, body::BoxBody,
dev::{Service, ServiceRequest, ServiceResponse},
error::Error, error::Error,
guard::Guard, guard::Guard,
http::{header, Method}, http::{header, Method},
@ -94,7 +94,7 @@ impl fmt::Debug for FilesService {
} }
impl Service<ServiceRequest> for FilesService { impl Service<ServiceRequest> for FilesService {
type Response = ServiceResponse; type Response = ServiceResponse<BoxBody>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
@ -103,7 +103,7 @@ impl Service<ServiceRequest> for FilesService {
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
let is_method_valid = if let Some(guard) = &self.guards { let is_method_valid = if let Some(guard) = &self.guards {
// execute user defined guards // execute user defined guards
(**guard).check(req.head()) (**guard).check(&req.guard_ctx())
} else { } else {
// default behavior // default behavior
matches!(*req.method(), Method::HEAD | Method::GET) matches!(*req.method(), Method::HEAD | Method::GET)

View File

@ -1,4 +1,4 @@
use std::{cell::RefCell, rc::Rc}; use std::{cell::RefCell, ops, rc::Rc};
use bitflags::bitflags; use bitflags::bitflags;
@ -49,7 +49,7 @@ impl<T: Head> Message<T> {
} }
} }
impl<T: Head> std::ops::Deref for Message<T> { impl<T: Head> ops::Deref for Message<T> {
type Target = T; type Target = T;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -57,7 +57,7 @@ impl<T: Head> std::ops::Deref for Message<T> {
} }
} }
impl<T: Head> std::ops::DerefMut for Message<T> { impl<T: Head> ops::DerefMut for Message<T> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
Rc::get_mut(&mut self.head).expect("Multiple copies exist") Rc::get_mut(&mut self.head).expect("Multiple copies exist")
} }

View File

@ -1,14 +1,16 @@
use std::{cell::RefCell, mem, rc::Rc}; use std::{cell::RefCell, mem, rc::Rc};
use actix_http::{Extensions, Request}; use actix_http::Request;
use actix_router::{Path, ResourceDef, Router, Url}; use actix_router::{Path, ResourceDef, Router, Url};
use actix_service::{boxed, fn_service, Service, ServiceFactory}; use actix_service::{boxed, fn_service, Service, ServiceFactory};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::future::join_all; use futures_util::future::join_all;
use crate::{ use crate::{
body::BoxBody,
config::{AppConfig, AppService}, config::{AppConfig, AppService},
data::FnDataFactory, data::FnDataFactory,
dev::Extensions,
guard::Guard, guard::Guard,
request::{HttpRequest, HttpRequestPool}, request::{HttpRequest, HttpRequestPool},
rmap::ResourceMap, rmap::ResourceMap,
@ -297,7 +299,7 @@ pub struct AppRouting {
} }
impl Service<ServiceRequest> for AppRouting { impl Service<ServiceRequest> for AppRouting {
type Response = ServiceResponse; type Response = ServiceResponse<BoxBody>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
@ -306,12 +308,15 @@ impl Service<ServiceRequest> for AppRouting {
fn call(&self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
let res = self.router.recognize_fn(&mut req, |req, guards| { let res = self.router.recognize_fn(&mut req, |req, guards| {
if let Some(ref guards) = guards { if let Some(ref guards) = guards {
for f in guards { let guard_ctx = req.guard_ctx();
if !f.check(req.head()) {
for guard in guards {
if !guard.check(&guard_ctx) {
return false; return false;
} }
} }
} }
true true
}); });

View File

@ -3,6 +3,16 @@
//! Most users will not have to interact with the types in this module, but it is useful for those //! Most users will not have to interact with the types in this module, but it is useful for those
//! writing extractors, middleware, libraries, or interacting with the service API directly. //! writing extractors, middleware, libraries, or interacting with the service API directly.
pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead};
pub use actix_router::{Path, ResourceDef, ResourcePath, Url};
pub use actix_server::{Server, ServerHandle};
pub use actix_service::{
always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform,
};
#[cfg(feature = "__compress")]
pub use actix_http::encoding::Decoder as Decompress;
pub use crate::config::{AppConfig, AppService}; pub use crate::config::{AppConfig, AppService};
#[doc(hidden)] #[doc(hidden)]
pub use crate::handler::Handler; pub use crate::handler::Handler;
@ -14,16 +24,6 @@ pub use crate::types::form::UrlEncoded;
pub use crate::types::json::JsonBody; pub use crate::types::json::JsonBody;
pub use crate::types::readlines::Readlines; pub use crate::types::readlines::Readlines;
pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead};
pub use actix_router::{Path, ResourceDef, ResourcePath, Url};
pub use actix_server::{Server, ServerHandle};
pub use actix_service::{
always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform,
};
#[cfg(feature = "__compress")]
pub use actix_http::encoding::Decoder as Decompress;
use crate::http::header::ContentEncoding; use crate::http::header::ContentEncoding;
use actix_router::Patterns; use actix_router::Patterns;
@ -46,7 +46,6 @@ pub(crate) fn ensure_leading_slash(mut patterns: Patterns) -> Patterns {
patterns patterns
} }
struct Enc(ContentEncoding);
/// Helper trait that allows to set specific encoding for response. /// Helper trait that allows to set specific encoding for response.
pub trait BodyEncoding { pub trait BodyEncoding {
@ -70,6 +69,8 @@ impl BodyEncoding for actix_http::ResponseBuilder {
} }
} }
struct Enc(ContentEncoding);
impl<B> BodyEncoding for actix_http::Response<B> { impl<B> BodyEncoding for actix_http::Response<B> {
fn get_encoding(&self) -> Option<ContentEncoding> { fn get_encoding(&self) -> Option<ContentEncoding> {
self.extensions().get::<Enc>().map(|enc| enc.0) self.extensions().get::<Enc>().map(|enc| enc.0)

View File

@ -1,16 +1,14 @@
//! Route match guards. //! Route guards.
//! //!
//! Guards are one of the ways how actix-web router chooses a //! Guards are one of the ways how actix-web router chooses a handler service. In essence it is just
//! handler service. In essence it is just a function that accepts a //! a function that accepts a reference to a `RequestHead` instance and returns a boolean. It is
//! reference to a `RequestHead` instance and returns a boolean. //! possible to add guards to *scopes*, *resources* and *routes*. Actix provide several guards by
//! It is possible to add guards to *scopes*, *resources* //! default, like various HTTP methods, header, etc. To become a guard, type must implement the
//! and *routes*. Actix provide several guards by default, like various //! `Guard` trait. Simple functions could be guards as well.
//! http methods, header, etc. To become a guard, type must implement `Guard`
//! trait. Simple functions could be guards as well.
//! //!
//! Guards can not modify the request object. But it is possible //! Guards can not modify the request object. But it is possible to store extra attributes on a
//! to store extra attributes on a request by using the `Extensions` container. //! request by using the `Extensions` container. Extensions containers are available via the
//! Extensions containers are available via the `RequestHead::extensions()` method. //! `RequestHead::extensions()` method.
//! //!
//! ``` //! ```
//! use actix_web::{web, http, dev, guard, App, HttpResponse}; //! use actix_web::{web, http, dev, guard, App, HttpResponse};
@ -18,31 +16,56 @@
//! App::new().service(web::resource("/index.html").route( //! App::new().service(web::resource("/index.html").route(
//! web::route() //! web::route()
//! .guard(guard::Post()) //! .guard(guard::Post())
//! .guard(guard::fn_guard(|head| head.method == http::Method::GET)) //! .guard(guard::fn_guard(|ctx| ctx.head().method == http::Method::GET))
//! .to(|| HttpResponse::MethodNotAllowed())) //! .to(|| HttpResponse::MethodNotAllowed()))
//! ); //! );
//! ``` //! ```
#![allow(non_snake_case)] use std::{
cell::{Ref, RefMut},
convert::TryFrom,
rc::Rc,
};
use std::rc::Rc; use actix_http::{header, uri::Uri, Extensions, Method as HttpMethod, RequestHead};
use std::{convert::TryFrom, ops::Deref};
use actix_http::{header, uri::Uri, Method as HttpMethod, RequestHead}; use crate::service::ServiceRequest;
#[derive(Debug)]
pub struct GuardContext<'a> {
pub(crate) req: &'a ServiceRequest,
}
impl<'a> GuardContext<'a> {
#[inline]
pub fn head(&self) -> &RequestHead {
self.req.head()
}
#[inline]
pub fn req_data(&self) -> Ref<'a, Extensions> {
self.req.req_data()
}
#[inline]
pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
self.req.req_data_mut()
}
}
/// Trait defines resource guards. Guards are used for route selection. /// Trait defines resource guards. Guards are used for route selection.
/// ///
/// Guards can not modify the request object. But it is possible /// Guards can not modify the request object. But it is possible to store extra attributes on a
/// to store extra attributes on a request by using the `Extensions` container. /// request by using the `Extensions` container. Extensions containers are available via the
/// Extensions containers are available via the `RequestHead::extensions()` method. /// `RequestHead::extensions()` method.
pub trait Guard { pub trait Guard {
/// Check if request matches predicate /// Check if request matches predicate
fn check(&self, request: &RequestHead) -> bool; fn check(&self, ctx: &GuardContext<'_>) -> bool;
} }
impl Guard for Rc<dyn Guard> { impl Guard for Rc<dyn Guard> {
fn check(&self, request: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
self.deref().check(request) (**self).check(ctx)
} }
} }
@ -51,39 +74,40 @@ impl Guard for Rc<dyn Guard> {
/// ``` /// ```
/// use actix_web::{guard, web, App, HttpResponse}; /// use actix_web::{guard, web, App, HttpResponse};
/// ///
/// App::new().service(web::resource("/index.html").route( /// App::new().service(
/// web::route() /// web::resource("/index.html").route(
/// .guard( /// web::route()
/// guard::fn_guard( /// .guard(guard::fn_guard(|ctx| {
/// |req| req.headers() /// ctx.head().headers().contains_key("content-type")
/// .contains_key("content-type"))) /// }))
/// .to(|| HttpResponse::MethodNotAllowed())) /// .to(|| HttpResponse::MethodNotAllowed()),
/// ); /// ),
/// );
/// ``` /// ```
pub fn fn_guard<F>(f: F) -> impl Guard pub fn fn_guard<F>(f: F) -> impl Guard
where where
F: Fn(&RequestHead) -> bool, F: Fn(&GuardContext<'_>) -> bool,
{ {
FnGuard(f) FnGuard(f)
} }
struct FnGuard<F: Fn(&RequestHead) -> bool>(F); struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
impl<F> Guard for FnGuard<F> impl<F> Guard for FnGuard<F>
where where
F: Fn(&RequestHead) -> bool, F: Fn(&GuardContext<'_>) -> bool,
{ {
fn check(&self, head: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
(self.0)(head) (self.0)(ctx)
} }
} }
impl<F> Guard for F impl<F> Guard for F
where where
F: Fn(&RequestHead) -> bool, F: Fn(&GuardContext<'_>) -> bool,
{ {
fn check(&self, head: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
(self)(head) (self)(ctx)
} }
} }
@ -98,28 +122,34 @@ where
/// .to(|| HttpResponse::MethodNotAllowed())) /// .to(|| HttpResponse::MethodNotAllowed()))
/// ); /// );
/// ``` /// ```
#[allow(non_snake_case)]
pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard { pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
AnyGuard(vec![Box::new(guard)]) AnyGuard {
guards: vec![Box::new(guard)],
}
} }
/// Matches any of supplied guards. /// Matches any of supplied guards.
pub struct AnyGuard(Vec<Box<dyn Guard>>); pub struct AnyGuard {
guards: Vec<Box<dyn Guard>>,
}
impl AnyGuard { impl AnyGuard {
/// Add guard to a list of guards to check /// Add guard to a list of guards to check
pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self { pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
self.0.push(Box::new(guard)); self.guards.push(Box::new(guard));
self self
} }
} }
impl Guard for AnyGuard { impl Guard for AnyGuard {
fn check(&self, req: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
for p in &self.0 { for guard in &self.guards {
if p.check(req) { if guard.check(ctx) {
return true; return true;
} }
} }
false false
} }
} }
@ -136,25 +166,30 @@ impl Guard for AnyGuard {
/// .to(|| HttpResponse::MethodNotAllowed())) /// .to(|| HttpResponse::MethodNotAllowed()))
/// ); /// );
/// ``` /// ```
#[allow(non_snake_case)]
pub fn All<F: Guard + 'static>(guard: F) -> AllGuard { pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
AllGuard(vec![Box::new(guard)]) AllGuard {
guards: vec![Box::new(guard)],
}
} }
/// Matches if all of supplied guards. /// Matches if all of supplied guards.
pub struct AllGuard(Vec<Box<dyn Guard>>); pub struct AllGuard {
guards: Vec<Box<dyn Guard>>,
}
impl AllGuard { impl AllGuard {
/// Add new guard to the list of guards to check /// Add new guard to the list of guards to check
pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self { pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
self.0.push(Box::new(guard)); self.guards.push(Box::new(guard));
self self
} }
} }
impl Guard for AllGuard { impl Guard for AllGuard {
fn check(&self, request: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
for p in &self.0 { for guard in &self.guards {
if !p.check(request) { if !guard.check(ctx) {
return false; return false;
} }
} }
@ -163,82 +198,91 @@ impl Guard for AllGuard {
} }
/// Return guard that matches if supplied guard does not match. /// Return guard that matches if supplied guard does not match.
pub fn Not<F: Guard + 'static>(guard: F) -> NotGuard { #[allow(non_snake_case)]
pub fn Not<F: Guard + 'static>(guard: F) -> impl Guard {
NotGuard(Box::new(guard)) NotGuard(Box::new(guard))
} }
#[doc(hidden)] struct NotGuard(Box<dyn Guard>);
pub struct NotGuard(Box<dyn Guard>);
impl Guard for NotGuard { impl Guard for NotGuard {
fn check(&self, request: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
!self.0.check(request) !self.0.check(ctx)
} }
} }
/// HTTP method guard. /// HTTP method guard.
#[doc(hidden)] struct MethodGuard(HttpMethod);
pub struct MethodGuard(HttpMethod);
impl Guard for MethodGuard { impl Guard for MethodGuard {
fn check(&self, request: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
request.method == self.0 ctx.head().method == self.0
} }
} }
/// Guard to match *GET* HTTP method. /// Guard to match *GET* HTTP method.
pub fn Get() -> MethodGuard { #[allow(non_snake_case)]
pub fn Get() -> impl Guard {
MethodGuard(HttpMethod::GET) MethodGuard(HttpMethod::GET)
} }
/// Predicate to match *POST* HTTP method. /// Predicate to match *POST* HTTP method.
pub fn Post() -> MethodGuard { #[allow(non_snake_case)]
pub fn Post() -> impl Guard {
MethodGuard(HttpMethod::POST) MethodGuard(HttpMethod::POST)
} }
/// Predicate to match *PUT* HTTP method. /// Predicate to match *PUT* HTTP method.
pub fn Put() -> MethodGuard { #[allow(non_snake_case)]
pub fn Put() -> impl Guard {
MethodGuard(HttpMethod::PUT) MethodGuard(HttpMethod::PUT)
} }
/// Predicate to match *DELETE* HTTP method. /// Predicate to match *DELETE* HTTP method.
pub fn Delete() -> MethodGuard { #[allow(non_snake_case)]
pub fn Delete() -> impl Guard {
MethodGuard(HttpMethod::DELETE) MethodGuard(HttpMethod::DELETE)
} }
/// Predicate to match *HEAD* HTTP method. /// Predicate to match *HEAD* HTTP method.
pub fn Head() -> MethodGuard { #[allow(non_snake_case)]
pub fn Head() -> impl Guard {
MethodGuard(HttpMethod::HEAD) MethodGuard(HttpMethod::HEAD)
} }
/// Predicate to match *OPTIONS* HTTP method. /// Predicate to match *OPTIONS* HTTP method.
pub fn Options() -> MethodGuard { #[allow(non_snake_case)]
pub fn Options() -> impl Guard {
MethodGuard(HttpMethod::OPTIONS) MethodGuard(HttpMethod::OPTIONS)
} }
/// Predicate to match *CONNECT* HTTP method. /// Predicate to match *CONNECT* HTTP method.
pub fn Connect() -> MethodGuard { #[allow(non_snake_case)]
pub fn Connect() -> impl Guard {
MethodGuard(HttpMethod::CONNECT) MethodGuard(HttpMethod::CONNECT)
} }
/// Predicate to match *PATCH* HTTP method. /// Predicate to match *PATCH* HTTP method.
pub fn Patch() -> MethodGuard { #[allow(non_snake_case)]
pub fn Patch() -> impl Guard {
MethodGuard(HttpMethod::PATCH) MethodGuard(HttpMethod::PATCH)
} }
/// Predicate to match *TRACE* HTTP method. /// Predicate to match *TRACE* HTTP method.
pub fn Trace() -> MethodGuard { #[allow(non_snake_case)]
pub fn Trace() -> impl Guard {
MethodGuard(HttpMethod::TRACE) MethodGuard(HttpMethod::TRACE)
} }
/// Predicate to match specified HTTP method. /// Predicate to match specified HTTP method.
pub fn Method(method: HttpMethod) -> MethodGuard { #[allow(non_snake_case)]
pub fn Method(method: HttpMethod) -> impl Guard {
MethodGuard(method) MethodGuard(method)
} }
/// Return predicate that matches if request contains specified header and /// Return predicate that matches if request contains specified header and value.
/// value. #[allow(non_snake_case)]
pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard { pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
HeaderGuard( HeaderGuard(
header::HeaderName::try_from(name).unwrap(), header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value), header::HeaderValue::from_static(value),
@ -246,13 +290,14 @@ pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard {
} }
#[doc(hidden)] #[doc(hidden)]
pub struct HeaderGuard(header::HeaderName, header::HeaderValue); struct HeaderGuard(header::HeaderName, header::HeaderValue);
impl Guard for HeaderGuard { impl Guard for HeaderGuard {
fn check(&self, req: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
if let Some(val) = req.headers.get(&self.0) { if let Some(val) = ctx.head().headers.get(&self.0) {
return val == self.1; return val == self.1;
} }
false false
} }
} }
@ -268,48 +313,54 @@ impl Guard for HeaderGuard {
/// .to(|| HttpResponse::MethodNotAllowed()) /// .to(|| HttpResponse::MethodNotAllowed())
/// ); /// );
/// ``` /// ```
#[allow(non_snake_case)]
pub fn Host<H: AsRef<str>>(host: H) -> HostGuard { pub fn Host<H: AsRef<str>>(host: H) -> HostGuard {
HostGuard(host.as_ref().to_string(), None) HostGuard {
host: host.as_ref().to_string(),
scheme: None,
}
} }
fn get_host_uri(req: &RequestHead) -> Option<Uri> { fn get_host_uri(req: &RequestHead) -> Option<Uri> {
use core::str::FromStr;
req.headers req.headers
.get(header::HOST) .get(header::HOST)
.and_then(|host_value| host_value.to_str().ok()) .and_then(|host_value| host_value.to_str().ok())
.or_else(|| req.uri.host()) .or_else(|| req.uri.host())
.map(|host: &str| Uri::from_str(host).ok()) .map(|host| host.parse().ok())
.and_then(|host_success| host_success) .and_then(|host_success| host_success)
} }
#[doc(hidden)] #[doc(hidden)]
pub struct HostGuard(String, Option<String>); pub struct HostGuard {
host: String,
scheme: Option<String>,
}
impl HostGuard { impl HostGuard {
/// Set request scheme to match /// Set request scheme to match
pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard { pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard {
self.1 = Some(scheme.as_ref().to_string()); self.scheme = Some(scheme.as_ref().to_string());
self self
} }
} }
impl Guard for HostGuard { impl Guard for HostGuard {
fn check(&self, req: &RequestHead) -> bool { fn check(&self, ctx: &GuardContext<'_>) -> bool {
let req_host_uri = if let Some(uri) = get_host_uri(req) { let req_host_uri = if let Some(uri) = get_host_uri(ctx.head()) {
uri uri
} else { } else {
return false; return false;
}; };
if let Some(uri_host) = req_host_uri.host() { if let Some(uri_host) = req_host_uri.host() {
if self.0 != uri_host { if self.host != uri_host {
return false; return false;
} }
} else { } else {
return false; return false;
} }
if let Some(ref scheme) = self.1 { if let Some(ref scheme) = self.scheme {
if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() { if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() {
return scheme == req_host_uri_scheme; return scheme == req_host_uri_scheme;
} }
@ -330,16 +381,16 @@ mod tests {
fn test_header() { fn test_header() {
let req = TestRequest::default() let req = TestRequest::default()
.insert_header((header::TRANSFER_ENCODING, "chunked")) .insert_header((header::TRANSFER_ENCODING, "chunked"))
.to_http_request(); .to_srv_request();
let pred = Header("transfer-encoding", "chunked"); let pred = Header("transfer-encoding", "chunked");
assert!(pred.check(req.head())); assert!(pred.check(&req.guard_ctx()));
let pred = Header("transfer-encoding", "other"); let pred = Header("transfer-encoding", "other");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Header("content-type", "other"); let pred = Header("content-type", "other");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
} }
#[test] #[test]
@ -349,25 +400,25 @@ mod tests {
header::HOST, header::HOST,
header::HeaderValue::from_static("www.rust-lang.org"), header::HeaderValue::from_static("www.rust-lang.org"),
)) ))
.to_http_request(); .to_srv_request();
let pred = Host("www.rust-lang.org"); let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head())); assert!(pred.check(&req.guard_ctx()));
let pred = Host("www.rust-lang.org").scheme("https"); let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head())); assert!(pred.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org"); let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org").scheme("https"); let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("crates.io"); let pred = Host("crates.io");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("localhost"); let pred = Host("localhost");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
} }
#[test] #[test]
@ -377,121 +428,117 @@ mod tests {
header::HOST, header::HOST,
header::HeaderValue::from_static("https://www.rust-lang.org"), header::HeaderValue::from_static("https://www.rust-lang.org"),
)) ))
.to_http_request(); .to_srv_request();
let pred = Host("www.rust-lang.org").scheme("https"); let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head())); assert!(pred.check(&req.guard_ctx()));
let pred = Host("www.rust-lang.org"); let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head())); assert!(pred.check(&req.guard_ctx()));
let pred = Host("www.rust-lang.org").scheme("http"); let pred = Host("www.rust-lang.org").scheme("http");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org"); let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org").scheme("https"); let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("crates.io").scheme("https"); let pred = Host("crates.io").scheme("https");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("localhost"); let pred = Host("localhost");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
} }
#[test] #[test]
fn test_host_without_header() { fn test_host_without_header() {
let req = TestRequest::default() let req = TestRequest::default()
.uri("www.rust-lang.org") .uri("www.rust-lang.org")
.to_http_request(); .to_srv_request();
let pred = Host("www.rust-lang.org"); let pred = Host("www.rust-lang.org");
assert!(pred.check(req.head())); assert!(pred.check(&req.guard_ctx()));
let pred = Host("www.rust-lang.org").scheme("https"); let pred = Host("www.rust-lang.org").scheme("https");
assert!(pred.check(req.head())); assert!(pred.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org"); let pred = Host("blog.rust-lang.org");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("blog.rust-lang.org").scheme("https"); let pred = Host("blog.rust-lang.org").scheme("https");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("crates.io"); let pred = Host("crates.io");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
let pred = Host("localhost"); let pred = Host("localhost");
assert!(!pred.check(req.head())); assert!(!pred.check(&req.guard_ctx()));
} }
#[test] #[test]
fn test_methods() { fn test_methods() {
let req = TestRequest::default().to_http_request(); let req = TestRequest::default().to_srv_request();
let req2 = TestRequest::default() let req2 = TestRequest::default().method(Method::POST).to_srv_request();
.method(Method::POST)
.to_http_request();
assert!(Get().check(req.head())); assert!(Get().check(&req.guard_ctx()));
assert!(!Get().check(req2.head())); assert!(!Get().check(&req2.guard_ctx()));
assert!(Post().check(req2.head())); assert!(Post().check(&req2.guard_ctx()));
assert!(!Post().check(req.head())); assert!(!Post().check(&req.guard_ctx()));
let r = TestRequest::default().method(Method::PUT).to_http_request(); let r = TestRequest::default().method(Method::PUT).to_srv_request();
assert!(Put().check(r.head())); assert!(Put().check(&r.guard_ctx()));
assert!(!Put().check(req.head())); assert!(!Put().check(&req.guard_ctx()));
let r = TestRequest::default() let r = TestRequest::default()
.method(Method::DELETE) .method(Method::DELETE)
.to_http_request(); .to_srv_request();
assert!(Delete().check(r.head())); assert!(Delete().check(&r.guard_ctx()));
assert!(!Delete().check(req.head())); assert!(!Delete().check(&req.guard_ctx()));
let r = TestRequest::default() let r = TestRequest::default().method(Method::HEAD).to_srv_request();
.method(Method::HEAD) assert!(Head().check(&r.guard_ctx()));
.to_http_request(); assert!(!Head().check(&req.guard_ctx()));
assert!(Head().check(r.head()));
assert!(!Head().check(req.head()));
let r = TestRequest::default() let r = TestRequest::default()
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_http_request(); .to_srv_request();
assert!(Options().check(r.head())); assert!(Options().check(&r.guard_ctx()));
assert!(!Options().check(req.head())); assert!(!Options().check(&req.guard_ctx()));
let r = TestRequest::default() let r = TestRequest::default()
.method(Method::CONNECT) .method(Method::CONNECT)
.to_http_request(); .to_srv_request();
assert!(Connect().check(r.head())); assert!(Connect().check(&r.guard_ctx()));
assert!(!Connect().check(req.head())); assert!(!Connect().check(&req.guard_ctx()));
let r = TestRequest::default() let r = TestRequest::default()
.method(Method::PATCH) .method(Method::PATCH)
.to_http_request(); .to_srv_request();
assert!(Patch().check(r.head())); assert!(Patch().check(&r.guard_ctx()));
assert!(!Patch().check(req.head())); assert!(!Patch().check(&req.guard_ctx()));
let r = TestRequest::default() let r = TestRequest::default()
.method(Method::TRACE) .method(Method::TRACE)
.to_http_request(); .to_srv_request();
assert!(Trace().check(r.head())); assert!(Trace().check(&r.guard_ctx()));
assert!(!Trace().check(req.head())); assert!(!Trace().check(&req.guard_ctx()));
} }
#[test] #[test]
fn test_preds() { fn test_preds() {
let r = TestRequest::default() let r = TestRequest::default()
.method(Method::TRACE) .method(Method::TRACE)
.to_http_request(); .to_srv_request();
assert!(Not(Get()).check(r.head())); assert!(Not(Get()).check(&r.guard_ctx()));
assert!(!Not(Trace()).check(r.head())); assert!(!Not(Trace()).check(&r.guard_ctx()));
assert!(All(Trace()).and(Trace()).check(r.head())); assert!(All(Trace()).and(Trace()).check(&r.guard_ctx()));
assert!(!All(Get()).and(Trace()).check(r.head())); assert!(!All(Get()).and(Trace()).check(&r.guard_ctx()));
assert!(Any(Get()).or(Trace()).check(r.head())); assert!(Any(Get()).or(Trace()).check(&r.guard_ctx()));
assert!(!Any(Get()).or(Get()).check(r.head())); assert!(!Any(Get()).or(Get()).check(&r.guard_ctx()));
} }
} }

View File

@ -225,7 +225,7 @@ mod tests {
.service(web::resource("/v1/something").to(HttpResponse::Ok)) .service(web::resource("/v1/something").to(HttpResponse::Ok))
.service( .service(
web::resource("/v2/something") web::resource("/v2/something")
.guard(fn_guard(|req| req.uri.query() == Some("query=test"))) .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok), .to(HttpResponse::Ok),
), ),
) )
@ -261,7 +261,7 @@ mod tests {
.service(web::resource("/v1/something").to(HttpResponse::Ok)) .service(web::resource("/v1/something").to(HttpResponse::Ok))
.service( .service(
web::resource("/v2/something") web::resource("/v2/something")
.guard(fn_guard(|req| req.uri.query() == Some("query=test"))) .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok), .to(HttpResponse::Ok),
), ),
) )
@ -294,7 +294,7 @@ mod tests {
let app = init_service( let app = init_service(
App::new().wrap(NormalizePath(TrailingSlash::Trim)).service( App::new().wrap(NormalizePath(TrailingSlash::Trim)).service(
web::resource("/") web::resource("/")
.guard(fn_guard(|req| req.uri.query() == Some("query=test"))) .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok), .to(HttpResponse::Ok),
), ),
) )
@ -318,7 +318,7 @@ mod tests {
.service(web::resource("/v1/something/").to(HttpResponse::Ok)) .service(web::resource("/v1/something/").to(HttpResponse::Ok))
.service( .service(
web::resource("/v2/something/") web::resource("/v2/something/")
.guard(fn_guard(|req| req.uri.query() == Some("query=test"))) .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok), .to(HttpResponse::Ok),
), ),
) )
@ -353,7 +353,7 @@ mod tests {
.wrap(NormalizePath(TrailingSlash::Always)) .wrap(NormalizePath(TrailingSlash::Always))
.service( .service(
web::resource("/") web::resource("/")
.guard(fn_guard(|req| req.uri.query() == Some("query=test"))) .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok), .to(HttpResponse::Ok),
), ),
) )
@ -378,7 +378,7 @@ mod tests {
.service(web::resource("/v1/").to(HttpResponse::Ok)) .service(web::resource("/v1/").to(HttpResponse::Ok))
.service( .service(
web::resource("/v2/something") web::resource("/v2/something")
.guard(fn_guard(|req| req.uri.query() == Some("query=test"))) .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
.to(HttpResponse::Ok), .to(HttpResponse::Ok),
), ),
) )

View File

@ -65,9 +65,12 @@ pub struct RouteService {
} }
impl RouteService { impl RouteService {
// TODO: does this need to take &mut ?
pub fn check(&self, req: &mut ServiceRequest) -> bool { pub fn check(&self, req: &mut ServiceRequest) -> bool {
for f in self.guards.iter() { for guard in self.guards.iter() {
if !f.check(req.head()) { let guard_ctx = req.guard_ctx();
if !guard.check(&guard_ctx) {
return false; return false;
} }
} }

View File

@ -538,12 +538,15 @@ impl Service<ServiceRequest> for ScopeService {
fn call(&self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
let res = self.router.recognize_fn(&mut req, |req, guards| { let res = self.router.recognize_fn(&mut req, |req, guards| {
if let Some(ref guards) = guards { if let Some(ref guards) = guards {
for f in guards { let guard_ctx = req.guard_ctx();
if !f.check(req.head()) {
for guard in guards {
if !guard.check(&guard_ctx) {
return false; return false;
} }
} }
} }
true true
}); });

View File

@ -21,7 +21,7 @@ use cookie::{Cookie, ParseError as CookieParseError};
use crate::{ use crate::{
config::{AppConfig, AppService}, config::{AppConfig, AppService},
dev::ensure_leading_slash, dev::ensure_leading_slash,
guard::Guard, guard::{Guard, GuardContext},
info::ConnectionInfo, info::ConnectionInfo,
rmap::ResourceMap, rmap::ResourceMap,
Error, HttpRequest, HttpResponse, Error, HttpRequest, HttpResponse,
@ -172,7 +172,7 @@ impl ServiceRequest {
self.head().uri.path() self.head().uri.path()
} }
/// Counterpart to [`HttpRequest::query_string`](super::HttpRequest::query_string()). /// Counterpart to [`HttpRequest::query_string`].
#[inline] #[inline]
pub fn query_string(&self) -> &str { pub fn query_string(&self) -> &str {
self.req.query_string() self.req.query_string()
@ -208,13 +208,13 @@ impl ServiceRequest {
self.req.match_info() self.req.match_info()
} }
/// Counterpart to [`HttpRequest::match_name`](super::HttpRequest::match_name()). /// Counterpart to [`HttpRequest::match_name`].
#[inline] #[inline]
pub fn match_name(&self) -> Option<&str> { pub fn match_name(&self) -> Option<&str> {
self.req.match_name() self.req.match_name()
} }
/// Counterpart to [`HttpRequest::match_pattern`](super::HttpRequest::match_pattern()). /// Counterpart to [`HttpRequest::match_pattern`].
#[inline] #[inline]
pub fn match_pattern(&self) -> Option<String> { pub fn match_pattern(&self) -> Option<String> {
self.req.match_pattern() self.req.match_pattern()
@ -238,7 +238,7 @@ impl ServiceRequest {
self.req.app_config() self.req.app_config()
} }
/// Counterpart to [`HttpRequest::app_data`](super::HttpRequest::app_data()). /// Counterpart to [`HttpRequest::app_data`].
#[inline] #[inline]
pub fn app_data<T: 'static>(&self) -> Option<&T> { pub fn app_data<T: 'static>(&self) -> Option<&T> {
for container in self.req.inner.app_data.iter().rev() { for container in self.req.inner.app_data.iter().rev() {
@ -250,19 +250,33 @@ impl ServiceRequest {
None None
} }
/// Counterpart to [`HttpRequest::conn_data`](super::HttpRequest::conn_data()). /// Counterpart to [`HttpRequest::conn_data`].
#[inline] #[inline]
pub fn conn_data<T: 'static>(&self) -> Option<&T> { pub fn conn_data<T: 'static>(&self) -> Option<&T> {
self.req.conn_data() self.req.conn_data()
} }
/// Counterpart to [`HttpRequest::req_data`].
#[inline]
pub fn req_data(&self) -> Ref<'_, Extensions> {
self.req.req_data()
}
/// Counterpart to [`HttpRequest::req_data_mut`].
#[inline]
pub fn req_data_mut(&self) -> RefMut<'_, Extensions> {
self.req.req_data_mut()
}
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
#[inline]
pub fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> { pub fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
self.req.cookies() self.req.cookies()
} }
/// Return request cookie. /// Return request cookie.
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
#[inline]
pub fn cookie(&self, name: &str) -> Option<Cookie<'static>> { pub fn cookie(&self, name: &str) -> Option<Cookie<'static>> {
self.req.cookie(name) self.req.cookie(name)
} }
@ -283,6 +297,14 @@ impl ServiceRequest {
.app_data .app_data
.push(extensions); .push(extensions);
} }
/// Creates a context object for use with a [guard](crate::guard).
///
/// Useful if you are implementing
#[inline]
pub fn guard_ctx(&self) -> GuardContext<'_> {
GuardContext { req: self }
}
} }
impl Resource<Url> for ServiceRequest { impl Resource<Url> for ServiceRequest {