mirror of https://github.com/fafhrd91/actix-web
				
				
				
			unlink MessageBody from Unpin
This commit is contained in:
		
							parent
							
								
									2e2ea7ab80
								
							
						
					
					
						commit
						ec5c779732
					
				|  | @ -33,7 +33,7 @@ impl BodySize { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Type that provides this trait can be streamed to a peer.
 | /// Type that provides this trait can be streamed to a peer.
 | ||||||
| pub trait MessageBody: Unpin { | pub trait MessageBody { | ||||||
|     fn size(&self) -> BodySize; |     fn size(&self) -> BodySize; | ||||||
| 
 | 
 | ||||||
|     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>>; |     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>>; | ||||||
|  | @ -53,14 +53,13 @@ impl MessageBody for () { | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<T: MessageBody> MessageBody for Box<T> { | impl<T: MessageBody + Unpin> MessageBody for Box<T> { | ||||||
|     fn size(&self) -> BodySize { |     fn size(&self) -> BodySize { | ||||||
|         self.as_ref().size() |         self.as_ref().size() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { |     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { | ||||||
|         let a: Pin<&mut T> = Pin::new(self.get_mut().as_mut()); |         unsafe { self.map_unchecked_mut(|boxed| boxed.as_mut()) }.poll_next(cx) | ||||||
|         a.poll_next(cx) |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -70,8 +69,7 @@ impl MessageBody for Box<dyn MessageBody> { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { |     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { | ||||||
|         let a: Pin<&mut dyn MessageBody> = Pin::new(self.get_mut().as_mut()); |         unsafe { Pin::new_unchecked(self.get_mut().as_mut()) }.poll_next(cx) | ||||||
|         a.poll_next(cx) |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -10,6 +10,7 @@ use actix_service::Service; | ||||||
| use bitflags::bitflags; | use bitflags::bitflags; | ||||||
| use bytes::{Buf, BytesMut}; | use bytes::{Buf, BytesMut}; | ||||||
| use log::{error, trace}; | use log::{error, trace}; | ||||||
|  | use pin_project::pin_project; | ||||||
| 
 | 
 | ||||||
| use crate::body::{Body, BodySize, MessageBody, ResponseBody}; | use crate::body::{Body, BodySize, MessageBody, ResponseBody}; | ||||||
| use crate::cloneable::CloneableService; | use crate::cloneable::CloneableService; | ||||||
|  | @ -41,6 +42,7 @@ bitflags! { | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #[pin_project::pin_project] | ||||||
| /// Dispatcher for HTTP/1.1 protocol
 | /// Dispatcher for HTTP/1.1 protocol
 | ||||||
| pub struct Dispatcher<T, S, B, X, U> | pub struct Dispatcher<T, S, B, X, U> | ||||||
| where | where | ||||||
|  | @ -52,9 +54,11 @@ where | ||||||
|     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, |     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, | ||||||
|     U::Error: fmt::Display, |     U::Error: fmt::Display, | ||||||
| { | { | ||||||
|  |     #[pin] | ||||||
|     inner: DispatcherState<T, S, B, X, U>, |     inner: DispatcherState<T, S, B, X, U>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #[pin_project] | ||||||
| enum DispatcherState<T, S, B, X, U> | enum DispatcherState<T, S, B, X, U> | ||||||
| where | where | ||||||
|     S: Service<Request = Request>, |     S: Service<Request = Request>, | ||||||
|  | @ -65,11 +69,12 @@ where | ||||||
|     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, |     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, | ||||||
|     U::Error: fmt::Display, |     U::Error: fmt::Display, | ||||||
| { | { | ||||||
|     Normal(InnerDispatcher<T, S, B, X, U>), |     Normal(#[pin] InnerDispatcher<T, S, B, X, U>), | ||||||
|     Upgrade(Pin<Box<U::Future>>), |     Upgrade(#[pin] U::Future), | ||||||
|     None, |     None, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #[pin_project] | ||||||
| struct InnerDispatcher<T, S, B, X, U> | struct InnerDispatcher<T, S, B, X, U> | ||||||
| where | where | ||||||
|     S: Service<Request = Request>, |     S: Service<Request = Request>, | ||||||
|  | @ -88,6 +93,7 @@ where | ||||||
|     peer_addr: Option<net::SocketAddr>, |     peer_addr: Option<net::SocketAddr>, | ||||||
|     error: Option<DispatchError>, |     error: Option<DispatchError>, | ||||||
| 
 | 
 | ||||||
|  |     #[pin] | ||||||
|     state: State<S, B, X>, |     state: State<S, B, X>, | ||||||
|     payload: Option<PayloadSender>, |     payload: Option<PayloadSender>, | ||||||
|     messages: VecDeque<DispatcherMessage>, |     messages: VecDeque<DispatcherMessage>, | ||||||
|  | @ -107,6 +113,7 @@ enum DispatcherMessage { | ||||||
|     Error(Response<()>), |     Error(Response<()>), | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #[pin_project] | ||||||
| enum State<S, B, X> | enum State<S, B, X> | ||||||
| where | where | ||||||
|     S: Service<Request = Request>, |     S: Service<Request = Request>, | ||||||
|  | @ -114,9 +121,9 @@ where | ||||||
|     B: MessageBody, |     B: MessageBody, | ||||||
| { | { | ||||||
|     None, |     None, | ||||||
|     ExpectCall(Pin<Box<X::Future>>), |     ExpectCall(#[pin] X::Future), | ||||||
|     ServiceCall(Pin<Box<S::Future>>), |     ServiceCall(#[pin] S::Future), | ||||||
|     SendPayload(ResponseBody<B>), |     SendPayload(#[pin] ResponseBody<B>), | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<S, B, X> State<S, B, X> | impl<S, B, X> State<S, B, X> | ||||||
|  | @ -142,6 +149,21 @@ where | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | impl<T, S, B, X, U> DispatcherState<T, S, B, X, U> | ||||||
|  | where | ||||||
|  |     S: Service<Request = Request>, | ||||||
|  |     S::Error: Into<Error>, | ||||||
|  |     B: MessageBody, | ||||||
|  |     X: Service<Request = Request, Response = Request>, | ||||||
|  |     X::Error: Into<Error>, | ||||||
|  |     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, | ||||||
|  |     U::Error: fmt::Display, | ||||||
|  | { | ||||||
|  |     fn take(self: Pin<&mut Self>) -> Self { | ||||||
|  |         std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| enum PollResponse { | enum PollResponse { | ||||||
|     Upgrade(Request), |     Upgrade(Request), | ||||||
|     DoNothing, |     DoNothing, | ||||||
|  | @ -278,10 +300,11 @@ where | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // if checked is set to true, delay disconnect until all tasks have finished.
 |     // if checked is set to true, delay disconnect until all tasks have finished.
 | ||||||
|     fn client_disconnected(&mut self) { |     fn client_disconnected(self: Pin<&mut Self>) { | ||||||
|         self.flags |         let this = self.project(); | ||||||
|  |         this.flags | ||||||
|             .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT); |             .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT); | ||||||
|         if let Some(mut payload) = self.payload.take() { |         if let Some(mut payload) = this.payload.take() { | ||||||
|             payload.set_error(PayloadError::Incomplete(None)); |             payload.set_error(PayloadError::Incomplete(None)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | @ -290,16 +313,18 @@ where | ||||||
|     ///
 |     ///
 | ||||||
|     /// true - got whouldblock
 |     /// true - got whouldblock
 | ||||||
|     /// false - didnt get whouldblock
 |     /// false - didnt get whouldblock
 | ||||||
|     fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<bool, DispatchError> { |     #[pin_project::project] | ||||||
|  |     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> { | ||||||
|         if self.write_buf.is_empty() { |         if self.write_buf.is_empty() { | ||||||
|             return Ok(false); |             return Ok(false); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let len = self.write_buf.len(); |         let len = self.write_buf.len(); | ||||||
|         let mut written = 0; |         let mut written = 0; | ||||||
|  |         #[project] | ||||||
|  |         let InnerDispatcher { mut io, write_buf, .. } = self.project(); | ||||||
|         while written < len { |         while written < len { | ||||||
|             match Pin::new(&mut self.io) |             match Pin::new(&mut io).poll_write(cx, &write_buf[written..]) | ||||||
|                 .poll_write(cx, &self.write_buf[written..]) |  | ||||||
|             { |             { | ||||||
|                 Poll::Ready(Ok(0)) => { |                 Poll::Ready(Ok(0)) => { | ||||||
|                     return Err(DispatchError::Io(io::Error::new( |                     return Err(DispatchError::Io(io::Error::new( | ||||||
|  | @ -312,113 +337,120 @@ where | ||||||
|                 } |                 } | ||||||
|                 Poll::Pending => { |                 Poll::Pending => { | ||||||
|                     if written > 0 { |                     if written > 0 { | ||||||
|                         self.write_buf.advance(written); |                         write_buf.advance(written); | ||||||
|                     } |                     } | ||||||
|                     return Ok(true); |                     return Ok(true); | ||||||
|                 } |                 } | ||||||
|                 Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), |                 Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         if written == self.write_buf.len() { |         if written == write_buf.len() { | ||||||
|             unsafe { self.write_buf.set_len(0) } |             unsafe { write_buf.set_len(0) } | ||||||
|         } else { |         } else { | ||||||
|             self.write_buf.advance(written); |             write_buf.advance(written); | ||||||
|         } |         } | ||||||
|         Ok(false) |         Ok(false) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn send_response( |     fn send_response( | ||||||
|         &mut self, |         self: Pin<&mut Self>, | ||||||
|         message: Response<()>, |         message: Response<()>, | ||||||
|         body: ResponseBody<B>, |         body: ResponseBody<B>, | ||||||
|     ) -> Result<State<S, B, X>, DispatchError> { |     ) -> Result<State<S, B, X>, DispatchError> { | ||||||
|         self.codec |         let mut this = self.project(); | ||||||
|             .encode(Message::Item((message, body.size())), &mut self.write_buf) |         this.codec | ||||||
|  |             .encode(Message::Item((message, body.size())), &mut this.write_buf) | ||||||
|             .map_err(|err| { |             .map_err(|err| { | ||||||
|                 if let Some(mut payload) = self.payload.take() { |                 if let Some(mut payload) = this.payload.take() { | ||||||
|                     payload.set_error(PayloadError::Incomplete(None)); |                     payload.set_error(PayloadError::Incomplete(None)); | ||||||
|                 } |                 } | ||||||
|                 DispatchError::Io(err) |                 DispatchError::Io(err) | ||||||
|             })?; |             })?; | ||||||
| 
 | 
 | ||||||
|         self.flags.set(Flags::KEEPALIVE, self.codec.keepalive()); |         this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); | ||||||
|         match body.size() { |         match body.size() { | ||||||
|             BodySize::None | BodySize::Empty => Ok(State::None), |             BodySize::None | BodySize::Empty => Ok(State::None), | ||||||
|             _ => Ok(State::SendPayload(body)), |             _ => Ok(State::SendPayload(body)), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn send_continue(&mut self) { |     fn send_continue(self: Pin<&mut Self>) { | ||||||
|         self.write_buf |         self.project().write_buf | ||||||
|             .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); |             .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     #[pin_project::project] | ||||||
|     fn poll_response( |     fn poll_response( | ||||||
|         &mut self, |         mut self: Pin<&mut Self>, | ||||||
|         cx: &mut Context<'_>, |         cx: &mut Context<'_>, | ||||||
|     ) -> Result<PollResponse, DispatchError> { |     ) -> Result<PollResponse, DispatchError> { | ||||||
|         loop { |         loop { | ||||||
|             let state = match self.state { |             let mut this = self.as_mut().project(); | ||||||
|                 State::None => match self.messages.pop_front() { |             #[project] | ||||||
|  |             let state = match this.state.project() { | ||||||
|  |                 State::None => match this.messages.pop_front() { | ||||||
|                     Some(DispatcherMessage::Item(req)) => { |                     Some(DispatcherMessage::Item(req)) => { | ||||||
|                         Some(self.handle_request(req, cx)?) |                         Some(self.as_mut().handle_request(req, cx)?) | ||||||
|                     } |                     } | ||||||
|                     Some(DispatcherMessage::Error(res)) => { |                     Some(DispatcherMessage::Error(res)) => { | ||||||
|                         Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) |                         Some(self.as_mut().send_response(res, ResponseBody::Other(Body::Empty))?) | ||||||
|                     } |                     } | ||||||
|                     Some(DispatcherMessage::Upgrade(req)) => { |                     Some(DispatcherMessage::Upgrade(req)) => { | ||||||
|                         return Ok(PollResponse::Upgrade(req)); |                         return Ok(PollResponse::Upgrade(req)); | ||||||
|                     } |                     } | ||||||
|                     None => None, |                     None => None, | ||||||
|                 }, |                 }, | ||||||
|                 State::ExpectCall(ref mut fut) => { |                 State::ExpectCall(fut) => { | ||||||
|                     match fut.as_mut().poll(cx) { |                     match fut.poll(cx) { | ||||||
|                         Poll::Ready(Ok(req)) => { |                         Poll::Ready(Ok(req)) => { | ||||||
|                             self.send_continue(); |                             self.as_mut().send_continue(); | ||||||
|                             self.state = State::ServiceCall(Box::pin(self.service.call(req))); |                             this = self.as_mut().project(); | ||||||
|  |                             this.state.set(State::ServiceCall(this.service.call(req))); | ||||||
|                             continue; |                             continue; | ||||||
|                         } |                         } | ||||||
|                         Poll::Ready(Err(e)) => { |                         Poll::Ready(Err(e)) => { | ||||||
|                             let res: Response = e.into().into(); |                             let res: Response = e.into().into(); | ||||||
|                             let (res, body) = res.replace_body(()); |                             let (res, body) = res.replace_body(()); | ||||||
|                             Some(self.send_response(res, body.into_body())?) |                             Some(self.as_mut().send_response(res, body.into_body())?) | ||||||
|                         } |                         } | ||||||
|                         Poll::Pending => None, |                         Poll::Pending => None, | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 State::ServiceCall(ref mut fut) => { |                 State::ServiceCall(fut) => { | ||||||
|                     match fut.as_mut().poll(cx) { |                     match fut.poll(cx) { | ||||||
|                         Poll::Ready(Ok(res)) => { |                         Poll::Ready(Ok(res)) => { | ||||||
|                             let (res, body) = res.into().replace_body(()); |                             let (res, body) = res.into().replace_body(()); | ||||||
|                             self.state = self.send_response(res, body)?; |                             let state = self.as_mut().send_response(res, body)?; | ||||||
|  |                             this = self.as_mut().project(); | ||||||
|  |                             this.state.set(state); | ||||||
|                             continue; |                             continue; | ||||||
|                         } |                         } | ||||||
|                         Poll::Ready(Err(e)) => { |                         Poll::Ready(Err(e)) => { | ||||||
|                             let res: Response = e.into().into(); |                             let res: Response = e.into().into(); | ||||||
|                             let (res, body) = res.replace_body(()); |                             let (res, body) = res.replace_body(()); | ||||||
|                             Some(self.send_response(res, body.into_body())?) |                             Some(self.as_mut().send_response(res, body.into_body())?) | ||||||
|                         } |                         } | ||||||
|                         Poll::Pending => None, |                         Poll::Pending => None, | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 State::SendPayload(ref mut stream) => { |                 State::SendPayload(mut stream) => { | ||||||
|                     let mut stream = Pin::new(stream); |  | ||||||
|                     loop { |                     loop { | ||||||
|                         if self.write_buf.len() < HW_BUFFER_SIZE { |                         if this.write_buf.len() < HW_BUFFER_SIZE { | ||||||
|                             match stream.as_mut().poll_next(cx) { |                             match stream.as_mut().poll_next(cx) { | ||||||
|                                 Poll::Ready(Some(Ok(item))) => { |                                 Poll::Ready(Some(Ok(item))) => { | ||||||
|                                     self.codec.encode( |                                     this.codec.encode( | ||||||
|                                         Message::Chunk(Some(item)), |                                         Message::Chunk(Some(item)), | ||||||
|                                         &mut self.write_buf, |                                         &mut this.write_buf, | ||||||
|                                     )?; |                                     )?; | ||||||
|                                     continue; |                                     continue; | ||||||
|                                 } |                                 } | ||||||
|                                 Poll::Ready(None) => { |                                 Poll::Ready(None) => { | ||||||
|                                     self.codec.encode( |                                     this.codec.encode( | ||||||
|                                         Message::Chunk(None), |                                         Message::Chunk(None), | ||||||
|                                         &mut self.write_buf, |                                         &mut this.write_buf, | ||||||
|                                     )?; |                                     )?; | ||||||
|                                     self.state = State::None; |                                     this = self.as_mut().project(); | ||||||
|  |                                     this.state.set(State::None); | ||||||
|                                 } |                                 } | ||||||
|                                 Poll::Ready(Some(Err(_))) => { |                                 Poll::Ready(Some(Err(_))) => { | ||||||
|                                     return Err(DispatchError::Unknown) |                                     return Err(DispatchError::Unknown) | ||||||
|  | @ -434,9 +466,11 @@ where | ||||||
|                 } |                 } | ||||||
|             }; |             }; | ||||||
| 
 | 
 | ||||||
|  |             this = self.as_mut().project(); | ||||||
|  | 
 | ||||||
|             // set new state
 |             // set new state
 | ||||||
|             if let Some(state) = state { |             if let Some(state) = state { | ||||||
|                 self.state = state; |                 this.state.set(state); | ||||||
|                 if !self.state.is_empty() { |                 if !self.state.is_empty() { | ||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
|  | @ -444,7 +478,7 @@ where | ||||||
|                 // if read-backpressure is enabled and we consumed some data.
 |                 // if read-backpressure is enabled and we consumed some data.
 | ||||||
|                 // we may read more data and retry
 |                 // we may read more data and retry
 | ||||||
|                 if self.state.is_call() { |                 if self.state.is_call() { | ||||||
|                     if self.poll_request(cx)? { |                     if self.as_mut().poll_request(cx)? { | ||||||
|                         continue; |                         continue; | ||||||
|                     } |                     } | ||||||
|                 } else if !self.messages.is_empty() { |                 } else if !self.messages.is_empty() { | ||||||
|  | @ -458,16 +492,16 @@ where | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn handle_request( |     fn handle_request( | ||||||
|         &mut self, |         mut self: Pin<&mut Self>, | ||||||
|         req: Request, |         req: Request, | ||||||
|         cx: &mut Context<'_>, |         cx: &mut Context<'_>, | ||||||
|     ) -> Result<State<S, B, X>, DispatchError> { |     ) -> Result<State<S, B, X>, DispatchError> { | ||||||
|         // Handle `EXPECT: 100-Continue` header
 |         // Handle `EXPECT: 100-Continue` header
 | ||||||
|         let req = if req.head().expect() { |         let req = if req.head().expect() { | ||||||
|             let mut task = Box::pin(self.expect.call(req)); |             let mut task = self.as_mut().project().expect.call(req); | ||||||
|             match task.as_mut().poll(cx) { |             match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { | ||||||
|                 Poll::Ready(Ok(req)) => { |                 Poll::Ready(Ok(req)) => { | ||||||
|                     self.send_continue(); |                     self.as_mut().send_continue(); | ||||||
|                     req |                     req | ||||||
|                 } |                 } | ||||||
|                 Poll::Pending => return Ok(State::ExpectCall(task)), |                 Poll::Pending => return Ok(State::ExpectCall(task)), | ||||||
|  | @ -483,8 +517,8 @@ where | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         // Call service
 |         // Call service
 | ||||||
|         let mut task = Box::pin(self.service.call(req)); |         let mut task = self.as_mut().project().service.call(req); | ||||||
|         match task.as_mut().poll(cx) { |         match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { | ||||||
|             Poll::Ready(Ok(res)) => { |             Poll::Ready(Ok(res)) => { | ||||||
|                 let (res, body) = res.into().replace_body(()); |                 let (res, body) = res.into().replace_body(()); | ||||||
|                 self.send_response(res, body) |                 self.send_response(res, body) | ||||||
|  | @ -500,7 +534,7 @@ where | ||||||
| 
 | 
 | ||||||
|     /// Process one incoming requests
 |     /// Process one incoming requests
 | ||||||
|     pub(self) fn poll_request( |     pub(self) fn poll_request( | ||||||
|         &mut self, |         mut self: Pin<&mut Self>, | ||||||
|         cx: &mut Context<'_>, |         cx: &mut Context<'_>, | ||||||
|     ) -> Result<bool, DispatchError> { |     ) -> Result<bool, DispatchError> { | ||||||
|         // limit a mount of non processed requests
 |         // limit a mount of non processed requests
 | ||||||
|  | @ -509,24 +543,25 @@ where | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let mut updated = false; |         let mut updated = false; | ||||||
|  |         let mut this = self.as_mut().project(); | ||||||
|         loop { |         loop { | ||||||
|             match self.codec.decode(&mut self.read_buf) { |             match this.codec.decode(&mut this.read_buf) { | ||||||
|                 Ok(Some(msg)) => { |                 Ok(Some(msg)) => { | ||||||
|                     updated = true; |                     updated = true; | ||||||
|                     self.flags.insert(Flags::STARTED); |                     this.flags.insert(Flags::STARTED); | ||||||
| 
 | 
 | ||||||
|                     match msg { |                     match msg { | ||||||
|                         Message::Item(mut req) => { |                         Message::Item(mut req) => { | ||||||
|                             let pl = self.codec.message_type(); |                             let pl = this.codec.message_type(); | ||||||
|                             req.head_mut().peer_addr = self.peer_addr; |                             req.head_mut().peer_addr = *this.peer_addr; | ||||||
| 
 | 
 | ||||||
|                             // set on_connect data
 |                             // set on_connect data
 | ||||||
|                             if let Some(ref on_connect) = self.on_connect { |                             if let Some(ref on_connect) = this.on_connect { | ||||||
|                                 on_connect.set(&mut req.extensions_mut()); |                                 on_connect.set(&mut req.extensions_mut()); | ||||||
|                             } |                             } | ||||||
| 
 | 
 | ||||||
|                             if pl == MessageType::Stream && self.upgrade.is_some() { |                             if pl == MessageType::Stream && this.upgrade.is_some() { | ||||||
|                                 self.messages.push_back(DispatcherMessage::Upgrade(req)); |                                 this.messages.push_back(DispatcherMessage::Upgrade(req)); | ||||||
|                                 break; |                                 break; | ||||||
|                             } |                             } | ||||||
|                             if pl == MessageType::Payload || pl == MessageType::Stream { |                             if pl == MessageType::Payload || pl == MessageType::Stream { | ||||||
|  | @ -534,41 +569,43 @@ where | ||||||
|                                 let (req1, _) = |                                 let (req1, _) = | ||||||
|                                     req.replace_payload(crate::Payload::H1(pl)); |                                     req.replace_payload(crate::Payload::H1(pl)); | ||||||
|                                 req = req1; |                                 req = req1; | ||||||
|                                 self.payload = Some(ps); |                                 *this.payload = Some(ps); | ||||||
|                             } |                             } | ||||||
| 
 | 
 | ||||||
|                             // handle request early
 |                             // handle request early
 | ||||||
|                             if self.state.is_empty() { |                             if this.state.is_empty() { | ||||||
|                                 self.state = self.handle_request(req, cx)?; |                                 let state = self.as_mut().handle_request(req, cx)?; | ||||||
|  |                                 this = self.as_mut().project(); | ||||||
|  |                                 this.state.set(state); | ||||||
|                             } else { |                             } else { | ||||||
|                                 self.messages.push_back(DispatcherMessage::Item(req)); |                                 this.messages.push_back(DispatcherMessage::Item(req)); | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                         Message::Chunk(Some(chunk)) => { |                         Message::Chunk(Some(chunk)) => { | ||||||
|                             if let Some(ref mut payload) = self.payload { |                             if let Some(ref mut payload) = this.payload { | ||||||
|                                 payload.feed_data(chunk); |                                 payload.feed_data(chunk); | ||||||
|                             } else { |                             } else { | ||||||
|                                 error!( |                                 error!( | ||||||
|                                     "Internal server error: unexpected payload chunk" |                                     "Internal server error: unexpected payload chunk" | ||||||
|                                 ); |                                 ); | ||||||
|                                 self.flags.insert(Flags::READ_DISCONNECT); |                                 this.flags.insert(Flags::READ_DISCONNECT); | ||||||
|                                 self.messages.push_back(DispatcherMessage::Error( |                                 this.messages.push_back(DispatcherMessage::Error( | ||||||
|                                     Response::InternalServerError().finish().drop_body(), |                                     Response::InternalServerError().finish().drop_body(), | ||||||
|                                 )); |                                 )); | ||||||
|                                 self.error = Some(DispatchError::InternalError); |                                 *this.error = Some(DispatchError::InternalError); | ||||||
|                                 break; |                                 break; | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                         Message::Chunk(None) => { |                         Message::Chunk(None) => { | ||||||
|                             if let Some(mut payload) = self.payload.take() { |                             if let Some(mut payload) = this.payload.take() { | ||||||
|                                 payload.feed_eof(); |                                 payload.feed_eof(); | ||||||
|                             } else { |                             } else { | ||||||
|                                 error!("Internal server error: unexpected eof"); |                                 error!("Internal server error: unexpected eof"); | ||||||
|                                 self.flags.insert(Flags::READ_DISCONNECT); |                                 this.flags.insert(Flags::READ_DISCONNECT); | ||||||
|                                 self.messages.push_back(DispatcherMessage::Error( |                                 this.messages.push_back(DispatcherMessage::Error( | ||||||
|                                     Response::InternalServerError().finish().drop_body(), |                                     Response::InternalServerError().finish().drop_body(), | ||||||
|                                 )); |                                 )); | ||||||
|                                 self.error = Some(DispatchError::InternalError); |                                 *this.error = Some(DispatchError::InternalError); | ||||||
|                                 break; |                                 break; | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|  | @ -576,44 +613,46 @@ where | ||||||
|                 } |                 } | ||||||
|                 Ok(None) => break, |                 Ok(None) => break, | ||||||
|                 Err(ParseError::Io(e)) => { |                 Err(ParseError::Io(e)) => { | ||||||
|                     self.client_disconnected(); |                     self.as_mut().client_disconnected(); | ||||||
|                     self.error = Some(DispatchError::Io(e)); |                     this = self.as_mut().project(); | ||||||
|  |                     *this.error = Some(DispatchError::Io(e)); | ||||||
|                     break; |                     break; | ||||||
|                 } |                 } | ||||||
|                 Err(e) => { |                 Err(e) => { | ||||||
|                     if let Some(mut payload) = self.payload.take() { |                     if let Some(mut payload) = this.payload.take() { | ||||||
|                         payload.set_error(PayloadError::EncodingCorrupted); |                         payload.set_error(PayloadError::EncodingCorrupted); | ||||||
|                     } |                     } | ||||||
| 
 | 
 | ||||||
|                     // Malformed requests should be responded with 400
 |                     // Malformed requests should be responded with 400
 | ||||||
|                     self.messages.push_back(DispatcherMessage::Error( |                     this.messages.push_back(DispatcherMessage::Error( | ||||||
|                         Response::BadRequest().finish().drop_body(), |                         Response::BadRequest().finish().drop_body(), | ||||||
|                     )); |                     )); | ||||||
|                     self.flags.insert(Flags::READ_DISCONNECT); |                     this.flags.insert(Flags::READ_DISCONNECT); | ||||||
|                     self.error = Some(e.into()); |                     *this.error = Some(e.into()); | ||||||
|                     break; |                     break; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if updated && self.ka_timer.is_some() { |         if updated && this.ka_timer.is_some() { | ||||||
|             if let Some(expire) = self.codec.config().keep_alive_expire() { |             if let Some(expire) = this.codec.config().keep_alive_expire() { | ||||||
|                 self.ka_expire = expire; |                 *this.ka_expire = expire; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         Ok(updated) |         Ok(updated) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// keep-alive timer
 |     /// keep-alive timer
 | ||||||
|     fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> { |     fn poll_keepalive(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> { | ||||||
|         if self.ka_timer.is_none() { |         let mut this = self.as_mut().project(); | ||||||
|  |         if this.ka_timer.is_none() { | ||||||
|             // shutdown timeout
 |             // shutdown timeout
 | ||||||
|             if self.flags.contains(Flags::SHUTDOWN) { |             if this.flags.contains(Flags::SHUTDOWN) { | ||||||
|                 if let Some(interval) = self.codec.config().client_disconnect_timer() { |                 if let Some(interval) = this.codec.config().client_disconnect_timer() { | ||||||
|                     self.ka_timer = Some(delay_until(interval)); |                     *this.ka_timer = Some(delay_until(interval)); | ||||||
|                 } else { |                 } else { | ||||||
|                     self.flags.insert(Flags::READ_DISCONNECT); |                     this.flags.insert(Flags::READ_DISCONNECT); | ||||||
|                     if let Some(mut payload) = self.payload.take() { |                     if let Some(mut payload) = this.payload.take() { | ||||||
|                         payload.set_error(PayloadError::Incomplete(None)); |                         payload.set_error(PayloadError::Incomplete(None)); | ||||||
|                     } |                     } | ||||||
|                     return Ok(()); |                     return Ok(()); | ||||||
|  | @ -623,55 +662,56 @@ where | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { |         match Pin::new(&mut this.ka_timer.as_mut().unwrap()).poll(cx) { | ||||||
|             Poll::Ready(()) => { |             Poll::Ready(()) => { | ||||||
|                 // if we get timeout during shutdown, drop connection
 |                 // if we get timeout during shutdown, drop connection
 | ||||||
|                 if self.flags.contains(Flags::SHUTDOWN) { |                 if this.flags.contains(Flags::SHUTDOWN) { | ||||||
|                     return Err(DispatchError::DisconnectTimeout); |                     return Err(DispatchError::DisconnectTimeout); | ||||||
|                 } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { |                 } else if this.ka_timer.as_mut().unwrap().deadline() >= *this.ka_expire { | ||||||
|                     // check for any outstanding tasks
 |                     // check for any outstanding tasks
 | ||||||
|                     if self.state.is_empty() && self.write_buf.is_empty() { |                     if this.state.is_empty() && this.write_buf.is_empty() { | ||||||
|                         if self.flags.contains(Flags::STARTED) { |                         if this.flags.contains(Flags::STARTED) { | ||||||
|                             trace!("Keep-alive timeout, close connection"); |                             trace!("Keep-alive timeout, close connection"); | ||||||
|                             self.flags.insert(Flags::SHUTDOWN); |                             this.flags.insert(Flags::SHUTDOWN); | ||||||
| 
 | 
 | ||||||
|                             // start shutdown timer
 |                             // start shutdown timer
 | ||||||
|                             if let Some(deadline) = |                             if let Some(deadline) = | ||||||
|                                 self.codec.config().client_disconnect_timer() |                                 this.codec.config().client_disconnect_timer() | ||||||
|                             { |                             { | ||||||
|                                 if let Some(mut timer) = self.ka_timer.as_mut() { |                                 if let Some(mut timer) = this.ka_timer.as_mut() { | ||||||
|                                     timer.reset(deadline); |                                     timer.reset(deadline); | ||||||
|                                     let _ = Pin::new(&mut timer).poll(cx); |                                     let _ = Pin::new(&mut timer).poll(cx); | ||||||
|                                 } |                                 } | ||||||
|                             } else { |                             } else { | ||||||
|                                 // no shutdown timeout, drop socket
 |                                 // no shutdown timeout, drop socket
 | ||||||
|                                 self.flags.insert(Flags::WRITE_DISCONNECT); |                                 this.flags.insert(Flags::WRITE_DISCONNECT); | ||||||
|                                 return Ok(()); |                                 return Ok(()); | ||||||
|                             } |                             } | ||||||
|                         } else { |                         } else { | ||||||
|                             // timeout on first request (slow request) return 408
 |                             // timeout on first request (slow request) return 408
 | ||||||
|                             if !self.flags.contains(Flags::STARTED) { |                             if !this.flags.contains(Flags::STARTED) { | ||||||
|                                 trace!("Slow request timeout"); |                                 trace!("Slow request timeout"); | ||||||
|                                 let _ = self.send_response( |                                 let _ = self.as_mut().send_response( | ||||||
|                                     Response::RequestTimeout().finish().drop_body(), |                                     Response::RequestTimeout().finish().drop_body(), | ||||||
|                                     ResponseBody::Other(Body::Empty), |                                     ResponseBody::Other(Body::Empty), | ||||||
|                                 ); |                                 ); | ||||||
|  |                                 this = self.as_mut().project(); | ||||||
|                             } else { |                             } else { | ||||||
|                                 trace!("Keep-alive connection timeout"); |                                 trace!("Keep-alive connection timeout"); | ||||||
|                             } |                             } | ||||||
|                             self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); |                             this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); | ||||||
|                             self.state = State::None; |                             this.state.set(State::None); | ||||||
|                         } |                         } | ||||||
|                     } else if let Some(deadline) = |                     } else if let Some(deadline) = | ||||||
|                         self.codec.config().keep_alive_expire() |                         this.codec.config().keep_alive_expire() | ||||||
|                     { |                     { | ||||||
|                         if let Some(mut timer) = self.ka_timer.as_mut() { |                         if let Some(mut timer) = this.ka_timer.as_mut() { | ||||||
|                             timer.reset(deadline); |                             timer.reset(deadline); | ||||||
|                             let _ = Pin::new(&mut timer).poll(cx); |                             let _ = Pin::new(&mut timer).poll(cx); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 } else if let Some(mut timer) = self.ka_timer.as_mut() { |                 } else if let Some(mut timer) = this.ka_timer.as_mut() { | ||||||
|                     timer.reset(self.ka_expire); |                     timer.reset(*this.ka_expire); | ||||||
|                     let _ = Pin::new(&mut timer).poll(cx); |                     let _ = Pin::new(&mut timer).poll(cx); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  | @ -696,22 +736,25 @@ where | ||||||
| { | { | ||||||
|     type Output = Result<(), DispatchError>; |     type Output = Result<(), DispatchError>; | ||||||
| 
 | 
 | ||||||
|  |     #[pin_project::project] | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||||
|         match self.as_mut().inner { |         let this = self.as_mut().project(); | ||||||
|             DispatcherState::Normal(ref mut inner) => { |         #[project] | ||||||
|                 inner.poll_keepalive(cx)?; |         match this.inner.project() { | ||||||
|  |             DispatcherState::Normal(mut inner) => { | ||||||
|  |                 inner.as_mut().poll_keepalive(cx)?; | ||||||
| 
 | 
 | ||||||
|                 if inner.flags.contains(Flags::SHUTDOWN) { |                 if inner.flags.contains(Flags::SHUTDOWN) { | ||||||
|                     if inner.flags.contains(Flags::WRITE_DISCONNECT) { |                     if inner.flags.contains(Flags::WRITE_DISCONNECT) { | ||||||
|                         Poll::Ready(Ok(())) |                         Poll::Ready(Ok(())) | ||||||
|                     } else { |                     } else { | ||||||
|                         // flush buffer
 |                         // flush buffer
 | ||||||
|                         inner.poll_flush(cx)?; |                         inner.as_mut().poll_flush(cx)?; | ||||||
|                         if !inner.write_buf.is_empty() { |                         if !inner.write_buf.is_empty() { | ||||||
|                             Poll::Pending |                             Poll::Pending | ||||||
|                         } else { |                         } else { | ||||||
|                             match Pin::new(&mut inner.io).poll_shutdown(cx) { |                             match Pin::new(inner.project().io).poll_shutdown(cx) { | ||||||
|                                 Poll::Ready(res) => { |                                 Poll::Ready(res) => { | ||||||
|                                     Poll::Ready(res.map_err(DispatchError::from)) |                                     Poll::Ready(res.map_err(DispatchError::from)) | ||||||
|                                 } |                                 } | ||||||
|  | @ -723,33 +766,34 @@ where | ||||||
|                     // read socket into a buf
 |                     // read socket into a buf
 | ||||||
|                     let should_disconnect = |                     let should_disconnect = | ||||||
|                         if !inner.flags.contains(Flags::READ_DISCONNECT) { |                         if !inner.flags.contains(Flags::READ_DISCONNECT) { | ||||||
|                             read_available(cx, &mut inner.io, &mut inner.read_buf)? |                             let mut inner_p = inner.as_mut().project(); | ||||||
|  |                             read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)? | ||||||
|                         } else { |                         } else { | ||||||
|                             None |                             None | ||||||
|                         }; |                         }; | ||||||
| 
 | 
 | ||||||
|                     inner.poll_request(cx)?; |                     inner.as_mut().poll_request(cx)?; | ||||||
|                     if let Some(true) = should_disconnect { |                     if let Some(true) = should_disconnect { | ||||||
|                         inner.flags.insert(Flags::READ_DISCONNECT); |                         let inner_p = inner.as_mut().project(); | ||||||
|                         if let Some(mut payload) = inner.payload.take() { |                         inner_p.flags.insert(Flags::READ_DISCONNECT); | ||||||
|  |                         if let Some(mut payload) = inner_p.payload.take() { | ||||||
|                             payload.feed_eof(); |                             payload.feed_eof(); | ||||||
|                         } |                         } | ||||||
|                     }; |                     }; | ||||||
| 
 | 
 | ||||||
|                     loop { |                     loop { | ||||||
|  |                         let inner_p = inner.as_mut().project(); | ||||||
|                         let remaining = |                         let remaining = | ||||||
|                             inner.write_buf.capacity() - inner.write_buf.len(); |                             inner_p.write_buf.capacity() - inner_p.write_buf.len(); | ||||||
|                         if remaining < LW_BUFFER_SIZE { |                         if remaining < LW_BUFFER_SIZE { | ||||||
|                             inner.write_buf.reserve(HW_BUFFER_SIZE - remaining); |                             inner_p.write_buf.reserve(HW_BUFFER_SIZE - remaining); | ||||||
|                         } |                         } | ||||||
|                         let result = inner.poll_response(cx)?; |                         let result = inner.as_mut().poll_response(cx)?; | ||||||
|                         let drain = result == PollResponse::DrainWriteBuf; |                         let drain = result == PollResponse::DrainWriteBuf; | ||||||
| 
 | 
 | ||||||
|                         // switch to upgrade handler
 |                         // switch to upgrade handler
 | ||||||
|                         if let PollResponse::Upgrade(req) = result { |                         if let PollResponse::Upgrade(req) = result { | ||||||
|                             if let DispatcherState::Normal(inner) = |                             if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() { | ||||||
|                                 std::mem::replace(&mut self.inner, DispatcherState::None) |  | ||||||
|                             { |  | ||||||
|                                 let mut parts = FramedParts::with_read_buf( |                                 let mut parts = FramedParts::with_read_buf( | ||||||
|                                     inner.io, |                                     inner.io, | ||||||
|                                     inner.codec, |                                     inner.codec, | ||||||
|  | @ -757,9 +801,8 @@ where | ||||||
|                                 ); |                                 ); | ||||||
|                                 parts.write_buf = inner.write_buf; |                                 parts.write_buf = inner.write_buf; | ||||||
|                                 let framed = Framed::from_parts(parts); |                                 let framed = Framed::from_parts(parts); | ||||||
|                                 self.inner = DispatcherState::Upgrade( |                                 let upgrade = inner.upgrade.unwrap().call((req, framed)); | ||||||
|                                     Box::pin(inner.upgrade.unwrap().call((req, framed))), |                                 self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade)); | ||||||
|                                 ); |  | ||||||
|                                 return self.poll(cx); |                                 return self.poll(cx); | ||||||
|                             } else { |                             } else { | ||||||
|                                 panic!() |                                 panic!() | ||||||
|  | @ -769,7 +812,7 @@ where | ||||||
|                         // we didnt get WouldBlock from write operation,
 |                         // we didnt get WouldBlock from write operation,
 | ||||||
|                         // so data get written to kernel completely (OSX)
 |                         // so data get written to kernel completely (OSX)
 | ||||||
|                         // and we have to write again otherwise response can get stuck
 |                         // and we have to write again otherwise response can get stuck
 | ||||||
|                         if inner.poll_flush(cx)? || !drain { |                         if inner.as_mut().poll_flush(cx)? || !drain { | ||||||
|                             break; |                             break; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  | @ -781,25 +824,26 @@ where | ||||||
| 
 | 
 | ||||||
|                     let is_empty = inner.state.is_empty(); |                     let is_empty = inner.state.is_empty(); | ||||||
| 
 | 
 | ||||||
|  |                     let inner_p = inner.as_mut().project(); | ||||||
|                     // read half is closed and we do not processing any responses
 |                     // read half is closed and we do not processing any responses
 | ||||||
|                     if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty { |                     if inner_p.flags.contains(Flags::READ_DISCONNECT) && is_empty { | ||||||
|                         inner.flags.insert(Flags::SHUTDOWN); |                         inner_p.flags.insert(Flags::SHUTDOWN); | ||||||
|                     } |                     } | ||||||
| 
 | 
 | ||||||
|                     // keep-alive and stream errors
 |                     // keep-alive and stream errors
 | ||||||
|                     if is_empty && inner.write_buf.is_empty() { |                     if is_empty && inner_p.write_buf.is_empty() { | ||||||
|                         if let Some(err) = inner.error.take() { |                         if let Some(err) = inner_p.error.take() { | ||||||
|                             Poll::Ready(Err(err)) |                             Poll::Ready(Err(err)) | ||||||
|                         } |                         } | ||||||
|                         // disconnect if keep-alive is not enabled
 |                         // disconnect if keep-alive is not enabled
 | ||||||
|                         else if inner.flags.contains(Flags::STARTED) |                         else if inner_p.flags.contains(Flags::STARTED) | ||||||
|                             && !inner.flags.intersects(Flags::KEEPALIVE) |                             && !inner_p.flags.intersects(Flags::KEEPALIVE) | ||||||
|                         { |                         { | ||||||
|                             inner.flags.insert(Flags::SHUTDOWN); |                             inner_p.flags.insert(Flags::SHUTDOWN); | ||||||
|                             self.poll(cx) |                             self.poll(cx) | ||||||
|                         } |                         } | ||||||
|                         // disconnect if shutdown
 |                         // disconnect if shutdown
 | ||||||
|                         else if inner.flags.contains(Flags::SHUTDOWN) { |                         else if inner_p.flags.contains(Flags::SHUTDOWN) { | ||||||
|                             self.poll(cx) |                             self.poll(cx) | ||||||
|                         } else { |                         } else { | ||||||
|                             Poll::Pending |                             Poll::Pending | ||||||
|  |  | ||||||
|  | @ -36,7 +36,7 @@ where | ||||||
| impl<T, B> Future for SendResponse<T, B> | impl<T, B> Future for SendResponse<T, B> | ||||||
| where | where | ||||||
|     T: AsyncRead + AsyncWrite, |     T: AsyncRead + AsyncWrite, | ||||||
|     B: MessageBody, |     B: MessageBody + Unpin, | ||||||
| { | { | ||||||
|     type Output = Result<Framed<T, Codec>, Error>; |     type Output = Result<Framed<T, Codec>, Error>; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue