From 136dac135245daec70b6b68a116dca9132b3680f Mon Sep 17 00:00:00 2001
From: James Wright <Digital-Chaos@users.noreply.github.com>
Date: Thu, 3 Jun 2021 03:28:09 +0100
Subject: [PATCH] Additional test coverage and tidyup (middleware::normalize)
 (#2243)

---
 src/middleware/normalize.rs | 216 ++++++++++++++++++++++++------------
 1 file changed, 148 insertions(+), 68 deletions(-)

diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs
index ec6c2a34..cbed7871 100644
--- a/src/middleware/normalize.rs
+++ b/src/middleware/normalize.rs
@@ -189,6 +189,7 @@ mod tests {
     use super::*;
     use crate::{
         dev::ServiceRequest,
+        guard::fn_guard,
         test::{call_service, init_service, TestRequest},
         web, App, HttpResponse,
     };
@@ -199,37 +200,34 @@ mod tests {
             App::new()
                 .wrap(NormalizePath::default())
                 .service(web::resource("/").to(HttpResponse::Ok))
-                .service(web::resource("/v1/something").to(HttpResponse::Ok)),
+                .service(web::resource("/v1/something").to(HttpResponse::Ok))
+                .service(
+                    web::resource("/v2/something")
+                        .guard(fn_guard(|req| req.uri.query() == Some("query=test")))
+                        .to(HttpResponse::Ok),
+                ),
         )
         .await;
 
-        let req = TestRequest::with_uri("/").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
+        let test_uris = vec![
+            "/",
+            "/?query=test",
+            "///",
+            "/v1//something",
+            "/v1//something////",
+            "//v1/something",
+            "//v1//////something",
+            "/v2//something?query=test",
+            "/v2//something////?query=test",
+            "//v2/something?query=test",
+            "//v2//////something?query=test",
+        ];
 
-        let req = TestRequest::with_uri("/?query=test").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
-
-        let req = TestRequest::with_uri("///").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
-
-        let req = TestRequest::with_uri("/v1//something////").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
-
-        let req2 = TestRequest::with_uri("//v1/something").to_request();
-        let res2 = call_service(&app, req2).await;
-        assert!(res2.status().is_success());
-
-        let req3 = TestRequest::with_uri("//v1//////something").to_request();
-        let res3 = call_service(&app, req3).await;
-        assert!(res3.status().is_success());
-
-        let req4 = TestRequest::with_uri("/v1//something").to_request();
-        let res4 = call_service(&app, req4).await;
-        assert!(res4.status().is_success());
+        for uri in test_uris {
+            let req = TestRequest::with_uri(uri).to_request();
+            let res = call_service(&app, req).await;
+            assert!(res.status().is_success(), "Failed uri: {}", uri);
+        }
     }
 
     #[actix_rt::test]
@@ -238,38 +236,114 @@ mod tests {
             App::new()
                 .wrap(NormalizePath(TrailingSlash::Trim))
                 .service(web::resource("/").to(HttpResponse::Ok))
-                .service(web::resource("/v1/something").to(HttpResponse::Ok)),
+                .service(web::resource("/v1/something").to(HttpResponse::Ok))
+                .service(
+                    web::resource("/v2/something")
+                        .guard(fn_guard(|req| req.uri.query() == Some("query=test")))
+                        .to(HttpResponse::Ok),
+                ),
         )
         .await;
 
-        // root paths should still work
-        let req = TestRequest::with_uri("/").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
+        let test_uris = vec![
+            "/",
+            "///",
+            "/v1/something",
+            "/v1/something/",
+            "/v1/something////",
+            "//v1//something",
+            "//v1//something//",
+            "/v2/something?query=test",
+            "/v2/something/?query=test",
+            "/v2/something////?query=test",
+            "//v2//something?query=test",
+            "//v2//something//?query=test",
+        ];
 
-        let req = TestRequest::with_uri("/?query=test").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
+        for uri in test_uris {
+            let req = TestRequest::with_uri(uri).to_request();
+            let res = call_service(&app, req).await;
+            assert!(res.status().is_success(), "Failed uri: {}", uri);
+        }
+    }
 
-        let req = TestRequest::with_uri("///").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
+    #[actix_rt::test]
+    async fn trim_root_trailing_slashes_with_query() {
+        let app = init_service(
+            App::new().wrap(NormalizePath(TrailingSlash::Trim)).service(
+                web::resource("/")
+                    .guard(fn_guard(|req| req.uri.query() == Some("query=test")))
+                    .to(HttpResponse::Ok),
+            ),
+        )
+        .await;
 
-        let req = TestRequest::with_uri("/v1/something////").to_request();
-        let res = call_service(&app, req).await;
-        assert!(res.status().is_success());
+        let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
 
-        let req2 = TestRequest::with_uri("/v1/something/").to_request();
-        let res2 = call_service(&app, req2).await;
-        assert!(res2.status().is_success());
+        for uri in test_uris {
+            let req = TestRequest::with_uri(uri).to_request();
+            let res = call_service(&app, req).await;
+            assert!(res.status().is_success(), "Failed uri: {}", uri);
+        }
+    }
 
-        let req3 = TestRequest::with_uri("//v1//something//").to_request();
-        let res3 = call_service(&app, req3).await;
-        assert!(res3.status().is_success());
+    #[actix_rt::test]
+    async fn ensure_trailing_slash() {
+        let app = init_service(
+            App::new()
+                .wrap(NormalizePath(TrailingSlash::Always))
+                .service(web::resource("/").to(HttpResponse::Ok))
+                .service(web::resource("/v1/something/").to(HttpResponse::Ok))
+                .service(
+                    web::resource("/v2/something/")
+                        .guard(fn_guard(|req| req.uri.query() == Some("query=test")))
+                        .to(HttpResponse::Ok),
+                ),
+        )
+        .await;
 
-        let req4 = TestRequest::with_uri("//v1//something").to_request();
-        let res4 = call_service(&app, req4).await;
-        assert!(res4.status().is_success());
+        let test_uris = vec![
+            "/",
+            "///",
+            "/v1/something",
+            "/v1/something/",
+            "/v1/something////",
+            "//v1//something",
+            "//v1//something//",
+            "/v2/something?query=test",
+            "/v2/something/?query=test",
+            "/v2/something////?query=test",
+            "//v2//something?query=test",
+            "//v2//something//?query=test",
+        ];
+
+        for uri in test_uris {
+            let req = TestRequest::with_uri(uri).to_request();
+            let res = call_service(&app, req).await;
+            assert!(res.status().is_success(), "Failed uri: {}", uri);
+        }
+    }
+
+    #[actix_rt::test]
+    async fn ensure_root_trailing_slash_with_query() {
+        let app = init_service(
+            App::new()
+                .wrap(NormalizePath(TrailingSlash::Always))
+                .service(
+                    web::resource("/")
+                        .guard(fn_guard(|req| req.uri.query() == Some("query=test")))
+                        .to(HttpResponse::Ok),
+                ),
+        )
+        .await;
+
+        let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
+
+        for uri in test_uris {
+            let req = TestRequest::with_uri(uri).to_request();
+            let res = call_service(&app, req).await;
+            assert!(res.status().is_success(), "Failed uri: {}", uri);
+        }
     }
 
     #[actix_rt::test]
@@ -279,7 +353,12 @@ mod tests {
                 .wrap(NormalizePath(TrailingSlash::MergeOnly))
                 .service(web::resource("/").to(HttpResponse::Ok))
                 .service(web::resource("/v1/something").to(HttpResponse::Ok))
-                .service(web::resource("/v1/").to(HttpResponse::Ok)),
+                .service(web::resource("/v1/").to(HttpResponse::Ok))
+                .service(
+                    web::resource("/v2/something")
+                        .guard(fn_guard(|req| req.uri.query() == Some("query=test")))
+                        .to(HttpResponse::Ok),
+                ),
         )
         .await;
 
@@ -295,12 +374,16 @@ mod tests {
             ("/v1////", true),
             ("//v1//", true),
             ("///v1", false),
+            ("/v2/something?query=test", true),
+            ("/v2/something/?query=test", false),
+            ("/v2/something//?query=test", false),
+            ("//v2//something?query=test", true),
         ];
 
-        for (path, success) in tests {
-            let req = TestRequest::with_uri(path).to_request();
+        for (uri, success) in tests {
+            let req = TestRequest::with_uri(uri).to_request();
             let res = call_service(&app, req).await;
-            assert_eq!(res.status().is_success(), success);
+            assert_eq!(res.status().is_success(), success, "Failed uri: {}", uri);
         }
     }
 
@@ -316,21 +399,18 @@ mod tests {
             .await
             .unwrap();
 
-        let req = TestRequest::with_uri("/v1//something////").to_srv_request();
-        let res = normalize.call(req).await.unwrap();
-        assert!(res.status().is_success());
+        let test_uris = vec![
+            "/v1//something////",
+            "///v1/something",
+            "//v1///something",
+            "/v1//something",
+        ];
 
-        let req2 = TestRequest::with_uri("///v1/something").to_srv_request();
-        let res2 = normalize.call(req2).await.unwrap();
-        assert!(res2.status().is_success());
-
-        let req3 = TestRequest::with_uri("//v1///something").to_srv_request();
-        let res3 = normalize.call(req3).await.unwrap();
-        assert!(res3.status().is_success());
-
-        let req4 = TestRequest::with_uri("/v1//something").to_srv_request();
-        let res4 = normalize.call(req4).await.unwrap();
-        assert!(res4.status().is_success());
+        for uri in test_uris {
+            let req = TestRequest::with_uri(uri).to_srv_request();
+            let res = normalize.call(req).await.unwrap();
+            assert!(res.status().is_success(), "Failed uri: {}", uri);
+        }
     }
 
     #[actix_rt::test]