mirror of https://github.com/fafhrd91/actix-net
return typed future for openssl connector.
This commit is contained in:
parent
6c27c1598c
commit
db6790a1ff
|
@ -18,8 +18,8 @@ path = "src/lib.rs"
|
|||
[dependencies]
|
||||
bitflags = "1.2.1"
|
||||
bytes = "1"
|
||||
futures-core = { version = "0.3.4", default-features = false }
|
||||
futures-sink = { version = "0.3.4", default-features = false }
|
||||
futures-core = { version = "0.3.7", default-features = false }
|
||||
futures-sink = { version = "0.3.7", default-features = false }
|
||||
log = "0.4"
|
||||
pin-project = "1.0.0"
|
||||
tokio = "1"
|
||||
|
|
|
@ -38,7 +38,7 @@ actix-rt = "1.1.1"
|
|||
|
||||
derive_more = "0.99.2"
|
||||
either = "1.5.3"
|
||||
futures-util = { version = "0.3.4", default-features = false }
|
||||
futures-util = { version = "0.3.7", default-features = false }
|
||||
# FIXME: update to 0.3
|
||||
http = { version = "0.2.2", optional = true }
|
||||
log = "0.4"
|
||||
|
@ -58,4 +58,4 @@ webpki = { version = "0.21", optional = true }
|
|||
[dev-dependencies]
|
||||
bytes = "1"
|
||||
actix-testing = "1.0.0"
|
||||
futures-util = { version = "0.3.4", default-features = false, features = ["sink"] }
|
||||
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
|
||||
|
|
|
@ -8,7 +8,7 @@ use std::task::{Context, Poll};
|
|||
|
||||
use actix_rt::net::TcpStream;
|
||||
use actix_service::{Service, ServiceFactory};
|
||||
use futures_util::future::{err, ok, BoxFuture, Either, FutureExt, Ready};
|
||||
use futures_util::future::{ready, Ready};
|
||||
|
||||
use super::connect::{Address, Connect, Connection};
|
||||
use super::error::ConnectError;
|
||||
|
@ -50,7 +50,7 @@ impl<T: Address> ServiceFactory for TcpConnectorFactory<T> {
|
|||
type Future = Ready<Result<Self::Service, Self::InitError>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
ok(self.service())
|
||||
ready(Ok(self.service()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,8 +74,7 @@ impl<T: Address> Service for TcpConnector<T> {
|
|||
type Request = Connect<T>;
|
||||
type Response = Connection<T, TcpStream>;
|
||||
type Error = ConnectError;
|
||||
#[allow(clippy::type_complexity)]
|
||||
type Future = Either<TcpConnectorResponse<T>, Ready<Result<Self::Response, Self::Error>>>;
|
||||
type Future = TcpConnectorResponse<T>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
|
@ -86,21 +85,26 @@ impl<T: Address> Service for TcpConnector<T> {
|
|||
let Connect { req, addr, .. } = req;
|
||||
|
||||
if let Some(addr) = addr {
|
||||
Either::Left(TcpConnectorResponse::new(req, port, addr))
|
||||
TcpConnectorResponse::new(req, port, addr)
|
||||
} else {
|
||||
error!("TCP connector: got unresolved address");
|
||||
Either::Right(err(ConnectError::Unresolved))
|
||||
TcpConnectorResponse::Error(Some(ConnectError::Unresolved))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
|
||||
|
||||
#[doc(hidden)]
|
||||
/// TCP stream connector response future
|
||||
pub struct TcpConnectorResponse<T> {
|
||||
req: Option<T>,
|
||||
port: u16,
|
||||
addrs: Option<VecDeque<SocketAddr>>,
|
||||
stream: Option<BoxFuture<'static, Result<TcpStream, io::Error>>>,
|
||||
pub enum TcpConnectorResponse<T> {
|
||||
Response {
|
||||
req: Option<T>,
|
||||
port: u16,
|
||||
addrs: Option<VecDeque<SocketAddr>>,
|
||||
stream: Option<LocalBoxFuture<'static, Result<TcpStream, io::Error>>>,
|
||||
},
|
||||
Error(Option<ConnectError>),
|
||||
}
|
||||
|
||||
impl<T: Address> TcpConnectorResponse<T> {
|
||||
|
@ -116,13 +120,13 @@ impl<T: Address> TcpConnectorResponse<T> {
|
|||
);
|
||||
|
||||
match addr {
|
||||
either::Either::Left(addr) => TcpConnectorResponse {
|
||||
either::Either::Left(addr) => TcpConnectorResponse::Response {
|
||||
req: Some(req),
|
||||
port,
|
||||
addrs: None,
|
||||
stream: Some(TcpStream::connect(addr).boxed()),
|
||||
stream: Some(Box::pin(TcpStream::connect(addr))),
|
||||
},
|
||||
either::Either::Right(addrs) => TcpConnectorResponse {
|
||||
either::Either::Right(addrs) => TcpConnectorResponse::Response {
|
||||
req: Some(req),
|
||||
port,
|
||||
addrs: Some(addrs),
|
||||
|
@ -137,36 +141,43 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
|
|||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// connect
|
||||
loop {
|
||||
if let Some(new) = this.stream.as_mut() {
|
||||
match new.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(sock)) => {
|
||||
let req = this.req.take().unwrap();
|
||||
trace!(
|
||||
"TCP connector - successfully connected to connecting to {:?} - {:?}",
|
||||
req.host(), sock.peer_addr()
|
||||
);
|
||||
return Poll::Ready(Ok(Connection::new(sock, req)));
|
||||
}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(err)) => {
|
||||
trace!(
|
||||
"TCP connector - failed to connect to connecting to {:?} port: {}",
|
||||
this.req.as_ref().unwrap().host(),
|
||||
this.port,
|
||||
);
|
||||
if this.addrs.is_none() || this.addrs.as_ref().unwrap().is_empty() {
|
||||
return Poll::Ready(Err(err.into()));
|
||||
match this {
|
||||
TcpConnectorResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
|
||||
// connect
|
||||
TcpConnectorResponse::Response {
|
||||
req,
|
||||
port,
|
||||
addrs,
|
||||
stream,
|
||||
} => loop {
|
||||
if let Some(new) = stream.as_mut() {
|
||||
match new.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(sock)) => {
|
||||
let req = req.take().unwrap();
|
||||
trace!(
|
||||
"TCP connector - successfully connected to connecting to {:?} - {:?}",
|
||||
req.host(), sock.peer_addr()
|
||||
);
|
||||
return Poll::Ready(Ok(Connection::new(sock, req)));
|
||||
}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(err)) => {
|
||||
trace!(
|
||||
"TCP connector - failed to connect to connecting to {:?} port: {}",
|
||||
req.as_ref().unwrap().host(),
|
||||
port,
|
||||
);
|
||||
if addrs.is_none() || addrs.as_ref().unwrap().is_empty() {
|
||||
return Poll::Ready(Err(err.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// try to connect
|
||||
let addr = this.addrs.as_mut().unwrap().pop_front().unwrap();
|
||||
this.stream = Some(TcpStream::connect(addr).boxed());
|
||||
// try to connect
|
||||
let addr = addrs.as_mut().unwrap().pop_front().unwrap();
|
||||
*stream = Some(Box::pin(TcpStream::connect(addr)));
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::pin::Pin;
|
|||
use std::task::{Context, Poll};
|
||||
|
||||
use actix_service::{Service, ServiceFactory};
|
||||
use futures_util::future::{ok, Either, Ready};
|
||||
use futures_util::future::{ready, Ready};
|
||||
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
|
||||
use trust_dns_resolver::{error::ResolveError, lookup_ip::LookupIp};
|
||||
|
||||
|
@ -64,7 +64,7 @@ impl<T: Address> ServiceFactory for ResolverFactory<T> {
|
|||
type Future = Ready<Result<Self::Service, Self::InitError>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
ok(self.service())
|
||||
ready(Ok(self.service()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,11 +106,7 @@ impl<T: Address> Service for Resolver<T> {
|
|||
type Request = Connect<T>;
|
||||
type Response = Connect<T>;
|
||||
type Error = ConnectError;
|
||||
#[allow(clippy::type_complexity)]
|
||||
type Future = Either<
|
||||
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>,
|
||||
Ready<Result<Connect<T>, Self::Error>>,
|
||||
>;
|
||||
type Future = ResolverServiceFuture<T>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
|
@ -118,13 +114,13 @@ impl<T: Address> Service for Resolver<T> {
|
|||
|
||||
fn call(&mut self, mut req: Connect<T>) -> Self::Future {
|
||||
if req.addr.is_some() {
|
||||
Either::Right(ok(req))
|
||||
ResolverServiceFuture::NoLookUp(Some(req))
|
||||
} else if let Ok(ip) = req.host().parse() {
|
||||
req.addr = Some(either::Either::Left(SocketAddr::new(ip, req.port())));
|
||||
Either::Right(ok(req))
|
||||
ResolverServiceFuture::NoLookUp(Some(req))
|
||||
} else {
|
||||
let resolver = self.resolver.as_ref().map(AsyncResolver::clone);
|
||||
Either::Left(Box::pin(async move {
|
||||
ResolverServiceFuture::LookUp(Box::pin(async move {
|
||||
trace!("DNS resolver: resolving host {:?}", req.host());
|
||||
let resolver = if let Some(resolver) = resolver {
|
||||
resolver
|
||||
|
@ -139,13 +135,30 @@ impl<T: Address> Service for Resolver<T> {
|
|||
}
|
||||
}
|
||||
|
||||
type LookupIpFuture = Pin<Box<dyn Future<Output = Result<LookupIp, ResolveError>>>>;
|
||||
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub enum ResolverServiceFuture<T: Address> {
|
||||
NoLookUp(Option<Connect<T>>),
|
||||
LookUp(LocalBoxFuture<'static, Result<Connect<T>, ConnectError>>),
|
||||
}
|
||||
|
||||
impl<T: Address> Future for ResolverServiceFuture<T> {
|
||||
type Output = Result<Connect<T>, ConnectError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.get_mut() {
|
||||
Self::NoLookUp(conn) => Poll::Ready(Ok(conn.take().unwrap())),
|
||||
Self::LookUp(fut) => fut.as_mut().poll(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
/// Resolver future
|
||||
pub struct ResolverFuture<T: Address> {
|
||||
req: Option<Connect<T>>,
|
||||
lookup: LookupIpFuture,
|
||||
lookup: LocalBoxFuture<'static, Result<LookupIp, ResolveError>>,
|
||||
}
|
||||
|
||||
impl<T: Address> ResolverFuture<T> {
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::task::{Context, Poll};
|
|||
use actix_rt::net::TcpStream;
|
||||
use actix_service::{Service, ServiceFactory};
|
||||
use either::Either;
|
||||
use futures_util::future::{ok, Ready};
|
||||
use futures_util::future::{ready, Ready};
|
||||
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
|
||||
|
||||
use crate::connect::{Address, Connect, Connection};
|
||||
|
@ -80,7 +80,7 @@ impl<T: Address> ServiceFactory for ConnectServiceFactory<T> {
|
|||
type Future = Ready<Result<Self::Service, Self::InitError>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
ok(self.service())
|
||||
ready(Ok(self.service()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@ pub use tokio_openssl::SslStream;
|
|||
use actix_codec::{AsyncRead, AsyncWrite};
|
||||
use actix_rt::net::TcpStream;
|
||||
use actix_service::{Service, ServiceFactory};
|
||||
use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
|
||||
use futures_util::future::{ready, Ready};
|
||||
use futures_util::ready;
|
||||
use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
|
||||
|
||||
use crate::{
|
||||
|
@ -68,10 +69,10 @@ where
|
|||
type Future = Ready<Result<Self::Service, Self::InitError>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
ok(OpensslConnectorService {
|
||||
ready(Ok(OpensslConnectorService {
|
||||
connector: self.connector.clone(),
|
||||
_t: PhantomData,
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -97,62 +98,79 @@ where
|
|||
type Request = Connection<T, U>;
|
||||
type Response = Connection<T, SslStream<U>>;
|
||||
type Error = io::Error;
|
||||
#[allow(clippy::type_complexity)]
|
||||
type Future = Either<ConnectAsyncExt<T, U>, Ready<Result<Self::Response, Self::Error>>>;
|
||||
type Future = OpensslConnectorServiceFuture<T, U>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
|
||||
trace!("SSL Handshake start for: {:?}", stream.host());
|
||||
let (io, stream) = stream.replace(());
|
||||
let host = stream.host().to_string();
|
||||
|
||||
match self.connector.configure() {
|
||||
Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))),
|
||||
Ok(config) => Either::Left(ConnectAsyncExt {
|
||||
// TODO: unbox this future.
|
||||
fut: Box::pin(async move {
|
||||
let ssl = config.into_ssl(host.as_str())?;
|
||||
let mut io = tokio_openssl::SslStream::new(ssl, io)?;
|
||||
Pin::new(&mut io).connect().await?;
|
||||
Ok(io)
|
||||
}),
|
||||
stream: Some(stream),
|
||||
_t: PhantomData,
|
||||
}),
|
||||
match self.ssl_stream(stream) {
|
||||
Ok(acc) => OpensslConnectorServiceFuture::Accept(Some(acc)),
|
||||
Err(e) => OpensslConnectorServiceFuture::Error(Some(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectAsyncExt<T, U> {
|
||||
fut: LocalBoxFuture<'static, Result<SslStream<U>, SslError>>,
|
||||
stream: Option<Connection<T, ()>>,
|
||||
_t: PhantomData<U>,
|
||||
impl<T, U> OpensslConnectorService<T, U>
|
||||
where
|
||||
T: Address + 'static,
|
||||
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
|
||||
{
|
||||
// construct SslStream with connector.
|
||||
// At this point SslStream does not perform any I/O.
|
||||
// handshake would happen later in OpensslConnectorServiceFuture
|
||||
fn ssl_stream(
|
||||
&self,
|
||||
stream: Connection<T, U>,
|
||||
) -> Result<(SslStream<U>, Connection<T, ()>), SslError> {
|
||||
trace!("SSL Handshake start for: {:?}", stream.host());
|
||||
let (stream, connection) = stream.replace(());
|
||||
let host = connection.host().to_string();
|
||||
|
||||
let config = self.connector.configure()?;
|
||||
let ssl = config.into_ssl(host.as_str())?;
|
||||
let stream = tokio_openssl::SslStream::new(ssl, stream)?;
|
||||
Ok((stream, connection))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address, U> Future for ConnectAsyncExt<T, U>
|
||||
#[doc(hidden)]
|
||||
pub enum OpensslConnectorServiceFuture<T, U>
|
||||
where
|
||||
T: Address + 'static,
|
||||
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
|
||||
{
|
||||
Accept(Option<(SslStream<U>, Connection<T, ()>)>),
|
||||
Error(Option<SslError>),
|
||||
}
|
||||
|
||||
impl<T, U> Future for OpensslConnectorServiceFuture<T, U>
|
||||
where
|
||||
T: Address,
|
||||
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
|
||||
{
|
||||
type Output = Result<Connection<T, SslStream<U>>, io::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
let e = match self.get_mut() {
|
||||
Self::Error(e) => e.take().unwrap(),
|
||||
Self::Accept(acc) => {
|
||||
let (stream, _) = acc.as_mut().unwrap();
|
||||
match ready!(Pin::new(stream).poll_accept(cx)) {
|
||||
Ok(()) => {
|
||||
let (stream, connection) = acc.take().unwrap();
|
||||
trace!("SSL Handshake success: {:?}", connection.host());
|
||||
let (_, connection) = connection.replace(stream);
|
||||
return Poll::Ready(Ok(connection));
|
||||
}
|
||||
Err(e) => e,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match Pin::new(&mut this.fut).poll(cx) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
let s = this.stream.take().unwrap();
|
||||
trace!("SSL Handshake success: {:?}", s.host());
|
||||
Poll::Ready(Ok(s.replace(stream).1))
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
trace!("SSL Handshake error: {:?}", e);
|
||||
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
trace!("SSL Handshake error: {:?}", e);
|
||||
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -209,7 +227,7 @@ impl<T: Address + 'static> ServiceFactory for OpensslConnectServiceFactory<T> {
|
|||
type Future = Ready<Result<Self::Service, Self::InitError>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
ok(self.service())
|
||||
ready(Ok(self.service()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@ pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
|
|||
|
||||
use actix_codec::{AsyncRead, AsyncWrite};
|
||||
use actix_service::{Service, ServiceFactory};
|
||||
use futures_util::future::{ok, Ready};
|
||||
use futures_util::future::{ready, Ready};
|
||||
use futures_util::ready;
|
||||
use tokio_rustls::{Connect, TlsConnector};
|
||||
use webpki::DNSNameRef;
|
||||
|
||||
|
@ -66,10 +67,10 @@ where
|
|||
type Future = Ready<Result<Self::Service, Self::InitError>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
ok(RustlsConnectorService {
|
||||
ready(Ok(RustlsConnectorService {
|
||||
connector: self.connector.clone(),
|
||||
_t: PhantomData,
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,12 +126,9 @@ where
|
|||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
Poll::Ready(
|
||||
futures_util::ready!(Pin::new(&mut this.fut).poll(cx)).map(|stream| {
|
||||
let s = this.stream.take().unwrap();
|
||||
trace!("SSL Handshake success: {:?}", s.host());
|
||||
s.replace(stream).1
|
||||
}),
|
||||
)
|
||||
let stream = ready!(Pin::new(&mut this.fut).poll(cx))?;
|
||||
let s = this.stream.take().unwrap();
|
||||
trace!("SSL Handshake success: {:?}", s.host());
|
||||
Poll::Ready(Ok(s.replace(stream).1))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,8 +18,8 @@ path = "src/lib.rs"
|
|||
[dependencies]
|
||||
actix-macros = "0.1.0"
|
||||
copyless = "0.1.4"
|
||||
futures-channel = "0.3.4"
|
||||
futures-util = { version = "0.3.4", default-features = false, features = ["alloc"] }
|
||||
futures-channel = "0.3.7"
|
||||
futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] }
|
||||
smallvec = "1"
|
||||
tokio = { version = "1", features = ["rt", "net", "signal", "time"] }
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ actix-codec = "0.3.0"
|
|||
actix-utils = "2.0.0"
|
||||
|
||||
concurrent-queue = "1.2.2"
|
||||
futures-channel = { version = "0.3.4", default-features = false }
|
||||
futures-util = { version = "0.3.4", default-features = false }
|
||||
futures-channel = { version = "0.3.7", default-features = false }
|
||||
futures-util = { version = "0.3.7", default-features = false }
|
||||
log = "0.4"
|
||||
mio = { version = "0.7.3", features = [ "os-poll", "tcp", "uds"] }
|
||||
num_cpus = "1.13"
|
||||
|
@ -37,5 +37,5 @@ slab = "0.4"
|
|||
actix-testing = "1.0.0"
|
||||
bytes = "1"
|
||||
env_logger = "0.7"
|
||||
futures-util = { version = "0.3.4", default-features = false, features = ["sink"] }
|
||||
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
|
|
@ -19,7 +19,7 @@ path = "src/lib.rs"
|
|||
|
||||
[dependencies]
|
||||
derive_more = "0.99.2"
|
||||
futures-channel = "0.3.1"
|
||||
futures-channel = "0.3.7"
|
||||
parking_lot = "0.11"
|
||||
lazy_static = "1.3"
|
||||
log = "0.4"
|
||||
|
|
|
@ -39,7 +39,7 @@ actix-service = "1.0.0"
|
|||
actix-codec = "0.3.0"
|
||||
actix-utils = "2.0.0"
|
||||
|
||||
futures-util = { version = "0.3.4", default-features = false }
|
||||
futures-util = { version = "0.3.7", default-features = false }
|
||||
|
||||
# openssl
|
||||
open-ssl = { package = "openssl", version = "0.10", optional = true }
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use std::future::Future;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use actix_codec::{AsyncRead, AsyncWrite};
|
||||
use actix_service::{Service, ServiceFactory};
|
||||
use actix_utils::counter::Counter;
|
||||
use futures_util::future::{self, FutureExt, LocalBoxFuture, TryFutureExt};
|
||||
use futures_util::future::{ready, Ready};
|
||||
|
||||
pub use native_tls::Error;
|
||||
pub use tokio_native_tls::{TlsAcceptor, TlsStream};
|
||||
|
@ -50,19 +52,19 @@ where
|
|||
type Request = T;
|
||||
type Response = TlsStream<T>;
|
||||
type Error = Error;
|
||||
type Service = NativeTlsAcceptorService<T>;
|
||||
|
||||
type Config = ();
|
||||
|
||||
type Service = NativeTlsAcceptorService<T>;
|
||||
type InitError = ();
|
||||
type Future = future::Ready<Result<Self::Service, Self::InitError>>;
|
||||
type Future = Ready<Result<Self::Service, Self::InitError>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
MAX_CONN_COUNTER.with(|conns| {
|
||||
future::ok(NativeTlsAcceptorService {
|
||||
ready(Ok(NativeTlsAcceptorService {
|
||||
acceptor: self.acceptor.clone(),
|
||||
conns: conns.clone(),
|
||||
io: PhantomData,
|
||||
})
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -83,6 +85,8 @@ impl<T> Clone for NativeTlsAcceptorService<T> {
|
|||
}
|
||||
}
|
||||
|
||||
type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
|
||||
|
||||
impl<T> Service for NativeTlsAcceptorService<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
|
@ -103,12 +107,10 @@ where
|
|||
fn call(&mut self, req: Self::Request) -> Self::Future {
|
||||
let guard = self.conns.get();
|
||||
let this = self.clone();
|
||||
async move { this.acceptor.accept(req).await }
|
||||
.map_ok(move |io| {
|
||||
// Required to preserve `CounterGuard` until `Self::Future` is completely resolved.
|
||||
let _ = guard;
|
||||
io
|
||||
})
|
||||
.boxed_local()
|
||||
Box::pin(async move {
|
||||
let res = this.acceptor.accept(req).await;
|
||||
drop(guard);
|
||||
res
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -84,9 +84,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Service for AcceptorService<T>
|
|||
}
|
||||
|
||||
fn call(&mut self, req: Self::Request) -> Self::Future {
|
||||
let guard = self.conns.get();
|
||||
let stream = self.ssl_stream(req);
|
||||
AcceptorServiceResponse::Init(Some(stream), Some(guard))
|
||||
match self.ssl_stream(req) {
|
||||
Ok(stream) => {
|
||||
let guard = self.conns.get();
|
||||
AcceptorServiceResponse::Accept(Some(stream), Some(guard))
|
||||
}
|
||||
Err(e) => AcceptorServiceResponse::Error(Some(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,32 +109,21 @@ pub enum AcceptorServiceResponse<T>
|
|||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
Init(Option<Result<SslStream<T>, Error>>, Option<CounterGuard>),
|
||||
Accept(Option<SslStream<T>>, Option<CounterGuard>),
|
||||
Error(Option<Error>),
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceResponse<T> {
|
||||
type Output = Result<SslStream<T>, Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
loop {
|
||||
match self.as_mut().get_mut() {
|
||||
// Init branch only used to return the error in future
|
||||
// on success goes to Accept branch directly.
|
||||
AcceptorServiceResponse::Init(res, guard) => {
|
||||
let guard = guard.take();
|
||||
let stream = res.take().unwrap()?;
|
||||
let state = AcceptorServiceResponse::Accept(Some(stream), guard);
|
||||
self.as_mut().set(state);
|
||||
}
|
||||
AcceptorServiceResponse::Accept(stream, guard) => {
|
||||
ready!(Pin::new(stream.as_mut().unwrap()).poll_accept(cx))?;
|
||||
// drop counter guard a little early as the accept has finished
|
||||
guard.take();
|
||||
|
||||
let stream = stream.take().unwrap();
|
||||
return Poll::Ready(Ok(stream));
|
||||
}
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.get_mut() {
|
||||
AcceptorServiceResponse::Error(e) => Poll::Ready(Err(e.take().unwrap())),
|
||||
AcceptorServiceResponse::Accept(stream, guard) => {
|
||||
ready!(Pin::new(stream.as_mut().unwrap()).poll_accept(cx))?;
|
||||
// drop counter guard a little early as the accept has finished
|
||||
guard.take();
|
||||
Poll::Ready(Ok(stream.take().unwrap()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ path = "src/lib.rs"
|
|||
|
||||
[dependencies]
|
||||
actix-service = "1.0.4"
|
||||
futures-util = { version = "0.3.4", default-features = false }
|
||||
futures-util = { version = "0.3.7", default-features = false }
|
||||
tracing = "0.1"
|
||||
tracing-futures = "0.2"
|
||||
|
||||
|
|
|
@ -22,9 +22,9 @@ actix-service = "1.0.6"
|
|||
bitflags = "1.2.1"
|
||||
bytes = "1"
|
||||
either = "1.5.3"
|
||||
futures-channel = { version = "0.3.4", default-features = false }
|
||||
futures-sink = { version = "0.3.4", default-features = false }
|
||||
futures-util = { version = "0.3.4", default-features = false }
|
||||
futures-channel = { version = "0.3.7", default-features = false }
|
||||
futures-sink = { version = "0.3.7", default-features = false }
|
||||
futures-util = { version = "0.3.7", default-features = false }
|
||||
log = "0.4"
|
||||
pin-project = "1.0.0"
|
||||
slab = "0.4"
|
||||
|
|
Loading…
Reference in New Issue