add body size check when enabling expect flag

This commit is contained in:
fakeshadow 2021-03-08 00:24:41 +08:00
parent 375eadeb46
commit af45f12c52
2 changed files with 53 additions and 11 deletions

View File

@ -68,16 +68,35 @@ where
io: Some(io), io: Some(io),
}; };
let is_expect = head.as_ref().headers.contains_key(EXPECT); // create Framed and prepare sending request
// create Framed and send request
let mut framed = Framed::new(io, h1::ClientCodec::default()); let mut framed = Framed::new(io, h1::ClientCodec::default());
// Check EXPECT header and enable expect handle flag accordingly.
//
// RFC: https://tools.ietf.org/html/rfc7231#section-5.1.1
let is_expect = if head.as_ref().headers.contains_key(EXPECT) {
match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
let pin_framed = Pin::new(&mut framed);
let force_close = !pin_framed.codec_ref().keepalive();
release_connection(pin_framed, force_close);
// TODO: use a new variant or a new type better describing error violate
// `Requirements for clients` session of above RFC
return Err(SendRequestError::Connect(ConnectError::Disconnected));
}
_ => true,
}
} else {
false
};
framed.send((head, body.size()).into()).await?; framed.send((head, body.size()).into()).await?;
// make Pin<&mut Framed> for polling it on stack.
let mut pin_framed = Pin::new(&mut framed); let mut pin_framed = Pin::new(&mut framed);
// special handler for EXPECT request. // special handle for EXPECT request.
let (do_send, mut res_head) = if is_expect { let (do_send, mut res_head) = if is_expect {
let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx)) let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
.await .await

View File

@ -4,7 +4,10 @@ use actix_http::{
use actix_http_test::test_server; use actix_http_test::test_server;
use actix_service::ServiceFactoryExt; use actix_service::ServiceFactoryExt;
use bytes::Bytes; use bytes::Bytes;
use futures_util::future::{self, ok}; use futures_util::{
future::{self, ok},
StreamExt,
};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
@ -99,26 +102,46 @@ async fn test_h1_expect() {
if req.headers().contains_key("AUTH") { if req.headers().contains_key("AUTH") {
Ok(req) Ok(req)
} else { } else {
Err(error::ErrorBadRequest("bad request")) Err(error::ErrorExpectationFailed("expect failed"))
} }
}) })
.h1(|_| async { Ok::<_, ()>(Response::Ok().finish()) }) .h1(|req: Request| async move {
let (_, mut body) = req.into_parts();
let mut buf = Vec::new();
while let Some(Ok(chunk)) = body.next().await {
buf.extend_from_slice(&chunk);
}
let str = std::str::from_utf8(&buf).unwrap();
assert_eq!(str, "expect body");
Ok::<_, ()>(Response::Ok().finish())
})
.tcp() .tcp()
}) })
.await; .await;
// test expect without payload.
let request = srv let request = srv
.request(http::Method::GET, srv.url("/")) .request(http::Method::GET, srv.url("/"))
.insert_header(("Expect", "100-continue")); .insert_header(("Expect", "100-continue"));
let response = request.send().await.unwrap(); let response = request.send().await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST); assert!(response.is_err());
// test expect would fail to continue
let request = srv
.request(http::Method::GET, srv.url("/"))
.insert_header(("Expect", "100-continue"));
let response = request.send_body("expect body").await.unwrap();
assert_eq!(response.status(), StatusCode::EXPECTATION_FAILED);
// test exepct would continue
let request = srv let request = srv
.request(http::Method::GET, srv.url("/")) .request(http::Method::GET, srv.url("/"))
.insert_header(("Expect", "100-continue")) .insert_header(("Expect", "100-continue"))
.insert_header(("AUTH", "996")); .insert_header(("AUTH", "996"));
let response = request.send().await.unwrap(); let response = request.send_body("expect body").await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }