refactor OpenWaitingConnection guard. use pin_project_lite

This commit is contained in:
fakeshadow 2021-02-08 13:32:22 -08:00
parent dddb623a11
commit 36830b98f6
20 changed files with 624 additions and 562 deletions

View File

@ -19,6 +19,7 @@
* `client::error::ConnectError` Resolver variant contains `Box<dyn std::error::Error>` type [#1905] * `client::error::ConnectError` Resolver variant contains `Box<dyn std::error::Error>` type [#1905]
* `client::ConnectorConfig` default timeout changed to 5 seconds. [#1905] * `client::ConnectorConfig` default timeout changed to 5 seconds. [#1905]
* Simplify `BlockingError` type to a struct. It's only triggered with blocking thread pool is dead. [#1957] * Simplify `BlockingError` type to a struct. It's only triggered with blocking thread pool is dead. [#1957]
* `body::ResponseBody` enum use named field.
### Removed ### Removed
* `ResponseBuilder::set`; use `ResponseBuilder::insert_header`. [#1869] * `ResponseBuilder::set`; use `ResponseBuilder::insert_header`. [#1869]

View File

@ -69,7 +69,7 @@ language-tags = "0.2"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
percent-encoding = "2.1" percent-encoding = "2.1"
pin-project = "1.0.0" pin-project-lite = "0.2"
rand = "0.8" rand = "0.8"
regex = "1.3" regex = "1.3"
serde = "1.0" serde = "1.0"

View File

@ -4,7 +4,7 @@ use std::{fmt, mem};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_core::{ready, Stream}; use futures_core::{ready, Stream};
use pin_project::pin_project; use pin_project_lite::pin_project;
use crate::error::Error; use crate::error::Error;
@ -63,31 +63,33 @@ impl<T: MessageBody + Unpin> MessageBody for Box<T> {
} }
} }
#[pin_project(project = ResponseBodyProj)] pin_project! {
pub enum ResponseBody<B> { #[project = ResponseBodyProj]
Body(#[pin] B), pub enum ResponseBody<B> {
Other(Body), Body { #[pin] body: B },
Other { body: Body }
}
} }
impl ResponseBody<Body> { impl ResponseBody<Body> {
pub fn into_body<B>(self) -> ResponseBody<B> { pub fn into_body<B>(self) -> ResponseBody<B> {
match self { match self {
ResponseBody::Body(b) => ResponseBody::Other(b), ResponseBody::Body { body } => ResponseBody::Other { body },
ResponseBody::Other(b) => ResponseBody::Other(b), ResponseBody::Other { body } => ResponseBody::Other { body },
} }
} }
} }
impl<B> ResponseBody<B> { impl<B> ResponseBody<B> {
pub fn take_body(&mut self) -> ResponseBody<B> { pub fn take_body(&mut self) -> ResponseBody<B> {
std::mem::replace(self, ResponseBody::Other(Body::None)) std::mem::replace(self, ResponseBody::Other { body: Body::None })
} }
} }
impl<B: MessageBody> ResponseBody<B> { impl<B: MessageBody> ResponseBody<B> {
pub fn as_ref(&self) -> Option<&B> { pub fn as_ref(&self) -> Option<&B> {
if let ResponseBody::Body(ref b) = self { if let ResponseBody::Body { body } = self {
Some(b) Some(body)
} else { } else {
None None
} }
@ -97,8 +99,8 @@ impl<B: MessageBody> ResponseBody<B> {
impl<B: MessageBody> MessageBody for ResponseBody<B> { impl<B: MessageBody> MessageBody for ResponseBody<B> {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
match self { match self {
ResponseBody::Body(ref body) => body.size(), ResponseBody::Body { body } => body.size(),
ResponseBody::Other(ref body) => body.size(), ResponseBody::Other { body } => body.size(),
} }
} }
@ -107,8 +109,8 @@ impl<B: MessageBody> MessageBody for ResponseBody<B> {
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Error>>> { ) -> Poll<Option<Result<Bytes, Error>>> {
match self.project() { match self.project() {
ResponseBodyProj::Body(body) => body.poll_next(cx), ResponseBodyProj::Body { body } => body.poll_next(cx),
ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), ResponseBodyProj::Other { body } => Pin::new(body).poll_next(cx),
} }
} }
} }
@ -121,8 +123,8 @@ impl<B: MessageBody> Stream for ResponseBody<B> {
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> { ) -> Poll<Option<Self::Item>> {
match self.project() { match self.project() {
ResponseBodyProj::Body(body) => body.poll_next(cx), ResponseBodyProj::Body { body } => body.poll_next(cx),
ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), ResponseBodyProj::Other { body } => Pin::new(body).poll_next(cx),
} }
} }
} }
@ -468,8 +470,8 @@ mod tests {
impl ResponseBody<Body> { impl ResponseBody<Body> {
pub(crate) fn get_ref(&self) -> &[u8] { pub(crate) fn get_ref(&self) -> &[u8] {
match *self { match *self {
ResponseBody::Body(ref b) => b.get_ref(), ResponseBody::Body { ref body } => body.get_ref(),
ResponseBody::Other(ref b) => b.get_ref(), ResponseBody::Other { ref body } => body.get_ref(),
} }
} }
} }

View File

@ -10,7 +10,7 @@ use bytes::Bytes;
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::future::{err, Either, FutureExt, Ready}; use futures_util::future::{err, Either, FutureExt, Ready};
use h2::client::SendRequest; use h2::client::SendRequest;
use pin_project::pin_project; use pin_project_lite::pin_project;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::h1::ClientCodec; use crate::h1::ClientCodec;
@ -245,25 +245,30 @@ where
EitherConnection::A(con) => con EitherConnection::A(con) => con
.open_tunnel(head) .open_tunnel(head)
.map(|res| { .map(|res| {
res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::A))) res.map(|(head, framed)| {
(head, framed.into_map_io(|a| EitherIo::A { a }))
})
}) })
.boxed_local(), .boxed_local(),
EitherConnection::B(con) => con EitherConnection::B(con) => con
.open_tunnel(head) .open_tunnel(head)
.map(|res| { .map(|res| {
res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::B))) res.map(|(head, framed)| {
(head, framed.into_map_io(|b| EitherIo::B { b }))
})
}) })
.boxed_local(), .boxed_local(),
} }
} }
} }
#[pin_project(project = EitherIoProj)] pin_project! {
pub enum EitherIo<A, B> { #[project = EitherIoProj]
A(#[pin] A), pub enum EitherIo<A, B> {
B(#[pin] B), A { #[pin] a: A },
B { #[pin] b: B}
}
} }
impl<A, B> AsyncRead for EitherIo<A, B> impl<A, B> AsyncRead for EitherIo<A, B>
where where
A: AsyncRead, A: AsyncRead,
@ -275,8 +280,8 @@ where
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
match self.project() { match self.project() {
EitherIoProj::A(val) => val.poll_read(cx, buf), EitherIoProj::A { a } => a.poll_read(cx, buf),
EitherIoProj::B(val) => val.poll_read(cx, buf), EitherIoProj::B { b } => b.poll_read(cx, buf),
} }
} }
} }
@ -292,15 +297,15 @@ where
buf: &[u8], buf: &[u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
match self.project() { match self.project() {
EitherIoProj::A(val) => val.poll_write(cx, buf), EitherIoProj::A { a } => a.poll_write(cx, buf),
EitherIoProj::B(val) => val.poll_write(cx, buf), EitherIoProj::B { b } => b.poll_write(cx, buf),
} }
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.project() { match self.project() {
EitherIoProj::A(val) => val.poll_flush(cx), EitherIoProj::A { a } => a.poll_flush(cx),
EitherIoProj::B(val) => val.poll_flush(cx), EitherIoProj::B { b } => b.poll_flush(cx),
} }
} }
@ -309,8 +314,8 @@ where
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
match self.project() { match self.project() {
EitherIoProj::A(val) => val.poll_shutdown(cx), EitherIoProj::A { a } => a.poll_shutdown(cx),
EitherIoProj::B(val) => val.poll_shutdown(cx), EitherIoProj::B { b } => b.poll_shutdown(cx),
} }
} }
} }

View File

@ -416,6 +416,7 @@ mod connect_impl {
use futures_core::ready; use futures_core::ready;
use futures_util::future::Either; use futures_util::future::Either;
use pin_project_lite::pin_project;
use super::*; use super::*;
use crate::client::connection::EitherConnection; use crate::client::connection::EitherConnection;
@ -478,15 +479,20 @@ mod connect_impl {
} }
} }
#[pin_project::pin_project] pin_project! {
pub(crate) struct InnerConnectorResponseA<T, Io1, Io2> pub(crate) struct InnerConnectorResponseA<T, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead,
T: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static, Io1: AsyncWrite,
Io1: Unpin,
Io1: 'static,
T: Service<Connect, Response = (Io1, Protocol), Error = ConnectError>,
T: 'static
{ {
#[pin] #[pin]
fut: <ConnectionPool<T, Io1> as Service<Connect>>::Future, fut: <ConnectionPool<T, Io1> as Service<Connect>>::Future,
_phantom: PhantomData<Io2>, _phantom: PhantomData<Io2>
}
} }
impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2>
@ -505,15 +511,20 @@ mod connect_impl {
} }
} }
#[pin_project::pin_project] pin_project! {
pub(crate) struct InnerConnectorResponseB<T, Io1, Io2> pub(crate) struct InnerConnectorResponseB<T, Io1, Io2>
where where
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead,
T: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static, Io2: AsyncWrite,
Io2: Unpin,
Io2: 'static,
T: Service<Connect, Response = (Io2, Protocol), Error = ConnectError>,
T: 'static
{ {
#[pin] #[pin]
fut: <ConnectionPool<T, Io2> as Service<Connect>>::Future, fut: <ConnectionPool<T, Io2> as Service<Connect>>::Future,
_phantom: PhantomData<Io1>, _phantom: PhantomData<Io1>
}
} }
impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2>

View File

@ -9,6 +9,7 @@ use bytes::{Bytes, BytesMut};
use futures_core::Stream; use futures_core::Stream;
use futures_util::future::poll_fn; use futures_util::future::poll_fn;
use futures_util::{pin_mut, SinkExt, StreamExt}; use futures_util::{pin_mut, SinkExt, StreamExt};
use pin_project_lite::pin_project;
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::h1; use crate::h1;
@ -237,10 +238,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T>
} }
} }
#[pin_project::pin_project] pin_project! {
pub(crate) struct PlStream<Io> { pub(crate) struct PlStream<Io> {
#[pin] #[pin]
framed: Option<Framed<Io, h1::ClientPayloadCodec>>, framed: Option<Framed<Io, h1::ClientPayloadCodec>>,
}
} }
impl<Io: ConnectionLifetime> PlStream<Io> { impl<Io: ConnectionLifetime> PlStream<Io> {

View File

@ -11,14 +11,12 @@ use actix_rt::time::{sleep, Sleep};
use actix_service::Service; use actix_service::Service;
use actix_utils::task::LocalWaker; use actix_utils::task::LocalWaker;
use ahash::AHashMap; use ahash::AHashMap;
use bytes::Bytes;
use futures_channel::oneshot; use futures_channel::oneshot;
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::future::{poll_fn, FutureExt}; use futures_util::future::{poll_fn, FutureExt};
use h2::client::{Connection, SendRequest};
use http::uri::Authority; use http::uri::Authority;
use indexmap::IndexSet; use indexmap::IndexSet;
use pin_project::pin_project; use pin_project_lite::pin_project;
use slab::Slab; use slab::Slab;
use super::config::ConnectorConfig; use super::config::ConnectorConfig;
@ -386,11 +384,12 @@ where
} }
} }
#[pin_project::pin_project] pin_project! {
struct CloseConnection<T> { struct CloseConnection<T> {
io: T, io: T,
#[pin] #[pin]
timeout: Sleep, timeout: Sleep
}
} }
impl<T> CloseConnection<T> impl<T> CloseConnection<T>
@ -424,7 +423,6 @@ where
} }
} }
#[pin_project]
struct ConnectorPoolSupport<T, Io> struct ConnectorPoolSupport<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
@ -442,9 +440,9 @@ where
type Output = (); type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.get_mut();
if Rc::strong_count(this.inner) == 1 { if Rc::strong_count(&this.inner) == 1 {
// If we are last copy of Inner<Io> it means the ConnectionPool is already gone // If we are last copy of Inner<Io> it means the ConnectionPool is already gone
// and we are safe to exit. // and we are safe to exit.
return Poll::Ready(()); return Poll::Ready(());
@ -498,55 +496,75 @@ where
} }
} }
#[pin_project::pin_project(PinnedDrop)] struct OpenWaitingConnection<Io>
struct OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
#[pin]
fut: F,
key: Key,
h2: Option<
LocalBoxFuture<
'static,
Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>,
>,
>,
rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
config: ConnectorConfig,
} }
impl<F, Io> OpenWaitingConnection<F, Io> impl<Io> OpenWaitingConnection<Io>
where where
F: Future<Output = Result<(Io, Protocol), ConnectError>> + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn spawn( fn spawn<F>(
key: Key, key: Key,
rx: oneshot::Sender<Result<IoConnection<Io>, ConnectError>>, rx: oneshot::Sender<Result<IoConnection<Io>, ConnectError>>,
inner: Rc<RefCell<Inner<Io>>>, inner: Rc<RefCell<Inner<Io>>>,
fut: F, fut: F,
config: ConnectorConfig, config: ConnectorConfig,
) { ) where
actix_rt::spawn(OpenWaitingConnection { F: Future<Output = Result<(Io, Protocol), ConnectError>> + 'static,
key, {
fut, // OpenWaitingConnection would guard the spawn task and release
h2: None, // permission/wake up support future when spawn task is canceled/generated error.
rx: Some(rx), let mut guard = OpenWaitingConnection { inner: Some(inner) };
inner: Some(inner),
config, actix_rt::spawn(async move {
let (io, proto) = match fut.await {
Ok((io, proto)) => (io, proto),
Err(e) => {
let _ = Option::take(&mut guard.inner);
let _ = rx.send(Err(e));
return;
}
};
match proto {
Protocol::Http1 => {
let inner = Option::take(&mut guard.inner);
let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(Acquired(key, inner)),
)));
}
_ => match handshake(io, &config).await {
Ok((sender, connection)) => {
let inner = Option::take(&mut guard.inner);
let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H2(H2Connection::new(sender, connection)),
Instant::now(),
Some(Acquired(key, inner)),
)));
}
Err(err) => {
let _ = Option::take(&mut guard.inner);
let _ = rx.send(Err(ConnectError::H2(err)));
}
},
}
}); });
} }
} }
#[pin_project::pinned_drop] impl<Io> Drop for OpenWaitingConnection<Io>
impl<F, Io> PinnedDrop for OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn drop(self: Pin<&mut Self>) { fn drop(&mut self) {
if let Some(inner) = self.project().inner.take() { // if inner is some it means OpenWaitingConnection did not finish
// it's task. release permission and try to wake up support future.
if let Some(inner) = self.inner.take() {
let mut inner = inner.as_ref().borrow_mut(); let mut inner = inner.as_ref().borrow_mut();
inner.release(); inner.release();
inner.check_availability(); inner.check_availability();
@ -554,65 +572,6 @@ where
} }
} }
impl<F, Io> Future for OpenWaitingConnection<F, Io>
where
F: Future<Output = Result<(Io, Protocol), ConnectError>>,
Io: AsyncRead + AsyncWrite + Unpin,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();
if let Some(ref mut h2) = this.h2 {
return match Pin::new(h2).poll(cx) {
Poll::Ready(Ok((sender, connection))) => {
let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H2(H2Connection::new(sender, connection)),
Instant::now(),
Some(Acquired(this.key.clone(), this.inner.take())),
)));
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
let _ = this.inner.take();
if let Some(rx) = this.rx.take() {
let _ = rx.send(Err(ConnectError::H2(err)));
}
Poll::Ready(())
}
};
}
match this.fut.poll(cx) {
Poll::Ready(Err(err)) => {
let _ = this.inner.take();
if let Some(rx) = this.rx.take() {
let _ = rx.send(Err(err));
}
Poll::Ready(())
}
Poll::Ready(Ok((io, proto))) => {
if proto == Protocol::Http1 {
let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(Acquired(this.key.clone(), this.inner.take())),
)));
Poll::Ready(())
} else {
*this.h2 = Some(handshake(io, this.config).boxed_local());
self.poll(cx)
}
}
Poll::Pending => Poll::Pending,
}
}
}
pub(crate) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>); pub(crate) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
impl<T> Acquired<T> impl<T> Acquired<T>

View File

@ -9,7 +9,7 @@ use brotli2::write::BrotliEncoder;
use bytes::Bytes; use bytes::Bytes;
use flate2::write::{GzEncoder, ZlibEncoder}; use flate2::write::{GzEncoder, ZlibEncoder};
use futures_core::ready; use futures_core::ready;
use pin_project::pin_project; use pin_project_lite::pin_project;
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; use crate::http::header::{ContentEncoding, CONTENT_ENCODING};
@ -21,13 +21,14 @@ use crate::error::BlockingError;
const INPLACE: usize = 1024; const INPLACE: usize = 1024;
#[pin_project] pin_project! {
pub struct Encoder<B> { pub struct Encoder<B> {
eof: bool, eof: bool,
#[pin] #[pin]
body: EncoderBody<B>, body: EncoderBody<B>,
encoder: Option<ContentEncoder>, encoder: Option<ContentEncoder>,
fut: Option<JoinHandle<Result<ContentEncoder, io::Error>>>, fut: Option<JoinHandle<Result<ContentEncoder, io::Error>>>,
}
} }
impl<B: MessageBody> Encoder<B> { impl<B: MessageBody> Encoder<B> {
@ -43,19 +44,21 @@ impl<B: MessageBody> Encoder<B> {
|| encoding == ContentEncoding::Auto); || encoding == ContentEncoding::Auto);
let body = match body { let body = match body {
ResponseBody::Other(b) => match b { ResponseBody::Other { body } => match body {
Body::None => return ResponseBody::Other(Body::None), Body::None => return ResponseBody::Other { body: Body::None },
Body::Empty => return ResponseBody::Other(Body::Empty), Body::Empty => return ResponseBody::Other { body: Body::Empty },
Body::Bytes(buf) => { Body::Bytes(bytes) => {
if can_encode { if can_encode {
EncoderBody::Bytes(buf) EncoderBody::Bytes { bytes }
} else { } else {
return ResponseBody::Other(Body::Bytes(buf)); return ResponseBody::Other {
body: Body::Bytes(bytes),
};
} }
} }
Body::Message(stream) => EncoderBody::BoxedStream(stream), Body::Message(stream) => EncoderBody::BoxedStream { stream },
}, },
ResponseBody::Body(stream) => EncoderBody::Stream(stream), ResponseBody::Body { body } => EncoderBody::Stream { stream: body },
}; };
if can_encode { if can_encode {
@ -63,36 +66,42 @@ impl<B: MessageBody> Encoder<B> {
if let Some(enc) = ContentEncoder::encoder(encoding) { if let Some(enc) = ContentEncoder::encoder(encoding) {
update_head(encoding, head); update_head(encoding, head);
head.no_chunking(false); head.no_chunking(false);
return ResponseBody::Body(Encoder { return ResponseBody::Body {
body: Encoder {
body, body,
eof: false, eof: false,
fut: None, fut: None,
encoder: Some(enc), encoder: Some(enc),
}); },
};
} }
} }
ResponseBody::Body(Encoder { ResponseBody::Body {
body: Encoder {
body, body,
eof: false, eof: false,
fut: None, fut: None,
encoder: None, encoder: None,
}) },
}
} }
} }
#[pin_project(project = EncoderBodyProj)] pin_project! {
enum EncoderBody<B> { #[project = EncoderBodyProj]
Bytes(Bytes), enum EncoderBody<B> {
Stream(#[pin] B), Bytes { bytes: Bytes },
BoxedStream(Box<dyn MessageBody + Unpin>), Stream { #[pin] stream: B },
BoxedStream { stream: Box<dyn MessageBody + Unpin> }
}
} }
impl<B: MessageBody> MessageBody for EncoderBody<B> { impl<B: MessageBody> MessageBody for EncoderBody<B> {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
match self { match self {
EncoderBody::Bytes(ref b) => b.size(), EncoderBody::Bytes { ref bytes } => bytes.size(),
EncoderBody::Stream(ref b) => b.size(), EncoderBody::Stream { ref stream } => stream.size(),
EncoderBody::BoxedStream(ref b) => b.size(), EncoderBody::BoxedStream { ref stream } => stream.size(),
} }
} }
@ -101,16 +110,16 @@ impl<B: MessageBody> MessageBody for EncoderBody<B> {
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Error>>> { ) -> Poll<Option<Result<Bytes, Error>>> {
match self.project() { match self.project() {
EncoderBodyProj::Bytes(b) => { EncoderBodyProj::Bytes { bytes } => {
if b.is_empty() { if bytes.is_empty() {
Poll::Ready(None) Poll::Ready(None)
} else { } else {
Poll::Ready(Some(Ok(std::mem::take(b)))) Poll::Ready(Some(Ok(std::mem::take(bytes))))
} }
} }
EncoderBodyProj::Stream(b) => b.poll_next(cx), EncoderBodyProj::Stream { stream } => stream.poll_next(cx),
EncoderBodyProj::BoxedStream(ref mut b) => { EncoderBodyProj::BoxedStream { stream } => {
Pin::new(b.as_mut()).poll_next(cx) Pin::new(stream.as_mut()).poll_next(cx)
} }
} }
} }

View File

@ -14,7 +14,7 @@ use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Buf, BytesMut}; use bytes::{Buf, BytesMut};
use log::{error, trace}; use log::{error, trace};
use pin_project::pin_project; use pin_project_lite::pin_project;
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
@ -45,51 +45,72 @@ bitflags! {
} }
} }
#[pin_project::pin_project] #[cfg(test)]
/// Dispatcher for HTTP/1.1 protocol pin_project! {
pub struct Dispatcher<T, S, B, X, U> /// Dispatcher for HTTP/1.1 protocol
where pub struct Dispatcher<T, S, B, X, U>
where
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
B: MessageBody, B: MessageBody,
X: Service<Request, Response = Request>, X: Service<Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<(Request, Framed<T, Codec>), Response = ()>, U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display
{ {
#[pin] #[pin]
inner: DispatcherState<T, S, B, X, U>, inner: DispatcherState<T, S, B, X, U>,
poll_count: u64
#[cfg(test)] }
poll_count: u64,
} }
#[pin_project(project = DispatcherStateProj)] #[cfg(not(test))]
enum DispatcherState<T, S, B, X, U> pin_project! {
where /// Dispatcher for HTTP/1.1 protocol
pub struct Dispatcher<T, S, B, X, U>
where
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
B: MessageBody, B: MessageBody,
X: Service<Request, Response = Request>, X: Service<Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<(Request, Framed<T, Codec>), Response = ()>, U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display
{ {
Normal(#[pin] InnerDispatcher<T, S, B, X, U>), #[pin]
Upgrade(#[pin] U::Future), inner: DispatcherState<T, S, B, X, U>
}
} }
#[pin_project(project = InnerDispatcherProj)] pin_project! {
struct InnerDispatcher<T, S, B, X, U> #[project = DispatcherStateProj]
where enum DispatcherState<T, S, B, X, U>
where
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
B: MessageBody, B: MessageBody,
X: Service<Request, Response = Request>, X: Service<Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<(Request, Framed<T, Codec>), Response = ()>, U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display
{ {
Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U>},
Upgrade { #[pin] upgrade: U::Future }
}
}
pin_project! {
#[project = InnerDispatcherProj]
struct InnerDispatcher<T, S, B, X, U>
where
S: Service<Request>,
S::Error: Into<Error>,
B: MessageBody,
X: Service<Request, Response = Request>,
X::Error: Into<Error>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display
{
flow: Rc<HttpFlow<S, X, U>>, flow: Rc<HttpFlow<S, X, U>>,
on_connect_data: OnConnectData, on_connect_data: OnConnectData,
flags: Flags, flags: Flags,
@ -108,7 +129,8 @@ where
io: Option<T>, io: Option<T>,
read_buf: BytesMut, read_buf: BytesMut,
write_buf: BytesMut, write_buf: BytesMut,
codec: Codec, codec: Codec
}
} }
enum DispatcherMessage { enum DispatcherMessage {
@ -117,17 +139,19 @@ enum DispatcherMessage {
Error(Response<()>), Error(Response<()>),
} }
#[pin_project(project = StateProj)] pin_project! {
enum State<S, B, X> #[project = StateProj]
where enum State<S, B, X>
where
S: Service<Request>, S: Service<Request>,
X: Service<Request, Response = Request>, X: Service<Request, Response = Request>,
B: MessageBody, B: MessageBody
{ {
None, None,
ExpectCall(#[pin] X::Future), ExpectCall { #[pin] fut: X::Future },
ServiceCall(#[pin] S::Future), ServiceCall { #[pin] fut: S::Future },
SendPayload(#[pin] ResponseBody<B>), SendPayload { #[pin] body: ResponseBody<B> }
}
} }
impl<S, B, X> State<S, B, X> impl<S, B, X> State<S, B, X>
@ -141,7 +165,7 @@ where
} }
fn is_call(&self) -> bool { fn is_call(&self) -> bool {
matches!(self, State::ServiceCall(_)) matches!(self, State::ServiceCall { .. })
} }
} }
enum PollResponse { enum PollResponse {
@ -220,7 +244,8 @@ where
}; };
Dispatcher { Dispatcher {
inner: DispatcherState::Normal(InnerDispatcher { inner: DispatcherState::Normal {
inner: InnerDispatcher {
write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
payload: None, payload: None,
state: State::None, state: State::None,
@ -235,7 +260,8 @@ where
peer_addr, peer_addr,
ka_expire, ka_expire,
ka_timer, ka_timer,
}), },
},
#[cfg(test)] #[cfg(test)]
poll_count: 0, poll_count: 0,
@ -337,7 +363,7 @@ where
this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); this.flags.set(Flags::KEEPALIVE, this.codec.keepalive());
match body.size() { match body.size() {
BodySize::None | BodySize::Empty => this.state.set(State::None), BodySize::None | BodySize::Empty => this.state.set(State::None),
_ => this.state.set(State::SendPayload(body)), _ => this.state.set(State::SendPayload { body }),
}; };
Ok(()) Ok(())
} }
@ -363,8 +389,10 @@ where
true true
} }
Some(DispatcherMessage::Error(res)) => { Some(DispatcherMessage::Error(res)) => {
self.as_mut() self.as_mut().send_response(
.send_response(res, ResponseBody::Other(Body::Empty))?; res,
ResponseBody::Other { body: Body::Empty },
)?;
true true
} }
Some(DispatcherMessage::Upgrade(req)) => { Some(DispatcherMessage::Upgrade(req)) => {
@ -372,12 +400,12 @@ where
} }
None => false, None => false,
}, },
StateProj::ExpectCall(fut) => match fut.poll(cx) { StateProj::ExpectCall { fut } => match fut.poll(cx) {
Poll::Ready(Ok(req)) => { Poll::Ready(Ok(req)) => {
self.as_mut().send_continue(); self.as_mut().send_continue();
this = self.as_mut().project(); this = self.as_mut().project();
let fut = this.flow.service.call(req); let fut = this.flow.service.call(req);
this.state.set(State::ServiceCall(fut)); this.state.set(State::ServiceCall { fut });
continue; continue;
} }
Poll::Ready(Err(e)) => { Poll::Ready(Err(e)) => {
@ -388,7 +416,7 @@ where
} }
Poll::Pending => false, Poll::Pending => false,
}, },
StateProj::ServiceCall(fut) => match fut.poll(cx) { StateProj::ServiceCall { fut } => match fut.poll(cx) {
Poll::Ready(Ok(res)) => { Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
self.as_mut().send_response(res, body)?; self.as_mut().send_response(res, body)?;
@ -402,10 +430,10 @@ where
} }
Poll::Pending => false, Poll::Pending => false,
}, },
StateProj::SendPayload(mut stream) => { StateProj::SendPayload { mut body } => {
loop { loop {
if this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { if this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
match stream.as_mut().poll_next(cx) { match body.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(item))) => { Poll::Ready(Some(Ok(item))) => {
this.codec.encode( this.codec.encode(
Message::Chunk(Some(item)), Message::Chunk(Some(item)),
@ -466,26 +494,26 @@ where
if req.head().expect() { if req.head().expect() {
// set dispatcher state so the future is pinned. // set dispatcher state so the future is pinned.
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
let task = this.flow.expect.call(req); let fut = this.flow.expect.call(req);
this.state.set(State::ExpectCall(task)); this.state.set(State::ExpectCall { fut });
} else { } else {
// the same as above. // the same as above.
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
let task = this.flow.service.call(req); let fut = this.flow.service.call(req);
this.state.set(State::ServiceCall(task)); this.state.set(State::ServiceCall { fut });
}; };
// eagerly poll the future for once(or twice if expect is resolved immediately). // eagerly poll the future for once(or twice if expect is resolved immediately).
loop { loop {
match self.as_mut().project().state.project() { match self.as_mut().project().state.project() {
StateProj::ExpectCall(fut) => { StateProj::ExpectCall { fut } => {
match fut.poll(cx) { match fut.poll(cx) {
// expect is resolved. continue loop and poll the service call branch. // expect is resolved. continue loop and poll the service call branch.
Poll::Ready(Ok(req)) => { Poll::Ready(Ok(req)) => {
self.as_mut().send_continue(); self.as_mut().send_continue();
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
let task = this.flow.service.call(req); let fut = this.flow.service.call(req);
this.state.set(State::ServiceCall(task)); this.state.set(State::ServiceCall { fut });
continue; continue;
} }
// future is pending. return Ok(()) to notify that a new state is // future is pending. return Ok(()) to notify that a new state is
@ -502,7 +530,7 @@ where
} }
} }
} }
StateProj::ServiceCall(fut) => { StateProj::ServiceCall { fut } => {
// return no matter the service call future's result. // return no matter the service call future's result.
return match fut.poll(cx) { return match fut.poll(cx) {
// future is resolved. send response and return a result. On success // future is resolved. send response and return a result. On success
@ -707,7 +735,7 @@ where
trace!("Slow request timeout"); trace!("Slow request timeout");
let _ = self.as_mut().send_response( let _ = self.as_mut().send_response(
Response::RequestTimeout().finish().drop_body(), Response::RequestTimeout().finish().drop_body(),
ResponseBody::Other(Body::Empty), ResponseBody::Other { body: Body::Empty },
); );
this = self.project(); this = self.project();
} else { } else {
@ -832,7 +860,7 @@ where
} }
match this.inner.project() { match this.inner.project() {
DispatcherStateProj::Normal(mut inner) => { DispatcherStateProj::Normal { mut inner } => {
inner.as_mut().poll_keepalive(cx)?; inner.as_mut().poll_keepalive(cx)?;
if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::SHUTDOWN) {
@ -876,7 +904,7 @@ where
self.as_mut() self.as_mut()
.project() .project()
.inner .inner
.set(DispatcherState::Upgrade(upgrade)); .set(DispatcherState::Upgrade { upgrade });
return self.poll(cx); return self.poll(cx);
} }
}; };
@ -928,7 +956,7 @@ where
} }
} }
} }
DispatcherStateProj::Upgrade(fut) => fut.poll(cx).map_err(|e| { DispatcherStateProj::Upgrade { upgrade } => upgrade.poll(cx).map_err(|e| {
error!("Upgrade handler error: {}", e); error!("Upgrade handler error: {}", e);
DispatchError::Upgrade DispatchError::Upgrade
}), }),
@ -1017,7 +1045,7 @@ mod tests {
Poll::Ready(res) => assert!(res.is_err()), Poll::Ready(res) => assert!(res.is_err()),
} }
if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() {
assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert!(inner.flags.contains(Flags::READ_DISCONNECT));
assert_eq!( assert_eq!(
&inner.project().io.take().unwrap().write_buf[..26], &inner.project().io.take().unwrap().write_buf[..26],
@ -1052,7 +1080,7 @@ mod tests {
futures_util::pin_mut!(h1); futures_util::pin_mut!(h1);
assert!(matches!(&h1.inner, DispatcherState::Normal(_))); assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
match h1.as_mut().poll(cx) { match h1.as_mut().poll(cx) {
Poll::Pending => panic!("first poll should not be pending"), Poll::Pending => panic!("first poll should not be pending"),
@ -1062,7 +1090,7 @@ mod tests {
// polls: initial => shutdown // polls: initial => shutdown
assert_eq!(h1.poll_count, 2); assert_eq!(h1.poll_count, 2);
if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() {
let res = &mut inner.project().io.take().unwrap().write_buf[..]; let res = &mut inner.project().io.take().unwrap().write_buf[..];
stabilize_date_header(res); stabilize_date_header(res);
@ -1106,7 +1134,7 @@ mod tests {
futures_util::pin_mut!(h1); futures_util::pin_mut!(h1);
assert!(matches!(&h1.inner, DispatcherState::Normal(_))); assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
match h1.as_mut().poll(cx) { match h1.as_mut().poll(cx) {
Poll::Pending => panic!("first poll should not be pending"), Poll::Pending => panic!("first poll should not be pending"),
@ -1116,7 +1144,7 @@ mod tests {
// polls: initial => shutdown // polls: initial => shutdown
assert_eq!(h1.poll_count, 1); assert_eq!(h1.poll_count, 1);
if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() {
let res = &mut inner.project().io.take().unwrap().write_buf[..]; let res = &mut inner.project().io.take().unwrap().write_buf[..];
stabilize_date_header(res); stabilize_date_header(res);
@ -1166,13 +1194,13 @@ mod tests {
futures_util::pin_mut!(h1); futures_util::pin_mut!(h1);
assert!(h1.as_mut().poll(cx).is_pending()); assert!(h1.as_mut().poll(cx).is_pending());
assert!(matches!(&h1.inner, DispatcherState::Normal(_))); assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
// polls: manual // polls: manual
assert_eq!(h1.poll_count, 1); assert_eq!(h1.poll_count, 1);
eprintln!("poll count: {}", h1.poll_count); eprintln!("poll count: {}", h1.poll_count);
if let DispatcherState::Normal(ref inner) = h1.inner { if let DispatcherState::Normal { ref inner } = h1.inner {
let io = inner.io.as_ref().unwrap(); let io = inner.io.as_ref().unwrap();
let res = &io.write_buf()[..]; let res = &io.write_buf()[..];
assert_eq!( assert_eq!(
@ -1187,7 +1215,7 @@ mod tests {
// polls: manual manual shutdown // polls: manual manual shutdown
assert_eq!(h1.poll_count, 3); assert_eq!(h1.poll_count, 3);
if let DispatcherState::Normal(ref inner) = h1.inner { if let DispatcherState::Normal { ref inner } = h1.inner {
let io = inner.io.as_ref().unwrap(); let io = inner.io.as_ref().unwrap();
let mut res = (&io.write_buf()[..]).to_owned(); let mut res = (&io.write_buf()[..]).to_owned();
stabilize_date_header(&mut res); stabilize_date_header(&mut res);
@ -1238,12 +1266,12 @@ mod tests {
futures_util::pin_mut!(h1); futures_util::pin_mut!(h1);
assert!(h1.as_mut().poll(cx).is_ready()); assert!(h1.as_mut().poll(cx).is_ready());
assert!(matches!(&h1.inner, DispatcherState::Normal(_))); assert!(matches!(&h1.inner, DispatcherState::Normal { .. }));
// polls: manual shutdown // polls: manual shutdown
assert_eq!(h1.poll_count, 2); assert_eq!(h1.poll_count, 2);
if let DispatcherState::Normal(ref inner) = h1.inner { if let DispatcherState::Normal { ref inner } = h1.inner {
let io = inner.io.as_ref().unwrap(); let io = inner.io.as_ref().unwrap();
let mut res = (&io.write_buf()[..]).to_owned(); let mut res = (&io.write_buf()[..]).to_owned();
stabilize_date_header(&mut res); stabilize_date_header(&mut res);
@ -1299,7 +1327,7 @@ mod tests {
futures_util::pin_mut!(h1); futures_util::pin_mut!(h1);
assert!(h1.as_mut().poll(cx).is_ready()); assert!(h1.as_mut().poll(cx).is_ready());
assert!(matches!(&h1.inner, DispatcherState::Upgrade(_))); assert!(matches!(&h1.inner, DispatcherState::Upgrade { .. }));
// polls: manual shutdown // polls: manual shutdown
assert_eq!(h1.poll_count, 2); assert_eq!(h1.poll_count, 2);

View File

@ -10,6 +10,7 @@ use actix_rt::net::TcpStream;
use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
use futures_core::ready; use futures_core::ready;
use futures_util::future::ready; use futures_util::future::ready;
use pin_project_lite::pin_project;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
@ -275,10 +276,10 @@ where
} }
} }
#[doc(hidden)] pin_project! {
#[pin_project::pin_project] #[doc(hidden)]
pub struct H1ServiceResponse<T, S, B, X, U> pub struct H1ServiceResponse<T, S, B, X, U>
where where
S: ServiceFactory<Request>, S: ServiceFactory<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
@ -287,8 +288,8 @@ where
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Framed<T, Codec>), Response = ()>, U: ServiceFactory<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug
{ {
#[pin] #[pin]
fut: S::Future, fut: S::Future,
#[pin] #[pin]
@ -299,7 +300,8 @@ where
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
_phantom: PhantomData<B>, _phantom: PhantomData<B>
}
} }
impl<T, S, B, X, U> Future for H1ServiceResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for H1ServiceResponse<T, S, B, X, U>

View File

@ -3,20 +3,22 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use pin_project_lite::pin_project;
use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::body::{BodySize, MessageBody, ResponseBody};
use crate::error::Error; use crate::error::Error;
use crate::h1::{Codec, Message}; use crate::h1::{Codec, Message};
use crate::response::Response; use crate::response::Response;
/// Send HTTP/1 response pin_project! {
#[pin_project::pin_project] /// Send HTTP/1 response
pub struct SendResponse<T, B> { pub struct SendResponse<T, B> {
res: Option<Message<(Response<()>, BodySize)>>, res: Option<Message<(Response<()>, BodySize)>>,
#[pin] #[pin]
body: Option<ResponseBody<B>>, body: Option<ResponseBody<B>>,
#[pin] #[pin]
framed: Option<Framed<T, Codec>>, framed: Option<Framed<T, Codec>>
}
} }
impl<T, B> SendResponse<T, B> impl<T, B> SendResponse<T, B>

View File

@ -15,6 +15,7 @@ use h2::server::{Connection, SendResponse};
use h2::SendStream; use h2::SendStream;
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use log::{error, trace}; use log::{error, trace};
use pin_project_lite::pin_project;
use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::body::{BodySize, MessageBody, ResponseBody};
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
@ -28,14 +29,16 @@ use crate::OnConnectData;
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
/// Dispatcher for HTTP/2 protocol. pin_project! {
#[pin_project::pin_project] /// Dispatcher for HTTP/2 protocol.
pub struct Dispatcher<T, S, B, X, U> pub struct Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead,
T: AsyncWrite,
T: Unpin,
S: Service<Request>, S: Service<Request>,
B: MessageBody, B: MessageBody
{ {
flow: Rc<HttpFlow<S, X, U>>, flow: Rc<HttpFlow<S, X, U>>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
on_connect_data: OnConnectData, on_connect_data: OnConnectData,
@ -43,7 +46,8 @@ where
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
ka_expire: Instant, ka_expire: Instant,
ka_timer: Option<Sleep>, ka_timer: Option<Sleep>,
_phantom: PhantomData<B>, _phantom: PhantomData<B>
}
} }
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
@ -136,10 +140,10 @@ where
this.on_connect_data.merge_into(&mut req); this.on_connect_data.merge_into(&mut req);
let svc = ServiceResponse::<S::Future, S::Response, S::Error, B> { let svc = ServiceResponse::<S::Future, S::Response, S::Error, B> {
state: ServiceResponseState::ServiceCall( state: ServiceResponseState::ServiceCall {
this.flow.service.call(req), fut: this.flow.service.call(req),
Some(res), sender: Some(res),
), },
config: this.config.clone(), config: this.config.clone(),
buffer: None, buffer: None,
_phantom: PhantomData, _phantom: PhantomData,
@ -152,19 +156,30 @@ where
} }
} }
#[pin_project::pin_project] pin_project! {
struct ServiceResponse<F, I, E, B> { struct ServiceResponse<F, I, E, B> {
#[pin] #[pin]
state: ServiceResponseState<F, B>, state: ServiceResponseState<F, B>,
config: ServiceConfig, config: ServiceConfig,
buffer: Option<Bytes>, buffer: Option<Bytes>,
_phantom: PhantomData<(I, E)>, _phantom: PhantomData<(I, E)>,
}
} }
#[pin_project::pin_project(project = ServiceResponseStateProj)] pin_project! {
enum ServiceResponseState<F, B> { #[project = ServiceResponseStateProj]
ServiceCall(#[pin] F, Option<SendResponse<Bytes>>), enum ServiceResponseState<F, B> {
SendPayload(SendStream<Bytes>, #[pin] ResponseBody<B>), ServiceCall {
#[pin]
fut: F,
sender: Option<SendResponse<Bytes>>
},
SendPayload {
stream: SendStream<Bytes>,
#[pin]
body: ResponseBody<B>
},
}
} }
impl<F, I, E, B> ServiceResponse<F, I, E, B> impl<F, I, E, B> ServiceResponse<F, I, E, B>
@ -250,12 +265,12 @@ where
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
match this.state.project() { match this.state.project() {
ServiceResponseStateProj::ServiceCall(call, send) => { ServiceResponseStateProj::ServiceCall { fut, sender } => {
match ready!(call.poll(cx)) { match ready!(fut.poll(cx)) {
Ok(res) => { Ok(res) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
let mut send = send.take().unwrap(); let mut send = sender.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = let h2_res =
self.as_mut().prepare_response(res.head(), &mut size); self.as_mut().prepare_response(res.head(), &mut size);
@ -273,7 +288,7 @@ where
Poll::Ready(()) Poll::Ready(())
} else { } else {
this.state this.state
.set(ServiceResponseState::SendPayload(stream, body)); .set(ServiceResponseState::SendPayload { stream, body });
self.poll(cx) self.poll(cx)
} }
} }
@ -282,7 +297,7 @@ where
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
let mut send = send.take().unwrap(); let mut send = sender.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = let h2_res =
self.as_mut().prepare_response(res.head(), &mut size); self.as_mut().prepare_response(res.head(), &mut size);
@ -299,22 +314,20 @@ where
if size.is_eof() { if size.is_eof() {
Poll::Ready(()) Poll::Ready(())
} else { } else {
this.state.set(ServiceResponseState::SendPayload( this.state.set(ServiceResponseState::SendPayload {
stream, stream,
body.into_body(), body: body.into_body(),
)); });
self.poll(cx) self.poll(cx)
} }
} }
} }
} }
ServiceResponseStateProj::SendPayload(ref mut stream, ref mut body) => { ServiceResponseStateProj::SendPayload { stream, mut body } => loop {
loop {
loop { loop {
match this.buffer { match this.buffer {
Some(ref mut buffer) => { Some(ref mut buffer) => match ready!(stream.poll_capacity(cx)) {
match ready!(stream.poll_capacity(cx)) {
None => return Poll::Ready(()), None => return Poll::Ready(()),
Some(Ok(cap)) => { Some(Ok(cap)) => {
@ -336,23 +349,19 @@ where
warn!("{:?}", e); warn!("{:?}", e);
return Poll::Ready(()); return Poll::Ready(());
} }
} },
}
None => match ready!(body.as_mut().poll_next(cx)) { None => match ready!(body.as_mut().poll_next(cx)) {
None => { None => {
if let Err(e) = stream.send_data(Bytes::new(), true) if let Err(e) = stream.send_data(Bytes::new(), true) {
{
warn!("{:?}", e); warn!("{:?}", e);
} }
return Poll::Ready(()); return Poll::Ready(());
} }
Some(Ok(chunk)) => { Some(Ok(chunk)) => {
stream.reserve_capacity(cmp::min( stream
chunk.len(), .reserve_capacity(cmp::min(chunk.len(), CHUNK_SIZE));
CHUNK_SIZE,
));
*this.buffer = Some(chunk); *this.buffer = Some(chunk);
} }
@ -363,8 +372,7 @@ where
}, },
} }
} }
} },
}
} }
} }
} }

View File

@ -15,6 +15,7 @@ use futures_core::ready;
use futures_util::future::ok; use futures_util::future::ok;
use h2::server::{self, Handshake}; use h2::server::{self, Handshake};
use log::error; use log::error;
use pin_project_lite::pin_project;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
@ -205,17 +206,18 @@ where
} }
} }
#[doc(hidden)] pin_project! {
#[pin_project::pin_project] #[doc(hidden)]
pub struct H2ServiceResponse<T, S, B> pub struct H2ServiceResponse<T, S, B>
where where
S: ServiceFactory<Request>, S: ServiceFactory<Request>,
{ {
#[pin] #[pin]
fut: S::Future, fut: S::Future,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_phantom: PhantomData<B>, _phantom: PhantomData<B>
}
} }
impl<T, S, B> Future for H2ServiceResponse<T, S, B> impl<T, S, B> Future for H2ServiceResponse<T, S, B>

View File

@ -49,7 +49,7 @@ impl Response<Body> {
pub fn new(status: StatusCode) -> Response { pub fn new(status: StatusCode) -> Response {
Response { Response {
head: BoxedResponseHead::new(status), head: BoxedResponseHead::new(status),
body: ResponseBody::Body(Body::Empty), body: ResponseBody::Body { body: Body::Empty },
error: None, error: None,
} }
} }
@ -67,14 +67,14 @@ impl Response<Body> {
/// Convert response to response with body /// Convert response to response with body
pub fn into_body<B>(self) -> Response<B> { pub fn into_body<B>(self) -> Response<B> {
let b = match self.body { let body = match self.body {
ResponseBody::Body(b) => b, ResponseBody::Body { body } => body,
ResponseBody::Other(b) => b, ResponseBody::Other { body } => body,
}; };
Response { Response {
head: self.head, head: self.head,
error: self.error, error: self.error,
body: ResponseBody::Other(b), body: ResponseBody::Other { body },
} }
} }
} }
@ -85,7 +85,7 @@ impl<B> Response<B> {
pub fn with_body(status: StatusCode, body: B) -> Response<B> { pub fn with_body(status: StatusCode, body: B) -> Response<B> {
Response { Response {
head: BoxedResponseHead::new(status), head: BoxedResponseHead::new(status),
body: ResponseBody::Body(body), body: ResponseBody::Body { body },
error: None, error: None,
} }
} }
@ -210,7 +210,7 @@ impl<B> Response<B> {
pub fn set_body<B2>(self, body: B2) -> Response<B2> { pub fn set_body<B2>(self, body: B2) -> Response<B2> {
Response { Response {
head: self.head, head: self.head,
body: ResponseBody::Body(body), body: ResponseBody::Body { body },
error: None, error: None,
} }
} }
@ -220,7 +220,7 @@ impl<B> Response<B> {
( (
Response { Response {
head: self.head, head: self.head,
body: ResponseBody::Body(()), body: ResponseBody::Body { body: () },
error: self.error, error: self.error,
}, },
self.body, self.body,
@ -231,7 +231,7 @@ impl<B> Response<B> {
pub fn drop_body(self) -> Response<()> { pub fn drop_body(self) -> Response<()> {
Response { Response {
head: self.head, head: self.head,
body: ResponseBody::Body(()), body: ResponseBody::Body { body: () },
error: None, error: None,
} }
} }
@ -241,7 +241,7 @@ impl<B> Response<B> {
( (
Response { Response {
head: self.head, head: self.head,
body: ResponseBody::Body(body), body: ResponseBody::Body { body },
error: self.error, error: self.error,
}, },
self.body, self.body,
@ -635,7 +635,7 @@ impl ResponseBuilder {
Response { Response {
head: response, head: response,
body: ResponseBody::Body(body), body: ResponseBody::Body { body },
error: None, error: None,
} }
} }

View File

@ -9,7 +9,7 @@ use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactor
use bytes::Bytes; use bytes::Bytes;
use futures_core::{ready, Future}; use futures_core::{ready, Future};
use h2::server::{self, Handshake}; use h2::server::{self, Handshake};
use pin_project::pin_project; use pin_project_lite::pin_project;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::builder::HttpServiceBuilder; use crate::builder::HttpServiceBuilder;
@ -351,14 +351,14 @@ where
} }
} }
#[doc(hidden)] pin_project! {
#[pin_project] #[doc(hidden)]
pub struct HttpServiceResponse<T, S, B, X, U> pub struct HttpServiceResponse<T, S, B, X, U>
where where
S: ServiceFactory<Request>, S: ServiceFactory<Request>,
X: ServiceFactory<Request>, X: ServiceFactory<Request>,
U: ServiceFactory<(Request, Framed<T, h1::Codec>)>, U: ServiceFactory<(Request, Framed<T, h1::Codec>)>
{ {
#[pin] #[pin]
fut: S::Future, fut: S::Future,
#[pin] #[pin]
@ -369,7 +369,8 @@ where
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: ServiceConfig, cfg: ServiceConfig,
_phantom: PhantomData<B>, _phantom: PhantomData<B>
}
} }
impl<T, S, B, X, U> Future for HttpServiceResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for HttpServiceResponse<T, S, B, X, U>
@ -561,23 +562,27 @@ where
match proto { match proto {
Protocol::Http2 => HttpServiceHandlerResponse { Protocol::Http2 => HttpServiceHandlerResponse {
state: State::H2Handshake(Some(( state: State::H2Handshake {
hds: Some((
server::handshake(io), server::handshake(io),
self.cfg.clone(), self.cfg.clone(),
self.flow.clone(), self.flow.clone(),
on_connect_data, on_connect_data,
peer_addr, peer_addr,
))), )),
},
}, },
Protocol::Http1 => HttpServiceHandlerResponse { Protocol::Http1 => HttpServiceHandlerResponse {
state: State::H1(h1::Dispatcher::new( state: State::H1 {
dsp: h1::Dispatcher::new(
io, io,
self.cfg.clone(), self.cfg.clone(),
self.flow.clone(), self.flow.clone(),
on_connect_data, on_connect_data,
peer_addr, peer_addr,
)), ),
},
}, },
proto => unimplemented!("Unsupported HTTP version: {:?}.", proto), proto => unimplemented!("Unsupported HTTP version: {:?}.", proto),
@ -585,48 +590,58 @@ where
} }
} }
#[pin_project(project = StateProj)] pin_project! {
enum State<T, S, B, X, U> #[project = StateProj]
where enum State<T, S, B, X, U>
where
S: Service<Request>, S: Service<Request>,
S::Future: 'static, S::Future: 'static,
S::Error: Into<Error>, S::Error: Into<Error>,
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead,
T: AsyncWrite,
T: Unpin,
B: MessageBody, B: MessageBody,
X: Service<Request, Response = Request>, X: Service<Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>, U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display
{ {
H1(#[pin] h1::Dispatcher<T, S, B, X, U>), H1 { #[pin] dsp: h1::Dispatcher<T, S, B, X, U> },
H2(#[pin] Dispatcher<T, S, B, X, U>), H2 { #[pin] dsp: Dispatcher<T, S, B, X, U> },
H2Handshake( H2Handshake {
Option<( hds: Option<(
Handshake<T, Bytes>, Handshake<T, Bytes>,
ServiceConfig, ServiceConfig,
Rc<HttpFlow<S, X, U>>, Rc<HttpFlow<S, X, U>>,
OnConnectData, OnConnectData,
Option<net::SocketAddr>, Option<net::SocketAddr>,
)>, )>
), }
}
} }
#[pin_project] pin_project! {
pub struct HttpServiceHandlerResponse<T, S, B, X, U> pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead,
T: AsyncWrite,
T: Unpin,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Error: 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, S::Response: 'static,
B: MessageBody,
B: 'static,
X: Service<Request, Response = Request>, X: Service<Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>, U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display
{ {
#[pin] #[pin]
state: State<T, S, B, X, U>, state: State<T, S, B, X, U>
}
} }
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
@ -646,21 +661,23 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project().state.project() { match self.as_mut().project().state.project() {
StateProj::H1(disp) => disp.poll(cx), StateProj::H1 { dsp } => dsp.poll(cx),
StateProj::H2(disp) => disp.poll(cx), StateProj::H2 { dsp } => dsp.poll(cx),
StateProj::H2Handshake(data) => { StateProj::H2Handshake { hds } => {
match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) { match ready!(Pin::new(&mut hds.as_mut().unwrap().0).poll(cx)) {
Ok(conn) => { Ok(conn) => {
let (_, cfg, srv, on_connect_data, peer_addr) = let (_, cfg, srv, on_connect_data, peer_addr) =
data.take().unwrap(); hds.take().unwrap();
self.as_mut().project().state.set(State::H2(Dispatcher::new( self.as_mut().project().state.set(State::H2 {
dsp: Dispatcher::new(
srv, srv,
conn, conn,
on_connect_data, on_connect_data,
cfg, cfg,
None, None,
peer_addr, peer_addr,
))); ),
});
self.poll(cx) self.poll(cx)
} }
Err(err) => { Err(err) => {

View File

@ -5,17 +5,21 @@ use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{IntoService, Service}; use actix_service::{IntoService, Service};
use actix_utils::dispatcher::{Dispatcher as InnerDispatcher, DispatcherError}; use actix_utils::dispatcher::{Dispatcher as InnerDispatcher, DispatcherError};
use pin_project_lite::pin_project;
use super::{Codec, Frame, Message}; use super::{Codec, Frame, Message};
#[pin_project::pin_project] pin_project! {
pub struct Dispatcher<S, T> pub struct Dispatcher<S, T>
where where
S: Service<Frame, Response = Message> + 'static, S: Service<Frame, Response = Message>,
T: AsyncRead + AsyncWrite, S: 'static,
{ T: AsyncRead,
T: AsyncWrite,
{
#[pin] #[pin]
inner: InnerDispatcher<S, T, Codec, Message>, inner: InnerDispatcher<S, T, Codec, Message>
}
} }
impl<S, T> Dispatcher<S, T> impl<S, T> Dispatcher<S, T>

View File

@ -117,7 +117,9 @@ pub trait MapServiceResponseBody {
impl<B: MessageBody + Unpin + 'static> MapServiceResponseBody for ServiceResponse<B> { impl<B: MessageBody + Unpin + 'static> MapServiceResponseBody for ServiceResponse<B> {
fn map_body(self) -> ServiceResponse { fn map_body(self) -> ServiceResponse {
self.map_body(|_, body| ResponseBody::Other(Body::from_message(body))) self.map_body(|_, body| ResponseBody::Other {
body: Body::from_message(body),
})
} }
} }

View File

@ -289,13 +289,13 @@ where
let time = *this.time; let time = *this.time;
let format = this.format.take(); let format = this.format.take();
Poll::Ready(Ok(res.map_body(move |_, body| { Poll::Ready(Ok(res.map_body(move |_, body| ResponseBody::Body {
ResponseBody::Body(StreamLog { body: StreamLog {
body, body,
time, time,
format, format,
size: 0, size: 0,
}) },
}))) })))
} }
} }

View File

@ -277,7 +277,9 @@ pub(crate) mod tests {
let resp = srv.call(req).await.unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
match resp.response().body() { match resp.response().body() {
ResponseBody::Body(Body::Bytes(ref b)) => { ResponseBody::Body {
body: Body::Bytes(ref b),
} => {
let bytes = b.clone(); let bytes = b.clone();
assert_eq!(bytes, Bytes::from_static(b"some")); assert_eq!(bytes, Bytes::from_static(b"some"));
} }
@ -292,21 +294,21 @@ pub(crate) mod tests {
impl BodyTest for ResponseBody<Body> { impl BodyTest for ResponseBody<Body> {
fn bin_ref(&self) -> &[u8] { fn bin_ref(&self) -> &[u8] {
match self { match *self {
ResponseBody::Body(ref b) => match b { ResponseBody::Body { ref body } => match body {
Body::Bytes(ref bin) => &bin, Body::Bytes(ref bin) => &bin,
_ => panic!(), _ => panic!(),
}, },
ResponseBody::Other(ref b) => match b { ResponseBody::Other { ref body } => match body {
Body::Bytes(ref bin) => &bin, Body::Bytes(ref bin) => &bin,
_ => panic!(), _ => panic!(),
}, },
} }
} }
fn body(&self) -> &Body { fn body(&self) -> &Body {
match self { match *self {
ResponseBody::Body(ref b) => b, ResponseBody::Body { ref body } => body,
ResponseBody::Other(ref b) => b, ResponseBody::Other { ref body } => body,
} }
} }
} }

View File

@ -748,7 +748,9 @@ mod tests {
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
match resp.response().body() { match resp.response().body() {
ResponseBody::Body(Body::Bytes(ref b)) => { ResponseBody::Body {
body: Body::Bytes(ref b),
} => {
let bytes = b.clone(); let bytes = b.clone();
assert_eq!(bytes, Bytes::from_static(b"project: project1")); assert_eq!(bytes, Bytes::from_static(b"project: project1"));
} }
@ -849,7 +851,9 @@ mod tests {
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
match resp.response().body() { match resp.response().body() {
ResponseBody::Body(Body::Bytes(ref b)) => { ResponseBody::Body {
body: Body::Bytes(ref b),
} => {
let bytes = b.clone(); let bytes = b.clone();
assert_eq!(bytes, Bytes::from_static(b"project: project_1")); assert_eq!(bytes, Bytes::from_static(b"project: project_1"));
} }
@ -877,7 +881,9 @@ mod tests {
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
match resp.response().body() { match resp.response().body() {
ResponseBody::Body(Body::Bytes(ref b)) => { ResponseBody::Body {
body: Body::Bytes(ref b),
} => {
let bytes = b.clone(); let bytes = b.clone();
assert_eq!(bytes, Bytes::from_static(b"project: test - 1")); assert_eq!(bytes, Bytes::from_static(b"project: test - 1"));
} }