From bb1bf5f2be04fb3b1f0200c0912d507e9ba40a37 Mon Sep 17 00:00:00 2001
From: joshbenaron <73971531+joshbenaron@users.noreply.github.com>
Date: Fri, 2 Apr 2021 13:27:30 +0100
Subject: [PATCH 1/4] Add awc retry middleware

---
 awc/src/middleware/mod.rs   |   2 +
 awc/src/middleware/retry.rs | 472 ++++++++++++++++++++++++++++++++++++
 2 files changed, 474 insertions(+)
 create mode 100644 awc/src/middleware/retry.rs

diff --git a/awc/src/middleware/mod.rs b/awc/src/middleware/mod.rs
index 330e3b7f..df66ad70 100644
--- a/awc/src/middleware/mod.rs
+++ b/awc/src/middleware/mod.rs
@@ -1,6 +1,8 @@
 mod redirect;
+mod retry;
 
 pub use self::redirect::Redirect;
+pub use self::retry::Retry;
 
 use std::marker::PhantomData;
 
diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs
new file mode 100644
index 00000000..27bb3a1c
--- /dev/null
+++ b/awc/src/middleware/retry.rs
@@ -0,0 +1,472 @@
+use super::Transform;
+use std::rc::Rc;
+use actix_http::RequestHeadType;
+use actix_http::http::{StatusCode, HeaderMap};
+use std::ops::Deref;
+use crate::{ConnectRequest, ConnectResponse};
+use actix_service::Service;
+use actix_http::client::SendRequestError;
+use std::task::{Context, Poll};
+use crate::RequestHead;
+use futures_core::future::LocalBoxFuture;
+use actix_http::body::Body;
+
+pub struct Retry(Inner);
+
+struct Inner {
+    /// Number of retries. So each request will be tried [max_retry + 1] times
+    max_retry: u8,
+    policies: Vec<RetryPolicy>,
+}
+
+impl Retry {
+    pub fn new(retries: u8) -> Self {
+        Retry(Inner {
+            max_retry: retries,
+            policies: vec![],
+        })
+    }
+
+    /// Allows you to add a retry policy to the [`policies`]
+    /// It allows two types of policy:
+    ///  - `Vec<StatusCode>` and will retry if one of them is received
+    ///  - `Fn(&ResponseHead) -> bool` and will retry when this function resolves to false
+    ///
+    /// # example
+    ///
+    ///```
+    ///
+    /// // Creates a policy which will try each request a max of 5 times if any policies resolve to true
+    /// // i.e.
+    /// // if you receive a 401 or 501 status code
+    /// // or
+    /// // the response doesn't have a [`SOME_HEADER`] header
+    /// use awc::http::{StatusCode, HeaderMap};
+    /// use awc::middleware::Retry;
+    ///
+    /// let retry_policies = Retry::new(5)
+    ///     .policy(vec![StatusCode::INTERNAL_SERVER_ERROR, StatusCode::UNAUTHORIZED])
+    ///     .policy(|code: StatusCode, headers: &HeaderMap| {
+    ///         return if headers.contains_key("SOME_HEADER") {
+    ///             true
+    ///         } else {
+    ///             false
+    ///         };
+    ///     });
+    ///
+    /// // Creates awc client
+    /// let client = awc::Client::builder()
+    ///     .wrap(retry_policies)
+    ///     .finish();
+    ///```
+    pub fn policy<T>(mut self, p: T) -> Self
+        where T: IntoRetryPolicy
+    {
+        self.0.policies.push(p.into_policy());
+        self
+    }
+}
+
+#[non_exhaustive]
+pub enum RetryPolicy {
+    Status(Vec<StatusCode>),
+    Custom(Box<dyn Fn(StatusCode, &HeaderMap) -> bool>),
+}
+
+pub trait IntoRetryPolicy {
+    fn into_policy(self) -> RetryPolicy;
+}
+
+impl<T> IntoRetryPolicy for T
+    where T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static
+{
+    fn into_policy(self) -> RetryPolicy {
+        RetryPolicy::Custom(Box::new(self))
+    }
+}
+
+impl IntoRetryPolicy for Vec<StatusCode> {
+    fn into_policy(self) -> RetryPolicy {
+        RetryPolicy::Status(self)
+    }
+}
+
+impl<S> Transform<S, ConnectRequest> for Retry
+    where
+        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
+{
+    type Transform = RetryService<S>;
+
+    fn new_transform(self, service: S) -> Self::Transform {
+        RetryService {
+            max_retry: self.0.max_retry,
+            policies: self.0.policies.into_boxed_slice().into(),
+            connector: service.into(),
+        }
+    }
+}
+
+#[derive(Clone)]
+pub struct RetryService<S> {
+    policies: Rc<[RetryPolicy]>,
+    max_retry: u8,
+    connector: Rc<S>,
+}
+
+impl<S> Service<ConnectRequest> for RetryService<S>
+    where
+        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
+{
+    type Response = S::Response;
+    type Error = S::Error;
+    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
+
+    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        self.connector.poll_ready(ctx)
+    }
+
+    fn call(&self, req: ConnectRequest) -> Self::Future {
+        let connector = self.connector.clone();
+        let policies = self.policies.clone();
+        let max_retry = self.max_retry;
+
+        Box::pin(async move {
+            let mut tries = 0;
+            match req {
+                ConnectRequest::Client(head, body, addr) => {
+                    match body {
+                        Body::Bytes(b) => {
+                            loop {
+                                let h = clone_request_head_type(&head);
+
+                                match connector.call(ConnectRequest::Client(h, Body::Bytes(b.clone()), addr)).await
+                                {
+                                    Ok(res) => {
+                                        // ConnectResponse
+                                        match &res {
+                                            ConnectResponse::Client(ref r) => {
+                                                if is_valid_response(policies.as_ref(), r.status(), r.headers()) {
+                                                    return Ok(res);
+                                                }
+
+                                                if tries == max_retry {
+                                                    return Ok(res);
+                                                }
+
+                                                tries += 1;
+                                            }
+                                            ConnectResponse::Tunnel(ref head, _) => {
+                                                if is_valid_response(policies.as_ref(), head.status, head.headers()) {
+                                                    return Ok(res);
+                                                }
+
+                                                if tries == max_retry {
+                                                    return Ok(res);
+                                                }
+
+                                                tries += 1;
+                                            }
+                                        }
+                                    }
+                                    // SendRequestError
+                                    Err(e) => {
+                                        if tries == max_retry {
+                                            log::debug!("Request max retry reached");
+                                            return Err(e);
+                                        }
+
+                                        tries += 1;
+                                    }
+                                }
+                            }
+                        }
+                        Body::Empty => {
+                            loop {
+                                let h = clone_request_head_type(&head);
+
+                                match connector.call(ConnectRequest::Client(h, Body::Empty, addr)).await
+                                {
+                                    Ok(res) => {
+                                        // ConnectResponse
+                                        match &res {
+                                            ConnectResponse::Client(ref r) => {
+                                                if is_valid_response(policies.as_ref(), r.status(), r.headers()) {
+                                                    return Ok(res);
+                                                }
+
+                                                if tries == max_retry {
+                                                    log::debug!("Request max retry reached");
+                                                    return Ok(res);
+                                                }
+
+                                                tries += 1;
+                                            }
+                                            ConnectResponse::Tunnel(ref head, _) => {
+                                                if is_valid_response(policies.as_ref(), head.status, head.headers()) {
+                                                    return Ok(res);
+                                                } else {
+                                                    if tries == max_retry {
+                                                        log::debug!("Request max retry reached");
+                                                        return Ok(res);
+                                                    }
+
+                                                    tries += 1;
+                                                }
+                                            }
+                                        }
+                                    }
+                                    // SendRequestError
+                                    Err(e) => {
+                                        if tries == max_retry {
+                                            log::debug!("Request max retry reached");
+                                            return Err(e);
+                                        }
+
+                                        tries += 1;
+                                    }
+                                }
+                            }
+                        }
+                        _ => {
+                            log::debug!("Non cloneable body type given - defaulting to `Body::None`");
+                            loop {
+                                let h = clone_request_head_type(&head);
+
+                                match connector.call(ConnectRequest::Client(h, Body::None, addr)).await
+                                {
+                                    Ok(res) => {
+                                        // ConnectResponse
+                                        match &res {
+                                            ConnectResponse::Client(ref r) => {
+                                                if is_valid_response(policies.as_ref(), r.status(), r.headers()) {
+                                                    return Ok(res);
+                                                }
+
+                                                if tries == max_retry {
+                                                    log::debug!("Request max retry reached");
+                                                    return Ok(res);
+                                                }
+
+                                                tries += 1;
+                                            }
+                                            ConnectResponse::Tunnel(ref head, _) => {
+                                                if is_valid_response(policies.as_ref(), head.status, head.headers()) {
+                                                    return Ok(res);
+                                                } else {
+                                                    if tries == max_retry {
+                                                        log::debug!("Request max retry reached");
+                                                        return Ok(res);
+                                                    }
+
+                                                    tries += 1;
+                                                }
+                                            }
+                                        }
+                                    }
+                                    // SendRequestError
+                                    Err(e) => {
+                                        if tries == max_retry {
+                                            log::debug!("Request max retry reached");
+                                            return Err(e);
+                                        }
+
+                                        tries += 1;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                ConnectRequest::Tunnel(head, addr) => {
+                    loop {
+                        let h = clone_request_head(&head);
+
+                        match connector.call(ConnectRequest::Tunnel(h, addr)).await {
+                            Ok(res) => {
+                                match &res {
+                                    ConnectResponse::Client(r) => {
+                                        if is_valid_response(&policies, r.status(), r.headers()) {
+                                            return Ok(res)
+                                        }
+
+                                        if tries == max_retry {
+                                            log::debug!("Request max retry reached");
+                                            return Ok(res)
+                                        }
+
+                                        tries += 1;
+                                    }
+                                    ConnectResponse::Tunnel(head, _) => {
+                                        if is_valid_response(&policies, head.status, head.headers()) {
+                                            return Ok(res)
+                                        }
+
+                                        if tries == max_retry {
+                                            log::debug!("Request max retry reached");
+                                            return Ok(res)
+                                        }
+
+                                        tries += 1;
+                                    }
+                                }
+                            },
+                            Err(e) => {
+                                if tries == max_retry {
+                                    log::debug!("Request max retry reached");
+                                    return Err(e)
+                                }
+
+                                tries += 1;
+                            }
+                        }
+                    }
+                }
+            }
+        })
+    }
+}
+
+#[doc(hidden)]
+/// Clones [RequestHeadType] except for the extensions (not required for this middleware)
+fn clone_request_head_type(head_type: &RequestHeadType) -> RequestHeadType {
+    match head_type {
+        RequestHeadType::Owned(h) => {
+            let mut inner_head = RequestHead::default();
+            inner_head.uri = h.uri.clone();
+            inner_head.method = h.method.clone();
+            inner_head.version = h.version;
+            inner_head.peer_addr = h.peer_addr;
+            inner_head.headers = h.headers.clone();
+
+            RequestHeadType::Owned(inner_head)
+        }
+        RequestHeadType::Rc(h, header_map) => {
+            RequestHeadType::Rc(h.clone(), header_map.clone())
+        }
+    }
+}
+
+#[doc(hidden)]
+/// Clones [RequestHeadType] except for the extensions (not required for this middleware)
+fn clone_request_head(head: &RequestHead) -> RequestHead {
+    let mut new_head = RequestHead::default();
+    new_head.uri = head.uri.clone();
+    new_head.method = head.method.clone();
+    new_head.version = head.version;
+    new_head.headers = head.headers.clone();
+    new_head.peer_addr = head.peer_addr;
+
+    new_head
+}
+
+#[doc(hidden)]
+/// Checks whether the response matches the policies
+fn is_valid_response(policies: &[RetryPolicy], status_code: StatusCode, headers: &HeaderMap) -> bool {
+    policies.iter().all(|policy| {
+        match policy {
+            RetryPolicy::Status(v) => {
+                // is valid if:
+                // - the list of status codes is empty
+                // or
+                // - the list doesn't contain the received status code
+                v.is_empty() || !v.contains(&status_code)
+            }
+            RetryPolicy::Custom(func) => {
+                // custom policy
+                (func.deref())(status_code, headers)
+            }
+        }
+    })
+}
+
+#[cfg(test)]
+mod tests {
+    use actix_web::{web, App, Error, HttpResponse};
+
+    use super::*;
+    use crate::ClientBuilder;
+
+    #[actix_rt::test]
+    async fn test_basic_policy() {
+        let client = ClientBuilder::new()
+            .disable_redirects()
+            .wrap(Retry::new(3)
+                .policy(vec![StatusCode::INTERNAL_SERVER_ERROR])
+            )
+            .finish();
+
+        let srv = actix_test::start(|| {
+            App::new()
+                .service(web::resource("/test").route(web::to(|| async {
+                    Ok::<_, Error>(
+                        HttpResponse::InternalServerError()
+                            .finish(),
+                    )
+                })))
+        });
+
+        let res = client.get(srv.url("/test")).send().await.unwrap();
+
+        assert_eq!(res.status().as_u16(), 500);
+    }
+
+    #[actix_rt::test]
+    async fn test_header_policy() {
+        std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug");
+        env_logger::init();
+
+        let client = ClientBuilder::new()
+            .disable_redirects()
+            .wrap(Retry::new(3)
+                .policy(|code: StatusCode, headers: &HeaderMap| {
+                    code.is_success() && headers.contains_key("SOME_HEADER")
+                })
+            )
+            .finish();
+
+        let srv = actix_test::start(|| {
+            App::new()
+                .service(web::resource("/test").route(web::to(|| async {
+                    Ok::<_, Error>(
+                        HttpResponse::Ok()
+                            .insert_header(("SOME_HEADER", "test"))
+                            .finish(),
+                    )
+                })))
+        });
+
+        let res = client.get(srv.url("/test")).send().await.unwrap();
+
+        assert_eq!(res.status().as_u16(), 200);
+    }
+
+    #[actix_rt::test]
+    async fn test_bad_header_policy() {
+        std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug");
+        env_logger::init();
+
+        let client = ClientBuilder::new()
+            .disable_redirects()
+            .wrap(Retry::new(3)
+                .policy(|code: StatusCode, headers: &HeaderMap| {
+                    code.is_success() && headers.contains_key("WRONG_HEADER")
+                })
+            )
+            .finish();
+
+        let srv = actix_test::start(|| {
+            App::new()
+                .service(web::resource("/test").route(web::to(|| async {
+                    Ok::<_, Error>(
+                        HttpResponse::Ok()
+                            .insert_header(("SOME_HEADER", "test"))
+                            .finish(),
+                    )
+                })))
+        });
+
+        let res = client.get(srv.url("/test")).send().await.unwrap();
+
+        assert_eq!(res.status().as_u16(), 200);
+    }
+}

From 58904b9ebc69d57511ae4d718b21ef6dcb3b2b7c Mon Sep 17 00:00:00 2001
From: joshbenaron <73971531+joshbenaron@users.noreply.github.com>
Date: Fri, 2 Apr 2021 14:50:43 +0100
Subject: [PATCH 2/4] Fixed invalid tests

---
 awc/src/middleware/retry.rs | 228 ++++++++++++++++++++----------------
 1 file changed, 128 insertions(+), 100 deletions(-)

diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs
index 27bb3a1c..32b54ffc 100644
--- a/awc/src/middleware/retry.rs
+++ b/awc/src/middleware/retry.rs
@@ -1,15 +1,15 @@
 use super::Transform;
-use std::rc::Rc;
-use actix_http::RequestHeadType;
-use actix_http::http::{StatusCode, HeaderMap};
-use std::ops::Deref;
-use crate::{ConnectRequest, ConnectResponse};
-use actix_service::Service;
-use actix_http::client::SendRequestError;
-use std::task::{Context, Poll};
 use crate::RequestHead;
-use futures_core::future::LocalBoxFuture;
+use crate::{ConnectRequest, ConnectResponse};
 use actix_http::body::Body;
+use actix_http::client::SendRequestError;
+use actix_http::http::{HeaderMap, StatusCode};
+use actix_http::RequestHeadType;
+use actix_service::Service;
+use futures_core::future::LocalBoxFuture;
+use std::ops::Deref;
+use std::rc::Rc;
+use std::task::{Context, Poll};
 
 pub struct Retry(Inner);
 
@@ -60,7 +60,8 @@ impl Retry {
     ///     .finish();
     ///```
     pub fn policy<T>(mut self, p: T) -> Self
-        where T: IntoRetryPolicy
+    where
+        T: IntoRetryPolicy,
     {
         self.0.policies.push(p.into_policy());
         self
@@ -78,7 +79,8 @@ pub trait IntoRetryPolicy {
 }
 
 impl<T> IntoRetryPolicy for T
-    where T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static
+where
+    T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static,
 {
     fn into_policy(self) -> RetryPolicy {
         RetryPolicy::Custom(Box::new(self))
@@ -92,8 +94,8 @@ impl IntoRetryPolicy for Vec<StatusCode> {
 }
 
 impl<S> Transform<S, ConnectRequest> for Retry
-    where
-        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
+where
+    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
 {
     type Transform = RetryService<S>;
 
@@ -114,8 +116,8 @@ pub struct RetryService<S> {
 }
 
 impl<S> Service<ConnectRequest> for RetryService<S>
-    where
-        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
+where
+    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
 {
     type Response = S::Response;
     type Error = S::Error;
@@ -139,13 +141,23 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                             loop {
                                 let h = clone_request_head_type(&head);
 
-                                match connector.call(ConnectRequest::Client(h, Body::Bytes(b.clone()), addr)).await
+                                match connector
+                                    .call(ConnectRequest::Client(
+                                        h,
+                                        Body::Bytes(b.clone()),
+                                        addr,
+                                    ))
+                                    .await
                                 {
                                     Ok(res) => {
                                         // ConnectResponse
                                         match &res {
                                             ConnectResponse::Client(ref r) => {
-                                                if is_valid_response(policies.as_ref(), r.status(), r.headers()) {
+                                                if is_valid_response(
+                                                    policies.as_ref(),
+                                                    r.status(),
+                                                    r.headers(),
+                                                ) {
                                                     return Ok(res);
                                                 }
 
@@ -156,7 +168,11 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                                                 tries += 1;
                                             }
                                             ConnectResponse::Tunnel(ref head, _) => {
-                                                if is_valid_response(policies.as_ref(), head.status, head.headers()) {
+                                                if is_valid_response(
+                                                    policies.as_ref(),
+                                                    head.status,
+                                                    head.headers(),
+                                                ) {
                                                     return Ok(res);
                                                 }
 
@@ -184,13 +200,19 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                             loop {
                                 let h = clone_request_head_type(&head);
 
-                                match connector.call(ConnectRequest::Client(h, Body::Empty, addr)).await
+                                match connector
+                                    .call(ConnectRequest::Client(h, Body::Empty, addr))
+                                    .await
                                 {
                                     Ok(res) => {
                                         // ConnectResponse
                                         match &res {
                                             ConnectResponse::Client(ref r) => {
-                                                if is_valid_response(policies.as_ref(), r.status(), r.headers()) {
+                                                if is_valid_response(
+                                                    policies.as_ref(),
+                                                    r.status(),
+                                                    r.headers(),
+                                                ) {
                                                     return Ok(res);
                                                 }
 
@@ -202,11 +224,17 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                                                 tries += 1;
                                             }
                                             ConnectResponse::Tunnel(ref head, _) => {
-                                                if is_valid_response(policies.as_ref(), head.status, head.headers()) {
+                                                if is_valid_response(
+                                                    policies.as_ref(),
+                                                    head.status,
+                                                    head.headers(),
+                                                ) {
                                                     return Ok(res);
                                                 } else {
                                                     if tries == max_retry {
-                                                        log::debug!("Request max retry reached");
+                                                        log::debug!(
+                                                            "Request max retry reached"
+                                                        );
                                                         return Ok(res);
                                                     }
 
@@ -228,17 +256,25 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                             }
                         }
                         _ => {
-                            log::debug!("Non cloneable body type given - defaulting to `Body::None`");
+                            log::debug!(
+                                "Non cloneable body type given - defaulting to `Body::None`"
+                            );
                             loop {
                                 let h = clone_request_head_type(&head);
 
-                                match connector.call(ConnectRequest::Client(h, Body::None, addr)).await
+                                match connector
+                                    .call(ConnectRequest::Client(h, Body::None, addr))
+                                    .await
                                 {
                                     Ok(res) => {
                                         // ConnectResponse
                                         match &res {
                                             ConnectResponse::Client(ref r) => {
-                                                if is_valid_response(policies.as_ref(), r.status(), r.headers()) {
+                                                if is_valid_response(
+                                                    policies.as_ref(),
+                                                    r.status(),
+                                                    r.headers(),
+                                                ) {
                                                     return Ok(res);
                                                 }
 
@@ -250,11 +286,17 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                                                 tries += 1;
                                             }
                                             ConnectResponse::Tunnel(ref head, _) => {
-                                                if is_valid_response(policies.as_ref(), head.status, head.headers()) {
+                                                if is_valid_response(
+                                                    policies.as_ref(),
+                                                    head.status,
+                                                    head.headers(),
+                                                ) {
                                                     return Ok(res);
                                                 } else {
                                                     if tries == max_retry {
-                                                        log::debug!("Request max retry reached");
+                                                        log::debug!(
+                                                            "Request max retry reached"
+                                                        );
                                                         return Ok(res);
                                                     }
 
@@ -277,50 +319,46 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                         }
                     }
                 }
-                ConnectRequest::Tunnel(head, addr) => {
-                    loop {
-                        let h = clone_request_head(&head);
+                ConnectRequest::Tunnel(head, addr) => loop {
+                    let h = clone_request_head(&head);
 
-                        match connector.call(ConnectRequest::Tunnel(h, addr)).await {
-                            Ok(res) => {
-                                match &res {
-                                    ConnectResponse::Client(r) => {
-                                        if is_valid_response(&policies, r.status(), r.headers()) {
-                                            return Ok(res)
-                                        }
-
-                                        if tries == max_retry {
-                                            log::debug!("Request max retry reached");
-                                            return Ok(res)
-                                        }
-
-                                        tries += 1;
-                                    }
-                                    ConnectResponse::Tunnel(head, _) => {
-                                        if is_valid_response(&policies, head.status, head.headers()) {
-                                            return Ok(res)
-                                        }
-
-                                        if tries == max_retry {
-                                            log::debug!("Request max retry reached");
-                                            return Ok(res)
-                                        }
-
-                                        tries += 1;
-                                    }
+                    match connector.call(ConnectRequest::Tunnel(h, addr)).await {
+                        Ok(res) => match &res {
+                            ConnectResponse::Client(r) => {
+                                if is_valid_response(&policies, r.status(), r.headers()) {
+                                    return Ok(res);
                                 }
-                            },
-                            Err(e) => {
+
                                 if tries == max_retry {
                                     log::debug!("Request max retry reached");
-                                    return Err(e)
+                                    return Ok(res);
                                 }
 
                                 tries += 1;
                             }
+                            ConnectResponse::Tunnel(head, _) => {
+                                if is_valid_response(&policies, head.status, head.headers()) {
+                                    return Ok(res);
+                                }
+
+                                if tries == max_retry {
+                                    log::debug!("Request max retry reached");
+                                    return Ok(res);
+                                }
+
+                                tries += 1;
+                            }
+                        },
+                        Err(e) => {
+                            if tries == max_retry {
+                                log::debug!("Request max retry reached");
+                                return Err(e);
+                            }
+
+                            tries += 1;
                         }
                     }
-                }
+                },
             }
         })
     }
@@ -361,7 +399,11 @@ fn clone_request_head(head: &RequestHead) -> RequestHead {
 
 #[doc(hidden)]
 /// Checks whether the response matches the policies
-fn is_valid_response(policies: &[RetryPolicy], status_code: StatusCode, headers: &HeaderMap) -> bool {
+fn is_valid_response(
+    policies: &[RetryPolicy],
+    status_code: StatusCode,
+    headers: &HeaderMap,
+) -> bool {
     policies.iter().all(|policy| {
         match policy {
             RetryPolicy::Status(v) => {
@@ -390,19 +432,13 @@ mod tests {
     async fn test_basic_policy() {
         let client = ClientBuilder::new()
             .disable_redirects()
-            .wrap(Retry::new(3)
-                .policy(vec![StatusCode::INTERNAL_SERVER_ERROR])
-            )
+            .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR]))
             .finish();
 
         let srv = actix_test::start(|| {
-            App::new()
-                .service(web::resource("/test").route(web::to(|| async {
-                    Ok::<_, Error>(
-                        HttpResponse::InternalServerError()
-                            .finish(),
-                    )
-                })))
+            App::new().service(web::resource("/test").route(web::to(|| async {
+                Ok::<_, Error>(HttpResponse::InternalServerError().finish())
+            })))
         });
 
         let res = client.get(srv.url("/test")).send().await.unwrap();
@@ -412,27 +448,23 @@ mod tests {
 
     #[actix_rt::test]
     async fn test_header_policy() {
-        std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug");
-        env_logger::init();
-
         let client = ClientBuilder::new()
             .disable_redirects()
-            .wrap(Retry::new(3)
-                .policy(|code: StatusCode, headers: &HeaderMap| {
+            .wrap(
+                Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| {
                     code.is_success() && headers.contains_key("SOME_HEADER")
-                })
+                }),
             )
             .finish();
 
         let srv = actix_test::start(|| {
-            App::new()
-                .service(web::resource("/test").route(web::to(|| async {
-                    Ok::<_, Error>(
-                        HttpResponse::Ok()
-                            .insert_header(("SOME_HEADER", "test"))
-                            .finish(),
-                    )
-                })))
+            App::new().service(web::resource("/test").route(web::to(|| async {
+                Ok::<_, Error>(
+                    HttpResponse::Ok()
+                        .insert_header(("SOME_HEADER", "test"))
+                        .finish(),
+                )
+            })))
         });
 
         let res = client.get(srv.url("/test")).send().await.unwrap();
@@ -442,27 +474,23 @@ mod tests {
 
     #[actix_rt::test]
     async fn test_bad_header_policy() {
-        std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug");
-        env_logger::init();
-
         let client = ClientBuilder::new()
             .disable_redirects()
-            .wrap(Retry::new(3)
-                .policy(|code: StatusCode, headers: &HeaderMap| {
+            .wrap(
+                Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| {
                     code.is_success() && headers.contains_key("WRONG_HEADER")
-                })
+                }),
             )
             .finish();
 
         let srv = actix_test::start(|| {
-            App::new()
-                .service(web::resource("/test").route(web::to(|| async {
-                    Ok::<_, Error>(
-                        HttpResponse::Ok()
-                            .insert_header(("SOME_HEADER", "test"))
-                            .finish(),
-                    )
-                })))
+            App::new().service(web::resource("/test").route(web::to(|| async {
+                Ok::<_, Error>(
+                    HttpResponse::Ok()
+                        .insert_header(("SOME_HEADER", "test"))
+                        .finish(),
+                )
+            })))
         });
 
         let res = client.get(srv.url("/test")).send().await.unwrap();

From 22e40613e87e5149a37595f3d116e477939a5bc7 Mon Sep 17 00:00:00 2001
From: joshbenaron <73971531+joshbenaron@users.noreply.github.com>
Date: Fri, 2 Apr 2021 18:50:07 +0100
Subject: [PATCH 3/4] Improved readability and logic

---
 awc/src/middleware/retry.rs | 287 +++++++++---------------------------
 1 file changed, 70 insertions(+), 217 deletions(-)

diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs
index 32b54ffc..6e595733 100644
--- a/awc/src/middleware/retry.rs
+++ b/awc/src/middleware/retry.rs
@@ -60,8 +60,8 @@ impl Retry {
     ///     .finish();
     ///```
     pub fn policy<T>(mut self, p: T) -> Self
-    where
-        T: IntoRetryPolicy,
+        where
+            T: IntoRetryPolicy,
     {
         self.0.policies.push(p.into_policy());
         self
@@ -79,8 +79,8 @@ pub trait IntoRetryPolicy {
 }
 
 impl<T> IntoRetryPolicy for T
-where
-    T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static,
+    where
+        T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static,
 {
     fn into_policy(self) -> RetryPolicy {
         RetryPolicy::Custom(Box::new(self))
@@ -94,8 +94,8 @@ impl IntoRetryPolicy for Vec<StatusCode> {
 }
 
 impl<S> Transform<S, ConnectRequest> for Retry
-where
-    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
+    where
+        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
 {
     type Transform = RetryService<S>;
 
@@ -116,8 +116,8 @@ pub struct RetryService<S> {
 }
 
 impl<S> Service<ConnectRequest> for RetryService<S>
-where
-    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
+    where
+        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
 {
     type Response = S::Response;
     type Error = S::Error;
@@ -133,237 +133,87 @@ where
         let max_retry = self.max_retry;
 
         Box::pin(async move {
-            let mut tries = 0;
             match req {
                 ConnectRequest::Client(head, body, addr) => {
-                    match body {
-                        Body::Bytes(b) => {
-                            loop {
-                                let h = clone_request_head_type(&head);
+                    for _ in 1..max_retry {
+                        let h = clone_request_head_type(&head);
 
-                                match connector
-                                    .call(ConnectRequest::Client(
-                                        h,
-                                        Body::Bytes(b.clone()),
-                                        addr,
-                                    ))
-                                    .await
-                                {
-                                    Ok(res) => {
-                                        // ConnectResponse
-                                        match &res {
-                                            ConnectResponse::Client(ref r) => {
-                                                if is_valid_response(
-                                                    policies.as_ref(),
-                                                    r.status(),
-                                                    r.headers(),
-                                                ) {
-                                                    return Ok(res);
-                                                }
+                        let result = connector
+                            .call(ConnectRequest::Client(h, body_to_retry_body(&body), addr))
+                            .await;
 
-                                                if tries == max_retry {
-                                                    return Ok(res);
-                                                }
-
-                                                tries += 1;
-                                            }
-                                            ConnectResponse::Tunnel(ref head, _) => {
-                                                if is_valid_response(
-                                                    policies.as_ref(),
-                                                    head.status,
-                                                    head.headers(),
-                                                ) {
-                                                    return Ok(res);
-                                                }
-
-                                                if tries == max_retry {
-                                                    return Ok(res);
-                                                }
-
-                                                tries += 1;
-                                            }
-                                        }
-                                    }
-                                    // SendRequestError
-                                    Err(e) => {
-                                        if tries == max_retry {
-                                            log::debug!("Request max retry reached");
-                                            return Err(e);
-                                        }
-
-                                        tries += 1;
+                        if let Ok(res) = result {
+                            match &res {
+                                ConnectResponse::Client(ref r) => {
+                                    if is_valid_response(
+                                        policies.as_ref(),
+                                        r.status(),
+                                        r.headers(),
+                                    ) {
+                                        return Ok(res);
                                     }
                                 }
-                            }
-                        }
-                        Body::Empty => {
-                            loop {
-                                let h = clone_request_head_type(&head);
-
-                                match connector
-                                    .call(ConnectRequest::Client(h, Body::Empty, addr))
-                                    .await
-                                {
-                                    Ok(res) => {
-                                        // ConnectResponse
-                                        match &res {
-                                            ConnectResponse::Client(ref r) => {
-                                                if is_valid_response(
-                                                    policies.as_ref(),
-                                                    r.status(),
-                                                    r.headers(),
-                                                ) {
-                                                    return Ok(res);
-                                                }
-
-                                                if tries == max_retry {
-                                                    log::debug!("Request max retry reached");
-                                                    return Ok(res);
-                                                }
-
-                                                tries += 1;
-                                            }
-                                            ConnectResponse::Tunnel(ref head, _) => {
-                                                if is_valid_response(
-                                                    policies.as_ref(),
-                                                    head.status,
-                                                    head.headers(),
-                                                ) {
-                                                    return Ok(res);
-                                                } else {
-                                                    if tries == max_retry {
-                                                        log::debug!(
-                                                            "Request max retry reached"
-                                                        );
-                                                        return Ok(res);
-                                                    }
-
-                                                    tries += 1;
-                                                }
-                                            }
-                                        }
-                                    }
-                                    // SendRequestError
-                                    Err(e) => {
-                                        if tries == max_retry {
-                                            log::debug!("Request max retry reached");
-                                            return Err(e);
-                                        }
-
-                                        tries += 1;
-                                    }
-                                }
-                            }
-                        }
-                        _ => {
-                            log::debug!(
-                                "Non cloneable body type given - defaulting to `Body::None`"
-                            );
-                            loop {
-                                let h = clone_request_head_type(&head);
-
-                                match connector
-                                    .call(ConnectRequest::Client(h, Body::None, addr))
-                                    .await
-                                {
-                                    Ok(res) => {
-                                        // ConnectResponse
-                                        match &res {
-                                            ConnectResponse::Client(ref r) => {
-                                                if is_valid_response(
-                                                    policies.as_ref(),
-                                                    r.status(),
-                                                    r.headers(),
-                                                ) {
-                                                    return Ok(res);
-                                                }
-
-                                                if tries == max_retry {
-                                                    log::debug!("Request max retry reached");
-                                                    return Ok(res);
-                                                }
-
-                                                tries += 1;
-                                            }
-                                            ConnectResponse::Tunnel(ref head, _) => {
-                                                if is_valid_response(
-                                                    policies.as_ref(),
-                                                    head.status,
-                                                    head.headers(),
-                                                ) {
-                                                    return Ok(res);
-                                                } else {
-                                                    if tries == max_retry {
-                                                        log::debug!(
-                                                            "Request max retry reached"
-                                                        );
-                                                        return Ok(res);
-                                                    }
-
-                                                    tries += 1;
-                                                }
-                                            }
-                                        }
-                                    }
-                                    // SendRequestError
-                                    Err(e) => {
-                                        if tries == max_retry {
-                                            log::debug!("Request max retry reached");
-                                            return Err(e);
-                                        }
-
-                                        tries += 1;
+                                ConnectResponse::Tunnel(ref head, _) => {
+                                    if is_valid_response(
+                                        policies.as_ref(),
+                                        head.status,
+                                        head.headers(),
+                                    ) {
+                                        return Ok(res);
                                     }
                                 }
                             }
                         }
                     }
+
+                    // Exceed max retry so just return whatever response is received
+                    log::debug!("Request max retry reached");
+                    connector.call(ConnectRequest::Client(
+                        head,
+                        body,
+                        addr,
+                    ))
+                        .await
                 }
-                ConnectRequest::Tunnel(head, addr) => loop {
-                    let h = clone_request_head(&head);
+                ConnectRequest::Tunnel(head, addr) => {
+                    for _ in 1..max_retry {
+                        let h = clone_request_head(&head);
 
-                    match connector.call(ConnectRequest::Tunnel(h, addr)).await {
-                        Ok(res) => match &res {
-                            ConnectResponse::Client(r) => {
-                                if is_valid_response(&policies, r.status(), r.headers()) {
-                                    return Ok(res);
+                        let result = connector.call(ConnectRequest::Tunnel(h, addr)).await;
+
+                        if let Ok(res) = result {
+                            match &res {
+                                ConnectResponse::Client(r) => {
+                                    if is_valid_response(&policies, r.status(), r.headers()) {
+                                        return Ok(res);
+                                    }
                                 }
-
-                                if tries == max_retry {
-                                    log::debug!("Request max retry reached");
-                                    return Ok(res);
+                                ConnectResponse::Tunnel(head, _) => {
+                                    if is_valid_response(&policies, head.status, head.headers()) {
+                                        return Ok(res);
+                                    }
                                 }
-
-                                tries += 1;
                             }
-                            ConnectResponse::Tunnel(head, _) => {
-                                if is_valid_response(&policies, head.status, head.headers()) {
-                                    return Ok(res);
-                                }
-
-                                if tries == max_retry {
-                                    log::debug!("Request max retry reached");
-                                    return Ok(res);
-                                }
-
-                                tries += 1;
-                            }
-                        },
-                        Err(e) => {
-                            if tries == max_retry {
-                                log::debug!("Request max retry reached");
-                                return Err(e);
-                            }
-
-                            tries += 1;
                         }
-                    }
-                },
+                    };
+
+                    // Exceed max retry so just return whatever response is received
+                    log::debug!("Request max retry reached");
+                    connector.call(ConnectRequest::Tunnel(head, addr)).await
+                }
             }
         })
     }
 }
 
+fn body_to_retry_body(body: &Body) -> Body {
+    match body {
+        Body::Empty => Body::Empty,
+        Body::Bytes(b) => Body::Bytes(b.clone()),
+        _ => Body::None
+    }
+}
+
 #[doc(hidden)]
 /// Clones [RequestHeadType] except for the extensions (not required for this middleware)
 fn clone_request_head_type(head_type: &RequestHeadType) -> RequestHeadType {
@@ -430,6 +280,9 @@ mod tests {
 
     #[actix_rt::test]
     async fn test_basic_policy() {
+        std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug");
+        env_logger::init();
+
         let client = ClientBuilder::new()
             .disable_redirects()
             .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR]))

From e87e636e840ca768a5d583af91d2dd80e5fb0154 Mon Sep 17 00:00:00 2001
From: joshbenaron <73971531+joshbenaron@users.noreply.github.com>
Date: Fri, 2 Apr 2021 19:06:39 +0100
Subject: [PATCH 4/4] Fix formatting

---
 awc/src/middleware/retry.rs | 39 ++++++++++++++++---------------------
 1 file changed, 17 insertions(+), 22 deletions(-)

diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs
index 6e595733..1f052f22 100644
--- a/awc/src/middleware/retry.rs
+++ b/awc/src/middleware/retry.rs
@@ -60,8 +60,8 @@ impl Retry {
     ///     .finish();
     ///```
     pub fn policy<T>(mut self, p: T) -> Self
-        where
-            T: IntoRetryPolicy,
+    where
+        T: IntoRetryPolicy,
     {
         self.0.policies.push(p.into_policy());
         self
@@ -79,8 +79,8 @@ pub trait IntoRetryPolicy {
 }
 
 impl<T> IntoRetryPolicy for T
-    where
-        T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static,
+where
+    T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static,
 {
     fn into_policy(self) -> RetryPolicy {
         RetryPolicy::Custom(Box::new(self))
@@ -94,8 +94,8 @@ impl IntoRetryPolicy for Vec<StatusCode> {
 }
 
 impl<S> Transform<S, ConnectRequest> for Retry
-    where
-        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
+where
+    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
 {
     type Transform = RetryService<S>;
 
@@ -116,8 +116,8 @@ pub struct RetryService<S> {
 }
 
 impl<S> Service<ConnectRequest> for RetryService<S>
-    where
-        S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static,
+where
+    S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
 {
     type Response = S::Response;
     type Error = S::Error;
@@ -135,7 +135,7 @@ impl<S> Service<ConnectRequest> for RetryService<S>
         Box::pin(async move {
             match req {
                 ConnectRequest::Client(head, body, addr) => {
-                    for _ in 1..max_retry {
+                    for _ in 0..max_retry {
                         let h = clone_request_head_type(&head);
 
                         let result = connector
@@ -168,11 +168,8 @@ impl<S> Service<ConnectRequest> for RetryService<S>
 
                     // Exceed max retry so just return whatever response is received
                     log::debug!("Request max retry reached");
-                    connector.call(ConnectRequest::Client(
-                        head,
-                        body,
-                        addr,
-                    ))
+                    connector
+                        .call(ConnectRequest::Client(head, body, addr))
                         .await
                 }
                 ConnectRequest::Tunnel(head, addr) => {
@@ -189,13 +186,14 @@ impl<S> Service<ConnectRequest> for RetryService<S>
                                     }
                                 }
                                 ConnectResponse::Tunnel(head, _) => {
-                                    if is_valid_response(&policies, head.status, head.headers()) {
+                                    if is_valid_response(&policies, head.status, head.headers())
+                                    {
                                         return Ok(res);
                                     }
                                 }
                             }
                         }
-                    };
+                    }
 
                     // Exceed max retry so just return whatever response is received
                     log::debug!("Request max retry reached");
@@ -210,7 +208,7 @@ fn body_to_retry_body(body: &Body) -> Body {
     match body {
         Body::Empty => Body::Empty,
         Body::Bytes(b) => Body::Bytes(b.clone()),
-        _ => Body::None
+        _ => Body::None,
     }
 }
 
@@ -280,12 +278,9 @@ mod tests {
 
     #[actix_rt::test]
     async fn test_basic_policy() {
-        std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug");
-        env_logger::init();
-
         let client = ClientBuilder::new()
             .disable_redirects()
-            .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR]))
+            .wrap(Retry::new(1).policy(vec![StatusCode::INTERNAL_SERVER_ERROR]))
             .finish();
 
         let srv = actix_test::start(|| {
@@ -304,7 +299,7 @@ mod tests {
         let client = ClientBuilder::new()
             .disable_redirects()
             .wrap(
-                Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| {
+                Retry::new(2).policy(|code: StatusCode, headers: &HeaderMap| {
                     code.is_success() && headers.contains_key("SOME_HEADER")
                 }),
             )