From 0915879267b4a18afef5a5afdef12113b5d4e567 Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Mon, 9 Dec 2024 11:16:54 +0100 Subject: [PATCH 1/2] feat(awc): allow to set a specific sni host on the request --- awc/CHANGES.md | 1 + awc/src/builder.rs | 22 ++++-- awc/src/client/connector.rs | 126 +++++++++++++++++++++++--------- awc/src/client/mod.rs | 32 ++++++-- awc/src/client/pool.rs | 62 ++++++++++------ awc/src/connect.rs | 67 +++++++++++++---- awc/src/frozen.rs | 13 +++- awc/src/middleware/redirect.rs | 13 ++-- awc/src/request.rs | 19 ++++- awc/src/sender.rs | 16 +++- awc/src/ws.rs | 12 ++- awc/tests/test_rustls_client.rs | 101 ++++++++++++++++++++++++- 12 files changed, 382 insertions(+), 102 deletions(-) diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 8a2a1ec4..2a1b4462 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -5,6 +5,7 @@ - Update `brotli` dependency to `7`. - Prevent panics on connection pool drop when Tokio runtime is shutdown early. - Minimum supported Rust version (MSRV) is now 1.75. +- Allow to set a specific SNI hostname on the request for TLS connections. ## 3.5.1 diff --git a/awc/src/builder.rs b/awc/src/builder.rs index 5aae394f..0dfcd547 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -3,7 +3,6 @@ use std::{fmt, net::IpAddr, rc::Rc, time::Duration}; use actix_http::{ error::HttpError, header::{self, HeaderMap, HeaderName, TryIntoHeaderPair}, - Uri, }; use actix_rt::net::{ActixStream, TcpStream}; use actix_service::{boxed, Service}; @@ -11,7 +10,8 @@ use base64::prelude::*; use crate::{ client::{ - ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection, + ClientConfig, ConnectInfo, Connector, ConnectorService, HostnameWithSni, TcpConnectError, + TcpConnection, }, connect::DefaultConnector, error::SendRequestError, @@ -46,8 +46,8 @@ impl ClientBuilder { #[allow(clippy::new_ret_no_self)] pub fn new() -> ClientBuilder< impl Service< - ConnectInfo, - Response = TcpConnection, + ConnectInfo, + Response = TcpConnection, Error = TcpConnectError, > + Clone, (), @@ -69,16 +69,22 @@ impl ClientBuilder { impl ClientBuilder where - S: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, Io: ActixStream + fmt::Debug + 'static, { /// Use custom connector service. pub fn connector(self, connector: Connector) -> ClientBuilder where - S1: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S1: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, Io1: ActixStream + fmt::Debug + 'static, { diff --git a/awc/src/client/connector.rs b/awc/src/client/connector.rs index f3d44307..2e3f977f 100644 --- a/awc/src/client/connector.rs +++ b/awc/src/client/connector.rs @@ -16,10 +16,9 @@ use actix_rt::{ use actix_service::Service; use actix_tls::connect::{ ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection, - Connector as TcpConnector, Resolver, + Connector as TcpConnector, Host, Resolver, }; use futures_core::{future::LocalBoxFuture, ready}; -use http::Uri; use pin_project_lite::pin_project; use super::{ @@ -27,9 +26,41 @@ use super::{ connection::{Connection, ConnectionIo}, error::ConnectError, pool::ConnectionPool, - Connect, + Connect, ServerName, }; +pub enum HostnameWithSni { + ForTcp(String, u16, Option), + ForTls(String, u16, Option), +} + +impl Host for HostnameWithSni { + fn hostname(&self) -> &str { + match self { + HostnameWithSni::ForTcp(hostname, _, _) => hostname, + HostnameWithSni::ForTls(hostname, _, sni) => sni.as_deref().unwrap_or(hostname), + } + } + + fn port(&self) -> Option { + match self { + HostnameWithSni::ForTcp(_, port, _) => Some(*port), + HostnameWithSni::ForTls(_, port, _) => Some(*port), + } + } +} + +impl HostnameWithSni { + pub fn to_tls(self) -> Self { + match self { + HostnameWithSni::ForTcp(hostname, port, sni) => { + HostnameWithSni::ForTls(hostname, port, sni) + } + HostnameWithSni::ForTls(_, _, _) => self, + } + } +} + enum OurTlsConnector { #[allow(dead_code)] // only dead when no TLS feature is enabled None, @@ -95,8 +126,8 @@ impl Connector<()> { #[allow(clippy::new_ret_no_self, clippy::let_unit_value)] pub fn new() -> Connector< impl Service< - ConnectInfo, - Response = TcpConnection, + ConnectInfo, + Response = TcpConnection, Error = actix_tls::connect::ConnectError, > + Clone, > { @@ -214,8 +245,11 @@ impl Connector { pub fn connector(self, connector: S1) -> Connector where Io1: ActixStream + fmt::Debug + 'static, - S1: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone, + S1: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone, { Connector { connector, @@ -235,8 +269,11 @@ where // This remap is to hide ActixStream's trait methods. They are not meant to be called // from user code. IO: ActixStream + fmt::Debug + 'static, - S: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, { /// Sets TCP connection timeout. @@ -454,7 +491,7 @@ where use actix_utils::future::{ready, Ready}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let io = self.into_parts().0; (io, Protocol::Http2) @@ -505,7 +542,7 @@ where use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -543,7 +580,7 @@ where use actix_tls::connect::rustls_0_20::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -577,7 +614,7 @@ where use actix_tls::connect::rustls_0_21::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -614,7 +651,7 @@ where use actix_tls::connect::rustls_0_22::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -648,7 +685,7 @@ where use actix_tls::connect::rustls_0_23::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -688,7 +725,7 @@ where } } -/// tcp service for map `TcpConnection` type to `(Io, Protocol)` +/// tcp service for map `TcpConnection` type to `(Io, Protocol)` #[derive(Clone)] pub struct TcpConnectorService { service: S, @@ -696,7 +733,9 @@ pub struct TcpConnectorService { impl Service for TcpConnectorService where - S: Service, Error = ConnectError> + Clone + 'static, + S: Service, Error = ConnectError> + + Clone + + 'static, { type Response = (Io, Protocol); type Error = ConnectError; @@ -721,7 +760,7 @@ pin_project! { impl Future for TcpConnectorFuture where - Fut: Future, ConnectError>>, + Fut: Future, ConnectError>>, { type Output = Result<(Io, Protocol), ConnectError>; @@ -767,9 +806,10 @@ struct TlsConnectorService { ))] impl Service for TlsConnectorService where - Tcp: - Service, Error = ConnectError> + Clone + 'static, - Tls: Service, Error = std::io::Error> + Clone + 'static, + Tcp: Service, Error = ConnectError> + + Clone + + 'static, + Tls: Service, Error = std::io::Error> + Clone + 'static, Tls::Response: IntoConnectionIo, IO: ConnectionIo, { @@ -822,9 +862,14 @@ trait IntoConnectionIo { impl Future for TlsConnectorFuture where - S: Service, Response = Res, Error = std::io::Error, Future = Fut2>, + S: Service< + TcpConnection, + Response = Res, + Error = std::io::Error, + Future = Fut2, + >, S::Response: IntoConnectionIo, - Fut1: Future, ConnectError>>, + Fut1: Future, ConnectError>>, Fut2: Future>, Io: ConnectionIo, { @@ -838,10 +883,11 @@ where timeout, } => { let res = ready!(fut.poll(cx))?; + let (io, hostname_with_sni) = res.into_parts(); let fut = tls_service .take() .expect("TlsConnectorFuture polled after complete") - .call(res); + .call(TcpConnection::new(hostname_with_sni.to_tls(), io)); let timeout = sleep(*timeout); self.set(TlsConnectorFuture::TlsConnect { fut, timeout }); self.poll(cx) @@ -875,8 +921,11 @@ impl TcpConnectorInnerService { impl Service for TcpConnectorInnerService where - S: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, { type Response = S::Response; @@ -886,7 +935,13 @@ where actix_service::forward_ready!(service); fn call(&self, req: Connect) -> Self::Future { - let mut req = ConnectInfo::new(req.uri).set_addr(req.addr); + let mut req = ConnectInfo::new(HostnameWithSni::ForTcp( + req.hostname, + req.port, + req.sni_host, + )) + .set_addr(req.addr) + .set_port(req.port); if let Some(local_addr) = self.local_address { req = req.set_local_addr(local_addr); @@ -911,9 +966,9 @@ pin_project! { impl Future for TcpConnectorInnerFuture where - Fut: Future, TcpConnectError>>, + Fut: Future, TcpConnectError>>, { - type Output = Result, ConnectError>; + type Output = Result, ConnectError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -973,16 +1028,17 @@ where } fn call(&self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => match self.tls_pool { + if req.tls { + match &self.tls_pool { None => ConnectorServiceFuture::SslIsNotSupported, - Some(ref pool) => ConnectorServiceFuture::Tls { + Some(pool) => ConnectorServiceFuture::Tls { fut: pool.call(req), }, - }, - _ => ConnectorServiceFuture::Tcp { + } + } else { + ConnectorServiceFuture::Tcp { fut: self.tcp_pool.call(req), - }, + } } } } diff --git a/awc/src/client/mod.rs b/awc/src/client/mod.rs index c9fa3725..5ac8650d 100644 --- a/awc/src/client/mod.rs +++ b/awc/src/client/mod.rs @@ -1,6 +1,6 @@ //! HTTP client. -use std::{rc::Rc, time::Duration}; +use std::{ops::Deref, rc::Rc, time::Duration}; use actix_http::{error::HttpError, header::HeaderMap, Method, RequestHead, Uri}; use actix_rt::net::TcpStream; @@ -21,13 +21,33 @@ mod pool; pub use self::{ connection::{Connection, ConnectionIo}, - connector::{Connector, ConnectorService}, + connector::{Connector, ConnectorService, HostnameWithSni}, error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}, }; -#[derive(Clone)] +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum ServerName { + Owned(String), + Borrowed(Rc), +} + +impl Deref for ServerName { + type Target = str; + + fn deref(&self) -> &str { + match self { + ServerName::Owned(ref s) => s, + ServerName::Borrowed(ref s) => s, + } + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] pub struct Connect { - pub uri: Uri, + pub hostname: String, + pub sni_host: Option, + pub port: u16, + pub tls: bool, pub addr: Option, } @@ -79,8 +99,8 @@ impl Client { /// This function is equivalent of `ClientBuilder::new()`. pub fn builder() -> ClientBuilder< impl Service< - ConnectInfo, - Response = TcpConnection, + ConnectInfo, + Response = TcpConnection, Error = TcpConnectError, > + Clone, > { diff --git a/awc/src/client/pool.rs b/awc/src/client/pool.rs index 5d764f72..9b1058d8 100644 --- a/awc/src/client/pool.rs +++ b/awc/src/client/pool.rs @@ -4,6 +4,7 @@ use std::{ cell::RefCell, collections::{HashMap, VecDeque}, future::Future, + hash::Hash, io, ops::Deref, pin::Pin, @@ -127,7 +128,7 @@ where Io: AsyncWrite + Unpin + 'static, { config: ConnectorConfig, - available: RefCell>>>, + available: RefCell>>>, permits: Arc, } @@ -168,12 +169,6 @@ where let inner = self.inner.clone(); Box::pin(async move { - let key = if let Some(authority) = req.uri.authority() { - authority.clone().into() - } else { - return Err(ConnectError::Unresolved); - }; - // acquire an owned permit and carry it with connection let permit = Arc::clone(&inner.permits) .acquire_owned() @@ -191,7 +186,7 @@ where // check if there is idle connection for given key. let mut map = inner.available.borrow_mut(); - if let Some(conns) = map.get_mut(&key) { + if let Some(conns) = map.get_mut(&req) { let now = Instant::now(); while let Some(mut c) = conns.pop_front() { @@ -232,7 +227,11 @@ where // construct acquired. It's used to put Io type back to pool/ close the Io type. // permit is carried with the whole lifecycle of Acquired. - let acquired = Acquired { key, inner, permit }; + let acquired = Acquired { + req: req.clone(), + inner, + permit, + }; // match the connection and spawn new one if did not get anything. match conn { @@ -344,8 +343,8 @@ pub struct Acquired where Io: AsyncWrite + Unpin + 'static, { - /// authority key for identify connection. - key: Key, + /// hash key for identify connection. + req: Connect, /// handle to connection pool. inner: ConnectionPoolInner, /// permit for limit concurrent in-flight connection for a Client object. @@ -360,12 +359,12 @@ impl Acquired { /// Release IO back into pool. pub(super) fn release(&self, conn: ConnectionInnerType, created: Instant) { - let Acquired { key, inner, .. } = self; + let Acquired { req, inner, .. } = self; inner .available .borrow_mut() - .entry(key.clone()) + .entry(req.clone()) .or_insert_with(VecDeque::new) .push_back(PooledConnection { conn, @@ -381,8 +380,6 @@ impl Acquired { mod test { use std::cell::Cell; - use http::Uri; - use super::*; /// A stream type that always returns pending on async read. @@ -467,7 +464,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("http://localhost"), + hostname: "localhost".to_string(), + port: 80, + tls: false, + sni_host: None, addr: None, }; @@ -507,7 +507,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("http://localhost"), + hostname: "localhost".to_string(), + port: 80, + tls: false, + sni_host: None, addr: None, }; @@ -549,7 +552,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("http://localhost"), + hostname: "localhost".to_string(), + port: 80, + tls: false, + sni_host: None, addr: None, }; @@ -588,7 +594,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("https://crates.io"), + hostname: "crates.io".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; @@ -601,7 +610,10 @@ mod test { release(conn); let req = Connect { - uri: Uri::from_static("https://google.com"), + hostname: "google.com".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; @@ -625,7 +637,10 @@ mod test { let pool = Rc::new(super::ConnectionPool::new(connector, config)); let req = Connect { - uri: Uri::from_static("https://crates.io"), + hostname: "crates.io".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; @@ -634,7 +649,10 @@ mod test { release(conn); let req = Connect { - uri: Uri::from_static("https://google.com"), + hostname: "google.com".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; let conn = pool.call(req.clone()).await.unwrap(); diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 14ed9e95..a7bbd7b2 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -13,7 +13,10 @@ use futures_core::{future::LocalBoxFuture, ready}; use crate::{ any_body::AnyBody, - client::{Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError}, + client::{ + Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, + ServerName, + }, ClientResponse, }; @@ -32,13 +35,18 @@ pub type BoxedSocket = Box; pub enum ConnectRequest { /// Standard HTTP request. /// - /// Contains the request head, body type, and optional pre-resolved socket address. - Client(RequestHeadType, AnyBody, Option), + /// Contains the request head, body type, optional pre-resolved socket address and optional sni host. + Client( + RequestHeadType, + AnyBody, + Option, + Option, + ), /// Tunnel used by WebSocket connection requests. /// - /// Contains the request head and optional pre-resolved socket address. - Tunnel(RequestHead, Option), + /// Contains the request head, optional pre-resolved socket address and optional sni host. + Tunnel(RequestHead, Option, Option), } /// Combined HTTP response & WebSocket tunnel type returned from connection service. @@ -103,17 +111,41 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { // connect to the host - let fut = match req { - ConnectRequest::Client(ref head, .., addr) => self.connector.call(ClientConnect { - uri: head.as_ref().uri.clone(), - addr, - }), - ConnectRequest::Tunnel(ref head, addr) => self.connector.call(ClientConnect { - uri: head.uri.clone(), - addr, - }), + let (head, addr, sni_host) = match req { + ConnectRequest::Client(ref head, .., addr, ref sni_host) => { + (head.as_ref(), addr, sni_host.clone()) + } + ConnectRequest::Tunnel(ref head, addr, ref sni_host) => (head, addr, sni_host.clone()), }; + let authority = if let Some(authority) = head.uri.authority() { + authority + } else { + return ConnectRequestFuture::Error { + err: ConnectError::Unresolved, + }; + }; + + let tls = match head.uri.scheme_str() { + Some("https") | Some("wss") => true, + _ => false, + }; + + let fut = + self.connector.call(ClientConnect { + hostname: authority.host().to_string(), + port: authority.port().map(|p| p.as_u16()).unwrap_or_else(|| { + if tls { + 443 + } else { + 80 + } + }), + tls, + sni_host, + addr, + }); + ConnectRequestFuture::Connection { fut, req: Some(req), @@ -127,6 +159,9 @@ pin_project_lite::pin_project! { where Io: ConnectionIo { + Error { + err: ConnectError + }, Connection { #[pin] fut: Fut, @@ -192,6 +227,10 @@ where let framed = framed.into_map_io(|io| Box::new(io) as _); Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed))) } + + ConnectRequestProj::Error { .. } => { + Poll::Ready(Err(SendRequestError::Connect(ConnectError::Unresolved))) + } } } } diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index 86240523..d622f8ec 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -11,7 +11,7 @@ use futures_core::Stream; use serde::Serialize; use crate::{ - client::ClientConfig, + client::{ClientConfig, ServerName}, sender::{RequestSender, SendClientRequest}, BoxError, }; @@ -26,6 +26,7 @@ pub struct FrozenClientRequest { pub(crate) response_decompress: bool, pub(crate) timeout: Option, pub(crate) config: ClientConfig, + pub(crate) sni_host: Option, } impl FrozenClientRequest { @@ -54,6 +55,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), body, ) } @@ -65,6 +67,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), value, ) } @@ -76,6 +79,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), value, ) } @@ -91,6 +95,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), stream, ) } @@ -102,6 +107,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), ) } @@ -156,6 +162,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), body, ) } @@ -171,6 +178,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), value, ) } @@ -186,6 +194,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), value, ) } @@ -205,6 +214,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), stream, ) } @@ -220,6 +230,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), ) } } diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index b2cf9c45..81f4d799 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -73,11 +73,13 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { match req { - ConnectRequest::Tunnel(head, addr) => { - let fut = self.connector.call(ConnectRequest::Tunnel(head, addr)); + ConnectRequest::Tunnel(head, addr, sni_host) => { + let fut = self + .connector + .call(ConnectRequest::Tunnel(head, addr, sni_host)); RedirectServiceFuture::Tunnel { fut } } - ConnectRequest::Client(head, body, addr) => { + ConnectRequest::Client(head, body, addr, sni_host) => { let connector = Rc::clone(&self.connector); let max_redirect_times = self.max_redirect_times; @@ -96,7 +98,7 @@ where _ => None, }; - let fut = connector.call(ConnectRequest::Client(head, body, addr)); + let fut = connector.call(ConnectRequest::Client(head, body, addr, sni_host)); RedirectServiceFuture::Client { fut, @@ -221,7 +223,8 @@ where let fut = connector .as_ref() .unwrap() - .call(ConnectRequest::Client(head, body_new, addr)); + // @TODO find a way to get sni host + .call(ConnectRequest::Client(head, body_new, addr, None)); self.set(RedirectServiceFuture::Client { fut, diff --git a/awc/src/request.rs b/awc/src/request.rs index 5f42f67e..b0f995a6 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -14,7 +14,7 @@ use serde::Serialize; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::ClientConfig, + client::{ClientConfig, ServerName}, error::{FreezeRequestError, InvalidUrl}, frozen::FrozenClientRequest, sender::{PrepForSendingError, RequestSender, SendClientRequest}, @@ -48,6 +48,7 @@ pub struct ClientRequest { response_decompress: bool, timeout: Option, config: ClientConfig, + sni_host: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -69,6 +70,7 @@ impl ClientRequest { cookies: None, timeout: None, response_decompress: true, + sni_host: None, } .method(method) .uri(uri) @@ -306,6 +308,12 @@ impl ClientRequest { Ok(self) } + /// Set SNI (Server Name Indication) host for this request. + pub fn sni_host(mut self, host: impl Into) -> Self { + self.sni_host = Some(ServerName::Owned(host.into())); + self + } + /// Freeze request builder and construct `FrozenClientRequest`, /// which could be used for sending same request multiple times. pub fn freeze(self) -> Result { @@ -320,6 +328,10 @@ impl ClientRequest { response_decompress: slf.response_decompress, timeout: slf.timeout, config: slf.config, + sni_host: slf.sni_host.map(|v| match v { + ServerName::Borrowed(r) => ServerName::Borrowed(r), + ServerName::Owned(o) => ServerName::Borrowed(Rc::new(o)), + }), }; Ok(request) @@ -340,6 +352,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, body, ) } @@ -356,6 +369,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, value, ) } @@ -374,6 +388,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, value, ) } @@ -394,6 +409,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, stream, ) } @@ -410,6 +426,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, ) } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index b676ebf2..ab3ca596 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -23,7 +23,7 @@ use serde::Serialize; use crate::{ any_body::AnyBody, - client::ClientConfig, + client::{ClientConfig, ServerName}, error::{FreezeRequestError, InvalidUrl, SendRequestError}, BoxError, ClientResponse, ConnectRequest, ConnectResponse, }; @@ -186,6 +186,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, body: impl MessageBody + 'static, ) -> SendClientRequest { let req = match self { @@ -193,11 +194,13 @@ impl RequestSender { RequestHeadType::Owned(head), AnyBody::from_message_body(body).into_boxed(), addr, + sni_host, ), RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestHeadType::Rc(head, extra_headers), AnyBody::from_message_body(body).into_boxed(), addr, + sni_host, ), }; @@ -212,6 +215,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, value: impl Serialize, ) -> SendClientRequest { let body = match serde_json::to_string(&value) { @@ -223,7 +227,7 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, body) + self.send_body(addr, response_decompress, timeout, config, sni_host, body) } pub(crate) fn send_form( @@ -232,6 +236,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, value: impl Serialize, ) -> SendClientRequest { let body = match serde_urlencoded::to_string(value) { @@ -246,7 +251,7 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, body) + self.send_body(addr, response_decompress, timeout, config, sni_host, body) } pub(crate) fn send_stream( @@ -255,6 +260,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, stream: S, ) -> SendClientRequest where @@ -266,6 +272,7 @@ impl RequestSender { response_decompress, timeout, config, + sni_host, BodyStream::new(stream), ) } @@ -276,8 +283,9 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, ) -> SendClientRequest { - self.send_body(addr, response_decompress, timeout, config, ()) + self.send_body(addr, response_decompress, timeout, config, sni_host, ()) } fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 760331e9..ef5cb715 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -38,7 +38,7 @@ use base64::prelude::*; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::ClientConfig, + client::{ClientConfig, ServerName}, connect::{BoxedSocket, ConnectRequest}, error::{HttpError, InvalidUrl, SendRequestError, WsClientError}, http::{ @@ -58,6 +58,7 @@ pub struct WebsocketsRequest { max_size: usize, server_mode: bool, config: ClientConfig, + sni_host: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -96,6 +97,7 @@ impl WebsocketsRequest { server_mode: false, #[cfg(feature = "cookies")] cookies: None, + sni_host: None, } } @@ -249,6 +251,12 @@ impl WebsocketsRequest { self.header(AUTHORIZATION, format!("Bearer {}", token)) } + /// Set SNI (Server Name Indication) host for this request. + pub fn sni_host(mut self, host: impl Into) -> Self { + self.sni_host = Some(ServerName::Owned(host.into())); + self + } + /// Complete request construction and connect to a WebSocket server. pub async fn connect( mut self, @@ -338,7 +346,7 @@ impl WebsocketsRequest { let max_size = self.max_size; let server_mode = self.server_mode; - let req = ConnectRequest::Tunnel(head, self.addr); + let req = ConnectRequest::Tunnel(head, self.addr, self.sni_host); let fut = self.config.connector.call(req); diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs index 7e832f67..afe21f9e 100644 --- a/awc/tests/test_rustls_client.rs +++ b/awc/tests/test_rustls_client.rs @@ -43,6 +43,8 @@ fn tls_config() -> ServerConfig { } mod danger { + use std::collections::HashSet; + use rustls::{ client::danger::{ServerCertVerified, ServerCertVerifier}, pki_types::UnixTime, @@ -50,8 +52,10 @@ mod danger { use super::*; - #[derive(Debug)] - pub struct NoCertificateVerification; + #[derive(Debug, Default)] + pub struct NoCertificateVerification { + pub trusted_hosts: HashSet, + } impl ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( @@ -62,7 +66,15 @@ mod danger { _ocsp_response: &[u8], _now: UnixTime, ) -> Result { - Ok(rustls::client::danger::ServerCertVerified::assertion()) + if self.trusted_hosts.is_empty() { + return Ok(ServerCertVerified::assertion()); + } + + if self.trusted_hosts.contains(_server_name.to_str().as_ref()) { + return Ok(ServerCertVerified::assertion()); + } + + Err(rustls::Error::General("untrusted host".into())) } fn verify_tls12_signature( @@ -124,7 +136,7 @@ async fn test_connection_reuse_h2() { // disable TLS verification config .dangerous() - .set_certificate_verifier(Arc::new(danger::NoCertificateVerification)); + .set_certificate_verifier(Arc::new(danger::NoCertificateVerification::default())); let client = awc::Client::builder() .connector(awc::Connector::new().rustls_0_23(Arc::new(config))) @@ -144,3 +156,84 @@ async fn test_connection_reuse_h2() { // one connection assert_eq!(num.load(Ordering::Relaxed), 1); } + +#[actix_rt::test] +async fn test_connection_with_sni() { + let srv = test_server(move || { + HttpService::build() + .h2(map_config( + App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))), + |_| AppConfig::default(), + )) + .rustls_0_23(tls_config()) + .map_err(|_| ()) + }) + .await; + + let mut config = ClientConfig::builder() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + + let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + config.alpn_protocols = protos; + + // disable TLS verification + config + .dangerous() + .set_certificate_verifier(Arc::new(danger::NoCertificateVerification { + trusted_hosts: ["localhost".to_owned()].iter().cloned().collect(), + })); + + let client = awc::Client::builder() + .connector(awc::Connector::new().rustls_0_23(Arc::new(config))) + .finish(); + + // req : standard request + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req : test specific host with address, return trusted host + let request = client.get(srv.surl("/")).sni_host("localhost").send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req : test bad host, return untrusted host + let request = client.get(srv.surl("/")).sni_host("bad.host").send(); + let response = request.await; + + assert!(response.is_err()); + assert_eq!( + response.unwrap_err().to_string(), + "Failed to connect to host: unexpected error: untrusted host" + ); + + // req : test specific host with address, return untrusted host + let addr = srv.addr(); + let request = client.get("https://example.com:443/").address(addr).send(); + let response = request.await; + + assert!(response.is_err()); + assert_eq!( + response.unwrap_err().to_string(), + "Failed to connect to host: unexpected error: untrusted host" + ); + + // req : test specify sni_host with address and other host (authority) + let request = client + .get("https://example.com:443/") + .address(addr) + .sni_host("localhost") + .send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req : test ip address with sni host + let request = client + .get("https://127.0.0.1:443/") + .address(addr) + .sni_host("localhost") + .send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); +} From 610dd616ef3206e781a04d33ed6ed2e2ed4bd50f Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Fri, 6 Dec 2024 09:43:26 +0100 Subject: [PATCH 2/2] feat(awc): split connector config with connect config, allow to configure connect config per request --- awc/src/client/config.rs | 100 ++++++++++++++++++++++++++++++--- awc/src/client/connector.rs | 39 +++++++------ awc/src/client/h2proto.rs | 8 +-- awc/src/client/mod.rs | 2 + awc/src/client/pool.rs | 39 +++++++++++-- awc/src/connect.rs | 23 +++++--- awc/src/frozen.rs | 13 ++++- awc/src/middleware/redirect.rs | 13 +++-- awc/src/request.rs | 19 ++++++- awc/src/sender.rs | 40 +++++++++++-- awc/src/ws.rs | 22 +++++++- 11 files changed, 261 insertions(+), 57 deletions(-) diff --git a/awc/src/client/config.rs b/awc/src/client/config.rs index 530c1e03..bd3da234 100644 --- a/awc/src/client/config.rs +++ b/awc/src/client/config.rs @@ -3,29 +3,33 @@ use std::{net::IpAddr, time::Duration}; const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB -/// Connector configuration -#[derive(Clone)] -pub(crate) struct ConnectorConfig { +/// Connect configuration +#[derive(Clone, Hash, Eq, PartialEq)] +pub struct ConnectConfig { pub(crate) timeout: Duration, pub(crate) handshake_timeout: Duration, pub(crate) conn_lifetime: Duration, pub(crate) conn_keep_alive: Duration, - pub(crate) disconnect_timeout: Option, - pub(crate) limit: usize, pub(crate) conn_window_size: u32, pub(crate) stream_window_size: u32, pub(crate) local_address: Option, } -impl Default for ConnectorConfig { +/// Connector configuration +#[derive(Clone)] +pub struct ConnectorConfig { + pub(crate) default_connect_config: ConnectConfig, + pub(crate) disconnect_timeout: Option, + pub(crate) limit: usize, +} + +impl Default for ConnectConfig { fn default() -> Self { Self { timeout: Duration::from_secs(5), handshake_timeout: Duration::from_secs(5), conn_lifetime: Duration::from_secs(75), conn_keep_alive: Duration::from_secs(15), - disconnect_timeout: Some(Duration::from_millis(3000)), - limit: 100, conn_window_size: DEFAULT_H2_CONN_WINDOW, stream_window_size: DEFAULT_H2_STREAM_WINDOW, local_address: None, @@ -33,10 +37,88 @@ impl Default for ConnectorConfig { } } +impl Default for ConnectorConfig { + fn default() -> Self { + Self { + default_connect_config: ConnectConfig::default(), + disconnect_timeout: Some(Duration::from_millis(3000)), + limit: 100, + } + } +} + impl ConnectorConfig { - pub(crate) fn no_disconnect_timeout(&self) -> Self { + pub fn no_disconnect_timeout(&self) -> Self { let mut res = self.clone(); res.disconnect_timeout = None; res } } + +impl ConnectConfig { + /// Sets TCP connection timeout. + /// + /// This is the max time allowed to connect to remote host, including DNS name resolution. + /// + /// By default, the timeout is 5 seconds. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + /// Sets TLS handshake timeout. + /// + /// This is the max time allowed to perform the TLS handshake with remote host after TCP + /// connection is established. + /// + /// By default, the timeout is 5 seconds. + pub fn handshake_timeout(mut self, timeout: Duration) -> Self { + self.handshake_timeout = timeout; + self + } + + /// Sets the initial window size (in bytes) for HTTP/2 stream-level flow control for received + /// data. + /// + /// The default value is 65,535 and is good for APIs, but not for big objects. + pub fn initial_window_size(mut self, size: u32) -> Self { + self.stream_window_size = size; + self + } + + /// Sets the initial window size (in bytes) for HTTP/2 connection-level flow control for + /// received data. + /// + /// The default value is 65,535 and is good for APIs, but not for big objects. + pub fn initial_connection_window_size(mut self, size: u32) -> Self { + self.conn_window_size = size; + self + } + + /// Set keep-alive period for opened connection. + /// + /// Keep-alive period is the period between connection usage. If + /// the delay between repeated usages of the same connection + /// exceeds this period, the connection is closed. + /// Default keep-alive period is 15 seconds. + pub fn conn_keep_alive(mut self, dur: Duration) -> Self { + self.conn_keep_alive = dur; + self + } + + /// Set max lifetime period for connection. + /// + /// Connection lifetime is max lifetime of any opened connection + /// until it is closed regardless of keep-alive period. + /// Default lifetime period is 75 seconds. + pub fn conn_lifetime(mut self, dur: Duration) -> Self { + self.conn_lifetime = dur; + self + } + + /// Set local IP Address the connector would use for establishing connection. + pub fn local_address(mut self, addr: IpAddr) -> Self { + self.local_address = Some(addr); + self + } +} diff --git a/awc/src/client/connector.rs b/awc/src/client/connector.rs index 2e3f977f..d0118d3c 100644 --- a/awc/src/client/connector.rs +++ b/awc/src/client/connector.rs @@ -282,7 +282,7 @@ where /// /// By default, the timeout is 5 seconds. pub fn timeout(mut self, timeout: Duration) -> Self { - self.config.timeout = timeout; + self.config.default_connect_config.timeout = timeout; self } @@ -293,7 +293,7 @@ where /// /// By default, the timeout is 5 seconds. pub fn handshake_timeout(mut self, timeout: Duration) -> Self { - self.config.handshake_timeout = timeout; + self.config.default_connect_config.handshake_timeout = timeout; self } @@ -387,7 +387,7 @@ where /// /// The default value is 65,535 and is good for APIs, but not for big objects. pub fn initial_window_size(mut self, size: u32) -> Self { - self.config.stream_window_size = size; + self.config.default_connect_config.stream_window_size = size; self } @@ -396,7 +396,7 @@ where /// /// The default value is 65,535 and is good for APIs, but not for big objects. pub fn initial_connection_window_size(mut self, size: u32) -> Self { - self.config.conn_window_size = size; + self.config.default_connect_config.conn_window_size = size; self } @@ -422,7 +422,7 @@ where /// exceeds this period, the connection is closed. /// Default keep-alive period is 15 seconds. pub fn conn_keep_alive(mut self, dur: Duration) -> Self { - self.config.conn_keep_alive = dur; + self.config.default_connect_config.conn_keep_alive = dur; self } @@ -432,7 +432,7 @@ where /// until it is closed regardless of keep-alive period. /// Default lifetime period is 75 seconds. pub fn conn_lifetime(mut self, dur: Duration) -> Self { - self.config.conn_lifetime = dur; + self.config.default_connect_config.conn_lifetime = dur; self } @@ -451,7 +451,7 @@ where /// Set local IP Address the connector would use for establishing connection. pub fn local_address(mut self, addr: IpAddr) -> Self { - self.config.local_address = Some(addr); + self.config.default_connect_config.local_address = Some(addr); self } @@ -459,8 +459,8 @@ where /// /// The `Connector` builder always concludes by calling `finish()` last in its combinator chain. pub fn finish(self) -> ConnectorService { - let local_address = self.config.local_address; - let timeout = self.config.timeout; + let local_address = self.config.default_connect_config.local_address; + let timeout = self.config.default_connect_config.timeout; let tcp_service_inner = TcpConnectorInnerService::new(self.connector, timeout, local_address); @@ -523,7 +523,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -557,7 +557,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -596,7 +596,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -630,7 +630,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -667,7 +667,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -701,7 +701,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -824,9 +824,13 @@ where } fn call(&self, req: Connect) -> Self::Future { + let timeout = req + .config + .clone() + .map(|c| c.handshake_timeout) + .unwrap_or(self.timeout); let fut = self.tcp_service.call(req); let tls_service = self.tls_service.clone(); - let timeout = self.timeout; TlsConnectorFuture::TcpConnect { fut, @@ -935,6 +939,7 @@ where actix_service::forward_ready!(service); fn call(&self, req: Connect) -> Self::Future { + let timeout = req.config.map(|c| c.timeout).unwrap_or(self.timeout); let mut req = ConnectInfo::new(HostnameWithSni::ForTcp( req.hostname, req.port, @@ -949,7 +954,7 @@ where TcpConnectorInnerFuture { fut: self.service.call(req), - timeout: sleep(self.timeout), + timeout: sleep(timeout), } } } diff --git a/awc/src/client/h2proto.rs b/awc/src/client/h2proto.rs index c3f801f2..738a9f12 100644 --- a/awc/src/client/h2proto.rs +++ b/awc/src/client/h2proto.rs @@ -19,7 +19,6 @@ use http::{ use log::trace; use super::{ - config::ConnectorConfig, connection::{ConnectionIo, H2Connection}, error::SendRequestError, }; @@ -186,12 +185,13 @@ where pub(crate) fn handshake( io: Io, - config: &ConnectorConfig, + stream_window_size: u32, + conn_window_size: u32, ) -> impl Future, Connection), h2::Error>> { let mut builder = Builder::new(); builder - .initial_window_size(config.stream_window_size) - .initial_connection_window_size(config.conn_window_size) + .initial_window_size(stream_window_size) + .initial_connection_window_size(conn_window_size) .enable_push(false); builder.handshake(io) } diff --git a/awc/src/client/mod.rs b/awc/src/client/mod.rs index 5ac8650d..ca324b3d 100644 --- a/awc/src/client/mod.rs +++ b/awc/src/client/mod.rs @@ -20,6 +20,7 @@ mod h2proto; mod pool; pub use self::{ + config::ConnectConfig, connection::{Connection, ConnectionIo}, connector::{Connector, ConnectorService, HostnameWithSni}, error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}, @@ -49,6 +50,7 @@ pub struct Connect { pub port: u16, pub tls: bool, pub addr: Option, + pub config: Option>, } /// An asynchronous HTTP and WebSocket client. diff --git a/awc/src/client/pool.rs b/awc/src/client/pool.rs index 9b1058d8..c0429e04 100644 --- a/awc/src/client/pool.rs +++ b/awc/src/client/pool.rs @@ -190,7 +190,11 @@ where let now = Instant::now(); while let Some(mut c) = conns.pop_front() { - let config = &inner.config; + let config = req + .config + .as_ref() + .map(|c| c.as_ref()) + .unwrap_or(&inner.config.default_connect_config); let idle_dur = now - c.used; let age = now - c.created; let conn_ineligible = @@ -225,6 +229,17 @@ where conn }; + let stream_window_size = req + .config + .as_ref() + .map(|c| c.stream_window_size) + .unwrap_or(inner.config.default_connect_config.stream_window_size); + let conn_window_size = req + .config + .as_ref() + .map(|c| c.conn_window_size) + .unwrap_or(inner.config.default_connect_config.conn_window_size); + // construct acquired. It's used to put Io type back to pool/ close the Io type. // permit is carried with the whole lifecycle of Acquired. let acquired = Acquired { @@ -245,8 +260,8 @@ where if proto == Protocol::Http1 { Ok(ConnectionType::from_h1(io, Instant::now(), acquired)) } else { - let config = &acquired.inner.config; - let (sender, connection) = handshake(io, config).await?; + let (sender, connection) = + handshake(io, stream_window_size, conn_window_size).await?; let inner = H2ConnectionInner::new(sender, connection); Ok(ConnectionType::from_h2(inner, Instant::now(), acquired)) } @@ -381,6 +396,7 @@ mod test { use std::cell::Cell; use super::*; + use crate::client::ConnectConfig; /// A stream type that always returns pending on async read. /// @@ -469,6 +485,7 @@ mod test { tls: false, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -500,7 +517,10 @@ mod test { let connector = TestPoolConnector { generated }; let config = ConnectorConfig { - conn_keep_alive: Duration::from_secs(1), + default_connect_config: ConnectConfig { + conn_keep_alive: Duration::from_secs(1), + ..Default::default() + }, ..Default::default() }; @@ -512,6 +532,7 @@ mod test { tls: false, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -545,7 +566,10 @@ mod test { let connector = TestPoolConnector { generated }; let config = ConnectorConfig { - conn_lifetime: Duration::from_secs(1), + default_connect_config: ConnectConfig { + conn_lifetime: Duration::from_secs(1), + ..Default::default() + }, ..Default::default() }; @@ -557,6 +581,7 @@ mod test { tls: false, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -599,6 +624,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -615,6 +641,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -642,6 +669,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -654,6 +682,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); assert_eq!(2, generated_clone.get()); diff --git a/awc/src/connect.rs b/awc/src/connect.rs index a7bbd7b2..f5b79d0d 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -14,8 +14,8 @@ use futures_core::{future::LocalBoxFuture, ready}; use crate::{ any_body::AnyBody, client::{ - Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, - ServerName, + Connect as ClientConnect, ConnectConfig, ConnectError, Connection, ConnectionIo, + SendRequestError, ServerName, }, ClientResponse, }; @@ -41,12 +41,18 @@ pub enum ConnectRequest { AnyBody, Option, Option, + Option>, ), /// Tunnel used by WebSocket connection requests. /// /// Contains the request head, optional pre-resolved socket address and optional sni host. - Tunnel(RequestHead, Option, Option), + Tunnel( + RequestHead, + Option, + Option, + Option>, + ), } /// Combined HTTP response & WebSocket tunnel type returned from connection service. @@ -111,11 +117,13 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { // connect to the host - let (head, addr, sni_host) = match req { - ConnectRequest::Client(ref head, .., addr, ref sni_host) => { - (head.as_ref(), addr, sni_host.clone()) + let (head, addr, sni_host, config) = match req { + ConnectRequest::Client(ref head, .., addr, ref sni_host, ref config) => { + (head.as_ref(), addr, sni_host.clone(), config.clone()) + } + ConnectRequest::Tunnel(ref head, addr, ref sni_host, ref config) => { + (head, addr, sni_host.clone(), config.clone()) } - ConnectRequest::Tunnel(ref head, addr, ref sni_host) => (head, addr, sni_host.clone()), }; let authority = if let Some(authority) = head.uri.authority() { @@ -144,6 +152,7 @@ where tls, sni_host, addr, + config, }); ConnectRequestFuture::Connection { diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index d622f8ec..eba5e1a4 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -11,7 +11,7 @@ use futures_core::Stream; use serde::Serialize; use crate::{ - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, sender::{RequestSender, SendClientRequest}, BoxError, }; @@ -27,6 +27,7 @@ pub struct FrozenClientRequest { pub(crate) timeout: Option, pub(crate) config: ClientConfig, pub(crate) sni_host: Option, + pub(crate) connect_config: Option>, } impl FrozenClientRequest { @@ -56,6 +57,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), body, ) } @@ -68,6 +70,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), value, ) } @@ -80,6 +83,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), value, ) } @@ -96,6 +100,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), stream, ) } @@ -108,6 +113,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), ) } @@ -163,6 +169,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, body, ) } @@ -179,6 +186,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, value, ) } @@ -195,6 +203,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, value, ) } @@ -215,6 +224,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, stream, ) } @@ -231,6 +241,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, ) } } diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index 81f4d799..d927328b 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -73,13 +73,13 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { match req { - ConnectRequest::Tunnel(head, addr, sni_host) => { + ConnectRequest::Tunnel(head, addr, sni_host, config) => { let fut = self .connector - .call(ConnectRequest::Tunnel(head, addr, sni_host)); + .call(ConnectRequest::Tunnel(head, addr, sni_host, config)); RedirectServiceFuture::Tunnel { fut } } - ConnectRequest::Client(head, body, addr, sni_host) => { + ConnectRequest::Client(head, body, addr, sni_host, config) => { let connector = Rc::clone(&self.connector); let max_redirect_times = self.max_redirect_times; @@ -98,7 +98,8 @@ where _ => None, }; - let fut = connector.call(ConnectRequest::Client(head, body, addr, sni_host)); + let fut = + connector.call(ConnectRequest::Client(head, body, addr, sni_host, config)); RedirectServiceFuture::Client { fut, @@ -223,8 +224,8 @@ where let fut = connector .as_ref() .unwrap() - // @TODO find a way to get sni host - .call(ConnectRequest::Client(head, body_new, addr, None)); + // @TODO find a way to get sni host and config + .call(ConnectRequest::Client(head, body_new, addr, None, None)); self.set(RedirectServiceFuture::Client { fut, diff --git a/awc/src/request.rs b/awc/src/request.rs index b0f995a6..e24f19c7 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -14,7 +14,7 @@ use serde::Serialize; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, error::{FreezeRequestError, InvalidUrl}, frozen::FrozenClientRequest, sender::{PrepForSendingError, RequestSender, SendClientRequest}, @@ -49,6 +49,7 @@ pub struct ClientRequest { timeout: Option, config: ClientConfig, sni_host: Option, + connect_config: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -71,6 +72,7 @@ impl ClientRequest { timeout: None, response_decompress: true, sni_host: None, + connect_config: None, } .method(method) .uri(uri) @@ -281,6 +283,15 @@ impl ClientRequest { self } + /// Set specific connector configuration for this request. + /// + /// Not all config may be applied to the request, it depends on the connector and also + /// if there is already a connection established. + pub fn connect_config(mut self, config: ConnectConfig) -> Self { + self.connect_config = Some(config); + self + } + /// Set request timeout. Overrides client wide timeout setting. /// /// Request timeout is the total time before a response must be received. @@ -332,6 +343,7 @@ impl ClientRequest { ServerName::Borrowed(r) => ServerName::Borrowed(r), ServerName::Owned(o) => ServerName::Borrowed(Rc::new(o)), }), + connect_config: slf.connect_config.map(Rc::new), }; Ok(request) @@ -353,6 +365,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), body, ) } @@ -370,6 +383,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), value, ) } @@ -389,6 +403,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), value, ) } @@ -410,6 +425,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), stream, ) } @@ -427,6 +443,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), ) } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index ab3ca596..8347ceb4 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -23,7 +23,7 @@ use serde::Serialize; use crate::{ any_body::AnyBody, - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, error::{FreezeRequestError, InvalidUrl, SendRequestError}, BoxError, ClientResponse, ConnectRequest, ConnectResponse, }; @@ -187,6 +187,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connect_config: Option>, body: impl MessageBody + 'static, ) -> SendClientRequest { let req = match self { @@ -195,12 +196,14 @@ impl RequestSender { AnyBody::from_message_body(body).into_boxed(), addr, sni_host, + connect_config, ), RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestHeadType::Rc(head, extra_headers), AnyBody::from_message_body(body).into_boxed(), addr, sni_host, + connect_config, ), }; @@ -216,6 +219,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, value: impl Serialize, ) -> SendClientRequest { let body = match serde_json::to_string(&value) { @@ -227,7 +231,15 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, sni_host, body) + self.send_body( + addr, + response_decompress, + timeout, + config, + sni_host, + connector_config, + body, + ) } pub(crate) fn send_form( @@ -237,6 +249,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, value: impl Serialize, ) -> SendClientRequest { let body = match serde_urlencoded::to_string(value) { @@ -251,7 +264,15 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, sni_host, body) + self.send_body( + addr, + response_decompress, + timeout, + config, + sni_host, + connector_config, + body, + ) } pub(crate) fn send_stream( @@ -261,6 +282,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, stream: S, ) -> SendClientRequest where @@ -273,6 +295,7 @@ impl RequestSender { timeout, config, sni_host, + connector_config, BodyStream::new(stream), ) } @@ -284,8 +307,17 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, ) -> SendClientRequest { - self.send_body(addr, response_decompress, timeout, config, sni_host, ()) + self.send_body( + addr, + response_decompress, + timeout, + config, + sni_host, + connector_config, + (), + ) } fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> diff --git a/awc/src/ws.rs b/awc/src/ws.rs index ef5cb715..77b00fce 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -26,7 +26,7 @@ //! } //! ``` -use std::{fmt, net::SocketAddr, str}; +use std::{fmt, net::SocketAddr, rc::Rc, str}; use actix_codec::Framed; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; @@ -38,7 +38,7 @@ use base64::prelude::*; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, connect::{BoxedSocket, ConnectRequest}, error::{HttpError, InvalidUrl, SendRequestError, WsClientError}, http::{ @@ -59,6 +59,7 @@ pub struct WebsocketsRequest { server_mode: bool, config: ClientConfig, sni_host: Option, + connect_config: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -98,6 +99,7 @@ impl WebsocketsRequest { #[cfg(feature = "cookies")] cookies: None, sni_host: None, + connect_config: None, } } @@ -110,6 +112,15 @@ impl WebsocketsRequest { self } + /// Set specific connector configuration for this request. + /// + /// Not all config may be applied to the request, it depends on the connector and also + /// if there is already a connection established. + pub fn connector_config(mut self, config: ConnectConfig) -> Self { + self.connect_config = Some(config); + self + } + /// Set supported WebSocket protocols pub fn protocols(mut self, protos: U) -> Self where @@ -346,7 +357,12 @@ impl WebsocketsRequest { let max_size = self.max_size; let server_mode = self.server_mode; - let req = ConnectRequest::Tunnel(head, self.addr, self.sni_host); + let req = ConnectRequest::Tunnel( + head, + self.addr, + self.sni_host, + self.connect_config.map(Rc::new), + ); let fut = self.config.connector.call(req);