diff --git a/awc/src/middleware/mod.rs b/awc/src/middleware/mod.rs index ae767c4e8..330e3b7fe 100644 --- a/awc/src/middleware/mod.rs +++ b/awc/src/middleware/mod.rs @@ -1,6 +1,6 @@ mod redirect; -pub use self::redirect::RedirectMiddleware; +pub use self::redirect::Redirect; use std::marker::PhantomData; diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index 4f7184b80..a08a0027e 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -4,6 +4,7 @@ use std::{ pin::Pin, rc::Rc, task::{Context, Poll}, + convert::TryFrom }; use actix_http::{ @@ -21,17 +22,17 @@ use super::Transform; use crate::connect::{ConnectRequest, ConnectResponse}; use crate::ClientResponse; -pub struct RedirectMiddleware { +pub struct Redirect { max_redirect_times: u8, } -impl Default for RedirectMiddleware { +impl Default for Redirect { fn default() -> Self { Self::new() } } -impl RedirectMiddleware { +impl Redirect { pub fn new() -> Self { Self { max_redirect_times: 10, @@ -44,7 +45,7 @@ impl RedirectMiddleware { } } -impl Transform for RedirectMiddleware +impl Transform for Redirect where S: Service + 'static, { @@ -150,115 +151,112 @@ where body, addr, connector, - } => { - match ready!(fut.poll(cx))? { - ConnectResponse::Client(res) => match res.head().status { - StatusCode::MOVED_PERMANENTLY - | StatusCode::FOUND - | StatusCode::SEE_OTHER - if *max_redirect_times > 0 => - { - let org_uri = uri.take().unwrap(); - // rebuild uri from the location header value. - let uri = rebuild_uri(&res, org_uri)?; + } => match ready!(fut.poll(cx))? { + ConnectResponse::Client(res) => match res.head().status { + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::SEE_OTHER + if *max_redirect_times > 0 => + { + let org_uri = uri.take().unwrap(); + // rebuild uri from the location header value. + let uri = rebuild_uri(&res, org_uri)?; - // reset method - let method = method.take().unwrap(); - let method = match method { - Method::GET | Method::HEAD => method, - _ => Method::GET, - }; + // reset method + let method = method.take().unwrap(); + let method = match method { + Method::GET | Method::HEAD => method, + _ => Method::GET, + }; - // take ownership of states that could be reused - let addr = addr.take(); - let connector = connector.take(); - let mut max_redirect_times = *max_redirect_times; + // take ownership of states that could be reused + let addr = addr.take(); + let connector = connector.take(); + let mut max_redirect_times = *max_redirect_times; - // use a new request head. - let mut head = RequestHead::default(); - head.uri = uri.clone(); - head.method = method.clone(); + // use a new request head. + let mut head = RequestHead::default(); + head.uri = uri.clone(); + head.method = method.clone(); - let head = RequestHeadType::Owned(head); + let head = RequestHeadType::Owned(head); - max_redirect_times -= 1; + max_redirect_times -= 1; - let fut = connector - .as_ref() - .unwrap() - // remove body - .call(ConnectRequest::Client(head, Body::None, addr)); + let fut = connector + .as_ref() + .unwrap() + // remove body + .call(ConnectRequest::Client(head, Body::None, addr)); - self.as_mut().set(RedirectServiceFuture::Client { - fut, - max_redirect_times, - uri: Some(uri), - method: Some(method), - // body is dropped on 301,302,303 - body: None, - addr, - connector, - }); + self.as_mut().set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + // body is dropped on 301,302,303 + body: None, + addr, + connector, + }); - self.poll(cx) - } - StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT - if *max_redirect_times > 0 => - { - let org_uri = uri.take().unwrap(); - // rebuild uri from the location header value. - let uri = rebuild_uri(&res, org_uri)?; + self.poll(cx) + } + StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT + if *max_redirect_times > 0 => + { + let org_uri = uri.take().unwrap(); + // rebuild uri from the location header value. + let uri = rebuild_uri(&res, org_uri)?; - // try to reuse body - let body = body.take(); - let body_new = match body { - Some(ref bytes) => Body::Bytes(bytes.clone()), - // TODO: should this be Body::Empty or Body::None. - _ => Body::Empty, - }; + // try to reuse body + let body = body.take(); + let body_new = match body { + Some(ref bytes) => Body::Bytes(bytes.clone()), + // TODO: should this be Body::Empty or Body::None. + _ => Body::Empty, + }; - let addr = addr.take(); - let method = method.take(); - let connector = connector.take(); - let mut max_redirect_times = *max_redirect_times; + let addr = addr.take(); + let method = method.take().unwrap(); + let connector = connector.take(); + let mut max_redirect_times = *max_redirect_times; - // use a new request head. - let mut head = RequestHead::default(); - head.uri = uri.clone(); + // use a new request head. + let mut head = RequestHead::default(); + head.uri = uri.clone(); + head.method = method.clone(); - let head = RequestHeadType::Owned(head); + let head = RequestHeadType::Owned(head); - max_redirect_times -= 1; + max_redirect_times -= 1; - let fut = connector - .as_ref() - .unwrap() - .call(ConnectRequest::Client(head, body_new, addr)); + let fut = connector + .as_ref() + .unwrap() + .call(ConnectRequest::Client(head, body_new, addr)); - self.as_mut().set(RedirectServiceFuture::Client { - fut, - max_redirect_times, - uri: Some(uri), - method, - body, - addr, - connector, - }); + self.as_mut().set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + body, + addr, + connector, + }); - self.poll(cx) - } - _ => Poll::Ready(Ok(ConnectResponse::Client(res))), - }, - _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"), - } - } + self.poll(cx) + } + _ => Poll::Ready(Ok(ConnectResponse::Client(res))), + }, + _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"), + }, } } } fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result { - use std::convert::TryFrom; - let uri = res .headers() .get(header::LOCATION) @@ -295,7 +293,7 @@ mod tests { async fn test_basic_redirect() { let client = ClientBuilder::new() .connector(crate::Connector::new()) - .wrap(RedirectMiddleware::new().max_redirect_times(10)) + .wrap(Redirect::new().max_redirect_times(10)) .finish(); let srv = start(|| { @@ -320,7 +318,7 @@ mod tests { #[actix_rt::test] async fn test_redirect_limit() { let client = ClientBuilder::new() - .wrap(RedirectMiddleware::new().max_redirect_times(1)) + .wrap(Redirect::new().max_redirect_times(1)) .connector(crate::Connector::new()) .finish(); diff --git a/awc/src/response.rs b/awc/src/response.rs index 514b8a90b..40de3dc17 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -492,9 +492,7 @@ mod tests { JsonPayloadError::Payload(PayloadError::Overflow) => { matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) } - JsonPayloadError::ContentType => { - matches!(other, JsonPayloadError::ContentType) - } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), _ => false, } }