diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 6df579c0a..40c509fdc 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -6,10 +6,14 @@ use std::{ pin::Pin, rc::Rc, task::{Context, Poll}, + time::Duration, }; -use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; -use actix_rt::time::{sleep_until, Instant, Sleep}; +use actix_codec::{Decoder, Encoder, Framed, FramedParts}; +use actix_rt::{ + net::ActixStream, + time::{sleep_until, Instant, Sleep}, +}; use actix_service::Service; use bitflags::bitflags; use bytes::{Buf, BytesMut}; @@ -41,6 +45,7 @@ bitflags! { const SHUTDOWN = 0b0000_0100; const READ_DISCONNECT = 0b0000_1000; const WRITE_DISCONNECT = 0b0001_0000; + const PAYLOAD_PENDING = 0b0010_0000; } } @@ -148,7 +153,7 @@ enum PollResponse { impl Dispatcher where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: Service, S::Error: Into, S::Response: Into>, @@ -204,7 +209,7 @@ where impl InnerDispatcher where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: Service, S::Error: Into, S::Response: Into>, @@ -369,6 +374,9 @@ where while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { match stream.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(item))) => { + // any state other than pending should + // remove the PAYLOAD_PENDING flag + this.flags.remove(Flags::PAYLOAD_PENDING); this.codec.encode( Message::Chunk(Some(item)), &mut this.write_buf, @@ -376,6 +384,7 @@ where } Poll::Ready(None) => { + this.flags.remove(Flags::PAYLOAD_PENDING); this.codec .encode(Message::Chunk(None), &mut this.write_buf)?; // payload stream finished. @@ -385,10 +394,35 @@ where } Poll::Ready(Some(Err(err))) => { - return Err(DispatchError::Service(err)) + this.flags.remove(Flags::PAYLOAD_PENDING); + return Err(DispatchError::Service(err)); } - Poll::Pending => return Ok(PollResponse::DoNothing), + // Payload is pending. register a timer for wake up + // dispatcher in interval and check connection status. + Poll::Pending => { + // write pending flag and configure ka_timer to the + // nearest deadline + if !this.flags.contains(Flags::PAYLOAD_PENDING) { + this.flags.insert(Flags::PAYLOAD_PENDING); + + // pending check use 1 second timer interval. + let target = Instant::now() + Duration::from_secs(1); + + // reset the ka_timer to be used as interval. + match this.ka_timer.as_mut().as_pin_mut() { + Some(timer) => timer.reset(target), + None => { + this.ka_timer.set(Some(sleep_until(target))) + } + } + + // poll the timer to register the interval. + self.poll_keepalive(cx)?; + } + + return Ok(PollResponse::DoNothing); + } } } // buffer is beyond max size. @@ -659,8 +693,19 @@ where Some(mut timer) => { // only operate when keep-alive timer is resolved. if timer.as_mut().poll(cx).is_ready() { + // payload is pending and it's time to check the ready state of io. + if this.flags.contains(Flags::PAYLOAD_PENDING) { + // only interest in the error type. + // The io is ready or not is not important. + let _ = + Pin::new(this.io.as_mut().unwrap()).poll_read_ready(cx)?; + // reset the interval and check again after 1 second. + timer + .as_mut() + .reset(Instant::now() + Duration::from_secs(1)); + let _ = timer.poll(cx); // got timeout during shutdown, drop connection - if this.flags.contains(Flags::SHUTDOWN) { + } else if this.flags.contains(Flags::SHUTDOWN) { return Err(DispatchError::DisconnectTimeout); // exceed deadline. check for any outstanding tasks } else if timer.deadline() >= *this.ka_expire { @@ -824,7 +869,7 @@ where impl Future for Dispatcher where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: Service, S::Error: Into, S::Response: Into>, diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index b79453ebd..d8426305c 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -5,8 +5,8 @@ use std::rc::Rc; use std::task::{Context, Poll}; use std::{fmt, net}; -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_rt::net::TcpStream; +use actix_codec::Framed; +use actix_rt::net::{ActixStream, TcpStream}; use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use futures_core::ready; use futures_util::future::ready; @@ -94,10 +94,10 @@ mod openssl { use super::*; use actix_service::ServiceFactoryExt; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; use actix_tls::accept::TlsError; - impl H1Service, S, B, X, U> + impl H1Service, S, B, X, U> where S: ServiceFactory, S::Error: Into, @@ -108,7 +108,7 @@ mod openssl { X::Error: Into, X::InitError: fmt::Debug, U: ServiceFactory< - (Request, Framed, Codec>), + (Request, Framed, Codec>), Config = (), Response = (), >, @@ -131,7 +131,7 @@ mod openssl { .map_err(TlsError::Tls) .map_init_err(|_| panic!()), ) - .and_then(|io: SslStream| { + .and_then(|io: TlsStream| { let peer_addr = io.get_ref().peer_addr().ok(); ready(Ok((io, peer_addr))) }) @@ -241,7 +241,7 @@ where impl ServiceFactory<(T, Option)> for H1Service where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: ServiceFactory, S::Error: Into, S::Response: Into>, @@ -304,7 +304,7 @@ where impl Future for H1ServiceResponse where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: ServiceFactory, S::Error: Into, S::Response: Into>, @@ -402,7 +402,7 @@ where impl Service<(T, Option)> for H1ServiceHandler where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: Service, S::Error: Into, S::Response: Into>, diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index e00c8d968..0984b3f23 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -93,12 +93,12 @@ where #[cfg(feature = "openssl")] mod openssl { use actix_service::{fn_factory, fn_service, ServiceFactoryExt}; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; use actix_tls::accept::TlsError; use super::*; - impl H2Service, S, B> + impl H2Service, S, B> where S: ServiceFactory, S::Error: Into + 'static, @@ -123,7 +123,7 @@ mod openssl { .map_init_err(|_| panic!()), ) .and_then(fn_factory(|| { - ok::<_, S::InitError>(fn_service(|io: SslStream| { + ok::<_, S::InitError>(fn_service(|io: TlsStream| { let peer_addr = io.get_ref().peer_addr().ok(); ok((io, peer_addr)) })) diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index fee26dcc3..0e399bdbf 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -3,8 +3,8 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, net, rc::Rc}; -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_rt::net::TcpStream; +use actix_codec::Framed; +use actix_rt::net::{ActixStream, TcpStream}; use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use bytes::Bytes; use futures_core::{ready, Future}; @@ -185,10 +185,10 @@ where mod openssl { use super::*; use actix_service::ServiceFactoryExt; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; use actix_tls::accept::TlsError; - impl HttpService, S, B, X, U> + impl HttpService, S, B, X, U> where S: ServiceFactory, S::Error: Into + 'static, @@ -201,13 +201,13 @@ mod openssl { X::InitError: fmt::Debug, >::Future: 'static, U: ServiceFactory< - (Request, Framed, h1::Codec>), + (Request, Framed, h1::Codec>), Config = (), Response = (), >, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, - , h1::Codec>)>>::Future: 'static, + , h1::Codec>)>>::Future: 'static, { /// Create openssl based service pub fn openssl( @@ -225,7 +225,7 @@ mod openssl { .map_err(TlsError::Tls) .map_init_err(|_| panic!()), ) - .and_then(|io: SslStream| async { + .and_then(|io: TlsStream| async { let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { if protos.windows(2).any(|window| window == b"h2") { Protocol::Http2 @@ -314,7 +314,7 @@ mod rustls { impl ServiceFactory<(T, Protocol, Option)> for HttpService where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, @@ -374,7 +374,7 @@ where impl Future for HttpServiceResponse where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, @@ -493,7 +493,7 @@ where impl Service<(T, Protocol, Option)> for HttpServiceHandler where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: Service, S::Error: Into + 'static, S::Future: 'static, @@ -591,7 +591,7 @@ where S: Service, S::Future: 'static, S::Error: Into, - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, B: MessageBody, X: Service, X::Error: Into, @@ -614,7 +614,7 @@ where #[pin_project] pub struct HttpServiceHandlerResponse where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: Service, S::Error: Into + 'static, S::Future: 'static, @@ -631,7 +631,7 @@ where impl Future for HttpServiceHandlerResponse where - T: AsyncRead + AsyncWrite + Unpin, + T: ActixStream, S: Service, S::Error: Into + 'static, S::Future: 'static, diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs index 870a656df..77bd459f1 100644 --- a/actix-http/src/test.rs +++ b/actix-http/src/test.rs @@ -10,6 +10,7 @@ use std::{ }; use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; +use actix_rt::net::ActixStream; use bytes::{Bytes, BytesMut}; use http::{Method, Uri, Version}; @@ -396,3 +397,23 @@ impl AsyncWrite for TestSeqBuffer { Poll::Ready(Ok(())) } } + +impl ActixStream for TestBuffer { + fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl ActixStream for TestSeqBuffer { + fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +}