diff --git a/.travis.yml b/.travis.yml index 82db86a6e..683f77cc5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,9 +10,9 @@ matrix: include: - rust: stable - rust: beta - - rust: nightly-2019-08-10 + - rust: nightly-2019-11-20 allow_failures: - - rust: nightly-2019-08-10 + - rust: nightly-2019-11-20 env: global: @@ -25,7 +25,7 @@ before_install: - sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev before_cache: | - if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin fi @@ -37,8 +37,8 @@ script: - cargo update - cargo check --all --no-default-features - cargo test --all-features --all -- --nocapture - - cd actix-http; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd .. - - cd awc; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd .. + # - cd actix-http; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. + # - cd awc; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. # Upload docs after_success: @@ -51,7 +51,7 @@ after_success: echo "Uploaded documentation" fi - | - if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then taskset -c 0 cargo tarpaulin --out Xml --all --all-features bash <(curl -s https://codecov.io/bash) echo "Uploaded code coverage" diff --git a/CHANGES.md b/CHANGES.md index 4ff7d1e66..bb17a7efc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,11 +1,16 @@ # Changes -## [1.0.9] - 2019-xx-xx +## [1.0.9] - 2019-11-14 ### Added * Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110) +### Changed + +* Support `Host` guards when the `Host` header is unset (e.g. HTTP/2 requests) (#1129) + + ## [1.0.8] - 2019-09-25 ### Added diff --git a/Cargo.toml b/Cargo.toml index 35ca28b2c..7524bff73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "1.0.8" +version = "2.0.0-alpha.1" authors = ["Nikolay Kim "] description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." readme = "README.md" @@ -16,7 +16,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] edition = "2018" [package.metadata.docs.rs] -features = ["ssl", "brotli", "flate2-zlib", "secure-cookies", "client", "rust-tls", "uds"] +features = ["openssl", "rustls", "brotli", "flate2-zlib", "secure-cookies", "client"] [badges] travis-ci = { repository = "actix/actix-web", branch = "master" } @@ -63,37 +63,35 @@ secure-cookies = ["actix-http/secure-cookies"] fail = ["actix-http/fail"] # openssl -ssl = ["openssl", "actix-server/ssl", "awc/ssl"] +openssl = ["open-ssl", "actix-server/openssl", "awc/openssl"] # rustls -rust-tls = ["rustls", "actix-server/rust-tls", "awc/rust-tls"] - -# unix domain sockets support -uds = ["actix-server/uds"] +# rustls = ["rust-tls", "actix-server/rustls", "awc/rustls"] [dependencies] -actix-codec = "0.1.2" -actix-service = "0.4.1" -actix-utils = "0.4.4" +actix-codec = "0.2.0-alpha.1" +actix-service = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" actix-router = "0.1.5" -actix-rt = "0.2.4" -actix-web-codegen = "0.1.2" -actix-http = "0.2.9" -actix-server = "0.6.1" -actix-server-config = "0.1.2" -actix-testing = "0.1.0" -actix-threadpool = "0.1.1" -awc = { version = "0.2.7", optional = true } +actix-rt = "1.0.0-alpha.1" +actix-web-codegen = "0.2.0-alpha.1" +actix-http = "0.3.0-alpha.1" +actix-server = "0.8.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" +actix-testing = "0.3.0-alpha.1" +actix-threadpool = "0.2.0-alpha.1" +awc = { version = "0.3.0-alpha.1", optional = true } bytes = "0.4" derive_more = "0.15.0" encoding_rs = "0.8" -futures = "0.1.25" -hashbrown = "0.5.0" +futures = "0.3.1" +hashbrown = "0.6.3" log = "0.4" mime = "0.3" net2 = "0.2.33" parking_lot = "0.9" +pin-project = "0.4.5" regex = "1.0" serde = { version = "1.0", features=["derive"] } serde_json = "1.0" @@ -102,17 +100,17 @@ time = "0.1.42" url = "2.1" # ssl support -openssl = { version="0.10", optional = true } -rustls = { version = "0.15", optional = true } +open-ssl = { version="0.10", package="openssl", optional = true } +rust-tls = { version = "0.16", package="rustls", optional = true } [dev-dependencies] -actix = "0.8.3" -actix-connect = "0.2.2" -actix-http-test = "0.2.4" +# actix = "0.8.3" +actix-connect = "0.3.0-alpha.1" +actix-http-test = "0.3.0-alpha.1" rand = "0.7" env_logger = "0.6" serde_derive = "1.0" -tokio-timer = "0.2.8" +tokio-timer = "0.3.0-alpha.6" brotli2 = "0.3.2" flate2 = "1.0.2" @@ -126,8 +124,28 @@ actix-web = { path = "." } actix-http = { path = "actix-http" } actix-http-test = { path = "test-server" } actix-web-codegen = { path = "actix-web-codegen" } -actix-web-actors = { path = "actix-web-actors" } +# actix-web-actors = { path = "actix-web-actors" } actix-session = { path = "actix-session" } actix-files = { path = "actix-files" } actix-multipart = { path = "actix-multipart" } awc = { path = "awc" } + +actix-codec = { git = "https://github.com/actix/actix-net.git" } +actix-connect = { git = "https://github.com/actix/actix-net.git" } +actix-rt = { git = "https://github.com/actix/actix-net.git" } +actix-server = { git = "https://github.com/actix/actix-net.git" } +actix-server-config = { git = "https://github.com/actix/actix-net.git" } +actix-service = { git = "https://github.com/actix/actix-net.git" } +actix-testing = { git = "https://github.com/actix/actix-net.git" } +actix-threadpool = { git = "https://github.com/actix/actix-net.git" } +actix-utils = { git = "https://github.com/actix/actix-net.git" } + +# actix-codec = { path = "../actix-net/actix-codec" } +# actix-connect = { path = "../actix-net/actix-connect" } +# actix-rt = { path = "../actix-net/actix-rt" } +# actix-server = { path = "../actix-net/actix-server" } +# actix-server-config = { path = "../actix-net/actix-server-config" } +# actix-service = { path = "../actix-net/actix-service" } +# actix-testing = { path = "../actix-net/actix-testing" } +# actix-threadpool = { path = "../actix-net/actix-threadpool" } +# actix-utils = { path = "../actix-net/actix-utils" } diff --git a/MIGRATION.md b/MIGRATION.md index 2f0f369ad..675dc61ed 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,3 +1,10 @@ +## 2.0.0 + +* Sync handlers has been removed. `.to_async()` method has been renamed to `.to()` + + replace `fn` with `async fn` to convert sync handler to async + + ## 1.0.1 * Cors middleware has been moved to `actix-cors` crate diff --git a/README.md b/README.md index 99b7b1760..00bb3ec4e 100644 --- a/README.md +++ b/README.md @@ -19,17 +19,16 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust. * [User Guide](https://actix.rs/docs/) * [API Documentation (1.0)](https://docs.rs/actix-web/) -* [API Documentation (0.7)](https://docs.rs/actix-web/0.7.19/actix_web/) * [Chat on gitter](https://gitter.im/actix/actix) * Cargo package: [actix-web](https://crates.io/crates/actix-web) -* Minimum supported Rust version: 1.36 or later +* Minimum supported Rust version: 1.39 or later ## Example ```rust use actix_web::{web, App, HttpServer, Responder}; -fn index(info: web::Path<(u32, String)>) -> impl Responder { +async fn index(info: web::Path<(u32, String)>) -> impl Responder { format!("Hello {}! id:{}", info.1, info.0) } diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml index 091c94044..57aa5833a 100644 --- a/actix-cors/Cargo.toml +++ b/actix-cors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-cors" -version = "0.1.0" +version = "0.2.0-alpha.1" authors = ["Nikolay Kim "] description = "Cross-origin resource sharing (CORS) for Actix applications." readme = "README.md" @@ -10,14 +10,14 @@ repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-cors/" license = "MIT/Apache-2.0" edition = "2018" -#workspace = ".." +workspace = ".." [lib] name = "actix_cors" path = "src/lib.rs" [dependencies] -actix-web = "1.0.0" -actix-service = "0.4.0" +actix-web = "2.0.0-alpha.1" +actix-service = "1.0.0-alpha.1" derive_more = "0.15.0" -futures = "0.1.25" +futures = "0.3.1" diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index c76bae925..db7e4cc4f 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -11,7 +11,7 @@ //! use actix_cors::Cors; //! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; //! -//! fn index(req: HttpRequest) -> &'static str { +//! async fn index(req: HttpRequest) -> &'static str { //! "Hello world" //! } //! @@ -23,7 +23,8 @@ //! .allowed_methods(vec!["GET", "POST"]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_header(http::header::CONTENT_TYPE) -//! .max_age(3600)) +//! .max_age(3600) +//! .finish()) //! .service( //! web::resource("/index.html") //! .route(web::get().to(index)) @@ -41,16 +42,16 @@ use std::collections::HashSet; use std::iter::FromIterator; use std::rc::Rc; +use std::task::{Context, Poll}; -use actix_service::{IntoTransform, Service, Transform}; +use actix_service::{Service, Transform}; use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; use actix_web::error::{Error, ResponseError, Result}; use actix_web::http::header::{self, HeaderName, HeaderValue}; use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri}; use actix_web::HttpResponse; use derive_more::Display; -use futures::future::{ok, Either, Future, FutureResult}; -use futures::Poll; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; /// A set of errors that can occur during processing CORS #[derive(Debug, Display)] @@ -456,25 +457,9 @@ impl Cors { } self } -} -fn cors<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut Inner> { - if err.is_some() { - return None; - } - parts.as_mut() -} - -impl IntoTransform for Cors -where - S: Service, Error = Error>, - S::Future: 'static, - B: 'static, -{ - fn into_transform(self) -> CorsFactory { + /// Construct cors middleware + pub fn finish(self) -> CorsFactory { let mut slf = if !self.methods { self.allowed_methods(vec![ Method::GET, @@ -521,6 +506,16 @@ where } } +fn cors<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + parts.as_mut() +} + /// `Middleware` for Cross-origin resource sharing support /// /// The Cors struct contains the settings for CORS requests to be validated and @@ -540,7 +535,7 @@ where type Error = Error; type InitError = (); type Transform = CorsMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CorsMiddleware { @@ -682,12 +677,12 @@ where type Response = ServiceResponse; type Error = Error; type Future = Either< - FutureResult, - Either>>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { @@ -698,7 +693,7 @@ where .and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head())) { - return Either::A(ok(req.error_response(e))); + return Either::Left(ok(req.error_response(e))); } // allowed headers @@ -751,39 +746,50 @@ where .finish() .into_body(); - Either::A(ok(req.into_response(res))) - } else if req.headers().contains_key(&header::ORIGIN) { - // Only check requests with a origin header. - if let Err(e) = self.inner.validate_origin(req.head()) { - return Either::A(ok(req.error_response(e))); + Either::Left(ok(req.into_response(res))) + } else { + if req.headers().contains_key(&header::ORIGIN) { + // Only check requests with a origin header. + if let Err(e) = self.inner.validate_origin(req.head()) { + return Either::Left(ok(req.error_response(e))); + } } let inner = self.inner.clone(); + let has_origin = req.headers().contains_key(&header::ORIGIN); + let fut = self.service.call(req); - Either::B(Either::B(Box::new(self.service.call(req).and_then( - move |mut res| { - if let Some(origin) = - inner.access_control_allow_origin(res.request().head()) - { - res.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); - }; + Either::Right( + async move { + let res = fut.await; - if let Some(ref expose) = inner.expose_hdrs { - res.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - HeaderValue::try_from(expose.as_str()).unwrap(), - ); - } - if inner.supports_credentials { - res.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } - if inner.vary_header { - let value = - if let Some(hdr) = res.headers_mut().get(&header::VARY) { + if has_origin { + let mut res = res?; + if let Some(origin) = + inner.access_control_allow_origin(res.request().head()) + { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.clone(), + ); + }; + + if let Some(ref expose) = inner.expose_hdrs { + res.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap(), + ); + } + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { + let value = if let Some(hdr) = + res.headers_mut().get(&header::VARY) + { let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); val.extend(hdr.as_bytes()); @@ -792,159 +798,153 @@ where } else { HeaderValue::from_static("Origin") }; - res.headers_mut().insert(header::VARY, value); + res.headers_mut().insert(header::VARY, value); + } + Ok(res) + } else { + res } - Ok(res) - }, - )))) - } else { - Either::B(Either::A(self.service.call(req))) + } + .boxed_local(), + ) } } } #[cfg(test)] mod tests { - use actix_service::{IntoService, Transform}; + use actix_service::{service_fn2, Transform}; use actix_web::test::{self, block_on, TestRequest}; use super::*; - impl Cors { - fn finish(self, srv: F) -> CorsMiddleware - where - F: IntoService, - S: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - > + 'static, - S::Future: 'static, - B: 'static, - { - block_on( - IntoTransform::::into_transform(self) - .new_transform(srv.into_service()), - ) - .unwrap() - } - } - #[test] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] fn cors_validates_illegal_allow_credentials() { - let _cors = Cors::new() - .supports_credentials() - .send_wildcard() - .finish(test::ok_service()); + let _cors = Cors::new().supports_credentials().send_wildcard().finish(); } #[test] fn validate_origin_allows_all_origins() { - let mut cors = Cors::new().finish(test::ok_service()); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); + block_on(async { + let mut cors = Cors::new() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn default() { - let mut cors = - block_on(Cors::default().new_transform(test::ok_service())).unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); + block_on(async { + let mut cors = Cors::default() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_preflight() { - let mut cors = Cors::new() - .send_wildcard() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) - .allowed_header(header::CONTENT_TYPE) - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .send_wildcard() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") + .to_srv_request(); - assert!(cors.inner.validate_allowed_method(req.head()).is_err()); - assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") + .method(Method::OPTIONS) + .to_srv_request(); - assert!(cors.inner.validate_allowed_method(req.head()).is_err()); - assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"*"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"3600"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_MAX_AGE) + .unwrap() + .as_bytes() + ); + let hdr = resp + .headers() + .get(&header::ACCESS_CONTROL_ALLOW_HEADERS) .unwrap() - .as_bytes() - ); - assert_eq!( - &b"3600"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_MAX_AGE) + .to_str() + .unwrap(); + assert!(hdr.contains("authorization")); + assert!(hdr.contains("accept")); + assert!(hdr.contains("content-type")); + + let methods = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_METHODS) .unwrap() - .as_bytes() - ); - let hdr = resp - .headers() - .get(&header::ACCESS_CONTROL_ALLOW_HEADERS) - .unwrap() - .to_str() - .unwrap(); - assert!(hdr.contains("authorization")); - assert!(hdr.contains("accept")); - assert!(hdr.contains("content-type")); + .to_str() + .unwrap(); + assert!(methods.contains("POST")); + assert!(methods.contains("GET")); + assert!(methods.contains("OPTIONS")); - let methods = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_METHODS) - .unwrap() - .to_str() - .unwrap(); - assert!(methods.contains("POST")); - assert!(methods.contains("GET")); - assert!(methods.contains("OPTIONS")); + Rc::get_mut(&mut cors.inner).unwrap().preflight = false; - Rc::get_mut(&mut cors.inner).unwrap().preflight = false; + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) - .method(Method::OPTIONS) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } // #[test] @@ -960,216 +960,254 @@ mod tests { #[test] #[should_panic(expected = "OriginNotAllowed")] fn test_validate_not_allowed_origin() { - let cors = Cors::new() - .allowed_origin("https://www.example.com") - .finish(test::ok_service()); + block_on(async { + let cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.unknown.com") - .method(Method::GET) - .to_srv_request(); - cors.inner.validate_origin(req.head()).unwrap(); - cors.inner.validate_allowed_method(req.head()).unwrap(); - cors.inner.validate_allowed_headers(req.head()).unwrap(); + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + cors.inner.validate_origin(req.head()).unwrap(); + cors.inner.validate_allowed_method(req.head()).unwrap(); + cors.inner.validate_allowed_headers(req.head()).unwrap(); + }) } #[test] fn test_validate_origin() { - let mut cors = Cors::new() - .allowed_origin("https://www.example.com") - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::GET) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_no_origin_response() { - let mut cors = Cors::new().disable_preflight().finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .disable_preflight() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::default().method(Method::GET).to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert!(resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .is_none()); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://www.example.com"[..], - resp.headers() + let req = TestRequest::default().method(Method::GET).to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert!(resp + .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + .is_none()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://www.example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + }) } #[test] fn test_response() { - let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(exposed_headers.clone()) - .expose_headers(exposed_headers.clone()) - .allowed_header(header::CONTENT_TYPE) - .finish(test::ok_service()); + block_on(async { + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"*"[..], - resp.headers() + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + { + let headers = resp + .headers() + .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) + .unwrap() + .to_str() + .unwrap() + .split(',') + .map(|s| s.trim()) + .collect::>(); + + for h in exposed_headers { + assert!(headers.contains(&h.as_str())); + } + } + + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(service_fn2(|req: ServiceRequest| { + ok(req.into_response( + HttpResponse::Ok().header(header::VARY, "Accept").finish(), + )) + })) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"Accept, Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + let mut cors = Cors::new() + .disable_vary_header() + .allowed_origin("https://www.example.com") + .allowed_origin("https://www.google.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + + let origins_str = resp + .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .unwrap() - .as_bytes() - ); - assert_eq!( - &b"Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() - ); - - { - let headers = resp - .headers() - .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) - .unwrap() .to_str() - .unwrap() - .split(',') - .map(|s| s.trim()) - .collect::>(); + .unwrap(); - for h in exposed_headers { - assert!(headers.contains(&h.as_str())); - } - } - - let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(exposed_headers.clone()) - .expose_headers(exposed_headers.clone()) - .allowed_header(header::CONTENT_TYPE) - .finish(|req: ServiceRequest| { - req.into_response( - HttpResponse::Ok().header(header::VARY, "Accept").finish(), - ) - }); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"Accept, Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() - ); - - let mut cors = Cors::new() - .disable_vary_header() - .allowed_origin("https://www.example.com") - .allowed_origin("https://www.google.com") - .finish(test::ok_service()); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .to_srv_request(); - let resp = test::call_service(&mut cors, req); - - let origins_str = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() - .unwrap(); - - assert_eq!("https://www.example.com", origins_str); + assert_eq!("https://www.example.com", origins_str); + }) } #[test] fn test_multiple_origins() { - let mut cors = Cors::new() - .allowed_origin("https://example.com") - .allowed_origin("https://example.org") - .allowed_methods(vec![Method::GET]) - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://example.com") - .method(Method::GET) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.com") + .method(Method::GET) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); - let req = TestRequest::with_header("Origin", "https://example.org") - .method(Method::GET) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.org") + .method(Method::GET) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + }) } #[test] fn test_multiple_origins_preflight() { - let mut cors = Cors::new() - .allowed_origin("https://example.com") - .allowed_origin("https://example.org") - .allowed_methods(vec![Method::GET]) - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); - let req = TestRequest::with_header("Origin", "https://example.org") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.org") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + }) } } diff --git a/actix-files/CHANGES.md b/actix-files/CHANGES.md index 49ecdbffc..5ec56593c 100644 --- a/actix-files/CHANGES.md +++ b/actix-files/CHANGES.md @@ -1,21 +1,29 @@ # Changes -## [0.1.5] - unreleased +## [0.1.7] - 2019-11-06 + +* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) + +## [0.1.6] - 2019-10-14 + +* Add option to redirect to a slash-ended path `Files` #1132 + +## [0.1.5] - 2019-10-08 * Bump up `mime_guess` crate version to 2.0.1 * Bump up `percent-encoding` crate version to 2.1 +* Allow user defined request guards for `Files` #1113 + ## [0.1.4] - 2019-07-20 * Allow to disable `Content-Disposition` header #686 - ## [0.1.3] - 2019-06-28 * Do not set `Content-Length` header, let actix-http set it #930 - ## [0.1.2] - 2019-06-13 * Content-Length is 0 for NamedFile HEAD request #914 diff --git a/actix-files/Cargo.toml b/actix-files/Cargo.toml index 8f36cddc3..6e33bb412 100644 --- a/actix-files/Cargo.toml +++ b/actix-files/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-files" -version = "0.1.4" +version = "0.2.0-alpha.1" authors = ["Nikolay Kim "] description = "Static files support for actix web." readme = "README.md" @@ -18,12 +18,12 @@ name = "actix_files" path = "src/lib.rs" [dependencies] -actix-web = { version = "1.0.2", default-features = false } -actix-http = "0.2.9" -actix-service = "0.4.1" +actix-web = { version = "2.0.0-alpha.1", default-features = false } +actix-http = "0.3.0-alpha.1" +actix-service = "1.0.0-alpha.1" bitflags = "1" bytes = "0.4" -futures = "0.1.25" +futures = "0.3.1" derive_more = "0.15.0" log = "0.4" mime = "0.3" @@ -32,4 +32,4 @@ percent-encoding = "2.1" v_htmlescape = "0.4" [dev-dependencies] -actix-web = { version = "1.0.2", features=["ssl"] } +actix-web = { version = "2.0.0-alpha.1", features=["openssl"] } diff --git a/actix-files/src/lib.rs b/actix-files/src/lib.rs index c99d3265f..2f8a5c49c 100644 --- a/actix-files/src/lib.rs +++ b/actix-files/src/lib.rs @@ -4,23 +4,28 @@ use std::cell::RefCell; use std::fmt::Write; use std::fs::{DirEntry, File}; +use std::future::Future; use std::io::{Read, Seek}; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::{cmp, io}; use actix_service::boxed::{self, BoxedNewService, BoxedService}; -use actix_service::{IntoNewService, NewService, Service}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_web::dev::{ AppService, HttpServiceFactory, Payload, ResourceDef, ServiceRequest, ServiceResponse, }; -use actix_web::error::{BlockingError, Error, ErrorInternalServerError}; -use actix_web::http::header::DispositionType; -use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder}; +use actix_web::error::{Canceled, Error, ErrorInternalServerError}; +use actix_web::guard::Guard; +use actix_web::http::header::{self, DispositionType}; +use actix_web::http::Method; +use actix_web::{web, FromRequest, HttpRequest, HttpResponse}; use bytes::Bytes; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, Poll, Stream}; +use futures::future::{ok, ready, Either, FutureExt, LocalBoxFuture, Ready}; +use futures::Stream; use mime; use mime_guess::from_ext; use percent_encoding::{utf8_percent_encode, CONTROLS}; @@ -52,32 +57,34 @@ pub struct ChunkedReadFile { size: u64, offset: u64, file: Option, - fut: Option>>>, + fut: Option< + LocalBoxFuture<'static, Result, Canceled>>, + >, counter: u64, } -fn handle_error(err: BlockingError) -> Error { - match err { - BlockingError::Error(err) => err.into(), - BlockingError::Canceled => ErrorInternalServerError("Unexpected error"), - } -} - impl Stream for ChunkedReadFile { - type Item = Bytes; - type Error = Error; + type Item = Result; - fn poll(&mut self) -> Poll, Error> { - if self.fut.is_some() { - return match self.fut.as_mut().unwrap().poll().map_err(handle_error)? { - Async::Ready((file, bytes)) => { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + if let Some(ref mut fut) = self.fut { + return match Pin::new(fut).poll(cx) { + Poll::Ready(Err(_)) => Poll::Ready(Some(Err(ErrorInternalServerError( + "Unexpected error", + ) + .into()))), + Poll::Ready(Ok(Ok((file, bytes)))) => { self.fut.take(); self.file = Some(file); self.offset += bytes.len() as u64; self.counter += bytes.len() as u64; - Ok(Async::Ready(Some(bytes))) + Poll::Ready(Some(Ok(bytes))) } - Async::NotReady => Ok(Async::NotReady), + Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e.into()))), + Poll::Pending => Poll::Pending, }; } @@ -86,22 +93,25 @@ impl Stream for ChunkedReadFile { let counter = self.counter; if size == counter { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { let mut file = self.file.take().expect("Use after completion"); - self.fut = Some(Box::new(web::block(move || { - let max_bytes: usize; - max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; - let mut buf = Vec::with_capacity(max_bytes); - file.seek(io::SeekFrom::Start(offset))?; - let nbytes = - file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; - if nbytes == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()); - } - Ok((file, Bytes::from(buf))) - }))); - self.poll() + self.fut = Some( + web::block(move || { + let max_bytes: usize; + max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + let mut buf = Vec::with_capacity(max_bytes); + file.seek(io::SeekFrom::Start(offset))?; + let nbytes = + file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; + if nbytes == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + Ok((file, Bytes::from(buf))) + }) + .boxed_local(), + ); + self.poll_next(cx) } } } @@ -231,10 +241,12 @@ pub struct Files { directory: PathBuf, index: Option, show_index: bool, + redirect_to_slash: bool, default: Rc>>>, renderer: Rc, mime_override: Option>, file_flags: named::Flags, + guards: Option>>, } impl Clone for Files { @@ -243,11 +255,13 @@ impl Clone for Files { directory: self.directory.clone(), index: self.index.clone(), show_index: self.show_index, + redirect_to_slash: self.redirect_to_slash, default: self.default.clone(), renderer: self.renderer.clone(), file_flags: self.file_flags, path: self.path.clone(), mime_override: self.mime_override.clone(), + guards: self.guards.clone(), } } } @@ -269,10 +283,12 @@ impl Files { directory: dir, index: None, show_index: false, + redirect_to_slash: false, default: Rc::new(RefCell::new(None)), renderer: Rc::new(directory_listing), mime_override: None, file_flags: named::Flags::default(), + guards: None, } } @@ -284,12 +300,19 @@ impl Files { self } + /// Redirects to a slash-ended path when browsing a directory. + /// + /// By default never redirect. + pub fn redirect_to_slash_directory(mut self) -> Self { + self.redirect_to_slash = true; + self + } + /// Set custom directory renderer pub fn files_listing_renderer(mut self, f: F) -> Self where - for<'r, 's> F: - Fn(&'r Directory, &'s HttpRequest) -> Result - + 'static, + for<'r, 's> F: Fn(&'r Directory, &'s HttpRequest) -> Result + + 'static, { self.renderer = Rc::new(f); self @@ -331,6 +354,15 @@ impl Files { self } + /// Specifies custom guards to use for directory listings and files. + /// + /// Default behaviour allows GET and HEAD. + #[inline] + pub fn use_guards(mut self, guards: G) -> Self { + self.guards = Some(Rc::new(Box::new(guards))); + self + } + /// Disable `Content-Disposition` header. /// /// By default Content-Disposition` header is enabled. @@ -343,8 +375,8 @@ impl Files { /// Sets default handler which is used when no matched file could be found. pub fn default_handler(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -352,8 +384,8 @@ impl Files { > + 'static, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|_| ()), + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|_| ()), ))))); self @@ -374,38 +406,41 @@ impl HttpServiceFactory for Files { } } -impl NewService for Files { - type Config = (); +impl ServiceFactory for Files { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; + type Config = (); type Service = FilesService; type InitError = (); - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: &()) -> Self::Future { let mut srv = FilesService { directory: self.directory.clone(), index: self.index.clone(), show_index: self.show_index, + redirect_to_slash: self.redirect_to_slash, default: None, renderer: self.renderer.clone(), mime_override: self.mime_override.clone(), file_flags: self.file_flags, + guards: self.guards.clone(), }; if let Some(ref default) = *self.default.borrow() { - Box::new( - default - .new_service(&()) - .map(move |default| { + default + .new_service(&()) + .map(move |result| match result { + Ok(default) => { srv.default = Some(default); - srv - }) - .map_err(|_| ()), - ) + Ok(srv) + } + Err(_) => Err(()), + }) + .boxed_local() } else { - Box::new(ok(srv)) + ok(srv).boxed_local() } } } @@ -414,10 +449,12 @@ pub struct FilesService { directory: PathBuf, index: Option, show_index: bool, + redirect_to_slash: bool, default: Option, renderer: Rc, mime_override: Option>, file_flags: named::Flags, + guards: Option>>, } impl FilesService { @@ -426,14 +463,14 @@ impl FilesService { e: io::Error, req: ServiceRequest, ) -> Either< - FutureResult, - Box>, + Ready>, + LocalBoxFuture<'static, Result>, > { log::debug!("Files: Failed to handle {}: {}", req.path(), e); if let Some(ref mut default) = self.default { - default.call(req) + Either::Right(default.call(req)) } else { - Either::A(ok(req.error_response(e))) + Either::Left(ok(req.error_response(e))) } } } @@ -443,20 +480,37 @@ impl Service for FilesService { type Response = ServiceResponse; type Error = Error; type Future = Either< - FutureResult, - Box>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - // let (req, pl) = req.into_parts(); + let is_method_valid = if let Some(guard) = &self.guards { + // execute user defined guards + (**guard).check(req.head()) + } else { + // default behaviour + match *req.method() { + Method::HEAD | Method::GET => true, + _ => false, + } + }; + + if !is_method_valid { + return Either::Left(ok(req.into_response( + actix_web::HttpResponse::MethodNotAllowed() + .header(header::CONTENT_TYPE, "text/plain") + .body("Request did not meet this resource's requirements."), + ))); + } let real_path = match PathBufWrp::get_pathbuf(req.match_info().path()) { Ok(item) => item, - Err(e) => return Either::A(ok(req.error_response(e))), + Err(e) => return Either::Left(ok(req.error_response(e))), }; // full filepath @@ -467,6 +521,16 @@ impl Service for FilesService { if path.is_dir() { if let Some(ref redir_index) = self.index { + if self.redirect_to_slash && !req.path().ends_with('/') { + let redirect_to = format!("{}/", req.path()); + return Either::Left(ok(req.into_response( + HttpResponse::Found() + .header(header::LOCATION, redirect_to) + .body("") + .into_body(), + ))); + } + let path = path.join(redir_index); match NamedFile::open(path) { @@ -479,7 +543,7 @@ impl Service for FilesService { named_file.flags = self.file_flags; let (req, _) = req.into_parts(); - Either::A(ok(match named_file.respond_to(&req) { + Either::Left(ok(match named_file.into_response(&req) { Ok(item) => ServiceResponse::new(req, item), Err(e) => ServiceResponse::from_err(e, req), })) @@ -491,11 +555,11 @@ impl Service for FilesService { let (req, _) = req.into_parts(); let x = (self.renderer)(&dir, &req); match x { - Ok(resp) => Either::A(ok(resp)), - Err(e) => Either::A(ok(ServiceResponse::from_err(e, req))), + Ok(resp) => Either::Left(ok(resp)), + Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), } } else { - Either::A(ok(ServiceResponse::from_err( + Either::Left(ok(ServiceResponse::from_err( FilesError::IsDirectory, req.into_parts().0, ))) @@ -511,11 +575,11 @@ impl Service for FilesService { named_file.flags = self.file_flags; let (req, _) = req.into_parts(); - match named_file.respond_to(&req) { + match named_file.into_response(&req) { Ok(item) => { - Either::A(ok(ServiceResponse::new(req.clone(), item))) + Either::Left(ok(ServiceResponse::new(req.clone(), item))) } - Err(e) => Either::A(ok(ServiceResponse::from_err(e, req))), + Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), } } Err(e) => self.handle_err(e, req), @@ -558,11 +622,11 @@ impl PathBufWrp { impl FromRequest for PathBufWrp { type Error = UriSegmentError; - type Future = Result; + type Future = Ready>; type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - PathBufWrp::get_pathbuf(req.match_info().path()) + ready(PathBufWrp::get_pathbuf(req.match_info().path())) } } @@ -573,16 +637,15 @@ mod tests { use std::ops::Add; use std::time::{Duration, SystemTime}; - use bytes::BytesMut; - use super::*; + use actix_web::guard; use actix_web::http::header::{ self, ContentDisposition, DispositionParam, DispositionType, }; use actix_web::http::{Method, StatusCode}; use actix_web::middleware::Compress; - use actix_web::test::{self, TestRequest}; - use actix_web::App; + use actix_web::test::{self, block_on, TestRequest}; + use actix_web::{App, Responder}; #[test] fn test_file_extension_to_mime() { @@ -598,518 +661,647 @@ mod tests { #[test] fn test_if_modified_since_without_if_none_match() { - let file = NamedFile::open("Cargo.toml").unwrap(); - let since = - header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); + block_on(async { + let file = NamedFile::open("Cargo.toml").unwrap(); + let since = + header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); - let req = TestRequest::default() - .header(header::IF_MODIFIED_SINCE, since) - .to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_MODIFIED); + let req = TestRequest::default() + .header(header::IF_MODIFIED_SINCE, since) + .to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_MODIFIED); + }) } #[test] fn test_if_modified_since_with_if_none_match() { - let file = NamedFile::open("Cargo.toml").unwrap(); - let since = - header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); + block_on(async { + let file = NamedFile::open("Cargo.toml").unwrap(); + let since = + header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); - let req = TestRequest::default() - .header(header::IF_NONE_MATCH, "miss_etag") - .header(header::IF_MODIFIED_SINCE, since) - .to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_ne!(resp.status(), StatusCode::NOT_MODIFIED); + let req = TestRequest::default() + .header(header::IF_NONE_MATCH, "miss_etag") + .header(header::IF_MODIFIED_SINCE, since) + .to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_ne!(resp.status(), StatusCode::NOT_MODIFIED); + }) } #[test] fn test_named_file_text() { - assert!(NamedFile::open("test--").is_err()); - let mut file = NamedFile::open("Cargo.toml").unwrap(); - { - file.file(); - let _f: &File = &file; - } - { - let _f: &mut File = &mut file; - } + block_on(async { + assert!(NamedFile::open("test--").is_err()); + let mut file = NamedFile::open("Cargo.toml").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - "text/x-toml" - ); - assert_eq!( - resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - "inline; filename=\"Cargo.toml\"" - ); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-toml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + }) } #[test] fn test_named_file_content_disposition() { - assert!(NamedFile::open("test--").is_err()); - let mut file = NamedFile::open("Cargo.toml").unwrap(); - { - file.file(); - let _f: &File = &file; - } - { - let _f: &mut File = &mut file; - } + block_on(async { + assert!(NamedFile::open("test--").is_err()); + let mut file = NamedFile::open("Cargo.toml").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!( + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + + let file = NamedFile::open("Cargo.toml") + .unwrap() + .disable_content_disposition(); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert!(resp.headers().get(header::CONTENT_DISPOSITION).is_none()); + }) + } + + #[test] + fn test_named_file_non_ascii_file_name() { + block_on(async { + let mut file = + NamedFile::from_file(File::open("Cargo.toml").unwrap(), "貨物.toml") + .unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-toml" + ); + assert_eq!( resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - "inline; filename=\"Cargo.toml\"" + "inline; filename=\"貨物.toml\"; filename*=UTF-8''%E8%B2%A8%E7%89%A9.toml" ); - - let file = NamedFile::open("Cargo.toml") - .unwrap() - .disable_content_disposition(); - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert!(resp.headers().get(header::CONTENT_DISPOSITION).is_none()); + }) } #[test] fn test_named_file_set_content_type() { - let mut file = NamedFile::open("Cargo.toml") - .unwrap() - .set_content_type(mime::TEXT_XML); - { - file.file(); - let _f: &File = &file; - } - { - let _f: &mut File = &mut file; - } + block_on(async { + let mut file = NamedFile::open("Cargo.toml") + .unwrap() + .set_content_type(mime::TEXT_XML); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - "text/xml" - ); - assert_eq!( - resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - "inline; filename=\"Cargo.toml\"" - ); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/xml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + }) } #[test] fn test_named_file_image() { - let mut file = NamedFile::open("tests/test.png").unwrap(); - { - file.file(); - let _f: &File = &file; - } - { - let _f: &mut File = &mut file; - } + block_on(async { + let mut file = NamedFile::open("tests/test.png").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - "image/png" - ); - assert_eq!( - resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - "inline; filename=\"test.png\"" - ); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "image/png" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"test.png\"" + ); + }) } #[test] fn test_named_file_image_attachment() { - let cd = ContentDisposition { - disposition: DispositionType::Attachment, - parameters: vec![DispositionParam::Filename(String::from("test.png"))], - }; - let mut file = NamedFile::open("tests/test.png") - .unwrap() - .set_content_disposition(cd); - { - file.file(); - let _f: &File = &file; - } - { - let _f: &mut File = &mut file; - } + block_on(async { + let cd = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::Filename(String::from("test.png"))], + }; + let mut file = NamedFile::open("tests/test.png") + .unwrap() + .set_content_disposition(cd); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - "image/png" - ); - assert_eq!( - resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - "attachment; filename=\"test.png\"" - ); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "image/png" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "attachment; filename=\"test.png\"" + ); + }) } #[test] fn test_named_file_binary() { - let mut file = NamedFile::open("tests/test.binary").unwrap(); - { - file.file(); - let _f: &File = &file; - } - { - let _f: &mut File = &mut file; - } + block_on(async { + let mut file = NamedFile::open("tests/test.binary").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - "application/octet-stream" - ); - assert_eq!( - resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - "attachment; filename=\"test.binary\"" - ); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "application/octet-stream" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "attachment; filename=\"test.binary\"" + ); + }) } #[test] fn test_named_file_status_code_text() { - let mut file = NamedFile::open("Cargo.toml") - .unwrap() - .set_status_code(StatusCode::NOT_FOUND); - { - file.file(); - let _f: &File = &file; - } - { - let _f: &mut File = &mut file; - } + block_on(async { + let mut file = NamedFile::open("Cargo.toml") + .unwrap() + .set_status_code(StatusCode::NOT_FOUND); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } - let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - "text/x-toml" - ); - assert_eq!( - resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - "inline; filename=\"Cargo.toml\"" - ); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-toml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }) } #[test] fn test_mime_override() { - fn all_attachment(_: &mime::Name) -> DispositionType { - DispositionType::Attachment - } + block_on(async { + fn all_attachment(_: &mime::Name) -> DispositionType { + DispositionType::Attachment + } - let mut srv = test::init_service( - App::new().service( - Files::new("/", ".") - .mime_override(all_attachment) - .index_file("Cargo.toml"), - ), - ); + let mut srv = test::init_service( + App::new().service( + Files::new("/", ".") + .mime_override(all_attachment) + .index_file("Cargo.toml"), + ), + ) + .await; - let request = TestRequest::get().uri("/").to_request(); - let response = test::call_service(&mut srv, request); - assert_eq!(response.status(), StatusCode::OK); + let request = TestRequest::get().uri("/").to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::OK); - let content_disposition = response - .headers() - .get(header::CONTENT_DISPOSITION) - .expect("To have CONTENT_DISPOSITION"); - let content_disposition = content_disposition - .to_str() - .expect("Convert CONTENT_DISPOSITION to str"); - assert_eq!(content_disposition, "attachment; filename=\"Cargo.toml\""); + let content_disposition = response + .headers() + .get(header::CONTENT_DISPOSITION) + .expect("To have CONTENT_DISPOSITION"); + let content_disposition = content_disposition + .to_str() + .expect("Convert CONTENT_DISPOSITION to str"); + assert_eq!(content_disposition, "attachment; filename=\"Cargo.toml\""); + }) } #[test] fn test_named_file_ranges_status_code() { - let mut srv = test::init_service( - App::new().service(Files::new("/test", ".").index_file("Cargo.toml")), - ); + block_on(async { + let mut srv = test::init_service( + App::new().service(Files::new("/test", ".").index_file("Cargo.toml")), + ) + .await; - // Valid range header - let request = TestRequest::get() - .uri("/t%65st/Cargo.toml") - .header(header::RANGE, "bytes=10-20") - .to_request(); - let response = test::call_service(&mut srv, request); - assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT); + // Valid range header + let request = TestRequest::get() + .uri("/t%65st/Cargo.toml") + .header(header::RANGE, "bytes=10-20") + .to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT); - // Invalid range header - let request = TestRequest::get() - .uri("/t%65st/Cargo.toml") - .header(header::RANGE, "bytes=1-0") - .to_request(); - let response = test::call_service(&mut srv, request); + // Invalid range header + let request = TestRequest::get() + .uri("/t%65st/Cargo.toml") + .header(header::RANGE, "bytes=1-0") + .to_request(); + let response = test::call_service(&mut srv, request).await; - assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); + assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); + }) } #[test] fn test_named_file_content_range_headers() { - let mut srv = test::init_service( - App::new().service(Files::new("/test", ".").index_file("tests/test.binary")), - ); + block_on(async { + let mut srv = test::init_service( + App::new() + .service(Files::new("/test", ".").index_file("tests/test.binary")), + ) + .await; - // Valid range header - let request = TestRequest::get() - .uri("/t%65st/tests/test.binary") - .header(header::RANGE, "bytes=10-20") - .to_request(); + // Valid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-20") + .to_request(); - let response = test::call_service(&mut srv, request); - let contentrange = response - .headers() - .get(header::CONTENT_RANGE) - .unwrap() - .to_str() - .unwrap(); + let response = test::call_service(&mut srv, request).await; + let contentrange = response + .headers() + .get(header::CONTENT_RANGE) + .unwrap() + .to_str() + .unwrap(); - assert_eq!(contentrange, "bytes 10-20/100"); + assert_eq!(contentrange, "bytes 10-20/100"); - // Invalid range header - let request = TestRequest::get() - .uri("/t%65st/tests/test.binary") - .header(header::RANGE, "bytes=10-5") - .to_request(); - let response = test::call_service(&mut srv, request); + // Invalid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-5") + .to_request(); + let response = test::call_service(&mut srv, request).await; - let contentrange = response - .headers() - .get(header::CONTENT_RANGE) - .unwrap() - .to_str() - .unwrap(); + let contentrange = response + .headers() + .get(header::CONTENT_RANGE) + .unwrap() + .to_str() + .unwrap(); - assert_eq!(contentrange, "bytes */100"); + assert_eq!(contentrange, "bytes */100"); + }) } #[test] fn test_named_file_content_length_headers() { - // use actix_web::body::{MessageBody, ResponseBody}; + block_on(async { + // use actix_web::body::{MessageBody, ResponseBody}; - let mut srv = test::init_service( - App::new().service(Files::new("test", ".").index_file("tests/test.binary")), - ); + let mut srv = test::init_service( + App::new() + .service(Files::new("test", ".").index_file("tests/test.binary")), + ) + .await; - // Valid range header - let request = TestRequest::get() - .uri("/t%65st/tests/test.binary") - .header(header::RANGE, "bytes=10-20") - .to_request(); - let _response = test::call_service(&mut srv, request); + // Valid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-20") + .to_request(); + let _response = test::call_service(&mut srv, request).await; - // let contentlength = response - // .headers() - // .get(header::CONTENT_LENGTH) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(contentlength, "11"); + // let contentlength = response + // .headers() + // .get(header::CONTENT_LENGTH) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(contentlength, "11"); - // Invalid range header - let request = TestRequest::get() - .uri("/t%65st/tests/test.binary") - .header(header::RANGE, "bytes=10-8") - .to_request(); - let response = test::call_service(&mut srv, request); - assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); + // Invalid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-8") + .to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); - // Without range header - let request = TestRequest::get() - .uri("/t%65st/tests/test.binary") - // .no_default_headers() - .to_request(); - let _response = test::call_service(&mut srv, request); + // Without range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + // .no_default_headers() + .to_request(); + let _response = test::call_service(&mut srv, request).await; - // let contentlength = response - // .headers() - // .get(header::CONTENT_LENGTH) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(contentlength, "100"); + // let contentlength = response + // .headers() + // .get(header::CONTENT_LENGTH) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(contentlength, "100"); - // chunked - let request = TestRequest::get() - .uri("/t%65st/tests/test.binary") - .to_request(); - let mut response = test::call_service(&mut srv, request); + // chunked + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .to_request(); + let response = test::call_service(&mut srv, request).await; - // with enabled compression - // { - // let te = response - // .headers() - // .get(header::TRANSFER_ENCODING) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(te, "chunked"); - // } + // with enabled compression + // { + // let te = response + // .headers() + // .get(header::TRANSFER_ENCODING) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(te, "chunked"); + // } - let bytes = - test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); - let data = Bytes::from(fs::read("tests/test.binary").unwrap()); - assert_eq!(bytes.freeze(), data); + let bytes = test::read_body(response).await; + let data = Bytes::from(fs::read("tests/test.binary").unwrap()); + assert_eq!(bytes, data); + }) } #[test] fn test_head_content_length_headers() { - let mut srv = test::init_service( - App::new().service(Files::new("test", ".").index_file("tests/test.binary")), - ); + block_on(async { + let mut srv = test::init_service( + App::new() + .service(Files::new("test", ".").index_file("tests/test.binary")), + ) + .await; - // Valid range header - let request = TestRequest::default() - .method(Method::HEAD) - .uri("/t%65st/tests/test.binary") - .to_request(); - let _response = test::call_service(&mut srv, request); + // Valid range header + let request = TestRequest::default() + .method(Method::HEAD) + .uri("/t%65st/tests/test.binary") + .to_request(); + let _response = test::call_service(&mut srv, request).await; - // TODO: fix check - // let contentlength = response - // .headers() - // .get(header::CONTENT_LENGTH) - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(contentlength, "100"); + // TODO: fix check + // let contentlength = response + // .headers() + // .get(header::CONTENT_LENGTH) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(contentlength, "100"); + }) } #[test] fn test_static_files_with_spaces() { - let mut srv = test::init_service( - App::new().service(Files::new("/", ".").index_file("Cargo.toml")), - ); - let request = TestRequest::get() - .uri("/tests/test%20space.binary") - .to_request(); - let mut response = test::call_service(&mut srv, request); - assert_eq!(response.status(), StatusCode::OK); + block_on(async { + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").index_file("Cargo.toml")), + ) + .await; + let request = TestRequest::get() + .uri("/tests/test%20space.binary") + .to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::OK); - let bytes = - test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); - - let data = Bytes::from(fs::read("tests/test space.binary").unwrap()); - assert_eq!(bytes.freeze(), data); + let bytes = test::read_body(response).await; + let data = Bytes::from(fs::read("tests/test space.binary").unwrap()); + assert_eq!(bytes, data); + }) } #[test] - fn test_named_file_not_allowed() { - let file = NamedFile::open("Cargo.toml").unwrap(); - let req = TestRequest::default() - .method(Method::POST) - .to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + fn test_files_not_allowed() { + block_on(async { + let mut srv = + test::init_service(App::new().service(Files::new("/", "."))).await; - let file = NamedFile::open("Cargo.toml").unwrap(); - let req = TestRequest::default().method(Method::PUT).to_http_request(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let req = TestRequest::default() + .uri("/Cargo.toml") + .method(Method::POST) + .to_request(); + + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + + let mut srv = + test::init_service(App::new().service(Files::new("/", "."))).await; + let req = TestRequest::default() + .method(Method::PUT) + .uri("/Cargo.toml") + .to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + }) + } + + #[test] + fn test_files_guards() { + block_on(async { + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").use_guards(guard::Post())), + ) + .await; + + let req = TestRequest::default() + .uri("/Cargo.toml") + .method(Method::POST) + .to_request(); + + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_named_file_content_encoding() { - let mut srv = test::init_service(App::new().wrap(Compress::default()).service( - web::resource("/").to(|| { - NamedFile::open("Cargo.toml") - .unwrap() - .set_content_encoding(header::ContentEncoding::Identity) - }), - )); + block_on(async { + let mut srv = + test::init_service(App::new().wrap(Compress::default()).service( + web::resource("/").to(|| { + async { + NamedFile::open("Cargo.toml") + .unwrap() + .set_content_encoding(header::ContentEncoding::Identity) + } + }), + )) + .await; - let request = TestRequest::get() - .uri("/") - .header(header::ACCEPT_ENCODING, "gzip") - .to_request(); - let res = test::call_service(&mut srv, request); - assert_eq!(res.status(), StatusCode::OK); - assert!(!res.headers().contains_key(header::CONTENT_ENCODING)); + let request = TestRequest::get() + .uri("/") + .header(header::ACCEPT_ENCODING, "gzip") + .to_request(); + let res = test::call_service(&mut srv, request).await; + assert_eq!(res.status(), StatusCode::OK); + assert!(!res.headers().contains_key(header::CONTENT_ENCODING)); + }) } #[test] fn test_named_file_content_encoding_gzip() { - let mut srv = test::init_service(App::new().wrap(Compress::default()).service( - web::resource("/").to(|| { - NamedFile::open("Cargo.toml") - .unwrap() - .set_content_encoding(header::ContentEncoding::Gzip) - }), - )); + block_on(async { + let mut srv = + test::init_service(App::new().wrap(Compress::default()).service( + web::resource("/").to(|| { + async { + NamedFile::open("Cargo.toml") + .unwrap() + .set_content_encoding(header::ContentEncoding::Gzip) + } + }), + )) + .await; - let request = TestRequest::get() - .uri("/") - .header(header::ACCEPT_ENCODING, "gzip") - .to_request(); - let res = test::call_service(&mut srv, request); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!( - res.headers() - .get(header::CONTENT_ENCODING) - .unwrap() - .to_str() - .unwrap(), - "gzip" - ); + let request = TestRequest::get() + .uri("/") + .header(header::ACCEPT_ENCODING, "gzip") + .to_request(); + let res = test::call_service(&mut srv, request).await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers() + .get(header::CONTENT_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "gzip" + ); + }) } #[test] fn test_named_file_allowed_method() { - let req = TestRequest::default().method(Method::GET).to_http_request(); - let file = NamedFile::open("Cargo.toml").unwrap(); - let resp = file.respond_to(&req).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + block_on(async { + let req = TestRequest::default().method(Method::GET).to_http_request(); + let file = NamedFile::open("Cargo.toml").unwrap(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_static_files() { - let mut srv = test::init_service( - App::new().service(Files::new("/", ".").show_files_listing()), - ); - let req = TestRequest::with_uri("/missing").to_request(); + block_on(async { + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").show_files_listing()), + ) + .await; + let req = TestRequest::with_uri("/missing").to_request(); - let resp = test::call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let mut srv = test::init_service(App::new().service(Files::new("/", "."))); + let mut srv = + test::init_service(App::new().service(Files::new("/", "."))).await; - let req = TestRequest::default().to_request(); - let resp = test::call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::default().to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let mut srv = test::init_service( - App::new().service(Files::new("/", ".").show_files_listing()), - ); - let req = TestRequest::with_uri("/tests").to_request(); - let mut resp = test::call_service(&mut srv, req); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - "text/html; charset=utf-8" - ); + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").show_files_listing()), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/html; charset=utf-8" + ); - let bytes = - test::block_on(resp.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); - assert!(format!("{:?}", bytes).contains("/tests/test.png")); + let bytes = test::read_body(resp).await; + assert!(format!("{:?}", bytes).contains("/tests/test.png")); + }) + } + + #[test] + fn test_redirect_to_slash_directory() { + block_on(async { + // should not redirect if no index + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").redirect_to_slash_directory()), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + // should redirect if index present + let mut srv = test::init_service( + App::new().service( + Files::new("/", ".") + .index_file("test.png") + .redirect_to_slash_directory(), + ), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::FOUND); + + // should not redirect if the path is wrong + let req = TestRequest::with_uri("/not_existing").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }) } #[test] @@ -1120,26 +1312,21 @@ mod tests { #[test] fn test_default_handler_file_missing() { - let mut st = test::block_on( - Files::new("/", ".") + block_on(async { + let mut st = Files::new("/", ".") .default_handler(|req: ServiceRequest| { - Ok(req.into_response(HttpResponse::Ok().body("default content"))) + ok(req.into_response(HttpResponse::Ok().body("default content"))) }) - .new_service(&()), - ) - .unwrap(); - let req = TestRequest::with_uri("/missing").to_srv_request(); + .new_service(&()) + .await + .unwrap(); + let req = TestRequest::with_uri("/missing").to_srv_request(); - let mut resp = test::call_service(&mut st, req); - assert_eq!(resp.status(), StatusCode::OK); - let bytes = - test::block_on(resp.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); - - assert_eq!(bytes.freeze(), Bytes::from_static(b"default content")); + let resp = test::call_service(&mut st, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let bytes = test::read_body(resp).await; + assert_eq!(bytes, Bytes::from_static(b"default content")); + }) } // #[test] diff --git a/actix-files/src/named.rs b/actix-files/src/named.rs index f548a7a1b..0dcbd93b8 100644 --- a/actix-files/src/named.rs +++ b/actix-files/src/named.rs @@ -13,11 +13,12 @@ use mime_guess::from_path; use actix_http::body::SizedStream; use actix_web::http::header::{ - self, ContentDisposition, DispositionParam, DispositionType, + self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, }; -use actix_web::http::{ContentEncoding, Method, StatusCode}; +use actix_web::http::{ContentEncoding, StatusCode}; use actix_web::middleware::BodyEncoding; use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder}; +use futures::future::{ready, Ready}; use crate::range::HttpRange; use crate::ChunkedReadFile; @@ -93,9 +94,18 @@ impl NamedFile { mime::IMAGE | mime::TEXT | mime::VIDEO => DispositionType::Inline, _ => DispositionType::Attachment, }; + let mut parameters = + vec![DispositionParam::Filename(String::from(filename.as_ref()))]; + if !filename.is_ascii() { + parameters.push(DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: filename.into_owned().into_bytes(), + })) + } let cd = ContentDisposition { disposition: disposition_type, - parameters: vec![DispositionParam::Filename(filename.into_owned())], + parameters: parameters, }; (ct, cd) }; @@ -246,62 +256,8 @@ impl NamedFile { pub(crate) fn last_modified(&self) -> Option { self.modified.map(|mtime| mtime.into()) } -} -impl Deref for NamedFile { - type Target = File; - - fn deref(&self) -> &File { - &self.file - } -} - -impl DerefMut for NamedFile { - fn deref_mut(&mut self) -> &mut File { - &mut self.file - } -} - -/// Returns true if `req` has no `If-Match` header or one which matches `etag`. -fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { - match req.get_header::() { - None | Some(header::IfMatch::Any) => true, - Some(header::IfMatch::Items(ref items)) => { - if let Some(some_etag) = etag { - for item in items { - if item.strong_eq(some_etag) { - return true; - } - } - } - false - } - } -} - -/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`. -fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { - match req.get_header::() { - Some(header::IfNoneMatch::Any) => false, - Some(header::IfNoneMatch::Items(ref items)) => { - if let Some(some_etag) = etag { - for item in items { - if item.weak_eq(some_etag) { - return false; - } - } - } - true - } - None => true, - } -} - -impl Responder for NamedFile { - type Error = Error; - type Future = Result; - - fn respond_to(self, req: &HttpRequest) -> Self::Future { + pub fn into_response(self, req: &HttpRequest) -> Result { if self.status_code != StatusCode::OK { let mut resp = HttpResponse::build(self.status_code); resp.set(header::ContentType(self.content_type.clone())) @@ -324,16 +280,6 @@ impl Responder for NamedFile { return Ok(resp.streaming(reader)); } - match *req.method() { - Method::HEAD | Method::GET => (), - _ => { - return Ok(HttpResponse::MethodNotAllowed() - .header(header::CONTENT_TYPE, "text/plain") - .header(header::ALLOW, "GET, HEAD") - .body("This resource only supports GET and HEAD.")); - } - } - let etag = if self.flags.contains(Flags::ETAG) { self.etag() } else { @@ -443,8 +389,67 @@ impl Responder for NamedFile { counter: 0, }; if offset != 0 || length != self.md.len() { - return Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)); - }; - Ok(resp.body(SizedStream::new(length, reader))) + Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)) + } else { + Ok(resp.body(SizedStream::new(length, reader))) + } + } +} + +impl Deref for NamedFile { + type Target = File; + + fn deref(&self) -> &File { + &self.file + } +} + +impl DerefMut for NamedFile { + fn deref_mut(&mut self) -> &mut File { + &mut self.file + } +} + +/// Returns true if `req` has no `If-Match` header or one which matches `etag`. +fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { + match req.get_header::() { + None | Some(header::IfMatch::Any) => true, + Some(header::IfMatch::Items(ref items)) => { + if let Some(some_etag) = etag { + for item in items { + if item.strong_eq(some_etag) { + return true; + } + } + } + false + } + } +} + +/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`. +fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { + match req.get_header::() { + Some(header::IfNoneMatch::Any) => false, + Some(header::IfNoneMatch::Items(ref items)) => { + if let Some(some_etag) = etag { + for item in items { + if item.weak_eq(some_etag) { + return false; + } + } + } + true + } + None => true, + } +} + +impl Responder for NamedFile { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + ready(self.into_response(req)) } } diff --git a/actix-framed/Cargo.toml b/actix-framed/Cargo.toml index 321041c7e..4783daefd 100644 --- a/actix-framed/Cargo.toml +++ b/actix-framed/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-framed" -version = "0.2.1" +version = "0.3.0-alpha.1" authors = ["Nikolay Kim "] description = "Actix framed app server" readme = "README.md" @@ -20,19 +20,20 @@ name = "actix_framed" path = "src/lib.rs" [dependencies] -actix-codec = "0.1.2" -actix-service = "0.4.1" +actix-codec = "0.2.0-alpha.1" +actix-service = "1.0.0-alpha.1" actix-router = "0.1.2" -actix-rt = "0.2.2" -actix-http = "0.2.7" -actix-server-config = "0.1.2" +actix-rt = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" bytes = "0.4" -futures = "0.1.25" +futures = "0.3.1" +pin-project = "0.4.6" log = "0.4" [dev-dependencies] -actix-server = { version = "0.6.0", features=["ssl"] } -actix-connect = { version = "0.2.0", features=["ssl"] } -actix-http-test = { version = "0.2.4", features=["ssl"] } -actix-utils = "0.4.4" +actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } +actix-connect = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-utils = "0.5.0-alpha.1" diff --git a/actix-framed/src/app.rs b/actix-framed/src/app.rs index ad5b1ec26..f3e746e9f 100644 --- a/actix-framed/src/app.rs +++ b/actix-framed/src/app.rs @@ -1,21 +1,24 @@ +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::h1::{Codec, SendResponse}; use actix_http::{Error, Request, Response}; use actix_router::{Path, Router, Url}; use actix_server_config::ServerConfig; -use actix_service::{IntoNewService, NewService, Service}; -use futures::{Async, Future, Poll}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture}; use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService}; use crate::request::FramedRequest; use crate::state::State; -type BoxedResponse = Box>; +type BoxedResponse = LocalBoxFuture<'static, Result<(), Error>>; pub trait HttpServiceFactory { - type Factory: NewService; + type Factory: ServiceFactory; fn path(&self) -> &str; @@ -48,19 +51,19 @@ impl FramedApp { pub fn service(mut self, factory: U) -> Self where U: HttpServiceFactory, - U::Factory: NewService< + U::Factory: ServiceFactory< Config = (), Request = FramedRequest, Response = (), Error = Error, InitError = (), > + 'static, - ::Future: 'static, - ::Service: Service< + ::Future: 'static, + ::Service: Service< Request = FramedRequest, Response = (), Error = Error, - Future = Box>, + Future = LocalBoxFuture<'static, Result<(), Error>>, >, { let path = factory.path().to_string(); @@ -70,12 +73,12 @@ impl FramedApp { } } -impl IntoNewService> for FramedApp +impl IntoServiceFactory> for FramedApp where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, S: 'static, { - fn into_new_service(self) -> FramedAppFactory { + fn into_factory(self) -> FramedAppFactory { FramedAppFactory { state: self.state, services: Rc::new(self.services), @@ -89,9 +92,9 @@ pub struct FramedAppFactory { services: Rc>)>>, } -impl NewService for FramedAppFactory +impl ServiceFactory for FramedAppFactory where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, S: 'static, { type Config = ServerConfig; @@ -128,28 +131,30 @@ pub struct CreateService { enum CreateServiceItem { Future( Option, - Box>, Error = ()>>, + LocalBoxFuture<'static, Result>, ()>>, ), Service(String, BoxedHttpService>), } impl Future for CreateService where - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { - type Item = FramedAppService; - type Error = (); + type Output = Result, ()>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut done = true; // poll http services for item in &mut self.fut { let res = match item { CreateServiceItem::Future(ref mut path, ref mut fut) => { - match fut.poll()? { - Async::Ready(service) => Some((path.take().unwrap(), service)), - Async::NotReady => { + match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(service)) => { + Some((path.take().unwrap(), service)) + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { done = false; None } @@ -176,12 +181,12 @@ where } router }); - Ok(Async::Ready(FramedAppService { + Poll::Ready(Ok(FramedAppService { router: router.finish(), state: self.state.clone(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -193,15 +198,15 @@ pub struct FramedAppService { impl Service for FramedAppService where - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { type Request = (Request, Framed); type Response = (); type Error = Error; type Future = BoxedResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { @@ -210,8 +215,8 @@ where if let Some((srv, _info)) = self.router.recognize_mut(&mut path) { return srv.call(FramedRequest::new(req, framed, path, self.state.clone())); } - Box::new( - SendResponse::new(framed, Response::NotFound().finish()).then(|_| Ok(())), - ) + SendResponse::new(framed, Response::NotFound().finish()) + .then(|_| ok(())) + .boxed_local() } } diff --git a/actix-framed/src/helpers.rs b/actix-framed/src/helpers.rs index b343301f3..b654f9cd7 100644 --- a/actix-framed/src/helpers.rs +++ b/actix-framed/src/helpers.rs @@ -1,36 +1,38 @@ +use std::task::{Context, Poll}; + use actix_http::Error; -use actix_service::{NewService, Service}; -use futures::{Future, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{FutureExt, LocalBoxFuture}; pub(crate) type BoxedHttpService = Box< dyn Service< Request = Req, Response = (), Error = Error, - Future = Box>, + Future = LocalBoxFuture<'static, Result<(), Error>>, >, >; pub(crate) type BoxedHttpNewService = Box< - dyn NewService< + dyn ServiceFactory< Config = (), Request = Req, Response = (), Error = Error, InitError = (), Service = BoxedHttpService, - Future = Box, Error = ()>>, + Future = LocalBoxFuture<'static, Result, ()>>, >, >; -pub(crate) struct HttpNewService(T); +pub(crate) struct HttpNewService(T); impl HttpNewService where - T: NewService, + T: ServiceFactory, T::Response: 'static, T::Future: 'static, - T::Service: Service>> + 'static, + T::Service: Service>> + 'static, ::Future: 'static, { pub fn new(service: T) -> Self { @@ -38,12 +40,12 @@ where } } -impl NewService for HttpNewService +impl ServiceFactory for HttpNewService where - T: NewService, + T: ServiceFactory, T::Request: 'static, T::Future: 'static, - T::Service: Service>> + 'static, + T::Service: Service>> + 'static, ::Future: 'static, { type Config = (); @@ -52,13 +54,19 @@ where type Error = Error; type InitError = (); type Service = BoxedHttpService; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: &()) -> Self::Future { - Box::new(self.0.new_service(&()).map_err(|_| ()).and_then(|service| { - let service: BoxedHttpService<_> = Box::new(HttpServiceWrapper { service }); - Ok(service) - })) + let fut = self.0.new_service(&()); + + async move { + fut.await.map_err(|_| ()).map(|service| { + let service: BoxedHttpService<_> = + Box::new(HttpServiceWrapper { service }); + service + }) + } + .boxed_local() } } @@ -70,7 +78,7 @@ impl Service for HttpServiceWrapper where T: Service< Response = (), - Future = Box>, + Future = LocalBoxFuture<'static, Result<(), Error>>, Error = Error, >, T::Request: 'static, @@ -78,10 +86,10 @@ where type Request = T::Request; type Response = (); type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result<(), Error>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: Self::Request) -> Self::Future { diff --git a/actix-framed/src/route.rs b/actix-framed/src/route.rs index 5beb24165..783039684 100644 --- a/actix-framed/src/route.rs +++ b/actix-framed/src/route.rs @@ -1,11 +1,12 @@ use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::{http::Method, Error}; -use actix_service::{NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use log::error; use crate::app::HttpServiceFactory; @@ -15,11 +16,11 @@ use crate::request::FramedRequest; /// /// Route uses builder-like pattern for configuration. /// If handler is not explicitly set, default *404 Not Found* handler is used. -pub struct FramedRoute { +pub struct FramedRoute { handler: F, pattern: String, methods: Vec, - state: PhantomData<(Io, S, R)>, + state: PhantomData<(Io, S, R, E)>, } impl FramedRoute { @@ -53,12 +54,12 @@ impl FramedRoute { self } - pub fn to(self, handler: F) -> FramedRoute + pub fn to(self, handler: F) -> FramedRoute where F: FnMut(FramedRequest) -> R, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Debug, + R: Future> + 'static, + + E: fmt::Debug, { FramedRoute { handler, @@ -69,15 +70,14 @@ impl FramedRoute { } } -impl HttpServiceFactory for FramedRoute +impl HttpServiceFactory for FramedRoute where Io: AsyncRead + AsyncWrite + 'static, F: FnMut(FramedRequest) -> R + Clone, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Display, + R: Future> + 'static, + E: fmt::Display, { - type Factory = FramedRouteFactory; + type Factory = FramedRouteFactory; fn path(&self) -> &str { &self.pattern @@ -92,27 +92,26 @@ where } } -pub struct FramedRouteFactory { +pub struct FramedRouteFactory { handler: F, methods: Vec, - _t: PhantomData<(Io, S, R)>, + _t: PhantomData<(Io, S, R, E)>, } -impl NewService for FramedRouteFactory +impl ServiceFactory for FramedRouteFactory where Io: AsyncRead + AsyncWrite + 'static, F: FnMut(FramedRequest) -> R + Clone, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Display, + R: Future> + 'static, + E: fmt::Display, { type Config = (); type Request = FramedRequest; type Response = (); type Error = Error; type InitError = (); - type Service = FramedRouteService; - type Future = FutureResult; + type Service = FramedRouteService; + type Future = Ready>; fn new_service(&self, _: &()) -> Self::Future { ok(FramedRouteService { @@ -123,35 +122,38 @@ where } } -pub struct FramedRouteService { +pub struct FramedRouteService { handler: F, methods: Vec, - _t: PhantomData<(Io, S, R)>, + _t: PhantomData<(Io, S, R, E)>, } -impl Service for FramedRouteService +impl Service for FramedRouteService where Io: AsyncRead + AsyncWrite + 'static, F: FnMut(FramedRequest) -> R + Clone, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Display, + R: Future> + 'static, + E: fmt::Display, { type Request = FramedRequest; type Response = (); type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result<(), Error>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: FramedRequest) -> Self::Future { - Box::new((self.handler)(req).into_future().then(|res| { + let fut = (self.handler)(req); + + async move { + let res = fut.await; if let Err(e) = res { error!("Error in request handler: {}", e); } Ok(()) - })) + } + .boxed_local() } } diff --git a/actix-framed/src/service.rs b/actix-framed/src/service.rs index fbbc9fbef..ed3a75ff5 100644 --- a/actix-framed/src/service.rs +++ b/actix-framed/src/service.rs @@ -1,4 +1,6 @@ use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::body::BodySize; @@ -6,9 +8,9 @@ use actix_http::error::ResponseError; use actix_http::h1::{Codec, Message}; use actix_http::ws::{verify_handshake, HandshakeError}; use actix_http::{Request, Response}; -use actix_service::{NewService, Service}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll, Sink}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{err, ok, Either, Ready}; +use futures::Future; /// Service that verifies incoming request if it is valid websocket /// upgrade request. In case of error returns `HandshakeError` @@ -22,14 +24,14 @@ impl Default for VerifyWebSockets { } } -impl NewService for VerifyWebSockets { +impl ServiceFactory for VerifyWebSockets { type Config = C; type Request = (Request, Framed); type Response = (Request, Framed); type Error = (HandshakeError, Framed); type InitError = (); type Service = VerifyWebSockets; - type Future = FutureResult; + type Future = Ready>; fn new_service(&self, _: &C) -> Self::Future { ok(VerifyWebSockets { _t: PhantomData }) @@ -40,16 +42,16 @@ impl Service for VerifyWebSockets { type Request = (Request, Framed); type Response = (Request, Framed); type Error = (HandshakeError, Framed); - type Future = FutureResult; + type Future = Ready>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { match verify_handshake(req.head()) { - Err(e) => Err((e, framed)).into_future(), - Ok(_) => Ok((req, framed)).into_future(), + Err(e) => err((e, framed)), + Ok(_) => ok((req, framed)), } } } @@ -67,9 +69,9 @@ where } } -impl NewService for SendError +impl ServiceFactory for SendError where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, R: 'static, E: ResponseError + 'static, { @@ -79,7 +81,7 @@ where type Error = (E, Framed); type InitError = (); type Service = SendError; - type Future = FutureResult; + type Future = Ready>; fn new_service(&self, _: &C) -> Self::Future { ok(SendError(PhantomData)) @@ -88,25 +90,25 @@ where impl Service for SendError where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, R: 'static, E: ResponseError + 'static, { type Request = Result)>; type Response = R; type Error = (E, Framed); - type Future = Either)>, SendErrorFut>; + type Future = Either)>>, SendErrorFut>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: Result)>) -> Self::Future { match req { - Ok(r) => Either::A(ok(r)), + Ok(r) => Either::Left(ok(r)), Err((e, framed)) => { let res = e.error_response().drop_body(); - Either::B(SendErrorFut { + Either::Right(SendErrorFut { framed: Some(framed), res: Some((res, BodySize::Empty).into()), err: Some(e), @@ -117,6 +119,7 @@ where } } +#[pin_project::pin_project] pub struct SendErrorFut { res: Option, BodySize)>>, framed: Option>, @@ -127,23 +130,27 @@ pub struct SendErrorFut { impl Future for SendErrorFut where E: ResponseError, - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { - type Item = R; - type Error = (E, Framed); + type Output = Result)>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if let Some(res) = self.res.take() { - if self.framed.as_mut().unwrap().force_send(res).is_err() { - return Err((self.err.take().unwrap(), self.framed.take().unwrap())); + if self.framed.as_mut().unwrap().write(res).is_err() { + return Poll::Ready(Err(( + self.err.take().unwrap(), + self.framed.take().unwrap(), + ))); } } - match self.framed.as_mut().unwrap().poll_complete() { - Ok(Async::Ready(_)) => { - Err((self.err.take().unwrap(), self.framed.take().unwrap())) + match self.framed.as_mut().unwrap().flush(cx) { + Poll::Ready(Ok(_)) => { + Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())), + Poll::Ready(Err(_)) => { + Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) + } + Poll::Pending => Poll::Pending, } } } diff --git a/actix-framed/src/test.rs b/actix-framed/src/test.rs index 3bc828df4..b90a493dc 100644 --- a/actix-framed/src/test.rs +++ b/actix-framed/src/test.rs @@ -1,4 +1,6 @@ //! Various helpers for Actix applications to use during testing. +use std::future::Future; + use actix_codec::Framed; use actix_http::h1::Codec; use actix_http::http::header::{Header, HeaderName, IntoHeaderValue}; @@ -6,7 +8,6 @@ use actix_http::http::{HttpTryFrom, Method, Uri, Version}; use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest}; use actix_router::{Path, Url}; use actix_rt::Runtime; -use futures::IntoFuture; use crate::{FramedRequest, State}; @@ -121,10 +122,10 @@ impl TestRequest { pub fn run(self, f: F) -> Result where F: FnOnce(FramedRequest) -> R, - R: IntoFuture, + R: Future>, { let mut rt = Runtime::new().unwrap(); - rt.block_on(f(self.finish()).into_future()) + rt.block_on(f(self.finish())) } } diff --git a/actix-framed/tests/test_server.rs b/actix-framed/tests/test_server.rs index 00f6a97d8..6e4bb6ada 100644 --- a/actix-framed/tests/test_server.rs +++ b/actix-framed/tests/test_server.rs @@ -1,30 +1,31 @@ use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response}; -use actix_http_test::TestServer; -use actix_service::{IntoNewService, NewService}; +use actix_http_test::{block_on, TestServer}; +use actix_service::{pipeline_factory, IntoServiceFactory, ServiceFactory}; use actix_utils::framed::FramedTransport; use bytes::{Bytes, BytesMut}; -use futures::future::{self, ok}; -use futures::{Future, Sink, Stream}; +use futures::{future, SinkExt, StreamExt}; use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets}; -fn ws_service( +async fn ws_service( req: FramedRequest, -) -> impl Future { - let (req, framed, _) = req.into_parts(); +) -> Result<(), Error> { + let (req, mut framed, _) = req.into_parts(); let res = ws::handshake(req.head()).unwrap().message_body(()); framed .send((res, body::BodySize::None).into()) - .map_err(|_| panic!()) - .and_then(|framed| { - FramedTransport::new(framed.into_framed(ws::Codec::new()), service) - .map_err(|_| panic!()) - }) + .await + .unwrap(); + FramedTransport::new(framed.into_framed(ws::Codec::new()), service) + .await + .unwrap(); + + Ok(()) } -fn service(msg: ws::Frame) -> impl Future { +async fn service(msg: ws::Frame) -> Result { let msg = match msg { ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Text(text) => { @@ -34,108 +35,129 @@ fn service(msg: ws::Frame) -> impl Future { ws::Frame::Close(reason) => ws::Message::Close(reason), _ => panic!(), }; - ok(msg) + Ok(msg) } #[test] fn test_simple() { - let mut srv = TestServer::new(|| { - HttpService::build() - .upgrade( - FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)), - ) - .finish(|_| future::ok::<_, Error>(Response::NotFound())) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build() + .upgrade( + FramedApp::new() + .service(FramedRoute::get("/index.html").to(ws_service)), + ) + .finish(|_| future::ok::<_, Error>(Response::NotFound())) + }); - assert!(srv.ws_at("/test").is_err()); + assert!(srv.ws_at("/test").await.is_err()); - // client service - let framed = srv.ws_at("/index.html").unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + // client service + let mut framed = srv.ws_at("/index.html").await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Text(Some(BytesMut::from("text"))) + ); - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) - ); + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) - .unwrap(); + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) - ); + let (item, _) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) + ); + }) } #[test] fn test_service() { - let mut srv = TestServer::new(|| { - actix_http::h1::OneRequest::new().map_err(|_| ()).and_then( - VerifyWebSockets::default() - .then(SendError::default()) - .map_err(|_| ()) + block_on(async { + let mut srv = TestServer::start(|| { + pipeline_factory(actix_http::h1::OneRequest::new().map_err(|_| ())).and_then( + pipeline_factory( + pipeline_factory(VerifyWebSockets::default()) + .then(SendError::default()) + .map_err(|_| ()), + ) .and_then( FramedApp::new() .service(FramedRoute::get("/index.html").to(ws_service)) - .into_new_service() + .into_factory() .map_err(|_| ()), ), - ) - }); + ) + }); - // non ws request - let res = srv.block_on(srv.get("/index.html").send()).unwrap(); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); + // non ws request + let res = srv.get("/index.html").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); - // not found - assert!(srv.ws_at("/test").is_err()); + // not found + assert!(srv.ws_at("/test").await.is_err()); - // client service - let framed = srv.ws_at("/index.html").unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + // client service + let mut framed = srv.ws_at("/index.html").await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Text(Some(BytesMut::from("text"))) + ); - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) - ); + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) - .unwrap(); + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) - ); + let (item, _) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) + ); + }) } diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 06756033f..4cb5644c3 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,11 +1,19 @@ # Changes -## Not released yet +## [0.2.11] - 2019-11-06 ### Added * Add support for serde_json::Value to be passed as argument to ResponseBuilder.body() +* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) + +* Allow to use `std::convert::Infallible` as `actix_http::error::Error` + +### Fixed + +* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header #1118 + ## [0.2.10] - 2019-09-11 @@ -13,9 +21,6 @@ * Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests with `RequestHead` -* Allow to use `std::convert::Infallible` as `actix_http::error::Error` - - ### Fixed * h2 will use error response #1080 diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 72c78303f..a375fa42c 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http" -version = "0.2.10" +version = "0.3.0-alpha.1" authors = ["Nikolay Kim "] description = "Actix http primitives" readme = "README.md" @@ -16,7 +16,7 @@ edition = "2018" workspace = ".." [package.metadata.docs.rs] -features = ["ssl", "fail", "brotli", "flate2-zlib", "secure-cookies"] +features = ["openssl", "fail", "brotli", "flate2-zlib", "secure-cookies"] [lib] name = "actix_http" @@ -27,10 +27,10 @@ default = ["invalid-ping-payload"] invalid-ping-payload = [] # openssl -ssl = ["openssl", "actix-connect/ssl"] +openssl = ["open-ssl", "actix-connect/openssl", "tokio-openssl"] # rustls support -rust-tls = ["rustls", "webpki-roots", "actix-connect/rust-tls"] +# rustls = ["rust-tls", "webpki-roots", "actix-connect/rustls"] # brotli encoding, requires c compiler brotli = ["brotli2"] @@ -48,23 +48,24 @@ fail = ["failure"] secure-cookies = ["ring"] [dependencies] -actix-service = "0.4.1" -actix-codec = "0.1.2" -actix-connect = "0.2.4" -actix-utils = "0.4.4" -actix-server-config = "0.1.2" -actix-threadpool = "0.1.1" +actix-service = "1.0.0-alpha.1" +actix-codec = "0.2.0-alpha.1" +actix-connect = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" +actix-threadpool = "0.2.0-alpha.1" base64 = "0.10" bitflags = "1.0" bytes = "0.4" copyless = "0.1.4" +chrono = "0.4.6" derive_more = "0.15.0" either = "1.5.2" encoding_rs = "0.8" -futures = "0.1.25" -hashbrown = "0.5.0" -h2 = "0.1.16" +futures = "0.3.1" +hashbrown = "0.6.3" +h2 = "0.2.0-alpha.3" http = "0.1.17" httparse = "1.3" indexmap = "1.2" @@ -73,6 +74,7 @@ language-tags = "0.2" log = "0.4" mime = "0.3" percent-encoding = "2.1" +pin-project = "0.4.5" rand = "0.7" regex = "1.0" serde = "1.0" @@ -81,13 +83,16 @@ sha1 = "0.6" slab = "0.4" serde_urlencoded = "0.6.1" time = "0.1.42" -tokio-tcp = "0.1.3" -tokio-timer = "0.2.8" -tokio-current-thread = "0.1" -trust-dns-resolver = { version="0.11.1", default-features = false } + +tokio = "=0.2.0-alpha.6" +tokio-io = "=0.2.0-alpha.6" +tokio-net = "=0.2.0-alpha.6" +tokio-timer = "0.3.0-alpha.6" +tokio-executor = "=0.2.0-alpha.6" +trust-dns-resolver = { version="0.18.0-alpha.1", default-features = false } # for secure cookie -ring = { version = "0.14.6", optional = true } +ring = { version = "0.16.9", optional = true } # compression brotli2 = { version="0.3.2", optional = true } @@ -95,17 +100,17 @@ flate2 = { version="1.0.7", optional = true, default-features = false } # optional deps failure = { version = "0.1.5", optional = true } -openssl = { version="0.10", optional = true } -rustls = { version = "0.15.2", optional = true } -webpki-roots = { version = "0.16", optional = true } -chrono = "0.4.6" +open-ssl = { version="0.10", package="openssl", optional = true } +tokio-openssl = { version = "0.4.0-alpha.6", optional = true } + +rust-tls = { version = "0.16.0", package="rustls", optional = true } +webpki-roots = { version = "0.18", optional = true } [dev-dependencies] -actix-rt = "0.2.2" -actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] } -actix-connect = { version = "0.2.0", features=["ssl"] } -actix-http-test = { version = "0.2.4", features=["ssl"] } +actix-rt = "1.0.0-alpha.1" +actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } +actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } env_logger = "0.6" serde_derive = "1.0" -openssl = { version="0.10" } -tokio-tcp = "0.1" +open-ssl = { version="0.10", package="openssl" } diff --git a/actix-http/examples/echo.rs b/actix-http/examples/echo.rs index c36292c44..ba81020ca 100644 --- a/actix-http/examples/echo.rs +++ b/actix-http/examples/echo.rs @@ -1,9 +1,9 @@ use std::{env, io}; -use actix_http::{error::PayloadError, HttpService, Request, Response}; +use actix_http::{Error, HttpService, Request, Response}; use actix_server::Server; use bytes::BytesMut; -use futures::{Future, Stream}; +use futures::StreamExt; use http::header::HeaderValue; use log::info; @@ -17,20 +17,22 @@ fn main() -> io::Result<()> { .client_timeout(1000) .client_disconnect(1000) .finish(|mut req: Request| { - req.take_payload() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) - .and_then(|bytes| { - info!("request body: {:?}", bytes); - let mut res = Response::Ok(); - res.header( - "x-head", - HeaderValue::from_static("dummy value!"), - ); - Ok(res.body(bytes)) - }) + async move { + let mut body = BytesMut::new(); + while let Some(item) = req.payload().next().await { + body.extend_from_slice(&item?); + } + + info!("request body: {:?}", body); + Ok::<_, Error>( + Response::Ok() + .header( + "x-head", + HeaderValue::from_static("dummy value!"), + ) + .body(body), + ) + } }) })? .run() diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs index b239796b4..3776c7d58 100644 --- a/actix-http/examples/echo2.rs +++ b/actix-http/examples/echo2.rs @@ -1,25 +1,22 @@ use std::{env, io}; use actix_http::http::HeaderValue; -use actix_http::{error::PayloadError, Error, HttpService, Request, Response}; +use actix_http::{Error, HttpService, Request, Response}; use actix_server::Server; use bytes::BytesMut; -use futures::{Future, Stream}; +use futures::StreamExt; use log::info; -fn handle_request(mut req: Request) -> impl Future { - req.take_payload() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) - .from_err() - .and_then(|bytes| { - info!("request body: {:?}", bytes); - let mut res = Response::Ok(); - res.header("x-head", HeaderValue::from_static("dummy value!")); - Ok(res.body(bytes)) - }) +async fn handle_request(mut req: Request) -> Result { + let mut body = BytesMut::new(); + while let Some(item) = req.payload().next().await { + body.extend_from_slice(&item?) + } + + info!("request body: {:?}", body); + Ok(Response::Ok() + .header("x-head", HeaderValue::from_static("dummy value!")) + .body(body)) } fn main() -> io::Result<()> { @@ -28,7 +25,7 @@ fn main() -> io::Result<()> { Server::build() .bind("echo", "127.0.0.1:8080", || { - HttpService::build().finish(|_req: Request| handle_request(_req)) + HttpService::build().finish(handle_request) })? .run() } diff --git a/actix-http/src/body.rs b/actix-http/src/body.rs index b761738e1..1d3a43fe5 100644 --- a/actix-http/src/body.rs +++ b/actix-http/src/body.rs @@ -1,8 +1,11 @@ use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, mem}; use bytes::{Bytes, BytesMut}; -use futures::{Async, Poll, Stream}; +use futures::Stream; +use pin_project::{pin_project, project}; use crate::error::Error; @@ -32,7 +35,7 @@ impl BodySize { pub trait MessageBody { fn size(&self) -> BodySize; - fn poll_next(&mut self) -> Poll, Error>; + fn poll_next(&mut self, cx: &mut Context) -> Poll>>; } impl MessageBody for () { @@ -40,8 +43,8 @@ impl MessageBody for () { BodySize::Empty } - fn poll_next(&mut self) -> Poll, Error> { - Ok(Async::Ready(None)) + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + Poll::Ready(None) } } @@ -50,11 +53,12 @@ impl MessageBody for Box { self.as_ref().size() } - fn poll_next(&mut self) -> Poll, Error> { - self.as_mut().poll_next() + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + self.as_mut().poll_next(cx) } } +#[pin_project] pub enum ResponseBody { Body(B), Other(Body), @@ -93,20 +97,24 @@ impl MessageBody for ResponseBody { } } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { match self { - ResponseBody::Body(ref mut body) => body.poll_next(), - ResponseBody::Other(ref mut body) => body.poll_next(), + ResponseBody::Body(ref mut body) => body.poll_next(cx), + ResponseBody::Other(ref mut body) => body.poll_next(cx), } } } impl Stream for ResponseBody { - type Item = Bytes; - type Error = Error; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.poll_next() + #[project] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + #[project] + match self.project() { + ResponseBody::Body(ref mut body) => body.poll_next(cx), + ResponseBody::Other(ref mut body) => body.poll_next(cx), + } } } @@ -144,19 +152,19 @@ impl MessageBody for Body { } } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { match self { - Body::None => Ok(Async::Ready(None)), - Body::Empty => Ok(Async::Ready(None)), + Body::None => Poll::Ready(None), + Body::Empty => Poll::Ready(None), Body::Bytes(ref mut bin) => { let len = bin.len(); if len == 0 { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(mem::replace(bin, Bytes::new())))) + Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new())))) } } - Body::Message(ref mut body) => body.poll_next(), + Body::Message(ref mut body) => body.poll_next(cx), } } } @@ -242,7 +250,7 @@ impl From for Body { impl From> for Body where - S: Stream + 'static, + S: Stream> + 'static, { fn from(s: SizedStream) -> Body { Body::from_message(s) @@ -251,7 +259,7 @@ where impl From> for Body where - S: Stream + 'static, + S: Stream> + 'static, E: Into + 'static, { fn from(s: BodyStream) -> Body { @@ -264,11 +272,11 @@ impl MessageBody for Bytes { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(mem::replace(self, Bytes::new())))) + Poll::Ready(Some(Ok(mem::replace(self, Bytes::new())))) } } } @@ -278,13 +286,11 @@ impl MessageBody for BytesMut { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some( - mem::replace(self, BytesMut::new()).freeze(), - ))) + Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze()))) } } } @@ -294,11 +300,11 @@ impl MessageBody for &'static str { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from_static( + Poll::Ready(Some(Ok(Bytes::from_static( mem::replace(self, "").as_ref(), )))) } @@ -310,13 +316,11 @@ impl MessageBody for &'static [u8] { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from_static(mem::replace( - self, b"", - ))))) + Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b""))))) } } } @@ -326,14 +330,11 @@ impl MessageBody for Vec { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from(mem::replace( - self, - Vec::new(), - ))))) + Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new()))))) } } } @@ -343,11 +344,11 @@ impl MessageBody for String { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from( + Poll::Ready(Some(Ok(Bytes::from( mem::replace(self, String::new()).into_bytes(), )))) } @@ -356,14 +357,16 @@ impl MessageBody for String { /// Type represent streaming body. /// Response does not contain `content-length` header and appropriate transfer encoding is used. +#[pin_project] pub struct BodyStream { + #[pin] stream: S, _t: PhantomData, } impl BodyStream where - S: Stream, + S: Stream>, E: Into, { pub fn new(stream: S) -> Self { @@ -376,28 +379,34 @@ where impl MessageBody for BodyStream where - S: Stream, + S: Stream>, E: Into, { fn size(&self) -> BodySize { BodySize::Stream } - fn poll_next(&mut self) -> Poll, Error> { - self.stream.poll().map_err(std::convert::Into::into) + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + unsafe { Pin::new_unchecked(self) } + .project() + .stream + .poll_next(cx) + .map(|res| res.map(|res| res.map_err(std::convert::Into::into))) } } /// Type represent streaming body. This body implementation should be used /// if total size of stream is known. Data get sent as is without using transfer encoding. +#[pin_project] pub struct SizedStream { size: u64, + #[pin] stream: S, } impl SizedStream where - S: Stream, + S: Stream>, { pub fn new(size: u64, stream: S) -> Self { SizedStream { size, stream } @@ -406,20 +415,25 @@ where impl MessageBody for SizedStream where - S: Stream, + S: Stream>, { fn size(&self) -> BodySize { BodySize::Sized64(self.size) } - fn poll_next(&mut self) -> Poll, Error> { - self.stream.poll() + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + unsafe { Pin::new_unchecked(self) } + .project() + .stream + .poll_next(cx) } } #[cfg(test)] mod tests { use super::*; + use actix_http_test::block_on; + use futures::future::{lazy, poll_fn}; impl Body { pub(crate) fn get_ref(&self) -> &[u8] { @@ -447,8 +461,8 @@ mod tests { assert_eq!("test".size(), BodySize::Sized(4)); assert_eq!( - "test".poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + block_on(poll_fn(|cx| "test".poll_next(cx))).unwrap().ok(), + Some(Bytes::from("test")) ); } @@ -464,8 +478,10 @@ mod tests { assert_eq!((&b"test"[..]).size(), BodySize::Sized(4)); assert_eq!( - (&b"test"[..]).poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + block_on(poll_fn(|cx| (&b"test"[..]).poll_next(cx))) + .unwrap() + .ok(), + Some(Bytes::from("test")) ); } @@ -476,8 +492,10 @@ mod tests { assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); assert_eq!( - Vec::from("test").poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + block_on(poll_fn(|cx| Vec::from("test").poll_next(cx))) + .unwrap() + .ok(), + Some(Bytes::from("test")) ); } @@ -489,8 +507,8 @@ mod tests { assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!( - b.poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(), + Some(Bytes::from("test")) ); } @@ -502,8 +520,8 @@ mod tests { assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!( - b.poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(), + Some(Bytes::from("test")) ); } @@ -517,22 +535,22 @@ mod tests { assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!( - b.poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(), + Some(Bytes::from("test")) ); } #[test] fn test_unit() { assert_eq!(().size(), BodySize::Empty); - assert_eq!(().poll_next().unwrap(), Async::Ready(None)); + assert!(block_on(poll_fn(|cx| ().poll_next(cx))).is_none()); } #[test] fn test_box() { let mut val = Box::new(()); assert_eq!(val.size(), BodySize::Empty); - assert_eq!(val.poll_next().unwrap(), Async::Ready(None)); + assert!(block_on(poll_fn(|cx| val.poll_next(cx))).is_none()); } #[test] diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index cd23b7265..7e1dae58f 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -4,7 +4,7 @@ use std::rc::Rc; use actix_codec::Framed; use actix_server_config::ServerConfig as SrvConfig; -use actix_service::{IntoNewService, NewService, Service}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use crate::body::MessageBody; use crate::config::{KeepAlive, ServiceConfig}; @@ -32,9 +32,10 @@ pub struct HttpServiceBuilder> { impl HttpServiceBuilder> where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, + ::Future: 'static, { /// Create instance of `ServiceConfigBuilder` pub fn new() -> Self { @@ -52,19 +53,22 @@ where impl HttpServiceBuilder where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - X: NewService, + ::Future: 'static, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService< + ::Future: 'static, + U: ServiceFactory< Config = SrvConfig, Request = (Request, Framed), Response = (), >, U::Error: fmt::Display, U::InitError: fmt::Debug, + ::Future: 'static, { /// Set server keep-alive setting. /// @@ -108,16 +112,17 @@ where /// request will be forwarded to main service. pub fn expect(self, expect: F) -> HttpServiceBuilder where - F: IntoNewService, - X1: NewService, + F: IntoServiceFactory, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, + ::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, client_timeout: self.client_timeout, client_disconnect: self.client_disconnect, - expect: expect.into_new_service(), + expect: expect.into_factory(), upgrade: self.upgrade, on_connect: self.on_connect, _t: PhantomData, @@ -130,21 +135,22 @@ where /// and this service get called with original request and framed object. pub fn upgrade(self, upgrade: F) -> HttpServiceBuilder where - F: IntoNewService, - U1: NewService< + F: IntoServiceFactory, + U1: ServiceFactory< Config = SrvConfig, Request = (Request, Framed), Response = (), >, U1::Error: fmt::Display, U1::InitError: fmt::Debug, + ::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, client_timeout: self.client_timeout, client_disconnect: self.client_disconnect, expect: self.expect, - upgrade: Some(upgrade.into_new_service()), + upgrade: Some(upgrade.into_factory()), on_connect: self.on_connect, _t: PhantomData, } @@ -166,8 +172,8 @@ where /// Finish service configuration and create *http service* for HTTP/1 protocol. pub fn h1(self, service: F) -> H1Service where - B: MessageBody + 'static, - F: IntoNewService, + B: MessageBody, + F: IntoServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, @@ -177,7 +183,7 @@ where self.client_timeout, self.client_disconnect, ); - H1Service::with_config(cfg, service.into_new_service()) + H1Service::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) @@ -187,10 +193,10 @@ where pub fn h2(self, service: F) -> H2Service where B: MessageBody + 'static, - F: IntoNewService, - S::Error: Into, + F: IntoServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, { let cfg = ServiceConfig::new( @@ -198,18 +204,17 @@ where self.client_timeout, self.client_disconnect, ); - H2Service::with_config(cfg, service.into_new_service()) - .on_connect(self.on_connect) + H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) } /// Finish service configuration and create `HttpService` instance. pub fn finish(self, service: F) -> HttpService where B: MessageBody + 'static, - F: IntoNewService, - S::Error: Into, + F: IntoServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, { let cfg = ServiceConfig::new( @@ -217,7 +222,7 @@ where self.client_timeout, self.client_disconnect, ); - HttpService::with_config(cfg, service.into_new_service()) + HttpService::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index d2b94b3e5..75d393b1b 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -1,10 +1,12 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, io, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{Buf, Bytes}; -use futures::future::{err, Either, Future, FutureResult}; -use futures::Poll; +use futures::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready}; use h2::client::SendRequest; +use pin_project::{pin_project, project}; use crate::body::MessageBody; use crate::h1::ClientCodec; @@ -21,8 +23,8 @@ pub(crate) enum ConnectionType { } pub trait Connection { - type Io: AsyncRead + AsyncWrite; - type Future: Future; + type Io: AsyncRead + AsyncWrite + Unpin; + type Future: Future>; fn protocol(&self) -> Protocol; @@ -34,8 +36,7 @@ pub trait Connection { ) -> Self::Future; type TunnelFuture: Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, + Output = Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed @@ -71,7 +72,7 @@ where } } -impl IoConnection { +impl IoConnection { pub(crate) fn new( io: ConnectionType, created: time::Instant, @@ -91,11 +92,11 @@ impl IoConnection { impl Connection for IoConnection where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = T; type Future = - Box>; + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self.io { @@ -111,38 +112,30 @@ where body: B, ) -> Self::Future { match self.io.take().unwrap() { - ConnectionType::H1(io) => Box::new(h1proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), - ConnectionType::H2(io) => Box::new(h2proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), + ConnectionType::H1(io) => { + h1proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } + ConnectionType::H2(io) => { + h2proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } } } type TunnelFuture = Either< - Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >, - FutureResult<(ResponseHead, Framed), SendRequestError>, + Ready), SendRequestError>>, >; /// Send request, returns Response and Framed fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { match self.io.take().unwrap() { ConnectionType::H1(io) => { - Either::A(Box::new(h1proto::open_tunnel(io, head.into()))) + Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) } ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { @@ -152,7 +145,7 @@ where None, )); } - Either::B(err(SendRequestError::TunnelNotSupported)) + Either::Right(err(SendRequestError::TunnelNotSupported)) } } } @@ -166,12 +159,12 @@ pub(crate) enum EitherConnection { impl Connection for EitherConnection where - A: AsyncRead + AsyncWrite + 'static, - B: AsyncRead + AsyncWrite + 'static, + A: AsyncRead + AsyncWrite + Unpin + 'static, + B: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = EitherIo; type Future = - Box>; + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self { @@ -191,44 +184,30 @@ where } } - type TunnelFuture = Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + type TunnelFuture = LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed fn open_tunnel>(self, head: H) -> Self::TunnelFuture { match self { - EitherConnection::A(con) => Box::new( - con.open_tunnel(head) - .map(|(head, framed)| (head, framed.map_io(EitherIo::A))), - ), - EitherConnection::B(con) => Box::new( - con.open_tunnel(head) - .map(|(head, framed)| (head, framed.map_io(EitherIo::B))), - ), + EitherConnection::A(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A)))) + .boxed_local(), + EitherConnection::B(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B)))) + .boxed_local(), } } } +#[pin_project] pub enum EitherIo { - A(A), - B(B), -} - -impl io::Read for EitherIo -where - A: io::Read, - B: io::Read, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - EitherIo::A(ref mut val) => val.read(buf), - EitherIo::B(ref mut val) => val.read(buf), - } - } + A(#[pin] A), + B(#[pin] B), } impl AsyncRead for EitherIo @@ -236,6 +215,19 @@ where A: AsyncRead, B: AsyncRead, { + #[project] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_read(cx, buf), + EitherIo::B(val) => val.poll_read(cx, buf), + } + } + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { match self { EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), @@ -244,45 +236,58 @@ where } } -impl io::Write for EitherIo -where - A: io::Write, - B: io::Write, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - EitherIo::A(ref mut val) => val.write(buf), - EitherIo::B(ref mut val) => val.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - EitherIo::A(ref mut val) => val.flush(), - EitherIo::B(ref mut val) => val.flush(), - } - } -} - impl AsyncWrite for EitherIo where A: AsyncWrite, B: AsyncWrite, { - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self { - EitherIo::A(ref mut val) => val.shutdown(), - EitherIo::B(ref mut val) => val.shutdown(), + #[project] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_write(cx, buf), + EitherIo::B(val) => val.poll_write(cx, buf), } } - fn write_buf(&mut self, buf: &mut U) -> Poll + #[project] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_flush(cx), + EitherIo::B(val) => val.poll_flush(cx), + } + } + + #[project] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_shutdown(cx), + EitherIo::B(val) => val.poll_shutdown(cx), + } + } + + #[project] + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut U, + ) -> Poll> where Self: Sized, { - match self { - EitherIo::A(ref mut val) => val.write_buf(buf), - EitherIo::B(ref mut val) => val.write_buf(buf), + #[project] + match self.project() { + EitherIo::A(val) => val.poll_write_buf(cx, buf), + EitherIo::B(val) => val.poll_write_buf(cx, buf), } } } diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs index 98e8618c3..1895f5306 100644 --- a/actix-http/src/client/connector.rs +++ b/actix-http/src/client/connector.rs @@ -1,37 +1,41 @@ use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use actix_codec::{AsyncRead, AsyncWrite}; use actix_connect::{ default_connector, Connect as TcpConnect, Connection as TcpConnection, }; -use actix_service::{apply_fn, Service, ServiceExt}; +use actix_service::{apply_fn, Service}; use actix_utils::timeout::{TimeoutError, TimeoutService}; +use futures::future::Ready; use http::Uri; -use tokio_tcp::TcpStream; +use tokio_net::tcp::TcpStream; use super::connection::Connection; use super::error::ConnectError; use super::pool::{ConnectionPool, Protocol}; use super::Connect; -#[cfg(feature = "ssl")] -use openssl::ssl::SslConnector as OpensslConnector; +#[cfg(feature = "openssl")] +use open_ssl::ssl::SslConnector as OpensslConnector; -#[cfg(feature = "rust-tls")] -use rustls::ClientConfig; -#[cfg(feature = "rust-tls")] +#[cfg(feature = "rustls")] +use rust_tls::ClientConfig; +#[cfg(feature = "rustls")] use std::sync::Arc; -#[cfg(any(feature = "ssl", feature = "rust-tls"))] +#[cfg(any(feature = "openssl", feature = "rustls"))] enum SslConnector { - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] Openssl(OpensslConnector), - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] Rustls(Arc), } -#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] +#[cfg(not(any(feature = "openssl", feature = "rustls")))] type SslConnector = (); /// Manages http client network connectivity @@ -58,8 +62,8 @@ pub struct Connector { _t: PhantomData, } -trait Io: AsyncRead + AsyncWrite {} -impl Io for T {} +trait Io: AsyncRead + AsyncWrite + Unpin {} +impl Io for T {} impl Connector<(), ()> { #[allow(clippy::new_ret_no_self)] @@ -72,9 +76,9 @@ impl Connector<(), ()> { TcpStream, > { let ssl = { - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] { - use openssl::ssl::SslMethod; + use open_ssl::ssl::SslMethod; let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); let _ = ssl @@ -82,7 +86,7 @@ impl Connector<(), ()> { .map_err(|e| error!("Can not set alpn protocol: {:?}", e)); SslConnector::Openssl(ssl.build()) } - #[cfg(all(not(feature = "ssl"), feature = "rust-tls"))] + #[cfg(all(not(feature = "openssl"), feature = "rustls"))] { let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let mut config = ClientConfig::new(); @@ -92,7 +96,7 @@ impl Connector<(), ()> { .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); SslConnector::Rustls(Arc::new(config)) } - #[cfg(not(any(feature = "ssl", feature = "rust-tls")))] + #[cfg(not(any(feature = "openssl", feature = "rustls")))] {} }; @@ -113,7 +117,7 @@ impl Connector { /// Use custom connector. pub fn connector(self, connector: T1) -> Connector where - U1: AsyncRead + AsyncWrite + fmt::Debug, + U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, T1: Service< Request = TcpConnect, Response = TcpConnection, @@ -135,7 +139,7 @@ impl Connector { impl Connector where - U: AsyncRead + AsyncWrite + fmt::Debug + 'static, + U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, T: Service< Request = TcpConnect, Response = TcpConnection, @@ -150,14 +154,14 @@ where self } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] /// Use custom `SslConnector` instance. pub fn ssl(mut self, connector: OpensslConnector) -> Self { self.ssl = SslConnector::Openssl(connector); self } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] pub fn rustls(mut self, connector: Arc) -> Self { self.ssl = SslConnector::Rustls(connector); self @@ -212,8 +216,8 @@ where pub fn finish( self, ) -> impl Service - + Clone { - #[cfg(not(any(feature = "ssl", feature = "rust-tls")))] + + Clone { + #[cfg(not(any(feature = "openssl", feature = "rustls")))] { let connector = TimeoutService::new( self.timeout, @@ -238,32 +242,32 @@ where ), } } - #[cfg(any(feature = "ssl", feature = "rust-tls"))] + #[cfg(any(feature = "openssl", feature = "rustls"))] { const H2: &[u8] = b"h2"; - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] use actix_connect::ssl::OpensslConnector; - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] use actix_connect::ssl::RustlsConnector; - use actix_service::boxed::service; - #[cfg(feature = "rust-tls")] - use rustls::Session; + use actix_service::{boxed::service, pipeline}; + #[cfg(feature = "rustls")] + use rust_tls::Session; let ssl_service = TimeoutService::new( self.timeout, - apply_fn(self.connector.clone(), |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) + pipeline( + apply_fn(self.connector.clone(), |msg: Connect, srv| { + srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + }) + .map_err(ConnectError::from), + ) .and_then(match self.ssl { - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] SslConnector::Openssl(ssl) => service( OpensslConnector::service(ssl) - .map_err(ConnectError::from) .map(|stream| { let sock = stream.into_parts().0; let h2 = sock - .get_ref() .ssl() .selected_alpn_protocol() .map(|protos| protos.windows(2).any(|w| w == H2)) @@ -273,9 +277,10 @@ where } else { (Box::new(sock) as Box, Protocol::Http1) } - }), + }) + .map_err(ConnectError::from), ), - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] SslConnector::Rustls(ssl) => service( RustlsConnector::service(ssl) .map_err(ConnectError::from) @@ -303,7 +308,7 @@ where let tcp_service = TimeoutService::new( self.timeout, - apply_fn(self.connector.clone(), |msg: Connect, srv| { + apply_fn(self.connector, |msg: Connect, srv| { srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) }) .map_err(ConnectError::from) @@ -334,19 +339,20 @@ where } } -#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] +#[cfg(not(any(feature = "openssl", feature = "rustls")))] mod connect_impl { - use futures::future::{err, Either, FutureResult}; - use futures::Poll; + use std::task::{Context, Poll}; + + use futures::future::{err, Either, Ready}; + use futures::ready; use super::*; use crate::client::connection::IoConnection; pub(crate) struct InnerConnector where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { pub(crate) tcp_pool: ConnectionPool, @@ -354,9 +360,8 @@ mod connect_impl { impl Clone for InnerConnector where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { fn clone(&self) -> Self { @@ -368,9 +373,8 @@ mod connect_impl { impl Service for InnerConnector where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { type Request = Connect; @@ -378,38 +382,38 @@ mod connect_impl { type Error = ConnectError; type Future = Either< as Service>::Future, - FutureResult, ConnectError>, + Ready, ConnectError>>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.tcp_pool.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.tcp_pool.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { match req.uri.scheme_str() { Some("https") | Some("wss") => { - Either::B(err(ConnectError::SslIsNotSupported)) + Either::Right(err(ConnectError::SslIsNotSupported)) } - _ => Either::A(self.tcp_pool.call(req)), + _ => Either::Left(self.tcp_pool.call(req)), } } } } -#[cfg(any(feature = "ssl", feature = "rust-tls"))] +#[cfg(any(feature = "openssl", feature = "rustls"))] mod connect_impl { use std::marker::PhantomData; - use futures::future::{Either, FutureResult}; - use futures::{Async, Future, Poll}; + use futures::future::Either; + use futures::ready; use super::*; use crate::client::connection::EitherConnection; pub(crate) struct InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service, T2: Service, { @@ -419,13 +423,11 @@ mod connect_impl { impl Clone for InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service - + Clone + 'static, T2: Service - + Clone + 'static, { fn clone(&self) -> Self { @@ -438,53 +440,47 @@ mod connect_impl { impl Service for InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service - + Clone + 'static, T2: Service - + Clone + 'static, { type Request = Connect; type Response = EitherConnection; type Error = ConnectError; type Future = Either< - FutureResult, - Either< - InnerConnectorResponseA, - InnerConnectorResponseB, - >, + InnerConnectorResponseA, + InnerConnectorResponseB, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.tcp_pool.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.tcp_pool.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { match req.uri.scheme_str() { - Some("https") | Some("wss") => { - Either::B(Either::B(InnerConnectorResponseB { - fut: self.ssl_pool.call(req), - _t: PhantomData, - })) - } - _ => Either::B(Either::A(InnerConnectorResponseA { + Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB { + fut: self.ssl_pool.call(req), + _t: PhantomData, + }), + _ => Either::Left(InnerConnectorResponseA { fut: self.tcp_pool.call(req), _t: PhantomData, - })), + }), } } } + #[pin_project::pin_project] pub(crate) struct InnerConnectorResponseA where - Io1: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { + #[pin] fut: as Service>::Future, _t: PhantomData, } @@ -492,29 +488,28 @@ mod connect_impl { impl Future for InnerConnectorResponseA where T: Service - + Clone + 'static, - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = EitherConnection; - type Error = ConnectError; + type Output = Result, ConnectError>; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(|res| EitherConnection::A(res)), + ) } } + #[pin_project::pin_project] pub(crate) struct InnerConnectorResponseB where - Io2: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { + #[pin] fut: as Service>::Future, _t: PhantomData, } @@ -522,19 +517,17 @@ mod connect_impl { impl Future for InnerConnectorResponseB where T: Service - + Clone + 'static, - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = EitherConnection; - type Error = ConnectError; + type Output = Result, ConnectError>; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(|res| EitherConnection::B(res)), + ) } } } diff --git a/actix-http/src/client/error.rs b/actix-http/src/client/error.rs index 0ac5f30f5..75f7935f2 100644 --- a/actix-http/src/client/error.rs +++ b/actix-http/src/client/error.rs @@ -3,8 +3,8 @@ use std::io; use derive_more::{Display, From}; use trust_dns_resolver::error::ResolveError; -#[cfg(feature = "ssl")] -use openssl::ssl::{Error as SslError, HandshakeError}; +#[cfg(feature = "openssl")] +use open_ssl::ssl::{Error as SslError, HandshakeError}; use crate::error::{Error, ParseError, ResponseError}; use crate::http::Error as HttpError; @@ -18,7 +18,7 @@ pub enum ConnectError { SslIsNotSupported, /// SSL error - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] #[display(fmt = "{}", _0)] SslError(SslError), @@ -63,7 +63,7 @@ impl From for ConnectError { } } -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] impl From> for ConnectError { fn from(err: HandshakeError) -> ConnectError { match err { diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index b078c6a67..041a36856 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -1,10 +1,13 @@ +use std::future::Future; use std::io::Write; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{io, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{BufMut, Bytes, BytesMut}; -use futures::future::{ok, Either}; -use futures::{Async, Future, Poll, Sink, Stream}; +use futures::future::{ok, poll_fn, Either}; +use futures::{Sink, SinkExt, Stream, StreamExt}; use crate::error::PayloadError; use crate::h1; @@ -18,15 +21,15 @@ use super::error::{ConnectError, SendRequestError}; use super::pool::Acquired; use crate::body::{BodySize, MessageBody}; -pub(crate) fn send_request( +pub(crate) async fn send_request( io: T, mut head: RequestHeadType, body: B, created: time::Instant, pool: Option>, -) -> impl Future +) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { // set request host header @@ -62,68 +65,99 @@ where io: Some(io), }; - let len = body.size(); - // create Framed and send request - Framed::new(io, h1::ClientCodec::default()) - .send((head, len).into()) - .from_err() - // send request body - .and_then(move |framed| match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => { - Either::A(ok(framed)) - } - _ => Either::B(SendBody::new(body, framed)), - }) - // read response and init read body - .and_then(|framed| { - framed - .into_future() - .map_err(|(e, _)| SendRequestError::from(e)) - .and_then(|(item, framed)| { - if let Some(res) = item { - match framed.get_codec().message_type() { - h1::MessageType::None => { - let force_close = !framed.get_codec().keepalive(); - release_connection(framed, force_close); - Ok((res, Payload::None)) - } - _ => { - let pl: PayloadStream = Box::new(PlStream::new(framed)); - Ok((res, pl.into())) - } - } - } else { - Err(ConnectError::Disconnected.into()) - } - }) - }) + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, body.size()).into()).await?; + + // send request body + match body.size() { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), + _ => send_body(body, &mut framed).await?, + }; + + // read response and init read body + let res = framed.into_future().await; + let (head, framed) = if let (Some(result), framed) = res { + let item = result.map_err(SendRequestError::from)?; + (item, framed) + } else { + return Err(SendRequestError::from(ConnectError::Disconnected)); + }; + + match framed.get_codec().message_type() { + h1::MessageType::None => { + let force_close = !framed.get_codec().keepalive(); + release_connection(framed, force_close); + Ok((head, Payload::None)) + } + _ => { + let pl: PayloadStream = PlStream::new(framed).boxed_local(); + Ok((head, pl.into())) + } + } } -pub(crate) fn open_tunnel( +pub(crate) async fn open_tunnel( io: T, head: RequestHeadType, -) -> impl Future), Error = SendRequestError> +) -> Result<(ResponseHead, Framed), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { // create Framed and send request - Framed::new(io, h1::ClientCodec::default()) - .send((head, BodySize::None).into()) - .from_err() - // read response - .and_then(|framed| { - framed - .into_future() - .map_err(|(e, _)| SendRequestError::from(e)) - .and_then(|(head, framed)| { - if let Some(head) = head { - Ok((head, framed)) + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, BodySize::None).into()).await?; + + // read response + if let (Some(result), framed) = framed.into_future().await { + let head = result.map_err(SendRequestError::from)?; + Ok((head, framed)) + } else { + Err(SendRequestError::from(ConnectError::Disconnected)) + } +} + +/// send request body to the peer +pub(crate) async fn send_body( + mut body: B, + framed: &mut Framed, +) -> Result<(), SendRequestError> +where + I: ConnectionLifetime, + B: MessageBody, +{ + let mut eof = false; + while !eof { + while !eof && !framed.is_write_buf_full() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(result) => { + framed.write(h1::Message::Chunk(Some(result?)))?; + } + None => { + eof = true; + framed.write(h1::Message::Chunk(None))?; + } + } + } + + if !framed.is_write_buf_empty() { + poll_fn(|cx| match framed.flush(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => { + if !framed.is_write_buf_full() { + Poll::Ready(Ok(())) } else { - Err(SendRequestError::from(ConnectError::Disconnected)) + Poll::Pending } - }) - }) + } + }) + .await?; + } + } + + SinkExt::flush(framed).await?; + Ok(()) } #[doc(hidden)] @@ -134,7 +168,10 @@ pub struct H1Connection { pool: Option>, } -impl ConnectionLifetime for H1Connection { +impl ConnectionLifetime for H1Connection +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ /// Close connection fn close(&mut self) { if let Some(mut pool) = self.pool.take() { @@ -162,98 +199,41 @@ impl ConnectionLifetime for H1Connection } } -impl io::Read for H1Connection { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.io.as_mut().unwrap().read(buf) +impl AsyncRead for H1Connection { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) } } -impl AsyncRead for H1Connection {} - -impl io::Write for H1Connection { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.io.as_mut().unwrap().write(buf) +impl AsyncWrite for H1Connection { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.io.as_mut().unwrap().flush() + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_flush(cx) } -} -impl AsyncWrite for H1Connection { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.as_mut().unwrap().shutdown() - } -} - -/// Future responsible for sending request body to the peer -pub(crate) struct SendBody { - body: Option, - framed: Option>, - flushed: bool, -} - -impl SendBody -where - I: AsyncRead + AsyncWrite + 'static, - B: MessageBody, -{ - pub(crate) fn new(body: B, framed: Framed) -> Self { - SendBody { - body: Some(body), - framed: Some(framed), - flushed: true, - } - } -} - -impl Future for SendBody -where - I: ConnectionLifetime, - B: MessageBody, -{ - type Item = Framed; - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - let mut body_ready = true; - loop { - while body_ready - && self.body.is_some() - && !self.framed.as_ref().unwrap().is_write_buf_full() - { - match self.body.as_mut().unwrap().poll_next()? { - Async::Ready(item) => { - // check if body is done - if item.is_none() { - let _ = self.body.take(); - } - self.flushed = false; - self.framed - .as_mut() - .unwrap() - .force_send(h1::Message::Chunk(item))?; - break; - } - Async::NotReady => body_ready = false, - } - } - - if !self.flushed { - match self.framed.as_mut().unwrap().poll_complete()? { - Async::Ready(_) => { - self.flushed = true; - continue; - } - Async::NotReady => return Ok(Async::NotReady), - } - } - - if self.body.is_none() { - return Ok(Async::Ready(self.framed.take().unwrap())); - } - return Ok(Async::NotReady); - } + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) } } @@ -270,23 +250,24 @@ impl PlStream { } impl Stream for PlStream { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - match self.framed.as_mut().unwrap().poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(Some(chunk)) => { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + match this.framed.as_mut().unwrap().next_item(cx)? { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(chunk)) => { if let Some(chunk) = chunk { - Ok(Async::Ready(Some(chunk))) + Poll::Ready(Some(Ok(chunk))) } else { - let framed = self.framed.take().unwrap(); + let framed = this.framed.take().unwrap(); let force_close = !framed.get_codec().keepalive(); release_connection(framed, force_close); - Ok(Async::Ready(None)) + Poll::Ready(None) } } - Async::Ready(None) => Ok(Async::Ready(None)), + Poll::Ready(None) => Poll::Ready(None), } } } diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 5744a1547..1647abf81 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -1,9 +1,11 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time; use actix_codec::{AsyncRead, AsyncWrite}; use bytes::Bytes; -use futures::future::{err, Either}; -use futures::{Async, Future, Poll}; +use futures::future::{err, poll_fn, Either}; use h2::{client::SendRequest, SendStream}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::{request::Request, HttpTryFrom, Method, Version}; @@ -17,15 +19,15 @@ use super::connection::{ConnectionType, IoConnection}; use super::error::SendRequestError; use super::pool::Acquired; -pub(crate) fn send_request( - io: SendRequest, +pub(crate) async fn send_request( + mut io: SendRequest, head: RequestHeadType, body: B, created: time::Instant, pool: Option>, -) -> impl Future +) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { trace!("Sending client request: {:?} {:?}", head, body.size()); @@ -36,158 +38,140 @@ where _ => false, }; - io.ready() - .map_err(SendRequestError::from) - .and_then(move |mut io| { - let mut req = Request::new(()); - *req.uri_mut() = head.as_ref().uri.clone(); - *req.method_mut() = head.as_ref().method.clone(); - *req.version_mut() = Version::HTTP_2; + let mut req = Request::new(()); + *req.uri_mut() = head.as_ref().uri.clone(); + *req.method_mut() = head.as_ref().method.clone(); + *req.version_mut() = Version::HTTP_2; - let mut skip_len = true; - // let mut has_date = false; + let mut skip_len = true; + // let mut has_date = false; - // Content length - let _ = match length { - BodySize::None => None, - BodySize::Stream => { - skip_len = false; - None - } - BodySize::Empty => req - .headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), - BodySize::Sized(len) => req.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), - BodySize::Sized64(len) => req.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), - }; + // Content length + let _ = match length { + BodySize::None => None, + BodySize::Stream => { + skip_len = false; + None + } + BodySize::Empty => req + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + BodySize::Sized64(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + }; - // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. - let (head, extra_headers) = match head { - RequestHeadType::Owned(head) => { - (RequestHeadType::Owned(head), HeaderMap::new()) - } - RequestHeadType::Rc(head, extra_headers) => ( - RequestHeadType::Rc(head, None), - extra_headers.unwrap_or_else(HeaderMap::new), - ), - }; + // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. + let (head, extra_headers) = match head { + RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), + RequestHeadType::Rc(head, extra_headers) => ( + RequestHeadType::Rc(head, None), + extra_headers.unwrap_or_else(HeaderMap::new), + ), + }; - // merging headers from head and extra headers. - let headers = head - .as_ref() - .headers - .iter() - .filter(|(name, _)| !extra_headers.contains_key(*name)) - .chain(extra_headers.iter()); + // merging headers from head and extra headers. + let headers = head + .as_ref() + .headers + .iter() + .filter(|(name, _)| !extra_headers.contains_key(*name)) + .chain(extra_headers.iter()); - // copy headers - for (key, value) in headers { - match *key { - CONNECTION | TRANSFER_ENCODING => continue, // http2 specific - CONTENT_LENGTH if skip_len => continue, - // DATE => has_date = true, - _ => (), - } - req.headers_mut().append(key, value.clone()); + // copy headers + for (key, value) in headers { + match *key { + CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + CONTENT_LENGTH if skip_len => continue, + // DATE => has_date = true, + _ => (), + } + req.headers_mut().append(key, value.clone()); + } + + let res = poll_fn(|cx| io.poll_ready(cx)).await; + if let Err(e) = res { + release(io, pool, created, e.is_io()); + return Err(SendRequestError::from(e)); + } + + let resp = match io.send_request(req, eof) { + Ok((fut, send)) => { + release(io, pool, created, false); + + if !eof { + send_body(body, send).await?; } + fut.await.map_err(SendRequestError::from)? + } + Err(e) => { + release(io, pool, created, e.is_io()); + return Err(e.into()); + } + }; - match io.send_request(req, eof) { - Ok((res, send)) => { - release(io, pool, created, false); + let (parts, body) = resp.into_parts(); + let payload = if head_req { Payload::None } else { body.into() }; - if !eof { - Either::A(Either::B( - SendBody { - body, - send, - buf: None, - } - .and_then(move |_| res.map_err(SendRequestError::from)), - )) - } else { - Either::B(res.map_err(SendRequestError::from)) - } - } - Err(e) => { - release(io, pool, created, e.is_io()); - Either::A(Either::A(err(e.into()))) - } - } - }) - .and_then(move |resp| { - let (parts, body) = resp.into_parts(); - let payload = if head_req { Payload::None } else { body.into() }; - - let mut head = ResponseHead::new(parts.status); - head.version = parts.version; - head.headers = parts.headers.into(); - Ok((head, payload)) - }) - .from_err() + let mut head = ResponseHead::new(parts.status); + head.version = parts.version; + head.headers = parts.headers.into(); + Ok((head, payload)) } -struct SendBody { - body: B, - send: SendStream, - buf: Option, -} - -impl Future for SendBody { - type Item = (); - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - loop { - if self.buf.is_none() { - match self.body.poll_next() { - Ok(Async::Ready(Some(buf))) => { - self.send.reserve_capacity(buf.len()); - self.buf = Some(buf); - } - Ok(Async::Ready(None)) => { - if let Err(e) = self.send.send_data(Bytes::new(), true) { - return Err(e.into()); - } - self.send.reserve_capacity(0); - return Ok(Async::Ready(())); - } - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => return Err(e.into()), +async fn send_body( + mut body: B, + mut send: SendStream, +) -> Result<(), SendRequestError> { + let mut buf = None; + loop { + if buf.is_none() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(Ok(b)) => { + send.reserve_capacity(b.len()); + buf = Some(b); } - } - - match self.send.poll_capacity() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(None)) => return Ok(Async::Ready(())), - Ok(Async::Ready(Some(cap))) => { - let mut buf = self.buf.take().unwrap(); - let len = buf.len(); - let bytes = buf.split_to(std::cmp::min(cap, len)); - - if let Err(e) = self.send.send_data(bytes, false) { + Some(Err(e)) => return Err(e.into()), + None => { + if let Err(e) = send.send_data(Bytes::new(), true) { return Err(e.into()); - } else { - if !buf.is_empty() { - self.send.reserve_capacity(buf.len()); - self.buf = Some(buf); - } - continue; } + send.reserve_capacity(0); + return Ok(()); } - Err(e) => return Err(e.into()), } } + + match poll_fn(|cx| send.poll_capacity(cx)).await { + None => return Ok(()), + Some(Ok(cap)) => { + let b = buf.as_mut().unwrap(); + let len = b.len(); + let bytes = b.split_to(std::cmp::min(cap, len)); + + if let Err(e) = send.send_data(bytes, false) { + return Err(e.into()); + } else { + if !b.is_empty() { + send.reserve_capacity(b.len()); + } else { + buf = None; + } + continue; + } + } + Some(Err(e)) => return Err(e.into()), + } } } // release SendRequest object -fn release( +fn release( io: SendRequest, pool: Option>, created: time::Instant, diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs index a3522ff8a..1952dca59 100644 --- a/actix-http/src/client/pool.rs +++ b/actix-http/src/client/pool.rs @@ -1,22 +1,23 @@ use std::cell::RefCell; use std::collections::VecDeque; +use std::future::Future; use std::io; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::Service; +use actix_utils::{oneshot, task::LocalWaker}; use bytes::Bytes; -use futures::future::{err, ok, Either, FutureResult}; -use futures::task::AtomicTask; -use futures::unsync::oneshot; -use futures::{Async, Future, Poll}; -use h2::client::{handshake, Handshake}; +use futures::future::{err, ok, poll_fn, Either, FutureExt, LocalBoxFuture, Ready}; +use h2::client::{handshake, Connection, SendRequest}; use hashbrown::HashMap; use http::uri::Authority; use indexmap::IndexSet; use slab::Slab; -use tokio_timer::{sleep, Delay}; +use tokio_timer::{delay_for, Delay}; use super::connection::{ConnectionType, IoConnection}; use super::error::ConnectError; @@ -41,16 +42,12 @@ impl From for Key { } /// Connections pool -pub(crate) struct ConnectionPool( - T, - Rc>>, -); +pub(crate) struct ConnectionPool(Rc>, Rc>>); impl ConnectionPool where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { pub(crate) fn new( @@ -61,7 +58,7 @@ where limit: usize, ) -> Self { ConnectionPool( - connector, + Rc::new(RefCell::new(connector)), Rc::new(RefCell::new(Inner { conn_lifetime, conn_keep_alive, @@ -71,7 +68,7 @@ where waiters: Slab::new(), waiters_queue: IndexSet::new(), available: HashMap::new(), - task: None, + waker: LocalWaker::new(), })), ) } @@ -79,8 +76,7 @@ where impl Clone for ConnectionPool where - T: Clone, - Io: AsyncRead + AsyncWrite + 'static, + Io: 'static, { fn clone(&self) -> Self { ConnectionPool(self.0.clone(), self.1.clone()) @@ -89,86 +85,116 @@ where impl Service for ConnectionPool where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { type Request = Connect; type Response = IoConnection; type Error = ConnectError; - type Future = Either< - FutureResult, - Either, OpenConnection>, - >; + type Future = LocalBoxFuture<'static, Result, ConnectError>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.0.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.0.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { - let key = if let Some(authority) = req.uri.authority_part() { - authority.clone().into() - } else { - return Either::A(err(ConnectError::Unresolverd)); + // start support future + tokio_executor::current_thread::spawn(ConnectorPoolSupport { + connector: self.0.clone(), + inner: self.1.clone(), + }); + + let mut connector = self.0.clone(); + let inner = self.1.clone(); + + let fut = async move { + let key = if let Some(authority) = req.uri.authority_part() { + authority.clone().into() + } else { + return Err(ConnectError::Unresolverd); + }; + + // acquire connection + match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { + Acquire::Acquired(io, created) => { + // use existing connection + return Ok(IoConnection::new( + io, + created, + Some(Acquired(key, Some(inner))), + )); + } + Acquire::Available => { + // open tcp connection + let (io, proto) = connector.call(req).await?; + + let guard = OpenGuard::new(key, inner); + + if proto == Protocol::Http1 { + Ok(IoConnection::new( + ConnectionType::H1(io), + Instant::now(), + Some(guard.consume()), + )) + } else { + let (snd, connection) = handshake(io).await?; + tokio_executor::current_thread::spawn(connection.map(|_| ())); + Ok(IoConnection::new( + ConnectionType::H2(snd), + Instant::now(), + Some(guard.consume()), + )) + } + } + _ => { + // connection is not available, wait + let (rx, token) = inner.borrow_mut().wait_for(req); + + let guard = WaiterGuard::new(key, token, inner); + let res = match rx.await { + Err(_) => Err(ConnectError::Disconnected), + Ok(res) => res, + }; + guard.consume(); + res + } + } }; - // acquire connection - match self.1.as_ref().borrow_mut().acquire(&key) { - Acquire::Acquired(io, created) => { - // use existing connection - return Either::A(ok(IoConnection::new( - io, - created, - Some(Acquired(key, Some(self.1.clone()))), - ))); - } - Acquire::Available => { - // open new connection - return Either::B(Either::B(OpenConnection::new( - key, - self.1.clone(), - self.0.call(req), - ))); - } - _ => (), - } - - // connection is not available, wait - let (rx, token, support) = self.1.as_ref().borrow_mut().wait_for(req); - - // start support future - if !support { - self.1.as_ref().borrow_mut().task = Some(AtomicTask::new()); - tokio_current_thread::spawn(ConnectorPoolSupport { - connector: self.0.clone(), - inner: self.1.clone(), - }) - } - - Either::B(Either::A(WaitForConnection { - rx, - key, - token, - inner: Some(self.1.clone()), - })) + fut.boxed_local() } } -#[doc(hidden)] -pub struct WaitForConnection +struct WaiterGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { key: Key, token: usize, - rx: oneshot::Receiver, ConnectError>>, inner: Option>>>, } -impl Drop for WaitForConnection +impl WaiterGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn new(key: Key, token: usize, inner: Rc>>) -> Self { + Self { + key, + token, + inner: Some(inner), + } + } + + fn consume(mut self) { + let _ = self.inner.take(); + } +} + +impl Drop for WaiterGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { if let Some(i) = self.inner.take() { @@ -179,113 +205,43 @@ where } } -impl Future for WaitForConnection +struct OpenGuard where - Io: AsyncRead + AsyncWrite, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = IoConnection; - type Error = ConnectError; - - fn poll(&mut self) -> Poll { - match self.rx.poll() { - Ok(Async::Ready(item)) => match item { - Err(err) => Err(err), - Ok(conn) => { - let _ = self.inner.take(); - Ok(Async::Ready(conn)) - } - }, - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => { - let _ = self.inner.take(); - Err(ConnectError::Disconnected) - } - } - } -} - -#[doc(hidden)] -pub struct OpenConnection -where - Io: AsyncRead + AsyncWrite + 'static, -{ - fut: F, key: Key, - h2: Option>, inner: Option>>>, } -impl OpenConnection +impl OpenGuard where - F: Future, - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { - fn new(key: Key, inner: Rc>>, fut: F) -> Self { - OpenConnection { + fn new(key: Key, inner: Rc>>) -> Self { + Self { key, - fut, inner: Some(inner), - h2: None, } } + + fn consume(mut self) -> Acquired { + Acquired(self.key.clone(), self.inner.take()) + } } -impl Drop for OpenConnection +impl Drop for OpenGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { - if let Some(inner) = self.inner.take() { - let mut inner = inner.as_ref().borrow_mut(); + if let Some(i) = self.inner.take() { + let mut inner = i.as_ref().borrow_mut(); inner.release(); inner.check_availibility(); } } } -impl Future for OpenConnection -where - F: Future, - Io: AsyncRead + AsyncWrite, -{ - type Item = IoConnection; - type Error = ConnectError; - - fn poll(&mut self) -> Poll { - if let Some(ref mut h2) = self.h2 { - return match h2.poll() { - Ok(Async::Ready((snd, connection))) => { - tokio_current_thread::spawn(connection.map_err(|_| ())); - Ok(Async::Ready(IoConnection::new( - ConnectionType::H2(snd), - Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), - ))) - } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e.into()), - }; - } - - match self.fut.poll() { - Err(err) => Err(err), - Ok(Async::Ready((io, proto))) => { - if proto == Protocol::Http1 { - Ok(Async::Ready(IoConnection::new( - ConnectionType::H1(io), - Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), - ))) - } else { - self.h2 = Some(handshake(io)); - self.poll() - } - } - Ok(Async::NotReady) => Ok(Async::NotReady), - } - } -} - enum Acquire { Acquired(ConnectionType, Instant), Available, @@ -312,7 +268,7 @@ pub(crate) struct Inner { )>, >, waiters_queue: IndexSet<(Key, usize)>, - task: Option, + waker: LocalWaker, } impl Inner { @@ -332,7 +288,7 @@ impl Inner { impl Inner where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { /// connection is not available, wait fn wait_for( @@ -341,7 +297,6 @@ where ) -> ( oneshot::Receiver, ConnectError>>, usize, - bool, ) { let (tx, rx) = oneshot::channel(); @@ -351,10 +306,10 @@ where entry.insert(Some((connect, tx))); assert!(self.waiters_queue.insert((key, token))); - (rx, token, self.task.is_some()) + (rx, token) } - fn acquire(&mut self, key: &Key) -> Acquire { + fn acquire(&mut self, key: &Key, cx: &mut Context) -> Acquire { // check limits if self.limit > 0 && self.acquired >= self.limit { return Acquire::NotAvailable; @@ -373,7 +328,7 @@ where { if let Some(timeout) = self.disconnect_timeout { if let ConnectionType::H1(io) = conn.io { - tokio_current_thread::spawn(CloseConnection::new( + tokio_executor::current_thread::spawn(CloseConnection::new( io, timeout, )) } @@ -382,19 +337,19 @@ where let mut io = conn.io; let mut buf = [0; 2]; if let ConnectionType::H1(ref mut s) = io { - match s.read(&mut buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Ok(n) if n > 0 => { + match Pin::new(s).poll_read(cx, &mut buf) { + Poll::Pending => (), + Poll::Ready(Ok(n)) if n > 0 => { if let Some(timeout) = self.disconnect_timeout { if let ConnectionType::H1(io) = io { - tokio_current_thread::spawn( + tokio_executor::current_thread::spawn( CloseConnection::new(io, timeout), ) } } continue; } - Ok(_) | Err(_) => continue, + _ => continue, } } return Acquire::Acquired(io, conn.created); @@ -421,7 +376,7 @@ where self.acquired -= 1; if let Some(timeout) = self.disconnect_timeout { if let ConnectionType::H1(io) = io { - tokio_current_thread::spawn(CloseConnection::new(io, timeout)) + tokio_executor::current_thread::spawn(CloseConnection::new(io, timeout)) } } self.check_availibility(); @@ -429,9 +384,7 @@ where fn check_availibility(&self) { if !self.waiters_queue.is_empty() && self.acquired < self.limit { - if let Some(t) = self.task.as_ref() { - t.notify() - } + self.waker.wake(); } } } @@ -443,29 +396,30 @@ struct CloseConnection { impl CloseConnection where - T: AsyncWrite, + T: AsyncWrite + Unpin, { fn new(io: T, timeout: Duration) -> Self { CloseConnection { io, - timeout: sleep(timeout), + timeout: delay_for(timeout), } } } impl Future for CloseConnection where - T: AsyncWrite, + T: AsyncWrite + Unpin, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll<(), ()> { - match self.timeout.poll() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => match self.io.shutdown() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => Ok(Async::NotReady), + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + let this = self.get_mut(); + + match Pin::new(&mut this.timeout).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, }, } } @@ -473,7 +427,7 @@ where struct ConnectorPoolSupport where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { connector: T, inner: Rc>>, @@ -481,16 +435,17 @@ where impl Future for ConnectorPoolSupport where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service, T::Future: 'static, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - let mut inner = self.inner.as_ref().borrow_mut(); - inner.task.as_ref().unwrap().register(); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + + let mut inner = this.inner.as_ref().borrow_mut(); + inner.waker.register(cx.waker()); // check waiters loop { @@ -505,14 +460,14 @@ where continue; } - match inner.acquire(&key) { + match inner.acquire(&key, cx) { Acquire::NotAvailable => break, Acquire::Acquired(io, created) => { let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; if let Err(conn) = tx.send(Ok(IoConnection::new( io, created, - Some(Acquired(key.clone(), Some(self.inner.clone()))), + Some(Acquired(key.clone(), Some(this.inner.clone()))), ))) { let (io, created) = conn.unwrap().into_inner(); inner.release_conn(&key, io, created); @@ -524,33 +479,38 @@ where OpenWaitingConnection::spawn( key.clone(), tx, - self.inner.clone(), - self.connector.call(connect), + this.inner.clone(), + this.connector.call(connect), ); } } let _ = inner.waiters_queue.swap_remove_index(0); } - Ok(Async::NotReady) + Poll::Pending } } struct OpenWaitingConnection where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fut: F, key: Key, - h2: Option>, + h2: Option< + LocalBoxFuture< + 'static, + Result<(SendRequest, Connection), h2::Error>, + >, + >, rx: Option, ConnectError>>>, inner: Option>>>, } impl OpenWaitingConnection where - F: Future + 'static, - Io: AsyncRead + AsyncWrite + 'static, + F: Future> + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn spawn( key: Key, @@ -558,7 +518,7 @@ where inner: Rc>>, fut: F, ) { - tokio_current_thread::spawn(OpenWaitingConnection { + tokio_executor::current_thread::spawn(OpenWaitingConnection { key, fut, h2: None, @@ -570,7 +530,7 @@ where impl Drop for OpenWaitingConnection where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { if let Some(inner) = self.inner.take() { @@ -583,59 +543,60 @@ where impl Future for OpenWaitingConnection where - F: Future, - Io: AsyncRead + AsyncWrite, + F: Future>, + Io: AsyncRead + AsyncWrite + Unpin, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - if let Some(ref mut h2) = self.h2 { - return match h2.poll() { - Ok(Async::Ready((snd, connection))) => { - tokio_current_thread::spawn(connection.map_err(|_| ())); - let rx = self.rx.take().unwrap(); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + + if let Some(ref mut h2) = this.h2 { + return match Pin::new(h2).poll(cx) { + Poll::Ready(Ok((snd, connection))) => { + tokio_executor::current_thread::spawn(connection.map(|_| ())); + let rx = this.rx.take().unwrap(); let _ = rx.send(Ok(IoConnection::new( ConnectionType::H2(snd), Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), + Some(Acquired(this.key.clone(), this.inner.take())), ))); - Ok(Async::Ready(())) + Poll::Ready(()) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { - let _ = self.inner.take(); - if let Some(rx) = self.rx.take() { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { let _ = rx.send(Err(ConnectError::H2(err))); } - Err(()) + Poll::Ready(()) } }; } - match self.fut.poll() { - Err(err) => { - let _ = self.inner.take(); - if let Some(rx) = self.rx.take() { + match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) { + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { let _ = rx.send(Err(err)); } - Err(()) + Poll::Ready(()) } - Ok(Async::Ready((io, proto))) => { + Poll::Ready(Ok((io, proto))) => { if proto == Protocol::Http1 { - let rx = self.rx.take().unwrap(); + let rx = this.rx.take().unwrap(); let _ = rx.send(Ok(IoConnection::new( ConnectionType::H1(io), Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), + Some(Acquired(this.key.clone(), this.inner.take())), ))); - Ok(Async::Ready(())) + Poll::Ready(()) } else { - self.h2 = Some(handshake(io)); - self.poll() + this.h2 = Some(handshake(io).boxed_local()); + unsafe { Pin::new_unchecked(this) }.poll(cx) } } - Ok(Async::NotReady) => Ok(Async::NotReady), + Poll::Pending => Poll::Pending, } } } @@ -644,7 +605,7 @@ pub(crate) struct Acquired(Key, Option>>>); impl Acquired where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { pub(crate) fn close(&mut self, conn: IoConnection) { if let Some(inner) = self.1.take() { diff --git a/actix-http/src/cloneable.rs b/actix-http/src/cloneable.rs index ffc1d0611..18869c66d 100644 --- a/actix-http/src/cloneable.rs +++ b/actix-http/src/cloneable.rs @@ -1,8 +1,8 @@ use std::cell::UnsafeCell; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::Service; -use futures::Poll; #[doc(hidden)] /// Service that allows to turn non-clone service to a service with `Clone` impl @@ -32,8 +32,8 @@ where type Error = T::Error; type Future = T::Future; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - unsafe { &mut *self.0.as_ref().get() }.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + unsafe { &mut *self.0.as_ref().get() }.poll_ready(cx) } fn call(&mut self, req: T::Request) -> Self::Future { diff --git a/actix-http/src/config.rs b/actix-http/src/config.rs index bdfecef30..488e4d98a 100644 --- a/actix-http/src/config.rs +++ b/actix-http/src/config.rs @@ -5,9 +5,9 @@ use std::rc::Rc; use std::time::{Duration, Instant}; use bytes::BytesMut; -use futures::{future, Future}; +use futures::{future, Future, FutureExt}; use time; -use tokio_timer::{sleep, Delay}; +use tokio_timer::{delay, delay_for, Delay}; // "Sun, 06 Nov 1994 08:49:37 GMT".len() const DATE_VALUE_LENGTH: usize = 29; @@ -104,10 +104,10 @@ impl ServiceConfig { #[inline] /// Client timeout for first request. pub fn client_timer(&self) -> Option { - let delay = self.0.client_timeout; - if delay != 0 { - Some(Delay::new( - self.0.timer.now() + Duration::from_millis(delay), + let delay_time = self.0.client_timeout; + if delay_time != 0 { + Some(delay( + self.0.timer.now() + Duration::from_millis(delay_time), )) } else { None @@ -138,7 +138,7 @@ impl ServiceConfig { /// Return keep-alive timer delay is configured. pub fn keep_alive_timer(&self) -> Option { if let Some(ka) = self.0.keep_alive { - Some(Delay::new(self.0.timer.now() + ka)) + Some(delay(self.0.timer.now() + ka)) } else { None } @@ -242,12 +242,12 @@ impl DateService { // periodic date update let s = self.clone(); - tokio_current_thread::spawn(sleep(Duration::from_millis(500)).then( - move |_| { + tokio_executor::current_thread::spawn( + delay_for(Duration::from_millis(500)).then(move |_| { s.0.reset(); - future::ok(()) - }, - )); + future::ready(()) + }), + ); } } @@ -277,7 +277,7 @@ mod tests { fn test_date() { let mut rt = System::new("test"); - let _ = rt.block_on(future::lazy(|| { + let _ = rt.block_on(future::lazy(|_| { let settings = ServiceConfig::new(KeepAlive::Os, 0, 0); let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); settings.set_date(&mut buf1); diff --git a/actix-http/src/cookie/secure/key.rs b/actix-http/src/cookie/secure/key.rs index 39575c93f..95058ed81 100644 --- a/actix-http/src/cookie/secure/key.rs +++ b/actix-http/src/cookie/secure/key.rs @@ -1,13 +1,12 @@ -use ring::digest::{Algorithm, SHA256}; -use ring::hkdf::expand; -use ring::hmac::SigningKey; +use ring::hkdf::{Algorithm, KeyType, Prk, HKDF_SHA256}; +use ring::hmac; use ring::rand::{SecureRandom, SystemRandom}; use super::private::KEY_LEN as PRIVATE_KEY_LEN; use super::signed::KEY_LEN as SIGNED_KEY_LEN; -static HKDF_DIGEST: &Algorithm = &SHA256; -const KEYS_INFO: &str = "COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"; +static HKDF_DIGEST: Algorithm = HKDF_SHA256; +const KEYS_INFO: &[&[u8]] = &[b"COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"]; /// A cryptographic master key for use with `Signed` and/or `Private` jars. /// @@ -25,6 +24,13 @@ pub struct Key { encryption_key: [u8; PRIVATE_KEY_LEN], } +impl KeyType for &Key { + #[inline] + fn len(&self) -> usize { + SIGNED_KEY_LEN + PRIVATE_KEY_LEN + } +} + impl Key { /// Derives new signing/encryption keys from a master key. /// @@ -56,21 +62,26 @@ impl Key { ); } - // Expand the user's key into two. - let prk = SigningKey::new(HKDF_DIGEST, key); + // An empty `Key` structure; will be filled in with HKDF derived keys. + let mut output_key = Key { + signing_key: [0; SIGNED_KEY_LEN], + encryption_key: [0; PRIVATE_KEY_LEN], + }; + + // Expand the master key into two HKDF generated keys. let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN]; - expand(&prk, KEYS_INFO.as_bytes(), &mut both_keys); + let prk = Prk::new_less_safe(HKDF_DIGEST, key); + let okm = prk.expand(KEYS_INFO, &output_key).expect("okm expand"); + okm.fill(&mut both_keys).expect("fill keys"); - // Copy the keys into their respective arrays. - let mut signing_key = [0; SIGNED_KEY_LEN]; - let mut encryption_key = [0; PRIVATE_KEY_LEN]; - signing_key.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); - encryption_key.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); - - Key { - signing_key, - encryption_key, - } + // Copy the key parts into their respective fields. + output_key + .signing_key + .copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); + output_key + .encryption_key + .copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); + output_key } /// Generates signing/encryption keys from a secure, random source. Keys are diff --git a/actix-http/src/cookie/secure/private.rs b/actix-http/src/cookie/secure/private.rs index eb8e9beb1..6c16e94e8 100644 --- a/actix-http/src/cookie/secure/private.rs +++ b/actix-http/src/cookie/secure/private.rs @@ -1,8 +1,8 @@ use std::str; use log::warn; -use ring::aead::{open_in_place, seal_in_place, Aad, Algorithm, Nonce, AES_256_GCM}; -use ring::aead::{OpeningKey, SealingKey}; +use ring::aead::{Aad, Algorithm, Nonce, AES_256_GCM}; +use ring::aead::{LessSafeKey, UnboundKey}; use ring::rand::{SecureRandom, SystemRandom}; use super::Key; @@ -10,7 +10,7 @@ use crate::cookie::{Cookie, CookieJar}; // Keep these in sync, and keep the key len synced with the `private` docs as // well as the `KEYS_INFO` const in secure::Key. -static ALGO: &Algorithm = &AES_256_GCM; +static ALGO: &'static Algorithm = &AES_256_GCM; const NONCE_LEN: usize = 12; pub const KEY_LEN: usize = 32; @@ -53,11 +53,14 @@ impl<'a> PrivateJar<'a> { } let ad = Aad::from(name.as_bytes()); - let key = OpeningKey::new(ALGO, &self.key).expect("opening key"); - let (nonce, sealed) = data.split_at_mut(NONCE_LEN); + let key = LessSafeKey::new( + UnboundKey::new(&ALGO, &self.key).expect("matching key length"), + ); + let (nonce, mut sealed) = data.split_at_mut(NONCE_LEN); let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); - let unsealed = open_in_place(&key, nonce, ad, 0, sealed) + let unsealed = key + .open_in_place(nonce, ad, &mut sealed) .map_err(|_| "invalid key/nonce/value: bad seal")?; if let Ok(unsealed_utf8) = str::from_utf8(unsealed) { @@ -196,30 +199,33 @@ Please change it as soon as possible." fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec { // Create the `SealingKey` structure. - let key = SealingKey::new(ALGO, key).expect("sealing key creation"); + let unbound = UnboundKey::new(&ALGO, key).expect("matching key length"); + let key = LessSafeKey::new(unbound); // Create a vec to hold the [nonce | cookie value | overhead]. - let overhead = ALGO.tag_len(); - let mut data = vec![0; NONCE_LEN + value.len() + overhead]; + let mut data = vec![0; NONCE_LEN + value.len() + ALGO.tag_len()]; // Randomly generate the nonce, then copy the cookie value as input. let (nonce, in_out) = data.split_at_mut(NONCE_LEN); + let (in_out, tag) = in_out.split_at_mut(value.len()); + in_out.copy_from_slice(value); + + // Randomly generate the nonce into the nonce piece. SystemRandom::new() .fill(nonce) .expect("couldn't random fill nonce"); - in_out[..value.len()].copy_from_slice(value); - let nonce = - Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); + let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid `nonce` length"); // Use cookie's name as associated data to prevent value swapping. let ad = Aad::from(name); + let ad_tag = key + .seal_in_place_separate_tag(nonce, ad, in_out) + .expect("in-place seal"); - // Perform the actual sealing operation and get the output length. - let output_len = - seal_in_place(&key, nonce, ad, in_out, overhead).expect("in-place seal"); + // Copy the tag into the tag piece. + tag.copy_from_slice(ad_tag.as_ref()); // Remove the overhead and return the sealed content. - data.truncate(NONCE_LEN + output_len); data } diff --git a/actix-http/src/cookie/secure/signed.rs b/actix-http/src/cookie/secure/signed.rs index 36a277cd5..3fcd2cd84 100644 --- a/actix-http/src/cookie/secure/signed.rs +++ b/actix-http/src/cookie/secure/signed.rs @@ -1,12 +1,11 @@ -use ring::digest::{Algorithm, SHA256}; -use ring::hmac::{sign, verify_with_own_key as verify, SigningKey}; +use ring::hmac::{self, sign, verify}; use super::Key; use crate::cookie::{Cookie, CookieJar}; // Keep these in sync, and keep the key len synced with the `signed` docs as // well as the `KEYS_INFO` const in secure::Key. -static HMAC_DIGEST: &Algorithm = &SHA256; +static HMAC_DIGEST: hmac::Algorithm = hmac::HMAC_SHA256; const BASE64_DIGEST_LEN: usize = 44; pub const KEY_LEN: usize = 32; @@ -21,7 +20,7 @@ pub const KEY_LEN: usize = 32; /// This type is only available when the `secure` feature is enabled. pub struct SignedJar<'a> { parent: &'a mut CookieJar, - key: SigningKey, + key: hmac::Key, } impl<'a> SignedJar<'a> { @@ -32,7 +31,7 @@ impl<'a> SignedJar<'a> { pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> { SignedJar { parent, - key: SigningKey::new(HMAC_DIGEST, key.signing()), + key: hmac::Key::new(HMAC_DIGEST, key.signing()), } } diff --git a/actix-http/src/encoding/decoder.rs b/actix-http/src/encoding/decoder.rs index 4b56a1b62..1e51e8b56 100644 --- a/actix-http/src/encoding/decoder.rs +++ b/actix-http/src/encoding/decoder.rs @@ -1,4 +1,7 @@ +use std::future::Future; use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_threadpool::{run, CpuFuture}; #[cfg(feature = "brotli")] @@ -6,7 +9,7 @@ use brotli2::write::BrotliDecoder; use bytes::Bytes; #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] use flate2::write::{GzDecoder, ZlibDecoder}; -use futures::{try_ready, Async, Future, Poll, Stream}; +use futures::{ready, Stream}; use super::Writer; use crate::error::PayloadError; @@ -18,12 +21,12 @@ pub struct Decoder { decoder: Option, stream: S, eof: bool, - fut: Option, ContentDecoder), io::Error>>, + fut: Option, ContentDecoder), io::Error>>>, } impl Decoder where - S: Stream, + S: Stream>, { /// Construct a decoder. #[inline] @@ -71,34 +74,41 @@ where impl Stream for Decoder where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { loop { if let Some(ref mut fut) = self.fut { - let (chunk, decoder) = try_ready!(fut.poll()); + let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) { + Ok(Ok(item)) => item, + Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; self.decoder = Some(decoder); self.fut.take(); if let Some(chunk) = chunk { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } if self.eof { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } - match self.stream.poll()? { - Async::Ready(Some(chunk)) => { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), + Poll::Ready(Some(Ok(chunk))) => { if let Some(mut decoder) = self.decoder.take() { if chunk.len() < INPLACE { let chunk = decoder.feed_data(chunk)?; self.decoder = Some(decoder); if let Some(chunk) = chunk { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } else { self.fut = Some(run(move || { @@ -108,21 +118,25 @@ where } continue; } else { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } - Async::Ready(None) => { + Poll::Ready(None) => { self.eof = true; return if let Some(mut decoder) = self.decoder.take() { - Ok(Async::Ready(decoder.feed_eof()?)) + match decoder.feed_eof() { + Ok(Some(res)) => Poll::Ready(Some(Ok(res))), + Ok(None) => Poll::Ready(None), + Err(err) => Poll::Ready(Some(Err(err.into()))), + } } else { - Ok(Async::Ready(None)) + Poll::Ready(None) }; } - Async::NotReady => break, + Poll::Pending => break, } } - Ok(Async::NotReady) + Poll::Pending } } diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index 58d8a2d9e..295d99a2a 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -1,5 +1,8 @@ //! Stream encoder +use std::future::Future; use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_threadpool::{run, CpuFuture}; #[cfg(feature = "brotli")] @@ -7,7 +10,6 @@ use brotli2::write::BrotliEncoder; use bytes::Bytes; #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] use flate2::write::{GzEncoder, ZlibEncoder}; -use futures::{Async, Future, Poll}; use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; @@ -22,7 +24,7 @@ pub struct Encoder { eof: bool, body: EncoderBody, encoder: Option, - fut: Option>, + fut: Option>>, } impl Encoder { @@ -94,43 +96,46 @@ impl MessageBody for Encoder { } } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { loop { if self.eof { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } if let Some(ref mut fut) = self.fut { - let mut encoder = futures::try_ready!(fut.poll()); + let mut encoder = match futures::ready!(Pin::new(fut).poll(cx)) { + Ok(Ok(item)) => item, + Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; let chunk = encoder.take(); self.encoder = Some(encoder); self.fut.take(); if !chunk.is_empty() { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } let result = match self.body { EncoderBody::Bytes(ref mut b) => { if b.is_empty() { - Async::Ready(None) + Poll::Ready(None) } else { - Async::Ready(Some(std::mem::replace(b, Bytes::new()))) + Poll::Ready(Some(Ok(std::mem::replace(b, Bytes::new())))) } } - EncoderBody::Stream(ref mut b) => b.poll_next()?, - EncoderBody::BoxedStream(ref mut b) => b.poll_next()?, + EncoderBody::Stream(ref mut b) => b.poll_next(cx), + EncoderBody::BoxedStream(ref mut b) => b.poll_next(cx), }; match result { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(chunk)) => { + Poll::Ready(Some(Ok(chunk))) => { if let Some(mut encoder) = self.encoder.take() { if chunk.len() < INPLACE { encoder.write(&chunk)?; let chunk = encoder.take(); self.encoder = Some(encoder); if !chunk.is_empty() { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } else { self.fut = Some(run(move || { @@ -139,22 +144,23 @@ impl MessageBody for Encoder { })); } } else { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } - Async::Ready(None) => { + Poll::Ready(None) => { if let Some(encoder) = self.encoder.take() { let chunk = encoder.finish()?; if chunk.is_empty() { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } else { self.eof = true; - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } else { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } } + val => return val, } } } diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 90c35e486..a725789a2 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -6,11 +6,10 @@ use std::str::Utf8Error; use std::string::FromUtf8Error; use std::{fmt, io, result}; -pub use actix_threadpool::BlockingError; use actix_utils::timeout::TimeoutError; use bytes::BytesMut; use derive_more::{Display, From}; -use futures::Canceled; +pub use futures::channel::oneshot::Canceled; use http::uri::InvalidUri; use http::{header, Error as HttpError, StatusCode}; use httparse; @@ -75,7 +74,7 @@ pub trait ResponseError: fmt::Debug + fmt::Display { let _ = write!(Writer(&mut buf), "{}", self); resp.headers_mut().insert( header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), + header::HeaderValue::from_static("text/plain; charset=utf-8"), ); resp.set_body(Body::from(buf)) } @@ -182,13 +181,13 @@ impl ResponseError for FormError {} /// `InternalServerError` for `TimerError` impl ResponseError for TimerError {} -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] /// `InternalServerError` for `openssl::ssl::Error` -impl ResponseError for openssl::ssl::Error {} +impl ResponseError for open_ssl::ssl::Error {} -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] /// `InternalServerError` for `openssl::ssl::HandshakeError` -impl ResponseError for openssl::ssl::HandshakeError {} +impl ResponseError for open_ssl::ssl::HandshakeError {} /// Return `BAD_REQUEST` for `de::value::Error` impl ResponseError for DeError { @@ -197,8 +196,8 @@ impl ResponseError for DeError { } } -/// `InternalServerError` for `BlockingError` -impl ResponseError for BlockingError {} +/// `InternalServerError` for `Canceled` +impl ResponseError for Canceled {} /// Return `BAD_REQUEST` for `Utf8Error` impl ResponseError for Utf8Error { @@ -236,9 +235,6 @@ impl ResponseError for header::InvalidHeaderValueBytes { } } -/// `InternalServerError` for `futures::Canceled` -impl ResponseError for Canceled {} - /// A set of errors that can occur during parsing HTTP streams #[derive(Debug, Display)] pub enum ParseError { @@ -365,15 +361,12 @@ impl From for PayloadError { } } -impl From> for PayloadError { - fn from(err: BlockingError) -> Self { - match err { - BlockingError::Error(e) => PayloadError::Io(e), - BlockingError::Canceled => PayloadError::Io(io::Error::new( - io::ErrorKind::Other, - "Thread pool is gone", - )), - } +impl From for PayloadError { + fn from(_: Canceled) -> Self { + PayloadError::Io(io::Error::new( + io::ErrorKind::Other, + "Operation is canceled", + )) } } @@ -536,7 +529,7 @@ where let _ = write!(Writer(&mut buf), "{}", self); res.headers_mut().insert( header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), + header::HeaderValue::from_static("text/plain; charset=utf-8"), ); res.set_body(Body::from(buf)) } diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs index ce113a145..272270ca1 100644 --- a/actix-http/src/h1/decoder.rs +++ b/actix-http/src/h1/decoder.rs @@ -1,10 +1,12 @@ +use std::future::Future; use std::io; use std::marker::PhantomData; use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_codec::Decoder; use bytes::{Bytes, BytesMut}; -use futures::{Async, Poll}; use http::header::{HeaderName, HeaderValue}; use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version}; use httparse; @@ -442,9 +444,10 @@ impl Decoder for PayloadDecoder { loop { let mut buf = None; // advances the chunked state - *state = match state.step(src, size, &mut buf)? { - Async::NotReady => return Ok(None), - Async::Ready(state) => state, + *state = match state.step(src, size, &mut buf) { + Poll::Pending => return Ok(None), + Poll::Ready(Ok(state)) => state, + Poll::Ready(Err(e)) => return Err(e), }; if *state == ChunkedState::End { trace!("End of chunked stream"); @@ -476,7 +479,7 @@ macro_rules! byte ( $rdr.split_to(1); b } else { - return Ok(Async::NotReady) + return Poll::Pending } }) ); @@ -487,7 +490,7 @@ impl ChunkedState { body: &mut BytesMut, size: &mut u64, buf: &mut Option, - ) -> Poll { + ) -> Poll> { use self::ChunkedState::*; match *self { Size => ChunkedState::read_size(body, size), @@ -499,10 +502,14 @@ impl ChunkedState { BodyLf => ChunkedState::read_body_lf(body), EndCr => ChunkedState::read_end_cr(body), EndLf => ChunkedState::read_end_lf(body), - End => Ok(Async::Ready(ChunkedState::End)), + End => Poll::Ready(Ok(ChunkedState::End)), } } - fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll { + + fn read_size( + rdr: &mut BytesMut, + size: &mut u64, + ) -> Poll> { let radix = 16; match byte!(rdr) { b @ b'0'..=b'9' => { @@ -517,48 +524,49 @@ impl ChunkedState { *size *= radix; *size += u64::from(b + 10 - b'A'); } - b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)), - b';' => return Ok(Async::Ready(ChunkedState::Extension)), - b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)), + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), _ => { - return Err(io::Error::new( + return Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk size line: Invalid Size", - )); + ))); } } - Ok(Async::Ready(ChunkedState::Size)) + Poll::Ready(Ok(ChunkedState::Size)) } - fn read_size_lws(rdr: &mut BytesMut) -> Poll { + + fn read_size_lws(rdr: &mut BytesMut) -> Poll> { trace!("read_size_lws"); match byte!(rdr) { // LWS can follow the chunk size, but no more digits can come - b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)), - b';' => Ok(Async::Ready(ChunkedState::Extension)), - b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), - _ => Err(io::Error::new( + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk size linear white space", - )), + ))), } } - fn read_extension(rdr: &mut BytesMut) -> Poll { + fn read_extension(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), - _ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions } } fn read_size_lf( rdr: &mut BytesMut, size: &mut u64, - ) -> Poll { + ) -> Poll> { match byte!(rdr) { - b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)), - b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)), - _ => Err(io::Error::new( + b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), + b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk size LF", - )), + ))), } } @@ -566,12 +574,12 @@ impl ChunkedState { rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option, - ) -> Poll { + ) -> Poll> { trace!("Chunked read, remaining={:?}", rem); let len = rdr.len() as u64; if len == 0 { - Ok(Async::Ready(ChunkedState::Body)) + Poll::Ready(Ok(ChunkedState::Body)) } else { let slice; if *rem > len { @@ -583,47 +591,47 @@ impl ChunkedState { } *buf = Some(slice); if *rem > 0 { - Ok(Async::Ready(ChunkedState::Body)) + Poll::Ready(Ok(ChunkedState::Body)) } else { - Ok(Async::Ready(ChunkedState::BodyCr)) + Poll::Ready(Ok(ChunkedState::BodyCr)) } } } - fn read_body_cr(rdr: &mut BytesMut) -> Poll { + fn read_body_cr(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)), - _ => Err(io::Error::new( + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk body CR", - )), + ))), } } - fn read_body_lf(rdr: &mut BytesMut) -> Poll { + fn read_body_lf(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\n' => Ok(Async::Ready(ChunkedState::Size)), - _ => Err(io::Error::new( + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk body LF", - )), + ))), } } - fn read_end_cr(rdr: &mut BytesMut) -> Poll { + fn read_end_cr(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::EndLf)), - _ => Err(io::Error::new( + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk end CR", - )), + ))), } } - fn read_end_lf(rdr: &mut BytesMut) -> Poll { + fn read_end_lf(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\n' => Ok(Async::Ready(ChunkedState::End)), - _ => Err(io::Error::new( + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk end LF", - )), + ))), } } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index c82eb4ac8..8c0896029 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -1,15 +1,17 @@ use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Instant; -use std::{fmt, io, net}; +use std::{fmt, io, io::Write, net}; -use actix_codec::{Decoder, Encoder, Framed, FramedParts}; +use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; use actix_server_config::IoStream; use actix_service::Service; use bitflags::bitflags; use bytes::{BufMut, BytesMut}; -use futures::{Async, Future, Poll}; use log::{error, trace}; -use tokio_timer::Delay; +use tokio_timer::{delay, Delay}; use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::cloneable::CloneableService; @@ -261,14 +263,14 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - fn can_read(&self) -> bool { + fn can_read(&self, cx: &mut Context) -> bool { if self .flags .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) { false } else if let Some(ref info) = self.payload { - info.need_read() == PayloadStatus::Read + info.need_read(cx) == PayloadStatus::Read } else { true } @@ -287,7 +289,7 @@ where /// /// true - got whouldblock /// false - didnt get whouldblock - fn poll_flush(&mut self) -> Result { + fn poll_flush(&mut self, cx: &mut Context) -> Result { if self.write_buf.is_empty() { return Ok(false); } @@ -295,31 +297,31 @@ where let len = self.write_buf.len(); let mut written = 0; while written < len { - match self.io.write(&self.write_buf[written..]) { - Ok(0) => { + match unsafe { Pin::new_unchecked(&mut self.io) } + .poll_write(cx, &self.write_buf[written..]) + { + Poll::Ready(Ok(0)) => { return Err(DispatchError::Io(io::Error::new( io::ErrorKind::WriteZero, "", ))); } - Ok(n) => { + Poll::Ready(Ok(n)) => { written += n; } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + Poll::Pending => { if written > 0 { let _ = self.write_buf.split_to(written); } return Ok(true); } - Err(err) => return Err(DispatchError::Io(err)), + Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), } } - if written > 0 { - if written == self.write_buf.len() { - unsafe { self.write_buf.set_len(0) } - } else { - let _ = self.write_buf.split_to(written); - } + if written == self.write_buf.len() { + unsafe { self.write_buf.set_len(0) } + } else { + let _ = self.write_buf.split_to(written); } Ok(false) } @@ -350,12 +352,15 @@ where .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); } - fn poll_response(&mut self) -> Result { + fn poll_response( + &mut self, + cx: &mut Context, + ) -> Result { loop { let state = match self.state { State::None => match self.messages.pop_front() { Some(DispatcherMessage::Item(req)) => { - Some(self.handle_request(req)?) + Some(self.handle_request(req, cx)?) } Some(DispatcherMessage::Error(res)) => { Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) @@ -365,54 +370,58 @@ where } None => None, }, - State::ExpectCall(ref mut fut) => match fut.poll() { - Ok(Async::Ready(req)) => { - self.send_continue(); - self.state = State::ServiceCall(self.service.call(req)); - continue; + State::ExpectCall(ref mut fut) => { + match unsafe { Pin::new_unchecked(fut) }.poll(cx) { + Poll::Ready(Ok(req)) => { + self.send_continue(); + self.state = State::ServiceCall(self.service.call(req)); + continue; + } + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + Poll::Pending => None, } - Ok(Async::NotReady) => None, - Err(e) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - Some(self.send_response(res, body.into_body())?) + } + State::ServiceCall(ref mut fut) => { + match unsafe { Pin::new_unchecked(fut) }.poll(cx) { + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().replace_body(()); + self.state = self.send_response(res, body)?; + continue; + } + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + Poll::Pending => None, } - }, - State::ServiceCall(ref mut fut) => match fut.poll() { - Ok(Async::Ready(res)) => { - let (res, body) = res.into().replace_body(()); - self.state = self.send_response(res, body)?; - continue; - } - Ok(Async::NotReady) => None, - Err(e) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - Some(self.send_response(res, body.into_body())?) - } - }, + } State::SendPayload(ref mut stream) => { loop { if self.write_buf.len() < HW_BUFFER_SIZE { - match stream - .poll_next() - .map_err(|_| DispatchError::Unknown)? - { - Async::Ready(Some(item)) => { + match stream.poll_next(cx) { + Poll::Ready(Some(Ok(item))) => { self.codec.encode( Message::Chunk(Some(item)), &mut self.write_buf, )?; continue; } - Async::Ready(None) => { + Poll::Ready(None) => { self.codec.encode( Message::Chunk(None), &mut self.write_buf, )?; self.state = State::None; } - Async::NotReady => return Ok(PollResponse::DoNothing), + Poll::Ready(Some(Err(_))) => { + return Err(DispatchError::Unknown) + } + Poll::Pending => return Ok(PollResponse::DoNothing), } } else { return Ok(PollResponse::DrainWriteBuf); @@ -433,7 +442,7 @@ where // if read-backpressure is enabled and we consumed some data. // we may read more data and retry if self.state.is_call() { - if self.poll_request()? { + if self.poll_request(cx)? { continue; } } else if !self.messages.is_empty() { @@ -446,17 +455,21 @@ where Ok(PollResponse::DoNothing) } - fn handle_request(&mut self, req: Request) -> Result, DispatchError> { + fn handle_request( + &mut self, + req: Request, + cx: &mut Context, + ) -> Result, DispatchError> { // Handle `EXPECT: 100-Continue` header let req = if req.head().expect() { let mut task = self.expect.call(req); - match task.poll() { - Ok(Async::Ready(req)) => { + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { + Poll::Ready(Ok(req)) => { self.send_continue(); req } - Ok(Async::NotReady) => return Ok(State::ExpectCall(task)), - Err(e) => { + Poll::Pending => return Ok(State::ExpectCall(task)), + Poll::Ready(Err(e)) => { let e = e.into(); let res: Response = e.into(); let (res, body) = res.replace_body(()); @@ -469,13 +482,13 @@ where // Call service let mut task = self.service.call(req); - match task.poll() { - Ok(Async::Ready(res)) => { + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { + Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); self.send_response(res, body) } - Ok(Async::NotReady) => Ok(State::ServiceCall(task)), - Err(e) => { + Poll::Pending => Ok(State::ServiceCall(task)), + Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); self.send_response(res, body.into_body()) @@ -484,9 +497,12 @@ where } /// Process one incoming requests - pub(self) fn poll_request(&mut self) -> Result { + pub(self) fn poll_request( + &mut self, + cx: &mut Context, + ) -> Result { // limit a mount of non processed requests - if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read() { + if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) { return Ok(false); } @@ -521,7 +537,7 @@ where // handle request early if self.state.is_empty() { - self.state = self.handle_request(req)?; + self.state = self.handle_request(req, cx)?; } else { self.messages.push_back(DispatcherMessage::Item(req)); } @@ -587,12 +603,12 @@ where } /// keep-alive timer - fn poll_keepalive(&mut self) -> Result<(), DispatchError> { + fn poll_keepalive(&mut self, cx: &mut Context) -> Result<(), DispatchError> { if self.ka_timer.is_none() { // shutdown timeout if self.flags.contains(Flags::SHUTDOWN) { if let Some(interval) = self.codec.config().client_disconnect_timer() { - self.ka_timer = Some(Delay::new(interval)); + self.ka_timer = Some(delay(interval)); } else { self.flags.insert(Flags::READ_DISCONNECT); if let Some(mut payload) = self.payload.take() { @@ -605,11 +621,8 @@ where } } - match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { - error!("Timer error {:?}", e); - DispatchError::Unknown - })? { - Async::Ready(_) => { + match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { + Poll::Ready(()) => { // if we get timeout during shutdown, drop connection if self.flags.contains(Flags::SHUTDOWN) { return Err(DispatchError::DisconnectTimeout); @@ -624,9 +637,9 @@ where if let Some(deadline) = self.codec.config().client_disconnect_timer() { - if let Some(timer) = self.ka_timer.as_mut() { + if let Some(mut timer) = self.ka_timer.as_mut() { timer.reset(deadline); - let _ = timer.poll(); + let _ = Pin::new(&mut timer).poll(cx); } } else { // no shutdown timeout, drop socket @@ -650,23 +663,37 @@ where } else if let Some(deadline) = self.codec.config().keep_alive_expire() { - if let Some(timer) = self.ka_timer.as_mut() { + if let Some(mut timer) = self.ka_timer.as_mut() { timer.reset(deadline); - let _ = timer.poll(); + let _ = Pin::new(&mut timer).poll(cx); } } - } else if let Some(timer) = self.ka_timer.as_mut() { + } else if let Some(mut timer) = self.ka_timer.as_mut() { timer.reset(self.ka_expire); - let _ = timer.poll(); + let _ = Pin::new(&mut timer).poll(cx); } } - Async::NotReady => (), + Poll::Pending => (), } Ok(()) } } +impl Unpin for Dispatcher +where + T: IoStream, + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ +} + impl Future for Dispatcher where T: IoStream, @@ -679,27 +706,28 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - type Item = (); - type Error = DispatchError; + type Output = Result<(), DispatchError>; #[inline] - fn poll(&mut self) -> Poll { - match self.inner { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match self.as_mut().inner { DispatcherState::Normal(ref mut inner) => { - inner.poll_keepalive()?; + inner.poll_keepalive(cx)?; if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::WRITE_DISCONNECT) { - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } else { // flush buffer - inner.poll_flush()?; + inner.poll_flush(cx)?; if !inner.write_buf.is_empty() { - Ok(Async::NotReady) + Poll::Pending } else { - match inner.io.shutdown()? { - Async::Ready(_) => Ok(Async::Ready(())), - Async::NotReady => Ok(Async::NotReady), + match Pin::new(&mut inner.io).poll_shutdown(cx) { + Poll::Ready(res) => { + Poll::Ready(res.map_err(DispatchError::from)) + } + Poll::Pending => Poll::Pending, } } } @@ -707,12 +735,12 @@ where // read socket into a buf let should_disconnect = if !inner.flags.contains(Flags::READ_DISCONNECT) { - read_available(&mut inner.io, &mut inner.read_buf)? + read_available(cx, &mut inner.io, &mut inner.read_buf)? } else { None }; - inner.poll_request()?; + inner.poll_request(cx)?; if let Some(true) = should_disconnect { inner.flags.insert(Flags::READ_DISCONNECT); if let Some(mut payload) = inner.payload.take() { @@ -724,7 +752,7 @@ where if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { inner.write_buf.reserve(HW_BUFFER_SIZE); } - let result = inner.poll_response()?; + let result = inner.poll_response(cx)?; let drain = result == PollResponse::DrainWriteBuf; // switch to upgrade handler @@ -742,7 +770,7 @@ where self.inner = DispatcherState::Upgrade( inner.upgrade.unwrap().call((req, framed)), ); - return self.poll(); + return self.poll(cx); } else { panic!() } @@ -751,14 +779,14 @@ where // we didnt get WouldBlock from write operation, // so data get written to kernel completely (OSX) // and we have to write again otherwise response can get stuck - if inner.poll_flush()? || !drain { + if inner.poll_flush(cx)? || !drain { break; } } // client is gone if inner.flags.contains(Flags::WRITE_DISCONNECT) { - return Ok(Async::Ready(())); + return Poll::Ready(Ok(())); } let is_empty = inner.state.is_empty(); @@ -771,38 +799,44 @@ where // keep-alive and stream errors if is_empty && inner.write_buf.is_empty() { if let Some(err) = inner.error.take() { - Err(err) + Poll::Ready(Err(err)) } // disconnect if keep-alive is not enabled else if inner.flags.contains(Flags::STARTED) && !inner.flags.intersects(Flags::KEEPALIVE) { inner.flags.insert(Flags::SHUTDOWN); - self.poll() + self.poll(cx) } // disconnect if shutdown else if inner.flags.contains(Flags::SHUTDOWN) { - self.poll() + self.poll(cx) } else { - Ok(Async::NotReady) + Poll::Pending } } else { - Ok(Async::NotReady) + Poll::Pending } } } - DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| { - error!("Upgrade handler error: {}", e); - DispatchError::Upgrade - }), + DispatcherState::Upgrade(ref mut fut) => { + unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| { + error!("Upgrade handler error: {}", e); + DispatchError::Upgrade + }) + } DispatcherState::None => panic!(), } } } -fn read_available(io: &mut T, buf: &mut BytesMut) -> Result, io::Error> +fn read_available( + cx: &mut Context, + io: &mut T, + buf: &mut BytesMut, +) -> Result, io::Error> where - T: io::Read, + T: AsyncRead + Unpin, { let mut read_some = false; loop { @@ -810,19 +844,18 @@ where buf.reserve(HW_BUFFER_SIZE); } - let read = unsafe { io.read(buf.bytes_mut()) }; - match read { - Ok(n) => { + match read(cx, io, buf) { + Poll::Pending => { + return if read_some { Ok(Some(false)) } else { Ok(None) }; + } + Poll::Ready(Ok(n)) => { if n == 0 { return Ok(Some(true)); } else { read_some = true; - unsafe { - buf.advance_mut(n); - } } } - Err(e) => { + Poll::Ready(Err(e)) => { return if e.kind() == io::ErrorKind::WouldBlock { if read_some { Ok(Some(false)) @@ -833,12 +866,23 @@ where Ok(Some(true)) } else { Err(e) - }; + } } } } } +fn read( + cx: &mut Context, + io: &mut T, + buf: &mut BytesMut, +) -> Poll> +where + T: AsyncRead + Unpin, +{ + Pin::new(io).poll_read_buf(cx, buf) +} + #[cfg(test)] mod tests { use actix_service::IntoService; @@ -852,7 +896,7 @@ mod tests { #[test] fn test_req_parse_err() { let mut sys = actix_rt::System::new("test"); - let _ = sys.block_on(lazy(|| { + let _ = sys.block_on(lazy(|cx| { let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( @@ -865,7 +909,10 @@ mod tests { None, None, ); - assert!(h1.poll().is_err()); + match Pin::new(&mut h1).poll(cx) { + Poll::Pending => panic!(), + Poll::Ready(res) => assert!(res.is_err()), + } if let DispatcherState::Normal(ref inner) = h1.inner { assert!(inner.flags.contains(Flags::READ_DISCONNECT)); diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 51ea497e0..6396f3b55 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -548,10 +548,11 @@ mod tests { ConnectionType::Close, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\nContent-Length: 0\r\nConnection: close\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") - ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("Content-Length: 0\r\n")); + assert!(data.contains("Connection: close\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); let _ = head.encode_headers( &mut bytes, @@ -560,10 +561,10 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\nTransfer-Encoding: chunked\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") - ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("Transfer-Encoding: chunked\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); let _ = head.encode_headers( &mut bytes, @@ -572,10 +573,10 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\nContent-Length: 100\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") - ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("Content-Length: 100\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); let mut head = RequestHead::default(); head.set_camel_case_headers(false); @@ -586,7 +587,6 @@ mod tests { .append(CONTENT_TYPE, HeaderValue::from_static("xml")); let mut head = RequestHeadType::Owned(head); - let _ = head.encode_headers( &mut bytes, Version::HTTP_11, @@ -594,10 +594,11 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\ntransfer-encoding: chunked\r\ndate: date\r\ncontent-type: xml\r\ncontent-type: plain/text\r\n\r\n") - ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("transfer-encoding: chunked\r\n")); + assert!(data.contains("content-type: xml\r\n")); + assert!(data.contains("content-type: plain/text\r\n")); + assert!(data.contains("date: date\r\n")); } #[test] @@ -626,9 +627,10 @@ mod tests { ConnectionType::Close, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\ncontent-length: 0\r\nconnection: close\r\nauthorization: another authorization\r\ndate: date\r\n\r\n") - ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("content-length: 0\r\n")); + assert!(data.contains("connection: close\r\n")); + assert!(data.contains("authorization: another authorization\r\n")); + assert!(data.contains("date: date\r\n")); } } diff --git a/actix-http/src/h1/expect.rs b/actix-http/src/h1/expect.rs index 32b6bd9c4..79831eae1 100644 --- a/actix-http/src/h1/expect.rs +++ b/actix-http/src/h1/expect.rs @@ -1,21 +1,24 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_server_config::ServerConfig; -use actix_service::{NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{Async, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, Ready}; use crate::error::Error; use crate::request::Request; pub struct ExpectHandler; -impl NewService for ExpectHandler { +impl ServiceFactory for ExpectHandler { type Config = ServerConfig; type Request = Request; type Response = Request; type Error = Error; type Service = ExpectHandler; type InitError = Error; - type Future = FutureResult; + type Future = Ready>; fn new_service(&self, _: &ServerConfig) -> Self::Future { ok(ExpectHandler) @@ -26,10 +29,10 @@ impl Service for ExpectHandler { type Request = Request; type Response = Request; type Error = Error; - type Future = FutureResult; + type Future = Ready>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { diff --git a/actix-http/src/h1/payload.rs b/actix-http/src/h1/payload.rs index 28acb64bb..2b52cfd86 100644 --- a/actix-http/src/h1/payload.rs +++ b/actix-http/src/h1/payload.rs @@ -1,12 +1,14 @@ //! Payload stream use std::cell::RefCell; use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; use std::rc::{Rc, Weak}; +use std::task::{Context, Poll}; +use actix_utils::task::LocalWaker; use bytes::Bytes; -use futures::task::current as current_task; -use futures::task::Task; -use futures::{Async, Poll, Stream}; +use futures::Stream; use crate::error::PayloadError; @@ -77,15 +79,24 @@ impl Payload { pub fn unread_data(&mut self, data: Bytes) { self.inner.borrow_mut().unread_data(data); } + + #[inline] + pub fn readany( + &mut self, + cx: &mut Context, + ) -> Poll>> { + self.inner.borrow_mut().readany(cx) + } } impl Stream for Payload { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - #[inline] - fn poll(&mut self) -> Poll, PayloadError> { - self.inner.borrow_mut().readany() + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll>> { + self.inner.borrow_mut().readany(cx) } } @@ -117,19 +128,14 @@ impl PayloadSender { } #[inline] - pub fn need_read(&self) -> PayloadStatus { + pub fn need_read(&self, cx: &mut Context) -> PayloadStatus { // we check need_read only if Payload (other side) is alive, // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { if shared.borrow().need_read { PayloadStatus::Read } else { - #[cfg(not(test))] - { - if shared.borrow_mut().io_task.is_none() { - shared.borrow_mut().io_task = Some(current_task()); - } - } + shared.borrow_mut().io_task.register(cx.waker()); PayloadStatus::Pause } } else { @@ -145,8 +151,8 @@ struct Inner { err: Option, need_read: bool, items: VecDeque, - task: Option, - io_task: Option, + task: LocalWaker, + io_task: LocalWaker, } impl Inner { @@ -157,8 +163,8 @@ impl Inner { err: None, items: VecDeque::new(), need_read: true, - task: None, - io_task: None, + task: LocalWaker::new(), + io_task: LocalWaker::new(), } } @@ -178,7 +184,7 @@ impl Inner { self.items.push_back(data); self.need_read = self.len < MAX_BUFFER_SIZE; if let Some(task) = self.task.take() { - task.notify() + task.wake() } } @@ -187,34 +193,28 @@ impl Inner { self.len } - fn readany(&mut self) -> Poll, PayloadError> { + fn readany( + &mut self, + cx: &mut Context, + ) -> Poll>> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); self.need_read = self.len < MAX_BUFFER_SIZE; - if self.need_read && self.task.is_none() && !self.eof { - self.task = Some(current_task()); + if self.need_read && !self.eof { + self.task.register(cx.waker()); } - if let Some(task) = self.io_task.take() { - task.notify() - } - Ok(Async::Ready(Some(data))) + self.io_task.wake(); + Poll::Ready(Some(Ok(data))) } else if let Some(err) = self.err.take() { - Err(err) + Poll::Ready(Some(Err(err))) } else if self.eof { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { self.need_read = true; - #[cfg(not(test))] - { - if self.task.is_none() { - self.task = Some(current_task()); - } - if let Some(task) = self.io_task.take() { - task.notify() - } - } - Ok(Async::NotReady) + self.task.register(cx.waker()); + self.io_task.wake(); + Poll::Pending } } @@ -228,27 +228,23 @@ impl Inner { mod tests { use super::*; use actix_rt::Runtime; - use futures::future::{lazy, result}; + use futures::future::{poll_fn, ready}; #[test] fn test_unread_data() { - Runtime::new() - .unwrap() - .block_on(lazy(|| { - let (_, mut payload) = Payload::create(false); + Runtime::new().unwrap().block_on(async { + let (_, mut payload) = Payload::create(false); - payload.unread_data(Bytes::from("data")); - assert!(!payload.is_empty()); - assert_eq!(payload.len(), 4); + payload.unread_data(Bytes::from("data")); + assert!(!payload.is_empty()); + assert_eq!(payload.len(), 4); - assert_eq!( - Async::Ready(Some(Bytes::from("data"))), - payload.poll().ok().unwrap() - ); + assert_eq!( + Bytes::from("data"), + poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap() + ); - let res: Result<(), ()> = Ok(()); - result(res) - })) - .unwrap(); + ready(()) + }); } } diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 89bf08e9b..ce8ff6626 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -1,12 +1,15 @@ use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_codec::Framed; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; -use actix_service::{IntoNewService, NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use futures::future::{ok, Ready}; +use futures::{ready, Stream}; use crate::body::MessageBody; use crate::cloneable::CloneableService; @@ -20,7 +23,7 @@ use super::codec::Codec; use super::dispatcher::Dispatcher; use super::{ExpectHandler, Message, UpgradeHandler}; -/// `NewService` implementation for HTTP1 transport +/// `ServiceFactory` implementation for HTTP1 transport pub struct H1Service> { srv: S, cfg: ServiceConfig, @@ -32,19 +35,19 @@ pub struct H1Service> { impl H1Service where - S: NewService, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, { /// Create new `HttpService` instance with default config. - pub fn new>(service: F) -> Self { + pub fn new>(service: F) -> Self { let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); H1Service { cfg, - srv: service.into_new_service(), + srv: service.into_factory(), expect: ExpectHandler, upgrade: None, on_connect: None, @@ -53,10 +56,13 @@ where } /// Create new `HttpService` instance with config. - pub fn with_config>(cfg: ServiceConfig, service: F) -> Self { + pub fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { H1Service { cfg, - srv: service.into_new_service(), + srv: service.into_factory(), expect: ExpectHandler, upgrade: None, on_connect: None, @@ -67,7 +73,7 @@ where impl H1Service where - S: NewService, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, @@ -75,7 +81,7 @@ where { pub fn expect(self, expect: X1) -> H1Service where - X1: NewService, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, { @@ -91,7 +97,7 @@ where pub fn upgrade(self, upgrade: Option) -> H1Service where - U1: NewService), Response = ()>, + U1: ServiceFactory), Response = ()>, U1::Error: fmt::Display, U1::InitError: fmt::Debug, { @@ -115,18 +121,18 @@ where } } -impl NewService for H1Service +impl ServiceFactory for H1Service where T: IoStream, - S: NewService, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService< + U: ServiceFactory< Config = SrvConfig, Request = (Request, Framed), Response = (), @@ -144,7 +150,7 @@ where fn new_service(&self, cfg: &SrvConfig) -> Self::Future { H1ServiceResponse { - fut: self.srv.new_service(cfg).into_future(), + fut: self.srv.new_service(cfg), fut_ex: Some(self.expect.new_service(cfg)), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), expect: None, @@ -157,20 +163,24 @@ where } #[doc(hidden)] +#[pin_project::pin_project] pub struct H1ServiceResponse where - S: NewService, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService), Response = ()>, + U: ServiceFactory), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, { + #[pin] fut: S::Future, + #[pin] fut_ex: Option, + #[pin] fut_upg: Option, expect: Option, upgrade: Option, @@ -182,49 +192,57 @@ where impl Future for H1ServiceResponse where T: IoStream, - S: NewService, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService), Response = ()>, + U: ServiceFactory), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, { - type Item = H1ServiceHandler; - type Error = (); + type Output = + Result, ()>; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut_ex { - let expect = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.expect = Some(expect); - self.fut_ex.take(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut().project(); + + if let Some(fut) = this.fut_ex.as_pin_mut() { + let expect = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.expect = Some(expect); + this.fut_ex.set(None); } - if let Some(ref mut fut) = self.fut_upg { - let upgrade = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.upgrade = Some(upgrade); - self.fut_ex.take(); + if let Some(fut) = this.fut_upg.as_pin_mut() { + let upgrade = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.upgrade = Some(upgrade); + this.fut_ex.set(None); } - let service = try_ready!(self + let result = ready!(this .fut - .poll() + .poll(cx) .map_err(|e| log::error!("Init http service error: {:?}", e))); - Ok(Async::Ready(H1ServiceHandler::new( - self.cfg.take().unwrap(), - service, - self.expect.take().unwrap(), - self.upgrade.take(), - self.on_connect.clone(), - ))) + + Poll::Ready(result.map(|service| { + let this = self.as_mut().project(); + H1ServiceHandler::new( + this.cfg.take().unwrap(), + service, + this.expect.take().unwrap(), + this.upgrade.take(), + this.on_connect.clone(), + ) + })) } } @@ -284,10 +302,10 @@ where type Error = DispatchError; type Future = Dispatcher; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { let ready = self .expect - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -297,7 +315,7 @@ where let ready = self .srv - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -307,9 +325,9 @@ where && ready; if ready { - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } else { - Ok(Async::NotReady) + Poll::Pending } } @@ -333,7 +351,7 @@ where } } -/// `NewService` implementation for `OneRequestService` service +/// `ServiceFactory` implementation for `OneRequestService` service #[derive(Default)] pub struct OneRequest { config: ServiceConfig, @@ -353,7 +371,7 @@ where } } -impl NewService for OneRequest +impl ServiceFactory for OneRequest where T: IoStream, { @@ -363,7 +381,7 @@ where type Error = ParseError; type InitError = (); type Service = OneRequestService; - type Future = FutureResult; + type Future = Ready>; fn new_service(&self, _: &SrvConfig) -> Self::Future { ok(OneRequestService { @@ -389,8 +407,8 @@ where type Error = ParseError; type Future = OneRequestServiceResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: Self::Request) -> Self::Future { @@ -415,19 +433,19 @@ impl Future for OneRequestServiceResponse where T: IoStream, { - type Item = (Request, Framed); - type Error = ParseError; + type Output = Result<(Request, Framed), ParseError>; - fn poll(&mut self) -> Poll { - match self.framed.as_mut().unwrap().poll()? { - Async::Ready(Some(req)) => match req { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match self.framed.as_mut().unwrap().next_item(cx) { + Poll::Ready(Some(Ok(req))) => match req { Message::Item(req) => { - Ok(Async::Ready((req, self.framed.take().unwrap()))) + Poll::Ready(Ok((req, self.framed.take().unwrap()))) } Message::Chunk(_) => unreachable!("Something is wrong"), }, - Async::Ready(None) => Err(ParseError::Incomplete), - Async::NotReady => Ok(Async::NotReady), + Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)), + Poll::Ready(None) => Poll::Ready(Err(ParseError::Incomplete)), + Poll::Pending => Poll::Pending, } } } diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs index 0278f23e5..43ab53d01 100644 --- a/actix-http/src/h1/upgrade.rs +++ b/actix-http/src/h1/upgrade.rs @@ -1,10 +1,12 @@ +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_codec::Framed; use actix_server_config::ServerConfig; -use actix_service::{NewService, Service}; -use futures::future::FutureResult; -use futures::{Async, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::Ready; use crate::error::Error; use crate::h1::Codec; @@ -12,14 +14,14 @@ use crate::request::Request; pub struct UpgradeHandler(PhantomData); -impl NewService for UpgradeHandler { +impl ServiceFactory for UpgradeHandler { type Config = ServerConfig; type Request = (Request, Framed); type Response = (); type Error = Error; type Service = UpgradeHandler; type InitError = Error; - type Future = FutureResult; + type Future = Ready>; fn new_service(&self, _: &ServerConfig) -> Self::Future { unimplemented!() @@ -30,10 +32,10 @@ impl Service for UpgradeHandler { type Request = (Request, Framed); type Response = (); type Error = Error; - type Future = FutureResult; + type Future = Ready>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, _: Self::Request) -> Self::Future { diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index fdc4cf0bc..7057bf1ca 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -1,5 +1,9 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use futures::{Async, Future, Poll, Sink}; +use futures::Sink; use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::error::Error; @@ -7,6 +11,7 @@ use crate::h1::{Codec, Message}; use crate::response::Response; /// Send http/1 response +#[pin_project::pin_project] pub struct SendResponse { res: Option, BodySize)>>, body: Option>, @@ -33,60 +38,61 @@ where T: AsyncRead + AsyncWrite, B: MessageBody, { - type Item = Framed; - type Error = Error; + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); - fn poll(&mut self) -> Poll { loop { - let mut body_ready = self.body.is_some(); - let framed = self.framed.as_mut().unwrap(); + let mut body_ready = this.body.is_some(); + let framed = this.framed.as_mut().unwrap(); // send body - if self.res.is_none() && self.body.is_some() { - while body_ready && self.body.is_some() && !framed.is_write_buf_full() { - match self.body.as_mut().unwrap().poll_next()? { - Async::Ready(item) => { + if this.res.is_none() && this.body.is_some() { + while body_ready && this.body.is_some() && !framed.is_write_buf_full() { + match this.body.as_mut().unwrap().poll_next(cx)? { + Poll::Ready(item) => { // body is done if item.is_none() { - let _ = self.body.take(); + let _ = this.body.take(); } - framed.force_send(Message::Chunk(item))?; + framed.write(Message::Chunk(item))?; } - Async::NotReady => body_ready = false, + Poll::Pending => body_ready = false, } } } // flush write buffer if !framed.is_write_buf_empty() { - match framed.poll_complete()? { - Async::Ready(_) => { + match framed.flush(cx)? { + Poll::Ready(_) => { if body_ready { continue; } else { - return Ok(Async::NotReady); + return Poll::Pending; } } - Async::NotReady => return Ok(Async::NotReady), + Poll::Pending => return Poll::Pending, } } // send response - if let Some(res) = self.res.take() { - framed.force_send(res)?; + if let Some(res) = this.res.take() { + framed.write(res)?; continue; } - if self.body.is_some() { + if this.body.is_some() { if body_ready { continue; } else { - return Ok(Async::NotReady); + return Poll::Pending; } } else { break; } } - Ok(Async::Ready(self.framed.take().unwrap())) + Poll::Ready(Ok(this.framed.take().unwrap())) } } diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index 888f9065e..1a52a60f2 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -1,5 +1,8 @@ use std::collections::VecDeque; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Instant; use std::{fmt, mem, net}; @@ -8,7 +11,7 @@ use actix_server_config::IoStream; use actix_service::Service; use bitflags::bitflags; use bytes::{Bytes, BytesMut}; -use futures::{try_ready, Async, Future, Poll, Sink, Stream}; +use futures::{ready, Sink, Stream}; use h2::server::{Connection, SendResponse}; use h2::{RecvStream, SendStream}; use http::header::{ @@ -32,6 +35,7 @@ use crate::response::Response; const CHUNK_SIZE: usize = 16_384; /// Dispatcher for HTTP/2 protocol +#[pin_project::pin_project] pub struct Dispatcher, B: MessageBody> { service: CloneableService, connection: Connection, @@ -48,9 +52,9 @@ where T: IoStream, S: Service, S::Error: Into, - S::Future: 'static, + // S::Future: 'static, S::Response: Into>, - B: MessageBody + 'static, + B: MessageBody, { pub(crate) fn new( service: CloneableService, @@ -93,61 +97,75 @@ impl Future for Dispatcher where T: IoStream, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { - type Item = (); - type Error = DispatchError; + type Output = Result<(), DispatchError>; #[inline] - fn poll(&mut self) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + loop { - match self.connection.poll()? { - Async::Ready(None) => return Ok(Async::Ready(())), - Async::Ready(Some((req, res))) => { + match Pin::new(&mut this.connection).poll_accept(cx) { + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(Some(Ok((req, res)))) => { // update keep-alive expire - if self.ka_timer.is_some() { - if let Some(expire) = self.config.keep_alive_expire() { - self.ka_expire = expire; + if this.ka_timer.is_some() { + if let Some(expire) = this.config.keep_alive_expire() { + this.ka_expire = expire; } } let (parts, body) = req.into_parts(); - let mut req = Request::with_payload(body.into()); + let mut req = Request::with_payload(Payload::< + crate::payload::PayloadStream, + >::H2( + crate::h2::Payload::new(body) + )); let head = &mut req.head_mut(); head.uri = parts.uri; head.method = parts.method; head.version = parts.version; head.headers = parts.headers.into(); - head.peer_addr = self.peer_addr; + head.peer_addr = this.peer_addr; // set on_connect data - if let Some(ref on_connect) = self.on_connect { + if let Some(ref on_connect) = this.on_connect { on_connect.set(&mut req.extensions_mut()); } - tokio_current_thread::spawn(ServiceResponse:: { + tokio_executor::current_thread::spawn(ServiceResponse::< + S::Future, + S::Response, + S::Error, + B, + > { state: ServiceResponseState::ServiceCall( - self.service.call(req), + this.service.call(req), Some(res), ), - config: self.config.clone(), + config: this.config.clone(), buffer: None, - }) + _t: PhantomData, + }); } - Async::NotReady => return Ok(Async::NotReady), + Poll::Pending => return Poll::Pending, } } } } -struct ServiceResponse { +#[pin_project::pin_project] +struct ServiceResponse { state: ServiceResponseState, config: ServiceConfig, buffer: Option, + _t: PhantomData<(I, E)>, } enum ServiceResponseState { @@ -155,12 +173,12 @@ enum ServiceResponseState { SendPayload(SendStream, ResponseBody), } -impl ServiceResponse +impl ServiceResponse where - F: Future, - F::Error: Into, - F::Item: Into>, - B: MessageBody + 'static, + F: Future>, + E: Into, + I: Into>, + B: MessageBody, { fn prepare_response( &self, @@ -223,109 +241,121 @@ where } } -impl Future for ServiceResponse +impl Future for ServiceResponse where - F: Future, - F::Error: Into, - F::Item: Into>, - B: MessageBody + 'static, + F: Future>, + E: Into, + I: Into>, + B: MessageBody, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - match self.state { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut().project(); + + match this.state { ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { - match call.poll() { - Ok(Async::Ready(res)) => { + match unsafe { Pin::new_unchecked(call) }.poll(cx) { + Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); let mut send = send.take().unwrap(); let mut size = body.size(); - let h2_res = self.prepare_response(res.head(), &mut size); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); - let stream = - send.send_response(h2_res, size.is_eof()).map_err(|e| { + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { trace!("Error sending h2 response: {:?}", e); - })?; + return Poll::Ready(()); + } + Ok(stream) => stream, + }; if size.is_eof() { - Ok(Async::Ready(())) + Poll::Ready(()) } else { - self.state = ServiceResponseState::SendPayload(stream, body); - self.poll() + *this.state = + ServiceResponseState::SendPayload(stream, body); + self.poll(cx) } } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); let mut send = send.take().unwrap(); let mut size = body.size(); - let h2_res = self.prepare_response(res.head(), &mut size); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); - let stream = - send.send_response(h2_res, size.is_eof()).map_err(|e| { + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { trace!("Error sending h2 response: {:?}", e); - })?; + return Poll::Ready(()); + } + Ok(stream) => stream, + }; if size.is_eof() { - Ok(Async::Ready(())) + Poll::Ready(()) } else { - self.state = ServiceResponseState::SendPayload( + *this.state = ServiceResponseState::SendPayload( stream, body.into_body(), ); - self.poll() + self.poll(cx) } } } } ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { loop { - if let Some(ref mut buffer) = self.buffer { - match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => return Ok(Async::Ready(())), - Async::Ready(Some(cap)) => { + if let Some(ref mut buffer) = this.buffer { + match stream.poll_capacity(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(()), + Poll::Ready(Some(Ok(cap))) => { let len = buffer.len(); let bytes = buffer.split_to(std::cmp::min(cap, len)); if let Err(e) = stream.send_data(bytes, false) { warn!("{:?}", e); - return Err(()); + return Poll::Ready(()); } else if !buffer.is_empty() { let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); stream.reserve_capacity(cap); } else { - self.buffer.take(); + this.buffer.take(); } } + Poll::Ready(Some(Err(e))) => { + warn!("{:?}", e); + return Poll::Ready(()); + } } } else { - match body.poll_next() { - Ok(Async::NotReady) => { - return Ok(Async::NotReady); - } - Ok(Async::Ready(None)) => { + match body.poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => { if let Err(e) = stream.send_data(Bytes::new(), true) { warn!("{:?}", e); - return Err(()); - } else { - return Ok(Async::Ready(())); } + return Poll::Ready(()); } - Ok(Async::Ready(Some(chunk))) => { + Poll::Ready(Some(Ok(chunk))) => { stream.reserve_capacity(std::cmp::min( chunk.len(), CHUNK_SIZE, )); - self.buffer = Some(chunk); + *this.buffer = Some(chunk); } - Err(e) => { + Poll::Ready(Some(Err(e))) => { error!("Response payload stream error: {:?}", e); - return Err(()); + return Poll::Ready(()); } } } diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs index c5972123f..9c902f18c 100644 --- a/actix-http/src/h2/mod.rs +++ b/actix-http/src/h2/mod.rs @@ -1,9 +1,11 @@ #![allow(dead_code, unused_imports)] - use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use bytes::Bytes; -use futures::{Async, Poll, Stream}; +use futures::Stream; use h2::RecvStream; mod dispatcher; @@ -25,22 +27,23 @@ impl Payload { } impl Stream for Payload { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - match self.pl.poll() { - Ok(Async::Ready(Some(chunk))) => { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + match Pin::new(&mut this.pl).poll_data(cx) { + Poll::Ready(Some(Ok(chunk))) => { let len = chunk.len(); - if let Err(err) = self.pl.release_capacity().release_capacity(len) { - Err(err.into()) + if let Err(err) = this.pl.release_capacity().release_capacity(len) { + Poll::Ready(Some(Err(err.into()))) } else { - Ok(Async::Ready(Some(chunk))) + Poll::Ready(Some(Ok(chunk))) } } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => Err(err.into()), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), } } } diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index e894cf660..860a61f73 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -1,13 +1,16 @@ use std::fmt::Debug; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{io, net, rc}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; -use actix_service::{IntoNewService, NewService, Service}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use bytes::Bytes; -use futures::future::{ok, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; +use futures::future::{ok, Ready}; +use futures::{ready, Stream}; use h2::server::{self, Connection, Handshake}; use h2::RecvStream; use log::error; @@ -23,7 +26,7 @@ use crate::response::Response; use super::dispatcher::Dispatcher; -/// `NewService` implementation for HTTP2 transport +/// `ServiceFactory` implementation for HTTP2 transport pub struct H2Service { srv: S, cfg: ServiceConfig, @@ -33,30 +36,33 @@ pub struct H2Service { impl H2Service where - S: NewService, - S::Error: Into, - S::Response: Into>, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { /// Create new `HttpService` instance. - pub fn new>(service: F) -> Self { + pub fn new>(service: F) -> Self { let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); H2Service { cfg, on_connect: None, - srv: service.into_new_service(), + srv: service.into_factory(), _t: PhantomData, } } /// Create new `HttpService` instance with config. - pub fn with_config>(cfg: ServiceConfig, service: F) -> Self { + pub fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { H2Service { cfg, on_connect: None, - srv: service.into_new_service(), + srv: service.into_factory(), _t: PhantomData, } } @@ -71,12 +77,12 @@ where } } -impl NewService for H2Service +impl ServiceFactory for H2Service where T: IoStream, - S: NewService, - S::Error: Into, - S::Response: Into>, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { @@ -90,7 +96,7 @@ where fn new_service(&self, cfg: &SrvConfig) -> Self::Future { H2ServiceResponse { - fut: self.srv.new_service(cfg).into_future(), + fut: self.srv.new_service(cfg), cfg: Some(self.cfg.clone()), on_connect: self.on_connect.clone(), _t: PhantomData, @@ -99,8 +105,10 @@ where } #[doc(hidden)] -pub struct H2ServiceResponse { - fut: ::Future, +#[pin_project::pin_project] +pub struct H2ServiceResponse { + #[pin] + fut: S::Future, cfg: Option, on_connect: Option Box>>, _t: PhantomData<(T, P, B)>, @@ -109,22 +117,25 @@ pub struct H2ServiceResponse { impl Future for H2ServiceResponse where T: IoStream, - S: NewService, - S::Error: Into, - S::Response: Into>, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { - type Item = H2ServiceHandler; - type Error = S::InitError; + type Output = Result, S::InitError>; - fn poll(&mut self) -> Poll { - let service = try_ready!(self.fut.poll()); - Ok(Async::Ready(H2ServiceHandler::new( - self.cfg.take().unwrap(), - self.on_connect.clone(), - service, - ))) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.as_mut().project(); + + Poll::Ready(ready!(this.fut.poll(cx)).map(|service| { + let this = self.as_mut().project(); + H2ServiceHandler::new( + this.cfg.take().unwrap(), + this.on_connect.clone(), + service, + ) + })) } } @@ -139,9 +150,9 @@ pub struct H2ServiceHandler { impl H2ServiceHandler where S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { fn new( @@ -162,9 +173,9 @@ impl Service for H2ServiceHandler where T: IoStream, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { type Request = Io; @@ -172,8 +183,8 @@ where type Error = DispatchError; type Future = H2ServiceHandlerResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.srv.poll_ready().map_err(|e| { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.srv.poll_ready(cx).map_err(|e| { let e = e.into(); error!("Service readiness error: {:?}", e); DispatchError::Service(e) @@ -219,9 +230,9 @@ pub struct H2ServiceHandlerResponse where T: IoStream, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { state: State, @@ -231,25 +242,24 @@ impl Future for H2ServiceHandlerResponse where T: IoStream, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody, { - type Item = (); - type Error = DispatchError; + type Output = Result<(), DispatchError>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { match self.state { - State::Incoming(ref mut disp) => disp.poll(), + State::Incoming(ref mut disp) => Pin::new(disp).poll(cx), State::Handshake( ref mut srv, ref mut config, ref peer_addr, ref mut on_connect, ref mut handshake, - ) => match handshake.poll() { - Ok(Async::Ready(conn)) => { + ) => match Pin::new(handshake).poll(cx) { + Poll::Ready(Ok(conn)) => { self.state = State::Incoming(Dispatcher::new( srv.take().unwrap(), conn, @@ -258,13 +268,13 @@ where None, *peer_addr, )); - self.poll() + self.poll(cx) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { + Poll::Ready(Err(err)) => { trace!("H2 handshake error: {}", err); - Err(err.into()) + Poll::Ready(Err(err.into())) } + Poll::Pending => Poll::Pending, }, } } diff --git a/actix-http/src/header/common/content_disposition.rs b/actix-http/src/header/common/content_disposition.rs index 14fcc3517..b2b6f34d7 100644 --- a/actix-http/src/header/common/content_disposition.rs +++ b/actix-http/src/header/common/content_disposition.rs @@ -76,6 +76,11 @@ pub enum DispositionParam { /// the form. Name(String), /// A plain file name. + /// + /// It is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any + /// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where + /// [`FilenameExt`](DispositionParam::FilenameExt) with charset UTF-8 may be used instead + /// in case there are Unicode characters in file names. Filename(String), /// An extended file name. It must not exist for `ContentType::Formdata` according to /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). @@ -220,7 +225,16 @@ impl DispositionParam { /// ext-token = /// ``` /// -/// **Note**: filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within +/// # Note +/// +/// filename is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any +/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where +/// filename* with charset UTF-8 may be used instead in case there are Unicode characters in file +/// names. +/// filename is [acceptable](https://tools.ietf.org/html/rfc7578#section-4.2) to be UTF-8 encoded +/// directly in a *Content-Disposition* header for *multipart/form-data*, though. +/// +/// filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within /// *multipart/form-data*. /// /// # Example @@ -251,6 +265,22 @@ impl DispositionParam { /// }; /// assert_eq!(cd2.get_name(), Some("file")); // field name /// assert_eq!(cd2.get_filename(), Some("bill.odt")); +/// +/// // HTTP response header with Unicode characters in file names +/// let cd3 = ContentDisposition { +/// disposition: DispositionType::Attachment, +/// parameters: vec![ +/// DispositionParam::FilenameExt(ExtendedValue { +/// charset: Charset::Ext(String::from("UTF-8")), +/// language_tag: None, +/// value: String::from("\u{1f600}.svg").into_bytes(), +/// }), +/// // fallback for better compatibility +/// DispositionParam::Filename(String::from("Grinning-Face-Emoji.svg")) +/// ], +/// }; +/// assert_eq!(cd3.get_filename_ext().map(|ev| ev.value.as_ref()), +/// Some("\u{1f600}.svg".as_bytes())); /// ``` /// /// # WARN @@ -333,15 +363,17 @@ impl ContentDisposition { // token: won't contains semicolon according to RFC 2616 Section 2.2 let (token, new_left) = split_once_and_trim(left, ';'); left = new_left; + if token.is_empty() { + // quoted-string can be empty, but token cannot be empty + return Err(crate::error::ParseError::Header); + } token.to_owned() }; - if value.is_empty() { - return Err(crate::error::ParseError::Header); - } let param = if param_name.eq_ignore_ascii_case("name") { DispositionParam::Name(value) } else if param_name.eq_ignore_ascii_case("filename") { + // See also comments in test_from_raw_uncessary_percent_decode. DispositionParam::Filename(value) } else { DispositionParam::Unknown(param_name.to_owned(), value) @@ -466,11 +498,40 @@ impl fmt::Display for DispositionType { impl fmt::Display for DispositionParam { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // All ASCII control charaters (0-30, 127) excepting horizontal tab, double quote, and + // All ASCII control characters (0-30, 127) including horizontal tab, double quote, and // backslash should be escaped in quoted-string (i.e. "foobar"). - // Ref: RFC6266 S4.1 -> RFC2616 S2.2; RFC 7578 S4.2 -> RFC2183 S2 -> ... . + // Ref: RFC6266 S4.1 -> RFC2616 S3.6 + // filename-parm = "filename" "=" value + // value = token | quoted-string + // quoted-string = ( <"> *(qdtext | quoted-pair ) <"> ) + // qdtext = > + // quoted-pair = "\" CHAR + // TEXT = + // LWS = [CRLF] 1*( SP | HT ) + // OCTET = + // CHAR = + // CTL = + // + // Ref: RFC7578 S4.2 -> RFC2183 S2 -> RFC2045 S5.1 + // parameter := attribute "=" value + // attribute := token + // ; Matching of attributes + // ; is ALWAYS case-insensitive. + // value := token / quoted-string + // token := 1* + // tspecials := "(" / ")" / "<" / ">" / "@" / + // "," / ";" / ":" / "\" / <"> + // "/" / "[" / "]" / "?" / "=" + // ; Must be in quoted-string, + // ; to use within parameter values + // + // + // See also comments in test_from_raw_uncessary_percent_decode. lazy_static! { - static ref RE: Regex = Regex::new("[\x01-\x08\x10\x1F\x7F\"\\\\]").unwrap(); + static ref RE: Regex = Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap(); } match self { DispositionParam::Name(ref value) => write!(f, "name={}", value), @@ -774,8 +835,18 @@ mod tests { #[test] fn test_from_raw_uncessary_percent_decode() { + // In fact, RFC7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with + // non-ASCII characters MAY be percent-encoded. + // On the contrary, RFC6266 or other RFCs related to Content-Disposition response header + // do not mention such percent-encoding. + // So, it appears to be undecidable whether to percent-decode or not without + // knowing the usage scenario (multipart/form-data v.s. HTTP response header) and + // inevitable to unnecessarily percent-decode filename with %XX in the former scenario. + // Fortunately, it seems that almost all mainstream browsers just send UTF-8 encoded file + // names in quoted-string format (tested on Edge, IE11, Chrome and Firefox) without + // percent-encoding. So we do not bother to attempt to percent-decode. let a = HeaderValue::from_static( - "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", // Should not be decoded! + "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", ); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { @@ -811,6 +882,9 @@ mod tests { let a = HeaderValue::from_static("inline; filename= "); assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; filename=\"\""); + assert!(ContentDisposition::from_raw(&a).expect("parse cd").get_filename().expect("filename").is_empty()); } #[test] diff --git a/actix-http/src/httpcodes.rs b/actix-http/src/httpcodes.rs index 3cac35eb7..0c7f23fc8 100644 --- a/actix-http/src/httpcodes.rs +++ b/actix-http/src/httpcodes.rs @@ -29,7 +29,6 @@ impl Response { STATIC_RESP!(AlreadyReported, StatusCode::ALREADY_REPORTED); STATIC_RESP!(MultipleChoices, StatusCode::MULTIPLE_CHOICES); - STATIC_RESP!(MovedPermanenty, StatusCode::MOVED_PERMANENTLY); STATIC_RESP!(MovedPermanently, StatusCode::MOVED_PERMANENTLY); STATIC_RESP!(Found, StatusCode::FOUND); STATIC_RESP!(SeeOther, StatusCode::SEE_OTHER); diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index b57fdddce..4d17347db 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -4,7 +4,8 @@ clippy::too_many_arguments, clippy::new_without_default, clippy::borrow_interior_mutable_const, - clippy::write_with_newline + clippy::write_with_newline, + unused_imports )] #[macro_use] diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index 316df2611..5994ed39e 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -388,6 +388,12 @@ impl BoxedResponseHead { pub fn new(status: StatusCode) -> Self { RESPONSE_POOL.with(|p| p.get_message(status)) } + + pub(crate) fn take(&mut self) -> Self { + BoxedResponseHead { + head: self.head.take(), + } + } } impl std::ops::Deref for BoxedResponseHead { @@ -406,7 +412,9 @@ impl std::ops::DerefMut for BoxedResponseHead { impl Drop for BoxedResponseHead { fn drop(&mut self) { - RESPONSE_POOL.with(|p| p.release(self.head.take().unwrap())) + if let Some(head) = self.head.take() { + RESPONSE_POOL.with(move |p| p.release(head)) + } } } diff --git a/actix-http/src/payload.rs b/actix-http/src/payload.rs index 0ce209705..f2cc6414f 100644 --- a/actix-http/src/payload.rs +++ b/actix-http/src/payload.rs @@ -1,11 +1,15 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use bytes::Bytes; -use futures::{Async, Poll, Stream}; +use futures::Stream; use h2::RecvStream; use crate::error::PayloadError; /// Type represent boxed payload -pub type PayloadStream = Box>; +pub type PayloadStream = Pin>>>; /// Type represent streaming payload pub enum Payload { @@ -48,18 +52,17 @@ impl Payload { impl Stream for Payload where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; #[inline] - fn poll(&mut self) -> Poll, Self::Error> { - match self { - Payload::None => Ok(Async::Ready(None)), - Payload::H1(ref mut pl) => pl.poll(), - Payload::H2(ref mut pl) => pl.poll(), - Payload::Stream(ref mut pl) => pl.poll(), + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.get_mut() { + Payload::None => Poll::Ready(None), + Payload::H1(ref mut pl) => pl.readany(cx), + Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx), + Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx), } } } diff --git a/actix-http/src/request.rs b/actix-http/src/request.rs index e9252a829..77ece01c5 100644 --- a/actix-http/src/request.rs +++ b/actix-http/src/request.rs @@ -80,6 +80,11 @@ impl

Request

{ ) } + /// Get request's payload + pub fn payload(&mut self) -> &mut Payload

{ + &mut self.payload + } + /// Get request's payload pub fn take_payload(&mut self) -> Payload

{ std::mem::replace(&mut self.payload, Payload::None) @@ -199,7 +204,6 @@ mod tests { assert_eq!(req.uri().query(), Some("q=1")); let s = format!("{:?}", req); - println!("T: {:?}", s); assert!(s.contains("Request HTTP/1.1 GET:/index.html")); } } diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs index 5b0b3bc87..a5f18cc79 100644 --- a/actix-http/src/response.rs +++ b/actix-http/src/response.rs @@ -1,11 +1,14 @@ //! Http response use std::cell::{Ref, RefMut}; +use std::future::Future; use std::io::Write; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, str}; use bytes::{BufMut, Bytes, BytesMut}; -use futures::future::{ok, FutureResult, IntoFuture}; -use futures::Stream; +use futures::future::{ok, Ready}; +use futures::stream::Stream; use serde::Serialize; use serde_json; @@ -194,7 +197,7 @@ impl Response { self.head.extensions.borrow_mut() } - /// Get body os this response + /// Get body of this response #[inline] pub fn body(&self) -> &ResponseBody { &self.body @@ -280,13 +283,15 @@ impl fmt::Debug for Response { } } -impl IntoFuture for Response { - type Item = Response; - type Error = Error; - type Future = FutureResult; +impl Future for Response { + type Output = Result; - fn into_future(self) -> Self::Future { - ok(self) + fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll { + Poll::Ready(Ok(Response { + head: self.head.take(), + body: self.body.take_body(), + error: self.error.take(), + })) } } @@ -635,7 +640,7 @@ impl ResponseBuilder { /// `ResponseBuilder` can not be used after this call. pub fn streaming(&mut self, stream: S) -> Response where - S: Stream + 'static, + S: Stream> + 'static, E: Into + 'static, { self.body(Body::from_message(BodyStream::new(stream))) @@ -757,13 +762,11 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder { } } -impl IntoFuture for ResponseBuilder { - type Item = Response; - type Error = Error; - type Future = FutureResult; +impl Future for ResponseBuilder { + type Output = Result; - fn into_future(mut self) -> Self::Future { - ok(self.finish()) + fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll { + Poll::Ready(Ok(self.finish())) } } diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index 09b8077b3..e18b10130 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,14 +1,17 @@ use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, io, net, rc}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_server_config::{ Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig, }; -use actix_service::{IntoNewService, NewService, Service}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use futures::{ready, Future}; use h2::server::{self, Handshake}; +use pin_project::{pin_project, project}; use crate::body::MessageBody; use crate::builder::HttpServiceBuilder; @@ -20,7 +23,7 @@ use crate::request::Request; use crate::response::Response; use crate::{h1, h2::Dispatcher}; -/// `NewService` HTTP1.1/HTTP2 transport implementation +/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation pub struct HttpService> { srv: S, cfg: ServiceConfig, @@ -32,10 +35,10 @@ pub struct HttpService HttpService where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { @@ -47,20 +50,20 @@ where impl HttpService where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { /// Create new `HttpService` instance. - pub fn new>(service: F) -> Self { + pub fn new>(service: F) -> Self { let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); HttpService { cfg, - srv: service.into_new_service(), + srv: service.into_factory(), expect: h1::ExpectHandler, upgrade: None, on_connect: None, @@ -69,13 +72,13 @@ where } /// Create new `HttpService` instance with config. - pub(crate) fn with_config>( + pub(crate) fn with_config>( cfg: ServiceConfig, service: F, ) -> Self { HttpService { cfg, - srv: service.into_new_service(), + srv: service.into_factory(), expect: h1::ExpectHandler, upgrade: None, on_connect: None, @@ -86,10 +89,11 @@ where impl HttpService where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, + ::Future: 'static, B: MessageBody, { /// Provide service for `EXPECT: 100-Continue` support. @@ -99,9 +103,10 @@ where /// request will be forwarded to main service. pub fn expect(self, expect: X1) -> HttpService where - X1: NewService, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, + ::Future: 'static, { HttpService { expect, @@ -119,13 +124,14 @@ where /// and this service get called with original request and framed object. pub fn upgrade(self, upgrade: Option) -> HttpService where - U1: NewService< + U1: ServiceFactory< Config = SrvConfig, Request = (Request, Framed), Response = (), >, U1::Error: fmt::Display, U1::InitError: fmt::Debug, + ::Future: 'static, { HttpService { upgrade, @@ -147,25 +153,27 @@ where } } -impl NewService for HttpService +impl ServiceFactory for HttpService where T: IoStream, - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService< + ::Future: 'static, + U: ServiceFactory< Config = SrvConfig, Request = (Request, Framed), Response = (), >, U::Error: fmt::Display, U::InitError: fmt::Debug, + ::Future: 'static, { type Config = SrvConfig; type Request = ServerIo; @@ -177,7 +185,7 @@ where fn new_service(&self, cfg: &SrvConfig) -> Self::Future { HttpServiceResponse { - fut: self.srv.new_service(cfg).into_future(), + fut: self.srv.new_service(cfg), fut_ex: Some(self.expect.new_service(cfg)), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), expect: None, @@ -190,9 +198,20 @@ where } #[doc(hidden)] -pub struct HttpServiceResponse { +#[pin_project] +pub struct HttpServiceResponse< + T, + P, + S: ServiceFactory, + B, + X: ServiceFactory, + U: ServiceFactory, +> { + #[pin] fut: S::Future, + #[pin] fut_ex: Option, + #[pin] fut_upg: Option, expect: Option, upgrade: Option, @@ -204,50 +223,59 @@ pub struct HttpServiceResponse Future for HttpServiceResponse where T: IoStream, - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService), Response = ()>, + ::Future: 'static, + U: ServiceFactory), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, + ::Future: 'static, { - type Item = HttpServiceHandler; - type Error = (); + type Output = + Result, ()>; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut_ex { - let expect = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.expect = Some(expect); - self.fut_ex.take(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut().project(); + + if let Some(fut) = this.fut_ex.as_pin_mut() { + let expect = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.expect = Some(expect); + this.fut_ex.set(None); } - if let Some(ref mut fut) = self.fut_upg { - let upgrade = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.upgrade = Some(upgrade); - self.fut_ex.take(); + if let Some(fut) = this.fut_upg.as_pin_mut() { + let upgrade = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.upgrade = Some(upgrade); + this.fut_ex.set(None); } - let service = try_ready!(self + let result = ready!(this .fut - .poll() + .poll(cx) .map_err(|e| log::error!("Init http service error: {:?}", e))); - Ok(Async::Ready(HttpServiceHandler::new( - self.cfg.take().unwrap(), - service, - self.expect.take().unwrap(), - self.upgrade.take(), - self.on_connect.clone(), - ))) + Poll::Ready(result.map(|service| { + let this = self.as_mut().project(); + HttpServiceHandler::new( + this.cfg.take().unwrap(), + service, + this.expect.take().unwrap(), + this.upgrade.take(), + this.on_connect.clone(), + ) + })) } } @@ -264,9 +292,9 @@ pub struct HttpServiceHandler { impl HttpServiceHandler where S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, X: Service, X::Error: Into, @@ -295,9 +323,9 @@ impl Service for HttpServiceHandler where T: IoStream, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, X: Service, X::Error: Into, @@ -309,10 +337,10 @@ where type Error = DispatchError; type Future = HttpServiceHandlerResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { let ready = self .expect - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -322,7 +350,7 @@ where let ready = self .srv - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -332,9 +360,9 @@ where && ready; if ready { - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } else { - Ok(Async::NotReady) + Poll::Pending } } @@ -389,6 +417,7 @@ where } } +#[pin_project] enum State where S: Service, @@ -401,8 +430,8 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - H1(h1::Dispatcher), - H2(Dispatcher, S, B>), + H1(#[pin] h1::Dispatcher), + H2(#[pin] Dispatcher, S, B>), Unknown( Option<( T, @@ -425,19 +454,21 @@ where ), } +#[pin_project] pub struct HttpServiceHandlerResponse where T: IoStream, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, X: Service, X::Error: Into, U: Service), Response = ()>, U::Error: fmt::Display, { + #[pin] state: State, } @@ -447,30 +478,51 @@ impl Future for HttpServiceHandlerResponse where T: IoStream, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody, X: Service, X::Error: Into, U: Service), Response = ()>, U::Error: fmt::Display, { - type Item = (); - type Error = DispatchError; + type Output = Result<(), DispatchError>; - fn poll(&mut self) -> Poll { - match self.state { - State::H1(ref mut disp) => disp.poll(), - State::H2(ref mut disp) => disp.poll(), + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + self.project().state.poll(cx) + } +} + +impl State +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + #[project] + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + #[project] + match self.as_mut().project() { + State::H1(disp) => disp.poll(cx), + State::H2(disp) => disp.poll(cx), State::Unknown(ref mut data) => { if let Some(ref mut item) = data { loop { // Safety - we only write to the returned slice. let b = unsafe { item.1.bytes_mut() }; - let n = try_ready!(item.0.poll_read(b)); + let n = ready!(Pin::new(&mut item.0).poll_read(cx, b))?; if n == 0 { - return Ok(Async::Ready(())); + return Poll::Ready(Ok(())); } // Safety - we know that 'n' bytes have // been initialized via the contract of @@ -491,15 +543,15 @@ where inner: io, unread: Some(buf), }; - self.state = State::Handshake(Some(( + self.set(State::Handshake(Some(( server::handshake(io), cfg, srv, peer_addr, on_connect, - ))); + )))); } else { - self.state = State::H1(h1::Dispatcher::with_timeout( + self.set(State::H1(h1::Dispatcher::with_timeout( io, h1::Codec::new(cfg.clone()), cfg, @@ -509,36 +561,38 @@ where expect, upgrade, on_connect, - )) + ))) } - self.poll() + self.poll(cx) } State::Handshake(ref mut data) => { let conn = if let Some(ref mut item) = data { - match item.0.poll() { - Ok(Async::Ready(conn)) => conn, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => { + match Pin::new(&mut item.0).poll(cx) { + Poll::Ready(Ok(conn)) => conn, + Poll::Ready(Err(err)) => { trace!("H2 handshake error: {}", err); - return Err(err.into()); + return Poll::Ready(Err(err.into())); } + Poll::Pending => return Poll::Pending, } } else { panic!() }; let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap(); - self.state = State::H2(Dispatcher::new( + self.set(State::H2(Dispatcher::new( srv, conn, on_connect, cfg, None, peer_addr, - )); - self.poll() + ))); + self.poll(cx) } } } } /// Wrapper for `AsyncRead + AsyncWrite` types +#[pin_project::pin_project] struct Io { unread: Option, + #[pin] inner: T, } @@ -568,21 +622,65 @@ impl io::Write for Io { } impl AsyncRead for Io { + // unsafe fn initializer(&self) -> io::Initializer { + // self.get_mut().inner.initializer() + // } + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { self.inner.prepare_uninitialized_buffer(buf) } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = self.project(); + + if let Some(mut bytes) = this.unread.take() { + let size = std::cmp::min(buf.len(), bytes.len()); + buf[..size].copy_from_slice(&bytes[..size]); + if bytes.len() > size { + bytes.split_to(size); + *this.unread = Some(bytes); + } + Poll::Ready(Ok(size)) + } else { + this.inner.poll_read(cx, buf) + } + } + + // fn poll_read_vectored( + // self: Pin<&mut Self>, + // cx: &mut Context<'_>, + // bufs: &mut [io::IoSliceMut<'_>], + // ) -> Poll> { + // self.get_mut().inner.poll_read_vectored(cx, bufs) + // } } -impl AsyncWrite for Io { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.inner.shutdown() +impl tokio_io::AsyncWrite for Io { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) } - fn write_buf(&mut self, buf: &mut B) -> Poll { - self.inner.write_buf(buf) + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().inner.poll_shutdown(cx) } } -impl IoStream for Io { +impl actix_server_config::IoStream for Io { #[inline] fn peer_addr(&self) -> Option { self.inner.peer_addr() diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs index ed5b81a35..26f2c223f 100644 --- a/actix-http/src/test.rs +++ b/actix-http/src/test.rs @@ -1,12 +1,13 @@ //! Test Various helpers for Actix applications to use during testing. use std::fmt::Write as FmtWrite; -use std::io; +use std::io::{self, Read, Write}; +use std::pin::Pin; use std::str::FromStr; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_server_config::IoStream; use bytes::{Buf, Bytes, BytesMut}; -use futures::{Async, Poll}; use http::header::{self, HeaderName, HeaderValue}; use http::{HttpTryFrom, Method, Uri, Version}; use percent_encoding::percent_encode; @@ -244,14 +245,31 @@ impl io::Write for TestBuffer { } } -impl AsyncRead for TestBuffer {} +impl AsyncRead for TestBuffer { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Poll::Ready(self.get_mut().read(buf)) + } +} impl AsyncWrite for TestBuffer { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(self.get_mut().write(buf)) } - fn write_buf(&mut self, _: &mut B) -> Poll { - Ok(Async::NotReady) + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } diff --git a/actix-http/src/ws/transport.rs b/actix-http/src/ws/transport.rs index da7782be5..58ba3160f 100644 --- a/actix-http/src/ws/transport.rs +++ b/actix-http/src/ws/transport.rs @@ -1,7 +1,10 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::{IntoService, Service}; use actix_utils::framed::{FramedTransport, FramedTransportError}; -use futures::{Future, Poll}; use super::{Codec, Frame, Message}; @@ -40,10 +43,9 @@ where S::Future: 'static, S::Error: 'static, { - type Item = (); - type Error = FramedTransportError; + type Output = Result<(), FramedTransportError>; - fn poll(&mut self) -> Poll { - self.inner.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Pin::new(&mut self.inner).poll(cx) } } diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index a4f1569cc..05248966a 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -1,9 +1,9 @@ -use actix_service::NewService; +use actix_service::ServiceFactory; use bytes::Bytes; use futures::future::{self, ok}; use actix_http::{http, HttpService, Request, Response}; -use actix_http_test::TestServer; +use actix_http_test::{block_on, TestServer}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -29,55 +29,63 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ #[test] fn test_h1_v2() { - env_logger::init(); - let mut srv = TestServer::new(move || { - HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) - }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + block_on(async { + let srv = TestServer::start(move || { + HttpService::build() + .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + }); - let request = srv.get("/").header("x-test", "111").send(); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + let request = srv.get("/").header("x-test", "111").send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - let response = srv.block_on(srv.post("/").send()).unwrap(); - assert!(response.status().is_success()); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_connection_close() { - let mut srv = TestServer::new(move || { - HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map(|_| ()) - }); - let response = srv.block_on(srv.get("/").force_close().send()).unwrap(); - assert!(response.status().is_success()); + block_on(async { + let srv = TestServer::start(move || { + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map(|_| ()) + }); + + let response = srv.get("/").force_close().send().await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn test_with_query_parameter() { - let mut srv = TestServer::new(move || { - HttpService::build() - .finish(|req: Request| { - if req.uri().query().unwrap().contains("qp=") { - ok::<_, ()>(Response::Ok().finish()) - } else { - ok::<_, ()>(Response::BadRequest().finish()) - } - }) - .map(|_| ()) - }); + block_on(async { + let srv = TestServer::start(move || { + HttpService::build() + .finish(|req: Request| { + if req.uri().query().unwrap().contains("qp=") { + ok::<_, ()>(Response::Ok().finish()) + } else { + ok::<_, ()>(Response::BadRequest().finish()) + } + }) + .map(|_| ()) + }); - let request = srv.request(http::Method::GET, srv.url("/?qp=5")).send(); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::GET, srv.url("/?qp=5")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + }) } diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs new file mode 100644 index 000000000..7eaa8e2a2 --- /dev/null +++ b/actix-http/tests/test_openssl.rs @@ -0,0 +1,545 @@ +#![cfg(feature = "openssl")] +use std::io; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http_test::{block_on, TestServer}; +use actix_server::ssl::OpensslAcceptor; +use actix_server_config::ServerConfig; +use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory}; + +use bytes::{Bytes, BytesMut}; +use futures::future::{err, ok, ready}; +use futures::stream::{once, Stream, StreamExt}; +use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; + +use actix_http::error::{ErrorBadRequest, PayloadError}; +use actix_http::http::header::{self, HeaderName, HeaderValue}; +use actix_http::http::{Method, StatusCode, Version}; +use actix_http::httpmessage::HttpMessage; +use actix_http::{body, Error, HttpService, Request, Response}; + +async fn load_body(stream: S) -> Result +where + S: Stream>, +{ + let body = stream + .map(|res| match res { + Ok(chunk) => chunk, + Err(_) => panic!(), + }) + .fold(BytesMut::new(), move |mut body, chunk| { + body.extend_from_slice(&chunk); + ready(body) + }) + .await; + + Ok(body) +} + +fn ssl_acceptor() -> io::Result> { + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2")?; + Ok(OpensslAcceptor::new(builder.build())) +} + +#[test] +fn test_h2() -> io::Result<()> { + block_on(async { + let openssl = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| ok::<_, Error>(Response::Ok().finish())) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) + }) +} + +#[test] +fn test_h2_1() -> io::Result<()> { + block_on(async { + let openssl = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), Version::HTTP_2); + ok::<_, Error>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) + }) +} + +#[test] +fn test_h2_body() -> io::Result<()> { + block_on(async { + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let openssl = ssl_acceptor()?; + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|mut req: Request<_>| { + async move { + let body = load_body(req.take_payload()).await?; + Ok::<_, Error>(Response::Ok().body(body)) + } + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); + + let body = srv.load_body(response).await.unwrap(); + assert_eq!(&body, data.as_bytes()); + Ok(()) + }) +} + +#[test] +fn test_h2_content_length() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + ok::<_, ()>(Response::new(statuses[indx])) + }) + .map_err(|_| ()), + ) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } + }) +} + +#[test] +fn test_h2_headers() { + block_on(async { + let data = STR.repeat(10); + let data2 = data.clone(); + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + let data = data.clone(); + pipeline_factory(openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e))) + .and_then( + HttpService::build().h2(move |_| { + let mut builder = Response::Ok(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + ok::<_, ()>(builder.body(data.clone())) + }).map_err(|_| ())) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); + }) +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[test] +fn test_h2_body2() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) +} + +#[test] +fn test_h2_head_empty() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); + }) +} + +#[test] +fn test_h2_head_binary() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| { + ok::<_, ()>( + Response::Ok().content_length(STR.len() as u64).body(STR), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); + }) +} + +#[test] +fn test_h2_head_binary2() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + }) +} + +#[test] +fn test_h2_body_length() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) +} + +#[test] +fn test_h2_body_chunked_explicit() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| { + let body = + once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) +} + +#[test] +fn test_h2_response_http_error_handling() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(factory_fn_cfg(|_: &ServerConfig| { + ok::<_, ()>(service_fn2(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(header::CONTENT_TYPE, broken_header) + .body(STR), + ) + })) + })) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); + }) +} + +#[test] +fn test_h2_service_error() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| err::(ErrorBadRequest("error"))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); + }) +} + +#[test] +fn test_h2_on_connect() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .on_connect(|_| 10usize) + .h2(|req: Request| { + assert!(req.extensions().contains::()); + ok::<_, ()>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + }) +} diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs new file mode 100644 index 000000000..c36d05794 --- /dev/null +++ b/actix-http/tests/test_rustls.rs @@ -0,0 +1,474 @@ +#![cfg(feature = "rustls")] +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http::error::PayloadError; +use actix_http::http::header::{self, HeaderName, HeaderValue}; +use actix_http::http::{Method, StatusCode, Version}; +use actix_http::{body, error, Error, HttpService, Request, Response}; +use actix_http_test::{block_on, TestServer}; +use actix_server::ssl::RustlsAcceptor; +use actix_server_config::ServerConfig; +use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory}; + +use bytes::{Bytes, BytesMut}; +use futures::future::{self, err, ok}; +use futures::stream::{once, Stream, StreamExt}; +use rust_tls::{ + internal::pemfile::{certs, pkcs8_private_keys}, + NoClientAuth, ServerConfig as RustlsServerConfig, +}; + +use std::fs::File; +use std::io::{self, BufReader}; + +async fn load_body(mut stream: S) -> Result +where + S: Stream> + Unpin, +{ + let mut body = BytesMut::new(); + while let Some(item) = stream.next().await { + body.extend_from_slice(&item?) + } + Ok(body) +} + +fn ssl_acceptor() -> io::Result> { + // load ssl keys + let mut config = RustlsServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let protos = vec![b"h2".to_vec()]; + config.set_protocols(&protos); + Ok(RustlsAcceptor::new(config)) +} + +#[test] +fn test_h2() -> io::Result<()> { + block_on(async { + let rustls = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) + }) +} + +#[test] +fn test_h2_1() -> io::Result<()> { + block_on(async { + let rustls = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), Version::HTTP_2); + future::ok::<_, Error>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) + }) +} + +#[test] +fn test_h2_body1() -> io::Result<()> { + block_on(async { + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let rustls = ssl_acceptor()?; + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|mut req: Request<_>| { + async move { + let body = load_body(req.take_payload()).await?; + Ok::<_, Error>(Response::Ok().body(body)) + } + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); + + let body = srv.load_body(response).await.unwrap(); + assert_eq!(&body, data.as_bytes()); + Ok(()) + }) +} + +#[test] +fn test_h2_content_length() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + .map_err(|_| ()), + ) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } + }) +} + +#[test] +fn test_h2_headers() { + block_on(async { + let data = STR.repeat(10); + let data2 = data.clone(); + let rustls = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + let data = data.clone(); + pipeline_factory(rustls + .clone() + .map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build().h2(move |_| { + let mut config = Response::Ok(); + for idx in 0..90 { + config.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + future::ok::<_, ()>(config.body(data.clone())) + }).map_err(|_| ())) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); + }) +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[test] +fn test_h2_body2() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) +} + +#[test] +fn test_h2_head_empty() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); + }) +} + +#[test] +fn test_h2_head_binary() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| { + ok::<_, ()>( + Response::Ok() + .content_length(STR.len() as u64) + .body(STR), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); + }) +} + +#[test] +fn test_h2_head_binary2() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + }) +} + +#[test] +fn test_h2_body_length() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok().body(body::SizedStream::new( + STR.len() as u64, + body, + )), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) +} + +#[test] +fn test_h2_body_chunked_explicit() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| { + let body = + once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) +} + +#[test] +fn test_h2_response_http_error_handling() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(factory_fn_cfg(|_: &ServerConfig| { + ok::<_, ()>(service_fn2(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header( + http::header::CONTENT_TYPE, + broken_header, + ) + .body(STR), + ) + })) + })) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); + }) +} + +#[test] +fn test_h2_service_error() { + block_on(async { + let rustls = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| err::(error::ErrorBadRequest("error"))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); + }) +} diff --git a/actix-http/tests/test_rustls_server.rs b/actix-http/tests/test_rustls_server.rs deleted file mode 100644 index b74fd07bf..000000000 --- a/actix-http/tests/test_rustls_server.rs +++ /dev/null @@ -1,462 +0,0 @@ -#![cfg(feature = "rust-tls")] -use actix_codec::{AsyncRead, AsyncWrite}; -use actix_http::error::PayloadError; -use actix_http::http::header::{self, HeaderName, HeaderValue}; -use actix_http::http::{Method, StatusCode, Version}; -use actix_http::{body, error, Error, HttpService, Request, Response}; -use actix_http_test::TestServer; -use actix_server::ssl::RustlsAcceptor; -use actix_server_config::ServerConfig; -use actix_service::{new_service_cfg, NewService}; - -use bytes::{Bytes, BytesMut}; -use futures::future::{self, ok, Future}; -use futures::stream::{once, Stream}; -use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - NoClientAuth, ServerConfig as RustlsServerConfig, -}; - -use std::fs::File; -use std::io::{BufReader, Result}; - -fn load_body(stream: S) -> impl Future -where - S: Stream, -{ - stream.fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) -} - -fn ssl_acceptor() -> Result> { - // load ssl keys - let mut config = RustlsServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - - let protos = vec![b"h2".to_vec()]; - config.set_protocols(&protos); - Ok(RustlsAcceptor::new(config)) -} - -#[test] -fn test_h2() -> Result<()> { - let rustls = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_1() -> Result<()> { - let rustls = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .finish(|req: Request| { - assert!(req.peer_addr().is_some()); - assert_eq!(req.version(), Version::HTTP_2); - future::ok::<_, Error>(Response::Ok().finish()) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_body() -> Result<()> { - let data = "HELLOWORLD".to_owned().repeat(64 * 1024); - let rustls = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|mut req: Request<_>| { - load_body(req.take_payload()) - .and_then(|body| Ok(Response::Ok().body(body))) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap(); - assert!(response.status().is_success()); - - let body = srv.load_body(response).unwrap(); - assert_eq!(&body, data.as_bytes()); - Ok(()) -} - -#[test] -fn test_h2_content_length() { - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); - let statuses = [ - StatusCode::NO_CONTENT, - StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, - StatusCode::OK, - StatusCode::NOT_FOUND, - ]; - future::ok::<_, ()>(Response::new(statuses[indx])) - }) - .map_err(|_| ()), - ) - }); - - let header = HeaderName::from_static("content-length"); - let value = HeaderValue::from_static("0"); - - { - for i in 0..4 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - - let req = srv - .request(Method::HEAD, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - } - - for i in 4..6 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), Some(&value)); - } - } -} - -#[test] -fn test_h2_headers() { - let data = STR.repeat(10); - let data2 = data.clone(); - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - let data = data.clone(); - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build().h2(move |_| { - let mut config = Response::Ok(); - for idx in 0..90 { - config.header( - format!("X-TEST-{}", idx).as_str(), - "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", - ); - } - future::ok::<_, ()>(config.body(data.clone())) - }).map_err(|_| ())) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from(data2)); -} - -const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; - -#[test] -fn test_h2_body2() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_head_empty() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.version(), Version::HTTP_2); - - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - ok::<_, ()>( - Response::Ok().content_length(STR.len() as u64).body(STR), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary2() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } -} - -#[test] -fn test_h2_body_length() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .body(body::SizedStream::new(STR.len() as u64, body)), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_body_chunked_explicit() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = - once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") - .streaming(body), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); - - // read response - let bytes = srv.load_body(response).unwrap(); - - // decode - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_response_http_error_handling() { - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(new_service_cfg(|_: &ServerConfig| { - Ok::<_, ()>(|_| { - let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() - .header(http::header::CONTENT_TYPE, broken_header) - .body(STR), - ) - }) - })) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); -} - -#[test] -fn test_h2_service_error() { - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| Err::(error::ErrorBadRequest("error"))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"error")); -} diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index a31e4ac89..c37e8fad1 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -2,14 +2,14 @@ use std::io::{Read, Write}; use std::time::Duration; use std::{net, thread}; -use actix_http_test::TestServer; +use actix_http_test::{block_on, TestServer}; use actix_server_config::ServerConfig; -use actix_service::{new_service_cfg, service_fn, NewService}; +use actix_service::{factory_fn_cfg, pipeline, service_fn, ServiceFactory}; use bytes::Bytes; -use futures::future::{self, ok, Future}; -use futures::stream::{once, Stream}; +use futures::future::{self, err, ok, ready, FutureExt}; +use futures::stream::{once, StreamExt}; use regex::Regex; -use tokio_timer::sleep; +use tokio_timer::delay_for; use actix_http::httpmessage::HttpMessage; use actix_http::{ @@ -18,346 +18,379 @@ use actix_http::{ #[test] fn test_h1() { - let mut srv = TestServer::new(|| { - HttpService::build() - .keep_alive(KeepAlive::Disabled) - .client_timeout(1000) - .client_disconnect(1000) - .h1(|req: Request| { - assert!(req.peer_addr().is_some()); - future::ok::<_, ()>(Response::Ok().finish()) - }) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_disconnect(1000) + .h1(|req: Request| { + assert!(req.peer_addr().is_some()); + future::ok::<_, ()>(Response::Ok().finish()) + }) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn test_h1_2() { - let mut srv = TestServer::new(|| { - HttpService::build() - .keep_alive(KeepAlive::Disabled) - .client_timeout(1000) - .client_disconnect(1000) - .finish(|req: Request| { - assert!(req.peer_addr().is_some()); - assert_eq!(req.version(), http::Version::HTTP_11); - future::ok::<_, ()>(Response::Ok().finish()) - }) - .map(|_| ()) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_disconnect(1000) + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), http::Version::HTTP_11); + future::ok::<_, ()>(Response::Ok().finish()) + }) + .map(|_| ()) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn test_expect_continue() { - let srv = TestServer::new(|| { - HttpService::build() - .expect(service_fn(|req: Request| { - if req.head().uri.query() == Some("yes=") { - Ok(req) - } else { - Err(error::ErrorPreconditionFailed("error")) - } - })) - .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .expect(service_fn(|req: Request| { + if req.head().uri.query() == Some("yes=") { + ok(req) + } else { + err(error::ErrorPreconditionFailed("error")) + } + })) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); - let mut data = String::new(); - let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); - let mut data = String::new(); - let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = + stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + }) } #[test] fn test_expect_continue_h1() { - let srv = TestServer::new(|| { - HttpService::build() - .expect(service_fn(|req: Request| { - sleep(Duration::from_millis(20)).then(move |_| { - if req.head().uri.query() == Some("yes=") { - Ok(req) - } else { - Err(error::ErrorPreconditionFailed("error")) - } - }) - })) - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .expect(service_fn(|req: Request| { + delay_for(Duration::from_millis(20)).then(move |_| { + if req.head().uri.query() == Some("yes=") { + ok(req) + } else { + err(error::ErrorPreconditionFailed("error")) + } + }) + })) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); - let mut data = String::new(); - let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); - let mut data = String::new(); - let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = + stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + }) } #[test] fn test_chunked_payload() { - let chunk_sizes = vec![32768, 32, 32768]; - let total_size: usize = chunk_sizes.iter().sum(); + block_on(async { + let chunk_sizes = vec![32768, 32, 32768]; + let total_size: usize = chunk_sizes.iter().sum(); - let srv = TestServer::new(|| { - HttpService::build().h1(|mut request: Request| { - request - .take_payload() - .map_err(|e| panic!(format!("Error reading payload: {}", e))) - .fold(0usize, |acc, chunk| future::ok::<_, ()>(acc + chunk.len())) - .map(|req_size| Response::Ok().body(format!("size={}", req_size))) - }) - }); + let srv = TestServer::start(|| { + HttpService::build().h1(service_fn(|mut request: Request| { + request + .take_payload() + .map(|res| match res { + Ok(pl) => pl, + Err(e) => panic!(format!("Error reading payload: {}", e)), + }) + .fold(0usize, |acc, chunk| ready(acc + chunk.len())) + .map(|req_size| { + Ok::<_, Error>(Response::Ok().body(format!("size={}", req_size))) + }) + })) + }); - let returned_size = { - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream - .write_all(b"POST /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); + let returned_size = { + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream + .write_all(b"POST /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); - for chunk_size in chunk_sizes.iter() { - let mut bytes = Vec::new(); - let random_bytes: Vec = - (0..*chunk_size).map(|_| rand::random::()).collect(); + for chunk_size in chunk_sizes.iter() { + let mut bytes = Vec::new(); + let random_bytes: Vec = + (0..*chunk_size).map(|_| rand::random::()).collect(); - bytes.extend(format!("{:X}\r\n", chunk_size).as_bytes()); - bytes.extend(&random_bytes[..]); - bytes.extend(b"\r\n"); - let _ = stream.write_all(&bytes); - } + bytes.extend(format!("{:X}\r\n", chunk_size).as_bytes()); + bytes.extend(&random_bytes[..]); + bytes.extend(b"\r\n"); + let _ = stream.write_all(&bytes); + } - let _ = stream.write_all(b"0\r\n\r\n"); - stream.shutdown(net::Shutdown::Write).unwrap(); + let _ = stream.write_all(b"0\r\n\r\n"); + stream.shutdown(net::Shutdown::Write).unwrap(); - let mut data = String::new(); - let _ = stream.read_to_string(&mut data); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); - let re = Regex::new(r"size=(\d+)").unwrap(); - let size: usize = match re.captures(&data) { - Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(), - None => panic!(format!("Failed to find size in HTTP Response: {}", data)), + let re = Regex::new(r"size=(\d+)").unwrap(); + let size: usize = match re.captures(&data) { + Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(), + None => { + panic!(format!("Failed to find size in HTTP Response: {}", data)) + } + }; + size }; - size - }; - assert_eq!(returned_size, total_size); + assert_eq!(returned_size, total_size); + }) } #[test] fn test_slow_request() { - let srv = TestServer::new(|| { - HttpService::build() - .client_timeout(100) - .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .client_timeout(100) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); - let mut data = String::new(); - let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + }) } #[test] fn test_http1_malformed_request() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n"); - let mut data = String::new(); - let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 400 Bad Request")); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 400 Bad Request")); + }) } #[test] fn test_http1_keepalive() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); - let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + }) } #[test] fn test_http1_keepalive_timeout() { - let srv = TestServer::new(|| { - HttpService::build() - .keep_alive(1) - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(1) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); - thread::sleep(Duration::from_millis(1100)); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + thread::sleep(Duration::from_millis(1100)); - let mut data = vec![0; 1024]; - let res = stream.read(&mut data).unwrap(); - assert_eq!(res, 0); + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); + }) } #[test] fn test_http1_keepalive_close() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = - stream.write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream + .write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); - let mut data = vec![0; 1024]; - let res = stream.read(&mut data).unwrap(); - assert_eq!(res, 0); + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); + }) } #[test] fn test_http10_keepalive_default_close() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); - let mut data = vec![0; 1024]; - let res = stream.read(&mut data).unwrap(); - assert_eq!(res, 0); + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); + }) } #[test] fn test_http10_keepalive() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream - .write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all( + b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n", + ); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); - let mut data = vec![0; 1024]; - let res = stream.read(&mut data).unwrap(); - assert_eq!(res, 0); + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); + }) } #[test] fn test_http1_keepalive_disabled() { - let srv = TestServer::new(|| { - HttpService::build() - .keep_alive(KeepAlive::Disabled) - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); - let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); - let mut data = vec![0; 1024]; - let _ = stream.read(&mut data); - assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); - let mut data = vec![0; 1024]; - let res = stream.read(&mut data).unwrap(); - assert_eq!(res, 0); + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); + }) } #[test] fn test_content_length() { - use actix_http::http::{ - header::{HeaderName, HeaderValue}, - StatusCode, - }; + block_on(async { + use actix_http::http::{ + header::{HeaderName, HeaderValue}, + StatusCode, + }; - let mut srv = TestServer::new(|| { - HttpService::build().h1(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); - let statuses = [ - StatusCode::NO_CONTENT, - StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, - StatusCode::OK, - StatusCode::NOT_FOUND, - ]; - future::ok::<_, ()>(Response::new(statuses[indx])) - }) - }); + let srv = TestServer::start(|| { + HttpService::build().h1(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + }); - let header = HeaderName::from_static("content-length"); - let value = HeaderValue::from_static("0"); + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); - { - for i in 0..4 { - let req = srv - .request(http::Method::GET, srv.url(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); + { + for i in 0..4 { + let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); + assert_eq!(response.headers().get(&header), None); - let req = srv - .request(http::Method::HEAD, srv.url(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); + let req = srv.request(http::Method::HEAD, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } } - - for i in 4..6 { - let req = srv - .request(http::Method::GET, srv.url(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), Some(&value)); - } - } + }) } #[test] fn test_h1_headers() { - let data = STR.repeat(10); - let data2 = data.clone(); + block_on(async { + let data = STR.repeat(10); + let data2 = data.clone(); - let mut srv = TestServer::new(move || { - let data = data.clone(); - HttpService::build().h1(move |_| { + let mut srv = TestServer::start(move || { + let data = data.clone(); + HttpService::build().h1(move |_| { let mut builder = Response::Ok(); for idx in 0..90 { builder.header( @@ -379,14 +412,15 @@ fn test_h1_headers() { } future::ok::<_, ()>(builder.body(data.clone())) }) - }); + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from(data2)); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); + }) } const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ @@ -413,208 +447,228 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ #[test] fn test_h1_body() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().body(STR))) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_h1_head_empty() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); - let response = srv.block_on(srv.head("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.head("/").send().await.unwrap(); + assert!(response.status().is_success()); - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); + }) } #[test] fn test_h1_head_binary() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) - }) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) + }) + }); - let response = srv.block_on(srv.head("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.head("/").send().await.unwrap(); + assert!(response.status().is_success()); - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); + }) } #[test] fn test_h1_head_binary2() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); - let response = srv.block_on(srv.head("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.head("/").send().await.unwrap(); + assert!(response.status().is_success()); - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + }) } #[test] fn test_h1_body_length() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), - ) - }) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_h1_body_chunked_explicit() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") - .streaming(body), - ) - }) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!( - response - .headers() - .get(header::TRANSFER_ENCODING) - .unwrap() - .to_str() - .unwrap(), - "chunked" - ); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); - // read response - let bytes = srv.load_body(response).unwrap(); + // read response + let bytes = srv.load_body(response).await.unwrap(); - // decode - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_h1_body_chunked_implicit() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>(Response::Ok().streaming(body)) - }) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>(Response::Ok().streaming(body)) + }) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!( - response - .headers() - .get(header::TRANSFER_ENCODING) - .unwrap() - .to_str() - .unwrap(), - "chunked" - ); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_h1_response_http_error_handling() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(new_service_cfg(|_: &ServerConfig| { - Ok::<_, ()>(|_| { - let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() - .header(http::header::CONTENT_TYPE, broken_header) - .body(STR), - ) - }) - })) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build().h1(factory_fn_cfg(|_: &ServerConfig| { + ok::<_, ()>(pipeline(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + })) + })) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + let response = srv.get("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); + }) } #[test] fn test_h1_service_error() { - let mut srv = TestServer::new(|| { - HttpService::build() - .h1(|_| Err::(error::ErrorBadRequest("error"))) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build() + .h1(|_| future::err::(error::ErrorBadRequest("error"))) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); + let response = srv.get("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"error")); + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); + }) } #[test] fn test_h1_on_connect() { - let mut srv = TestServer::new(|| { - HttpService::build() - .on_connect(|_| 10usize) - .h1(|req: Request| { - assert!(req.extensions().contains::()); - future::ok::<_, ()>(Response::Ok().finish()) - }) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::build() + .on_connect(|_| 10usize) + .h1(|req: Request| { + assert!(req.extensions().contains::()); + future::ok::<_, ()>(Response::Ok().finish()) + }) + }); - let response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + }) } diff --git a/actix-http/tests/test_ssl_server.rs b/actix-http/tests/test_ssl_server.rs deleted file mode 100644 index 897d92b37..000000000 --- a/actix-http/tests/test_ssl_server.rs +++ /dev/null @@ -1,480 +0,0 @@ -#![cfg(feature = "ssl")] -use actix_codec::{AsyncRead, AsyncWrite}; -use actix_http_test::TestServer; -use actix_server::ssl::OpensslAcceptor; -use actix_server_config::ServerConfig; -use actix_service::{new_service_cfg, NewService}; - -use bytes::{Bytes, BytesMut}; -use futures::future::{ok, Future}; -use futures::stream::{once, Stream}; -use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; -use std::io::Result; - -use actix_http::error::{ErrorBadRequest, PayloadError}; -use actix_http::http::header::{self, HeaderName, HeaderValue}; -use actix_http::http::{Method, StatusCode, Version}; -use actix_http::httpmessage::HttpMessage; -use actix_http::{body, Error, HttpService, Request, Response}; - -fn load_body(stream: S) -> impl Future -where - S: Stream, -{ - stream.fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) -} - -fn ssl_acceptor() -> Result> { - // load ssl keys - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("../tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("../tests/cert.pem") - .unwrap(); - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(AlpnError::NOACK) - } - }); - builder.set_alpn_protos(b"\x02h2")?; - Ok(OpensslAcceptor::new(builder.build())) -} - -#[test] -fn test_h2() -> Result<()> { - let openssl = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, Error>(Response::Ok().finish())) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_1() -> Result<()> { - let openssl = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .finish(|req: Request| { - assert!(req.peer_addr().is_some()); - assert_eq!(req.version(), Version::HTTP_2); - ok::<_, Error>(Response::Ok().finish()) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_body() -> Result<()> { - let data = "HELLOWORLD".to_owned().repeat(64 * 1024); - let openssl = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|mut req: Request<_>| { - load_body(req.take_payload()) - .and_then(|body| Ok(Response::Ok().body(body))) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap(); - assert!(response.status().is_success()); - - let body = srv.load_body(response).unwrap(); - assert_eq!(&body, data.as_bytes()); - Ok(()) -} - -#[test] -fn test_h2_content_length() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); - let statuses = [ - StatusCode::NO_CONTENT, - StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, - StatusCode::OK, - StatusCode::NOT_FOUND, - ]; - ok::<_, ()>(Response::new(statuses[indx])) - }) - .map_err(|_| ()), - ) - }); - - let header = HeaderName::from_static("content-length"); - let value = HeaderValue::from_static("0"); - - { - for i in 0..4 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - - let req = srv - .request(Method::HEAD, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - } - - for i in 4..6 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), Some(&value)); - } - } -} - -#[test] -fn test_h2_headers() { - let data = STR.repeat(10); - let data2 = data.clone(); - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - let data = data.clone(); - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build().h2(move |_| { - let mut builder = Response::Ok(); - for idx in 0..90 { - builder.header( - format!("X-TEST-{}", idx).as_str(), - "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", - ); - } - ok::<_, ()>(builder.body(data.clone())) - }).map_err(|_| ())) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from(data2)); -} - -const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; - -#[test] -fn test_h2_body2() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_head_empty() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.version(), Version::HTTP_2); - - { - let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - ok::<_, ()>( - Response::Ok().content_length(STR.len() as u64).body(STR), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary2() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } -} - -#[test] -fn test_h2_body_length() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .body(body::SizedStream::new(STR.len() as u64, body)), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_body_chunked_explicit() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = - once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") - .streaming(body), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); - - // read response - let bytes = srv.load_body(response).unwrap(); - - // decode - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_response_http_error_handling() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(new_service_cfg(|_: &ServerConfig| { - Ok::<_, ()>(|_| { - let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() - .header(header::CONTENT_TYPE, broken_header) - .body(STR), - ) - }) - })) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); -} - -#[test] -fn test_h2_service_error() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| Err::(ErrorBadRequest("error"))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"error")); -} - -#[test] -fn test_h2_on_connect() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .on_connect(|_| 10usize) - .h2(|req: Request| { - assert!(req.extensions().contains::()); - ok::<_, ()>(Response::Ok().finish()) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); -} diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 65a4d094d..74c1cb405 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -1,26 +1,27 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; -use actix_http_test::TestServer; +use actix_http_test::{block_on, TestServer}; use actix_utils::framed::FramedTransport; use bytes::{Bytes, BytesMut}; -use futures::future::{self, ok}; -use futures::{Future, Sink, Stream}; +use futures::future; +use futures::{SinkExt, StreamExt}; -fn ws_service( - (req, framed): (Request, Framed), -) -> impl Future { +async fn ws_service( + (req, mut framed): (Request, Framed), +) -> Result<(), Error> { let res = ws::handshake(req.head()).unwrap().message_body(()); framed .send((res, body::BodySize::None).into()) + .await + .unwrap(); + + FramedTransport::new(framed.into_framed(ws::Codec::new()), service) + .await .map_err(|_| panic!()) - .and_then(|framed| { - FramedTransport::new(framed.into_framed(ws::Codec::new()), service) - .map_err(|_| panic!()) - }) } -fn service(msg: ws::Frame) -> impl Future { +async fn service(msg: ws::Frame) -> Result { let msg = match msg { ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Text(text) => { @@ -30,47 +31,56 @@ fn service(msg: ws::Frame) -> impl Future { ws::Frame::Close(reason) => ws::Message::Close(reason), _ => panic!(), }; - ok(msg) + Ok(msg) } #[test] fn test_simple() { - let mut srv = TestServer::new(|| { - HttpService::build() - .upgrade(ws_service) - .finish(|_| future::ok::<_, ()>(Response::NotFound())) - }); + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build() + .upgrade(actix_service::service_fn(ws_service)) + .finish(|_| future::ok::<_, ()>(Response::NotFound())) + }); - // client service - let framed = srv.ws().unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + // client service + let mut framed = srv.ws().await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Text(Some(BytesMut::from("text"))) + ); - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) - ); + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) - .unwrap(); + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) - ); + let (item, _framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) + ); + }) } diff --git a/actix-identity/Cargo.toml b/actix-identity/Cargo.toml index e645275a2..d05b37685 100644 --- a/actix-identity/Cargo.toml +++ b/actix-identity/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-identity" -version = "0.1.0" +version = "0.2.0-alpha.1" authors = ["Nikolay Kim "] description = "Identity service for actix web framework." readme = "README.md" @@ -17,14 +17,14 @@ name = "actix_identity" path = "src/lib.rs" [dependencies] -actix-web = { version = "1.0.0", default-features = false, features = ["secure-cookies"] } -actix-service = "0.4.0" -futures = "0.1.25" +actix-web = { version = "2.0.0-alpha.1", default-features = false, features = ["secure-cookies"] } +actix-service = "1.0.0-alpha.1" +futures = "0.3.1" serde = "1.0" serde_json = "1.0" time = "0.1.42" [dev-dependencies] -actix-rt = "0.2.2" -actix-http = "0.2.3" +actix-rt = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" bytes = "0.4" \ No newline at end of file diff --git a/actix-identity/src/lib.rs b/actix-identity/src/lib.rs index 7216104eb..2980a7753 100644 --- a/actix-identity/src/lib.rs +++ b/actix-identity/src/lib.rs @@ -16,7 +16,7 @@ //! use actix_web::*; //! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; //! -//! fn index(id: Identity) -> String { +//! async fn index(id: Identity) -> String { //! // access request identity //! if let Some(id) = id.identity() { //! format!("Welcome! {}", id) @@ -25,12 +25,12 @@ //! } //! } //! -//! fn login(id: Identity) -> HttpResponse { +//! async fn login(id: Identity) -> HttpResponse { //! id.remember("User1".to_owned()); // <- remember identity //! HttpResponse::Ok().finish() //! } //! -//! fn logout(id: Identity) -> HttpResponse { +//! async fn logout(id: Identity) -> HttpResponse { //! id.forget(); // <- remove identity //! HttpResponse::Ok().finish() //! } @@ -47,12 +47,13 @@ //! } //! ``` use std::cell::RefCell; +use std::future::Future; use std::rc::Rc; +use std::task::{Context, Poll}; use std::time::SystemTime; use actix_service::{Service, Transform}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Future, IntoFuture, Poll}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use serde::{Deserialize, Serialize}; use time::Duration; @@ -165,21 +166,21 @@ where impl FromRequest for Identity { type Config = (); type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(Identity(req.clone())) + ok(Identity(req.clone())) } } /// Identity policy definition. pub trait IdentityPolicy: Sized + 'static { /// The return type of the middleware - type Future: IntoFuture, Error = Error>; + type Future: Future, Error>>; /// The return type of the middleware - type ResponseFuture: IntoFuture; + type ResponseFuture: Future>; /// Parse the session from request and load data from a service identity. fn from_request(&self, request: &mut ServiceRequest) -> Self::Future; @@ -234,7 +235,7 @@ where type Error = Error; type InitError = (); type Transform = IdentityServiceMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(IdentityServiceMiddleware { @@ -261,46 +262,39 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.borrow_mut().poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.borrow_mut().poll_ready(cx) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { let srv = self.service.clone(); let backend = self.backend.clone(); + let fut = self.backend.from_request(&mut req); - Box::new( - self.backend.from_request(&mut req).into_future().then( - move |res| match res { - Ok(id) => { - req.extensions_mut() - .insert(IdentityItem { id, changed: false }); + async move { + match fut.await { + Ok(id) => { + req.extensions_mut() + .insert(IdentityItem { id, changed: false }); - Either::A(srv.borrow_mut().call(req).and_then(move |mut res| { - let id = - res.request().extensions_mut().remove::(); + let mut res = srv.borrow_mut().call(req).await?; + let id = res.request().extensions_mut().remove::(); - if let Some(id) = id { - Either::A( - backend - .to_response(id.id, id.changed, &mut res) - .into_future() - .then(move |t| match t { - Ok(_) => Ok(res), - Err(e) => Ok(res.error_response(e)), - }), - ) - } else { - Either::B(ok(res)) - } - })) + if let Some(id) = id { + match backend.to_response(id.id, id.changed, &mut res).await { + Ok(_) => Ok(res), + Err(e) => Ok(res.error_response(e)), + } + } else { + Ok(res) } - Err(err) => Either::B(ok(req.error_response(err))), - }, - ), - ) + } + Err(err) => Ok(req.error_response(err)), + } + } + .boxed_local() } } @@ -547,11 +541,11 @@ impl CookieIdentityPolicy { } impl IdentityPolicy for CookieIdentityPolicy { - type Future = Result, Error>; - type ResponseFuture = Result<(), Error>; + type Future = Ready, Error>>; + type ResponseFuture = Ready>; fn from_request(&self, req: &mut ServiceRequest) -> Self::Future { - Ok(self.0.load(req).map( + ok(self.0.load(req).map( |CookieValue { identity, login_timestamp, @@ -603,7 +597,7 @@ impl IdentityPolicy for CookieIdentityPolicy { } else { Ok(()) }; - Ok(()) + ok(()) } } @@ -613,7 +607,7 @@ mod tests { use super::*; use actix_web::http::StatusCode; - use actix_web::test::{self, TestRequest}; + use actix_web::test::{self, block_on, TestRequest}; use actix_web::{web, App, Error, HttpResponse}; const COOKIE_KEY_MASTER: [u8; 32] = [0; 32]; @@ -622,115 +616,138 @@ mod tests { #[test] fn test_identity() { - let mut srv = test::init_service( - App::new() - .wrap(IdentityService::new( - CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) - .domain("www.rust-lang.org") - .name(COOKIE_NAME) - .path("/") - .secure(true), - )) - .service(web::resource("/index").to(|id: Identity| { - if id.identity().is_some() { - HttpResponse::Created() - } else { + block_on(async { + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .secure(true), + )) + .service(web::resource("/index").to(|id: Identity| { + if id.identity().is_some() { + HttpResponse::Created() + } else { + HttpResponse::Ok() + } + })) + .service(web::resource("/login").to(|id: Identity| { + id.remember(COOKIE_LOGIN.to_string()); HttpResponse::Ok() - } - })) - .service(web::resource("/login").to(|id: Identity| { - id.remember(COOKIE_LOGIN.to_string()); - HttpResponse::Ok() - })) - .service(web::resource("/logout").to(|id: Identity| { - if id.identity().is_some() { - id.forget(); - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - })), - ); - let resp = - test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()); - assert_eq!(resp.status(), StatusCode::OK); + })) + .service(web::resource("/logout").to(|id: Identity| { + if id.identity().is_some() { + id.forget(); + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + })), + ) + .await; + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/index").to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::OK); - let resp = - test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); - assert_eq!(resp.status(), StatusCode::OK); - let c = resp.response().cookies().next().unwrap().to_owned(); + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/login").to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::OK); + let c = resp.response().cookies().next().unwrap().to_owned(); - let resp = test::call_service( - &mut srv, - TestRequest::with_uri("/index") - .cookie(c.clone()) - .to_request(), - ); - assert_eq!(resp.status(), StatusCode::CREATED); + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/index") + .cookie(c.clone()) + .to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::CREATED); - let resp = test::call_service( - &mut srv, - TestRequest::with_uri("/logout") - .cookie(c.clone()) - .to_request(), - ); - assert_eq!(resp.status(), StatusCode::OK); - assert!(resp.headers().contains_key(header::SET_COOKIE)) + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/logout") + .cookie(c.clone()) + .to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)) + }) } #[test] fn test_identity_max_age_time() { - let duration = Duration::days(1); - let mut srv = test::init_service( - App::new() - .wrap(IdentityService::new( - CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) - .domain("www.rust-lang.org") - .name(COOKIE_NAME) - .path("/") - .max_age_time(duration) - .secure(true), - )) - .service(web::resource("/login").to(|id: Identity| { - id.remember("test".to_string()); - HttpResponse::Ok() - })), - ); - let resp = - test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); - assert_eq!(resp.status(), StatusCode::OK); - assert!(resp.headers().contains_key(header::SET_COOKIE)); - let c = resp.response().cookies().next().unwrap().to_owned(); - assert_eq!(duration, c.max_age().unwrap()); + block_on(async { + let duration = Duration::days(1); + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .max_age_time(duration) + .secure(true), + )) + .service(web::resource("/login").to(|id: Identity| { + id.remember("test".to_string()); + HttpResponse::Ok() + })), + ) + .await; + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/login").to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)); + let c = resp.response().cookies().next().unwrap().to_owned(); + assert_eq!(duration, c.max_age().unwrap()); + }) } #[test] fn test_identity_max_age() { - let seconds = 60; - let mut srv = test::init_service( - App::new() - .wrap(IdentityService::new( - CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) - .domain("www.rust-lang.org") - .name(COOKIE_NAME) - .path("/") - .max_age(seconds) - .secure(true), - )) - .service(web::resource("/login").to(|id: Identity| { - id.remember("test".to_string()); - HttpResponse::Ok() - })), - ); - let resp = - test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); - assert_eq!(resp.status(), StatusCode::OK); - assert!(resp.headers().contains_key(header::SET_COOKIE)); - let c = resp.response().cookies().next().unwrap().to_owned(); - assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap()); + block_on(async { + let seconds = 60; + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .max_age(seconds) + .secure(true), + )) + .service(web::resource("/login").to(|id: Identity| { + id.remember("test".to_string()); + HttpResponse::Ok() + })), + ) + .await; + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/login").to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)); + let c = resp.response().cookies().next().unwrap().to_owned(); + assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap()); + }) } - fn create_identity_server< + async fn create_identity_server< F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static, >( f: F, @@ -747,13 +764,16 @@ mod tests { .secure(false) .name(COOKIE_NAME)))) .service(web::resource("/").to(|id: Identity| { - let identity = id.identity(); - if identity.is_none() { - id.remember(COOKIE_LOGIN.to_string()) + async move { + let identity = id.identity(); + if identity.is_none() { + id.remember(COOKIE_LOGIN.to_string()) + } + web::Json(identity) } - web::Json(identity) })), ) + .await } fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { @@ -786,15 +806,8 @@ mod tests { jar.get(COOKIE_NAME).unwrap().clone() } - fn assert_logged_in(response: &mut ServiceResponse, identity: Option<&str>) { - use bytes::BytesMut; - use futures::Stream; - let bytes = - test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); + async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) { + let bytes = test::read_body(response).await; let resp: Option = serde_json::from_slice(&bytes[..]).unwrap(); assert_eq!(resp.as_ref().map(|s| s.borrow()), identity); } @@ -874,183 +887,221 @@ mod tests { #[test] fn test_identity_legacy_cookie_is_set() { - let mut srv = create_identity_server(|c| c); - let mut resp = - test::call_service(&mut srv, TestRequest::with_uri("/").to_request()); - assert_logged_in(&mut resp, None); - assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); + block_on(async { + let mut srv = create_identity_server(|c| c).await; + let mut resp = + test::call_service(&mut srv, TestRequest::with_uri("/").to_request()) + .await; + assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); + assert_logged_in(resp, None).await; + }) } #[test] fn test_identity_legacy_cookie_works() { - let mut srv = create_identity_server(|c| c); - let cookie = legacy_login_cookie(COOKIE_LOGIN); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); - assert_no_login_cookie(&mut resp); + block_on(async { + let mut srv = create_identity_server(|c| c).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + }) } #[test] fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { - let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); - let cookie = legacy_login_cookie(COOKIE_LOGIN); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, None); - assert_login_cookie( - &mut resp, - COOKIE_LOGIN, - LoginTimestampCheck::NoTimestamp, - VisitTimeStampCheck::NewTimestamp, - ); + block_on(async { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + }) } #[test] fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); - let cookie = legacy_login_cookie(COOKIE_LOGIN); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, None); - assert_login_cookie( - &mut resp, - COOKIE_LOGIN, - LoginTimestampCheck::NewTimestamp, - VisitTimeStampCheck::NoTimestamp, - ); + block_on(async { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + }) } #[test] fn test_identity_cookie_rejected_if_login_timestamp_needed() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); - let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, None); - assert_login_cookie( - &mut resp, - COOKIE_LOGIN, - LoginTimestampCheck::NewTimestamp, - VisitTimeStampCheck::NoTimestamp, - ); + block_on(async { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + }) } #[test] fn test_identity_cookie_rejected_if_visit_timestamp_needed() { - let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); - let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, None); - assert_login_cookie( - &mut resp, - COOKIE_LOGIN, - LoginTimestampCheck::NoTimestamp, - VisitTimeStampCheck::NewTimestamp, - ); + block_on(async { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + }) } #[test] fn test_identity_cookie_rejected_if_login_timestamp_too_old() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); - let cookie = login_cookie( - COOKIE_LOGIN, - Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), - None, - ); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, None); - assert_login_cookie( - &mut resp, - COOKIE_LOGIN, - LoginTimestampCheck::NewTimestamp, - VisitTimeStampCheck::NoTimestamp, - ); + block_on(async { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie( + COOKIE_LOGIN, + Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), + None, + ); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + }) } #[test] fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { - let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); - let cookie = login_cookie( - COOKIE_LOGIN, - None, - Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), - ); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, None); - assert_login_cookie( - &mut resp, - COOKIE_LOGIN, - LoginTimestampCheck::NoTimestamp, - VisitTimeStampCheck::NewTimestamp, - ); + block_on(async { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = login_cookie( + COOKIE_LOGIN, + None, + Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), + ); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + }) } #[test] fn test_identity_cookie_not_updated_on_login_deadline() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); - let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); - assert_no_login_cookie(&mut resp); + block_on(async { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + }) } #[test] fn test_identity_cookie_updated_on_visit_deadline() { - let mut srv = create_identity_server(|c| { - c.visit_deadline(Duration::days(90)) - .login_deadline(Duration::days(90)) - }); - let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap(); - let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); - let mut resp = test::call_service( - &mut srv, - TestRequest::with_uri("/") - .cookie(cookie.clone()) - .to_request(), - ); - assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); - assert_login_cookie( - &mut resp, - COOKIE_LOGIN, - LoginTimestampCheck::OldTimestamp(timestamp), - VisitTimeStampCheck::NewTimestamp, - ); + block_on(async { + let mut srv = create_identity_server(|c| { + c.visit_deadline(Duration::days(90)) + .login_deadline(Duration::days(90)) + }) + .await; + let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap(); + let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::OldTimestamp(timestamp), + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + }) } } diff --git a/actix-multipart/Cargo.toml b/actix-multipart/Cargo.toml index 2168c259a..f5cdc8afd 100644 --- a/actix-multipart/Cargo.toml +++ b/actix-multipart/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-multipart" -version = "0.1.4" +version = "0.2.0-alpha.1" authors = ["Nikolay Kim "] description = "Multipart support for actix web framework." readme = "README.md" @@ -18,17 +18,18 @@ name = "actix_multipart" path = "src/lib.rs" [dependencies] -actix-web = { version = "1.0.0", default-features = false } -actix-service = "0.4.1" +actix-web = { version = "2.0.0-alpha.1", default-features = false } +actix-service = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" bytes = "0.4" derive_more = "0.15.0" httparse = "1.3" -futures = "0.1.25" +futures = "0.3.1" log = "0.4" mime = "0.3" time = "0.1" twoway = "0.2" [dev-dependencies] -actix-rt = "0.2.2" -actix-http = "0.2.4" \ No newline at end of file +actix-rt = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" \ No newline at end of file diff --git a/actix-multipart/src/extractor.rs b/actix-multipart/src/extractor.rs index 7274ed092..71c815227 100644 --- a/actix-multipart/src/extractor.rs +++ b/actix-multipart/src/extractor.rs @@ -1,5 +1,6 @@ //! Multipart payload support use actix_web::{dev::Payload, Error, FromRequest, HttpRequest}; +use futures::future::{ok, Ready}; use crate::server::Multipart; @@ -10,33 +11,31 @@ use crate::server::Multipart; /// ## Server example /// /// ```rust -/// # use futures::{Future, Stream}; -/// # use futures::future::{ok, result, Either}; +/// use futures::{Stream, StreamExt}; /// use actix_web::{web, HttpResponse, Error}; /// use actix_multipart as mp; /// -/// fn index(payload: mp::Multipart) -> impl Future { -/// payload.from_err() // <- get multipart stream for current request -/// .and_then(|field| { // <- iterate over multipart items +/// async fn index(mut payload: mp::Multipart) -> Result { +/// // iterate over multipart stream +/// while let Some(item) = payload.next().await { +/// let mut field = item?; +/// /// // Field in turn is stream of *Bytes* object -/// field.from_err() -/// .fold((), |_, chunk| { -/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk)); -/// Ok::<_, Error>(()) -/// }) -/// }) -/// .fold((), |_, _| Ok::<_, Error>(())) -/// .map(|_| HttpResponse::Ok().into()) +/// while let Some(chunk) = field.next().await { +/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk?)); +/// } +/// } +/// Ok(HttpResponse::Ok().into()) /// } /// # fn main() {} /// ``` impl FromRequest for Multipart { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = (); #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - Ok(Multipart::new(req.headers(), payload.take())) + ok(Multipart::new(req.headers(), payload.take())) } } diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index a7c787f46..dd7852c8e 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -1,15 +1,17 @@ //! Multipart payload support use std::cell::{Cell, RefCell, RefMut}; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::{cmp, fmt}; use bytes::{Bytes, BytesMut}; -use futures::task::{current as current_task, Task}; -use futures::{Async, Poll, Stream}; +use futures::stream::{LocalBoxStream, Stream, StreamExt}; use httparse; use mime; +use actix_utils::task::LocalWaker; use actix_web::error::{ParseError, PayloadError}; use actix_web::http::header::{ self, ContentDisposition, HeaderMap, HeaderName, HeaderValue, @@ -60,7 +62,7 @@ impl Multipart { /// Create multipart instance for boundary. pub fn new(headers: &HeaderMap, stream: S) -> Multipart where - S: Stream + 'static, + S: Stream> + Unpin + 'static, { match Self::boundary(headers) { Ok(boundary) => Multipart { @@ -104,22 +106,25 @@ impl Multipart { } impl Stream for Multipart { - type Item = Field; - type Error = MultipartError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { if let Some(err) = self.error.take() { - Err(err) + Poll::Ready(Some(Err(err))) } else if self.safety.current() { - let mut inner = self.inner.as_mut().unwrap().borrow_mut(); - if let Some(mut payload) = inner.payload.get_mut(&self.safety) { - payload.poll_stream()?; + let this = self.get_mut(); + let mut inner = this.inner.as_mut().unwrap().borrow_mut(); + if let Some(mut payload) = inner.payload.get_mut(&this.safety) { + payload.poll_stream(cx)?; } - inner.poll(&self.safety) + inner.poll(&this.safety, cx) } else if !self.safety.is_clean() { - Err(MultipartError::NotConsumed) + Poll::Ready(Some(Err(MultipartError::NotConsumed))) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -238,9 +243,13 @@ impl InnerMultipart { Ok(Some(eof)) } - fn poll(&mut self, safety: &Safety) -> Poll, MultipartError> { + fn poll( + &mut self, + safety: &Safety, + cx: &mut Context, + ) -> Poll>> { if self.state == InnerState::Eof { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { // release field loop { @@ -249,10 +258,13 @@ impl InnerMultipart { if safety.current() { let stop = match self.item { InnerMultipartItem::Field(ref mut field) => { - match field.borrow_mut().poll(safety)? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(_)) => continue, - Async::Ready(None) => true, + match field.borrow_mut().poll(safety) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(_))) => continue, + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => true, } } InnerMultipartItem::None => false, @@ -277,12 +289,12 @@ impl InnerMultipart { Some(eof) => { if eof { self.state = InnerState::Eof; - return Ok(Async::Ready(None)); + return Poll::Ready(None); } else { self.state = InnerState::Headers; } } - None => return Ok(Async::NotReady), + None => return Poll::Pending, } } // read boundary @@ -291,11 +303,11 @@ impl InnerMultipart { &mut *payload, &self.boundary, )? { - None => return Ok(Async::NotReady), + None => return Poll::Pending, Some(eof) => { if eof { self.state = InnerState::Eof; - return Ok(Async::Ready(None)); + return Poll::Ready(None); } else { self.state = InnerState::Headers; } @@ -311,14 +323,14 @@ impl InnerMultipart { self.state = InnerState::Boundary; headers } else { - return Ok(Async::NotReady); + return Poll::Pending; } } else { unreachable!() } } else { log::debug!("NotReady: field is in flight"); - return Ok(Async::NotReady); + return Poll::Pending; }; // content type @@ -335,7 +347,7 @@ impl InnerMultipart { // nested multipart stream if mt.type_() == mime::MULTIPART { - Err(MultipartError::Nested) + Poll::Ready(Some(Err(MultipartError::Nested))) } else { let field = Rc::new(RefCell::new(InnerField::new( self.payload.clone(), @@ -344,12 +356,7 @@ impl InnerMultipart { )?)); self.item = InnerMultipartItem::Field(Rc::clone(&field)); - Ok(Async::Ready(Some(Field::new( - safety.clone(), - headers, - mt, - field, - )))) + Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field)))) } } } @@ -409,23 +416,21 @@ impl Field { } impl Stream for Field { - type Item = Bytes; - type Error = MultipartError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.safety.current() { let mut inner = self.inner.borrow_mut(); if let Some(mut payload) = inner.payload.as_ref().unwrap().get_mut(&self.safety) { - payload.poll_stream()?; + payload.poll_stream(cx)?; } - inner.poll(&self.safety) } else if !self.safety.is_clean() { - Err(MultipartError::NotConsumed) + Poll::Ready(Some(Err(MultipartError::NotConsumed))) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -482,9 +487,9 @@ impl InnerField { fn read_len( payload: &mut PayloadBuffer, size: &mut u64, - ) -> Poll, MultipartError> { + ) -> Poll>> { if *size == 0 { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { match payload.read_max(*size)? { Some(mut chunk) => { @@ -494,13 +499,13 @@ impl InnerField { if !chunk.is_empty() { payload.unprocessed(chunk); } - Ok(Async::Ready(Some(ch))) + Poll::Ready(Some(Ok(ch))) } None => { if payload.eof && (*size != 0) { - Err(MultipartError::Incomplete) + Poll::Ready(Some(Err(MultipartError::Incomplete))) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -512,15 +517,15 @@ impl InnerField { fn read_stream( payload: &mut PayloadBuffer, boundary: &str, - ) -> Poll, MultipartError> { + ) -> Poll>> { let mut pos = 0; let len = payload.buf.len(); if len == 0 { return if payload.eof { - Err(MultipartError::Incomplete) + Poll::Ready(Some(Err(MultipartError::Incomplete))) } else { - Ok(Async::NotReady) + Poll::Pending }; } @@ -537,10 +542,10 @@ impl InnerField { if let Some(b_len) = b_len { let b_size = boundary.len() + b_len; if len < b_size { - return Ok(Async::NotReady); + return Poll::Pending; } else if &payload.buf[b_len..b_size] == boundary.as_bytes() { // found boundary - return Ok(Async::Ready(None)); + return Poll::Ready(None); } } } @@ -552,9 +557,9 @@ impl InnerField { // check if we have enough data for boundary detection if cur + 4 > len { if cur > 0 { - Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze()))) + Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) } else { - Ok(Async::NotReady) + Poll::Pending } } else { // check boundary @@ -565,7 +570,7 @@ impl InnerField { { if cur != 0 { // return buffer - Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze()))) + Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) } else { pos = cur + 1; continue; @@ -577,49 +582,51 @@ impl InnerField { } } } else { - Ok(Async::Ready(Some(payload.buf.take().freeze()))) + Poll::Ready(Some(Ok(payload.buf.take().freeze()))) }; } } - fn poll(&mut self, s: &Safety) -> Poll, MultipartError> { + fn poll(&mut self, s: &Safety) -> Poll>> { if self.payload.is_none() { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) { if !self.eof { let res = if let Some(ref mut len) = self.length { - InnerField::read_len(&mut *payload, len)? + InnerField::read_len(&mut *payload, len) } else { - InnerField::read_stream(&mut *payload, &self.boundary)? + InnerField::read_stream(&mut *payload, &self.boundary) }; match res { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(bytes)) => return Ok(Async::Ready(Some(bytes))), - Async::Ready(None) => self.eof = true, + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => self.eof = true, } } - match payload.readline()? { - None => Async::Ready(None), - Some(line) => { + match payload.readline() { + Ok(None) => Poll::Ready(None), + Ok(Some(line)) => { if line.as_ref() != b"\r\n" { log::warn!("multipart field did not read all the data or it is malformed"); } - Async::Ready(None) + Poll::Ready(None) } + Err(e) => Poll::Ready(Some(Err(e))), } } else { - Async::NotReady + Poll::Pending }; - if Async::Ready(None) == result { + if let Poll::Ready(None) = result { self.payload.take(); } - Ok(result) + result } } @@ -659,7 +666,7 @@ impl Clone for PayloadRef { /// most task. #[derive(Debug)] struct Safety { - task: Option, + task: LocalWaker, level: usize, payload: Rc>, clean: Rc>, @@ -669,7 +676,7 @@ impl Safety { fn new() -> Safety { let payload = Rc::new(PhantomData); Safety { - task: None, + task: LocalWaker::new(), level: Rc::strong_count(&payload), clean: Rc::new(Cell::new(true)), payload, @@ -683,17 +690,17 @@ impl Safety { fn is_clean(&self) -> bool { self.clean.get() } -} -impl Clone for Safety { - fn clone(&self) -> Safety { + fn clone(&self, cx: &mut Context) -> Safety { let payload = Rc::clone(&self.payload); - Safety { - task: Some(current_task()), + let s = Safety { + task: LocalWaker::new(), level: Rc::strong_count(&payload), clean: self.clean.clone(), payload, - } + }; + s.task.register(cx.waker()); + s } } @@ -704,7 +711,7 @@ impl Drop for Safety { self.clean.set(true); } if let Some(task) = self.task.take() { - task.notify() + task.wake() } } } @@ -713,31 +720,32 @@ impl Drop for Safety { struct PayloadBuffer { eof: bool, buf: BytesMut, - stream: Box>, + stream: LocalBoxStream<'static, Result>, } impl PayloadBuffer { /// Create new `PayloadBuffer` instance fn new(stream: S) -> Self where - S: Stream + 'static, + S: Stream> + 'static, { PayloadBuffer { eof: false, buf: BytesMut::new(), - stream: Box::new(stream), + stream: stream.boxed_local(), } } - fn poll_stream(&mut self) -> Result<(), PayloadError> { + fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> { loop { - match self.stream.poll()? { - Async::Ready(Some(data)) => self.buf.extend_from_slice(&data), - Async::Ready(None) => { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), + Poll::Ready(Some(Err(e))) => return Err(e), + Poll::Ready(None) => { self.eof = true; return Ok(()); } - Async::NotReady => return Ok(()), + Poll::Pending => return Ok(()), } } } @@ -800,13 +808,14 @@ impl PayloadBuffer { #[cfg(test)] mod tests { - use actix_http::h1::Payload; - use bytes::Bytes; - use futures::unsync::mpsc; - use super::*; + + use actix_http::h1::Payload; + use actix_utils::mpsc; use actix_web::http::header::{DispositionParam, DispositionType}; - use actix_web::test::run_on; + use actix_web::test::block_on; + use bytes::Bytes; + use futures::future::lazy; #[test] fn test_boundary() { @@ -852,12 +861,12 @@ mod tests { } fn create_stream() -> ( - mpsc::UnboundedSender>, - impl Stream, + mpsc::Sender>, + impl Stream>, ) { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::channel(); - (tx, rx.map_err(|_| panic!()).and_then(|res| res)) + (tx, rx.map(|res| res.map_err(|_| panic!()))) } fn create_simple_request_with_header() -> (Bytes, HeaderMap) { @@ -884,28 +893,28 @@ mod tests { #[test] fn test_multipart_no_end_crlf() { - run_on(|| { + block_on(async { let (sender, payload) = create_stream(); let (bytes, headers) = create_simple_request_with_header(); let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf - sender.unbounded_send(Ok(bytes_stripped)).unwrap(); + sender.send(Ok(bytes_stripped)).unwrap(); drop(sender); // eof let mut multipart = Multipart::new(&headers, payload); - match multipart.poll().unwrap() { - Async::Ready(Some(_)) => (), + match multipart.next().await.unwrap() { + Ok(_) => (), _ => unreachable!(), } - match multipart.poll().unwrap() { - Async::Ready(Some(_)) => (), + match multipart.next().await.unwrap() { + Ok(_) => (), _ => unreachable!(), } - match multipart.poll().unwrap() { - Async::Ready(None) => (), + match multipart.next().await { + None => (), _ => unreachable!(), } }) @@ -913,15 +922,15 @@ mod tests { #[test] fn test_multipart() { - run_on(|| { + block_on(async { let (sender, payload) = create_stream(); let (bytes, headers) = create_simple_request_with_header(); - sender.unbounded_send(Ok(bytes)).unwrap(); + sender.send(Ok(bytes)).unwrap(); let mut multipart = Multipart::new(&headers, payload); - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { + match multipart.next().await { + Some(Ok(mut field)) => { let cd = field.content_disposition().unwrap(); assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); @@ -929,37 +938,37 @@ mod tests { assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().subtype(), mime::PLAIN); - match field.poll().unwrap() { - Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"), + match field.next().await.unwrap() { + Ok(chunk) => assert_eq!(chunk, "test"), _ => unreachable!(), } - match field.poll().unwrap() { - Async::Ready(None) => (), + match field.next().await { + None => (), _ => unreachable!(), } } _ => unreachable!(), } - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { + match multipart.next().await.unwrap() { + Ok(mut field) => { assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().subtype(), mime::PLAIN); - match field.poll() { - Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"), + match field.next().await { + Some(Ok(chunk)) => assert_eq!(chunk, "data"), _ => unreachable!(), } - match field.poll() { - Ok(Async::Ready(None)) => (), + match field.next().await { + None => (), _ => unreachable!(), } } _ => unreachable!(), } - match multipart.poll().unwrap() { - Async::Ready(None) => (), + match multipart.next().await { + None => (), _ => unreachable!(), } }); @@ -967,15 +976,15 @@ mod tests { #[test] fn test_stream() { - run_on(|| { + block_on(async { let (sender, payload) = create_stream(); let (bytes, headers) = create_simple_request_with_header(); - sender.unbounded_send(Ok(bytes)).unwrap(); + sender.send(Ok(bytes)).unwrap(); let mut multipart = Multipart::new(&headers, payload); - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { + match multipart.next().await.unwrap() { + Ok(mut field) => { let cd = field.content_disposition().unwrap(); assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); @@ -983,37 +992,37 @@ mod tests { assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().subtype(), mime::PLAIN); - match field.poll().unwrap() { - Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"), + match field.next().await.unwrap() { + Ok(chunk) => assert_eq!(chunk, "test"), _ => unreachable!(), } - match field.poll().unwrap() { - Async::Ready(None) => (), + match field.next().await { + None => (), _ => unreachable!(), } } _ => unreachable!(), } - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { + match multipart.next().await { + Some(Ok(mut field)) => { assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().subtype(), mime::PLAIN); - match field.poll() { - Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"), + match field.next().await { + Some(Ok(chunk)) => assert_eq!(chunk, "data"), _ => unreachable!(), } - match field.poll() { - Ok(Async::Ready(None)) => (), + match field.next().await { + None => (), _ => unreachable!(), } } _ => unreachable!(), } - match multipart.poll().unwrap() { - Async::Ready(None) => (), + match multipart.next().await { + None => (), _ => unreachable!(), } }); @@ -1021,26 +1030,26 @@ mod tests { #[test] fn test_basic() { - run_on(|| { + block_on(async { let (_, payload) = Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(payload.buf.len(), 0); - payload.poll_stream().unwrap(); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); assert_eq!(None, payload.read_max(1).unwrap()); }) } #[test] fn test_eof() { - run_on(|| { + block_on(async { let (mut sender, payload) = Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(None, payload.read_max(4).unwrap()); sender.feed_data(Bytes::from("data")); sender.feed_eof(); - payload.poll_stream().unwrap(); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap()); assert_eq!(payload.buf.len(), 0); @@ -1051,24 +1060,24 @@ mod tests { #[test] fn test_err() { - run_on(|| { + block_on(async { let (mut sender, payload) = Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(None, payload.read_max(1).unwrap()); sender.set_error(PayloadError::Incomplete(None)); - payload.poll_stream().err().unwrap(); + lazy(|cx| payload.poll_stream(cx)).await.err().unwrap(); }) } #[test] fn test_readmax() { - run_on(|| { + block_on(async { let (mut sender, payload) = Payload::create(false); let mut payload = PayloadBuffer::new(payload); sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line2")); - payload.poll_stream().unwrap(); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); assert_eq!(payload.buf.len(), 10); assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap()); @@ -1081,7 +1090,7 @@ mod tests { #[test] fn test_readexactly() { - run_on(|| { + block_on(async { let (mut sender, payload) = Payload::create(false); let mut payload = PayloadBuffer::new(payload); @@ -1089,7 +1098,7 @@ mod tests { sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line2")); - payload.poll_stream().unwrap(); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2)); assert_eq!(payload.buf.len(), 8); @@ -1101,7 +1110,7 @@ mod tests { #[test] fn test_readuntil() { - run_on(|| { + block_on(async { let (mut sender, payload) = Payload::create(false); let mut payload = PayloadBuffer::new(payload); @@ -1109,7 +1118,7 @@ mod tests { sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line2")); - payload.poll_stream().unwrap(); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); assert_eq!( Some(Bytes::from("line")), diff --git a/actix-session/Cargo.toml b/actix-session/Cargo.toml index d973661ef..3ce2a8b40 100644 --- a/actix-session/Cargo.toml +++ b/actix-session/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-session" -version = "0.2.0" +version = "0.3.0-alpha.1" authors = ["Nikolay Kim "] description = "Session for actix web framework." readme = "README.md" @@ -24,15 +24,15 @@ default = ["cookie-session"] cookie-session = ["actix-web/secure-cookies"] [dependencies] -actix-web = "1.0.0" -actix-service = "0.4.1" +actix-web = "2.0.0-alpha.1" +actix-service = "1.0.0-alpha.1" bytes = "0.4" derive_more = "0.15.0" -futures = "0.1.25" -hashbrown = "0.5.0" +futures = "0.3.1" +hashbrown = "0.6.3" serde = "1.0" serde_json = "1.0" time = "0.1.42" [dev-dependencies] -actix-rt = "0.2.2" +actix-rt = "1.0.0-alpha.1" diff --git a/actix-session/src/cookie.rs b/actix-session/src/cookie.rs index 192737780..bb5fba97b 100644 --- a/actix-session/src/cookie.rs +++ b/actix-session/src/cookie.rs @@ -17,6 +17,7 @@ use std::collections::HashMap; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; @@ -24,8 +25,7 @@ use actix_web::dev::{ServiceRequest, ServiceResponse}; use actix_web::http::{header::SET_COOKIE, HeaderValue}; use actix_web::{Error, HttpMessage, ResponseError}; use derive_more::{Display, From}; -use futures::future::{ok, Future, FutureResult}; -use futures::Poll; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use serde_json::error::Error as JsonError; use crate::{Session, SessionStatus}; @@ -284,7 +284,7 @@ where type Error = S::Error; type InitError = (); type Transform = CookieSessionMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CookieSessionMiddleware { @@ -309,10 +309,10 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = S::Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } /// On first request, a new session cookie is returned in response, regardless @@ -325,29 +325,36 @@ where let (is_new, state) = self.inner.load(&req); Session::set_session(state.into_iter(), &mut req); - Box::new(self.service.call(req).map(move |mut res| { - match Session::get_changes(&mut res) { - (SessionStatus::Changed, Some(state)) - | (SessionStatus::Renewed, Some(state)) => { - res.checked_expr(|res| inner.set_cookie(res, state)) - } - (SessionStatus::Unchanged, _) => - // set a new session cookie upon first request (new client) - { - if is_new { - let state: HashMap = HashMap::new(); - res.checked_expr(|res| inner.set_cookie(res, state.into_iter())) - } else { + let fut = self.service.call(req); + + async move { + fut.await.map(|mut res| { + match Session::get_changes(&mut res) { + (SessionStatus::Changed, Some(state)) + | (SessionStatus::Renewed, Some(state)) => { + res.checked_expr(|res| inner.set_cookie(res, state)) + } + (SessionStatus::Unchanged, _) => + // set a new session cookie upon first request (new client) + { + if is_new { + let state: HashMap = HashMap::new(); + res.checked_expr(|res| { + inner.set_cookie(res, state.into_iter()) + }) + } else { + res + } + } + (SessionStatus::Purged, _) => { + let _ = inner.remove_cookie(&mut res); res } + _ => res, } - (SessionStatus::Purged, _) => { - let _ = inner.remove_cookie(&mut res); - res - } - _ => res, - } - })) + }) + } + .boxed_local() } } @@ -359,101 +366,123 @@ mod tests { #[test] fn cookie_session() { - let mut app = test::init_service( - App::new() - .wrap(CookieSession::signed(&[0; 32]).secure(false)) - .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" - })), - ); + test::block_on(async { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::signed(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; - let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); - assert!(response - .response() - .cookies() - .find(|c| c.name() == "actix-session") - .is_some()); + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + }) } #[test] fn private_cookie() { - let mut app = test::init_service( - App::new() - .wrap(CookieSession::private(&[0; 32]).secure(false)) - .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" - })), - ); + test::block_on(async { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::private(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; - let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); - assert!(response - .response() - .cookies() - .find(|c| c.name() == "actix-session") - .is_some()); + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + }) } #[test] fn cookie_session_extractor() { - let mut app = test::init_service( - App::new() - .wrap(CookieSession::signed(&[0; 32]).secure(false)) - .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" - })), - ); + test::block_on(async { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::signed(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; - let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); - assert!(response - .response() - .cookies() - .find(|c| c.name() == "actix-session") - .is_some()); + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + }) } #[test] fn basics() { - let mut app = test::init_service( - App::new() - .wrap( - CookieSession::signed(&[0; 32]) - .path("/test/") - .name("actix-test") - .domain("localhost") - .http_only(true) - .same_site(SameSite::Lax) - .max_age(100), - ) - .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" - })) - .service(web::resource("/test/").to(|ses: Session| { - let val: usize = ses.get("counter").unwrap().unwrap(); - format!("counter: {}", val) - })), - ); + test::block_on(async { + let mut app = test::init_service( + App::new() + .wrap( + CookieSession::signed(&[0; 32]) + .path("/test/") + .name("actix-test") + .domain("localhost") + .http_only(true) + .same_site(SameSite::Lax) + .max_age(100), + ) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })) + .service(web::resource("/test/").to(|ses: Session| { + async move { + let val: usize = ses.get("counter").unwrap().unwrap(); + format!("counter: {}", val) + } + })), + ) + .await; - let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); - let cookie = response - .response() - .cookies() - .find(|c| c.name() == "actix-test") - .unwrap() - .clone(); - assert_eq!(cookie.path().unwrap(), "/test/"); + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + let cookie = response + .response() + .cookies() + .find(|c| c.name() == "actix-test") + .unwrap() + .clone(); + assert_eq!(cookie.path().unwrap(), "/test/"); - let request = test::TestRequest::with_uri("/test/") - .cookie(cookie) - .to_request(); - let body = test::read_response(&mut app, request); - assert_eq!(body, Bytes::from_static(b"counter: 100")); + let request = test::TestRequest::with_uri("/test/") + .cookie(cookie) + .to_request(); + let body = test::read_response(&mut app, request).await; + assert_eq!(body, Bytes::from_static(b"counter: 100")); + }) } } diff --git a/actix-session/src/lib.rs b/actix-session/src/lib.rs index 2e9e51714..def35a1e9 100644 --- a/actix-session/src/lib.rs +++ b/actix-session/src/lib.rs @@ -47,6 +47,7 @@ use std::rc::Rc; use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse}; use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; +use futures::future::{ok, Ready}; use hashbrown::HashMap; use serde::de::DeserializeOwned; use serde::Serialize; @@ -230,12 +231,12 @@ impl Session { /// ``` impl FromRequest for Session { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = (); #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(Session::get_session(&mut *req.extensions_mut())) + ok(Session::get_session(&mut *req.extensions_mut())) } } diff --git a/actix-web-actors/CHANGES.md b/actix-web-actors/CHANGES.md index 0d1df7e55..c1417c9c4 100644 --- a/actix-web-actors/CHANGES.md +++ b/actix-web-actors/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [1.0.3] - 2019-11-14 + +* Update actix-web and actix-http dependencies + ## [1.0.2] - 2019-07-20 * Add `ws::start_with_addr()`, returning the address of the created actor, along diff --git a/actix-web-actors/Cargo.toml b/actix-web-actors/Cargo.toml index 6b1d23791..a483d0da9 100644 --- a/actix-web-actors/Cargo.toml +++ b/actix-web-actors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web-actors" -version = "1.0.2" +version = "1.0.3" authors = ["Nikolay Kim "] description = "Actix actors support for actix web framework." readme = "README.md" @@ -23,8 +23,8 @@ invalid-ping-payload = ["actix-http/invalid-ping-payload"] [dependencies] actix = "0.8.3" -actix-web = "1.0.3" -actix-http = { path = "../actix-http" } # "0.2.5" +actix-web = "1.0.9" +actix-http = "0.2.11" actix-codec = "0.1.2" bytes = "0.4" futures = "0.1.25" diff --git a/actix-web-codegen/CHANGES.md b/actix-web-codegen/CHANGES.md index 7cc0c164f..2beea62cf 100644 --- a/actix-web-codegen/CHANGES.md +++ b/actix-web-codegen/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [0.1.3] - 2019-10-14 + +* Bump up `syn` & `quote` to 1.0 + +* Provide better error message + ## [0.1.2] - 2019-06-04 * Add macros for head, options, trace, connect and patch http methods diff --git a/actix-web-codegen/Cargo.toml b/actix-web-codegen/Cargo.toml index 29abb4897..5336f60b9 100644 --- a/actix-web-codegen/Cargo.toml +++ b/actix-web-codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web-codegen" -version = "0.1.2" +version = "0.2.0-alpha.1" description = "Actix web proc macros" readme = "README.md" authors = ["Nikolay Kim "] @@ -12,11 +12,12 @@ workspace = ".." proc-macro = true [dependencies] -quote = "0.6.12" -syn = { version = "0.15.34", features = ["full", "parsing", "extra-traits"] } +quote = "1" +syn = { version = "1", features = ["full", "parsing"] } +proc-macro2 = "1" [dev-dependencies] -actix-web = { version = "1.0.0" } -actix-http = { version = "0.2.4", features=["ssl"] } -actix-http-test = { version = "0.2.0", features=["ssl"] } -futures = { version = "0.1" } +actix-web = { version = "2.0.0-alpha.1" } +actix-http = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } +futures = { version = "0.3.1" } diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs index b3ae7dd9b..0a727ed69 100644 --- a/actix-web-codegen/src/lib.rs +++ b/actix-web-codegen/src/lib.rs @@ -35,8 +35,8 @@ //! use futures::{future, Future}; //! //! #[get("/test")] -//! fn async_test() -> impl Future { -//! future::ok(HttpResponse::Ok().finish()) +//! async fn async_test() -> Result { +//! Ok(HttpResponse::Ok().finish()) //! } //! ``` @@ -58,7 +58,10 @@ use syn::parse_macro_input; #[proc_macro_attribute] pub fn get(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Get); + let gen = match route::Route::new(args, input, route::GuardType::Get) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -70,7 +73,10 @@ pub fn get(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn post(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Post); + let gen = match route::Route::new(args, input, route::GuardType::Post) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -82,7 +88,10 @@ pub fn post(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn put(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Put); + let gen = match route::Route::new(args, input, route::GuardType::Put) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -94,7 +103,10 @@ pub fn put(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Delete); + let gen = match route::Route::new(args, input, route::GuardType::Delete) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -106,7 +118,10 @@ pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn head(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Head); + let gen = match route::Route::new(args, input, route::GuardType::Head) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -118,7 +133,10 @@ pub fn head(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn connect(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Connect); + let gen = match route::Route::new(args, input, route::GuardType::Connect) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -130,7 +148,10 @@ pub fn connect(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn options(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Options); + let gen = match route::Route::new(args, input, route::GuardType::Options) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -142,7 +163,10 @@ pub fn options(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn trace(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Trace); + let gen = match route::Route::new(args, input, route::GuardType::Trace) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -154,6 +178,9 @@ pub fn trace(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn patch(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Patch); + let gen = match route::Route::new(args, input, route::GuardType::Patch) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } diff --git a/actix-web-codegen/src/route.rs b/actix-web-codegen/src/route.rs index 5215f60c8..f8e2496c4 100644 --- a/actix-web-codegen/src/route.rs +++ b/actix-web-codegen/src/route.rs @@ -1,21 +1,23 @@ extern crate proc_macro; -use std::fmt; - use proc_macro::TokenStream; -use quote::quote; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{AttributeArgs, Ident, NestedMeta}; enum ResourceType { Async, Sync, } -impl fmt::Display for ResourceType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - ResourceType::Async => write!(f, "to_async"), - ResourceType::Sync => write!(f, "to"), - } +impl ToTokens for ResourceType { + fn to_tokens(&self, stream: &mut TokenStream2) { + let ident = match self { + ResourceType::Async => "to", + ResourceType::Sync => "to", + }; + let ident = Ident::new(ident, Span::call_site()); + stream.append(ident); } } @@ -32,63 +34,89 @@ pub enum GuardType { Patch, } -impl fmt::Display for GuardType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - GuardType::Get => write!(f, "Get"), - GuardType::Post => write!(f, "Post"), - GuardType::Put => write!(f, "Put"), - GuardType::Delete => write!(f, "Delete"), - GuardType::Head => write!(f, "Head"), - GuardType::Connect => write!(f, "Connect"), - GuardType::Options => write!(f, "Options"), - GuardType::Trace => write!(f, "Trace"), - GuardType::Patch => write!(f, "Patch"), +impl GuardType { + fn as_str(&self) -> &'static str { + match self { + GuardType::Get => "Get", + GuardType::Post => "Post", + GuardType::Put => "Put", + GuardType::Delete => "Delete", + GuardType::Head => "Head", + GuardType::Connect => "Connect", + GuardType::Options => "Options", + GuardType::Trace => "Trace", + GuardType::Patch => "Patch", } } } -pub struct Args { - name: syn::Ident, - path: String, - ast: syn::ItemFn, - resource_type: ResourceType, - pub guard: GuardType, - pub extra_guards: Vec, +impl ToTokens for GuardType { + fn to_tokens(&self, stream: &mut TokenStream2) { + let ident = self.as_str(); + let ident = Ident::new(ident, Span::call_site()); + stream.append(ident); + } } -impl fmt::Display for Args { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let ast = &self.ast; - let guards = format!(".guard(actix_web::guard::{}())", self.guard); - let guards = self.extra_guards.iter().fold(guards, |acc, val| { - format!("{}.guard(actix_web::guard::fn_guard({}))", acc, val) - }); +struct Args { + path: syn::LitStr, + guards: Vec, +} - write!( - f, - " -#[allow(non_camel_case_types)] -pub struct {name}; - -impl actix_web::dev::HttpServiceFactory for {name} {{ - fn register(self, config: &mut actix_web::dev::AppService) {{ - {ast} - - let resource = actix_web::Resource::new(\"{path}\"){guards}.{to}({name}); - - actix_web::dev::HttpServiceFactory::register(resource, config) - }} -}}", - name = self.name, - ast = quote!(#ast), - path = self.path, - guards = guards, - to = self.resource_type - ) +impl Args { + fn new(args: AttributeArgs) -> syn::Result { + let mut path = None; + let mut guards = Vec::new(); + for arg in args { + match arg { + NestedMeta::Lit(syn::Lit::Str(lit)) => match path { + None => { + path = Some(lit); + } + _ => { + return Err(syn::Error::new_spanned( + lit, + "Multiple paths specified! Should be only one!", + )); + } + }, + NestedMeta::Meta(syn::Meta::NameValue(nv)) => { + if nv.path.is_ident("guard") { + if let syn::Lit::Str(lit) = nv.lit { + guards.push(Ident::new(&lit.value(), Span::call_site())); + } else { + return Err(syn::Error::new_spanned( + nv.lit, + "Attribute guard expects literal string!", + )); + } + } else { + return Err(syn::Error::new_spanned( + nv.path, + "Unknown attribute key is specified. Allowed: guard", + )); + } + } + arg => { + return Err(syn::Error::new_spanned(arg, "Unknown attribute")); + } + } + } + Ok(Args { + path: path.unwrap(), + guards, + }) } } +pub struct Route { + name: syn::Ident, + args: Args, + ast: syn::ItemFn, + resource_type: ResourceType, + guard: GuardType, +} + fn guess_resource_type(typ: &syn::Type) -> ResourceType { let mut guess = ResourceType::Sync; @@ -111,75 +139,73 @@ fn guess_resource_type(typ: &syn::Type) -> ResourceType { guess } -impl Args { - pub fn new(args: &[syn::NestedMeta], input: TokenStream, guard: GuardType) -> Self { +impl Route { + pub fn new( + args: AttributeArgs, + input: TokenStream, + guard: GuardType, + ) -> syn::Result { if args.is_empty() { - panic!( - "invalid server definition, expected: #[{}(\"some path\")]", - guard - ); + return Err(syn::Error::new( + Span::call_site(), + format!( + r#"invalid server definition, expected #[{}("")]"#, + guard.as_str().to_ascii_lowercase() + ), + )); } + let ast: syn::ItemFn = syn::parse(input)?; + let name = ast.sig.ident.clone(); - let ast: syn::ItemFn = syn::parse(input).expect("Parse input as function"); - let name = ast.ident.clone(); + let args = Args::new(args)?; - let mut extra_guards = Vec::new(); - let mut path = None; - for arg in args { - match arg { - syn::NestedMeta::Literal(syn::Lit::Str(ref fname)) => { - if path.is_some() { - panic!("Multiple paths specified! Should be only one!") - } - let fname = quote!(#fname).to_string(); - path = Some(fname.as_str()[1..fname.len() - 1].to_owned()) - } - syn::NestedMeta::Meta(syn::Meta::NameValue(ident)) => { - match ident.ident.to_string().to_lowercase().as_str() { - "guard" => match ident.lit { - syn::Lit::Str(ref text) => extra_guards.push(text.value()), - _ => panic!("Attribute guard expects literal string!"), - }, - attr => panic!( - "Unknown attribute key is specified: {}. Allowed: guard", - attr - ), - } - } - attr => panic!("Unknown attribute{:?}", attr), - } - } - - let resource_type = if ast.asyncness.is_some() { + let resource_type = if ast.sig.asyncness.is_some() { ResourceType::Async } else { - match ast.decl.output { - syn::ReturnType::Default => panic!( - "Function {} has no return type. Cannot be used as handler", - name - ), + match ast.sig.output { + syn::ReturnType::Default => { + return Err(syn::Error::new_spanned( + ast, + "Function has no return type. Cannot be used as handler", + )); + } syn::ReturnType::Type(_, ref typ) => guess_resource_type(typ.as_ref()), } }; - let path = path.unwrap(); - - Self { + Ok(Self { name, - path, + args, ast, resource_type, guard, - extra_guards, - } + }) } pub fn generate(&self) -> TokenStream { - let text = self.to_string(); + let name = &self.name; + let guard = &self.guard; + let ast = &self.ast; + let path = &self.args.path; + let extra_guards = &self.args.guards; + let resource_type = &self.resource_type; + let stream = quote! { + #[allow(non_camel_case_types)] + pub struct #name; - match text.parse() { - Ok(res) => res, - Err(error) => panic!("Error: {:?}\nGenerated code: {}", error, text), - } + impl actix_web::dev::HttpServiceFactory for #name { + fn register(self, config: &mut actix_web::dev::AppService) { + #ast + + let resource = actix_web::Resource::new(#path) + .guard(actix_web::guard::#guard()) + #(.guard(actix_web::guard::fn_guard(#extra_guards)))* + .#resource_type(#name); + + actix_web::dev::HttpServiceFactory::register(resource, config) + } + } + }; + stream.into() } } diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs index 8ecc81dc1..18c01f374 100644 --- a/actix-web-codegen/tests/test_macro.rs +++ b/actix-web-codegen/tests/test_macro.rs @@ -1,157 +1,163 @@ use actix_http::HttpService; -use actix_http_test::TestServer; +use actix_http_test::{block_on, TestServer}; use actix_web::{http, web::Path, App, HttpResponse, Responder}; use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace}; use futures::{future, Future}; #[get("/test")] -fn test() -> impl Responder { +async fn test() -> impl Responder { HttpResponse::Ok() } #[put("/test")] -fn put_test() -> impl Responder { +async fn put_test() -> impl Responder { HttpResponse::Created() } #[patch("/test")] -fn patch_test() -> impl Responder { +async fn patch_test() -> impl Responder { HttpResponse::Ok() } #[post("/test")] -fn post_test() -> impl Responder { +async fn post_test() -> impl Responder { HttpResponse::NoContent() } #[head("/test")] -fn head_test() -> impl Responder { +async fn head_test() -> impl Responder { HttpResponse::Ok() } #[connect("/test")] -fn connect_test() -> impl Responder { +async fn connect_test() -> impl Responder { HttpResponse::Ok() } #[options("/test")] -fn options_test() -> impl Responder { +async fn options_test() -> impl Responder { HttpResponse::Ok() } #[trace("/test")] -fn trace_test() -> impl Responder { +async fn trace_test() -> impl Responder { HttpResponse::Ok() } #[get("/test")] -fn auto_async() -> impl Future { +fn auto_async() -> impl Future> { future::ok(HttpResponse::Ok().finish()) } #[get("/test")] -fn auto_sync() -> impl Future { +fn auto_sync() -> impl Future> { future::ok(HttpResponse::Ok().finish()) } #[put("/test/{param}")] -fn put_param_test(_: Path) -> impl Responder { +async fn put_param_test(_: Path) -> impl Responder { HttpResponse::Created() } #[delete("/test/{param}")] -fn delete_param_test(_: Path) -> impl Responder { +async fn delete_param_test(_: Path) -> impl Responder { HttpResponse::NoContent() } #[get("/test/{param}")] -fn get_param_test(_: Path) -> impl Responder { +async fn get_param_test(_: Path) -> impl Responder { HttpResponse::Ok() } #[test] fn test_params() { - let mut srv = TestServer::new(|| { - HttpService::new( - App::new() - .service(get_param_test) - .service(put_param_test) - .service(delete_param_test), - ) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new( + App::new() + .service(get_param_test) + .service(put_param_test) + .service(delete_param_test), + ) + }); - let request = srv.request(http::Method::GET, srv.url("/test/it")); - let response = srv.block_on(request.send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::OK); + let request = srv.request(http::Method::GET, srv.url("/test/it")); + let response = request.send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::OK); - let request = srv.request(http::Method::PUT, srv.url("/test/it")); - let response = srv.block_on(request.send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::CREATED); + let request = srv.request(http::Method::PUT, srv.url("/test/it")); + let response = request.send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::CREATED); - let request = srv.request(http::Method::DELETE, srv.url("/test/it")); - let response = srv.block_on(request.send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::NO_CONTENT); + let request = srv.request(http::Method::DELETE, srv.url("/test/it")); + let response = request.send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::NO_CONTENT); + }) } #[test] fn test_body() { - let mut srv = TestServer::new(|| { - HttpService::new( - App::new() - .service(post_test) - .service(put_test) - .service(head_test) - .service(connect_test) - .service(options_test) - .service(trace_test) - .service(patch_test) - .service(test), - ) - }); - let request = srv.request(http::Method::GET, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new( + App::new() + .service(post_test) + .service(put_test) + .service(head_test) + .service(connect_test) + .service(options_test) + .service(trace_test) + .service(patch_test) + .service(test), + ) + }); + let request = srv.request(http::Method::GET, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); - let request = srv.request(http::Method::HEAD, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::HEAD, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); - let request = srv.request(http::Method::CONNECT, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::CONNECT, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); - let request = srv.request(http::Method::OPTIONS, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::OPTIONS, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); - let request = srv.request(http::Method::TRACE, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::TRACE, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); - let request = srv.request(http::Method::PATCH, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::PATCH, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); - let request = srv.request(http::Method::PUT, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.status(), http::StatusCode::CREATED); + let request = srv.request(http::Method::PUT, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.status(), http::StatusCode::CREATED); - let request = srv.request(http::Method::POST, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.status(), http::StatusCode::NO_CONTENT); + let request = srv.request(http::Method::POST, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.status(), http::StatusCode::NO_CONTENT); - let request = srv.request(http::Method::GET, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::GET, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn test_auto_async() { - let mut srv = TestServer::new(|| HttpService::new(App::new().service(auto_async))); + block_on(async { + let srv = TestServer::start(|| HttpService::new(App::new().service(auto_async))); - let request = srv.request(http::Method::GET, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + let request = srv.request(http::Method::GET, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + }) } diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 6f8fe2dbd..89423f80e 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.8] - 2019-11-06 + +* Add support for setting query from Serialize type for client request. + ## [0.2.7] - 2019-09-25 diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 6f0f63f92..eed7eabe7 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "awc" -version = "0.2.7" +version = "0.3.0-alpha.1" authors = ["Nikolay Kim "] description = "Actix http client." readme = "README.md" @@ -21,16 +21,16 @@ name = "awc" path = "src/lib.rs" [package.metadata.docs.rs] -features = ["ssl", "brotli", "flate2-zlib"] +features = ["openssl", "brotli", "flate2-zlib"] [features] default = ["brotli", "flate2-zlib"] # openssl -ssl = ["openssl", "actix-http/ssl"] +openssl = ["open-ssl", "actix-http/openssl"] # rustls -rust-tls = ["rustls", "actix-http/rust-tls"] +# rustls = ["rust-tls", "actix-http/rustls"] # brotli encoding, requires c compiler brotli = ["actix-http/brotli"] @@ -42,13 +42,14 @@ flate2-zlib = ["actix-http/flate2-zlib"] flate2-rust = ["actix-http/flate2-rust"] [dependencies] -actix-codec = "0.1.2" -actix-service = "0.4.1" -actix-http = "0.2.10" +actix-codec = "0.2.0-alpha.1" +actix-service = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" + base64 = "0.10.1" bytes = "0.4" derive_more = "0.15.0" -futures = "0.1.25" +futures = "0.3.1" log =" 0.4" mime = "0.3" percent-encoding = "2.1" @@ -56,21 +57,21 @@ rand = "0.7" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.6.1" -tokio-timer = "0.2.8" -openssl = { version="0.10", optional = true } -rustls = { version = "0.15.2", optional = true } +tokio-timer = "0.3.0-alpha.6" +open-ssl = { version="0.10", package="openssl", optional = true } +rust-tls = { version = "0.16.0", package="rustls", optional = true, features = ["dangerous_configuration"] } [dev-dependencies] -actix-rt = "0.2.2" -actix-web = { version = "1.0.0", features=["ssl"] } -actix-http = { version = "0.2.10", features=["ssl"] } -actix-http-test = { version = "0.2.0", features=["ssl"] } -actix-utils = "0.4.1" -actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] } +actix-rt = "1.0.0-alpha.1" +actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] } +actix-web = { version = "2.0.0-alpha.1", features=["openssl"] } +actix-http = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-utils = "0.5.0-alpha.1" +actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } brotli2 = { version="0.3.2" } flate2 = { version="1.0.2" } env_logger = "0.6" rand = "0.7" tokio-tcp = "0.1" -webpki = "0.19" -rustls = { version = "0.15.2", features = ["dangerous_configuration"] } +webpki = { version = "0.21" } diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 97864d300..cc92fdbb6 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,4 +1,6 @@ +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::{fmt, io, net}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; @@ -10,7 +12,7 @@ use actix_http::h1::ClientCodec; use actix_http::http::HeaderMap; use actix_http::{RequestHead, RequestHeadType, ResponseHead}; use actix_service::Service; -use futures::{Future, Poll}; +use futures::future::{FutureExt, LocalBoxFuture}; use crate::response::ClientResponse; @@ -22,7 +24,7 @@ pub(crate) trait Connect { head: RequestHead, body: Body, addr: Option, - ) -> Box>; + ) -> LocalBoxFuture<'static, Result>; fn send_request_extra( &mut self, @@ -30,18 +32,16 @@ pub(crate) trait Connect { extra_headers: Option, body: Body, addr: Option, - ) -> Box>; + ) -> LocalBoxFuture<'static, Result>; /// Send request, returns Response and Framed fn open_tunnel( &mut self, head: RequestHead, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request and extra headers, returns Response and Framed @@ -50,11 +50,9 @@ pub(crate) trait Connect { head: Rc, extra_headers: Option, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >; } @@ -72,21 +70,23 @@ where head: RequestHead, body: Body, addr: Option, - ) -> Box> { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| { - connection.send_request(RequestHeadType::from(head), body) - }) - .map(|(head, payload)| ClientResponse::new(head, payload)), - ) + ) -> LocalBoxFuture<'static, Result> { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + connection + .send_request(RequestHeadType::from(head), body) + .await + .map(|(head, payload)| ClientResponse::new(head, payload)) + } + .boxed_local() } fn send_request_extra( @@ -95,51 +95,51 @@ where extra_headers: Option, body: Body, addr: Option, - ) -> Box> { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| { - connection - .send_request(RequestHeadType::Rc(head, extra_headers), body) - }) - .map(|(head, payload)| ClientResponse::new(head, payload)), - ) + ) -> LocalBoxFuture<'static, Result> { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + let (head, payload) = connection + .send_request(RequestHeadType::Rc(head, extra_headers), body) + .await?; + + Ok(ClientResponse::new(head, payload)) + } + .boxed_local() } fn open_tunnel( &mut self, head: RequestHead, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, > { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| { - connection.open_tunnel(RequestHeadType::from(head)) - }) - .map(|(head, framed)| { - let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); - (head, framed) - }), - ) + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + let (head, framed) = + connection.open_tunnel(RequestHeadType::from(head)).await?; + + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok((head, framed)) + } + .boxed_local() } fn open_tunnel_extra( @@ -147,48 +147,47 @@ where head: Rc, extra_headers: Option, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, > { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| { - connection.open_tunnel(RequestHeadType::Rc(head, extra_headers)) - }) - .map(|(head, framed)| { - let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); - (head, framed) - }), - ) + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + let (head, framed) = connection + .open_tunnel(RequestHeadType::Rc(head, extra_headers)) + .await?; + + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok((head, framed)) + } + .boxed_local() } } trait AsyncSocket { - fn as_read(&self) -> &dyn AsyncRead; - fn as_read_mut(&mut self) -> &mut dyn AsyncRead; - fn as_write(&mut self) -> &mut dyn AsyncWrite; + fn as_read(&self) -> &(dyn AsyncRead + Unpin); + fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin); + fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin); } -struct Socket(T); +struct Socket(T); -impl AsyncSocket for Socket { - fn as_read(&self) -> &dyn AsyncRead { +impl AsyncSocket for Socket { + fn as_read(&self) -> &(dyn AsyncRead + Unpin) { &self.0 } - fn as_read_mut(&mut self) -> &mut dyn AsyncRead { + fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) { &mut self.0 } - fn as_write(&mut self) -> &mut dyn AsyncWrite { + fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) { &mut self.0 } } @@ -201,30 +200,37 @@ impl fmt::Debug for BoxedSocket { } } -impl io::Read for BoxedSocket { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.as_read_mut().read(buf) - } -} - impl AsyncRead for BoxedSocket { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { self.0.as_read().prepare_uninitialized_buffer(buf) } -} -impl io::Write for BoxedSocket { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.as_write().write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.0.as_write().flush() + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) } } impl AsyncWrite for BoxedSocket { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.0.as_write().shutdown() + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx) } } diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index d9f65d431..61ba87aad 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -82,7 +82,7 @@ impl FrozenClientRequest { /// Send a streaming body. pub fn send_stream(&self, stream: S) -> SendClientRequest where - S: Stream + 'static, + S: Stream> + Unpin + 'static, E: Into + 'static, { RequestSender::Rc(self.head.clone(), None).send_stream( @@ -203,7 +203,7 @@ impl FrozenSendBuilder { /// Complete request construction and send a streaming body. pub fn send_stream(self, stream: S) -> SendClientRequest where - S: Stream + 'static, + S: Stream> + Unpin + 'static, E: Into + 'static, { if let Some(e) = self.err { diff --git a/awc/src/lib.rs b/awc/src/lib.rs index 58c9056b2..d6cea6ded 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -7,18 +7,18 @@ //! use awc::Client; //! //! fn main() { -//! System::new("test").block_on(lazy(|| { +//! System::new("test").block_on(async { //! let mut client = Client::default(); //! //! client.get("http://www.rust-lang.org") // <- Create request builder //! .header("User-Agent", "Actix-web") //! .send() // <- Send http request -//! .map_err(|_| ()) +//! .await //! .and_then(|response| { // <- server http response //! println!("Response: {:?}", response); //! Ok(()) //! }) -//! })); +//! }); //! } //! ``` use std::cell::RefCell; @@ -52,23 +52,22 @@ use self::connect::{Connect, ConnectorWrapper}; /// An HTTP Client /// /// ```rust -/// # use futures::future::{Future, lazy}; /// use actix_rt::System; /// use awc::Client; /// /// fn main() { -/// System::new("test").block_on(lazy(|| { +/// System::new("test").block_on(async { /// let mut client = Client::default(); /// /// client.get("http://www.rust-lang.org") // <- Create request builder /// .header("User-Agent", "Actix-web") /// .send() // <- Send http request -/// .map_err(|_| ()) +/// .await /// .and_then(|response| { // <- server http response /// println!("Response: {:?}", response); /// Ok(()) /// }) -/// })); +/// }); /// } /// ``` #[derive(Clone)] diff --git a/awc/src/request.rs b/awc/src/request.rs index a90cf60b2..c6b09e95c 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -37,21 +37,21 @@ const HTTPS_ENCODING: &str = "gzip, deflate"; /// builder-like pattern. /// /// ```rust -/// use futures::future::{Future, lazy}; /// use actix_rt::System; /// /// fn main() { -/// System::new("test").block_on(lazy(|| { -/// awc::Client::new() +/// System::new("test").block_on(async { +/// let response = awc::Client::new() /// .get("http://www.rust-lang.org") // <- Create request builder /// .header("User-Agent", "Actix-web") /// .send() // <- Send http request -/// .map_err(|_| ()) -/// .and_then(|response| { // <- server http response -/// println!("Response: {:?}", response); -/// Ok(()) +/// .await; +/// +/// response.and_then(|response| { // <- server http response +/// println!("Response: {:?}", response); +/// Ok(()) /// }) -/// })); +/// }); /// } /// ``` pub struct ClientRequest { @@ -158,7 +158,7 @@ impl ClientRequest { /// /// ```rust /// fn main() { - /// # actix_rt::System::new("test").block_on(futures::future::lazy(|| { + /// # actix_rt::System::new("test").block_on(futures::future::lazy(|_| { /// let req = awc::Client::new() /// .get("http://www.rust-lang.org") /// .set(awc::http::header::Date::now()) @@ -186,13 +186,13 @@ impl ClientRequest { /// use awc::{http, Client}; /// /// fn main() { - /// # actix_rt::System::new("test").block_on(futures::future::lazy(|| { + /// # actix_rt::System::new("test").block_on(async { /// let req = Client::new() /// .get("http://www.rust-lang.org") /// .header("X-TEST", "value") /// .header(http::header::CONTENT_TYPE, "application/json"); /// # Ok::<_, ()>(()) - /// # })); + /// # }); /// } /// ``` pub fn header(mut self, key: K, value: V) -> Self @@ -309,9 +309,8 @@ impl ClientRequest { /// /// ```rust /// # use actix_rt::System; - /// # use futures::future::{lazy, Future}; /// fn main() { - /// System::new("test").block_on(lazy(|| { + /// System::new("test").block_on(async { /// awc::Client::new().get("https://www.rust-lang.org") /// .cookie( /// awc::http::Cookie::build("name", "value") @@ -322,12 +321,12 @@ impl ClientRequest { /// .finish(), /// ) /// .send() - /// .map_err(|_| ()) + /// .await /// .and_then(|response| { /// println!("Response: {:?}", response); /// Ok(()) /// }) - /// })); + /// }); /// } /// ``` pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { @@ -382,6 +381,27 @@ impl ClientRequest { } } + /// Sets the query part of the request + pub fn query( + mut self, + query: &T, + ) -> Result { + let mut parts = self.head.uri.clone().into_parts(); + + if let Some(path_and_query) = parts.path_and_query { + let query = serde_urlencoded::to_string(query)?; + let path = path_and_query.path(); + parts.path_and_query = format!("{}?{}", path, query).parse().ok(); + + match Uri::from_parts(parts) { + Ok(uri) => self.head.uri = uri, + Err(e) => self.err = Some(e.into()), + } + } + + Ok(self) + } + /// Freeze request builder and construct `FrozenClientRequest`, /// which could be used for sending same request multiple times. pub fn freeze(self) -> Result { @@ -457,7 +477,7 @@ impl ClientRequest { /// Set an streaming body and generate `ClientRequest`. pub fn send_stream(self, stream: S) -> SendClientRequest where - S: Stream + 'static, + S: Stream> + Unpin + 'static, E: Into + 'static, { let slf = match self.prep_for_sending() { @@ -690,4 +710,13 @@ mod tests { "Bearer someS3cr3tAutht0k3n" ); } + + #[test] + fn client_query() { + let req = Client::new() + .get("/") + .query(&[("key1", "val1"), ("key2", "val2")]) + .unwrap(); + assert_eq!(req.get_uri().query().unwrap(), "key1=val1&key2=val2"); + } } diff --git a/awc/src/response.rs b/awc/src/response.rs index d186526de..5ef8e18b5 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -1,9 +1,11 @@ use std::cell::{Ref, RefMut}; use std::fmt; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use bytes::{Bytes, BytesMut}; -use futures::{Async, Future, Poll, Stream}; +use futures::{ready, Future, Stream}; use actix_http::cookie::Cookie; use actix_http::error::{CookieParseError, PayloadError}; @@ -104,7 +106,7 @@ impl ClientResponse { impl ClientResponse where - S: Stream, + S: Stream>, { /// Loads http response's body. pub fn body(&mut self) -> MessageBody { @@ -125,13 +127,12 @@ where impl Stream for ClientResponse where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.payload.poll() + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.get_mut().payload).poll_next(cx) } } @@ -155,7 +156,7 @@ pub struct MessageBody { impl MessageBody where - S: Stream, + S: Stream>, { /// Create `MessageBody` for request. pub fn new(res: &mut ClientResponse) -> MessageBody { @@ -198,23 +199,24 @@ where impl Future for MessageBody where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Output = Result; - fn poll(&mut self) -> Poll { - if let Some(err) = self.err.take() { - return Err(err); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Err(err)); } - if let Some(len) = self.length.take() { - if len > self.fut.as_ref().unwrap().limit { - return Err(PayloadError::Overflow); + if let Some(len) = this.length.take() { + if len > this.fut.as_ref().unwrap().limit { + return Poll::Ready(Err(PayloadError::Overflow)); } } - self.fut.as_mut().unwrap().poll() + Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } } @@ -233,7 +235,7 @@ pub struct JsonBody { impl JsonBody where - S: Stream, + S: Stream>, U: DeserializeOwned, { /// Create `JsonBody` for request. @@ -279,27 +281,35 @@ where } } -impl Future for JsonBody +impl Unpin for JsonBody where - T: Stream, + T: Stream> + Unpin, U: DeserializeOwned, { - type Item = U; - type Error = JsonPayloadError; +} - fn poll(&mut self) -> Poll { +impl Future for JsonBody +where + T: Stream> + Unpin, + U: DeserializeOwned, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } if let Some(len) = self.length.take() { if len > self.fut.as_ref().unwrap().limit { - return Err(JsonPayloadError::Payload(PayloadError::Overflow)); + return Poll::Ready(Err(JsonPayloadError::Payload( + PayloadError::Overflow, + ))); } } - let body = futures::try_ready!(self.fut.as_mut().unwrap().poll()); - Ok(Async::Ready(serde_json::from_slice::(&body)?)) + let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; + Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) } } @@ -321,24 +331,25 @@ impl ReadBody { impl Future for ReadBody where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); - fn poll(&mut self) -> Poll { loop { - return match self.stream.poll()? { - Async::Ready(Some(chunk)) => { - if (self.buf.len() + chunk.len()) > self.limit { - Err(PayloadError::Overflow) + return match Pin::new(&mut this.stream).poll_next(cx)? { + Poll::Ready(Some(chunk)) => { + if (this.buf.len() + chunk.len()) > this.limit { + Poll::Ready(Err(PayloadError::Overflow)) } else { - self.buf.extend_from_slice(&chunk); + this.buf.extend_from_slice(&chunk); continue; } } - Async::Ready(None) => Ok(Async::Ready(self.buf.take().freeze())), - Async::NotReady => Ok(Async::NotReady), + Poll::Ready(None) => Poll::Ready(Ok(this.buf.take().freeze())), + Poll::Pending => Poll::Pending, }; } } @@ -348,41 +359,40 @@ where mod tests { use super::*; use actix_http_test::block_on; - use futures::Async; use serde::{Deserialize, Serialize}; use crate::{http::header, test::TestResponse}; #[test] fn test_body() { - let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); - match req.body().poll().err().unwrap() { - PayloadError::UnknownLength => (), - _ => unreachable!("error"), - } + block_on(async { + let mut req = + TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); + match req.body().await.err().unwrap() { + PayloadError::UnknownLength => (), + _ => unreachable!("error"), + } - let mut req = - TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); - match req.body().poll().err().unwrap() { - PayloadError::Overflow => (), - _ => unreachable!("error"), - } + let mut req = + TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); + match req.body().await.err().unwrap() { + PayloadError::Overflow => (), + _ => unreachable!("error"), + } - let mut req = TestResponse::default() - .set_payload(Bytes::from_static(b"test")) - .finish(); - match req.body().poll().ok().unwrap() { - Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")), - _ => unreachable!("error"), - } + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"test")) + .finish(); + assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); - let mut req = TestResponse::default() - .set_payload(Bytes::from_static(b"11111111111111")) - .finish(); - match req.body().limit(5).poll().err().unwrap() { - PayloadError::Overflow => (), - _ => unreachable!("error"), - } + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"11111111111111")) + .finish(); + match req.body().limit(5).await.err().unwrap() { + PayloadError::Overflow => (), + _ => unreachable!("error"), + } + }) } #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -406,54 +416,56 @@ mod tests { #[test] fn test_json_body() { - let mut req = TestResponse::default().finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + block_on(async { + let mut req = TestResponse::default().finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - let mut req = TestResponse::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/text"), - ) - .finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + let mut req = TestResponse::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/text"), + ) + .finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - let mut req = TestResponse::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("10000"), - ) - .finish(); + let mut req = TestResponse::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("10000"), + ) + .finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req).limit(100)); - assert!(json_eq( - json.err().unwrap(), - JsonPayloadError::Payload(PayloadError::Overflow) - )); + let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; + assert!(json_eq( + json.err().unwrap(), + JsonPayloadError::Payload(PayloadError::Overflow) + )); - let mut req = TestResponse::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .finish(); + let mut req = TestResponse::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); - assert_eq!( - json.ok().unwrap(), - MyObject { - name: "test".to_owned() - } - ); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert_eq!( + json.ok().unwrap(), + MyObject { + name: "test".to_owned() + } + ); + }) } } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index 95109b92d..f7461113a 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -1,13 +1,15 @@ use std::net; +use std::pin::Pin; use std::rc::Rc; -use std::time::{Duration, Instant}; +use std::task::{Context, Poll}; +use std::time::Duration; use bytes::Bytes; use derive_more::From; -use futures::{try_ready, Async, Future, Poll, Stream}; +use futures::{future::LocalBoxFuture, ready, Future, Stream}; use serde::Serialize; use serde_json; -use tokio_timer::Delay; +use tokio_timer::{delay_for, Delay}; use actix_http::body::{Body, BodyStream}; use actix_http::encoding::Decoder; @@ -47,7 +49,7 @@ impl Into for PrepForSendingError { #[must_use = "futures do nothing unless polled"] pub enum SendClientRequest { Fut( - Box>, + LocalBoxFuture<'static, Result>, Option, bool, ), @@ -56,41 +58,51 @@ pub enum SendClientRequest { impl SendClientRequest { pub(crate) fn new( - send: Box>, + send: LocalBoxFuture<'static, Result>, response_decompress: bool, timeout: Option, ) -> SendClientRequest { - let delay = timeout.map(|t| Delay::new(Instant::now() + t)); + let delay = timeout.map(|t| delay_for(t)); SendClientRequest::Fut(send, delay, response_decompress) } } impl Future for SendClientRequest { - type Item = ClientResponse>>; - type Error = SendRequestError; + type Output = + Result>>, SendRequestError>; - fn poll(&mut self) -> Poll { - match self { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + match this { SendClientRequest::Fut(send, delay, response_decompress) => { if delay.is_some() { - match delay.poll() { - Ok(Async::NotReady) => (), - _ => return Err(SendRequestError::Timeout), + match Pin::new(delay.as_mut().unwrap()).poll(cx) { + Poll::Pending => (), + _ => return Poll::Ready(Err(SendRequestError::Timeout)), } } - let res = try_ready!(send.poll()).map_body(|head, payload| { - if *response_decompress { - Payload::Stream(Decoder::from_headers(payload, &head.headers)) - } else { - Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) - } + let res = ready!(Pin::new(send).poll(cx)).map(|res| { + res.map_body(|head, payload| { + if *response_decompress { + Payload::Stream(Decoder::from_headers( + payload, + &head.headers, + )) + } else { + Payload::Stream(Decoder::new( + payload, + ContentEncoding::Identity, + )) + } + }) }); - Ok(Async::Ready(res)) + Poll::Ready(res) } SendClientRequest::Err(ref mut e) => match e.take() { - Some(e) => Err(e), + Some(e) => Poll::Ready(Err(e)), None => panic!("Attempting to call completed future"), }, } @@ -223,7 +235,7 @@ impl RequestSender { stream: S, ) -> SendClientRequest where - S: Stream + 'static, + S: Stream> + Unpin + 'static, E: Into + 'static, { self.send_body( diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 39111ca69..ccfa8ebd9 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -7,7 +7,6 @@ use std::{fmt, str}; use actix_codec::Framed; use actix_http::cookie::{Cookie, CookieJar}; use actix_http::{ws, Payload, RequestHead}; -use futures::future::{err, Either, Future}; use percent_encoding::percent_encode; use tokio_timer::Timeout; @@ -210,27 +209,26 @@ impl WebsocketsRequest { } /// Complete request construction and connect to a websockets server. - pub fn connect( + pub async fn connect( mut self, - ) -> impl Future), Error = WsClientError> - { + ) -> Result<(ClientResponse, Framed), WsClientError> { if let Some(e) = self.err.take() { - return Either::A(err(e.into())); + return Err(e.into()); } // validate uri let uri = &self.head.uri; if uri.host().is_none() { - return Either::A(err(InvalidUrl::MissingHost.into())); + return Err(InvalidUrl::MissingHost.into()); } else if uri.scheme_part().is_none() { - return Either::A(err(InvalidUrl::MissingScheme.into())); + return Err(InvalidUrl::MissingScheme.into()); } else if let Some(scheme) = uri.scheme_part() { match scheme.as_str() { "http" | "ws" | "https" | "wss" => (), - _ => return Either::A(err(InvalidUrl::UnknownScheme.into())), + _ => return Err(InvalidUrl::UnknownScheme.into()), } } else { - return Either::A(err(InvalidUrl::UnknownScheme.into())); + return Err(InvalidUrl::UnknownScheme.into()); } if !self.head.headers.contains_key(header::HOST) { @@ -315,90 +313,83 @@ impl WebsocketsRequest { .config .connector .borrow_mut() - .open_tunnel(head, self.addr) - .from_err() - .and_then(move |(head, framed)| { - // verify response - if head.status != StatusCode::SWITCHING_PROTOCOLS { - return Err(WsClientError::InvalidResponseStatus(head.status)); - } - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_ascii_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - log::trace!("Invalid upgrade header"); - return Err(WsClientError::InvalidUpgradeHeader); - } - // Check for "CONNECTION" header - if let Some(conn) = head.headers.get(&header::CONNECTION) { - if let Ok(s) = conn.to_str() { - if !s.to_ascii_lowercase().contains("upgrade") { - log::trace!("Invalid connection header: {}", s); - return Err(WsClientError::InvalidConnectionHeader( - conn.clone(), - )); - } - } else { - log::trace!("Invalid connection header: {:?}", conn); - return Err(WsClientError::InvalidConnectionHeader( - conn.clone(), - )); - } - } else { - log::trace!("Missing connection header"); - return Err(WsClientError::MissingConnectionHeader); - } - - if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { - let encoded = ws::hash_key(key.as_ref()); - if hdr_key.as_bytes() != encoded.as_bytes() { - log::trace!( - "Invalid challenge response: expected: {} received: {:?}", - encoded, - key - ); - return Err(WsClientError::InvalidChallengeResponse( - encoded, - hdr_key.clone(), - )); - } - } else { - log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); - return Err(WsClientError::MissingWebSocketAcceptHeader); - }; - - // response and ws framed - Ok(( - ClientResponse::new(head, Payload::None), - framed.map_codec(|_| { - if server_mode { - ws::Codec::new().max_size(max_size) - } else { - ws::Codec::new().max_size(max_size).client_mode() - } - }), - )) - }); + .open_tunnel(head, self.addr); // set request timeout - if let Some(timeout) = self.config.timeout { - Either::B(Either::A(Timeout::new(fut, timeout).map_err(|e| { - if let Some(e) = e.into_inner() { - e - } else { - SendRequestError::Timeout.into() - } - }))) + let (head, framed) = if let Some(timeout) = self.config.timeout { + Timeout::new(fut, timeout) + .await + .map_err(|_| SendRequestError::Timeout.into()) + .and_then(|res| res)? } else { - Either::B(Either::B(fut)) + fut.await? + }; + + // verify response + if head.status != StatusCode::SWITCHING_PROTOCOLS { + return Err(WsClientError::InvalidResponseStatus(head.status)); } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + log::trace!("Invalid upgrade header"); + return Err(WsClientError::InvalidUpgradeHeader); + } + + // Check for "CONNECTION" header + if let Some(conn) = head.headers.get(&header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_ascii_lowercase().contains("upgrade") { + log::trace!("Invalid connection header: {}", s); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Invalid connection header: {:?}", conn); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Missing connection header"); + return Err(WsClientError::MissingConnectionHeader); + } + + if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { + let encoded = ws::hash_key(key.as_ref()); + if hdr_key.as_bytes() != encoded.as_bytes() { + log::trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, + key + ); + return Err(WsClientError::InvalidChallengeResponse( + encoded, + hdr_key.clone(), + )); + } + } else { + log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(WsClientError::MissingWebSocketAcceptHeader); + }; + + // response and ws framed + Ok(( + ClientResponse::new(head, Payload::None), + framed.map_codec(|_| { + if server_mode { + ws::Codec::new().max_size(max_size) + } else { + ws::Codec::new().max_size(max_size).client_mode() + } + }), + )) } } @@ -419,6 +410,8 @@ impl fmt::Debug for WebsocketsRequest { #[cfg(test)] mod tests { + use actix_web::test::block_on; + use super::*; use crate::Client; @@ -493,35 +486,33 @@ mod tests { #[test] fn basics() { - let req = Client::new() - .ws("http://localhost/") - .origin("test-origin") - .max_frame_size(100) - .server_mode() - .protocols(&["v1", "v2"]) - .set_header_if_none(header::CONTENT_TYPE, "json") - .set_header_if_none(header::CONTENT_TYPE, "text") - .cookie(Cookie::build("cookie1", "value1").finish()); - assert_eq!( - req.origin.as_ref().unwrap().to_str().unwrap(), - "test-origin" - ); - assert_eq!(req.max_size, 100); - assert_eq!(req.server_mode, true); - assert_eq!(req.protocols, Some("v1,v2".to_string())); - assert_eq!( - req.head.headers.get(header::CONTENT_TYPE).unwrap(), - header::HeaderValue::from_static("json") - ); + block_on(async { + let req = Client::new() + .ws("http://localhost/") + .origin("test-origin") + .max_frame_size(100) + .server_mode() + .protocols(&["v1", "v2"]) + .set_header_if_none(header::CONTENT_TYPE, "json") + .set_header_if_none(header::CONTENT_TYPE, "text") + .cookie(Cookie::build("cookie1", "value1").finish()); + assert_eq!( + req.origin.as_ref().unwrap().to_str().unwrap(), + "test-origin" + ); + assert_eq!(req.max_size, 100); + assert_eq!(req.server_mode, true); + assert_eq!(req.protocols, Some("v1,v2".to_string())); + assert_eq!( + req.head.headers.get(header::CONTENT_TYPE).unwrap(), + header::HeaderValue::from_static("json") + ); - let _ = actix_http_test::block_fn(move || req.connect()); + let _ = req.connect().await; - assert!(Client::new().ws("/").connect().poll().is_err()); - assert!(Client::new().ws("http:///test").connect().poll().is_err()); - assert!(Client::new() - .ws("hmm://test.com/") - .connect() - .poll() - .is_err()); + assert!(Client::new().ws("/").connect().await.is_err()); + assert!(Client::new().ws("http:///test").connect().await.is_err()); + assert!(Client::new().ws("hmm://test.com/").connect().await.is_err()); + }) } } diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index cb38c7315..9e1948f79 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -9,12 +9,12 @@ use bytes::Bytes; use flate2::read::GzDecoder; use flate2::write::GzEncoder; use flate2::Compression; -use futures::Future; +use futures::future::ok; use rand::Rng; use actix_http::HttpService; -use actix_http_test::TestServer; -use actix_service::{service_fn, NewService}; +use actix_http_test::{block_on, TestServer}; +use actix_service::pipeline_factory; use actix_web::http::Cookie; use actix_web::middleware::{BodyEncoding, Compress}; use actix_web::{http::header, web, App, Error, HttpMessage, HttpRequest, HttpResponse}; @@ -44,459 +44,508 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ #[test] fn test_simple() { - let mut srv = - TestServer::new(|| { + block_on(async { + let srv = TestServer::start(|| { HttpService::new(App::new().service( web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), )) }); - let request = srv.get("/").header("x-test", "111").send(); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv.get("/").header("x-test", "111").send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - let mut response = srv.block_on(srv.post("/").send()).unwrap(); - assert!(response.status().is_success()); + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - // camel case - let response = srv.block_on(srv.post("/").camel_case().send()).unwrap(); - assert!(response.status().is_success()); + // camel case + let response = srv.post("/").camel_case().send().await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn test_json() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service( - web::resource("/").route(web::to(|_: web::Json| HttpResponse::Ok())), - )) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(|_: web::Json| HttpResponse::Ok())), + ), + ) + }); - let request = srv - .get("/") - .header("x-test", "111") - .send_json(&"TEST".to_string()); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv + .get("/") + .header("x-test", "111") + .send_json(&"TEST".to_string()); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn test_form() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to( - |_: web::Form>| HttpResponse::Ok(), - )))) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + |_: web::Form>| HttpResponse::Ok(), + )))) + }); - let mut data = HashMap::new(); - let _ = data.insert("key".to_string(), "TEST".to_string()); + let mut data = HashMap::new(); + let _ = data.insert("key".to_string(), "TEST".to_string()); - let request = srv.get("/").header("x-test", "111").send_form(&data); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv.get("/").header("x-test", "111").send_form(&data); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn test_timeout() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to_async( - || { - tokio_timer::sleep(Duration::from_millis(200)) - .then(|_| Ok::<_, Error>(HttpResponse::Ok().body(STR))) - }, - )))) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + || { + async { + tokio_timer::delay_for(Duration::from_millis(200)).await; + Ok::<_, Error>(HttpResponse::Ok().body(STR)) + } + }, + )))) + }); - let client = srv.execute(|| { - awc::Client::build() + let connector = awc::Connector::new() + .connector(actix_connect::new_connector( + actix_connect::start_default_resolver(), + )) + .timeout(Duration::from_secs(15)) + .finish(); + + let client = awc::Client::build() + .connector(connector) .timeout(Duration::from_millis(50)) - .finish() - }); - let request = client.get(srv.url("/")).send(); - match srv.block_on(request) { - Err(SendRequestError::Timeout) => (), - _ => panic!(), - } + .finish(); + + let request = client.get(srv.url("/")).send(); + match request.await { + Err(SendRequestError::Timeout) => (), + _ => panic!(), + } + }) } #[test] fn test_timeout_override() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to_async( - || { - tokio_timer::sleep(Duration::from_millis(200)) - .then(|_| Ok::<_, Error>(HttpResponse::Ok().body(STR))) - }, - )))) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + || { + async { + tokio_timer::delay_for(Duration::from_millis(200)).await; + Ok::<_, Error>(HttpResponse::Ok().body(STR)) + } + }, + )))) + }); - let client = awc::Client::build() - .timeout(Duration::from_millis(50000)) - .finish(); - let request = client - .get(srv.url("/")) - .timeout(Duration::from_millis(50)) - .send(); - match srv.block_on(request) { - Err(SendRequestError::Timeout) => (), - _ => panic!(), - } + let client = awc::Client::build() + .timeout(Duration::from_millis(50000)) + .finish(); + let request = client + .get(srv.url("/")) + .timeout(Duration::from_millis(50)) + .send(); + match request.await { + Err(SendRequestError::Timeout) => (), + _ => panic!(), + } + }) } #[test] fn test_connection_reuse() { - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); + block_on(async { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); - let mut srv = TestServer::new(move || { - let num2 = num2.clone(); - service_fn(move |io| { - num2.fetch_add(1, Ordering::Relaxed); - Ok(io) - }) - .and_then(HttpService::new( - App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), - )) - }); + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new(App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok())), + ))) + }); - let client = awc::Client::default(); + let client = awc::Client::default(); - // req 1 - let request = client.get(srv.url("/")).send(); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // req 1 + let request = client.get(srv.url("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); - // req 2 - let req = client.post(srv.url("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); - assert!(response.status().is_success()); + // req 2 + let req = client.post(srv.url("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); - // one connection - assert_eq!(num.load(Ordering::Relaxed), 1); + // one connection + assert_eq!(num.load(Ordering::Relaxed), 1); + }) } #[test] fn test_connection_force_close() { - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); + block_on(async { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); - let mut srv = TestServer::new(move || { - let num2 = num2.clone(); - service_fn(move |io| { - num2.fetch_add(1, Ordering::Relaxed); - Ok(io) - }) - .and_then(HttpService::new( - App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), - )) - }); + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new(App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok())), + ))) + }); - let client = awc::Client::default(); + let client = awc::Client::default(); - // req 1 - let request = client.get(srv.url("/")).force_close().send(); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // req 1 + let request = client.get(srv.url("/")).force_close().send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); - // req 2 - let req = client.post(srv.url("/")).force_close(); - let response = srv.block_on_fn(move || req.send()).unwrap(); - assert!(response.status().is_success()); + // req 2 + let req = client.post(srv.url("/")).force_close(); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); - // two connection - assert_eq!(num.load(Ordering::Relaxed), 2); + // two connection + assert_eq!(num.load(Ordering::Relaxed), 2); + }) } #[test] fn test_connection_server_close() { - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); + block_on(async { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); - let mut srv = TestServer::new(move || { - let num2 = num2.clone(); - service_fn(move |io| { - num2.fetch_add(1, Ordering::Relaxed); - Ok(io) - }) - .and_then(HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(|| HttpResponse::Ok().force_close().finish())), - ), - )) - }); + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(|| HttpResponse::Ok().force_close().finish())), + ), + )) + }); - let client = awc::Client::default(); + let client = awc::Client::default(); - // req 1 - let request = client.get(srv.url("/")).send(); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // req 1 + let request = client.get(srv.url("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); - // req 2 - let req = client.post(srv.url("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); - assert!(response.status().is_success()); + // req 2 + let req = client.post(srv.url("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); - // two connection - assert_eq!(num.load(Ordering::Relaxed), 2); + // two connection + assert_eq!(num.load(Ordering::Relaxed), 2); + }) } #[test] fn test_connection_wait_queue() { - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); + block_on(async { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); - let mut srv = TestServer::new(move || { - let num2 = num2.clone(); - service_fn(move |io| { - num2.fetch_add(1, Ordering::Relaxed); - Ok(io) - }) - .and_then(HttpService::new(App::new().service( - web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), - ))) - }); + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new(App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), + ))) + }); - let client = awc::Client::build() - .connector(awc::Connector::new().limit(1).finish()) - .finish(); + let client = awc::Client::build() + .connector(awc::Connector::new().limit(1).finish()) + .finish(); - // req 1 - let request = client.get(srv.url("/")).send(); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // req 1 + let request = client.get(srv.url("/")).send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // req 2 - let req2 = client.post(srv.url("/")); - let req2_fut = srv.execute(move || { - let mut fut = req2.send(); - assert!(fut.poll().unwrap().is_not_ready()); - fut - }); + // req 2 + let req2 = client.post(srv.url("/")); + let req2_fut = req2.send(); - // read response 1 - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response 1 + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - // req 2 - let response = srv.block_on(req2_fut).unwrap(); - assert!(response.status().is_success()); + // req 2 + let response = req2_fut.await.unwrap(); + assert!(response.status().is_success()); - // two connection - assert_eq!(num.load(Ordering::Relaxed), 1); + // two connection + assert_eq!(num.load(Ordering::Relaxed), 1); + }) } #[test] fn test_connection_wait_queue_force_close() { - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); + block_on(async { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); - let mut srv = TestServer::new(move || { - let num2 = num2.clone(); - service_fn(move |io| { - num2.fetch_add(1, Ordering::Relaxed); - Ok(io) - }) - .and_then(HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(|| HttpResponse::Ok().force_close().body(STR))), - ), - )) - }); + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(|| HttpResponse::Ok().force_close().body(STR))), + ), + )) + }); - let client = awc::Client::build() - .connector(awc::Connector::new().limit(1).finish()) - .finish(); + let client = awc::Client::build() + .connector(awc::Connector::new().limit(1).finish()) + .finish(); - // req 1 - let request = client.get(srv.url("/")).send(); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // req 1 + let request = client.get(srv.url("/")).send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // req 2 - let req2 = client.post(srv.url("/")); - let req2_fut = srv.execute(move || { - let mut fut = req2.send(); - assert!(fut.poll().unwrap().is_not_ready()); - fut - }); + // req 2 + let req2 = client.post(srv.url("/")); + let req2_fut = req2.send(); - // read response 1 - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response 1 + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - // req 2 - let response = srv.block_on(req2_fut).unwrap(); - assert!(response.status().is_success()); + // req 2 + let response = req2_fut.await.unwrap(); + assert!(response.status().is_success()); - // two connection - assert_eq!(num.load(Ordering::Relaxed), 2); + // two connection + assert_eq!(num.load(Ordering::Relaxed), 2); + }) } #[test] fn test_with_query_parameter() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").to( - |req: HttpRequest| { - if req.query_string().contains("qp") { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - }, - ))) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").to( + |req: HttpRequest| { + if req.query_string().contains("qp") { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }, + ))) + }); - let res = srv - .block_on(awc::Client::new().get(srv.url("/?qp=5")).send()) - .unwrap(); - assert!(res.status().is_success()); + let res = awc::Client::new() + .get(srv.url("/?qp=5")) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); + }) } #[test] fn test_no_decompress() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().wrap(Compress::default()).service( - web::resource("/").route(web::to(|| { - let mut res = HttpResponse::Ok().body(STR); - res.encoding(header::ContentEncoding::Gzip); - res - })), - )) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().wrap(Compress::default()).service( + web::resource("/").route(web::to(|| { + let mut res = HttpResponse::Ok().body(STR); + res.encoding(header::ContentEncoding::Gzip); + res + })), + )) + }); - let mut res = srv - .block_on(awc::Client::new().get(srv.url("/")).no_decompress().send()) - .unwrap(); - assert!(res.status().is_success()); + let mut res = awc::Client::new() + .get(srv.url("/")) + .no_decompress() + .send() + .await + .unwrap(); + assert!(res.status().is_success()); - // read response - let bytes = srv.block_on(res.body()).unwrap(); + // read response + let bytes = res.body().await.unwrap(); - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); - // POST - let mut res = srv - .block_on(awc::Client::new().post(srv.url("/")).no_decompress().send()) - .unwrap(); - assert!(res.status().is_success()); + // POST + let mut res = awc::Client::new() + .post(srv.url("/")) + .no_decompress() + .send() + .await + .unwrap(); + assert!(res.status().is_success()); - let bytes = srv.block_on(res.body()).unwrap(); - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + let bytes = res.body().await.unwrap(); + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_client_gzip_encoding() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to(|| { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let data = e.finish().unwrap(); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + || { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.as_ref()).unwrap(); + let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "gzip") - .body(data) - })))) - }); + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + }, + )))) + }); - // client request - let mut response = srv.block_on(srv.post("/").send()).unwrap(); - assert!(response.status().is_success()); + // client request + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_client_gzip_encoding_large() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to(|| { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.repeat(10).as_ref()).unwrap(); - let data = e.finish().unwrap(); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + || { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.repeat(10).as_ref()).unwrap(); + let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "gzip") - .body(data) - })))) - }); + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + }, + )))) + }); - // client request - let mut response = srv.block_on(srv.post("/").send()).unwrap(); - assert!(response.status().is_success()); + // client request + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(STR.repeat(10))); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(STR.repeat(10))); + }) } #[test] fn test_client_gzip_encoding_large_random() { - let data = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(100_000) - .collect::(); + block_on(async { + let data = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(100_000) + .collect::(); - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to( - |data: Bytes| { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "gzip") - .body(data) - }, - )))) - }); + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + |data: Bytes| { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + }, + )))) + }); - // client request - let mut response = srv.block_on(srv.post("/").send_body(data.clone())).unwrap(); - assert!(response.status().is_success()); + // client request + let mut response = srv.post("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(data)); + }) } #[test] fn test_client_brotli_encoding() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to( - |data: Bytes| { - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "br") - .body(data) - }, - )))) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + |data: Bytes| { + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "br") + .body(data) + }, + )))) + }); - // client request - let mut response = srv.block_on(srv.post("/").send_body(STR)).unwrap(); - assert!(response.status().is_success()); + // client request + let mut response = srv.post("/").send_body(STR).await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } // #[test] @@ -506,7 +555,7 @@ fn test_client_brotli_encoding() { // .take(70_000) // .collect::(); -// let mut srv = test::TestServer::new(|app| { +// let srv = test::TestServer::start(|app| { // app.handler(|req: &HttpRequest| { // req.body() // .and_then(move |bytes: Bytes| { @@ -524,11 +573,11 @@ fn test_client_brotli_encoding() { // .content_encoding(http::ContentEncoding::Br) // .body(data.clone()) // .unwrap(); -// let response = srv.execute(request.send()).unwrap(); +// let response = request.send().await.unwrap(); // assert!(response.status().is_success()); // // read response -// let bytes = srv.execute(response.body()).unwrap(); +// let bytes = response.body().await.unwrap(); // assert_eq!(bytes.len(), data.len()); // assert_eq!(bytes, Bytes::from(data)); // } @@ -536,7 +585,7 @@ fn test_client_brotli_encoding() { // #[cfg(feature = "brotli")] // #[test] // fn test_client_deflate_encoding() { -// let mut srv = test::TestServer::new(|app| { +// let srv = test::TestServer::start(|app| { // app.handler(|req: &HttpRequest| { // req.body() // .and_then(|bytes: Bytes| { @@ -569,7 +618,7 @@ fn test_client_brotli_encoding() { // .take(70_000) // .collect::(); -// let mut srv = test::TestServer::new(|app| { +// let srv = test::TestServer::start(|app| { // app.handler(|req: &HttpRequest| { // req.body() // .and_then(|bytes: Bytes| { @@ -597,7 +646,7 @@ fn test_client_brotli_encoding() { // #[test] // fn test_client_streaming_explicit() { -// let mut srv = test::TestServer::new(|app| { +// let srv = test::TestServer::start(|app| { // app.handler(|req: &HttpRequest| { // req.body() // .map_err(Error::from) @@ -624,7 +673,7 @@ fn test_client_brotli_encoding() { // #[test] // fn test_body_streaming_implicit() { -// let mut srv = test::TestServer::new(|app| { +// let srv = test::TestServer::start(|app| { // app.handler(|_| { // let body = once(Ok(Bytes::from_static(STR.as_ref()))); // HttpResponse::Ok() @@ -644,65 +693,78 @@ fn test_client_brotli_encoding() { #[test] fn test_client_cookie_handling() { - fn err() -> Error { - use std::io::{Error as IoError, ErrorKind}; - // stub some generic error - Error::from(IoError::from(ErrorKind::NotFound)) - } - let cookie1 = Cookie::build("cookie1", "value1").finish(); - let cookie2 = Cookie::build("cookie2", "value2") - .domain("www.example.org") - .path("/") - .secure(true) - .http_only(true) - .finish(); - // Q: are all these clones really necessary? A: Yes, possibly - let cookie1b = cookie1.clone(); - let cookie2b = cookie2.clone(); + use std::io::{Error as IoError, ErrorKind}; - let mut srv = TestServer::new(move || { - let cookie1 = cookie1b.clone(); - let cookie2 = cookie2b.clone(); + block_on(async { + let cookie1 = Cookie::build("cookie1", "value1").finish(); + let cookie2 = Cookie::build("cookie2", "value2") + .domain("www.example.org") + .path("/") + .secure(true) + .http_only(true) + .finish(); + // Q: are all these clones really necessary? A: Yes, possibly + let cookie1b = cookie1.clone(); + let cookie2b = cookie2.clone(); - HttpService::new(App::new().route( - "/", - web::to(move |req: HttpRequest| { - // Check cookies were sent correctly - req.cookie("cookie1") - .ok_or_else(err) - .and_then(|c1| { - if c1.value() == "value1" { - Ok(()) + let srv = TestServer::start(move || { + let cookie1 = cookie1b.clone(); + let cookie2 = cookie2b.clone(); + + HttpService::new(App::new().route( + "/", + web::to(move |req: HttpRequest| { + let cookie1 = cookie1.clone(); + let cookie2 = cookie2.clone(); + + async move { + // Check cookies were sent correctly + let res: Result<(), Error> = req + .cookie("cookie1") + .ok_or(()) + .and_then(|c1| { + if c1.value() == "value1" { + Ok(()) + } else { + Err(()) + } + }) + .and_then(|()| req.cookie("cookie2").ok_or(())) + .and_then(|c2| { + if c2.value() == "value2" { + Ok(()) + } else { + Err(()) + } + }) + .map_err(|_| { + Error::from(IoError::from(ErrorKind::NotFound)) + }); + + if let Err(e) = res { + Err(e) } else { - Err(err()) + // Send some cookies back + Ok::<_, Error>( + HttpResponse::Ok() + .cookie(cookie1) + .cookie(cookie2) + .finish(), + ) } - }) - .and_then(|()| req.cookie("cookie2").ok_or_else(err)) - .and_then(|c2| { - if c2.value() == "value2" { - Ok(()) - } else { - Err(err()) - } - }) - // Send some cookies back - .map(|_| { - HttpResponse::Ok() - .cookie(cookie1.clone()) - .cookie(cookie2.clone()) - .finish() - }) - }), - )) - }); + } + }), + )) + }); - let request = srv.get("/").cookie(cookie1.clone()).cookie(cookie2.clone()); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); - let c1 = response.cookie("cookie1").expect("Missing cookie1"); - assert_eq!(c1, cookie1); - let c2 = response.cookie("cookie2").expect("Missing cookie2"); - assert_eq!(c2, cookie2); + let request = srv.get("/").cookie(cookie1.clone()).cookie(cookie2.clone()); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + let c1 = response.cookie("cookie1").expect("Missing cookie1"); + assert_eq!(c1, cookie1); + let c2 = response.cookie("cookie2").expect("Missing cookie2"); + assert_eq!(c2, cookie2); + }) } // #[test] @@ -727,66 +789,70 @@ fn test_client_cookie_handling() { // let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) // .finish() // .unwrap(); -// let response = sys.block_on(req.send()).unwrap(); +// let response = req.send().await.unwrap(); // assert!(response.status().is_success()); // // read response -// let bytes = sys.block_on(response.body()).unwrap(); +// let bytes = response.body().await.unwrap(); // assert_eq!(bytes, Bytes::from_static(b"welcome!")); // } #[test] fn client_basic_auth() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().route( - "/", - web::to(|req: HttpRequest| { - if req - .headers() - .get(header::AUTHORIZATION) - .unwrap() - .to_str() - .unwrap() - == "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" - { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - }), - )) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().route( + "/", + web::to(|req: HttpRequest| { + if req + .headers() + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap() + == "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + }); - // set authorization header to Basic - let request = srv.get("/").basic_auth("username", Some("password")); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + // set authorization header to Basic + let request = srv.get("/").basic_auth("username", Some("password")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + }) } #[test] fn client_bearer_auth() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().route( - "/", - web::to(|req: HttpRequest| { - if req - .headers() - .get(header::AUTHORIZATION) - .unwrap() - .to_str() - .unwrap() - == "Bearer someS3cr3tAutht0k3n" - { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - }), - )) - }); + block_on(async { + let srv = TestServer::start(|| { + HttpService::new(App::new().route( + "/", + web::to(|req: HttpRequest| { + if req + .headers() + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap() + == "Bearer someS3cr3tAutht0k3n" + { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + }); - // set authorization header to Bearer - let request = srv.get("/").bearer_auth("someS3cr3tAutht0k3n"); - let response = srv.block_on(request.send()).unwrap(); - assert!(response.status().is_success()); + // set authorization header to Bearer + let request = srv.get("/").bearer_auth("someS3cr3tAutht0k3n"); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + }) } diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs index e65e4e874..bdfd21031 100644 --- a/awc/tests/test_rustls_client.rs +++ b/awc/tests/test_rustls_client.rs @@ -1,96 +1,109 @@ -#![cfg(feature = "rust-tls")] -use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - ClientConfig, NoClientAuth, -}; +#![cfg(feature = "rustls")] +use rust_tls::ClientConfig; -use std::fs::File; -use std::io::{BufReader, Result}; +use std::io::Result; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::HttpService; -use actix_http_test::TestServer; -use actix_server::ssl::RustlsAcceptor; -use actix_service::{service_fn, NewService}; +use actix_http_test::{block_on, TestServer}; +use actix_server::ssl::OpensslAcceptor; +use actix_service::{pipeline_factory, ServiceFactory}; use actix_web::http::Version; use actix_web::{web, App, HttpResponse}; +use futures::future::ok; +use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode}; -fn ssl_acceptor() -> Result> { - use rustls::ServerConfig; +fn ssl_acceptor() -> Result> { // load ssl keys - let mut config = ServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - let protos = vec![b"h2".to_vec()]; - config.set_protocols(&protos); - Ok(RustlsAcceptor::new(config)) + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder.set_verify_callback(SslVerifyMode::NONE, |_, _| true); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(open_ssl::ssl::AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2")?; + Ok(actix_server::ssl::OpensslAcceptor::new(builder.build())) } mod danger { pub struct NoCertificateVerification {} - impl rustls::ServerCertVerifier for NoCertificateVerification { + impl rust_tls::ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, - _roots: &rustls::RootCertStore, - _presented_certs: &[rustls::Certificate], + _roots: &rust_tls::RootCertStore, + _presented_certs: &[rust_tls::Certificate], _dns_name: webpki::DNSNameRef<'_>, _ocsp: &[u8], - ) -> Result { - Ok(rustls::ServerCertVerified::assertion()) + ) -> Result { + Ok(rust_tls::ServerCertVerified::assertion()) } } } -#[test] -fn test_connection_reuse_h2() { - let rustls = ssl_acceptor().unwrap(); - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); +// #[test] +fn _test_connection_reuse_h2() { + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); - let mut srv = TestServer::new(move || { - let num2 = num2.clone(); - service_fn(move |io| { - num2.fetch_add(1, Ordering::Relaxed); - Ok(io) - }) - .and_then(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) - .and_then( - HttpService::build() - .h2(App::new() - .service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) - .map_err(|_| ()), - ) - }); + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok())), + )) + .map_err(|_| ()), + ) + }); - // disable ssl verification - let mut config = ClientConfig::new(); - let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - config.set_protocols(&protos); - config - .dangerous() - .set_certificate_verifier(Arc::new(danger::NoCertificateVerification {})); + // disable ssl verification + let mut config = ClientConfig::new(); + let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + config.set_protocols(&protos); + config + .dangerous() + .set_certificate_verifier(Arc::new(danger::NoCertificateVerification {})); - let client = awc::Client::build() - .connector(awc::Connector::new().rustls(Arc::new(config)).finish()) - .finish(); + let client = awc::Client::build() + .connector(awc::Connector::new().rustls(Arc::new(config)).finish()) + .finish(); - // req 1 - let request = client.get(srv.surl("/")).send(); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // req 1 + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); - // req 2 - let req = client.post(srv.surl("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.version(), Version::HTTP_2); + // req 2 + let req = client.post(srv.surl("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); - // one connection - assert_eq!(num.load(Ordering::Relaxed), 1); + // one connection + assert_eq!(num.load(Ordering::Relaxed), 1); + }) } diff --git a/awc/tests/test_ssl_client.rs b/awc/tests/test_ssl_client.rs index e6b0101b2..d37dba291 100644 --- a/awc/tests/test_ssl_client.rs +++ b/awc/tests/test_ssl_client.rs @@ -1,5 +1,5 @@ -#![cfg(feature = "ssl")] -use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; +#![cfg(feature = "openssl")] +use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; use std::io::Result; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -7,11 +7,12 @@ use std::sync::Arc; use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::HttpService; -use actix_http_test::TestServer; +use actix_http_test::{block_on, TestServer}; use actix_server::ssl::OpensslAcceptor; -use actix_service::{service_fn, NewService}; +use actix_service::{pipeline_factory, ServiceFactory}; use actix_web::http::Version; use actix_web::{web, App, HttpResponse}; +use futures::future::ok; fn ssl_acceptor() -> Result> { // load ssl keys @@ -27,7 +28,7 @@ fn ssl_acceptor() -> Result> { if protos.windows(3).any(|window| window == H2) { Ok(b"h2") } else { - Err(openssl::ssl::AlpnError::NOACK) + Err(open_ssl::ssl::AlpnError::NOACK) } }); builder.set_alpn_protos(b"\x02h2")?; @@ -36,51 +37,54 @@ fn ssl_acceptor() -> Result> { #[test] fn test_connection_reuse_h2() { - let openssl = ssl_acceptor().unwrap(); - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); + block_on(async { + let openssl = ssl_acceptor().unwrap(); + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); - let mut srv = TestServer::new(move || { - let num2 = num2.clone(); - service_fn(move |io| { - num2.fetch_add(1, Ordering::Relaxed); - Ok(io) - }) - .and_then( - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)), - ) - .and_then( - HttpService::build() - .h2(App::new() - .service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) - .map_err(|_| ()), - ) - }); + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok())), + )) + .map_err(|_| ()), + ) + }); - // disable ssl verification - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder - .set_alpn_protos(b"\x02h2\x08http/1.1") - .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + // disable ssl verification + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); - let client = awc::Client::build() - .connector(awc::Connector::new().ssl(builder.build()).finish()) - .finish(); + let client = awc::Client::build() + .connector(awc::Connector::new().ssl(builder.build()).finish()) + .finish(); - // req 1 - let request = client.get(srv.surl("/")).send(); - let response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // req 1 + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); - // req 2 - let req = client.post(srv.surl("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.version(), Version::HTTP_2); + // req 2 + let req = client.post(srv.surl("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); - // one connection - assert_eq!(num.load(Ordering::Relaxed), 1); + // one connection + assert_eq!(num.load(Ordering::Relaxed), 1); + }) } diff --git a/awc/tests/test_ws.rs b/awc/tests/test_ws.rs index 5abf96355..633e8db51 100644 --- a/awc/tests/test_ws.rs +++ b/awc/tests/test_ws.rs @@ -2,81 +2,82 @@ use std::io; use actix_codec::Framed; use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response}; -use actix_http_test::TestServer; +use actix_http_test::{block_on, TestServer}; use bytes::{Bytes, BytesMut}; use futures::future::ok; -use futures::{Future, Sink, Stream}; +use futures::{SinkExt, StreamExt}; -fn ws_service(req: ws::Frame) -> impl Future { +async fn ws_service(req: ws::Frame) -> Result { match req { - ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)), + ws::Frame::Ping(msg) => Ok(ws::Message::Pong(msg)), ws::Frame::Text(text) => { let text = if let Some(pl) = text { String::from_utf8(Vec::from(pl.as_ref())).unwrap() } else { String::new() }; - ok(ws::Message::Text(text)) + Ok(ws::Message::Text(text)) } - ws::Frame::Binary(bin) => ok(ws::Message::Binary( + ws::Frame::Binary(bin) => Ok(ws::Message::Binary( bin.map(|e| e.freeze()) .unwrap_or_else(|| Bytes::from("")) .into(), )), - ws::Frame::Close(reason) => ok(ws::Message::Close(reason)), - _ => ok(ws::Message::Close(None)), + ws::Frame::Close(reason) => Ok(ws::Message::Close(reason)), + _ => Ok(ws::Message::Close(None)), } } #[test] fn test_simple() { - let mut srv = TestServer::new(|| { - HttpService::build() - .upgrade(|(req, framed): (Request, Framed<_, _>)| { - let res = ws::handshake_response(req.head()).finish(); - // send handshake response - framed - .send(h1::Message::Item((res.drop_body(), BodySize::None))) - .map_err(|e: io::Error| e.into()) - .and_then(|framed| { + block_on(async { + let mut srv = TestServer::start(|| { + HttpService::build() + .upgrade(|(req, mut framed): (Request, Framed<_, _>)| { + async move { + let res = ws::handshake_response(req.head()).finish(); + // send handshake response + framed + .send(h1::Message::Item((res.drop_body(), BodySize::None))) + .await?; + // start websocket service let framed = framed.into_framed(ws::Codec::new()); - ws::Transport::with(framed, ws_service) - }) - }) - .finish(|_| ok::<_, Error>(Response::NotFound())) - }); + ws::Transport::with(framed, ws_service).await + } + }) + .finish(|_| ok::<_, Error>(Response::NotFound())) + }); - // client service - let framed = srv.ws().unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + // client service + let mut framed = srv.ws().await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Text(Some(BytesMut::from("text")))); - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) - ); + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!( + item, + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Pong("text".to_string().into())); - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) - .unwrap(); + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) - ); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); + }) } diff --git a/examples/basic.rs b/examples/basic.rs index 46440d706..6d9a4dcd8 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,22 +1,18 @@ -use futures::IntoFuture; - -use actix_web::{ - get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, -}; +use actix_web::{get, middleware, web, App, HttpRequest, HttpResponse, HttpServer}; #[get("/resource1/{name}/index.html")] -fn index(req: HttpRequest, name: web::Path) -> String { +async fn index(req: HttpRequest, name: web::Path) -> String { println!("REQ: {:?}", req); format!("Hello: {}!\r\n", name) } -fn index_async(req: HttpRequest) -> impl IntoFuture { +async fn index_async(req: HttpRequest) -> &'static str { println!("REQ: {:?}", req); - Ok("Hello world!\r\n") + "Hello world!\r\n" } #[get("/")] -fn no_params() -> &'static str { +async fn no_params() -> &'static str { "Hello world!\r\n" } @@ -39,9 +35,9 @@ fn main() -> std::io::Result<()> { .default_service( web::route().to(|| HttpResponse::MethodNotAllowed()), ) - .route(web::get().to_async(index_async)), + .route(web::get().to(index_async)), ) - .service(web::resource("/test1.html").to(|| "Test\r\n")) + .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) }) .bind("127.0.0.1:8080")? .workers(1) diff --git a/examples/client.rs b/examples/client.rs index 8a75fd306..90a362fe3 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,26 +1,27 @@ use actix_http::Error; use actix_rt::System; -use futures::{future::lazy, Future}; fn main() -> Result<(), Error> { std::env::set_var("RUST_LOG", "actix_http=trace"); env_logger::init(); - System::new("test").block_on(lazy(|| { - awc::Client::new() - .get("https://www.rust-lang.org/") // <- Create request builder - .header("User-Agent", "Actix-web") - .send() // <- Send http request - .from_err() - .and_then(|mut response| { - // <- server http response - println!("Response: {:?}", response); + System::new("test").block_on(async { + let client = awc::Client::new(); - // read response body - response - .body() - .from_err() - .map(|body| println!("Downloaded: {:?} bytes", body.len())) - }) - })) + // Create request builder, configure request and send + let mut response = client + .get("https://www.rust-lang.org/") + .header("User-Agent", "Actix-web") + .send() + .await?; + + // server http response + println!("Response: {:?}", response); + + // read response body + let body = response.body().await?; + println!("Downloaded: {:?} bytes", body.len()); + + Ok(()) + }) } diff --git a/examples/uds.rs b/examples/uds.rs index 9dc82903f..fc6a58de1 100644 --- a/examples/uds.rs +++ b/examples/uds.rs @@ -1,26 +1,24 @@ -use futures::IntoFuture; - use actix_web::{ get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, }; #[get("/resource1/{name}/index.html")] -fn index(req: HttpRequest, name: web::Path) -> String { +async fn index(req: HttpRequest, name: web::Path) -> String { println!("REQ: {:?}", req); format!("Hello: {}!\r\n", name) } -fn index_async(req: HttpRequest) -> impl IntoFuture { +async fn index_async(req: HttpRequest) -> Result<&'static str, Error> { println!("REQ: {:?}", req); Ok("Hello world!\r\n") } #[get("/")] -fn no_params() -> &'static str { +async fn no_params() -> &'static str { "Hello world!\r\n" } -#[cfg(feature = "uds")] +#[cfg(unix)] fn main() -> std::io::Result<()> { std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); env_logger::init(); @@ -40,14 +38,14 @@ fn main() -> std::io::Result<()> { .default_service( web::route().to(|| HttpResponse::MethodNotAllowed()), ) - .route(web::get().to_async(index_async)), + .route(web::get().to(index_async)), ) - .service(web::resource("/test1.html").to(|| "Test\r\n")) + .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) }) .bind_uds("/Users/fafhrd91/uds-test")? .workers(1) .run() } -#[cfg(not(feature = "uds"))] +#[cfg(not(unix))] fn main() {} diff --git a/src/app.rs b/src/app.rs index f93859c7e..d9ac8c09d 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,14 +1,17 @@ use std::cell::RefCell; use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::body::{Body, MessageBody}; use actix_service::boxed::{self, BoxedNewService}; use actix_service::{ - apply_transform, IntoNewService, IntoTransform, NewService, Transform, + apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform, }; -use futures::{Future, IntoFuture}; +use futures::future::{FutureExt, LocalBoxFuture}; use crate::app_service::{AppEntry, AppInit, AppRoutingFactory}; use crate::config::{AppConfig, AppConfigInner, ServiceConfig}; @@ -18,19 +21,19 @@ use crate::error::Error; use crate::resource::Resource; use crate::route::Route; use crate::service::{ - HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest, + AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, }; type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; type FnDataFactory = - Box Box, Error = ()>>>; + Box LocalBoxFuture<'static, Result, ()>>>; /// Application builder - structure that follows the builder pattern /// for building application instances. pub struct App { endpoint: T, - services: Vec>, + services: Vec>, default: Option>, factory_ref: Rc>>, data: Vec>, @@ -61,7 +64,7 @@ impl App { impl App where B: MessageBody, - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -87,7 +90,7 @@ where /// counter: Cell, /// } /// - /// fn index(data: web::Data) { + /// async fn index(data: web::Data) { /// data.counter.set(data.counter.get() + 1); /// } /// @@ -107,24 +110,30 @@ where /// Set application data factory. This function is /// similar to `.data()` but it accepts data factory. Data object get /// constructed asynchronously during application initialization. - pub fn data_factory(mut self, data: F) -> Self + pub fn data_factory(mut self, data: F) -> Self where F: Fn() -> Out + 'static, - Out: IntoFuture + 'static, - Out::Error: std::fmt::Debug, + Out: Future> + 'static, + D: 'static, + E: std::fmt::Debug, { self.data_factories.push(Box::new(move || { - Box::new( - data() - .into_future() - .map_err(|e| { - log::error!("Can not construct data instance: {:?}", e); - }) - .map(|data| { - let data: Box = Box::new(Data::new(data)); - data - }), - ) + { + let fut = data(); + async move { + match fut.await { + Err(e) => { + log::error!("Can not construct data instance: {:?}", e); + Err(()) + } + Ok(data) => { + let data: Box = Box::new(Data::new(data)); + Ok(data) + } + } + } + } + .boxed_local() })); self } @@ -183,7 +192,7 @@ where /// ```rust /// use actix_web::{web, App, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -238,7 +247,7 @@ where /// ```rust /// use actix_web::{web, App, HttpResponse}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// @@ -267,8 +276,8 @@ where /// ``` pub fn default_service(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -277,11 +286,9 @@ where U::InitError: fmt::Debug, { // create and configure default resource - self.default = Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|e| { - log::error!("Can not construct default service: {:?}", e) - }), - ))); + self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err( + |e| log::error!("Can not construct default service: {:?}", e), + )))); self } @@ -295,7 +302,7 @@ where /// ```rust /// use actix_web::{web, App, HttpRequest, HttpResponse, Result}; /// - /// fn index(req: HttpRequest) -> Result { + /// async fn index(req: HttpRequest) -> Result { /// let url = req.url_for("youtube", &["asdlkjqme"])?; /// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme"); /// Ok(HttpResponse::Ok().into()) @@ -336,11 +343,10 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{middleware, web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// @@ -350,11 +356,11 @@ where /// .route("/index.html", web::get().to(index)); /// } /// ``` - pub fn wrap( + pub fn wrap( self, - mw: F, + mw: M, ) -> App< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -372,11 +378,9 @@ where InitError = (), >, B1: MessageBody, - F: IntoTransform, { - let endpoint = apply_transform(mw, self.endpoint); App { - endpoint, + endpoint: apply(mw, self.endpoint), data: self.data, data_factories: self.data_factories, services: self.services, @@ -397,23 +401,25 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// /// fn main() { /// let app = App::new() - /// .wrap_fn(|req, srv| - /// srv.call(req).map(|mut res| { + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; /// res.headers_mut().insert( /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// ); - /// res - /// })) + /// Ok(res) + /// } + /// }) /// .route("/index.html", web::get().to(index)); /// } /// ``` @@ -421,7 +427,7 @@ where self, mw: F, ) -> App< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -433,16 +439,26 @@ where where B1: MessageBody, F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, - R: IntoFuture, Error = Error>, + R: Future, Error>>, { - self.wrap(mw) + App { + endpoint: apply_fn_factory(self.endpoint, mw), + data: self.data, + data_factories: self.data_factories, + services: self.services, + default: self.default, + factory_ref: self.factory_ref, + config: self.config, + external: self.external, + _t: PhantomData, + } } } -impl IntoNewService> for App +impl IntoServiceFactory> for App where B: MessageBody, - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -450,7 +466,7 @@ where InitError = (), >, { - fn into_new_service(self) -> AppInit { + fn into_factory(self) -> AppInit { AppInit { data: Rc::new(self.data), data_factories: Rc::new(self.data_factories), @@ -468,82 +484,89 @@ where mod tests { use actix_service::Service; use bytes::Bytes; - use futures::{Future, IntoFuture}; + use futures::future::{ok, Future}; use super::*; use crate::http::{header, HeaderValue, Method, StatusCode}; + use crate::middleware::DefaultHeaders; use crate::service::{ServiceRequest, ServiceResponse}; - use crate::test::{ - block_fn, block_on, call_service, init_service, read_body, TestRequest, - }; + use crate::test::{block_on, call_service, init_service, read_body, TestRequest}; use crate::{web, Error, HttpRequest, HttpResponse}; #[test] fn test_default_resource() { - let mut srv = init_service( - App::new().service(web::resource("/test").to(|| HttpResponse::Ok())), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = block_fn(|| srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + block_on(async { + let mut srv = init_service( + App::new().service(web::resource("/test").to(|| HttpResponse::Ok())), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/blah").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/blah").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let mut srv = init_service( - App::new() - .service(web::resource("/test").to(|| HttpResponse::Ok())) - .service( - web::resource("/test2") - .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::Created()) - }) - .route(web::get().to(|| HttpResponse::Ok())), - ) - .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::MethodNotAllowed()) - }), - ); + let mut srv = init_service( + App::new() + .service(web::resource("/test").to(|| HttpResponse::Ok())) + .service( + web::resource("/test2") + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::Created())) + }) + .route(web::get().to(|| HttpResponse::Ok())), + ) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::MethodNotAllowed())) + }), + ) + .await; - let req = TestRequest::with_uri("/blah").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let req = TestRequest::with_uri("/blah").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let req = TestRequest::with_uri("/test2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/test2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/test2") - .method(Method::POST) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/test2") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + }) } #[test] fn test_data_factory() { - let mut srv = - init_service(App::new().data_factory(|| Ok::<_, ()>(10usize)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + block_on(async { + let mut srv = + init_service(App::new().data_factory(|| ok::<_, ()>(10usize)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let mut srv = - init_service(App::new().data_factory(|| Ok::<_, ()>(10u32)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + let mut srv = + init_service(App::new().data_factory(|| ok::<_, ()>(10u32)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + }) } fn md( req: ServiceRequest, srv: &mut S, - ) -> impl IntoFuture, Error = Error> + ) -> impl Future, Error>> where S: Service< Request = ServiceRequest, @@ -551,112 +574,141 @@ mod tests { Error = Error, >, { - srv.call(req).map(|mut res| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; res.headers_mut() .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); - res - }) + Ok(res) + } } #[test] fn test_wrap() { - let mut srv = init_service( - App::new() - .wrap(md) - .route("/test", web::get().to(|| HttpResponse::Ok())), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + block_on(async { + let mut srv = + init_service( + App::new() + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )) + .route("/test", web::get().to(|| HttpResponse::Ok())), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] fn test_router_wrap() { - let mut srv = init_service( - App::new() - .route("/test", web::get().to(|| HttpResponse::Ok())) - .wrap(md), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + block_on(async { + let mut srv = + init_service( + App::new() + .route("/test", web::get().to(|| HttpResponse::Ok())) + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] fn test_wrap_fn() { - let mut srv = init_service( - App::new() - .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); - res + block_on(async { + let mut srv = init_service( + App::new() + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + Ok(res) + } }) - }) - .service(web::resource("/test").to(|| HttpResponse::Ok())), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + .service(web::resource("/test").to(|| HttpResponse::Ok())), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] fn test_router_wrap_fn() { - let mut srv = init_service( - App::new() - .route("/test", web::get().to(|| HttpResponse::Ok())) - .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); - res - }) - }), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + block_on(async { + let mut srv = init_service( + App::new() + .route("/test", web::get().to(|| HttpResponse::Ok())) + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { + let mut res = fut.await?; + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + Ok(res) + } + }), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] fn test_external_resource() { - let mut srv = init_service( - App::new() - .external_resource("youtube", "https://youtube.com/watch/{video_id}") - .route( - "/test", - web::get().to(|req: HttpRequest| { - HttpResponse::Ok().body(format!( - "{}", - req.url_for("youtube", &["12345"]).unwrap() - )) - }), - ), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); - assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); + block_on(async { + let mut srv = init_service( + App::new() + .external_resource("youtube", "https://youtube.com/watch/{video_id}") + .route( + "/test", + web::get().to(|req: HttpRequest| { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("youtube", &["12345"]).unwrap() + )) + }), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); + }) } } diff --git a/src/app_service.rs b/src/app_service.rs index 513b4aa4b..7407ee2fb 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,14 +1,16 @@ use std::cell::RefCell; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{Extensions, Request, Response}; use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url}; use actix_server_config::ServerConfig; use actix_service::boxed::{self, BoxedNewService, BoxedService}; -use actix_service::{service_fn, NewService, Service}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, Poll}; +use actix_service::{service_fn, Service, ServiceFactory}; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; use crate::config::{AppConfig, AppService}; use crate::data::DataFactory; @@ -16,23 +18,20 @@ use crate::error::Error; use crate::guard::Guard; use crate::request::{HttpRequest, HttpRequestPool}; use crate::rmap::ResourceMap; -use crate::service::{ServiceFactory, ServiceRequest, ServiceResponse}; +use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse}; type Guards = Vec>; type HttpService = BoxedService; type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; -type BoxedResponse = Either< - FutureResult, - Box>, ->; +type BoxedResponse = LocalBoxFuture<'static, Result>; type FnDataFactory = - Box Box, Error = ()>>>; + Box LocalBoxFuture<'static, Result, ()>>>; /// Service factory to convert `Request` to a `ServiceRequest`. /// It also executes data factories. pub struct AppInit where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -44,15 +43,15 @@ where pub(crate) data: Rc>>, pub(crate) data_factories: Rc>, pub(crate) config: RefCell, - pub(crate) services: Rc>>>, + pub(crate) services: Rc>>>, pub(crate) default: Option>, pub(crate) factory_ref: Rc>>, pub(crate) external: RefCell>, } -impl NewService for AppInit +impl ServiceFactory for AppInit where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -71,8 +70,8 @@ where fn new_service(&self, cfg: &ServerConfig) -> Self::Future { // update resource default service let default = self.default.clone().unwrap_or_else(|| { - Rc::new(boxed::new_service(service_fn(|req: ServiceRequest| { - Ok(req.into_response(Response::NotFound().finish())) + Rc::new(boxed::factory(service_fn(|req: ServiceRequest| { + ok(req.into_response(Response::NotFound().finish())) }))) }); @@ -135,23 +134,25 @@ where } } +#[pin_project::pin_project] pub struct AppInitResult where - T: NewService, + T: ServiceFactory, { endpoint: Option, + #[pin] endpoint_fut: T::Future, rmap: Rc, config: AppConfig, data: Rc>>, data_factories: Vec>, - data_factories_fut: Vec, Error = ()>>>, + data_factories_fut: Vec, ()>>>, _t: PhantomData, } impl Future for AppInitResult where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -159,48 +160,49 @@ where InitError = (), >, { - type Item = AppInitService; - type Error = (); + type Output = Result, ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); - fn poll(&mut self) -> Poll { // async data factories let mut idx = 0; - while idx < self.data_factories_fut.len() { - match self.data_factories_fut[idx].poll()? { - Async::Ready(f) => { - self.data_factories.push(f); - let _ = self.data_factories_fut.remove(idx); + while idx < this.data_factories_fut.len() { + match Pin::new(&mut this.data_factories_fut[idx]).poll(cx)? { + Poll::Ready(f) => { + this.data_factories.push(f); + let _ = this.data_factories_fut.remove(idx); } - Async::NotReady => idx += 1, + Poll::Pending => idx += 1, } } - if self.endpoint.is_none() { - if let Async::Ready(srv) = self.endpoint_fut.poll()? { - self.endpoint = Some(srv); + if this.endpoint.is_none() { + if let Poll::Ready(srv) = this.endpoint_fut.poll(cx)? { + *this.endpoint = Some(srv); } } - if self.endpoint.is_some() && self.data_factories_fut.is_empty() { + if this.endpoint.is_some() && this.data_factories_fut.is_empty() { // create app data container let mut data = Extensions::new(); - for f in self.data.iter() { + for f in this.data.iter() { f.create(&mut data); } - for f in &self.data_factories { + for f in this.data_factories.iter() { f.create(&mut data); } - Ok(Async::Ready(AppInitService { - service: self.endpoint.take().unwrap(), - rmap: self.rmap.clone(), - config: self.config.clone(), + Poll::Ready(Ok(AppInitService { + service: this.endpoint.take().unwrap(), + rmap: this.rmap.clone(), + config: this.config.clone(), data: Rc::new(data), pool: HttpRequestPool::create(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -226,8 +228,8 @@ where type Error = T::Error; type Future = T::Future; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { @@ -270,7 +272,7 @@ pub struct AppRoutingFactory { default: Rc, } -impl NewService for AppRoutingFactory { +impl ServiceFactory for AppRoutingFactory { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -288,7 +290,7 @@ impl NewService for AppRoutingFactory { CreateAppRoutingItem::Future( Some(path.clone()), guards.borrow_mut().take(), - service.new_service(&()), + service.new_service(&()).boxed_local(), ) }) .collect(), @@ -298,14 +300,14 @@ impl NewService for AppRoutingFactory { } } -type HttpServiceFut = Box>; +type HttpServiceFut = LocalBoxFuture<'static, Result>; /// Create app service #[doc(hidden)] pub struct AppRoutingFactoryResponse { fut: Vec, default: Option, - default_fut: Option>>, + default_fut: Option>>, } enum CreateAppRoutingItem { @@ -314,16 +316,15 @@ enum CreateAppRoutingItem { } impl Future for AppRoutingFactoryResponse { - type Item = AppRouting; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut done = true; if let Some(ref mut fut) = self.default_fut { - match fut.poll()? { - Async::Ready(default) => self.default = Some(default), - Async::NotReady => done = false, + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, } } @@ -334,11 +335,12 @@ impl Future for AppRoutingFactoryResponse { ref mut path, ref mut guards, ref mut fut, - ) => match fut.poll()? { - Async::Ready(service) => { + ) => match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(service)) => { Some((path.take().unwrap(), guards.take(), service)) } - Async::NotReady => { + Poll::Ready(Err(_)) => return Poll::Ready(Err(())), + Poll::Pending => { done = false; None } @@ -364,13 +366,13 @@ impl Future for AppRoutingFactoryResponse { } router }); - Ok(Async::Ready(AppRouting { + Poll::Ready(Ok(AppRouting { ready: None, router: router.finish(), default: self.default.take(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -387,11 +389,11 @@ impl Service for AppRouting { type Error = Error; type Future = BoxedResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, _: &mut Context) -> Poll> { if self.ready.is_none() { - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } else { - Ok(Async::NotReady) + Poll::Pending } } @@ -413,7 +415,7 @@ impl Service for AppRouting { default.call(req) } else { let req = req.into_parts().0; - Either::A(ok(ServiceResponse::new(req, Response::NotFound().finish()))) + ok(ServiceResponse::new(req, Response::NotFound().finish())).boxed_local() } } } @@ -429,7 +431,7 @@ impl AppEntry { } } -impl NewService for AppEntry { +impl ServiceFactory for AppEntry { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -464,15 +466,16 @@ mod tests { #[test] fn drop_data() { let data = Arc::new(AtomicBool::new(false)); - { + test::block_on(async { let mut app = test::init_service( App::new() .data(DropData(data.clone())) .service(web::resource("/test").to(|| HttpResponse::Ok())), - ); + ) + .await; let req = test::TestRequest::with_uri("/test").to_request(); - let _ = test::block_on(app.call(req)).unwrap(); - } + let _ = app.call(req).await.unwrap(); + }); assert!(data.load(Ordering::Relaxed)); } } diff --git a/src/config.rs b/src/config.rs index 63fd31d27..3ce18f98b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use actix_http::Extensions; use actix_router::ResourceDef; -use actix_service::{boxed, IntoNewService, NewService}; +use actix_service::{boxed, IntoServiceFactory, ServiceFactory}; use crate::data::{Data, DataFactory}; use crate::error::Error; @@ -12,7 +12,7 @@ use crate::resource::Resource; use crate::rmap::ResourceMap; use crate::route::Route; use crate::service::{ - HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest, + AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, }; @@ -102,11 +102,11 @@ impl AppService { &mut self, rdef: ResourceDef, guards: Option>>, - service: F, + factory: F, nested: Option>, ) where - F: IntoNewService, - S: NewService< + F: IntoServiceFactory, + S: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -116,7 +116,7 @@ impl AppService { { self.services.push(( rdef, - boxed::new_service(service.into_new_service()), + boxed::factory(factory.into_factory()), guards, nested, )); @@ -174,7 +174,7 @@ impl Default for AppConfigInner { /// to set of external methods. This could help with /// modularization of big application configuration. pub struct ServiceConfig { - pub(crate) services: Vec>, + pub(crate) services: Vec>, pub(crate) data: Vec>, pub(crate) external: Vec, } @@ -251,17 +251,19 @@ mod tests { #[test] fn test_data() { - let cfg = |cfg: &mut ServiceConfig| { - cfg.data(10usize); - }; + block_on(async { + let cfg = |cfg: &mut ServiceConfig| { + cfg.data(10usize); + }; - let mut srv = - init_service(App::new().configure(cfg).service( + let mut srv = init_service(App::new().configure(cfg).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } // #[test] @@ -298,50 +300,57 @@ mod tests { #[test] fn test_external_resource() { - let mut srv = init_service( - App::new() - .configure(|cfg| { - cfg.external_resource( - "youtube", - "https://youtube.com/watch/{video_id}", - ); - }) - .route( - "/test", - web::get().to(|req: HttpRequest| { - HttpResponse::Ok().body(format!( - "{}", - req.url_for("youtube", &["12345"]).unwrap() - )) - }), - ), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); - assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); + block_on(async { + let mut srv = init_service( + App::new() + .configure(|cfg| { + cfg.external_resource( + "youtube", + "https://youtube.com/watch/{video_id}", + ); + }) + .route( + "/test", + web::get().to(|req: HttpRequest| { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("youtube", &["12345"]).unwrap() + )) + }), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); + }) } #[test] fn test_service() { - let mut srv = init_service(App::new().configure(|cfg| { - cfg.service( - web::resource("/test").route(web::get().to(|| HttpResponse::Created())), - ) - .route("/index.html", web::get().to(|| HttpResponse::Ok())); - })); + block_on(async { + let mut srv = init_service(App::new().configure(|cfg| { + cfg.service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Created())), + ) + .route("/index.html", web::get().to(|| HttpResponse::Ok())); + })) + .await; - let req = TestRequest::with_uri("/test") - .method(Method::GET) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/test") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); - let req = TestRequest::with_uri("/index.html") - .method(Method::GET) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/index.html") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } } diff --git a/src/data.rs b/src/data.rs index 14e293bc2..a026946aa 100644 --- a/src/data.rs +++ b/src/data.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use actix_http::error::{Error, ErrorInternalServerError}; use actix_http::Extensions; +use futures::future::{err, ok, Ready}; use crate::dev::Payload; use crate::extract::FromRequest; @@ -44,7 +45,7 @@ pub(crate) trait DataFactory { /// } /// /// /// Use `Data` extractor to access data in handler. -/// fn index(data: web::Data>) { +/// async fn index(data: web::Data>) { /// let mut data = data.lock().unwrap(); /// data.counter += 1; /// } @@ -101,19 +102,19 @@ impl Clone for Data { impl FromRequest for Data { type Config = (); type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { if let Some(st) = req.get_app_data::() { - Ok(st) + ok(st) } else { log::debug!( "Failed to construct App-level Data extractor. \ Request path: {:?}", req.path() ); - Err(ErrorInternalServerError( + err(ErrorInternalServerError( "App data is not configured, to configure use App::data()", )) } @@ -142,85 +143,99 @@ mod tests { #[test] fn test_data_extractor() { - let mut srv = - init_service(App::new().data(10usize).service( + block_on(async { + let mut srv = init_service(App::new().data(10usize).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + )) + .await; - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let mut srv = - init_service(App::new().data(10u32).service( + let mut srv = init_service(App::new().data(10u32).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + }) } #[test] fn test_register_data_extractor() { - let mut srv = - init_service(App::new().register_data(Data::new(10usize)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + block_on(async { + let mut srv = + init_service(App::new().register_data(Data::new(10usize)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let mut srv = - init_service(App::new().register_data(Data::new(10u32)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + let mut srv = + init_service(App::new().register_data(Data::new(10u32)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + }) } #[test] fn test_route_data_extractor() { - let mut srv = - init_service(App::new().service(web::resource("/").data(10usize).route( - web::get().to(|data: web::Data| { - let _ = data.clone(); - HttpResponse::Ok() - }), - ))); + block_on(async { + let mut srv = init_service(App::new().service( + web::resource("/").data(10usize).route(web::get().to( + |data: web::Data| { + let _ = data.clone(); + HttpResponse::Ok() + }, + )), + )) + .await; - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - // different type - let mut srv = init_service( - App::new().service( - web::resource("/") - .data(10u32) - .route(web::get().to(|_: web::Data| HttpResponse::Ok())), - ), - ); - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + // different type + let mut srv = init_service( + App::new().service( + web::resource("/") + .data(10u32) + .route(web::get().to(|_: web::Data| HttpResponse::Ok())), + ), + ) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + }) } #[test] fn test_override_data() { - let mut srv = init_service(App::new().data(1usize).service( - web::resource("/").data(10usize).route(web::get().to( - |data: web::Data| { - assert_eq!(*data, 10); - let _ = data.clone(); - HttpResponse::Ok() - }, - )), - )); + block_on(async { + let mut srv = init_service(App::new().data(1usize).service( + web::resource("/").data(10usize).route(web::get().to( + |data: web::Data| { + assert_eq!(*data, 10); + let _ = data.clone(); + HttpResponse::Ok() + }, + )), + )) + .await; - let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } } diff --git a/src/extract.rs b/src/extract.rs index 1687973ac..9c8633368 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,8 +1,10 @@ //! Request extractors +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_http::error::Error; -use futures::future::ok; -use futures::{future, Async, Future, IntoFuture, Poll}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use crate::dev::Payload; use crate::request::HttpRequest; @@ -15,7 +17,7 @@ pub trait FromRequest: Sized { type Error: Into; /// Future that resolves to a Self - type Future: IntoFuture; + type Future: Future>; /// Configuration for this extractor type Config: Default + 'static; @@ -46,9 +48,10 @@ pub trait FromRequest: Sized { /// ## Example /// /// ```rust -/// # #[macro_use] extern crate serde_derive; /// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest}; /// use actix_web::error::ErrorBadRequest; +/// use futures::future::{ok, err, Ready}; +/// use serde_derive::Deserialize; /// use rand; /// /// #[derive(Debug, Deserialize)] @@ -58,21 +61,21 @@ pub trait FromRequest: Sized { /// /// impl FromRequest for Thing { /// type Error = Error; -/// type Future = Result; +/// type Future = Ready>; /// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { -/// Ok(Thing { name: "thingy".into() }) +/// ok(Thing { name: "thingy".into() }) /// } else { -/// Err(ErrorBadRequest("no luck")) +/// err(ErrorBadRequest("no luck")) /// } /// /// } /// } /// /// /// extract `Thing` from request -/// fn index(supplied_thing: Option) -> String { +/// async fn index(supplied_thing: Option) -> String { /// match supplied_thing { /// // Puns not intended /// Some(thing) => format!("Got something: {:?}", thing), @@ -94,21 +97,19 @@ where { type Config = T::Config; type Error = Error; - type Future = Box, Error = Error>>; + type Future = LocalBoxFuture<'static, Result, Error>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - Box::new( - T::from_request(req, payload) - .into_future() - .then(|r| match r { - Ok(v) => future::ok(Some(v)), - Err(e) => { - log::debug!("Error for Option extractor: {}", e.into()); - future::ok(None) - } - }), - ) + T::from_request(req, payload) + .then(|r| match r { + Ok(v) => ok(Some(v)), + Err(e) => { + log::debug!("Error for Option extractor: {}", e.into()); + ok(None) + } + }) + .boxed_local() } } @@ -119,9 +120,10 @@ where /// ## Example /// /// ```rust -/// # #[macro_use] extern crate serde_derive; /// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest}; /// use actix_web::error::ErrorBadRequest; +/// use futures::future::{ok, err, Ready}; +/// use serde_derive::Deserialize; /// use rand; /// /// #[derive(Debug, Deserialize)] @@ -131,20 +133,20 @@ where /// /// impl FromRequest for Thing { /// type Error = Error; -/// type Future = Result; +/// type Future = Ready>; /// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { -/// Ok(Thing { name: "thingy".into() }) +/// ok(Thing { name: "thingy".into() }) /// } else { -/// Err(ErrorBadRequest("no luck")) +/// err(ErrorBadRequest("no luck")) /// } /// } /// } /// /// /// extract `Thing` from request -/// fn index(supplied_thing: Result) -> String { +/// async fn index(supplied_thing: Result) -> String { /// match supplied_thing { /// Ok(thing) => format!("Got thing: {:?}", thing), /// Err(e) => format!("Error extracting thing: {}", e) @@ -157,26 +159,24 @@ where /// ); /// } /// ``` -impl FromRequest for Result +impl FromRequest for Result where - T: FromRequest, - T::Future: 'static, + T: FromRequest + 'static, T::Error: 'static, + T::Future: 'static, { type Config = T::Config; type Error = Error; - type Future = Box, Error = Error>>; + type Future = LocalBoxFuture<'static, Result, Error>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - Box::new( - T::from_request(req, payload) - .into_future() - .then(|res| match res { - Ok(v) => ok(Ok(v)), - Err(e) => ok(Err(e)), - }), - ) + T::from_request(req, payload) + .then(|res| match res { + Ok(v) => ok(Ok(v)), + Err(e) => ok(Err(e)), + }) + .boxed_local() } } @@ -184,10 +184,10 @@ where impl FromRequest for () { type Config = (); type Error = Error; - type Future = Result<(), Error>; + type Future = Ready>; fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(()) + ok(()) } } @@ -204,43 +204,44 @@ macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { $fut_type { items: <($(Option<$T>,)+)>::default(), - futs: ($($T::from_request(req, payload).into_future(),)+), + futs: ($($T::from_request(req, payload),)+), } } } #[doc(hidden)] + #[pin_project::pin_project] pub struct $fut_type<$($T: FromRequest),+> { items: ($(Option<$T>,)+), - futs: ($(<$T::Future as futures::IntoFuture>::Future,)+), + futs: ($($T::Future,)+), } impl<$($T: FromRequest),+> Future for $fut_type<$($T),+> { - type Item = ($($T,)+); - type Error = Error; + type Output = Result<($($T,)+), Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); - fn poll(&mut self) -> Poll { let mut ready = true; - $( - if self.items.$n.is_none() { - match self.futs.$n.poll() { - Ok(Async::Ready(item)) => { - self.items.$n = Some(item); + if this.items.$n.is_none() { + match unsafe { Pin::new_unchecked(&mut this.futs.$n) }.poll(cx) { + Poll::Ready(Ok(item)) => { + this.items.$n = Some(item); } - Ok(Async::NotReady) => ready = false, - Err(e) => return Err(e.into()), + Poll::Pending => ready = false, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), } } )+ if ready { - Ok(Async::Ready( - ($(self.items.$n.take().unwrap(),)+) + Poll::Ready(Ok( + ($(this.items.$n.take().unwrap(),)+) )) } else { - Ok(Async::NotReady) + Poll::Pending } } } diff --git a/src/guard.rs b/src/guard.rs index e0b4055ba..3db525f9a 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,16 +1,16 @@ //! Route match guards. //! -//! Guards are one of the way how actix-web router chooses -//! handler service. In essence it just function that accepts -//! reference to a `RequestHead` instance and returns boolean. +//! Guards are one of the ways how actix-web router chooses a +//! handler service. In essence it is just a function that accepts a +//! reference to a `RequestHead` instance and returns a boolean. //! It is possible to add guards to *scopes*, *resources* //! and *routes*. Actix provide several guards by default, like various //! http methods, header, etc. To become a guard, type must implement `Guard` //! trait. Simple functions coulds guards as well. //! -//! Guard can not modify request object. But it is possible to -//! to store extra attributes on a request by using `Extensions` container. -//! Extensions container available via `RequestHead::extensions()` method. +//! Guards can not modify the request object. But it is possible +//! to store extra attributes on a request by using the `Extensions` container. +//! Extensions containers are available via the `RequestHead::extensions()` method. //! //! ```rust //! use actix_web::{web, http, dev, guard, App, HttpResponse}; @@ -29,11 +29,11 @@ use actix_http::http::{self, header, uri::Uri, HttpTryFrom}; use actix_http::RequestHead; -/// Trait defines resource guards. Guards are used for routes selection. +/// Trait defines resource guards. Guards are used for route selection. /// -/// Guard can not modify request object. But it is possible to -/// to store extra attributes on request by using `Extensions` container, -/// Extensions container available via `RequestHead::extensions()` method. +/// Guards can not modify the request object. But it is possible +/// to store extra attributes on a request by using the `Extensions` container. +/// Extensions containers are available via the `RequestHead::extensions()` method. pub trait Guard { /// Check if request matches predicate fn check(&self, request: &RequestHead) -> bool; @@ -276,10 +276,12 @@ pub fn Host>(host: H) -> HostGuard { fn get_host_uri(req: &RequestHead) -> Option { use core::str::FromStr; - let host_value = req.headers.get(header::HOST)?; - let host = host_value.to_str().ok()?; - let uri = Uri::from_str(host).ok()?; - Some(uri) + req.headers + .get(header::HOST) + .and_then(|host_value| host_value.to_str().ok()) + .or_else(|| req.uri.host()) + .map(|host: &str| Uri::from_str(host).ok()) + .and_then(|host_success| host_success) } #[doc(hidden)] @@ -400,6 +402,31 @@ mod tests { assert!(!pred.check(req.head())); } + #[test] + fn test_host_without_header() { + let req = TestRequest::default() + .uri("www.rust-lang.org") + .to_http_request(); + + let pred = Host("www.rust-lang.org"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org").scheme("https"); + assert!(pred.check(req.head())); + + let pred = Host("blog.rust-lang.org"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("crates.io"); + assert!(!pred.check(req.head())); + + let pred = Host("localhost"); + assert!(!pred.check(req.head())); + } + #[test] fn test_methods() { let req = TestRequest::default().to_http_request(); diff --git a/src/handler.rs b/src/handler.rs index 078abbf1d..a7023422b 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,28 +1,34 @@ use std::convert::Infallible; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_http::{Error, Response}; -use actix_service::{NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, Ready}; +use futures::ready; +use pin_project::pin_project; use crate::extract::FromRequest; use crate::request::HttpRequest; use crate::responder::Responder; use crate::service::{ServiceRequest, ServiceResponse}; -/// Handler converter factory -pub trait Factory: Clone +/// Async handler converter factory +pub trait Factory: Clone + 'static where - R: Responder, + R: Future, + O: Responder, { fn call(&self, param: T) -> R; } -impl Factory<(), R> for F +impl Factory<(), R, O> for F where - F: Fn() -> R + Clone, - R: Responder, + F: Fn() -> R + Clone + 'static, + R: Future, + O: Responder, { fn call(&self, _: ()) -> R { (self)() @@ -30,19 +36,21 @@ where } #[doc(hidden)] -pub struct Handler +pub struct Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { hnd: F, - _t: PhantomData<(T, R)>, + _t: PhantomData<(T, R, O)>, } -impl Handler +impl Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { pub fn new(hnd: F) -> Self { Handler { @@ -52,156 +60,38 @@ where } } -impl Clone for Handler +impl Clone for Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { fn clone(&self) -> Self { - Self { + Handler { hnd: self.hnd.clone(), _t: PhantomData, } } } -impl Service for Handler +impl Service for Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { type Request = (T, HttpRequest); type Response = ServiceResponse; type Error = Infallible; - type Future = HandlerServiceResponse<::Future>; + type Future = HandlerServiceResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { - let fut = self.hnd.call(param).respond_to(&req).into_future(); HandlerServiceResponse { - fut, - req: Some(req), - } - } -} - -pub struct HandlerServiceResponse { - fut: T, - req: Option, -} - -impl Future for HandlerServiceResponse -where - T: Future, - T::Error: Into, -{ - type Item = ServiceResponse; - type Error = Infallible; - - fn poll(&mut self) -> Poll { - match self.fut.poll() { - Ok(Async::Ready(res)) => Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - let res: Response = e.into().into(); - Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))) - } - } - } -} - -/// Async handler converter factory -pub trait AsyncFactory: Clone + 'static -where - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - fn call(&self, param: T) -> R; -} - -impl AsyncFactory<(), R> for F -where - F: Fn() -> R + Clone + 'static, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - fn call(&self, _: ()) -> R { - (self)() - } -} - -#[doc(hidden)] -pub struct AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - hnd: F, - _t: PhantomData<(T, R)>, -} - -impl AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - pub fn new(hnd: F) -> Self { - AsyncHandler { - hnd, - _t: PhantomData, - } - } -} - -impl Clone for AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - fn clone(&self) -> Self { - AsyncHandler { - hnd: self.hnd.clone(), - _t: PhantomData, - } - } -} - -impl Service for AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - type Request = (T, HttpRequest); - type Response = ServiceResponse; - type Error = Infallible; - type Future = AsyncHandlerServiceResponse; - - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) - } - - fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { - AsyncHandlerServiceResponse { - fut: self.hnd.call(param).into_future(), + fut: self.hnd.call(param), fut2: None, req: Some(req), } @@ -209,57 +99,49 @@ where } #[doc(hidden)] -pub struct AsyncHandlerServiceResponse +#[pin_project] +pub struct HandlerServiceResponse where - T: Future, - T::Item: Responder, + T: Future, + R: Responder, { + #[pin] fut: T, - fut2: Option<<::Future as IntoFuture>::Future>, + #[pin] + fut2: Option, req: Option, } -impl Future for AsyncHandlerServiceResponse +impl Future for HandlerServiceResponse where - T: Future, - T::Item: Responder, - T::Error: Into, + T: Future, + R: Responder, { - type Item = ServiceResponse; - type Error = Infallible; + type Output = Result; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut2 { - return match fut.poll() { - Ok(Async::Ready(res)) => Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.as_mut().project(); + + if let Some(fut) = this.fut2.as_pin_mut() { + return match fut.poll(cx) { + Poll::Ready(Ok(res)) => { + Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + } + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { let res: Response = e.into().into(); - Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))) + Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) } }; } - match self.fut.poll() { - Ok(Async::Ready(res)) => { - self.fut2 = - Some(res.respond_to(self.req.as_ref().unwrap()).into_future()); - self.poll() - } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - let res: Response = e.into().into(); - Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))) + match this.fut.poll(cx) { + Poll::Ready(res) => { + let fut = res.respond_to(this.req.as_ref().unwrap()); + self.as_mut().project().fut2.set(Some(fut)); + self.poll(cx) } + Poll::Pending => Poll::Pending, } } } @@ -279,7 +161,7 @@ impl Extract { } } -impl NewService for Extract +impl ServiceFactory for Extract where S: Service< Request = (T, HttpRequest), @@ -293,7 +175,7 @@ where type Error = (Error, ServiceRequest); type InitError = (); type Service = ExtractService; - type Future = FutureResult; + type Future = Ready>; fn new_service(&self, _: &()) -> Self::Future { ok(ExtractService { @@ -321,13 +203,13 @@ where type Error = (Error, ServiceRequest); type Future = ExtractResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let (req, mut payload) = req.into_parts(); - let fut = T::from_request(&req, &mut payload).into_future(); + let fut = T::from_request(&req, &mut payload); ExtractResponse { fut, @@ -338,10 +220,13 @@ where } } +#[pin_project] pub struct ExtractResponse { req: HttpRequest, service: S, - fut: ::Future, + #[pin] + fut: T::Future, + #[pin] fut_s: Option, } @@ -353,40 +238,35 @@ where Error = Infallible, >, { - type Item = ServiceResponse; - type Error = (Error, ServiceRequest); + type Output = Result; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut_s { - return fut.poll().map_err(|_| panic!()); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.as_mut().project(); + + if let Some(fut) = this.fut_s.as_pin_mut() { + return fut.poll(cx).map_err(|_| panic!()); } - let item = try_ready!(self.fut.poll().map_err(|e| { - let req = ServiceRequest::new(self.req.clone()); - (e.into(), req) - })); - - self.fut_s = Some(self.service.call((item, self.req.clone()))); - self.poll() + match ready!(this.fut.poll(cx)) { + Err(e) => { + let req = ServiceRequest::new(this.req.clone()); + Poll::Ready(Err((e.into(), req))) + } + Ok(item) => { + let fut = Some(this.service.call((item, this.req.clone()))); + self.as_mut().project().fut_s.set(fut); + self.poll(cx) + } + } } } /// FromRequest trait impl for tuples macro_rules! factory_tuple ({ $(($n:tt, $T:ident)),+} => { - impl Factory<($($T,)+), Res> for Func - where Func: Fn($($T,)+) -> Res + Clone, - Res: Responder, - { - fn call(&self, param: ($($T,)+)) -> Res { - (self)($(param.$n,)+) - } - } - - impl AsyncFactory<($($T,)+), Res> for Func + impl Factory<($($T,)+), Res, O> for Func where Func: Fn($($T,)+) -> Res + Clone + 'static, - Res: IntoFuture, - Res::Item: Responder, - Res::Error: Into, + Res: Future, + O: Responder, { fn call(&self, param: ($($T,)+)) -> Res { (self)($(param.$n,)+) diff --git a/src/info.rs b/src/info.rs index 61914516e..a9c3e4eeb 100644 --- a/src/info.rs +++ b/src/info.rs @@ -162,6 +162,12 @@ impl ConnectionInfo { /// - Forwarded /// - X-Forwarded-For /// - peer name of opened socket + /// + /// # Security + /// Do not use this function for security purposes, unless you can ensure the Forwarded and + /// X-Forwarded-For headers cannot be spoofed by the client. If you want the client's socket + /// address explicitly, use + /// [`HttpRequest::peer_addr()`](../web/struct.HttpRequest.html#method.peer_addr) instead. #[inline] pub fn remote(&self) -> Option<&str> { if let Some(ref r) = self.remote { diff --git a/src/lib.rs b/src/lib.rs index 60c34489e..8063d0d35 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![allow(clippy::borrow_interior_mutable_const)] +#![allow(clippy::borrow_interior_mutable_const, unused_imports, dead_code)] //! Actix web is a small, pragmatic, and extremely fast web framework //! for Rust. //! @@ -6,7 +6,7 @@ //! use actix_web::{web, App, Responder, HttpServer}; //! # use std::thread; //! -//! fn index(info: web::Path<(String, u32)>) -> impl Responder { +//! async fn index(info: web::Path<(String, u32)>) -> impl Responder { //! format!("Hello {}! id:{}", info.0, info.1) //! } //! @@ -68,8 +68,8 @@ //! ## Package feature //! //! * `client` - enables http client (default enabled) -//! * `ssl` - enables ssl support via `openssl` crate, supports `http/2` -//! * `rust-tls` - enables ssl support via `rustls` crate, supports `http/2` +//! * `openssl` - enables ssl support via `openssl` crate, supports `http/2` +//! * `rustls` - enables ssl support via `rustls` crate, supports `http/2` //! * `secure-cookies` - enables secure cookies support, includes `ring` crate as //! dependency //! * `brotli` - enables `brotli` compression support, requires `c` @@ -78,7 +78,6 @@ //! `c` compiler (default enabled) //! * `flate2-rust` - experimental rust based implementation for //! `gzip`, `deflate` compression. -//! * `uds` - Unix domain support, enables `HttpServer::bind_uds()` method. //! #![allow(clippy::type_complexity, clippy::new_without_default)] @@ -137,15 +136,16 @@ pub mod dev { pub use crate::config::{AppConfig, AppService}; #[doc(hidden)] - pub use crate::handler::{AsyncFactory, Factory}; + pub use crate::handler::Factory; pub use crate::info::ConnectionInfo; pub use crate::rmap::ResourceMap; pub use crate::service::{ HttpServiceFactory, ServiceRequest, ServiceResponse, WebService, }; - pub use crate::types::form::UrlEncoded; - pub use crate::types::json::JsonBody; - pub use crate::types::readlines::Readlines; + + //pub use crate::types::form::UrlEncoded; + //pub use crate::types::json::JsonBody; + //pub use crate::types::readlines::Readlines; pub use actix_http::body::{Body, BodySize, MessageBody, ResponseBody, SizedStream}; pub use actix_http::encoding::Decoder as Decompress; @@ -171,23 +171,20 @@ pub mod client { //! An HTTP Client //! //! ```rust - //! # use futures::future::{Future, lazy}; //! use actix_rt::System; //! use actix_web::client::Client; //! //! fn main() { - //! System::new("test").block_on(lazy(|| { + //! System::new("test").block_on(async { //! let mut client = Client::default(); //! - //! client.get("http://www.rust-lang.org") // <- Create request builder + //! // Create request builder and send request + //! let response = client.get("http://www.rust-lang.org") //! .header("User-Agent", "Actix-web") - //! .send() // <- Send http request - //! .map_err(|_| ()) - //! .and_then(|response| { // <- server http response - //! println!("Response: {:?}", response); - //! Ok(()) - //! }) - //! })); + //! .send().await; // <- Send http request + //! + //! println!("Response: {:?}", response); + //! }); //! } //! ``` pub use awc::error::{ diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs index 86665d824..a697deaec 100644 --- a/src/middleware/compress.rs +++ b/src/middleware/compress.rs @@ -1,15 +1,18 @@ //! `Middleware` for compressing response body. use std::cmp; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::str::FromStr; +use std::task::{Context, Poll}; use actix_http::body::MessageBody; use actix_http::encoding::Encoder; use actix_http::http::header::{ContentEncoding, ACCEPT_ENCODING}; use actix_http::{Error, Response, ResponseBuilder}; use actix_service::{Service, Transform}; -use futures::future::{ok, FutureResult}; -use futures::{Async, Future, Poll}; +use futures::future::{ok, Ready}; +use pin_project::pin_project; use crate::service::{ServiceRequest, ServiceResponse}; @@ -78,7 +81,7 @@ where type Error = Error; type InitError = (); type Transform = CompressMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CompressMiddleware { @@ -103,8 +106,8 @@ where type Error = Error; type Future = CompressResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { @@ -128,11 +131,13 @@ where } #[doc(hidden)] +#[pin_project] pub struct CompressResponse where S: Service, B: MessageBody, { + #[pin] fut: S::Future, encoding: ContentEncoding, _t: PhantomData<(B)>, @@ -143,21 +148,25 @@ where B: MessageBody, S: Service, Error = Error>, { - type Item = ServiceResponse>; - type Error = Error; + type Output = Result>, Error>; - fn poll(&mut self) -> Poll { - let resp = futures::try_ready!(self.fut.poll()); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); - let enc = if let Some(enc) = resp.response().extensions().get::() { - enc.0 - } else { - self.encoding - }; + match futures::ready!(this.fut.poll(cx)) { + Ok(resp) => { + let enc = if let Some(enc) = resp.response().extensions().get::() { + enc.0 + } else { + *this.encoding + }; - Ok(Async::Ready(resp.map_body(move |head, body| { - Encoder::response(enc, head, body) - }))) + Poll::Ready(Ok( + resp.map_body(move |head, body| Encoder::response(enc, head, body)) + )) + } + Err(e) => Poll::Ready(Err(e)), + } } } diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index ddc5fdd42..6603fc001 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -1,7 +1,8 @@ //! `Middleware` for conditionally enables another middleware. +use std::task::{Context, Poll}; + use actix_service::{Service, Transform}; -use futures::future::{ok, Either, FutureResult, Map}; -use futures::{Future, Poll}; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; /// `Middleware` for conditionally enables another middleware. /// The controled middleware must not change the `Service` interfaces. @@ -13,11 +14,11 @@ use futures::{Future, Poll}; /// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::App; /// -/// fn main() { -/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); -/// let app = App::new() -/// .wrap(Condition::new(enable_normalize, NormalizePath)); -/// } +/// # fn main() { +/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); +/// let app = App::new() +/// .wrap(Condition::new(enable_normalize, NormalizePath)); +/// # } /// ``` pub struct Condition { trans: T, @@ -32,29 +33,31 @@ impl Condition { impl Transform for Condition where - S: Service, + S: Service + 'static, T: Transform, + T::Future: 'static, + T::InitError: 'static, + T::Transform: 'static, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type InitError = T::InitError; type Transform = ConditionMiddleware; - type Future = Either< - Map Self::Transform>, - FutureResult, - >; + type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { if self.enable { - let f = self - .trans - .new_transform(service) - .map(ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform); - Either::A(f) + let f = self.trans.new_transform(service).map(|res| { + res.map( + ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform, + ) + }); + Either::Left(f) } else { - Either::B(ok(ConditionMiddleware::Disable(service))) + Either::Right(ok(ConditionMiddleware::Disable(service))) } + .boxed_local() } } @@ -73,19 +76,19 @@ where type Error = E::Error; type Future = Either; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { use ConditionMiddleware::*; match self { - Enable(service) => service.poll_ready(), - Disable(service) => service.poll_ready(), + Enable(service) => service.poll_ready(cx), + Disable(service) => service.poll_ready(cx), } } fn call(&mut self, req: E::Request) -> Self::Future { use ConditionMiddleware::*; match self { - Enable(service) => Either::A(service.call(req)), - Disable(service) => Either::B(service.call(req)), + Enable(service) => Either::Left(service.call(req)), + Disable(service) => Either::Right(service.call(req)), } } } @@ -99,7 +102,7 @@ mod tests { use crate::error::Result; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; use crate::middleware::errhandlers::*; - use crate::test::{self, TestRequest}; + use crate::test::{self, block_on, TestRequest}; use crate::HttpResponse; fn render_500(mut res: ServiceResponse) -> Result> { @@ -111,33 +114,44 @@ mod tests { #[test] fn test_handler_enabled() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mw = - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = - test::block_on(Condition::new(true, mw).new_transform(srv.into_service())) + let mut mw = Condition::new(true, mw) + .new_transform(srv.into_service()) + .await .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + }) } + #[test] fn test_handler_disabled() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mw = - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = - test::block_on(Condition::new(false, mw).new_transform(srv.into_service())) + let mut mw = Condition::new(false, mw) + .new_transform(srv.into_service()) + .await .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE), None); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE), None); + }) } } diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index ab2d36c2c..5c995503a 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -1,9 +1,11 @@ //! Middleware for setting default response headers +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures::future::{ok, FutureResult}; -use futures::{Future, Poll}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::{HeaderMap, HttpTryFrom}; @@ -96,7 +98,7 @@ where type Error = Error; type InitError = (); type Transform = DefaultHeadersMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(DefaultHeadersMiddleware { @@ -119,16 +121,19 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let inner = self.inner.clone(); + let fut = self.service.call(req); + + async move { + let mut res = fut.await?; - Box::new(self.service.call(req).map(move |mut res| { // set response headers for (key, value) in inner.headers.iter() { if !res.headers().contains_key(key) { @@ -142,15 +147,16 @@ where HeaderValue::from_static("application/octet-stream"), ); } - - res - })) + Ok(res) + } + .boxed_local() } } #[cfg(test)] mod tests { use actix_service::IntoService; + use futures::future::ok; use super::*; use crate::dev::ServiceRequest; @@ -160,46 +166,50 @@ mod tests { #[test] fn test_default_headers() { - let mut mw = block_on( - DefaultHeaders::new() + block_on(async { + let mut mw = DefaultHeaders::new() .header(CONTENT_TYPE, "0001") - .new_transform(ok_service()), - ) - .unwrap(); + .new_transform(ok_service()) + .await + .unwrap(); - let req = TestRequest::default().to_srv_request(); - let resp = block_on(mw.call(req)).unwrap(); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let req = TestRequest::default().to_srv_request(); + let resp = mw.call(req).await.unwrap(); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); - let req = TestRequest::default().to_srv_request(); - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish()) - }; - let mut mw = block_on( - DefaultHeaders::new() + let req = TestRequest::default().to_srv_request(); + let srv = |req: ServiceRequest| { + ok(req.into_response( + HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish(), + )) + }; + let mut mw = DefaultHeaders::new() .header(CONTENT_TYPE, "0001") - .new_transform(srv.into_service()), - ) - .unwrap(); - let resp = block_on(mw.call(req)).unwrap(); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = mw.call(req).await.unwrap(); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); + }) } #[test] fn test_content_type() { - let srv = |req: ServiceRequest| req.into_response(HttpResponse::Ok().finish()); - let mut mw = block_on( - DefaultHeaders::new() + block_on(async { + let srv = + |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); + let mut mw = DefaultHeaders::new() .content_type() - .new_transform(srv.into_service()), - ) - .unwrap(); + .new_transform(srv.into_service()) + .await + .unwrap(); - let req = TestRequest::default().to_srv_request(); - let resp = block_on(mw.call(req)).unwrap(); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - "application/octet-stream" - ); + let req = TestRequest::default().to_srv_request(); + let resp = mw.call(req).await.unwrap(); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + "application/octet-stream" + ); + }) } } diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index 5f73d4d7e..c8a702857 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -1,9 +1,9 @@ //! Custom handlers service for responses. use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures::future::{err, ok, Either, Future, FutureResult}; -use futures::Poll; +use futures::future::{err, ok, Either, Future, FutureExt, LocalBoxFuture, Ready}; use hashbrown::HashMap; use crate::dev::{ServiceRequest, ServiceResponse}; @@ -15,7 +15,7 @@ pub enum ErrorHandlerResponse { /// New http response got generated Response(ServiceResponse), /// Result is a future that resolves to a new http response - Future(Box, Error = Error>>), + Future(LocalBoxFuture<'static, Result, Error>>), } type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; @@ -39,17 +39,17 @@ type ErrorHandler = dyn Fn(ServiceResponse) -> Result { handlers: Rc>>>, @@ -92,7 +92,7 @@ where type Error = Error; type InitError = (); type Transform = ErrorHandlersMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(ErrorHandlersMiddleware { @@ -117,26 +117,30 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let handlers = self.handlers.clone(); + let fut = self.service.call(req); + + async move { + let res = fut.await?; - Box::new(self.service.call(req).and_then(move |res| { if let Some(handler) = handlers.get(&res.status()) { match handler(res) { - Ok(ErrorHandlerResponse::Response(res)) => Either::A(ok(res)), - Ok(ErrorHandlerResponse::Future(fut)) => Either::B(fut), - Err(e) => Either::A(err(e)), + Ok(ErrorHandlerResponse::Response(res)) => Ok(res), + Ok(ErrorHandlerResponse::Future(fut)) => fut.await, + Err(e) => Err(e), } } else { - Either::A(ok(res)) + Ok(res) } - })) + } + .boxed_local() } } @@ -147,7 +151,7 @@ mod tests { use super::*; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; - use crate::test::{self, TestRequest}; + use crate::test::{self, block_on, TestRequest}; use crate::HttpResponse; fn render_500(mut res: ServiceResponse) -> Result> { @@ -159,19 +163,22 @@ mod tests { #[test] fn test_handler() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mut mw = test::block_on( - ErrorHandlers::new() + let mut mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) - .new_transform(srv.into_service()), - ) - .unwrap(); + .new_transform(srv.into_service()) + .await + .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + }) } fn render_500_async( @@ -180,23 +187,26 @@ mod tests { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Future(Box::new(ok(res)))) + Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) } #[test] fn test_handler_async() { - let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; - let mut mw = test::block_on( - ErrorHandlers::new() + let mut mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) - .new_transform(srv.into_service()), - ) - .unwrap(); + .new_transform(srv.into_service()) + .await + .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()) + .await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + }) } } diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index f450f0481..45df4bf34 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -2,13 +2,15 @@ use std::collections::HashSet; use std::env; use std::fmt::{self, Display, Formatter}; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use bytes::Bytes; -use futures::future::{ok, FutureResult}; -use futures::{Async, Future, Poll}; +use futures::future::{ok, Ready}; use log::debug; use regex::Regex; use time; @@ -125,7 +127,7 @@ where type Error = Error; type InitError = (); type Transform = LoggerMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(LoggerMiddleware { @@ -151,8 +153,8 @@ where type Error = Error; type Future = LoggerResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { @@ -181,11 +183,13 @@ where } #[doc(hidden)] +#[pin_project::pin_project] pub struct LoggerResponse where B: MessageBody, S: Service, { + #[pin] fut: S::Future, time: time::Tm, format: Option, @@ -197,11 +201,15 @@ where B: MessageBody, S: Service, Error = Error>, { - type Item = ServiceResponse>; - type Error = Error; + type Output = Result>, Error>; - fn poll(&mut self) -> Poll { - let res = futures::try_ready!(self.fut.poll()); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + let res = match futures::ready!(this.fut.poll(cx)) { + Ok(res) => res, + Err(e) => return Poll::Ready(Err(e)), + }; if let Some(error) = res.response().error() { if res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR { @@ -209,18 +217,21 @@ where } } - if let Some(ref mut format) = self.format { + if let Some(ref mut format) = this.format { for unit in &mut format.0 { unit.render_response(res.response()); } } - Ok(Async::Ready(res.map_body(move |_, body| { + let time = *this.time; + let format = this.format.take(); + + Poll::Ready(Ok(res.map_body(move |_, body| { ResponseBody::Body(StreamLog { body, + time, + format, size: 0, - time: self.time, - format: self.format.take(), }) }))) } @@ -252,13 +263,13 @@ impl MessageBody for StreamLog { self.body.size() } - fn poll_next(&mut self) -> Poll, Error> { - match self.body.poll_next()? { - Async::Ready(Some(chunk)) => { + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + match self.body.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { self.size += chunk.len(); - Ok(Async::Ready(Some(chunk))) + Poll::Ready(Some(Ok(chunk))) } - val => Ok(val), + val => val, } } } @@ -464,6 +475,7 @@ impl<'a> fmt::Display for FormatDisplay<'a> { #[cfg(test)] mod tests { use actix_service::{IntoService, Service, Transform}; + use futures::future::ok; use super::*; use crate::http::{header, StatusCode}; @@ -472,11 +484,11 @@ mod tests { #[test] fn test_logger() { let srv = |req: ServiceRequest| { - req.into_response( + ok(req.into_response( HttpResponse::build(StatusCode::OK) .header("X-Test", "ttt") .finish(), - ) + )) }; let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index 9cfbefb30..b7eb1384a 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -1,9 +1,10 @@ //! `Middleware` to normalize request's URI +use std::task::{Context, Poll}; use actix_http::http::{HttpTryFrom, PathAndQuery, Uri}; use actix_service::{Service, Transform}; use bytes::Bytes; -use futures::future::{self, FutureResult}; +use futures::future::{ok, Ready}; use regex::Regex; use crate::service::{ServiceRequest, ServiceResponse}; @@ -19,15 +20,15 @@ use crate::Error; /// ```rust /// use actix_web::{web, http, middleware, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new() -/// .wrap(middleware::NormalizePath) -/// .service( -/// web::resource("/test") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// # fn main() { +/// let app = App::new() +/// .wrap(middleware::NormalizePath) +/// .service( +/// web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// # } /// ``` pub struct NormalizePath; @@ -42,10 +43,10 @@ where type Error = Error; type InitError = (); type Transform = NormalizePathNormalization; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - future::ok(NormalizePathNormalization { + ok(NormalizePathNormalization { service, merge_slash: Regex::new("//+").unwrap(), }) @@ -67,8 +68,8 @@ where type Error = Error; type Future = S::Future; - fn poll_ready(&mut self) -> futures::Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { @@ -109,46 +110,57 @@ mod tests { #[test] fn test_wrap() { - let mut app = init_service( - App::new() - .wrap(NormalizePath::default()) - .service(web::resource("/v1/something/").to(|| HttpResponse::Ok())), - ); + block_on(async { + let mut app = init_service( + App::new() + .wrap(NormalizePath::default()) + .service(web::resource("/v1/something/").to(|| HttpResponse::Ok())), + ) + .await; - let req = TestRequest::with_uri("/v1//something////").to_request(); - let res = call_service(&mut app, req); - assert!(res.status().is_success()); + let req = TestRequest::with_uri("/v1//something////").to_request(); + let res = call_service(&mut app, req).await; + assert!(res.status().is_success()); + }) } #[test] fn test_in_place_normalization() { - let srv = |req: ServiceRequest| { - assert_eq!("/v1/something/", req.path()); - req.into_response(HttpResponse::Ok().finish()) - }; + block_on(async { + let srv = |req: ServiceRequest| { + assert_eq!("/v1/something/", req.path()); + ok(req.into_response(HttpResponse::Ok().finish())) + }; - let mut normalize = - block_on(NormalizePath.new_transform(srv.into_service())).unwrap(); + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); - let req = TestRequest::with_uri("/v1//something////").to_srv_request(); - let res = block_on(normalize.call(req)).unwrap(); - assert!(res.status().is_success()); + let req = TestRequest::with_uri("/v1//something////").to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert!(res.status().is_success()); + }) } #[test] fn should_normalize_nothing() { - const URI: &str = "/v1/something/"; + block_on(async { + const URI: &str = "/v1/something/"; - let srv = |req: ServiceRequest| { - assert_eq!(URI, req.path()); - req.into_response(HttpResponse::Ok().finish()) - }; + let srv = |req: ServiceRequest| { + assert_eq!(URI, req.path()); + ok(req.into_response(HttpResponse::Ok().finish())) + }; - let mut normalize = - block_on(NormalizePath.new_transform(srv.into_service())).unwrap(); + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); - let req = TestRequest::with_uri(URI).to_srv_request(); - let res = block_on(normalize.call(req)).unwrap(); - assert!(res.status().is_success()); + let req = TestRequest::with_uri(URI).to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert!(res.status().is_success()); + }) } } diff --git a/src/request.rs b/src/request.rs index 6d9d26e8c..19072fcb1 100644 --- a/src/request.rs +++ b/src/request.rs @@ -5,6 +5,7 @@ use std::{fmt, net}; use actix_http::http::{HeaderMap, Method, Uri, Version}; use actix_http::{Error, Extensions, HttpMessage, Message, Payload, RequestHead}; use actix_router::{Path, Url}; +use futures::future::{ok, Ready}; use crate::config::AppConfig; use crate::data::Data; @@ -271,11 +272,11 @@ impl Drop for HttpRequest { /// ## Example /// /// ```rust -/// # #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, HttpRequest}; +/// use serde_derive::Deserialize; /// /// /// extract `Thing` from request -/// fn index(req: HttpRequest) -> String { +/// async fn index(req: HttpRequest) -> String { /// format!("Got thing: {:?}", req) /// } /// @@ -289,11 +290,11 @@ impl Drop for HttpRequest { impl FromRequest for HttpRequest { type Config = (); type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(req.clone()) + ok(req.clone()) } } @@ -349,7 +350,7 @@ mod tests { use super::*; use crate::dev::{ResourceDef, ResourceMap}; use crate::http::{header, StatusCode}; - use crate::test::{call_service, init_service, TestRequest}; + use crate::test::{block_on, call_service, init_service, TestRequest}; use crate::{web, App, HttpResponse}; #[test] @@ -467,66 +468,73 @@ mod tests { #[test] fn test_app_data() { - let mut srv = init_service(App::new().data(10usize).service( - web::resource("/").to(|req: HttpRequest| { - if req.app_data::().is_some() { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - }), - )); + block_on(async { + let mut srv = init_service(App::new().data(10usize).service( + web::resource("/").to(|req: HttpRequest| { + if req.app_data::().is_some() { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + .await; - let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::default().to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); - let mut srv = init_service(App::new().data(10u32).service( - web::resource("/").to(|req: HttpRequest| { - if req.app_data::().is_some() { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - }), - )); + let mut srv = init_service(App::new().data(10u32).service( + web::resource("/").to(|req: HttpRequest| { + if req.app_data::().is_some() { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + .await; - let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let req = TestRequest::default().to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + }) } #[test] fn test_extensions_dropped() { - struct Tracker { - pub dropped: bool, - } - struct Foo { - tracker: Rc>, - } - impl Drop for Foo { - fn drop(&mut self) { - self.tracker.borrow_mut().dropped = true; + block_on(async { + struct Tracker { + pub dropped: bool, + } + struct Foo { + tracker: Rc>, + } + impl Drop for Foo { + fn drop(&mut self) { + self.tracker.borrow_mut().dropped = true; + } } - } - let tracker = Rc::new(RefCell::new(Tracker { dropped: false })); - { - let tracker2 = Rc::clone(&tracker); - let mut srv = init_service(App::new().data(10u32).service( - web::resource("/").to(move |req: HttpRequest| { - req.extensions_mut().insert(Foo { - tracker: Rc::clone(&tracker2), - }); - HttpResponse::Ok() - }), - )); + let tracker = Rc::new(RefCell::new(Tracker { dropped: false })); + { + let tracker2 = Rc::clone(&tracker); + let mut srv = init_service(App::new().data(10u32).service( + web::resource("/").to(move |req: HttpRequest| { + req.extensions_mut().insert(Foo { + tracker: Rc::clone(&tracker2), + }); + HttpResponse::Ok() + }), + )) + .await; - let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - } + let req = TestRequest::default().to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } - assert!(tracker.borrow().dropped); + assert!(tracker.borrow().dropped); + }) } } diff --git a/src/resource.rs b/src/resource.rs index 3ee0167a0..a06530d48 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,20 +1,23 @@ use std::cell::RefCell; use std::fmt; +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{Error, Extensions, Response}; use actix_service::boxed::{self, BoxedNewService, BoxedService}; use actix_service::{ - apply_transform, IntoNewService, IntoTransform, NewService, Service, Transform, + apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, }; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll}; +use futures::future::{ok, Either, LocalBoxFuture, Ready}; +use pin_project::pin_project; use crate::data::Data; use crate::dev::{insert_slash, AppService, HttpServiceFactory, ResourceDef}; use crate::extract::FromRequest; use crate::guard::Guard; -use crate::handler::{AsyncFactory, Factory}; +use crate::handler::Factory; use crate::responder::Responder; use crate::route::{CreateRouteService, Route, RouteService}; use crate::service::{ServiceRequest, ServiceResponse}; @@ -74,7 +77,7 @@ impl Resource { impl Resource where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -95,7 +98,7 @@ where /// ```rust /// use actix_web::{web, guard, App, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -153,9 +156,9 @@ where /// .route(web::delete().to(delete_handler)) /// ); /// } - /// # fn get_handler() {} - /// # fn post_handler() {} - /// # fn delete_handler() {} + /// # async fn get_handler() -> impl actix_web::Responder { HttpResponse::Ok() } + /// # async fn post_handler() -> impl actix_web::Responder { HttpResponse::Ok() } + /// # async fn delete_handler() -> impl actix_web::Responder { HttpResponse::Ok() } /// ``` pub fn route(mut self, route: Route) -> Self { self.routes.push(route); @@ -171,7 +174,7 @@ where /// use actix_web::{web, App, FromRequest}; /// /// /// extract text data from request - /// fn index(body: String) -> String { + /// async fn index(body: String) -> String { /// format!("Body {}!", body) /// } /// @@ -227,52 +230,17 @@ where /// # fn index(req: HttpRequest) -> HttpResponse { unimplemented!() } /// App::new().service(web::resource("/").route(web::route().to(index))); /// ``` - pub fn to(mut self, handler: F) -> Self + pub fn to(mut self, handler: F) -> Self where - F: Factory + 'static, + F: Factory, I: FromRequest + 'static, - R: Responder + 'static, + R: Future + 'static, + U: Responder + 'static, { self.routes.push(Route::new().to(handler)); self } - /// Register a new route and add async handler. - /// - /// ```rust - /// use actix_web::*; - /// use futures::future::{ok, Future}; - /// - /// fn index(req: HttpRequest) -> impl Future { - /// ok(HttpResponse::Ok().finish()) - /// } - /// - /// App::new().service(web::resource("/").to_async(index)); - /// ``` - /// - /// This is shortcut for: - /// - /// ```rust - /// # use actix_web::*; - /// # use futures::future::Future; - /// # fn index(req: HttpRequest) -> Box> { - /// # unimplemented!() - /// # } - /// App::new().service(web::resource("/").route(web::route().to_async(index))); - /// ``` - #[allow(clippy::wrong_self_convention)] - pub fn to_async(mut self, handler: F) -> Self - where - F: AsyncFactory, - I: FromRequest + 'static, - R: IntoFuture + 'static, - R::Item: Responder, - R::Error: Into, - { - self.routes.push(Route::new().to_async(handler)); - self - } - /// Register a resource middleware. /// /// This is similar to `App's` middlewares, but middleware get invoked on resource level. @@ -280,11 +248,11 @@ where /// type (i.e modify response's body). /// /// **Note**: middlewares get called in opposite order of middlewares registration. - pub fn wrap( + pub fn wrap( self, - mw: F, + mw: M, ) -> Resource< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -300,11 +268,9 @@ where Error = Error, InitError = (), >, - F: IntoTransform, { - let endpoint = apply_transform(mw, self.endpoint); Resource { - endpoint, + endpoint: apply(mw, self.endpoint), rdef: self.rdef, name: self.name, guards: self.guards, @@ -326,24 +292,26 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/index.html") - /// .wrap_fn(|req, srv| - /// srv.call(req).map(|mut res| { + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; /// res.headers_mut().insert( /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// ); - /// res - /// })) + /// Ok(res) + /// } + /// }) /// .route(web::get().to(index))); /// } /// ``` @@ -351,7 +319,7 @@ where self, mw: F, ) -> Resource< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -361,9 +329,18 @@ where > where F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, - R: IntoFuture, + R: Future>, { - self.wrap(mw) + Resource { + endpoint: apply_fn_factory(self.endpoint, mw), + rdef: self.rdef, + name: self.name, + guards: self.guards, + routes: self.routes, + default: self.default, + data: self.data, + factory_ref: self.factory_ref, + } } /// Default service to be used if no matching route could be found. @@ -371,8 +348,8 @@ where /// default handler from `App` or `Scope`. pub fn default_service(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -381,8 +358,8 @@ where U::InitError: fmt::Debug, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|e| { + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|e| { log::error!("Can not construct default service: {:?}", e) }), ))))); @@ -393,7 +370,7 @@ where impl HttpServiceFactory for Resource where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -423,9 +400,9 @@ where } } -impl IntoNewService for Resource +impl IntoServiceFactory for Resource where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -433,7 +410,7 @@ where InitError = (), >, { - fn into_new_service(self) -> T { + fn into_factory(self) -> T { *self.factory_ref.borrow_mut() = Some(ResourceFactory { routes: self.routes, data: self.data.map(Rc::new), @@ -450,7 +427,7 @@ pub struct ResourceFactory { default: Rc>>>, } -impl NewService for ResourceFactory { +impl ServiceFactory for ResourceFactory { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -488,31 +465,30 @@ pub struct CreateResourceService { fut: Vec, data: Option>, default: Option, - default_fut: Option>>, + default_fut: Option>>, } impl Future for CreateResourceService { - type Item = ResourceService; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut done = true; if let Some(ref mut fut) = self.default_fut { - match fut.poll()? { - Async::Ready(default) => self.default = Some(default), - Async::NotReady => done = false, + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, } } // poll http services for item in &mut self.fut { match item { - CreateRouteServiceItem::Future(ref mut fut) => match fut.poll()? { - Async::Ready(route) => { - *item = CreateRouteServiceItem::Service(route) - } - Async::NotReady => { + CreateRouteServiceItem::Future(ref mut fut) => match Pin::new(fut) + .poll(cx)? + { + Poll::Ready(route) => *item = CreateRouteServiceItem::Service(route), + Poll::Pending => { done = false; } }, @@ -529,13 +505,13 @@ impl Future for CreateResourceService { CreateRouteServiceItem::Future(_) => unreachable!(), }) .collect(); - Ok(Async::Ready(ResourceService { + Poll::Ready(Ok(ResourceService { routes, data: self.data.clone(), default: self.default.take(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -551,12 +527,12 @@ impl Service for ResourceService { type Response = ServiceResponse; type Error = Error; type Future = Either< - FutureResult, - Box>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { @@ -565,14 +541,14 @@ impl Service for ResourceService { if let Some(ref data) = self.data { req.set_data_container(data.clone()); } - return route.call(req); + return Either::Right(route.call(req)); } } if let Some(ref mut default) = self.default { - default.call(req) + Either::Right(default.call(req)) } else { let req = req.into_parts().0; - Either::A(ok(ServiceResponse::new( + Either::Left(ok(ServiceResponse::new( req, Response::MethodNotAllowed().finish(), ))) @@ -591,7 +567,7 @@ impl ResourceEndpoint { } } -impl NewService for ResourceEndpoint { +impl ServiceFactory for ResourceEndpoint { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -610,18 +586,19 @@ mod tests { use std::time::Duration; use actix_service::Service; - use futures::{Future, IntoFuture}; - use tokio_timer::sleep; + use futures::future::{ok, Future}; + use tokio_timer::delay_for; use crate::http::{header, HeaderValue, Method, StatusCode}; + use crate::middleware::DefaultHeaders; use crate::service::{ServiceRequest, ServiceResponse}; - use crate::test::{call_service, init_service, TestRequest}; + use crate::test::{block_on, call_service, init_service, TestRequest}; use crate::{guard, web, App, Error, HttpResponse}; fn md( req: ServiceRequest, srv: &mut S, - ) -> impl IntoFuture, Error = Error> + ) -> impl Future, Error>> where S: Service< Request = ServiceRequest, @@ -629,178 +606,209 @@ mod tests { Error = Error, >, { - srv.call(req).map(|mut res| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; res.headers_mut() .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); - res - }) + Ok(res) + } } #[test] fn test_middleware() { - let mut srv = init_service( - App::new().service( - web::resource("/test") - .name("test") - .wrap(md) - .route(web::get().to(|| HttpResponse::Ok())), - ), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + block_on(async { + let mut srv = init_service( + App::new().service( + web::resource("/test") + .name("test") + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )) + .route(web::get().to(|| HttpResponse::Ok())), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] fn test_middleware_fn() { - let mut srv = init_service( - App::new().service( - web::resource("/test") - .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); - res + block_on(async { + let mut srv = init_service( + App::new().service( + web::resource("/test") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { + fut.await.map(|mut res| { + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + res + }) + } }) - }) - .route(web::get().to(|| HttpResponse::Ok())), - ), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + .route(web::get().to(|| HttpResponse::Ok())), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] - fn test_to_async() { - let mut srv = - init_service(App::new().service(web::resource("/test").to_async(|| { - sleep(Duration::from_millis(100)).then(|_| HttpResponse::Ok()) - }))); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + fn test_to() { + block_on(async { + let mut srv = + init_service(App::new().service(web::resource("/test").to(|| { + async { + delay_for(Duration::from_millis(100)).await; + Ok::<_, Error>(HttpResponse::Ok()) + } + }))) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_default_resource() { - let mut srv = init_service( - App::new() - .service( - web::resource("/test").route(web::get().to(|| HttpResponse::Ok())), - ) - .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::BadRequest()) - }), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - - let req = TestRequest::with_uri("/test") - .method(Method::POST) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - - let mut srv = init_service( - App::new().service( - web::resource("/test") - .route(web::get().to(|| HttpResponse::Ok())) + block_on(async { + let mut srv = init_service( + App::new() + .service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())), + ) .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::BadRequest()) + ok(r.into_response(HttpResponse::BadRequest())) }), - ), - ); + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/test") + .method(Method::POST) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let req = TestRequest::with_uri("/test") - .method(Method::POST) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let mut srv = init_service( + App::new().service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::BadRequest())) + }), + ), + ) + .await; + + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/test") + .method(Method::POST) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + }) } #[test] fn test_resource_guards() { - let mut srv = init_service( - App::new() - .service( - web::resource("/test/{p}") - .guard(guard::Get()) - .to(|| HttpResponse::Ok()), - ) - .service( - web::resource("/test/{p}") - .guard(guard::Put()) - .to(|| HttpResponse::Created()), - ) - .service( - web::resource("/test/{p}") - .guard(guard::Delete()) - .to(|| HttpResponse::NoContent()), - ), - ); + block_on(async { + let mut srv = init_service( + App::new() + .service( + web::resource("/test/{p}") + .guard(guard::Get()) + .to(|| HttpResponse::Ok()), + ) + .service( + web::resource("/test/{p}") + .guard(guard::Put()) + .to(|| HttpResponse::Created()), + ) + .service( + web::resource("/test/{p}") + .guard(guard::Delete()) + .to(|| HttpResponse::NoContent()), + ), + ) + .await; - let req = TestRequest::with_uri("/test/it") - .method(Method::GET) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/test/it") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/test/it") - .method(Method::PUT) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/test/it") + .method(Method::PUT) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); - let req = TestRequest::with_uri("/test/it") - .method(Method::DELETE) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::NO_CONTENT); + let req = TestRequest::with_uri("/test/it") + .method(Method::DELETE) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NO_CONTENT); + }) } #[test] fn test_data() { - let mut srv = init_service( - App::new() - .data(1.0f64) - .data(1usize) - .register_data(web::Data::new('-')) - .service( - web::resource("/test") - .data(10usize) - .register_data(web::Data::new('*')) - .guard(guard::Get()) - .to( - |data1: web::Data, - data2: web::Data, - data3: web::Data| { - assert_eq!(*data1, 10); - assert_eq!(*data2, '*'); - assert_eq!(*data3, 1.0); - HttpResponse::Ok() - }, - ), - ), - ); + block_on(async { + let mut srv = init_service( + App::new() + .data(1.0f64) + .data(1usize) + .register_data(web::Data::new('-')) + .service( + web::resource("/test") + .data(10usize) + .register_data(web::Data::new('*')) + .guard(guard::Get()) + .to( + |data1: web::Data, + data2: web::Data, + data3: web::Data| { + assert_eq!(*data1, 10); + assert_eq!(*data2, '*'); + assert_eq!(*data3, 1.0); + HttpResponse::Ok() + }, + ), + ), + ) + .await; - let req = TestRequest::get().uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::get().uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } } diff --git a/src/responder.rs b/src/responder.rs index 4988ad5bc..b254567de 100644 --- a/src/responder.rs +++ b/src/responder.rs @@ -1,3 +1,8 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_http::error::InternalError; use actix_http::http::{ header::IntoHeaderValue, Error as HttpError, HeaderMap, HeaderName, HttpTryFrom, @@ -5,8 +10,9 @@ use actix_http::http::{ }; use actix_http::{Error, Response, ResponseBuilder}; use bytes::{Bytes, BytesMut}; -use futures::future::{err, ok, Either as EitherFuture, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use futures::future::{err, ok, Either as EitherFuture, LocalBoxFuture, Ready}; +use futures::ready; +use pin_project::{pin_project, project}; use crate::request::HttpRequest; @@ -18,7 +24,7 @@ pub trait Responder { type Error: Into; /// The future response value. - type Future: IntoFuture; + type Future: Future>; /// Convert itself to `AsyncResult` or `Error`. fn respond_to(self, req: &HttpRequest) -> Self::Future; @@ -71,7 +77,7 @@ pub trait Responder { impl Responder for Response { type Error = Error; - type Future = FutureResult; + type Future = Ready>; #[inline] fn respond_to(self, _: &HttpRequest) -> Self::Future { @@ -84,15 +90,14 @@ where T: Responder, { type Error = T::Error; - type Future = EitherFuture< - ::Future, - FutureResult, - >; + type Future = EitherFuture>>; fn respond_to(self, req: &HttpRequest) -> Self::Future { match self { - Some(t) => EitherFuture::A(t.respond_to(req).into_future()), - None => EitherFuture::B(ok(Response::build(StatusCode::NOT_FOUND).finish())), + Some(t) => EitherFuture::Left(t.respond_to(req)), + None => { + EitherFuture::Right(ok(Response::build(StatusCode::NOT_FOUND).finish())) + } } } } @@ -104,23 +109,21 @@ where { type Error = Error; type Future = EitherFuture< - ResponseFuture<::Future>, - FutureResult, + ResponseFuture, + Ready>, >; fn respond_to(self, req: &HttpRequest) -> Self::Future { match self { - Ok(val) => { - EitherFuture::A(ResponseFuture::new(val.respond_to(req).into_future())) - } - Err(e) => EitherFuture::B(err(e.into())), + Ok(val) => EitherFuture::Left(ResponseFuture::new(val.respond_to(req))), + Err(e) => EitherFuture::Right(err(e.into())), } } } impl Responder for ResponseBuilder { type Error = Error; - type Future = FutureResult; + type Future = Ready>; #[inline] fn respond_to(mut self, _: &HttpRequest) -> Self::Future { @@ -130,7 +133,7 @@ impl Responder for ResponseBuilder { impl Responder for () { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK).finish()) @@ -146,7 +149,7 @@ where fn respond_to(self, req: &HttpRequest) -> Self::Future { CustomResponderFut { - fut: self.0.respond_to(req).into_future(), + fut: self.0.respond_to(req), status: Some(self.1), headers: None, } @@ -155,7 +158,7 @@ where impl Responder for &'static str { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -166,7 +169,7 @@ impl Responder for &'static str { impl Responder for &'static [u8] { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -177,7 +180,7 @@ impl Responder for &'static [u8] { impl Responder for String { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -188,7 +191,7 @@ impl Responder for String { impl<'a> Responder for &'a String { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -199,7 +202,7 @@ impl<'a> Responder for &'a String { impl Responder for Bytes { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -210,7 +213,7 @@ impl Responder for Bytes { impl Responder for BytesMut { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -299,45 +302,49 @@ impl Responder for CustomResponder { fn respond_to(self, req: &HttpRequest) -> Self::Future { CustomResponderFut { - fut: self.responder.respond_to(req).into_future(), + fut: self.responder.respond_to(req), status: self.status, headers: self.headers, } } } +#[pin_project] pub struct CustomResponderFut { - fut: ::Future, + #[pin] + fut: T::Future, status: Option, headers: Option, } impl Future for CustomResponderFut { - type Item = Response; - type Error = T::Error; + type Output = Result; - fn poll(&mut self) -> Poll { - let mut res = try_ready!(self.fut.poll()); - if let Some(status) = self.status { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + let mut res = match ready!(this.fut.poll(cx)) { + Ok(res) => res, + Err(e) => return Poll::Ready(Err(e)), + }; + if let Some(status) = this.status.take() { *res.status_mut() = status; } - if let Some(ref headers) = self.headers { + if let Some(ref headers) = this.headers { for (k, v) in headers { res.headers_mut().insert(k.clone(), v.clone()); } } - Ok(Async::Ready(res)) + Poll::Ready(Ok(res)) } } /// Combines two different responder types into a single type /// /// ```rust -/// # use futures::future::{ok, Future}; /// use actix_web::{Either, Error, HttpResponse}; /// -/// type RegisterResult = -/// Either>>; +/// type RegisterResult = Either>; /// /// fn index() -> RegisterResult { /// if is_a_variant() { @@ -346,9 +353,9 @@ impl Future for CustomResponderFut { /// } else { /// Either::B( /// // <- Right variant -/// Box::new(ok(HttpResponse::Ok() +/// Ok(HttpResponse::Ok() /// .content_type("text/html") -/// .body("Hello!"))) +/// .body("Hello!")) /// ) /// } /// } @@ -369,97 +376,85 @@ where B: Responder, { type Error = Error; - type Future = EitherResponder< - ::Future, - ::Future, - >; + type Future = EitherResponder; fn respond_to(self, req: &HttpRequest) -> Self::Future { match self { - Either::A(a) => EitherResponder::A(a.respond_to(req).into_future()), - Either::B(b) => EitherResponder::B(b.respond_to(req).into_future()), + Either::A(a) => EitherResponder::A(a.respond_to(req)), + Either::B(b) => EitherResponder::B(b.respond_to(req)), } } } +#[pin_project] pub enum EitherResponder where - A: Future, - A::Error: Into, - B: Future, - B::Error: Into, + A: Responder, + B: Responder, { - A(A), - B(B), + A(#[pin] A::Future), + B(#[pin] B::Future), } impl Future for EitherResponder where - A: Future, - A::Error: Into, - B: Future, - B::Error: Into, + A: Responder, + B: Responder, { - type Item = Response; - type Error = Error; + type Output = Result; - fn poll(&mut self) -> Poll { - match self { - EitherResponder::A(ref mut fut) => Ok(fut.poll().map_err(|e| e.into())?), - EitherResponder::B(ref mut fut) => Ok(fut.poll().map_err(|e| e.into())?), + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + #[project] + match self.project() { + EitherResponder::A(fut) => { + Poll::Ready(ready!(fut.poll(cx)).map_err(|e| e.into())) + } + EitherResponder::B(fut) => { + Poll::Ready(ready!(fut.poll(cx).map_err(|e| e.into()))) + } } } } -impl Responder for Box> -where - I: Responder + 'static, - E: Into + 'static, -{ - type Error = Error; - type Future = Box>; - - #[inline] - fn respond_to(self, req: &HttpRequest) -> Self::Future { - let req = req.clone(); - Box::new( - self.map_err(|e| e.into()) - .and_then(move |r| ResponseFuture(r.respond_to(&req).into_future())), - ) - } -} - impl Responder for InternalError where T: std::fmt::Debug + std::fmt::Display + 'static, { type Error = Error; - type Future = Result; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { let err: Error = self.into(); - Ok(err.into()) + ok(err.into()) } } -pub struct ResponseFuture(T); +#[pin_project] +pub struct ResponseFuture { + #[pin] + fut: T, + _t: PhantomData, +} -impl ResponseFuture { +impl ResponseFuture { pub fn new(fut: T) -> Self { - ResponseFuture(fut) + ResponseFuture { + fut, + _t: PhantomData, + } } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - T: Future, - T::Error: Into, + T: Future>, + E: Into, { - type Item = Response; - type Error = Error; + type Output = Result; - fn poll(&mut self) -> Poll { - Ok(self.0.poll().map_err(|e| e.into())?) + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready(ready!(self.project().fut.poll(cx)).map_err(|e| e.into())) } } @@ -476,26 +471,32 @@ pub(crate) mod tests { #[test] fn test_option_responder() { - let mut srv = init_service( - App::new() - .service(web::resource("/none").to(|| -> Option<&'static str> { None })) - .service(web::resource("/some").to(|| Some("some"))), - ); + block_on(async { + let mut srv = init_service( + App::new() + .service( + web::resource("/none") + .to(|| async { Option::<&'static str>::None }), + ) + .service(web::resource("/some").to(|| async { Some("some") })), + ) + .await; - let req = TestRequest::with_uri("/none").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/none").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let req = TestRequest::with_uri("/some").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes: Bytes = b.clone().into(); - assert_eq!(bytes, Bytes::from_static(b"some")); + let req = TestRequest::with_uri("/some").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"some")); + } + _ => panic!(), } - _ => panic!(), - } + }) } pub(crate) trait BodyTest { @@ -526,142 +527,155 @@ pub(crate) mod tests { #[test] fn test_responder() { - let req = TestRequest::default().to_http_request(); + block_on(async { + let req = TestRequest::default().to_http_request(); - let resp: HttpResponse = block_on(().respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(*resp.body().body(), Body::Empty); + let resp: HttpResponse = ().respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(*resp.body().body(), Body::Empty); - let resp: HttpResponse = block_on("test".respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); + let resp: HttpResponse = "test".respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); - let resp: HttpResponse = block_on(b"test".respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); + let resp: HttpResponse = b"test".respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); - let resp: HttpResponse = block_on("test".to_string().respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); + let resp: HttpResponse = "test".to_string().respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); - let resp: HttpResponse = - block_on((&"test".to_string()).respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); + let resp: HttpResponse = + (&"test".to_string()).respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); - let resp: HttpResponse = - block_on(Bytes::from_static(b"test").respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); + let resp: HttpResponse = + Bytes::from_static(b"test").respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); - let resp: HttpResponse = - block_on(BytesMut::from(b"test".as_ref()).respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - - // InternalError - let resp: HttpResponse = - error::InternalError::new("err", StatusCode::BAD_REQUEST) + let resp: HttpResponse = BytesMut::from(b"test".as_ref()) .respond_to(&req) + .await .unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + + // InternalError + let resp: HttpResponse = + error::InternalError::new("err", StatusCode::BAD_REQUEST) + .respond_to(&req) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + }) } #[test] fn test_result_responder() { - let req = TestRequest::default().to_http_request(); + block_on(async { + let req = TestRequest::default().to_http_request(); - // Result - let resp: HttpResponse = - block_on(Ok::<_, Error>("test".to_string()).respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); + // Result + let resp: HttpResponse = Ok::<_, Error>("test".to_string()) + .respond_to(&req) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); - let res = block_on( - Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) - .respond_to(&req), - ); - assert!(res.is_err()); + let res = Err::(error::InternalError::new( + "err", + StatusCode::BAD_REQUEST, + )) + .respond_to(&req) + .await; + assert!(res.is_err()); + }) } #[test] fn test_custom_responder() { - let req = TestRequest::default().to_http_request(); - let res = block_on( - "test" + block_on(async { + let req = TestRequest::default().to_http_request(); + let res = "test" .to_string() .with_status(StatusCode::BAD_REQUEST) - .respond_to(&req), - ) - .unwrap(); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().bin_ref(), b"test"); + .respond_to(&req) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!(res.body().bin_ref(), b"test"); - let res = block_on( - "test" + let res = "test" .to_string() .with_header("content-type", "json") - .respond_to(&req), - ) - .unwrap(); + .respond_to(&req) + .await + .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().bin_ref(), b"test"); - assert_eq!( - res.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("json") - ); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.body().bin_ref(), b"test"); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("json") + ); + }) } #[test] fn test_tuple_responder_with_status_code() { - let req = TestRequest::default().to_http_request(); - let res = - block_on(("test".to_string(), StatusCode::BAD_REQUEST).respond_to(&req)) + block_on(async { + let req = TestRequest::default().to_http_request(); + let res = ("test".to_string(), StatusCode::BAD_REQUEST) + .respond_to(&req) + .await .unwrap(); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().bin_ref(), b"test"); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!(res.body().bin_ref(), b"test"); - let req = TestRequest::default().to_http_request(); - let res = block_on( - ("test".to_string(), StatusCode::OK) + let req = TestRequest::default().to_http_request(); + let res = ("test".to_string(), StatusCode::OK) .with_header("content-type", "json") - .respond_to(&req), - ) - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().bin_ref(), b"test"); - assert_eq!( - res.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("json") - ); + .respond_to(&req) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.body().bin_ref(), b"test"); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("json") + ); + }) } } diff --git a/src/route.rs b/src/route.rs index f4d303632..3ebfc3f52 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,13 +1,15 @@ +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{http::Method, Error}; -use actix_service::{NewService, Service}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, ready, Either, FutureExt, LocalBoxFuture, Ready}; use crate::extract::FromRequest; use crate::guard::{self, Guard}; -use crate::handler::{AsyncFactory, AsyncHandler, Extract, Factory, Handler}; +use crate::handler::{Extract, Factory, Handler}; use crate::responder::Responder; use crate::service::{ServiceRequest, ServiceResponse}; use crate::HttpResponse; @@ -17,22 +19,19 @@ type BoxedRouteService = Box< Request = Req, Response = Res, Error = Error, - Future = Either< - FutureResult, - Box>, - >, + Future = LocalBoxFuture<'static, Result>, >, >; type BoxedRouteNewService = Box< - dyn NewService< + dyn ServiceFactory< Config = (), Request = Req, Response = Res, Error = Error, InitError = (), Service = BoxedRouteService, - Future = Box, Error = ()>>, + Future = LocalBoxFuture<'static, Result, ()>>, >, >; @@ -50,7 +49,7 @@ impl Route { pub fn new() -> Route { Route { service: Box::new(RouteNewService::new(Extract::new(Handler::new(|| { - HttpResponse::NotFound() + ready(HttpResponse::NotFound()) })))), guards: Rc::new(Vec::new()), } @@ -61,7 +60,7 @@ impl Route { } } -impl NewService for Route { +impl ServiceFactory for Route { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -78,26 +77,30 @@ impl NewService for Route { } } -type RouteFuture = Box< - dyn Future, Error = ()>, +type RouteFuture = LocalBoxFuture< + 'static, + Result, ()>, >; +#[pin_project::pin_project] pub struct CreateRouteService { + #[pin] fut: RouteFuture, guards: Rc>>, } impl Future for CreateRouteService { - type Item = RouteService; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::Ready(service) => Ok(Async::Ready(RouteService { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + match this.fut.poll(cx)? { + Poll::Ready(service) => Poll::Ready(Ok(RouteService { service, - guards: self.guards.clone(), + guards: this.guards.clone(), })), - Async::NotReady => Ok(Async::NotReady), + Poll::Pending => Poll::Pending, } } } @@ -122,17 +125,14 @@ impl Service for RouteService { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Either< - FutureResult, - Box>, - >; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - self.service.call(req) + self.service.call(req).boxed_local() } } @@ -178,8 +178,8 @@ impl Route { /// Set handler function, use request extractors for parameters. /// /// ```rust - /// #[macro_use] extern crate serde_derive; /// use actix_web::{web, http, App}; + /// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -187,7 +187,7 @@ impl Route { /// } /// /// /// extract path info using serde - /// fn index(info: web::Path) -> String { + /// async fn index(info: web::Path) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -212,7 +212,7 @@ impl Route { /// } /// /// /// extract path info using serde - /// fn index(path: web::Path, query: web::Query>, body: web::Json) -> String { + /// async fn index(path: web::Path, query: web::Query>, body: web::Json) -> String { /// format!("Welcome {}!", path.username) /// } /// @@ -223,69 +223,29 @@ impl Route { /// ); /// } /// ``` - pub fn to(mut self, handler: F) -> Route + pub fn to(mut self, handler: F) -> Self where - F: Factory + 'static, + F: Factory, T: FromRequest + 'static, - R: Responder + 'static, + R: Future + 'static, + U: Responder + 'static, { self.service = Box::new(RouteNewService::new(Extract::new(Handler::new(handler)))); self } - - /// Set async handler function, use request extractors for parameters. - /// This method has to be used if your handler function returns `impl Future<>` - /// - /// ```rust - /// # use futures::future::ok; - /// #[macro_use] extern crate serde_derive; - /// use actix_web::{web, App, Error}; - /// use futures::Future; - /// - /// #[derive(Deserialize)] - /// struct Info { - /// username: String, - /// } - /// - /// /// extract path info using serde - /// fn index(info: web::Path) -> impl Future { - /// ok("Hello World!") - /// } - /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/{username}/index.html") // <- define path parameters - /// .route(web::get().to_async(index)) // <- register async handler - /// ); - /// } - /// ``` - #[allow(clippy::wrong_self_convention)] - pub fn to_async(mut self, handler: F) -> Self - where - F: AsyncFactory, - T: FromRequest + 'static, - R: IntoFuture + 'static, - R::Item: Responder, - R::Error: Into, - { - self.service = Box::new(RouteNewService::new(Extract::new(AsyncHandler::new( - handler, - )))); - self - } } struct RouteNewService where - T: NewService, + T: ServiceFactory, { service: T, } impl RouteNewService where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -300,9 +260,9 @@ where } } -impl NewService for RouteNewService +impl ServiceFactory for RouteNewService where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -318,19 +278,20 @@ where type Error = Error; type InitError = (); type Service = BoxedRouteService; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: &()) -> Self::Future { - Box::new( - self.service - .new_service(&()) - .map_err(|_| ()) - .and_then(|service| { + self.service + .new_service(&()) + .map(|result| match result { + Ok(service) => { let service: BoxedRouteService<_, _> = Box::new(RouteServiceWrapper { service }); Ok(service) - }), - ) + } + Err(_) => Err(()), + }) + .boxed_local() } } @@ -350,25 +311,30 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Either< - FutureResult, - Box>, - >; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready().map_err(|(e, _)| e) + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx).map_err(|(e, _)| e) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - let mut fut = self.service.call(req); - match fut.poll() { - Ok(Async::Ready(res)) => Either::A(ok(res)), - Err((e, req)) => Either::A(ok(req.error_response(e))), - Ok(Async::NotReady) => Either::B(Box::new(fut.then(|res| match res { + // let mut fut = self.service.call(req); + self.service + .call(req) + .map(|res| match res { Ok(res) => Ok(res), Err((err, req)) => Ok(req.error_response(err)), - }))), - } + }) + .boxed_local() + + // match fut.poll() { + // Poll::Ready(Ok(res)) => Either::Left(ok(res)), + // Poll::Ready(Err((e, req))) => Either::Left(ok(req.error_response(e))), + // Poll::Pending => Either::Right(Box::new(fut.then(|res| match res { + // Ok(res) => Ok(res), + // Err((err, req)) => Ok(req.error_response(err)), + // }))), + // } } } @@ -379,11 +345,11 @@ mod tests { use bytes::Bytes; use futures::Future; use serde_derive::Serialize; - use tokio_timer::sleep; + use tokio_timer::delay_for; use crate::http::{Method, StatusCode}; - use crate::test::{call_service, init_service, read_body, TestRequest}; - use crate::{error, web, App, HttpResponse}; + use crate::test::{block_on, call_service, init_service, read_body, TestRequest}; + use crate::{error, web, App, Error, HttpResponse}; #[derive(Serialize, PartialEq, Debug)] struct MyObject { @@ -392,68 +358,77 @@ mod tests { #[test] fn test_route() { - let mut srv = init_service( - App::new() - .service( - web::resource("/test") - .route(web::get().to(|| HttpResponse::Ok())) - .route(web::put().to(|| { - Err::(error::ErrorBadRequest("err")) - })) - .route(web::post().to_async(|| { - sleep(Duration::from_millis(100)) - .then(|_| HttpResponse::Created()) - })) - .route(web::delete().to_async(|| { - sleep(Duration::from_millis(100)).then(|_| { - Err::(error::ErrorBadRequest("err")) + block_on(async { + let mut srv = init_service( + App::new() + .service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())) + .route(web::put().to(|| { + async { + Err::(error::ErrorBadRequest("err")) + } + })) + .route(web::post().to(|| { + async { + delay_for(Duration::from_millis(100)).await; + HttpResponse::Created() + } + })) + .route(web::delete().to(|| { + async { + delay_for(Duration::from_millis(100)).await; + Err::(error::ErrorBadRequest("err")) + } + })), + ) + .service(web::resource("/json").route(web::get().to(|| { + async { + delay_for(Duration::from_millis(25)).await; + web::Json(MyObject { + name: "test".to_string(), }) - })), - ) - .service(web::resource("/json").route(web::get().to_async(|| { - sleep(Duration::from_millis(25)).then(|_| { - Ok::<_, crate::Error>(web::Json(MyObject { - name: "test".to_string(), - })) - }) - }))), - ); + } + }))), + ) + .await; - let req = TestRequest::with_uri("/test") - .method(Method::GET) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/test") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/test") - .method(Method::POST) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/test") + .method(Method::POST) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); - let req = TestRequest::with_uri("/test") - .method(Method::PUT) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let req = TestRequest::with_uri("/test") + .method(Method::PUT) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let req = TestRequest::with_uri("/test") - .method(Method::DELETE) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let req = TestRequest::with_uri("/test") + .method(Method::DELETE) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let req = TestRequest::with_uri("/test") - .method(Method::HEAD) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let req = TestRequest::with_uri("/test") + .method(Method::HEAD) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let req = TestRequest::with_uri("/json").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/json").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); - assert_eq!(body, Bytes::from_static(b"{\"name\":\"test\"}")); + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"{\"name\":\"test\"}")); + }) } } diff --git a/src/scope.rs b/src/scope.rs index c152bc334..e5c04d71e 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,15 +1,16 @@ use std::cell::RefCell; use std::fmt; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{Extensions, Response}; use actix_router::{ResourceDef, ResourceInfo, Router}; use actix_service::boxed::{self, BoxedNewService, BoxedService}; use actix_service::{ - apply_transform, IntoNewService, IntoTransform, NewService, Service, Transform, + apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, }; -use futures::future::{ok, Either, Future, FutureResult}; -use futures::{Async, IntoFuture, Poll}; +use futures::future::{ok, Either, Future, LocalBoxFuture, Ready}; use crate::config::ServiceConfig; use crate::data::Data; @@ -20,16 +21,13 @@ use crate::resource::Resource; use crate::rmap::ResourceMap; use crate::route::Route; use crate::service::{ - ServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, + AppServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, }; type Guards = Vec>; type HttpService = BoxedService; type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; -type BoxedResponse = Either< - FutureResult, - Box>, ->; +type BoxedResponse = LocalBoxFuture<'static, Result>; /// Resources scope. /// @@ -48,7 +46,7 @@ type BoxedResponse = Either< /// fn main() { /// let app = App::new().service( /// web::scope("/{project_id}/") -/// .service(web::resource("/path1").to(|| HttpResponse::Ok())) +/// .service(web::resource("/path1").to(|| async { HttpResponse::Ok() })) /// .service(web::resource("/path2").route(web::get().to(|| HttpResponse::Ok()))) /// .service(web::resource("/path3").route(web::head().to(|| HttpResponse::MethodNotAllowed()))) /// ); @@ -64,7 +62,7 @@ pub struct Scope { endpoint: T, rdef: String, data: Option, - services: Vec>, + services: Vec>, guards: Vec>, default: Rc>>>, external: Vec, @@ -90,7 +88,7 @@ impl Scope { impl Scope where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -103,7 +101,7 @@ where /// ```rust /// use actix_web::{web, guard, App, HttpRequest, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -134,7 +132,7 @@ where /// counter: Cell, /// } /// - /// fn index(data: web::Data) { + /// async fn index(data: web::Data) { /// data.counter.set(data.counter.get() + 1); /// } /// @@ -230,7 +228,7 @@ where /// /// struct AppState; /// - /// fn index(req: HttpRequest) -> &'static str { + /// async fn index(req: HttpRequest) -> &'static str { /// "Welcome!" /// } /// @@ -260,7 +258,7 @@ where /// ```rust /// use actix_web::{web, App, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -285,8 +283,8 @@ where /// If default resource is not registered, app's default resource is being used. pub fn default_service(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -295,8 +293,8 @@ where U::InitError: fmt::Debug, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|e| { + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|e| { log::error!("Can not construct default service: {:?}", e) }), ))))); @@ -313,11 +311,11 @@ where /// ServiceResponse. /// /// Use middleware when you need to read or modify *every* request in some way. - pub fn wrap( + pub fn wrap( self, - mw: F, + mw: M, ) -> Scope< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -333,11 +331,9 @@ where Error = Error, InitError = (), >, - F: IntoTransform, { - let endpoint = apply_transform(mw, self.endpoint); Scope { - endpoint, + endpoint: apply(mw, self.endpoint), rdef: self.rdef, data: self.data, guards: self.guards, @@ -357,24 +353,26 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// /// fn main() { /// let app = App::new().service( /// web::scope("/app") - /// .wrap_fn(|req, srv| - /// srv.call(req).map(|mut res| { + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; /// res.headers_mut().insert( /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// ); - /// res - /// })) + /// Ok(res) + /// } + /// }) /// .route("/index.html", web::get().to(index))); /// } /// ``` @@ -382,7 +380,7 @@ where self, mw: F, ) -> Scope< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -392,15 +390,24 @@ where > where F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, - R: IntoFuture, + R: Future>, { - self.wrap(mw) + Scope { + endpoint: apply_fn_factory(self.endpoint, mw), + rdef: self.rdef, + data: self.data, + guards: self.guards, + services: self.services, + default: self.default, + external: self.external, + factory_ref: self.factory_ref, + } } } impl HttpServiceFactory for Scope where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -471,7 +478,7 @@ pub struct ScopeFactory { default: Rc>>>, } -impl NewService for ScopeFactory { +impl ServiceFactory for ScopeFactory { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -508,14 +515,15 @@ impl NewService for ScopeFactory { /// Create scope service #[doc(hidden)] +#[pin_project::pin_project] pub struct ScopeFactoryResponse { fut: Vec, data: Option>, default: Option, - default_fut: Option>>, + default_fut: Option>>, } -type HttpServiceFut = Box>; +type HttpServiceFut = LocalBoxFuture<'static, Result>; enum CreateScopeServiceItem { Future(Option, Option, HttpServiceFut), @@ -523,16 +531,15 @@ enum CreateScopeServiceItem { } impl Future for ScopeFactoryResponse { - type Item = ScopeService; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut done = true; if let Some(ref mut fut) = self.default_fut { - match fut.poll()? { - Async::Ready(default) => self.default = Some(default), - Async::NotReady => done = false, + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, } } @@ -543,11 +550,11 @@ impl Future for ScopeFactoryResponse { ref mut path, ref mut guards, ref mut fut, - ) => match fut.poll()? { - Async::Ready(service) => { + ) => match Pin::new(fut).poll(cx)? { + Poll::Ready(service) => { Some((path.take().unwrap(), guards.take(), service)) } - Async::NotReady => { + Poll::Pending => { done = false; None } @@ -573,14 +580,14 @@ impl Future for ScopeFactoryResponse { } router }); - Ok(Async::Ready(ScopeService { + Poll::Ready(Ok(ScopeService { data: self.data.clone(), router: router.finish(), default: self.default.take(), _ready: None, })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -596,10 +603,10 @@ impl Service for ScopeService { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Either>; + type Future = Either>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { @@ -618,12 +625,12 @@ impl Service for ScopeService { if let Some(ref data) = self.data { req.set_data_container(data.clone()); } - Either::A(srv.call(req)) + Either::Left(srv.call(req)) } else if let Some(ref mut default) = self.default { - Either::A(default.call(req)) + Either::Left(default.call(req)) } else { let req = req.into_parts().0; - Either::B(ok(ServiceResponse::new(req, Response::NotFound().finish()))) + Either::Right(ok(ServiceResponse::new(req, Response::NotFound().finish()))) } } } @@ -639,7 +646,7 @@ impl ScopeEndpoint { } } -impl NewService for ScopeEndpoint { +impl ServiceFactory for ScopeEndpoint { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -657,367 +664,424 @@ impl NewService for ScopeEndpoint { mod tests { use actix_service::Service; use bytes::Bytes; - use futures::{Future, IntoFuture}; + use futures::future::ok; + use futures::Future; use crate::dev::{Body, ResponseBody}; use crate::http::{header, HeaderValue, Method, StatusCode}; + use crate::middleware::DefaultHeaders; use crate::service::{ServiceRequest, ServiceResponse}; use crate::test::{block_on, call_service, init_service, read_body, TestRequest}; use crate::{guard, web, App, Error, HttpRequest, HttpResponse}; #[test] fn test_scope() { - let mut srv = init_service( - App::new().service( - web::scope("/app") - .service(web::resource("/path1").to(|| HttpResponse::Ok())), - ), - ); + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::resource("/path1").to(|| HttpResponse::Ok())), + ), + ) + .await; - let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_scope_root() { - let mut srv = init_service( - App::new().service( - web::scope("/app") - .service(web::resource("").to(|| HttpResponse::Ok())) - .service(web::resource("/").to(|| HttpResponse::Created())), - ), - ); + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::resource("").to(|| HttpResponse::Ok())) + .service(web::resource("/").to(|| HttpResponse::Created())), + ), + ) + .await; - let req = TestRequest::with_uri("/app").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/app/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/app/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + }) } #[test] fn test_scope_root2() { - let mut srv = init_service(App::new().service( - web::scope("/app/").service(web::resource("").to(|| HttpResponse::Ok())), - )); + block_on(async { + let mut srv = init_service(App::new().service( + web::scope("/app/").service(web::resource("").to(|| HttpResponse::Ok())), + )) + .await; - let req = TestRequest::with_uri("/app").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/app").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let req = TestRequest::with_uri("/app/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_scope_root3() { - let mut srv = init_service(App::new().service( - web::scope("/app/").service(web::resource("/").to(|| HttpResponse::Ok())), - )); + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("/app/") + .service(web::resource("/").to(|| HttpResponse::Ok())), + ), + ) + .await; - let req = TestRequest::with_uri("/app").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/app").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let req = TestRequest::with_uri("/app/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/app/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }) } #[test] fn test_scope_route() { - let mut srv = init_service( - App::new().service( - web::scope("app") - .route("/path1", web::get().to(|| HttpResponse::Ok())) - .route("/path1", web::delete().to(|| HttpResponse::Ok())), - ), - ); + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("app") + .route("/path1", web::get().to(|| HttpResponse::Ok())) + .route("/path1", web::delete().to(|| HttpResponse::Ok())), + ), + ) + .await; - let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/app/path1") - .method(Method::DELETE) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/path1") + .method(Method::DELETE) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/app/path1") - .method(Method::POST) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }) } #[test] fn test_scope_route_without_leading_slash() { - let mut srv = init_service( - App::new().service( - web::scope("app").service( - web::resource("path1") - .route(web::get().to(|| HttpResponse::Ok())) - .route(web::delete().to(|| HttpResponse::Ok())), + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("app").service( + web::resource("path1") + .route(web::get().to(|| HttpResponse::Ok())) + .route(web::delete().to(|| HttpResponse::Ok())), + ), ), - ), - ); + ) + .await; - let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/app/path1") - .method(Method::DELETE) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/path1") + .method(Method::DELETE) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/app/path1") - .method(Method::POST) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + }) } #[test] fn test_scope_guard() { - let mut srv = init_service( - App::new().service( - web::scope("/app") - .guard(guard::Get()) - .service(web::resource("/path1").to(|| HttpResponse::Ok())), - ), - ); + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .guard(guard::Get()) + .service(web::resource("/path1").to(|| HttpResponse::Ok())), + ), + ) + .await; - let req = TestRequest::with_uri("/app/path1") - .method(Method::POST) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let req = TestRequest::with_uri("/app/path1") - .method(Method::GET) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/path1") + .method(Method::GET) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_scope_variable_segment() { - let mut srv = - init_service(App::new().service(web::scope("/ab-{project}").service( - web::resource("/path1").to(|r: HttpRequest| { - HttpResponse::Ok() - .body(format!("project: {}", &r.match_info()["project"])) - }), - ))); + block_on(async { + let mut srv = + init_service(App::new().service(web::scope("/ab-{project}").service( + web::resource("/path1").to(|r: HttpRequest| { + async move { + HttpResponse::Ok() + .body(format!("project: {}", &r.match_info()["project"])) + } + }), + ))) + .await; - let req = TestRequest::with_uri("/ab-project1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/ab-project1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes: Bytes = b.clone().into(); - assert_eq!(bytes, Bytes::from_static(b"project: project1")); + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"project: project1")); + } + _ => panic!(), } - _ => panic!(), - } - let req = TestRequest::with_uri("/aa-project1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/aa-project1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }) } #[test] fn test_nested_scope() { - let mut srv = init_service( - App::new().service( - web::scope("/app") - .service(web::scope("/t1").service( + block_on(async { + let mut srv = + init_service(App::new().service( + web::scope("/app").service(web::scope("/t1").service( web::resource("/path1").to(|| HttpResponse::Created()), )), - ), - ); + )) + .await; - let req = TestRequest::with_uri("/app/t1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/app/t1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + }) } #[test] fn test_nested_scope_no_slash() { - let mut srv = init_service( - App::new().service( - web::scope("/app") - .service(web::scope("t1").service( + block_on(async { + let mut srv = + init_service(App::new().service( + web::scope("/app").service(web::scope("t1").service( web::resource("/path1").to(|| HttpResponse::Created()), )), - ), - ); + )) + .await; - let req = TestRequest::with_uri("/app/t1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/app/t1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + }) } #[test] fn test_nested_scope_root() { - let mut srv = init_service( - App::new().service( - web::scope("/app").service( - web::scope("/t1") - .service(web::resource("").to(|| HttpResponse::Ok())) - .service(web::resource("/").to(|| HttpResponse::Created())), + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("/app").service( + web::scope("/t1") + .service(web::resource("").to(|| HttpResponse::Ok())) + .service(web::resource("/").to(|| HttpResponse::Created())), + ), ), - ), - ); + ) + .await; - let req = TestRequest::with_uri("/app/t1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/t1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); - let req = TestRequest::with_uri("/app/t1/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/app/t1/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + }) } #[test] fn test_nested_scope_filter() { - let mut srv = init_service( - App::new().service( - web::scope("/app").service( - web::scope("/t1") - .guard(guard::Get()) - .service(web::resource("/path1").to(|| HttpResponse::Ok())), + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("/app").service( + web::scope("/t1") + .guard(guard::Get()) + .service(web::resource("/path1").to(|| HttpResponse::Ok())), + ), ), - ), - ); + ) + .await; - let req = TestRequest::with_uri("/app/t1/path1") - .method(Method::POST) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let req = TestRequest::with_uri("/app/t1/path1") - .method(Method::GET) - .to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::GET) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_nested_scope_with_variable_segment() { - let mut srv = init_service(App::new().service(web::scope("/app").service( - web::scope("/{project_id}").service(web::resource("/path1").to( - |r: HttpRequest| { - HttpResponse::Created() - .body(format!("project: {}", &r.match_info()["project_id"])) - }, - )), - ))); + block_on(async { + let mut srv = init_service(App::new().service(web::scope("/app").service( + web::scope("/{project_id}").service(web::resource("/path1").to( + |r: HttpRequest| { + async move { + HttpResponse::Created().body(format!( + "project: {}", + &r.match_info()["project_id"] + )) + } + }, + )), + ))) + .await; - let req = TestRequest::with_uri("/app/project_1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/app/project_1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes: Bytes = b.clone().into(); - assert_eq!(bytes, Bytes::from_static(b"project: project_1")); + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"project: project_1")); + } + _ => panic!(), } - _ => panic!(), - } + }) } #[test] fn test_nested2_scope_with_variable_segment() { - let mut srv = init_service(App::new().service(web::scope("/app").service( - web::scope("/{project}").service(web::scope("/{id}").service( - web::resource("/path1").to(|r: HttpRequest| { - HttpResponse::Created().body(format!( - "project: {} - {}", - &r.match_info()["project"], - &r.match_info()["id"], - )) - }), - )), - ))); + block_on(async { + let mut srv = init_service(App::new().service(web::scope("/app").service( + web::scope("/{project}").service(web::scope("/{id}").service( + web::resource("/path1").to(|r: HttpRequest| { + async move { + HttpResponse::Created().body(format!( + "project: {} - {}", + &r.match_info()["project"], + &r.match_info()["id"], + )) + } + }), + )), + ))) + .await; - let req = TestRequest::with_uri("/app/test/1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); + let req = TestRequest::with_uri("/app/test/1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes: Bytes = b.clone().into(); - assert_eq!(bytes, Bytes::from_static(b"project: test - 1")); + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"project: test - 1")); + } + _ => panic!(), } - _ => panic!(), - } - let req = TestRequest::with_uri("/app/test/1/path2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/app/test/1/path2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }) } #[test] fn test_default_resource() { - let mut srv = init_service( - App::new().service( - web::scope("/app") - .service(web::resource("/path1").to(|| HttpResponse::Ok())) - .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::BadRequest()) - }), - ), - ); + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::resource("/path1").to(|| HttpResponse::Ok())) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::BadRequest())) + }), + ), + ) + .await; - let req = TestRequest::with_uri("/app/path2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let req = TestRequest::with_uri("/app/path2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let req = TestRequest::with_uri("/path2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let req = TestRequest::with_uri("/path2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }) } #[test] fn test_default_resource_propagation() { - let mut srv = init_service( - App::new() - .service(web::scope("/app1").default_service( - web::resource("").to(|| HttpResponse::BadRequest()), - )) - .service(web::scope("/app2")) - .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::MethodNotAllowed()) - }), - ); + block_on(async { + let mut srv = init_service( + App::new() + .service(web::scope("/app1").default_service( + web::resource("").to(|| HttpResponse::BadRequest()), + )) + .service(web::scope("/app2")) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::MethodNotAllowed())) + }), + ) + .await; - let req = TestRequest::with_uri("/non-exist").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let req = TestRequest::with_uri("/non-exist").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let req = TestRequest::with_uri("/app1/non-exist").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let req = TestRequest::with_uri("/app1/non-exist").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let req = TestRequest::with_uri("/app2/non-exist").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let req = TestRequest::with_uri("/app2/non-exist").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + }) } fn md( req: ServiceRequest, srv: &mut S, - ) -> impl IntoFuture, Error = Error> + ) -> impl Future, Error>> where S: Service< Request = ServiceRequest, @@ -1025,165 +1089,213 @@ mod tests { Error = Error, >, { - srv.call(req).map(|mut res| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; res.headers_mut() .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); - res - }) + Ok(res) + } } #[test] fn test_middleware() { - let mut srv = - init_service(App::new().service(web::scope("app").wrap(md).service( - web::resource("/test").route(web::get().to(|| HttpResponse::Ok())), - ))); - let req = TestRequest::with_uri("/app/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("app") + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )) + .service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] fn test_middleware_fn() { - let mut srv = init_service( - App::new().service( - web::scope("app") - .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); - res + block_on(async { + let mut srv = init_service( + App::new().service( + web::scope("app") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + Ok(res) + } }) - }) - .route("/test", web::get().to(|| HttpResponse::Ok())), - ), - ); - let req = TestRequest::with_uri("/app/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static("0001") - ); + .route("/test", web::get().to(|| HttpResponse::Ok())), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + }) } #[test] fn test_override_data() { - let mut srv = init_service(App::new().data(1usize).service( - web::scope("app").data(10usize).route( - "/t", - web::get().to(|data: web::Data| { - assert_eq!(*data, 10); - let _ = data.clone(); - HttpResponse::Ok() - }), - ), - )); + block_on(async { + let mut srv = init_service(App::new().data(1usize).service( + web::scope("app").data(10usize).route( + "/t", + web::get().to(|data: web::Data| { + assert_eq!(*data, 10); + let _ = data.clone(); + HttpResponse::Ok() + }), + ), + )) + .await; - let req = TestRequest::with_uri("/app/t").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/t").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_override_register_data() { - let mut srv = init_service( - App::new().register_data(web::Data::new(1usize)).service( - web::scope("app") - .register_data(web::Data::new(10usize)) - .route( - "/t", - web::get().to(|data: web::Data| { - assert_eq!(*data, 10); - let _ = data.clone(); - HttpResponse::Ok() - }), - ), - ), - ); + block_on(async { + let mut srv = init_service( + App::new().register_data(web::Data::new(1usize)).service( + web::scope("app") + .register_data(web::Data::new(10usize)) + .route( + "/t", + web::get().to(|data: web::Data| { + assert_eq!(*data, 10); + let _ = data.clone(); + HttpResponse::Ok() + }), + ), + ), + ) + .await; - let req = TestRequest::with_uri("/app/t").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/t").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_scope_config() { - let mut srv = - init_service(App::new().service(web::scope("/app").configure(|s| { - s.route("/path1", web::get().to(|| HttpResponse::Ok())); - }))); + block_on(async { + let mut srv = + init_service(App::new().service(web::scope("/app").configure(|s| { + s.route("/path1", web::get().to(|| HttpResponse::Ok())); + }))) + .await; - let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_scope_config_2() { - let mut srv = - init_service(App::new().service(web::scope("/app").configure(|s| { - s.service(web::scope("/v1").configure(|s| { - s.route("/", web::get().to(|| HttpResponse::Ok())); - })); - }))); + block_on(async { + let mut srv = + init_service(App::new().service(web::scope("/app").configure(|s| { + s.service(web::scope("/v1").configure(|s| { + s.route("/", web::get().to(|| HttpResponse::Ok())); + })); + }))) + .await; - let req = TestRequest::with_uri("/app/v1/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/app/v1/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_url_for_external() { - let mut srv = - init_service(App::new().service(web::scope("/app").configure(|s| { - s.service(web::scope("/v1").configure(|s| { - s.external_resource( - "youtube", - "https://youtube.com/watch/{video_id}", - ); - s.route( - "/", - web::get().to(|req: HttpRequest| { - HttpResponse::Ok().body(format!( - "{}", - req.url_for("youtube", &["xxxxxx"]).unwrap().as_str() - )) - }), - ); - })); - }))); + block_on(async { + let mut srv = + init_service(App::new().service(web::scope("/app").configure(|s| { + s.service(web::scope("/v1").configure(|s| { + s.external_resource( + "youtube", + "https://youtube.com/watch/{video_id}", + ); + s.route( + "/", + web::get().to(|req: HttpRequest| { + async move { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("youtube", &["xxxxxx"]) + .unwrap() + .as_str() + )) + } + }), + ); + })); + }))) + .await; - let req = TestRequest::with_uri("/app/v1/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); - assert_eq!(body, &b"https://youtube.com/watch/xxxxxx"[..]); + let req = TestRequest::with_uri("/app/v1/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!(body, &b"https://youtube.com/watch/xxxxxx"[..]); + }) } #[test] fn test_url_for_nested() { - let mut srv = init_service(App::new().service(web::scope("/a").service( - web::scope("/b").service(web::resource("/c/{stuff}").name("c").route( - web::get().to(|req: HttpRequest| { - HttpResponse::Ok() - .body(format!("{}", req.url_for("c", &["12345"]).unwrap())) - }), - )), - ))); - let req = TestRequest::with_uri("/a/b/c/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); - assert_eq!( - body, - Bytes::from_static(b"http://localhost:8080/a/b/c/12345") - ); + block_on(async { + let mut srv = init_service(App::new().service(web::scope("/a").service( + web::scope("/b").service(web::resource("/c/{stuff}").name("c").route( + web::get().to(|req: HttpRequest| { + async move { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("c", &["12345"]).unwrap() + )) + } + }), + )), + ))) + .await; + + let req = TestRequest::with_uri("/a/b/c/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!( + body, + Bytes::from_static(b"http://localhost:8080/a/b/c/12345") + ); + }) } } diff --git a/src/server.rs b/src/server.rs index 51492eb01..a98d06275 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,15 +6,15 @@ use actix_http::{body::MessageBody, Error, HttpService, KeepAlive, Request, Resp use actix_rt::System; use actix_server::{Server, ServerBuilder}; use actix_server_config::ServerConfig; -use actix_service::{IntoNewService, NewService}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use parking_lot::Mutex; use net2::TcpBuilder; -#[cfg(feature = "ssl")] -use openssl::ssl::{SslAcceptor, SslAcceptorBuilder}; -#[cfg(feature = "rust-tls")] -use rustls::ServerConfig as RustlsServerConfig; +#[cfg(feature = "openssl")] +use open_ssl::ssl::{SslAcceptor, SslAcceptorBuilder}; +#[cfg(feature = "rustls")] +use rust_tls::ServerConfig as RustlsServerConfig; struct Socket { scheme: &'static str, @@ -51,12 +51,11 @@ struct Config { pub struct HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoNewService, - S: NewService, + I: IntoServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, - S::Service: 'static, B: MessageBody, { pub(super) factory: F, @@ -71,12 +70,12 @@ where impl HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoNewService, - S: NewService, - S::Error: Into, + I: IntoServiceFactory, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, - S::Service: 'static, + S::Response: Into> + 'static, + ::Future: 'static, B: MessageBody + 'static, { /// Create new http server with application factory @@ -254,11 +253,11 @@ where Ok(self) } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] /// Use listener for accepting incoming tls connection requests /// /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn listen_ssl( + pub fn listen_openssl( self, lst: net::TcpListener, builder: SslAcceptorBuilder, @@ -266,13 +265,14 @@ where self.listen_ssl_inner(lst, openssl_acceptor(builder)?) } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] fn listen_ssl_inner( mut self, lst: net::TcpListener, acceptor: SslAcceptor, ) -> io::Result { use actix_server::ssl::{OpensslAcceptor, SslError}; + use actix_service::pipeline_factory; let acceptor = OpensslAcceptor::new(acceptor); let factory = self.factory.clone(); @@ -288,7 +288,7 @@ where lst, move || { let c = cfg.lock(); - acceptor.clone().map_err(SslError::Ssl).and_then( + pipeline_factory(acceptor.clone().map_err(SslError::Ssl)).and_then( HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) @@ -302,7 +302,7 @@ where Ok(self) } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] /// Use listener for accepting incoming tls connection requests /// /// This method sets alpn protocols to "h2" and "http/1.1" @@ -314,13 +314,14 @@ where self.listen_rustls_inner(lst, config) } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] fn listen_rustls_inner( mut self, lst: net::TcpListener, mut config: RustlsServerConfig, ) -> io::Result { use actix_server::ssl::{RustlsAcceptor, SslError}; + use actix_service::pipeline_factory; let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; config.set_protocols(&protos); @@ -339,7 +340,7 @@ where lst, move || { let c = cfg.lock(); - acceptor.clone().map_err(SslError::Ssl).and_then( + pipeline_factory(acceptor.clone().map_err(SslError::Ssl)).and_then( HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) @@ -397,11 +398,11 @@ where } } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] /// Start listening for incoming tls connections. /// /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn bind_ssl( + pub fn bind_openssl( mut self, addr: A, builder: SslAcceptorBuilder, @@ -419,7 +420,7 @@ where Ok(self) } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] /// Start listening for incoming tls connections. /// /// This method sets alpn protocols to "h2" and "http/1.1" @@ -435,7 +436,7 @@ where Ok(self) } - #[cfg(feature = "uds")] + #[cfg(unix)] /// Start listening for unix domain connections on existing listener. /// /// This method is available with `uds` feature. @@ -466,7 +467,7 @@ where Ok(self) } - #[cfg(feature = "uds")] + #[cfg(unix)] /// Start listening for incoming unix domain connections. /// /// This method is available with `uds` feature. @@ -502,8 +503,8 @@ where impl HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoNewService, - S: NewService, + I: IntoServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, @@ -577,10 +578,10 @@ fn create_tcp_listener( Ok(builder.listen(backlog)?) } -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] /// Configure `SslAcceptorBuilder` with custom server flags. fn openssl_acceptor(mut builder: SslAcceptorBuilder) -> io::Result { - use openssl::ssl::AlpnError; + use open_ssl::ssl::AlpnError; builder.set_alpn_select_callback(|_, protos| { const H2: &[u8] = b"\x02h2"; diff --git a/src/service.rs b/src/service.rs index 8b94dd284..39540b067 100644 --- a/src/service.rs +++ b/src/service.rs @@ -9,8 +9,8 @@ use actix_http::{ ResponseHead, }; use actix_router::{Path, Resource, ResourceDef, Url}; -use actix_service::{IntoNewService, NewService}; -use futures::future::{ok, FutureResult, IntoFuture}; +use actix_service::{IntoServiceFactory, ServiceFactory}; +use futures::future::{ok, Ready}; use crate::config::{AppConfig, AppService}; use crate::data::Data; @@ -24,7 +24,7 @@ pub trait HttpServiceFactory { fn register(self, config: &mut AppService); } -pub(crate) trait ServiceFactory { +pub(crate) trait AppServiceFactory { fn register(&mut self, config: &mut AppService); } @@ -40,7 +40,7 @@ impl ServiceFactoryWrapper { } } -impl ServiceFactory for ServiceFactoryWrapper +impl AppServiceFactory for ServiceFactoryWrapper where T: HttpServiceFactory, { @@ -404,16 +404,6 @@ impl Into> for ServiceResponse { } } -impl IntoFuture for ServiceResponse { - type Item = ServiceResponse; - type Error = Error; - type Future = FutureResult, Error>; - - fn into_future(self) -> Self::Future { - ok(self) - } -} - impl fmt::Debug for ServiceResponse { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = writeln!( @@ -459,10 +449,10 @@ impl WebService { /// Add match guard to a web service. /// /// ```rust - /// use actix_web::{web, guard, dev, App, HttpResponse}; + /// use actix_web::{web, guard, dev, App, Error, HttpResponse}; /// - /// fn index(req: dev::ServiceRequest) -> dev::ServiceResponse { - /// req.into_response(HttpResponse::Ok().finish()) + /// async fn index(req: dev::ServiceRequest) -> Result { + /// Ok(req.into_response(HttpResponse::Ok().finish())) /// } /// /// fn main() { @@ -482,8 +472,8 @@ impl WebService { /// Set a service factory implementation and generate web service. pub fn finish(self, service: F) -> impl HttpServiceFactory where - F: IntoNewService, - T: NewService< + F: IntoServiceFactory, + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -492,7 +482,7 @@ impl WebService { > + 'static, { WebServiceImpl { - srv: service.into_new_service(), + srv: service.into_factory(), rdef: self.rdef, name: self.name, guards: self.guards, @@ -509,7 +499,7 @@ struct WebServiceImpl { impl HttpServiceFactory for WebServiceImpl where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -539,8 +529,9 @@ where #[cfg(test)] mod tests { use super::*; - use crate::test::{call_service, init_service, TestRequest}; + use crate::test::{block_on, init_service, TestRequest}; use crate::{guard, http, web, App, HttpResponse}; + use actix_service::Service; #[test] fn test_service_request() { @@ -565,25 +556,33 @@ mod tests { #[test] fn test_service() { - let mut srv = init_service( - App::new().service(web::service("/test").name("test").finish( - |req: ServiceRequest| req.into_response(HttpResponse::Ok().finish()), - )), - ); - let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), http::StatusCode::OK); + block_on(async { + let mut srv = init_service( + App::new().service(web::service("/test").name("test").finish( + |req: ServiceRequest| { + ok(req.into_response(HttpResponse::Ok().finish())) + }, + )), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); - let mut srv = init_service( - App::new().service(web::service("/test").guard(guard::Get()).finish( - |req: ServiceRequest| req.into_response(HttpResponse::Ok().finish()), - )), - ); - let req = TestRequest::with_uri("/test") - .method(http::Method::PUT) - .to_request(); - let resp = call_service(&mut srv, req); - assert_eq!(resp.status(), http::StatusCode::NOT_FOUND); + let mut srv = init_service(App::new().service( + web::service("/test").guard(guard::Get()).finish( + |req: ServiceRequest| { + ok(req.into_response(HttpResponse::Ok().finish())) + }, + ), + )) + .await; + let req = TestRequest::with_uri("/test") + .method(http::Method::PUT) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::NOT_FOUND); + }) } #[test] diff --git a/src/test.rs b/src/test.rs index 6563253cc..0776b0f15 100644 --- a/src/test.rs +++ b/src/test.rs @@ -7,10 +7,10 @@ use actix_http::test::TestRequest as HttpTestRequest; use actix_http::{cookie::Cookie, Extensions, Request}; use actix_router::{Path, ResourceDef, Url}; use actix_server_config::ServerConfig; -use actix_service::{IntoNewService, IntoService, NewService, Service}; +use actix_service::{IntoService, IntoServiceFactory, Service, ServiceFactory}; use bytes::{Bytes, BytesMut}; -use futures::future::{ok, Future}; -use futures::Stream; +use futures::future::{ok, Future, FutureExt}; +use futures::stream::{Stream, StreamExt}; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json; @@ -39,7 +39,7 @@ pub fn default_service( ) -> impl Service, Error = Error> { (move |req: ServiceRequest| { - req.into_response(HttpResponse::build(status_code).finish()) + ok(req.into_response(HttpResponse::build(status_code).finish())) }) .into_service() } @@ -55,7 +55,7 @@ pub fn default_service( /// fn test_init_service() { /// let mut app = test::init_service( /// App::new() -/// .service(web::resource("/test").to(|| HttpResponse::Ok())) +/// .service(web::resource("/test").to(|| async { HttpResponse::Ok() })) /// ); /// /// // Create request object @@ -66,12 +66,12 @@ pub fn default_service( /// assert_eq!(resp.status(), StatusCode::OK); /// } /// ``` -pub fn init_service( +pub async fn init_service( app: R, ) -> impl Service, Error = E> where - R: IntoNewService, - S: NewService< + R: IntoServiceFactory, + S: ServiceFactory< Config = ServerConfig, Request = Request, Response = ServiceResponse, @@ -80,9 +80,8 @@ where S::InitError: std::fmt::Debug, { let cfg = ServerConfig::new("127.0.0.1:8080".parse().unwrap()); - let srv = app.into_new_service(); - let fut = run_on(move || srv.new_service(&cfg)); - block_on(fut).unwrap() + let srv = app.into_factory(); + srv.new_service(&cfg).await.unwrap() } /// Calls service and waits for response future completion. @@ -95,23 +94,25 @@ where /// fn test_response() { /// let mut app = test::init_service( /// App::new() -/// .service(web::resource("/test").to(|| HttpResponse::Ok())) -/// ); +/// .service(web::resource("/test").to(|| async { +/// HttpResponse::Ok() +/// })) +/// ).await; /// /// // Create request object /// let req = test::TestRequest::with_uri("/test").to_request(); /// /// // Call application -/// let resp = test::call_service(&mut app, req); +/// let resp = test::call_service(&mut app, req).await; /// assert_eq!(resp.status(), StatusCode::OK); /// } /// ``` -pub fn call_service(app: &mut S, req: R) -> S::Response +pub async fn call_service(app: &mut S, req: R) -> S::Response where S: Service, Error = E>, E: std::fmt::Debug, { - block_on(run_on(move || app.call(req))).unwrap() + app.call(req).await.unwrap() } /// Helper function that returns a response body of a TestRequest @@ -126,34 +127,36 @@ where /// let mut app = test::init_service( /// App::new().service( /// web::resource("/index.html") -/// .route(web::post().to( -/// || HttpResponse::Ok().body("welcome!"))))); +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; /// /// let req = test::TestRequest::post() /// .uri("/index.html") /// .header(header::CONTENT_TYPE, "application/json") /// .to_request(); /// -/// let result = test::read_response(&mut app, req); +/// let result = test::read_response(&mut app, req).await; /// assert_eq!(result, Bytes::from_static(b"welcome!")); /// } /// ``` -pub fn read_response(app: &mut S, req: Request) -> Bytes +pub async fn read_response(app: &mut S, req: Request) -> Bytes where S: Service, Error = Error>, B: MessageBody, { - block_on(run_on(move || { - app.call(req).and_then(|mut resp: ServiceResponse| { - resp.take_body() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, Error>(body) - }) - .map(|body: BytesMut| body.freeze()) - }) - })) - .unwrap_or_else(|_| panic!("read_response failed at block_on unwrap")) + let mut resp = app + .call(req) + .await + .unwrap_or_else(|_| panic!("read_response failed at block_on unwrap")); + + let mut body = resp.take_body(); + let mut bytes = BytesMut::new(); + while let Some(item) = body.next().await { + bytes.extend_from_slice(&item.unwrap()); + } + bytes.freeze() } /// Helper function that returns a response body of a ServiceResponse. @@ -168,32 +171,42 @@ where /// let mut app = test::init_service( /// App::new().service( /// web::resource("/index.html") -/// .route(web::post().to( -/// || HttpResponse::Ok().body("welcome!"))))); +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; /// /// let req = test::TestRequest::post() /// .uri("/index.html") /// .header(header::CONTENT_TYPE, "application/json") /// .to_request(); /// -/// let resp = test::call_service(&mut app, req); +/// let resp = test::call_service(&mut app, req).await; /// let result = test::read_body(resp); /// assert_eq!(result, Bytes::from_static(b"welcome!")); /// } /// ``` -pub fn read_body(mut res: ServiceResponse) -> Bytes +pub async fn read_body(mut res: ServiceResponse) -> Bytes where B: MessageBody, { - block_on(run_on(move || { - res.take_body() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, Error>(body) - }) - .map(|body: BytesMut| body.freeze()) - })) - .unwrap_or_else(|_| panic!("read_response failed at block_on unwrap")) + let mut body = res.take_body(); + let mut bytes = BytesMut::new(); + while let Some(item) = body.next().await { + bytes.extend_from_slice(&item.unwrap()); + } + bytes.freeze() +} + +pub async fn load_stream(mut stream: S) -> Result +where + S: Stream> + Unpin, +{ + let mut data = BytesMut::new(); + while let Some(item) = stream.next().await { + data.extend_from_slice(&item?); + } + Ok(data.freeze()) } /// Helper function that returns a deserialized response body of a TestRequest @@ -214,10 +227,11 @@ where /// let mut app = test::init_service( /// App::new().service( /// web::resource("/people") -/// .route(web::post().to(|person: web::Json| { +/// .route(web::post().to(|person: web::Json| async { /// HttpResponse::Ok() /// .json(person.into_inner())}) -/// ))); +/// )) +/// ).await; /// /// let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); /// @@ -227,30 +241,19 @@ where /// .set_payload(payload) /// .to_request(); /// -/// let result: Person = test::read_response_json(&mut app, req); +/// let result: Person = test::read_response_json(&mut app, req).await; /// } /// ``` -pub fn read_response_json(app: &mut S, req: Request) -> T +pub async fn read_response_json(app: &mut S, req: Request) -> T where S: Service, Error = Error>, B: MessageBody, T: DeserializeOwned, { - block_on(run_on(move || { - app.call(req).and_then(|mut resp: ServiceResponse| { - resp.take_body() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, Error>(body) - }) - .and_then(|body: BytesMut| { - ok(serde_json::from_slice(&body).unwrap_or_else(|_| { - panic!("read_response_json failed during deserialization") - })) - }) - }) - })) - .unwrap_or_else(|_| panic!("read_response_json failed at block_on unwrap")) + let body = read_response(app, req).await; + + serde_json::from_slice(&body) + .unwrap_or_else(|_| panic!("read_response_json failed during deserialization")) } /// Test `Request` builder. @@ -266,7 +269,7 @@ where /// use actix_web::{test, HttpRequest, HttpResponse, HttpMessage}; /// use actix_web::http::{header, StatusCode}; /// -/// fn index(req: HttpRequest) -> HttpResponse { +/// async fn index(req: HttpRequest) -> HttpResponse { /// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { /// HttpResponse::Ok().into() /// } else { @@ -279,11 +282,11 @@ where /// let req = test::TestRequest::with_header("content-type", "text/plain") /// .to_http_request(); /// -/// let resp = test::block_on(index(req)).unwrap(); +/// let resp = index(req).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// /// let req = test::TestRequest::default().to_http_request(); -/// let resp = test::block_on(index(req)).unwrap(); +/// let resp = index(req).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); /// } /// ``` @@ -511,74 +514,88 @@ mod tests { #[test] fn test_basics() { - let req = TestRequest::with_hdr(header::ContentType::json()) - .version(Version::HTTP_2) - .set(header::Date(SystemTime::now().into())) - .param("test", "123") - .data(10u32) - .to_http_request(); - assert!(req.headers().contains_key(header::CONTENT_TYPE)); - assert!(req.headers().contains_key(header::DATE)); - assert_eq!(&req.match_info()["test"], "123"); - assert_eq!(req.version(), Version::HTTP_2); - let data = req.get_app_data::().unwrap(); - assert!(req.get_app_data::().is_none()); - assert_eq!(*data, 10); - assert_eq!(*data.get_ref(), 10); + block_on(async { + let req = TestRequest::with_hdr(header::ContentType::json()) + .version(Version::HTTP_2) + .set(header::Date(SystemTime::now().into())) + .param("test", "123") + .data(10u32) + .to_http_request(); + assert!(req.headers().contains_key(header::CONTENT_TYPE)); + assert!(req.headers().contains_key(header::DATE)); + assert_eq!(&req.match_info()["test"], "123"); + assert_eq!(req.version(), Version::HTTP_2); + let data = req.get_app_data::().unwrap(); + assert!(req.get_app_data::().is_none()); + assert_eq!(*data, 10); + assert_eq!(*data.get_ref(), 10); - assert!(req.app_data::().is_none()); - let data = req.app_data::().unwrap(); - assert_eq!(*data, 10); + assert!(req.app_data::().is_none()); + let data = req.app_data::().unwrap(); + assert_eq!(*data, 10); + }) } #[test] fn test_request_methods() { - let mut app = init_service( - App::new().service( - web::resource("/index.html") - .route(web::put().to(|| HttpResponse::Ok().body("put!"))) - .route(web::patch().to(|| HttpResponse::Ok().body("patch!"))) - .route(web::delete().to(|| HttpResponse::Ok().body("delete!"))), - ), - ); + block_on(async { + let mut app = init_service( + App::new().service( + web::resource("/index.html") + .route( + web::put().to(|| async { HttpResponse::Ok().body("put!") }), + ) + .route( + web::patch() + .to(|| async { HttpResponse::Ok().body("patch!") }), + ) + .route( + web::delete() + .to(|| async { HttpResponse::Ok().body("delete!") }), + ), + ), + ) + .await; - let put_req = TestRequest::put() - .uri("/index.html") - .header(header::CONTENT_TYPE, "application/json") - .to_request(); + let put_req = TestRequest::put() + .uri("/index.html") + .header(header::CONTENT_TYPE, "application/json") + .to_request(); - let result = read_response(&mut app, put_req); - assert_eq!(result, Bytes::from_static(b"put!")); + let result = read_response(&mut app, put_req).await; + assert_eq!(result, Bytes::from_static(b"put!")); - let patch_req = TestRequest::patch() - .uri("/index.html") - .header(header::CONTENT_TYPE, "application/json") - .to_request(); + let patch_req = TestRequest::patch() + .uri("/index.html") + .header(header::CONTENT_TYPE, "application/json") + .to_request(); - let result = read_response(&mut app, patch_req); - assert_eq!(result, Bytes::from_static(b"patch!")); + let result = read_response(&mut app, patch_req).await; + assert_eq!(result, Bytes::from_static(b"patch!")); - let delete_req = TestRequest::delete().uri("/index.html").to_request(); - let result = read_response(&mut app, delete_req); - assert_eq!(result, Bytes::from_static(b"delete!")); + let delete_req = TestRequest::delete().uri("/index.html").to_request(); + let result = read_response(&mut app, delete_req).await; + assert_eq!(result, Bytes::from_static(b"delete!")); + }) } #[test] fn test_response() { - let mut app = init_service( - App::new().service( - web::resource("/index.html") - .route(web::post().to(|| HttpResponse::Ok().body("welcome!"))), - ), - ); + block_on(async { + let mut app = + init_service(App::new().service(web::resource("/index.html").route( + web::post().to(|| async { HttpResponse::Ok().body("welcome!") }), + ))) + .await; - let req = TestRequest::post() - .uri("/index.html") - .header(header::CONTENT_TYPE, "application/json") - .to_request(); + let req = TestRequest::post() + .uri("/index.html") + .header(header::CONTENT_TYPE, "application/json") + .to_request(); - let result = read_response(&mut app, req); - assert_eq!(result, Bytes::from_static(b"welcome!")); + let result = read_response(&mut app, req).await; + assert_eq!(result, Bytes::from_static(b"welcome!")); + }) } #[derive(Serialize, Deserialize)] @@ -589,129 +606,146 @@ mod tests { #[test] fn test_response_json() { - let mut app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| { - HttpResponse::Ok().json(person.into_inner()) - }), - ))); + block_on(async { + let mut app = + init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| { + async { HttpResponse::Ok().json(person.into_inner()) } + }), + ))) + .await; - let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); + let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); - let req = TestRequest::post() - .uri("/people") - .header(header::CONTENT_TYPE, "application/json") - .set_payload(payload) - .to_request(); + let req = TestRequest::post() + .uri("/people") + .header(header::CONTENT_TYPE, "application/json") + .set_payload(payload) + .to_request(); - let result: Person = read_response_json(&mut app, req); - assert_eq!(&result.id, "12345"); + let result: Person = read_response_json(&mut app, req).await; + assert_eq!(&result.id, "12345"); + }) } #[test] fn test_request_response_form() { - let mut app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Form| { - HttpResponse::Ok().json(person.into_inner()) - }), - ))); + block_on(async { + let mut app = + init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Form| { + async { HttpResponse::Ok().json(person.into_inner()) } + }), + ))) + .await; - let payload = Person { - id: "12345".to_string(), - name: "User name".to_string(), - }; + let payload = Person { + id: "12345".to_string(), + name: "User name".to_string(), + }; - let req = TestRequest::post() - .uri("/people") - .set_form(&payload) - .to_request(); + let req = TestRequest::post() + .uri("/people") + .set_form(&payload) + .to_request(); - assert_eq!(req.content_type(), "application/x-www-form-urlencoded"); + assert_eq!(req.content_type(), "application/x-www-form-urlencoded"); - let result: Person = read_response_json(&mut app, req); - assert_eq!(&result.id, "12345"); - assert_eq!(&result.name, "User name"); + let result: Person = read_response_json(&mut app, req).await; + assert_eq!(&result.id, "12345"); + assert_eq!(&result.name, "User name"); + }) } #[test] fn test_request_response_json() { - let mut app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| { - HttpResponse::Ok().json(person.into_inner()) - }), - ))); + block_on(async { + let mut app = + init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| { + async { HttpResponse::Ok().json(person.into_inner()) } + }), + ))) + .await; - let payload = Person { - id: "12345".to_string(), - name: "User name".to_string(), - }; + let payload = Person { + id: "12345".to_string(), + name: "User name".to_string(), + }; - let req = TestRequest::post() - .uri("/people") - .set_json(&payload) - .to_request(); + let req = TestRequest::post() + .uri("/people") + .set_json(&payload) + .to_request(); - assert_eq!(req.content_type(), "application/json"); + assert_eq!(req.content_type(), "application/json"); - let result: Person = read_response_json(&mut app, req); - assert_eq!(&result.id, "12345"); - assert_eq!(&result.name, "User name"); + let result: Person = read_response_json(&mut app, req).await; + assert_eq!(&result.id, "12345"); + assert_eq!(&result.name, "User name"); + }) } #[test] fn test_async_with_block() { - fn async_with_block() -> impl Future { - web::block(move || Some(4).ok_or("wrong")).then(|res| match res { - Ok(value) => HttpResponse::Ok() - .content_type("text/plain") - .body(format!("Async with block value: {}", value)), - Err(_) => panic!("Unexpected"), - }) - } + block_on(async { + async fn async_with_block() -> Result { + let res = web::block(move || Some(4usize).ok_or("wrong")).await; - let mut app = init_service( - App::new().service(web::resource("/index.html").to_async(async_with_block)), - ); - - let req = TestRequest::post().uri("/index.html").to_request(); - let res = block_fn(|| app.call(req)).unwrap(); - assert!(res.status().is_success()); - } - - #[test] - fn test_actor() { - use actix::Actor; - - struct MyActor; - - struct Num(usize); - impl actix::Message for Num { - type Result = usize; - } - impl actix::Actor for MyActor { - type Context = actix::Context; - } - impl actix::Handler for MyActor { - type Result = usize; - fn handle(&mut self, msg: Num, _: &mut Self::Context) -> Self::Result { - msg.0 + match res? { + Ok(value) => Ok(HttpResponse::Ok() + .content_type("text/plain") + .body(format!("Async with block value: {}", value))), + Err(_) => panic!("Unexpected"), + } } - } - let addr = run_on(|| MyActor.start()); - let mut app = init_service(App::new().service( - web::resource("/index.html").to_async(move || { - addr.send(Num(1)).from_err().and_then(|res| { - if res == 1 { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - }) - }), - )); + let mut app = init_service( + App::new().service(web::resource("/index.html").to(async_with_block)), + ) + .await; - let req = TestRequest::post().uri("/index.html").to_request(); - let res = block_fn(|| app.call(req)).unwrap(); - assert!(res.status().is_success()); + let req = TestRequest::post().uri("/index.html").to_request(); + let res = app.call(req).await.unwrap(); + assert!(res.status().is_success()); + }) } + + // #[test] + // fn test_actor() { + // use actix::Actor; + + // struct MyActor; + + // struct Num(usize); + // impl actix::Message for Num { + // type Result = usize; + // } + // impl actix::Actor for MyActor { + // type Context = actix::Context; + // } + // impl actix::Handler for MyActor { + // type Result = usize; + // fn handle(&mut self, msg: Num, _: &mut Self::Context) -> Self::Result { + // msg.0 + // } + // } + + // let addr = run_on(|| MyActor.start()); + // let mut app = init_service(App::new().service( + // web::resource("/index.html").to(move || { + // addr.send(Num(1)).from_err().and_then(|res| { + // if res == 1 { + // HttpResponse::Ok() + // } else { + // HttpResponse::BadRequest() + // } + // }) + // }), + // )); + + // let req = TestRequest::post().uri("/index.html").to_request(); + // let res = block_fn(|| app.call(req)).unwrap(); + // assert!(res.status().is_success()); + // } } diff --git a/src/types/form.rs b/src/types/form.rs index 3bc067ab5..c20dc7a05 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -1,12 +1,16 @@ //! Form extractor +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::{fmt, ops}; use actix_http::{Error, HttpMessage, Payload, Response}; use bytes::BytesMut; use encoding_rs::{Encoding, UTF_8}; -use futures::{Future, Poll, Stream}; +use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; +use futures::{Stream, StreamExt}; use serde::de::DeserializeOwned; use serde::Serialize; @@ -35,9 +39,8 @@ use crate::responder::Responder; /// /// ### Example /// ```rust -/// # extern crate actix_web; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{web, App}; +/// use actix_web::web; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct FormData { @@ -61,9 +64,9 @@ use crate::responder::Responder; /// /// ### Example /// ```rust -/// # #[macro_use] extern crate serde_derive; -/// # use actix_web::*; -/// # +/// use actix_web::*; +/// use serde_derive::Serialize; +/// /// #[derive(Serialize)] /// struct SomeForm { /// name: String, @@ -111,7 +114,7 @@ where { type Config = FormConfig; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { @@ -121,18 +124,19 @@ where .map(|c| (c.limit, c.ehandler.clone())) .unwrap_or((16384, None)); - Box::new( - UrlEncoded::new(req, payload) - .limit(limit) - .map_err(move |e| { + UrlEncoded::new(req, payload) + .limit(limit) + .map(move |res| match res { + Err(e) => { if let Some(err) = err { - (*err)(e, &req2) + Err((*err)(e, &req2)) } else { - e.into() + Err(e.into()) } - }) - .map(Form), - ) + } + Ok(item) => Ok(Form(item)), + }) + .boxed_local() } } @@ -150,15 +154,15 @@ impl fmt::Display for Form { impl Responder for Form { type Error = Error; - type Future = Result; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { let body = match serde_urlencoded::to_string(&self.0) { Ok(body) => body, - Err(e) => return Err(e.into()), + Err(e) => return err(e.into()), }; - Ok(Response::build(StatusCode::OK) + ok(Response::build(StatusCode::OK) .set(ContentType::form_url_encoded()) .body(body)) } @@ -167,8 +171,8 @@ impl Responder for Form { /// Form extractor configuration /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, FromRequest, Result}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct FormData { @@ -177,7 +181,7 @@ impl Responder for Form { /// /// /// Extract form data using serde. /// /// Custom configuration is used for this handler, max payload size is 4k -/// fn index(form: web::Form) -> Result { +/// async fn index(form: web::Form) -> Result { /// Ok(format!("Welcome {}!", form.username)) /// } /// @@ -241,7 +245,7 @@ pub struct UrlEncoded { length: Option, encoding: &'static Encoding, err: Option, - fut: Option>>, + fut: Option>>, } impl UrlEncoded { @@ -302,45 +306,45 @@ impl Future for UrlEncoded where U: DeserializeOwned + 'static, { - type Item = U; - type Error = UrlencodedError; + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if let Some(ref mut fut) = self.fut { - return fut.poll(); + return Pin::new(fut).poll(cx); } if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } // payload size let limit = self.limit; if let Some(len) = self.length.take() { if len > limit { - return Err(UrlencodedError::Overflow { size: len, limit }); + return Poll::Ready(Err(UrlencodedError::Overflow { size: len, limit })); } } // future let encoding = self.encoding; - let fut = self - .stream - .take() - .unwrap() - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(UrlencodedError::Overflow { - size: body.len() + chunk.len(), - limit, - }) - } else { - body.extend_from_slice(&chunk); - Ok(body) + let mut stream = self.stream.take().unwrap(); + + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if (body.len() + chunk.len()) > limit { + return Err(UrlencodedError::Overflow { + size: body.len() + chunk.len(), + limit, + }); + } else { + body.extend_from_slice(&chunk); + } } - }) - .and_then(move |body| { + if encoding == UTF_8 { serde_urlencoded::from_bytes::(&body) .map_err(|_| UrlencodedError::Parse) @@ -352,9 +356,10 @@ where serde_urlencoded::from_str::(&body) .map_err(|_| UrlencodedError::Parse) } - }); - self.fut = Some(Box::new(fut)); - self.poll() + } + .boxed_local(), + ); + self.poll(cx) } } @@ -375,20 +380,24 @@ mod tests { #[test] fn test_form() { - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world&counter=123")) - .to_http_parts(); + block_on(async { + let (req, mut pl) = TestRequest::with_header( + CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); - let Form(s) = block_on(Form::::from_request(&req, &mut pl)).unwrap(); - assert_eq!( - s, - Info { - hello: "world".into(), - counter: 123 - } - ); + let Form(s) = Form::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!( + s, + Info { + hello: "world".into(), + counter: 123 + } + ); + }) } fn eq(err: UrlencodedError, other: UrlencodedError) -> bool { @@ -411,81 +420,93 @@ mod tests { #[test] fn test_urlencoded_error() { - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "xxxx") - .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)); - assert!(eq(info.err().unwrap(), UrlencodedError::UnknownLength)); - - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "1000000") - .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)); - assert!(eq( - info.err().unwrap(), - UrlencodedError::Overflow { size: 0, limit: 0 } - )); - - let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "text/plain") - .header(CONTENT_LENGTH, "10") + block_on(async { + let (req, mut pl) = TestRequest::with_header( + CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(CONTENT_LENGTH, "xxxx") .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)); - assert!(eq(info.err().unwrap(), UrlencodedError::ContentType)); + let info = UrlEncoded::::new(&req, &mut pl).await; + assert!(eq(info.err().unwrap(), UrlencodedError::UnknownLength)); + + let (req, mut pl) = TestRequest::with_header( + CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(CONTENT_LENGTH, "1000000") + .to_http_parts(); + let info = UrlEncoded::::new(&req, &mut pl).await; + assert!(eq( + info.err().unwrap(), + UrlencodedError::Overflow { size: 0, limit: 0 } + )); + + let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "text/plain") + .header(CONTENT_LENGTH, "10") + .to_http_parts(); + let info = UrlEncoded::::new(&req, &mut pl).await; + assert!(eq(info.err().unwrap(), UrlencodedError::ContentType)); + }) } #[test] fn test_urlencoded() { - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world&counter=123")) - .to_http_parts(); + block_on(async { + let (req, mut pl) = TestRequest::with_header( + CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)).unwrap(); - assert_eq!( - info, - Info { - hello: "world".to_owned(), - counter: 123 - } - ); + let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); + assert_eq!( + info, + Info { + hello: "world".to_owned(), + counter: 123 + } + ); - let (req, mut pl) = TestRequest::with_header( - CONTENT_TYPE, - "application/x-www-form-urlencoded; charset=utf-8", - ) - .header(CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world&counter=123")) - .to_http_parts(); + let (req, mut pl) = TestRequest::with_header( + CONTENT_TYPE, + "application/x-www-form-urlencoded; charset=utf-8", + ) + .header(CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)).unwrap(); - assert_eq!( - info, - Info { - hello: "world".to_owned(), - counter: 123 - } - ); + let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); + assert_eq!( + info, + Info { + hello: "world".to_owned(), + counter: 123 + } + ); + }) } #[test] fn test_responder() { - let req = TestRequest::default().to_http_request(); + block_on(async { + let req = TestRequest::default().to_http_request(); - let form = Form(Info { - hello: "world".to_string(), - counter: 123, - }); - let resp = form.respond_to(&req).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/x-www-form-urlencoded") - ); + let form = Form(Info { + hello: "world".to_string(), + counter: 123, + }); + let resp = form.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/x-www-form-urlencoded") + ); - use crate::responder::tests::BodyTest; - assert_eq!(resp.body().bin_ref(), b"hello=world&counter=123"); + use crate::responder::tests::BodyTest; + assert_eq!(resp.body().bin_ref(), b"hello=world&counter=123"); + }) } } diff --git a/src/types/json.rs b/src/types/json.rs index f309a3c5a..206a4e425 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -1,10 +1,14 @@ //! Json extractor/responder +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::{fmt, ops}; use bytes::BytesMut; -use futures::{Future, Poll, Stream}; +use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; +use futures::{Stream, StreamExt}; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json; @@ -33,8 +37,8 @@ use crate::responder::Responder; /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -42,7 +46,7 @@ use crate::responder::Responder; /// } /// /// /// deserialize `Info` from request's body -/// fn index(info: web::Json) -> String { +/// async fn index(info: web::Json) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -60,9 +64,9 @@ use crate::responder::Responder; /// trait from *serde*. /// /// ```rust -/// # #[macro_use] extern crate serde_derive; -/// # use actix_web::*; -/// # +/// use actix_web::*; +/// use serde_derive::Serialize; +/// /// #[derive(Serialize)] /// struct MyObj { /// name: String, @@ -118,15 +122,15 @@ where impl Responder for Json { type Error = Error; - type Future = Result; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { let body = match serde_json::to_string(&self.0) { Ok(body) => body, - Err(e) => return Err(e.into()), + Err(e) => return err(e.into()), }; - Ok(Response::build(StatusCode::OK) + ok(Response::build(StatusCode::OK) .content_type("application/json") .body(body)) } @@ -144,8 +148,8 @@ impl Responder for Json { /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -153,7 +157,7 @@ impl Responder for Json { /// } /// /// /// deserialize `Info` from request's body -/// fn index(info: web::Json) -> String { +/// async fn index(info: web::Json) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -169,7 +173,7 @@ where T: DeserializeOwned + 'static, { type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; type Config = JsonConfig; #[inline] @@ -180,31 +184,32 @@ where .map(|c| (c.limit, c.ehandler.clone(), c.content_type.clone())) .unwrap_or((32768, None, None)); - Box::new( - JsonBody::new(req, payload, ctype) - .limit(limit) - .map_err(move |e| { + JsonBody::new(req, payload, ctype) + .limit(limit) + .map(move |res| match res { + Err(e) => { log::debug!( "Failed to deserialize Json from payload. \ Request path: {}", req2.path() ); if let Some(err) = err { - (*err)(e, &req2) + Err((*err)(e, &req2)) } else { - e.into() + Err(e.into()) } - }) - .map(Json), - ) + } + Ok(data) => Ok(Json(data)), + }) + .boxed_local() } } /// Json extractor configuration /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -212,7 +217,7 @@ where /// } /// /// /// deserialize `Info` from request's body, max payload size is 4kb -/// fn index(info: web::Json) -> String { +/// async fn index(info: web::Json) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -290,7 +295,7 @@ pub struct JsonBody { length: Option, stream: Option>, err: Option, - fut: Option>>, + fut: Option>>, } impl JsonBody @@ -349,41 +354,43 @@ impl Future for JsonBody where U: DeserializeOwned + 'static, { - type Item = U; - type Error = JsonPayloadError; + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if let Some(ref mut fut) = self.fut { - return fut.poll(); + return Pin::new(fut).poll(cx); } if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } let limit = self.limit; if let Some(len) = self.length.take() { if len > limit { - return Err(JsonPayloadError::Overflow); + return Poll::Ready(Err(JsonPayloadError::Overflow)); } } + let mut stream = self.stream.take().unwrap(); - let fut = self - .stream - .take() - .unwrap() - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(JsonPayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if (body.len() + chunk.len()) > limit { + return Err(JsonPayloadError::Overflow); + } else { + body.extend_from_slice(&chunk); + } } - }) - .and_then(|body| Ok(serde_json::from_slice::(&body)?)); - self.fut = Some(Box::new(fut)); - self.poll() + Ok(serde_json::from_slice::(&body)?) + } + .boxed_local(), + ); + + self.poll(cx) } } @@ -395,7 +402,7 @@ mod tests { use super::*; use crate::error::InternalError; use crate::http::header; - use crate::test::{block_on, TestRequest}; + use crate::test::{block_on, load_stream, TestRequest}; use crate::HttpResponse; #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -419,218 +426,234 @@ mod tests { #[test] fn test_responder() { - let req = TestRequest::default().to_http_request(); + block_on(async { + let req = TestRequest::default().to_http_request(); - let j = Json(MyObject { - name: "test".to_string(), - }); - let resp = j.respond_to(&req).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), - header::HeaderValue::from_static("application/json") - ); + let j = Json(MyObject { + name: "test".to_string(), + }); + let resp = j.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + header::HeaderValue::from_static("application/json") + ); - use crate::responder::tests::BodyTest; - assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}"); + use crate::responder::tests::BodyTest; + assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}"); + }) } #[test] fn test_custom_error_responder() { - let (req, mut pl) = TestRequest::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().limit(10).error_handler(|err, _| { - let msg = MyObject { - name: "invalid request".to_string(), - }; - let resp = HttpResponse::BadRequest() - .body(serde_json::to_string(&msg).unwrap()); - InternalError::from_response(err, resp).into() - })) - .to_http_parts(); + block_on(async { + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().limit(10).error_handler(|err, _| { + let msg = MyObject { + name: "invalid request".to_string(), + }; + let resp = HttpResponse::BadRequest() + .body(serde_json::to_string(&msg).unwrap()); + InternalError::from_response(err, resp).into() + })) + .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); - let mut resp = Response::from_error(s.err().unwrap().into()); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let s = Json::::from_request(&req, &mut pl).await; + let mut resp = Response::from_error(s.err().unwrap().into()); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let body = block_on(resp.take_body().concat2()).unwrap(); - let msg: MyObject = serde_json::from_slice(&body).unwrap(); - assert_eq!(msg.name, "invalid request"); + let body = load_stream(resp.take_body()).await.unwrap(); + let msg: MyObject = serde_json::from_slice(&body).unwrap(); + assert_eq!(msg.name, "invalid request"); + }) } #[test] fn test_extract() { - let (req, mut pl) = TestRequest::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .to_http_parts(); + block_on(async { + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)).unwrap(); - assert_eq!(s.name, "test"); - assert_eq!( - s.into_inner(), - MyObject { - name: "test".to_string() - } - ); + let s = Json::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.name, "test"); + assert_eq!( + s.into_inner(), + MyObject { + name: "test".to_string() + } + ); - let (req, mut pl) = TestRequest::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().limit(10)) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().limit(10)) + .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); - assert!(format!("{}", s.err().unwrap()) - .contains("Json payload size is bigger than allowed")); + let s = Json::::from_request(&req, &mut pl).await; + assert!(format!("{}", s.err().unwrap()) + .contains("Json payload size is bigger than allowed")); - let (req, mut pl) = TestRequest::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data( - JsonConfig::default() - .limit(10) - .error_handler(|_, _| JsonPayloadError::ContentType.into()), - ) - .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); - assert!(format!("{}", s.err().unwrap()).contains("Content type error")); + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data( + JsonConfig::default() + .limit(10) + .error_handler(|_, _| JsonPayloadError::ContentType.into()), + ) + .to_http_parts(); + let s = Json::::from_request(&req, &mut pl).await; + assert!(format!("{}", s.err().unwrap()).contains("Content type error")); + }) } #[test] fn test_json_body() { - let (req, mut pl) = TestRequest::default().to_http_parts(); - let json = block_on(JsonBody::::new(&req, &mut pl, None)); - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + block_on(async { + let (req, mut pl) = TestRequest::default().to_http_parts(); + let json = JsonBody::::new(&req, &mut pl, None).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - let (req, mut pl) = TestRequest::default() - .header( + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/text"), + ) + .to_http_parts(); + let json = JsonBody::::new(&req, &mut pl, None).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("10000"), + ) + .to_http_parts(); + + let json = JsonBody::::new(&req, &mut pl, None) + .limit(100) + .await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); + + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .to_http_parts(); + + let json = JsonBody::::new(&req, &mut pl, None).await; + assert_eq!( + json.ok().unwrap(), + MyObject { + name: "test".to_owned() + } + ); + }) + } + + #[test] + fn test_with_json_and_bad_content_type() { + block_on(async { + let (req, mut pl) = TestRequest::with_header( header::CONTENT_TYPE, - header::HeaderValue::from_static("application/text"), - ) - .to_http_parts(); - let json = block_on(JsonBody::::new(&req, &mut pl, None)); - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - - let (req, mut pl) = TestRequest::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("10000"), - ) - .to_http_parts(); - - let json = block_on(JsonBody::::new(&req, &mut pl, None).limit(100)); - assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); - - let (req, mut pl) = TestRequest::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), + header::HeaderValue::from_static("text/plain"), ) .header( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), ) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().limit(4096)) .to_http_parts(); - let json = block_on(JsonBody::::new(&req, &mut pl, None)); - assert_eq!( - json.ok().unwrap(), - MyObject { - name: "test".to_owned() - } - ); - } - - #[test] - fn test_with_json_and_bad_content_type() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().limit(4096)) - .to_http_parts(); - - let s = block_on(Json::::from_request(&req, &mut pl)); - assert!(s.is_err()) + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_err()) + }) } #[test] fn test_with_json_and_good_custom_content_type() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().content_type(|mime: mime::Mime| { - mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN - })) - .to_http_parts(); + block_on(async { + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().content_type(|mime: mime::Mime| { + mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN + })) + .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); - assert!(s.is_ok()) + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_ok()) + }) } #[test] fn test_with_json_and_bad_custom_content_type() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/html"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().content_type(|mime: mime::Mime| { - mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN - })) - .to_http_parts(); + block_on(async { + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/html"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().content_type(|mime: mime::Mime| { + mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN + })) + .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); - assert!(s.is_err()) + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_err()) + }) } } diff --git a/src/types/path.rs b/src/types/path.rs index a46575764..29a574feb 100644 --- a/src/types/path.rs +++ b/src/types/path.rs @@ -5,6 +5,7 @@ use std::{fmt, ops}; use actix_http::error::{Error, ErrorNotFound}; use actix_router::PathDeserializer; +use futures::future::{ready, Ready}; use serde::de; use crate::dev::Payload; @@ -23,7 +24,7 @@ use crate::FromRequest; /// /// extract path info from "/{username}/{count}/index.html" url /// /// {username} - deserializes to a String /// /// {count} - - deserializes to a u32 -/// fn index(info: web::Path<(String, u32)>) -> String { +/// async fn index(info: web::Path<(String, u32)>) -> String { /// format!("Welcome {}! {}", info.0, info.1) /// } /// @@ -39,8 +40,8 @@ use crate::FromRequest; /// implements `Deserialize` trait from *serde*. /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, Error}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -48,7 +49,7 @@ use crate::FromRequest; /// } /// /// /// extract `Info` from a path using serde -/// fn index(info: web::Path) -> Result { +/// async fn index(info: web::Path) -> Result { /// Ok(format!("Welcome {}!", info.username)) /// } /// @@ -118,7 +119,7 @@ impl fmt::Display for Path { /// /// extract path info from "/{username}/{count}/index.html" url /// /// {username} - deserializes to a String /// /// {count} - - deserializes to a u32 -/// fn index(info: web::Path<(String, u32)>) -> String { +/// async fn index(info: web::Path<(String, u32)>) -> String { /// format!("Welcome {}! {}", info.0, info.1) /// } /// @@ -134,8 +135,8 @@ impl fmt::Display for Path { /// implements `Deserialize` trait from *serde*. /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, Error}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -143,7 +144,7 @@ impl fmt::Display for Path { /// } /// /// /// extract `Info` from a path using serde -/// fn index(info: web::Path) -> Result { +/// async fn index(info: web::Path) -> Result { /// Ok(format!("Welcome {}!", info.username)) /// } /// @@ -159,7 +160,7 @@ where T: de::DeserializeOwned, { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = PathConfig; #[inline] @@ -169,31 +170,32 @@ where .map(|c| c.ehandler.clone()) .unwrap_or(None); - de::Deserialize::deserialize(PathDeserializer::new(req.match_info())) - .map(|inner| Path { inner }) - .map_err(move |e| { - log::debug!( - "Failed during Path extractor deserialization. \ - Request path: {:?}", - req.path() - ); - if let Some(error_handler) = error_handler { - let e = PathError::Deserialize(e); - (error_handler)(e, req) - } else { - ErrorNotFound(e) - } - }) + ready( + de::Deserialize::deserialize(PathDeserializer::new(req.match_info())) + .map(|inner| Path { inner }) + .map_err(move |e| { + log::debug!( + "Failed during Path extractor deserialization. \ + Request path: {:?}", + req.path() + ); + if let Some(error_handler) = error_handler { + let e = PathError::Deserialize(e); + (error_handler)(e, req) + } else { + ErrorNotFound(e) + } + }), + ) } } /// Path extractor configuration /// /// ```rust -/// # #[macro_use] -/// # extern crate serde_derive; /// use actix_web::web::PathConfig; /// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize, Debug)] /// enum Folder { @@ -204,7 +206,7 @@ where /// } /// /// // deserialize `Info` from request's path -/// fn index(folder: web::Path) -> String { +/// async fn index(folder: web::Path) -> String { /// format!("Selected folder: {:?}!", folder) /// } /// @@ -269,100 +271,116 @@ mod tests { #[test] fn test_extract_path_single() { - let resource = ResourceDef::new("/{value}/"); + block_on(async { + let resource = ResourceDef::new("/{value}/"); - let mut req = TestRequest::with_uri("/32/").to_srv_request(); - resource.match_path(req.match_info_mut()); + let mut req = TestRequest::with_uri("/32/").to_srv_request(); + resource.match_path(req.match_info_mut()); - let (req, mut pl) = req.into_parts(); - assert_eq!(*Path::::from_request(&req, &mut pl).unwrap(), 32); - assert!(Path::::from_request(&req, &mut pl).is_err()); + let (req, mut pl) = req.into_parts(); + assert_eq!(*Path::::from_request(&req, &mut pl).await.unwrap(), 32); + assert!(Path::::from_request(&req, &mut pl).await.is_err()); + }) } #[test] fn test_tuple_extract() { - let resource = ResourceDef::new("/{key}/{value}/"); + block_on(async { + let resource = ResourceDef::new("/{key}/{value}/"); - let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); - resource.match_path(req.match_info_mut()); + let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); + resource.match_path(req.match_info_mut()); - let (req, mut pl) = req.into_parts(); - let res = - block_on(<(Path<(String, String)>,)>::from_request(&req, &mut pl)).unwrap(); - assert_eq!((res.0).0, "name"); - assert_eq!((res.0).1, "user1"); + let (req, mut pl) = req.into_parts(); + let res = <(Path<(String, String)>,)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!((res.0).0, "name"); + assert_eq!((res.0).1, "user1"); - let res = block_on( - <(Path<(String, String)>, Path<(String, String)>)>::from_request( + let res = <(Path<(String, String)>, Path<(String, String)>)>::from_request( &req, &mut pl, - ), - ) - .unwrap(); - assert_eq!((res.0).0, "name"); - assert_eq!((res.0).1, "user1"); - assert_eq!((res.1).0, "name"); - assert_eq!((res.1).1, "user1"); + ) + .await + .unwrap(); + assert_eq!((res.0).0, "name"); + assert_eq!((res.0).1, "user1"); + assert_eq!((res.1).0, "name"); + assert_eq!((res.1).1, "user1"); - let () = <()>::from_request(&req, &mut pl).unwrap(); + let () = <()>::from_request(&req, &mut pl).await.unwrap(); + }) } #[test] fn test_request_extract() { - let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); + block_on(async { + let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); - let resource = ResourceDef::new("/{key}/{value}/"); - resource.match_path(req.match_info_mut()); + let resource = ResourceDef::new("/{key}/{value}/"); + resource.match_path(req.match_info_mut()); - let (req, mut pl) = req.into_parts(); - let mut s = Path::::from_request(&req, &mut pl).unwrap(); - assert_eq!(s.key, "name"); - assert_eq!(s.value, "user1"); - s.value = "user2".to_string(); - assert_eq!(s.value, "user2"); - assert_eq!( - format!("{}, {:?}", s, s), - "MyStruct(name, user2), MyStruct { key: \"name\", value: \"user2\" }" - ); - let s = s.into_inner(); - assert_eq!(s.value, "user2"); + let (req, mut pl) = req.into_parts(); + let mut s = Path::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.key, "name"); + assert_eq!(s.value, "user1"); + s.value = "user2".to_string(); + assert_eq!(s.value, "user2"); + assert_eq!( + format!("{}, {:?}", s, s), + "MyStruct(name, user2), MyStruct { key: \"name\", value: \"user2\" }" + ); + let s = s.into_inner(); + assert_eq!(s.value, "user2"); - let s = Path::<(String, String)>::from_request(&req, &mut pl).unwrap(); - assert_eq!(s.0, "name"); - assert_eq!(s.1, "user1"); + let s = Path::<(String, String)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, "user1"); - let mut req = TestRequest::with_uri("/name/32/").to_srv_request(); - let resource = ResourceDef::new("/{key}/{value}/"); - resource.match_path(req.match_info_mut()); + let mut req = TestRequest::with_uri("/name/32/").to_srv_request(); + let resource = ResourceDef::new("/{key}/{value}/"); + resource.match_path(req.match_info_mut()); - let (req, mut pl) = req.into_parts(); - let s = Path::::from_request(&req, &mut pl).unwrap(); - assert_eq!(s.as_ref().key, "name"); - assert_eq!(s.value, 32); + let (req, mut pl) = req.into_parts(); + let s = Path::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.as_ref().key, "name"); + assert_eq!(s.value, 32); - let s = Path::<(String, u8)>::from_request(&req, &mut pl).unwrap(); - assert_eq!(s.0, "name"); - assert_eq!(s.1, 32); + let s = Path::<(String, u8)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, 32); - let res = Path::>::from_request(&req, &mut pl).unwrap(); - assert_eq!(res[0], "name".to_owned()); - assert_eq!(res[1], "32".to_owned()); + let res = Path::>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(res[0], "name".to_owned()); + assert_eq!(res[1], "32".to_owned()); + }) } #[test] fn test_custom_err_handler() { - let (req, mut pl) = TestRequest::with_uri("/name/user1/") - .data(PathConfig::default().error_handler(|err, _| { - error::InternalError::from_response( - err, - HttpResponse::Conflict().finish(), - ) - .into() - })) - .to_http_parts(); + block_on(async { + let (req, mut pl) = TestRequest::with_uri("/name/user1/") + .data(PathConfig::default().error_handler(|err, _| { + error::InternalError::from_response( + err, + HttpResponse::Conflict().finish(), + ) + .into() + })) + .to_http_parts(); - let s = block_on(Path::<(usize,)>::from_request(&req, &mut pl)).unwrap_err(); - let res: HttpResponse = s.into(); + let s = Path::<(usize,)>::from_request(&req, &mut pl) + .await + .unwrap_err(); + let res: HttpResponse = s.into(); - assert_eq!(res.status(), http::StatusCode::CONFLICT); + assert_eq!(res.status(), http::StatusCode::CONFLICT); + }) } } diff --git a/src/types/payload.rs b/src/types/payload.rs index 8fc5f093e..ee7e11667 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -1,12 +1,15 @@ //! Payload/Bytes/String extractors +use std::future::Future; +use std::pin::Pin; use std::str; +use std::task::{Context, Poll}; use actix_http::error::{Error, ErrorBadRequest, PayloadError}; use actix_http::HttpMessage; use bytes::{Bytes, BytesMut}; use encoding_rs::UTF_8; -use futures::future::{err, Either, FutureResult}; -use futures::{Future, Poll, Stream}; +use futures::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; +use futures::{Stream, StreamExt}; use mime::Mime; use crate::dev; @@ -19,27 +22,25 @@ use crate::request::HttpRequest; /// ## Example /// /// ```rust -/// use futures::{Future, Stream}; +/// use futures::{Future, Stream, StreamExt}; /// use actix_web::{web, error, App, Error, HttpResponse}; /// /// /// extract binary data from request -/// fn index(body: web::Payload) -> impl Future +/// async fn index(mut body: web::Payload) -> Result /// { -/// body.map_err(Error::from) -/// .fold(web::BytesMut::new(), move |mut body, chunk| { -/// body.extend_from_slice(&chunk); -/// Ok::<_, Error>(body) -/// }) -/// .and_then(|body| { -/// format!("Body {:?}!", body); -/// Ok(HttpResponse::Ok().finish()) -/// }) +/// let mut bytes = web::BytesMut::new(); +/// while let Some(item) = body.next().await { +/// bytes.extend_from_slice(&item?); +/// } +/// +/// format!("Body {:?}!", bytes); +/// Ok(HttpResponse::Ok().finish()) /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/index.html").route( -/// web::get().to_async(index)) +/// web::get().to(index)) /// ); /// } /// ``` @@ -53,12 +54,14 @@ impl Payload { } impl Stream for Payload { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; #[inline] - fn poll(&mut self) -> Poll, PayloadError> { - self.0.poll() + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) } } @@ -67,38 +70,36 @@ impl Stream for Payload { /// ## Example /// /// ```rust -/// use futures::{Future, Stream}; +/// use futures::{Future, Stream, StreamExt}; /// use actix_web::{web, error, App, Error, HttpResponse}; /// /// /// extract binary data from request -/// fn index(body: web::Payload) -> impl Future +/// async fn index(mut body: web::Payload) -> Result /// { -/// body.map_err(Error::from) -/// .fold(web::BytesMut::new(), move |mut body, chunk| { -/// body.extend_from_slice(&chunk); -/// Ok::<_, Error>(body) -/// }) -/// .and_then(|body| { -/// format!("Body {:?}!", body); -/// Ok(HttpResponse::Ok().finish()) -/// }) +/// let mut bytes = web::BytesMut::new(); +/// while let Some(item) = body.next().await { +/// bytes.extend_from_slice(&item?); +/// } +/// +/// format!("Body {:?}!", bytes); +/// Ok(HttpResponse::Ok().finish()) /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/index.html").route( -/// web::get().to_async(index)) +/// web::get().to(index)) /// ); /// } /// ``` impl FromRequest for Payload { type Config = PayloadConfig; type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(_: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { - Ok(Payload(payload.take())) + ok(Payload(payload.take())) } } @@ -116,7 +117,7 @@ impl FromRequest for Payload { /// use actix_web::{web, App}; /// /// /// extract binary data from request -/// fn index(body: Bytes) -> String { +/// async fn index(body: Bytes) -> String { /// format!("Body {:?}!", body) /// } /// @@ -130,8 +131,10 @@ impl FromRequest for Payload { impl FromRequest for Bytes { type Config = PayloadConfig; type Error = Error; - type Future = - Either>, FutureResult>; + type Future = Either< + LocalBoxFuture<'static, Result>, + Ready>, + >; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -144,13 +147,12 @@ impl FromRequest for Bytes { }; if let Err(e) = cfg.check_mimetype(req) { - return Either::B(err(e)); + return Either::Right(err(e)); } let limit = cfg.limit; - Either::A(Box::new( - HttpMessageBody::new(req, payload).limit(limit).from_err(), - )) + let fut = HttpMessageBody::new(req, payload).limit(limit); + Either::Left(async move { Ok(fut.await?) }.boxed_local()) } } @@ -167,7 +169,7 @@ impl FromRequest for Bytes { /// use actix_web::{web, App, FromRequest}; /// /// /// extract text data from request -/// fn index(text: String) -> String { +/// async fn index(text: String) -> String { /// format!("Body {}!", text) /// } /// @@ -185,8 +187,8 @@ impl FromRequest for String { type Config = PayloadConfig; type Error = Error; type Future = Either< - Box>, - FutureResult, + LocalBoxFuture<'static, Result>, + Ready>, >; #[inline] @@ -201,33 +203,34 @@ impl FromRequest for String { // check content-type if let Err(e) = cfg.check_mimetype(req) { - return Either::B(err(e)); + return Either::Right(err(e)); } // check charset let encoding = match req.encoding() { Ok(enc) => enc, - Err(e) => return Either::B(err(e.into())), + Err(e) => return Either::Right(err(e.into())), }; let limit = cfg.limit; + let fut = HttpMessageBody::new(req, payload).limit(limit); - Either::A(Box::new( - HttpMessageBody::new(req, payload) - .limit(limit) - .from_err() - .and_then(move |body| { - if encoding == UTF_8 { - Ok(str::from_utf8(body.as_ref()) - .map_err(|_| ErrorBadRequest("Can not decode body"))? - .to_owned()) - } else { - Ok(encoding - .decode_without_bom_handling_and_without_replacement(&body) - .map(|s| s.into_owned()) - .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) - } - }), - )) + Either::Left( + async move { + let body = fut.await?; + + if encoding == UTF_8 { + Ok(str::from_utf8(body.as_ref()) + .map_err(|_| ErrorBadRequest("Can not decode body"))? + .to_owned()) + } else { + Ok(encoding + .decode_without_bom_handling_and_without_replacement(&body) + .map(|s| s.into_owned()) + .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) + } + } + .boxed_local(), + ) } } /// Payload configuration for request's payload. @@ -300,7 +303,7 @@ pub struct HttpMessageBody { length: Option, stream: Option>, err: Option, - fut: Option>>, + fut: Option>>, } impl HttpMessageBody { @@ -346,42 +349,43 @@ impl HttpMessageBody { } impl Future for HttpMessageBody { - type Item = Bytes; - type Error = PayloadError; + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if let Some(ref mut fut) = self.fut { - return fut.poll(); + return Pin::new(fut).poll(cx); } if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } if let Some(len) = self.length.take() { if len > self.limit { - return Err(PayloadError::Overflow); + return Poll::Ready(Err(PayloadError::Overflow)); } } // future let limit = self.limit; - self.fut = Some(Box::new( - self.stream - .take() - .unwrap() - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(PayloadError::Overflow) + let mut stream = self.stream.take().unwrap(); + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if body.len() + chunk.len() > limit { + return Err(PayloadError::Overflow); } else { body.extend_from_slice(&chunk); - Ok(body) } - }) - .map(|body| body.freeze()), - )); - self.poll() + } + Ok(body.freeze()) + } + .boxed_local(), + ); + self.poll(cx) } } diff --git a/src/types/query.rs b/src/types/query.rs index 60b07085d..e442f1c31 100644 --- a/src/types/query.rs +++ b/src/types/query.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::{fmt, ops}; use actix_http::error::Error; +use futures::future::{err, ok, Ready}; use serde::de; use serde_urlencoded; @@ -21,8 +22,8 @@ use crate::request::HttpRequest; /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Debug, Deserialize)] /// pub enum ResponseType { @@ -39,7 +40,7 @@ use crate::request::HttpRequest; /// // Use `Query` extractor for query information (and destructure it within the signature). /// // This handler gets called only if the request's query string contains a `username` field. /// // The correct request for this handler would be `/index.html?id=64&response_type=Code"`. -/// fn index(web::Query(info): web::Query) -> String { +/// async fn index(web::Query(info): web::Query) -> String { /// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) /// } /// @@ -99,8 +100,8 @@ impl fmt::Display for Query { /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Debug, Deserialize)] /// pub enum ResponseType { @@ -117,7 +118,7 @@ impl fmt::Display for Query { /// // Use `Query` extractor for query information. /// // This handler get called only if request's query contains `username` field /// // The correct request for this handler would be `/index.html?id=64&response_type=Code"` -/// fn index(info: web::Query) -> String { +/// async fn index(info: web::Query) -> String { /// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) /// } /// @@ -132,7 +133,7 @@ where T: de::DeserializeOwned, { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = QueryConfig; #[inline] @@ -143,7 +144,7 @@ where .unwrap_or(None); serde_urlencoded::from_str::(req.query_string()) - .map(|val| Ok(Query(val))) + .map(|val| ok(Query(val))) .unwrap_or_else(move |e| { let e = QueryPayloadError::Deserialize(e); @@ -159,7 +160,7 @@ where e.into() }; - Err(e) + err(e) }) } } @@ -169,8 +170,8 @@ where /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -178,7 +179,7 @@ where /// } /// /// /// deserialize `Info` from request's querystring -/// fn index(info: web::Query) -> String { +/// async fn index(info: web::Query) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -227,7 +228,7 @@ mod tests { use super::*; use crate::error::InternalError; - use crate::test::TestRequest; + use crate::test::{block_on, TestRequest}; use crate::HttpResponse; #[derive(Deserialize, Debug, Display)] @@ -253,42 +254,46 @@ mod tests { #[test] fn test_request_extract() { - let req = TestRequest::with_uri("/name/user1/").to_srv_request(); - let (req, mut pl) = req.into_parts(); - assert!(Query::::from_request(&req, &mut pl).is_err()); + block_on(async { + let req = TestRequest::with_uri("/name/user1/").to_srv_request(); + let (req, mut pl) = req.into_parts(); + assert!(Query::::from_request(&req, &mut pl).await.is_err()); - let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); - let (req, mut pl) = req.into_parts(); + let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); + let (req, mut pl) = req.into_parts(); - let mut s = Query::::from_request(&req, &mut pl).unwrap(); - assert_eq!(s.id, "test"); - assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }"); + let mut s = Query::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.id, "test"); + assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }"); - s.id = "test1".to_string(); - let s = s.into_inner(); - assert_eq!(s.id, "test1"); + s.id = "test1".to_string(); + let s = s.into_inner(); + assert_eq!(s.id, "test1"); + }) } #[test] fn test_custom_error_responder() { - let req = TestRequest::with_uri("/name/user1/") - .data(QueryConfig::default().error_handler(|e, _| { - let resp = HttpResponse::UnprocessableEntity().finish(); - InternalError::from_response(e, resp).into() - })) - .to_srv_request(); + block_on(async { + let req = TestRequest::with_uri("/name/user1/") + .data(QueryConfig::default().error_handler(|e, _| { + let resp = HttpResponse::UnprocessableEntity().finish(); + InternalError::from_response(e, resp).into() + })) + .to_srv_request(); - let (req, mut pl) = req.into_parts(); - let query = Query::::from_request(&req, &mut pl); + let (req, mut pl) = req.into_parts(); + let query = Query::::from_request(&req, &mut pl).await; - assert!(query.is_err()); - assert_eq!( - query - .unwrap_err() - .as_response_error() - .error_response() - .status(), - StatusCode::UNPROCESSABLE_ENTITY - ); + assert!(query.is_err()); + assert_eq!( + query + .unwrap_err() + .as_response_error() + .error_response() + .status(), + StatusCode::UNPROCESSABLE_ENTITY + ); + }) } } diff --git a/src/types/readlines.rs b/src/types/readlines.rs index cea63e43b..e2b3f9c1d 100644 --- a/src/types/readlines.rs +++ b/src/types/readlines.rs @@ -1,9 +1,13 @@ use std::borrow::Cow; +use std::future::Future; +use std::pin::Pin; use std::str; +use std::task::{Context, Poll}; use bytes::{Bytes, BytesMut}; use encoding_rs::{Encoding, UTF_8}; -use futures::{Async, Poll, Stream}; +use futures::Stream; +use pin_project::pin_project; use crate::dev::Payload; use crate::error::{PayloadError, ReadlinesError}; @@ -22,7 +26,7 @@ pub struct Readlines { impl Readlines where T: HttpMessage, - T::Stream: Stream, + T::Stream: Stream> + Unpin, { /// Create a new stream to read request line by line. pub fn new(req: &mut T) -> Self { @@ -62,20 +66,21 @@ where impl Stream for Readlines where T: HttpMessage, - T::Stream: Stream, + T::Stream: Stream> + Unpin, { - type Item = String; - type Error = ReadlinesError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - if let Some(err) = self.err.take() { - return Err(err); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Some(Err(err))); } // check if there is a newline in the buffer - if !self.checked_buff { + if !this.checked_buff { let mut found: Option = None; - for (ind, b) in self.buff.iter().enumerate() { + for (ind, b) in this.buff.iter().enumerate() { if *b == b'\n' { found = Some(ind); break; @@ -83,28 +88,28 @@ where } if let Some(ind) = found { // check if line is longer than limit - if ind + 1 > self.limit { - return Err(ReadlinesError::LimitOverflow); + if ind + 1 > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); } - let line = if self.encoding == UTF_8 { - str::from_utf8(&self.buff.split_to(ind + 1)) + let line = if this.encoding == UTF_8 { + str::from_utf8(&this.buff.split_to(ind + 1)) .map_err(|_| ReadlinesError::EncodingError)? .to_owned() } else { - self.encoding + this.encoding .decode_without_bom_handling_and_without_replacement( - &self.buff.split_to(ind + 1), + &this.buff.split_to(ind + 1), ) .map(Cow::into_owned) .ok_or(ReadlinesError::EncodingError)? }; - return Ok(Async::Ready(Some(line))); + return Poll::Ready(Some(Ok(line))); } - self.checked_buff = true; + this.checked_buff = true; } // poll req for more bytes - match self.stream.poll() { - Ok(Async::Ready(Some(mut bytes))) => { + match Pin::new(&mut this.stream).poll_next(cx) { + Poll::Ready(Some(Ok(mut bytes))) => { // check if there is a newline in bytes let mut found: Option = None; for (ind, b) in bytes.iter().enumerate() { @@ -115,15 +120,15 @@ where } if let Some(ind) = found { // check if line is longer than limit - if ind + 1 > self.limit { - return Err(ReadlinesError::LimitOverflow); + if ind + 1 > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); } - let line = if self.encoding == UTF_8 { + let line = if this.encoding == UTF_8 { str::from_utf8(&bytes.split_to(ind + 1)) .map_err(|_| ReadlinesError::EncodingError)? .to_owned() } else { - self.encoding + this.encoding .decode_without_bom_handling_and_without_replacement( &bytes.split_to(ind + 1), ) @@ -131,83 +136,72 @@ where .ok_or(ReadlinesError::EncodingError)? }; // extend buffer with rest of the bytes; - self.buff.extend_from_slice(&bytes); - self.checked_buff = false; - return Ok(Async::Ready(Some(line))); + this.buff.extend_from_slice(&bytes); + this.checked_buff = false; + return Poll::Ready(Some(Ok(line))); } - self.buff.extend_from_slice(&bytes); - Ok(Async::NotReady) + this.buff.extend_from_slice(&bytes); + Poll::Pending } - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(None)) => { - if self.buff.is_empty() { - return Ok(Async::Ready(None)); + Poll::Pending => Poll::Pending, + Poll::Ready(None) => { + if this.buff.is_empty() { + return Poll::Ready(None); } - if self.buff.len() > self.limit { - return Err(ReadlinesError::LimitOverflow); + if this.buff.len() > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); } - let line = if self.encoding == UTF_8 { - str::from_utf8(&self.buff) + let line = if this.encoding == UTF_8 { + str::from_utf8(&this.buff) .map_err(|_| ReadlinesError::EncodingError)? .to_owned() } else { - self.encoding - .decode_without_bom_handling_and_without_replacement(&self.buff) + this.encoding + .decode_without_bom_handling_and_without_replacement(&this.buff) .map(Cow::into_owned) .ok_or(ReadlinesError::EncodingError)? }; - self.buff.clear(); - Ok(Async::Ready(Some(line))) + this.buff.clear(); + Poll::Ready(Some(Ok(line))) } - Err(e) => Err(ReadlinesError::from(e)), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ReadlinesError::from(e)))), } } } #[cfg(test)] mod tests { + use futures::stream::StreamExt; + use super::*; use crate::test::{block_on, TestRequest}; #[test] fn test_readlines() { - let mut req = TestRequest::default() + block_on(async { + let mut req = TestRequest::default() .set_payload(Bytes::from_static( b"Lorem Ipsum is simply dummy text of the printing and typesetting\n\ industry. Lorem Ipsum has been the industry's standard dummy\n\ Contrary to popular belief, Lorem Ipsum is not simply random text.", )) .to_request(); - let stream = match block_on(Readlines::new(&mut req).into_future()) { - Ok((Some(s), stream)) => { - assert_eq!( - s, - "Lorem Ipsum is simply dummy text of the printing and typesetting\n" - ); - stream - } - _ => unreachable!("error"), - }; - let stream = match block_on(stream.into_future()) { - Ok((Some(s), stream)) => { - assert_eq!( - s, - "industry. Lorem Ipsum has been the industry's standard dummy\n" - ); - stream - } - _ => unreachable!("error"), - }; + let mut stream = Readlines::new(&mut req); + assert_eq!( + stream.next().await.unwrap().unwrap(), + "Lorem Ipsum is simply dummy text of the printing and typesetting\n" + ); - match block_on(stream.into_future()) { - Ok((Some(s), _)) => { - assert_eq!( - s, - "Contrary to popular belief, Lorem Ipsum is not simply random text." - ); - } - _ => unreachable!("error"), - } + assert_eq!( + stream.next().await.unwrap().unwrap(), + "industry. Lorem Ipsum has been the industry's standard dummy\n" + ); + + assert_eq!( + stream.next().await.unwrap().unwrap(), + "Contrary to popular belief, Lorem Ipsum is not simply random text." + ); + }) } } diff --git a/src/web.rs b/src/web.rs index 5669a1e86..22630ae81 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1,13 +1,14 @@ //! Essentials helper functions and types for application registration. use actix_http::http::Method; -use futures::{Future, IntoFuture}; +use futures::Future; pub use actix_http::Response as HttpResponse; pub use bytes::{Bytes, BytesMut}; +pub use futures::channel::oneshot::Canceled; -use crate::error::{BlockingError, Error}; +use crate::error::Error; use crate::extract::FromRequest; -use crate::handler::{AsyncFactory, Factory}; +use crate::handler::Factory; use crate::resource::Resource; use crate::responder::Responder; use crate::route::Route; @@ -230,10 +231,10 @@ pub fn method(method: Method) -> Route { /// Create a new route and add handler. /// /// ```rust -/// use actix_web::{web, App, HttpResponse}; +/// use actix_web::{web, App, HttpResponse, Responder}; /// -/// fn index() -> HttpResponse { -/// unimplemented!() +/// async fn index() -> impl Responder { +/// HttpResponse::Ok() /// } /// /// App::new().service( @@ -241,48 +242,23 @@ pub fn method(method: Method) -> Route { /// web::to(index)) /// ); /// ``` -pub fn to(handler: F) -> Route +pub fn to(handler: F) -> Route where - F: Factory + 'static, + F: Factory, I: FromRequest + 'static, - R: Responder + 'static, + R: Future + 'static, + U: Responder + 'static, { Route::new().to(handler) } -/// Create a new route and add async handler. -/// -/// ```rust -/// # use futures::future::{ok, Future}; -/// use actix_web::{web, App, HttpResponse, Error}; -/// -/// fn index() -> impl Future { -/// ok(HttpResponse::Ok().finish()) -/// } -/// -/// App::new().service(web::resource("/").route( -/// web::to_async(index)) -/// ); -/// ``` -pub fn to_async(handler: F) -> Route -where - F: AsyncFactory, - I: FromRequest + 'static, - R: IntoFuture + 'static, - R::Item: Responder, - R::Error: Into, -{ - Route::new().to_async(handler) -} - /// Create raw service for a specific path. /// /// ```rust -/// # extern crate actix_web; -/// use actix_web::{dev, web, guard, App, HttpResponse}; +/// use actix_web::{dev, web, guard, App, Error, HttpResponse}; /// -/// fn my_service(req: dev::ServiceRequest) -> dev::ServiceResponse { -/// req.into_response(HttpResponse::Ok().finish()) +/// async fn my_service(req: dev::ServiceRequest) -> Result { +/// Ok(req.into_response(HttpResponse::Ok().finish())) /// } /// /// fn main() { @@ -299,11 +275,10 @@ pub fn service(path: &str) -> WebService { /// Execute blocking function on a thread pool, returns future that resolves /// to result of the function execution. -pub fn block(f: F) -> impl Future> +pub fn block(f: F) -> impl Future> where - F: FnOnce() -> Result + Send + 'static, - I: Send + 'static, - E: Send + std::fmt::Debug + 'static, + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, { - actix_threadpool::run(f).from_err() + actix_threadpool::run(f) } diff --git a/test-server/CHANGES.md b/test-server/CHANGES.md index 798dbf506..57068fe95 100644 --- a/test-server/CHANGES.md +++ b/test-server/CHANGES.md @@ -5,6 +5,7 @@ ### Changed * Update serde_urlencoded to "0.6.1" +* Increase TestServerRuntime timeouts from 500ms to 3000ms ### Fixed diff --git a/test-server/Cargo.toml b/test-server/Cargo.toml index 77301b0a9..a2b03ffc2 100644 --- a/test-server/Cargo.toml +++ b/test-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http-test" -version = "0.2.5" +version = "0.3.0-alpha.1" authors = ["Nikolay Kim "] description = "Actix http test server" readme = "README.md" @@ -27,20 +27,22 @@ path = "src/lib.rs" default = [] # openssl -ssl = ["openssl", "actix-server/ssl", "awc/ssl"] +openssl = ["open-ssl", "actix-server/openssl", "awc/openssl"] [dependencies] -actix-codec = "0.1.2" -actix-rt = "0.2.2" -actix-service = "0.4.1" -actix-server = "0.6.0" -actix-utils = "0.4.1" -awc = "0.2.6" -actix-connect = "0.2.2" +actix-service = "1.0.0-alpha.1" +actix-codec = "0.2.0-alpha.1" +actix-connect = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" +actix-rt = "1.0.0-alpha.1" +actix-server = "0.8.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" +actix-testing = "0.3.0-alpha.1" +awc = "0.3.0-alpha.1" base64 = "0.10" bytes = "0.4" -futures = "0.1" +futures = "0.3.1" http = "0.1.8" log = "0.4" env_logger = "0.6" @@ -51,10 +53,11 @@ sha1 = "0.6" slab = "0.4" serde_urlencoded = "0.6.1" time = "0.1" -tokio-tcp = "0.1" -tokio-timer = "0.2" -openssl = { version="0.10", optional = true } +tokio-net = "0.2.0-alpha.6" +tokio-timer = "0.3.0-alpha.6" + +open-ssl = { version="0.10", package="openssl", optional = true } [dev-dependencies] -actix-web = "1.0.7" -actix-http = "0.2.9" +actix-web = "2.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" diff --git a/test-server/src/lib.rs b/test-server/src/lib.rs index a2366bf48..1911c75d6 100644 --- a/test-server/src/lib.rs +++ b/test-server/src/lib.rs @@ -1,73 +1,18 @@ //! Various helpers for Actix applications to use during testing. -use std::cell::RefCell; use std::sync::mpsc; use std::{net, thread, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_rt::{Runtime, System}; -use actix_server::{Server, StreamServiceFactory}; +use actix_rt::System; +use actix_server::{Server, ServiceFactory}; use awc::{error::PayloadError, ws, Client, ClientRequest, ClientResponse, Connector}; use bytes::Bytes; -use futures::future::lazy; -use futures::{Future, IntoFuture, Stream}; +use futures::Stream; use http::Method; use net2::TcpBuilder; -use tokio_tcp::TcpStream; +use tokio_net::tcp::TcpStream; -thread_local! { - static RT: RefCell = { - RefCell::new(Inner(Some(Runtime::new().unwrap()))) - }; -} - -struct Inner(Option); - -impl Inner { - fn get_mut(&mut self) -> &mut Runtime { - self.0.as_mut().unwrap() - } -} - -impl Drop for Inner { - fn drop(&mut self) { - std::mem::forget(self.0.take().unwrap()) - } -} - -/// Runs the provided future, blocking the current thread until the future -/// completes. -/// -/// This function can be used to synchronously block the current thread -/// until the provided `future` has resolved either successfully or with an -/// error. The result of the future is then returned from this function -/// call. -/// -/// Note that this function is intended to be used only for testing purpose. -/// This function panics on nested call. -pub fn block_on(f: F) -> Result -where - F: IntoFuture, -{ - RT.with(move |rt| rt.borrow_mut().get_mut().block_on(f.into_future())) -} - -/// Runs the provided function, blocking the current thread until the resul -/// future completes. -/// -/// This function can be used to synchronously block the current thread -/// until the provided `future` has resolved either successfully or with an -/// error. The result of the future is then returned from this function -/// call. -/// -/// Note that this function is intended to be used only for testing purpose. -/// This function panics on nested call. -pub fn block_fn(f: F) -> Result -where - F: FnOnce() -> R, - R: IntoFuture, -{ - RT.with(move |rt| rt.borrow_mut().get_mut().block_on(lazy(f))) -} +pub use actix_testing::*; /// The `TestServer` type. /// @@ -78,24 +23,26 @@ where /// /// ```rust /// use actix_http::HttpService; -/// use actix_http_test::TestServer; -/// use actix_web::{web, App, HttpResponse}; +/// use actix_http_test::{block_on, TestServer}; +/// use actix_web::{web, App, HttpResponse, Error}; /// -/// fn my_handler() -> HttpResponse { -/// HttpResponse::Ok().into() +/// async fn my_handler() -> Result { +/// Ok(HttpResponse::Ok().into()) /// } /// /// fn main() { -/// let mut srv = TestServer::new( -/// || HttpService::new( -/// App::new().service( -/// web::resource("/").to(my_handler)) -/// ) -/// ); +/// block_on( async { +/// let mut srv = TestServer::start( +/// || HttpService::new( +/// App::new().service( +/// web::resource("/").to(my_handler)) +/// ) +/// ); /// -/// let req = srv.get("/"); -/// let response = srv.block_on(req.send()).unwrap(); -/// assert!(response.status().is_success()); +/// let req = srv.get("/"); +/// let response = req.send().await.unwrap(); +/// assert!(response.status().is_success()); +/// }) /// } /// ``` pub struct TestServer; @@ -110,7 +57,7 @@ pub struct TestServerRuntime { impl TestServer { #[allow(clippy::new_ret_no_self)] /// Start new test server with application factory - pub fn new>(factory: F) -> TestServerRuntime { + pub fn start>(factory: F) -> TestServerRuntime { let (tx, rx) = mpsc::channel(); // run server in separate thread @@ -131,11 +78,11 @@ impl TestServer { let (system, addr) = rx.recv().unwrap(); - let client = block_on(lazy(move || { + let client = { let connector = { - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); @@ -144,27 +91,22 @@ impl TestServer { .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); Connector::new() .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(500)) + .timeout(time::Duration::from_millis(3000)) .ssl(builder.build()) .finish() } - #[cfg(not(feature = "ssl"))] + #[cfg(not(feature = "openssl"))] { Connector::new() .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(500)) + .timeout(time::Duration::from_millis(3000)) .finish() } }; - Ok::(Client::build().connector(connector).finish()) - })) - .unwrap(); - - block_on(lazy( - || Ok::<_, ()>(actix_connect::start_default_resolver()), - )) - .unwrap(); + Client::build().connector(connector).finish() + }; + actix_connect::start_default_resolver(); TestServerRuntime { addr, @@ -185,31 +127,6 @@ impl TestServer { } impl TestServerRuntime { - /// Execute future on current core - pub fn block_on(&mut self, fut: F) -> Result - where - F: Future, - { - block_on(fut) - } - - /// Execute future on current core - pub fn block_on_fn(&mut self, f: F) -> Result - where - F: FnOnce() -> R, - R: Future, - { - block_on(lazy(f)) - } - - /// Execute function on current core - pub fn execute(&mut self, fut: F) -> R - where - F: FnOnce() -> R, - { - block_on(lazy(|| Ok::<_, ()>(fut()))).unwrap() - } - /// Construct test server url pub fn addr(&self) -> net::SocketAddr { self.addr @@ -308,33 +225,33 @@ impl TestServerRuntime { self.client.request(method, path.as_ref()) } - pub fn load_body( + pub async fn load_body( &mut self, mut response: ClientResponse, ) -> Result where - S: Stream + 'static, + S: Stream> + Unpin + 'static, { - self.block_on(response.body().limit(10_485_760)) + response.body().limit(10_485_760).await } /// Connect to websocket server at a given path - pub fn ws_at( + pub async fn ws_at( &mut self, path: &str, ) -> Result, awc::error::WsClientError> { let url = self.url(path); let connect = self.client.ws(url).connect(); - block_on(lazy(move || connect.map(|(_, framed)| framed))) + connect.await.map(|(_, framed)| framed) } /// Connect to a websocket server - pub fn ws( + pub async fn ws( &mut self, ) -> Result, awc::error::WsClientError> { - self.ws_at("/") + self.ws_at("/").await } /// Stop http server diff --git a/tests/cert.pem b/tests/cert.pem index f9bb05081..0eeb6721d 100644 --- a/tests/cert.pem +++ b/tests/cert.pem @@ -1,32 +1,19 @@ -----BEGIN CERTIFICATE----- -MIIFfjCCA2agAwIBAgIJAOIBvp/w68KrMA0GCSqGSIb3DQEBCwUAMGsxCzAJBgNV -BAYTAlJVMRkwFwYDVQQIDBBTYWludC1QZXRlcnNidXJnMRkwFwYDVQQHDBBTYWlu -dC1QZXRlcnNidXJnMRIwEAYDVQQKDAlLdXBpYmlsZXQxEjAQBgNVBAMMCWxvY2Fs -aG9zdDAgFw0xOTA3MjcxODIzMTJaGA8zMDE5MDcyNzE4MjMxMlowazELMAkGA1UE -BhMCUlUxGTAXBgNVBAgMEFNhaW50LVBldGVyc2J1cmcxGTAXBgNVBAcMEFNhaW50 -LVBldGVyc2J1cmcxEjAQBgNVBAoMCUt1cGliaWxldDESMBAGA1UEAwwJbG9jYWxo -b3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuiQZzTO3gRRPr6ZH -wcmKqkoXig9taCCqx72Qvb9tvCLhQLE1dDPZV8I/r8bx+mM4Yz3r0Hm5LxTIhCM9 -p3/abuiJAZENC/VkxgFzBGg7KGLSFmzU+A8Ft+2mrKmj5MpIPBCxDeVg80TCQOJy -hj+NU3PpBo9nxTgxWNWO6X+ZovZohdp78fYLLtns8rxjug3FVzdPrrLnBvihkGlq -gfImkh+vZxMTj1OgtxyCOhdbO4Ol4jCbn7a5yIw+iixHOEgBQfTQopRP7z1PEUV2 -WIy2VEGzvQDlj2OyzH86T1IOFV5rz5MjdZuW0qNzeS0w3Jzgp/olSbIZLhGAaIk0 -gN7y9XvSHqs7rO0wW+467ico7+uP1ScGgPgJA5fGu7ahp7F7G3ZSoAqAcS60wYsX -kStoA3RWAuqste6aChv1tlgTt+Rhk8qjGhuh0ng2qVnTGyo2T3OCHB/c47Bcsp6L -xiyTCnQIPH3fh2iO/SC7gPw3jihPMCAQZYlkC3MhMk974rn2zs9cKtJ8ubnG2m5F -VFVYmduRqS/YQS/O802jVCFdc8KDmoeAYNuHzgRZjQv9018UUeW3jtWKnopJnWs5 -ae9pbtmYeOtc7OovOxT7J2AaVfUkHRhmlqWZMzEJTcZws0fRPTZDifFJ5LFWbZsC -zW4tCKBKvYM9eAnbb+abiHXlY1MCAwEAAaMjMCEwHwYDVR0RBBgwFoIJbG9jYWxo -b3N0ggkxMjcuMC4wLjEwDQYJKoZIhvcNAQELBQADggIBAC1EU4SVCfUKc7JbhYRf -P87F+8e13bBTLxevJdnTCH3Xw2AN8UPmwQ2vv9Mv2FMulMBQ7yLnQLGtgGUua2OE -XO+EdBBEKnglo9cwXGzU6qHhaiCeXZDM8s53qOOrD42XsDsY0nOoFYqDLW3WixP9 -f1fWbcEf6+ktlvqi/1/3R6QtQR+6LS43enbsYHq8aAP60NrpXxdXxEoUwW6Z/sje -XAQluH8jzledwJcY8bXRskAHZlE4kGlOVuGgnyI3BXyLiwB4g9smFzYIs98iAGmV -7ZBaR5IIiRCtoKBG+SngM7Log0bHphvFPjDDvgqWYiWaOHboYM60Y2Z/gRbcjuMU -WZX64jw29fa8UPFdtGTupt+iuO7iXnHnm0lBBK36rVdOvsZup76p6L4BXmFsRmFK -qJ2Zd8uWNPDq80Am0mYaAqENuIANHHJXX38SesC+QO+G2JZt6vCwkGk/Qune4GIg -1GwhvsDRfTQopSxg1rdPwPM7HWeTfUGHZ34B5p/iILA3o6PfYQU8fNAWIsCDkRX2 -MrgDgCnLZxKb6pjR4DYNAdPwkxyMFACZ2T46z6WvLWFlnkK5nbZoqsOsp+GJHole -llafhrelXEzt3zFR0q4zGcqheJDI+Wy+fBy3XawgAc4eN0T2UCzL/jKxKgzlzSU3 -+xh1SDNjFLRd6sGzZHPMgXN0 +MIIDEDCCAfgCCQCQdmIZc/Ib/jANBgkqhkiG9w0BAQsFADBKMQswCQYDVQQGEwJ1 +czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZIhvcNAQkBFhJmYWZo +cmQ5MUBnbWFpbC5jb20wHhcNMTkxMTE5MTEwNjU1WhcNMjkxMTE2MTEwNjU1WjBK +MQswCQYDVQQGEwJ1czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZI +hvcNAQkBFhJmYWZocmQ5MUBnbWFpbC5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQDcnaz12CKzUL7248V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C ++kLWKjAc2coqDSbGsrsR6KiH2g06Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezy +XRe/olcHFTeCk/Tllz4xGEplhPua6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqc +K2xntIPreumXpiE3QY4+MWyteiJko4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvu +GccHd/ex8cOwotUqd6emZb+0bVE24Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zU +b2GJosbmfGaf+xTfnGGhTLLL7kCtva+NvZr5AgMBAAEwDQYJKoZIhvcNAQELBQAD +ggEBANftoL8zDGrjCwWvct8kOOqset2ukK8vjIGwfm88CKsy0IfSochNz2qeIu9R +ZuO7c0pfjmRkir9ZQdq9vXgG3ccL9UstFsferPH9W3YJ83kgXg3fa0EmCiN/0hwz +6Ij1ZBiN1j3+d6+PJPgyYFNu2nGwox5mJ9+aRAGe0/9c63PEOY8P2TI4HsiPmYSl +fFR8k/03vr6e+rTKW85BgctjvYKe/TnFxeCQ7dZ+na7vlEtch4tNmy6O/vEk2kCt +5jW0DUxhmRsv2wGmfFRI0+LotHjoXQQZi6nN5aGL3odaGF3gYwIVlZNd3AdkwDQz +BzG0ZwXuDDV9bSs3MfWEWcy4xuU= -----END CERTIFICATE----- diff --git a/tests/key.pem b/tests/key.pem index 70153c8ae..a6d308168 100644 --- a/tests/key.pem +++ b/tests/key.pem @@ -1,52 +1,28 @@ -----BEGIN PRIVATE KEY----- -MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQC6JBnNM7eBFE+v -pkfByYqqSheKD21oIKrHvZC9v228IuFAsTV0M9lXwj+vxvH6YzhjPevQebkvFMiE -Iz2nf9pu6IkBkQ0L9WTGAXMEaDsoYtIWbNT4DwW37aasqaPkykg8ELEN5WDzRMJA -4nKGP41Tc+kGj2fFODFY1Y7pf5mi9miF2nvx9gsu2ezyvGO6DcVXN0+usucG+KGQ -aWqB8iaSH69nExOPU6C3HII6F1s7g6XiMJuftrnIjD6KLEc4SAFB9NCilE/vPU8R -RXZYjLZUQbO9AOWPY7LMfzpPUg4VXmvPkyN1m5bSo3N5LTDcnOCn+iVJshkuEYBo -iTSA3vL1e9Ieqzus7TBb7jruJyjv64/VJwaA+AkDl8a7tqGnsXsbdlKgCoBxLrTB -ixeRK2gDdFYC6qy17poKG/W2WBO35GGTyqMaG6HSeDapWdMbKjZPc4IcH9zjsFyy -novGLJMKdAg8fd+HaI79ILuA/DeOKE8wIBBliWQLcyEyT3viufbOz1wq0ny5ucba -bkVUVViZ25GpL9hBL87zTaNUIV1zwoOah4Bg24fOBFmNC/3TXxRR5beO1Yqeikmd -azlp72lu2Zh461zs6i87FPsnYBpV9SQdGGaWpZkzMQlNxnCzR9E9NkOJ8UnksVZt -mwLNbi0IoEq9gz14Cdtv5puIdeVjUwIDAQABAoICAQCZVVezw+BsAjFKPi1qIv2J -HZOadO7pEc/czflHdUON8SWgxtmDqZpmQmt3/ugiHE284qs4hqzXbcVnpCgLrLRh -HEiP887NhQ3IVjVK8hmZQR5SvsAIv0c0ph3gqbWKqF8sq4tOKR/eBUwHawJwODXR -AvB4KPWQbqOny/P3wNbseRLNAJeNT+MSaw5XPnzgLKvdFoEbJeBNy847Sbsk5DaF -tHgm7n30WS1Q6bkU5VyP//hMBUKNJFaSL4TtCWB5qkbu8B5VbtsR9m0FizTb6L3h -VmYbUXvIzJXjAwMjiDJ1w9wHl+tj3BE33tEmhuVzNf+SH+tLc9xuKJighDWt2vpD -eTpZ1qest26ANLOmNXWVCVTGpcWvOu5yhG/P7En10EzjFruMfHAFdwLm1gMx1rlR -9fyNAk/0ROJ+5BUtuWgDiyytS5f2T9KGiOHni7UbBIkv0CV2H6VL39Twxf+3OHnx -JJ7OWZ8DRuLM/EJfN3C1+3eDsXOvcdvbo2TFBmCCl4Pa2pm4k3g2NBfxy/zSYWIh -ccGPZorFKNMUi29U0Ool6fxeVflbll570pWVBLAB31HdkLSESv9h+2j/IiEJcJXj -nzl2RtYB0Uxzk6SjO0z4WXjz/SXg5tQQkm/dx8kM8TvHICFq68AEnw8t9Hagsdxs -v5jNyOEeI1I5gPgZmKuboQKCAQEA7Hw6s8Xc3UtNaywMcK1Eb1O+kwRdztgtm0uE -uqsHWmGqbBxXN4jiVLh3dILIXFuVbvDSsSZtPLhOj1wqxgsTHg93b4BtvowyNBgo -X4tErMu7/6NRael/hfOEdtpfv2gV+0eQa+8KKqYJPbqpMz/r5L/3RaxS3iXkj3QM -6oC4+cRuwy/flPMIpxhDriH5yjfiMOdDsi3ZfMTJu/58DTrKV7WkJxQZmha4EoZn -IiXeRhzo+2ukMDWrr3GGPyDfjd/NB7rmY8QBdmhB5NSk+6B66JCNTIbKka/pichS -36bwSYFNji4NaHUUlYDUjfKoTNuQMEZknMGhc/433ADO7s17iQKCAQEAyYBYVG7/ -LE2IkvQy9Nwly5tRCNlSvSkloz7PUwRbzG5uF5mweWEa8YECJe9/vrFXvyBW+NR8 -XABFn4eG0POTR9nyb4n2nUlqiGugDIPgkrKCkJws5InifITZ/+Viocd4YZL5UwCU -R1/kMf0UjK2iJjWEeTPS6RmwRI2Iu7kym9BzphDyNYBQSbUE/f+4hNP6nUT/h09c -VH4/sUhubSgVKeK4onOci8bKitAkwVBYCYSyhuBCeCu8fTk2hVRWviRaJPVq2PMB -LHw1FCcfJLIPJG6MZpFAPkMQxpiewdggXIgi46ZlZcsNXEJ81ocT4GU2j+ArQXCf -lgEycyD3mx4k+wKCAQBGneohmKoVYtEheavVUcgnvkggOqOQirlDsE9YNo4hjRyI -4AWjTbrYNaVmI0+VVLvQvxULVUA1a4v5/zm+nbv9s/ykTSN4TQEI0VXtAfdl6gif -k7NR/ynXZBpgK2GAFKLLwFj+Agl1JtOHnV+9MA9O5Yv/QDAWqhYQSEU7GWkjHGc+ -3eLT5ablzrcXHooqunlOxSBP6qURPupGuv1sLewSOOllyfjDLJmW3o+ZgNlY8nUX -7tK+mqhD4ZCG9VgMU5I0BrmZfQQ6yXMz09PYV9mb7N5kxbNjwbXpMOqeYolKSdRQ -6quST7Pv2OKf6KAdI0txPvP4Y1HFA1rG1W71nGKRAoIBAHlDU+T8J3Rx9I77hu70 -zYoKnmnE35YW/R+Q3RQIu3X7vyVUyG9DkQNlr/VEfIw2Dahnve9hcLWtNDkdRnTZ -IPlMoCmfzVo6pHIU0uy1MKEX7Js6YYnnsPVevhLR6NmTQU73NDRPVOzfOGUc+RDw -LXTxIBgQqAy/+ORIiNDwUxSSDgcSi7DG14qD9c0l59WH/HpI276Cc/4lPA9kl4/5 -X0MlvheFm+BCcgG34Wa1A0Y3JXkl3NqU94oktDro1or3NYioaPTGyR4MYaUPJh7f -SV2TacsP/ql5ks7xahkeB9un0ddOfBcWa6PqH1a7U6rnPj63mVB4hpGvhrziSiB/ -s6ECggEAOp2P4Yd9Vm9/CptxS50HFF4adyLscAtsDd3S2hIAXhDovcPbvRek4pLQ -idPhHlRAfqrEztnhaVAmCK9HlhgthtiQGQX62YI4CS4QL2IhzDFo3M1a2snjFEdl -QuFk3XI7kQ0Yp8BLLG7T436JUrUkCXc4gQX2uRNut+ff34RIR2CjcQQjChxuHVeG -sP/3xFFj8OSs7ZoSPbmDBLrMOl64YHwezQUNAZiRYiaGbFiY0QUV6dHq8qX/qE1h -a/0Rq+gTqObDST0TqhMzI8V/i7R8SwVcD5ODHaZp5I2N2P/hV5OWY7ghQXhh89WM -o21xtGh0nP2Fq1TC6jFO+9cpbK8jNA== +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDcnaz12CKzUL72 +48V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C+kLWKjAc2coqDSbGsrsR6KiH2g06 +Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezyXRe/olcHFTeCk/Tllz4xGEplhPua +6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqcK2xntIPreumXpiE3QY4+MWyteiJk +o4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvuGccHd/ex8cOwotUqd6emZb+0bVE2 +4Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zUb2GJosbmfGaf+xTfnGGhTLLL7kCt +va+NvZr5AgMBAAECggEBAKoU0UwzVgVCQgca8Jt2dnBvWYDhnxIfYAI/BvaKedMm +1ms87OKfB7oOiksjyI0E2JklH72dzZf2jm4CuZt5UjGC+xwPzlTaJ4s6hQVbBHyC +NRyxU1BCXtW5tThbrhD4OjxqjmLRJEIB9OunLtwAEQoeuFLB8Va7+HFhR+Zd9k3f +7aVA93pC5A50NRbZlke4miJ3Q8n7ZF0+UmxkBfm3fbqLk7aMWkoEKwLLTadjRlu1 +bBp0YDStX66I/p1kujqBOdh6VpPvxFOa1sV9pq0jeiGc9YfSkzRSKzIn8GoyviFB +fHeszQdNlcnrSDSNnMABAw+ZpxUO7SCaftjwejEmKZUCgYEA+TY43VpmV95eY7eo +WKwGepiHE0fwQLuKGELmZdZI80tFi73oZMuiB5WzwmkaKGcJmm7KGE9KEvHQCo9j +xvmktBR0VEZH8pmVfun+4h6+0H7m/NKMBBeOyv/IK8jBgHjkkB6e6nmeR7CqTxCw +tf9tbajl1QN8gNzXZSjBDT/lanMCgYEA4qANOKOSiEARtgwyXQeeSJcM2uPv6zF3 +ffM7vjSedtuEOHUSVeyBP/W8KDt7zyPppO/WNbURHS+HV0maS9yyj6zpVS2HGmbs +3fetswsQ+zYVdokW89x4oc2z4XOGHd1LcSlyhRwPt0u2g1E9L0irwTQLWU0npFmG +PRf7sN9+LeMCgYAGkDUDL2ROoB6gRa/7Vdx90hKMoXJkYgwLA4gJ2pDlR3A3c/Lw +5KQJyxmG3zm/IqeQF6be6QesZA30mT4peV2rGHbP2WH/s6fKReNelSy1VQJEWk8x +tGUgV4gwDwN5nLV4TjYlOrq+bJqvpmLhCC8bmj0jVQosYqSRl3cuICasnQKBgGlV +VO/Xb1su1EyWPK5qxRIeSxZOTYw2sMB01nbgxCqge0M2fvA6/hQ5ZlwY0cIEgits +YlcSMsMq/TAAANxz1vbaupUhlSMbZcsBvNV0Nk9c4vr2Wxm7hsJF9u66IEMvQUp2 +pkjiMxfR9CHzF4orr9EcHI5EQ0Grbq5kwFKEfoRbAoGAcWoFPILeJOlp2yW/Ds3E +g2fQdI9BAamtEZEaslJmZMmsDTg5ACPcDkOSFEQIaJ7wLPXeZy74FVk/NrY5F8Gz +bjX9OD/xzwp852yW5L9r62vYJakAlXef5jI6CFdYKDDCcarU0S7W5k6kq9n+wrBR +i1NklYmUAMr2q59uJA5zsic= -----END PRIVATE KEY----- diff --git a/tests/test_httpserver.rs b/tests/test_httpserver.rs index c0d2e81c4..122f79baf 100644 --- a/tests/test_httpserver.rs +++ b/tests/test_httpserver.rs @@ -2,8 +2,8 @@ use net2::TcpBuilder; use std::sync::mpsc; use std::{net, thread, time::Duration}; -#[cfg(feature = "ssl")] -use openssl::ssl::SslAcceptorBuilder; +#[cfg(feature = "openssl")] +use open_ssl::ssl::SslAcceptorBuilder; use actix_http::Response; use actix_web::{test, web, App, HttpServer}; @@ -55,22 +55,19 @@ fn test_start() { use actix_http::client; use actix_web::test; - let client = test::run_on(|| { - Ok::<_, ()>( - awc::Client::build() - .connector( - client::Connector::new() - .timeout(Duration::from_millis(100)) - .finish(), - ) - .finish(), - ) - }) - .unwrap(); - let host = format!("http://{}", addr); + test::block_on(async { + let client = awc::Client::build() + .connector( + client::Connector::new() + .timeout(Duration::from_millis(100)) + .finish(), + ) + .finish(); - let response = test::block_on(client.get(host.clone()).send()).unwrap(); - assert!(response.status().is_success()); + let host = format!("http://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); + assert!(response.status().is_success()); + }); } // stop @@ -80,9 +77,9 @@ fn test_start() { let _ = sys.stop(); } -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] fn ssl_acceptor() -> std::io::Result { - use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; + use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; // load ssl keys let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); builder @@ -95,7 +92,7 @@ fn ssl_acceptor() -> std::io::Result { } #[test] -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] fn test_start_ssl() { let addr = unused_addr(); let (tx, rx) = mpsc::channel(); @@ -113,7 +110,7 @@ fn test_start_ssl() { .shutdown_timeout(1) .system_exit() .disable_signals() - .bind_ssl(format!("{}", addr), builder) + .bind_openssl(format!("{}", addr), builder) .unwrap() .start(); @@ -122,30 +119,27 @@ fn test_start_ssl() { }); let (srv, sys) = rx.recv().unwrap(); - let client = test::run_on(|| { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + test::block_on(async move { + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); let _ = builder .set_alpn_protos(b"\x02h2\x08http/1.1") .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); - Ok::<_, ()>( - awc::Client::build() - .connector( - awc::Connector::new() - .ssl(builder.build()) - .timeout(Duration::from_millis(100)) - .finish(), - ) - .finish(), - ) - }) - .unwrap(); - let host = format!("https://{}", addr); + let client = awc::Client::build() + .connector( + awc::Connector::new() + .ssl(builder.build()) + .timeout(Duration::from_millis(100)) + .finish(), + ) + .finish(); - let response = test::block_on(client.get(host.clone()).send()).unwrap(); - assert!(response.status().is_success()); + let host = format!("https://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); + assert!(response.status().is_success()); + }); // stop let _ = srv.stop(false); diff --git a/tests/test_server.rs b/tests/test_server.rs index 1623d2ef3..0114b21ff 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1,24 +1,21 @@ use std::io::{Read, Write}; -use std::sync::mpsc; -use std::thread; use actix_http::http::header::{ ContentEncoding, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING, }; use actix_http::{h1, Error, HttpService, Response}; -use actix_http_test::TestServer; +use actix_http_test::{block_on, TestServer}; use brotli2::write::{BrotliDecoder, BrotliEncoder}; use bytes::Bytes; use flate2::read::GzDecoder; use flate2::write::{GzEncoder, ZlibDecoder, ZlibEncoder}; use flate2::Compression; -use futures::stream::once; +use futures::{future::ok, stream::once}; use rand::{distributions::Alphanumeric, Rng}; -use actix_connect::start_default_resolver; use actix_web::middleware::{BodyEncoding, Compress}; -use actix_web::{dev, http, test, web, App, HttpResponse, HttpServer}; +use actix_web::{dev, http, web, App, HttpResponse, HttpServer}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -44,679 +41,721 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ #[test] fn test_body() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .service(web::resource("/").route(web::to(|| Response::Ok().body(STR)))), - ) - }); + block_on(async { + let srv = + TestServer::start(|| { + h1::H1Service::new(App::new().service( + web::resource("/").route(web::to(|| Response::Ok().body(STR))), + )) + }); - let mut response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); + let mut response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[test] fn test_body_gzip() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| Response::Ok().body(STR)))), - ) - }); + block_on(async { + let srv = TestServer::start(|| { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service( + web::resource("/").route(web::to(|| Response::Ok().body(STR))), + ), + ) + }); - let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + // decode + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[test] fn test_body_gzip2() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| { - Response::Ok().body(STR).into_body::() - }))), - ) - }); + block_on(async { + let srv = TestServer::start(|| { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| { + Response::Ok().body(STR).into_body::() + }))), + ) + }); - let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + // decode + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[test] fn test_body_encoding_override() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| { - Response::Ok().encoding(ContentEncoding::Deflate).body(STR) - }))) - .service(web::resource("/raw").route(web::to(|| { - let body = actix_web::dev::Body::Bytes(STR.into()); - let mut response = - Response::with_body(actix_web::http::StatusCode::OK, body); + block_on(async { + let srv = TestServer::start(|| { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| { + Response::Ok().encoding(ContentEncoding::Deflate).body(STR) + }))) + .service(web::resource("/raw").route(web::to(|| { + let body = actix_web::dev::Body::Bytes(STR.into()); + let mut response = + Response::with_body(actix_web::http::StatusCode::OK, body); - response.encoding(ContentEncoding::Deflate); + response.encoding(ContentEncoding::Deflate); - response - }))), - ) - }); + response + }))), + ) + }); - // Builder - let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "deflate") - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + // Builder + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "deflate") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode - let mut e = ZlibDecoder::new(Vec::new()); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + // decode + let mut e = ZlibDecoder::new(Vec::new()); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); - // Raw Response - let mut response = srv - .block_on( - srv.request(actix_web::http::Method::GET, srv.url("/raw")) - .no_decompress() - .header(ACCEPT_ENCODING, "deflate") - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + // Raw Response + let mut response = srv + .request(actix_web::http::Method::GET, srv.url("/raw")) + .no_decompress() + .header(ACCEPT_ENCODING, "deflate") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode - let mut e = ZlibDecoder::new(Vec::new()); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + // decode + let mut e = ZlibDecoder::new(Vec::new()); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[test] fn test_body_gzip_large() { - let data = STR.repeat(10); - let srv_data = data.clone(); + block_on(async { + let data = STR.repeat(10); + let srv_data = data.clone(); - let mut srv = TestServer::new(move || { - let data = srv_data.clone(); - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service( - web::resource("/") - .route(web::to(move || Response::Ok().body(data.clone()))), - ), - ) - }); + let srv = TestServer::start(move || { + let data = srv_data.clone(); + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service( + web::resource("/") + .route(web::to(move || Response::Ok().body(data.clone()))), + ), + ) + }); - let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from(data)); + // decode + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from(data)); + }) } #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[test] fn test_body_gzip_large_random() { - let data = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(70_000) - .collect::(); - let srv_data = data.clone(); + block_on(async { + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(70_000) + .collect::(); + let srv_data = data.clone(); - let mut srv = TestServer::new(move || { - let data = srv_data.clone(); - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service( - web::resource("/") - .route(web::to(move || Response::Ok().body(data.clone()))), - ), - ) - }); + let srv = TestServer::start(move || { + let data = srv_data.clone(); + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service( + web::resource("/") + .route(web::to(move || Response::Ok().body(data.clone()))), + ), + ) + }); - let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(dec.len(), data.len()); - assert_eq!(Bytes::from(dec), Bytes::from(data)); + // decode + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(dec.len(), data.len()); + assert_eq!(Bytes::from(dec), Bytes::from(data)); + }) } #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[test] fn test_body_chunked_implicit() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::get().to(move || { - Response::Ok().streaming(once(Ok::<_, Error>(Bytes::from_static( - STR.as_ref(), - )))) - }))), - ) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::get().to(move || { + Response::Ok().streaming(once(ok::<_, Error>( + Bytes::from_static(STR.as_ref()), + ))) + }))), + ) + }); - let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); - assert_eq!( - response.headers().get(TRANSFER_ENCODING).unwrap(), - &b"chunked"[..] - ); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); + assert_eq!( + response.headers().get(TRANSFER_ENCODING).unwrap(), + &b"chunked"[..] + ); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + // decode + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[test] #[cfg(feature = "brotli")] fn test_body_br_streaming() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().wrap(Compress::new(ContentEncoding::Br)).service( - web::resource("/").route(web::to(move || { - Response::Ok() - .streaming(once(Ok::<_, Error>(Bytes::from_static(STR.as_ref())))) - })), - )) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().wrap(Compress::new(ContentEncoding::Br)).service( + web::resource("/").route(web::to(move || { + Response::Ok().streaming(once(ok::<_, Error>( + Bytes::from_static(STR.as_ref()), + ))) + })), + ), + ) + }); - let mut response = srv - .block_on( - srv.get("/") - .header(ACCEPT_ENCODING, "br") - .no_decompress() - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + let mut response = srv + .get("/") + .header(ACCEPT_ENCODING, "br") + .no_decompress() + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode br - let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + // decode br + let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[test] fn test_head_binary() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().service(web::resource("/").route( - web::head().to(move || Response::Ok().content_length(100).body(STR)), - ))) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new(App::new().service(web::resource("/").route( + web::head().to(move || Response::Ok().content_length(100).body(STR)), + ))) + }); - let mut response = srv.block_on(srv.head("/").send()).unwrap(); - assert!(response.status().is_success()); + let mut response = srv.head("/").send().await.unwrap(); + assert!(response.status().is_success()); - { - let len = response.headers().get(CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } + { + let len = response.headers().get(CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert!(bytes.is_empty()); + // read response + let bytes = response.body().await.unwrap(); + assert!(bytes.is_empty()); + }) } #[test] fn test_no_chunking() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().service(web::resource("/").route(web::to( - move || { - Response::Ok() - .no_chunking() - .content_length(STR.len() as u64) - .streaming(once(Ok::<_, Error>(Bytes::from_static(STR.as_ref())))) - }, - )))) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new(App::new().service(web::resource("/").route(web::to( + move || { + Response::Ok() + .no_chunking() + .content_length(STR.len() as u64) + .streaming(once(ok::<_, Error>(Bytes::from_static( + STR.as_ref(), + )))) + }, + )))) + }); - let mut response = srv.block_on(srv.get("/").send()).unwrap(); - assert!(response.status().is_success()); - assert!(!response.headers().contains_key(TRANSFER_ENCODING)); + let mut response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(TRANSFER_ENCODING)); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_body_deflate() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Deflate)) - .service( - web::resource("/").route(web::to(move || Response::Ok().body(STR))), - ), - ) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Deflate)) + .service( + web::resource("/") + .route(web::to(move || Response::Ok().body(STR))), + ), + ) + }); - // client request - let mut response = srv - .block_on( - srv.get("/") - .header(ACCEPT_ENCODING, "deflate") - .no_decompress() - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + // client request + let mut response = srv + .get("/") + .header(ACCEPT_ENCODING, "deflate") + .no_decompress() + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - let mut e = ZlibDecoder::new(Vec::new()); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + let mut e = ZlibDecoder::new(Vec::new()); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[test] #[cfg(any(feature = "brotli"))] fn test_body_brotli() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().wrap(Compress::new(ContentEncoding::Br)).service( - web::resource("/").route(web::to(move || Response::Ok().body(STR))), - )) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().wrap(Compress::new(ContentEncoding::Br)).service( + web::resource("/").route(web::to(move || Response::Ok().body(STR))), + ), + ) + }); - // client request - let mut response = srv - .block_on( - srv.get("/") - .header(ACCEPT_ENCODING, "br") - .no_decompress() - .send(), - ) - .unwrap(); - assert!(response.status().is_success()); + // client request + let mut response = srv + .get("/") + .header(ACCEPT_ENCODING, "br") + .no_decompress() + .send() + .await + .unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); + // read response + let bytes = response.body().await.unwrap(); - // decode brotli - let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + // decode brotli + let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_encoding() { - let mut srv = TestServer::new(move || { - HttpService::new( - App::new().wrap(Compress::default()).service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + block_on(async { + let srv = TestServer::start(move || { + HttpService::new( + App::new().wrap(Compress::default()).service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + // client request + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - let request = srv - .post("/") - .header(CONTENT_ENCODING, "gzip") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_gzip_encoding() { - let mut srv = TestServer::new(move || { - HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + block_on(async { + let srv = TestServer::start(move || { + HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + // client request + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - let request = srv - .post("/") - .header(CONTENT_ENCODING, "gzip") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_gzip_encoding_large() { - let data = STR.repeat(10); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + block_on(async { + let data = STR.repeat(10); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + // client request + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - let request = srv - .post("/") - .header(CONTENT_ENCODING, "gzip") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(data)); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_reading_gzip_encoding_large_random() { - let data = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(60_000) - .collect::(); + block_on(async { + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(60_000) + .collect::(); - let mut srv = TestServer::new(move || { - HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + let srv = TestServer::start(move || { + HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + // client request + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - let request = srv - .post("/") - .header(CONTENT_ENCODING, "gzip") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data)); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_reading_deflate_encoding() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - // client request - let request = srv - .post("/") - .header(CONTENT_ENCODING, "deflate") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // client request + let request = srv + .post("/") + .header(CONTENT_ENCODING, "deflate") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_reading_deflate_encoding_large() { - let data = STR.repeat(10); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + block_on(async { + let data = STR.repeat(10); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - // client request - let request = srv - .post("/") - .header(CONTENT_ENCODING, "deflate") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // client request + let request = srv + .post("/") + .header(CONTENT_ENCODING, "deflate") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(data)); + }) } #[test] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] fn test_reading_deflate_encoding_large_random() { - let data = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(160_000) - .collect::(); + block_on(async { + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(160_000) + .collect::(); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - // client request - let request = srv - .post("/") - .header(CONTENT_ENCODING, "deflate") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // client request + let request = srv + .post("/") + .header(CONTENT_ENCODING, "deflate") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data)); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); + }) } #[test] #[cfg(feature = "brotli")] fn test_brotli_encoding() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + block_on(async { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(STR.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - // client request - let request = srv - .post("/") - .header(CONTENT_ENCODING, "br") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // client request + let request = srv + .post("/") + .header(CONTENT_ENCODING, "br") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + }) } #[cfg(feature = "brotli")] #[test] fn test_brotli_encoding_large() { - let data = STR.repeat(10); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), - ) - }); + block_on(async { + let data = STR.repeat(10); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - // client request - let request = srv - .post("/") - .header(CONTENT_ENCODING, "br") - .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); - assert!(response.status().is_success()); + // client request + let request = srv + .post("/") + .header(CONTENT_ENCODING, "br") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); - // read response - let bytes = srv.block_on(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(data)); + }) } // #[cfg(all(feature = "brotli", feature = "ssl"))] @@ -776,91 +815,95 @@ fn test_brotli_encoding_large() { // } #[cfg(all( - feature = "rust-tls", - feature = "ssl", + feature = "rustls", + feature = "openssl", any(feature = "flate2-zlib", feature = "flate2-rust") ))] #[test] fn test_reading_deflate_encoding_large_random_ssl() { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - use rustls::internal::pemfile::{certs, pkcs8_private_keys}; - use rustls::{NoClientAuth, ServerConfig}; - use std::fs::File; - use std::io::BufReader; + block_on(async { + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use rust_tls::internal::pemfile::{certs, pkcs8_private_keys}; + use rust_tls::{NoClientAuth, ServerConfig}; + use std::fs::File; + use std::io::BufReader; + use std::sync::mpsc; - let addr = TestServer::unused_addr(); - let (tx, rx) = mpsc::channel(); + let addr = TestServer::unused_addr(); + let (tx, rx) = mpsc::channel(); - let data = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(160_000) - .collect::(); + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(160_000) + .collect::(); - thread::spawn(move || { - let sys = actix_rt::System::new("test"); + std::thread::spawn(move || { + let sys = actix_rt::System::new("test"); - // load ssl keys - let mut config = ServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + // load ssl keys + let mut config = ServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - let srv = HttpServer::new(|| { - App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { - Ok::<_, Error>( - HttpResponse::Ok() - .encoding(http::ContentEncoding::Identity) - .body(bytes), + let srv = HttpServer::new(|| { + App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { + async move { + Ok::<_, Error>( + HttpResponse::Ok() + .encoding(http::ContentEncoding::Identity) + .body(bytes), + ) + } + }))) + }) + .bind_rustls(addr, config) + .unwrap() + .start(); + + let _ = tx.send((srv, actix_rt::System::current())); + let _ = sys.run(); + }); + let (srv, _sys) = rx.recv().unwrap(); + let client = { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder.set_alpn_protos(b"\x02h2\x08http/1.1").unwrap(); + + awc::Client::build() + .connector( + awc::Connector::new() + .timeout(std::time::Duration::from_millis(500)) + .ssl(builder.build()) + .finish(), ) - }))) - }) - .bind_rustls(addr, config) - .unwrap() - .start(); + .finish() + }; - let _ = tx.send((srv, actix_rt::System::current())); - let _ = sys.run(); - }); - let (srv, _sys) = rx.recv().unwrap(); - test::block_on(futures::lazy(|| Ok::<_, ()>(start_default_resolver()))).unwrap(); - let client = test::run_on(|| { - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder.set_alpn_protos(b"\x02h2\x08http/1.1").unwrap(); + // encode data + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); - awc::Client::build() - .connector( - awc::Connector::new() - .timeout(std::time::Duration::from_millis(500)) - .ssl(builder.build()) - .finish(), - ) - .finish() - }); + // client request + let req = client + .post(format!("https://localhost:{}/", addr.port())) + .header(http::header::CONTENT_ENCODING, "deflate") + .send_body(enc); - // encode data - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut response = req.await.unwrap(); + assert!(response.status().is_success()); - // client request - let req = client - .post(format!("https://localhost:{}/", addr.port())) - .header(http::header::CONTENT_ENCODING, "deflate") - .send_body(enc); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); - let mut response = test::block_on(req).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = test::block_on(response.body()).unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data)); - - // stop - let _ = srv.stop(false); + // stop + let _ = srv.stop(false); + }) } // #[cfg(all(feature = "tls", feature = "ssl"))] @@ -954,7 +997,7 @@ fn test_reading_deflate_encoding_large_random_ssl() { // fn test_server_cookies() { // use actix_web::http; -// let mut srv = test::TestServer::with_factory(|| { +// let srv = test::TestServer::with_factory(|| { // App::new().resource("/", |r| { // r.f(|_| { // HttpResponse::Ok()