diff --git a/actix-server/src/ssl/nativetls.rs b/actix-server/src/ssl/nativetls.rs index 60099bef..2c13ed5b 100644 --- a/actix-server/src/ssl/nativetls.rs +++ b/actix-server/src/ssl/nativetls.rs @@ -1,16 +1,18 @@ +use std::convert::Infallible; +use std::future::Future; use std::io; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; -use actix_service::{NewService, Service}; -use futures::{future::ok, future::FutureResult, Async, Future, Poll}; -use native_tls::{self, Error, HandshakeError, TlsAcceptor}; +use actix_service::{Service, ServiceFactory}; +use futures::future; +use native_tls::{Error, HandshakeError, TlsAcceptor, TlsStream}; use tokio_io::{AsyncRead, AsyncWrite}; use crate::counter::{Counter, CounterGuard}; use crate::ssl::MAX_CONN_COUNTER; use crate::{Io, Protocol, ServerConfig}; -use std::pin::Pin; -use std::task::Context; /// Support `SSL` connections via native-tls package /// @@ -30,7 +32,7 @@ impl NativeTlsAcceptor { } } -impl Clone for NativeTlsAcceptor { +impl Clone for NativeTlsAcceptor { fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), @@ -39,21 +41,21 @@ impl Clone for NativeTlsAcceptor { } } -impl NewService for NativeTlsAcceptor { +impl ServiceFactory for NativeTlsAcceptor { type Request = Io; - type Response = Io, P>; + type Response = Io, P>; type Error = Error; type Config = ServerConfig; type Service = NativeTlsAcceptorService; - type InitError = (); - type Future = FutureResult; + type InitError = Infallible; + type Future = future::Ready>; fn new_service(&self, cfg: &ServerConfig) -> Self::Future { cfg.set_secure(); MAX_CONN_COUNTER.with(|conns| { - ok(NativeTlsAcceptorService { + future::ok(NativeTlsAcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), io: PhantomData, @@ -70,31 +72,18 @@ pub struct NativeTlsAcceptorService { impl Service for NativeTlsAcceptorService { type Request = Io; - type Response = Io, P>; + type Response = Io, P>; type Error = Error; type Future = Accept; - fn poll_ready( - self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll> { + fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll> { if self.conns.available(ctx) { - Ok(Async::Ready(())) + Ok(Poll::Ready(Ok(()))) } else { - Ok(Async::NotReady) + Ok(Poll::Pending) } } - /* - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - if self.conns.available() { - Ok(Async::Ready(())) - } else { - Ok(Async::NotReady) - } - } - */ - fn call(&mut self, req: Self::Request) -> Self::Future { let (io, params, _) = req.into_parts(); Accept { @@ -105,75 +94,74 @@ impl Service for NativeTlsAcceptorService { } } -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -/// -/// A `TlsStream` represents a handshake that has been completed successfully -/// and both the server and the client are ready for receiving and sending -/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written -/// to a `TlsStream` are encrypted when passing through to `S`. -#[derive(Debug)] -pub struct TlsStream { - inner: native_tls::TlsStream, -} - /// Future returned from `NativeTlsAcceptor::accept` which will resolve /// once the accept handshake has finished. pub struct Accept { - inner: Option, HandshakeError>>, + inner: Option, HandshakeError>>, params: Option

, _guard: CounterGuard, } impl Future for Accept { - type Item = Io, P>; - type Error = Error; + type Output = Result, P>, Error>; - fn poll(&mut self) -> Poll { - match self.inner.take().expect("cannot poll MidHandshake twice") { - Ok(stream) => Ok(Async::Ready(Io::from_parts( - TlsStream { inner: stream }, - self.params.take().unwrap(), + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let my = self.get_mut(); + match my.inner.take().expect("cannot poll MidHandshake twice") { + Ok(stream) => Poll::Ready(Ok(Io::from_parts( + NativeTlsStream { inner: stream }, + my.params.take().unwrap(), Protocol::Unknown, ))), - Err(HandshakeError::Failure(e)) => Err(e), + Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), Err(HandshakeError::WouldBlock(s)) => match s.handshake() { - Ok(stream) => Ok(Async::Ready(Io::from_parts( - TlsStream { inner: stream }, - self.params.take().unwrap(), + Ok(stream) => Poll::Ready(Ok(Io::from_parts( + NativeTlsStream { inner: stream }, + my.params.take().unwrap(), Protocol::Unknown, ))), - Err(HandshakeError::Failure(e)) => Err(e), + Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), Err(HandshakeError::WouldBlock(s)) => { - self.inner = Some(Err(HandshakeError::WouldBlock(s))); - Ok(Async::NotReady) + my.inner = Some(Err(HandshakeError::WouldBlock(s))); + // TODO: should we use Waker somehow? + Poll::Pending } }, } } } -impl TlsStream { - /// Get access to the internal `native_tls::TlsStream` stream which also - /// transitively allows access to `S`. - pub fn get_ref(&self) -> &native_tls::TlsStream { +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +/// +/// A `NativeTlsStream` represents a handshake that has been completed successfully +/// and both the server and the client are ready for receiving and sending +/// data. Bytes read from a `NativeTlsStream` are decrypted from `S` and bytes written +/// to a `NativeTlsStream` are encrypted when passing through to `S`. +#[derive(Debug)] +pub struct NativeTlsStream { + inner: TlsStream, +} + +impl AsRef> for NativeTlsStream { + fn as_ref(&self) -> &TlsStream { &self.inner } +} - /// Get mutable access to the internal `native_tls::TlsStream` stream which - /// also transitively allows mutable access to `S`. - pub fn get_mut(&mut self) -> &mut native_tls::TlsStream { +impl AsMut> for NativeTlsStream { + fn as_mut(&mut self) -> &mut TlsStream { &mut self.inner } } -impl io::Read for TlsStream { +impl io::Read for NativeTlsStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.inner.read(buf) } } -impl io::Write for TlsStream { +impl io::Write for NativeTlsStream { fn write(&mut self, buf: &[u8]) -> io::Result { self.inner.write(buf) } @@ -183,15 +171,40 @@ impl io::Write for TlsStream { } } -impl AsyncRead for TlsStream {} - -impl AsyncWrite for TlsStream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.inner.shutdown() { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e), - } - self.inner.get_mut().shutdown() +impl AsyncRead for NativeTlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + // TODO: wha? + unimplemented!() + } +} + +impl AsyncWrite for NativeTlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + unimplemented!() + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unimplemented!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let inner = &mut Pin::get_mut(self).inner; + match inner.shutdown() { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Poll::Ready(Err(e)), + } + inner.get_mut().poll_shutdown() } }