From 76c16a4e7b91a2c6f1dba1dd0a6f74aa2ba6af31 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Mon, 29 Nov 2021 18:16:03 +0000 Subject: [PATCH] reorganise and document --- actix-tls/CHANGES.md | 13 +- actix-tls/Cargo.toml | 7 +- actix-tls/src/accept/mod.rs | 13 +- actix-tls/src/accept/native_tls.rs | 45 +-- actix-tls/src/accept/openssl.rs | 43 +-- actix-tls/src/accept/rustls.rs | 37 +-- actix-tls/src/connect/address.rs | 22 ++ actix-tls/src/connect/connect.rs | 359 --------------------- actix-tls/src/connect/connect_addrs.rs | 81 +++++ actix-tls/src/connect/connection.rs | 54 ++++ actix-tls/src/connect/connector.rs | 242 ++++++-------- actix-tls/src/connect/info.rs | 257 +++++++++++++++ actix-tls/src/connect/mod.rs | 53 +-- actix-tls/src/connect/native_tls.rs | 90 ++++++ actix-tls/src/connect/{tls => }/openssl.rs | 68 ++-- actix-tls/src/connect/resolve.rs | 207 +----------- actix-tls/src/connect/resolver.rs | 212 ++++++++++++ actix-tls/src/connect/{tls => }/rustls.rs | 92 +++--- actix-tls/src/connect/service.rs | 129 -------- actix-tls/src/connect/tcp.rs | 224 +++++++++++-- actix-tls/src/connect/tls/mod.rs | 10 - actix-tls/src/connect/tls/native_tls.rs | 89 ----- actix-tls/src/lib.rs | 6 +- actix-tls/tests/test_connect.rs | 10 +- actix-tls/tests/test_resolvers.rs | 6 +- 25 files changed, 1202 insertions(+), 1167 deletions(-) create mode 100644 actix-tls/src/connect/address.rs delete mode 100755 actix-tls/src/connect/connect.rs create mode 100644 actix-tls/src/connect/connect_addrs.rs create mode 100644 actix-tls/src/connect/connection.rs create mode 100755 actix-tls/src/connect/info.rs create mode 100644 actix-tls/src/connect/native_tls.rs rename actix-tls/src/connect/{tls => }/openssl.rs (59%) mode change 100755 => 100644 actix-tls/src/connect/resolve.rs create mode 100755 actix-tls/src/connect/resolver.rs rename actix-tls/src/connect/{tls => }/rustls.rs (63%) delete mode 100755 actix-tls/src/connect/service.rs mode change 100644 => 100755 actix-tls/src/connect/tcp.rs delete mode 100644 actix-tls/src/connect/tls/mod.rs delete mode 100644 actix-tls/src/connect/tls/native_tls.rs diff --git a/actix-tls/CHANGES.md b/actix-tls/CHANGES.md index 52ce7751..edc22c90 100644 --- a/actix-tls/CHANGES.md +++ b/actix-tls/CHANGES.md @@ -6,6 +6,16 @@ * Remove redundant `connect::Connection::from_parts` method. [#422] * Rename TLS acceptor service future types and hide from docs. [#422] * Implement `Error` for `ConnectError`. [#422] +* Implement `Error` for `TlsError` where both types also implement `Error`. [#422] +* Rename `accept::native_tls::{NativeTlsAcceptorService => AcceptorService}`. [#422] +* Make `ConnectAddrsIter` private. [#422] +* Rename method `connect::Connection::{host => hostname}`. [#422] +* Rename struct `connect::{Connect => ConnectionInfo}`. [#422] +* Rename struct `connect::{ConnectServiceFactory => Connector}`. [#422] +* Rename struct `connect::{ConnectService => ConnectorService}`. [#422] +* Remove `connect::{new_connector, new_connector_factory, default_connector, default_connector_factory}` methods. [#422] +* Convert `connect::ResolverService` from enum to struct. [#422] +* Remove `connect::native_tls::Connector::service` method. [#422] [#422]: https://github.com/actix/actix-net/pull/422 @@ -48,8 +58,7 @@ * Remove `connect::ssl::openssl::OpensslConnectService`. [#297] * Add `connect::ssl::native_tls` module for native tls support. [#295] * Rename `accept::{nativetls => native_tls}`. [#295] -* Remove `connect::TcpConnectService` type. Service caller expecting a `TcpStream` should use - `connect::ConnectService` instead and call `Connection::into_parts`. [#299] +* Remove `connect::TcpConnectService` type. Service caller expecting a `TcpStream` should use `connect::ConnectService` instead and call `Connection::into_parts`. [#299] [#295]: https://github.com/actix/actix-net/pull/295 [#296]: https://github.com/actix/actix-net/pull/296 diff --git a/actix-tls/Cargo.toml b/actix-tls/Cargo.toml index b9b75d35..a878dcdc 100755 --- a/actix-tls/Cargo.toml +++ b/actix-tls/Cargo.toml @@ -13,7 +13,8 @@ license = "MIT OR Apache-2.0" edition = "2018" [package.metadata.docs.rs] -features = ["openssl", "rustls", "native-tls", "accept", "connect", "uri"] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] [lib] name = "actix_tls" @@ -48,11 +49,13 @@ actix-utils = "3.0.0" derive_more = "0.99.5" futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } -http = { version = "0.2.3", optional = true } log = "0.4" pin-project-lite = "0.2.7" tokio-util = { version = "0.6.3", default-features = false } +# uri +http = { version = "0.2.3", optional = true } + # openssl tls-openssl = { package = "openssl", version = "0.10.9", optional = true } tokio-openssl = { version = "0.6", optional = true } diff --git a/actix-tls/src/accept/mod.rs b/actix-tls/src/accept/mod.rs index 300e1767..de220ac5 100644 --- a/actix-tls/src/accept/mod.rs +++ b/actix-tls/src/accept/mod.rs @@ -1,4 +1,4 @@ -//! TLS acceptor services. +//! TLS connection acceptor services. use std::{ convert::Infallible, @@ -6,14 +6,18 @@ use std::{ }; use actix_utils::counter::Counter; +use derive_more::{Display, Error}; #[cfg(feature = "openssl")] +#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] pub mod openssl; #[cfg(feature = "rustls")] +#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] pub mod rustls; #[cfg(feature = "native-tls")] +#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] pub mod native_tls; pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256); @@ -41,15 +45,18 @@ pub fn max_concurrent_tls_connect(num: usize) { /// All TLS acceptors from this crate will return the `SvcErr` type parameter as [`Infallible`], /// which can be cast to your own service type, inferred or otherwise, /// using [`into_service_error`](Self::into_service_error). -#[derive(Debug)] +#[derive(Debug, Display, Error)] pub enum TlsError { /// TLS handshake has timed-out. + #[display(fmt = "TLS handshake has timed-out")] Timeout, /// Wraps TLS service errors. + #[display(fmt = "TLS handshake error")] Tls(TlsErr), - /// Wraps inner service errors. + /// Wraps service errors. + #[display(fmt = "Service error")] Service(SvcErr), } diff --git a/actix-tls/src/accept/native_tls.rs b/actix-tls/src/accept/native_tls.rs index c6b77a39..57ac287a 100644 --- a/actix-tls/src/accept/native_tls.rs +++ b/actix-tls/src/accept/native_tls.rs @@ -1,9 +1,10 @@ -//! Native-TLS based acceptor service. +//! `native-tls` based TLS connection acceptor service. +//! +//! See [`Acceptor`] for main service factory docs. use std::{ convert::Infallible, io::{self, IoSlice}, - ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -16,35 +17,16 @@ use actix_rt::{ }; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::Counter; +use derive_more::{Deref, DerefMut, From}; use futures_core::future::LocalBoxFuture; - pub use tokio_native_tls::{native_tls::Error, TlsAcceptor}; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; -/// Wraps a [`tokio_native_tls::TlsStream`] in order to impl [`ActixStream`] trait. +/// Wraps a `native-tls` based async TLS stream in order to implement [`ActixStream`]. +#[derive(Deref, DerefMut, From)] pub struct TlsStream(tokio_native_tls::TlsStream); -impl From> for TlsStream { - fn from(stream: tokio_native_tls::TlsStream) -> Self { - Self(stream) - } -} - -impl Deref for TlsStream { - type Target = tokio_native_tls::TlsStream; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TlsStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, @@ -95,17 +77,14 @@ impl ActixStream for TlsStream { } } -/// Accept TLS connections via `native-tls` package. -/// -/// `native-tls` feature enables this `Acceptor` type. +/// Accept TLS connections via the `native-tls` crate. pub struct Acceptor { acceptor: TlsAcceptor, handshake_timeout: Duration, } impl Acceptor { - /// Create `native-tls` based `Acceptor` service factory. - #[inline] + /// Constructs `native-tls` based `Acceptor` service factory. pub fn new(acceptor: TlsAcceptor) -> Self { Acceptor { acceptor, @@ -136,13 +115,13 @@ impl ServiceFactory for Acceptor { type Response = TlsStream; type Error = TlsError; type Config = (); - type Service = NativeTlsAcceptorService; + type Service = AcceptorService; type InitError = (); type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { let res = MAX_CONN_COUNTER.with(|conns| { - Ok(NativeTlsAcceptorService { + Ok(AcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), handshake_timeout: self.handshake_timeout, @@ -154,13 +133,13 @@ impl ServiceFactory for Acceptor { } /// Native-TLS based acceptor service. -pub struct NativeTlsAcceptorService { +pub struct AcceptorService { acceptor: TlsAcceptor, conns: Counter, handshake_timeout: Duration, } -impl Service for NativeTlsAcceptorService { +impl Service for AcceptorService { type Response = TlsStream; type Error = TlsError; type Future = LocalBoxFuture<'static, Result>; diff --git a/actix-tls/src/accept/openssl.rs b/actix-tls/src/accept/openssl.rs index 3320fcce..777e09f8 100644 --- a/actix-tls/src/accept/openssl.rs +++ b/actix-tls/src/accept/openssl.rs @@ -1,10 +1,11 @@ -//! OpenSSL based acceptor service. +//! `openssl` based TLS acceptor service. +//! +//! See [`Acceptor`] for main service factory docs. use std::{ convert::Infallible, future::Future, io::{self, IoSlice}, - ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -17,37 +18,19 @@ use actix_rt::{ }; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::{Counter, CounterGuard}; +use derive_more::{Deref, DerefMut, From}; use futures_core::future::LocalBoxFuture; pub use openssl::ssl::{ - AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder, + AlpnError, Error, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder, }; use pin_project_lite::pin_project; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; -/// Wraps a [`tokio_openssl::SslStream`] in order to impl [`ActixStream`] trait. +/// Wraps an `openssl` based async TLS stream in order to implement [`ActixStream`]. +#[derive(Deref, DerefMut, From)] pub struct TlsStream(tokio_openssl::SslStream); -impl From> for TlsStream { - fn from(stream: tokio_openssl::SslStream) -> Self { - Self(stream) - } -} - -impl Deref for TlsStream { - type Target = tokio_openssl::SslStream; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TlsStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, @@ -98,9 +81,7 @@ impl ActixStream for TlsStream { } } -/// Accept TLS connections via `openssl` package. -/// -/// `openssl` feature enables this `Acceptor` type. +/// Accept TLS connections via the `openssl` crate. pub struct Acceptor { acceptor: SslAcceptor, handshake_timeout: Duration, @@ -137,7 +118,7 @@ impl Clone for Acceptor { impl ServiceFactory for Acceptor { type Response = TlsStream; - type Error = TlsError; + type Error = TlsError; type Config = (); type Service = AcceptorService; type InitError = (); @@ -165,7 +146,7 @@ pub struct AcceptorService { impl Service for AcceptorService { type Response = TlsStream; - type Error = TlsError; + type Error = TlsError; type Future = AcceptFut; fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll> { @@ -189,7 +170,7 @@ impl Service for AcceptorService { } pin_project! { - /// Accept future for Rustls service. + /// Accept future for OpenSSL service. #[doc(hidden)] pub struct AcceptFut { stream: Option>, @@ -200,7 +181,7 @@ pin_project! { } impl Future for AcceptFut { - type Output = Result, TlsError>; + type Output = Result, TlsError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); diff --git a/actix-tls/src/accept/rustls.rs b/actix-tls/src/accept/rustls.rs index e7c34d41..e3f6fdad 100644 --- a/actix-tls/src/accept/rustls.rs +++ b/actix-tls/src/accept/rustls.rs @@ -1,10 +1,11 @@ -//! Rustls based acceptor service. +//! `rustls` based TLS connection acceptor service. +//! +//! See [`Acceptor`] for main service factory docs. use std::{ convert::Infallible, future::Future, io::{self, IoSlice}, - ops::{Deref, DerefMut}, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -18,6 +19,7 @@ use actix_rt::{ }; use actix_service::{Service, ServiceFactory}; use actix_utils::counter::{Counter, CounterGuard}; +use derive_more::{Deref, DerefMut, From}; use futures_core::future::LocalBoxFuture; use pin_project_lite::pin_project; pub use tokio_rustls::rustls::ServerConfig; @@ -25,29 +27,10 @@ use tokio_rustls::{Accept, TlsAcceptor}; use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER}; -/// Wraps a [`tokio_rustls::server::TlsStream`] in order to impl [`ActixStream`] trait. +/// Wraps a `rustls` based async TLS stream in order to implement [`ActixStream`]. +#[derive(Deref, DerefMut, From)] pub struct TlsStream(tokio_rustls::server::TlsStream); -impl From> for TlsStream { - fn from(stream: tokio_rustls::server::TlsStream) -> Self { - Self(stream) - } -} - -impl Deref for TlsStream { - type Target = tokio_rustls::server::TlsStream; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TlsStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, @@ -98,17 +81,14 @@ impl ActixStream for TlsStream { } } -/// Accept TLS connections via `rustls` package. -/// -/// `rustls` feature enables this `Acceptor` type. +/// Accept TLS connections via the `rustls` crate. pub struct Acceptor { config: Arc, handshake_timeout: Duration, } impl Acceptor { - /// Create Rustls based `Acceptor` service factory. - #[inline] + /// Constructs Rustls based acceptor service factory. pub fn new(config: ServerConfig) -> Self { Acceptor { config: Arc::new(config), @@ -126,7 +106,6 @@ impl Acceptor { } impl Clone for Acceptor { - #[inline] fn clone(&self) -> Self { Self { config: self.config.clone(), diff --git a/actix-tls/src/connect/address.rs b/actix-tls/src/connect/address.rs new file mode 100644 index 00000000..8530a7c3 --- /dev/null +++ b/actix-tls/src/connect/address.rs @@ -0,0 +1,22 @@ +/// An interface for types where host parts (hostname and port) can be derived. +pub trait Address: Unpin + 'static { + /// Returns hostname part. + fn hostname(&self) -> &str; + + /// Returns optional port part. + fn port(&self) -> Option { + None + } +} + +impl Address for String { + fn hostname(&self) -> &str { + self + } +} + +impl Address for &'static str { + fn hostname(&self) -> &str { + self + } +} diff --git a/actix-tls/src/connect/connect.rs b/actix-tls/src/connect/connect.rs deleted file mode 100755 index 86f702da..00000000 --- a/actix-tls/src/connect/connect.rs +++ /dev/null @@ -1,359 +0,0 @@ -use std::{ - collections::{vec_deque, VecDeque}, - fmt, - iter::{self, FromIterator as _}, - mem, - net::{IpAddr, SocketAddr}, - ops, -}; - -/// Parse a host into parts (hostname and port). -pub trait Address: Unpin + 'static { - /// Get hostname part. - fn hostname(&self) -> &str; - - /// Get optional port part. - fn port(&self) -> Option { - None - } -} - -impl Address for String { - fn hostname(&self) -> &str { - self - } -} - -impl Address for &'static str { - fn hostname(&self) -> &str { - self - } -} - -#[derive(Debug, Eq, PartialEq, Hash)] -pub(crate) enum ConnectAddrs { - None, - One(SocketAddr), - Multi(VecDeque), -} - -impl ConnectAddrs { - pub(crate) fn is_none(&self) -> bool { - matches!(self, Self::None) - } - - pub(crate) fn is_some(&self) -> bool { - !self.is_none() - } -} - -impl Default for ConnectAddrs { - fn default() -> Self { - Self::None - } -} - -impl From> for ConnectAddrs { - fn from(addr: Option) -> Self { - match addr { - Some(addr) => ConnectAddrs::One(addr), - None => ConnectAddrs::None, - } - } -} - -/// Connection info. -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct Connect { - pub(crate) req: R, - pub(crate) port: u16, - pub(crate) addr: ConnectAddrs, - pub(crate) local_addr: Option, -} - -impl Connect { - /// Create `Connect` instance by splitting the string by ':' and convert the second part to u16 - pub fn new(req: R) -> Connect { - let (_, port) = parse_host(req.hostname()); - - Connect { - req, - port: port.unwrap_or(0), - addr: ConnectAddrs::None, - local_addr: None, - } - } - - /// Create new `Connect` instance from host and address. Connector skips name resolution stage - /// for such connect messages. - pub fn with_addr(req: R, addr: SocketAddr) -> Connect { - Connect { - req, - port: 0, - addr: ConnectAddrs::One(addr), - local_addr: None, - } - } - - /// Use port if address does not provide one. - /// - /// Default value is 0. - pub fn set_port(mut self, port: u16) -> Self { - self.port = port; - self - } - - /// Set address. - pub fn set_addr(mut self, addr: Option) -> Self { - self.addr = ConnectAddrs::from(addr); - self - } - - /// Set list of addresses. - pub fn set_addrs(mut self, addrs: I) -> Self - where - I: IntoIterator, - { - let mut addrs = VecDeque::from_iter(addrs); - self.addr = if addrs.len() < 2 { - ConnectAddrs::from(addrs.pop_front()) - } else { - ConnectAddrs::Multi(addrs) - }; - self - } - - /// Set local_addr of connect. - pub fn set_local_addr(mut self, addr: impl Into) -> Self { - self.local_addr = Some(addr.into()); - self - } - - /// Get hostname. - pub fn hostname(&self) -> &str { - self.req.hostname() - } - - /// Get request port. - pub fn port(&self) -> u16 { - self.req.port().unwrap_or(self.port) - } - - /// Get resolved request addresses. - pub fn addrs(&self) -> ConnectAddrsIter<'_> { - match self.addr { - ConnectAddrs::None => ConnectAddrsIter::None, - ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), - ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), - } - } - - /// Take resolved request addresses. - pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> { - match mem::take(&mut self.addr) { - ConnectAddrs::None => ConnectAddrsIter::None, - ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), - ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), - } - } - - /// Returns a reference to the connection request. - pub fn request(&self) -> &R { - &self.req - } -} - -impl From for Connect { - fn from(addr: R) -> Self { - Connect::new(addr) - } -} - -impl fmt::Display for Connect { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.hostname(), self.port()) - } -} - -/// Iterator over addresses in a [`Connect`] request. -#[derive(Clone)] -pub enum ConnectAddrsIter<'a> { - None, - One(SocketAddr), - Multi(vec_deque::Iter<'a, SocketAddr>), - MultiOwned(vec_deque::IntoIter), -} - -impl Iterator for ConnectAddrsIter<'_> { - type Item = SocketAddr; - - fn next(&mut self) -> Option { - match *self { - Self::None => None, - Self::One(addr) => { - *self = Self::None; - Some(addr) - } - Self::Multi(ref mut iter) => iter.next().copied(), - Self::MultiOwned(ref mut iter) => iter.next(), - } - } - - fn size_hint(&self) -> (usize, Option) { - match *self { - Self::None => (0, Some(0)), - Self::One(_) => (1, Some(1)), - Self::Multi(ref iter) => iter.size_hint(), - Self::MultiOwned(ref iter) => iter.size_hint(), - } - } -} - -impl fmt::Debug for ConnectAddrsIter<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.clone()).finish() - } -} - -impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {} - -impl iter::FusedIterator for ConnectAddrsIter<'_> {} - -/// Holds underlying I/O and original connection request. -#[derive(Debug)] -pub struct Connection { - req: R, - io: IO, -} - -impl Connection { - /// Construct new `Connection` from - pub fn new(io: IO, req: R) -> Self { - Self { io, req } - } -} - -impl Connection { - /// Deconstructs into parts. - pub fn into_parts(self) -> (IO, R) { - (self.io, self.req) - } - - /// Replaces underlying IO, returning old UI and new `Connection`. - pub fn replace_io(self, io: IO2) -> (IO, Connection) { - (self.io, Connection { io, req: self.req }) - } - - /// Returns a shared reference to the underlying IO. - pub fn io_ref(&self) -> &IO { - &self.io - } - - /// Returns a mutable reference to the underlying IO. - pub fn io_mut(&mut self) -> &mut IO { - &mut self.io - } - - /// Returns a reference to the connection request. - pub fn request(&self) -> &R { - &self.req - } -} - -impl Connection { - /// Get hostname. - pub fn host(&self) -> &str { - self.req.hostname() - } -} - -impl ops::Deref for Connection { - type Target = IO; - - fn deref(&self) -> &IO { - &self.io - } -} - -impl ops::DerefMut for Connection { - fn deref_mut(&mut self) -> &mut IO { - &mut self.io - } -} - -fn parse_host(host: &str) -> (&str, Option) { - let mut parts_iter = host.splitn(2, ':'); - - match parts_iter.next() { - Some(hostname) => { - let port_str = parts_iter.next().unwrap_or(""); - let port = port_str.parse::().ok(); - (hostname, port) - } - - None => (host, None), - } -} - -#[cfg(test)] -mod tests { - use std::net::Ipv4Addr; - - use super::*; - - #[test] - fn test_host_parser() { - assert_eq!(parse_host("example.com"), ("example.com", None)); - assert_eq!(parse_host("example.com:8080"), ("example.com", Some(8080))); - assert_eq!(parse_host("example:8080"), ("example", Some(8080))); - assert_eq!(parse_host("example.com:false"), ("example.com", None)); - assert_eq!(parse_host("example.com:false:false"), ("example.com", None)); - } - - #[test] - fn test_addr_iter_multi() { - let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); - let unspecified = SocketAddr::from((IpAddr::from(Ipv4Addr::UNSPECIFIED), 8080)); - - let mut addrs = VecDeque::new(); - addrs.push_back(localhost); - addrs.push_back(unspecified); - - let mut iter = ConnectAddrsIter::Multi(addrs.iter()); - assert_eq!(iter.next(), Some(localhost)); - assert_eq!(iter.next(), Some(unspecified)); - assert_eq!(iter.next(), None); - - let mut iter = ConnectAddrsIter::MultiOwned(addrs.into_iter()); - assert_eq!(iter.next(), Some(localhost)); - assert_eq!(iter.next(), Some(unspecified)); - assert_eq!(iter.next(), None); - } - - #[test] - fn test_addr_iter_single() { - let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); - - let mut iter = ConnectAddrsIter::One(localhost); - assert_eq!(iter.next(), Some(localhost)); - assert_eq!(iter.next(), None); - - let mut iter = ConnectAddrsIter::None; - assert_eq!(iter.next(), None); - } - - #[test] - fn test_local_addr() { - let conn = Connect::new("hello").set_local_addr([127, 0, 0, 1]); - assert_eq!( - conn.local_addr.unwrap(), - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) - ) - } - - #[test] - fn request_ref() { - let conn = Connect::new("hello"); - assert_eq!(conn.request(), &"hello") - } -} diff --git a/actix-tls/src/connect/connect_addrs.rs b/actix-tls/src/connect/connect_addrs.rs new file mode 100644 index 00000000..7b15a8e5 --- /dev/null +++ b/actix-tls/src/connect/connect_addrs.rs @@ -0,0 +1,81 @@ +use std::{ + collections::{vec_deque, VecDeque}, + fmt, iter, + net::SocketAddr, +}; + +#[derive(Debug, Eq, PartialEq, Hash)] +pub(crate) enum ConnectAddrs { + None, + One(SocketAddr), + Multi(VecDeque), +} + +impl ConnectAddrs { + pub(crate) fn is_none(&self) -> bool { + matches!(self, Self::None) + } + + pub(crate) fn is_some(&self) -> bool { + !self.is_none() + } +} + +impl Default for ConnectAddrs { + fn default() -> Self { + Self::None + } +} + +impl From> for ConnectAddrs { + fn from(addr: Option) -> Self { + match addr { + Some(addr) => ConnectAddrs::One(addr), + None => ConnectAddrs::None, + } + } +} + +/// Iterator over addresses in a [`Connect`] request. +#[derive(Clone)] +pub(crate) enum ConnectAddrsIter<'a> { + None, + One(SocketAddr), + Multi(vec_deque::Iter<'a, SocketAddr>), + MultiOwned(vec_deque::IntoIter), +} + +impl Iterator for ConnectAddrsIter<'_> { + type Item = SocketAddr; + + fn next(&mut self) -> Option { + match *self { + Self::None => None, + Self::One(addr) => { + *self = Self::None; + Some(addr) + } + Self::Multi(ref mut iter) => iter.next().copied(), + Self::MultiOwned(ref mut iter) => iter.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match *self { + Self::None => (0, Some(0)), + Self::One(_) => (1, Some(1)), + Self::Multi(ref iter) => iter.size_hint(), + Self::MultiOwned(ref iter) => iter.size_hint(), + } + } +} + +impl fmt::Debug for ConnectAddrsIter<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.clone()).finish() + } +} + +impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {} + +impl iter::FusedIterator for ConnectAddrsIter<'_> {} diff --git a/actix-tls/src/connect/connection.rs b/actix-tls/src/connect/connection.rs new file mode 100644 index 00000000..47fa90c2 --- /dev/null +++ b/actix-tls/src/connect/connection.rs @@ -0,0 +1,54 @@ +use derive_more::{Deref, DerefMut}; + +use super::Address; + +/// Wraps underlying I/O and the connection request that initiated it. +#[derive(Debug, Deref, DerefMut)] +pub struct Connection { + pub(crate) req: R, + + #[deref] + #[deref_mut] + pub(crate) io: IO, +} + +impl Connection { + /// Construct new `Connection` from + pub(crate) fn new(io: IO, req: R) -> Self { + Self { io, req } + } +} + +impl Connection { + /// Deconstructs into parts. + pub fn into_parts(self) -> (IO, R) { + (self.io, self.req) + } + + /// Replaces underlying IO, returning old UI and new `Connection`. + pub fn replace_io(self, io: IO2) -> (IO, Connection) { + (self.io, Connection { io, req: self.req }) + } + + /// Returns a shared reference to the underlying IO. + pub fn io_ref(&self) -> &IO { + &self.io + } + + /// Returns a mutable reference to the underlying IO. + pub fn io_mut(&mut self) -> &mut IO { + &mut self.io + } + + /// Returns a reference to the connection request. + pub fn request(&self) -> &R { + &self.req + } +} + +impl Connection { + /// Get hostname. + pub fn hostname(&self) -> &str { + self.req.hostname() + } +} diff --git a/actix-tls/src/connect/connector.rs b/actix-tls/src/connect/connector.rs index ab9b08be..ac9dfe54 100755 --- a/actix-tls/src/connect/connector.rs +++ b/actix-tls/src/connect/connector.rs @@ -1,196 +1,148 @@ use std::{ - collections::VecDeque, future::Future, - io, - net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, pin::Pin, task::{Context, Poll}, }; -use actix_rt::net::{TcpSocket, TcpStream}; +use actix_rt::net::TcpStream; use actix_service::{Service, ServiceFactory}; use futures_core::{future::LocalBoxFuture, ready}; -use log::{error, trace}; -use tokio_util::sync::ReusableBoxFuture; use super::{ - connect::{Address, Connect, ConnectAddrs, Connection}, error::ConnectError, + resolver::{Resolver, ResolverService}, + tcp::{TcpConnector, TcpConnectorService}, + Address, Connection, ConnectionInfo, }; -/// TCP connector service factory -#[derive(Debug, Copy, Clone)] -pub struct TcpConnectorFactory; +/// Combined resolver and TCP connector service factory. +/// +/// Used to create [`ConnectService`]s which receive connection information, resolve DNS if +/// required, and return a TCP stream. +pub struct Connector { + tcp: TcpConnector, + resolver: Resolver, +} -impl TcpConnectorFactory { - /// Create TCP connector service - pub fn service(&self) -> TcpConnector { - TcpConnector +impl Connector { + /// Constructs new connector factory. + pub fn new(resolver: Resolver) -> Self { + Connector { + tcp: TcpConnector, + resolver, + } + } + + /// Build connector service. + pub fn service(&self) -> ConnectorService { + ConnectorService { + tcp: self.tcp.service(), + resolver: self.resolver.service(), + } } } -impl ServiceFactory> for TcpConnectorFactory { +impl Clone for Connector { + fn clone(&self) -> Self { + Connector { + tcp: self.tcp, + resolver: self.resolver.clone(), + } + } +} + +impl Default for Connector { + fn default() -> Self { + Self { + tcp: TcpConnector, + resolver: Resolver::default(), + } + } +} + +impl ServiceFactory> for Connector { type Response = Connection; type Error = ConnectError; type Config = (); - type Service = TcpConnector; + type Service = ConnectorService; type InitError = (); type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { let service = self.service(); - Box::pin(async move { Ok(service) }) + Box::pin(async { Ok(service) }) } } -/// TCP connector service. -#[derive(Debug, Copy, Clone)] -pub struct TcpConnector; +/// Combined resolver and TCP connector service. +/// +/// Service implementation receives connection information, resolves DNS if required, and returns +/// a TCP stream. +#[derive(Clone)] +pub struct ConnectorService { + tcp: TcpConnectorService, + resolver: ResolverService, +} -impl Service> for TcpConnector { +impl Service> for ConnectorService { type Response = Connection; type Error = ConnectError; - type Future = TcpConnectorResponse; + type Future = ConnectServiceResponse; actix_service::always_ready!(); - fn call(&self, req: Connect) -> Self::Future { - let port = req.port(); - let Connect { - req, - addr, - local_addr, - .. - } = req; - - TcpConnectorResponse::new(req, port, local_addr, addr) + fn call(&self, req: ConnectionInfo) -> Self::Future { + ConnectServiceResponse { + fut: ConnectFuture::Resolve(self.resolver.call(req)), + tcp: self.tcp, + } } } -/// TCP stream connector response future -pub enum TcpConnectorResponse { - Response { - req: Option, - port: u16, - local_addr: Option, - addrs: Option>, - stream: ReusableBoxFuture>, - }, - Error(Option), +// helper enum to generic over futures of resolve and connect phase. +pub(crate) enum ConnectFuture { + Resolve(>>::Future), + Connect(>>::Future), } -impl TcpConnectorResponse { - pub(crate) fn new( - req: R, - port: u16, - local_addr: Option, - addr: ConnectAddrs, - ) -> TcpConnectorResponse { - if addr.is_none() { - error!("TCP connector: unresolved connection address"); - return TcpConnectorResponse::Error(Some(ConnectError::Unresolved)); - } +/// Helper enum to contain the future output of `ConnectFuture`. +pub(crate) enum ConnectOutput { + Resolved(ConnectionInfo), + Connected(Connection), +} - trace!( - "TCP connector: connecting to {} on port {}", - req.hostname(), - port - ); - - match addr { - ConnectAddrs::None => unreachable!("none variant already checked"), - - ConnectAddrs::One(addr) => TcpConnectorResponse::Response { - req: Some(req), - port, - local_addr, - addrs: None, - stream: ReusableBoxFuture::new(connect(addr, local_addr)), - }, - - // when resolver returns multiple socket addr for request they would be popped from - // front end of queue and returns with the first successful tcp connection. - ConnectAddrs::Multi(mut addrs) => { - let addr = addrs.pop_front().unwrap(); - - TcpConnectorResponse::Response { - req: Some(req), - port, - local_addr, - addrs: Some(addrs), - stream: ReusableBoxFuture::new(connect(addr, local_addr)), - } +impl ConnectFuture { + fn poll_connect( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, ConnectError>> { + match self { + ConnectFuture::Resolve(ref mut fut) => { + Pin::new(fut).poll(cx).map_ok(ConnectOutput::Resolved) + } + ConnectFuture::Connect(ref mut fut) => { + Pin::new(fut).poll(cx).map_ok(ConnectOutput::Connected) } } } } -impl Future for TcpConnectorResponse { +pub struct ConnectServiceResponse { + fut: ConnectFuture, + tcp: TcpConnectorService, +} + +impl Future for ConnectServiceResponse { type Output = Result, ConnectError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.get_mut() { - TcpConnectorResponse::Error(err) => Poll::Ready(Err(err.take().unwrap())), - - TcpConnectorResponse::Response { - req, - port, - local_addr, - addrs, - stream, - } => loop { - match ready!(stream.poll(cx)) { - Ok(sock) => { - let req = req.take().unwrap(); - trace!( - "TCP connector: successfully connected to {:?} - {:?}", - req.hostname(), - sock.peer_addr() - ); - return Poll::Ready(Ok(Connection::new(sock, req))); - } - - Err(err) => { - trace!( - "TCP connector: failed to connect to {:?} port: {}", - req.as_ref().unwrap().hostname(), - port, - ); - - if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) { - stream.set(connect(addr, *local_addr)); - } else { - return Poll::Ready(Err(ConnectError::Io(err))); - } - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match ready!(self.fut.poll_connect(cx))? { + ConnectOutput::Resolved(res) => { + self.fut = ConnectFuture::Connect(self.tcp.call(res)); } - }, + ConnectOutput::Connected(res) => return Poll::Ready(Ok(res)), + } } } } - -async fn connect(addr: SocketAddr, local_addr: Option) -> io::Result { - // use local addr if connect asks for it. - match local_addr { - Some(ip_addr) => { - let socket = match ip_addr { - IpAddr::V4(ip_addr) => { - let socket = TcpSocket::new_v4()?; - let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0)); - socket.bind(addr)?; - socket - } - IpAddr::V6(ip_addr) => { - let socket = TcpSocket::new_v6()?; - let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0)); - socket.bind(addr)?; - socket - } - }; - - socket.connect(addr).await - } - - None => TcpStream::connect(addr).await, - } -} diff --git a/actix-tls/src/connect/info.rs b/actix-tls/src/connect/info.rs new file mode 100755 index 00000000..62f6cfa9 --- /dev/null +++ b/actix-tls/src/connect/info.rs @@ -0,0 +1,257 @@ +//! Connection info struct. + +use std::{ + collections::VecDeque, + fmt, + iter::{self, FromIterator as _}, + mem, + net::{IpAddr, SocketAddr}, +}; + +use super::{ + connect_addrs::{ConnectAddrs, ConnectAddrsIter}, + Address, +}; + +/// Connection request information. +/// +/// May contain known/pre-resolved socket address(es) or a host that needs resolving with DNS. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ConnectionInfo { + pub(crate) req: R, + pub(crate) port: u16, + pub(crate) addr: ConnectAddrs, + pub(crate) local_addr: Option, +} + +impl ConnectionInfo { + /// Create `Connect` instance by splitting the host at ':' and convert the second part to u16. + // TODO: assess usage and find nicer API + pub fn new(req: R) -> ConnectionInfo { + let (_, port) = parse_host(req.hostname()); + + ConnectionInfo { + req, + port: port.unwrap_or(0), + addr: ConnectAddrs::None, + local_addr: None, + } + } + + /// Create new `Connect` instance from host and socket address. + /// + /// Since socket address is known, Connector will skip name resolution stage. + pub fn with_addr(req: R, addr: SocketAddr) -> ConnectionInfo { + ConnectionInfo { + req, + port: 0, + addr: ConnectAddrs::One(addr), + local_addr: None, + } + } + + /// Set port if address does not provide one. + pub fn set_port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Set connect address. + pub fn set_addr(mut self, addr: impl Into>) -> Self { + self.addr = ConnectAddrs::from(addr.into()); + self + } + + /// Set list of addresses. + pub fn set_addrs(mut self, addrs: I) -> Self + where + I: IntoIterator, + { + let mut addrs = VecDeque::from_iter(addrs); + self.addr = if addrs.len() < 2 { + ConnectAddrs::from(addrs.pop_front()) + } else { + ConnectAddrs::Multi(addrs) + }; + self + } + + /// Set local_addr of connect. + pub fn set_local_addr(mut self, addr: impl Into) -> Self { + self.local_addr = Some(addr.into()); + self + } + + /// Get hostname. + pub fn hostname(&self) -> &str { + self.req.hostname() + } + + /// Get request port. + pub fn port(&self) -> u16 { + self.req.port().unwrap_or(self.port) + } + + /** + Get resolved request addresses. + + # Examples + ``` + # use std::net::SocketAddr; + # use actix_tls::connect::ConnectionInfo; + let addr = SocketAddr::from(([127, 0, 0, 1], 4242)); + + let conn = ConnectionInfo::with_addr("localhost").set_addr(None); + let mut addrs = conn.addrs(); + assert!(addrs.next().is_none()); + ``` + */ + pub fn addrs( + &self, + ) -> impl Iterator + + ExactSizeIterator + + iter::FusedIterator + + Clone + + fmt::Debug + + '_ { + match self.addr { + ConnectAddrs::None => ConnectAddrsIter::None, + ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), + ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), + } + } + + /** + Take resolved request addresses. + + # Examples + ``` + + ``` + */ + pub fn take_addrs( + &mut self, + ) -> impl Iterator + + ExactSizeIterator + + iter::FusedIterator + + Clone + + fmt::Debug + + 'static { + match mem::take(&mut self.addr) { + ConnectAddrs::None => ConnectAddrsIter::None, + ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), + ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), + } + } + + /// Returns a reference to the connection request. + pub fn request(&self) -> &R { + &self.req + } +} + +impl From for ConnectionInfo { + fn from(addr: R) -> Self { + ConnectionInfo::new(addr) + } +} + +impl fmt::Display for ConnectionInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{}", self.hostname(), self.port()) + } +} + +fn parse_host(host: &str) -> (&str, Option) { + let mut parts_iter = host.splitn(2, ':'); + + match parts_iter.next() { + Some(hostname) => { + let port_str = parts_iter.next().unwrap_or(""); + let port = port_str.parse::().ok(); + (hostname, port) + } + + None => (host, None), + } +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use super::*; + + #[test] + fn test_host_parser() { + assert_eq!(parse_host("example.com"), ("example.com", None)); + assert_eq!(parse_host("example.com:8080"), ("example.com", Some(8080))); + assert_eq!(parse_host("example:8080"), ("example", Some(8080))); + assert_eq!(parse_host("example.com:false"), ("example.com", None)); + assert_eq!(parse_host("example.com:false:false"), ("example.com", None)); + } + + #[test] + fn test_addr_iter_multi() { + let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); + let unspecified = SocketAddr::from((IpAddr::from(Ipv4Addr::UNSPECIFIED), 8080)); + + let mut addrs = VecDeque::new(); + addrs.push_back(localhost); + addrs.push_back(unspecified); + + let mut iter = ConnectAddrsIter::Multi(addrs.iter()); + assert_eq!(iter.next(), Some(localhost)); + assert_eq!(iter.next(), Some(unspecified)); + assert_eq!(iter.next(), None); + + let mut iter = ConnectAddrsIter::MultiOwned(addrs.into_iter()); + assert_eq!(iter.next(), Some(localhost)); + assert_eq!(iter.next(), Some(unspecified)); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_addr_iter_single() { + let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); + + let mut iter = ConnectAddrsIter::One(localhost); + assert_eq!(iter.next(), Some(localhost)); + assert_eq!(iter.next(), None); + + let mut iter = ConnectAddrsIter::None; + assert_eq!(iter.next(), None); + } + + #[test] + fn test_local_addr() { + let conn = ConnectionInfo::new("hello").set_local_addr([127, 0, 0, 1]); + assert_eq!( + conn.local_addr.unwrap(), + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) + ) + } + + #[test] + fn request_ref() { + let conn = ConnectionInfo::new("hello"); + assert_eq!(conn.request(), &"hello") + } + + #[test] + fn set_connect_addr_into_option() { + let addr = SocketAddr::from(([127, 0, 0, 1], 4242)); + + let conn = ConnectionInfo::new("hello").set_addr(None); + let mut addrs = conn.addrs(); + assert!(addrs.next().is_none()); + + let conn = ConnectionInfo::new("hello").set_addr(addr); + let mut addrs = conn.addrs(); + assert_eq!(addrs.next().unwrap(), addr); + + let conn = ConnectionInfo::new("hello").set_addr(Some(addr)); + let mut addrs = conn.addrs(); + assert_eq!(addrs.next().unwrap(), addr); + } +} diff --git a/actix-tls/src/connect/mod.rs b/actix-tls/src/connect/mod.rs index c2aee9c7..8ff96275 100644 --- a/actix-tls/src/connect/mod.rs +++ b/actix-tls/src/connect/mod.rs @@ -1,35 +1,48 @@ //! TCP and TLS connector services. //! //! # Stages of the TCP connector service: -//! - Resolve [`Address`] with given [`Resolver`] and collect list of socket addresses. -//! - Establish TCP connection and return [`TcpStream`]. +//! 1. Resolve [`Address`] with given [`Resolver`] and collect list of socket addresses. +//! 1. Establish TCP connection and return [`TcpStream`]. //! //! # Stages of TLS connector services: -//! - Establish [`TcpStream`] with connector service. -//! - Wrap the stream and perform connect handshake with remote peer. -//! - Return certain stream type that impls `AsyncRead` and `AsyncWrite`. +//! 1. Resolve DNS and establish a [`TcpStream`] with the TCP connector service. +//! 1. Wrap the stream and perform connect handshake with remote peer. +//! 1. Return wrapped stream type that implements [`AsyncRead`] and [`AsyncWrite`]. //! //! [`TcpStream`]: actix_rt::net::TcpStream +//! [`AsyncRead`]: actix_rt::net::AsyncRead +//! [`AsyncWrite`]: actix_rt::net::AsyncWrite -#[allow(clippy::module_inception)] -mod connect; +mod address; +mod connect_addrs; +mod connection; mod connector; mod error; +mod info; mod resolve; -mod service; -pub mod tls; -// TODO: remove `ssl` mod re-export in next break change -#[doc(hidden)] -pub use tls as ssl; -mod tcp; +mod resolver; +pub mod tcp; + #[cfg(feature = "uri")] +#[cfg_attr(docsrs, doc(cfg(feature = "uri")))] mod uri; -pub use self::connect::{Address, Connect, Connection}; -pub use self::connector::{TcpConnector, TcpConnectorFactory}; +#[cfg(feature = "openssl")] +#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))] +pub mod openssl; + +#[cfg(feature = "rustls")] +#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))] +pub mod rustls; + +#[cfg(feature = "native-tls")] +#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] +pub mod native_tls; + +pub use self::address::Address; +pub use self::connection::Connection; +pub use self::connector::{Connector, ConnectorService}; pub use self::error::ConnectError; -pub use self::resolve::{Resolve, Resolver, ResolverFactory}; -pub use self::service::{ConnectService, ConnectServiceFactory}; -pub use self::tcp::{ - default_connector, default_connector_factory, new_connector, new_connector_factory, -}; +pub use self::info::ConnectionInfo; +pub use self::resolve::Resolve; +pub use self::resolver::{Resolver, ResolverService}; diff --git a/actix-tls/src/connect/native_tls.rs b/actix-tls/src/connect/native_tls.rs new file mode 100644 index 00000000..a325562d --- /dev/null +++ b/actix-tls/src/connect/native_tls.rs @@ -0,0 +1,90 @@ +//! Native-TLS based connector service. +//! +//! See [`Connector`] for main connector service factory docs. + +use std::io; + +use actix_rt::net::ActixStream; +use actix_service::{Service, ServiceFactory}; +use actix_utils::future::{ok, Ready}; +use futures_core::future::LocalBoxFuture; +use log::trace; +use tokio_native_tls::{ + native_tls::TlsConnector as NativeTlsConnector, TlsConnector as TokioNativeTlsConnector, + TlsStream, +}; + +use crate::connect::{Address, Connection}; + +pub mod reexports { + //! Re-exports from `native-tls` that are useful for connectors. + + pub use tokio_native_tls::native_tls::TlsConnector; +} + +/// Connector service and factory using `native-tls`. +#[derive(Clone)] +pub struct TlsConnector { + connector: TokioNativeTlsConnector, +} + +impl TlsConnector { + /// Constructs new connector service from a `native-tls` connector. + /// + /// This type is it's own service factory, so it can be used in that setting, too. + pub fn new(connector: NativeTlsConnector) -> Self { + Self { + connector: TokioNativeTlsConnector::from(connector), + } + } +} + +impl ServiceFactory> for TlsConnector +where + IO: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Config = (); + type Service = Self; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(self.clone()) + } +} + +/// The `native-tls` connector is both it's ServiceFactory and Service impl type. +/// As the factory and service share the same type and state. +impl Service> for TlsConnector +where + R: Address, + IO: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::always_ready!(); + + fn call(&self, stream: Connection) -> Self::Future { + let (io, stream) = stream.replace_io(()); + let connector = self.connector.clone(); + + Box::pin(async move { + trace!("SSL Handshake start for: {:?}", stream.hostname()); + connector + .connect(stream.hostname(), io) + .await + .map(|res| { + trace!("SSL Handshake success: {:?}", stream.hostname()); + stream.replace_io(res).1 + }) + .map_err(|e| { + trace!("SSL Handshake error: {:?}", e); + io::Error::new(io::ErrorKind::Other, format!("{}", e)) + }) + }) + } +} diff --git a/actix-tls/src/connect/tls/openssl.rs b/actix-tls/src/connect/openssl.rs similarity index 59% rename from actix-tls/src/connect/tls/openssl.rs rename to actix-tls/src/connect/openssl.rs index 6048e0ab..69e0eda8 100755 --- a/actix-tls/src/connect/tls/openssl.rs +++ b/actix-tls/src/connect/openssl.rs @@ -1,3 +1,7 @@ +//! OpenSSL based connector service. +//! +//! See [`Connector`] for main connector service factory docs. + use std::{ future::Future, io, @@ -7,30 +11,38 @@ use std::{ use actix_rt::net::ActixStream; use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; +use actix_utils::future::{ok, Ready}; +use futures_core::ready; use log::trace; - -pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; -pub use tokio_openssl::SslStream; +use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; +use tokio_openssl::SslStream; use crate::connect::{Address, Connection}; -/// OpenSSL connector factory -pub struct OpensslConnector { +pub mod reexports { + //! Re-exports from `openssl` that are useful for connectors. + + pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; +} + +/// Connector service factory using `openssl`. +pub struct Connector { connector: SslConnector, } -impl OpensslConnector { +impl Connector { + /// Constructs new connector service factory from an `openssl` connector. pub fn new(connector: SslConnector) -> Self { - OpensslConnector { connector } + Connector { connector } } - pub fn service(connector: SslConnector) -> OpensslConnectorService { - OpensslConnectorService { connector } + /// Constructs new connector service from an `openssl` connector. + pub fn service(connector: SslConnector) -> ConnectorService { + ConnectorService { connector } } } -impl Clone for OpensslConnector { +impl Clone for Connector { fn clone(&self) -> Self { Self { connector: self.connector.clone(), @@ -38,7 +50,7 @@ impl Clone for OpensslConnector { } } -impl ServiceFactory> for OpensslConnector +impl ServiceFactory> for Connector where R: Address, IO: ActixStream + 'static, @@ -46,21 +58,23 @@ where type Response = Connection>; type Error = io::Error; type Config = (); - type Service = OpensslConnectorService; + type Service = ConnectorService; type InitError = (); - type Future = LocalBoxFuture<'static, Result>; + type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { - let connector = self.connector.clone(); - Box::pin(async { Ok(OpensslConnectorService { connector }) }) + ok(ConnectorService { + connector: self.connector.clone(), + }) } } -pub struct OpensslConnectorService { +/// Connector service using `openssl`. +pub struct ConnectorService { connector: SslConnector, } -impl Clone for OpensslConnectorService { +impl Clone for ConnectorService { fn clone(&self) -> Self { Self { connector: self.connector.clone(), @@ -68,21 +82,21 @@ impl Clone for OpensslConnectorService { } } -impl Service> for OpensslConnectorService +impl Service> for ConnectorService where R: Address, IO: ActixStream, { type Response = Connection>; type Error = io::Error; - type Future = ConnectAsyncExt; + type Future = ConnectFut; actix_service::always_ready!(); fn call(&self, stream: Connection) -> Self::Future { - trace!("SSL Handshake start for: {:?}", stream.host()); + trace!("SSL Handshake start for: {:?}", stream.hostname()); let (io, stream) = stream.replace_io(()); - let host = stream.host(); + let host = stream.hostname(); let config = self .connector @@ -93,19 +107,21 @@ where .into_ssl(host) .expect("SSL connect configuration was invalid."); - ConnectAsyncExt { + ConnectFut { io: Some(SslStream::new(ssl, io).unwrap()), stream: Some(stream), } } } -pub struct ConnectAsyncExt { +/// Connect future for OpenSSL service. +#[doc(hidden)] +pub struct ConnectFut { io: Option>, stream: Option>, } -impl Future for ConnectAsyncExt +impl Future for ConnectFut where R: Address, IO: ActixStream, @@ -118,7 +134,7 @@ where match ready!(Pin::new(this.io.as_mut().unwrap()).poll_connect(cx)) { Ok(_) => { let stream = this.stream.take().unwrap(); - trace!("SSL Handshake success: {:?}", stream.host()); + trace!("SSL Handshake success: {:?}", stream.hostname()); Poll::Ready(Ok(stream.replace_io(this.io.take().unwrap()).1)) } Err(e) => { diff --git a/actix-tls/src/connect/resolve.rs b/actix-tls/src/connect/resolve.rs old mode 100755 new mode 100644 index c9639bcf..5e79f63c --- a/actix-tls/src/connect/resolve.rs +++ b/actix-tls/src/connect/resolve.rs @@ -1,54 +1,10 @@ -use std::{ - future::Future, - io, - net::SocketAddr, - pin::Pin, - rc::Rc, - task::{Context, Poll}, - vec::IntoIter, -}; +//! [`Resolve`] trait. -use actix_rt::task::{spawn_blocking, JoinHandle}; -use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; -use log::trace; +use std::{error::Error as StdError, net::SocketAddr}; -use super::connect::{Address, Connect}; -use super::error::ConnectError; +use futures_core::future::LocalBoxFuture; -/// DNS resolver service factory. -#[derive(Clone)] -pub struct ResolverFactory { - resolver: Resolver, -} - -impl ResolverFactory { - /// Constructs a new resolver factory with the given resolver. - pub fn new(resolver: Resolver) -> Self { - Self { resolver } - } - - /// Returns a reference to the inner resolver. - pub fn service(&self) -> Resolver { - self.resolver.clone() - } -} - -impl ServiceFactory> for ResolverFactory { - type Response = Connect; - type Error = ConnectError; - type Config = (); - type Service = Resolver; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let service = self.resolver.clone(); - Box::pin(async { Ok(service) }) - } -} - -/// An interface for custom async DNS resolvers. +/// Custom async DNS resolvers. /// /// # Usage /// ``` @@ -105,158 +61,5 @@ pub trait Resolve { &'a self, host: &'a str, port: u16, - ) -> LocalBoxFuture<'a, Result, Box>>; -} - -/// DNS resolver service -#[derive(Clone)] -pub enum Resolver { - /// Built-in DNS resolver. - /// - /// See [`std::net::ToSocketAddrs`] trait. - Default, - - /// Custom, user-provided DNS resolver. - Custom(Rc), -} - -impl Default for Resolver { - fn default() -> Self { - Self::Default - } -} - -impl Resolver { - /// Constructor for custom Resolve trait object and use it as resolver. - pub fn new_custom(resolver: impl Resolve + 'static) -> Self { - Self::Custom(Rc::new(resolver)) - } - - /// Resolve DNS with default resolver. - fn look_up(req: &Connect) -> JoinHandle>> { - let host = req.hostname(); - // TODO: Connect should always return host(name?) with port if possible; basically try to - // reduce ability to create conflicting lookup info by having port in host string being - // different from numeric port in connect - - let host = if req - .hostname() - .split_once(':') - .and_then(|(_, port)| port.parse::().ok()) - .map(|port| port == req.port()) - .unwrap_or(false) - { - // if hostname contains port and also matches numeric port then just use the hostname - host.to_string() - } else { - // concatenate domain-only hostname and port together - format!("{}:{}", host, req.port()) - }; - - // run blocking DNS lookup in thread pool since DNS lookups can take upwards of seconds on - // some platforms if conditions are poor and OS-level cache is not populated - spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host)) - } -} - -impl Service> for Resolver { - type Response = Connect; - type Error = ConnectError; - type Future = ResolverFuture; - - actix_service::always_ready!(); - - fn call(&self, req: Connect) -> Self::Future { - if req.addr.is_some() { - ResolverFuture::Connected(Some(req)) - } else if let Ok(ip) = req.hostname().parse() { - let addr = SocketAddr::new(ip, req.port()); - let req = req.set_addr(Some(addr)); - ResolverFuture::Connected(Some(req)) - } else { - trace!("DNS resolver: resolving host {:?}", req.hostname()); - - match self { - Self::Default => { - let fut = Self::look_up(&req); - ResolverFuture::LookUp(fut, Some(req)) - } - - Self::Custom(resolver) => { - let resolver = Rc::clone(resolver); - ResolverFuture::LookupCustom(Box::pin(async move { - let addrs = resolver - .lookup(req.hostname(), req.port()) - .await - .map_err(ConnectError::Resolver)?; - - let req = req.set_addrs(addrs); - - if req.addr.is_none() { - Err(ConnectError::NoRecords) - } else { - Ok(req) - } - })) - } - } - } - } -} - -pub enum ResolverFuture { - Connected(Option>), - LookUp( - JoinHandle>>, - Option>, - ), - LookupCustom(LocalBoxFuture<'static, Result, ConnectError>>), -} - -impl Future for ResolverFuture { - type Output = Result, ConnectError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.get_mut() { - Self::Connected(conn) => Poll::Ready(Ok(conn - .take() - .expect("ResolverFuture polled after finished"))), - - Self::LookUp(fut, req) => { - let res = match ready!(Pin::new(fut).poll(cx)) { - Ok(Ok(res)) => Ok(res), - Ok(Err(e)) => Err(ConnectError::Resolver(Box::new(e))), - Err(e) => Err(ConnectError::Io(e.into())), - }; - - let req = req.take().unwrap(); - - let addrs = res.map_err(|err| { - trace!( - "DNS resolver: failed to resolve host {:?} err: {:?}", - req.hostname(), - err - ); - - err - })?; - - let req = req.set_addrs(addrs); - - trace!( - "DNS resolver: host {:?} resolved to {:?}", - req.hostname(), - req.addrs() - ); - - if req.addr.is_none() { - Poll::Ready(Err(ConnectError::NoRecords)) - } else { - Poll::Ready(Ok(req)) - } - } - - Self::LookupCustom(fut) => fut.as_mut().poll(cx), - } - } + ) -> LocalBoxFuture<'a, Result, Box>>; } diff --git a/actix-tls/src/connect/resolver.rs b/actix-tls/src/connect/resolver.rs new file mode 100755 index 00000000..1691a16b --- /dev/null +++ b/actix-tls/src/connect/resolver.rs @@ -0,0 +1,212 @@ +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, + vec::IntoIter, +}; + +use actix_rt::task::{spawn_blocking, JoinHandle}; +use actix_service::{Service, ServiceFactory}; +use actix_utils::future::{ok, Ready}; +use futures_core::{future::LocalBoxFuture, ready}; +use log::trace; + +use super::{Address, ConnectError, ConnectionInfo, Resolve}; + +/// DNS resolver service factory. +#[derive(Clone, Default)] +pub struct Resolver { + resolver: ResolverService, +} + +impl Resolver { + /// Constructs a new resolver factory with a custom resolver. + pub fn custom(resolver: impl Resolve + 'static) -> Self { + Self { + resolver: ResolverService::custom(resolver), + } + } + + /// Returns a new resolver service. + pub fn service(&self) -> ResolverService { + self.resolver.clone() + } +} + +impl ServiceFactory> for Resolver { + type Response = ConnectionInfo; + type Error = ConnectError; + type Config = (); + type Service = ResolverService; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(self.resolver.clone()) + } +} + +#[derive(Clone)] +enum ResolverKind { + /// Built-in DNS resolver. + /// + /// See [`std::net::ToSocketAddrs`] trait. + Default, + + /// Custom, user-provided DNS resolver. + Custom(Rc), +} + +impl Default for ResolverKind { + fn default() -> Self { + Self::Default + } +} + +/// DNS resolver service. +#[derive(Clone, Default)] +pub struct ResolverService { + kind: ResolverKind, +} + +impl ResolverService { + /// Constructor for custom Resolve trait object and use it as resolver. + pub fn custom(resolver: impl Resolve + 'static) -> Self { + Self { + kind: ResolverKind::Custom(Rc::new(resolver)), + } + } + + /// Resolve DNS with default resolver. + fn look_up( + req: &ConnectionInfo, + ) -> JoinHandle>> { + let host = req.hostname(); + // TODO: Connect should always return host(name?) with port if possible; basically try to + // reduce ability to create conflicting lookup info by having port in host string being + // different from numeric port in connect + + let host = if req + .hostname() + .split_once(':') + .and_then(|(_, port)| port.parse::().ok()) + .map(|port| port == req.port()) + .unwrap_or(false) + { + // if hostname contains port and also matches numeric port then just use the hostname + host.to_string() + } else { + // concatenate domain-only hostname and port together + format!("{}:{}", host, req.port()) + }; + + // run blocking DNS lookup in thread pool since DNS lookups can take upwards of seconds on + // some platforms if conditions are poor and OS-level cache is not populated + spawn_blocking(move || std::net::ToSocketAddrs::to_socket_addrs(&host)) + } +} + +impl Service> for ResolverService { + type Response = ConnectionInfo; + type Error = ConnectError; + type Future = ResolverFuture; + + actix_service::always_ready!(); + + fn call(&self, req: ConnectionInfo) -> Self::Future { + if req.addr.is_some() { + ResolverFuture::Connected(Some(req)) + } else if let Ok(ip) = req.hostname().parse() { + let addr = SocketAddr::new(ip, req.port()); + let req = req.set_addr(Some(addr)); + ResolverFuture::Connected(Some(req)) + } else { + trace!("DNS resolver: resolving host {:?}", req.hostname()); + + match &self.kind { + ResolverKind::Default => { + let fut = Self::look_up(&req); + ResolverFuture::LookUp(fut, Some(req)) + } + + ResolverKind::Custom(resolver) => { + let resolver = Rc::clone(resolver); + ResolverFuture::LookupCustom(Box::pin(async move { + let addrs = resolver + .lookup(req.hostname(), req.port()) + .await + .map_err(ConnectError::Resolver)?; + + let req = req.set_addrs(addrs); + + if req.addr.is_none() { + Err(ConnectError::NoRecords) + } else { + Ok(req) + } + })) + } + } + } + } +} + +pub enum ResolverFuture { + Connected(Option>), + LookUp( + JoinHandle>>, + Option>, + ), + LookupCustom(LocalBoxFuture<'static, Result, ConnectError>>), +} + +impl Future for ResolverFuture { + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + Self::Connected(conn) => Poll::Ready(Ok(conn + .take() + .expect("ResolverFuture polled after finished"))), + + Self::LookUp(fut, req) => { + let res = match ready!(Pin::new(fut).poll(cx)) { + Ok(Ok(res)) => Ok(res), + Ok(Err(e)) => Err(ConnectError::Resolver(Box::new(e))), + Err(e) => Err(ConnectError::Io(e.into())), + }; + + let req = req.take().unwrap(); + + let addrs = res.map_err(|err| { + trace!( + "DNS resolver: failed to resolve host {:?} err: {:?}", + req.hostname(), + err + ); + + err + })?; + + let req = req.set_addrs(addrs); + + trace!( + "DNS resolver: host {:?} resolved to {:?}", + req.hostname(), + req.addrs() + ); + + if req.addr.is_none() { + Poll::Ready(Err(ConnectError::NoRecords)) + } else { + Poll::Ready(Ok(req)) + } + } + + Self::LookupCustom(fut) => fut.as_mut().poll(cx), + } + } +} diff --git a/actix-tls/src/connect/tls/rustls.rs b/actix-tls/src/connect/rustls.rs similarity index 63% rename from actix-tls/src/connect/tls/rustls.rs rename to actix-tls/src/connect/rustls.rs index e621b8a0..14329b14 100755 --- a/actix-tls/src/connect/tls/rustls.rs +++ b/actix-tls/src/connect/rustls.rs @@ -1,3 +1,7 @@ +//! Rustls based connector service. +//! +//! See [`Connector`] for main connector service factory docs. + use std::{ convert::TryFrom, future::Future, @@ -7,18 +11,26 @@ use std::{ task::{Context, Poll}, }; -pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; -pub use webpki_roots::TLS_SERVER_ROOTS; - use actix_rt::net::ActixStream; use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; +use actix_utils::future::{ok, Ready}; +use futures_core::ready; use log::trace; use tokio_rustls::rustls::{client::ServerName, OwnedTrustAnchor, RootCertStore}; +use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; use tokio_rustls::{Connect as RustlsConnect, TlsConnector as RustlsTlsConnector}; +use webpki_roots::TLS_SERVER_ROOTS; use crate::connect::{Address, Connection}; +pub mod reexports { + //! Re-exports from `rustls` and `webpki_roots` that are useful for connectors. + + pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; + + pub use webpki_roots::TLS_SERVER_ROOTS; +} + /// Returns standard root certificates from `webpki-roots` crate as a rustls certificate store. pub fn webpki_roots_cert_store() -> RootCertStore { let mut root_certs = RootCertStore::empty(); @@ -34,32 +46,25 @@ pub fn webpki_roots_cert_store() -> RootCertStore { root_certs } -/// Rustls connector factory -pub struct RustlsConnector { +/// Connector service factory using `rustls`. +#[derive(Clone)] +pub struct Connector { connector: Arc, } -impl RustlsConnector { +impl Connector { + /// Constructs new connector service factory from a `rustls` client configuration. pub fn new(connector: Arc) -> Self { - RustlsConnector { connector } + Connector { connector } + } + + /// Constructs new connector service from a `rustls` client configuration. + pub fn service(connector: Arc) -> ConnectorService { + ConnectorService { connector } } } -impl RustlsConnector { - pub fn service(connector: Arc) -> RustlsConnectorService { - RustlsConnectorService { connector } - } -} - -impl Clone for RustlsConnector { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl ServiceFactory> for RustlsConnector +impl ServiceFactory> for Connector where R: Address, IO: ActixStream + 'static, @@ -67,54 +72,51 @@ where type Response = Connection>; type Error = io::Error; type Config = (); - type Service = RustlsConnectorService; + type Service = ConnectorService; type InitError = (); - type Future = LocalBoxFuture<'static, Result>; + type Future = Ready>; fn new_service(&self, _: ()) -> Self::Future { - let connector = self.connector.clone(); - Box::pin(async { Ok(RustlsConnectorService { connector }) }) + ok(ConnectorService { + connector: self.connector.clone(), + }) } } -pub struct RustlsConnectorService { +/// Connector service using `rustls`. +#[derive(Clone)] +pub struct ConnectorService { connector: Arc, } -impl Clone for RustlsConnectorService { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl Service> for RustlsConnectorService +impl Service> for ConnectorService where R: Address, IO: ActixStream, { type Response = Connection>; type Error = io::Error; - type Future = RustlsConnectorServiceFuture; + type Future = ConnectFut; actix_service::always_ready!(); fn call(&self, connection: Connection) -> Self::Future { - trace!("SSL Handshake start for: {:?}", connection.host()); + trace!("SSL Handshake start for: {:?}", connection.hostname()); let (stream, connection) = connection.replace_io(()); - match ServerName::try_from(connection.host()) { - Ok(host) => RustlsConnectorServiceFuture::Future { + match ServerName::try_from(connection.hostname()) { + Ok(host) => ConnectFut::Future { connect: RustlsTlsConnector::from(self.connector.clone()).connect(host, stream), connection: Some(connection), }, - Err(_) => RustlsConnectorServiceFuture::InvalidDns, + Err(_) => ConnectFut::InvalidDns, } } } -pub enum RustlsConnectorServiceFuture { +/// Connect future for Rustls service. +#[doc(hidden)] +pub enum ConnectFut { /// See issue InvalidDns, Future { @@ -123,7 +125,7 @@ pub enum RustlsConnectorServiceFuture { }, } -impl Future for RustlsConnectorServiceFuture +impl Future for ConnectFut where R: Address, IO: ActixStream, @@ -138,7 +140,7 @@ where Self::Future { connect, connection } => { let stream = ready!(Pin::new(connect).poll(cx))?; let connection = connection.take().unwrap(); - trace!("SSL Handshake success: {:?}", connection.host()); + trace!("SSL Handshake success: {:?}", connection.hostname()); Poll::Ready(Ok(connection.replace_io(stream).1)) } } diff --git a/actix-tls/src/connect/service.rs b/actix-tls/src/connect/service.rs deleted file mode 100755 index 0bfa8302..00000000 --- a/actix-tls/src/connect/service.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use actix_rt::net::TcpStream; -use actix_service::{Service, ServiceFactory}; -use futures_core::{future::LocalBoxFuture, ready}; - -use super::connect::{Address, Connect, Connection}; -use super::connector::{TcpConnector, TcpConnectorFactory}; -use super::error::ConnectError; -use super::resolve::{Resolver, ResolverFactory}; - -pub struct ConnectServiceFactory { - tcp: TcpConnectorFactory, - resolver: ResolverFactory, -} - -impl ConnectServiceFactory { - /// Constructs new ConnectService factory. - pub fn new(resolver: Resolver) -> Self { - ConnectServiceFactory { - tcp: TcpConnectorFactory, - resolver: ResolverFactory::new(resolver), - } - } - - /// Constructs new service. - pub fn service(&self) -> ConnectService { - ConnectService { - tcp: self.tcp.service(), - resolver: self.resolver.service(), - } - } -} - -impl Clone for ConnectServiceFactory { - fn clone(&self) -> Self { - ConnectServiceFactory { - tcp: self.tcp, - resolver: self.resolver.clone(), - } - } -} - -impl ServiceFactory> for ConnectServiceFactory { - type Response = Connection; - type Error = ConnectError; - type Config = (); - type Service = ConnectService; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let service = self.service(); - Box::pin(async { Ok(service) }) - } -} - -#[derive(Clone)] -pub struct ConnectService { - tcp: TcpConnector, - resolver: Resolver, -} - -impl Service> for ConnectService { - type Response = Connection; - type Error = ConnectError; - type Future = ConnectServiceResponse; - - actix_service::always_ready!(); - - fn call(&self, req: Connect) -> Self::Future { - ConnectServiceResponse { - fut: ConnectFuture::Resolve(self.resolver.call(req)), - tcp: self.tcp, - } - } -} - -// helper enum to generic over futures of resolve and connect phase. -pub(crate) enum ConnectFuture { - Resolve(>>::Future), - Connect(>>::Future), -} - -/// Helper enum to contain the future output of `ConnectFuture`. -pub(crate) enum ConnectOutput { - Resolved(Connect), - Connected(Connection), -} - -impl ConnectFuture { - fn poll_connect( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, ConnectError>> { - match self { - ConnectFuture::Resolve(ref mut fut) => { - Pin::new(fut).poll(cx).map_ok(ConnectOutput::Resolved) - } - ConnectFuture::Connect(ref mut fut) => { - Pin::new(fut).poll(cx).map_ok(ConnectOutput::Connected) - } - } - } -} - -pub struct ConnectServiceResponse { - fut: ConnectFuture, - tcp: TcpConnector, -} - -impl Future for ConnectServiceResponse { - type Output = Result, ConnectError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - match ready!(self.fut.poll_connect(cx))? { - ConnectOutput::Resolved(res) => { - self.fut = ConnectFuture::Connect(self.tcp.call(res)); - } - ConnectOutput::Connected(res) => return Poll::Ready(Ok(res)), - } - } - } -} diff --git a/actix-tls/src/connect/tcp.rs b/actix-tls/src/connect/tcp.rs old mode 100644 new mode 100755 index 2efda791..0b6f890e --- a/actix-tls/src/connect/tcp.rs +++ b/actix-tls/src/connect/tcp.rs @@ -1,43 +1,201 @@ -use actix_rt::net::TcpStream; +//! TCP connector service. +//! +//! See [`TcpConnector`] for main connector service factory docs. + +use std::{ + collections::VecDeque, + future::Future, + io, + net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_rt::net::{TcpSocket, TcpStream}; use actix_service::{Service, ServiceFactory}; +use futures_core::{future::LocalBoxFuture, ready}; +use log::{error, trace}; +use tokio_util::sync::ReusableBoxFuture; -use super::{Address, Connect, ConnectError, ConnectServiceFactory, Connection, Resolver}; +use super::{ + connect_addrs::ConnectAddrs, error::ConnectError, Address, Connection, ConnectionInfo, +}; -/// Create TCP connector service. -pub fn new_connector( - resolver: Resolver, -) -> impl Service, Response = Connection, Error = ConnectError> + Clone -{ - ConnectServiceFactory::new(resolver).service() +/// TCP connector service factory. +#[derive(Debug, Copy, Clone)] +pub struct TcpConnector; + +impl TcpConnector { + /// Create TCP connector service + pub fn service(&self) -> TcpConnectorService { + TcpConnectorService + } } -/// Create TCP connector service factory. -pub fn new_connector_factory( - resolver: Resolver, -) -> impl ServiceFactory< - Connect, - Config = (), - Response = Connection, - Error = ConnectError, - InitError = (), -> + Clone { - ConnectServiceFactory::new(resolver) +impl ServiceFactory> for TcpConnector { + type Response = Connection; + type Error = ConnectError; + type Config = (); + type Service = TcpConnectorService; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; + + fn new_service(&self, _: ()) -> Self::Future { + let service = self.service(); + Box::pin(async move { Ok(service) }) + } } -/// Create TCP connector service with default parameters. -pub fn default_connector( -) -> impl Service, Response = Connection, Error = ConnectError> + Clone -{ - new_connector(Resolver::Default) +/// TCP connector service. +#[derive(Debug, Copy, Clone)] +pub struct TcpConnectorService; + +impl Service> for TcpConnectorService { + type Response = Connection; + type Error = ConnectError; + type Future = TcpConnectorFut; + + actix_service::always_ready!(); + + fn call(&self, req: ConnectionInfo) -> Self::Future { + let port = req.port(); + let ConnectionInfo { + req, + addr, + local_addr, + .. + } = req; + + TcpConnectorFut::new(req, port, local_addr, addr) + } } -/// Create TCP connector service factory with default parameters. -pub fn default_connector_factory() -> impl ServiceFactory< - Connect, - Config = (), - Response = Connection, - Error = ConnectError, - InitError = (), -> + Clone { - new_connector_factory(Resolver::Default) +/// Connect future for TCP service. +#[doc(hidden)] +pub enum TcpConnectorFut { + Response { + req: Option, + port: u16, + local_addr: Option, + addrs: Option>, + stream: ReusableBoxFuture>, + }, + + Error(Option), +} + +impl TcpConnectorFut { + pub(crate) fn new( + req: R, + port: u16, + local_addr: Option, + addr: ConnectAddrs, + ) -> TcpConnectorFut { + if addr.is_none() { + error!("TCP connector: unresolved connection address"); + return TcpConnectorFut::Error(Some(ConnectError::Unresolved)); + } + + trace!( + "TCP connector: connecting to {} on port {}", + req.hostname(), + port + ); + + match addr { + ConnectAddrs::None => unreachable!("none variant already checked"), + + ConnectAddrs::One(addr) => TcpConnectorFut::Response { + req: Some(req), + port, + local_addr, + addrs: None, + stream: ReusableBoxFuture::new(connect(addr, local_addr)), + }, + + // when resolver returns multiple socket addr for request they would be popped from + // front end of queue and returns with the first successful tcp connection. + ConnectAddrs::Multi(mut addrs) => { + let addr = addrs.pop_front().unwrap(); + + TcpConnectorFut::Response { + req: Some(req), + port, + local_addr, + addrs: Some(addrs), + stream: ReusableBoxFuture::new(connect(addr, local_addr)), + } + } + } + } +} + +impl Future for TcpConnectorFut { + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + TcpConnectorFut::Error(err) => Poll::Ready(Err(err.take().unwrap())), + + TcpConnectorFut::Response { + req, + port, + local_addr, + addrs, + stream, + } => loop { + match ready!(stream.poll(cx)) { + Ok(sock) => { + let req = req.take().unwrap(); + trace!( + "TCP connector: successfully connected to {:?} - {:?}", + req.hostname(), + sock.peer_addr() + ); + return Poll::Ready(Ok(Connection::new(sock, req))); + } + + Err(err) => { + trace!( + "TCP connector: failed to connect to {:?} port: {}", + req.as_ref().unwrap().hostname(), + port, + ); + + if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) { + stream.set(connect(addr, *local_addr)); + } else { + return Poll::Ready(Err(ConnectError::Io(err))); + } + } + } + }, + } + } +} + +async fn connect(addr: SocketAddr, local_addr: Option) -> io::Result { + // use local addr if connect asks for it. + match local_addr { + Some(ip_addr) => { + let socket = match ip_addr { + IpAddr::V4(ip_addr) => { + let socket = TcpSocket::new_v4()?; + let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0)); + socket.bind(addr)?; + socket + } + IpAddr::V6(ip_addr) => { + let socket = TcpSocket::new_v6()?; + let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0)); + socket.bind(addr)?; + socket + } + }; + + socket.connect(addr).await + } + + None => TcpStream::connect(addr).await, + } } diff --git a/actix-tls/src/connect/tls/mod.rs b/actix-tls/src/connect/tls/mod.rs deleted file mode 100644 index 7f48d06c..00000000 --- a/actix-tls/src/connect/tls/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! TLS Services - -#[cfg(feature = "openssl")] -pub mod openssl; - -#[cfg(feature = "rustls")] -pub mod rustls; - -#[cfg(feature = "native-tls")] -pub mod native_tls; diff --git a/actix-tls/src/connect/tls/native_tls.rs b/actix-tls/src/connect/tls/native_tls.rs deleted file mode 100644 index ffb13754..00000000 --- a/actix-tls/src/connect/tls/native_tls.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::io; - -use actix_rt::net::ActixStream; -use actix_service::{Service, ServiceFactory}; -use futures_core::future::LocalBoxFuture; -use log::trace; -use tokio_native_tls::{TlsConnector as TokioNativetlsConnector, TlsStream}; - -pub use tokio_native_tls::native_tls::TlsConnector; - -use crate::connect::{Address, Connection}; - -/// Native-tls connector factory and service -pub struct NativetlsConnector { - connector: TokioNativetlsConnector, -} - -impl NativetlsConnector { - pub fn new(connector: TlsConnector) -> Self { - Self { - connector: TokioNativetlsConnector::from(connector), - } - } -} - -impl NativetlsConnector { - pub fn service(connector: TlsConnector) -> Self { - Self::new(connector) - } -} - -impl Clone for NativetlsConnector { - fn clone(&self) -> Self { - Self { - connector: self.connector.clone(), - } - } -} - -impl ServiceFactory> for NativetlsConnector -where - IO: ActixStream + 'static, -{ - type Response = Connection>; - type Error = io::Error; - type Config = (); - type Service = Self; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let connector = self.clone(); - Box::pin(async { Ok(connector) }) - } -} - -// NativetlsConnector is both it's ServiceFactory and Service impl type. -// As the factory and service share the same type and state. -impl Service> for NativetlsConnector -where - R: Address, - IO: ActixStream + 'static, -{ - type Response = Connection>; - type Error = io::Error; - type Future = LocalBoxFuture<'static, Result>; - - actix_service::always_ready!(); - - fn call(&self, stream: Connection) -> Self::Future { - let (io, stream) = stream.replace_io(()); - let connector = self.connector.clone(); - - Box::pin(async move { - trace!("SSL Handshake start for: {:?}", stream.host()); - connector - .connect(stream.host(), io) - .await - .map(|res| { - trace!("SSL Handshake success: {:?}", stream.host()); - stream.replace_io(res).1 - }) - .map_err(|e| { - trace!("SSL Handshake error: {:?}", e); - io::Error::new(io::ErrorKind::Other, format!("{}", e)) - }) - }) - } -} diff --git a/actix-tls/src/lib.rs b/actix-tls/src/lib.rs index f617d57b..68ca0e35 100644 --- a/actix-tls/src/lib.rs +++ b/actix-tls/src/lib.rs @@ -1,15 +1,19 @@ -//! TLS acceptor and connector services for Actix ecosystem +//! TLS acceptor and connector services for the Actix ecosystem. #![deny(rust_2018_idioms, nonstandard_style)] #![warn(missing_docs)] #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] +#![cfg_attr(docsrs, feature(doc_cfg))] #[cfg(feature = "openssl")] #[allow(unused_extern_crates)] extern crate tls_openssl as openssl; #[cfg(feature = "accept")] +#[cfg_attr(docsrs, doc(cfg(feature = "accept")))] pub mod accept; + #[cfg(feature = "connect")] +#[cfg_attr(docsrs, doc(cfg(feature = "connect")))] pub mod connect; diff --git a/actix-tls/tests/test_connect.rs b/actix-tls/tests/test_connect.rs index 564151ce..fdf468de 100755 --- a/actix-tls/tests/test_connect.rs +++ b/actix-tls/tests/test_connect.rs @@ -12,7 +12,7 @@ use actix_service::{fn_service, Service, ServiceFactory}; use bytes::Bytes; use futures_util::sink::SinkExt; -use actix_tls::connect::{self as actix_connect, Connect}; +use actix_tls::connect::{self as actix_connect, ConnectionInfo}; #[cfg(feature = "openssl")] #[actix_rt::test] @@ -61,12 +61,12 @@ async fn test_static_str() { let conn = actix_connect::default_connector(); let con = conn - .call(Connect::with_addr("10", srv.addr())) + .call(ConnectionInfo::with_addr("10", srv.addr())) .await .unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); - let connect = Connect::new(srv.host().to_owned()); + let connect = ConnectionInfo::new(srv.host().to_owned()); let conn = actix_connect::default_connector(); let con = conn.call(connect).await; @@ -87,7 +87,7 @@ async fn test_new_service() { let conn = factory.new_service(()).await.unwrap(); let con = conn - .call(Connect::with_addr("10", srv.addr())) + .call(ConnectionInfo::with_addr("10", srv.addr())) .await .unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); @@ -145,7 +145,7 @@ async fn test_local_addr() { let local = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3)); let (con, _) = conn - .call(Connect::with_addr("10", srv.addr()).set_local_addr(local)) + .call(ConnectionInfo::with_addr("10", srv.addr()).set_local_addr(local)) .await .unwrap() .into_parts(); diff --git a/actix-tls/tests/test_resolvers.rs b/actix-tls/tests/test_resolvers.rs index 40ee21fa..eddcd98f 100644 --- a/actix-tls/tests/test_resolvers.rs +++ b/actix-tls/tests/test_resolvers.rs @@ -10,7 +10,7 @@ use actix_server::TestServer; use actix_service::{fn_service, Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; -use actix_tls::connect::{new_connector_factory, Connect, Resolve, Resolver}; +use actix_tls::connect::{new_connector_factory, ConnectionInfo, Resolve, ResolverService}; #[actix_rt::test] async fn custom_resolver() { @@ -68,12 +68,12 @@ async fn custom_resolver_connect() { trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(), }; - let resolver = Resolver::new_custom(resolver); + let resolver = ResolverService::custom(resolver); let factory = new_connector_factory(resolver); let conn = factory.new_service(()).await.unwrap(); let con = conn - .call(Connect::with_addr("example.com", srv.addr())) + .call(ConnectionInfo::with_addr("example.com", srv.addr())) .await .unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr());