From 2cf27863cb99a450d35a460c73a70919dfd29470 Mon Sep 17 00:00:00 2001
From: Rob Ede <robjtede@icloud.com>
Date: Fri, 17 Dec 2021 14:13:54 +0000
Subject: [PATCH] remove direct dep on pin-project in -http (#2524)

---
 actix-http/Cargo.toml           |   3 +-
 actix-http/src/h1/dispatcher.rs | 278 ++++++++++++++++++--------------
 2 files changed, 158 insertions(+), 123 deletions(-)

diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml
index 2e8ec1df..515574ab 100644
--- a/actix-http/Cargo.toml
+++ b/actix-http/Cargo.toml
@@ -45,7 +45,7 @@ __compress = []
 actix-service = "2.0.0"
 actix-codec = "0.4.1"
 actix-utils = "3.0.0"
-actix-rt = "2.2"
+actix-rt = { version = "2.2", default-features = false }
 
 ahash = "0.7"
 base64 = "0.13"
@@ -66,7 +66,6 @@ local-channel = "0.1"
 log = "0.4"
 mime = "0.3"
 percent-encoding = "2.1"
-pin-project = "1.0.0"
 pin-project-lite = "0.2"
 rand = "0.8"
 sha-1 = "0.9"
diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs
index 16d7c3c1..472845e6 100644
--- a/actix-http/src/h1/dispatcher.rs
+++ b/actix-http/src/h1/dispatcher.rs
@@ -15,7 +15,7 @@ use bitflags::bitflags;
 use bytes::{Buf, BytesMut};
 use futures_core::ready;
 use log::{error, trace};
-use pin_project::pin_project;
+use pin_project_lite::pin_project;
 
 use crate::{
     body::{BodySize, BoxBody, MessageBody},
@@ -46,79 +46,111 @@ bitflags! {
     }
 }
 
-#[pin_project]
-/// Dispatcher for HTTP/1.1 protocol
-pub struct Dispatcher<T, S, B, X, U>
-where
-    S: Service<Request>,
-    S::Error: Into<Response<BoxBody>>,
+// there's 2 versions of Dispatcher state because of:
+// https://github.com/taiki-e/pin-project-lite/issues/3
+//
+// tl;dr: pin-project-lite doesn't play well with other attribute macros
 
-    B: MessageBody,
+#[cfg(not(test))]
+pin_project! {
+    /// Dispatcher for HTTP/1.1 protocol
+    pub struct Dispatcher<T, S, B, X, U>
+    where
+        S: Service<Request>,
+        S::Error: Into<Response<BoxBody>>,
 
-    X: Service<Request, Response = Request>,
-    X::Error: Into<Response<BoxBody>>,
+        B: MessageBody,
 
-    U: Service<(Request, Framed<T, Codec>), Response = ()>,
-    U::Error: fmt::Display,
-{
-    #[pin]
-    inner: DispatcherState<T, S, B, X, U>,
+        X: Service<Request, Response = Request>,
+        X::Error: Into<Response<BoxBody>>,
 
-    #[cfg(test)]
-    poll_count: u64,
+        U: Service<(Request, Framed<T, Codec>), Response = ()>,
+        U::Error: fmt::Display,
+    {
+        #[pin]
+        inner: DispatcherState<T, S, B, X, U>,
+    }
 }
 
-#[pin_project(project = DispatcherStateProj)]
-enum DispatcherState<T, S, B, X, U>
-where
-    S: Service<Request>,
-    S::Error: Into<Response<BoxBody>>,
+#[cfg(test)]
+pin_project! {
+    /// Dispatcher for HTTP/1.1 protocol
+    pub struct Dispatcher<T, S, B, X, U>
+    where
+        S: Service<Request>,
+        S::Error: Into<Response<BoxBody>>,
 
-    B: MessageBody,
+        B: MessageBody,
 
-    X: Service<Request, Response = Request>,
-    X::Error: Into<Response<BoxBody>>,
+        X: Service<Request, Response = Request>,
+        X::Error: Into<Response<BoxBody>>,
 
-    U: Service<(Request, Framed<T, Codec>), Response = ()>,
-    U::Error: fmt::Display,
-{
-    Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
-    Upgrade(#[pin] U::Future),
+        U: Service<(Request, Framed<T, Codec>), Response = ()>,
+        U::Error: fmt::Display,
+    {
+        #[pin]
+        inner: DispatcherState<T, S, B, X, U>,
+
+        // used in tests
+        poll_count: u64,
+    }
 }
 
-#[pin_project(project = InnerDispatcherProj)]
-struct InnerDispatcher<T, S, B, X, U>
-where
-    S: Service<Request>,
-    S::Error: Into<Response<BoxBody>>,
+pin_project! {
+    #[project = DispatcherStateProj]
+    enum DispatcherState<T, S, B, X, U>
+    where
+        S: Service<Request>,
+        S::Error: Into<Response<BoxBody>>,
 
-    B: MessageBody,
+        B: MessageBody,
 
-    X: Service<Request, Response = Request>,
-    X::Error: Into<Response<BoxBody>>,
+        X: Service<Request, Response = Request>,
+        X::Error: Into<Response<BoxBody>>,
 
-    U: Service<(Request, Framed<T, Codec>), Response = ()>,
-    U::Error: fmt::Display,
-{
-    flow: Rc<HttpFlow<S, X, U>>,
-    flags: Flags,
-    peer_addr: Option<net::SocketAddr>,
-    conn_data: Option<Rc<Extensions>>,
-    error: Option<DispatchError>,
+        U: Service<(Request, Framed<T, Codec>), Response = ()>,
+        U::Error: fmt::Display,
+    {
+        Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U> },
+        Upgrade { #[pin] fut: U::Future },
+    }
+}
 
-    #[pin]
-    state: State<S, B, X>,
-    payload: Option<PayloadSender>,
-    messages: VecDeque<DispatcherMessage>,
+pin_project! {
+    #[project = InnerDispatcherProj]
+    struct InnerDispatcher<T, S, B, X, U>
+    where
+        S: Service<Request>,
+        S::Error: Into<Response<BoxBody>>,
 
-    ka_expire: Instant,
-    #[pin]
-    ka_timer: Option<Sleep>,
+        B: MessageBody,
 
-    io: Option<T>,
-    read_buf: BytesMut,
-    write_buf: BytesMut,
-    codec: Codec,
+        X: Service<Request, Response = Request>,
+        X::Error: Into<Response<BoxBody>>,
+
+        U: Service<(Request, Framed<T, Codec>), Response = ()>,
+        U::Error: fmt::Display,
+    {
+        flow: Rc<HttpFlow<S, X, U>>,
+        flags: Flags,
+        peer_addr: Option<net::SocketAddr>,
+        conn_data: Option<Rc<Extensions>>,
+        error: Option<DispatchError>,
+
+        #[pin]
+        state: State<S, B, X>,
+        payload: Option<PayloadSender>,
+        messages: VecDeque<DispatcherMessage>,
+
+        ka_expire: Instant,
+        #[pin]
+        ka_timer: Option<Sleep>,
+
+        io: Option<T>,
+        read_buf: BytesMut,
+        write_buf: BytesMut,
+        codec: Codec,
+    }
 }
 
 enum DispatcherMessage {
@@ -127,19 +159,21 @@ enum DispatcherMessage {
     Error(Response<()>),
 }
 
-#[pin_project(project = StateProj)]
-enum State<S, B, X>
-where
-    S: Service<Request>,
-    X: Service<Request, Response = Request>,
+pin_project! {
+    #[project = StateProj]
+    enum State<S, B, X>
+    where
+        S: Service<Request>,
+        X: Service<Request, Response = Request>,
 
-    B: MessageBody,
-{
-    None,
-    ExpectCall(#[pin] X::Future),
-    ServiceCall(#[pin] S::Future),
-    SendPayload(#[pin] B),
-    SendErrorPayload(#[pin] BoxBody),
+        B: MessageBody,
+    {
+        None,
+        ExpectCall { #[pin] fut: X::Future },
+        ServiceCall { #[pin] fut: S::Future },
+        SendPayload { #[pin] body: B },
+        SendErrorPayload { #[pin] body: BoxBody },
+    }
 }
 
 impl<S, B, X> State<S, B, X>
@@ -198,25 +232,27 @@ where
         };
 
         Dispatcher {
-            inner: DispatcherState::Normal(InnerDispatcher {
-                flow,
-                flags,
-                peer_addr,
-                conn_data: conn_data.0.map(Rc::new),
-                error: None,
+            inner: DispatcherState::Normal {
+                inner: InnerDispatcher {
+                    flow,
+                    flags,
+                    peer_addr,
+                    conn_data: conn_data.0.map(Rc::new),
+                    error: None,
 
-                state: State::None,
-                payload: None,
-                messages: VecDeque::new(),
+                    state: State::None,
+                    payload: None,
+                    messages: VecDeque::new(),
 
-                ka_expire,
-                ka_timer,
+                    ka_expire,
+                    ka_timer,
 
-                io: Some(io),
-                read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
-                write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
-                codec: Codec::new(config),
-            }),
+                    io: Some(io),
+                    read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
+                    write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
+                    codec: Codec::new(config),
+                },
+            },
 
             #[cfg(test)]
             poll_count: 0,
@@ -316,7 +352,7 @@ where
         let size = self.as_mut().send_response_inner(message, &body)?;
         let state = match size {
             BodySize::None | BodySize::Sized(0) => State::None,
-            _ => State::SendPayload(body),
+            _ => State::SendPayload { body },
         };
         self.project().state.set(state);
         Ok(())
@@ -330,7 +366,7 @@ where
         let size = self.as_mut().send_response_inner(message, &body)?;
         let state = match size {
             BodySize::None | BodySize::Sized(0) => State::None,
-            _ => State::SendErrorPayload(body),
+            _ => State::SendErrorPayload { body },
         };
         self.project().state.set(state);
         Ok(())
@@ -356,12 +392,12 @@ where
                         // Handle `EXPECT: 100-Continue` header
                         if req.head().expect() {
                             // set InnerDispatcher state and continue loop to poll it.
-                            let task = this.flow.expect.call(req);
-                            this.state.set(State::ExpectCall(task));
+                            let fut = this.flow.expect.call(req);
+                            this.state.set(State::ExpectCall { fut });
                         } else {
                             // the same as expect call.
-                            let task = this.flow.service.call(req);
-                            this.state.set(State::ServiceCall(task));
+                            let fut = this.flow.service.call(req);
+                            this.state.set(State::ServiceCall { fut });
                         };
                     }
 
@@ -381,7 +417,7 @@ where
                     // all messages are dealt with.
                     None => return Ok(PollResponse::DoNothing),
                 },
-                StateProj::ServiceCall(fut) => match fut.poll(cx) {
+                StateProj::ServiceCall { fut } => match fut.poll(cx) {
                     // service call resolved. send response.
                     Poll::Ready(Ok(res)) => {
                         let (res, body) = res.into().replace_body(());
@@ -407,11 +443,11 @@ where
                     }
                 },
 
-                StateProj::SendPayload(mut stream) => {
+                StateProj::SendPayload { mut body } => {
                     // keep populate writer buffer until buffer size limit hit,
                     // get blocked or finished.
                     while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
-                        match stream.as_mut().poll_next(cx) {
+                        match body.as_mut().poll_next(cx) {
                             Poll::Ready(Some(Ok(item))) => {
                                 this.codec
                                     .encode(Message::Chunk(Some(item)), this.write_buf)?;
@@ -437,13 +473,13 @@ where
                     return Ok(PollResponse::DrainWriteBuf);
                 }
 
-                StateProj::SendErrorPayload(mut stream) => {
+                StateProj::SendErrorPayload { mut body } => {
                     // TODO: de-dupe impl with SendPayload
 
                     // keep populate writer buffer until buffer size limit hit,
                     // get blocked or finished.
                     while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
-                        match stream.as_mut().poll_next(cx) {
+                        match body.as_mut().poll_next(cx) {
                             Poll::Ready(Some(Ok(item))) => {
                                 this.codec
                                     .encode(Message::Chunk(Some(item)), this.write_buf)?;
@@ -469,14 +505,14 @@ where
                     return Ok(PollResponse::DrainWriteBuf);
                 }
 
-                StateProj::ExpectCall(fut) => match fut.poll(cx) {
+                StateProj::ExpectCall { fut } => match fut.poll(cx) {
                     // expect resolved. write continue to buffer and set InnerDispatcher state
                     // to service call.
                     Poll::Ready(Ok(req)) => {
                         this.write_buf
                             .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
                         let fut = this.flow.service.call(req);
-                        this.state.set(State::ServiceCall(fut));
+                        this.state.set(State::ServiceCall { fut });
                     }
 
                     // send expect error as response
@@ -502,25 +538,25 @@ where
         let mut this = self.as_mut().project();
         if req.head().expect() {
             // set dispatcher state so the future is pinned.
-            let task = this.flow.expect.call(req);
-            this.state.set(State::ExpectCall(task));
+            let fut = this.flow.expect.call(req);
+            this.state.set(State::ExpectCall { fut });
         } else {
             // the same as above.
-            let task = this.flow.service.call(req);
-            this.state.set(State::ServiceCall(task));
+            let fut = this.flow.service.call(req);
+            this.state.set(State::ServiceCall { fut });
         };
 
         // eagerly poll the future for once(or twice if expect is resolved immediately).
         loop {
             match self.as_mut().project().state.project() {
-                StateProj::ExpectCall(fut) => {
+                StateProj::ExpectCall { fut } => {
                     match fut.poll(cx) {
                         // expect is resolved. continue loop and poll the service call branch.
                         Poll::Ready(Ok(req)) => {
                             self.as_mut().send_continue();
                             let mut this = self.as_mut().project();
-                            let task = this.flow.service.call(req);
-                            this.state.set(State::ServiceCall(task));
+                            let fut = this.flow.service.call(req);
+                            this.state.set(State::ServiceCall { fut });
                             continue;
                         }
                         // future is pending. return Ok(()) to notify that a new state is
@@ -536,7 +572,7 @@ where
                         }
                     }
                 }
-                StateProj::ServiceCall(fut) => {
+                StateProj::ServiceCall { fut } => {
                     // return no matter the service call future's result.
                     return match fut.poll(cx) {
                         // future is resolved. send response and return a result. On success
@@ -901,7 +937,7 @@ where
         }
 
         match this.inner.project() {
-            DispatcherStateProj::Normal(mut inner) => {
+            DispatcherStateProj::Normal { mut inner } => {
                 inner.as_mut().poll_keepalive(cx)?;
 
                 if inner.flags.contains(Flags::SHUTDOWN) {
@@ -941,7 +977,7 @@ where
                                 self.as_mut()
                                     .project()
                                     .inner
-                                    .set(DispatcherState::Upgrade(upgrade));
+                                    .set(DispatcherState::Upgrade { fut: upgrade });
                                 return self.poll(cx);
                             }
                         };
@@ -993,8 +1029,8 @@ where
                     }
                 }
             }
-            DispatcherStateProj::Upgrade(fut) => fut.poll(cx).map_err(|e| {
-                error!("Upgrade handler error: {}", e);
+            DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| {
+                error!("Upgrade handler error: {}", err);
                 DispatchError::Upgrade
             }),
         }
@@ -1088,7 +1124,7 @@ mod tests {
                 Poll::Ready(res) => assert!(res.is_err()),
             }
 
-            if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() {
+            if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() {
                 assert!(inner.flags.contains(Flags::READ_DISCONNECT));
                 assert_eq!(
                     &inner.project().io.take().unwrap().write_buf[..26],
@@ -1123,7 +1159,7 @@ mod tests {
 
             actix_rt::pin!(h1);
 
-            assert!(matches!(&h1.inner, DispatcherState::Normal(_)));
+            assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
 
             match h1.as_mut().poll(cx) {
                 Poll::Pending => panic!("first poll should not be pending"),
@@ -1133,7 +1169,7 @@ mod tests {
             // polls: initial => shutdown
             assert_eq!(h1.poll_count, 2);
 
-            if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() {
+            if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() {
                 let res = &mut inner.project().io.take().unwrap().write_buf[..];
                 stabilize_date_header(res);
 
@@ -1177,7 +1213,7 @@ mod tests {
 
             actix_rt::pin!(h1);
 
-            assert!(matches!(&h1.inner, DispatcherState::Normal(_)));
+            assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
 
             match h1.as_mut().poll(cx) {
                 Poll::Pending => panic!("first poll should not be pending"),
@@ -1187,7 +1223,7 @@ mod tests {
             // polls: initial => shutdown
             assert_eq!(h1.poll_count, 1);
 
-            if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() {
+            if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() {
                 let res = &mut inner.project().io.take().unwrap().write_buf[..];
                 stabilize_date_header(res);
 
@@ -1237,13 +1273,13 @@ mod tests {
             actix_rt::pin!(h1);
 
             assert!(h1.as_mut().poll(cx).is_pending());
-            assert!(matches!(&h1.inner, DispatcherState::Normal(_)));
+            assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
 
             // polls: manual
             assert_eq!(h1.poll_count, 1);
             eprintln!("poll count: {}", h1.poll_count);
 
-            if let DispatcherState::Normal(ref inner) = h1.inner {
+            if let DispatcherState::Normal { ref inner } = h1.inner {
                 let io = inner.io.as_ref().unwrap();
                 let res = &io.write_buf()[..];
                 assert_eq!(
@@ -1258,7 +1294,7 @@ mod tests {
             // polls: manual manual shutdown
             assert_eq!(h1.poll_count, 3);
 
-            if let DispatcherState::Normal(ref inner) = h1.inner {
+            if let DispatcherState::Normal { ref inner } = h1.inner {
                 let io = inner.io.as_ref().unwrap();
                 let mut res = (&io.write_buf()[..]).to_owned();
                 stabilize_date_header(&mut res);
@@ -1309,12 +1345,12 @@ mod tests {
             actix_rt::pin!(h1);
 
             assert!(h1.as_mut().poll(cx).is_ready());
-            assert!(matches!(&h1.inner, DispatcherState::Normal(_)));
+            assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
 
             // polls: manual shutdown
             assert_eq!(h1.poll_count, 2);
 
-            if let DispatcherState::Normal(ref inner) = h1.inner {
+            if let DispatcherState::Normal { ref inner } = h1.inner {
                 let io = inner.io.as_ref().unwrap();
                 let mut res = (&io.write_buf()[..]).to_owned();
                 stabilize_date_header(&mut res);
@@ -1386,7 +1422,7 @@ mod tests {
             actix_rt::pin!(h1);
 
             assert!(h1.as_mut().poll(cx).is_ready());
-            assert!(matches!(&h1.inner, DispatcherState::Upgrade(_)));
+            assert!(matches!(&h1.inner, DispatcherState::Upgrade { .. }));
 
             // polls: manual shutdown
             assert_eq!(h1.poll_count, 2);