From bb1bf5f2be04fb3b1f0200c0912d507e9ba40a37 Mon Sep 17 00:00:00 2001 From: joshbenaron <73971531+joshbenaron@users.noreply.github.com> Date: Fri, 2 Apr 2021 13:27:30 +0100 Subject: [PATCH 1/4] Add awc retry middleware --- awc/src/middleware/mod.rs | 2 + awc/src/middleware/retry.rs | 472 ++++++++++++++++++++++++++++++++++++ 2 files changed, 474 insertions(+) create mode 100644 awc/src/middleware/retry.rs diff --git a/awc/src/middleware/mod.rs b/awc/src/middleware/mod.rs index 330e3b7f..df66ad70 100644 --- a/awc/src/middleware/mod.rs +++ b/awc/src/middleware/mod.rs @@ -1,6 +1,8 @@ mod redirect; +mod retry; pub use self::redirect::Redirect; +pub use self::retry::Retry; use std::marker::PhantomData; diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs new file mode 100644 index 00000000..27bb3a1c --- /dev/null +++ b/awc/src/middleware/retry.rs @@ -0,0 +1,472 @@ +use super::Transform; +use std::rc::Rc; +use actix_http::RequestHeadType; +use actix_http::http::{StatusCode, HeaderMap}; +use std::ops::Deref; +use crate::{ConnectRequest, ConnectResponse}; +use actix_service::Service; +use actix_http::client::SendRequestError; +use std::task::{Context, Poll}; +use crate::RequestHead; +use futures_core::future::LocalBoxFuture; +use actix_http::body::Body; + +pub struct Retry(Inner); + +struct Inner { + /// Number of retries. So each request will be tried [max_retry + 1] times + max_retry: u8, + policies: Vec<RetryPolicy>, +} + +impl Retry { + pub fn new(retries: u8) -> Self { + Retry(Inner { + max_retry: retries, + policies: vec![], + }) + } + + /// Allows you to add a retry policy to the [`policies`] + /// It allows two types of policy: + /// - `Vec<StatusCode>` and will retry if one of them is received + /// - `Fn(&ResponseHead) -> bool` and will retry when this function resolves to false + /// + /// # example + /// + ///``` + /// + /// // Creates a policy which will try each request a max of 5 times if any policies resolve to true + /// // i.e. + /// // if you receive a 401 or 501 status code + /// // or + /// // the response doesn't have a [`SOME_HEADER`] header + /// use awc::http::{StatusCode, HeaderMap}; + /// use awc::middleware::Retry; + /// + /// let retry_policies = Retry::new(5) + /// .policy(vec![StatusCode::INTERNAL_SERVER_ERROR, StatusCode::UNAUTHORIZED]) + /// .policy(|code: StatusCode, headers: &HeaderMap| { + /// return if headers.contains_key("SOME_HEADER") { + /// true + /// } else { + /// false + /// }; + /// }); + /// + /// // Creates awc client + /// let client = awc::Client::builder() + /// .wrap(retry_policies) + /// .finish(); + ///``` + pub fn policy<T>(mut self, p: T) -> Self + where T: IntoRetryPolicy + { + self.0.policies.push(p.into_policy()); + self + } +} + +#[non_exhaustive] +pub enum RetryPolicy { + Status(Vec<StatusCode>), + Custom(Box<dyn Fn(StatusCode, &HeaderMap) -> bool>), +} + +pub trait IntoRetryPolicy { + fn into_policy(self) -> RetryPolicy; +} + +impl<T> IntoRetryPolicy for T + where T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static +{ + fn into_policy(self) -> RetryPolicy { + RetryPolicy::Custom(Box::new(self)) + } +} + +impl IntoRetryPolicy for Vec<StatusCode> { + fn into_policy(self) -> RetryPolicy { + RetryPolicy::Status(self) + } +} + +impl<S> Transform<S, ConnectRequest> for Retry + where + S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, +{ + type Transform = RetryService<S>; + + fn new_transform(self, service: S) -> Self::Transform { + RetryService { + max_retry: self.0.max_retry, + policies: self.0.policies.into_boxed_slice().into(), + connector: service.into(), + } + } +} + +#[derive(Clone)] +pub struct RetryService<S> { + policies: Rc<[RetryPolicy]>, + max_retry: u8, + connector: Rc<S>, +} + +impl<S> Service<ConnectRequest> for RetryService<S> + where + S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; + + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.connector.poll_ready(ctx) + } + + fn call(&self, req: ConnectRequest) -> Self::Future { + let connector = self.connector.clone(); + let policies = self.policies.clone(); + let max_retry = self.max_retry; + + Box::pin(async move { + let mut tries = 0; + match req { + ConnectRequest::Client(head, body, addr) => { + match body { + Body::Bytes(b) => { + loop { + let h = clone_request_head_type(&head); + + match connector.call(ConnectRequest::Client(h, Body::Bytes(b.clone()), addr)).await + { + Ok(res) => { + // ConnectResponse + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + return Ok(res); + } + + if tries == max_retry { + return Ok(res); + } + + tries += 1; + } + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response(policies.as_ref(), head.status, head.headers()) { + return Ok(res); + } + + if tries == max_retry { + return Ok(res); + } + + tries += 1; + } + } + } + // SendRequestError + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; + } + } + } + } + Body::Empty => { + loop { + let h = clone_request_head_type(&head); + + match connector.call(ConnectRequest::Client(h, Body::Empty, addr)).await + { + Ok(res) => { + // ConnectResponse + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + return Ok(res); + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response(policies.as_ref(), head.status, head.headers()) { + return Ok(res); + } else { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + } + } + } + // SendRequestError + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; + } + } + } + } + _ => { + log::debug!("Non cloneable body type given - defaulting to `Body::None`"); + loop { + let h = clone_request_head_type(&head); + + match connector.call(ConnectRequest::Client(h, Body::None, addr)).await + { + Ok(res) => { + // ConnectResponse + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + return Ok(res); + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response(policies.as_ref(), head.status, head.headers()) { + return Ok(res); + } else { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + } + } + } + // SendRequestError + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; + } + } + } + } + } + } + ConnectRequest::Tunnel(head, addr) => { + loop { + let h = clone_request_head(&head); + + match connector.call(ConnectRequest::Tunnel(h, addr)).await { + Ok(res) => { + match &res { + ConnectResponse::Client(r) => { + if is_valid_response(&policies, r.status(), r.headers()) { + return Ok(res) + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res) + } + + tries += 1; + } + ConnectResponse::Tunnel(head, _) => { + if is_valid_response(&policies, head.status, head.headers()) { + return Ok(res) + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res) + } + + tries += 1; + } + } + }, + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e) + } + + tries += 1; + } + } + } + } + } + }) + } +} + +#[doc(hidden)] +/// Clones [RequestHeadType] except for the extensions (not required for this middleware) +fn clone_request_head_type(head_type: &RequestHeadType) -> RequestHeadType { + match head_type { + RequestHeadType::Owned(h) => { + let mut inner_head = RequestHead::default(); + inner_head.uri = h.uri.clone(); + inner_head.method = h.method.clone(); + inner_head.version = h.version; + inner_head.peer_addr = h.peer_addr; + inner_head.headers = h.headers.clone(); + + RequestHeadType::Owned(inner_head) + } + RequestHeadType::Rc(h, header_map) => { + RequestHeadType::Rc(h.clone(), header_map.clone()) + } + } +} + +#[doc(hidden)] +/// Clones [RequestHeadType] except for the extensions (not required for this middleware) +fn clone_request_head(head: &RequestHead) -> RequestHead { + let mut new_head = RequestHead::default(); + new_head.uri = head.uri.clone(); + new_head.method = head.method.clone(); + new_head.version = head.version; + new_head.headers = head.headers.clone(); + new_head.peer_addr = head.peer_addr; + + new_head +} + +#[doc(hidden)] +/// Checks whether the response matches the policies +fn is_valid_response(policies: &[RetryPolicy], status_code: StatusCode, headers: &HeaderMap) -> bool { + policies.iter().all(|policy| { + match policy { + RetryPolicy::Status(v) => { + // is valid if: + // - the list of status codes is empty + // or + // - the list doesn't contain the received status code + v.is_empty() || !v.contains(&status_code) + } + RetryPolicy::Custom(func) => { + // custom policy + (func.deref())(status_code, headers) + } + } + }) +} + +#[cfg(test)] +mod tests { + use actix_web::{web, App, Error, HttpResponse}; + + use super::*; + use crate::ClientBuilder; + + #[actix_rt::test] + async fn test_basic_policy() { + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Retry::new(3) + .policy(vec![StatusCode::INTERNAL_SERVER_ERROR]) + ) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::InternalServerError() + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/test")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 500); + } + + #[actix_rt::test] + async fn test_header_policy() { + std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); + env_logger::init(); + + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Retry::new(3) + .policy(|code: StatusCode, headers: &HeaderMap| { + code.is_success() && headers.contains_key("SOME_HEADER") + }) + ) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/test")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_bad_header_policy() { + std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); + env_logger::init(); + + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Retry::new(3) + .policy(|code: StatusCode, headers: &HeaderMap| { + code.is_success() && headers.contains_key("WRONG_HEADER") + }) + ) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/test")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 200); + } +} From 58904b9ebc69d57511ae4d718b21ef6dcb3b2b7c Mon Sep 17 00:00:00 2001 From: joshbenaron <73971531+joshbenaron@users.noreply.github.com> Date: Fri, 2 Apr 2021 14:50:43 +0100 Subject: [PATCH 2/4] Fixed invalid tests --- awc/src/middleware/retry.rs | 228 ++++++++++++++++++++---------------- 1 file changed, 128 insertions(+), 100 deletions(-) diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs index 27bb3a1c..32b54ffc 100644 --- a/awc/src/middleware/retry.rs +++ b/awc/src/middleware/retry.rs @@ -1,15 +1,15 @@ use super::Transform; -use std::rc::Rc; -use actix_http::RequestHeadType; -use actix_http::http::{StatusCode, HeaderMap}; -use std::ops::Deref; -use crate::{ConnectRequest, ConnectResponse}; -use actix_service::Service; -use actix_http::client::SendRequestError; -use std::task::{Context, Poll}; use crate::RequestHead; -use futures_core::future::LocalBoxFuture; +use crate::{ConnectRequest, ConnectResponse}; use actix_http::body::Body; +use actix_http::client::SendRequestError; +use actix_http::http::{HeaderMap, StatusCode}; +use actix_http::RequestHeadType; +use actix_service::Service; +use futures_core::future::LocalBoxFuture; +use std::ops::Deref; +use std::rc::Rc; +use std::task::{Context, Poll}; pub struct Retry(Inner); @@ -60,7 +60,8 @@ impl Retry { /// .finish(); ///``` pub fn policy<T>(mut self, p: T) -> Self - where T: IntoRetryPolicy + where + T: IntoRetryPolicy, { self.0.policies.push(p.into_policy()); self @@ -78,7 +79,8 @@ pub trait IntoRetryPolicy { } impl<T> IntoRetryPolicy for T - where T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static +where + T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, { fn into_policy(self) -> RetryPolicy { RetryPolicy::Custom(Box::new(self)) @@ -92,8 +94,8 @@ impl IntoRetryPolicy for Vec<StatusCode> { } impl<S> Transform<S, ConnectRequest> for Retry - where - S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, +where + S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static, { type Transform = RetryService<S>; @@ -114,8 +116,8 @@ pub struct RetryService<S> { } impl<S> Service<ConnectRequest> for RetryService<S> - where - S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, +where + S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static, { type Response = S::Response; type Error = S::Error; @@ -139,13 +141,23 @@ impl<S> Service<ConnectRequest> for RetryService<S> loop { let h = clone_request_head_type(&head); - match connector.call(ConnectRequest::Client(h, Body::Bytes(b.clone()), addr)).await + match connector + .call(ConnectRequest::Client( + h, + Body::Bytes(b.clone()), + addr, + )) + .await { Ok(res) => { // ConnectResponse match &res { ConnectResponse::Client(ref r) => { - if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { return Ok(res); } @@ -156,7 +168,11 @@ impl<S> Service<ConnectRequest> for RetryService<S> tries += 1; } ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response(policies.as_ref(), head.status, head.headers()) { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { return Ok(res); } @@ -184,13 +200,19 @@ impl<S> Service<ConnectRequest> for RetryService<S> loop { let h = clone_request_head_type(&head); - match connector.call(ConnectRequest::Client(h, Body::Empty, addr)).await + match connector + .call(ConnectRequest::Client(h, Body::Empty, addr)) + .await { Ok(res) => { // ConnectResponse match &res { ConnectResponse::Client(ref r) => { - if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { return Ok(res); } @@ -202,11 +224,17 @@ impl<S> Service<ConnectRequest> for RetryService<S> tries += 1; } ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response(policies.as_ref(), head.status, head.headers()) { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { return Ok(res); } else { if tries == max_retry { - log::debug!("Request max retry reached"); + log::debug!( + "Request max retry reached" + ); return Ok(res); } @@ -228,17 +256,25 @@ impl<S> Service<ConnectRequest> for RetryService<S> } } _ => { - log::debug!("Non cloneable body type given - defaulting to `Body::None`"); + log::debug!( + "Non cloneable body type given - defaulting to `Body::None`" + ); loop { let h = clone_request_head_type(&head); - match connector.call(ConnectRequest::Client(h, Body::None, addr)).await + match connector + .call(ConnectRequest::Client(h, Body::None, addr)) + .await { Ok(res) => { // ConnectResponse match &res { ConnectResponse::Client(ref r) => { - if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { return Ok(res); } @@ -250,11 +286,17 @@ impl<S> Service<ConnectRequest> for RetryService<S> tries += 1; } ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response(policies.as_ref(), head.status, head.headers()) { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { return Ok(res); } else { if tries == max_retry { - log::debug!("Request max retry reached"); + log::debug!( + "Request max retry reached" + ); return Ok(res); } @@ -277,50 +319,46 @@ impl<S> Service<ConnectRequest> for RetryService<S> } } } - ConnectRequest::Tunnel(head, addr) => { - loop { - let h = clone_request_head(&head); + ConnectRequest::Tunnel(head, addr) => loop { + let h = clone_request_head(&head); - match connector.call(ConnectRequest::Tunnel(h, addr)).await { - Ok(res) => { - match &res { - ConnectResponse::Client(r) => { - if is_valid_response(&policies, r.status(), r.headers()) { - return Ok(res) - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res) - } - - tries += 1; - } - ConnectResponse::Tunnel(head, _) => { - if is_valid_response(&policies, head.status, head.headers()) { - return Ok(res) - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res) - } - - tries += 1; - } + match connector.call(ConnectRequest::Tunnel(h, addr)).await { + Ok(res) => match &res { + ConnectResponse::Client(r) => { + if is_valid_response(&policies, r.status(), r.headers()) { + return Ok(res); } - }, - Err(e) => { + if tries == max_retry { log::debug!("Request max retry reached"); - return Err(e) + return Ok(res); } tries += 1; } + ConnectResponse::Tunnel(head, _) => { + if is_valid_response(&policies, head.status, head.headers()) { + return Ok(res); + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + }, + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; } } - } + }, } }) } @@ -361,7 +399,11 @@ fn clone_request_head(head: &RequestHead) -> RequestHead { #[doc(hidden)] /// Checks whether the response matches the policies -fn is_valid_response(policies: &[RetryPolicy], status_code: StatusCode, headers: &HeaderMap) -> bool { +fn is_valid_response( + policies: &[RetryPolicy], + status_code: StatusCode, + headers: &HeaderMap, +) -> bool { policies.iter().all(|policy| { match policy { RetryPolicy::Status(v) => { @@ -390,19 +432,13 @@ mod tests { async fn test_basic_policy() { let client = ClientBuilder::new() .disable_redirects() - .wrap(Retry::new(3) - .policy(vec![StatusCode::INTERNAL_SERVER_ERROR]) - ) + .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR])) .finish(); let srv = actix_test::start(|| { - App::new() - .service(web::resource("/test").route(web::to(|| async { - Ok::<_, Error>( - HttpResponse::InternalServerError() - .finish(), - ) - }))) + App::new().service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::InternalServerError().finish()) + }))) }); let res = client.get(srv.url("/test")).send().await.unwrap(); @@ -412,27 +448,23 @@ mod tests { #[actix_rt::test] async fn test_header_policy() { - std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); - env_logger::init(); - let client = ClientBuilder::new() .disable_redirects() - .wrap(Retry::new(3) - .policy(|code: StatusCode, headers: &HeaderMap| { + .wrap( + Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| { code.is_success() && headers.contains_key("SOME_HEADER") - }) + }), ) .finish(); let srv = actix_test::start(|| { - App::new() - .service(web::resource("/test").route(web::to(|| async { - Ok::<_, Error>( - HttpResponse::Ok() - .insert_header(("SOME_HEADER", "test")) - .finish(), - ) - }))) + App::new().service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) }); let res = client.get(srv.url("/test")).send().await.unwrap(); @@ -442,27 +474,23 @@ mod tests { #[actix_rt::test] async fn test_bad_header_policy() { - std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); - env_logger::init(); - let client = ClientBuilder::new() .disable_redirects() - .wrap(Retry::new(3) - .policy(|code: StatusCode, headers: &HeaderMap| { + .wrap( + Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| { code.is_success() && headers.contains_key("WRONG_HEADER") - }) + }), ) .finish(); let srv = actix_test::start(|| { - App::new() - .service(web::resource("/test").route(web::to(|| async { - Ok::<_, Error>( - HttpResponse::Ok() - .insert_header(("SOME_HEADER", "test")) - .finish(), - ) - }))) + App::new().service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) }); let res = client.get(srv.url("/test")).send().await.unwrap(); From 22e40613e87e5149a37595f3d116e477939a5bc7 Mon Sep 17 00:00:00 2001 From: joshbenaron <73971531+joshbenaron@users.noreply.github.com> Date: Fri, 2 Apr 2021 18:50:07 +0100 Subject: [PATCH 3/4] Improved readability and logic --- awc/src/middleware/retry.rs | 287 +++++++++--------------------------- 1 file changed, 70 insertions(+), 217 deletions(-) diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs index 32b54ffc..6e595733 100644 --- a/awc/src/middleware/retry.rs +++ b/awc/src/middleware/retry.rs @@ -60,8 +60,8 @@ impl Retry { /// .finish(); ///``` pub fn policy<T>(mut self, p: T) -> Self - where - T: IntoRetryPolicy, + where + T: IntoRetryPolicy, { self.0.policies.push(p.into_policy()); self @@ -79,8 +79,8 @@ pub trait IntoRetryPolicy { } impl<T> IntoRetryPolicy for T -where - T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, + where + T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, { fn into_policy(self) -> RetryPolicy { RetryPolicy::Custom(Box::new(self)) @@ -94,8 +94,8 @@ impl IntoRetryPolicy for Vec<StatusCode> { } impl<S> Transform<S, ConnectRequest> for Retry -where - S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static, + where + S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, { type Transform = RetryService<S>; @@ -116,8 +116,8 @@ pub struct RetryService<S> { } impl<S> Service<ConnectRequest> for RetryService<S> -where - S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static, + where + S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, { type Response = S::Response; type Error = S::Error; @@ -133,237 +133,87 @@ where let max_retry = self.max_retry; Box::pin(async move { - let mut tries = 0; match req { ConnectRequest::Client(head, body, addr) => { - match body { - Body::Bytes(b) => { - loop { - let h = clone_request_head_type(&head); + for _ in 1..max_retry { + let h = clone_request_head_type(&head); - match connector - .call(ConnectRequest::Client( - h, - Body::Bytes(b.clone()), - addr, - )) - .await - { - Ok(res) => { - // ConnectResponse - match &res { - ConnectResponse::Client(ref r) => { - if is_valid_response( - policies.as_ref(), - r.status(), - r.headers(), - ) { - return Ok(res); - } + let result = connector + .call(ConnectRequest::Client(h, body_to_retry_body(&body), addr)) + .await; - if tries == max_retry { - return Ok(res); - } - - tries += 1; - } - ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response( - policies.as_ref(), - head.status, - head.headers(), - ) { - return Ok(res); - } - - if tries == max_retry { - return Ok(res); - } - - tries += 1; - } - } - } - // SendRequestError - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; + if let Ok(res) = result { + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response( + policies.as_ref(), + r.status(), + r.headers(), + ) { + return Ok(res); } } - } - } - Body::Empty => { - loop { - let h = clone_request_head_type(&head); - - match connector - .call(ConnectRequest::Client(h, Body::Empty, addr)) - .await - { - Ok(res) => { - // ConnectResponse - match &res { - ConnectResponse::Client(ref r) => { - if is_valid_response( - policies.as_ref(), - r.status(), - r.headers(), - ) { - return Ok(res); - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); - } - - tries += 1; - } - ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response( - policies.as_ref(), - head.status, - head.headers(), - ) { - return Ok(res); - } else { - if tries == max_retry { - log::debug!( - "Request max retry reached" - ); - return Ok(res); - } - - tries += 1; - } - } - } - } - // SendRequestError - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; - } - } - } - } - _ => { - log::debug!( - "Non cloneable body type given - defaulting to `Body::None`" - ); - loop { - let h = clone_request_head_type(&head); - - match connector - .call(ConnectRequest::Client(h, Body::None, addr)) - .await - { - Ok(res) => { - // ConnectResponse - match &res { - ConnectResponse::Client(ref r) => { - if is_valid_response( - policies.as_ref(), - r.status(), - r.headers(), - ) { - return Ok(res); - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); - } - - tries += 1; - } - ConnectResponse::Tunnel(ref head, _) => { - if is_valid_response( - policies.as_ref(), - head.status, - head.headers(), - ) { - return Ok(res); - } else { - if tries == max_retry { - log::debug!( - "Request max retry reached" - ); - return Ok(res); - } - - tries += 1; - } - } - } - } - // SendRequestError - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response( + policies.as_ref(), + head.status, + head.headers(), + ) { + return Ok(res); } } } } } + + // Exceed max retry so just return whatever response is received + log::debug!("Request max retry reached"); + connector.call(ConnectRequest::Client( + head, + body, + addr, + )) + .await } - ConnectRequest::Tunnel(head, addr) => loop { - let h = clone_request_head(&head); + ConnectRequest::Tunnel(head, addr) => { + for _ in 1..max_retry { + let h = clone_request_head(&head); - match connector.call(ConnectRequest::Tunnel(h, addr)).await { - Ok(res) => match &res { - ConnectResponse::Client(r) => { - if is_valid_response(&policies, r.status(), r.headers()) { - return Ok(res); + let result = connector.call(ConnectRequest::Tunnel(h, addr)).await; + + if let Ok(res) = result { + match &res { + ConnectResponse::Client(r) => { + if is_valid_response(&policies, r.status(), r.headers()) { + return Ok(res); + } } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); + ConnectResponse::Tunnel(head, _) => { + if is_valid_response(&policies, head.status, head.headers()) { + return Ok(res); + } } - - tries += 1; } - ConnectResponse::Tunnel(head, _) => { - if is_valid_response(&policies, head.status, head.headers()) { - return Ok(res); - } - - if tries == max_retry { - log::debug!("Request max retry reached"); - return Ok(res); - } - - tries += 1; - } - }, - Err(e) => { - if tries == max_retry { - log::debug!("Request max retry reached"); - return Err(e); - } - - tries += 1; } - } - }, + }; + + // Exceed max retry so just return whatever response is received + log::debug!("Request max retry reached"); + connector.call(ConnectRequest::Tunnel(head, addr)).await + } } }) } } +fn body_to_retry_body(body: &Body) -> Body { + match body { + Body::Empty => Body::Empty, + Body::Bytes(b) => Body::Bytes(b.clone()), + _ => Body::None + } +} + #[doc(hidden)] /// Clones [RequestHeadType] except for the extensions (not required for this middleware) fn clone_request_head_type(head_type: &RequestHeadType) -> RequestHeadType { @@ -430,6 +280,9 @@ mod tests { #[actix_rt::test] async fn test_basic_policy() { + std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); + env_logger::init(); + let client = ClientBuilder::new() .disable_redirects() .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR])) From e87e636e840ca768a5d583af91d2dd80e5fb0154 Mon Sep 17 00:00:00 2001 From: joshbenaron <73971531+joshbenaron@users.noreply.github.com> Date: Fri, 2 Apr 2021 19:06:39 +0100 Subject: [PATCH 4/4] Fix formatting --- awc/src/middleware/retry.rs | 39 ++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs index 6e595733..1f052f22 100644 --- a/awc/src/middleware/retry.rs +++ b/awc/src/middleware/retry.rs @@ -60,8 +60,8 @@ impl Retry { /// .finish(); ///``` pub fn policy<T>(mut self, p: T) -> Self - where - T: IntoRetryPolicy, + where + T: IntoRetryPolicy, { self.0.policies.push(p.into_policy()); self @@ -79,8 +79,8 @@ pub trait IntoRetryPolicy { } impl<T> IntoRetryPolicy for T - where - T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, +where + T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static, { fn into_policy(self) -> RetryPolicy { RetryPolicy::Custom(Box::new(self)) @@ -94,8 +94,8 @@ impl IntoRetryPolicy for Vec<StatusCode> { } impl<S> Transform<S, ConnectRequest> for Retry - where - S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, +where + S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static, { type Transform = RetryService<S>; @@ -116,8 +116,8 @@ pub struct RetryService<S> { } impl<S> Service<ConnectRequest> for RetryService<S> - where - S: Service<ConnectRequest, Response=ConnectResponse, Error=SendRequestError> + 'static, +where + S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static, { type Response = S::Response; type Error = S::Error; @@ -135,7 +135,7 @@ impl<S> Service<ConnectRequest> for RetryService<S> Box::pin(async move { match req { ConnectRequest::Client(head, body, addr) => { - for _ in 1..max_retry { + for _ in 0..max_retry { let h = clone_request_head_type(&head); let result = connector @@ -168,11 +168,8 @@ impl<S> Service<ConnectRequest> for RetryService<S> // Exceed max retry so just return whatever response is received log::debug!("Request max retry reached"); - connector.call(ConnectRequest::Client( - head, - body, - addr, - )) + connector + .call(ConnectRequest::Client(head, body, addr)) .await } ConnectRequest::Tunnel(head, addr) => { @@ -189,13 +186,14 @@ impl<S> Service<ConnectRequest> for RetryService<S> } } ConnectResponse::Tunnel(head, _) => { - if is_valid_response(&policies, head.status, head.headers()) { + if is_valid_response(&policies, head.status, head.headers()) + { return Ok(res); } } } } - }; + } // Exceed max retry so just return whatever response is received log::debug!("Request max retry reached"); @@ -210,7 +208,7 @@ fn body_to_retry_body(body: &Body) -> Body { match body { Body::Empty => Body::Empty, Body::Bytes(b) => Body::Bytes(b.clone()), - _ => Body::None + _ => Body::None, } } @@ -280,12 +278,9 @@ mod tests { #[actix_rt::test] async fn test_basic_policy() { - std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); - env_logger::init(); - let client = ClientBuilder::new() .disable_redirects() - .wrap(Retry::new(3).policy(vec![StatusCode::INTERNAL_SERVER_ERROR])) + .wrap(Retry::new(1).policy(vec![StatusCode::INTERNAL_SERVER_ERROR])) .finish(); let srv = actix_test::start(|| { @@ -304,7 +299,7 @@ mod tests { let client = ClientBuilder::new() .disable_redirects() .wrap( - Retry::new(3).policy(|code: StatusCode, headers: &HeaderMap| { + Retry::new(2).policy(|code: StatusCode, headers: &HeaderMap| { code.is_success() && headers.contains_key("SOME_HEADER") }), )