static dispatch redirect future

This commit is contained in:
fakeshadow 2021-02-19 00:20:26 -08:00
parent d7bdd04336
commit 68d079ba36
2 changed files with 131 additions and 66 deletions

View File

@ -18,7 +18,6 @@ use crate::{Client, ClientConfig, ConnectRequest, ConnectResponse, ConnectorServ
/// This type can be used to construct an instance of `Client` through a /// This type can be used to construct an instance of `Client` through a
/// builder-like pattern. /// builder-like pattern.
pub struct ClientBuilder<S = (), Io = (), M = ()> { pub struct ClientBuilder<S = (), Io = (), M = ()> {
middleware: M,
default_headers: bool, default_headers: bool,
max_http_version: Option<http::Version>, max_http_version: Option<http::Version>,
stream_window_size: Option<u32>, stream_window_size: Option<u32>,
@ -26,6 +25,7 @@ pub struct ClientBuilder<S = (), Io = (), M = ()> {
headers: HeaderMap, headers: HeaderMap,
timeout: Option<Duration>, timeout: Option<Duration>,
connector: Connector<S, Io>, connector: Connector<S, Io>,
middleware: M,
} }
impl ClientBuilder { impl ClientBuilder {

View File

@ -1,14 +1,19 @@
use std::rc::Rc; use std::{
future::Future,
net::SocketAddr,
pin::Pin,
rc::Rc,
task::{Context, Poll},
};
use actix_http::client::InvalidUrl;
use actix_http::{ use actix_http::{
body::Body, body::Body,
client::SendRequestError, client::{InvalidUrl, SendRequestError},
http::{header, StatusCode, Uri}, http::{header, StatusCode, Uri},
RequestHead, RequestHeadType, RequestHead, RequestHeadType,
}; };
use actix_service::Service; use actix_service::Service;
use futures_core::future::LocalBoxFuture; use futures_core::ready;
use super::Transform; use super::Transform;
@ -62,76 +67,134 @@ where
{ {
type Response = S::Response; type Response = S::Response;
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = RedirectServiceFuture<S>;
actix_service::forward_ready!(connector); actix_service::forward_ready!(connector);
fn call(&self, req: ConnectRequest) -> Self::Future { fn call(&self, req: ConnectRequest) -> Self::Future {
let connector = self.connector.clone(); match req {
let mut max_redirect_times = self.max_redirect_times; ConnectRequest::Tunnel(head, addr) => {
let fut = self.connector.call(ConnectRequest::Tunnel(head, addr));
RedirectServiceFuture::Tunnel { fut }
}
ConnectRequest::Client(head, body, addr) => {
let connector = self.connector.clone();
let max_redirect_times = self.max_redirect_times;
Box::pin(async move { // backup the uri for reuse schema and authority.
match req { let uri = match head {
// tunnel request is skipped. RequestHeadType::Owned(ref head) => head.uri.clone(),
ConnectRequest::Tunnel(head, addr) => { RequestHeadType::Rc(ref head, ..) => head.uri.clone(),
return connector.call(ConnectRequest::Tunnel(head, addr)).await };
}
ConnectRequest::Client(mut head, mut body, addr) => {
// backup the uri for reuse schema and authority.
let uri = match head {
RequestHeadType::Owned(ref head) => head.uri.clone(),
RequestHeadType::Rc(ref head, ..) => head.uri.clone(),
};
loop { let fut = connector.call(ConnectRequest::Client(head, body, addr));
let res = connector
.call(ConnectRequest::Client(head, body, addr.clone()))
.await?;
match res {
ConnectResponse::Client(res) => match res.head().status {
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::SEE_OTHER
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
if max_redirect_times > 0 =>
{
// rebuild uri from the location header value.
let uri = res
.headers()
.get(header::LOCATION)
.map(|value| {
Uri::builder()
.scheme(uri.scheme().cloned().unwrap())
.authority(uri.authority().cloned().unwrap())
.path_and_query(value.as_bytes())
})
.ok_or(SendRequestError::Url(
InvalidUrl::MissingScheme,
))?
.build()?;
// use a new request head. RedirectServiceFuture::Client {
let mut head_new = RequestHead::default(); fut,
head_new.uri = uri; max_redirect_times,
uri: Some(uri),
head = RequestHeadType::Owned(head_new); addr,
connector: Some(connector),
// throw body
body = Body::None;
max_redirect_times -= 1;
}
_ => return Ok(ConnectResponse::Client(res)),
},
_ => unreachable!(
" ConnectRequest::Tunnel is not handled by Redirect"
),
}
}
} }
} }
}) }
}
}
pin_project_lite::pin_project! {
#[project = RedirectServiceProj]
pub enum RedirectServiceFuture<S>
where
S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
S: 'static
{
Tunnel { #[pin] fut: S::Future },
Client {
#[pin]
fut: S::Future,
max_redirect_times: u8,
uri: Option<Uri>,
addr: Option<SocketAddr>,
connector: Option<Rc<S>>
}
}
}
impl<S> Future for RedirectServiceFuture<S>
where
S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
{
type Output = Result<ConnectResponse, SendRequestError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() {
RedirectServiceProj::Tunnel { fut } => fut.poll(cx),
RedirectServiceProj::Client {
fut,
max_redirect_times,
uri,
addr,
connector,
} => {
match ready!(fut.poll(cx))? {
ConnectResponse::Client(res) => match res.head().status {
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::SEE_OTHER
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
if *max_redirect_times > 0 =>
{
let uri = uri.take().unwrap();
// rebuild uri from the location header value.
let uri = res
.headers()
.get(header::LOCATION)
.map(|value| {
Uri::builder()
.scheme(uri.scheme().cloned().unwrap())
.authority(uri.authority().cloned().unwrap())
.path_and_query(value.as_bytes())
})
.ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))?
.build()?;
let addr = addr.take();
let connector = connector.take();
let mut max_redirect_times = *max_redirect_times;
// use a new request head.
let mut head = RequestHead::default();
head.uri = uri.clone();
let head = RequestHeadType::Owned(head);
// throw body
let body = Body::None;
max_redirect_times -= 1;
let fut = connector
.as_ref()
.unwrap()
.call(ConnectRequest::Client(head, body, addr));
self.as_mut().set(RedirectServiceFuture::Client {
fut,
max_redirect_times,
uri: Some(uri),
addr,
connector,
});
self.poll(cx)
}
_ => Poll::Ready(Ok(ConnectResponse::Client(res))),
},
_ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
}
}
}
} }
} }
@ -146,6 +209,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_basic_redirect() { async fn test_basic_redirect() {
let client = ClientBuilder::new() let client = ClientBuilder::new()
.connector(crate::Connector::new())
.wrap(RedirectMiddleware::new().max_redirect_times(10)) .wrap(RedirectMiddleware::new().max_redirect_times(10))
.finish(); .finish();
@ -172,6 +236,7 @@ mod tests {
async fn test_redirect_limit() { async fn test_redirect_limit() {
let client = ClientBuilder::new() let client = ClientBuilder::new()
.wrap(RedirectMiddleware::new().max_redirect_times(1)) .wrap(RedirectMiddleware::new().max_redirect_times(1))
.connector(crate::Connector::new())
.finish(); .finish();
let srv = start(|| { let srv = start(|| {