diff --git a/actix-tls/Cargo.toml b/actix-tls/Cargo.toml index 00082278..df97b83b 100755 --- a/actix-tls/Cargo.toml +++ b/actix-tls/Cargo.toml @@ -49,6 +49,7 @@ derive_more = "0.99.5" futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } http = { version = "0.2.3", optional = true } log = "0.4" +pin-project-lite = "0.2.7" tokio-util = { version = "0.6.3", default-features = false } # openssl diff --git a/actix-tls/src/accept/rustls.rs b/actix-tls/src/accept/rustls.rs index ffac687a..a5a0b461 100644 --- a/actix-tls/src/accept/rustls.rs +++ b/actix-tls/src/accept/rustls.rs @@ -5,13 +5,18 @@ use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, + time::Duration, }; use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; -use actix_rt::net::{ActixStream, Ready}; +use actix_rt::{ + net::{ActixStream, Ready}, + time::{sleep, Sleep}, +}; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::{Counter, CounterGuard}; use futures_core::future::LocalBoxFuture; +use pin_project_lite::pin_project; use tokio_rustls::{Accept, TlsAcceptor}; pub use tokio_rustls::rustls::{ServerConfig, Session}; @@ -158,22 +163,40 @@ impl Service for AcceptorService { fn call(&self, req: T) -> Self::Future { AcceptorServiceFut { - _guard: self.conns.get(), fut: self.acceptor.accept(req), + // default tls accept timeout is 3 seconds. + // TODO: make it configurable with service builder. + timeout: sleep(Duration::from_secs(3)), + _guard: self.conns.get(), } } } -pub struct AcceptorServiceFut { - fut: Accept, - _guard: CounterGuard, +pin_project! { + pub struct AcceptorServiceFut { + fut: Accept, + #[pin] + timeout: Sleep, + _guard: CounterGuard, + } } impl Future for AcceptorServiceFut { type Output = Result, io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - Pin::new(&mut this.fut).poll(cx).map_ok(TlsStream) + let mut this = self.project(); + match Pin::new(&mut this.fut).poll(cx) { + Poll::Ready(res) => Poll::Ready(res.map(TlsStream)), + Poll::Pending => { + this.timeout.poll(cx).map(|_| { + // TODO: make the error message typed. + Err(io::Error::new( + io::ErrorKind::TimedOut, + "Tls Handshake timedout", + )) + }) + } + } } }