mirror of https://github.com/fafhrd91/actix-net
Switch NativeTlsAcceptor to use 'tokio-tls' crate
This commit is contained in:
parent
f2502c8a6c
commit
cb052679ab
|
@ -23,8 +23,8 @@ path = "src/lib.rs"
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|
||||||
# tls
|
# native-tls
|
||||||
tls = ["native-tls"]
|
tls = ["native-tls", "tokio-tls"]
|
||||||
|
|
||||||
# openssl
|
# openssl
|
||||||
ssl = ["openssl", "tokio-openssl", "actix-server-config/ssl"]
|
ssl = ["openssl", "tokio-openssl", "actix-server-config/ssl"]
|
||||||
|
@ -54,15 +54,16 @@ tokio-net = { version = "0.2.0-alpha.6", features = ["signal"] }
|
||||||
tokio-timer = "0.3.0-alpha.6"
|
tokio-timer = "0.3.0-alpha.6"
|
||||||
|
|
||||||
# unix domain sockets
|
# unix domain sockets
|
||||||
mio-uds = { version="0.6.7", optional = true }
|
mio-uds = { version = "0.6.7", optional = true }
|
||||||
#tokio-uds = { version="0.2.5", optional = true }
|
#tokio-uds = { version="0.2.5", optional = true }
|
||||||
|
|
||||||
# native-tls
|
# native-tls
|
||||||
native-tls = { version="0.2", optional = true }
|
native-tls = { version = "0.2", optional = true }
|
||||||
|
tokio-tls = { version = "0.3.0-alpha.6", optional = true }
|
||||||
|
|
||||||
# openssl
|
# openssl
|
||||||
openssl = { version="0.10", optional = true }
|
openssl = { version = "0.10", optional = true }
|
||||||
tokio-openssl = { version="0.4.0-alpha.6", optional = true }
|
tokio-openssl = { version = "0.4.0-alpha.6", optional = true }
|
||||||
|
|
||||||
# rustls
|
# rustls
|
||||||
rustls = { version = "0.16.0", optional = true }
|
rustls = { version = "0.16.0", optional = true }
|
||||||
|
|
|
@ -11,7 +11,7 @@ pub use self::openssl::OpensslAcceptor;
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
mod nativetls;
|
mod nativetls;
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
pub use self::nativetls::{NativeTlsAcceptor, TlsStream};
|
pub use self::nativetls::NativeTlsAcceptor;
|
||||||
|
|
||||||
#[cfg(feature = "rust-tls")]
|
#[cfg(feature = "rust-tls")]
|
||||||
mod rustls;
|
mod rustls;
|
||||||
|
|
|
@ -1,18 +1,19 @@
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::future::Future;
|
|
||||||
use std::io;
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::pin::Pin;
|
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
use actix_service::{Service, ServiceFactory};
|
use actix_service::{Service, ServiceFactory};
|
||||||
use futures::future;
|
use futures::{
|
||||||
use native_tls::{Error, HandshakeError, TlsAcceptor, TlsStream};
|
future::{self, LocalBoxFuture},
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
FutureExt as _, TryFutureExt as _,
|
||||||
|
};
|
||||||
|
use native_tls::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio_tls::{TlsAcceptor, TlsStream};
|
||||||
|
|
||||||
use crate::counter::{Counter, CounterGuard};
|
use crate::counter::Counter;
|
||||||
use crate::ssl::MAX_CONN_COUNTER;
|
use crate::ssl::MAX_CONN_COUNTER;
|
||||||
use crate::{Io, Protocol, ServerConfig};
|
use crate::{Io, ServerConfig};
|
||||||
|
|
||||||
/// Support `SSL` connections via native-tls package
|
/// Support `SSL` connections via native-tls package
|
||||||
///
|
///
|
||||||
|
@ -22,8 +23,12 @@ pub struct NativeTlsAcceptor<T, P = ()> {
|
||||||
io: PhantomData<(T, P)>,
|
io: PhantomData<(T, P)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> NativeTlsAcceptor<T, P> {
|
impl<T, P> NativeTlsAcceptor<T, P>
|
||||||
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
/// Create `NativeTlsAcceptor` instance
|
/// Create `NativeTlsAcceptor` instance
|
||||||
|
#[inline]
|
||||||
pub fn new(acceptor: TlsAcceptor) -> Self {
|
pub fn new(acceptor: TlsAcceptor) -> Self {
|
||||||
NativeTlsAcceptor {
|
NativeTlsAcceptor {
|
||||||
acceptor,
|
acceptor,
|
||||||
|
@ -33,6 +38,7 @@ impl<T: AsyncRead + AsyncWrite, P> NativeTlsAcceptor<T, P> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, P> Clone for NativeTlsAcceptor<T, P> {
|
impl<T, P> Clone for NativeTlsAcceptor<T, P> {
|
||||||
|
#[inline]
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
acceptor: self.acceptor.clone(),
|
acceptor: self.acceptor.clone(),
|
||||||
|
@ -41,9 +47,13 @@ impl<T, P> Clone for NativeTlsAcceptor<T, P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> ServiceFactory for NativeTlsAcceptor<T, P> {
|
impl<T, P> ServiceFactory for NativeTlsAcceptor<T, P>
|
||||||
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||||
|
P: 'static,
|
||||||
|
{
|
||||||
type Request = Io<T, P>;
|
type Request = Io<T, P>;
|
||||||
type Response = Io<NativeTlsStream<T>, P>;
|
type Response = Io<TlsStream<T>, P>;
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
|
|
||||||
type Config = ServerConfig;
|
type Config = ServerConfig;
|
||||||
|
@ -70,141 +80,46 @@ pub struct NativeTlsAcceptorService<T, P> {
|
||||||
conns: Counter,
|
conns: Counter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
|
impl<T, P> Clone for NativeTlsAcceptorService<T, P> {
|
||||||
type Request = Io<T, P>;
|
fn clone(&self) -> Self {
|
||||||
type Response = Io<NativeTlsStream<T>, P>;
|
Self {
|
||||||
type Error = Error;
|
acceptor: self.acceptor.clone(),
|
||||||
type Future = Accept<T, P>;
|
io: PhantomData,
|
||||||
|
conns: self.conns.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
impl<T, P> Service for NativeTlsAcceptorService<T, P>
|
||||||
if self.conns.available(ctx) {
|
where
|
||||||
Ok(Poll::Ready(Ok(())))
|
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||||
|
P: 'static,
|
||||||
|
{
|
||||||
|
type Request = Io<T, P>;
|
||||||
|
type Response = Io<TlsStream<T>, P>;
|
||||||
|
type Error = Error;
|
||||||
|
type Future = LocalBoxFuture<'static, Result<Io<TlsStream<T>, P>, Error>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
if self.conns.available(cx) {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
Ok(Poll::Pending)
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: Self::Request) -> Self::Future {
|
fn call(&mut self, req: Self::Request) -> Self::Future {
|
||||||
let (io, params, _) = req.into_parts();
|
let guard = self.conns.get();
|
||||||
Accept {
|
let this = self.clone();
|
||||||
_guard: self.conns.get(),
|
let (io, params, proto) = req.into_parts();
|
||||||
inner: Some(self.acceptor.accept(io)),
|
async move { this.acceptor.accept(io).await }
|
||||||
params: Some(params),
|
.map_ok(move |stream| Io::from_parts(stream, params, proto))
|
||||||
}
|
.map_ok(move |io| {
|
||||||
}
|
// Required to preserve `CounterGuard` until `Self::Future`
|
||||||
}
|
// is completely resolved.
|
||||||
|
let _ = guard;
|
||||||
/// Future returned from `NativeTlsAcceptor::accept` which will resolve
|
io
|
||||||
/// once the accept handshake has finished.
|
})
|
||||||
pub struct Accept<S, P> {
|
.boxed_local()
|
||||||
inner: Option<Result<TlsStream<S>, HandshakeError<S>>>,
|
|
||||||
params: Option<P>,
|
|
||||||
_guard: CounterGuard,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> Future for Accept<T, P> {
|
|
||||||
type Output = Result<Io<NativeTlsStream<T>, P>, Error>;
|
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
let my = self.get_mut();
|
|
||||||
match my.inner.take().expect("cannot poll MidHandshake twice") {
|
|
||||||
Ok(stream) => Poll::Ready(Ok(Io::from_parts(
|
|
||||||
NativeTlsStream { inner: stream },
|
|
||||||
my.params.take().unwrap(),
|
|
||||||
Protocol::Unknown,
|
|
||||||
))),
|
|
||||||
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
|
|
||||||
Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
|
|
||||||
Ok(stream) => Poll::Ready(Ok(Io::from_parts(
|
|
||||||
NativeTlsStream { inner: stream },
|
|
||||||
my.params.take().unwrap(),
|
|
||||||
Protocol::Unknown,
|
|
||||||
))),
|
|
||||||
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
|
|
||||||
Err(HandshakeError::WouldBlock(s)) => {
|
|
||||||
my.inner = Some(Err(HandshakeError::WouldBlock(s)));
|
|
||||||
// TODO: should we use Waker somehow?
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
|
||||||
/// protocol.
|
|
||||||
///
|
|
||||||
/// A `NativeTlsStream<S>` represents a handshake that has been completed successfully
|
|
||||||
/// and both the server and the client are ready for receiving and sending
|
|
||||||
/// data. Bytes read from a `NativeTlsStream` are decrypted from `S` and bytes written
|
|
||||||
/// to a `NativeTlsStream` are encrypted when passing through to `S`.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct NativeTlsStream<S> {
|
|
||||||
inner: TlsStream<S>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> AsRef<TlsStream<S>> for NativeTlsStream<S> {
|
|
||||||
fn as_ref(&self) -> &TlsStream<S> {
|
|
||||||
&self.inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> AsMut<TlsStream<S>> for NativeTlsStream<S> {
|
|
||||||
fn as_mut(&mut self) -> &mut TlsStream<S> {
|
|
||||||
&mut self.inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: io::Read + io::Write> io::Read for NativeTlsStream<S> {
|
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
|
||||||
self.inner.read(buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: io::Read + io::Write> io::Write for NativeTlsStream<S> {
|
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
|
||||||
self.inner.write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
self.inner.flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite> AsyncRead for NativeTlsStream<S> {
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut [u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
// TODO: wha?
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite> AsyncWrite for NativeTlsStream<S> {
|
|
||||||
fn poll_write(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<Result<usize, io::Error>> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> Poll<Result<(), io::Error>> {
|
|
||||||
let inner = &mut Pin::get_mut(self).inner;
|
|
||||||
match inner.shutdown() {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
|
|
||||||
Err(e) => return Poll::Ready(Err(e)),
|
|
||||||
}
|
|
||||||
inner.get_mut().poll_shutdown()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue