mirror of https://github.com/fafhrd91/actix-web
add partial body reuse
This commit is contained in:
parent
68d079ba36
commit
6437474d3e
|
@ -9,15 +9,17 @@ use std::{
|
||||||
use actix_http::{
|
use actix_http::{
|
||||||
body::Body,
|
body::Body,
|
||||||
client::{InvalidUrl, SendRequestError},
|
client::{InvalidUrl, SendRequestError},
|
||||||
http::{header, StatusCode, Uri},
|
http::{header, Method, StatusCode, Uri},
|
||||||
RequestHead, RequestHeadType,
|
RequestHead, RequestHeadType,
|
||||||
};
|
};
|
||||||
use actix_service::Service;
|
use actix_service::Service;
|
||||||
|
use bytes::Bytes;
|
||||||
use futures_core::ready;
|
use futures_core::ready;
|
||||||
|
|
||||||
use super::Transform;
|
use super::Transform;
|
||||||
|
|
||||||
use crate::connect::{ConnectRequest, ConnectResponse};
|
use crate::connect::{ConnectRequest, ConnectResponse};
|
||||||
|
use crate::ClientResponse;
|
||||||
|
|
||||||
pub struct RedirectMiddleware {
|
pub struct RedirectMiddleware {
|
||||||
max_redirect_times: u8,
|
max_redirect_times: u8,
|
||||||
|
@ -81,10 +83,17 @@ where
|
||||||
let connector = self.connector.clone();
|
let connector = self.connector.clone();
|
||||||
let max_redirect_times = self.max_redirect_times;
|
let max_redirect_times = self.max_redirect_times;
|
||||||
|
|
||||||
// backup the uri for reuse schema and authority.
|
// backup the uri and method for reuse schema and authority.
|
||||||
let uri = match head {
|
let (uri, method) = match head {
|
||||||
RequestHeadType::Owned(ref head) => head.uri.clone(),
|
RequestHeadType::Owned(ref head) => (head.uri.clone(), head.method.clone()),
|
||||||
RequestHeadType::Rc(ref head, ..) => head.uri.clone(),
|
RequestHeadType::Rc(ref head, ..) => {
|
||||||
|
(head.uri.clone(), head.method.clone())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let body_opt = match body {
|
||||||
|
Body::Bytes(ref b) => Some(b.clone()),
|
||||||
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let fut = connector.call(ConnectRequest::Client(head, body, addr));
|
let fut = connector.call(ConnectRequest::Client(head, body, addr));
|
||||||
|
@ -93,6 +102,8 @@ where
|
||||||
fut,
|
fut,
|
||||||
max_redirect_times,
|
max_redirect_times,
|
||||||
uri: Some(uri),
|
uri: Some(uri),
|
||||||
|
method: Some(method),
|
||||||
|
body: body_opt,
|
||||||
addr,
|
addr,
|
||||||
connector: Some(connector),
|
connector: Some(connector),
|
||||||
}
|
}
|
||||||
|
@ -114,6 +125,8 @@ pin_project_lite::pin_project! {
|
||||||
fut: S::Future,
|
fut: S::Future,
|
||||||
max_redirect_times: u8,
|
max_redirect_times: u8,
|
||||||
uri: Option<Uri>,
|
uri: Option<Uri>,
|
||||||
|
method: Option<Method>,
|
||||||
|
body: Option<Bytes>,
|
||||||
addr: Option<SocketAddr>,
|
addr: Option<SocketAddr>,
|
||||||
connector: Option<Rc<S>>
|
connector: Option<Rc<S>>
|
||||||
}
|
}
|
||||||
|
@ -133,6 +146,8 @@ where
|
||||||
fut,
|
fut,
|
||||||
max_redirect_times,
|
max_redirect_times,
|
||||||
uri,
|
uri,
|
||||||
|
method,
|
||||||
|
body,
|
||||||
addr,
|
addr,
|
||||||
connector,
|
connector,
|
||||||
} => {
|
} => {
|
||||||
|
@ -141,25 +156,20 @@ where
|
||||||
StatusCode::MOVED_PERMANENTLY
|
StatusCode::MOVED_PERMANENTLY
|
||||||
| StatusCode::FOUND
|
| StatusCode::FOUND
|
||||||
| StatusCode::SEE_OTHER
|
| StatusCode::SEE_OTHER
|
||||||
| StatusCode::TEMPORARY_REDIRECT
|
|
||||||
| StatusCode::PERMANENT_REDIRECT
|
|
||||||
if *max_redirect_times > 0 =>
|
if *max_redirect_times > 0 =>
|
||||||
{
|
{
|
||||||
let uri = uri.take().unwrap();
|
let org_uri = uri.take().unwrap();
|
||||||
|
|
||||||
// rebuild uri from the location header value.
|
// rebuild uri from the location header value.
|
||||||
let uri = res
|
let uri = rebuild_uri(&res, org_uri)?;
|
||||||
.headers()
|
|
||||||
.get(header::LOCATION)
|
|
||||||
.map(|value| {
|
|
||||||
Uri::builder()
|
|
||||||
.scheme(uri.scheme().cloned().unwrap())
|
|
||||||
.authority(uri.authority().cloned().unwrap())
|
|
||||||
.path_and_query(value.as_bytes())
|
|
||||||
})
|
|
||||||
.ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))?
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
|
// reset method
|
||||||
|
let method = method.take().unwrap();
|
||||||
|
let method = match method {
|
||||||
|
Method::GET | Method::HEAD => method,
|
||||||
|
_ => Method::GET,
|
||||||
|
};
|
||||||
|
|
||||||
|
// take ownership of states that could be reused
|
||||||
let addr = addr.take();
|
let addr = addr.take();
|
||||||
let connector = connector.take();
|
let connector = connector.take();
|
||||||
let mut max_redirect_times = *max_redirect_times;
|
let mut max_redirect_times = *max_redirect_times;
|
||||||
|
@ -167,22 +177,70 @@ where
|
||||||
// use a new request head.
|
// use a new request head.
|
||||||
let mut head = RequestHead::default();
|
let mut head = RequestHead::default();
|
||||||
head.uri = uri.clone();
|
head.uri = uri.clone();
|
||||||
let head = RequestHeadType::Owned(head);
|
head.method = method.clone();
|
||||||
|
|
||||||
// throw body
|
let head = RequestHeadType::Owned(head);
|
||||||
let body = Body::None;
|
|
||||||
|
|
||||||
max_redirect_times -= 1;
|
max_redirect_times -= 1;
|
||||||
|
|
||||||
let fut = connector
|
let fut = connector
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.call(ConnectRequest::Client(head, body, addr));
|
// remove body
|
||||||
|
.call(ConnectRequest::Client(head, Body::None, addr));
|
||||||
|
|
||||||
self.as_mut().set(RedirectServiceFuture::Client {
|
self.as_mut().set(RedirectServiceFuture::Client {
|
||||||
fut,
|
fut,
|
||||||
max_redirect_times,
|
max_redirect_times,
|
||||||
uri: Some(uri),
|
uri: Some(uri),
|
||||||
|
method: Some(method),
|
||||||
|
// body is dropped on 301,302,303
|
||||||
|
body: None,
|
||||||
|
addr,
|
||||||
|
connector,
|
||||||
|
});
|
||||||
|
|
||||||
|
self.poll(cx)
|
||||||
|
}
|
||||||
|
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT
|
||||||
|
if *max_redirect_times > 0 =>
|
||||||
|
{
|
||||||
|
let org_uri = uri.take().unwrap();
|
||||||
|
// rebuild uri from the location header value.
|
||||||
|
let uri = rebuild_uri(&res, org_uri)?;
|
||||||
|
|
||||||
|
// try to reuse body
|
||||||
|
let body = body.take();
|
||||||
|
let body_new = match body {
|
||||||
|
Some(ref bytes) => Body::Bytes(bytes.clone()),
|
||||||
|
// TODO: should this be Body::Empty or Body::None.
|
||||||
|
_ => Body::Empty,
|
||||||
|
};
|
||||||
|
|
||||||
|
let addr = addr.take();
|
||||||
|
let method = method.take();
|
||||||
|
let connector = connector.take();
|
||||||
|
let mut max_redirect_times = *max_redirect_times;
|
||||||
|
|
||||||
|
// use a new request head.
|
||||||
|
let mut head = RequestHead::default();
|
||||||
|
head.uri = uri.clone();
|
||||||
|
|
||||||
|
let head = RequestHeadType::Owned(head);
|
||||||
|
|
||||||
|
max_redirect_times -= 1;
|
||||||
|
|
||||||
|
let fut = connector
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.call(ConnectRequest::Client(head, body_new, addr));
|
||||||
|
|
||||||
|
self.as_mut().set(RedirectServiceFuture::Client {
|
||||||
|
fut,
|
||||||
|
max_redirect_times,
|
||||||
|
uri: Some(uri),
|
||||||
|
method,
|
||||||
|
body,
|
||||||
addr,
|
addr,
|
||||||
connector,
|
connector,
|
||||||
});
|
});
|
||||||
|
@ -198,6 +256,33 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result<Uri, SendRequestError> {
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
|
let uri = res
|
||||||
|
.headers()
|
||||||
|
.get(header::LOCATION)
|
||||||
|
.map(|value| {
|
||||||
|
// try to parse the location to a full uri
|
||||||
|
let uri = Uri::try_from(value.as_bytes())
|
||||||
|
.map_err(|e| SendRequestError::Url(InvalidUrl::HttpError(e.into())))?;
|
||||||
|
if uri.scheme().is_none() || uri.authority().is_none() {
|
||||||
|
let uri = Uri::builder()
|
||||||
|
.scheme(org_uri.scheme().cloned().unwrap())
|
||||||
|
.authority(org_uri.authority().cloned().unwrap())
|
||||||
|
.path_and_query(value.as_bytes())
|
||||||
|
.build()?;
|
||||||
|
Ok::<_, SendRequestError>(uri)
|
||||||
|
} else {
|
||||||
|
Ok(uri)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
// TODO: this error type is wrong.
|
||||||
|
.ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))??;
|
||||||
|
|
||||||
|
Ok(uri)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use actix_web::{test::start, web, App, Error, HttpResponse};
|
use actix_web::{test::start, web, App, Error, HttpResponse};
|
||||||
|
|
Loading…
Reference in New Issue