diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 2ed78cfca..73e3099b1 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -6,6 +6,7 @@ use std::{ pin::Pin, rc::Rc, task::{Context, Poll}, + time::{Duration, Instant}, }; use actix_codec::{Framed, FramedParts}; @@ -31,11 +32,13 @@ use crate::{ config::ServiceConfig, error::{DispatchError, ParseError, PayloadError}, service::HttpFlow, - Error, Extensions, HttpMessage, OnConnectData, Request, Response, StatusCode, + ConnectionType, Error, Extensions, HttpMessage, OnConnectData, Request, Response, + StatusCode, }; const LW_BUFFER_SIZE: usize = 1024; const HW_BUFFER_SIZE: usize = 1024 * 8; +const LINGER_TIMEOUT: Duration = Duration::from_secs(1); const MAX_PIPELINED_MESSAGES: usize = 16; bitflags! { @@ -58,6 +61,9 @@ bitflags! { /// Set if write-half is disconnected. const WRITE_DISCONNECT = 0b0010_0000; + + /// Set while lingering on a non-reusable connection after sending a response. + const LINGER = 0b0100_0000; } } @@ -361,6 +367,50 @@ where io.poll_flush(cx) } + fn enter_linger(mut self: Pin<&mut Self>) { + let this = self.as_mut().project(); + this.flags.remove(Flags::KEEP_ALIVE); + this.flags.insert(Flags::LINGER | Flags::FINISHED); + } + + fn ensure_linger_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) { + let this = self.as_mut().project(); + if !matches!(this.shutdown_timer, TimerState::Active { .. }) { + let deadline = Instant::now() + LINGER_TIMEOUT; + + this.shutdown_timer + .set_and_init(cx, sleep_until(deadline.into()), line!()); + } + } + + fn poll_linger(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result, DispatchError> { + if self.as_mut().poll_flush(cx)?.is_pending() { + return Ok(Poll::Pending); + } + self.as_mut().ensure_linger_timer(cx); + + loop { + let should_disconnect = self.as_mut().read_available(cx)?; + let this = self.as_mut().project(); + let mut progressed = false; + + if !this.read_buf.is_empty() { + this.read_buf.clear(); + progressed = true; + } + + if should_disconnect { + this.flags.insert(Flags::READ_DISCONNECT | Flags::SHUTDOWN); + this.flags.remove(Flags::LINGER); + return Ok(Poll::Ready(())); + } + + if !progressed { + return Ok(Poll::Pending); + } + } + } + fn send_response_inner( self: Pin<&mut Self>, res: Response<()>, @@ -385,54 +435,68 @@ where fn send_response( mut self: Pin<&mut Self>, - res: Response<()>, + mut res: Response<()>, body: B, ) -> Result<(), DispatchError> { - let size = self.as_mut().send_response_inner(res, &body)?; - let mut this = self.project(); - this.state.set(match size { - BodySize::None | BodySize::Sized(0) => { - let payload_unfinished = this.payload.is_some(); - let drain_payload = this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) - && *this.payload_drainable; + let linger = { + let this = self.as_mut().project(); + should_linger(this.payload.as_ref(), *this.payload_drainable) + }; - if payload_unfinished && !drain_payload { - this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); + if linger { + res.head_mut().set_connection_type(ConnectionType::Close); + } + + let size = self.as_mut().send_response_inner(res, &body)?; + + match size { + BodySize::None | BodySize::Sized(0) => { + if linger { + self.as_mut().enter_linger(); } else { - this.flags.insert(Flags::FINISHED); + self.as_mut().project().flags.insert(Flags::FINISHED); } - State::None + self.as_mut().project().state.set(State::None); } - _ => State::SendPayload { body }, - }); + _ => self.as_mut().project().state.set(State::SendPayload { body }), + } Ok(()) } fn send_error_response( mut self: Pin<&mut Self>, - res: Response<()>, + mut res: Response<()>, body: BoxBody, ) -> Result<(), DispatchError> { - let size = self.as_mut().send_response_inner(res, &body)?; - let mut this = self.project(); - this.state.set(match size { - BodySize::None | BodySize::Sized(0) => { - let payload_unfinished = this.payload.is_some(); - let drain_payload = this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) - && *this.payload_drainable; + let linger = { + let this = self.as_mut().project(); + should_linger(this.payload.as_ref(), *this.payload_drainable) + }; - if payload_unfinished && !drain_payload { - this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); + if linger { + res.head_mut().set_connection_type(ConnectionType::Close); + } + + let size = self.as_mut().send_response_inner(res, &body)?; + + match size { + BodySize::None | BodySize::Sized(0) => { + if linger { + self.as_mut().enter_linger(); } else { - this.flags.insert(Flags::FINISHED); + self.as_mut().project().flags.insert(Flags::FINISHED); } - State::None + self.as_mut().project().state.set(State::None); } - _ => State::SendErrorPayload { body }, - }); + _ => self + .as_mut() + .project() + .state + .set(State::SendErrorPayload { body }), + } Ok(()) } @@ -534,18 +598,18 @@ where // this.payload was the payload for the request we just finished // responding to. We can check to see if we finished reading it // yet, and if not, shutdown the connection. - let payload_unfinished = this.payload.is_some(); - let drain_payload = - this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) - && *this.payload_drainable; + let linger = should_linger( + this.payload.as_ref(), + *this.payload_drainable, + ); let not_pipelined = this.messages.is_empty(); // payload stream finished. // set state to None and handle next message this.state.set(State::None); - if not_pipelined && payload_unfinished && !drain_payload { - this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); + if not_pipelined && linger { + self.as_mut().enter_linger(); } else { this.flags.insert(Flags::FINISHED); } @@ -588,18 +652,18 @@ where // this.payload was the payload for the request we just finished // responding to. We can check to see if we finished reading it // yet, and if not, shutdown the connection. - let payload_unfinished = this.payload.is_some(); - let drain_payload = - this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) - && *this.payload_drainable; + let linger = should_linger( + this.payload.as_ref(), + *this.payload_drainable, + ); let not_pipelined = this.messages.is_empty(); // payload stream finished. // set state to None and handle next message this.state.set(State::None); - if not_pipelined && payload_unfinished && !drain_payload { - this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); + if not_pipelined && linger { + self.as_mut().enter_linger(); } else { this.flags.insert(Flags::FINISHED); } @@ -960,14 +1024,20 @@ where let this = self.as_mut().project(); if let TimerState::Active { timer } = this.shutdown_timer { debug_assert!( - this.flags.contains(Flags::SHUTDOWN), - "shutdown flag should be set when timer is active", + this.flags.intersects(Flags::SHUTDOWN | Flags::LINGER), + "shutdown or linger flag should be set when timer is active", ); - // timed-out during shutdown; drop connection if timer.as_mut().poll(cx).is_ready() { - trace!("timed-out during shutdown"); - return Err(DispatchError::DisconnectTimeout); + if this.flags.contains(Flags::LINGER) { + trace!("timed-out during linger; shutting down connection"); + this.flags.remove(Flags::LINGER); + this.flags.insert(Flags::SHUTDOWN); + this.shutdown_timer.clear(line!()); + } else { + trace!("timed-out during shutdown"); + return Err(DispatchError::DisconnectTimeout); + } } } @@ -1133,7 +1203,15 @@ where inner.as_mut().poll_timers(cx)?; - let poll = if inner.flags.contains(Flags::SHUTDOWN) { + let poll = if inner.flags.contains(Flags::LINGER) { + match inner.as_mut().poll_linger(cx)? { + Poll::Ready(()) => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => Poll::Pending, + } + } else if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::WRITE_DISCONNECT) { Poll::Ready(Ok(())) } else { @@ -1281,7 +1359,7 @@ where inner_p.shutdown_timer, ); - if inner_p.flags.contains(Flags::SHUTDOWN) { + if inner_p.flags.intersects(Flags::SHUTDOWN | Flags::LINGER) { cx.waker().wake_by_ref(); } Poll::Pending @@ -1295,6 +1373,13 @@ where } } +fn should_linger(payload: Option<&PayloadSender>, payload_drainable: bool) -> bool { + let payload_unfinished = payload.is_some(); + let drain_payload = payload.is_some_and(|pl| pl.is_dropped()) && payload_drainable; + + payload_unfinished && !drain_payload +} + #[allow(dead_code)] fn trace_timer_states( label: &str, diff --git a/actix-http/src/h1/dispatcher_tests.rs b/actix-http/src/h1/dispatcher_tests.rs index e3a907e5c..dbb20327c 100644 --- a/actix-http/src/h1/dispatcher_tests.rs +++ b/actix-http/src/h1/dispatcher_tests.rs @@ -7,7 +7,10 @@ use std::{ }; use actix_codec::Framed; -use actix_rt::{pin, time::sleep}; +use actix_rt::{ + pin, + time::{sleep, timeout}, +}; use actix_service::{fn_service, Service}; use actix_utils::future::{ready, Ready}; use bytes::{Buf, Bytes, BytesMut}; @@ -84,6 +87,11 @@ fn drop_payload_service() -> impl Service impl Service, Error = Error> +{ + fn_service(|_req: Request| ready(Ok::<_, Error>(Response::with_body(StatusCode::OK, "ok")))) +} + fn echo_payload_service() -> impl Service, Error = Error> { fn_service(|mut req: Request| { Box::pin(async move { @@ -296,7 +304,7 @@ async fn keep_alive_timeout() { // polls: initial => keep-alive wake-up shutdown assert_eq!(h1.poll_count, 2); - if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() { + if let DispatcherStateProj::Normal { inner } = h1.as_mut().project().inner.project() { // connection closed assert!(inner.flags.contains(Flags::SHUTDOWN)); assert!(inner.flags.contains(Flags::WRITE_DISCONNECT)); @@ -454,7 +462,7 @@ async fn req_parse_err() { Poll::Ready(res) => assert!(res.is_err()), } - if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() { + if let DispatcherStateProj::Normal { inner } = h1.as_mut().project().inner.project() { assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert_eq!( &buf.write_buf_slice()[..26], @@ -536,7 +544,7 @@ async fn pipelining_ok_then_ok() { } #[actix_rt::test] -async fn early_response_with_payload_closes_connection() { +async fn early_response_with_payload_lingers_before_closing() { lazy(|cx| { let buf = TestBuffer::new( "\ @@ -569,12 +577,12 @@ async fn early_response_with_payload_closes_connection() { assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); match h1.as_mut().poll(cx) { - Poll::Pending => panic!("Should have shut down"), - Poll::Ready(res) => assert!(res.is_ok()), + Poll::Pending => {} + Poll::Ready(res) => panic!("should still be lingering: {:?}", res), } - // polls: initial => shutdown - assert_eq!(h1.poll_count, 2); + // polls: initial + assert_eq!(h1.poll_count, 1); { let mut res = buf.write_buf_slice_mut(); @@ -584,6 +592,7 @@ async fn early_response_with_payload_closes_connection() { let exp = b"\ HTTP/1.1 200 OK\r\n\ content-length: 11\r\n\ + connection: close\r\n\ date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\ /unfinished\ "; @@ -602,6 +611,138 @@ async fn early_response_with_payload_closes_connection() { .await; } +#[actix_rt::test] +async fn buffered_upload_ignored_by_handler_should_not_shutdown_immediately() { + lazy(|cx| { + let buf = TestSeqBuffer::new(http_msg( + r" + POST / HTTP/1.1 + Content-Length: 8 + + ab + ", + )); + + let cfg = ServiceConfig::new( + KeepAlive::Os, + Duration::from_millis(1), + Duration::from_millis(1), + false, + None, + ); + + let services = HttpFlow::new(ignore_payload_service(), ExpectHandler, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + buf.clone(), + services, + cfg, + None, + OnConnectData::default(), + ); + + pin!(h1); + + assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); + + match h1.as_mut().poll(cx) { + Poll::Pending => {} + Poll::Ready(res) => panic!("closed connection early: {:?}", res), + } + + let mut res = BytesMut::from(buf.take_write_buf().as_ref()); + stabilize_date_header(&mut res); + let res = &res[..]; + + let exp = http_msg( + r" + HTTP/1.1 200 OK + content-length: 2 + connection: close + date: Thu, 01 Jan 1970 12:34:56 UTC + + ok + ", + ); + + assert_eq!( + res, + exp, + "\nexpected response not in write buffer:\n\ + response: {:?}\n\ + expected: {:?}", + String::from_utf8_lossy(res), + String::from_utf8_lossy(&exp) + ); + + buf.close_read(); + + assert!(h1.as_mut().poll(cx).is_pending()); + assert!(h1.as_mut().poll(cx).is_ready()); + }) + .await; +} + +#[actix_rt::test] +async fn lingering_timeout_uses_graceful_shutdown() { + let buf = TestSeqBuffer::new( + "\ + POST / HTTP/1.1\r\n\ + Content-Length: 8\r\n\ + \r\n\ + ab\ + ", + ); + + let cfg = ServiceConfig::new( + KeepAlive::Disabled, + Duration::ZERO, + Duration::ZERO, + false, + None, + ); + + let services = HttpFlow::new(ignore_payload_service(), ExpectHandler, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + buf.clone(), + services, + cfg, + None, + OnConnectData::default(), + ); + + // A linger timeout should close an idle connection even if the peer never sends anything + // after the partial body. This catches cases where the timeout is only armed on later peer + // activity, which leaves unread-body connections hanging indefinitely. + assert!(matches!( + timeout(Duration::from_millis(1500), h1).await, + Ok(Ok(())) + )); + + let mut res = buf.take_write_buf().to_vec(); + stabilize_date_header(&mut res); + let res = &res[..]; + + let exp = b"\ + HTTP/1.1 200 OK\r\n\ + content-length: 2\r\n\ + connection: close\r\n\ + date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\ + ok\ + "; + + assert_eq!( + res, + exp, + "\nexpected response not in write buffer:\n\ + response: {:?}\n\ + expected: {:?}", + String::from_utf8_lossy(res), + String::from_utf8_lossy(exp) + ); +} + #[actix_rt::test] async fn pipelining_ok_then_bad() { lazy(|cx| { @@ -781,11 +922,11 @@ async fn expect_eager() { pin!(h1); - assert!(h1.as_mut().poll(cx).is_ready()); + assert!(h1.as_mut().poll(cx).is_pending()); assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); - // polls: manual shutdown - assert_eq!(h1.poll_count, 2); + // polls: manual + assert_eq!(h1.poll_count, 1); if let DispatcherState::Normal { ref inner } = h1.inner { let io = inner.io.as_ref().unwrap(); @@ -810,6 +951,11 @@ async fn expect_eager() { " ); } + + buf.close_read(); + + assert!(h1.as_mut().poll(cx).is_pending()); + assert!(h1.as_mut().poll(cx).is_ready()); }) .await; }