From 3650f6d7b888c823c2af60c507904e4700c77b3a Mon Sep 17 00:00:00 2001
From: Anton Lazarev <antonok35@gmail.com>
Date: Thu, 18 Jul 2019 21:28:43 -0700
Subject: [PATCH] Re-implement Host predicate (#989)

* update HostGuard implementation

* update/add tests for new HostGuard implementation
---
 src/guard.rs | 166 +++++++++++++++++++++++++++++++++++----------------
 1 file changed, 116 insertions(+), 50 deletions(-)

diff --git a/src/guard.rs b/src/guard.rs
index 6522a984..6fd6d1d2 100644
--- a/src/guard.rs
+++ b/src/guard.rs
@@ -26,7 +26,7 @@
 //! ```
 
 #![allow(non_snake_case)]
-use actix_http::http::{self, header, HttpTryFrom};
+use actix_http::http::{self, header, HttpTryFrom, uri::Uri};
 use actix_http::RequestHead;
 
 /// Trait defines resource guards. Guards are used for routes selection.
@@ -256,45 +256,68 @@ impl Guard for HeaderGuard {
     }
 }
 
-// /// Return predicate that matches if request contains specified Host name.
-// ///
-// /// ```rust,ignore
-// /// # extern crate actix_web;
-// /// use actix_web::{pred, App, HttpResponse};
-// ///
-// /// fn main() {
-// ///     App::new().resource("/index.html", |r| {
-// ///         r.route()
-// ///             .guard(pred::Host("www.rust-lang.org"))
-// ///             .f(|_| HttpResponse::MethodNotAllowed())
-// ///     });
-// /// }
-// /// ```
-// pub fn Host<H: AsRef<str>>(host: H) -> HostGuard {
-//     HostGuard(host.as_ref().to_string(), None)
-// }
+/// Return predicate that matches if request contains specified Host name.
+///
+/// ```rust,ignore
+/// # extern crate actix_web;
+/// use actix_web::{guard::Host, App, HttpResponse};
+///
+/// fn main() {
+///     App::new().resource("/index.html", |r| {
+///         r.route()
+///             .guard(Host("www.rust-lang.org"))
+///             .f(|_| HttpResponse::MethodNotAllowed())
+///     });
+/// }
+/// ```
+pub fn Host<H: AsRef<str>>(host: H) -> HostGuard {
+    HostGuard(host.as_ref().to_string(), None)
+}
 
-// #[doc(hidden)]
-// pub struct HostGuard(String, Option<String>);
+fn get_host_uri(req: &RequestHead) -> Option<Uri> {
+    use core::str::FromStr;
+    let host_value = req.headers.get(header::HOST)?;
+    let host = host_value.to_str().ok()?;
+    let uri = Uri::from_str(host).ok()?;
+    Some(uri)
+}
 
-// impl HostGuard {
-//     /// Set reuest scheme to match
-//     pub fn scheme<H: AsRef<str>>(&mut self, scheme: H) {
-//         self.1 = Some(scheme.as_ref().to_string())
-//     }
-// }
+#[doc(hidden)]
+pub struct HostGuard(String, Option<String>);
 
-// impl Guard for HostGuard {
-//     fn check(&self, _req: &RequestHead) -> bool {
-//         // let info = req.connection_info();
-//         // if let Some(ref scheme) = self.1 {
-//         //     self.0 == info.host() && scheme == info.scheme()
-//         // } else {
-//         //     self.0 == info.host()
-//         // }
-//         false
-//     }
-// }
+impl HostGuard {
+    /// Set request scheme to match
+    pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard {
+        self.1 = Some(scheme.as_ref().to_string());
+        self
+    }
+}
+
+impl Guard for HostGuard {
+    fn check(&self, req: &RequestHead) -> bool {
+        let req_host_uri = if let Some(uri) = get_host_uri(req) {
+            uri
+        } else {
+            return false;
+        };
+
+        if let Some(uri_host) = req_host_uri.host() {
+            if self.0 != uri_host {
+                return false;
+            }
+        } else {
+            return false;
+        }
+
+        if let Some(ref scheme) = self.1 {
+            if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() {
+                return scheme == req_host_uri_scheme;
+            }
+        }
+
+        true
+    }
+}
 
 #[cfg(test)]
 mod tests {
@@ -318,21 +341,64 @@ mod tests {
         assert!(!pred.check(req.head()));
     }
 
-    // #[test]
-    // fn test_host() {
-    //     let req = TestServiceRequest::default()
-    //         .header(
-    //             header::HOST,
-    //             header::HeaderValue::from_static("www.rust-lang.org"),
-    //         )
-    //         .request();
+    #[test]
+    fn test_host() {
+        let req = TestRequest::default()
+            .header(
+                header::HOST,
+                header::HeaderValue::from_static("www.rust-lang.org"),
+            )
+            .to_http_request();
 
-    //     let pred = Host("www.rust-lang.org");
-    //     assert!(pred.check(&req));
+        let pred = Host("www.rust-lang.org");
+        assert!(pred.check(req.head()));
 
-    //     let pred = Host("localhost");
-    //     assert!(!pred.check(&req));
-    // }
+        let pred = Host("www.rust-lang.org").scheme("https");
+        assert!(pred.check(req.head()));
+
+        let pred = Host("blog.rust-lang.org");
+        assert!(!pred.check(req.head()));
+
+        let pred = Host("blog.rust-lang.org").scheme("https");
+        assert!(!pred.check(req.head()));
+
+        let pred = Host("crates.io");
+        assert!(!pred.check(req.head()));
+
+        let pred = Host("localhost");
+        assert!(!pred.check(req.head()));
+    }
+
+    #[test]
+    fn test_host_scheme() {
+        let req = TestRequest::default()
+            .header(
+                header::HOST,
+                header::HeaderValue::from_static("https://www.rust-lang.org"),
+            )
+            .to_http_request();
+
+        let pred = Host("www.rust-lang.org").scheme("https");
+        assert!(pred.check(req.head()));
+
+        let pred = Host("www.rust-lang.org");
+        assert!(pred.check(req.head()));
+
+        let pred = Host("www.rust-lang.org").scheme("http");
+        assert!(!pred.check(req.head()));
+
+        let pred = Host("blog.rust-lang.org");
+        assert!(!pred.check(req.head()));
+
+        let pred = Host("blog.rust-lang.org").scheme("https");
+        assert!(!pred.check(req.head()));
+
+        let pred = Host("crates.io").scheme("https");
+        assert!(!pred.check(req.head()));
+
+        let pred = Host("localhost");
+        assert!(!pred.check(req.head()));
+    }
 
     #[test]
     fn test_methods() {