mirror of https://github.com/fafhrd91/actix-net
Refactor 'nativetls' module
This commit is contained in:
parent
0be859d440
commit
e2efd927b7
|
@ -1,16 +1,18 @@
|
||||||
|
use std::convert::Infallible;
|
||||||
|
use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
use actix_service::{NewService, Service};
|
use actix_service::{Service, ServiceFactory};
|
||||||
use futures::{future::ok, future::FutureResult, Async, Future, Poll};
|
use futures::future;
|
||||||
use native_tls::{self, Error, HandshakeError, TlsAcceptor};
|
use native_tls::{Error, HandshakeError, TlsAcceptor, TlsStream};
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
use crate::counter::{Counter, CounterGuard};
|
use crate::counter::{Counter, CounterGuard};
|
||||||
use crate::ssl::MAX_CONN_COUNTER;
|
use crate::ssl::MAX_CONN_COUNTER;
|
||||||
use crate::{Io, Protocol, ServerConfig};
|
use crate::{Io, Protocol, ServerConfig};
|
||||||
use std::pin::Pin;
|
|
||||||
use std::task::Context;
|
|
||||||
|
|
||||||
/// Support `SSL` connections via native-tls package
|
/// Support `SSL` connections via native-tls package
|
||||||
///
|
///
|
||||||
|
@ -30,7 +32,7 @@ impl<T: AsyncRead + AsyncWrite, P> NativeTlsAcceptor<T, P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> Clone for NativeTlsAcceptor<T, P> {
|
impl<T, P> Clone for NativeTlsAcceptor<T, P> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
acceptor: self.acceptor.clone(),
|
acceptor: self.acceptor.clone(),
|
||||||
|
@ -39,21 +41,21 @@ impl<T: AsyncRead + AsyncWrite, P> Clone for NativeTlsAcceptor<T, P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> NewService for NativeTlsAcceptor<T, P> {
|
impl<T: AsyncRead + AsyncWrite, P> ServiceFactory for NativeTlsAcceptor<T, P> {
|
||||||
type Request = Io<T, P>;
|
type Request = Io<T, P>;
|
||||||
type Response = Io<TlsStream<T>, P>;
|
type Response = Io<NativeTlsStream<T>, P>;
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
|
|
||||||
type Config = ServerConfig;
|
type Config = ServerConfig;
|
||||||
type Service = NativeTlsAcceptorService<T, P>;
|
type Service = NativeTlsAcceptorService<T, P>;
|
||||||
type InitError = ();
|
type InitError = Infallible;
|
||||||
type Future = FutureResult<Self::Service, Self::InitError>;
|
type Future = future::Ready<Result<Self::Service, Self::InitError>>;
|
||||||
|
|
||||||
fn new_service(&self, cfg: &ServerConfig) -> Self::Future {
|
fn new_service(&self, cfg: &ServerConfig) -> Self::Future {
|
||||||
cfg.set_secure();
|
cfg.set_secure();
|
||||||
|
|
||||||
MAX_CONN_COUNTER.with(|conns| {
|
MAX_CONN_COUNTER.with(|conns| {
|
||||||
ok(NativeTlsAcceptorService {
|
future::ok(NativeTlsAcceptorService {
|
||||||
acceptor: self.acceptor.clone(),
|
acceptor: self.acceptor.clone(),
|
||||||
conns: conns.clone(),
|
conns: conns.clone(),
|
||||||
io: PhantomData,
|
io: PhantomData,
|
||||||
|
@ -70,31 +72,18 @@ pub struct NativeTlsAcceptorService<T, P> {
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
|
impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
|
||||||
type Request = Io<T, P>;
|
type Request = Io<T, P>;
|
||||||
type Response = Io<TlsStream<T>, P>;
|
type Response = Io<NativeTlsStream<T>, P>;
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
type Future = Accept<T, P>;
|
type Future = Accept<T, P>;
|
||||||
|
|
||||||
fn poll_ready(
|
fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
ctx: &mut Context<'_>,
|
|
||||||
) -> Poll<Result<(), Self::Error>> {
|
|
||||||
if self.conns.available(ctx) {
|
if self.conns.available(ctx) {
|
||||||
Ok(Async::Ready(()))
|
Ok(Poll::Ready(Ok(())))
|
||||||
} else {
|
} else {
|
||||||
Ok(Async::NotReady)
|
Ok(Poll::Pending)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
|
|
||||||
if self.conns.available() {
|
|
||||||
Ok(Async::Ready(()))
|
|
||||||
} else {
|
|
||||||
Ok(Async::NotReady)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
fn call(&mut self, req: Self::Request) -> Self::Future {
|
fn call(&mut self, req: Self::Request) -> Self::Future {
|
||||||
let (io, params, _) = req.into_parts();
|
let (io, params, _) = req.into_parts();
|
||||||
Accept {
|
Accept {
|
||||||
|
@ -105,75 +94,74 @@ impl<T: AsyncRead + AsyncWrite, P> Service for NativeTlsAcceptorService<T, P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
|
||||||
/// protocol.
|
|
||||||
///
|
|
||||||
/// A `TlsStream<S>` represents a handshake that has been completed successfully
|
|
||||||
/// and both the server and the client are ready for receiving and sending
|
|
||||||
/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
|
|
||||||
/// to a `TlsStream` are encrypted when passing through to `S`.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TlsStream<S> {
|
|
||||||
inner: native_tls::TlsStream<S>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Future returned from `NativeTlsAcceptor::accept` which will resolve
|
/// Future returned from `NativeTlsAcceptor::accept` which will resolve
|
||||||
/// once the accept handshake has finished.
|
/// once the accept handshake has finished.
|
||||||
pub struct Accept<S, P> {
|
pub struct Accept<S, P> {
|
||||||
inner: Option<Result<native_tls::TlsStream<S>, HandshakeError<S>>>,
|
inner: Option<Result<TlsStream<S>, HandshakeError<S>>>,
|
||||||
params: Option<P>,
|
params: Option<P>,
|
||||||
_guard: CounterGuard,
|
_guard: CounterGuard,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite, P> Future for Accept<T, P> {
|
impl<T: AsyncRead + AsyncWrite, P> Future for Accept<T, P> {
|
||||||
type Item = Io<TlsStream<T>, P>;
|
type Output = Result<Io<NativeTlsStream<T>, P>, Error>;
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
match self.inner.take().expect("cannot poll MidHandshake twice") {
|
let my = self.get_mut();
|
||||||
Ok(stream) => Ok(Async::Ready(Io::from_parts(
|
match my.inner.take().expect("cannot poll MidHandshake twice") {
|
||||||
TlsStream { inner: stream },
|
Ok(stream) => Poll::Ready(Ok(Io::from_parts(
|
||||||
self.params.take().unwrap(),
|
NativeTlsStream { inner: stream },
|
||||||
|
my.params.take().unwrap(),
|
||||||
Protocol::Unknown,
|
Protocol::Unknown,
|
||||||
))),
|
))),
|
||||||
Err(HandshakeError::Failure(e)) => Err(e),
|
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
|
||||||
Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
|
Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
|
||||||
Ok(stream) => Ok(Async::Ready(Io::from_parts(
|
Ok(stream) => Poll::Ready(Ok(Io::from_parts(
|
||||||
TlsStream { inner: stream },
|
NativeTlsStream { inner: stream },
|
||||||
self.params.take().unwrap(),
|
my.params.take().unwrap(),
|
||||||
Protocol::Unknown,
|
Protocol::Unknown,
|
||||||
))),
|
))),
|
||||||
Err(HandshakeError::Failure(e)) => Err(e),
|
Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
|
||||||
Err(HandshakeError::WouldBlock(s)) => {
|
Err(HandshakeError::WouldBlock(s)) => {
|
||||||
self.inner = Some(Err(HandshakeError::WouldBlock(s)));
|
my.inner = Some(Err(HandshakeError::WouldBlock(s)));
|
||||||
Ok(Async::NotReady)
|
// TODO: should we use Waker somehow?
|
||||||
|
Poll::Pending
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> TlsStream<S> {
|
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||||
/// Get access to the internal `native_tls::TlsStream` stream which also
|
/// protocol.
|
||||||
/// transitively allows access to `S`.
|
///
|
||||||
pub fn get_ref(&self) -> &native_tls::TlsStream<S> {
|
/// A `NativeTlsStream<S>` represents a handshake that has been completed successfully
|
||||||
|
/// and both the server and the client are ready for receiving and sending
|
||||||
|
/// data. Bytes read from a `NativeTlsStream` are decrypted from `S` and bytes written
|
||||||
|
/// to a `NativeTlsStream` are encrypted when passing through to `S`.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct NativeTlsStream<S> {
|
||||||
|
inner: TlsStream<S>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> AsRef<TlsStream<S>> for NativeTlsStream<S> {
|
||||||
|
fn as_ref(&self) -> &TlsStream<S> {
|
||||||
&self.inner
|
&self.inner
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get mutable access to the internal `native_tls::TlsStream` stream which
|
impl<S> AsMut<TlsStream<S>> for NativeTlsStream<S> {
|
||||||
/// also transitively allows mutable access to `S`.
|
fn as_mut(&mut self) -> &mut TlsStream<S> {
|
||||||
pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<S> {
|
|
||||||
&mut self.inner
|
&mut self.inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
|
impl<S: io::Read + io::Write> io::Read for NativeTlsStream<S> {
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
self.inner.read(buf)
|
self.inner.read(buf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
|
impl<S: io::Read + io::Write> io::Write for NativeTlsStream<S> {
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
self.inner.write(buf)
|
self.inner.write(buf)
|
||||||
}
|
}
|
||||||
|
@ -183,15 +171,40 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite> AsyncRead for TlsStream<S> {}
|
impl<S: AsyncRead + AsyncWrite> AsyncRead for NativeTlsStream<S> {
|
||||||
|
fn poll_read(
|
||||||
impl<S: AsyncRead + AsyncWrite> AsyncWrite for TlsStream<S> {
|
self: Pin<&mut Self>,
|
||||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
cx: &mut Context<'_>,
|
||||||
match self.inner.shutdown() {
|
buf: &mut [u8],
|
||||||
Ok(_) => (),
|
) -> Poll<io::Result<usize>> {
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
|
// TODO: wha?
|
||||||
Err(e) => return Err(e),
|
unimplemented!()
|
||||||
}
|
}
|
||||||
self.inner.get_mut().shutdown()
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite> AsyncWrite for NativeTlsStream<S> {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, io::Error>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), io::Error>> {
|
||||||
|
let inner = &mut Pin::get_mut(self).inner;
|
||||||
|
match inner.shutdown() {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
|
||||||
|
Err(e) => return Poll::Ready(Err(e)),
|
||||||
|
}
|
||||||
|
inner.get_mut().poll_shutdown()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue