add `ClientBuilder::add_default_header`

This commit is contained in:
Rob Ede 2021-12-13 06:27:30 +00:00
parent 44005e216e
commit 22d6b8156a
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
6 changed files with 65 additions and 43 deletions

View File

@ -1,6 +1,9 @@
# Changes # Changes
## Unreleased - 2021-xx-xx ## Unreleased - 2021-xx-xx
* Add `ClientBuilder::add_default_header` and deprecate `ClientBuilder::header`. [#2510]
[#2510]: https://github.com/actix/actix-web/pull/2510
## 3.0.0-beta.13 - 2021-12-11 ## 3.0.0-beta.13 - 2021-12-11

View File

@ -2,7 +2,7 @@ use std::{convert::TryFrom, fmt, net::IpAddr, rc::Rc, time::Duration};
use actix_http::{ use actix_http::{
error::HttpError, error::HttpError,
header::{self, HeaderMap, HeaderName}, header::{self, HeaderMap, HeaderName, TryIntoHeaderPair},
Uri, Uri,
}; };
use actix_rt::net::{ActixStream, TcpStream}; use actix_rt::net::{ActixStream, TcpStream};
@ -21,11 +21,11 @@ use crate::{
/// This type can be used to construct an instance of `Client` through a /// This type can be used to construct an instance of `Client` through a
/// builder-like pattern. /// builder-like pattern.
pub struct ClientBuilder<S = (), M = ()> { pub struct ClientBuilder<S = (), M = ()> {
default_headers: bool,
max_http_version: Option<http::Version>, max_http_version: Option<http::Version>,
stream_window_size: Option<u32>, stream_window_size: Option<u32>,
conn_window_size: Option<u32>, conn_window_size: Option<u32>,
headers: HeaderMap, fundamental_headers: bool,
default_headers: HeaderMap,
timeout: Option<Duration>, timeout: Option<Duration>,
connector: Connector<S>, connector: Connector<S>,
middleware: M, middleware: M,
@ -44,15 +44,15 @@ impl ClientBuilder {
(), (),
> { > {
ClientBuilder { ClientBuilder {
middleware: (),
default_headers: true,
headers: HeaderMap::new(),
timeout: Some(Duration::from_secs(5)),
local_address: None,
connector: Connector::new(),
max_http_version: None, max_http_version: None,
stream_window_size: None, stream_window_size: None,
conn_window_size: None, conn_window_size: None,
fundamental_headers: true,
default_headers: HeaderMap::new(),
timeout: Some(Duration::from_secs(5)),
connector: Connector::new(),
middleware: (),
local_address: None,
max_redirects: 10, max_redirects: 10,
} }
} }
@ -78,8 +78,8 @@ where
{ {
ClientBuilder { ClientBuilder {
middleware: self.middleware, middleware: self.middleware,
fundamental_headers: self.fundamental_headers,
default_headers: self.default_headers, default_headers: self.default_headers,
headers: self.headers,
timeout: self.timeout, timeout: self.timeout,
local_address: self.local_address, local_address: self.local_address,
connector, connector,
@ -153,15 +153,31 @@ where
self self
} }
/// Do not add default request headers. /// Do not add fundamental default request headers.
///
/// By default `Date` and `User-Agent` headers are set. /// By default `Date` and `User-Agent` headers are set.
pub fn no_default_headers(mut self) -> Self { pub fn no_default_headers(mut self) -> Self {
self.default_headers = false; self.fundamental_headers = false;
self self
} }
/// Add default header. Headers added by this method /// Add default header.
/// get added to every request. ///
/// Headers added by this method get added to every request unless overriden by .
///
/// # Panics
/// Panics if header name or value is invalid.
pub fn add_default_header(mut self, header: impl TryIntoHeaderPair) -> Self {
match header.try_into_header_pair() {
Ok((key, value)) => self.default_headers.append(key, value),
Err(err) => panic!("Header error: {:?}", err.into()),
}
self
}
#[doc(hidden)]
#[deprecated(since = "3.0.0", note = "Prefer `add_default_header((key, value))`.")]
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
where where
HeaderName: TryFrom<K>, HeaderName: TryFrom<K>,
@ -172,11 +188,11 @@ where
match HeaderName::try_from(key) { match HeaderName::try_from(key) {
Ok(key) => match value.try_into_value() { Ok(key) => match value.try_into_value() {
Ok(value) => { Ok(value) => {
self.headers.append(key, value); self.default_headers.append(key, value);
} }
Err(e) => log::error!("Header value error: {:?}", e), Err(err) => log::error!("Header value error: {:?}", err),
}, },
Err(e) => log::error!("Header name error: {:?}", e), Err(err) => log::error!("Header name error: {:?}", err),
} }
self self
} }
@ -190,10 +206,10 @@ where
Some(password) => format!("{}:{}", username, password), Some(password) => format!("{}:{}", username, password),
None => format!("{}:", username), None => format!("{}:", username),
}; };
self.header( self.add_default_header((
header::AUTHORIZATION, header::AUTHORIZATION,
format!("Basic {}", base64::encode(&auth)), format!("Basic {}", base64::encode(&auth)),
) ))
} }
/// Set client wide HTTP bearer authentication header /// Set client wide HTTP bearer authentication header
@ -201,13 +217,12 @@ where
where where
T: fmt::Display, T: fmt::Display,
{ {
self.header(header::AUTHORIZATION, format!("Bearer {}", token)) self.add_default_header((header::AUTHORIZATION, format!("Bearer {}", token)))
} }
/// Registers middleware, in the form of a middleware component (type), /// Registers middleware, in the form of a middleware component (type), that runs during inbound
/// that runs during inbound and/or outbound processing in the request /// and/or outbound processing in the request life-cycle (request -> response),
/// life-cycle (request -> response), modifying request/response as /// modifying request/response as necessary, across all requests managed by the `Client`.
/// necessary, across all requests managed by the Client.
pub fn wrap<S1, M1>( pub fn wrap<S1, M1>(
self, self,
mw: M1, mw: M1,
@ -218,11 +233,11 @@ where
{ {
ClientBuilder { ClientBuilder {
middleware: NestTransform::new(self.middleware, mw), middleware: NestTransform::new(self.middleware, mw),
default_headers: self.default_headers, fundamental_headers: self.fundamental_headers,
max_http_version: self.max_http_version, max_http_version: self.max_http_version,
stream_window_size: self.stream_window_size, stream_window_size: self.stream_window_size,
conn_window_size: self.conn_window_size, conn_window_size: self.conn_window_size,
headers: self.headers, default_headers: self.default_headers,
timeout: self.timeout, timeout: self.timeout,
connector: self.connector, connector: self.connector,
local_address: self.local_address, local_address: self.local_address,
@ -237,10 +252,10 @@ where
M::Transform: M::Transform:
Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>, Service<ConnectRequest, Response = ConnectResponse, Error = SendRequestError>,
{ {
let redirect_time = self.max_redirects; let max_redirects = self.max_redirects;
if redirect_time > 0 { if max_redirects > 0 {
self.wrap(Redirect::new().max_redirect_times(redirect_time)) self.wrap(Redirect::new().max_redirect_times(max_redirects))
._finish() ._finish()
} else { } else {
self._finish() self._finish()
@ -272,7 +287,7 @@ where
let connector = boxed::rc_service(self.middleware.new_transform(connector)); let connector = boxed::rc_service(self.middleware.new_transform(connector));
Client(ClientConfig { Client(ClientConfig {
headers: Rc::new(self.headers), default_headers: Rc::new(self.default_headers),
timeout: self.timeout, timeout: self.timeout,
connector, connector,
}) })
@ -288,7 +303,7 @@ mod tests {
let client = ClientBuilder::new().basic_auth("username", Some("password")); let client = ClientBuilder::new().basic_auth("username", Some("password"));
assert_eq!( assert_eq!(
client client
.headers .default_headers
.get(header::AUTHORIZATION) .get(header::AUTHORIZATION)
.unwrap() .unwrap()
.to_str() .to_str()
@ -299,7 +314,7 @@ mod tests {
let client = ClientBuilder::new().basic_auth("username", None); let client = ClientBuilder::new().basic_auth("username", None);
assert_eq!( assert_eq!(
client client
.headers .default_headers
.get(header::AUTHORIZATION) .get(header::AUTHORIZATION)
.unwrap() .unwrap()
.to_str() .to_str()
@ -313,7 +328,7 @@ mod tests {
let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n"); let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n");
assert_eq!( assert_eq!(
client client
.headers .default_headers
.get(header::AUTHORIZATION) .get(header::AUTHORIZATION)
.unwrap() .unwrap()
.to_str() .to_str()

View File

@ -168,7 +168,7 @@ pub struct Client(ClientConfig);
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct ClientConfig { pub(crate) struct ClientConfig {
pub(crate) connector: BoxConnectorService, pub(crate) connector: BoxConnectorService,
pub(crate) headers: Rc<HeaderMap>, pub(crate) default_headers: Rc<HeaderMap>,
pub(crate) timeout: Option<Duration>, pub(crate) timeout: Option<Duration>,
} }
@ -204,7 +204,9 @@ impl Client {
{ {
let mut req = ClientRequest::new(method, url, self.0.clone()); let mut req = ClientRequest::new(method, url, self.0.clone());
for header in self.0.headers.iter() { for header in self.0.default_headers.iter() {
// header map is empty
// TODO: probably append instead
req = req.insert_header_if_none(header); req = req.insert_header_if_none(header);
} }
req req
@ -297,7 +299,7 @@ impl Client {
<Uri as TryFrom<U>>::Error: Into<HttpError>, <Uri as TryFrom<U>>::Error: Into<HttpError>,
{ {
let mut req = ws::WebsocketsRequest::new(url, self.0.clone()); let mut req = ws::WebsocketsRequest::new(url, self.0.clone());
for (key, value) in self.0.headers.iter() { for (key, value) in self.0.default_headers.iter() {
req.head.headers.insert(key.clone(), value.clone()); req.head.headers.insert(key.clone(), value.clone());
} }
req req
@ -308,6 +310,6 @@ impl Client {
/// Returns Some(&mut HeaderMap) when Client object is unique /// Returns Some(&mut HeaderMap) when Client object is unique
/// (No other clone of client exists at the same time). /// (No other clone of client exists at the same time).
pub fn headers(&mut self) -> Option<&mut HeaderMap> { pub fn headers(&mut self) -> Option<&mut HeaderMap> {
Rc::get_mut(&mut self.0.headers) Rc::get_mut(&mut self.0.default_headers)
} }
} }

View File

@ -442,13 +442,15 @@ mod tests {
}); });
let client = ClientBuilder::new() let client = ClientBuilder::new()
.header("custom", "value") .add_default_header(("custom", "value"))
.disable_redirects() .disable_redirects()
.finish(); .finish();
let res = client.get(srv.url("/")).send().await.unwrap(); let res = client.get(srv.url("/")).send().await.unwrap();
assert_eq!(res.status().as_u16(), 302); assert_eq!(res.status().as_u16(), 302);
let client = ClientBuilder::new().header("custom", "value").finish(); let client = ClientBuilder::new()
.add_default_header(("custom", "value"))
.finish();
let res = client.get(srv.url("/")).send().await.unwrap(); let res = client.get(srv.url("/")).send().await.unwrap();
assert_eq!(res.status().as_u16(), 200); assert_eq!(res.status().as_u16(), 200);
@ -520,7 +522,7 @@ mod tests {
// send a request to different origins, http://srv1/ then http://srv2/. So it should remove the header // send a request to different origins, http://srv1/ then http://srv2/. So it should remove the header
let client = ClientBuilder::new() let client = ClientBuilder::new()
.header(header::AUTHORIZATION, "auth_key_value") .add_default_header((header::AUTHORIZATION, "auth_key_value"))
.finish(); .finish();
let res = client.get(srv1.url("/")).send().await.unwrap(); let res = client.get(srv1.url("/")).send().await.unwrap();
assert_eq!(res.status().as_u16(), 200); assert_eq!(res.status().as_u16(), 200);

View File

@ -579,7 +579,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_client_header() { async fn test_client_header() {
let req = Client::builder() let req = Client::builder()
.header(header::CONTENT_TYPE, "111") .add_default_header((header::CONTENT_TYPE, "111"))
.finish() .finish()
.get("/"); .get("/");
@ -597,7 +597,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_client_header_override() { async fn test_client_header_override() {
let req = Client::builder() let req = Client::builder()
.header(header::CONTENT_TYPE, "111") .add_default_header((header::CONTENT_TYPE, "111"))
.finish() .finish()
.get("/") .get("/")
.insert_header((header::CONTENT_TYPE, "222")); .insert_header((header::CONTENT_TYPE, "222"));

View File

@ -445,7 +445,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_header_override() { async fn test_header_override() {
let req = Client::builder() let req = Client::builder()
.header(header::CONTENT_TYPE, "111") .add_default_header((header::CONTENT_TYPE, "111"))
.finish() .finish()
.ws("/") .ws("/")
.set_header(header::CONTENT_TYPE, "222"); .set_header(header::CONTENT_TYPE, "222");