This commit is contained in:
Joel Wurtz 2025-03-26 10:25:52 +02:00 committed by GitHub
commit 6ca703beb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 622 additions and 138 deletions

View File

@ -8,6 +8,7 @@
- Do not send `Host` header on HTTP/2 requests, as it is not required, and some web servers may reject it. - Do not send `Host` header on HTTP/2 requests, as it is not required, and some web servers may reject it.
- Update `brotli` dependency to `7`. - Update `brotli` dependency to `7`.
- Minimum supported Rust version (MSRV) is now 1.75. - Minimum supported Rust version (MSRV) is now 1.75.
- Allow to set a specific SNI hostname on the request for TLS connections.
## 3.5.1 ## 3.5.1

View File

@ -3,7 +3,6 @@ use std::{fmt, net::IpAddr, rc::Rc, time::Duration};
use actix_http::{ use actix_http::{
error::HttpError, error::HttpError,
header::{self, HeaderMap, HeaderName, TryIntoHeaderPair}, header::{self, HeaderMap, HeaderName, TryIntoHeaderPair},
Uri,
}; };
use actix_rt::net::{ActixStream, TcpStream}; use actix_rt::net::{ActixStream, TcpStream};
use actix_service::{boxed, Service}; use actix_service::{boxed, Service};
@ -11,7 +10,8 @@ use base64::prelude::*;
use crate::{ use crate::{
client::{ client::{
ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection, ClientConfig, ConnectInfo, Connector, ConnectorService, HostnameWithSni, TcpConnectError,
TcpConnection,
}, },
connect::DefaultConnector, connect::DefaultConnector,
error::SendRequestError, error::SendRequestError,
@ -46,8 +46,8 @@ impl ClientBuilder {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
pub fn new() -> ClientBuilder< pub fn new() -> ClientBuilder<
impl Service< impl Service<
ConnectInfo<Uri>, ConnectInfo<HostnameWithSni>,
Response = TcpConnection<Uri, TcpStream>, Response = TcpConnection<HostnameWithSni, TcpStream>,
Error = TcpConnectError, Error = TcpConnectError,
> + Clone, > + Clone,
(), (),
@ -69,16 +69,22 @@ impl ClientBuilder {
impl<S, Io, M> ClientBuilder<S, M> impl<S, Io, M> ClientBuilder<S, M>
where where
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError> S: Service<
+ Clone ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io>,
Error = TcpConnectError,
> + Clone
+ 'static, + 'static,
Io: ActixStream + fmt::Debug + 'static, Io: ActixStream + fmt::Debug + 'static,
{ {
/// Use custom connector service. /// Use custom connector service.
pub fn connector<S1, Io1>(self, connector: Connector<S1>) -> ClientBuilder<S1, M> pub fn connector<S1, Io1>(self, connector: Connector<S1>) -> ClientBuilder<S1, M>
where where
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError> S1: Service<
+ Clone ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io1>,
Error = TcpConnectError,
> + Clone
+ 'static, + 'static,
Io1: ActixStream + fmt::Debug + 'static, Io1: ActixStream + fmt::Debug + 'static,
{ {

View File

@ -3,29 +3,33 @@ use std::{net::IpAddr, time::Duration};
const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB
const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB
/// Connector configuration /// Connect configuration
#[derive(Clone)] #[derive(Clone, Hash, Eq, PartialEq)]
pub(crate) struct ConnectorConfig { pub struct ConnectConfig {
pub(crate) timeout: Duration, pub(crate) timeout: Duration,
pub(crate) handshake_timeout: Duration, pub(crate) handshake_timeout: Duration,
pub(crate) conn_lifetime: Duration, pub(crate) conn_lifetime: Duration,
pub(crate) conn_keep_alive: Duration, pub(crate) conn_keep_alive: Duration,
pub(crate) disconnect_timeout: Option<Duration>,
pub(crate) limit: usize,
pub(crate) conn_window_size: u32, pub(crate) conn_window_size: u32,
pub(crate) stream_window_size: u32, pub(crate) stream_window_size: u32,
pub(crate) local_address: Option<IpAddr>, pub(crate) local_address: Option<IpAddr>,
} }
impl Default for ConnectorConfig { /// Connector configuration
#[derive(Clone)]
pub struct ConnectorConfig {
pub(crate) default_connect_config: ConnectConfig,
pub(crate) disconnect_timeout: Option<Duration>,
pub(crate) limit: usize,
}
impl Default for ConnectConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
timeout: Duration::from_secs(5), timeout: Duration::from_secs(5),
handshake_timeout: Duration::from_secs(5), handshake_timeout: Duration::from_secs(5),
conn_lifetime: Duration::from_secs(75), conn_lifetime: Duration::from_secs(75),
conn_keep_alive: Duration::from_secs(15), conn_keep_alive: Duration::from_secs(15),
disconnect_timeout: Some(Duration::from_millis(3000)),
limit: 100,
conn_window_size: DEFAULT_H2_CONN_WINDOW, conn_window_size: DEFAULT_H2_CONN_WINDOW,
stream_window_size: DEFAULT_H2_STREAM_WINDOW, stream_window_size: DEFAULT_H2_STREAM_WINDOW,
local_address: None, local_address: None,
@ -33,10 +37,88 @@ impl Default for ConnectorConfig {
} }
} }
impl Default for ConnectorConfig {
fn default() -> Self {
Self {
default_connect_config: ConnectConfig::default(),
disconnect_timeout: Some(Duration::from_millis(3000)),
limit: 100,
}
}
}
impl ConnectorConfig { impl ConnectorConfig {
pub(crate) fn no_disconnect_timeout(&self) -> Self { pub fn no_disconnect_timeout(&self) -> Self {
let mut res = self.clone(); let mut res = self.clone();
res.disconnect_timeout = None; res.disconnect_timeout = None;
res res
} }
} }
impl ConnectConfig {
/// Sets TCP connection timeout.
///
/// This is the max time allowed to connect to remote host, including DNS name resolution.
///
/// By default, the timeout is 5 seconds.
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
/// Sets TLS handshake timeout.
///
/// This is the max time allowed to perform the TLS handshake with remote host after TCP
/// connection is established.
///
/// By default, the timeout is 5 seconds.
pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
self.handshake_timeout = timeout;
self
}
/// Sets the initial window size (in bytes) for HTTP/2 stream-level flow control for received
/// data.
///
/// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_window_size(mut self, size: u32) -> Self {
self.stream_window_size = size;
self
}
/// Sets the initial window size (in bytes) for HTTP/2 connection-level flow control for
/// received data.
///
/// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_connection_window_size(mut self, size: u32) -> Self {
self.conn_window_size = size;
self
}
/// Set keep-alive period for opened connection.
///
/// Keep-alive period is the period between connection usage. If
/// the delay between repeated usages of the same connection
/// exceeds this period, the connection is closed.
/// Default keep-alive period is 15 seconds.
pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
self.conn_keep_alive = dur;
self
}
/// Set max lifetime period for connection.
///
/// Connection lifetime is max lifetime of any opened connection
/// until it is closed regardless of keep-alive period.
/// Default lifetime period is 75 seconds.
pub fn conn_lifetime(mut self, dur: Duration) -> Self {
self.conn_lifetime = dur;
self
}
/// Set local IP Address the connector would use for establishing connection.
pub fn local_address(mut self, addr: IpAddr) -> Self {
self.local_address = Some(addr);
self
}
}

View File

@ -16,10 +16,9 @@ use actix_rt::{
use actix_service::Service; use actix_service::Service;
use actix_tls::connect::{ use actix_tls::connect::{
ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection, ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection,
Connector as TcpConnector, Resolver, Connector as TcpConnector, Host, Resolver,
}; };
use futures_core::{future::LocalBoxFuture, ready}; use futures_core::{future::LocalBoxFuture, ready};
use http::Uri;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use super::{ use super::{
@ -27,9 +26,41 @@ use super::{
connection::{Connection, ConnectionIo}, connection::{Connection, ConnectionIo},
error::ConnectError, error::ConnectError,
pool::ConnectionPool, pool::ConnectionPool,
Connect, Connect, ServerName,
}; };
pub enum HostnameWithSni {
ForTcp(String, u16, Option<ServerName>),
ForTls(String, u16, Option<ServerName>),
}
impl Host for HostnameWithSni {
fn hostname(&self) -> &str {
match self {
HostnameWithSni::ForTcp(hostname, _, _) => hostname,
HostnameWithSni::ForTls(hostname, _, sni) => sni.as_deref().unwrap_or(hostname),
}
}
fn port(&self) -> Option<u16> {
match self {
HostnameWithSni::ForTcp(_, port, _) => Some(*port),
HostnameWithSni::ForTls(_, port, _) => Some(*port),
}
}
}
impl HostnameWithSni {
pub fn to_tls(self) -> Self {
match self {
HostnameWithSni::ForTcp(hostname, port, sni) => {
HostnameWithSni::ForTls(hostname, port, sni)
}
HostnameWithSni::ForTls(_, _, _) => self,
}
}
}
enum OurTlsConnector { enum OurTlsConnector {
#[allow(dead_code)] // only dead when no TLS feature is enabled #[allow(dead_code)] // only dead when no TLS feature is enabled
None, None,
@ -95,8 +126,8 @@ impl Connector<()> {
#[allow(clippy::new_ret_no_self, clippy::let_unit_value)] #[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
pub fn new() -> Connector< pub fn new() -> Connector<
impl Service< impl Service<
ConnectInfo<Uri>, ConnectInfo<HostnameWithSni>,
Response = TcpConnection<Uri, TcpStream>, Response = TcpConnection<HostnameWithSni, TcpStream>,
Error = actix_tls::connect::ConnectError, Error = actix_tls::connect::ConnectError,
> + Clone, > + Clone,
> { > {
@ -214,8 +245,11 @@ impl<S> Connector<S> {
pub fn connector<S1, Io1>(self, connector: S1) -> Connector<S1> pub fn connector<S1, Io1>(self, connector: S1) -> Connector<S1>
where where
Io1: ActixStream + fmt::Debug + 'static, Io1: ActixStream + fmt::Debug + 'static,
S1: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io1>, Error = TcpConnectError> S1: Service<
+ Clone, ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io1>,
Error = TcpConnectError,
> + Clone,
{ {
Connector { Connector {
connector, connector,
@ -235,8 +269,11 @@ where
// This remap is to hide ActixStream's trait methods. They are not meant to be called // This remap is to hide ActixStream's trait methods. They are not meant to be called
// from user code. // from user code.
IO: ActixStream + fmt::Debug + 'static, IO: ActixStream + fmt::Debug + 'static,
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, IO>, Error = TcpConnectError> S: Service<
+ Clone ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, IO>,
Error = TcpConnectError,
> + Clone
+ 'static, + 'static,
{ {
/// Sets TCP connection timeout. /// Sets TCP connection timeout.
@ -245,7 +282,7 @@ where
/// ///
/// By default, the timeout is 5 seconds. /// By default, the timeout is 5 seconds.
pub fn timeout(mut self, timeout: Duration) -> Self { pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = timeout; self.config.default_connect_config.timeout = timeout;
self self
} }
@ -256,7 +293,7 @@ where
/// ///
/// By default, the timeout is 5 seconds. /// By default, the timeout is 5 seconds.
pub fn handshake_timeout(mut self, timeout: Duration) -> Self { pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
self.config.handshake_timeout = timeout; self.config.default_connect_config.handshake_timeout = timeout;
self self
} }
@ -350,7 +387,7 @@ where
/// ///
/// The default value is 65,535 and is good for APIs, but not for big objects. /// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_window_size(mut self, size: u32) -> Self { pub fn initial_window_size(mut self, size: u32) -> Self {
self.config.stream_window_size = size; self.config.default_connect_config.stream_window_size = size;
self self
} }
@ -359,7 +396,7 @@ where
/// ///
/// The default value is 65,535 and is good for APIs, but not for big objects. /// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_connection_window_size(mut self, size: u32) -> Self { pub fn initial_connection_window_size(mut self, size: u32) -> Self {
self.config.conn_window_size = size; self.config.default_connect_config.conn_window_size = size;
self self
} }
@ -385,7 +422,7 @@ where
/// exceeds this period, the connection is closed. /// exceeds this period, the connection is closed.
/// Default keep-alive period is 15 seconds. /// Default keep-alive period is 15 seconds.
pub fn conn_keep_alive(mut self, dur: Duration) -> Self { pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
self.config.conn_keep_alive = dur; self.config.default_connect_config.conn_keep_alive = dur;
self self
} }
@ -395,7 +432,7 @@ where
/// until it is closed regardless of keep-alive period. /// until it is closed regardless of keep-alive period.
/// Default lifetime period is 75 seconds. /// Default lifetime period is 75 seconds.
pub fn conn_lifetime(mut self, dur: Duration) -> Self { pub fn conn_lifetime(mut self, dur: Duration) -> Self {
self.config.conn_lifetime = dur; self.config.default_connect_config.conn_lifetime = dur;
self self
} }
@ -414,7 +451,7 @@ where
/// Set local IP Address the connector would use for establishing connection. /// Set local IP Address the connector would use for establishing connection.
pub fn local_address(mut self, addr: IpAddr) -> Self { pub fn local_address(mut self, addr: IpAddr) -> Self {
self.config.local_address = Some(addr); self.config.default_connect_config.local_address = Some(addr);
self self
} }
@ -422,8 +459,8 @@ where
/// ///
/// The `Connector` builder always concludes by calling `finish()` last in its combinator chain. /// The `Connector` builder always concludes by calling `finish()` last in its combinator chain.
pub fn finish(self) -> ConnectorService<S, IO> { pub fn finish(self) -> ConnectorService<S, IO> {
let local_address = self.config.local_address; let local_address = self.config.default_connect_config.local_address;
let timeout = self.config.timeout; let timeout = self.config.default_connect_config.timeout;
let tcp_service_inner = let tcp_service_inner =
TcpConnectorInnerService::new(self.connector, timeout, local_address); TcpConnectorInnerService::new(self.connector, timeout, local_address);
@ -454,7 +491,7 @@ where
use actix_utils::future::{ready, Ready}; use actix_utils::future::{ready, Ready};
#[allow(non_local_definitions)] #[allow(non_local_definitions)]
impl IntoConnectionIo for TcpConnection<Uri, Box<dyn ConnectionIo>> { impl IntoConnectionIo for TcpConnection<HostnameWithSni, Box<dyn ConnectionIo>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) { fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let io = self.into_parts().0; let io = self.into_parts().0;
(io, Protocol::Http2) (io, Protocol::Http2)
@ -486,7 +523,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -505,7 +542,7 @@ where
use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector}; use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector};
#[allow(non_local_definitions)] #[allow(non_local_definitions)]
impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncSslStream<IO>> { impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncSslStream<IO>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) { fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0; let sock = self.into_parts().0;
let h2 = sock let h2 = sock
@ -521,7 +558,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -544,7 +581,7 @@ where
use actix_tls::connect::rustls_0_20::{reexports::AsyncTlsStream, TlsConnector}; use actix_tls::connect::rustls_0_20::{reexports::AsyncTlsStream, TlsConnector};
#[allow(non_local_definitions)] #[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> { impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) { fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0; let sock = self.into_parts().0;
let h2 = sock let h2 = sock
@ -561,7 +598,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -579,7 +616,7 @@ where
use actix_tls::connect::rustls_0_21::{reexports::AsyncTlsStream, TlsConnector}; use actix_tls::connect::rustls_0_21::{reexports::AsyncTlsStream, TlsConnector};
#[allow(non_local_definitions)] #[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> { impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) { fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0; let sock = self.into_parts().0;
let h2 = sock let h2 = sock
@ -596,7 +633,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -617,7 +654,7 @@ where
use actix_tls::connect::rustls_0_22::{reexports::AsyncTlsStream, TlsConnector}; use actix_tls::connect::rustls_0_22::{reexports::AsyncTlsStream, TlsConnector};
#[allow(non_local_definitions)] #[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> { impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) { fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0; let sock = self.into_parts().0;
let h2 = sock let h2 = sock
@ -634,7 +671,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -652,7 +689,7 @@ where
use actix_tls::connect::rustls_0_23::{reexports::AsyncTlsStream, TlsConnector}; use actix_tls::connect::rustls_0_23::{reexports::AsyncTlsStream, TlsConnector};
#[allow(non_local_definitions)] #[allow(non_local_definitions)]
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> { impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<HostnameWithSni, AsyncTlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) { fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0; let sock = self.into_parts().0;
let h2 = sock let h2 = sock
@ -669,7 +706,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -693,7 +730,7 @@ where
} }
} }
/// tcp service for map `TcpConnection<Uri, Io>` type to `(Io, Protocol)` /// tcp service for map `TcpConnection<HostnameWithSni, Io>` type to `(Io, Protocol)`
#[derive(Clone)] #[derive(Clone)]
pub struct TcpConnectorService<S: Clone> { pub struct TcpConnectorService<S: Clone> {
service: S, service: S,
@ -701,7 +738,9 @@ pub struct TcpConnectorService<S: Clone> {
impl<S, Io> Service<Connect> for TcpConnectorService<S> impl<S, Io> Service<Connect> for TcpConnectorService<S>
where where
S: Service<Connect, Response = TcpConnection<Uri, Io>, Error = ConnectError> + Clone + 'static, S: Service<Connect, Response = TcpConnection<HostnameWithSni, Io>, Error = ConnectError>
+ Clone
+ 'static,
{ {
type Response = (Io, Protocol); type Response = (Io, Protocol);
type Error = ConnectError; type Error = ConnectError;
@ -726,7 +765,7 @@ pin_project! {
impl<Fut, Io> Future for TcpConnectorFuture<Fut> impl<Fut, Io> Future for TcpConnectorFuture<Fut>
where where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>, Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
{ {
type Output = Result<(Io, Protocol), ConnectError>; type Output = Result<(Io, Protocol), ConnectError>;
@ -772,9 +811,10 @@ struct TlsConnectorService<Tcp, Tls> {
))] ))]
impl<Tcp, Tls, IO> Service<Connect> for TlsConnectorService<Tcp, Tls> impl<Tcp, Tls, IO> Service<Connect> for TlsConnectorService<Tcp, Tls>
where where
Tcp: Tcp: Service<Connect, Response = TcpConnection<HostnameWithSni, IO>, Error = ConnectError>
Service<Connect, Response = TcpConnection<Uri, IO>, Error = ConnectError> + Clone + 'static, + Clone
Tls: Service<TcpConnection<Uri, IO>, Error = std::io::Error> + Clone + 'static, + 'static,
Tls: Service<TcpConnection<HostnameWithSni, IO>, Error = std::io::Error> + Clone + 'static,
Tls::Response: IntoConnectionIo, Tls::Response: IntoConnectionIo,
IO: ConnectionIo, IO: ConnectionIo,
{ {
@ -789,9 +829,13 @@ where
} }
fn call(&self, req: Connect) -> Self::Future { fn call(&self, req: Connect) -> Self::Future {
let timeout = req
.config
.clone()
.map(|c| c.handshake_timeout)
.unwrap_or(self.timeout);
let fut = self.tcp_service.call(req); let fut = self.tcp_service.call(req);
let tls_service = self.tls_service.clone(); let tls_service = self.tls_service.clone();
let timeout = self.timeout;
TlsConnectorFuture::TcpConnect { TlsConnectorFuture::TcpConnect {
fut, fut,
@ -827,9 +871,14 @@ trait IntoConnectionIo {
impl<S, Io, Fut1, Fut2, Res> Future for TlsConnectorFuture<S, Fut1, Fut2> impl<S, Io, Fut1, Fut2, Res> Future for TlsConnectorFuture<S, Fut1, Fut2>
where where
S: Service<TcpConnection<Uri, Io>, Response = Res, Error = std::io::Error, Future = Fut2>, S: Service<
TcpConnection<HostnameWithSni, Io>,
Response = Res,
Error = std::io::Error,
Future = Fut2,
>,
S::Response: IntoConnectionIo, S::Response: IntoConnectionIo,
Fut1: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>, Fut1: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>>,
Fut2: Future<Output = Result<S::Response, S::Error>>, Fut2: Future<Output = Result<S::Response, S::Error>>,
Io: ConnectionIo, Io: ConnectionIo,
{ {
@ -843,10 +892,11 @@ where
timeout, timeout,
} => { } => {
let res = ready!(fut.poll(cx))?; let res = ready!(fut.poll(cx))?;
let (io, hostname_with_sni) = res.into_parts();
let fut = tls_service let fut = tls_service
.take() .take()
.expect("TlsConnectorFuture polled after complete") .expect("TlsConnectorFuture polled after complete")
.call(res); .call(TcpConnection::new(hostname_with_sni.to_tls(), io));
let timeout = sleep(*timeout); let timeout = sleep(*timeout);
self.set(TlsConnectorFuture::TlsConnect { fut, timeout }); self.set(TlsConnectorFuture::TlsConnect { fut, timeout });
self.poll(cx) self.poll(cx)
@ -880,8 +930,11 @@ impl<S: Clone> TcpConnectorInnerService<S> {
impl<S, Io> Service<Connect> for TcpConnectorInnerService<S> impl<S, Io> Service<Connect> for TcpConnectorInnerService<S>
where where
S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError> S: Service<
+ Clone ConnectInfo<HostnameWithSni>,
Response = TcpConnection<HostnameWithSni, Io>,
Error = TcpConnectError,
> + Clone
+ 'static, + 'static,
{ {
type Response = S::Response; type Response = S::Response;
@ -891,7 +944,14 @@ where
actix_service::forward_ready!(service); actix_service::forward_ready!(service);
fn call(&self, req: Connect) -> Self::Future { fn call(&self, req: Connect) -> Self::Future {
let mut req = ConnectInfo::new(req.uri).set_addr(req.addr); let timeout = req.config.map(|c| c.timeout).unwrap_or(self.timeout);
let mut req = ConnectInfo::new(HostnameWithSni::ForTcp(
req.hostname,
req.port,
req.sni_host,
))
.set_addr(req.addr)
.set_port(req.port);
if let Some(local_addr) = self.local_address { if let Some(local_addr) = self.local_address {
req = req.set_local_addr(local_addr); req = req.set_local_addr(local_addr);
@ -899,7 +959,7 @@ where
TcpConnectorInnerFuture { TcpConnectorInnerFuture {
fut: self.service.call(req), fut: self.service.call(req),
timeout: sleep(self.timeout), timeout: sleep(timeout),
} }
} }
} }
@ -916,9 +976,9 @@ pin_project! {
impl<Fut, Io> Future for TcpConnectorInnerFuture<Fut> impl<Fut, Io> Future for TcpConnectorInnerFuture<Fut>
where where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, TcpConnectError>>, Fut: Future<Output = Result<TcpConnection<HostnameWithSni, Io>, TcpConnectError>>,
{ {
type Output = Result<TcpConnection<Uri, Io>, ConnectError>; type Output = Result<TcpConnection<HostnameWithSni, Io>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();
@ -978,16 +1038,17 @@ where
} }
fn call(&self, req: Connect) -> Self::Future { fn call(&self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { if req.tls {
Some("https") | Some("wss") => match self.tls_pool { match &self.tls_pool {
None => ConnectorServiceFuture::SslIsNotSupported, None => ConnectorServiceFuture::SslIsNotSupported,
Some(ref pool) => ConnectorServiceFuture::Tls { Some(pool) => ConnectorServiceFuture::Tls {
fut: pool.call(req), fut: pool.call(req),
}, },
}, }
_ => ConnectorServiceFuture::Tcp { } else {
ConnectorServiceFuture::Tcp {
fut: self.tcp_pool.call(req), fut: self.tcp_pool.call(req),
}, }
} }
} }
} }

View File

@ -19,7 +19,6 @@ use http::{
use log::trace; use log::trace;
use super::{ use super::{
config::ConnectorConfig,
connection::{ConnectionIo, H2Connection}, connection::{ConnectionIo, H2Connection},
error::SendRequestError, error::SendRequestError,
}; };
@ -186,12 +185,13 @@ where
pub(crate) fn handshake<Io: ConnectionIo>( pub(crate) fn handshake<Io: ConnectionIo>(
io: Io, io: Io,
config: &ConnectorConfig, stream_window_size: u32,
conn_window_size: u32,
) -> impl Future<Output = Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>> { ) -> impl Future<Output = Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>> {
let mut builder = Builder::new(); let mut builder = Builder::new();
builder builder
.initial_window_size(config.stream_window_size) .initial_window_size(stream_window_size)
.initial_connection_window_size(config.conn_window_size) .initial_connection_window_size(conn_window_size)
.enable_push(false); .enable_push(false);
builder.handshake(io) builder.handshake(io)
} }

View File

@ -1,6 +1,6 @@
//! HTTP client. //! HTTP client.
use std::{rc::Rc, time::Duration}; use std::{ops::Deref, rc::Rc, time::Duration};
use actix_http::{error::HttpError, header::HeaderMap, Method, RequestHead, Uri}; use actix_http::{error::HttpError, header::HeaderMap, Method, RequestHead, Uri};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
@ -20,15 +20,37 @@ mod h2proto;
mod pool; mod pool;
pub use self::{ pub use self::{
config::ConnectConfig,
connection::{Connection, ConnectionIo}, connection::{Connection, ConnectionIo},
connector::{Connector, ConnectorService}, connector::{Connector, ConnectorService, HostnameWithSni},
error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}, error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError},
}; };
#[derive(Clone)] #[derive(Clone, Hash, PartialEq, Eq)]
pub enum ServerName {
Owned(String),
Borrowed(Rc<String>),
}
impl Deref for ServerName {
type Target = str;
fn deref(&self) -> &str {
match self {
ServerName::Owned(ref s) => s,
ServerName::Borrowed(ref s) => s,
}
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct Connect { pub struct Connect {
pub uri: Uri, pub hostname: String,
pub sni_host: Option<ServerName>,
pub port: u16,
pub tls: bool,
pub addr: Option<std::net::SocketAddr>, pub addr: Option<std::net::SocketAddr>,
pub config: Option<Rc<ConnectConfig>>,
} }
/// An asynchronous HTTP and WebSocket client. /// An asynchronous HTTP and WebSocket client.
@ -79,8 +101,8 @@ impl Client {
/// This function is equivalent of `ClientBuilder::new()`. /// This function is equivalent of `ClientBuilder::new()`.
pub fn builder() -> ClientBuilder< pub fn builder() -> ClientBuilder<
impl Service< impl Service<
ConnectInfo<Uri>, ConnectInfo<HostnameWithSni>,
Response = TcpConnection<Uri, TcpStream>, Response = TcpConnection<HostnameWithSni, TcpStream>,
Error = TcpConnectError, Error = TcpConnectError,
> + Clone, > + Clone,
> { > {

View File

@ -4,6 +4,7 @@ use std::{
cell::RefCell, cell::RefCell,
collections::{HashMap, VecDeque}, collections::{HashMap, VecDeque},
future::Future, future::Future,
hash::Hash,
io, io,
ops::Deref, ops::Deref,
pin::Pin, pin::Pin,
@ -127,7 +128,7 @@ where
Io: AsyncWrite + Unpin + 'static, Io: AsyncWrite + Unpin + 'static,
{ {
config: ConnectorConfig, config: ConnectorConfig,
available: RefCell<HashMap<Key, VecDeque<PooledConnection<Io>>>>, available: RefCell<HashMap<Connect, VecDeque<PooledConnection<Io>>>>,
permits: Arc<Semaphore>, permits: Arc<Semaphore>,
} }
@ -168,12 +169,6 @@ where
let inner = self.inner.clone(); let inner = self.inner.clone();
Box::pin(async move { Box::pin(async move {
let key = if let Some(authority) = req.uri.authority() {
authority.clone().into()
} else {
return Err(ConnectError::Unresolved);
};
// acquire an owned permit and carry it with connection // acquire an owned permit and carry it with connection
let permit = Arc::clone(&inner.permits) let permit = Arc::clone(&inner.permits)
.acquire_owned() .acquire_owned()
@ -191,11 +186,15 @@ where
// check if there is idle connection for given key. // check if there is idle connection for given key.
let mut map = inner.available.borrow_mut(); let mut map = inner.available.borrow_mut();
if let Some(conns) = map.get_mut(&key) { if let Some(conns) = map.get_mut(&req) {
let now = Instant::now(); let now = Instant::now();
while let Some(mut c) = conns.pop_front() { while let Some(mut c) = conns.pop_front() {
let config = &inner.config; let config = req
.config
.as_ref()
.map(|c| c.as_ref())
.unwrap_or(&inner.config.default_connect_config);
let idle_dur = now - c.used; let idle_dur = now - c.used;
let age = now - c.created; let age = now - c.created;
let conn_ineligible = let conn_ineligible =
@ -230,9 +229,24 @@ where
conn conn
}; };
let stream_window_size = req
.config
.as_ref()
.map(|c| c.stream_window_size)
.unwrap_or(inner.config.default_connect_config.stream_window_size);
let conn_window_size = req
.config
.as_ref()
.map(|c| c.conn_window_size)
.unwrap_or(inner.config.default_connect_config.conn_window_size);
// construct acquired. It's used to put Io type back to pool/ close the Io type. // construct acquired. It's used to put Io type back to pool/ close the Io type.
// permit is carried with the whole lifecycle of Acquired. // permit is carried with the whole lifecycle of Acquired.
let acquired = Acquired { key, inner, permit }; let acquired = Acquired {
req: req.clone(),
inner,
permit,
};
// match the connection and spawn new one if did not get anything. // match the connection and spawn new one if did not get anything.
match conn { match conn {
@ -246,8 +260,8 @@ where
if proto == Protocol::Http1 { if proto == Protocol::Http1 {
Ok(ConnectionType::from_h1(io, Instant::now(), acquired)) Ok(ConnectionType::from_h1(io, Instant::now(), acquired))
} else { } else {
let config = &acquired.inner.config; let (sender, connection) =
let (sender, connection) = handshake(io, config).await?; handshake(io, stream_window_size, conn_window_size).await?;
let inner = H2ConnectionInner::new(sender, connection); let inner = H2ConnectionInner::new(sender, connection);
Ok(ConnectionType::from_h2(inner, Instant::now(), acquired)) Ok(ConnectionType::from_h2(inner, Instant::now(), acquired))
} }
@ -344,8 +358,8 @@ pub struct Acquired<Io>
where where
Io: AsyncWrite + Unpin + 'static, Io: AsyncWrite + Unpin + 'static,
{ {
/// authority key for identify connection. /// hash key for identify connection.
key: Key, req: Connect,
/// handle to connection pool. /// handle to connection pool.
inner: ConnectionPoolInner<Io>, inner: ConnectionPoolInner<Io>,
/// permit for limit concurrent in-flight connection for a Client object. /// permit for limit concurrent in-flight connection for a Client object.
@ -360,12 +374,12 @@ impl<Io: ConnectionIo> Acquired<Io> {
/// Release IO back into pool. /// Release IO back into pool.
pub(super) fn release(&self, conn: ConnectionInnerType<Io>, created: Instant) { pub(super) fn release(&self, conn: ConnectionInnerType<Io>, created: Instant) {
let Acquired { key, inner, .. } = self; let Acquired { req, inner, .. } = self;
inner inner
.available .available
.borrow_mut() .borrow_mut()
.entry(key.clone()) .entry(req.clone())
.or_insert_with(VecDeque::new) .or_insert_with(VecDeque::new)
.push_back(PooledConnection { .push_back(PooledConnection {
conn, conn,
@ -381,9 +395,8 @@ impl<Io: ConnectionIo> Acquired<Io> {
mod test { mod test {
use std::cell::Cell; use std::cell::Cell;
use http::Uri;
use super::*; use super::*;
use crate::client::ConnectConfig;
/// A stream type that always returns pending on async read. /// A stream type that always returns pending on async read.
/// ///
@ -467,8 +480,12 @@ mod test {
let pool = super::ConnectionPool::new(connector, config); let pool = super::ConnectionPool::new(connector, config);
let req = Connect { let req = Connect {
uri: Uri::from_static("http://localhost"), hostname: "localhost".to_string(),
port: 80,
tls: false,
sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -500,15 +517,22 @@ mod test {
let connector = TestPoolConnector { generated }; let connector = TestPoolConnector { generated };
let config = ConnectorConfig { let config = ConnectorConfig {
conn_keep_alive: Duration::from_secs(1), default_connect_config: ConnectConfig {
conn_keep_alive: Duration::from_secs(1),
..Default::default()
},
..Default::default() ..Default::default()
}; };
let pool = super::ConnectionPool::new(connector, config); let pool = super::ConnectionPool::new(connector, config);
let req = Connect { let req = Connect {
uri: Uri::from_static("http://localhost"), hostname: "localhost".to_string(),
port: 80,
tls: false,
sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -542,15 +566,22 @@ mod test {
let connector = TestPoolConnector { generated }; let connector = TestPoolConnector { generated };
let config = ConnectorConfig { let config = ConnectorConfig {
conn_lifetime: Duration::from_secs(1), default_connect_config: ConnectConfig {
conn_lifetime: Duration::from_secs(1),
..Default::default()
},
..Default::default() ..Default::default()
}; };
let pool = super::ConnectionPool::new(connector, config); let pool = super::ConnectionPool::new(connector, config);
let req = Connect { let req = Connect {
uri: Uri::from_static("http://localhost"), hostname: "localhost".to_string(),
port: 80,
tls: false,
sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -588,8 +619,12 @@ mod test {
let pool = super::ConnectionPool::new(connector, config); let pool = super::ConnectionPool::new(connector, config);
let req = Connect { let req = Connect {
uri: Uri::from_static("https://crates.io"), hostname: "crates.io".to_string(),
port: 443,
tls: true,
sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -601,8 +636,12 @@ mod test {
release(conn); release(conn);
let req = Connect { let req = Connect {
uri: Uri::from_static("https://google.com"), hostname: "google.com".to_string(),
port: 443,
tls: true,
sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -625,8 +664,12 @@ mod test {
let pool = Rc::new(super::ConnectionPool::new(connector, config)); let pool = Rc::new(super::ConnectionPool::new(connector, config));
let req = Connect { let req = Connect {
uri: Uri::from_static("https://crates.io"), hostname: "crates.io".to_string(),
port: 443,
tls: true,
sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -634,8 +677,12 @@ mod test {
release(conn); release(conn);
let req = Connect { let req = Connect {
uri: Uri::from_static("https://google.com"), hostname: "google.com".to_string(),
port: 443,
tls: true,
sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(2, generated_clone.get()); assert_eq!(2, generated_clone.get());

View File

@ -13,7 +13,10 @@ use futures_core::{future::LocalBoxFuture, ready};
use crate::{ use crate::{
any_body::AnyBody, any_body::AnyBody,
client::{Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError}, client::{
Connect as ClientConnect, ConnectConfig, ConnectError, Connection, ConnectionIo,
SendRequestError, ServerName,
},
ClientResponse, ClientResponse,
}; };
@ -32,13 +35,24 @@ pub type BoxedSocket = Box<dyn ConnectionIo>;
pub enum ConnectRequest { pub enum ConnectRequest {
/// Standard HTTP request. /// Standard HTTP request.
/// ///
/// Contains the request head, body type, and optional pre-resolved socket address. /// Contains the request head, body type, optional pre-resolved socket address and optional sni host.
Client(RequestHeadType, AnyBody, Option<net::SocketAddr>), Client(
RequestHeadType,
AnyBody,
Option<net::SocketAddr>,
Option<ServerName>,
Option<Rc<ConnectConfig>>,
),
/// Tunnel used by WebSocket connection requests. /// Tunnel used by WebSocket connection requests.
/// ///
/// Contains the request head and optional pre-resolved socket address. /// Contains the request head, optional pre-resolved socket address and optional sni host.
Tunnel(RequestHead, Option<net::SocketAddr>), Tunnel(
RequestHead,
Option<net::SocketAddr>,
Option<ServerName>,
Option<Rc<ConnectConfig>>,
),
} }
/// Combined HTTP response & WebSocket tunnel type returned from connection service. /// Combined HTTP response & WebSocket tunnel type returned from connection service.
@ -103,17 +117,44 @@ where
fn call(&self, req: ConnectRequest) -> Self::Future { fn call(&self, req: ConnectRequest) -> Self::Future {
// connect to the host // connect to the host
let fut = match req { let (head, addr, sni_host, config) = match req {
ConnectRequest::Client(ref head, .., addr) => self.connector.call(ClientConnect { ConnectRequest::Client(ref head, .., addr, ref sni_host, ref config) => {
uri: head.as_ref().uri.clone(), (head.as_ref(), addr, sni_host.clone(), config.clone())
addr, }
}), ConnectRequest::Tunnel(ref head, addr, ref sni_host, ref config) => {
ConnectRequest::Tunnel(ref head, addr) => self.connector.call(ClientConnect { (head, addr, sni_host.clone(), config.clone())
uri: head.uri.clone(), }
addr,
}),
}; };
let authority = if let Some(authority) = head.uri.authority() {
authority
} else {
return ConnectRequestFuture::Error {
err: ConnectError::Unresolved,
};
};
let tls = match head.uri.scheme_str() {
Some("https") | Some("wss") => true,
_ => false,
};
let fut =
self.connector.call(ClientConnect {
hostname: authority.host().to_string(),
port: authority.port().map(|p| p.as_u16()).unwrap_or_else(|| {
if tls {
443
} else {
80
}
}),
tls,
sni_host,
addr,
config,
});
ConnectRequestFuture::Connection { ConnectRequestFuture::Connection {
fut, fut,
req: Some(req), req: Some(req),
@ -127,6 +168,9 @@ pin_project_lite::pin_project! {
where where
Io: ConnectionIo Io: ConnectionIo
{ {
Error {
err: ConnectError
},
Connection { Connection {
#[pin] #[pin]
fut: Fut, fut: Fut,
@ -192,6 +236,10 @@ where
let framed = framed.into_map_io(|io| Box::new(io) as _); let framed = framed.into_map_io(|io| Box::new(io) as _);
Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed))) Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed)))
} }
ConnectRequestProj::Error { .. } => {
Poll::Ready(Err(SendRequestError::Connect(ConnectError::Unresolved)))
}
} }
} }
} }

View File

@ -11,7 +11,7 @@ use futures_core::Stream;
use serde::Serialize; use serde::Serialize;
use crate::{ use crate::{
client::ClientConfig, client::{ClientConfig, ConnectConfig, ServerName},
sender::{RequestSender, SendClientRequest}, sender::{RequestSender, SendClientRequest},
BoxError, BoxError,
}; };
@ -26,6 +26,8 @@ pub struct FrozenClientRequest {
pub(crate) response_decompress: bool, pub(crate) response_decompress: bool,
pub(crate) timeout: Option<Duration>, pub(crate) timeout: Option<Duration>,
pub(crate) config: ClientConfig, pub(crate) config: ClientConfig,
pub(crate) sni_host: Option<ServerName>,
pub(crate) connect_config: Option<Rc<ConnectConfig>>,
} }
impl FrozenClientRequest { impl FrozenClientRequest {
@ -54,6 +56,8 @@ impl FrozenClientRequest {
self.response_decompress, self.response_decompress,
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(),
self.connect_config.clone(),
body, body,
) )
} }
@ -65,6 +69,8 @@ impl FrozenClientRequest {
self.response_decompress, self.response_decompress,
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(),
self.connect_config.clone(),
value, value,
) )
} }
@ -76,6 +82,8 @@ impl FrozenClientRequest {
self.response_decompress, self.response_decompress,
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(),
self.connect_config.clone(),
value, value,
) )
} }
@ -91,6 +99,8 @@ impl FrozenClientRequest {
self.response_decompress, self.response_decompress,
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(),
self.connect_config.clone(),
stream, stream,
) )
} }
@ -102,6 +112,8 @@ impl FrozenClientRequest {
self.response_decompress, self.response_decompress,
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(),
self.connect_config.clone(),
) )
} }
@ -156,6 +168,8 @@ impl FrozenSendBuilder {
self.req.response_decompress, self.req.response_decompress,
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(),
self.req.connect_config,
body, body,
) )
} }
@ -171,6 +185,8 @@ impl FrozenSendBuilder {
self.req.response_decompress, self.req.response_decompress,
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(),
self.req.connect_config,
value, value,
) )
} }
@ -186,6 +202,8 @@ impl FrozenSendBuilder {
self.req.response_decompress, self.req.response_decompress,
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(),
self.req.connect_config,
value, value,
) )
} }
@ -205,6 +223,8 @@ impl FrozenSendBuilder {
self.req.response_decompress, self.req.response_decompress,
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(),
self.req.connect_config,
stream, stream,
) )
} }
@ -220,6 +240,8 @@ impl FrozenSendBuilder {
self.req.response_decompress, self.req.response_decompress,
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(),
self.req.connect_config,
) )
} }
} }

View File

@ -73,11 +73,13 @@ where
fn call(&self, req: ConnectRequest) -> Self::Future { fn call(&self, req: ConnectRequest) -> Self::Future {
match req { match req {
ConnectRequest::Tunnel(head, addr) => { ConnectRequest::Tunnel(head, addr, sni_host, config) => {
let fut = self.connector.call(ConnectRequest::Tunnel(head, addr)); let fut = self
.connector
.call(ConnectRequest::Tunnel(head, addr, sni_host, config));
RedirectServiceFuture::Tunnel { fut } RedirectServiceFuture::Tunnel { fut }
} }
ConnectRequest::Client(head, body, addr) => { ConnectRequest::Client(head, body, addr, sni_host, config) => {
let connector = Rc::clone(&self.connector); let connector = Rc::clone(&self.connector);
let max_redirect_times = self.max_redirect_times; let max_redirect_times = self.max_redirect_times;
@ -96,7 +98,8 @@ where
_ => None, _ => None,
}; };
let fut = connector.call(ConnectRequest::Client(head, body, addr)); let fut =
connector.call(ConnectRequest::Client(head, body, addr, sni_host, config));
RedirectServiceFuture::Client { RedirectServiceFuture::Client {
fut, fut,
@ -221,7 +224,8 @@ where
let fut = connector let fut = connector
.as_ref() .as_ref()
.unwrap() .unwrap()
.call(ConnectRequest::Client(head, body_new, addr)); // @TODO find a way to get sni host and config
.call(ConnectRequest::Client(head, body_new, addr, None, None));
self.set(RedirectServiceFuture::Client { self.set(RedirectServiceFuture::Client {
fut, fut,

View File

@ -14,7 +14,7 @@ use serde::Serialize;
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
use crate::cookie::{Cookie, CookieJar}; use crate::cookie::{Cookie, CookieJar};
use crate::{ use crate::{
client::ClientConfig, client::{ClientConfig, ConnectConfig, ServerName},
error::{FreezeRequestError, InvalidUrl}, error::{FreezeRequestError, InvalidUrl},
frozen::FrozenClientRequest, frozen::FrozenClientRequest,
sender::{PrepForSendingError, RequestSender, SendClientRequest}, sender::{PrepForSendingError, RequestSender, SendClientRequest},
@ -48,6 +48,8 @@ pub struct ClientRequest {
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
config: ClientConfig, config: ClientConfig,
sni_host: Option<ServerName>,
connect_config: Option<ConnectConfig>,
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookies: Option<CookieJar>, cookies: Option<CookieJar>,
@ -69,6 +71,8 @@ impl ClientRequest {
cookies: None, cookies: None,
timeout: None, timeout: None,
response_decompress: true, response_decompress: true,
sni_host: None,
connect_config: None,
} }
.method(method) .method(method)
.uri(uri) .uri(uri)
@ -279,6 +283,15 @@ impl ClientRequest {
self self
} }
/// Set specific connector configuration for this request.
///
/// Not all config may be applied to the request, it depends on the connector and also
/// if there is already a connection established.
pub fn connect_config(mut self, config: ConnectConfig) -> Self {
self.connect_config = Some(config);
self
}
/// Set request timeout. Overrides client wide timeout setting. /// Set request timeout. Overrides client wide timeout setting.
/// ///
/// Request timeout is the total time before a response must be received. /// Request timeout is the total time before a response must be received.
@ -306,6 +319,12 @@ impl ClientRequest {
Ok(self) Ok(self)
} }
/// Set SNI (Server Name Indication) host for this request.
pub fn sni_host(mut self, host: impl Into<String>) -> Self {
self.sni_host = Some(ServerName::Owned(host.into()));
self
}
/// Freeze request builder and construct `FrozenClientRequest`, /// Freeze request builder and construct `FrozenClientRequest`,
/// which could be used for sending same request multiple times. /// which could be used for sending same request multiple times.
pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> { pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> {
@ -320,6 +339,11 @@ impl ClientRequest {
response_decompress: slf.response_decompress, response_decompress: slf.response_decompress,
timeout: slf.timeout, timeout: slf.timeout,
config: slf.config, config: slf.config,
sni_host: slf.sni_host.map(|v| match v {
ServerName::Borrowed(r) => ServerName::Borrowed(r),
ServerName::Owned(o) => ServerName::Borrowed(Rc::new(o)),
}),
connect_config: slf.connect_config.map(Rc::new),
}; };
Ok(request) Ok(request)
@ -340,6 +364,8 @@ impl ClientRequest {
slf.response_decompress, slf.response_decompress,
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host,
slf.connect_config.map(Rc::new),
body, body,
) )
} }
@ -356,6 +382,8 @@ impl ClientRequest {
slf.response_decompress, slf.response_decompress,
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host,
slf.connect_config.map(Rc::new),
value, value,
) )
} }
@ -374,6 +402,8 @@ impl ClientRequest {
slf.response_decompress, slf.response_decompress,
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host,
slf.connect_config.map(Rc::new),
value, value,
) )
} }
@ -394,6 +424,8 @@ impl ClientRequest {
slf.response_decompress, slf.response_decompress,
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host,
slf.connect_config.map(Rc::new),
stream, stream,
) )
} }
@ -410,6 +442,8 @@ impl ClientRequest {
slf.response_decompress, slf.response_decompress,
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host,
slf.connect_config.map(Rc::new),
) )
} }

View File

@ -23,7 +23,7 @@ use serde::Serialize;
use crate::{ use crate::{
any_body::AnyBody, any_body::AnyBody,
client::ClientConfig, client::{ClientConfig, ConnectConfig, ServerName},
error::{FreezeRequestError, InvalidUrl, SendRequestError}, error::{FreezeRequestError, InvalidUrl, SendRequestError},
BoxError, ClientResponse, ConnectRequest, ConnectResponse, BoxError, ClientResponse, ConnectRequest, ConnectResponse,
}; };
@ -186,6 +186,8 @@ impl RequestSender {
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>,
connect_config: Option<Rc<ConnectConfig>>,
body: impl MessageBody + 'static, body: impl MessageBody + 'static,
) -> SendClientRequest { ) -> SendClientRequest {
let req = match self { let req = match self {
@ -193,11 +195,15 @@ impl RequestSender {
RequestHeadType::Owned(head), RequestHeadType::Owned(head),
AnyBody::from_message_body(body).into_boxed(), AnyBody::from_message_body(body).into_boxed(),
addr, addr,
sni_host,
connect_config,
), ),
RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestSender::Rc(head, extra_headers) => ConnectRequest::Client(
RequestHeadType::Rc(head, extra_headers), RequestHeadType::Rc(head, extra_headers),
AnyBody::from_message_body(body).into_boxed(), AnyBody::from_message_body(body).into_boxed(),
addr, addr,
sni_host,
connect_config,
), ),
}; };
@ -212,6 +218,8 @@ impl RequestSender {
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
value: impl Serialize, value: impl Serialize,
) -> SendClientRequest { ) -> SendClientRequest {
let body = match serde_json::to_string(&value) { let body = match serde_json::to_string(&value) {
@ -223,7 +231,15 @@ impl RequestSender {
return err.into(); return err.into();
} }
self.send_body(addr, response_decompress, timeout, config, body) self.send_body(
addr,
response_decompress,
timeout,
config,
sni_host,
connector_config,
body,
)
} }
pub(crate) fn send_form( pub(crate) fn send_form(
@ -232,6 +248,8 @@ impl RequestSender {
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
value: impl Serialize, value: impl Serialize,
) -> SendClientRequest { ) -> SendClientRequest {
let body = match serde_urlencoded::to_string(value) { let body = match serde_urlencoded::to_string(value) {
@ -246,7 +264,15 @@ impl RequestSender {
return err.into(); return err.into();
} }
self.send_body(addr, response_decompress, timeout, config, body) self.send_body(
addr,
response_decompress,
timeout,
config,
sni_host,
connector_config,
body,
)
} }
pub(crate) fn send_stream<S, E>( pub(crate) fn send_stream<S, E>(
@ -255,6 +281,8 @@ impl RequestSender {
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
stream: S, stream: S,
) -> SendClientRequest ) -> SendClientRequest
where where
@ -266,6 +294,8 @@ impl RequestSender {
response_decompress, response_decompress,
timeout, timeout,
config, config,
sni_host,
connector_config,
BodyStream::new(stream), BodyStream::new(stream),
) )
} }
@ -276,8 +306,18 @@ impl RequestSender {
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
) -> SendClientRequest { ) -> SendClientRequest {
self.send_body(addr, response_decompress, timeout, config, ()) self.send_body(
addr,
response_decompress,
timeout,
config,
sni_host,
connector_config,
(),
)
} }
fn set_header_if_none<V>(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> fn set_header_if_none<V>(&mut self, key: HeaderName, value: V) -> Result<(), HttpError>

View File

@ -26,7 +26,7 @@
//! } //! }
//! ``` //! ```
use std::{fmt, net::SocketAddr, str}; use std::{fmt, net::SocketAddr, rc::Rc, str};
use actix_codec::Framed; use actix_codec::Framed;
pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
@ -38,7 +38,7 @@ use base64::prelude::*;
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
use crate::cookie::{Cookie, CookieJar}; use crate::cookie::{Cookie, CookieJar};
use crate::{ use crate::{
client::ClientConfig, client::{ClientConfig, ConnectConfig, ServerName},
connect::{BoxedSocket, ConnectRequest}, connect::{BoxedSocket, ConnectRequest},
error::{HttpError, InvalidUrl, SendRequestError, WsClientError}, error::{HttpError, InvalidUrl, SendRequestError, WsClientError},
http::{ http::{
@ -58,6 +58,8 @@ pub struct WebsocketsRequest {
max_size: usize, max_size: usize,
server_mode: bool, server_mode: bool,
config: ClientConfig, config: ClientConfig,
sni_host: Option<ServerName>,
connect_config: Option<ConnectConfig>,
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookies: Option<CookieJar>, cookies: Option<CookieJar>,
@ -96,6 +98,8 @@ impl WebsocketsRequest {
server_mode: false, server_mode: false,
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookies: None, cookies: None,
sni_host: None,
connect_config: None,
} }
} }
@ -108,6 +112,15 @@ impl WebsocketsRequest {
self self
} }
/// Set specific connector configuration for this request.
///
/// Not all config may be applied to the request, it depends on the connector and also
/// if there is already a connection established.
pub fn connector_config(mut self, config: ConnectConfig) -> Self {
self.connect_config = Some(config);
self
}
/// Set supported WebSocket protocols /// Set supported WebSocket protocols
pub fn protocols<U, V>(mut self, protos: U) -> Self pub fn protocols<U, V>(mut self, protos: U) -> Self
where where
@ -249,6 +262,12 @@ impl WebsocketsRequest {
self.header(AUTHORIZATION, format!("Bearer {}", token)) self.header(AUTHORIZATION, format!("Bearer {}", token))
} }
/// Set SNI (Server Name Indication) host for this request.
pub fn sni_host(mut self, host: impl Into<String>) -> Self {
self.sni_host = Some(ServerName::Owned(host.into()));
self
}
/// Complete request construction and connect to a WebSocket server. /// Complete request construction and connect to a WebSocket server.
pub async fn connect( pub async fn connect(
mut self, mut self,
@ -338,7 +357,12 @@ impl WebsocketsRequest {
let max_size = self.max_size; let max_size = self.max_size;
let server_mode = self.server_mode; let server_mode = self.server_mode;
let req = ConnectRequest::Tunnel(head, self.addr); let req = ConnectRequest::Tunnel(
head,
self.addr,
self.sni_host,
self.connect_config.map(Rc::new),
);
let fut = self.config.connector.call(req); let fut = self.config.connector.call(req);

View File

@ -43,6 +43,8 @@ fn tls_config() -> ServerConfig {
} }
mod danger { mod danger {
use std::collections::HashSet;
use rustls::{ use rustls::{
client::danger::{ServerCertVerified, ServerCertVerifier}, client::danger::{ServerCertVerified, ServerCertVerifier},
pki_types::UnixTime, pki_types::UnixTime,
@ -50,8 +52,10 @@ mod danger {
use super::*; use super::*;
#[derive(Debug)] #[derive(Debug, Default)]
pub struct NoCertificateVerification; pub struct NoCertificateVerification {
pub trusted_hosts: HashSet<String>,
}
impl ServerCertVerifier for NoCertificateVerification { impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert( fn verify_server_cert(
@ -62,7 +66,15 @@ mod danger {
_ocsp_response: &[u8], _ocsp_response: &[u8],
_now: UnixTime, _now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> { ) -> Result<ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion()) if self.trusted_hosts.is_empty() {
return Ok(ServerCertVerified::assertion());
}
if self.trusted_hosts.contains(_server_name.to_str().as_ref()) {
return Ok(ServerCertVerified::assertion());
}
Err(rustls::Error::General("untrusted host".into()))
} }
fn verify_tls12_signature( fn verify_tls12_signature(
@ -124,7 +136,7 @@ async fn test_connection_reuse_h2() {
// disable TLS verification // disable TLS verification
config config
.dangerous() .dangerous()
.set_certificate_verifier(Arc::new(danger::NoCertificateVerification)); .set_certificate_verifier(Arc::new(danger::NoCertificateVerification::default()));
let client = awc::Client::builder() let client = awc::Client::builder()
.connector(awc::Connector::new().rustls_0_23(Arc::new(config))) .connector(awc::Connector::new().rustls_0_23(Arc::new(config)))
@ -144,3 +156,84 @@ async fn test_connection_reuse_h2() {
// one connection // one connection
assert_eq!(num.load(Ordering::Relaxed), 1); assert_eq!(num.load(Ordering::Relaxed), 1);
} }
#[actix_rt::test]
async fn test_connection_with_sni() {
let srv = test_server(move || {
HttpService::build()
.h2(map_config(
App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))),
|_| AppConfig::default(),
))
.rustls_0_23(tls_config())
.map_err(|_| ())
})
.await;
let mut config = ClientConfig::builder()
.with_root_certificates(webpki_roots_cert_store())
.with_no_client_auth();
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
config.alpn_protocols = protos;
// disable TLS verification
config
.dangerous()
.set_certificate_verifier(Arc::new(danger::NoCertificateVerification {
trusted_hosts: ["localhost".to_owned()].iter().cloned().collect(),
}));
let client = awc::Client::builder()
.connector(awc::Connector::new().rustls_0_23(Arc::new(config)))
.finish();
// req : standard request
let request = client.get(srv.surl("/")).send();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req : test specific host with address, return trusted host
let request = client.get(srv.surl("/")).sni_host("localhost").send();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req : test bad host, return untrusted host
let request = client.get(srv.surl("/")).sni_host("bad.host").send();
let response = request.await;
assert!(response.is_err());
assert_eq!(
response.unwrap_err().to_string(),
"Failed to connect to host: unexpected error: untrusted host"
);
// req : test specific host with address, return untrusted host
let addr = srv.addr();
let request = client.get("https://example.com:443/").address(addr).send();
let response = request.await;
assert!(response.is_err());
assert_eq!(
response.unwrap_err().to_string(),
"Failed to connect to host: unexpected error: untrusted host"
);
// req : test specify sni_host with address and other host (authority)
let request = client
.get("https://example.com:443/")
.address(addr)
.sni_host("localhost")
.send();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req : test ip address with sni host
let request = client
.get("https://127.0.0.1:443/")
.address(addr)
.sni_host("localhost")
.send();
let response = request.await.unwrap();
assert!(response.status().is_success());
}