Switch NativeTlsAcceptor to use 'tokio-tls' crate

This commit is contained in:
tyranron 2019-11-12 22:36:30 +02:00
parent f2502c8a6c
commit cb052679ab
No known key found for this signature in database
GPG Key ID: 762E144FB230A4F0
3 changed files with 64 additions and 148 deletions

View File

@ -23,8 +23,8 @@ path = "src/lib.rs"
[features]
default = []
# tls
tls = ["native-tls"]
# native-tls
tls = ["native-tls", "tokio-tls"]
# openssl
ssl = ["openssl", "tokio-openssl", "actix-server-config/ssl"]
@ -54,15 +54,16 @@ tokio-net = { version = "0.2.0-alpha.6", features = ["signal"] }
tokio-timer = "0.3.0-alpha.6"
# unix domain sockets
mio-uds = { version="0.6.7", optional = true }
mio-uds = { version = "0.6.7", optional = true }
#tokio-uds = { version="0.2.5", optional = true }
# native-tls
native-tls = { version="0.2", optional = true }
native-tls = { version = "0.2", optional = true }
tokio-tls = { version = "0.3.0-alpha.6", optional = true }
# openssl
openssl = { version="0.10", optional = true }
tokio-openssl = { version="0.4.0-alpha.6", optional = true }
openssl = { version = "0.10", optional = true }
tokio-openssl = { version = "0.4.0-alpha.6", optional = true }
# rustls
rustls = { version = "0.16.0", optional = true }

View File

@ -11,7 +11,7 @@ pub use self::openssl::OpensslAcceptor;
#[cfg(feature = "tls")]
mod nativetls;
#[cfg(feature = "tls")]
pub use self::nativetls::{NativeTlsAcceptor, TlsStream};
pub use self::nativetls::NativeTlsAcceptor;
#[cfg(feature = "rust-tls")]
mod rustls;

View File

@ -1,18 +1,19 @@
use std::convert::Infallible;
use std::future::Future;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_service::{Service, ServiceFactory};
use futures::future;
use native_tls::{Error, HandshakeError, TlsAcceptor, TlsStream};
use tokio_io::{AsyncRead, AsyncWrite};
use futures::{
future::{self, LocalBoxFuture},
FutureExt as _, TryFutureExt as _,
};
use native_tls::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tls::{TlsAcceptor, TlsStream};
use crate::counter::{Counter, CounterGuard};
use crate::counter::Counter;
use crate::ssl::MAX_CONN_COUNTER;
use crate::{Io, Protocol, ServerConfig};
use crate::{Io, ServerConfig};
/// Support `SSL` connections via native-tls package
///
@ -22,8 +23,12 @@ pub struct NativeTlsAcceptor<T, P = ()> {
io: PhantomData<(T, P)>,
}
impl<T: AsyncRead + AsyncWrite, P> NativeTlsAcceptor<T, P> {
impl<T, P> NativeTlsAcceptor<T, P>
where
T: AsyncRead + AsyncWrite + Unpin,
{
/// Create `NativeTlsAcceptor` instance
#[inline]
pub fn new(acceptor: TlsAcceptor) -> Self {
NativeTlsAcceptor {
acceptor,
@ -33,6 +38,7 @@ impl<T: AsyncRead + AsyncWrite, P> NativeTlsAcceptor<T, P> {
}
impl<T, P> Clone for NativeTlsAcceptor<T, P> {
#[inline]
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
@ -41,9 +47,13 @@ impl<T, P> Clone for NativeTlsAcceptor<T, P> {
}
}
impl<T: AsyncRead + AsyncWrite, P> ServiceFactory for NativeTlsAcceptor<T, P> {
impl<T, P> ServiceFactory for NativeTlsAcceptor<T, P>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
P: 'static,
{
type Request = Io<T, P>;
type Response = Io<NativeTlsStream<T>, P>;
type Response = Io<TlsStream<T>, P>;
type Error = Error;
type Config = ServerConfig;
@ -70,141 +80,46 @@ pub struct NativeTlsAcceptorService<T, P> {
conns: Counter,
}
impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
type Request = Io<T, P>;
type Response = Io<NativeTlsStream<T>, P>;
type Error = Error;
type Future = Accept<T, P>;
impl<T, P> Clone for NativeTlsAcceptorService<T, P> {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
io: PhantomData,
conns: self.conns.clone(),
}
}
}
fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.conns.available(ctx) {
Ok(Poll::Ready(Ok(())))
impl<T, P> Service for NativeTlsAcceptorService<T, P>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
P: 'static,
{
type Request = Io<T, P>;
type Response = Io<TlsStream<T>, P>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Io<TlsStream<T>, P>, Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.conns.available(cx) {
Poll::Ready(Ok(()))
} else {
Ok(Poll::Pending)
Poll::Pending
}
}
fn call(&mut self, req: Self::Request) -> Self::Future {
let (io, params, _) = req.into_parts();
Accept {
_guard: self.conns.get(),
inner: Some(self.acceptor.accept(io)),
params: Some(params),
}
}
}
/// Future returned from `NativeTlsAcceptor::accept` which will resolve
/// once the accept handshake has finished.
pub struct Accept<S, P> {
inner: Option<Result<TlsStream<S>, HandshakeError<S>>>,
params: Option<P>,
_guard: CounterGuard,
}
impl<T: AsyncRead + AsyncWrite, P> Future for Accept<T, P> {
type Output = Result<Io<NativeTlsStream<T>, P>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let my = self.get_mut();
match my.inner.take().expect("cannot poll MidHandshake twice") {
Ok(stream) => Poll::Ready(Ok(Io::from_parts(
NativeTlsStream { inner: stream },
my.params.take().unwrap(),
Protocol::Unknown,
))),
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
Ok(stream) => Poll::Ready(Ok(Io::from_parts(
NativeTlsStream { inner: stream },
my.params.take().unwrap(),
Protocol::Unknown,
))),
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
Err(HandshakeError::WouldBlock(s)) => {
my.inner = Some(Err(HandshakeError::WouldBlock(s)));
// TODO: should we use Waker somehow?
Poll::Pending
}
},
}
}
}
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
///
/// A `NativeTlsStream<S>` represents a handshake that has been completed successfully
/// and both the server and the client are ready for receiving and sending
/// data. Bytes read from a `NativeTlsStream` are decrypted from `S` and bytes written
/// to a `NativeTlsStream` are encrypted when passing through to `S`.
#[derive(Debug)]
pub struct NativeTlsStream<S> {
inner: TlsStream<S>,
}
impl<S> AsRef<TlsStream<S>> for NativeTlsStream<S> {
fn as_ref(&self) -> &TlsStream<S> {
&self.inner
}
}
impl<S> AsMut<TlsStream<S>> for NativeTlsStream<S> {
fn as_mut(&mut self) -> &mut TlsStream<S> {
&mut self.inner
}
}
impl<S: io::Read + io::Write> io::Read for NativeTlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for NativeTlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<S: AsyncRead + AsyncWrite> AsyncRead for NativeTlsStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
// TODO: wha?
unimplemented!()
}
}
impl<S: AsyncRead + AsyncWrite> AsyncWrite for NativeTlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
unimplemented!()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
unimplemented!()
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
let inner = &mut Pin::get_mut(self).inner;
match inner.shutdown() {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
Err(e) => return Poll::Ready(Err(e)),
}
inner.get_mut().poll_shutdown()
let guard = self.conns.get();
let this = self.clone();
let (io, params, proto) = req.into_parts();
async move { this.acceptor.accept(io).await }
.map_ok(move |stream| Io::from_parts(stream, params, proto))
.map_ok(move |io| {
// Required to preserve `CounterGuard` until `Self::Future`
// is completely resolved.
let _ = guard;
io
})
.boxed_local()
}
}