use std::{ fmt, future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, }; pub use rustls::Session; pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; pub use webpki_roots::TLS_SERVER_ROOTS; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::{Service, ServiceFactory}; use futures_core::{future::LocalBoxFuture, ready}; use log::trace; use tokio_rustls::{Connect, TlsConnector}; use webpki::DNSNameRef; use crate::connect::{Address, Connection}; /// Rustls connector factory pub struct RustlsConnector { connector: Arc, } impl RustlsConnector { pub fn new(connector: Arc) -> Self { RustlsConnector { connector } } } impl RustlsConnector { pub fn service(connector: Arc) -> RustlsConnectorService { RustlsConnectorService { connector } } } impl Clone for RustlsConnector { fn clone(&self) -> Self { Self { connector: self.connector.clone(), } } } impl ServiceFactory> for RustlsConnector where U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, { type Response = Connection>; type Error = std::io::Error; type Config = (); type Service = RustlsConnectorService; type InitError = (); type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { let connector = self.connector.clone(); Box::pin(async { Ok(RustlsConnectorService { connector }) }) } } pub struct RustlsConnectorService { connector: Arc, } impl Clone for RustlsConnectorService { fn clone(&self) -> Self { Self { connector: self.connector.clone(), } } } impl Service> for RustlsConnectorService where T: Address, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, { type Response = Connection>; type Error = std::io::Error; type Future = ConnectAsyncExt; actix_service::always_ready!(); fn call(&mut self, stream: Connection) -> Self::Future { trace!("SSL Handshake start for: {:?}", stream.host()); let (io, stream) = stream.replace(()); let host = DNSNameRef::try_from_ascii_str(stream.host()) .expect("rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54"); ConnectAsyncExt { fut: TlsConnector::from(self.connector.clone()).connect(host, io), stream: Some(stream), } } } pub struct ConnectAsyncExt { fut: Connect, stream: Option>, } impl Future for ConnectAsyncExt where T: Address, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, { type Output = Result>, std::io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); let stream = ready!(Pin::new(&mut this.fut).poll(cx))?; let s = this.stream.take().unwrap(); trace!("SSL Handshake success: {:?}", s.host()); Poll::Ready(Ok(s.replace(stream).1)) } }