diff --git a/actix-files/README.md b/actix-files/README.md index 463f20224..a4f0445aa 100644 --- a/actix-files/README.md +++ b/actix-files/README.md @@ -14,6 +14,6 @@ ## Documentation & Resources - [API Documentation](https://docs.rs/actix-files/) -- [Example Project](https://github.com/actix/examples/tree/master/static_index) +- [Example Project](https://github.com/actix/examples/tree/master/basics/static_index) - [Chat on Gitter](https://gitter.im/actix/actix-web) - Minimum supported Rust version: 1.46 or later diff --git a/actix-files/src/lib.rs b/actix-files/src/lib.rs index 04dd9f07f..3c34c0403 100644 --- a/actix-files/src/lib.rs +++ b/actix-files/src/lib.rs @@ -662,8 +662,12 @@ mod tests { #[actix_rt::test] async fn test_static_files_bad_directory() { - let _st: Files = Files::new("/", "missing"); - let _st: Files = Files::new("/", "Cargo.toml"); + let service = Files::new("/", "./missing").new_service(()).await.unwrap(); + + let req = TestRequest::with_uri("/").to_srv_request(); + let resp = test::call_service(&service, req).await; + + assert_eq!(resp.status(), StatusCode::NOT_FOUND); } #[actix_rt::test] @@ -676,75 +680,34 @@ mod tests { .await .unwrap(); let req = TestRequest::with_uri("/missing").to_srv_request(); - let resp = test::call_service(&st, req).await; + assert_eq!(resp.status(), StatusCode::OK); let bytes = test::read_body(resp).await; assert_eq!(bytes, web::Bytes::from_static(b"default content")); } - // #[actix_rt::test] - // async fn test_serve_index() { - // let st = Files::new(".").index_file("test.binary"); - // let req = TestRequest::default().uri("/tests").finish(); + #[actix_rt::test] + async fn test_serve_index_nested() { + let service = Files::new(".", ".") + .index_file("lib.rs") + .new_service(()) + .await + .unwrap(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::OK); - // assert_eq!( - // resp.headers() - // .get(header::CONTENT_TYPE) - // .expect("content type"), - // "application/octet-stream" - // ); - // assert_eq!( - // resp.headers() - // .get(header::CONTENT_DISPOSITION) - // .expect("content disposition"), - // "attachment; filename=\"test.binary\"" - // ); + let req = TestRequest::default().uri("/src").to_srv_request(); + let resp = test::call_service(&service, req).await; - // let req = TestRequest::default().uri("/tests/").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::OK); - // assert_eq!( - // resp.headers().get(header::CONTENT_TYPE).unwrap(), - // "application/octet-stream" - // ); - // assert_eq!( - // resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - // "attachment; filename=\"test.binary\"" - // ); - - // // nonexistent index file - // let req = TestRequest::default().uri("/tests/unknown").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::NOT_FOUND); - - // let req = TestRequest::default().uri("/tests/unknown/").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::NOT_FOUND); - // } - - // #[actix_rt::test] - // async fn test_serve_index_nested() { - // let st = Files::new(".").index_file("mod.rs"); - // let req = TestRequest::default().uri("/src/client").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::OK); - // assert_eq!( - // resp.headers().get(header::CONTENT_TYPE).unwrap(), - // "text/x-rust" - // ); - // assert_eq!( - // resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - // "inline; filename=\"mod.rs\"" - // ); - // } + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-rust" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"lib.rs\"" + ); + } #[actix_rt::test] async fn integration_serve_index() { diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 6ba111eb3..165b004a6 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -3,6 +3,7 @@ ## Unreleased - 2021-xx-xx ### Changed * Feature `cookies` is now optional and disabled by default. [#1981] +* `ws::hash_key` now returns array. [#2035] ### Removed * re-export of `futures_channel::oneshot::Canceled` is removed from `error` mod. [#1994] @@ -10,6 +11,7 @@ [#1981]: https://github.com/actix/actix-web/pull/1981 [#1994]: https://github.com/actix/actix-web/pull/1994 +[#2035]: https://github.com/actix/actix-web/pull/2035 ## 3.0.0-beta.3 - 2021-02-10 diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index f78901697..c79ad11b2 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -103,6 +103,10 @@ version = "0.10.9" package = "openssl" features = ["vendored"] +[[example]] +name = "ws" +required-features = ["rustls"] + [[bench]] name = "write-camel-case" harness = false diff --git a/actix-http/examples/ws.rs b/actix-http/examples/ws.rs new file mode 100644 index 000000000..4e03aa8ab --- /dev/null +++ b/actix-http/examples/ws.rs @@ -0,0 +1,107 @@ +//! Sets up a WebSocket server over TCP and TLS. +//! Sends a heartbeat message every 4 seconds but does not respond to any incoming frames. + +extern crate tls_rustls as rustls; + +use std::{ + env, io, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use actix_codec::Encoder; +use actix_http::{error::Error, ws, HttpService, Request, Response}; +use actix_rt::time::{interval, Interval}; +use actix_server::Server; +use bytes::{Bytes, BytesMut}; +use bytestring::ByteString; +use futures_core::{ready, Stream}; + +#[actix_rt::main] +async fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "actix=info,h2_ws=info"); + env_logger::init(); + + Server::build() + .bind("tcp", ("127.0.0.1", 8080), || { + HttpService::build().h1(handler).tcp() + })? + .bind("tls", ("127.0.0.1", 8443), || { + HttpService::build().finish(handler).rustls(tls_config()) + })? + .run() + .await +} + +async fn handler(req: Request) -> Result { + log::info!("handshaking"); + let mut res = ws::handshake(req.head())?; + + // handshake will always fail under HTTP/2 + + log::info!("responding"); + Ok(res.streaming(Heartbeat::new(ws::Codec::new()))) +} + +struct Heartbeat { + codec: ws::Codec, + interval: Interval, +} + +impl Heartbeat { + fn new(codec: ws::Codec) -> Self { + Self { + codec, + interval: interval(Duration::from_secs(4)), + } + } +} + +impl Stream for Heartbeat { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + log::trace!("poll"); + + ready!(self.as_mut().interval.poll_tick(cx)); + + let mut buffer = BytesMut::new(); + + self.as_mut() + .codec + .encode( + ws::Message::Text(ByteString::from_static("hello world")), + &mut buffer, + ) + .unwrap(); + + Poll::Ready(Some(Ok(buffer.freeze()))) + } +} + +fn tls_config() -> rustls::ServerConfig { + use std::io::BufReader; + + use rustls::{ + internal::pemfile::{certs, pkcs8_private_keys}, + NoClientAuth, ServerConfig, + }; + + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + + let mut config = ServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(cert_file.as_bytes()); + let key_file = &mut BufReader::new(key_file.as_bytes()); + + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + config +} diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 0490163d5..cec73db96 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -1,4 +1,4 @@ -//! WebSocket protocol. +//! WebSocket protocol implementation. //! //! To setup a WebSocket, first perform the WebSocket handshake then on success convert `Payload` into a //! `WsStream` stream and then use `WsWriter` to communicate with the peer. @@ -8,9 +8,12 @@ use std::io; use derive_more::{Display, Error, From}; use http::{header, Method, StatusCode}; -use crate::error::ResponseError; -use crate::message::RequestHead; -use crate::response::{Response, ResponseBuilder}; +use crate::{ + error::ResponseError, + header::HeaderValue, + message::RequestHead, + response::{Response, ResponseBuilder}, +}; mod codec; mod dispatcher; @@ -89,7 +92,7 @@ pub enum HandshakeError { NoVersionHeader, /// Unsupported WebSocket version. - #[display(fmt = "Unsupported version.")] + #[display(fmt = "Unsupported WebSocket version.")] UnsupportedVersion, /// WebSocket key is not set or wrong. @@ -105,19 +108,19 @@ impl ResponseError for HandshakeError { .finish(), HandshakeError::NoWebsocketUpgrade => Response::BadRequest() - .reason("No WebSocket UPGRADE header found") + .reason("No WebSocket Upgrade header found") .finish(), HandshakeError::NoConnectionUpgrade => Response::BadRequest() - .reason("No CONNECTION upgrade") + .reason("No Connection upgrade") .finish(), HandshakeError::NoVersionHeader => Response::BadRequest() - .reason("Websocket version header is required") + .reason("WebSocket version header is required") .finish(), HandshakeError::UnsupportedVersion => Response::BadRequest() - .reason("Unsupported version") + .reason("Unsupported WebSocket version") .finish(), HandshakeError::BadWebsocketKey => { @@ -193,7 +196,11 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { Response::build(StatusCode::SWITCHING_PROTOCOLS) .upgrade("websocket") .insert_header((header::TRANSFER_ENCODING, "chunked")) - .insert_header((header::SEC_WEBSOCKET_ACCEPT, key)) + .insert_header(( + header::SEC_WEBSOCKET_ACCEPT, + // key is known to be header value safe ascii + HeaderValue::from_bytes(&key).unwrap(), + )) .take() } diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index 1e8bf7af3..fdcde5eac 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -1,5 +1,7 @@ -use std::convert::{From, Into}; -use std::fmt; +use std::{ + convert::{From, Into}, + fmt, +}; /// Operation codes as part of RFC6455. #[derive(Debug, Eq, PartialEq, Clone, Copy)] @@ -28,8 +30,9 @@ pub enum OpCode { impl fmt::Display for OpCode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use self::OpCode::*; - match *self { + use OpCode::*; + + match self { Continue => write!(f, "CONTINUE"), Text => write!(f, "TEXT"), Binary => write!(f, "BINARY"), @@ -44,6 +47,7 @@ impl fmt::Display for OpCode { impl From for u8 { fn from(op: OpCode) -> u8 { use self::OpCode::*; + match op { Continue => 0, Text => 1, @@ -62,6 +66,7 @@ impl From for u8 { impl From for OpCode { fn from(byte: u8) -> OpCode { use self::OpCode::*; + match byte { 0 => Continue, 1 => Text, @@ -77,63 +82,66 @@ impl From for OpCode { /// Status code used to indicate why an endpoint is closing the WebSocket connection. #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum CloseCode { - /// Indicates a normal closure, meaning that the purpose for - /// which the connection was established has been fulfilled. + /// Indicates a normal closure, meaning that the purpose for which the connection was + /// established has been fulfilled. Normal, - /// Indicates that an endpoint is "going away", such as a server - /// going down or a browser having navigated away from a page. + + /// Indicates that an endpoint is "going away", such as a server going down or a browser having + /// navigated away from a page. Away, - /// Indicates that an endpoint is terminating the connection due - /// to a protocol error. + + /// Indicates that an endpoint is terminating the connection due to a protocol error. Protocol, - /// Indicates that an endpoint is terminating the connection - /// because it has received a type of data it cannot accept (e.g., an - /// endpoint that understands only text data MAY send this if it + + /// Indicates that an endpoint is terminating the connection because it has received a type of + /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if it /// receives a binary message). Unsupported, - /// Indicates an abnormal closure. If the abnormal closure was due to an - /// error, this close code will not be used. Instead, the `on_error` method - /// of the handler will be called with the error. However, if the connection - /// is simply dropped, without an error, this close code will be sent to the - /// handler. + + /// Indicates an abnormal closure. If the abnormal closure was due to an error, this close code + /// will not be used. Instead, the `on_error` method of the handler will be called with + /// the error. However, if the connection is simply dropped, without an error, this close code + /// will be sent to the handler. Abnormal, - /// Indicates that an endpoint is terminating the connection - /// because it has received data within a message that was not - /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] + + /// Indicates that an endpoint is terminating the connection because it has received data within + /// a message that was not consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] /// data within a text message). Invalid, - /// Indicates that an endpoint is terminating the connection - /// because it has received a message that violates its policy. This - /// is a generic status code that can be returned when there is no - /// other more suitable status code (e.g., Unsupported or Size) or if there - /// is a need to hide specific details about the policy. + + /// Indicates that an endpoint is terminating the connection because it has received a message + /// that violates its policy. This is a generic status code that can be returned when there is + /// no other more suitable status code (e.g., Unsupported or Size) or if there is a need to hide + /// specific details about the policy. Policy, - /// Indicates that an endpoint is terminating the connection - /// because it has received a message that is too big for it to - /// process. + + /// Indicates that an endpoint is terminating the connection because it has received a message + /// that is too big for it to process. Size, - /// Indicates that an endpoint (client) is terminating the - /// connection because it has expected the server to negotiate one or - /// more extension, but the server didn't return them in the response - /// message of the WebSocket handshake. The list of extensions that - /// are needed should be given as the reason for closing. - /// Note that this status code is not used by the server, because it - /// can fail the WebSocket handshake instead. + + /// Indicates that an endpoint (client) is terminating the connection because it has expected + /// the server to negotiate one or more extension, but the server didn't return them in the + /// response message of the WebSocket handshake. The list of extensions that are needed should + /// be given as the reason for closing. Note that this status code is not used by the server, + /// because it can fail the WebSocket handshake instead. Extension, - /// Indicates that a server is terminating the connection because - /// it encountered an unexpected condition that prevented it from - /// fulfilling the request. + + /// Indicates that a server is terminating the connection because it encountered an unexpected + /// condition that prevented it from fulfilling the request. Error, - /// Indicates that the server is restarting. A client may choose to - /// reconnect, and if it does, it should use a randomized delay of 5-30 - /// seconds between attempts. + + /// Indicates that the server is restarting. A client may choose to reconnect, and if it does, + /// it should use a randomized delay of 5-30 seconds between attempts. Restart, - /// Indicates that the server is overloaded and the client should either - /// connect to a different IP (when multiple targets exist), or - /// reconnect to the same IP when a user has performed an action. + + /// Indicates that the server is overloaded and the client should either connect to a different + /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed + /// an action. Again, + #[doc(hidden)] Tls, + #[doc(hidden)] Other(u16), } @@ -141,6 +149,7 @@ pub enum CloseCode { impl From for u16 { fn from(code: CloseCode) -> u16 { use self::CloseCode::*; + match code { Normal => 1000, Away => 1001, @@ -163,6 +172,7 @@ impl From for u16 { impl From for CloseCode { fn from(code: u16) -> CloseCode { use self::CloseCode::*; + match code { 1000 => Normal, 1001 => Away, @@ -210,17 +220,29 @@ impl> From<(CloseCode, T)> for CloseReason { } } -static WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +/// The WebSocket GUID as stated in the spec. See https://tools.ietf.org/html/rfc6455#section-1.3. +static WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; -// TODO: hash is always same size, we don't need String -pub fn hash_key(key: &[u8]) -> String { - use sha1::Digest; - let mut hasher = sha1::Sha1::new(); +/// Hashes the `Sec-WebSocket-Key` header according to the WebSocket spec. +/// +/// Result is a Base64 encoded byte array. `base64(sha1(input))` is always 28 bytes. +pub fn hash_key(key: &[u8]) -> [u8; 28] { + let hash = { + use sha1::Digest as _; - hasher.update(key); - hasher.update(WS_GUID.as_bytes()); + let mut hasher = sha1::Sha1::new(); - base64::encode(&hasher.finalize()) + hasher.update(key); + hasher.update(WS_GUID); + + hasher.finalize() + }; + + let mut hash_b64 = [0; 28]; + let n = base64::encode_config_slice(&hash, base64::STANDARD, &mut hash_b64); + assert_eq!(n, 28); + + hash_b64 } #[cfg(test)] @@ -288,11 +310,11 @@ mod test { #[test] fn test_hash_key() { let hash = hash_key(b"hello actix-web"); - assert_eq!(&hash, "cR1dlyUUJKp0s/Bel25u5TgvC3E="); + assert_eq!(&hash, b"cR1dlyUUJKp0s/Bel25u5TgvC3E="); } #[test] - fn closecode_from_u16() { + fn close_code_from_u16() { assert_eq!(CloseCode::from(1000u16), CloseCode::Normal); assert_eq!(CloseCode::from(1001u16), CloseCode::Away); assert_eq!(CloseCode::from(1002u16), CloseCode::Protocol); @@ -310,7 +332,7 @@ mod test { } #[test] - fn closecode_into_u16() { + fn close_code_into_u16() { assert_eq!(1000u16, Into::::into(CloseCode::Normal)); assert_eq!(1001u16, Into::::into(CloseCode::Away)); assert_eq!(1002u16, Into::::into(CloseCode::Protocol)); diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index 1ab4cfce5..de2802d21 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -15,10 +15,13 @@ use actix::{ SpawnHandle, }; use actix_codec::{Decoder, Encoder}; -use actix_http::ws::{hash_key, Codec}; pub use actix_http::ws::{ CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError, }; +use actix_http::{ + http::HeaderValue, + ws::{hash_key, Codec}, +}; use actix_web::dev::HttpResponseBuilder; use actix_web::error::{Error, PayloadError}; use actix_web::http::{header, Method, StatusCode}; @@ -162,7 +165,11 @@ pub fn handshake_with_protocols( let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) .upgrade("websocket") - .insert_header((header::SEC_WEBSOCKET_ACCEPT, key)) + .insert_header(( + header::SEC_WEBSOCKET_ACCEPT, + // key is known to be header value safe ascii + HeaderValue::from_bytes(&key).unwrap(), + )) .take(); if let Some(protocol) = protocol { diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 8cbba432c..ca345d3cb 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -47,7 +47,7 @@ trust-dns = ["actix-http/trust-dns"] actix-codec = "0.4.0-beta.1" actix-service = "2.0.0-beta.4" actix-http = "3.0.0-beta.3" -actix-rt = "2.1" +actix-rt = { version = "2.1", default-features = false } base64 = "0.13" bytes = "1" @@ -57,6 +57,7 @@ futures-core = { version = "0.3.7", default-features = false } log =" 0.4" mime = "0.3" percent-encoding = "2.1" +pin-project-lite = "0.2" rand = "0.8" serde = "1.0" serde_json = "1.0" diff --git a/awc/README.md b/awc/README.md index 043ae6a41..1f6e3b8fb 100644 --- a/awc/README.md +++ b/awc/README.md @@ -11,11 +11,12 @@ ## Documentation & Resources - [API Documentation](https://docs.rs/awc) -- [Example Project](https://github.com/actix/examples/tree/HEAD/awc_https) +- [Example Project](https://github.com/actix/examples/tree/HEAD/security/awc_https) - [Chat on Gitter](https://gitter.im/actix/actix-web) - Minimum Supported Rust Version (MSRV): 1.46.0 ## Example + ```rust use actix_rt::System; use awc::Client; diff --git a/awc/src/builder.rs b/awc/src/builder.rs index b7cdefd40..363056c02 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -10,24 +10,27 @@ use actix_http::{ http::{self, header, Error as HttpError, HeaderMap, HeaderName, Uri}, }; use actix_rt::net::TcpStream; -use actix_service::Service; +use actix_service::{boxed, Service}; -use crate::connect::ConnectorWrapper; -use crate::{Client, ClientConfig}; +use crate::connect::DefaultConnector; +use crate::error::SendRequestError; +use crate::middleware::{NestTransform, Transform}; +use crate::{Client, ClientConfig, ConnectRequest, ConnectResponse, ConnectorService}; /// An HTTP Client builder /// /// This type can be used to construct an instance of `Client` through a /// builder-like pattern. -pub struct ClientBuilder { +pub struct ClientBuilder { default_headers: bool, max_http_version: Option, stream_window_size: Option, conn_window_size: Option, headers: HeaderMap, timeout: Option, + connector: Connector, + middleware: M, local_address: Option, - connector: Connector, } impl ClientBuilder { @@ -39,8 +42,10 @@ impl ClientBuilder { Error = TcpConnectError, > + Clone, TcpStream, + (), > { ClientBuilder { + middleware: (), default_headers: true, headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), @@ -53,7 +58,7 @@ impl ClientBuilder { } } -impl ClientBuilder +impl ClientBuilder where S: Service, Response = TcpConnection, Error = TcpConnectError> + Clone @@ -61,7 +66,7 @@ where Io: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { /// Use custom connector service. - pub fn connector(self, connector: Connector) -> ClientBuilder + pub fn connector(self, connector: Connector) -> ClientBuilder where S1: Service< TcpConnect, @@ -72,10 +77,11 @@ where Io1: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { ClientBuilder { + middleware: self.middleware, default_headers: self.default_headers, headers: self.headers, timeout: self.timeout, - local_address: None, + local_address: self.local_address, connector, max_http_version: self.max_http_version, stream_window_size: self.stream_window_size, @@ -181,8 +187,38 @@ where self.header(header::AUTHORIZATION, format!("Bearer {}", token)) } + /// Registers middleware, in the form of a middleware component (type), + /// that runs during inbound and/or outbound processing in the request + /// life-cycle (request -> response), modifying request/response as + /// necessary, across all requests managed by the Client. + pub fn wrap( + self, + mw: M1, + ) -> ClientBuilder> + where + M: Transform, + M1: Transform, + { + ClientBuilder { + middleware: NestTransform::new(self.middleware, mw), + default_headers: self.default_headers, + max_http_version: self.max_http_version, + stream_window_size: self.stream_window_size, + conn_window_size: self.conn_window_size, + headers: self.headers, + timeout: self.timeout, + connector: self.connector, + local_address: self.local_address, + } + } + /// Finish build process and create `Client` instance. - pub fn finish(self) -> Client { + pub fn finish(self) -> Client + where + M: Transform + 'static, + M::Transform: + Service, + { let mut connector = self.connector; if let Some(val) = self.max_http_version { @@ -198,10 +234,13 @@ where connector = connector.local_address(val); } + let connector = boxed::service(DefaultConnector::new(connector.finish())); + let connector = boxed::service(self.middleware.new_transform(connector)); + let config = ClientConfig { headers: self.headers, timeout: self.timeout, - connector: Box::new(ConnectorWrapper::new(connector.finish())) as _, + connector, }; Client(Rc::new(config)) diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 97af2d1cc..a4abbc46b 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,5 +1,7 @@ use std::{ - fmt, io, net, + fmt, + future::Future, + io, net, pin::Pin, task::{Context, Poll}, }; @@ -9,24 +11,14 @@ use actix_http::{ body::Body, client::{Connect as ClientConnect, ConnectError, Connection, SendRequestError}, h1::ClientCodec, - RequestHead, RequestHeadType, ResponseHead, + Payload, RequestHead, RequestHeadType, ResponseHead, }; use actix_service::Service; -use futures_core::future::LocalBoxFuture; +use futures_core::{future::LocalBoxFuture, ready}; use crate::response::ClientResponse; -pub(crate) struct ConnectorWrapper { - connector: T, -} - -impl ConnectorWrapper { - pub(crate) fn new(connector: T) -> Self { - Self { connector } - } -} - -pub type ConnectService = Box< +pub type ConnectorService = Box< dyn Service< ConnectRequest, Response = ConnectResponse, @@ -65,16 +57,25 @@ impl ConnectResponse { } } -impl Service for ConnectorWrapper +pub(crate) struct DefaultConnector { + connector: S, +} + +impl DefaultConnector { + pub(crate) fn new(connector: S) -> Self { + Self { connector } + } +} + +impl Service for DefaultConnector where - T: Service, - T::Response: Connection, - ::Io: 'static, - T::Future: 'static, + S: Service, + S::Response: Connection, + ::Io: 'static, { type Response = ConnectResponse; type Error = SendRequestError; - type Future = LocalBoxFuture<'static, Result>; + type Future = ConnectRequestFuture::Io>; actix_service::forward_ready!(connector); @@ -91,26 +92,76 @@ where }), }; - Box::pin(async move { - let connection = fut.await?; + ConnectRequestFuture::Connection { + fut, + req: Some(req), + } + } +} - match req { - ConnectRequest::Client(head, body, ..) => { - // send request - let (head, payload) = connection.send_request(head, body).await?; +pin_project_lite::pin_project! { + #[project = ConnectRequestProj] + pub(crate) enum ConnectRequestFuture { + Connection { + #[pin] + fut: Fut, + req: Option + }, + Client { + fut: LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + }, + Tunnel { + fut: LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + >, + } + } +} - Ok(ConnectResponse::Client(ClientResponse::new(head, payload))) - } - ConnectRequest::Tunnel(head, ..) => { - // send request - let (head, framed) = - connection.open_tunnel(RequestHeadType::from(head)).await?; - - let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); - Ok(ConnectResponse::Tunnel(head, framed)) +impl Future for ConnectRequestFuture +where + Fut: Future>, + C: Connection, + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + ConnectRequestProj::Connection { fut, req } => { + let connection = ready!(fut.poll(cx))?; + let req = req.take().unwrap(); + match req { + ConnectRequest::Client(head, body, ..) => { + // send request + let fut = ConnectRequestFuture::Client { + fut: connection.send_request(head, body), + }; + self.as_mut().set(fut); + } + ConnectRequest::Tunnel(head, ..) => { + // send request + let fut = ConnectRequestFuture::Tunnel { + fut: connection.open_tunnel(RequestHeadType::from(head)), + }; + self.as_mut().set(fut); + } } + self.poll(cx) } - }) + ConnectRequestProj::Client { fut } => { + let (head, payload) = ready!(fut.as_mut().poll(cx))?; + Poll::Ready(Ok(ConnectResponse::Client(ClientResponse::new( + head, payload, + )))) + } + ConnectRequestProj::Tunnel { fut } => { + let (head, framed) = ready!(fut.as_mut().poll(cx))?; + let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed))) + } + } } } diff --git a/awc/src/error.rs b/awc/src/error.rs index f86224e62..b715f6213 100644 --- a/awc/src/error.rs +++ b/awc/src/error.rs @@ -18,24 +18,31 @@ pub enum WsClientError { /// Invalid response status #[display(fmt = "Invalid response status")] InvalidResponseStatus(StatusCode), + /// Invalid upgrade header #[display(fmt = "Invalid upgrade header")] InvalidUpgradeHeader, + /// Invalid connection header #[display(fmt = "Invalid connection header")] InvalidConnectionHeader(HeaderValue), - /// Missing CONNECTION header - #[display(fmt = "Missing CONNECTION header")] + + /// Missing Connection header + #[display(fmt = "Missing Connection header")] MissingConnectionHeader, - /// Missing SEC-WEBSOCKET-ACCEPT header - #[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")] + + /// Missing Sec-Websocket-Accept header + #[display(fmt = "Missing Sec-Websocket-Accept header")] MissingWebSocketAcceptHeader, + /// Invalid challenge response #[display(fmt = "Invalid challenge response")] - InvalidChallengeResponse(String, HeaderValue), + InvalidChallengeResponse([u8; 28], HeaderValue), + /// Protocol error #[display(fmt = "{}", _0)] Protocol(WsProtocolError), + /// Send request error #[display(fmt = "{}", _0)] SendRequest(SendRequestError), diff --git a/awc/src/lib.rs b/awc/src/lib.rs index 66ff55402..2f48dca79 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -107,12 +107,13 @@ use actix_http::{ RequestHead, }; use actix_rt::net::TcpStream; -use actix_service::Service; +use actix_service::{boxed, Service}; mod builder; mod connect; pub mod error; mod frozen; +pub mod middleware; mod request; mod response; mod sender; @@ -120,14 +121,12 @@ pub mod test; pub mod ws; pub use self::builder::ClientBuilder; -pub use self::connect::{BoxedSocket, ConnectRequest, ConnectResponse, ConnectService}; +pub use self::connect::{BoxedSocket, ConnectRequest, ConnectResponse, ConnectorService}; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; pub use self::response::{ClientResponse, JsonBody, MessageBody}; pub use self::sender::SendClientRequest; -use self::connect::ConnectorWrapper; - /// An asynchronous HTTP and WebSocket client. /// /// ## Examples @@ -151,7 +150,7 @@ use self::connect::ConnectorWrapper; pub struct Client(Rc); pub(crate) struct ClientConfig { - pub(crate) connector: ConnectService, + pub(crate) connector: ConnectorService, pub(crate) headers: HeaderMap, pub(crate) timeout: Option, } @@ -159,7 +158,9 @@ pub(crate) struct ClientConfig { impl Default for Client { fn default() -> Self { Client(Rc::new(ClientConfig { - connector: Box::new(ConnectorWrapper::new(Connector::new().finish())), + connector: boxed::service(self::connect::DefaultConnector::new( + Connector::new().finish(), + )), headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), })) diff --git a/awc/src/middleware/mod.rs b/awc/src/middleware/mod.rs new file mode 100644 index 000000000..330e3b7fe --- /dev/null +++ b/awc/src/middleware/mod.rs @@ -0,0 +1,71 @@ +mod redirect; + +pub use self::redirect::Redirect; + +use std::marker::PhantomData; + +use actix_service::Service; + +/// Trait for transform a type to another one. +/// Both the input and output type should impl [actix_service::Service] trait. +pub trait Transform { + type Transform: Service; + + /// Creates and returns a new Transform component. + fn new_transform(self, service: S) -> Self::Transform; +} + +#[doc(hidden)] +/// Helper struct for constructing Nested types that would call `Transform::new_transform` +/// in a chain. +/// +/// The child field would be called first and the output `Service` type is +/// passed to parent as input type. +pub struct NestTransform +where + T1: Transform, + T2: Transform, +{ + child: T1, + parent: T2, + _service: PhantomData<(S, Req)>, +} + +impl NestTransform +where + T1: Transform, + T2: Transform, +{ + pub(crate) fn new(child: T1, parent: T2) -> Self { + NestTransform { + child, + parent, + _service: PhantomData, + } + } +} + +impl Transform for NestTransform +where + T1: Transform, + T2: Transform, +{ + type Transform = T2::Transform; + + fn new_transform(self, service: S) -> Self::Transform { + let service = self.child.new_transform(service); + self.parent.new_transform(service) + } +} + +/// Dummy impl for kick start `NestTransform` type in `ClientBuilder` type +impl Transform for () +where + S: Service, +{ + type Transform = S; + + fn new_transform(self, service: S) -> Self::Transform { + service + } +} diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs new file mode 100644 index 000000000..1d0ace166 --- /dev/null +++ b/awc/src/middleware/redirect.rs @@ -0,0 +1,350 @@ +use std::{ + convert::TryFrom, + future::Future, + net::SocketAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; + +use actix_http::{ + body::Body, + client::{InvalidUrl, SendRequestError}, + http::{header, Method, StatusCode, Uri}, + RequestHead, RequestHeadType, +}; +use actix_service::Service; +use bytes::Bytes; +use futures_core::ready; + +use super::Transform; + +use crate::connect::{ConnectRequest, ConnectResponse}; +use crate::ClientResponse; + +pub struct Redirect { + max_redirect_times: u8, +} + +impl Default for Redirect { + fn default() -> Self { + Self::new() + } +} + +impl Redirect { + pub fn new() -> Self { + Self { + max_redirect_times: 10, + } + } + + pub fn max_redirect_times(mut self, times: u8) -> Self { + self.max_redirect_times = times; + self + } +} + +impl Transform for Redirect +where + S: Service + 'static, +{ + type Transform = RedirectService; + + fn new_transform(self, service: S) -> Self::Transform { + RedirectService { + max_redirect_times: self.max_redirect_times, + connector: Rc::new(service), + } + } +} + +pub struct RedirectService { + max_redirect_times: u8, + connector: Rc, +} + +impl Service for RedirectService +where + S: Service + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = RedirectServiceFuture; + + actix_service::forward_ready!(connector); + + fn call(&self, req: ConnectRequest) -> Self::Future { + match req { + ConnectRequest::Tunnel(head, addr) => { + let fut = self.connector.call(ConnectRequest::Tunnel(head, addr)); + RedirectServiceFuture::Tunnel { fut } + } + ConnectRequest::Client(head, body, addr) => { + let connector = self.connector.clone(); + let max_redirect_times = self.max_redirect_times; + + // backup the uri and method for reuse schema and authority. + let (uri, method) = match head { + RequestHeadType::Owned(ref head) => (head.uri.clone(), head.method.clone()), + RequestHeadType::Rc(ref head, ..) => { + (head.uri.clone(), head.method.clone()) + } + }; + + let body_opt = match body { + Body::Bytes(ref b) => Some(b.clone()), + _ => None, + }; + + let fut = connector.call(ConnectRequest::Client(head, body, addr)); + + RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + body: body_opt, + addr, + connector: Some(connector), + } + } + } + } +} + +pin_project_lite::pin_project! { + #[project = RedirectServiceProj] + pub enum RedirectServiceFuture + where + S: Service, + S: 'static + { + Tunnel { #[pin] fut: S::Future }, + Client { + #[pin] + fut: S::Future, + max_redirect_times: u8, + uri: Option, + method: Option, + body: Option, + addr: Option, + connector: Option> + } + } +} + +impl Future for RedirectServiceFuture +where + S: Service + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + RedirectServiceProj::Tunnel { fut } => fut.poll(cx), + RedirectServiceProj::Client { + fut, + max_redirect_times, + uri, + method, + body, + addr, + connector, + } => match ready!(fut.poll(cx))? { + ConnectResponse::Client(res) => match res.head().status { + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::SEE_OTHER + if *max_redirect_times > 0 => + { + let org_uri = uri.take().unwrap(); + // rebuild uri from the location header value. + let uri = rebuild_uri(&res, org_uri)?; + + // reset method + let method = method.take().unwrap(); + let method = match method { + Method::GET | Method::HEAD => method, + _ => Method::GET, + }; + + // take ownership of states that could be reused + let addr = addr.take(); + let connector = connector.take(); + let mut max_redirect_times = *max_redirect_times; + + // use a new request head. + let mut head = RequestHead::default(); + head.uri = uri.clone(); + head.method = method.clone(); + + let head = RequestHeadType::Owned(head); + + max_redirect_times -= 1; + + let fut = connector + .as_ref() + .unwrap() + // remove body + .call(ConnectRequest::Client(head, Body::None, addr)); + + self.as_mut().set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + // body is dropped on 301,302,303 + body: None, + addr, + connector, + }); + + self.poll(cx) + } + StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT + if *max_redirect_times > 0 => + { + let org_uri = uri.take().unwrap(); + // rebuild uri from the location header value. + let uri = rebuild_uri(&res, org_uri)?; + + // try to reuse body + let body = body.take(); + let body_new = match body { + Some(ref bytes) => Body::Bytes(bytes.clone()), + // TODO: should this be Body::Empty or Body::None. + _ => Body::Empty, + }; + + let addr = addr.take(); + let method = method.take().unwrap(); + let connector = connector.take(); + let mut max_redirect_times = *max_redirect_times; + + // use a new request head. + let mut head = RequestHead::default(); + head.uri = uri.clone(); + head.method = method.clone(); + + let head = RequestHeadType::Owned(head); + + max_redirect_times -= 1; + + let fut = connector + .as_ref() + .unwrap() + .call(ConnectRequest::Client(head, body_new, addr)); + + self.as_mut().set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + body, + addr, + connector, + }); + + self.poll(cx) + } + _ => Poll::Ready(Ok(ConnectResponse::Client(res))), + }, + _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"), + }, + } + } +} + +fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result { + let uri = res + .headers() + .get(header::LOCATION) + .map(|value| { + // try to parse the location to a full uri + let uri = Uri::try_from(value.as_bytes()) + .map_err(|e| SendRequestError::Url(InvalidUrl::HttpError(e.into())))?; + if uri.scheme().is_none() || uri.authority().is_none() { + let uri = Uri::builder() + .scheme(org_uri.scheme().cloned().unwrap()) + .authority(org_uri.authority().cloned().unwrap()) + .path_and_query(value.as_bytes()) + .build()?; + Ok::<_, SendRequestError>(uri) + } else { + Ok(uri) + } + }) + // TODO: this error type is wrong. + .ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))??; + + Ok(uri) +} + +#[cfg(test)] +mod tests { + use actix_web::{test::start, web, App, Error, HttpResponse}; + + use super::*; + + use crate::ClientBuilder; + + #[actix_rt::test] + async fn test_basic_redirect() { + let client = ClientBuilder::new() + .connector(crate::Connector::new()) + .wrap(Redirect::new().max_redirect_times(10)) + .finish(); + + let srv = start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::BadRequest()) + }))) + .service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test")) + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 400); + } + + #[actix_rt::test] + async fn test_redirect_limit() { + let client = ClientBuilder::new() + .wrap(Redirect::new().max_redirect_times(1)) + .connector(crate::Connector::new()) + .finish(); + + let srv = start(|| { + App::new() + .service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test")) + .finish(), + ) + }))) + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test2")) + .finish(), + ) + }))) + .service(web::resource("/test2").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::BadRequest()) + }))) + }); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 302); + } +} diff --git a/awc/src/response.rs b/awc/src/response.rs index 514b8a90b..40de3dc17 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -492,9 +492,7 @@ mod tests { JsonPayloadError::Payload(PayloadError::Overflow) => { matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) } - JsonPayloadError::ContentType => { - matches!(other, JsonPayloadError::ContentType) - } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), _ => false, } } diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 5f4570963..1aa426ac7 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -381,12 +381,14 @@ impl WebsocketsRequest { if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { let encoded = ws::hash_key(key.as_ref()); - if hdr_key.as_bytes() != encoded.as_bytes() { + + if hdr_key.as_bytes() != &encoded { log::trace!( - "Invalid challenge response: expected: {} received: {:?}", - encoded, + "Invalid challenge response: expected: {:?} received: {:?}", + &encoded, key ); + return Err(WsClientError::InvalidChallengeResponse( encoded, hdr_key.clone(), diff --git a/codecov.yml b/codecov.yml index e6bc40203..e45672bfc 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,10 +4,10 @@ coverage: status: project: default: - threshold: 10% # make CI green + threshold: 100% # make CI green patch: default: - threshold: 10% # make CI green + threshold: 100% # make CI green ignore: # ignore code coverage on following paths - "**/tests" diff --git a/examples/on_connect.rs b/examples/on_connect.rs index ba5a18f3f..24ac86c6b 100644 --- a/examples/on_connect.rs +++ b/examples/on_connect.rs @@ -2,7 +2,7 @@ //! properties and pass them to a handler through request-local data. //! //! For an example of extracting a client TLS certificate, see: -//! +//! use std::{any::Any, io, net::SocketAddr}; diff --git a/src/types/form.rs b/src/types/form.rs index 0b5c3c1b4..4f3ecbe7c 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -400,12 +400,8 @@ mod tests { UrlencodedError::Overflow { .. } => { matches!(other, UrlencodedError::Overflow { .. }) } - UrlencodedError::UnknownLength => { - matches!(other, UrlencodedError::UnknownLength) - } - UrlencodedError::ContentType => { - matches!(other, UrlencodedError::ContentType) - } + UrlencodedError::UnknownLength => matches!(other, UrlencodedError::UnknownLength), + UrlencodedError::ContentType => matches!(other, UrlencodedError::ContentType), _ => false, } } diff --git a/src/types/json.rs b/src/types/json.rs index 31ff680f4..866d835f2 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -441,9 +441,7 @@ mod tests { fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { match err { JsonPayloadError::Overflow => matches!(other, JsonPayloadError::Overflow), - JsonPayloadError::ContentType => { - matches!(other, JsonPayloadError::ContentType) - } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), _ => false, } }