merge master. revert to single thread tokio runtime

This commit is contained in:
fakeshadow 2021-03-28 21:06:35 +08:00
commit 899d22a98e
19 changed files with 322 additions and 295 deletions

View File

@ -170,13 +170,11 @@ impl Quoter {
idx += 1; idx += 1;
} }
if let Some(data) = cloned { cloned.map(|data| {
// Unsafe: we get data from http::Uri, which does utf-8 checks already // SAFETY: we get data from http::Uri, which does UTF-8 checks already
// this code only decodes valid pct encoded values // this code only decodes valid pct encoded values
Some(unsafe { String::from_utf8_unchecked(data) }) unsafe { String::from_utf8_unchecked(data) }
} else { })
None
}
} }
} }

View File

@ -0,0 +1,60 @@
//! An example on how to build a multi-thread tokio runtime for Actix System.
//! Then spawn async task that can make use of work stealing of tokio runtime.
use actix_rt::System;
fn main() {
System::with_tokio_rt(|| {
// build system with a multi-thread tokio runtime.
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.unwrap()
})
.block_on(async_main());
}
// async main function that acts like #[actix_web::main] or #[tokio::main]
async fn async_main() {
let (tx, rx) = tokio::sync::oneshot::channel();
// get a handle to system arbiter and spawn async task on it
System::current().arbiter().spawn(async {
// use tokio::spawn to get inside the context of multi thread tokio runtime
let h1 = tokio::spawn(async {
println!("thread id is {:?}", std::thread::current().id());
std::thread::sleep(std::time::Duration::from_secs(2));
});
// work stealing occurs for this task spawn
let h2 = tokio::spawn(async {
println!("thread id is {:?}", std::thread::current().id());
});
h1.await.unwrap();
h2.await.unwrap();
let _ = tx.send(());
});
rx.await.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let now = std::time::Instant::now();
// without additional tokio::spawn, all spawned tasks run on single thread
System::current().arbiter().spawn(async {
println!("thread id is {:?}", std::thread::current().id());
std::thread::sleep(std::time::Duration::from_secs(2));
let _ = tx.send(());
});
// previous spawn task has blocked the system arbiter thread
// so this task will wait for 2 seconds until it can be run
System::current().arbiter().spawn(async move {
println!("thread id is {:?}", std::thread::current().id());
assert!(now.elapsed() > std::time::Duration::from_secs(2));
});
rx.await.unwrap();
}

View File

@ -87,7 +87,7 @@ pub mod net {
pub use tokio::net::{UnixDatagram, UnixListener, UnixStream}; pub use tokio::net::{UnixDatagram, UnixListener, UnixStream};
/// Extension trait over async read+write types that can also signal readiness. /// Extension trait over async read+write types that can also signal readiness.
pub trait ActixStream: AsyncRead + AsyncWrite + Unpin + 'static { pub trait ActixStream: AsyncRead + AsyncWrite + Unpin {
/// Poll stream and check read readiness of Self. /// Poll stream and check read readiness of Self.
/// ///
/// See [tokio::net::TcpStream::poll_read_ready] for detail on intended use. /// See [tokio::net::TcpStream::poll_read_ready] for detail on intended use.
@ -127,6 +127,16 @@ pub mod net {
ready.poll(cx) ready.poll(cx)
} }
} }
impl<Io: ActixStream + ?Sized> ActixStream for Box<Io> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
(**self).poll_read_ready(cx)
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
(**self).poll_write_ready(cx)
}
}
} }
pub mod time { pub mod time {

View File

@ -1,6 +1,9 @@
# Changes # Changes
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
* Prevent panic when shutdown_timeout is very large. [f9262db]
[f9262db]: https://github.com/actix/actix-net/commit/f9262db
## 2.0.0-beta.3 - 2021-02-06 ## 2.0.0-beta.3 - 2021-02-06

View File

@ -14,12 +14,15 @@ use crate::Token;
const DUR_ON_ERR: Duration = Duration::from_millis(500); const DUR_ON_ERR: Duration = Duration::from_millis(500);
struct ServerSocketInfo { struct ServerSocketInfo {
// addr for socket. mainly used for logging. /// Address of socket. Mainly used for logging.
addr: SocketAddr, addr: SocketAddr,
// be ware this is the crate token for identify socket and should not be confused with
// mio::Token /// Beware this is the crate token for identify socket and should not be confused
/// with `mio::Token`.
token: Token, token: Token,
lst: MioListener, lst: MioListener,
// mark the deadline when this socket's listener should be registered again // mark the deadline when this socket's listener should be registered again
timeout_deadline: Option<Instant>, timeout_deadline: Option<Instant>,
} }
@ -192,10 +195,9 @@ impl Accept {
Some(WakerInterest::Stop) => { Some(WakerInterest::Stop) => {
return self.deregister_all(&mut sockets); return self.deregister_all(&mut sockets);
} }
// waker queue is drained. // waker queue is drained
None => { None => {
// Reset the WakerQueue before break so it does not grow // Reset the WakerQueue before break so it does not grow infinitely
// infinitely.
WakerQueue::reset(&mut guard); WakerQueue::reset(&mut guard);
break 'waker; break 'waker;
} }
@ -316,8 +318,8 @@ impl Accept {
} }
Err(tmp) => { Err(tmp) => {
// worker lost contact and could be gone. a message is sent to // worker lost contact and could be gone. a message is sent to
// `ServerBuilder` future to notify it a new worker should be made. // `ServerBuilder` future to notify it a new worker should be made
// after that remove the fault worker. // after that remove the fault worker
self.srv.worker_faulted(self.handles[self.next].idx); self.srv.worker_faulted(self.handles[self.next].idx);
msg = tmp; msg = tmp;
self.handles.swap_remove(self.next); self.handles.swap_remove(self.next);

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use std::{io, mem}; use std::{io, mem};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_rt::time::{sleep_until, Instant}; use actix_rt::time::sleep;
use actix_rt::System; use actix_rt::System;
use futures_core::future::BoxFuture; use futures_core::future::BoxFuture;
use log::{error, info}; use log::{error, info};
@ -115,13 +115,13 @@ impl ServerBuilder {
self self
} }
/// Stop actix system. /// Stop Actix system.
pub fn system_exit(mut self) -> Self { pub fn system_exit(mut self) -> Self {
self.exit = true; self.exit = true;
self self
} }
/// Disable signal handling /// Disable signal handling.
pub fn disable_signals(mut self) -> Self { pub fn disable_signals(mut self) -> Self {
self.no_signals = true; self.no_signals = true;
self self
@ -129,9 +129,8 @@ impl ServerBuilder {
/// Timeout for graceful workers shutdown in seconds. /// Timeout for graceful workers shutdown in seconds.
/// ///
/// After receiving a stop signal, workers have this much time to finish /// After receiving a stop signal, workers have this much time to finish serving requests.
/// serving requests. Workers still alive after the timeout are force /// Workers still alive after the timeout are force dropped.
/// dropped.
/// ///
/// By default shutdown timeout sets to 30 seconds. /// By default shutdown timeout sets to 30 seconds.
pub fn shutdown_timeout(mut self, sec: u64) -> Self { pub fn shutdown_timeout(mut self, sec: u64) -> Self {
@ -140,11 +139,10 @@ impl ServerBuilder {
self self
} }
/// Execute external configuration as part of the server building /// Execute external configuration as part of the server building process.
/// process.
/// ///
/// This function is useful for moving parts of configuration to a /// This function is useful for moving parts of configuration to a different module or
/// different module or even library. /// even library.
pub fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder> pub fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder>
where where
F: Fn(&mut ServiceConfig) -> io::Result<()>, F: Fn(&mut ServiceConfig) -> io::Result<()>,
@ -261,6 +259,7 @@ impl ServerBuilder {
self.sockets self.sockets
.push((token, name.as_ref().to_string(), MioListener::from(lst))); .push((token, name.as_ref().to_string(), MioListener::from(lst)));
Ok(self) Ok(self)
} }
@ -430,8 +429,8 @@ impl ServerFuture {
let _ = tx.send(()); let _ = tx.send(());
} }
if exit { if exit {
sleep_until(Instant::now() + Duration::from_millis(300)).await; sleep(Duration::from_millis(300)).await;
System::current().stop(); System::try_current().as_ref().map(System::stop);
} }
})) }))
} else { } else {
@ -440,8 +439,8 @@ impl ServerFuture {
// TODO: this async block can return io::Error. // TODO: this async block can return io::Error.
Some(Box::pin(async move { Some(Box::pin(async move {
if exit { if exit {
sleep_until(Instant::now() + Duration::from_millis(300)).await; sleep(Duration::from_millis(300)).await;
System::current().stop(); System::try_current().as_ref().map(System::stop);
} }
if let Some(tx) = completion { if let Some(tx) = completion {
let _ = tx.send(()); let _ = tx.send(());

View File

@ -103,10 +103,10 @@ impl TestServer {
let port = addr.port(); let port = addr.port();
TestServerRuntime { TestServerRuntime {
system,
addr, addr,
host, host,
port, port,
system,
} }
} }

View File

@ -5,7 +5,10 @@ use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use actix_rt::time::{sleep_until, Instant, Sleep}; use actix_rt::{
time::{sleep, Sleep},
System,
};
use actix_utils::counter::Counter; use actix_utils::counter::Counter;
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use log::{error, info, trace}; use log::{error, info, trace};
@ -199,12 +202,26 @@ impl ServerWorker {
availability.set(false); availability.set(false);
let handle = tokio::runtime::Handle::current(); // Try to get actix system when have one.
let system = System::try_current();
// every worker runs in it's own arbiter. // every worker runs in it's own thread.
// use a custom tokio runtime builder to change the settings of runtime. // use a custom tokio runtime builder to change the settings of runtime.
std::thread::spawn(move || { std::thread::spawn(move || {
handle.block_on(tokio::task::LocalSet::new().run_until(async move { // conditionally setup actix system.
if let Some(system) = system {
System::set_current(system);
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.max_blocking_threads(config.max_blocking_threads)
.build()
.unwrap();
let local = tokio::task::LocalSet::new();
rt.block_on(local.run_until(async move {
let mut wrk = MAX_CONNS_COUNTER.with(move |conns| ServerWorker { let mut wrk = MAX_CONNS_COUNTER.with(move |conns| ServerWorker {
rx, rx,
rx2, rx2,
@ -355,8 +372,8 @@ impl Future for ServerWorker {
if num != 0 { if num != 0 {
info!("Graceful worker shutdown, {} connections", num); info!("Graceful worker shutdown, {} connections", num);
self.state = WorkerState::Shutdown( self.state = WorkerState::Shutdown(
Box::pin(sleep_until(Instant::now() + Duration::from_secs(1))), Box::pin(sleep(Duration::from_secs(1))),
Box::pin(sleep_until(Instant::now() + self.config.shutdown_timeout)), Box::pin(sleep(self.config.shutdown_timeout)),
Some(result), Some(result),
); );
} else { } else {
@ -430,7 +447,7 @@ impl Future for ServerWorker {
// sleep for 1 second and then check again // sleep for 1 second and then check again
if t1.as_mut().poll(cx).is_ready() { if t1.as_mut().poll(cx).is_ready() {
*t1 = Box::pin(sleep_until(Instant::now() + Duration::from_secs(1))); *t1 = Box::pin(sleep(Duration::from_secs(1)));
let _ = t1.as_mut().poll(cx); let _ = t1.as_mut().poll(cx);
} }

View File

@ -180,7 +180,7 @@ where
F: Fn(A::Error) -> E, F: Fn(A::Error) -> E,
{ {
fn new(fut: A::Future, f: F) -> Self { fn new(fut: A::Future, f: F) -> Self {
MapErrServiceFuture { f, fut } MapErrServiceFuture { fut, f }
} }
} }

View File

@ -1,6 +1,19 @@
# Changes # Changes
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
* Changed `connect::ssl::rustls::RustlsConnectorService` to return error when `DNSNameRef`
generation failed instead of panic. [#296]
* Remove `connect::ssl::openssl::OpensslConnectServiceFactory`. [#297]
* Remove `connect::ssl::openssl::OpensslConnectService`. [#297]
* Add `connect::ssl::native_tls` module for native tls support. [#295]
* Rename `accept::{nativetls => native_tls}`. [#295]
* Remove `connect::TcpConnectService` type. service caller expect a `TcpStream` should use
`connect::ConnectService` instead and call `Connection<T, TcpStream>::into_parts`. [#299]
[#295]: https://github.com/actix/actix-net/pull/295
[#296]: https://github.com/actix/actix-net/pull/296
[#297]: https://github.com/actix/actix-net/pull/297
[#299]: https://github.com/actix/actix-net/pull/299
## 3.0.0-beta.4 - 2021-02-24 ## 3.0.0-beta.4 - 2021-02-24

View File

@ -16,7 +16,7 @@ pub mod openssl;
pub mod rustls; pub mod rustls;
#[cfg(feature = "native-tls")] #[cfg(feature = "native-tls")]
pub mod nativetls; pub mod native_tls;
pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256); pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256);

View File

@ -113,7 +113,7 @@ impl Clone for Acceptor {
} }
} }
impl<T: ActixStream> ServiceFactory<T> for Acceptor { impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor {
type Response = TlsStream<T>; type Response = TlsStream<T>;
type Error = Error; type Error = Error;
type Config = (); type Config = ();
@ -138,16 +138,7 @@ pub struct NativeTlsAcceptorService {
conns: Counter, conns: Counter,
} }
impl Clone for NativeTlsAcceptorService { impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
conns: self.conns.clone(),
}
}
}
impl<T: ActixStream> Service<T> for NativeTlsAcceptorService {
type Response = TlsStream<T>; type Response = TlsStream<T>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<TlsStream<T>, Error>>; type Future = LocalBoxFuture<'static, Result<TlsStream<T>, Error>>;
@ -162,9 +153,9 @@ impl<T: ActixStream> Service<T> for NativeTlsAcceptorService {
fn call(&self, io: T) -> Self::Future { fn call(&self, io: T) -> Self::Future {
let guard = self.conns.get(); let guard = self.conns.get();
let this = self.clone(); let acceptor = self.acceptor.clone();
Box::pin(async move { Box::pin(async move {
let io = this.acceptor.accept(io).await; let io = acceptor.accept(io).await;
drop(guard); drop(guard);
io.map(Into::into) io.map(Into::into)
}) })

View File

@ -72,7 +72,7 @@ pub enum TcpConnectorResponse<T> {
port: u16, port: u16,
local_addr: Option<IpAddr>, local_addr: Option<IpAddr>,
addrs: Option<VecDeque<SocketAddr>>, addrs: Option<VecDeque<SocketAddr>>,
stream: Option<ReusableBoxFuture<Result<TcpStream, io::Error>>>, stream: ReusableBoxFuture<Result<TcpStream, io::Error>>,
}, },
Error(Option<ConnectError>), Error(Option<ConnectError>),
} }
@ -103,18 +103,22 @@ impl<T: Address> TcpConnectorResponse<T> {
port, port,
local_addr, local_addr,
addrs: None, addrs: None,
stream: Some(ReusableBoxFuture::new(connect(addr, local_addr))), stream: ReusableBoxFuture::new(connect(addr, local_addr)),
}, },
// when resolver returns multiple socket addr for request they would be popped from // when resolver returns multiple socket addr for request they would be popped from
// front end of queue and returns with the first successful tcp connection. // front end of queue and returns with the first successful tcp connection.
ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response { ConnectAddrs::Multi(mut addrs) => {
req: Some(req), let addr = addrs.pop_front().unwrap();
port,
local_addr, TcpConnectorResponse::Response {
addrs: Some(addrs), req: Some(req),
stream: None, port,
}, local_addr,
addrs: Some(addrs),
stream: ReusableBoxFuture::new(connect(addr, local_addr)),
}
}
} }
} }
} }
@ -133,40 +137,31 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
addrs, addrs,
stream, stream,
} => loop { } => loop {
if let Some(new) = stream.as_mut() { match ready!(stream.poll(cx)) {
match ready!(new.poll(cx)) { Ok(sock) => {
Ok(sock) => { let req = req.take().unwrap();
let req = req.take().unwrap(); trace!(
trace!( "TCP connector: successfully connected to {:?} - {:?}",
"TCP connector: successfully connected to {:?} - {:?}", req.hostname(),
req.hostname(), sock.peer_addr()
sock.peer_addr() );
); return Poll::Ready(Ok(Connection::new(sock, req)));
return Poll::Ready(Ok(Connection::new(sock, req))); }
}
Err(err) => { Err(err) => {
trace!( trace!(
"TCP connector: failed to connect to {:?} port: {}", "TCP connector: failed to connect to {:?} port: {}",
req.as_ref().unwrap().hostname(), req.as_ref().unwrap().hostname(),
port, port,
); );
if addrs.is_none() || addrs.as_ref().unwrap().is_empty() { if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) {
return Poll::Ready(Err(ConnectError::Io(err))); stream.set(connect(addr, *local_addr));
} } else {
return Poll::Ready(Err(ConnectError::Io(err)));
} }
} }
} }
// try to connect
let addr = addrs.as_mut().unwrap().pop_front().unwrap();
let fut = connect(addr, *local_addr);
match stream {
Some(rbf) => rbf.set(fut),
None => *stream = Some(ReusableBoxFuture::new(fut)),
}
}, },
} }
} }

View File

@ -26,20 +26,20 @@ pub mod ssl;
mod uri; mod uri;
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{pipeline, pipeline_factory, Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
pub use self::connect::{Address, Connect, Connection}; pub use self::connect::{Address, Connect, Connection};
pub use self::connector::{TcpConnector, TcpConnectorFactory}; pub use self::connector::{TcpConnector, TcpConnectorFactory};
pub use self::error::ConnectError; pub use self::error::ConnectError;
pub use self::resolve::{Resolve, Resolver, ResolverFactory}; pub use self::resolve::{Resolve, Resolver, ResolverFactory};
pub use self::service::{ConnectService, ConnectServiceFactory, TcpConnectService}; pub use self::service::{ConnectService, ConnectServiceFactory};
/// Create TCP connector service. /// Create TCP connector service.
pub fn new_connector<T: Address + 'static>( pub fn new_connector<T: Address + 'static>(
resolver: Resolver, resolver: Resolver,
) -> impl Service<Connect<T>, Response = Connection<T, TcpStream>, Error = ConnectError> + Clone ) -> impl Service<Connect<T>, Response = Connection<T, TcpStream>, Error = ConnectError> + Clone
{ {
pipeline(resolver).and_then(TcpConnector) ConnectServiceFactory::new(resolver).service()
} }
/// Create TCP connector service factory. /// Create TCP connector service factory.
@ -52,7 +52,7 @@ pub fn new_connector_factory<T: Address + 'static>(
Error = ConnectError, Error = ConnectError,
InitError = (), InitError = (),
> + Clone { > + Clone {
pipeline_factory(ResolverFactory::new(resolver)).and_then(TcpConnectorFactory) ConnectServiceFactory::new(resolver)
} }
/// Create connector service with default parameters. /// Create connector service with default parameters.

View File

@ -34,14 +34,6 @@ impl ConnectServiceFactory {
resolver: self.resolver.service(), resolver: self.resolver.service(),
} }
} }
/// Construct new tcp stream service
pub fn tcp_service(&self) -> TcpConnectService {
TcpConnectService {
tcp: self.tcp.service(),
resolver: self.resolver.service(),
}
}
} }
impl Clone for ConnectServiceFactory { impl Clone for ConnectServiceFactory {
@ -63,7 +55,7 @@ impl<T: Address> ServiceFactory<Connect<T>> for ConnectServiceFactory {
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
let service = self.service(); let service = self.service();
Box::pin(async move { Ok(service) }) Box::pin(async { Ok(service) })
} }
} }
@ -135,44 +127,3 @@ impl<T: Address> Future for ConnectServiceResponse<T> {
} }
} }
} }
#[derive(Clone)]
pub struct TcpConnectService {
tcp: TcpConnector,
resolver: Resolver,
}
impl<T: Address> Service<Connect<T>> for TcpConnectService {
type Response = TcpStream;
type Error = ConnectError;
type Future = TcpConnectServiceResponse<T>;
actix_service::always_ready!();
fn call(&self, req: Connect<T>) -> Self::Future {
TcpConnectServiceResponse {
fut: ConnectFuture::Resolve(self.resolver.call(req)),
tcp: self.tcp,
}
}
}
pub struct TcpConnectServiceResponse<T: Address> {
fut: ConnectFuture<T>,
tcp: TcpConnector,
}
impl<T: Address> Future for TcpConnectServiceResponse<T> {
type Output = Result<TcpStream, ConnectError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match ready!(self.fut.poll_connect(cx))? {
ConnectOutput::Resolved(res) => {
self.fut = ConnectFuture::Connect(self.tcp.call(res));
}
ConnectOutput::Connected(conn) => return Poll::Ready(Ok(conn.into_parts().0)),
}
}
}
}

View File

@ -5,3 +5,6 @@ pub mod openssl;
#[cfg(feature = "rustls")] #[cfg(feature = "rustls")]
pub mod rustls; pub mod rustls;
#[cfg(feature = "native-tls")]
pub mod native_tls;

View File

@ -0,0 +1,88 @@
use std::io;
use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory};
use futures_core::future::LocalBoxFuture;
use log::trace;
use tokio_native_tls::{TlsConnector as TokioNativetlsConnector, TlsStream};
pub use tokio_native_tls::native_tls::TlsConnector;
use crate::connect::{Address, Connection};
/// Native-tls connector factory and service
pub struct NativetlsConnector {
connector: TokioNativetlsConnector,
}
impl NativetlsConnector {
pub fn new(connector: TlsConnector) -> Self {
Self {
connector: TokioNativetlsConnector::from(connector),
}
}
}
impl NativetlsConnector {
pub fn service(connector: TlsConnector) -> Self {
Self::new(connector)
}
}
impl Clone for NativetlsConnector {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
}
}
}
impl<T: Address, U> ServiceFactory<Connection<T, U>> for NativetlsConnector
where
U: ActixStream + 'static,
{
type Response = Connection<T, TlsStream<U>>;
type Error = io::Error;
type Config = ();
type Service = Self;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
let connector = self.clone();
Box::pin(async { Ok(connector) })
}
}
// NativetlsConnector is both it's ServiceFactory and Service impl type.
// As the factory and service share the same type and state.
impl<T, U> Service<Connection<T, U>> for NativetlsConnector
where
T: Address,
U: ActixStream + 'static,
{
type Response = Connection<T, TlsStream<U>>;
type Error = io::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
actix_service::always_ready!();
fn call(&self, stream: Connection<T, U>) -> Self::Future {
let (io, stream) = stream.replace_io(());
let connector = self.connector.clone();
Box::pin(async move {
trace!("SSL Handshake start for: {:?}", stream.host());
connector
.connect(stream.host(), io)
.await
.map(|res| {
trace!("SSL Handshake success: {:?}", stream.host());
stream.replace_io(res).1
})
.map_err(|e| {
trace!("SSL Handshake error: {:?}", e);
io::Error::new(io::ErrorKind::Other, format!("{}", e))
})
})
}
}

View File

@ -1,13 +1,11 @@
use std::{ use std::{
fmt,
future::Future, future::Future,
io, io,
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::ActixStream;
use actix_rt::net::TcpStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready}; use futures_core::{future::LocalBoxFuture, ready};
use log::trace; use log::trace;
@ -15,10 +13,7 @@ use log::trace;
pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
pub use tokio_openssl::SslStream; pub use tokio_openssl::SslStream;
use crate::connect::resolve::Resolve; use crate::connect::{Address, Connection};
use crate::connect::{
Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection, Resolver,
};
/// OpenSSL connector factory /// OpenSSL connector factory
pub struct OpensslConnector { pub struct OpensslConnector {
@ -45,8 +40,8 @@ impl Clone for OpensslConnector {
impl<T, U> ServiceFactory<Connection<T, U>> for OpensslConnector impl<T, U> ServiceFactory<Connection<T, U>> for OpensslConnector
where where
T: Address + 'static, T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, U: ActixStream + 'static,
{ {
type Response = Connection<T, SslStream<U>>; type Response = Connection<T, SslStream<U>>;
type Error = io::Error; type Error = io::Error;
@ -75,8 +70,8 @@ impl Clone for OpensslConnectorService {
impl<T, U> Service<Connection<T, U>> for OpensslConnectorService impl<T, U> Service<Connection<T, U>> for OpensslConnectorService
where where
T: Address + 'static, T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, U: ActixStream,
{ {
type Response = Connection<T, SslStream<U>>; type Response = Connection<T, SslStream<U>>;
type Error = io::Error; type Error = io::Error;
@ -112,7 +107,8 @@ pub struct ConnectAsyncExt<T, U> {
impl<T: Address, U> Future for ConnectAsyncExt<T, U> impl<T: Address, U> Future for ConnectAsyncExt<T, U>
where where
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, T: Address,
U: ActixStream,
{ {
type Output = Result<Connection<T, SslStream<U>>, io::Error>; type Output = Result<Connection<T, SslStream<U>>, io::Error>;
@ -132,115 +128,3 @@ where
} }
} }
} }
pub struct OpensslConnectServiceFactory {
tcp: ConnectServiceFactory,
openssl: OpensslConnector,
}
impl OpensslConnectServiceFactory {
/// Construct new OpensslConnectService factory
pub fn new(connector: SslConnector) -> Self {
OpensslConnectServiceFactory {
tcp: ConnectServiceFactory::new(Resolver::Default),
openssl: OpensslConnector::new(connector),
}
}
/// Construct new connect service with custom DNS resolver
pub fn with_resolver(connector: SslConnector, resolver: impl Resolve + 'static) -> Self {
OpensslConnectServiceFactory {
tcp: ConnectServiceFactory::new(Resolver::new_custom(resolver)),
openssl: OpensslConnector::new(connector),
}
}
/// Construct OpenSSL connect service
pub fn service(&self) -> OpensslConnectService {
OpensslConnectService {
tcp: self.tcp.service(),
openssl: OpensslConnectorService {
connector: self.openssl.connector.clone(),
},
}
}
}
impl Clone for OpensslConnectServiceFactory {
fn clone(&self) -> Self {
OpensslConnectServiceFactory {
tcp: self.tcp.clone(),
openssl: self.openssl.clone(),
}
}
}
impl<T: Address + 'static> ServiceFactory<Connect<T>> for OpensslConnectServiceFactory {
type Response = SslStream<TcpStream>;
type Error = ConnectError;
type Config = ();
type Service = OpensslConnectService;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
let service = self.service();
Box::pin(async { Ok(service) })
}
}
#[derive(Clone)]
pub struct OpensslConnectService {
tcp: ConnectService,
openssl: OpensslConnectorService,
}
impl<T: Address + 'static> Service<Connect<T>> for OpensslConnectService {
type Response = SslStream<TcpStream>;
type Error = ConnectError;
type Future = OpensslConnectServiceResponse<T>;
actix_service::always_ready!();
fn call(&self, req: Connect<T>) -> Self::Future {
OpensslConnectServiceResponse {
fut1: Some(self.tcp.call(req)),
fut2: None,
openssl: self.openssl.clone(),
}
}
}
pub struct OpensslConnectServiceResponse<T: Address + 'static> {
fut1: Option<<ConnectService as Service<Connect<T>>>::Future>,
fut2: Option<<OpensslConnectorService as Service<Connection<T, TcpStream>>>::Future>,
openssl: OpensslConnectorService,
}
impl<T: Address> Future for OpensslConnectServiceResponse<T> {
type Output = Result<SslStream<TcpStream>, ConnectError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut1 {
match ready!(Pin::new(fut).poll(cx)) {
Ok(res) => {
let _ = self.fut1.take();
self.fut2 = Some(self.openssl.call(res));
}
Err(e) => return Poll::Ready(Err(e)),
}
}
if let Some(ref mut fut) = self.fut2 {
match ready!(Pin::new(fut).poll(cx)) {
Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)),
Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new(
io::ErrorKind::Other,
e,
)))),
}
} else {
Poll::Pending
}
}
}

View File

@ -1,6 +1,6 @@
use std::{ use std::{
fmt,
future::Future, future::Future,
io,
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
@ -10,7 +10,7 @@ pub use tokio_rustls::rustls::Session;
pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
pub use webpki_roots::TLS_SERVER_ROOTS; pub use webpki_roots::TLS_SERVER_ROOTS;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::ActixStream;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready}; use futures_core::{future::LocalBoxFuture, ready};
use log::trace; use log::trace;
@ -44,12 +44,13 @@ impl Clone for RustlsConnector {
} }
} }
impl<T: Address, U> ServiceFactory<Connection<T, U>> for RustlsConnector impl<T, U> ServiceFactory<Connection<T, U>> for RustlsConnector
where where
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, T: Address,
U: ActixStream + 'static,
{ {
type Response = Connection<T, TlsStream<U>>; type Response = Connection<T, TlsStream<U>>;
type Error = std::io::Error; type Error = io::Error;
type Config = (); type Config = ();
type Service = RustlsConnectorService; type Service = RustlsConnectorService;
type InitError = (); type InitError = ();
@ -76,43 +77,55 @@ impl Clone for RustlsConnectorService {
impl<T, U> Service<Connection<T, U>> for RustlsConnectorService impl<T, U> Service<Connection<T, U>> for RustlsConnectorService
where where
T: Address, T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, U: ActixStream,
{ {
type Response = Connection<T, TlsStream<U>>; type Response = Connection<T, TlsStream<U>>;
type Error = std::io::Error; type Error = io::Error;
type Future = ConnectAsyncExt<T, U>; type Future = RustlsConnectorServiceFuture<T, U>;
actix_service::always_ready!(); actix_service::always_ready!();
fn call(&self, stream: Connection<T, U>) -> Self::Future { fn call(&self, connection: Connection<T, U>) -> Self::Future {
trace!("SSL Handshake start for: {:?}", stream.host()); trace!("SSL Handshake start for: {:?}", connection.host());
let (io, stream) = stream.replace_io(()); let (stream, connection) = connection.replace_io(());
let host = DNSNameRef::try_from_ascii_str(stream.host())
.expect("rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54"); match DNSNameRef::try_from_ascii_str(connection.host()) {
ConnectAsyncExt { Ok(host) => RustlsConnectorServiceFuture::Future {
fut: TlsConnector::from(self.connector.clone()).connect(host, io), connect: TlsConnector::from(self.connector.clone()).connect(host, stream),
stream: Some(stream), connection: Some(connection),
},
Err(_) => RustlsConnectorServiceFuture::InvalidDns,
} }
} }
} }
pub struct ConnectAsyncExt<T, U> { pub enum RustlsConnectorServiceFuture<T, U> {
fut: Connect<U>, /// See issue https://github.com/briansmith/webpki/issues/54
stream: Option<Connection<T, ()>>, InvalidDns,
Future {
connect: Connect<U>,
connection: Option<Connection<T, ()>>,
},
} }
impl<T, U> Future for ConnectAsyncExt<T, U> impl<T, U> Future for RustlsConnectorServiceFuture<T, U>
where where
T: Address, T: Address,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, U: ActixStream,
{ {
type Output = Result<Connection<T, TlsStream<U>>, std::io::Error>; type Output = Result<Connection<T, TlsStream<U>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); match self.get_mut() {
let stream = ready!(Pin::new(&mut this.fut).poll(cx))?; Self::InvalidDns => Poll::Ready(Err(
let s = this.stream.take().unwrap(); io::Error::new(io::ErrorKind::Other, "rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54")
trace!("SSL Handshake success: {:?}", s.host()); )),
Poll::Ready(Ok(s.replace_io(stream).1)) Self::Future { connect, connection } => {
let stream = ready!(Pin::new(connect).poll(cx))?;
let connection = connection.take().unwrap();
trace!("SSL Handshake success: {:?}", connection.host());
Poll::Ready(Ok(connection.replace_io(stream).1))
}
}
} }
} }