mirror of https://github.com/fafhrd91/actix-web
rename redirect middleware name. fix method not reset on 307/308
This commit is contained in:
parent
9578b50bb2
commit
2a2fc6022b
|
@ -1,6 +1,6 @@
|
|||
mod redirect;
|
||||
|
||||
pub use self::redirect::RedirectMiddleware;
|
||||
pub use self::redirect::Redirect;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ use std::{
|
|||
pin::Pin,
|
||||
rc::Rc,
|
||||
task::{Context, Poll},
|
||||
convert::TryFrom
|
||||
};
|
||||
|
||||
use actix_http::{
|
||||
|
@ -21,17 +22,17 @@ use super::Transform;
|
|||
use crate::connect::{ConnectRequest, ConnectResponse};
|
||||
use crate::ClientResponse;
|
||||
|
||||
pub struct RedirectMiddleware {
|
||||
pub struct Redirect {
|
||||
max_redirect_times: u8,
|
||||
}
|
||||
|
||||
impl Default for RedirectMiddleware {
|
||||
impl Default for Redirect {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RedirectMiddleware {
|
||||
impl Redirect {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_redirect_times: 10,
|
||||
|
@ -44,7 +45,7 @@ impl RedirectMiddleware {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S> Transform<S, ConnectRequest> for RedirectMiddleware
|
||||
impl<S> Transform<S, ConnectRequest> for Redirect
|
||||
where
|
||||
S: Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError> + 'static,
|
||||
{
|
||||
|
@ -150,115 +151,112 @@ where
|
|||
body,
|
||||
addr,
|
||||
connector,
|
||||
} => {
|
||||
match ready!(fut.poll(cx))? {
|
||||
ConnectResponse::Client(res) => match res.head().status {
|
||||
StatusCode::MOVED_PERMANENTLY
|
||||
| StatusCode::FOUND
|
||||
| StatusCode::SEE_OTHER
|
||||
if *max_redirect_times > 0 =>
|
||||
{
|
||||
let org_uri = uri.take().unwrap();
|
||||
// rebuild uri from the location header value.
|
||||
let uri = rebuild_uri(&res, org_uri)?;
|
||||
} => match ready!(fut.poll(cx))? {
|
||||
ConnectResponse::Client(res) => match res.head().status {
|
||||
StatusCode::MOVED_PERMANENTLY
|
||||
| StatusCode::FOUND
|
||||
| StatusCode::SEE_OTHER
|
||||
if *max_redirect_times > 0 =>
|
||||
{
|
||||
let org_uri = uri.take().unwrap();
|
||||
// rebuild uri from the location header value.
|
||||
let uri = rebuild_uri(&res, org_uri)?;
|
||||
|
||||
// reset method
|
||||
let method = method.take().unwrap();
|
||||
let method = match method {
|
||||
Method::GET | Method::HEAD => method,
|
||||
_ => Method::GET,
|
||||
};
|
||||
// reset method
|
||||
let method = method.take().unwrap();
|
||||
let method = match method {
|
||||
Method::GET | Method::HEAD => method,
|
||||
_ => Method::GET,
|
||||
};
|
||||
|
||||
// take ownership of states that could be reused
|
||||
let addr = addr.take();
|
||||
let connector = connector.take();
|
||||
let mut max_redirect_times = *max_redirect_times;
|
||||
// take ownership of states that could be reused
|
||||
let addr = addr.take();
|
||||
let connector = connector.take();
|
||||
let mut max_redirect_times = *max_redirect_times;
|
||||
|
||||
// use a new request head.
|
||||
let mut head = RequestHead::default();
|
||||
head.uri = uri.clone();
|
||||
head.method = method.clone();
|
||||
// use a new request head.
|
||||
let mut head = RequestHead::default();
|
||||
head.uri = uri.clone();
|
||||
head.method = method.clone();
|
||||
|
||||
let head = RequestHeadType::Owned(head);
|
||||
let head = RequestHeadType::Owned(head);
|
||||
|
||||
max_redirect_times -= 1;
|
||||
max_redirect_times -= 1;
|
||||
|
||||
let fut = connector
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
// remove body
|
||||
.call(ConnectRequest::Client(head, Body::None, addr));
|
||||
let fut = connector
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
// remove body
|
||||
.call(ConnectRequest::Client(head, Body::None, addr));
|
||||
|
||||
self.as_mut().set(RedirectServiceFuture::Client {
|
||||
fut,
|
||||
max_redirect_times,
|
||||
uri: Some(uri),
|
||||
method: Some(method),
|
||||
// body is dropped on 301,302,303
|
||||
body: None,
|
||||
addr,
|
||||
connector,
|
||||
});
|
||||
self.as_mut().set(RedirectServiceFuture::Client {
|
||||
fut,
|
||||
max_redirect_times,
|
||||
uri: Some(uri),
|
||||
method: Some(method),
|
||||
// body is dropped on 301,302,303
|
||||
body: None,
|
||||
addr,
|
||||
connector,
|
||||
});
|
||||
|
||||
self.poll(cx)
|
||||
}
|
||||
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT
|
||||
if *max_redirect_times > 0 =>
|
||||
{
|
||||
let org_uri = uri.take().unwrap();
|
||||
// rebuild uri from the location header value.
|
||||
let uri = rebuild_uri(&res, org_uri)?;
|
||||
self.poll(cx)
|
||||
}
|
||||
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT
|
||||
if *max_redirect_times > 0 =>
|
||||
{
|
||||
let org_uri = uri.take().unwrap();
|
||||
// rebuild uri from the location header value.
|
||||
let uri = rebuild_uri(&res, org_uri)?;
|
||||
|
||||
// try to reuse body
|
||||
let body = body.take();
|
||||
let body_new = match body {
|
||||
Some(ref bytes) => Body::Bytes(bytes.clone()),
|
||||
// TODO: should this be Body::Empty or Body::None.
|
||||
_ => Body::Empty,
|
||||
};
|
||||
// try to reuse body
|
||||
let body = body.take();
|
||||
let body_new = match body {
|
||||
Some(ref bytes) => Body::Bytes(bytes.clone()),
|
||||
// TODO: should this be Body::Empty or Body::None.
|
||||
_ => Body::Empty,
|
||||
};
|
||||
|
||||
let addr = addr.take();
|
||||
let method = method.take();
|
||||
let connector = connector.take();
|
||||
let mut max_redirect_times = *max_redirect_times;
|
||||
let addr = addr.take();
|
||||
let method = method.take().unwrap();
|
||||
let connector = connector.take();
|
||||
let mut max_redirect_times = *max_redirect_times;
|
||||
|
||||
// use a new request head.
|
||||
let mut head = RequestHead::default();
|
||||
head.uri = uri.clone();
|
||||
// use a new request head.
|
||||
let mut head = RequestHead::default();
|
||||
head.uri = uri.clone();
|
||||
head.method = method.clone();
|
||||
|
||||
let head = RequestHeadType::Owned(head);
|
||||
let head = RequestHeadType::Owned(head);
|
||||
|
||||
max_redirect_times -= 1;
|
||||
max_redirect_times -= 1;
|
||||
|
||||
let fut = connector
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.call(ConnectRequest::Client(head, body_new, addr));
|
||||
let fut = connector
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.call(ConnectRequest::Client(head, body_new, addr));
|
||||
|
||||
self.as_mut().set(RedirectServiceFuture::Client {
|
||||
fut,
|
||||
max_redirect_times,
|
||||
uri: Some(uri),
|
||||
method,
|
||||
body,
|
||||
addr,
|
||||
connector,
|
||||
});
|
||||
self.as_mut().set(RedirectServiceFuture::Client {
|
||||
fut,
|
||||
max_redirect_times,
|
||||
uri: Some(uri),
|
||||
method: Some(method),
|
||||
body,
|
||||
addr,
|
||||
connector,
|
||||
});
|
||||
|
||||
self.poll(cx)
|
||||
}
|
||||
_ => Poll::Ready(Ok(ConnectResponse::Client(res))),
|
||||
},
|
||||
_ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
|
||||
}
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
_ => Poll::Ready(Ok(ConnectResponse::Client(res))),
|
||||
},
|
||||
_ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result<Uri, SendRequestError> {
|
||||
use std::convert::TryFrom;
|
||||
|
||||
let uri = res
|
||||
.headers()
|
||||
.get(header::LOCATION)
|
||||
|
@ -295,7 +293,7 @@ mod tests {
|
|||
async fn test_basic_redirect() {
|
||||
let client = ClientBuilder::new()
|
||||
.connector(crate::Connector::new())
|
||||
.wrap(RedirectMiddleware::new().max_redirect_times(10))
|
||||
.wrap(Redirect::new().max_redirect_times(10))
|
||||
.finish();
|
||||
|
||||
let srv = start(|| {
|
||||
|
@ -320,7 +318,7 @@ mod tests {
|
|||
#[actix_rt::test]
|
||||
async fn test_redirect_limit() {
|
||||
let client = ClientBuilder::new()
|
||||
.wrap(RedirectMiddleware::new().max_redirect_times(1))
|
||||
.wrap(Redirect::new().max_redirect_times(1))
|
||||
.connector(crate::Connector::new())
|
||||
.finish();
|
||||
|
||||
|
|
|
@ -492,9 +492,7 @@ mod tests {
|
|||
JsonPayloadError::Payload(PayloadError::Overflow) => {
|
||||
matches!(other, JsonPayloadError::Payload(PayloadError::Overflow))
|
||||
}
|
||||
JsonPayloadError::ContentType => {
|
||||
matches!(other, JsonPayloadError::ContentType)
|
||||
}
|
||||
JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue