diff --git a/awc/src/builder.rs b/awc/src/builder.rs index bfef55c1e..92bef9d1e 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -18,7 +18,6 @@ use crate::{Client, ClientConfig, ConnectRequest, ConnectResponse, ConnectorServ /// This type can be used to construct an instance of `Client` through a /// builder-like pattern. pub struct ClientBuilder { - middleware: M, default_headers: bool, max_http_version: Option, stream_window_size: Option, @@ -26,6 +25,7 @@ pub struct ClientBuilder { headers: HeaderMap, timeout: Option, connector: Connector, + middleware: M, } impl ClientBuilder { diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index 02d37c3b4..a92892676 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -1,14 +1,19 @@ -use std::rc::Rc; +use std::{ + future::Future, + net::SocketAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; -use actix_http::client::InvalidUrl; use actix_http::{ body::Body, - client::SendRequestError, + client::{InvalidUrl, SendRequestError}, http::{header, StatusCode, Uri}, RequestHead, RequestHeadType, }; use actix_service::Service; -use futures_core::future::LocalBoxFuture; +use futures_core::ready; use super::Transform; @@ -62,76 +67,134 @@ where { type Response = S::Response; type Error = S::Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = RedirectServiceFuture; actix_service::forward_ready!(connector); fn call(&self, req: ConnectRequest) -> Self::Future { - let connector = self.connector.clone(); - let mut max_redirect_times = self.max_redirect_times; + match req { + ConnectRequest::Tunnel(head, addr) => { + let fut = self.connector.call(ConnectRequest::Tunnel(head, addr)); + RedirectServiceFuture::Tunnel { fut } + } + ConnectRequest::Client(head, body, addr) => { + let connector = self.connector.clone(); + let max_redirect_times = self.max_redirect_times; - Box::pin(async move { - match req { - // tunnel request is skipped. - ConnectRequest::Tunnel(head, addr) => { - return connector.call(ConnectRequest::Tunnel(head, addr)).await - } - ConnectRequest::Client(mut head, mut body, addr) => { - // backup the uri for reuse schema and authority. - let uri = match head { - RequestHeadType::Owned(ref head) => head.uri.clone(), - RequestHeadType::Rc(ref head, ..) => head.uri.clone(), - }; + // backup the uri for reuse schema and authority. + let uri = match head { + RequestHeadType::Owned(ref head) => head.uri.clone(), + RequestHeadType::Rc(ref head, ..) => head.uri.clone(), + }; - loop { - let res = connector - .call(ConnectRequest::Client(head, body, addr.clone())) - .await?; - match res { - ConnectResponse::Client(res) => match res.head().status { - StatusCode::MOVED_PERMANENTLY - | StatusCode::FOUND - | StatusCode::SEE_OTHER - | StatusCode::TEMPORARY_REDIRECT - | StatusCode::PERMANENT_REDIRECT - if max_redirect_times > 0 => - { - // rebuild uri from the location header value. - let uri = res - .headers() - .get(header::LOCATION) - .map(|value| { - Uri::builder() - .scheme(uri.scheme().cloned().unwrap()) - .authority(uri.authority().cloned().unwrap()) - .path_and_query(value.as_bytes()) - }) - .ok_or(SendRequestError::Url( - InvalidUrl::MissingScheme, - ))? - .build()?; + let fut = connector.call(ConnectRequest::Client(head, body, addr)); - // use a new request head. - let mut head_new = RequestHead::default(); - head_new.uri = uri; - - head = RequestHeadType::Owned(head_new); - - // throw body - body = Body::None; - - max_redirect_times -= 1; - } - _ => return Ok(ConnectResponse::Client(res)), - }, - _ => unreachable!( - " ConnectRequest::Tunnel is not handled by Redirect" - ), - } - } + RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + addr, + connector: Some(connector), } } - }) + } + } +} + +pin_project_lite::pin_project! { + #[project = RedirectServiceProj] + pub enum RedirectServiceFuture + where + S: Service, + S: 'static + { + Tunnel { #[pin] fut: S::Future }, + Client { + #[pin] + fut: S::Future, + max_redirect_times: u8, + uri: Option, + addr: Option, + connector: Option> + } + } +} + +impl Future for RedirectServiceFuture +where + S: Service + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + RedirectServiceProj::Tunnel { fut } => fut.poll(cx), + RedirectServiceProj::Client { + fut, + max_redirect_times, + uri, + addr, + connector, + } => { + match ready!(fut.poll(cx))? { + ConnectResponse::Client(res) => match res.head().status { + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::SEE_OTHER + | StatusCode::TEMPORARY_REDIRECT + | StatusCode::PERMANENT_REDIRECT + if *max_redirect_times > 0 => + { + let uri = uri.take().unwrap(); + + // rebuild uri from the location header value. + let uri = res + .headers() + .get(header::LOCATION) + .map(|value| { + Uri::builder() + .scheme(uri.scheme().cloned().unwrap()) + .authority(uri.authority().cloned().unwrap()) + .path_and_query(value.as_bytes()) + }) + .ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))? + .build()?; + + let addr = addr.take(); + let connector = connector.take(); + let mut max_redirect_times = *max_redirect_times; + + // use a new request head. + let mut head = RequestHead::default(); + head.uri = uri.clone(); + let head = RequestHeadType::Owned(head); + + // throw body + let body = Body::None; + + max_redirect_times -= 1; + + let fut = connector + .as_ref() + .unwrap() + .call(ConnectRequest::Client(head, body, addr)); + + self.as_mut().set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + addr, + connector, + }); + + self.poll(cx) + } + _ => Poll::Ready(Ok(ConnectResponse::Client(res))), + }, + _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"), + } + } + } } } @@ -146,6 +209,7 @@ mod tests { #[actix_rt::test] async fn test_basic_redirect() { let client = ClientBuilder::new() + .connector(crate::Connector::new()) .wrap(RedirectMiddleware::new().max_redirect_times(10)) .finish(); @@ -172,6 +236,7 @@ mod tests { async fn test_redirect_limit() { let client = ClientBuilder::new() .wrap(RedirectMiddleware::new().max_redirect_times(1)) + .connector(crate::Connector::new()) .finish(); let srv = start(|| {