actix-http: linger after early responses (#3985)

Co-authored-by: Ophir LOJKINE <contact@ophir.dev>
This commit is contained in:
Yuki Okushi 2026-04-18 12:46:44 +09:00 committed by GitHub
parent 0fb89457ed
commit 10609f749d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 359 additions and 80 deletions

View File

@ -2,6 +2,10 @@
## Unreleased ## Unreleased
- When configured, gracefully close HTTP/1 connections after early responses to unread request bodies. [#3967]
[#3967]: https://github.com/actix/actix-web/issues/3967
## 3.12.1 ## 3.12.1
**Notice: This release contains a security fix. Users are encouraged to update to this version ASAP.** **Notice: This release contains a security fix. Users are encouraged to update to this version ASAP.**

View File

@ -31,7 +31,7 @@ use crate::{
config::ServiceConfig, config::ServiceConfig,
error::{DispatchError, ParseError, PayloadError}, error::{DispatchError, ParseError, PayloadError},
service::HttpFlow, service::HttpFlow,
Error, Extensions, HttpMessage, OnConnectData, Request, Response, StatusCode, ConnectionType, Error, Extensions, HttpMessage, OnConnectData, Request, Response, StatusCode,
}; };
const LW_BUFFER_SIZE: usize = 1024; const LW_BUFFER_SIZE: usize = 1024;
@ -58,6 +58,9 @@ bitflags! {
/// Set if write-half is disconnected. /// Set if write-half is disconnected.
const WRITE_DISCONNECT = 0b0010_0000; const WRITE_DISCONNECT = 0b0010_0000;
/// Set while gracefully closing a connection after an early response.
const LINGER = 0b0100_0000;
} }
} }
@ -361,6 +364,65 @@ where
io.poll_flush(cx) io.poll_flush(cx)
} }
fn enter_linger(mut self: Pin<&mut Self>) {
let this = self.as_mut().project();
this.flags.remove(Flags::KEEP_ALIVE);
this.flags.insert(Flags::LINGER | Flags::FINISHED);
}
fn ensure_linger_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool {
let this = self.as_mut().project();
if matches!(this.shutdown_timer, TimerState::Active { .. }) {
return true;
}
if let Some(deadline) = this.config.client_disconnect_deadline() {
this.shutdown_timer
.set_and_init(cx, sleep_until(deadline.into()), line!());
true
} else {
false
}
}
fn poll_linger(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Result<Poll<()>, DispatchError> {
if self.as_mut().poll_flush(cx)?.is_pending() {
return Ok(Poll::Pending);
}
if !self.as_mut().ensure_linger_timer(cx) {
let this = self.as_mut().project();
this.flags.remove(Flags::LINGER);
this.flags.insert(Flags::SHUTDOWN);
return Ok(Poll::Ready(()));
}
loop {
let should_disconnect = self.as_mut().read_available(cx)?;
let this = self.as_mut().project();
let mut progressed = false;
if !this.read_buf.is_empty() {
this.read_buf.clear();
progressed = true;
}
if should_disconnect {
this.flags.remove(Flags::LINGER);
this.flags.insert(Flags::READ_DISCONNECT | Flags::SHUTDOWN);
return Ok(Poll::Ready(()));
}
if !progressed {
return Ok(Poll::Pending);
}
}
}
fn send_response_inner( fn send_response_inner(
self: Pin<&mut Self>, self: Pin<&mut Self>,
res: Response<()>, res: Response<()>,
@ -385,54 +447,90 @@ where
fn send_response( fn send_response(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
res: Response<()>, mut res: Response<()>,
body: B, body: B,
) -> Result<(), DispatchError> { ) -> Result<(), DispatchError> {
let size = self.as_mut().send_response_inner(res, &body)?; let close_after_response = {
let mut this = self.project(); let this = self.as_mut().project();
this.state.set(match size { should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
BodySize::None | BodySize::Sized(0) => { };
let payload_unfinished = this.payload.is_some();
let drain_payload = this.payload.as_ref().is_some_and(|pl| pl.is_dropped())
&& *this.payload_drainable;
if payload_unfinished && !drain_payload { if close_after_response {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); res.head_mut().set_connection_type(ConnectionType::Close);
}
let size = self.as_mut().send_response_inner(res, &body)?;
match size {
BodySize::None | BodySize::Sized(0) => {
let this = self.as_mut().project();
if close_after_response {
if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else { } else {
this.flags.insert(Flags::FINISHED); this.flags.insert(Flags::FINISHED);
} }
State::None self.as_mut().project().state.set(State::None);
} }
_ => State::SendPayload { body }, _ => self
}); .as_mut()
.project()
.state
.set(State::SendPayload { body }),
}
Ok(()) Ok(())
} }
fn send_error_response( fn send_error_response(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
res: Response<()>, mut res: Response<()>,
body: BoxBody, body: BoxBody,
) -> Result<(), DispatchError> { ) -> Result<(), DispatchError> {
let size = self.as_mut().send_response_inner(res, &body)?; let close_after_response = {
let mut this = self.project(); let this = self.as_mut().project();
this.state.set(match size { should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
BodySize::None | BodySize::Sized(0) => { };
let payload_unfinished = this.payload.is_some();
let drain_payload = this.payload.as_ref().is_some_and(|pl| pl.is_dropped())
&& *this.payload_drainable;
if payload_unfinished && !drain_payload { if close_after_response {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); res.head_mut().set_connection_type(ConnectionType::Close);
}
let size = self.as_mut().send_response_inner(res, &body)?;
match size {
BodySize::None | BodySize::Sized(0) => {
let this = self.as_mut().project();
if close_after_response {
if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else { } else {
this.flags.insert(Flags::FINISHED); this.flags.insert(Flags::FINISHED);
} }
State::None self.as_mut().project().state.set(State::None);
} }
_ => State::SendErrorPayload { body }, _ => self
}); .as_mut()
.project()
.state
.set(State::SendErrorPayload { body }),
}
Ok(()) Ok(())
} }
@ -534,18 +632,26 @@ where
// this.payload was the payload for the request we just finished // this.payload was the payload for the request we just finished
// responding to. We can check to see if we finished reading it // responding to. We can check to see if we finished reading it
// yet, and if not, shutdown the connection. // yet, and if not, shutdown the connection.
let payload_unfinished = this.payload.is_some(); let close_after_response = should_close_after_response(
let drain_payload = this.payload.as_ref(),
this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) *this.payload_drainable,
&& *this.payload_drainable; );
let not_pipelined = this.messages.is_empty(); let not_pipelined = this.messages.is_empty();
// payload stream finished. // payload stream finished.
// set state to None and handle next message // set state to None and handle next message
this.state.set(State::None); this.state.set(State::None);
if not_pipelined && payload_unfinished && !drain_payload { if not_pipelined && close_after_response {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else { } else {
this.flags.insert(Flags::FINISHED); this.flags.insert(Flags::FINISHED);
} }
@ -588,18 +694,26 @@ where
// this.payload was the payload for the request we just finished // this.payload was the payload for the request we just finished
// responding to. We can check to see if we finished reading it // responding to. We can check to see if we finished reading it
// yet, and if not, shutdown the connection. // yet, and if not, shutdown the connection.
let payload_unfinished = this.payload.is_some(); let close_after_response = should_close_after_response(
let drain_payload = this.payload.as_ref(),
this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) *this.payload_drainable,
&& *this.payload_drainable; );
let not_pipelined = this.messages.is_empty(); let not_pipelined = this.messages.is_empty();
// payload stream finished. // payload stream finished.
// set state to None and handle next message // set state to None and handle next message
this.state.set(State::None); this.state.set(State::None);
if not_pipelined && payload_unfinished && !drain_payload { if not_pipelined && close_after_response {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED); if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else { } else {
this.flags.insert(Flags::FINISHED); this.flags.insert(Flags::FINISHED);
} }
@ -960,14 +1074,20 @@ where
let this = self.as_mut().project(); let this = self.as_mut().project();
if let TimerState::Active { timer } = this.shutdown_timer { if let TimerState::Active { timer } = this.shutdown_timer {
debug_assert!( debug_assert!(
this.flags.contains(Flags::SHUTDOWN), this.flags.intersects(Flags::LINGER | Flags::SHUTDOWN),
"shutdown flag should be set when timer is active", "shutdown or linger flag should be set when timer is active",
); );
// timed-out during shutdown; drop connection
if timer.as_mut().poll(cx).is_ready() { if timer.as_mut().poll(cx).is_ready() {
trace!("timed-out during shutdown"); if this.flags.contains(Flags::LINGER) {
return Err(DispatchError::DisconnectTimeout); trace!("timed-out during linger; shutting down connection");
this.flags.remove(Flags::LINGER);
this.flags.insert(Flags::SHUTDOWN);
this.shutdown_timer.clear(line!());
} else {
trace!("timed-out during shutdown");
return Err(DispatchError::DisconnectTimeout);
}
} }
} }
@ -1133,7 +1253,15 @@ where
inner.as_mut().poll_timers(cx)?; inner.as_mut().poll_timers(cx)?;
let poll = if inner.flags.contains(Flags::SHUTDOWN) { let poll = if inner.flags.contains(Flags::LINGER) {
match inner.as_mut().poll_linger(cx)? {
Poll::Ready(()) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
} else if inner.flags.contains(Flags::SHUTDOWN) {
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} else { } else {
@ -1281,7 +1409,7 @@ where
inner_p.shutdown_timer, inner_p.shutdown_timer,
); );
if inner_p.flags.contains(Flags::SHUTDOWN) { if inner_p.flags.intersects(Flags::LINGER | Flags::SHUTDOWN) {
cx.waker().wake_by_ref(); cx.waker().wake_by_ref();
} }
Poll::Pending Poll::Pending
@ -1295,6 +1423,13 @@ where
} }
} }
fn should_close_after_response(payload: Option<&PayloadSender>, payload_drainable: bool) -> bool {
let payload_unfinished = payload.is_some();
let drain_payload = payload.is_some_and(|pl| pl.is_dropped()) && payload_drainable;
payload_unfinished && !drain_payload
}
#[allow(dead_code)] #[allow(dead_code)]
fn trace_timer_states( fn trace_timer_states(
label: &str, label: &str,

View File

@ -7,7 +7,10 @@ use std::{
}; };
use actix_codec::Framed; use actix_codec::Framed;
use actix_rt::{pin, time::sleep}; use actix_rt::{
pin,
time::{sleep, timeout},
};
use actix_service::{fn_service, Service}; use actix_service::{fn_service, Service};
use actix_utils::future::{ready, Ready}; use actix_utils::future::{ready, Ready};
use bytes::{Buf, Bytes, BytesMut}; use bytes::{Buf, Bytes, BytesMut};
@ -84,6 +87,11 @@ fn drop_payload_service() -> impl Service<Request, Response = Response<&'static
}) })
} }
fn ignore_payload_service(
) -> impl Service<Request, Response = Response<&'static str>, Error = Error> {
fn_service(|_req: Request| ready(Ok::<_, Error>(Response::with_body(StatusCode::OK, "ok"))))
}
fn echo_payload_service() -> impl Service<Request, Response = Response<Bytes>, Error = Error> { fn echo_payload_service() -> impl Service<Request, Response = Response<Bytes>, Error = Error> {
fn_service(|mut req: Request| { fn_service(|mut req: Request| {
Box::pin(async move { Box::pin(async move {
@ -536,15 +544,14 @@ async fn pipelining_ok_then_ok() {
} }
#[actix_rt::test] #[actix_rt::test]
async fn early_response_with_payload_closes_connection() { async fn early_response_with_payload_lingers_before_closing() {
lazy(|cx| { lazy(|cx| {
let buf = TestBuffer::new( let buf = TestSeqBuffer::new(http_msg(
"\ r"
GET /unfinished HTTP/1.1\r\n\ GET /unfinished HTTP/1.1
Content-Length: 2\r\n\ Content-Length: 2
\r\n\ ",
", ));
);
let cfg = ServiceConfig::new( let cfg = ServiceConfig::new(
KeepAlive::Os, KeepAlive::Os,
@ -569,39 +576,172 @@ async fn early_response_with_payload_closes_connection() {
assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
match h1.as_mut().poll(cx) { match h1.as_mut().poll(cx) {
Poll::Pending => panic!("Should have shut down"), Poll::Pending => {}
Poll::Ready(res) => assert!(res.is_ok()), Poll::Ready(res) => panic!("should still be lingering: {:?}", res),
} }
// polls: initial => shutdown // polls: initial
assert_eq!(h1.poll_count, 2); assert_eq!(h1.poll_count, 1);
{ let mut res = buf.take_write_buf().to_vec();
let mut res = buf.write_buf_slice_mut(); stabilize_date_header(&mut res);
stabilize_date_header(&mut res); let res = &res[..];
let res = &res[..];
let exp = b"\ let exp = b"\
HTTP/1.1 200 OK\r\n\ HTTP/1.1 200 OK\r\n\
content-length: 11\r\n\ content-length: 11\r\n\
date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\ connection: close\r\n\
/unfinished\ date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\
"; /unfinished\
";
assert_eq!( assert_eq!(
res, res,
exp, exp,
"\nexpected response not in write buffer:\n\ "\nexpected response not in write buffer:\n\
response: {:?}\n\ response: {:?}\n\
expected: {:?}", expected: {:?}",
String::from_utf8_lossy(res), String::from_utf8_lossy(res),
String::from_utf8_lossy(exp) String::from_utf8_lossy(exp)
); );
}
buf.close_read();
assert!(h1.as_mut().poll(cx).is_pending());
assert!(h1.as_mut().poll(cx).is_ready());
}) })
.await; .await;
} }
#[actix_rt::test]
async fn buffered_upload_ignored_by_handler_should_not_shutdown_immediately() {
lazy(|cx| {
let buf = TestSeqBuffer::new(http_msg(
r"
POST / HTTP/1.1
Content-Length: 8
ab
",
));
let cfg = ServiceConfig::new(
KeepAlive::Os,
Duration::from_millis(1),
Duration::from_millis(1),
false,
None,
);
let services = HttpFlow::new(ignore_payload_service(), ExpectHandler, None);
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf.clone(),
services,
cfg,
None,
OnConnectData::default(),
);
pin!(h1);
assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
match h1.as_mut().poll(cx) {
Poll::Pending => {}
Poll::Ready(res) => panic!("closed connection early: {:?}", res),
}
let mut res = BytesMut::from(buf.take_write_buf().as_ref());
stabilize_date_header(&mut res);
let res = &res[..];
let exp = http_msg(
r"
HTTP/1.1 200 OK
content-length: 2
connection: close
date: Thu, 01 Jan 1970 12:34:56 UTC
ok
",
);
assert_eq!(
res,
exp,
"\nexpected response not in write buffer:\n\
response: {:?}\n\
expected: {:?}",
String::from_utf8_lossy(res),
String::from_utf8_lossy(&exp)
);
buf.close_read();
assert!(h1.as_mut().poll(cx).is_pending());
assert!(h1.as_mut().poll(cx).is_ready());
})
.await;
}
#[actix_rt::test]
async fn lingering_timeout_uses_graceful_shutdown() {
let buf = TestSeqBuffer::new(
"\
POST / HTTP/1.1\r\n\
Content-Length: 8\r\n\
\r\n\
ab\
",
);
let cfg = ServiceConfig::new(
KeepAlive::Disabled,
Duration::ZERO,
Duration::from_millis(1),
false,
None,
);
let services = HttpFlow::new(ignore_payload_service(), ExpectHandler, None);
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf.clone(),
services,
cfg,
None,
OnConnectData::default(),
);
assert!(matches!(
timeout(Duration::from_millis(100), h1).await,
Ok(Ok(()))
));
let mut res = buf.take_write_buf().to_vec();
stabilize_date_header(&mut res);
let res = &res[..];
let exp = b"\
HTTP/1.1 200 OK\r\n\
content-length: 2\r\n\
connection: close\r\n\
date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\
ok\
";
assert_eq!(
res,
exp,
"\nexpected response not in write buffer:\n\
response: {:?}\n\
expected: {:?}",
String::from_utf8_lossy(res),
String::from_utf8_lossy(exp)
);
}
#[actix_rt::test] #[actix_rt::test]
async fn pipelining_ok_then_bad() { async fn pipelining_ok_then_bad() {
lazy(|cx| { lazy(|cx| {

View File

@ -245,7 +245,7 @@ where
/// ///
/// To disable timeout set value to 0. /// To disable timeout set value to 0.
/// ///
/// By default client timeout is set to 5000 milliseconds. /// By default client timeout is set to 1000 milliseconds.
pub fn client_disconnect_timeout(self, dur: Duration) -> Self { pub fn client_disconnect_timeout(self, dur: Duration) -> Self {
self.config.lock().unwrap().client_disconnect_timeout = dur; self.config.lock().unwrap().client_disconnect_timeout = dur;
self self