From 32d74edb8f406753d660ba41e7d5b844f98a44c6 Mon Sep 17 00:00:00 2001 From: Arthur LE MOIGNE Date: Fri, 23 Jul 2021 23:27:51 +0200 Subject: [PATCH] Add payload in NOT_ACCEPTABLE compress response --- src/middleware/compress.rs | 60 +++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs index 8baac7adf..ed0588562 100644 --- a/src/middleware/compress.rs +++ b/src/middleware/compress.rs @@ -10,9 +10,10 @@ use std::{ }; use actix_http::{ - body::{AnyBody, ResponseBody}, + body::{MessageBody, ResponseBody}, encoding::Encoder, http::header::{ContentEncoding, ACCEPT_ENCODING}, + StatusCode, }; use actix_service::{Service, Transform}; use actix_utils::future::{ok, Either, Ready}; @@ -54,11 +55,12 @@ impl Default for Compress { } } -impl Transform for Compress +impl Transform for Compress where - S: Service, Error = Error>, + B: MessageBody + From, + S: Service, Error = Error>, { - type Response = ServiceResponse>>; + type Response = ServiceResponse>>; type Error = Error; type Transform = CompressMiddleware; type InitError = (); @@ -77,13 +79,39 @@ pub struct CompressMiddleware { encoding: ContentEncoding, } -impl Service for CompressMiddleware +fn supported_algorithm_names() -> String { + let mut encoding = vec![]; + + #[cfg(feature = "compress-brotli")] + { + encoding.push("br"); + } + + #[cfg(feature = "compress-gzip")] + { + encoding.push("gzip"); + encoding.push("deflate"); + } + + #[cfg(feature = "compress-zstd")] + encoding.push("zstd"); + + assert!( + !encoding.is_empty(), + "encoding can not be empty unless __compress feature has been explicitly enabled" + ); + + encoding.join(", ") +} + +impl Service for CompressMiddleware where - S: Service, Error = Error>, + B: MessageBody + From, + S: Service, Error = Error>, { - type Response = ServiceResponse>>; + type Response = ServiceResponse>>; type Error = Error; - type Future = Either, Ready>>; + type Future = Either, Ready>>; actix_service::forward_ready!(service); @@ -113,7 +141,10 @@ where // There is an HTTP header but we cannot match what client as asked for Some(Err(_)) => { - let res = HttpResponse::NotAcceptable().finish(); + let res = HttpResponse::with_body( + StatusCode::NOT_ACCEPTABLE, + supported_algorithm_names().into(), + ); let enc = ContentEncoding::Identity; Either::right(ok(req.into_response(res.map_body(move |head, body| { @@ -125,21 +156,22 @@ where } #[pin_project] -pub struct CompressResponse +pub struct CompressResponse where S: Service, { #[pin] fut: S::Future, encoding: ContentEncoding, - _phantom: PhantomData, + _phantom: PhantomData, } -impl Future for CompressResponse +impl Future for CompressResponse where - S: Service, Error = Error>, + B: MessageBody, + S: Service, Error = Error>, { - type Output = Result>>, Error>; + type Output = Result>>, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project();