rename redirect middleware name. fix method not reset on 307/308

This commit is contained in:
fakeshadow 2021-02-23 17:53:37 +08:00
parent 9578b50bb2
commit 2a2fc6022b
3 changed files with 94 additions and 98 deletions

View File

@ -1,6 +1,6 @@
mod redirect; mod redirect;
pub use self::redirect::RedirectMiddleware; pub use self::redirect::Redirect;
use std::marker::PhantomData; use std::marker::PhantomData;

View File

@ -4,6 +4,7 @@ use std::{
pin::Pin, pin::Pin,
rc::Rc, rc::Rc,
task::{Context, Poll}, task::{Context, Poll},
convert::TryFrom
}; };
use actix_http::{ use actix_http::{
@ -21,17 +22,17 @@ use super::Transform;
use crate::connect::{ConnectRequest, ConnectResponse}; use crate::connect::{ConnectRequest, ConnectResponse};
use crate::ClientResponse; use crate::ClientResponse;
pub struct RedirectMiddleware { pub struct Redirect {
max_redirect_times: u8, max_redirect_times: u8,
} }
impl Default for RedirectMiddleware { impl Default for Redirect {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
impl RedirectMiddleware { impl Redirect {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
max_redirect_times: 10, max_redirect_times: 10,
@ -44,7 +45,7 @@ impl RedirectMiddleware {
} }
} }
impl<S> Transform<S, ConnectRequest> for RedirectMiddleware impl<S> Transform<S, ConnectRequest> for Redirect
where where
S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static, S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
{ {
@ -150,115 +151,112 @@ where
body, body,
addr, addr,
connector, connector,
} => { } => match ready!(fut.poll(cx))? {
match ready!(fut.poll(cx))? { ConnectResponse::Client(res) => match res.head().status {
ConnectResponse::Client(res) => match res.head().status { StatusCode::MOVED_PERMANENTLY
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND
| StatusCode::FOUND | StatusCode::SEE_OTHER
| StatusCode::SEE_OTHER if *max_redirect_times > 0 =>
if *max_redirect_times > 0 => {
{ let org_uri = uri.take().unwrap();
let org_uri = uri.take().unwrap(); // rebuild uri from the location header value.
// rebuild uri from the location header value. let uri = rebuild_uri(&res, org_uri)?;
let uri = rebuild_uri(&res, org_uri)?;
// reset method // reset method
let method = method.take().unwrap(); let method = method.take().unwrap();
let method = match method { let method = match method {
Method::GET | Method::HEAD => method, Method::GET | Method::HEAD => method,
_ => Method::GET, _ => Method::GET,
}; };
// take ownership of states that could be reused // take ownership of states that could be reused
let addr = addr.take(); let addr = addr.take();
let connector = connector.take(); let connector = connector.take();
let mut max_redirect_times = *max_redirect_times; let mut max_redirect_times = *max_redirect_times;
// use a new request head. // use a new request head.
let mut head = RequestHead::default(); let mut head = RequestHead::default();
head.uri = uri.clone(); head.uri = uri.clone();
head.method = method.clone(); head.method = method.clone();
let head = RequestHeadType::Owned(head); let head = RequestHeadType::Owned(head);
max_redirect_times -= 1; max_redirect_times -= 1;
let fut = connector let fut = connector
.as_ref() .as_ref()
.unwrap() .unwrap()
// remove body // remove body
.call(ConnectRequest::Client(head, Body::None, addr)); .call(ConnectRequest::Client(head, Body::None, addr));
self.as_mut().set(RedirectServiceFuture::Client { self.as_mut().set(RedirectServiceFuture::Client {
fut, fut,
max_redirect_times, max_redirect_times,
uri: Some(uri), uri: Some(uri),
method: Some(method), method: Some(method),
// body is dropped on 301,302,303 // body is dropped on 301,302,303
body: None, body: None,
addr, addr,
connector, connector,
}); });
self.poll(cx) self.poll(cx)
} }
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT
if *max_redirect_times > 0 => if *max_redirect_times > 0 =>
{ {
let org_uri = uri.take().unwrap(); let org_uri = uri.take().unwrap();
// rebuild uri from the location header value. // rebuild uri from the location header value.
let uri = rebuild_uri(&res, org_uri)?; let uri = rebuild_uri(&res, org_uri)?;
// try to reuse body // try to reuse body
let body = body.take(); let body = body.take();
let body_new = match body { let body_new = match body {
Some(ref bytes) => Body::Bytes(bytes.clone()), Some(ref bytes) => Body::Bytes(bytes.clone()),
// TODO: should this be Body::Empty or Body::None. // TODO: should this be Body::Empty or Body::None.
_ => Body::Empty, _ => Body::Empty,
}; };
let addr = addr.take(); let addr = addr.take();
let method = method.take(); let method = method.take().unwrap();
let connector = connector.take(); let connector = connector.take();
let mut max_redirect_times = *max_redirect_times; let mut max_redirect_times = *max_redirect_times;
// use a new request head. // use a new request head.
let mut head = RequestHead::default(); let mut head = RequestHead::default();
head.uri = uri.clone(); head.uri = uri.clone();
head.method = method.clone();
let head = RequestHeadType::Owned(head); let head = RequestHeadType::Owned(head);
max_redirect_times -= 1; max_redirect_times -= 1;
let fut = connector let fut = connector
.as_ref() .as_ref()
.unwrap() .unwrap()
.call(ConnectRequest::Client(head, body_new, addr)); .call(ConnectRequest::Client(head, body_new, addr));
self.as_mut().set(RedirectServiceFuture::Client { self.as_mut().set(RedirectServiceFuture::Client {
fut, fut,
max_redirect_times, max_redirect_times,
uri: Some(uri), uri: Some(uri),
method, method: Some(method),
body, body,
addr, addr,
connector, connector,
}); });
self.poll(cx) self.poll(cx)
} }
_ => Poll::Ready(Ok(ConnectResponse::Client(res))), _ => Poll::Ready(Ok(ConnectResponse::Client(res))),
}, },
_ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"), _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
} },
}
} }
} }
} }
fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result<Uri, SendRequestError> { fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result<Uri, SendRequestError> {
use std::convert::TryFrom;
let uri = res let uri = res
.headers() .headers()
.get(header::LOCATION) .get(header::LOCATION)
@ -295,7 +293,7 @@ mod tests {
async fn test_basic_redirect() { async fn test_basic_redirect() {
let client = ClientBuilder::new() let client = ClientBuilder::new()
.connector(crate::Connector::new()) .connector(crate::Connector::new())
.wrap(RedirectMiddleware::new().max_redirect_times(10)) .wrap(Redirect::new().max_redirect_times(10))
.finish(); .finish();
let srv = start(|| { let srv = start(|| {
@ -320,7 +318,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_redirect_limit() { async fn test_redirect_limit() {
let client = ClientBuilder::new() let client = ClientBuilder::new()
.wrap(RedirectMiddleware::new().max_redirect_times(1)) .wrap(Redirect::new().max_redirect_times(1))
.connector(crate::Connector::new()) .connector(crate::Connector::new())
.finish(); .finish();

View File

@ -492,9 +492,7 @@ mod tests {
JsonPayloadError::Payload(PayloadError::Overflow) => { JsonPayloadError::Payload(PayloadError::Overflow) => {
matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) matches!(other, JsonPayloadError::Payload(PayloadError::Overflow))
} }
JsonPayloadError::ContentType => { JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
matches!(other, JsonPayloadError::ContentType)
}
_ => false, _ => false,
} }
} }