add interval check for payload pending

This commit is contained in:
fakeshadow 2021-02-27 22:29:40 +08:00
parent 8f2a97c6e3
commit 2bd42d7002
5 changed files with 99 additions and 33 deletions

View File

@ -6,10 +6,14 @@ use std::{
pin::Pin, pin::Pin,
rc::Rc, rc::Rc,
task::{Context, Poll}, task::{Context, Poll},
time::Duration,
}; };
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; use actix_codec::{Decoder, Encoder, Framed, FramedParts};
use actix_rt::time::{sleep_until, Instant, Sleep}; use actix_rt::{
net::ActixStream,
time::{sleep_until, Instant, Sleep},
};
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Buf, BytesMut}; use bytes::{Buf, BytesMut};
@ -41,6 +45,7 @@ bitflags! {
const SHUTDOWN = 0b0000_0100; const SHUTDOWN = 0b0000_0100;
const READ_DISCONNECT = 0b0000_1000; const READ_DISCONNECT = 0b0000_1000;
const WRITE_DISCONNECT = 0b0001_0000; const WRITE_DISCONNECT = 0b0001_0000;
const PAYLOAD_PENDING = 0b0010_0000;
} }
} }
@ -148,7 +153,7 @@ enum PollResponse {
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@ -204,7 +209,7 @@ where
impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U> impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@ -369,6 +374,9 @@ where
while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
match stream.as_mut().poll_next(cx) { match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(item))) => { Poll::Ready(Some(Ok(item))) => {
// any state other than pending should
// remove the PAYLOAD_PENDING flag
this.flags.remove(Flags::PAYLOAD_PENDING);
this.codec.encode( this.codec.encode(
Message::Chunk(Some(item)), Message::Chunk(Some(item)),
&mut this.write_buf, &mut this.write_buf,
@ -376,6 +384,7 @@ where
} }
Poll::Ready(None) => { Poll::Ready(None) => {
this.flags.remove(Flags::PAYLOAD_PENDING);
this.codec this.codec
.encode(Message::Chunk(None), &mut this.write_buf)?; .encode(Message::Chunk(None), &mut this.write_buf)?;
// payload stream finished. // payload stream finished.
@ -385,10 +394,35 @@ where
} }
Poll::Ready(Some(Err(err))) => { Poll::Ready(Some(Err(err))) => {
return Err(DispatchError::Service(err)) this.flags.remove(Flags::PAYLOAD_PENDING);
return Err(DispatchError::Service(err));
} }
Poll::Pending => return Ok(PollResponse::DoNothing), // Payload is pending. register a timer for wake up
// dispatcher in interval and check connection status.
Poll::Pending => {
// write pending flag and configure ka_timer to the
// nearest deadline
if !this.flags.contains(Flags::PAYLOAD_PENDING) {
this.flags.insert(Flags::PAYLOAD_PENDING);
// pending check use 1 second timer interval.
let target = Instant::now() + Duration::from_secs(1);
// reset the ka_timer to be used as interval.
match this.ka_timer.as_mut().as_pin_mut() {
Some(timer) => timer.reset(target),
None => {
this.ka_timer.set(Some(sleep_until(target)))
}
}
// poll the timer to register the interval.
self.poll_keepalive(cx)?;
}
return Ok(PollResponse::DoNothing);
}
} }
} }
// buffer is beyond max size. // buffer is beyond max size.
@ -659,8 +693,19 @@ where
Some(mut timer) => { Some(mut timer) => {
// only operate when keep-alive timer is resolved. // only operate when keep-alive timer is resolved.
if timer.as_mut().poll(cx).is_ready() { if timer.as_mut().poll(cx).is_ready() {
// payload is pending and it's time to check the ready state of io.
if this.flags.contains(Flags::PAYLOAD_PENDING) {
// only interest in the error type.
// The io is ready or not is not important.
let _ =
Pin::new(this.io.as_mut().unwrap()).poll_read_ready(cx)?;
// reset the interval and check again after 1 second.
timer
.as_mut()
.reset(Instant::now() + Duration::from_secs(1));
let _ = timer.poll(cx);
// got timeout during shutdown, drop connection // got timeout during shutdown, drop connection
if this.flags.contains(Flags::SHUTDOWN) { } else if this.flags.contains(Flags::SHUTDOWN) {
return Err(DispatchError::DisconnectTimeout); return Err(DispatchError::DisconnectTimeout);
// exceed deadline. check for any outstanding tasks // exceed deadline. check for any outstanding tasks
} else if timer.deadline() >= *this.ka_expire { } else if timer.deadline() >= *this.ka_expire {
@ -824,7 +869,7 @@ where
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,

View File

@ -5,8 +5,8 @@ use std::rc::Rc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{fmt, net}; use std::{fmt, net};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::Framed;
use actix_rt::net::TcpStream; use actix_rt::net::{ActixStream, TcpStream};
use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
use futures_core::ready; use futures_core::ready;
use futures_util::future::ready; use futures_util::future::ready;
@ -94,10 +94,10 @@ mod openssl {
use super::*; use super::*;
use actix_service::ServiceFactoryExt; use actix_service::ServiceFactoryExt;
use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream};
use actix_tls::accept::TlsError; use actix_tls::accept::TlsError;
impl<S, B, X, U> H1Service<SslStream<TcpStream>, S, B, X, U> impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
where where
S: ServiceFactory<Request, Config = ()>, S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Error>, S::Error: Into<Error>,
@ -108,7 +108,7 @@ mod openssl {
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory< U: ServiceFactory<
(Request, Framed<SslStream<TcpStream>, Codec>), (Request, Framed<TlsStream<TcpStream>, Codec>),
Config = (), Config = (),
Response = (), Response = (),
>, >,
@ -131,7 +131,7 @@ mod openssl {
.map_err(TlsError::Tls) .map_err(TlsError::Tls)
.map_init_err(|_| panic!()), .map_init_err(|_| panic!()),
) )
.and_then(|io: SslStream<TcpStream>| { .and_then(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().peer_addr().ok(); let peer_addr = io.get_ref().peer_addr().ok();
ready(Ok((io, peer_addr))) ready(Ok((io, peer_addr)))
}) })
@ -241,7 +241,7 @@ where
impl<T, S, B, X, U> ServiceFactory<(T, Option<net::SocketAddr>)> impl<T, S, B, X, U> ServiceFactory<(T, Option<net::SocketAddr>)>
for H1Service<T, S, B, X, U> for H1Service<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: ServiceFactory<Request, Config = ()>, S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@ -304,7 +304,7 @@ where
impl<T, S, B, X, U> Future for H1ServiceResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for H1ServiceResponse<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: ServiceFactory<Request>, S: ServiceFactory<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@ -402,7 +402,7 @@ where
impl<T, S, B, X, U> Service<(T, Option<net::SocketAddr>)> impl<T, S, B, X, U> Service<(T, Option<net::SocketAddr>)>
for H1ServiceHandler<T, S, B, X, U> for H1ServiceHandler<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,

View File

@ -93,12 +93,12 @@ where
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
mod openssl { mod openssl {
use actix_service::{fn_factory, fn_service, ServiceFactoryExt}; use actix_service::{fn_factory, fn_service, ServiceFactoryExt};
use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream};
use actix_tls::accept::TlsError; use actix_tls::accept::TlsError;
use super::*; use super::*;
impl<S, B> H2Service<SslStream<TcpStream>, S, B> impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
where where
S: ServiceFactory<Request, Config = ()>, S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Error> + 'static, S::Error: Into<Error> + 'static,
@ -123,7 +123,7 @@ mod openssl {
.map_init_err(|_| panic!()), .map_init_err(|_| panic!()),
) )
.and_then(fn_factory(|| { .and_then(fn_factory(|| {
ok::<_, S::InitError>(fn_service(|io: SslStream<TcpStream>| { ok::<_, S::InitError>(fn_service(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().peer_addr().ok(); let peer_addr = io.get_ref().peer_addr().ok();
ok((io, peer_addr)) ok((io, peer_addr))
})) }))

View File

@ -3,8 +3,8 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{fmt, net, rc::Rc}; use std::{fmt, net, rc::Rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::Framed;
use actix_rt::net::TcpStream; use actix_rt::net::{ActixStream, TcpStream};
use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
use bytes::Bytes; use bytes::Bytes;
use futures_core::{ready, Future}; use futures_core::{ready, Future};
@ -185,10 +185,10 @@ where
mod openssl { mod openssl {
use super::*; use super::*;
use actix_service::ServiceFactoryExt; use actix_service::ServiceFactoryExt;
use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream};
use actix_tls::accept::TlsError; use actix_tls::accept::TlsError;
impl<S, B, X, U> HttpService<SslStream<TcpStream>, S, B, X, U> impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
where where
S: ServiceFactory<Request, Config = ()>, S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Error> + 'static, S::Error: Into<Error> + 'static,
@ -201,13 +201,13 @@ mod openssl {
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
<X::Service as Service<Request>>::Future: 'static, <X::Service as Service<Request>>::Future: 'static,
U: ServiceFactory< U: ServiceFactory<
(Request, Framed<SslStream<TcpStream>, h1::Codec>), (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
Config = (), Config = (),
Response = (), Response = (),
>, >,
U::Error: fmt::Display + Into<Error>, U::Error: fmt::Display + Into<Error>,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service<(Request, Framed<SslStream<TcpStream>, h1::Codec>)>>::Future: 'static, <U::Service as Service<(Request, Framed<TlsStream<TcpStream>, h1::Codec>)>>::Future: 'static,
{ {
/// Create openssl based service /// Create openssl based service
pub fn openssl( pub fn openssl(
@ -225,7 +225,7 @@ mod openssl {
.map_err(TlsError::Tls) .map_err(TlsError::Tls)
.map_init_err(|_| panic!()), .map_init_err(|_| panic!()),
) )
.and_then(|io: SslStream<TcpStream>| async { .and_then(|io: TlsStream<TcpStream>| async {
let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") { if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2 Protocol::Http2
@ -314,7 +314,7 @@ mod rustls {
impl<T, S, B, X, U> ServiceFactory<(T, Protocol, Option<net::SocketAddr>)> impl<T, S, B, X, U> ServiceFactory<(T, Protocol, Option<net::SocketAddr>)>
for HttpService<T, S, B, X, U> for HttpService<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: ServiceFactory<Request, Config = ()>, S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Error> + 'static, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
@ -374,7 +374,7 @@ where
impl<T, S, B, X, U> Future for HttpServiceResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for HttpServiceResponse<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: ServiceFactory<Request>, S: ServiceFactory<Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
@ -493,7 +493,7 @@ where
impl<T, S, B, X, U> Service<(T, Protocol, Option<net::SocketAddr>)> impl<T, S, B, X, U> Service<(T, Protocol, Option<net::SocketAddr>)>
for HttpServiceHandler<T, S, B, X, U> for HttpServiceHandler<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
@ -591,7 +591,7 @@ where
S: Service<Request>, S: Service<Request>,
S::Future: 'static, S::Future: 'static,
S::Error: Into<Error>, S::Error: Into<Error>,
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
B: MessageBody, B: MessageBody,
X: Service<Request, Response = Request>, X: Service<Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
@ -614,7 +614,7 @@ where
#[pin_project] #[pin_project]
pub struct HttpServiceHandlerResponse<T, S, B, X, U> pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
@ -631,7 +631,7 @@ where
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: ActixStream,
S: Service<Request>, S: Service<Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,

View File

@ -10,6 +10,7 @@ use std::{
}; };
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
use actix_rt::net::ActixStream;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use http::{Method, Uri, Version}; use http::{Method, Uri, Version};
@ -396,3 +397,23 @@ impl AsyncWrite for TestSeqBuffer {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
impl ActixStream for TestBuffer {
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl ActixStream for TestSeqBuffer {
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}