diff --git a/actix-tls/src/accept/mod.rs b/actix-tls/src/accept/mod.rs index 39fd72e6..c9d6cd73 100644 --- a/actix-tls/src/accept/mod.rs +++ b/actix-tls/src/accept/mod.rs @@ -38,6 +38,7 @@ pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256); feature = "rustls-0_20", feature = "rustls-0_21", feature = "rustls-0_22", + feature = "rustls-0_23", feature = "native-tls", ))] pub(crate) const DEFAULT_TLS_HANDSHAKE_TIMEOUT: std::time::Duration = diff --git a/actix-tls/src/accept/rustls_0_23.rs b/actix-tls/src/accept/rustls_0_23.rs index 64ab680b..9d2025ba 100644 --- a/actix-tls/src/accept/rustls_0_23.rs +++ b/actix-tls/src/accept/rustls_0_23.rs @@ -1,161 +1,198 @@ -//! Rustls based connector service. +//! `rustls` v0.23 based TLS connection acceptor service. //! -//! See [`TlsConnector`] for main connector service factory docs. +//! See [`Acceptor`] for main service factory docs. use std::{ + convert::Infallible, future::Future, - io, + io::{self, IoSlice}, pin::Pin, sync::Arc, task::{Context, Poll}, + time::Duration, }; -use actix_rt::net::ActixStream; -use actix_service::{Service, ServiceFactory}; -use actix_utils::future::{ok, Ready}; -use futures_core::ready; -use rustls_pki_types_1::ServerName; -use tokio_rustls::{ - client::TlsStream as AsyncTlsStream, rustls::ClientConfig, Connect as RustlsConnect, - TlsConnector as RustlsTlsConnector, +use actix_rt::{ + net::{ActixStream, Ready}, + time::{sleep, Sleep}, }; +use actix_service::{Service, ServiceFactory}; +use actix_utils::{ + counter::{Counter, CounterGuard}, + future::{ready, Ready as FutReady}, +}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::{Accept, TlsAcceptor}; use tokio_rustls_026 as tokio_rustls; -use crate::connect::{Connection, Host}; +use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; pub mod reexports { - //! Re-exports from the `rustls` v0.23 ecosystem that are useful for connectors. + //! Re-exports from `rustls` that are useful for acceptors. - pub use tokio_rustls_026::{client::TlsStream as AsyncTlsStream, rustls::ClientConfig}; - #[cfg(feature = "rustls-0_23-webpki-roots")] - pub use webpki_roots_026::TLS_SERVER_ROOTS; + pub use tokio_rustls_026::rustls::ServerConfig; } -/// Returns root certificates via `rustls-native-certs` crate as a rustls certificate store. -/// -/// See [`rustls_native_certs::load_native_certs()`] for more info on behavior and errors. -#[cfg(feature = "rustls-0_23-native-roots")] -pub fn native_roots_cert_store() -> io::Result { - let mut root_certs = tokio_rustls::rustls::RootCertStore::empty(); +/// Wraps a `rustls` based async TLS stream in order to implement [`ActixStream`]. +pub struct TlsStream(tokio_rustls::server::TlsStream); - for cert in rustls_native_certs_07::load_native_certs()? { - root_certs.add(cert).unwrap(); - } +impl_more::impl_from!( in tokio_rustls::server::TlsStream => TlsStream); +impl_more::impl_deref_and_mut!( in TlsStream => tokio_rustls::server::TlsStream); - Ok(root_certs) -} - -/// Returns standard root certificates from `webpki-roots` crate as a rustls certificate store. -#[cfg(feature = "rustls-0_23-webpki-roots")] -pub fn webpki_roots_cert_store() -> tokio_rustls::rustls::RootCertStore { - let mut root_certs = tokio_rustls::rustls::RootCertStore::empty(); - root_certs.extend(webpki_roots_026::TLS_SERVER_ROOTS.to_owned()); - root_certs -} - -/// Connector service factory using `rustls`. -#[derive(Clone)] -pub struct TlsConnector { - connector: Arc, -} - -impl TlsConnector { - /// Constructs new connector service factory from a `rustls` client configuration. - pub fn new(connector: Arc) -> Self { - TlsConnector { connector } - } - - /// Constructs new connector service from a `rustls` client configuration. - pub fn service(connector: Arc) -> TlsConnectorService { - TlsConnectorService { connector } +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_read(cx, buf) } } -impl ServiceFactory> for TlsConnector -where - R: Host, - IO: ActixStream + 'static, -{ - type Response = Connection>; - type Error = io::Error; +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } +} + +impl ActixStream for TlsStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + IO::poll_read_ready((**self).get_ref().0, cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + IO::poll_write_ready((**self).get_ref().0, cx) + } +} + +/// Accept TLS connections via the `rustls` crate. +pub struct Acceptor { + config: Arc, + handshake_timeout: Duration, +} + +impl Acceptor { + /// Constructs `rustls` based acceptor service factory. + pub fn new(config: reexports::ServerConfig) -> Self { + Acceptor { + config: Arc::new(config), + handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT, + } + } + + /// Limit the amount of time that the acceptor will wait for a TLS handshake to complete. + /// + /// Default timeout is 3 seconds. + pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self { + self.handshake_timeout = handshake_timeout; + self + } +} + +impl Clone for Acceptor { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + handshake_timeout: self.handshake_timeout, + } + } +} + +impl ServiceFactory for Acceptor { + type Response = TlsStream; + type Error = TlsError; type Config = (); - type Service = TlsConnectorService; + type Service = AcceptorService; type InitError = (); - type Future = Ready>; + type Future = FutReady>; fn new_service(&self, _: ()) -> Self::Future { - ok(TlsConnectorService { - connector: self.connector.clone(), - }) + let res = MAX_CONN_COUNTER.with(|conns| { + Ok(AcceptorService { + acceptor: self.config.clone().into(), + conns: conns.clone(), + handshake_timeout: self.handshake_timeout, + }) + }); + + ready(res) } } -/// Connector service using `rustls`. -#[derive(Clone)] -pub struct TlsConnectorService { - connector: Arc, +/// Rustls based acceptor service. +pub struct AcceptorService { + acceptor: TlsAcceptor, + conns: Counter, + handshake_timeout: Duration, } -impl Service> for TlsConnectorService -where - R: Host, - IO: ActixStream, -{ - type Response = Connection>; - type Error = io::Error; - type Future = ConnectFut; +impl Service for AcceptorService { + type Response = TlsStream; + type Error = TlsError; + type Future = AcceptFut; - actix_service::always_ready!(); + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + if self.conns.available(cx) { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } - fn call(&self, connection: Connection) -> Self::Future { - tracing::trace!("TLS handshake start for: {:?}", connection.hostname()); - let (stream, conn) = connection.replace_io(()); - - match ServerName::try_from(conn.hostname()) { - Ok(host) => ConnectFut::Future { - connect: RustlsTlsConnector::from(Arc::clone(&self.connector)) - .connect(host.to_owned(), stream), - connection: Some(conn), - }, - Err(_) => ConnectFut::InvalidServerName, + fn call(&self, req: IO) -> Self::Future { + AcceptFut { + fut: self.acceptor.accept(req), + timeout: sleep(self.handshake_timeout), + _guard: self.conns.get(), } } } -/// Connect future for Rustls service. -#[doc(hidden)] -#[allow(clippy::large_enum_variant)] -pub enum ConnectFut { - InvalidServerName, - Future { - connect: RustlsConnect, - connection: Option>, - }, +pin_project! { + /// Accept future for Rustls service. + #[doc(hidden)] + pub struct AcceptFut { + fut: Accept, + #[pin] + timeout: Sleep, + _guard: CounterGuard, + } } -impl Future for ConnectFut -where - R: Host, - IO: ActixStream, -{ - type Output = io::Result>>; +impl Future for AcceptFut { + type Output = Result, TlsError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.get_mut() { - Self::InvalidServerName => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "connection parameters specified invalid server name", - ))), - - Self::Future { - connect, - connection, - } => { - let stream = ready!(Pin::new(connect).poll(cx))?; - let connection = connection.take().unwrap(); - tracing::trace!("TLS handshake success: {:?}", connection.hostname()); - Poll::Ready(Ok(connection.replace_io(stream).1)) - } + let mut this = self.project(); + match Pin::new(&mut this.fut).poll(cx) { + Poll::Ready(Ok(stream)) => Poll::Ready(Ok(TlsStream(stream))), + Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))), + Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)), } } }