From 0e4cc7ca7106f3c69f6561ce1f9dc52b54fb48f7 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 24 Feb 2021 15:26:56 +0800 Subject: [PATCH] add local_address bind for client builder --- Cargo.toml | 3 ++ actix-http/src/client/config.rs | 3 ++ actix-http/src/client/connector.rs | 48 ++++++++++++++++++++++-------- actix-http/src/h1/service.rs | 8 ++--- actix-http/src/h2/service.rs | 6 ++-- actix-http/src/service.rs | 10 +++---- awc/CHANGES.md | 2 ++ awc/src/builder.rs | 13 ++++++++ awc/tests/test_client.rs | 32 ++++++++++++++++++++ 9 files changed, 101 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1a1b8645c..4c8e9c4ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,6 +141,9 @@ actix-multipart = { path = "actix-multipart" } actix-files = { path = "actix-files" } awc = { path = "awc" } +actix-tls = { git = "https://github.com/actix/actix-net.git" } +actix-rt = { git = "https://github.com/actix/actix-net.git" } + [[bench]] name = "server" harness = false diff --git a/actix-http/src/client/config.rs b/actix-http/src/client/config.rs index fad902d04..0d54e1b49 100644 --- a/actix-http/src/client/config.rs +++ b/actix-http/src/client/config.rs @@ -1,3 +1,4 @@ +use std::net::IpAddr; use std::time::Duration; const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB @@ -13,6 +14,7 @@ pub(crate) struct ConnectorConfig { pub(crate) limit: usize, pub(crate) conn_window_size: u32, pub(crate) stream_window_size: u32, + pub(crate) local_address: Option, } impl Default for ConnectorConfig { @@ -25,6 +27,7 @@ impl Default for ConnectorConfig { limit: 100, conn_window_size: DEFAULT_H2_CONN_WINDOW, stream_window_size: DEFAULT_H2_STREAM_WINDOW, + local_address: None, } } } diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs index 8aa5b1319..1a926fd6c 100644 --- a/actix-http/src/client/connector.rs +++ b/actix-http/src/client/connector.rs @@ -1,9 +1,12 @@ -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; +use std::{ + fmt, + future::Future, + marker::PhantomData, + net::IpAddr, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::TcpStream; @@ -240,6 +243,12 @@ where self } + /// Set local IP Address the connector would use for establishing connection. + pub fn local_address(mut self, addr: IpAddr) -> Self { + self.config.local_address = Some(addr); + self + } + /// Finish configuration process and create connector service. /// The Connector builder always concludes by calling `finish()` last in /// its combinator chain. @@ -247,10 +256,19 @@ where self, ) -> impl Service + Clone { + let local_address = self.config.local_address; + let timeout = self.config.timeout; + let tcp_service = TimeoutService::new( - self.config.timeout, - apply_fn(self.connector.clone(), |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + timeout, + apply_fn(self.connector.clone(), move |msg: Connect, srv| { + let mut req = TcpConnect::new(msg.uri).set_addr(msg.addr); + + if let Some(local_addr) = local_address { + req = req.set_local_addr(local_addr); + } + + srv.call(req) }) .map_err(ConnectError::from) .map(|stream| (stream.into_parts().0, Protocol::Http1)), @@ -294,10 +312,16 @@ where use actix_tls::connect::ssl::rustls::{RustlsConnector, Session}; let ssl_service = TimeoutService::new( - self.config.timeout, + timeout, pipeline( - apply_fn(self.connector.clone(), |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + apply_fn(self.connector.clone(), move |msg: Connect, srv| { + let mut req = TcpConnect::new(msg.uri).set_addr(msg.addr); + + if let Some(local_addr) = local_address { + req = req.set_local_addr(local_addr); + } + + srv.call(req) }) .map_err(ConnectError::from), ) diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index b79453ebd..51303886b 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -94,10 +94,10 @@ mod openssl { use super::*; use actix_service::ServiceFactoryExt; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; use actix_tls::accept::TlsError; - impl H1Service, S, B, X, U> + impl H1Service, S, B, X, U> where S: ServiceFactory, S::Error: Into, @@ -108,7 +108,7 @@ mod openssl { X::Error: Into, X::InitError: fmt::Debug, U: ServiceFactory< - (Request, Framed, Codec>), + (Request, Framed, Codec>), Config = (), Response = (), >, @@ -131,7 +131,7 @@ mod openssl { .map_err(TlsError::Tls) .map_init_err(|_| panic!()), ) - .and_then(|io: SslStream| { + .and_then(|io: TlsStream| { let peer_addr = io.get_ref().peer_addr().ok(); ready(Ok((io, peer_addr))) }) diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index e00c8d968..0984b3f23 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -93,12 +93,12 @@ where #[cfg(feature = "openssl")] mod openssl { use actix_service::{fn_factory, fn_service, ServiceFactoryExt}; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; use actix_tls::accept::TlsError; use super::*; - impl H2Service, S, B> + impl H2Service, S, B> where S: ServiceFactory, S::Error: Into + 'static, @@ -123,7 +123,7 @@ mod openssl { .map_init_err(|_| panic!()), ) .and_then(fn_factory(|| { - ok::<_, S::InitError>(fn_service(|io: SslStream| { + ok::<_, S::InitError>(fn_service(|io: TlsStream| { let peer_addr = io.get_ref().peer_addr().ok(); ok((io, peer_addr)) })) diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index fee26dcc3..402affb7e 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -185,10 +185,10 @@ where mod openssl { use super::*; use actix_service::ServiceFactoryExt; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; use actix_tls::accept::TlsError; - impl HttpService, S, B, X, U> + impl HttpService, S, B, X, U> where S: ServiceFactory, S::Error: Into + 'static, @@ -201,13 +201,13 @@ mod openssl { X::InitError: fmt::Debug, >::Future: 'static, U: ServiceFactory< - (Request, Framed, h1::Codec>), + (Request, Framed, h1::Codec>), Config = (), Response = (), >, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, - , h1::Codec>)>>::Future: 'static, + , h1::Codec>)>>::Future: 'static, { /// Create openssl based service pub fn openssl( @@ -225,7 +225,7 @@ mod openssl { .map_err(TlsError::Tls) .map_init_err(|_| panic!()), ) - .and_then(|io: SslStream| async { + .and_then(|io: TlsStream| async { let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { if protos.windows(2).any(|window| window == b"h2") { Protocol::Http2 diff --git a/awc/CHANGES.md b/awc/CHANGES.md index e6ead2cc8..04cf4f3f3 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -3,6 +3,7 @@ ## Unreleased - 2021-xx-xx ### Added * `ClientResponse::timeout` for set the timeout of collecting response body. [#1931] +* `ClientBuilder::local_address` for bind to a local ip address for this client. ### Changed * Feature `cookies` is now optional and enabled by default. [#1981] @@ -16,6 +17,7 @@ [#1981]: https://github.com/actix/actix-web/pull/1981 [#2008]: https://github.com/actix/actix-web/pull/2008 + ## 3.0.0-beta.2 - 2021-02-10 ### Added * `ClientRequest::insert_header` method which allows using typed headers. [#1869] diff --git a/awc/src/builder.rs b/awc/src/builder.rs index 4495b39fd..b7cdefd40 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -1,5 +1,6 @@ use std::convert::TryFrom; use std::fmt; +use std::net::IpAddr; use std::rc::Rc; use std::time::Duration; @@ -25,6 +26,7 @@ pub struct ClientBuilder { conn_window_size: Option, headers: HeaderMap, timeout: Option, + local_address: Option, connector: Connector, } @@ -42,6 +44,7 @@ impl ClientBuilder { default_headers: true, headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), + local_address: None, connector: Connector::new(), max_http_version: None, stream_window_size: None, @@ -72,6 +75,7 @@ where default_headers: self.default_headers, headers: self.headers, timeout: self.timeout, + local_address: None, connector, max_http_version: self.max_http_version, stream_window_size: self.stream_window_size, @@ -94,6 +98,12 @@ where self } + /// Set local IP Address the connector would use for establishing connection. + pub fn local_address(mut self, addr: IpAddr) -> Self { + self.local_address = Some(addr); + self + } + /// Maximum supported HTTP major version. /// /// Supported versions are HTTP/1.1 and HTTP/2. @@ -184,6 +194,9 @@ where if let Some(val) = self.stream_window_size { connector = connector.initial_window_size(val) }; + if let Some(val) = self.local_address { + connector = connector.local_address(val); + } let config = ClientConfig { headers: self.headers, diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index a41a8dac3..c7fa82de8 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::io::{Read, Write}; +use std::net::{IpAddr, Ipv4Addr}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -871,3 +872,34 @@ async fn client_bearer_auth() { let response = request.send().await.unwrap(); assert!(response.status().is_success()); } + +#[actix_rt::test] +async fn test_local_address() { + let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + + let srv = test::start(move || { + App::new().service(web::resource("/").route(web::to( + move |req: HttpRequest| async move { + assert_eq!(req.peer_addr().unwrap().ip(), ip); + Ok::<_, Error>(HttpResponse::Ok()) + }, + ))) + }); + let client = awc::Client::builder().local_address(ip).finish(); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status(), 200); + + let client = awc::Client::builder() + .connector( + // connector local address setting should always be override by client builder. + awc::Connector::new().local_address(IpAddr::V4(Ipv4Addr::new(128, 0, 0, 1))), + ) + .local_address(ip) + .finish(); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status(), 200); +}