add body size check when enabling expect flag

This commit is contained in:
fakeshadow 2021-03-08 00:24:41 +08:00
parent 375eadeb46
commit af45f12c52
2 changed files with 53 additions and 11 deletions

View File

@ -68,16 +68,35 @@ where
io: Some(io),
};
let is_expect = head.as_ref().headers.contains_key(EXPECT);
// create Framed and send request
// create Framed and prepare sending request
let mut framed = Framed::new(io, h1::ClientCodec::default());
// Check EXPECT header and enable expect handle flag accordingly.
//
// RFC: https://tools.ietf.org/html/rfc7231#section-5.1.1
let is_expect = if head.as_ref().headers.contains_key(EXPECT) {
match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
let pin_framed = Pin::new(&mut framed);
let force_close = !pin_framed.codec_ref().keepalive();
release_connection(pin_framed, force_close);
// TODO: use a new variant or a new type better describing error violate
// `Requirements for clients` session of above RFC
return Err(SendRequestError::Connect(ConnectError::Disconnected));
}
_ => true,
}
} else {
false
};
framed.send((head, body.size()).into()).await?;
// make Pin<&mut Framed> for polling it on stack.
let mut pin_framed = Pin::new(&mut framed);
// special handler for EXPECT request.
// special handle for EXPECT request.
let (do_send, mut res_head) = if is_expect {
let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
.await

View File

@ -4,7 +4,10 @@ use actix_http::{
use actix_http_test::test_server;
use actix_service::ServiceFactoryExt;
use bytes::Bytes;
use futures_util::future::{self, ok};
use futures_util::{
future::{self, ok},
StreamExt,
};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
@ -99,26 +102,46 @@ async fn test_h1_expect() {
if req.headers().contains_key("AUTH") {
Ok(req)
} else {
Err(error::ErrorBadRequest("bad request"))
Err(error::ErrorExpectationFailed("expect failed"))
}
})
.h1(|_| async { Ok::<_, ()>(Response::Ok().finish()) })
.h1(|req: Request| async move {
let (_, mut body) = req.into_parts();
let mut buf = Vec::new();
while let Some(Ok(chunk)) = body.next().await {
buf.extend_from_slice(&chunk);
}
let str = std::str::from_utf8(&buf).unwrap();
assert_eq!(str, "expect body");
Ok::<_, ()>(Response::Ok().finish())
})
.tcp()
})
.await;
// test expect without payload.
let request = srv
.request(http::Method::GET, srv.url("/"))
.insert_header(("Expect", "100-continue"));
let response = request.send().await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let response = request.send().await;
assert!(response.is_err());
// test expect would fail to continue
let request = srv
.request(http::Method::GET, srv.url("/"))
.insert_header(("Expect", "100-continue"));
let response = request.send_body("expect body").await.unwrap();
assert_eq!(response.status(), StatusCode::EXPECTATION_FAILED);
// test exepct would continue
let request = srv
.request(http::Method::GET, srv.url("/"))
.insert_header(("Expect", "100-continue"))
.insert_header(("AUTH", "996"));
let response = request.send().await.unwrap();
let response = request.send_body("expect body").await.unwrap();
assert!(response.status().is_success());
}