diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index eefe16c7e..422993bcd 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -11,7 +11,7 @@ use bytes::{Buf, Bytes}; use derive_more::Display; #[cfg(feature = "compress-gzip")] use flate2::write::{GzEncoder, ZlibEncoder}; -use futures_core::ready; +use futures_core::{ready, Stream}; use pin_project_lite::pin_project; use tracing::trace; #[cfg(feature = "compress-zstd")] @@ -28,7 +28,7 @@ pin_project! { pub struct Encoder { #[pin] body: EncoderBody, - encoder: Option>, + encoder: Option>, eof: bool, } } @@ -62,12 +62,12 @@ impl Encoder { if should_encode { // wrap body only if encoder is feature-enabled - if let Some(selected_encoder) = ContentEncoder::select(encoding) { + if let Some(coop_encoder) = CooperativeContentEncoder::select(encoding) { update_head(encoding, head); return Encoder { body, - encoder: Some(selected_encoder), + encoder: Some(coop_encoder), eof: false, }; } @@ -82,8 +82,8 @@ impl Encoder { pub fn with_encode_chunk_size(mut self, size: usize) -> Self { if size > 0 { - if let Some(selected_encoder) = self.encoder.as_mut() { - selected_encoder.preferred_chunk_size = size; + if let Some(coop_encoder) = self.encoder.as_mut() { + coop_encoder.preferred_chunk_size = size; } } self @@ -170,24 +170,16 @@ where return Poll::Ready(None); } - if let Some(selected_encoder) = this.encoder.as_deref_mut() { - if let Some(chunk) = selected_encoder.chunk_ready_to_encode.as_mut() { - let encode_len = chunk.len().min(selected_encoder.preferred_chunk_size); - selected_encoder - .content_encoder - .write(&chunk[..encode_len]) - .map_err(EncoderError::Io)?; - chunk.advance(encode_len); - - if chunk.is_empty() { - selected_encoder.chunk_ready_to_encode = None; + if let Some(cooperative_encoder) = this.encoder.as_deref_mut() { + match ready!(Pin::new(cooperative_encoder).poll_next(cx)) { + Some(Ok(Some(chunk))) => return Poll::Ready(Some(Ok(chunk))), + Some(Ok(None)) => { + // Need more data from uncompressed body } - - let encoded_chunk = selected_encoder.content_encoder.take(); - if encoded_chunk.is_empty() { - continue; + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => { + unreachable!() } - return Poll::Ready(Some(Ok(encoded_chunk))); } } @@ -199,13 +191,15 @@ where Some(Ok(chunk)) => match this.encoder.as_deref_mut() { None => return Poll::Ready(Some(Ok(chunk))), Some(encoder) => { + debug_assert!(encoder.chunk_ready_to_encode.is_none()); encoder.chunk_ready_to_encode = Some(chunk); + encoder.budget_used = 0; } }, None => { - if let Some(selected_encoder) = this.encoder.take() { - let chunk = selected_encoder + if let Some(coop_encoder) = this.encoder.take() { + let chunk = coop_encoder .content_encoder .finish() .map_err(EncoderError::Io)?; @@ -268,14 +262,15 @@ enum ContentEncoder { Zstd(ZstdEncoder<'static, Writer>), } -struct ChunkedContentEncoder { +struct CooperativeContentEncoder { content_encoder: ContentEncoder, preferred_chunk_size: usize, chunk_ready_to_encode: Option, + budget_used: u8, } -impl ContentEncoder { - fn select(encoding: ContentEncoding) -> Option> { +impl CooperativeContentEncoder { + fn select(encoding: ContentEncoding) -> Option> { // Chunk size picked as max chunk size which took less that 50 µs to compress on "cargo bench --bench compression-chunk-size" // Rust 1.72 linux/arm64 in Docker on Apple M2 Pro: "time to compress chunk/deflate-16384" time: [39.114 µs 39.283 µs 39.457 µs] @@ -289,49 +284,55 @@ impl ContentEncoder { match encoding { #[cfg(feature = "compress-gzip")] - ContentEncoding::Deflate => Some(Box::new(ChunkedContentEncoder { + ContentEncoding::Deflate => Some(Box::new(CooperativeContentEncoder { content_encoder: ContentEncoder::Deflate(ZlibEncoder::new( Writer::new(), flate2::Compression::fast(), )), preferred_chunk_size: MAX_DEFLATE_CHUNK_SIZE, chunk_ready_to_encode: None, + budget_used: 0, })), #[cfg(feature = "compress-gzip")] - ContentEncoding::Gzip => Some(Box::new(ChunkedContentEncoder { + ContentEncoding::Gzip => Some(Box::new(CooperativeContentEncoder { content_encoder: ContentEncoder::Gzip(GzEncoder::new( Writer::new(), flate2::Compression::fast(), )), preferred_chunk_size: MAX_GZIP_CHUNK_SIZE, chunk_ready_to_encode: None, + budget_used: 0, })), #[cfg(feature = "compress-brotli")] - ContentEncoding::Brotli => Some(Box::new(ChunkedContentEncoder { + ContentEncoding::Brotli => Some(Box::new(CooperativeContentEncoder { content_encoder: ContentEncoder::Brotli(new_brotli_compressor()), preferred_chunk_size: MAX_BROTLI_CHUNK_SIZE, chunk_ready_to_encode: None, + budget_used: 0, })), #[cfg(feature = "compress-zstd")] ContentEncoding::Zstd => { let encoder = ZstdEncoder::new(Writer::new(), 3).ok()?; - Some(Box::new(ChunkedContentEncoder { + Some(Box::new(CooperativeContentEncoder { content_encoder: ContentEncoder::Zstd(encoder), preferred_chunk_size: MAX_ZSTD_CHUNK_SIZE, chunk_ready_to_encode: None, + budget_used: 0, })) } _ => None, } } +} +impl ContentEncoder { #[inline] pub(crate) fn take(&mut self) -> Bytes { - match *self { + match self { #[cfg(feature = "compress-brotli")] ContentEncoder::Brotli(ref mut encoder) => encoder.get_mut().take(), @@ -375,7 +376,7 @@ impl ContentEncoder { } fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { - match *self { + match self { #[cfg(feature = "compress-brotli")] ContentEncoder::Brotli(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), @@ -415,6 +416,42 @@ impl ContentEncoder { } } +impl futures_core::Stream for CooperativeContentEncoder { + type Item = Result, EncoderError>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + loop { + if this.budget_used > 8 { + this.budget_used = 0; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if let Some(mut chunk) = this.chunk_ready_to_encode.take() { + let encode_len = chunk.len().min(this.preferred_chunk_size); + this.content_encoder + .write(&chunk[..encode_len]) + .map_err(EncoderError::Io)?; + chunk.advance(encode_len); + + if !chunk.is_empty() { + this.chunk_ready_to_encode = Some(chunk); + } + + let encoded_chunk = this.content_encoder.take(); + if encoded_chunk.is_empty() { + continue; + } + + this.budget_used += 1; + return Poll::Ready(Some(Ok(Some(encoded_chunk)))); + } else { + return Poll::Ready(Some(Ok(None))); + } + } + } +} + #[cfg(feature = "compress-brotli")] fn new_brotli_compressor() -> Box> { Box::new(brotli::CompressorWriter::new( @@ -477,10 +514,9 @@ mod tests { let compressed_body = Encoder::response(encoding, &mut head, body_to_compress.clone()) .with_encode_chunk_size(rand::thread_rng().gen_range(32..128)); - let encoder = ContentEncoder::select(encoding).unwrap(); - let mut compressor = encoder.content_encoder; - compressor.write(&body_to_compress).unwrap(); - let reference_compressed_bytes = compressor.finish().unwrap(); + let mut encoder = CooperativeContentEncoder::select(encoding).unwrap(); + encoder.content_encoder.write(&body_to_compress).unwrap(); + let reference_compressed_bytes = encoder.content_encoder.finish().unwrap(); let compressed_bytes = body::to_bytes_limited(compressed_body, 256 + body_to_compress.len())