use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, io}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; use futures_util::{ future::{ready, Either, Ready}, ready, }; use log::trace; pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; pub use tokio_openssl::SslStream; use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; use crate::connect::{ Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection, }; /// OpenSSL connector factory pub struct OpensslConnector { connector: SslConnector, } impl OpensslConnector { pub fn new(connector: SslConnector) -> Self { OpensslConnector { connector } } } impl OpensslConnector { pub fn service(connector: SslConnector) -> OpensslConnectorService { OpensslConnectorService { connector } } } impl Clone for OpensslConnector { fn clone(&self) -> Self { Self { connector: self.connector.clone(), } } } impl ServiceFactory> for OpensslConnector where T: Address + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { type Response = Connection>; type Error = io::Error; type Config = (); type Service = OpensslConnectorService; type InitError = (); type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { ready(Ok(OpensslConnectorService { connector: self.connector.clone(), })) } } pub struct OpensslConnectorService { connector: SslConnector, } impl Clone for OpensslConnectorService { fn clone(&self) -> Self { Self { connector: self.connector.clone(), } } } impl Service> for OpensslConnectorService where T: Address + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { type Response = Connection>; type Error = io::Error; #[allow(clippy::type_complexity)] type Future = Either, Ready>>; actix_service::always_ready!(); fn call(&self, stream: Connection) -> Self::Future { trace!("SSL Handshake start for: {:?}", stream.host()); let (io, stream) = stream.replace(()); let host = stream.host().to_string(); match self.connector.configure() { Err(e) => Either::Right(ready(Err(io::Error::new(io::ErrorKind::Other, e)))), Ok(config) => { let ssl = config .into_ssl(&host) .expect("SSL connect configuration was invalid."); Either::Left(ConnectAsyncExt { io: Some(SslStream::new(ssl, io).unwrap()), stream: Some(stream), _t: PhantomData, }) } } } } pub struct ConnectAsyncExt { io: Option>, stream: Option>, _t: PhantomData, } impl Future for ConnectAsyncExt where U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { type Output = Result>, io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match ready!(Pin::new(this.io.as_mut().unwrap()).poll_connect(cx)) { Ok(_) => { let stream = this.stream.take().unwrap(); trace!("SSL Handshake success: {:?}", stream.host()); Poll::Ready(Ok(stream.replace(this.io.take().unwrap()).1)) } Err(e) => { trace!("SSL Handshake error: {:?}", e); Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) } } } } pub struct OpensslConnectServiceFactory { tcp: ConnectServiceFactory, openssl: OpensslConnector, } impl OpensslConnectServiceFactory { /// Construct new OpensslConnectService factory pub fn new(connector: SslConnector) -> Self { OpensslConnectServiceFactory { tcp: ConnectServiceFactory::default(), openssl: OpensslConnector::new(connector), } } /// Construct new connect service with custom DNS resolver pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self { OpensslConnectServiceFactory { tcp: ConnectServiceFactory::with_resolver(resolver), openssl: OpensslConnector::new(connector), } } /// Construct OpenSSL connect service pub fn service(&self) -> OpensslConnectService { OpensslConnectService { tcp: self.tcp.service(), openssl: OpensslConnectorService { connector: self.openssl.connector.clone(), }, } } } impl Clone for OpensslConnectServiceFactory { fn clone(&self) -> Self { OpensslConnectServiceFactory { tcp: self.tcp.clone(), openssl: self.openssl.clone(), } } } impl ServiceFactory> for OpensslConnectServiceFactory { type Response = SslStream; type Error = ConnectError; type Config = (); type Service = OpensslConnectService; type InitError = (); type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { ready(Ok(self.service())) } } #[derive(Clone)] pub struct OpensslConnectService { tcp: ConnectService, openssl: OpensslConnectorService, } impl Service> for OpensslConnectService { type Response = SslStream; type Error = ConnectError; type Future = OpensslConnectServiceResponse; actix_service::always_ready!(); fn call(&self, req: Connect) -> Self::Future { OpensslConnectServiceResponse { fut1: Some(self.tcp.call(req)), fut2: None, openssl: self.openssl.clone(), } } } pub struct OpensslConnectServiceResponse { fut1: Option< as Service>>::Future>, fut2: Option<>>::Future>, openssl: OpensslConnectorService, } impl Future for OpensslConnectServiceResponse { type Output = Result, ConnectError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(ref mut fut) = self.fut1 { match ready!(Pin::new(fut).poll(cx)) { Ok(res) => { let _ = self.fut1.take(); self.fut2 = Some(self.openssl.call(res)); } Err(e) => return Poll::Ready(Err(e)), } } if let Some(ref mut fut) = self.fut2 { match ready!(Pin::new(fut).poll(cx)) { Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)), Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new( io::ErrorKind::Other, e, )))), } } else { Poll::Pending } } }