diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index 9354fca4a..047319470 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -1,4 +1,5 @@ use std::{fmt, io, time}; +use std::rc::Rc; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{Buf, Bytes}; @@ -10,6 +11,7 @@ use crate::body::MessageBody; use crate::h1::ClientCodec; use crate::message::{RequestHead, ResponseHead}; use crate::payload::Payload; +use crate::header::HeaderMap; use super::error::SendRequestError; use super::pool::{Acquired, Protocol}; @@ -29,7 +31,8 @@ pub trait Connection { /// Send request and body fn send_request( self, - head: RequestHead, + head: Rc, + additional_headers: Option, body: B, ) -> Self::Future; @@ -39,7 +42,10 @@ pub trait Connection { >; /// Send request, returns Response and Framed - fn open_tunnel(self, head: RequestHead) -> Self::TunnelFuture; + fn open_tunnel(self, + head: Rc, + additional_headers: Option, + ) -> Self::TunnelFuture; } pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { @@ -106,13 +112,15 @@ where fn send_request( mut self, - head: RequestHead, + head: Rc, + additional_headers: Option, body: B, ) -> Self::Future { match self.io.take().unwrap() { ConnectionType::H1(io) => Box::new(h1proto::send_request( io, head, + additional_headers, body, self.created, self.pool, @@ -120,6 +128,7 @@ where ConnectionType::H2(io) => Box::new(h2proto::send_request( io, head, + additional_headers, body, self.created, self.pool, @@ -138,10 +147,10 @@ where >; /// Send request, returns Response and Framed - fn open_tunnel(mut self, head: RequestHead) -> Self::TunnelFuture { + fn open_tunnel(mut self, head: Rc, additional_headers: Option) -> Self::TunnelFuture { match self.io.take().unwrap() { ConnectionType::H1(io) => { - Either::A(Box::new(h1proto::open_tunnel(io, head))) + Either::A(Box::new(h1proto::open_tunnel(io, head, additional_headers))) } ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { @@ -180,12 +189,13 @@ where fn send_request( self, - head: RequestHead, + head: Rc, + additional_headers: Option, body: RB, ) -> Self::Future { match self { - EitherConnection::A(con) => con.send_request(head, body), - EitherConnection::B(con) => con.send_request(head, body), + EitherConnection::A(con) => con.send_request(head, additional_headers, body), + EitherConnection::B(con) => con.send_request(head, additional_headers, body), } } @@ -197,14 +207,14 @@ where >; /// Send request, returns Response and Framed - fn open_tunnel(self, head: RequestHead) -> Self::TunnelFuture { + fn open_tunnel(self, head: Rc, additional_headers: Option) -> Self::TunnelFuture { match self { EitherConnection::A(con) => Box::new( - con.open_tunnel(head) + con.open_tunnel(head, additional_headers) .map(|(head, framed)| (head, framed.map_io(EitherIo::A))), ), EitherConnection::B(con) => Box::new( - con.open_tunnel(head) + con.open_tunnel(head, additional_headers) .map(|(head, framed)| (head, framed.map_io(EitherIo::B))), ), } diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index 97ed3bbc7..e178fe4ac 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -1,5 +1,6 @@ use std::io::Write; use std::{io, time}; +use std::rc::Rc; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{BufMut, Bytes, BytesMut}; @@ -11,6 +12,7 @@ use crate::h1; use crate::http::header::{IntoHeaderValue, HOST}; use crate::message::{RequestHead, ResponseHead}; use crate::payload::{Payload, PayloadStream}; +use crate::header::HeaderMap; use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; use super::error::{ConnectError, SendRequestError}; @@ -19,7 +21,8 @@ use crate::body::{BodySize, MessageBody}; pub(crate) fn send_request( io: T, - mut head: RequestHead, + head: Rc, + additional_headers: Option, body: B, created: time::Instant, pool: Option>, @@ -29,7 +32,7 @@ where B: MessageBody, { // set request host header - if !head.headers.contains_key(HOST) { + let additional_headers = if !head.headers.contains_key(HOST) && !additional_headers.iter().any(|h| h.contains_key(HOST)) { if let Some(host) = head.uri.host() { let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); @@ -40,14 +43,23 @@ where match wrt.get_mut().take().freeze().try_into() { Ok(value) => { - head.headers.insert(HOST, value); + let mut headers = additional_headers.unwrap_or(HeaderMap::new()); + headers.insert(HOST, value); + Some(headers) } Err(e) => { log::error!("Can not set HOST header {}", e); + additional_headers } } } + else { + additional_headers + } } + else { + additional_headers + }; let io = H1Connection { created, @@ -59,7 +71,7 @@ where // create Framed and send reqest Framed::new(io, h1::ClientCodec::default()) - .send((head, len).into()) + .send((head, additional_headers, len).into()) .from_err() // send request body .and_then(move |framed| match body.size() { @@ -95,14 +107,15 @@ where pub(crate) fn open_tunnel( io: T, - head: RequestHead, + head: Rc, + additional_headers: Option, ) -> impl Future), Error = SendRequestError> where T: AsyncRead + AsyncWrite + 'static, { // create Framed and send reqest Framed::new(io, h1::ClientCodec::default()) - .send((head, BodySize::None).into()) + .send((head, additional_headers, BodySize::None).into()) .from_err() // read response .and_then(|framed| { diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 91240268e..98ff973cc 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -1,4 +1,5 @@ use std::time; +use std::rc::Rc; use actix_codec::{AsyncRead, AsyncWrite}; use bytes::Bytes; @@ -11,6 +12,7 @@ use http::{request::Request, HttpTryFrom, Method, Version}; use crate::body::{BodySize, MessageBody}; use crate::message::{RequestHead, ResponseHead}; use crate::payload::Payload; +use crate::header::HeaderMap; use super::connection::{ConnectionType, IoConnection}; use super::error::SendRequestError; @@ -18,7 +20,8 @@ use super::pool::Acquired; pub(crate) fn send_request( io: SendRequest, - head: RequestHead, + head: Rc, + additional_headers: Option, body: B, created: time::Instant, pool: Option>, @@ -39,8 +42,8 @@ where .map_err(SendRequestError::from) .and_then(move |mut io| { let mut req = Request::new(()); - *req.uri_mut() = head.uri; - *req.method_mut() = head.method; + *req.uri_mut() = head.uri.clone(); + *req.method_mut() = head.method.clone(); *req.version_mut() = Version::HTTP_2; let mut skip_len = true; @@ -66,8 +69,16 @@ where ), }; + // merging headers from head and additional headers. HeaderMap::new() does not allocate. + let additional_headers = additional_headers.unwrap_or(HeaderMap::new()); + let headers = head.headers.iter() + .filter(|(name, _)| { + !additional_headers.contains_key(*name) + }) + .chain(additional_headers.iter()); + // copy headers - for (key, value) in head.headers.iter() { + for (key, value) in headers { match *key { CONNECTION | TRANSFER_ENCODING => continue, // http2 specific CONTENT_LENGTH if skip_len => continue, diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs index f93bc496a..59f85e49c 100644 --- a/actix-http/src/h1/client.rs +++ b/actix-http/src/h1/client.rs @@ -1,5 +1,6 @@ #![allow(unused_imports, unused_variables, dead_code)] use std::io::{self, Write}; +use std::rc::Rc; use actix_codec::{Decoder, Encoder}; use bitflags::bitflags; @@ -17,6 +18,7 @@ use crate::config::ServiceConfig; use crate::error::{ParseError, PayloadError}; use crate::helpers; use crate::message::{ConnectionType, Head, MessagePool, RequestHead, ResponseHead}; +use crate::header::HeaderMap; bitflags! { struct Flags: u8 { @@ -48,7 +50,7 @@ struct ClientCodecInner { // encoder part flags: Flags, headers_size: u32, - encoder: encoder::MessageEncoder, + encoder: encoder::MessageEncoder<(Rc, Option)>, } impl Default for ClientCodec { @@ -183,7 +185,7 @@ impl Decoder for ClientPayloadCodec { } impl Encoder for ClientCodec { - type Item = Message<(RequestHead, BodySize)>; + type Item = Message<(Rc, Option, BodySize)>; type Error = io::Error; fn encode( @@ -192,13 +194,13 @@ impl Encoder for ClientCodec { dst: &mut BytesMut, ) -> Result<(), Self::Error> { match item { - Message::Item((mut msg, length)) => { + Message::Item((head, additional_headers, length)) => { let inner = &mut self.inner; - inner.version = msg.version; - inner.flags.set(Flags::HEAD, msg.method == Method::HEAD); + inner.version = head.version; + inner.flags.set(Flags::HEAD, head.method == Method::HEAD); // connection status - inner.ctype = match msg.connection_type() { + inner.ctype = match head.connection_type() { ConnectionType::KeepAlive => { if inner.flags.contains(Flags::KEEPALIVE_ENABLED) { ConnectionType::KeepAlive @@ -212,7 +214,7 @@ impl Encoder for ClientCodec { inner.encoder.encode( dst, - &mut msg, + &mut (head, additional_headers), false, false, inner.version, diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 61ca48b1d..147fa0c5e 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -4,6 +4,7 @@ use std::io::Write; use std::marker::PhantomData; use std::str::FromStr; use std::{cmp, fmt, io, mem}; +use std::rc::Rc; use bytes::{BufMut, Bytes, BytesMut}; @@ -247,31 +248,32 @@ impl MessageType for Response<()> { } } -impl MessageType for RequestHead { +impl MessageType for (Rc, Option) { fn status(&self) -> Option { None } fn chunked(&self) -> bool { - self.chunked() + self.0.chunked() } fn camel_case(&self) -> bool { - RequestHead::camel_case_headers(self) + RequestHead::camel_case_headers(&self.0) } fn headers(&self) -> &HeaderMap { - &self.headers + &self.0.headers } fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { - dst.reserve(256 + self.headers.len() * AVERAGE_HEADER_SIZE); + let head = &self.0; + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); write!( Writer(dst), "{} {} {}", - self.method, - self.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), - match self.version { + head.method, + head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), + match head.version { Version::HTTP_09 => "HTTP/0.9", Version::HTTP_10 => "HTTP/1.0", Version::HTTP_11 => "HTTP/1.1", diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 4b564d777..a981b18f9 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,4 +1,5 @@ use std::{fmt, io, net}; +use std::rc::Rc; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::body::Body; @@ -7,6 +8,7 @@ use actix_http::client::{ }; use actix_http::h1::ClientCodec; use actix_http::{RequestHead, ResponseHead}; +use actix_http::http::HeaderMap; use actix_service::Service; use futures::{Future, Poll}; @@ -17,7 +19,8 @@ pub(crate) struct ConnectorWrapper(pub T); pub(crate) trait Connect { fn send_request( &mut self, - head: RequestHead, + head: Rc, + additional_headers: Option, body: Body, addr: Option, ) -> Box>; @@ -25,7 +28,8 @@ pub(crate) trait Connect { /// Send request, returns Response and Framed fn open_tunnel( &mut self, - head: RequestHead, + head: Rc, + additional_headers: Option, addr: Option, ) -> Box< Future< @@ -46,7 +50,8 @@ where { fn send_request( &mut self, - head: RequestHead, + head: Rc, + additional_headers: Option, body: Body, addr: Option, ) -> Box> { @@ -59,14 +64,15 @@ where }) .from_err() // send request - .and_then(move |connection| connection.send_request(head, body)) + .and_then(move |connection| connection.send_request(head, additional_headers, body)) .map(|(head, payload)| ClientResponse::new(head, payload)), ) } fn open_tunnel( &mut self, - head: RequestHead, + head: Rc, + additional_headers: Option, addr: Option, ) -> Box< Future< @@ -83,7 +89,7 @@ where }) .from_err() // send request - .and_then(move |connection| connection.open_tunnel(head)) + .and_then(move |connection| connection.open_tunnel(head, additional_headers)) .map(|(head, framed)| { let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); (head, framed) diff --git a/awc/src/request.rs b/awc/src/request.rs index 36cd6fcf3..c6dc9c256 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -453,7 +453,7 @@ impl ClientRequest { let fut = config .connector .borrow_mut() - .send_request(head, body.into(), slf.addr) + .send_request(Rc::new(head), None, body.into(), slf.addr) .map(move |res| { res.map_body(|head, payload| { if response_decompress { diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 95bf6ef70..0277110c1 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -286,7 +286,7 @@ impl WebsocketsRequest { .config .connector .borrow_mut() - .open_tunnel(head, self.addr) + .open_tunnel(Rc::new(head), None, self.addr) .from_err() .and_then(move |(head, framed)| { // verify response