From a7fdac1043a0a13985e46a5935c9eebd2834e4f4 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Mon, 8 Apr 2019 10:31:29 -0700
Subject: [PATCH] fix expect service registration and tests

---
 actix-http/src/builder.rs       | 25 +++++++++++++--
 actix-http/tests/test_server.rs | 56 ++++++++++++++++++++++++++++++++-
 2 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs
index 2a8a8360..6d93c156 100644
--- a/actix-http/src/builder.rs
+++ b/actix-http/src/builder.rs
@@ -87,6 +87,27 @@ where
         self
     }
 
+    /// Provide service for `EXPECT: 100-Continue` support.
+    ///
+    /// Service get called with request that contains `EXPECT` header.
+    /// Service must return request in case of success, in that case
+    /// request will be forwarded to main service.
+    pub fn expect<F, U>(self, expect: F) -> HttpServiceBuilder<T, S, U>
+    where
+        F: IntoNewService<U>,
+        U: NewService<Request = Request, Response = Request>,
+        U::Error: Into<Error>,
+        U::InitError: fmt::Debug,
+    {
+        HttpServiceBuilder {
+            keep_alive: self.keep_alive,
+            client_timeout: self.client_timeout,
+            client_disconnect: self.client_disconnect,
+            expect: expect.into_new_service(),
+            _t: PhantomData,
+        }
+    }
+
     // #[cfg(feature = "ssl")]
     // /// Configure alpn protocols for SslAcceptorBuilder.
     // pub fn configure_openssl(
@@ -142,7 +163,7 @@ where
     }
 
     /// Finish service configuration and create `HttpService` instance.
-    pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B>
+    pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B, X>
     where
         B: MessageBody + 'static,
         F: IntoNewService<S, SrvConfig>,
@@ -156,6 +177,6 @@ where
             self.client_timeout,
             self.client_disconnect,
         );
-        HttpService::with_config(cfg, service.into_new_service())
+        HttpService::with_config(cfg, service.into_new_service()).expect(self.expect)
     }
 }
diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs
index da41492f..e7b53937 100644
--- a/actix-http/tests/test_server.rs
+++ b/actix-http/tests/test_server.rs
@@ -5,7 +5,7 @@ use std::{net, thread};
 use actix_codec::{AsyncRead, AsyncWrite};
 use actix_http_test::TestServer;
 use actix_server_config::ServerConfig;
-use actix_service::{fn_cfg_factory, NewService};
+use actix_service::{fn_cfg_factory, fn_service, NewService};
 use bytes::{Bytes, BytesMut};
 use futures::future::{self, ok, Future};
 use futures::stream::{once, Stream};
@@ -153,6 +153,60 @@ fn test_h2_body() -> std::io::Result<()> {
     Ok(())
 }
 
+#[test]
+fn test_expect_continue() {
+    let srv = TestServer::new(|| {
+        HttpService::build()
+            .expect(fn_service(|req: Request| {
+                if req.head().uri.query() == Some("yes=") {
+                    Ok(req)
+                } else {
+                    Err(error::ErrorPreconditionFailed("error"))
+                }
+            }))
+            .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
+    });
+
+    let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
+    let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
+    let mut data = String::new();
+    let _ = stream.read_to_string(&mut data);
+    assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length"));
+
+    let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
+    let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
+    let mut data = String::new();
+    let _ = stream.read_to_string(&mut data);
+    assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
+}
+
+#[test]
+fn test_expect_continue_h1() {
+    let srv = TestServer::new(|| {
+        HttpService::build()
+            .expect(fn_service(|req: Request| {
+                if req.head().uri.query() == Some("yes=") {
+                    Ok(req)
+                } else {
+                    Err(error::ErrorPreconditionFailed("error"))
+                }
+            }))
+            .h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
+    });
+
+    let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
+    let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
+    let mut data = String::new();
+    let _ = stream.read_to_string(&mut data);
+    assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length"));
+
+    let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
+    let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
+    let mut data = String::new();
+    let _ = stream.read_to_string(&mut data);
+    assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
+}
+
 #[test]
 fn test_slow_request() {
     let srv = TestServer::new(|| {