return typed future for openssl connector.

This commit is contained in:
fakeshadow 2020-12-25 11:28:25 +08:00
parent 6c27c1598c
commit db6790a1ff
15 changed files with 191 additions and 156 deletions

View File

@ -18,8 +18,8 @@ path = "src/lib.rs"
[dependencies] [dependencies]
bitflags = "1.2.1" bitflags = "1.2.1"
bytes = "1" bytes = "1"
futures-core = { version = "0.3.4", default-features = false } futures-core = { version = "0.3.7", default-features = false }
futures-sink = { version = "0.3.4", default-features = false } futures-sink = { version = "0.3.7", default-features = false }
log = "0.4" log = "0.4"
pin-project = "1.0.0" pin-project = "1.0.0"
tokio = "1" tokio = "1"

View File

@ -38,7 +38,7 @@ actix-rt = "1.1.1"
derive_more = "0.99.2" derive_more = "0.99.2"
either = "1.5.3" either = "1.5.3"
futures-util = { version = "0.3.4", default-features = false } futures-util = { version = "0.3.7", default-features = false }
# FIXME: update to 0.3 # FIXME: update to 0.3
http = { version = "0.2.2", optional = true } http = { version = "0.2.2", optional = true }
log = "0.4" log = "0.4"
@ -58,4 +58,4 @@ webpki = { version = "0.21", optional = true }
[dev-dependencies] [dev-dependencies]
bytes = "1" bytes = "1"
actix-testing = "1.0.0" actix-testing = "1.0.0"
futures-util = { version = "0.3.4", default-features = false, features = ["sink"] } futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }

View File

@ -8,7 +8,7 @@ use std::task::{Context, Poll};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{err, ok, BoxFuture, Either, FutureExt, Ready}; use futures_util::future::{ready, Ready};
use super::connect::{Address, Connect, Connection}; use super::connect::{Address, Connect, Connection};
use super::error::ConnectError; use super::error::ConnectError;
@ -50,7 +50,7 @@ impl<T: Address> ServiceFactory for TcpConnectorFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(self.service()) ready(Ok(self.service()))
} }
} }
@ -74,8 +74,7 @@ impl<T: Address> Service for TcpConnector<T> {
type Request = Connect<T>; type Request = Connect<T>;
type Response = Connection<T, TcpStream>; type Response = Connection<T, TcpStream>;
type Error = ConnectError; type Error = ConnectError;
#[allow(clippy::type_complexity)] type Future = TcpConnectorResponse<T>;
type Future = Either<TcpConnectorResponse<T>, Ready<Result<Self::Response, Self::Error>>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
@ -86,21 +85,26 @@ impl<T: Address> Service for TcpConnector<T> {
let Connect { req, addr, .. } = req; let Connect { req, addr, .. } = req;
if let Some(addr) = addr { if let Some(addr) = addr {
Either::Left(TcpConnectorResponse::new(req, port, addr)) TcpConnectorResponse::new(req, port, addr)
} else { } else {
error!("TCP connector: got unresolved address"); error!("TCP connector: got unresolved address");
Either::Right(err(ConnectError::Unresolved)) TcpConnectorResponse::Error(Some(ConnectError::Unresolved))
} }
} }
} }
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[doc(hidden)] #[doc(hidden)]
/// TCP stream connector response future /// TCP stream connector response future
pub struct TcpConnectorResponse<T> { pub enum TcpConnectorResponse<T> {
Response {
req: Option<T>, req: Option<T>,
port: u16, port: u16,
addrs: Option<VecDeque<SocketAddr>>, addrs: Option<VecDeque<SocketAddr>>,
stream: Option<BoxFuture<'static, Result<TcpStream, io::Error>>>, stream: Option<LocalBoxFuture<'static, Result<TcpStream, io::Error>>>,
},
Error(Option<ConnectError>),
} }
impl<T: Address> TcpConnectorResponse<T> { impl<T: Address> TcpConnectorResponse<T> {
@ -116,13 +120,13 @@ impl<T: Address> TcpConnectorResponse<T> {
); );
match addr { match addr {
either::Either::Left(addr) => TcpConnectorResponse { either::Either::Left(addr) => TcpConnectorResponse::Response {
req: Some(req), req: Some(req),
port, port,
addrs: None, addrs: None,
stream: Some(TcpStream::connect(addr).boxed()), stream: Some(Box::pin(TcpStream::connect(addr))),
}, },
either::Either::Right(addrs) => TcpConnectorResponse { either::Either::Right(addrs) => TcpConnectorResponse::Response {
req: Some(req), req: Some(req),
port, port,
addrs: Some(addrs), addrs: Some(addrs),
@ -137,13 +141,19 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
match this {
TcpConnectorResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
// connect // connect
loop { TcpConnectorResponse::Response {
if let Some(new) = this.stream.as_mut() { req,
port,
addrs,
stream,
} => loop {
if let Some(new) = stream.as_mut() {
match new.as_mut().poll(cx) { match new.as_mut().poll(cx) {
Poll::Ready(Ok(sock)) => { Poll::Ready(Ok(sock)) => {
let req = this.req.take().unwrap(); let req = req.take().unwrap();
trace!( trace!(
"TCP connector - successfully connected to connecting to {:?} - {:?}", "TCP connector - successfully connected to connecting to {:?} - {:?}",
req.host(), sock.peer_addr() req.host(), sock.peer_addr()
@ -154,10 +164,10 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
trace!( trace!(
"TCP connector - failed to connect to connecting to {:?} port: {}", "TCP connector - failed to connect to connecting to {:?} port: {}",
this.req.as_ref().unwrap().host(), req.as_ref().unwrap().host(),
this.port, port,
); );
if this.addrs.is_none() || this.addrs.as_ref().unwrap().is_empty() { if addrs.is_none() || addrs.as_ref().unwrap().is_empty() {
return Poll::Ready(Err(err.into())); return Poll::Ready(Err(err.into()));
} }
} }
@ -165,8 +175,9 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
} }
// try to connect // try to connect
let addr = this.addrs.as_mut().unwrap().pop_front().unwrap(); let addr = addrs.as_mut().unwrap().pop_front().unwrap();
this.stream = Some(TcpStream::connect(addr).boxed()); *stream = Some(Box::pin(TcpStream::connect(addr)));
},
} }
} }
} }

View File

@ -5,7 +5,7 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{ok, Either, Ready}; use futures_util::future::{ready, Ready};
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
use trust_dns_resolver::{error::ResolveError, lookup_ip::LookupIp}; use trust_dns_resolver::{error::ResolveError, lookup_ip::LookupIp};
@ -64,7 +64,7 @@ impl<T: Address> ServiceFactory for ResolverFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(self.service()) ready(Ok(self.service()))
} }
} }
@ -106,11 +106,7 @@ impl<T: Address> Service for Resolver<T> {
type Request = Connect<T>; type Request = Connect<T>;
type Response = Connect<T>; type Response = Connect<T>;
type Error = ConnectError; type Error = ConnectError;
#[allow(clippy::type_complexity)] type Future = ResolverServiceFuture<T>;
type Future = Either<
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>,
Ready<Result<Connect<T>, Self::Error>>,
>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
@ -118,13 +114,13 @@ impl<T: Address> Service for Resolver<T> {
fn call(&mut self, mut req: Connect<T>) -> Self::Future { fn call(&mut self, mut req: Connect<T>) -> Self::Future {
if req.addr.is_some() { if req.addr.is_some() {
Either::Right(ok(req)) ResolverServiceFuture::NoLookUp(Some(req))
} else if let Ok(ip) = req.host().parse() { } else if let Ok(ip) = req.host().parse() {
req.addr = Some(either::Either::Left(SocketAddr::new(ip, req.port()))); req.addr = Some(either::Either::Left(SocketAddr::new(ip, req.port())));
Either::Right(ok(req)) ResolverServiceFuture::NoLookUp(Some(req))
} else { } else {
let resolver = self.resolver.as_ref().map(AsyncResolver::clone); let resolver = self.resolver.as_ref().map(AsyncResolver::clone);
Either::Left(Box::pin(async move { ResolverServiceFuture::LookUp(Box::pin(async move {
trace!("DNS resolver: resolving host {:?}", req.host()); trace!("DNS resolver: resolving host {:?}", req.host());
let resolver = if let Some(resolver) = resolver { let resolver = if let Some(resolver) = resolver {
resolver resolver
@ -139,13 +135,30 @@ impl<T: Address> Service for Resolver<T> {
} }
} }
type LookupIpFuture = Pin<Box<dyn Future<Output = Result<LookupIp, ResolveError>>>>; type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[doc(hidden)]
pub enum ResolverServiceFuture<T: Address> {
NoLookUp(Option<Connect<T>>),
LookUp(LocalBoxFuture<'static, Result<Connect<T>, ConnectError>>),
}
impl<T: Address> Future for ResolverServiceFuture<T> {
type Output = Result<Connect<T>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::NoLookUp(conn) => Poll::Ready(Ok(conn.take().unwrap())),
Self::LookUp(fut) => fut.as_mut().poll(cx),
}
}
}
#[doc(hidden)] #[doc(hidden)]
/// Resolver future /// Resolver future
pub struct ResolverFuture<T: Address> { pub struct ResolverFuture<T: Address> {
req: Option<Connect<T>>, req: Option<Connect<T>>,
lookup: LookupIpFuture, lookup: LocalBoxFuture<'static, Result<LookupIp, ResolveError>>,
} }
impl<T: Address> ResolverFuture<T> { impl<T: Address> ResolverFuture<T> {

View File

@ -5,7 +5,7 @@ use std::task::{Context, Poll};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use either::Either; use either::Either;
use futures_util::future::{ok, Ready}; use futures_util::future::{ready, Ready};
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
use crate::connect::{Address, Connect, Connection}; use crate::connect::{Address, Connect, Connection};
@ -80,7 +80,7 @@ impl<T: Address> ServiceFactory for ConnectServiceFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(self.service()) ready(Ok(self.service()))
} }
} }

View File

@ -10,7 +10,8 @@ pub use tokio_openssl::SslStream;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; use futures_util::future::{ready, Ready};
use futures_util::ready;
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver; use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
use crate::{ use crate::{
@ -68,10 +69,10 @@ where
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(OpensslConnectorService { ready(Ok(OpensslConnectorService {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData, _t: PhantomData,
}) }))
} }
} }
@ -97,63 +98,80 @@ where
type Request = Connection<T, U>; type Request = Connection<T, U>;
type Response = Connection<T, SslStream<U>>; type Response = Connection<T, SslStream<U>>;
type Error = io::Error; type Error = io::Error;
#[allow(clippy::type_complexity)] type Future = OpensslConnectorServiceFuture<T, U>;
type Future = Either<ConnectAsyncExt<T, U>, Ready<Result<Self::Response, Self::Error>>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, stream: Connection<T, U>) -> Self::Future { fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
trace!("SSL Handshake start for: {:?}", stream.host()); match self.ssl_stream(stream) {
let (io, stream) = stream.replace(()); Ok(acc) => OpensslConnectorServiceFuture::Accept(Some(acc)),
let host = stream.host().to_string(); Err(e) => OpensslConnectorServiceFuture::Error(Some(e)),
match self.connector.configure() {
Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))),
Ok(config) => Either::Left(ConnectAsyncExt {
// TODO: unbox this future.
fut: Box::pin(async move {
let ssl = config.into_ssl(host.as_str())?;
let mut io = tokio_openssl::SslStream::new(ssl, io)?;
Pin::new(&mut io).connect().await?;
Ok(io)
}),
stream: Some(stream),
_t: PhantomData,
}),
} }
} }
} }
pub struct ConnectAsyncExt<T, U> { impl<T, U> OpensslConnectorService<T, U>
fut: LocalBoxFuture<'static, Result<SslStream<U>, SslError>>,
stream: Option<Connection<T, ()>>,
_t: PhantomData<U>,
}
impl<T: Address, U> Future for ConnectAsyncExt<T, U>
where where
T: Address + 'static,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
// construct SslStream with connector.
// At this point SslStream does not perform any I/O.
// handshake would happen later in OpensslConnectorServiceFuture
fn ssl_stream(
&self,
stream: Connection<T, U>,
) -> Result<(SslStream<U>, Connection<T, ()>), SslError> {
trace!("SSL Handshake start for: {:?}", stream.host());
let (stream, connection) = stream.replace(());
let host = connection.host().to_string();
let config = self.connector.configure()?;
let ssl = config.into_ssl(host.as_str())?;
let stream = tokio_openssl::SslStream::new(ssl, stream)?;
Ok((stream, connection))
}
}
#[doc(hidden)]
pub enum OpensslConnectorServiceFuture<T, U>
where
T: Address + 'static,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
Accept(Option<(SslStream<U>, Connection<T, ()>)>),
Error(Option<SslError>),
}
impl<T, U> Future for OpensslConnectorServiceFuture<T, U>
where
T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{ {
type Output = Result<Connection<T, SslStream<U>>, io::Error>; type Output = Result<Connection<T, SslStream<U>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let e = match self.get_mut() {
Self::Error(e) => e.take().unwrap(),
match Pin::new(&mut this.fut).poll(cx) { Self::Accept(acc) => {
Poll::Ready(Ok(stream)) => { let (stream, _) = acc.as_mut().unwrap();
let s = this.stream.take().unwrap(); match ready!(Pin::new(stream).poll_accept(cx)) {
trace!("SSL Handshake success: {:?}", s.host()); Ok(()) => {
Poll::Ready(Ok(s.replace(stream).1)) let (stream, connection) = acc.take().unwrap();
trace!("SSL Handshake success: {:?}", connection.host());
let (_, connection) = connection.replace(stream);
return Poll::Ready(Ok(connection));
} }
Poll::Ready(Err(e)) => { Err(e) => e,
}
}
};
trace!("SSL Handshake error: {:?}", e); trace!("SSL Handshake error: {:?}", e);
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
} }
Poll::Pending => Poll::Pending,
}
}
} }
pub struct OpensslConnectServiceFactory<T> { pub struct OpensslConnectServiceFactory<T> {
@ -209,7 +227,7 @@ impl<T: Address + 'static> ServiceFactory for OpensslConnectServiceFactory<T> {
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(self.service()) ready(Ok(self.service()))
} }
} }

View File

@ -10,7 +10,8 @@ pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{ok, Ready}; use futures_util::future::{ready, Ready};
use futures_util::ready;
use tokio_rustls::{Connect, TlsConnector}; use tokio_rustls::{Connect, TlsConnector};
use webpki::DNSNameRef; use webpki::DNSNameRef;
@ -66,10 +67,10 @@ where
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
ok(RustlsConnectorService { ready(Ok(RustlsConnectorService {
connector: self.connector.clone(), connector: self.connector.clone(),
_t: PhantomData, _t: PhantomData,
}) }))
} }
} }
@ -125,12 +126,9 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
Poll::Ready( let stream = ready!(Pin::new(&mut this.fut).poll(cx))?;
futures_util::ready!(Pin::new(&mut this.fut).poll(cx)).map(|stream| {
let s = this.stream.take().unwrap(); let s = this.stream.take().unwrap();
trace!("SSL Handshake success: {:?}", s.host()); trace!("SSL Handshake success: {:?}", s.host());
s.replace(stream).1 Poll::Ready(Ok(s.replace(stream).1))
}),
)
} }
} }

View File

@ -18,8 +18,8 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix-macros = "0.1.0" actix-macros = "0.1.0"
copyless = "0.1.4" copyless = "0.1.4"
futures-channel = "0.3.4" futures-channel = "0.3.7"
futures-util = { version = "0.3.4", default-features = false, features = ["alloc"] } futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] }
smallvec = "1" smallvec = "1"
tokio = { version = "1", features = ["rt", "net", "signal", "time"] } tokio = { version = "1", features = ["rt", "net", "signal", "time"] }

View File

@ -26,8 +26,8 @@ actix-codec = "0.3.0"
actix-utils = "2.0.0" actix-utils = "2.0.0"
concurrent-queue = "1.2.2" concurrent-queue = "1.2.2"
futures-channel = { version = "0.3.4", default-features = false } futures-channel = { version = "0.3.7", default-features = false }
futures-util = { version = "0.3.4", default-features = false } futures-util = { version = "0.3.7", default-features = false }
log = "0.4" log = "0.4"
mio = { version = "0.7.3", features = [ "os-poll", "tcp", "uds"] } mio = { version = "0.7.3", features = [ "os-poll", "tcp", "uds"] }
num_cpus = "1.13" num_cpus = "1.13"
@ -37,5 +37,5 @@ slab = "0.4"
actix-testing = "1.0.0" actix-testing = "1.0.0"
bytes = "1" bytes = "1"
env_logger = "0.7" env_logger = "0.7"
futures-util = { version = "0.3.4", default-features = false, features = ["sink"] } futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }

View File

@ -19,7 +19,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
derive_more = "0.99.2" derive_more = "0.99.2"
futures-channel = "0.3.1" futures-channel = "0.3.7"
parking_lot = "0.11" parking_lot = "0.11"
lazy_static = "1.3" lazy_static = "1.3"
log = "0.4" log = "0.4"

View File

@ -39,7 +39,7 @@ actix-service = "1.0.0"
actix-codec = "0.3.0" actix-codec = "0.3.0"
actix-utils = "2.0.0" actix-utils = "2.0.0"
futures-util = { version = "0.3.4", default-features = false } futures-util = { version = "0.3.7", default-features = false }
# openssl # openssl
open-ssl = { package = "openssl", version = "0.10", optional = true } open-ssl = { package = "openssl", version = "0.10", optional = true }

View File

@ -1,10 +1,12 @@
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use actix_utils::counter::Counter; use actix_utils::counter::Counter;
use futures_util::future::{self, FutureExt, LocalBoxFuture, TryFutureExt}; use futures_util::future::{ready, Ready};
pub use native_tls::Error; pub use native_tls::Error;
pub use tokio_native_tls::{TlsAcceptor, TlsStream}; pub use tokio_native_tls::{TlsAcceptor, TlsStream};
@ -50,19 +52,19 @@ where
type Request = T; type Request = T;
type Response = TlsStream<T>; type Response = TlsStream<T>;
type Error = Error; type Error = Error;
type Service = NativeTlsAcceptorService<T>;
type Config = (); type Config = ();
type Service = NativeTlsAcceptorService<T>;
type InitError = (); type InitError = ();
type Future = future::Ready<Result<Self::Service, Self::InitError>>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
MAX_CONN_COUNTER.with(|conns| { MAX_CONN_COUNTER.with(|conns| {
future::ok(NativeTlsAcceptorService { ready(Ok(NativeTlsAcceptorService {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
conns: conns.clone(), conns: conns.clone(),
io: PhantomData, io: PhantomData,
}) }))
}) })
} }
} }
@ -83,6 +85,8 @@ impl<T> Clone for NativeTlsAcceptorService<T> {
} }
} }
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
impl<T> Service for NativeTlsAcceptorService<T> impl<T> Service for NativeTlsAcceptorService<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
@ -103,12 +107,10 @@ where
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let guard = self.conns.get(); let guard = self.conns.get();
let this = self.clone(); let this = self.clone();
async move { this.acceptor.accept(req).await } Box::pin(async move {
.map_ok(move |io| { let res = this.acceptor.accept(req).await;
// Required to preserve `CounterGuard` until `Self::Future` is completely resolved. drop(guard);
let _ = guard; res
io
}) })
.boxed_local()
} }
} }

View File

@ -84,9 +84,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Service for AcceptorService<T>
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
match self.ssl_stream(req) {
Ok(stream) => {
let guard = self.conns.get(); let guard = self.conns.get();
let stream = self.ssl_stream(req); AcceptorServiceResponse::Accept(Some(stream), Some(guard))
AcceptorServiceResponse::Init(Some(stream), Some(guard)) }
Err(e) => AcceptorServiceResponse::Error(Some(e)),
}
} }
} }
@ -105,32 +109,21 @@ pub enum AcceptorServiceResponse<T>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
{ {
Init(Option<Result<SslStream<T>, Error>>, Option<CounterGuard>),
Accept(Option<SslStream<T>>, Option<CounterGuard>), Accept(Option<SslStream<T>>, Option<CounterGuard>),
Error(Option<Error>),
} }
impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceResponse<T> { impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceResponse<T> {
type Output = Result<SslStream<T>, Error>; type Output = Result<SslStream<T>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop { match self.get_mut() {
match self.as_mut().get_mut() { AcceptorServiceResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
// Init branch only used to return the error in future
// on success goes to Accept branch directly.
AcceptorServiceResponse::Init(res, guard) => {
let guard = guard.take();
let stream = res.take().unwrap()?;
let state = AcceptorServiceResponse::Accept(Some(stream), guard);
self.as_mut().set(state);
}
AcceptorServiceResponse::Accept(stream, guard) => { AcceptorServiceResponse::Accept(stream, guard) => {
ready!(Pin::new(stream.as_mut().unwrap()).poll_accept(cx))?; ready!(Pin::new(stream.as_mut().unwrap()).poll_accept(cx))?;
// drop counter guard a little early as the accept has finished // drop counter guard a little early as the accept has finished
guard.take(); guard.take();
Poll::Ready(Ok(stream.take().unwrap()))
let stream = stream.take().unwrap();
return Poll::Ready(Ok(stream));
}
} }
} }
} }

View File

@ -17,7 +17,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix-service = "1.0.4" actix-service = "1.0.4"
futures-util = { version = "0.3.4", default-features = false } futures-util = { version = "0.3.7", default-features = false }
tracing = "0.1" tracing = "0.1"
tracing-futures = "0.2" tracing-futures = "0.2"

View File

@ -22,9 +22,9 @@ actix-service = "1.0.6"
bitflags = "1.2.1" bitflags = "1.2.1"
bytes = "1" bytes = "1"
either = "1.5.3" either = "1.5.3"
futures-channel = { version = "0.3.4", default-features = false } futures-channel = { version = "0.3.7", default-features = false }
futures-sink = { version = "0.3.4", default-features = false } futures-sink = { version = "0.3.7", default-features = false }
futures-util = { version = "0.3.4", default-features = false } futures-util = { version = "0.3.7", default-features = false }
log = "0.4" log = "0.4"
pin-project = "1.0.0" pin-project = "1.0.0"
slab = "0.4" slab = "0.4"