diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index 0f9215520..d2db18cec 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -68,16 +68,35 @@ where io: Some(io), }; - let is_expect = head.as_ref().headers.contains_key(EXPECT); - - // create Framed and send request + // create Framed and prepare sending request let mut framed = Framed::new(io, h1::ClientCodec::default()); + + // Check EXPECT header and enable expect handle flag accordingly. + // + // RFC: https://tools.ietf.org/html/rfc7231#section-5.1.1 + let is_expect = if head.as_ref().headers.contains_key(EXPECT) { + match body.size() { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => { + let pin_framed = Pin::new(&mut framed); + + let force_close = !pin_framed.codec_ref().keepalive(); + release_connection(pin_framed, force_close); + + // TODO: use a new variant or a new type better describing error violate + // `Requirements for clients` session of above RFC + return Err(SendRequestError::Connect(ConnectError::Disconnected)); + } + _ => true, + } + } else { + false + }; + framed.send((head, body.size()).into()).await?; - // make Pin<&mut Framed> for polling it on stack. let mut pin_framed = Pin::new(&mut framed); - // special handler for EXPECT request. + // special handle for EXPECT request. let (do_send, mut res_head) = if is_expect { let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx)) .await diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index 38fd03814..a50f2404d 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -4,7 +4,10 @@ use actix_http::{ use actix_http_test::test_server; use actix_service::ServiceFactoryExt; use bytes::Bytes; -use futures_util::future::{self, ok}; +use futures_util::{ + future::{self, ok}, + StreamExt, +}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -99,26 +102,46 @@ async fn test_h1_expect() { if req.headers().contains_key("AUTH") { Ok(req) } else { - Err(error::ErrorBadRequest("bad request")) + Err(error::ErrorExpectationFailed("expect failed")) } }) - .h1(|_| async { Ok::<_, ()>(Response::Ok().finish()) }) + .h1(|req: Request| async move { + let (_, mut body) = req.into_parts(); + let mut buf = Vec::new(); + while let Some(Ok(chunk)) = body.next().await { + buf.extend_from_slice(&chunk); + } + let str = std::str::from_utf8(&buf).unwrap(); + assert_eq!(str, "expect body"); + + Ok::<_, ()>(Response::Ok().finish()) + }) .tcp() }) .await; + // test expect without payload. let request = srv .request(http::Method::GET, srv.url("/")) .insert_header(("Expect", "100-continue")); - let response = request.send().await.unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let response = request.send().await; + assert!(response.is_err()); + // test expect would fail to continue + let request = srv + .request(http::Method::GET, srv.url("/")) + .insert_header(("Expect", "100-continue")); + + let response = request.send_body("expect body").await.unwrap(); + assert_eq!(response.status(), StatusCode::EXPECTATION_FAILED); + + // test exepct would continue let request = srv .request(http::Method::GET, srv.url("/")) .insert_header(("Expect", "100-continue")) .insert_header(("AUTH", "996")); - let response = request.send().await.unwrap(); + let response = request.send_body("expect body").await.unwrap(); assert!(response.status().is_success()); }