diff --git a/actix-router/src/url.rs b/actix-router/src/url.rs index d2dd7a19..f669da99 100644 --- a/actix-router/src/url.rs +++ b/actix-router/src/url.rs @@ -170,13 +170,11 @@ impl Quoter { idx += 1; } - if let Some(data) = cloned { - // Unsafe: we get data from http::Uri, which does utf-8 checks already + cloned.map(|data| { + // SAFETY: we get data from http::Uri, which does UTF-8 checks already // this code only decodes valid pct encoded values - Some(unsafe { String::from_utf8_unchecked(data) }) - } else { - None - } + unsafe { String::from_utf8_unchecked(data) } + }) } } diff --git a/actix-rt/examples/multi_thread_system.rs b/actix-rt/examples/multi_thread_system.rs new file mode 100644 index 00000000..0ecd1ef1 --- /dev/null +++ b/actix-rt/examples/multi_thread_system.rs @@ -0,0 +1,60 @@ +//! An example on how to build a multi-thread tokio runtime for Actix System. +//! Then spawn async task that can make use of work stealing of tokio runtime. + +use actix_rt::System; + +fn main() { + System::with_tokio_rt(|| { + // build system with a multi-thread tokio runtime. + tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build() + .unwrap() + }) + .block_on(async_main()); +} + +// async main function that acts like #[actix_web::main] or #[tokio::main] +async fn async_main() { + let (tx, rx) = tokio::sync::oneshot::channel(); + + // get a handle to system arbiter and spawn async task on it + System::current().arbiter().spawn(async { + // use tokio::spawn to get inside the context of multi thread tokio runtime + let h1 = tokio::spawn(async { + println!("thread id is {:?}", std::thread::current().id()); + std::thread::sleep(std::time::Duration::from_secs(2)); + }); + + // work stealing occurs for this task spawn + let h2 = tokio::spawn(async { + println!("thread id is {:?}", std::thread::current().id()); + }); + + h1.await.unwrap(); + h2.await.unwrap(); + let _ = tx.send(()); + }); + + rx.await.unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + let now = std::time::Instant::now(); + + // without additional tokio::spawn, all spawned tasks run on single thread + System::current().arbiter().spawn(async { + println!("thread id is {:?}", std::thread::current().id()); + std::thread::sleep(std::time::Duration::from_secs(2)); + let _ = tx.send(()); + }); + + // previous spawn task has blocked the system arbiter thread + // so this task will wait for 2 seconds until it can be run + System::current().arbiter().spawn(async move { + println!("thread id is {:?}", std::thread::current().id()); + assert!(now.elapsed() > std::time::Duration::from_secs(2)); + }); + + rx.await.unwrap(); +} diff --git a/actix-rt/src/lib.rs b/actix-rt/src/lib.rs index bd2e165d..afbe8642 100644 --- a/actix-rt/src/lib.rs +++ b/actix-rt/src/lib.rs @@ -87,7 +87,7 @@ pub mod net { pub use tokio::net::{UnixDatagram, UnixListener, UnixStream}; /// Extension trait over async read+write types that can also signal readiness. - pub trait ActixStream: AsyncRead + AsyncWrite + Unpin + 'static { + pub trait ActixStream: AsyncRead + AsyncWrite + Unpin { /// Poll stream and check read readiness of Self. /// /// See [tokio::net::TcpStream::poll_read_ready] for detail on intended use. @@ -127,6 +127,16 @@ pub mod net { ready.poll(cx) } } + + impl ActixStream for Box { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + (**self).poll_read_ready(cx) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + (**self).poll_write_ready(cx) + } + } } pub mod time { diff --git a/actix-server/CHANGES.md b/actix-server/CHANGES.md index 5eca1f91..aaa38911 100644 --- a/actix-server/CHANGES.md +++ b/actix-server/CHANGES.md @@ -1,6 +1,9 @@ # Changes ## Unreleased - 2021-xx-xx +* Prevent panic when shutdown_timeout is very large. [f9262db] + +[f9262db]: https://github.com/actix/actix-net/commit/f9262db ## 2.0.0-beta.3 - 2021-02-06 diff --git a/actix-server/src/accept.rs b/actix-server/src/accept.rs index 2b5bd764..e11c1f32 100644 --- a/actix-server/src/accept.rs +++ b/actix-server/src/accept.rs @@ -14,12 +14,15 @@ use crate::Token; const DUR_ON_ERR: Duration = Duration::from_millis(500); struct ServerSocketInfo { - // addr for socket. mainly used for logging. + /// Address of socket. Mainly used for logging. addr: SocketAddr, - // be ware this is the crate token for identify socket and should not be confused with - // mio::Token + + /// Beware this is the crate token for identify socket and should not be confused + /// with `mio::Token`. token: Token, + lst: MioListener, + // mark the deadline when this socket's listener should be registered again timeout_deadline: Option, } @@ -192,10 +195,9 @@ impl Accept { Some(WakerInterest::Stop) => { return self.deregister_all(&mut sockets); } - // waker queue is drained. + // waker queue is drained None => { - // Reset the WakerQueue before break so it does not grow - // infinitely. + // Reset the WakerQueue before break so it does not grow infinitely WakerQueue::reset(&mut guard); break 'waker; } @@ -316,8 +318,8 @@ impl Accept { } Err(tmp) => { // worker lost contact and could be gone. a message is sent to - // `ServerBuilder` future to notify it a new worker should be made. - // after that remove the fault worker. + // `ServerBuilder` future to notify it a new worker should be made + // after that remove the fault worker self.srv.worker_faulted(self.handles[self.next].idx); msg = tmp; self.handles.swap_remove(self.next); diff --git a/actix-server/src/builder.rs b/actix-server/src/builder.rs index 8fd130fc..f9ba7fc9 100644 --- a/actix-server/src/builder.rs +++ b/actix-server/src/builder.rs @@ -5,7 +5,7 @@ use std::time::Duration; use std::{io, mem}; use actix_rt::net::TcpStream; -use actix_rt::time::{sleep_until, Instant}; +use actix_rt::time::sleep; use actix_rt::System; use futures_core::future::BoxFuture; use log::{error, info}; @@ -115,13 +115,13 @@ impl ServerBuilder { self } - /// Stop actix system. + /// Stop Actix system. pub fn system_exit(mut self) -> Self { self.exit = true; self } - /// Disable signal handling + /// Disable signal handling. pub fn disable_signals(mut self) -> Self { self.no_signals = true; self @@ -129,9 +129,8 @@ impl ServerBuilder { /// Timeout for graceful workers shutdown in seconds. /// - /// After receiving a stop signal, workers have this much time to finish - /// serving requests. Workers still alive after the timeout are force - /// dropped. + /// After receiving a stop signal, workers have this much time to finish serving requests. + /// Workers still alive after the timeout are force dropped. /// /// By default shutdown timeout sets to 30 seconds. pub fn shutdown_timeout(mut self, sec: u64) -> Self { @@ -140,11 +139,10 @@ impl ServerBuilder { self } - /// Execute external configuration as part of the server building - /// process. + /// Execute external configuration as part of the server building process. /// - /// This function is useful for moving parts of configuration to a - /// different module or even library. + /// This function is useful for moving parts of configuration to a different module or + /// even library. pub fn configure(mut self, f: F) -> io::Result where F: Fn(&mut ServiceConfig) -> io::Result<()>, @@ -261,6 +259,7 @@ impl ServerBuilder { self.sockets .push((token, name.as_ref().to_string(), MioListener::from(lst))); + Ok(self) } @@ -430,8 +429,8 @@ impl ServerFuture { let _ = tx.send(()); } if exit { - sleep_until(Instant::now() + Duration::from_millis(300)).await; - System::current().stop(); + sleep(Duration::from_millis(300)).await; + System::try_current().as_ref().map(System::stop); } })) } else { @@ -440,8 +439,8 @@ impl ServerFuture { // TODO: this async block can return io::Error. Some(Box::pin(async move { if exit { - sleep_until(Instant::now() + Duration::from_millis(300)).await; - System::current().stop(); + sleep(Duration::from_millis(300)).await; + System::try_current().as_ref().map(System::stop); } if let Some(tx) = completion { let _ = tx.send(()); diff --git a/actix-server/src/test_server.rs b/actix-server/src/test_server.rs index c8e5941c..c2478f92 100644 --- a/actix-server/src/test_server.rs +++ b/actix-server/src/test_server.rs @@ -103,10 +103,10 @@ impl TestServer { let port = addr.port(); TestServerRuntime { - system, addr, host, port, + system, } } diff --git a/actix-server/src/worker.rs b/actix-server/src/worker.rs index a6527fc0..899e5fde 100644 --- a/actix-server/src/worker.rs +++ b/actix-server/src/worker.rs @@ -5,7 +5,10 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use actix_rt::time::{sleep_until, Instant, Sleep}; +use actix_rt::{ + time::{sleep, Sleep}, + System, +}; use actix_utils::counter::Counter; use futures_core::future::LocalBoxFuture; use log::{error, info, trace}; @@ -199,12 +202,26 @@ impl ServerWorker { availability.set(false); - let handle = tokio::runtime::Handle::current(); + // Try to get actix system when have one. + let system = System::try_current(); - // every worker runs in it's own arbiter. + // every worker runs in it's own thread. // use a custom tokio runtime builder to change the settings of runtime. std::thread::spawn(move || { - handle.block_on(tokio::task::LocalSet::new().run_until(async move { + // conditionally setup actix system. + if let Some(system) = system { + System::set_current(system); + } + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .max_blocking_threads(config.max_blocking_threads) + .build() + .unwrap(); + + let local = tokio::task::LocalSet::new(); + + rt.block_on(local.run_until(async move { let mut wrk = MAX_CONNS_COUNTER.with(move |conns| ServerWorker { rx, rx2, @@ -355,8 +372,8 @@ impl Future for ServerWorker { if num != 0 { info!("Graceful worker shutdown, {} connections", num); self.state = WorkerState::Shutdown( - Box::pin(sleep_until(Instant::now() + Duration::from_secs(1))), - Box::pin(sleep_until(Instant::now() + self.config.shutdown_timeout)), + Box::pin(sleep(Duration::from_secs(1))), + Box::pin(sleep(self.config.shutdown_timeout)), Some(result), ); } else { @@ -430,7 +447,7 @@ impl Future for ServerWorker { // sleep for 1 second and then check again if t1.as_mut().poll(cx).is_ready() { - *t1 = Box::pin(sleep_until(Instant::now() + Duration::from_secs(1))); + *t1 = Box::pin(sleep(Duration::from_secs(1))); let _ = t1.as_mut().poll(cx); } diff --git a/actix-service/src/map_err.rs b/actix-service/src/map_err.rs index ff25c4f7..7b1ac2ab 100644 --- a/actix-service/src/map_err.rs +++ b/actix-service/src/map_err.rs @@ -180,7 +180,7 @@ where F: Fn(A::Error) -> E, { fn new(fut: A::Future, f: F) -> Self { - MapErrServiceFuture { f, fut } + MapErrServiceFuture { fut, f } } } diff --git a/actix-tls/CHANGES.md b/actix-tls/CHANGES.md index 824663b0..067c4fe8 100644 --- a/actix-tls/CHANGES.md +++ b/actix-tls/CHANGES.md @@ -1,6 +1,19 @@ # Changes ## Unreleased - 2021-xx-xx +* Changed `connect::ssl::rustls::RustlsConnectorService` to return error when `DNSNameRef` + generation failed instead of panic. [#296] +* Remove `connect::ssl::openssl::OpensslConnectServiceFactory`. [#297] +* Remove `connect::ssl::openssl::OpensslConnectService`. [#297] +* Add `connect::ssl::native_tls` module for native tls support. [#295] +* Rename `accept::{nativetls => native_tls}`. [#295] +* Remove `connect::TcpConnectService` type. service caller expect a `TcpStream` should use + `connect::ConnectService` instead and call `Connection::into_parts`. [#299] + +[#295]: https://github.com/actix/actix-net/pull/295 +[#296]: https://github.com/actix/actix-net/pull/296 +[#297]: https://github.com/actix/actix-net/pull/297 +[#299]: https://github.com/actix/actix-net/pull/299 ## 3.0.0-beta.4 - 2021-02-24 diff --git a/actix-tls/src/accept/mod.rs b/actix-tls/src/accept/mod.rs index 8b1fe47c..dd939e4a 100644 --- a/actix-tls/src/accept/mod.rs +++ b/actix-tls/src/accept/mod.rs @@ -16,7 +16,7 @@ pub mod openssl; pub mod rustls; #[cfg(feature = "native-tls")] -pub mod nativetls; +pub mod native_tls; pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256); diff --git a/actix-tls/src/accept/nativetls.rs b/actix-tls/src/accept/native_tls.rs similarity index 91% rename from actix-tls/src/accept/nativetls.rs rename to actix-tls/src/accept/native_tls.rs index 614bdad3..53294384 100644 --- a/actix-tls/src/accept/nativetls.rs +++ b/actix-tls/src/accept/native_tls.rs @@ -113,7 +113,7 @@ impl Clone for Acceptor { } } -impl ServiceFactory for Acceptor { +impl ServiceFactory for Acceptor { type Response = TlsStream; type Error = Error; type Config = (); @@ -138,16 +138,7 @@ pub struct NativeTlsAcceptorService { conns: Counter, } -impl Clone for NativeTlsAcceptorService { - fn clone(&self) -> Self { - Self { - acceptor: self.acceptor.clone(), - conns: self.conns.clone(), - } - } -} - -impl Service for NativeTlsAcceptorService { +impl Service for NativeTlsAcceptorService { type Response = TlsStream; type Error = Error; type Future = LocalBoxFuture<'static, Result, Error>>; @@ -162,9 +153,9 @@ impl Service for NativeTlsAcceptorService { fn call(&self, io: T) -> Self::Future { let guard = self.conns.get(); - let this = self.clone(); + let acceptor = self.acceptor.clone(); Box::pin(async move { - let io = this.acceptor.accept(io).await; + let io = acceptor.accept(io).await; drop(guard); io.map(Into::into) }) diff --git a/actix-tls/src/connect/connector.rs b/actix-tls/src/connect/connector.rs index 8f32270f..9438404e 100755 --- a/actix-tls/src/connect/connector.rs +++ b/actix-tls/src/connect/connector.rs @@ -72,7 +72,7 @@ pub enum TcpConnectorResponse { port: u16, local_addr: Option, addrs: Option>, - stream: Option>>, + stream: ReusableBoxFuture>, }, Error(Option), } @@ -103,18 +103,22 @@ impl TcpConnectorResponse { port, local_addr, addrs: None, - stream: Some(ReusableBoxFuture::new(connect(addr, local_addr))), + stream: ReusableBoxFuture::new(connect(addr, local_addr)), }, // when resolver returns multiple socket addr for request they would be popped from // front end of queue and returns with the first successful tcp connection. - ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response { - req: Some(req), - port, - local_addr, - addrs: Some(addrs), - stream: None, - }, + ConnectAddrs::Multi(mut addrs) => { + let addr = addrs.pop_front().unwrap(); + + TcpConnectorResponse::Response { + req: Some(req), + port, + local_addr, + addrs: Some(addrs), + stream: ReusableBoxFuture::new(connect(addr, local_addr)), + } + } } } } @@ -133,40 +137,31 @@ impl Future for TcpConnectorResponse { addrs, stream, } => loop { - if let Some(new) = stream.as_mut() { - match ready!(new.poll(cx)) { - Ok(sock) => { - let req = req.take().unwrap(); - trace!( - "TCP connector: successfully connected to {:?} - {:?}", - req.hostname(), - sock.peer_addr() - ); - return Poll::Ready(Ok(Connection::new(sock, req))); - } + match ready!(stream.poll(cx)) { + Ok(sock) => { + let req = req.take().unwrap(); + trace!( + "TCP connector: successfully connected to {:?} - {:?}", + req.hostname(), + sock.peer_addr() + ); + return Poll::Ready(Ok(Connection::new(sock, req))); + } - Err(err) => { - trace!( - "TCP connector: failed to connect to {:?} port: {}", - req.as_ref().unwrap().hostname(), - port, - ); + Err(err) => { + trace!( + "TCP connector: failed to connect to {:?} port: {}", + req.as_ref().unwrap().hostname(), + port, + ); - if addrs.is_none() || addrs.as_ref().unwrap().is_empty() { - return Poll::Ready(Err(ConnectError::Io(err))); - } + if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) { + stream.set(connect(addr, *local_addr)); + } else { + return Poll::Ready(Err(ConnectError::Io(err))); } } } - - // try to connect - let addr = addrs.as_mut().unwrap().pop_front().unwrap(); - - let fut = connect(addr, *local_addr); - match stream { - Some(rbf) => rbf.set(fut), - None => *stream = Some(ReusableBoxFuture::new(fut)), - } }, } } diff --git a/actix-tls/src/connect/mod.rs b/actix-tls/src/connect/mod.rs index 4010e3cb..ad4f40a3 100644 --- a/actix-tls/src/connect/mod.rs +++ b/actix-tls/src/connect/mod.rs @@ -26,20 +26,20 @@ pub mod ssl; mod uri; use actix_rt::net::TcpStream; -use actix_service::{pipeline, pipeline_factory, Service, ServiceFactory}; +use actix_service::{Service, ServiceFactory}; pub use self::connect::{Address, Connect, Connection}; pub use self::connector::{TcpConnector, TcpConnectorFactory}; pub use self::error::ConnectError; pub use self::resolve::{Resolve, Resolver, ResolverFactory}; -pub use self::service::{ConnectService, ConnectServiceFactory, TcpConnectService}; +pub use self::service::{ConnectService, ConnectServiceFactory}; /// Create TCP connector service. pub fn new_connector( resolver: Resolver, ) -> impl Service, Response = Connection, Error = ConnectError> + Clone { - pipeline(resolver).and_then(TcpConnector) + ConnectServiceFactory::new(resolver).service() } /// Create TCP connector service factory. @@ -52,7 +52,7 @@ pub fn new_connector_factory( Error = ConnectError, InitError = (), > + Clone { - pipeline_factory(ResolverFactory::new(resolver)).and_then(TcpConnectorFactory) + ConnectServiceFactory::new(resolver) } /// Create connector service with default parameters. diff --git a/actix-tls/src/connect/service.rs b/actix-tls/src/connect/service.rs index 98765ca1..9961498e 100755 --- a/actix-tls/src/connect/service.rs +++ b/actix-tls/src/connect/service.rs @@ -34,14 +34,6 @@ impl ConnectServiceFactory { resolver: self.resolver.service(), } } - - /// Construct new tcp stream service - pub fn tcp_service(&self) -> TcpConnectService { - TcpConnectService { - tcp: self.tcp.service(), - resolver: self.resolver.service(), - } - } } impl Clone for ConnectServiceFactory { @@ -63,7 +55,7 @@ impl ServiceFactory> for ConnectServiceFactory { fn new_service(&self, _: ()) -> Self::Future { let service = self.service(); - Box::pin(async move { Ok(service) }) + Box::pin(async { Ok(service) }) } } @@ -135,44 +127,3 @@ impl Future for ConnectServiceResponse { } } } - -#[derive(Clone)] -pub struct TcpConnectService { - tcp: TcpConnector, - resolver: Resolver, -} - -impl Service> for TcpConnectService { - type Response = TcpStream; - type Error = ConnectError; - type Future = TcpConnectServiceResponse; - - actix_service::always_ready!(); - - fn call(&self, req: Connect) -> Self::Future { - TcpConnectServiceResponse { - fut: ConnectFuture::Resolve(self.resolver.call(req)), - tcp: self.tcp, - } - } -} - -pub struct TcpConnectServiceResponse { - fut: ConnectFuture, - tcp: TcpConnector, -} - -impl Future for TcpConnectServiceResponse { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - match ready!(self.fut.poll_connect(cx))? { - ConnectOutput::Resolved(res) => { - self.fut = ConnectFuture::Connect(self.tcp.call(res)); - } - ConnectOutput::Connected(conn) => return Poll::Ready(Ok(conn.into_parts().0)), - } - } - } -} diff --git a/actix-tls/src/connect/ssl/mod.rs b/actix-tls/src/connect/ssl/mod.rs index 8ace5ef1..6e0e8aac 100644 --- a/actix-tls/src/connect/ssl/mod.rs +++ b/actix-tls/src/connect/ssl/mod.rs @@ -5,3 +5,6 @@ pub mod openssl; #[cfg(feature = "rustls")] pub mod rustls; + +#[cfg(feature = "native-tls")] +pub mod native_tls; diff --git a/actix-tls/src/connect/ssl/native_tls.rs b/actix-tls/src/connect/ssl/native_tls.rs new file mode 100644 index 00000000..de08ea2a --- /dev/null +++ b/actix-tls/src/connect/ssl/native_tls.rs @@ -0,0 +1,88 @@ +use std::io; + +use actix_rt::net::ActixStream; +use actix_service::{Service, ServiceFactory}; +use futures_core::future::LocalBoxFuture; +use log::trace; +use tokio_native_tls::{TlsConnector as TokioNativetlsConnector, TlsStream}; + +pub use tokio_native_tls::native_tls::TlsConnector; + +use crate::connect::{Address, Connection}; + +/// Native-tls connector factory and service +pub struct NativetlsConnector { + connector: TokioNativetlsConnector, +} + +impl NativetlsConnector { + pub fn new(connector: TlsConnector) -> Self { + Self { + connector: TokioNativetlsConnector::from(connector), + } + } +} + +impl NativetlsConnector { + pub fn service(connector: TlsConnector) -> Self { + Self::new(connector) + } +} + +impl Clone for NativetlsConnector { + fn clone(&self) -> Self { + Self { + connector: self.connector.clone(), + } + } +} + +impl ServiceFactory> for NativetlsConnector +where + U: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Config = (); + type Service = Self; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; + + fn new_service(&self, _: ()) -> Self::Future { + let connector = self.clone(); + Box::pin(async { Ok(connector) }) + } +} + +// NativetlsConnector is both it's ServiceFactory and Service impl type. +// As the factory and service share the same type and state. +impl Service> for NativetlsConnector +where + T: Address, + U: ActixStream + 'static, +{ + type Response = Connection>; + type Error = io::Error; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::always_ready!(); + + fn call(&self, stream: Connection) -> Self::Future { + let (io, stream) = stream.replace_io(()); + let connector = self.connector.clone(); + Box::pin(async move { + trace!("SSL Handshake start for: {:?}", stream.host()); + connector + .connect(stream.host(), io) + .await + .map(|res| { + trace!("SSL Handshake success: {:?}", stream.host()); + stream.replace_io(res).1 + }) + .map_err(|e| { + trace!("SSL Handshake error: {:?}", e); + io::Error::new(io::ErrorKind::Other, format!("{}", e)) + }) + }) + } +} diff --git a/actix-tls/src/connect/ssl/openssl.rs b/actix-tls/src/connect/ssl/openssl.rs index b1c53f56..b4298fed 100755 --- a/actix-tls/src/connect/ssl/openssl.rs +++ b/actix-tls/src/connect/ssl/openssl.rs @@ -1,13 +1,11 @@ use std::{ - fmt, future::Future, io, pin::Pin, task::{Context, Poll}, }; -use actix_codec::{AsyncRead, AsyncWrite}; -use actix_rt::net::TcpStream; +use actix_rt::net::ActixStream; use actix_service::{Service, ServiceFactory}; use futures_core::{future::LocalBoxFuture, ready}; use log::trace; @@ -15,10 +13,7 @@ use log::trace; pub use openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; pub use tokio_openssl::SslStream; -use crate::connect::resolve::Resolve; -use crate::connect::{ - Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection, Resolver, -}; +use crate::connect::{Address, Connection}; /// OpenSSL connector factory pub struct OpensslConnector { @@ -45,8 +40,8 @@ impl Clone for OpensslConnector { impl ServiceFactory> for OpensslConnector where - T: Address + 'static, - U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, + T: Address, + U: ActixStream + 'static, { type Response = Connection>; type Error = io::Error; @@ -75,8 +70,8 @@ impl Clone for OpensslConnectorService { impl Service> for OpensslConnectorService where - T: Address + 'static, - U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, + T: Address, + U: ActixStream, { type Response = Connection>; type Error = io::Error; @@ -112,7 +107,8 @@ pub struct ConnectAsyncExt { impl Future for ConnectAsyncExt where - U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, + T: Address, + U: ActixStream, { type Output = Result>, io::Error>; @@ -132,115 +128,3 @@ where } } } - -pub struct OpensslConnectServiceFactory { - tcp: ConnectServiceFactory, - openssl: OpensslConnector, -} - -impl OpensslConnectServiceFactory { - /// Construct new OpensslConnectService factory - pub fn new(connector: SslConnector) -> Self { - OpensslConnectServiceFactory { - tcp: ConnectServiceFactory::new(Resolver::Default), - openssl: OpensslConnector::new(connector), - } - } - - /// Construct new connect service with custom DNS resolver - pub fn with_resolver(connector: SslConnector, resolver: impl Resolve + 'static) -> Self { - OpensslConnectServiceFactory { - tcp: ConnectServiceFactory::new(Resolver::new_custom(resolver)), - openssl: OpensslConnector::new(connector), - } - } - - /// Construct OpenSSL connect service - pub fn service(&self) -> OpensslConnectService { - OpensslConnectService { - tcp: self.tcp.service(), - openssl: OpensslConnectorService { - connector: self.openssl.connector.clone(), - }, - } - } -} - -impl Clone for OpensslConnectServiceFactory { - fn clone(&self) -> Self { - OpensslConnectServiceFactory { - tcp: self.tcp.clone(), - openssl: self.openssl.clone(), - } - } -} - -impl ServiceFactory> for OpensslConnectServiceFactory { - type Response = SslStream; - type Error = ConnectError; - type Config = (); - type Service = OpensslConnectService; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - let service = self.service(); - Box::pin(async { Ok(service) }) - } -} - -#[derive(Clone)] -pub struct OpensslConnectService { - tcp: ConnectService, - openssl: OpensslConnectorService, -} - -impl Service> for OpensslConnectService { - type Response = SslStream; - type Error = ConnectError; - type Future = OpensslConnectServiceResponse; - - actix_service::always_ready!(); - - fn call(&self, req: Connect) -> Self::Future { - OpensslConnectServiceResponse { - fut1: Some(self.tcp.call(req)), - fut2: None, - openssl: self.openssl.clone(), - } - } -} - -pub struct OpensslConnectServiceResponse { - fut1: Option<>>::Future>, - fut2: Option<>>::Future>, - openssl: OpensslConnectorService, -} - -impl Future for OpensslConnectServiceResponse { - type Output = Result, ConnectError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(ref mut fut) = self.fut1 { - match ready!(Pin::new(fut).poll(cx)) { - Ok(res) => { - let _ = self.fut1.take(); - self.fut2 = Some(self.openssl.call(res)); - } - Err(e) => return Poll::Ready(Err(e)), - } - } - - if let Some(ref mut fut) = self.fut2 { - match ready!(Pin::new(fut).poll(cx)) { - Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)), - Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new( - io::ErrorKind::Other, - e, - )))), - } - } else { - Poll::Pending - } - } -} diff --git a/actix-tls/src/connect/ssl/rustls.rs b/actix-tls/src/connect/ssl/rustls.rs index 46b4b11d..ee8ad02d 100755 --- a/actix-tls/src/connect/ssl/rustls.rs +++ b/actix-tls/src/connect/ssl/rustls.rs @@ -1,6 +1,6 @@ use std::{ - fmt, future::Future, + io, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -10,7 +10,7 @@ pub use tokio_rustls::rustls::Session; pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig}; pub use webpki_roots::TLS_SERVER_ROOTS; -use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::net::ActixStream; use actix_service::{Service, ServiceFactory}; use futures_core::{future::LocalBoxFuture, ready}; use log::trace; @@ -44,12 +44,13 @@ impl Clone for RustlsConnector { } } -impl ServiceFactory> for RustlsConnector +impl ServiceFactory> for RustlsConnector where - U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, + T: Address, + U: ActixStream + 'static, { type Response = Connection>; - type Error = std::io::Error; + type Error = io::Error; type Config = (); type Service = RustlsConnectorService; type InitError = (); @@ -76,43 +77,55 @@ impl Clone for RustlsConnectorService { impl Service> for RustlsConnectorService where T: Address, - U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, + U: ActixStream, { type Response = Connection>; - type Error = std::io::Error; - type Future = ConnectAsyncExt; + type Error = io::Error; + type Future = RustlsConnectorServiceFuture; actix_service::always_ready!(); - fn call(&self, stream: Connection) -> Self::Future { - trace!("SSL Handshake start for: {:?}", stream.host()); - let (io, stream) = stream.replace_io(()); - let host = DNSNameRef::try_from_ascii_str(stream.host()) - .expect("rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54"); - ConnectAsyncExt { - fut: TlsConnector::from(self.connector.clone()).connect(host, io), - stream: Some(stream), + fn call(&self, connection: Connection) -> Self::Future { + trace!("SSL Handshake start for: {:?}", connection.host()); + let (stream, connection) = connection.replace_io(()); + + match DNSNameRef::try_from_ascii_str(connection.host()) { + Ok(host) => RustlsConnectorServiceFuture::Future { + connect: TlsConnector::from(self.connector.clone()).connect(host, stream), + connection: Some(connection), + }, + Err(_) => RustlsConnectorServiceFuture::InvalidDns, } } } -pub struct ConnectAsyncExt { - fut: Connect, - stream: Option>, +pub enum RustlsConnectorServiceFuture { + /// See issue https://github.com/briansmith/webpki/issues/54 + InvalidDns, + Future { + connect: Connect, + connection: Option>, + }, } -impl Future for ConnectAsyncExt +impl Future for RustlsConnectorServiceFuture where T: Address, - U: AsyncRead + AsyncWrite + Unpin + fmt::Debug, + U: ActixStream, { - type Output = Result>, std::io::Error>; + type Output = Result>, io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - let stream = ready!(Pin::new(&mut this.fut).poll(cx))?; - let s = this.stream.take().unwrap(); - trace!("SSL Handshake success: {:?}", s.host()); - Poll::Ready(Ok(s.replace_io(stream).1)) + match self.get_mut() { + Self::InvalidDns => Poll::Ready(Err( + io::Error::new(io::ErrorKind::Other, "rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54") + )), + Self::Future { connect, connection } => { + let stream = ready!(Pin::new(connect).poll(cx))?; + let connection = connection.take().unwrap(); + trace!("SSL Handshake success: {:?}", connection.host()); + Poll::Ready(Ok(connection.replace_io(stream).1)) + } + } } }