add partial body reuse

This commit is contained in:
fakeshadow 2021-02-19 02:39:36 -08:00
parent 68d079ba36
commit 6437474d3e
1 changed files with 109 additions and 24 deletions

View File

@ -9,15 +9,17 @@ use std::{
use actix_http::{ use actix_http::{
body::Body, body::Body,
client::{InvalidUrl, SendRequestError}, client::{InvalidUrl, SendRequestError},
http::{header, StatusCode, Uri}, http::{header, Method, StatusCode, Uri},
RequestHead, RequestHeadType, RequestHead, RequestHeadType,
}; };
use actix_service::Service; use actix_service::Service;
use bytes::Bytes;
use futures_core::ready; use futures_core::ready;
use super::Transform; use super::Transform;
use crate::connect::{ConnectRequest, ConnectResponse}; use crate::connect::{ConnectRequest, ConnectResponse};
use crate::ClientResponse;
pub struct RedirectMiddleware { pub struct RedirectMiddleware {
max_redirect_times: u8, max_redirect_times: u8,
@ -81,10 +83,17 @@ where
let connector = self.connector.clone(); let connector = self.connector.clone();
let max_redirect_times = self.max_redirect_times; let max_redirect_times = self.max_redirect_times;
// backup the uri for reuse schema and authority. // backup the uri and method for reuse schema and authority.
let uri = match head { let (uri, method) = match head {
RequestHeadType::Owned(ref head) => head.uri.clone(), RequestHeadType::Owned(ref head) => (head.uri.clone(), head.method.clone()),
RequestHeadType::Rc(ref head, ..) => head.uri.clone(), RequestHeadType::Rc(ref head, ..) => {
(head.uri.clone(), head.method.clone())
}
};
let body_opt = match body {
Body::Bytes(ref b) => Some(b.clone()),
_ => None,
}; };
let fut = connector.call(ConnectRequest::Client(head, body, addr)); let fut = connector.call(ConnectRequest::Client(head, body, addr));
@ -93,6 +102,8 @@ where
fut, fut,
max_redirect_times, max_redirect_times,
uri: Some(uri), uri: Some(uri),
method: Some(method),
body: body_opt,
addr, addr,
connector: Some(connector), connector: Some(connector),
} }
@ -114,6 +125,8 @@ pin_project_lite::pin_project! {
fut: S::Future, fut: S::Future,
max_redirect_times: u8, max_redirect_times: u8,
uri: Option<Uri>, uri: Option<Uri>,
method: Option<Method>,
body: Option<Bytes>,
addr: Option<SocketAddr>, addr: Option<SocketAddr>,
connector: Option<Rc<S>> connector: Option<Rc<S>>
} }
@ -133,6 +146,8 @@ where
fut, fut,
max_redirect_times, max_redirect_times,
uri, uri,
method,
body,
addr, addr,
connector, connector,
} => { } => {
@ -141,25 +156,20 @@ where
StatusCode::MOVED_PERMANENTLY StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND | StatusCode::FOUND
| StatusCode::SEE_OTHER | StatusCode::SEE_OTHER
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
if *max_redirect_times > 0 => if *max_redirect_times > 0 =>
{ {
let uri = uri.take().unwrap(); let org_uri = uri.take().unwrap();
// rebuild uri from the location header value. // rebuild uri from the location header value.
let uri = res let uri = rebuild_uri(&res, org_uri)?;
.headers()
.get(header::LOCATION)
.map(|value| {
Uri::builder()
.scheme(uri.scheme().cloned().unwrap())
.authority(uri.authority().cloned().unwrap())
.path_and_query(value.as_bytes())
})
.ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))?
.build()?;
// reset method
let method = method.take().unwrap();
let method = match method {
Method::GET | Method::HEAD => method,
_ => Method::GET,
};
// take ownership of states that could be reused
let addr = addr.take(); let addr = addr.take();
let connector = connector.take(); let connector = connector.take();
let mut max_redirect_times = *max_redirect_times; let mut max_redirect_times = *max_redirect_times;
@ -167,22 +177,70 @@ where
// use a new request head. // use a new request head.
let mut head = RequestHead::default(); let mut head = RequestHead::default();
head.uri = uri.clone(); head.uri = uri.clone();
let head = RequestHeadType::Owned(head); head.method = method.clone();
// throw body let head = RequestHeadType::Owned(head);
let body = Body::None;
max_redirect_times -= 1; max_redirect_times -= 1;
let fut = connector let fut = connector
.as_ref() .as_ref()
.unwrap() .unwrap()
.call(ConnectRequest::Client(head, body, addr)); // remove body
.call(ConnectRequest::Client(head, Body::None, addr));
self.as_mut().set(RedirectServiceFuture::Client { self.as_mut().set(RedirectServiceFuture::Client {
fut, fut,
max_redirect_times, max_redirect_times,
uri: Some(uri), uri: Some(uri),
method: Some(method),
// body is dropped on 301,302,303
body: None,
addr,
connector,
});
self.poll(cx)
}
StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT
if *max_redirect_times > 0 =>
{
let org_uri = uri.take().unwrap();
// rebuild uri from the location header value.
let uri = rebuild_uri(&res, org_uri)?;
// try to reuse body
let body = body.take();
let body_new = match body {
Some(ref bytes) => Body::Bytes(bytes.clone()),
// TODO: should this be Body::Empty or Body::None.
_ => Body::Empty,
};
let addr = addr.take();
let method = method.take();
let connector = connector.take();
let mut max_redirect_times = *max_redirect_times;
// use a new request head.
let mut head = RequestHead::default();
head.uri = uri.clone();
let head = RequestHeadType::Owned(head);
max_redirect_times -= 1;
let fut = connector
.as_ref()
.unwrap()
.call(ConnectRequest::Client(head, body_new, addr));
self.as_mut().set(RedirectServiceFuture::Client {
fut,
max_redirect_times,
uri: Some(uri),
method,
body,
addr, addr,
connector, connector,
}); });
@ -198,6 +256,33 @@ where
} }
} }
fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result<Uri, SendRequestError> {
use std::convert::TryFrom;
let uri = res
.headers()
.get(header::LOCATION)
.map(|value| {
// try to parse the location to a full uri
let uri = Uri::try_from(value.as_bytes())
.map_err(|e| SendRequestError::Url(InvalidUrl::HttpError(e.into())))?;
if uri.scheme().is_none() || uri.authority().is_none() {
let uri = Uri::builder()
.scheme(org_uri.scheme().cloned().unwrap())
.authority(org_uri.authority().cloned().unwrap())
.path_and_query(value.as_bytes())
.build()?;
Ok::<_, SendRequestError>(uri)
} else {
Ok(uri)
}
})
// TODO: this error type is wrong.
.ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))??;
Ok(uri)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_web::{test::start, web, App, Error, HttpResponse}; use actix_web::{test::start, web, App, Error, HttpResponse};