Refactor 'nativetls' module

This commit is contained in:
tyranron 2019-11-12 19:00:37 +02:00
parent 0be859d440
commit e2efd927b7
No known key found for this signature in database
GPG Key ID: 762E144FB230A4F0
1 changed files with 87 additions and 74 deletions

View File

@ -1,16 +1,18 @@
use std::convert::Infallible;
use std::future::Future;
use std::io; use std::io;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_service::{NewService, Service}; use actix_service::{Service, ServiceFactory};
use futures::{future::ok, future::FutureResult, Async, Future, Poll}; use futures::future;
use native_tls::{self, Error, HandshakeError, TlsAcceptor}; use native_tls::{Error, HandshakeError, TlsAcceptor, TlsStream};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use crate::counter::{Counter, CounterGuard}; use crate::counter::{Counter, CounterGuard};
use crate::ssl::MAX_CONN_COUNTER; use crate::ssl::MAX_CONN_COUNTER;
use crate::{Io, Protocol, ServerConfig}; use crate::{Io, Protocol, ServerConfig};
use std::pin::Pin;
use std::task::Context;
/// Support `SSL` connections via native-tls package /// Support `SSL` connections via native-tls package
/// ///
@ -30,7 +32,7 @@ impl<T: AsyncRead + AsyncWrite, P> NativeTlsAcceptor<T, P> {
} }
} }
impl<T: AsyncRead + AsyncWrite, P> Clone for NativeTlsAcceptor<T, P> { impl<T, P> Clone for NativeTlsAcceptor<T, P> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
@ -39,21 +41,21 @@ impl<T: AsyncRead + AsyncWrite, P> Clone for NativeTlsAcceptor<T, P> {
} }
} }
impl<T: AsyncRead + AsyncWrite, P> NewService for NativeTlsAcceptor<T, P> { impl<T: AsyncRead + AsyncWrite, P> ServiceFactory for NativeTlsAcceptor<T, P> {
type Request = Io<T, P>; type Request = Io<T, P>;
type Response = Io<TlsStream<T>, P>; type Response = Io<NativeTlsStream<T>, P>;
type Error = Error; type Error = Error;
type Config = ServerConfig; type Config = ServerConfig;
type Service = NativeTlsAcceptorService<T, P>; type Service = NativeTlsAcceptorService<T, P>;
type InitError = (); type InitError = Infallible;
type Future = FutureResult<Self::Service, Self::InitError>; type Future = future::Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, cfg: &ServerConfig) -> Self::Future { fn new_service(&self, cfg: &ServerConfig) -> Self::Future {
cfg.set_secure(); cfg.set_secure();
MAX_CONN_COUNTER.with(|conns| { MAX_CONN_COUNTER.with(|conns| {
ok(NativeTlsAcceptorService { future::ok(NativeTlsAcceptorService {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
conns: conns.clone(), conns: conns.clone(),
io: PhantomData, io: PhantomData,
@ -70,31 +72,18 @@ pub struct NativeTlsAcceptorService<T, P> {
impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> { impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
type Request = Io<T, P>; type Request = Io<T, P>;
type Response = Io<TlsStream<T>, P>; type Response = Io<NativeTlsStream<T>, P>;
type Error = Error; type Error = Error;
type Future = Accept<T, P>; type Future = Accept<T, P>;
fn poll_ready( fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if self.conns.available(ctx) { if self.conns.available(ctx) {
Ok(Async::Ready(())) Ok(Poll::Ready(Ok(())))
} else { } else {
Ok(Async::NotReady) Ok(Poll::Pending)
} }
} }
/*
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
if self.conns.available() {
Ok(Async::Ready(()))
} else {
Ok(Async::NotReady)
}
}
*/
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let (io, params, _) = req.into_parts(); let (io, params, _) = req.into_parts();
Accept { Accept {
@ -105,75 +94,74 @@ impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
} }
} }
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
///
/// A `TlsStream<S>` represents a handshake that has been completed successfully
/// and both the server and the client are ready for receiving and sending
/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
/// to a `TlsStream` are encrypted when passing through to `S`.
#[derive(Debug)]
pub struct TlsStream<S> {
inner: native_tls::TlsStream<S>,
}
/// Future returned from `NativeTlsAcceptor::accept` which will resolve /// Future returned from `NativeTlsAcceptor::accept` which will resolve
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct Accept<S, P> { pub struct Accept<S, P> {
inner: Option<Result<native_tls::TlsStream<S>, HandshakeError<S>>>, inner: Option<Result<TlsStream<S>, HandshakeError<S>>>,
params: Option<P>, params: Option<P>,
_guard: CounterGuard, _guard: CounterGuard,
} }
impl<T: AsyncRead + AsyncWrite, P> Future for Accept<T, P> { impl<T: AsyncRead + AsyncWrite, P> Future for Accept<T, P> {
type Item = Io<TlsStream<T>, P>; type Output = Result<Io<NativeTlsStream<T>, P>, Error>;
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.inner.take().expect("cannot poll MidHandshake twice") { let my = self.get_mut();
Ok(stream) => Ok(Async::Ready(Io::from_parts( match my.inner.take().expect("cannot poll MidHandshake twice") {
TlsStream { inner: stream }, Ok(stream) => Poll::Ready(Ok(Io::from_parts(
self.params.take().unwrap(), NativeTlsStream { inner: stream },
my.params.take().unwrap(),
Protocol::Unknown, Protocol::Unknown,
))), ))),
Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
Err(HandshakeError::WouldBlock(s)) => match s.handshake() { Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
Ok(stream) => Ok(Async::Ready(Io::from_parts( Ok(stream) => Poll::Ready(Ok(Io::from_parts(
TlsStream { inner: stream }, NativeTlsStream { inner: stream },
self.params.take().unwrap(), my.params.take().unwrap(),
Protocol::Unknown, Protocol::Unknown,
))), ))),
Err(HandshakeError::Failure(e)) => Err(e), Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
Err(HandshakeError::WouldBlock(s)) => { Err(HandshakeError::WouldBlock(s)) => {
self.inner = Some(Err(HandshakeError::WouldBlock(s))); my.inner = Some(Err(HandshakeError::WouldBlock(s)));
Ok(Async::NotReady) // TODO: should we use Waker somehow?
Poll::Pending
} }
}, },
} }
} }
} }
impl<S> TlsStream<S> { /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// Get access to the internal `native_tls::TlsStream` stream which also /// protocol.
/// transitively allows access to `S`. ///
pub fn get_ref(&self) -> &native_tls::TlsStream<S> { /// A `NativeTlsStream<S>` represents a handshake that has been completed successfully
/// and both the server and the client are ready for receiving and sending
/// data. Bytes read from a `NativeTlsStream` are decrypted from `S` and bytes written
/// to a `NativeTlsStream` are encrypted when passing through to `S`.
#[derive(Debug)]
pub struct NativeTlsStream<S> {
inner: TlsStream<S>,
}
impl<S> AsRef<TlsStream<S>> for NativeTlsStream<S> {
fn as_ref(&self) -> &TlsStream<S> {
&self.inner &self.inner
} }
}
/// Get mutable access to the internal `native_tls::TlsStream` stream which impl<S> AsMut<TlsStream<S>> for NativeTlsStream<S> {
/// also transitively allows mutable access to `S`. fn as_mut(&mut self) -> &mut TlsStream<S> {
pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<S> {
&mut self.inner &mut self.inner
} }
} }
impl<S: io::Read + io::Write> io::Read for TlsStream<S> { impl<S: io::Read + io::Write> io::Read for NativeTlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf) self.inner.read(buf)
} }
} }
impl<S: io::Read + io::Write> io::Write for TlsStream<S> { impl<S: io::Read + io::Write> io::Write for NativeTlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf) self.inner.write(buf)
} }
@ -183,15 +171,40 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
} }
} }
impl<S: AsyncRead + AsyncWrite> AsyncRead for TlsStream<S> {} impl<S: AsyncRead + AsyncWrite> AsyncRead for NativeTlsStream<S> {
fn poll_read(
impl<S: AsyncRead + AsyncWrite> AsyncWrite for TlsStream<S> { self: Pin<&mut Self>,
fn shutdown(&mut self) -> Poll<(), io::Error> { cx: &mut Context<'_>,
match self.inner.shutdown() { buf: &mut [u8],
Ok(_) => (), ) -> Poll<io::Result<usize>> {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), // TODO: wha?
Err(e) => return Err(e), unimplemented!()
} }
self.inner.get_mut().shutdown() }
impl<S: AsyncRead + AsyncWrite> AsyncWrite for NativeTlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
unimplemented!()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
unimplemented!()
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
let inner = &mut Pin::get_mut(self).inner;
match inner.shutdown() {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
Err(e) => return Poll::Ready(Err(e)),
}
inner.get_mut().poll_shutdown()
} }
} }