From ab597dd98a4bd2f768e2ee419fa752e97ddcec2e Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Tue, 26 Mar 2019 20:57:06 -0700
Subject: [PATCH] Added HTTP Authentication for Client #540

---
 awc/Cargo.toml           |  1 +
 awc/src/request.rs       | 24 +++++++++++
 awc/tests/test_client.rs | 89 +++++++++++++++++++++++++---------------
 3 files changed, 82 insertions(+), 32 deletions(-)

diff --git a/awc/Cargo.toml b/awc/Cargo.toml
index 88c3be42..e08169c9 100644
--- a/awc/Cargo.toml
+++ b/awc/Cargo.toml
@@ -41,6 +41,7 @@ flate2-rust = ["actix-http/flate2-rust"]
 [dependencies]
 actix-service = "0.3.4"
 actix-http = { path = "../actix-http/" }
+base64 = "0.10.1"
 bytes = "0.4"
 futures = "0.1"
 log =" 0.4"
diff --git a/awc/src/request.rs b/awc/src/request.rs
index 90f9a1ab..649797df 100644
--- a/awc/src/request.rs
+++ b/awc/src/request.rs
@@ -242,6 +242,30 @@ impl ClientRequest {
         self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze())
     }
 
+    /// Set HTTP basic authorization
+    pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
+    where
+        U: fmt::Display,
+        P: fmt::Display,
+    {
+        let auth = match password {
+            Some(password) => format!("{}:{}", username, password),
+            None => format!("{}", username),
+        };
+        self.header(
+            header::AUTHORIZATION,
+            format!("Basic {}", base64::encode(&auth)),
+        )
+    }
+
+    /// Set HTTP bearer authentication
+    pub fn bearer_auth<T>(self, token: T) -> Self
+    where
+        T: fmt::Display,
+    {
+        self.header(header::AUTHORIZATION, format!("Bearer {}", token))
+    }
+
     #[cfg(feature = "cookies")]
     /// Set a cookie
     ///
diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs
index f7605b59..ac07eb6d 100644
--- a/awc/tests/test_client.rs
+++ b/awc/tests/test_client.rs
@@ -1,17 +1,14 @@
-use std::io::{Read, Write};
-use std::{net, thread};
+use std::io::Write;
 
 use brotli2::write::BrotliEncoder;
 use bytes::Bytes;
-use flate2::write::{GzEncoder, ZlibEncoder};
+use flate2::write::GzEncoder;
 use flate2::Compression;
-use futures::stream::once;
-use futures::Future;
 use rand::Rng;
 
 use actix_http::HttpService;
 use actix_http_test::TestServer;
-use actix_web::{middleware, web, App, HttpRequest, HttpResponse};
+use actix_web::{http::header, web, App, HttpMessage, HttpRequest, HttpResponse};
 
 const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
                    Hello World Hello World Hello World Hello World Hello World \
@@ -479,30 +476,58 @@ fn test_client_brotli_encoding() {
 //     assert_eq!(bytes, Bytes::from_static(b"welcome!"));
 // }
 
-// #[test]
-// fn client_basic_auth() {
-//     let mut srv =
-//         test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
-//     /// set authorization header to Basic <base64 encoded username:password>
-//     let request = srv
-//         .get()
-//         .basic_auth("username", Some("password"))
-//         .finish()
-//         .unwrap();
-//     let repr = format!("{:?}", request);
-//     assert!(repr.contains("Basic dXNlcm5hbWU6cGFzc3dvcmQ="));
-// }
+#[test]
+fn client_basic_auth() {
+    let mut srv = TestServer::new(|| {
+        HttpService::new(App::new().route(
+            "/",
+            web::to(|req: HttpRequest| {
+                if req
+                    .headers()
+                    .get(header::AUTHORIZATION)
+                    .unwrap()
+                    .to_str()
+                    .unwrap()
+                    == "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
+                {
+                    HttpResponse::Ok()
+                } else {
+                    HttpResponse::BadRequest()
+                }
+            }),
+        ))
+    });
 
-// #[test]
-// fn client_bearer_auth() {
-//     let mut srv =
-//         test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
-//     /// set authorization header to Bearer <token>
-//     let request = srv
-//         .get()
-//         .bearer_auth("someS3cr3tAutht0k3n")
-//         .finish()
-//         .unwrap();
-//     let repr = format!("{:?}", request);
-//     assert!(repr.contains("Bearer someS3cr3tAutht0k3n"));
-// }
+    // set authorization header to Basic <base64 encoded username:password>
+    let request = srv.get().basic_auth("username", Some("password"));
+    let response = srv.block_on(request.send()).unwrap();
+    assert!(response.status().is_success());
+}
+
+#[test]
+fn client_bearer_auth() {
+    let mut srv = TestServer::new(|| {
+        HttpService::new(App::new().route(
+            "/",
+            web::to(|req: HttpRequest| {
+                if req
+                    .headers()
+                    .get(header::AUTHORIZATION)
+                    .unwrap()
+                    .to_str()
+                    .unwrap()
+                    == "Bearer someS3cr3tAutht0k3n"
+                {
+                    HttpResponse::Ok()
+                } else {
+                    HttpResponse::BadRequest()
+                }
+            }),
+        ))
+    });
+
+    // set authorization header to Bearer <token>
+    let request = srv.get().bearer_auth("someS3cr3tAutht0k3n");
+    let response = srv.block_on(request.send()).unwrap();
+    assert!(response.status().is_success());
+}