mirror of https://github.com/fafhrd91/actix-web
				
				
				
			unlink MessageBody from Unpin
This commit is contained in:
		
							parent
							
								
									2e2ea7ab80
								
							
						
					
					
						commit
						ec5c779732
					
				|  | @ -33,7 +33,7 @@ impl BodySize { | |||
| } | ||||
| 
 | ||||
| /// Type that provides this trait can be streamed to a peer.
 | ||||
| pub trait MessageBody: Unpin { | ||||
| pub trait MessageBody { | ||||
|     fn size(&self) -> BodySize; | ||||
| 
 | ||||
|     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>>; | ||||
|  | @ -53,14 +53,13 @@ impl MessageBody for () { | |||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<T: MessageBody> MessageBody for Box<T> { | ||||
| impl<T: MessageBody + Unpin> MessageBody for Box<T> { | ||||
|     fn size(&self) -> BodySize { | ||||
|         self.as_ref().size() | ||||
|     } | ||||
| 
 | ||||
|     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { | ||||
|         let a: Pin<&mut T> = Pin::new(self.get_mut().as_mut()); | ||||
|         a.poll_next(cx) | ||||
|         unsafe { self.map_unchecked_mut(|boxed| boxed.as_mut()) }.poll_next(cx) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | @ -70,8 +69,7 @@ impl MessageBody for Box<dyn MessageBody> { | |||
|     } | ||||
| 
 | ||||
|     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { | ||||
|         let a: Pin<&mut dyn MessageBody> = Pin::new(self.get_mut().as_mut()); | ||||
|         a.poll_next(cx) | ||||
|         unsafe { Pin::new_unchecked(self.get_mut().as_mut()) }.poll_next(cx) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ use actix_service::Service; | |||
| use bitflags::bitflags; | ||||
| use bytes::{Buf, BytesMut}; | ||||
| use log::{error, trace}; | ||||
| use pin_project::pin_project; | ||||
| 
 | ||||
| use crate::body::{Body, BodySize, MessageBody, ResponseBody}; | ||||
| use crate::cloneable::CloneableService; | ||||
|  | @ -41,6 +42,7 @@ bitflags! { | |||
|     } | ||||
| } | ||||
| 
 | ||||
| #[pin_project::pin_project] | ||||
| /// Dispatcher for HTTP/1.1 protocol
 | ||||
| pub struct Dispatcher<T, S, B, X, U> | ||||
| where | ||||
|  | @ -52,9 +54,11 @@ where | |||
|     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, | ||||
|     U::Error: fmt::Display, | ||||
| { | ||||
|     #[pin] | ||||
|     inner: DispatcherState<T, S, B, X, U>, | ||||
| } | ||||
| 
 | ||||
| #[pin_project] | ||||
| enum DispatcherState<T, S, B, X, U> | ||||
| where | ||||
|     S: Service<Request = Request>, | ||||
|  | @ -65,11 +69,12 @@ where | |||
|     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, | ||||
|     U::Error: fmt::Display, | ||||
| { | ||||
|     Normal(InnerDispatcher<T, S, B, X, U>), | ||||
|     Upgrade(Pin<Box<U::Future>>), | ||||
|     Normal(#[pin] InnerDispatcher<T, S, B, X, U>), | ||||
|     Upgrade(#[pin] U::Future), | ||||
|     None, | ||||
| } | ||||
| 
 | ||||
| #[pin_project] | ||||
| struct InnerDispatcher<T, S, B, X, U> | ||||
| where | ||||
|     S: Service<Request = Request>, | ||||
|  | @ -88,6 +93,7 @@ where | |||
|     peer_addr: Option<net::SocketAddr>, | ||||
|     error: Option<DispatchError>, | ||||
| 
 | ||||
|     #[pin] | ||||
|     state: State<S, B, X>, | ||||
|     payload: Option<PayloadSender>, | ||||
|     messages: VecDeque<DispatcherMessage>, | ||||
|  | @ -107,6 +113,7 @@ enum DispatcherMessage { | |||
|     Error(Response<()>), | ||||
| } | ||||
| 
 | ||||
| #[pin_project] | ||||
| enum State<S, B, X> | ||||
| where | ||||
|     S: Service<Request = Request>, | ||||
|  | @ -114,9 +121,9 @@ where | |||
|     B: MessageBody, | ||||
| { | ||||
|     None, | ||||
|     ExpectCall(Pin<Box<X::Future>>), | ||||
|     ServiceCall(Pin<Box<S::Future>>), | ||||
|     SendPayload(ResponseBody<B>), | ||||
|     ExpectCall(#[pin] X::Future), | ||||
|     ServiceCall(#[pin] S::Future), | ||||
|     SendPayload(#[pin] ResponseBody<B>), | ||||
| } | ||||
| 
 | ||||
| impl<S, B, X> State<S, B, X> | ||||
|  | @ -142,6 +149,21 @@ where | |||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<T, S, B, X, U> DispatcherState<T, S, B, X, U> | ||||
| where | ||||
|     S: Service<Request = Request>, | ||||
|     S::Error: Into<Error>, | ||||
|     B: MessageBody, | ||||
|     X: Service<Request = Request, Response = Request>, | ||||
|     X::Error: Into<Error>, | ||||
|     U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, | ||||
|     U::Error: fmt::Display, | ||||
| { | ||||
|     fn take(self: Pin<&mut Self>) -> Self { | ||||
|         std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| enum PollResponse { | ||||
|     Upgrade(Request), | ||||
|     DoNothing, | ||||
|  | @ -278,10 +300,11 @@ where | |||
|     } | ||||
| 
 | ||||
|     // if checked is set to true, delay disconnect until all tasks have finished.
 | ||||
|     fn client_disconnected(&mut self) { | ||||
|         self.flags | ||||
|     fn client_disconnected(self: Pin<&mut Self>) { | ||||
|         let this = self.project(); | ||||
|         this.flags | ||||
|             .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT); | ||||
|         if let Some(mut payload) = self.payload.take() { | ||||
|         if let Some(mut payload) = this.payload.take() { | ||||
|             payload.set_error(PayloadError::Incomplete(None)); | ||||
|         } | ||||
|     } | ||||
|  | @ -290,16 +313,18 @@ where | |||
|     ///
 | ||||
|     /// true - got whouldblock
 | ||||
|     /// false - didnt get whouldblock
 | ||||
|     fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<bool, DispatchError> { | ||||
|     #[pin_project::project] | ||||
|     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> { | ||||
|         if self.write_buf.is_empty() { | ||||
|             return Ok(false); | ||||
|         } | ||||
| 
 | ||||
|         let len = self.write_buf.len(); | ||||
|         let mut written = 0; | ||||
|         #[project] | ||||
|         let InnerDispatcher { mut io, write_buf, .. } = self.project(); | ||||
|         while written < len { | ||||
|             match Pin::new(&mut self.io) | ||||
|                 .poll_write(cx, &self.write_buf[written..]) | ||||
|             match Pin::new(&mut io).poll_write(cx, &write_buf[written..]) | ||||
|             { | ||||
|                 Poll::Ready(Ok(0)) => { | ||||
|                     return Err(DispatchError::Io(io::Error::new( | ||||
|  | @ -312,113 +337,120 @@ where | |||
|                 } | ||||
|                 Poll::Pending => { | ||||
|                     if written > 0 { | ||||
|                         self.write_buf.advance(written); | ||||
|                         write_buf.advance(written); | ||||
|                     } | ||||
|                     return Ok(true); | ||||
|                 } | ||||
|                 Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), | ||||
|             } | ||||
|         } | ||||
|         if written == self.write_buf.len() { | ||||
|             unsafe { self.write_buf.set_len(0) } | ||||
|         if written == write_buf.len() { | ||||
|             unsafe { write_buf.set_len(0) } | ||||
|         } else { | ||||
|             self.write_buf.advance(written); | ||||
|             write_buf.advance(written); | ||||
|         } | ||||
|         Ok(false) | ||||
|     } | ||||
| 
 | ||||
|     fn send_response( | ||||
|         &mut self, | ||||
|         self: Pin<&mut Self>, | ||||
|         message: Response<()>, | ||||
|         body: ResponseBody<B>, | ||||
|     ) -> Result<State<S, B, X>, DispatchError> { | ||||
|         self.codec | ||||
|             .encode(Message::Item((message, body.size())), &mut self.write_buf) | ||||
|         let mut this = self.project(); | ||||
|         this.codec | ||||
|             .encode(Message::Item((message, body.size())), &mut this.write_buf) | ||||
|             .map_err(|err| { | ||||
|                 if let Some(mut payload) = self.payload.take() { | ||||
|                 if let Some(mut payload) = this.payload.take() { | ||||
|                     payload.set_error(PayloadError::Incomplete(None)); | ||||
|                 } | ||||
|                 DispatchError::Io(err) | ||||
|             })?; | ||||
| 
 | ||||
|         self.flags.set(Flags::KEEPALIVE, self.codec.keepalive()); | ||||
|         this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); | ||||
|         match body.size() { | ||||
|             BodySize::None | BodySize::Empty => Ok(State::None), | ||||
|             _ => Ok(State::SendPayload(body)), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn send_continue(&mut self) { | ||||
|         self.write_buf | ||||
|     fn send_continue(self: Pin<&mut Self>) { | ||||
|         self.project().write_buf | ||||
|             .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); | ||||
|     } | ||||
| 
 | ||||
|     #[pin_project::project] | ||||
|     fn poll_response( | ||||
|         &mut self, | ||||
|         mut self: Pin<&mut Self>, | ||||
|         cx: &mut Context<'_>, | ||||
|     ) -> Result<PollResponse, DispatchError> { | ||||
|         loop { | ||||
|             let state = match self.state { | ||||
|                 State::None => match self.messages.pop_front() { | ||||
|             let mut this = self.as_mut().project(); | ||||
|             #[project] | ||||
|             let state = match this.state.project() { | ||||
|                 State::None => match this.messages.pop_front() { | ||||
|                     Some(DispatcherMessage::Item(req)) => { | ||||
|                         Some(self.handle_request(req, cx)?) | ||||
|                         Some(self.as_mut().handle_request(req, cx)?) | ||||
|                     } | ||||
|                     Some(DispatcherMessage::Error(res)) => { | ||||
|                         Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) | ||||
|                         Some(self.as_mut().send_response(res, ResponseBody::Other(Body::Empty))?) | ||||
|                     } | ||||
|                     Some(DispatcherMessage::Upgrade(req)) => { | ||||
|                         return Ok(PollResponse::Upgrade(req)); | ||||
|                     } | ||||
|                     None => None, | ||||
|                 }, | ||||
|                 State::ExpectCall(ref mut fut) => { | ||||
|                     match fut.as_mut().poll(cx) { | ||||
|                 State::ExpectCall(fut) => { | ||||
|                     match fut.poll(cx) { | ||||
|                         Poll::Ready(Ok(req)) => { | ||||
|                             self.send_continue(); | ||||
|                             self.state = State::ServiceCall(Box::pin(self.service.call(req))); | ||||
|                             self.as_mut().send_continue(); | ||||
|                             this = self.as_mut().project(); | ||||
|                             this.state.set(State::ServiceCall(this.service.call(req))); | ||||
|                             continue; | ||||
|                         } | ||||
|                         Poll::Ready(Err(e)) => { | ||||
|                             let res: Response = e.into().into(); | ||||
|                             let (res, body) = res.replace_body(()); | ||||
|                             Some(self.send_response(res, body.into_body())?) | ||||
|                             Some(self.as_mut().send_response(res, body.into_body())?) | ||||
|                         } | ||||
|                         Poll::Pending => None, | ||||
|                     } | ||||
|                 } | ||||
|                 State::ServiceCall(ref mut fut) => { | ||||
|                     match fut.as_mut().poll(cx) { | ||||
|                 State::ServiceCall(fut) => { | ||||
|                     match fut.poll(cx) { | ||||
|                         Poll::Ready(Ok(res)) => { | ||||
|                             let (res, body) = res.into().replace_body(()); | ||||
|                             self.state = self.send_response(res, body)?; | ||||
|                             let state = self.as_mut().send_response(res, body)?; | ||||
|                             this = self.as_mut().project(); | ||||
|                             this.state.set(state); | ||||
|                             continue; | ||||
|                         } | ||||
|                         Poll::Ready(Err(e)) => { | ||||
|                             let res: Response = e.into().into(); | ||||
|                             let (res, body) = res.replace_body(()); | ||||
|                             Some(self.send_response(res, body.into_body())?) | ||||
|                             Some(self.as_mut().send_response(res, body.into_body())?) | ||||
|                         } | ||||
|                         Poll::Pending => None, | ||||
|                     } | ||||
|                 } | ||||
|                 State::SendPayload(ref mut stream) => { | ||||
|                     let mut stream = Pin::new(stream); | ||||
|                 State::SendPayload(mut stream) => { | ||||
|                     loop { | ||||
|                         if self.write_buf.len() < HW_BUFFER_SIZE { | ||||
|                         if this.write_buf.len() < HW_BUFFER_SIZE { | ||||
|                             match stream.as_mut().poll_next(cx) { | ||||
|                                 Poll::Ready(Some(Ok(item))) => { | ||||
|                                     self.codec.encode( | ||||
|                                     this.codec.encode( | ||||
|                                         Message::Chunk(Some(item)), | ||||
|                                         &mut self.write_buf, | ||||
|                                         &mut this.write_buf, | ||||
|                                     )?; | ||||
|                                     continue; | ||||
|                                 } | ||||
|                                 Poll::Ready(None) => { | ||||
|                                     self.codec.encode( | ||||
|                                     this.codec.encode( | ||||
|                                         Message::Chunk(None), | ||||
|                                         &mut self.write_buf, | ||||
|                                         &mut this.write_buf, | ||||
|                                     )?; | ||||
|                                     self.state = State::None; | ||||
|                                     this = self.as_mut().project(); | ||||
|                                     this.state.set(State::None); | ||||
|                                 } | ||||
|                                 Poll::Ready(Some(Err(_))) => { | ||||
|                                     return Err(DispatchError::Unknown) | ||||
|  | @ -434,9 +466,11 @@ where | |||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|             this = self.as_mut().project(); | ||||
| 
 | ||||
|             // set new state
 | ||||
|             if let Some(state) = state { | ||||
|                 self.state = state; | ||||
|                 this.state.set(state); | ||||
|                 if !self.state.is_empty() { | ||||
|                     continue; | ||||
|                 } | ||||
|  | @ -444,7 +478,7 @@ where | |||
|                 // if read-backpressure is enabled and we consumed some data.
 | ||||
|                 // we may read more data and retry
 | ||||
|                 if self.state.is_call() { | ||||
|                     if self.poll_request(cx)? { | ||||
|                     if self.as_mut().poll_request(cx)? { | ||||
|                         continue; | ||||
|                     } | ||||
|                 } else if !self.messages.is_empty() { | ||||
|  | @ -458,16 +492,16 @@ where | |||
|     } | ||||
| 
 | ||||
|     fn handle_request( | ||||
|         &mut self, | ||||
|         mut self: Pin<&mut Self>, | ||||
|         req: Request, | ||||
|         cx: &mut Context<'_>, | ||||
|     ) -> Result<State<S, B, X>, DispatchError> { | ||||
|         // Handle `EXPECT: 100-Continue` header
 | ||||
|         let req = if req.head().expect() { | ||||
|             let mut task = Box::pin(self.expect.call(req)); | ||||
|             match task.as_mut().poll(cx) { | ||||
|             let mut task = self.as_mut().project().expect.call(req); | ||||
|             match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { | ||||
|                 Poll::Ready(Ok(req)) => { | ||||
|                     self.send_continue(); | ||||
|                     self.as_mut().send_continue(); | ||||
|                     req | ||||
|                 } | ||||
|                 Poll::Pending => return Ok(State::ExpectCall(task)), | ||||
|  | @ -483,8 +517,8 @@ where | |||
|         }; | ||||
| 
 | ||||
|         // Call service
 | ||||
|         let mut task = Box::pin(self.service.call(req)); | ||||
|         match task.as_mut().poll(cx) { | ||||
|         let mut task = self.as_mut().project().service.call(req); | ||||
|         match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { | ||||
|             Poll::Ready(Ok(res)) => { | ||||
|                 let (res, body) = res.into().replace_body(()); | ||||
|                 self.send_response(res, body) | ||||
|  | @ -500,7 +534,7 @@ where | |||
| 
 | ||||
|     /// Process one incoming requests
 | ||||
|     pub(self) fn poll_request( | ||||
|         &mut self, | ||||
|         mut self: Pin<&mut Self>, | ||||
|         cx: &mut Context<'_>, | ||||
|     ) -> Result<bool, DispatchError> { | ||||
|         // limit a mount of non processed requests
 | ||||
|  | @ -509,24 +543,25 @@ where | |||
|         } | ||||
| 
 | ||||
|         let mut updated = false; | ||||
|         let mut this = self.as_mut().project(); | ||||
|         loop { | ||||
|             match self.codec.decode(&mut self.read_buf) { | ||||
|             match this.codec.decode(&mut this.read_buf) { | ||||
|                 Ok(Some(msg)) => { | ||||
|                     updated = true; | ||||
|                     self.flags.insert(Flags::STARTED); | ||||
|                     this.flags.insert(Flags::STARTED); | ||||
| 
 | ||||
|                     match msg { | ||||
|                         Message::Item(mut req) => { | ||||
|                             let pl = self.codec.message_type(); | ||||
|                             req.head_mut().peer_addr = self.peer_addr; | ||||
|                             let pl = this.codec.message_type(); | ||||
|                             req.head_mut().peer_addr = *this.peer_addr; | ||||
| 
 | ||||
|                             // set on_connect data
 | ||||
|                             if let Some(ref on_connect) = self.on_connect { | ||||
|                             if let Some(ref on_connect) = this.on_connect { | ||||
|                                 on_connect.set(&mut req.extensions_mut()); | ||||
|                             } | ||||
| 
 | ||||
|                             if pl == MessageType::Stream && self.upgrade.is_some() { | ||||
|                                 self.messages.push_back(DispatcherMessage::Upgrade(req)); | ||||
|                             if pl == MessageType::Stream && this.upgrade.is_some() { | ||||
|                                 this.messages.push_back(DispatcherMessage::Upgrade(req)); | ||||
|                                 break; | ||||
|                             } | ||||
|                             if pl == MessageType::Payload || pl == MessageType::Stream { | ||||
|  | @ -534,41 +569,43 @@ where | |||
|                                 let (req1, _) = | ||||
|                                     req.replace_payload(crate::Payload::H1(pl)); | ||||
|                                 req = req1; | ||||
|                                 self.payload = Some(ps); | ||||
|                                 *this.payload = Some(ps); | ||||
|                             } | ||||
| 
 | ||||
|                             // handle request early
 | ||||
|                             if self.state.is_empty() { | ||||
|                                 self.state = self.handle_request(req, cx)?; | ||||
|                             if this.state.is_empty() { | ||||
|                                 let state = self.as_mut().handle_request(req, cx)?; | ||||
|                                 this = self.as_mut().project(); | ||||
|                                 this.state.set(state); | ||||
|                             } else { | ||||
|                                 self.messages.push_back(DispatcherMessage::Item(req)); | ||||
|                                 this.messages.push_back(DispatcherMessage::Item(req)); | ||||
|                             } | ||||
|                         } | ||||
|                         Message::Chunk(Some(chunk)) => { | ||||
|                             if let Some(ref mut payload) = self.payload { | ||||
|                             if let Some(ref mut payload) = this.payload { | ||||
|                                 payload.feed_data(chunk); | ||||
|                             } else { | ||||
|                                 error!( | ||||
|                                     "Internal server error: unexpected payload chunk" | ||||
|                                 ); | ||||
|                                 self.flags.insert(Flags::READ_DISCONNECT); | ||||
|                                 self.messages.push_back(DispatcherMessage::Error( | ||||
|                                 this.flags.insert(Flags::READ_DISCONNECT); | ||||
|                                 this.messages.push_back(DispatcherMessage::Error( | ||||
|                                     Response::InternalServerError().finish().drop_body(), | ||||
|                                 )); | ||||
|                                 self.error = Some(DispatchError::InternalError); | ||||
|                                 *this.error = Some(DispatchError::InternalError); | ||||
|                                 break; | ||||
|                             } | ||||
|                         } | ||||
|                         Message::Chunk(None) => { | ||||
|                             if let Some(mut payload) = self.payload.take() { | ||||
|                             if let Some(mut payload) = this.payload.take() { | ||||
|                                 payload.feed_eof(); | ||||
|                             } else { | ||||
|                                 error!("Internal server error: unexpected eof"); | ||||
|                                 self.flags.insert(Flags::READ_DISCONNECT); | ||||
|                                 self.messages.push_back(DispatcherMessage::Error( | ||||
|                                 this.flags.insert(Flags::READ_DISCONNECT); | ||||
|                                 this.messages.push_back(DispatcherMessage::Error( | ||||
|                                     Response::InternalServerError().finish().drop_body(), | ||||
|                                 )); | ||||
|                                 self.error = Some(DispatchError::InternalError); | ||||
|                                 *this.error = Some(DispatchError::InternalError); | ||||
|                                 break; | ||||
|                             } | ||||
|                         } | ||||
|  | @ -576,44 +613,46 @@ where | |||
|                 } | ||||
|                 Ok(None) => break, | ||||
|                 Err(ParseError::Io(e)) => { | ||||
|                     self.client_disconnected(); | ||||
|                     self.error = Some(DispatchError::Io(e)); | ||||
|                     self.as_mut().client_disconnected(); | ||||
|                     this = self.as_mut().project(); | ||||
|                     *this.error = Some(DispatchError::Io(e)); | ||||
|                     break; | ||||
|                 } | ||||
|                 Err(e) => { | ||||
|                     if let Some(mut payload) = self.payload.take() { | ||||
|                     if let Some(mut payload) = this.payload.take() { | ||||
|                         payload.set_error(PayloadError::EncodingCorrupted); | ||||
|                     } | ||||
| 
 | ||||
|                     // Malformed requests should be responded with 400
 | ||||
|                     self.messages.push_back(DispatcherMessage::Error( | ||||
|                     this.messages.push_back(DispatcherMessage::Error( | ||||
|                         Response::BadRequest().finish().drop_body(), | ||||
|                     )); | ||||
|                     self.flags.insert(Flags::READ_DISCONNECT); | ||||
|                     self.error = Some(e.into()); | ||||
|                     this.flags.insert(Flags::READ_DISCONNECT); | ||||
|                     *this.error = Some(e.into()); | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if updated && self.ka_timer.is_some() { | ||||
|             if let Some(expire) = self.codec.config().keep_alive_expire() { | ||||
|                 self.ka_expire = expire; | ||||
|         if updated && this.ka_timer.is_some() { | ||||
|             if let Some(expire) = this.codec.config().keep_alive_expire() { | ||||
|                 *this.ka_expire = expire; | ||||
|             } | ||||
|         } | ||||
|         Ok(updated) | ||||
|     } | ||||
| 
 | ||||
|     /// keep-alive timer
 | ||||
|     fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> { | ||||
|         if self.ka_timer.is_none() { | ||||
|     fn poll_keepalive(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> { | ||||
|         let mut this = self.as_mut().project(); | ||||
|         if this.ka_timer.is_none() { | ||||
|             // shutdown timeout
 | ||||
|             if self.flags.contains(Flags::SHUTDOWN) { | ||||
|                 if let Some(interval) = self.codec.config().client_disconnect_timer() { | ||||
|                     self.ka_timer = Some(delay_until(interval)); | ||||
|             if this.flags.contains(Flags::SHUTDOWN) { | ||||
|                 if let Some(interval) = this.codec.config().client_disconnect_timer() { | ||||
|                     *this.ka_timer = Some(delay_until(interval)); | ||||
|                 } else { | ||||
|                     self.flags.insert(Flags::READ_DISCONNECT); | ||||
|                     if let Some(mut payload) = self.payload.take() { | ||||
|                     this.flags.insert(Flags::READ_DISCONNECT); | ||||
|                     if let Some(mut payload) = this.payload.take() { | ||||
|                         payload.set_error(PayloadError::Incomplete(None)); | ||||
|                     } | ||||
|                     return Ok(()); | ||||
|  | @ -623,55 +662,56 @@ where | |||
|             } | ||||
|         } | ||||
| 
 | ||||
|         match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { | ||||
|         match Pin::new(&mut this.ka_timer.as_mut().unwrap()).poll(cx) { | ||||
|             Poll::Ready(()) => { | ||||
|                 // if we get timeout during shutdown, drop connection
 | ||||
|                 if self.flags.contains(Flags::SHUTDOWN) { | ||||
|                 if this.flags.contains(Flags::SHUTDOWN) { | ||||
|                     return Err(DispatchError::DisconnectTimeout); | ||||
|                 } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { | ||||
|                 } else if this.ka_timer.as_mut().unwrap().deadline() >= *this.ka_expire { | ||||
|                     // check for any outstanding tasks
 | ||||
|                     if self.state.is_empty() && self.write_buf.is_empty() { | ||||
|                         if self.flags.contains(Flags::STARTED) { | ||||
|                     if this.state.is_empty() && this.write_buf.is_empty() { | ||||
|                         if this.flags.contains(Flags::STARTED) { | ||||
|                             trace!("Keep-alive timeout, close connection"); | ||||
|                             self.flags.insert(Flags::SHUTDOWN); | ||||
|                             this.flags.insert(Flags::SHUTDOWN); | ||||
| 
 | ||||
|                             // start shutdown timer
 | ||||
|                             if let Some(deadline) = | ||||
|                                 self.codec.config().client_disconnect_timer() | ||||
|                                 this.codec.config().client_disconnect_timer() | ||||
|                             { | ||||
|                                 if let Some(mut timer) = self.ka_timer.as_mut() { | ||||
|                                 if let Some(mut timer) = this.ka_timer.as_mut() { | ||||
|                                     timer.reset(deadline); | ||||
|                                     let _ = Pin::new(&mut timer).poll(cx); | ||||
|                                 } | ||||
|                             } else { | ||||
|                                 // no shutdown timeout, drop socket
 | ||||
|                                 self.flags.insert(Flags::WRITE_DISCONNECT); | ||||
|                                 this.flags.insert(Flags::WRITE_DISCONNECT); | ||||
|                                 return Ok(()); | ||||
|                             } | ||||
|                         } else { | ||||
|                             // timeout on first request (slow request) return 408
 | ||||
|                             if !self.flags.contains(Flags::STARTED) { | ||||
|                             if !this.flags.contains(Flags::STARTED) { | ||||
|                                 trace!("Slow request timeout"); | ||||
|                                 let _ = self.send_response( | ||||
|                                 let _ = self.as_mut().send_response( | ||||
|                                     Response::RequestTimeout().finish().drop_body(), | ||||
|                                     ResponseBody::Other(Body::Empty), | ||||
|                                 ); | ||||
|                                 this = self.as_mut().project(); | ||||
|                             } else { | ||||
|                                 trace!("Keep-alive connection timeout"); | ||||
|                             } | ||||
|                             self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); | ||||
|                             self.state = State::None; | ||||
|                             this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); | ||||
|                             this.state.set(State::None); | ||||
|                         } | ||||
|                     } else if let Some(deadline) = | ||||
|                         self.codec.config().keep_alive_expire() | ||||
|                         this.codec.config().keep_alive_expire() | ||||
|                     { | ||||
|                         if let Some(mut timer) = self.ka_timer.as_mut() { | ||||
|                         if let Some(mut timer) = this.ka_timer.as_mut() { | ||||
|                             timer.reset(deadline); | ||||
|                             let _ = Pin::new(&mut timer).poll(cx); | ||||
|                         } | ||||
|                     } | ||||
|                 } else if let Some(mut timer) = self.ka_timer.as_mut() { | ||||
|                     timer.reset(self.ka_expire); | ||||
|                 } else if let Some(mut timer) = this.ka_timer.as_mut() { | ||||
|                     timer.reset(*this.ka_expire); | ||||
|                     let _ = Pin::new(&mut timer).poll(cx); | ||||
|                 } | ||||
|             } | ||||
|  | @ -696,22 +736,25 @@ where | |||
| { | ||||
|     type Output = Result<(), DispatchError>; | ||||
| 
 | ||||
|     #[pin_project::project] | ||||
|     #[inline] | ||||
|     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||
|         match self.as_mut().inner { | ||||
|             DispatcherState::Normal(ref mut inner) => { | ||||
|                 inner.poll_keepalive(cx)?; | ||||
|         let this = self.as_mut().project(); | ||||
|         #[project] | ||||
|         match this.inner.project() { | ||||
|             DispatcherState::Normal(mut inner) => { | ||||
|                 inner.as_mut().poll_keepalive(cx)?; | ||||
| 
 | ||||
|                 if inner.flags.contains(Flags::SHUTDOWN) { | ||||
|                     if inner.flags.contains(Flags::WRITE_DISCONNECT) { | ||||
|                         Poll::Ready(Ok(())) | ||||
|                     } else { | ||||
|                         // flush buffer
 | ||||
|                         inner.poll_flush(cx)?; | ||||
|                         inner.as_mut().poll_flush(cx)?; | ||||
|                         if !inner.write_buf.is_empty() { | ||||
|                             Poll::Pending | ||||
|                         } else { | ||||
|                             match Pin::new(&mut inner.io).poll_shutdown(cx) { | ||||
|                             match Pin::new(inner.project().io).poll_shutdown(cx) { | ||||
|                                 Poll::Ready(res) => { | ||||
|                                     Poll::Ready(res.map_err(DispatchError::from)) | ||||
|                                 } | ||||
|  | @ -723,33 +766,34 @@ where | |||
|                     // read socket into a buf
 | ||||
|                     let should_disconnect = | ||||
|                         if !inner.flags.contains(Flags::READ_DISCONNECT) { | ||||
|                             read_available(cx, &mut inner.io, &mut inner.read_buf)? | ||||
|                             let mut inner_p = inner.as_mut().project(); | ||||
|                             read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)? | ||||
|                         } else { | ||||
|                             None | ||||
|                         }; | ||||
| 
 | ||||
|                     inner.poll_request(cx)?; | ||||
|                     inner.as_mut().poll_request(cx)?; | ||||
|                     if let Some(true) = should_disconnect { | ||||
|                         inner.flags.insert(Flags::READ_DISCONNECT); | ||||
|                         if let Some(mut payload) = inner.payload.take() { | ||||
|                         let inner_p = inner.as_mut().project(); | ||||
|                         inner_p.flags.insert(Flags::READ_DISCONNECT); | ||||
|                         if let Some(mut payload) = inner_p.payload.take() { | ||||
|                             payload.feed_eof(); | ||||
|                         } | ||||
|                     }; | ||||
| 
 | ||||
|                     loop { | ||||
|                         let inner_p = inner.as_mut().project(); | ||||
|                         let remaining = | ||||
|                             inner.write_buf.capacity() - inner.write_buf.len(); | ||||
|                             inner_p.write_buf.capacity() - inner_p.write_buf.len(); | ||||
|                         if remaining < LW_BUFFER_SIZE { | ||||
|                             inner.write_buf.reserve(HW_BUFFER_SIZE - remaining); | ||||
|                             inner_p.write_buf.reserve(HW_BUFFER_SIZE - remaining); | ||||
|                         } | ||||
|                         let result = inner.poll_response(cx)?; | ||||
|                         let result = inner.as_mut().poll_response(cx)?; | ||||
|                         let drain = result == PollResponse::DrainWriteBuf; | ||||
| 
 | ||||
|                         // switch to upgrade handler
 | ||||
|                         if let PollResponse::Upgrade(req) = result { | ||||
|                             if let DispatcherState::Normal(inner) = | ||||
|                                 std::mem::replace(&mut self.inner, DispatcherState::None) | ||||
|                             { | ||||
|                             if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() { | ||||
|                                 let mut parts = FramedParts::with_read_buf( | ||||
|                                     inner.io, | ||||
|                                     inner.codec, | ||||
|  | @ -757,9 +801,8 @@ where | |||
|                                 ); | ||||
|                                 parts.write_buf = inner.write_buf; | ||||
|                                 let framed = Framed::from_parts(parts); | ||||
|                                 self.inner = DispatcherState::Upgrade( | ||||
|                                     Box::pin(inner.upgrade.unwrap().call((req, framed))), | ||||
|                                 ); | ||||
|                                 let upgrade = inner.upgrade.unwrap().call((req, framed)); | ||||
|                                 self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade)); | ||||
|                                 return self.poll(cx); | ||||
|                             } else { | ||||
|                                 panic!() | ||||
|  | @ -769,7 +812,7 @@ where | |||
|                         // we didnt get WouldBlock from write operation,
 | ||||
|                         // so data get written to kernel completely (OSX)
 | ||||
|                         // and we have to write again otherwise response can get stuck
 | ||||
|                         if inner.poll_flush(cx)? || !drain { | ||||
|                         if inner.as_mut().poll_flush(cx)? || !drain { | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|  | @ -781,25 +824,26 @@ where | |||
| 
 | ||||
|                     let is_empty = inner.state.is_empty(); | ||||
| 
 | ||||
|                     let inner_p = inner.as_mut().project(); | ||||
|                     // read half is closed and we do not processing any responses
 | ||||
|                     if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty { | ||||
|                         inner.flags.insert(Flags::SHUTDOWN); | ||||
|                     if inner_p.flags.contains(Flags::READ_DISCONNECT) && is_empty { | ||||
|                         inner_p.flags.insert(Flags::SHUTDOWN); | ||||
|                     } | ||||
| 
 | ||||
|                     // keep-alive and stream errors
 | ||||
|                     if is_empty && inner.write_buf.is_empty() { | ||||
|                         if let Some(err) = inner.error.take() { | ||||
|                     if is_empty && inner_p.write_buf.is_empty() { | ||||
|                         if let Some(err) = inner_p.error.take() { | ||||
|                             Poll::Ready(Err(err)) | ||||
|                         } | ||||
|                         // disconnect if keep-alive is not enabled
 | ||||
|                         else if inner.flags.contains(Flags::STARTED) | ||||
|                             && !inner.flags.intersects(Flags::KEEPALIVE) | ||||
|                         else if inner_p.flags.contains(Flags::STARTED) | ||||
|                             && !inner_p.flags.intersects(Flags::KEEPALIVE) | ||||
|                         { | ||||
|                             inner.flags.insert(Flags::SHUTDOWN); | ||||
|                             inner_p.flags.insert(Flags::SHUTDOWN); | ||||
|                             self.poll(cx) | ||||
|                         } | ||||
|                         // disconnect if shutdown
 | ||||
|                         else if inner.flags.contains(Flags::SHUTDOWN) { | ||||
|                         else if inner_p.flags.contains(Flags::SHUTDOWN) { | ||||
|                             self.poll(cx) | ||||
|                         } else { | ||||
|                             Poll::Pending | ||||
|  |  | |||
|  | @ -36,7 +36,7 @@ where | |||
| impl<T, B> Future for SendResponse<T, B> | ||||
| where | ||||
|     T: AsyncRead + AsyncWrite, | ||||
|     B: MessageBody, | ||||
|     B: MessageBody + Unpin, | ||||
| { | ||||
|     type Output = Result<Framed<T, Codec>, Error>; | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue