From e254fe4f9c37da954c4bf544ca30207415dbd426 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Wed, 27 Mar 2019 11:29:31 -0700
Subject: [PATCH] allow to override response body encoding

---
 actix-files/src/named.rs     | 12 ++++----
 actix-http/src/message.rs    |  3 --
 actix-http/src/response.rs   | 16 ++++++++++
 examples/basic.rs            |  2 +-
 src/middleware/compress.rs   | 40 ++++++++++++++++++++++++-
 src/middleware/decompress.rs | 16 ++++++++++
 src/middleware/mod.rs        |  9 +++---
 tests/test_server.rs         | 58 ++++++++++++++++++++++++++----------
 8 files changed, 125 insertions(+), 31 deletions(-)

diff --git a/actix-files/src/named.rs b/actix-files/src/named.rs
index 7bc37054..842a0e5e 100644
--- a/actix-files/src/named.rs
+++ b/actix-files/src/named.rs
@@ -15,6 +15,7 @@ use actix_web::http::header::{
     self, ContentDisposition, DispositionParam, DispositionType,
 };
 use actix_web::http::{ContentEncoding, Method, StatusCode};
+use actix_web::middleware::encoding::BodyEncoding;
 use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder};
 
 use crate::range::HttpRange;
@@ -360,10 +361,10 @@ impl Responder for NamedFile {
                 header::CONTENT_DISPOSITION,
                 self.content_disposition.to_string(),
             );
-        // TODO blocking by compressing
-        // if let Some(current_encoding) = self.encoding {
-        //     resp.content_encoding(current_encoding);
-        // }
+        // default compressing
+        if let Some(current_encoding) = self.encoding {
+            resp.encoding(current_encoding);
+        }
 
         resp.if_some(last_modified, |lm, resp| {
             resp.set(header::LastModified(lm));
@@ -383,8 +384,7 @@ impl Responder for NamedFile {
                 if let Ok(rangesvec) = HttpRange::parse(rangesheader, length) {
                     length = rangesvec[0].length;
                     offset = rangesvec[0].start;
-                    // TODO blocking by compressing
-                    // resp.content_encoding(ContentEncoding::Identity);
+                    resp.encoding(ContentEncoding::Identity);
                     resp.header(
                         header::CONTENT_RANGE,
                         format!(
diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs
index a1e9e3c6..3466f66d 100644
--- a/actix-http/src/message.rs
+++ b/actix-http/src/message.rs
@@ -24,9 +24,6 @@ bitflags! {
         const KEEP_ALIVE  = 0b0000_0010;
         const UPGRADE     = 0b0000_0100;
         const NO_CHUNKING = 0b0000_1000;
-        const ENC_BR      = 0b0001_0000;
-        const ENC_DEFLATE = 0b0010_0000;
-        const ENC_GZIP    = 0b0100_0000;
     }
 }
 
diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs
index 3b33e1f9..29a850fa 100644
--- a/actix-http/src/response.rs
+++ b/actix-http/src/response.rs
@@ -1,4 +1,5 @@
 //! Http response
+use std::cell::{Ref, RefMut};
 use std::io::Write;
 use std::{fmt, str};
 
@@ -14,6 +15,7 @@ use serde_json;
 
 use crate::body::{Body, BodyStream, MessageBody, ResponseBody};
 use crate::error::Error;
+use crate::extensions::Extensions;
 use crate::header::{Header, IntoHeaderValue};
 use crate::message::{ConnectionType, Message, ResponseHead};
 
@@ -577,6 +579,20 @@ impl ResponseBuilder {
         self
     }
 
+    /// Responses extensions
+    #[inline]
+    pub fn extensions(&self) -> Ref<Extensions> {
+        let head = self.head.as_ref().expect("cannot reuse response builder");
+        head.extensions.borrow()
+    }
+
+    /// Mutable reference to a the response's extensions
+    #[inline]
+    pub fn extensions_mut(&mut self) -> RefMut<Extensions> {
+        let head = self.head.as_ref().expect("cannot reuse response builder");
+        head.extensions.borrow_mut()
+    }
+
     /// Set a body and generate `Response`.
     ///
     /// `ResponseBuilder` can not be used after this call.
diff --git a/examples/basic.rs b/examples/basic.rs
index 91119657..1191b371 100644
--- a/examples/basic.rs
+++ b/examples/basic.rs
@@ -27,7 +27,7 @@ fn main() -> std::io::Result<()> {
     HttpServer::new(|| {
         App::new()
             .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2"))
-            .wrap(middleware::Compress::default())
+            .wrap(middleware::encoding::Compress::default())
             .wrap(middleware::Logger::default())
             .service(index)
             .service(no_params)
diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs
index 5ffe9afb..5c6bad87 100644
--- a/src/middleware/compress.rs
+++ b/src/middleware/compress.rs
@@ -6,14 +6,46 @@ use std::str::FromStr;
 use actix_http::body::MessageBody;
 use actix_http::encoding::Encoder;
 use actix_http::http::header::{ContentEncoding, ACCEPT_ENCODING};
+use actix_http::ResponseBuilder;
 use actix_service::{Service, Transform};
 use futures::future::{ok, FutureResult};
 use futures::{Async, Future, Poll};
 
 use crate::service::{ServiceRequest, ServiceResponse};
 
+struct Enc(ContentEncoding);
+
+/// Helper trait that allows to set specific encoding for response.
+pub trait BodyEncoding {
+    fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self;
+}
+
+impl BodyEncoding for ResponseBuilder {
+    fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self {
+        self.extensions_mut().insert(Enc(encoding));
+        self
+    }
+}
+
 #[derive(Debug, Clone)]
 /// `Middleware` for compressing response body.
+///
+/// Use `BodyEncoding` trait for overriding response compression.
+/// To disable compression set encoding to `ContentEncoding::Identity` value.
+///
+/// ```rust
+/// use actix_web::{web, middleware::encoding, App, HttpResponse};
+///
+/// fn main() {
+///     let app = App::new()
+///         .wrap(encoding::Compress::default())
+///         .service(
+///             web::resource("/test")
+///                 .route(web::get().to(|| HttpResponse::Ok()))
+///                 .route(web::head().to(|| HttpResponse::MethodNotAllowed()))
+///         );
+/// }
+/// ```
 pub struct Compress(ContentEncoding);
 
 impl Compress {
@@ -118,8 +150,14 @@ where
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
         let resp = futures::try_ready!(self.fut.poll());
 
+        let enc = if let Some(enc) = resp.head().extensions().get::<Enc>() {
+            enc.0
+        } else {
+            self.encoding
+        };
+
         Ok(Async::Ready(resp.map_body(move |head, body| {
-            Encoder::response(self.encoding, head, body)
+            Encoder::response(enc, head, body)
         })))
     }
 }
diff --git a/src/middleware/decompress.rs b/src/middleware/decompress.rs
index d0a9bfd2..eaffbbdb 100644
--- a/src/middleware/decompress.rs
+++ b/src/middleware/decompress.rs
@@ -12,6 +12,22 @@ use crate::error::{Error, PayloadError};
 use crate::service::ServiceRequest;
 use crate::HttpMessage;
 
+/// `Middleware` for decompressing request's payload.
+/// `Decompress` middleware must be added with `App::chain()` method.
+///
+/// ```rust
+/// use actix_web::{web, middleware::encoding, App, HttpResponse};
+///
+/// fn main() {
+///     let app = App::new()
+///         .chain(encoding::Decompress::new())
+///         .service(
+///             web::resource("/test")
+///                 .route(web::get().to(|| HttpResponse::Ok()))
+///                 .route(web::head().to(|| HttpResponse::MethodNotAllowed()))
+///         );
+/// }
+/// ```
 pub struct Decompress<P>(PhantomData<P>);
 
 impl<P> Decompress<P>
diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs
index aee0ae3d..037d0006 100644
--- a/src/middleware/mod.rs
+++ b/src/middleware/mod.rs
@@ -1,13 +1,14 @@
 //! Middlewares
 #[cfg(any(feature = "brotli", feature = "flate2-zlib", feature = "flate2-rust"))]
 mod compress;
-#[cfg(any(feature = "brotli", feature = "flate2-zlib", feature = "flate2-rust"))]
-pub use self::compress::Compress;
-
 #[cfg(any(feature = "brotli", feature = "flate2-zlib", feature = "flate2-rust"))]
 mod decompress;
 #[cfg(any(feature = "brotli", feature = "flate2-zlib", feature = "flate2-rust"))]
-pub use self::decompress::Decompress;
+pub mod encoding {
+    //! Middlewares for compressing/decompressing payloads.
+    pub use super::compress::{BodyEncoding, Compress};
+    pub use super::decompress::Decompress;
+}
 
 pub mod cors;
 mod defaultheaders;
diff --git a/tests/test_server.rs b/tests/test_server.rs
index 29998bc0..364f9262 100644
--- a/tests/test_server.rs
+++ b/tests/test_server.rs
@@ -14,7 +14,7 @@ use flate2::Compression;
 use futures::stream::once; //Future, Stream
 use rand::{distributions::Alphanumeric, Rng};
 
-use actix_web::{dev::HttpMessageBody, middleware, web, App};
+use actix_web::{dev::HttpMessageBody, middleware::encoding, web, App};
 
 const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
                    Hello World Hello World Hello World Hello World Hello World \
@@ -60,7 +60,7 @@ fn test_body_gzip() {
     let mut srv = TestServer::new(|| {
         h1::H1Service::new(
             App::new()
-                .wrap(middleware::Compress::new(ContentEncoding::Gzip))
+                .wrap(encoding::Compress::new(ContentEncoding::Gzip))
                 .service(web::resource("/").route(web::to(|| Response::Ok().body(STR)))),
         )
     });
@@ -78,6 +78,32 @@ fn test_body_gzip() {
     assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
 }
 
+#[test]
+fn test_body_encoding_override() {
+    let mut srv = TestServer::new(|| {
+        h1::H1Service::new(
+            App::new()
+                .wrap(encoding::Compress::new(ContentEncoding::Gzip))
+                .service(web::resource("/").route(web::to(|| {
+                    use actix_web::middleware::encoding::BodyEncoding;
+                    Response::Ok().encoding(ContentEncoding::Deflate).body(STR)
+                }))),
+        )
+    });
+
+    let mut response = srv.block_on(srv.get().no_decompress().send()).unwrap();
+    assert!(response.status().is_success());
+
+    // read response
+    let bytes = srv.block_on(HttpMessageBody::new(&mut response)).unwrap();
+
+    // decode
+    let mut e = ZlibDecoder::new(Vec::new());
+    e.write_all(bytes.as_ref()).unwrap();
+    let dec = e.finish().unwrap();
+    assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
+}
+
 #[test]
 fn test_body_gzip_large() {
     let data = STR.repeat(10);
@@ -87,7 +113,7 @@ fn test_body_gzip_large() {
         let data = srv_data.clone();
         h1::H1Service::new(
             App::new()
-                .wrap(middleware::Compress::new(ContentEncoding::Gzip))
+                .wrap(encoding::Compress::new(ContentEncoding::Gzip))
                 .service(
                     web::resource("/")
                         .route(web::to(move || Response::Ok().body(data.clone()))),
@@ -120,7 +146,7 @@ fn test_body_gzip_large_random() {
         let data = srv_data.clone();
         h1::H1Service::new(
             App::new()
-                .wrap(middleware::Compress::new(ContentEncoding::Gzip))
+                .wrap(encoding::Compress::new(ContentEncoding::Gzip))
                 .service(
                     web::resource("/")
                         .route(web::to(move || Response::Ok().body(data.clone()))),
@@ -147,7 +173,7 @@ fn test_body_chunked_implicit() {
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
             App::new()
-                .wrap(middleware::Compress::new(ContentEncoding::Gzip))
+                .wrap(encoding::Compress::new(ContentEncoding::Gzip))
                 .service(web::resource("/").route(web::get().to(move || {
                     Response::Ok().streaming(once(Ok::<_, Error>(Bytes::from_static(
                         STR.as_ref(),
@@ -178,7 +204,7 @@ fn test_body_br_streaming() {
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
             App::new()
-                .wrap(middleware::Compress::new(ContentEncoding::Br))
+                .wrap(encoding::Compress::new(ContentEncoding::Br))
                 .service(web::resource("/").route(web::to(move || {
                     Response::Ok().streaming(once(Ok::<_, Error>(Bytes::from_static(
                         STR.as_ref(),
@@ -255,7 +281,7 @@ fn test_body_deflate() {
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
             App::new()
-                .wrap(middleware::Compress::new(ContentEncoding::Deflate))
+                .wrap(encoding::Compress::new(ContentEncoding::Deflate))
                 .service(
                     web::resource("/").route(web::to(move || Response::Ok().body(STR))),
                 ),
@@ -281,7 +307,7 @@ fn test_body_brotli() {
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
             App::new()
-                .wrap(middleware::Compress::new(ContentEncoding::Br))
+                .wrap(encoding::Compress::new(ContentEncoding::Br))
                 .service(
                     web::resource("/").route(web::to(move || Response::Ok().body(STR))),
                 ),
@@ -313,7 +339,7 @@ fn test_body_brotli() {
 fn test_gzip_encoding() {
     let mut srv = TestServer::new(move || {
         HttpService::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),
@@ -342,7 +368,7 @@ fn test_gzip_encoding_large() {
     let data = STR.repeat(10);
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),
@@ -375,7 +401,7 @@ fn test_reading_gzip_encoding_large_random() {
 
     let mut srv = TestServer::new(move || {
         HttpService::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),
@@ -404,7 +430,7 @@ fn test_reading_gzip_encoding_large_random() {
 fn test_reading_deflate_encoding() {
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),
@@ -433,7 +459,7 @@ fn test_reading_deflate_encoding_large() {
     let data = STR.repeat(10);
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),
@@ -466,7 +492,7 @@ fn test_reading_deflate_encoding_large_random() {
 
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),
@@ -496,7 +522,7 @@ fn test_reading_deflate_encoding_large_random() {
 fn test_brotli_encoding() {
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),
@@ -526,7 +552,7 @@ fn test_brotli_encoding_large() {
     let data = STR.repeat(10);
     let mut srv = TestServer::new(move || {
         h1::H1Service::new(
-            App::new().chain(middleware::Decompress::new()).service(
+            App::new().chain(encoding::Decompress::new()).service(
                 web::resource("/")
                     .route(web::to(move |body: Bytes| Response::Ok().body(body))),
             ),