return typed future for openssl connector.

This commit is contained in:
fakeshadow 2020-12-25 11:28:25 +08:00
parent 6c27c1598c
commit db6790a1ff
15 changed files with 191 additions and 156 deletions

View File

@ -18,8 +18,8 @@ path = "src/lib.rs"
[dependencies]
bitflags = "1.2.1"
bytes = "1"
futures-core = { version = "0.3.4", default-features = false }
futures-sink = { version = "0.3.4", default-features = false }
futures-core = { version = "0.3.7", default-features = false }
futures-sink = { version = "0.3.7", default-features = false }
log = "0.4"
pin-project = "1.0.0"
tokio = "1"

View File

@ -38,7 +38,7 @@ actix-rt = "1.1.1"
derive_more = "0.99.2"
either = "1.5.3"
futures-util = { version = "0.3.4", default-features = false }
futures-util = { version = "0.3.7", default-features = false }
# FIXME: update to 0.3
http = { version = "0.2.2", optional = true }
log = "0.4"
@ -58,4 +58,4 @@ webpki = { version = "0.21", optional = true }
[dev-dependencies]
bytes = "1"
actix-testing = "1.0.0"
futures-util = { version = "0.3.4", default-features = false, features = ["sink"] }
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }

View File

@ -8,7 +8,7 @@ use std::task::{Context, Poll};
use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory};
use futures_util::future::{err, ok, BoxFuture, Either, FutureExt, Ready};
use futures_util::future::{ready, Ready};
use super::connect::{Address, Connect, Connection};
use super::error::ConnectError;
@ -50,7 +50,7 @@ impl<T: Address> ServiceFactory for TcpConnectorFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(self.service())
ready(Ok(self.service()))
}
}
@ -74,8 +74,7 @@ impl<T: Address> Service for TcpConnector<T> {
type Request = Connect<T>;
type Response = Connection<T, TcpStream>;
type Error = ConnectError;
#[allow(clippy::type_complexity)]
type Future = Either<TcpConnectorResponse<T>, Ready<Result<Self::Response, Self::Error>>>;
type Future = TcpConnectorResponse<T>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
@ -86,21 +85,26 @@ impl<T: Address> Service for TcpConnector<T> {
let Connect { req, addr, .. } = req;
if let Some(addr) = addr {
Either::Left(TcpConnectorResponse::new(req, port, addr))
TcpConnectorResponse::new(req, port, addr)
} else {
error!("TCP connector: got unresolved address");
Either::Right(err(ConnectError::Unresolved))
TcpConnectorResponse::Error(Some(ConnectError::Unresolved))
}
}
}
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[doc(hidden)]
/// TCP stream connector response future
pub struct TcpConnectorResponse<T> {
req: Option<T>,
port: u16,
addrs: Option<VecDeque<SocketAddr>>,
stream: Option<BoxFuture<'static, Result<TcpStream, io::Error>>>,
pub enum TcpConnectorResponse<T> {
Response {
req: Option<T>,
port: u16,
addrs: Option<VecDeque<SocketAddr>>,
stream: Option<LocalBoxFuture<'static, Result<TcpStream, io::Error>>>,
},
Error(Option<ConnectError>),
}
impl<T: Address> TcpConnectorResponse<T> {
@ -116,13 +120,13 @@ impl<T: Address> TcpConnectorResponse<T> {
);
match addr {
either::Either::Left(addr) => TcpConnectorResponse {
either::Either::Left(addr) => TcpConnectorResponse::Response {
req: Some(req),
port,
addrs: None,
stream: Some(TcpStream::connect(addr).boxed()),
stream: Some(Box::pin(TcpStream::connect(addr))),
},
either::Either::Right(addrs) => TcpConnectorResponse {
either::Either::Right(addrs) => TcpConnectorResponse::Response {
req: Some(req),
port,
addrs: Some(addrs),
@ -137,36 +141,43 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
// connect
loop {
if let Some(new) = this.stream.as_mut() {
match new.as_mut().poll(cx) {
Poll::Ready(Ok(sock)) => {
let req = this.req.take().unwrap();
trace!(
"TCP connector - successfully connected to connecting to {:?} - {:?}",
req.host(), sock.peer_addr()
);
return Poll::Ready(Ok(Connection::new(sock, req)));
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
trace!(
"TCP connector - failed to connect to connecting to {:?} port: {}",
this.req.as_ref().unwrap().host(),
this.port,
);
if this.addrs.is_none() || this.addrs.as_ref().unwrap().is_empty() {
return Poll::Ready(Err(err.into()));
match this {
TcpConnectorResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
// connect
TcpConnectorResponse::Response {
req,
port,
addrs,
stream,
} => loop {
if let Some(new) = stream.as_mut() {
match new.as_mut().poll(cx) {
Poll::Ready(Ok(sock)) => {
let req = req.take().unwrap();
trace!(
"TCP connector - successfully connected to connecting to {:?} - {:?}",
req.host(), sock.peer_addr()
);
return Poll::Ready(Ok(Connection::new(sock, req)));
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
trace!(
"TCP connector - failed to connect to connecting to {:?} port: {}",
req.as_ref().unwrap().host(),
port,
);
if addrs.is_none() || addrs.as_ref().unwrap().is_empty() {
return Poll::Ready(Err(err.into()));
}
}
}
}
}
// try to connect
let addr = this.addrs.as_mut().unwrap().pop_front().unwrap();
this.stream = Some(TcpStream::connect(addr).boxed());
// try to connect
let addr = addrs.as_mut().unwrap().pop_front().unwrap();
*stream = Some(Box::pin(TcpStream::connect(addr)));
},
}
}
}

View File

@ -5,7 +5,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use actix_service::{Service, ServiceFactory};
use futures_util::future::{ok, Either, Ready};
use futures_util::future::{ready, Ready};
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
use trust_dns_resolver::{error::ResolveError, lookup_ip::LookupIp};
@ -64,7 +64,7 @@ impl<T: Address> ServiceFactory for ResolverFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(self.service())
ready(Ok(self.service()))
}
}
@ -106,11 +106,7 @@ impl<T: Address> Service for Resolver<T> {
type Request = Connect<T>;
type Response = Connect<T>;
type Error = ConnectError;
#[allow(clippy::type_complexity)]
type Future = Either<
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>,
Ready<Result<Connect<T>, Self::Error>>,
>;
type Future = ResolverServiceFuture<T>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
@ -118,13 +114,13 @@ impl<T: Address> Service for Resolver<T> {
fn call(&mut self, mut req: Connect<T>) -> Self::Future {
if req.addr.is_some() {
Either::Right(ok(req))
ResolverServiceFuture::NoLookUp(Some(req))
} else if let Ok(ip) = req.host().parse() {
req.addr = Some(either::Either::Left(SocketAddr::new(ip, req.port())));
Either::Right(ok(req))
ResolverServiceFuture::NoLookUp(Some(req))
} else {
let resolver = self.resolver.as_ref().map(AsyncResolver::clone);
Either::Left(Box::pin(async move {
ResolverServiceFuture::LookUp(Box::pin(async move {
trace!("DNS resolver: resolving host {:?}", req.host());
let resolver = if let Some(resolver) = resolver {
resolver
@ -139,13 +135,30 @@ impl<T: Address> Service for Resolver<T> {
}
}
type LookupIpFuture = Pin<Box<dyn Future<Output = Result<LookupIp, ResolveError>>>>;
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[doc(hidden)]
pub enum ResolverServiceFuture<T: Address> {
NoLookUp(Option<Connect<T>>),
LookUp(LocalBoxFuture<'static, Result<Connect<T>, ConnectError>>),
}
impl<T: Address> Future for ResolverServiceFuture<T> {
type Output = Result<Connect<T>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::NoLookUp(conn) => Poll::Ready(Ok(conn.take().unwrap())),
Self::LookUp(fut) => fut.as_mut().poll(cx),
}
}
}
#[doc(hidden)]
/// Resolver future
pub struct ResolverFuture<T: Address> {
req: Option<Connect<T>>,
lookup: LookupIpFuture,
lookup: LocalBoxFuture<'static, Result<LookupIp, ResolveError>>,
}
impl<T: Address> ResolverFuture<T> {

View File

@ -5,7 +5,7 @@ use std::task::{Context, Poll};
use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory};
use either::Either;
use futures_util::future::{ok, Ready};
use futures_util::future::{ready, Ready};
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
use crate::connect::{Address, Connect, Connection};
@ -80,7 +80,7 @@ impl<T: Address> ServiceFactory for ConnectServiceFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(self.service())
ready(Ok(self.service()))
}
}

View File

@ -10,7 +10,8 @@ pub use tokio_openssl::SslStream;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory};
use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures_util::future::{ready, Ready};
use futures_util::ready;
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
use crate::{
@ -68,10 +69,10 @@ where
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(OpensslConnectorService {
ready(Ok(OpensslConnectorService {
connector: self.connector.clone(),
_t: PhantomData,
})
}))
}
}
@ -97,62 +98,79 @@ where
type Request = Connection<T, U>;
type Response = Connection<T, SslStream<U>>;
type Error = io::Error;
#[allow(clippy::type_complexity)]
type Future = Either<ConnectAsyncExt<T, U>, Ready<Result<Self::Response, Self::Error>>>;
type Future = OpensslConnectorServiceFuture<T, U>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
trace!("SSL Handshake start for: {:?}", stream.host());
let (io, stream) = stream.replace(());
let host = stream.host().to_string();
match self.connector.configure() {
Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))),
Ok(config) => Either::Left(ConnectAsyncExt {
// TODO: unbox this future.
fut: Box::pin(async move {
let ssl = config.into_ssl(host.as_str())?;
let mut io = tokio_openssl::SslStream::new(ssl, io)?;
Pin::new(&mut io).connect().await?;
Ok(io)
}),
stream: Some(stream),
_t: PhantomData,
}),
match self.ssl_stream(stream) {
Ok(acc) => OpensslConnectorServiceFuture::Accept(Some(acc)),
Err(e) => OpensslConnectorServiceFuture::Error(Some(e)),
}
}
}
pub struct ConnectAsyncExt<T, U> {
fut: LocalBoxFuture<'static, Result<SslStream<U>, SslError>>,
stream: Option<Connection<T, ()>>,
_t: PhantomData<U>,
impl<T, U> OpensslConnectorService<T, U>
where
T: Address + 'static,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
// construct SslStream with connector.
// At this point SslStream does not perform any I/O.
// handshake would happen later in OpensslConnectorServiceFuture
fn ssl_stream(
&self,
stream: Connection<T, U>,
) -> Result<(SslStream<U>, Connection<T, ()>), SslError> {
trace!("SSL Handshake start for: {:?}", stream.host());
let (stream, connection) = stream.replace(());
let host = connection.host().to_string();
let config = self.connector.configure()?;
let ssl = config.into_ssl(host.as_str())?;
let stream = tokio_openssl::SslStream::new(ssl, stream)?;
Ok((stream, connection))
}
}
impl<T: Address, U> Future for ConnectAsyncExt<T, U>
#[doc(hidden)]
pub enum OpensslConnectorServiceFuture<T, U>
where
T: Address + 'static,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
Accept(Option<(SslStream<U>, Connection<T, ()>)>),
Error(Option<SslError>),
}
impl<T, U> Future for OpensslConnectorServiceFuture<T, U>
where
T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
type Output = Result<Connection<T, SslStream<U>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let e = match self.get_mut() {
Self::Error(e) => e.take().unwrap(),
Self::Accept(acc) => {
let (stream, _) = acc.as_mut().unwrap();
match ready!(Pin::new(stream).poll_accept(cx)) {
Ok(()) => {
let (stream, connection) = acc.take().unwrap();
trace!("SSL Handshake success: {:?}", connection.host());
let (_, connection) = connection.replace(stream);
return Poll::Ready(Ok(connection));
}
Err(e) => e,
}
}
};
match Pin::new(&mut this.fut).poll(cx) {
Poll::Ready(Ok(stream)) => {
let s = this.stream.take().unwrap();
trace!("SSL Handshake success: {:?}", s.host());
Poll::Ready(Ok(s.replace(stream).1))
}
Poll::Ready(Err(e)) => {
trace!("SSL Handshake error: {:?}", e);
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
Poll::Pending => Poll::Pending,
}
trace!("SSL Handshake error: {:?}", e);
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
}
@ -209,7 +227,7 @@ impl<T: Address + 'static> ServiceFactory for OpensslConnectServiceFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(self.service())
ready(Ok(self.service()))
}
}

View File

@ -10,7 +10,8 @@ pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory};
use futures_util::future::{ok, Ready};
use futures_util::future::{ready, Ready};
use futures_util::ready;
use tokio_rustls::{Connect, TlsConnector};
use webpki::DNSNameRef;
@ -66,10 +67,10 @@ where
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(RustlsConnectorService {
ready(Ok(RustlsConnectorService {
connector: self.connector.clone(),
_t: PhantomData,
})
}))
}
}
@ -125,12 +126,9 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
Poll::Ready(
futures_util::ready!(Pin::new(&mut this.fut).poll(cx)).map(|stream| {
let s = this.stream.take().unwrap();
trace!("SSL Handshake success: {:?}", s.host());
s.replace(stream).1
}),
)
let stream = ready!(Pin::new(&mut this.fut).poll(cx))?;
let s = this.stream.take().unwrap();
trace!("SSL Handshake success: {:?}", s.host());
Poll::Ready(Ok(s.replace(stream).1))
}
}

View File

@ -18,8 +18,8 @@ path = "src/lib.rs"
[dependencies]
actix-macros = "0.1.0"
copyless = "0.1.4"
futures-channel = "0.3.4"
futures-util = { version = "0.3.4", default-features = false, features = ["alloc"] }
futures-channel = "0.3.7"
futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] }
smallvec = "1"
tokio = { version = "1", features = ["rt", "net", "signal", "time"] }

View File

@ -26,8 +26,8 @@ actix-codec = "0.3.0"
actix-utils = "2.0.0"
concurrent-queue = "1.2.2"
futures-channel = { version = "0.3.4", default-features = false }
futures-util = { version = "0.3.4", default-features = false }
futures-channel = { version = "0.3.7", default-features = false }
futures-util = { version = "0.3.7", default-features = false }
log = "0.4"
mio = { version = "0.7.3", features = [ "os-poll", "tcp", "uds"] }
num_cpus = "1.13"
@ -37,5 +37,5 @@ slab = "0.4"
actix-testing = "1.0.0"
bytes = "1"
env_logger = "0.7"
futures-util = { version = "0.3.4", default-features = false, features = ["sink"] }
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
tokio = { version = "1", features = ["full"] }

View File

@ -19,7 +19,7 @@ path = "src/lib.rs"
[dependencies]
derive_more = "0.99.2"
futures-channel = "0.3.1"
futures-channel = "0.3.7"
parking_lot = "0.11"
lazy_static = "1.3"
log = "0.4"

View File

@ -39,7 +39,7 @@ actix-service = "1.0.0"
actix-codec = "0.3.0"
actix-utils = "2.0.0"
futures-util = { version = "0.3.4", default-features = false }
futures-util = { version = "0.3.7", default-features = false }
# openssl
open-ssl = { package = "openssl", version = "0.10", optional = true }

View File

@ -1,10 +1,12 @@
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory};
use actix_utils::counter::Counter;
use futures_util::future::{self, FutureExt, LocalBoxFuture, TryFutureExt};
use futures_util::future::{ready, Ready};
pub use native_tls::Error;
pub use tokio_native_tls::{TlsAcceptor, TlsStream};
@ -50,19 +52,19 @@ where
type Request = T;
type Response = TlsStream<T>;
type Error = Error;
type Service = NativeTlsAcceptorService<T>;
type Config = ();
type Service = NativeTlsAcceptorService<T>;
type InitError = ();
type Future = future::Ready<Result<Self::Service, Self::InitError>>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
MAX_CONN_COUNTER.with(|conns| {
future::ok(NativeTlsAcceptorService {
ready(Ok(NativeTlsAcceptorService {
acceptor: self.acceptor.clone(),
conns: conns.clone(),
io: PhantomData,
})
}))
})
}
}
@ -83,6 +85,8 @@ impl<T> Clone for NativeTlsAcceptorService<T> {
}
}
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
impl<T> Service for NativeTlsAcceptorService<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
@ -103,12 +107,10 @@ where
fn call(&mut self, req: Self::Request) -> Self::Future {
let guard = self.conns.get();
let this = self.clone();
async move { this.acceptor.accept(req).await }
.map_ok(move |io| {
// Required to preserve `CounterGuard` until `Self::Future` is completely resolved.
let _ = guard;
io
})
.boxed_local()
Box::pin(async move {
let res = this.acceptor.accept(req).await;
drop(guard);
res
})
}
}

View File

@ -84,9 +84,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Service for AcceptorService<T>
}
fn call(&mut self, req: Self::Request) -> Self::Future {
let guard = self.conns.get();
let stream = self.ssl_stream(req);
AcceptorServiceResponse::Init(Some(stream), Some(guard))
match self.ssl_stream(req) {
Ok(stream) => {
let guard = self.conns.get();
AcceptorServiceResponse::Accept(Some(stream), Some(guard))
}
Err(e) => AcceptorServiceResponse::Error(Some(e)),
}
}
}
@ -105,32 +109,21 @@ pub enum AcceptorServiceResponse<T>
where
T: AsyncRead + AsyncWrite,
{
Init(Option<Result<SslStream<T>, Error>>, Option<CounterGuard>),
Accept(Option<SslStream<T>>, Option<CounterGuard>),
Error(Option<Error>),
}
impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceResponse<T> {
type Output = Result<SslStream<T>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match self.as_mut().get_mut() {
// Init branch only used to return the error in future
// on success goes to Accept branch directly.
AcceptorServiceResponse::Init(res, guard) => {
let guard = guard.take();
let stream = res.take().unwrap()?;
let state = AcceptorServiceResponse::Accept(Some(stream), guard);
self.as_mut().set(state);
}
AcceptorServiceResponse::Accept(stream, guard) => {
ready!(Pin::new(stream.as_mut().unwrap()).poll_accept(cx))?;
// drop counter guard a little early as the accept has finished
guard.take();
let stream = stream.take().unwrap();
return Poll::Ready(Ok(stream));
}
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
AcceptorServiceResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
AcceptorServiceResponse::Accept(stream, guard) => {
ready!(Pin::new(stream.as_mut().unwrap()).poll_accept(cx))?;
// drop counter guard a little early as the accept has finished
guard.take();
Poll::Ready(Ok(stream.take().unwrap()))
}
}
}

View File

@ -17,7 +17,7 @@ path = "src/lib.rs"
[dependencies]
actix-service = "1.0.4"
futures-util = { version = "0.3.4", default-features = false }
futures-util = { version = "0.3.7", default-features = false }
tracing = "0.1"
tracing-futures = "0.2"

View File

@ -22,9 +22,9 @@ actix-service = "1.0.6"
bitflags = "1.2.1"
bytes = "1"
either = "1.5.3"
futures-channel = { version = "0.3.4", default-features = false }
futures-sink = { version = "0.3.4", default-features = false }
futures-util = { version = "0.3.4", default-features = false }
futures-channel = { version = "0.3.7", default-features = false }
futures-sink = { version = "0.3.7", default-features = false }
futures-util = { version = "0.3.7", default-features = false }
log = "0.4"
pin-project = "1.0.0"
slab = "0.4"