diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml new file mode 100644 index 000000000..f50ae2f05 --- /dev/null +++ b/.github/workflows/macos.yml @@ -0,0 +1,59 @@ +name: CI (macOS) + +on: [push, pull_request] + +jobs: + build_and_test: + strategy: + fail-fast: false + matrix: + version: + - stable + - nightly + + name: ${{ matrix.version }} - x86_64-apple-darwin + runs-on: macOS-latest + + steps: + - uses: actions/checkout@master + + - name: Install ${{ matrix.version }} + uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.version }}-x86_64-apple-darwin + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: + command: update + - name: Cache cargo registry + uses: actions/cache@v1 + with: + path: ~/.cargo/registry + key: ${{ matrix.version }}-x86_64-apple-darwin-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + - name: Cache cargo index + uses: actions/cache@v1 + with: + path: ~/.cargo/git + key: ${{ matrix.version }}-x86_64-apple-darwin-cargo-index-${{ hashFiles('**/Cargo.lock') }} + - name: Cache cargo build + uses: actions/cache@v1 + with: + path: target + key: ${{ matrix.version }}-x86_64-apple-darwin-cargo-build-${{ hashFiles('**/Cargo.lock') }} + + - name: check build + uses: actions-rs/cargo@v1 + with: + command: check + args: --all --bins --examples --tests + + - name: tests + uses: actions-rs/cargo@v1 + with: + command: test + args: --all --all-features --no-fail-fast -- --nocapture + --skip=test_h2_content_length + --skip=test_reading_deflate_encoding_large_random_rustls diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml new file mode 100644 index 000000000..9aa3d3ba4 --- /dev/null +++ b/.github/workflows/windows.yml @@ -0,0 +1,79 @@ +name: CI (Windows) + +on: [push, pull_request] + +env: + VCPKGRS_DYNAMIC: 1 + +jobs: + build_and_test: + strategy: + fail-fast: false + matrix: + version: + - stable + - nightly + + name: ${{ matrix.version }} - x86_64-pc-windows-msvc + runs-on: windows-latest + + steps: + - uses: actions/checkout@master + + - name: Install ${{ matrix.version }} + uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.version }}-x86_64-pc-windows-msvc + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: + command: update + - name: Cache cargo registry + uses: actions/cache@v1 + with: + path: ~/.cargo/registry + key: ${{ matrix.version }}-x86_64-pc-windows-msvc-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + - name: Cache cargo index + uses: actions/cache@v1 + with: + path: ~/.cargo/git + key: ${{ matrix.version }}-x86_64-pc-windows-msvc-cargo-index-${{ hashFiles('**/Cargo.lock') }} + - name: Cache cargo build + uses: actions/cache@v1 + with: + path: target + key: ${{ matrix.version }}-x86_64-pc-windows-msvc-cargo-build-${{ hashFiles('**/Cargo.lock') }} + - name: Cache vcpkg package + uses: actions/cache@v1 + id: cache-vcpkg + with: + path: C:\vcpkg + key: windows_x64-${{ matrix.version }}-vcpkg + + - name: Install OpenSSL + if: steps.cache-vcpkg.outputs.cache-hit != 'true' + run: | + vcpkg integrate install + vcpkg install openssl:x64-windows + + - name: check build + uses: actions-rs/cargo@v1 + with: + command: check + args: --all --bins --examples --tests + + - name: tests + uses: actions-rs/cargo@v1 + with: + command: test + args: --all --all-features --no-fail-fast -- --nocapture + --skip=test_h2_content_length + --skip=test_reading_deflate_encoding_large_random_rustls + --skip=test_params + --skip=test_simple + --skip=test_expect_continue + --skip=test_http10_keepalive + --skip=test_slow_request diff --git a/.travis.yml b/.travis.yml index 82db86a6e..f10f82a48 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,9 +10,9 @@ matrix: include: - rust: stable - rust: beta - - rust: nightly-2019-08-10 + - rust: nightly-2019-11-20 allow_failures: - - rust: nightly-2019-08-10 + - rust: nightly-2019-11-20 env: global: @@ -25,7 +25,7 @@ before_install: - sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev before_cache: | - if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin fi @@ -36,9 +36,12 @@ before_script: script: - cargo update - cargo check --all --no-default-features - - cargo test --all-features --all -- --nocapture - - cd actix-http; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd .. - - cd awc; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd .. + - | + if [[ "$TRAVIS_RUST_VERSION" == "stable" || "$TRAVIS_RUST_VERSION" == "beta" ]]; then + cargo test --all-features --all -- --nocapture + cd actix-http; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. + cd awc; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. + fi # Upload docs after_success: @@ -51,7 +54,7 @@ after_success: echo "Uploaded documentation" fi - | - if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then taskset -c 0 cargo tarpaulin --out Xml --all --all-features bash <(curl -s https://codecov.io/bash) echo "Uploaded code coverage" diff --git a/CHANGES.md b/CHANGES.md index f37f8b466..29f78e0b1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,13 +1,105 @@ # Changes -## not released yet -### Added +## [2.0.NEXT] - 2020-01-xx -* Add `middleware::Conditon` that conditionally enables another middleware +### Changed + +* Use `sha-1` crate instead of unmaintained `sha1` crate + +## [2.0.0] - 2019-12-25 + +### Changed + +* Rename `HttpServer::start()` to `HttpServer::run()` + +* Allow to gracefully stop test server via `TestServer::stop()` + +* Allow to specify multi-patterns for resources + +## [2.0.0-rc] - 2019-12-20 + +### Changed + +* Move `BodyEncoding` to `dev` module #1220 + +* Allow to set `peer_addr` for TestRequest #1074 + +* Make web::Data deref to Arc #1214 + +* Rename `App::register_data()` to `App::app_data()` + +* `HttpRequest::app_data()` returns `Option<&T>` instead of `Option<&Data>` ### Fixed -* h2 will use error response #1080 +* Fix `AppConfig::secure()` is always false. #1202 + + +## [2.0.0-alpha.6] - 2019-12-15 + +### Fixed + +* Fixed compilation with default features off + +## [2.0.0-alpha.5] - 2019-12-13 + +### Added + +* Add test server, `test::start()` and `test::start_with()` + +## [2.0.0-alpha.4] - 2019-12-08 + +### Deleted + +* Delete HttpServer::run(), it is not useful witht async/await + +## [2.0.0-alpha.3] - 2019-12-07 + +### Changed + +* Migrate to tokio 0.2 + + +## [2.0.0-alpha.1] - 2019-11-22 + +### Changed + +* Migrated to `std::future` + +* Remove implementation of `Responder` for `()`. (#1167) + + +## [1.0.9] - 2019-11-14 + +### Added + +* Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110) + +### Changed + +* Support `Host` guards when the `Host` header is unset (e.g. HTTP/2 requests) (#1129) + + +## [1.0.8] - 2019-09-25 + +### Added + +* Add `Scope::register_data` and `Resource::register_data` methods, parallel to + `App::register_data`. + +* Add `middleware::Condition` that conditionally enables another middleware + +* Allow to re-construct `ServiceRequest` from `HttpRequest` and `Payload` + +* Add `HttpServer::listen_uds` for ability to listen on UDS FD rather than path, + which is useful for example with systemd. + +### Changed + +* Make UrlEncodedError::Overflow more informativve + +* Use actix-testing for testing utils + ## [1.0.7] - 2019-08-29 diff --git a/Cargo.toml b/Cargo.toml index c2d3b0d2b..9e1b559eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "1.0.7" +version = "2.0.0" authors = ["Nikolay Kim "] description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." readme = "README.md" @@ -12,11 +12,10 @@ categories = ["network-programming", "asynchronous", "web-programming::http-server", "web-programming::websocket"] license = "MIT/Apache-2.0" -exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] edition = "2018" [package.metadata.docs.rs] -features = ["ssl", "brotli", "flate2-zlib", "secure-cookies", "client", "rust-tls", "uds"] +features = ["openssl", "rustls", "compress", "secure-cookies"] [badges] travis-ci = { repository = "actix/actix-web", branch = "master" } @@ -43,77 +42,63 @@ members = [ ] [features] -default = ["brotli", "flate2-zlib", "client", "fail"] +default = ["compress", "failure"] -# http client -client = ["awc"] - -# brotli encoding, requires c compiler -brotli = ["actix-http/brotli"] - -# miniz-sys backend for flate2 crate -flate2-zlib = ["actix-http/flate2-zlib"] - -# rust backend for flate2 crate -flate2-rust = ["actix-http/flate2-rust"] +# content-encoding support +compress = ["actix-http/compress", "awc/compress"] # sessions feature, session require "ring" crate and c compiler secure-cookies = ["actix-http/secure-cookies"] -fail = ["actix-http/fail"] +failure = ["actix-http/failure"] # openssl -ssl = ["openssl", "actix-server/ssl", "awc/ssl"] +openssl = ["actix-tls/openssl", "awc/openssl", "open-ssl"] # rustls -rust-tls = ["rustls", "actix-server/rust-tls", "awc/rust-tls"] - -# unix domain sockets support -uds = ["actix-server/uds"] +rustls = ["actix-tls/rustls", "awc/rustls", "rust-tls"] [dependencies] -actix-codec = "0.1.2" -actix-service = "0.4.1" -actix-utils = "0.4.4" -actix-router = "0.1.5" -actix-rt = "0.2.4" -actix-web-codegen = "0.1.2" -actix-http = "0.2.9" -actix-server = "0.6.0" -actix-server-config = "0.1.2" -actix-threadpool = "0.1.1" -awc = { version = "0.2.4", optional = true } +actix-codec = "0.2.0" +actix-service = "1.0.2" +actix-utils = "1.0.6" +actix-router = "0.2.4" +actix-rt = "1.0.0" +actix-server = "1.0.0" +actix-testing = "1.0.0" +actix-macros = "0.1.0" +actix-threadpool = "0.3.1" +actix-tls = "1.0.0" -bytes = "0.4" -derive_more = "0.15.0" +actix-web-codegen = "0.2.0" +actix-http = "1.0.1" +awc = { version = "1.0.1", default-features = false } + +bytes = "0.5.3" +derive_more = "0.99.2" encoding_rs = "0.8" -futures = "0.1.25" -hashbrown = "0.5.0" +futures = "0.3.1" +fxhash = "0.2.1" log = "0.4" mime = "0.3" net2 = "0.2.33" -parking_lot = "0.9" -regex = "1.0" +pin-project = "0.4.6" +regex = "1.3" serde = { version = "1.0", features=["derive"] } serde_json = "1.0" serde_urlencoded = "0.6.1" time = "0.1.42" url = "2.1" - -# ssl support -openssl = { version="0.10", optional = true } -rustls = { version = "0.15", optional = true } +open-ssl = { version="0.10", package = "openssl", optional = true } +rust-tls = { version = "0.16.0", package = "rustls", optional = true } [dev-dependencies] -actix = "0.8.3" -actix-connect = "0.2.2" -actix-http-test = "0.2.4" +actix = "0.9.0" rand = "0.7" env_logger = "0.6" serde_derive = "1.0" -tokio-timer = "0.2.8" brotli2 = "0.3.2" -flate2 = "1.0.2" +flate2 = "1.0.13" [profile.release] lto = true @@ -125,7 +110,8 @@ actix-web = { path = "." } actix-http = { path = "actix-http" } actix-http-test = { path = "test-server" } actix-web-codegen = { path = "actix-web-codegen" } -actix-web-actors = { path = "actix-web-actors" } +actix-cors = { path = "actix-cors" } +actix-identity = { path = "actix-identity" } actix-session = { path = "actix-session" } actix-files = { path = "actix-files" } actix-multipart = { path = "actix-multipart" } diff --git a/MIGRATION.md b/MIGRATION.md index 2f0f369ad..529f9714d 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,3 +1,51 @@ +## 2.0.0 + +* `HttpServer::start()` renamed to `HttpServer::run()`. It also possible to + `.await` on `run` method result, in that case it awaits server exit. + +* `App::register_data()` renamed to `App::app_data()` and accepts any type `T: 'static`. + Stored data is available via `HttpRequest::app_data()` method at runtime. + +* Extractor configuration must be registered with `App::app_data()` instead of `App::data()` + +* Sync handlers has been removed. `.to_async()` method has been renamed to `.to()` + replace `fn` with `async fn` to convert sync handler to async + +* `actix_http_test::TestServer` moved to `actix_web::test` module. To start + test server use `test::start()` or `test_start_with_config()` methods + +* `ResponseError` trait has been reafctored. `ResponseError::error_response()` renders + http response. + +* Feature `rust-tls` renamed to `rustls` + + instead of + + ```rust + actix-web = { version = "2.0.0", features = ["rust-tls"] } + ``` + + use + + ```rust + actix-web = { version = "2.0.0", features = ["rustls"] } + ``` + +* Feature `ssl` renamed to `openssl` + + instead of + + ```rust + actix-web = { version = "2.0.0", features = ["ssl"] } + ``` + + use + + ```rust + actix-web = { version = "2.0.0", features = ["openssl"] } + ``` +* `Cors` builder now requires that you call `.finish()` to construct the middleware + ## 1.0.1 * Cors middleware has been moved to `actix-cors` crate @@ -34,52 +82,52 @@ * Extractor configuration. In version 1.0 this is handled with the new `Data` mechanism for both setting and retrieving the configuration instead of - + ```rust - + #[derive(Default)] struct ExtractorConfig { config: String, } - + impl FromRequest for YourExtractor { type Config = ExtractorConfig; type Result = Result; - + fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { println!("use the config: {:?}", cfg.config); ... } } - + App::new().resource("/route_with_config", |r| { r.post().with_config(handler_fn, |cfg| { cfg.0.config = "test".to_string(); }) }) - + ``` - + use the HttpRequest to get the configuration like any other `Data` with `req.app_data::()` and set it with the `data()` method on the `resource` - + ```rust #[derive(Default)] struct ExtractorConfig { config: String, } - + impl FromRequest for YourExtractor { type Error = Error; type Future = Result; type Config = ExtractorConfig; - + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { let cfg = req.app_data::(); println!("config data?: {:?}", cfg.unwrap().role); ... } } - + App::new().service( resource("/route_with_config") .data(ExtractorConfig { @@ -88,7 +136,7 @@ .route(post().to(handler_fn)), ) ``` - + * Resource registration. 1.0 version uses generalized resource registration via `.service()` method. @@ -379,9 +427,9 @@ * `HttpRequest` does not implement `Stream` anymore. If you need to read request payload use `HttpMessage::payload()` method. - + instead of - + ```rust fn index(req: HttpRequest) -> impl Responder { req @@ -407,8 +455,8 @@ trait uses `&HttpRequest` instead of `&mut HttpRequest`. * Removed `Route::with2()` and `Route::with3()` use tuple of extractors instead. - - instead of + + instead of ```rust fn index(query: Query<..>, info: Json impl Responder {} @@ -424,7 +472,7 @@ * `Handler::handle()` accepts reference to `HttpRequest<_>` instead of value -* Removed deprecated `HttpServer::threads()`, use +* Removed deprecated `HttpServer::threads()`, use [HttpServer::workers()](https://actix.rs/actix-web/actix_web/server/struct.HttpServer.html#method.workers) instead. * Renamed `client::ClientConnectorError::Connector` to @@ -433,7 +481,7 @@ * `Route::with()` does not return `ExtractorConfig`, to configure extractor use `Route::with_config()` - instead of + instead of ```rust fn main() { @@ -444,11 +492,11 @@ }); } ``` - - use - + + use + ```rust - + fn main() { let app = App::new().resource("/index.html", |r| { r.method(http::Method::GET) @@ -478,12 +526,12 @@ * `HttpRequest::extensions()` returns read only reference to the request's Extension `HttpRequest::extensions_mut()` returns mutable reference. -* Instead of +* Instead of `use actix_web::middleware::{ CookieSessionBackend, CookieSessionError, RequestSession, Session, SessionBackend, SessionImpl, SessionStorage};` - + use `actix_web::middleware::session` `use actix_web::middleware::session{CookieSessionBackend, CookieSessionError, diff --git a/README.md b/README.md index 99b7b1760..0b55c5bae 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,28 @@ -# Actix web [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +
+

Actix web

+

Actix web is a small, pragmatic, and extremely fast rust web framework

+

+ +[![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) +[![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) +[![crates.io](https://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web) +[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Documentation](https://docs.rs/actix-web/badge.svg)](https://docs.rs/actix-web) +[![Download](https://img.shields.io/crates/d/actix-web.svg)](https://crates.io/crates/actix-web) +[![Version](https://img.shields.io/badge/rustc-1.39+-lightgray.svg)](https://blog.rust-lang.org/2019/11/07/Rust-1.39.0.html) +![License](https://img.shields.io/crates/l/actix-web.svg) + +

+ +

+ Website + | + Chat + | + Examples +

+
+
Actix web is a simple, pragmatic and extremely fast web framework for Rust. @@ -15,30 +39,32 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust. * Includes an asynchronous [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) * Supports [Actix actor framework](https://github.com/actix/actix) -## Documentation & community resources - -* [User Guide](https://actix.rs/docs/) -* [API Documentation (1.0)](https://docs.rs/actix-web/) -* [API Documentation (0.7)](https://docs.rs/actix-web/0.7.19/actix_web/) -* [Chat on gitter](https://gitter.im/actix/actix) -* Cargo package: [actix-web](https://crates.io/crates/actix-web) -* Minimum supported Rust version: 1.36 or later - ## Example -```rust -use actix_web::{web, App, HttpServer, Responder}; +Dependencies: -fn index(info: web::Path<(u32, String)>) -> impl Responder { +```toml +[dependencies] +actix-web = "2" +actix-rt = "1" +``` + +Code: + +```rust +use actix_web::{get, web, App, HttpServer, Responder}; + +#[get("/{id}/{name}/index.html")] +async fn index(info: web::Path<(u32, String)>) -> impl Responder { format!("Hello {}! id:{}", info.1, info.0) } -fn main() -> std::io::Result<()> { - HttpServer::new( - || App::new().service( - web::resource("/{id}/{name}/index.html").to(index))) +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + HttpServer::new(|| App::new().service(index)) .bind("127.0.0.1:8080")? .run() + .await } ``` @@ -49,10 +75,11 @@ fn main() -> std::io::Result<()> { * [Multipart streams](https://github.com/actix/examples/tree/master/multipart/) * [Simple websocket](https://github.com/actix/examples/tree/master/websocket/) * [Tera](https://github.com/actix/examples/tree/master/template_tera/) / - [Askama](https://github.com/actix/examples/tree/master/template_askama/) templates +* [Askama](https://github.com/actix/examples/tree/master/template_askama/) templates * [Diesel integration](https://github.com/actix/examples/tree/master/diesel/) * [r2d2](https://github.com/actix/examples/tree/master/r2d2/) -* [SSL / HTTP/2.0](https://github.com/actix/examples/tree/master/tls/) +* [OpenSSL](https://github.com/actix/examples/tree/master/openssl/) +* [Rustls](https://github.com/actix/examples/tree/master/rustls/) * [Tcp/Websocket chat](https://github.com/actix/examples/tree/master/websocket-chat/) * [Json](https://github.com/actix/examples/tree/master/json/) diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md index 10e408ede..8022ea4e8 100644 --- a/actix-cors/CHANGES.md +++ b/actix-cors/CHANGES.md @@ -1,8 +1,14 @@ # Changes -## [0.1.1] - unreleased +## [0.2.0] - 2019-12-20 -* Bump `derive_more` crate version to 0.15.0 +* Release + +## [0.2.0-alpha.3] - 2019-12-07 + +* Migrate to actix-web 2.0.0 + +* Bump `derive_more` crate version to 0.99.0 ## [0.1.0] - 2019-06-15 diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml index 091c94044..3fcd92f4f 100644 --- a/actix-cors/Cargo.toml +++ b/actix-cors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-cors" -version = "0.1.0" +version = "0.2.0" authors = ["Nikolay Kim "] description = "Cross-origin resource sharing (CORS) for Actix applications." readme = "README.md" @@ -10,14 +10,17 @@ repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-cors/" license = "MIT/Apache-2.0" edition = "2018" -#workspace = ".." +workspace = ".." [lib] name = "actix_cors" path = "src/lib.rs" [dependencies] -actix-web = "1.0.0" -actix-service = "0.4.0" -derive_more = "0.15.0" -futures = "0.1.25" +actix-web = "2.0.0-rc" +actix-service = "1.0.1" +derive_more = "0.99.2" +futures = "0.3.1" + +[dev-dependencies] +actix-rt = "1.0.0" diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index c76bae925..429fe9eab 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -11,7 +11,7 @@ //! use actix_cors::Cors; //! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; //! -//! fn index(req: HttpRequest) -> &'static str { +//! async fn index(req: HttpRequest) -> &'static str { //! "Hello world" //! } //! @@ -23,7 +23,8 @@ //! .allowed_methods(vec!["GET", "POST"]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_header(http::header::CONTENT_TYPE) -//! .max_age(3600)) +//! .max_age(3600) +//! .finish()) //! .service( //! web::resource("/index.html") //! .route(web::get().to(index)) @@ -39,18 +40,19 @@ //! //! Cors middleware automatically handle *OPTIONS* preflight request. use std::collections::HashSet; +use std::convert::TryFrom; use std::iter::FromIterator; use std::rc::Rc; +use std::task::{Context, Poll}; -use actix_service::{IntoTransform, Service, Transform}; +use actix_service::{Service, Transform}; use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; use actix_web::error::{Error, ResponseError, Result}; use actix_web::http::header::{self, HeaderName, HeaderValue}; -use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri}; +use actix_web::http::{self, Error as HttpError, Method, StatusCode, Uri}; use actix_web::HttpResponse; use derive_more::Display; -use futures::future::{ok, Either, Future, FutureResult}; -use futures::Poll; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; /// A set of errors that can occur during processing CORS #[derive(Debug, Display)] @@ -92,6 +94,10 @@ pub enum CorsError { } impl ResponseError for CorsError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } + fn error_response(&self) -> HttpResponse { HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into()) } @@ -269,7 +275,8 @@ impl Cors { pub fn allowed_methods(mut self, methods: U) -> Cors where U: IntoIterator, - Method: HttpTryFrom, + Method: TryFrom, + >::Error: Into, { self.methods = true; if let Some(cors) = cors(&mut self.cors, &self.error) { @@ -291,7 +298,8 @@ impl Cors { /// Set an allowed header pub fn allowed_header(mut self, header: H) -> Cors where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, { if let Some(cors) = cors(&mut self.cors, &self.error) { match HeaderName::try_from(header) { @@ -323,7 +331,8 @@ impl Cors { pub fn allowed_headers(mut self, headers: U) -> Cors where U: IntoIterator, - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, { if let Some(cors) = cors(&mut self.cors, &self.error) { for h in headers { @@ -357,7 +366,8 @@ impl Cors { pub fn expose_headers(mut self, headers: U) -> Cors where U: IntoIterator, - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, { for h in headers { match HeaderName::try_from(h) { @@ -456,25 +466,9 @@ impl Cors { } self } -} -fn cors<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut Inner> { - if err.is_some() { - return None; - } - parts.as_mut() -} - -impl IntoTransform for Cors -where - S: Service, Error = Error>, - S::Future: 'static, - B: 'static, -{ - fn into_transform(self) -> CorsFactory { + /// Construct cors middleware + pub fn finish(self) -> CorsFactory { let mut slf = if !self.methods { self.allowed_methods(vec![ Method::GET, @@ -521,6 +515,16 @@ where } } +fn cors<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + parts.as_mut() +} + /// `Middleware` for Cross-origin resource sharing support /// /// The Cors struct contains the settings for CORS requests to be validated and @@ -540,7 +544,7 @@ where type Error = Error; type InitError = (); type Transform = CorsMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CorsMiddleware { @@ -682,12 +686,12 @@ where type Response = ServiceResponse; type Error = Error; type Future = Either< - FutureResult, - Either>>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { @@ -698,7 +702,7 @@ where .and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head())) { - return Either::A(ok(req.error_response(e))); + return Either::Left(ok(req.error_response(e))); } // allowed headers @@ -751,39 +755,50 @@ where .finish() .into_body(); - Either::A(ok(req.into_response(res))) - } else if req.headers().contains_key(&header::ORIGIN) { - // Only check requests with a origin header. - if let Err(e) = self.inner.validate_origin(req.head()) { - return Either::A(ok(req.error_response(e))); + Either::Left(ok(req.into_response(res))) + } else { + if req.headers().contains_key(&header::ORIGIN) { + // Only check requests with a origin header. + if let Err(e) = self.inner.validate_origin(req.head()) { + return Either::Left(ok(req.error_response(e))); + } } let inner = self.inner.clone(); + let has_origin = req.headers().contains_key(&header::ORIGIN); + let fut = self.service.call(req); - Either::B(Either::B(Box::new(self.service.call(req).and_then( - move |mut res| { - if let Some(origin) = - inner.access_control_allow_origin(res.request().head()) - { - res.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); - }; + Either::Right( + async move { + let res = fut.await; - if let Some(ref expose) = inner.expose_hdrs { - res.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - HeaderValue::try_from(expose.as_str()).unwrap(), - ); - } - if inner.supports_credentials { - res.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } - if inner.vary_header { - let value = - if let Some(hdr) = res.headers_mut().get(&header::VARY) { + if has_origin { + let mut res = res?; + if let Some(origin) = + inner.access_control_allow_origin(res.request().head()) + { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.clone(), + ); + }; + + if let Some(ref expose) = inner.expose_hdrs { + res.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap(), + ); + } + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { + let value = if let Some(hdr) = + res.headers_mut().get(&header::VARY) + { let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); val.extend(hdr.as_bytes()); @@ -792,83 +807,71 @@ where } else { HeaderValue::from_static("Origin") }; - res.headers_mut().insert(header::VARY, value); + res.headers_mut().insert(header::VARY, value); + } + Ok(res) + } else { + res } - Ok(res) - }, - )))) - } else { - Either::B(Either::A(self.service.call(req))) + } + .boxed_local(), + ) } } } #[cfg(test)] mod tests { - use actix_service::{IntoService, Transform}; - use actix_web::test::{self, block_on, TestRequest}; + use actix_service::{fn_service, Transform}; + use actix_web::test::{self, TestRequest}; use super::*; - impl Cors { - fn finish(self, srv: F) -> CorsMiddleware - where - F: IntoService, - S: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - > + 'static, - S::Future: 'static, - B: 'static, - { - block_on( - IntoTransform::::into_transform(self) - .new_transform(srv.into_service()), - ) - .unwrap() - } - } - - #[test] + #[actix_rt::test] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] - fn cors_validates_illegal_allow_credentials() { - let _cors = Cors::new() - .supports_credentials() - .send_wildcard() - .finish(test::ok_service()); + async fn cors_validates_illegal_allow_credentials() { + let _cors = Cors::new().supports_credentials().send_wildcard().finish(); } - #[test] - fn validate_origin_allows_all_origins() { - let mut cors = Cors::new().finish(test::ok_service()); + #[actix_rt::test] + async fn validate_origin_allows_all_origins() { + let mut cors = Cors::new() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn default() { - let mut cors = - block_on(Cors::default().new_transform(test::ok_service())).unwrap(); + #[actix_rt::test] + async fn default() { + let mut cors = Cors::default() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_preflight() { + #[actix_rt::test] + async fn test_preflight() { let mut cors = Cors::new() .send_wildcard() .max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_header(header::CONTENT_TYPE) - .finish(test::ok_service()); + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) @@ -877,7 +880,7 @@ mod tests { assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_header("Origin", "https://www.example.com") @@ -897,7 +900,7 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"*"[..], resp.headers() @@ -943,13 +946,13 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - // #[test] + // #[actix_rt::test] // #[should_panic(expected = "MissingOrigin")] - // fn test_validate_missing_origin() { + // async fn test_validate_missing_origin() { // let cors = Cors::build() // .allowed_origin("https://www.example.com") // .finish(); @@ -957,12 +960,15 @@ mod tests { // cors.start(&req).unwrap(); // } - #[test] + #[actix_rt::test] #[should_panic(expected = "OriginNotAllowed")] - fn test_validate_not_allowed_origin() { + async fn test_validate_not_allowed_origin() { let cors = Cors::new() .allowed_origin("https://www.example.com") - .finish(test::ok_service()); + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.unknown.com") .method(Method::GET) @@ -972,26 +978,34 @@ mod tests { cors.inner.validate_allowed_headers(req.head()).unwrap(); } - #[test] - fn test_validate_origin() { + #[actix_rt::test] + async fn test_validate_origin() { let mut cors = Cors::new() .allowed_origin("https://www.example.com") - .finish(test::ok_service()); + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::GET) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_no_origin_response() { - let mut cors = Cors::new().disable_preflight().finish(test::ok_service()); + #[actix_rt::test] + async fn test_no_origin_response() { + let mut cors = Cors::new() + .disable_preflight() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::default().method(Method::GET).to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert!(resp .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) @@ -1000,7 +1014,7 @@ mod tests { let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"https://www.example.com"[..], resp.headers() @@ -1010,8 +1024,8 @@ mod tests { ); } - #[test] - fn test_response() { + #[actix_rt::test] + async fn test_response() { let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let mut cors = Cors::new() .send_wildcard() @@ -1021,13 +1035,16 @@ mod tests { .allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) - .finish(test::ok_service()); + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"*"[..], resp.headers() @@ -1065,15 +1082,18 @@ mod tests { .allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) - .finish(|req: ServiceRequest| { - req.into_response( + .finish() + .new_transform(fn_service(|req: ServiceRequest| { + ok(req.into_response( HttpResponse::Ok().header(header::VARY, "Accept").finish(), - ) - }); + )) + })) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"Accept, Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes() @@ -1083,13 +1103,16 @@ mod tests { .disable_vary_header() .allowed_origin("https://www.example.com") .allowed_origin("https://www.google.com") - .finish(test::ok_service()); + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; let origins_str = resp .headers() @@ -1101,19 +1124,22 @@ mod tests { assert_eq!("https://www.example.com", origins_str); } - #[test] - fn test_multiple_origins() { + #[actix_rt::test] + async fn test_multiple_origins() { let mut cors = Cors::new() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) - .finish(test::ok_service()); + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://example.com") .method(Method::GET) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"https://example.com"[..], resp.headers() @@ -1126,7 +1152,7 @@ mod tests { .method(Method::GET) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"https://example.org"[..], resp.headers() @@ -1136,20 +1162,23 @@ mod tests { ); } - #[test] - fn test_multiple_origins_preflight() { + #[actix_rt::test] + async fn test_multiple_origins_preflight() { let mut cors = Cors::new() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) - .finish(test::ok_service()); + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://example.com") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"https://example.com"[..], resp.headers() @@ -1163,7 +1192,7 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req); + let resp = test::call_service(&mut cors, req).await; assert_eq!( &b"https://example.org"[..], resp.headers() diff --git a/actix-files/CHANGES.md b/actix-files/CHANGES.md index 49ecdbffc..c4918b56d 100644 --- a/actix-files/CHANGES.md +++ b/actix-files/CHANGES.md @@ -1,21 +1,41 @@ # Changes -## [0.1.5] - unreleased +## [0.2.1] - 2019-12-22 + +* Use the same format for file URLs regardless of platforms + +## [0.2.0] - 2019-12-20 + +* Fix BodyEncoding trait import #1220 + +## [0.2.0-alpha.1] - 2019-12-07 + +* Migrate to `std::future` + +## [0.1.7] - 2019-11-06 + +* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) + +## [0.1.6] - 2019-10-14 + +* Add option to redirect to a slash-ended path `Files` #1132 + +## [0.1.5] - 2019-10-08 * Bump up `mime_guess` crate version to 2.0.1 * Bump up `percent-encoding` crate version to 2.1 +* Allow user defined request guards for `Files` #1113 + ## [0.1.4] - 2019-07-20 * Allow to disable `Content-Disposition` header #686 - ## [0.1.3] - 2019-06-28 * Do not set `Content-Length` header, let actix-http set it #930 - ## [0.1.2] - 2019-06-13 * Content-Length is 0 for NamedFile HEAD request #914 diff --git a/actix-files/Cargo.toml b/actix-files/Cargo.toml index 8f36cddc3..104eb3dfa 100644 --- a/actix-files/Cargo.toml +++ b/actix-files/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-files" -version = "0.1.4" +version = "0.2.1" authors = ["Nikolay Kim "] description = "Static files support for actix web." readme = "README.md" @@ -18,13 +18,13 @@ name = "actix_files" path = "src/lib.rs" [dependencies] -actix-web = { version = "1.0.2", default-features = false } -actix-http = "0.2.9" -actix-service = "0.4.1" +actix-web = { version = "2.0.0-rc", default-features = false } +actix-http = "1.0.1" +actix-service = "1.0.1" bitflags = "1" -bytes = "0.4" -futures = "0.1.25" -derive_more = "0.15.0" +bytes = "0.5.3" +futures = "0.3.1" +derive_more = "0.99.2" log = "0.4" mime = "0.3" mime_guess = "2.0.1" @@ -32,4 +32,5 @@ percent-encoding = "2.1" v_htmlescape = "0.4" [dev-dependencies] -actix-web = { version = "1.0.2", features=["ssl"] } +actix-rt = "1.0.0" +actix-web = { version = "2.0.0-rc", features=["openssl"] } diff --git a/actix-files/src/error.rs b/actix-files/src/error.rs index ca99fa813..49a46e58d 100644 --- a/actix-files/src/error.rs +++ b/actix-files/src/error.rs @@ -35,7 +35,7 @@ pub enum UriSegmentError { /// Return `BadRequest` for `UriSegmentError` impl ResponseError for UriSegmentError { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } diff --git a/actix-files/src/lib.rs b/actix-files/src/lib.rs index c99d3265f..d910b7d5f 100644 --- a/actix-files/src/lib.rs +++ b/actix-files/src/lib.rs @@ -4,23 +4,28 @@ use std::cell::RefCell; use std::fmt::Write; use std::fs::{DirEntry, File}; +use std::future::Future; use std::io::{Read, Seek}; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::{cmp, io}; -use actix_service::boxed::{self, BoxedNewService, BoxedService}; -use actix_service::{IntoNewService, NewService, Service}; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_web::dev::{ AppService, HttpServiceFactory, Payload, ResourceDef, ServiceRequest, ServiceResponse, }; use actix_web::error::{BlockingError, Error, ErrorInternalServerError}; -use actix_web::http::header::DispositionType; -use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder}; +use actix_web::guard::Guard; +use actix_web::http::header::{self, DispositionType}; +use actix_web::http::Method; +use actix_web::{web, FromRequest, HttpRequest, HttpResponse}; use bytes::Bytes; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, Poll, Stream}; +use futures::future::{ok, ready, Either, FutureExt, LocalBoxFuture, Ready}; +use futures::Stream; use mime; use mime_guess::from_ext; use percent_encoding::{utf8_percent_encode, CONTROLS}; @@ -34,8 +39,8 @@ use self::error::{FilesError, UriSegmentError}; pub use crate::named::NamedFile; pub use crate::range::HttpRange; -type HttpService = BoxedService; -type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; /// Return the MIME type associated with a filename extension (case-insensitive). /// If `ext` is empty or no associated type for the extension was found, returns @@ -45,6 +50,12 @@ pub fn file_extension_to_mime(ext: &str) -> mime::Mime { from_ext(ext).first_or_octet_stream() } +fn handle_error(err: BlockingError) -> Error { + match err { + BlockingError::Error(err) => err.into(), + BlockingError::Canceled => ErrorInternalServerError("Unexpected error"), + } +} #[doc(hidden)] /// A helper created from a `std::fs::File` which reads the file /// chunk-by-chunk on a `ThreadPool`. @@ -52,32 +63,29 @@ pub struct ChunkedReadFile { size: u64, offset: u64, file: Option, - fut: Option>>>, + fut: + Option>>>, counter: u64, } -fn handle_error(err: BlockingError) -> Error { - match err { - BlockingError::Error(err) => err.into(), - BlockingError::Canceled => ErrorInternalServerError("Unexpected error"), - } -} - impl Stream for ChunkedReadFile { - type Item = Bytes; - type Error = Error; + type Item = Result; - fn poll(&mut self) -> Poll, Error> { - if self.fut.is_some() { - return match self.fut.as_mut().unwrap().poll().map_err(handle_error)? { - Async::Ready((file, bytes)) => { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + if let Some(ref mut fut) = self.fut { + return match Pin::new(fut).poll(cx) { + Poll::Ready(Ok((file, bytes))) => { self.fut.take(); self.file = Some(file); self.offset += bytes.len() as u64; self.counter += bytes.len() as u64; - Ok(Async::Ready(Some(bytes))) + Poll::Ready(Some(Ok(bytes))) } - Async::NotReady => Ok(Async::NotReady), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(handle_error(e)))), + Poll::Pending => Poll::Pending, }; } @@ -86,22 +94,25 @@ impl Stream for ChunkedReadFile { let counter = self.counter; if size == counter { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { let mut file = self.file.take().expect("Use after completion"); - self.fut = Some(Box::new(web::block(move || { - let max_bytes: usize; - max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; - let mut buf = Vec::with_capacity(max_bytes); - file.seek(io::SeekFrom::Start(offset))?; - let nbytes = - file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; - if nbytes == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()); - } - Ok((file, Bytes::from(buf))) - }))); - self.poll() + self.fut = Some( + web::block(move || { + let max_bytes: usize; + max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + let mut buf = Vec::with_capacity(max_bytes); + file.seek(io::SeekFrom::Start(offset))?; + let nbytes = + file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; + if nbytes == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + Ok((file, Bytes::from(buf))) + }) + .boxed_local(), + ); + self.poll_next(cx) } } } @@ -144,7 +155,7 @@ impl Directory { // show file url as relative to static path macro_rules! encode_file_url { ($path:ident) => { - utf8_percent_encode(&$path.to_string_lossy(), CONTROLS) + utf8_percent_encode(&$path, CONTROLS) }; } @@ -167,7 +178,10 @@ fn directory_listing( if dir.is_visible(&entry) { let entry = entry.unwrap(); let p = match entry.path().strip_prefix(&dir.path) { - Ok(p) => base.join(p), + Ok(p) if cfg!(windows) => { + base.join(p).to_string_lossy().replace("\\", "/") + } + Ok(p) => base.join(p).to_string_lossy().into_owned(), Err(_) => continue, }; @@ -231,10 +245,12 @@ pub struct Files { directory: PathBuf, index: Option, show_index: bool, + redirect_to_slash: bool, default: Rc>>>, renderer: Rc, mime_override: Option>, file_flags: named::Flags, + guards: Option>>, } impl Clone for Files { @@ -243,11 +259,13 @@ impl Clone for Files { directory: self.directory.clone(), index: self.index.clone(), show_index: self.show_index, + redirect_to_slash: self.redirect_to_slash, default: self.default.clone(), renderer: self.renderer.clone(), file_flags: self.file_flags, path: self.path.clone(), mime_override: self.mime_override.clone(), + guards: self.guards.clone(), } } } @@ -257,22 +275,28 @@ impl Files { /// /// `File` uses `ThreadPool` for blocking filesystem operations. /// By default pool with 5x threads of available cpus is used. - /// Pool size can be changed by setting ACTIX_CPU_POOL environment variable. + /// Pool size can be changed by setting ACTIX_THREADPOOL environment variable. pub fn new>(path: &str, dir: T) -> Files { - let dir = dir.into().canonicalize().unwrap_or_else(|_| PathBuf::new()); - if !dir.is_dir() { - log::error!("Specified path is not a directory: {:?}", dir); - } + let orig_dir = dir.into(); + let dir = match orig_dir.canonicalize() { + Ok(canon_dir) => canon_dir, + Err(_) => { + log::error!("Specified path is not a directory: {:?}", orig_dir); + PathBuf::new() + } + }; Files { path: path.to_string(), directory: dir, index: None, show_index: false, + redirect_to_slash: false, default: Rc::new(RefCell::new(None)), renderer: Rc::new(directory_listing), mime_override: None, file_flags: named::Flags::default(), + guards: None, } } @@ -284,12 +308,19 @@ impl Files { self } + /// Redirects to a slash-ended path when browsing a directory. + /// + /// By default never redirect. + pub fn redirect_to_slash_directory(mut self) -> Self { + self.redirect_to_slash = true; + self + } + /// Set custom directory renderer pub fn files_listing_renderer(mut self, f: F) -> Self where - for<'r, 's> F: - Fn(&'r Directory, &'s HttpRequest) -> Result - + 'static, + for<'r, 's> F: Fn(&'r Directory, &'s HttpRequest) -> Result + + 'static, { self.renderer = Rc::new(f); self @@ -331,6 +362,15 @@ impl Files { self } + /// Specifies custom guards to use for directory listings and files. + /// + /// Default behaviour allows GET and HEAD. + #[inline] + pub fn use_guards(mut self, guards: G) -> Self { + self.guards = Some(Rc::new(Box::new(guards))); + self + } + /// Disable `Content-Disposition` header. /// /// By default Content-Disposition` header is enabled. @@ -343,8 +383,8 @@ impl Files { /// Sets default handler which is used when no matched file could be found. pub fn default_handler(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -352,8 +392,8 @@ impl Files { > + 'static, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|_| ()), + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|_| ()), ))))); self @@ -374,38 +414,41 @@ impl HttpServiceFactory for Files { } } -impl NewService for Files { - type Config = (); +impl ServiceFactory for Files { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; + type Config = (); type Service = FilesService; type InitError = (); - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn new_service(&self, _: &()) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { let mut srv = FilesService { directory: self.directory.clone(), index: self.index.clone(), show_index: self.show_index, + redirect_to_slash: self.redirect_to_slash, default: None, renderer: self.renderer.clone(), mime_override: self.mime_override.clone(), file_flags: self.file_flags, + guards: self.guards.clone(), }; if let Some(ref default) = *self.default.borrow() { - Box::new( - default - .new_service(&()) - .map(move |default| { + default + .new_service(()) + .map(move |result| match result { + Ok(default) => { srv.default = Some(default); - srv - }) - .map_err(|_| ()), - ) + Ok(srv) + } + Err(_) => Err(()), + }) + .boxed_local() } else { - Box::new(ok(srv)) + ok(srv).boxed_local() } } } @@ -414,10 +457,12 @@ pub struct FilesService { directory: PathBuf, index: Option, show_index: bool, + redirect_to_slash: bool, default: Option, renderer: Rc, mime_override: Option>, file_flags: named::Flags, + guards: Option>>, } impl FilesService { @@ -426,14 +471,14 @@ impl FilesService { e: io::Error, req: ServiceRequest, ) -> Either< - FutureResult, - Box>, + Ready>, + LocalBoxFuture<'static, Result>, > { log::debug!("Files: Failed to handle {}: {}", req.path(), e); if let Some(ref mut default) = self.default { - default.call(req) + Either::Right(default.call(req)) } else { - Either::A(ok(req.error_response(e))) + Either::Left(ok(req.error_response(e))) } } } @@ -443,20 +488,37 @@ impl Service for FilesService { type Response = ServiceResponse; type Error = Error; type Future = Either< - FutureResult, - Box>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - // let (req, pl) = req.into_parts(); + let is_method_valid = if let Some(guard) = &self.guards { + // execute user defined guards + (**guard).check(req.head()) + } else { + // default behaviour + match *req.method() { + Method::HEAD | Method::GET => true, + _ => false, + } + }; + + if !is_method_valid { + return Either::Left(ok(req.into_response( + actix_web::HttpResponse::MethodNotAllowed() + .header(header::CONTENT_TYPE, "text/plain") + .body("Request did not meet this resource's requirements."), + ))); + } let real_path = match PathBufWrp::get_pathbuf(req.match_info().path()) { Ok(item) => item, - Err(e) => return Either::A(ok(req.error_response(e))), + Err(e) => return Either::Left(ok(req.error_response(e))), }; // full filepath @@ -467,6 +529,16 @@ impl Service for FilesService { if path.is_dir() { if let Some(ref redir_index) = self.index { + if self.redirect_to_slash && !req.path().ends_with('/') { + let redirect_to = format!("{}/", req.path()); + return Either::Left(ok(req.into_response( + HttpResponse::Found() + .header(header::LOCATION, redirect_to) + .body("") + .into_body(), + ))); + } + let path = path.join(redir_index); match NamedFile::open(path) { @@ -479,7 +551,7 @@ impl Service for FilesService { named_file.flags = self.file_flags; let (req, _) = req.into_parts(); - Either::A(ok(match named_file.respond_to(&req) { + Either::Left(ok(match named_file.into_response(&req) { Ok(item) => ServiceResponse::new(req, item), Err(e) => ServiceResponse::from_err(e, req), })) @@ -491,11 +563,11 @@ impl Service for FilesService { let (req, _) = req.into_parts(); let x = (self.renderer)(&dir, &req); match x { - Ok(resp) => Either::A(ok(resp)), - Err(e) => Either::A(ok(ServiceResponse::from_err(e, req))), + Ok(resp) => Either::Left(ok(resp)), + Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), } } else { - Either::A(ok(ServiceResponse::from_err( + Either::Left(ok(ServiceResponse::from_err( FilesError::IsDirectory, req.into_parts().0, ))) @@ -511,11 +583,11 @@ impl Service for FilesService { named_file.flags = self.file_flags; let (req, _) = req.into_parts(); - match named_file.respond_to(&req) { + match named_file.into_response(&req) { Ok(item) => { - Either::A(ok(ServiceResponse::new(req.clone(), item))) + Either::Left(ok(ServiceResponse::new(req.clone(), item))) } - Err(e) => Either::A(ok(ServiceResponse::from_err(e, req))), + Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), } } Err(e) => self.handle_err(e, req), @@ -558,11 +630,11 @@ impl PathBufWrp { impl FromRequest for PathBufWrp { type Error = UriSegmentError; - type Future = Result; + type Future = Ready>; type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - PathBufWrp::get_pathbuf(req.match_info().path()) + ready(PathBufWrp::get_pathbuf(req.match_info().path())) } } @@ -573,19 +645,18 @@ mod tests { use std::ops::Add; use std::time::{Duration, SystemTime}; - use bytes::BytesMut; - use super::*; + use actix_web::guard; use actix_web::http::header::{ self, ContentDisposition, DispositionParam, DispositionType, }; use actix_web::http::{Method, StatusCode}; use actix_web::middleware::Compress; use actix_web::test::{self, TestRequest}; - use actix_web::App; + use actix_web::{App, Responder}; - #[test] - fn test_file_extension_to_mime() { + #[actix_rt::test] + async fn test_file_extension_to_mime() { let m = file_extension_to_mime("jpg"); assert_eq!(m, mime::IMAGE_JPEG); @@ -596,8 +667,8 @@ mod tests { assert_eq!(m, mime::APPLICATION_OCTET_STREAM); } - #[test] - fn test_if_modified_since_without_if_none_match() { + #[actix_rt::test] + async fn test_if_modified_since_without_if_none_match() { let file = NamedFile::open("Cargo.toml").unwrap(); let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); @@ -605,12 +676,12 @@ mod tests { let req = TestRequest::default() .header(header::IF_MODIFIED_SINCE, since) .to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_MODIFIED); } - #[test] - fn test_if_modified_since_with_if_none_match() { + #[actix_rt::test] + async fn test_if_modified_since_with_if_none_match() { let file = NamedFile::open("Cargo.toml").unwrap(); let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); @@ -619,12 +690,12 @@ mod tests { .header(header::IF_NONE_MATCH, "miss_etag") .header(header::IF_MODIFIED_SINCE, since) .to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_ne!(resp.status(), StatusCode::NOT_MODIFIED); } - #[test] - fn test_named_file_text() { + #[actix_rt::test] + async fn test_named_file_text() { assert!(NamedFile::open("test--").is_err()); let mut file = NamedFile::open("Cargo.toml").unwrap(); { @@ -636,7 +707,7 @@ mod tests { } let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/x-toml" @@ -647,8 +718,8 @@ mod tests { ); } - #[test] - fn test_named_file_content_disposition() { + #[actix_rt::test] + async fn test_named_file_content_disposition() { assert!(NamedFile::open("test--").is_err()); let mut file = NamedFile::open("Cargo.toml").unwrap(); { @@ -660,7 +731,7 @@ mod tests { } let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!( resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), "inline; filename=\"Cargo.toml\"" @@ -670,12 +741,37 @@ mod tests { .unwrap() .disable_content_disposition(); let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert!(resp.headers().get(header::CONTENT_DISPOSITION).is_none()); } - #[test] - fn test_named_file_set_content_type() { + #[actix_rt::test] + async fn test_named_file_non_ascii_file_name() { + let mut file = + NamedFile::from_file(File::open("Cargo.toml").unwrap(), "貨物.toml") + .unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-toml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"貨物.toml\"; filename*=UTF-8''%E8%B2%A8%E7%89%A9.toml" + ); + } + + #[actix_rt::test] + async fn test_named_file_set_content_type() { let mut file = NamedFile::open("Cargo.toml") .unwrap() .set_content_type(mime::TEXT_XML); @@ -688,7 +784,7 @@ mod tests { } let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/xml" @@ -699,8 +795,8 @@ mod tests { ); } - #[test] - fn test_named_file_image() { + #[actix_rt::test] + async fn test_named_file_image() { let mut file = NamedFile::open("tests/test.png").unwrap(); { file.file(); @@ -711,7 +807,7 @@ mod tests { } let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "image/png" @@ -722,8 +818,8 @@ mod tests { ); } - #[test] - fn test_named_file_image_attachment() { + #[actix_rt::test] + async fn test_named_file_image_attachment() { let cd = ContentDisposition { disposition: DispositionType::Attachment, parameters: vec![DispositionParam::Filename(String::from("test.png"))], @@ -740,7 +836,7 @@ mod tests { } let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "image/png" @@ -751,8 +847,8 @@ mod tests { ); } - #[test] - fn test_named_file_binary() { + #[actix_rt::test] + async fn test_named_file_binary() { let mut file = NamedFile::open("tests/test.binary").unwrap(); { file.file(); @@ -763,7 +859,7 @@ mod tests { } let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "application/octet-stream" @@ -774,8 +870,8 @@ mod tests { ); } - #[test] - fn test_named_file_status_code_text() { + #[actix_rt::test] + async fn test_named_file_status_code_text() { let mut file = NamedFile::open("Cargo.toml") .unwrap() .set_status_code(StatusCode::NOT_FOUND); @@ -788,7 +884,7 @@ mod tests { } let req = TestRequest::default().to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/x-toml" @@ -800,8 +896,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::NOT_FOUND); } - #[test] - fn test_mime_override() { + #[actix_rt::test] + async fn test_mime_override() { fn all_attachment(_: &mime::Name) -> DispositionType { DispositionType::Attachment } @@ -812,10 +908,11 @@ mod tests { .mime_override(all_attachment) .index_file("Cargo.toml"), ), - ); + ) + .await; let request = TestRequest::get().uri("/").to_request(); - let response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; assert_eq!(response.status(), StatusCode::OK); let content_disposition = response @@ -828,18 +925,19 @@ mod tests { assert_eq!(content_disposition, "attachment; filename=\"Cargo.toml\""); } - #[test] - fn test_named_file_ranges_status_code() { + #[actix_rt::test] + async fn test_named_file_ranges_status_code() { let mut srv = test::init_service( App::new().service(Files::new("/test", ".").index_file("Cargo.toml")), - ); + ) + .await; // Valid range header let request = TestRequest::get() .uri("/t%65st/Cargo.toml") .header(header::RANGE, "bytes=10-20") .to_request(); - let response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT); // Invalid range header @@ -847,16 +945,17 @@ mod tests { .uri("/t%65st/Cargo.toml") .header(header::RANGE, "bytes=1-0") .to_request(); - let response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); } - #[test] - fn test_named_file_content_range_headers() { + #[actix_rt::test] + async fn test_named_file_content_range_headers() { let mut srv = test::init_service( App::new().service(Files::new("/test", ".").index_file("tests/test.binary")), - ); + ) + .await; // Valid range header let request = TestRequest::get() @@ -864,7 +963,7 @@ mod tests { .header(header::RANGE, "bytes=10-20") .to_request(); - let response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; let contentrange = response .headers() .get(header::CONTENT_RANGE) @@ -879,7 +978,7 @@ mod tests { .uri("/t%65st/tests/test.binary") .header(header::RANGE, "bytes=10-5") .to_request(); - let response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; let contentrange = response .headers() @@ -891,20 +990,21 @@ mod tests { assert_eq!(contentrange, "bytes */100"); } - #[test] - fn test_named_file_content_length_headers() { + #[actix_rt::test] + async fn test_named_file_content_length_headers() { // use actix_web::body::{MessageBody, ResponseBody}; let mut srv = test::init_service( App::new().service(Files::new("test", ".").index_file("tests/test.binary")), - ); + ) + .await; // Valid range header let request = TestRequest::get() .uri("/t%65st/tests/test.binary") .header(header::RANGE, "bytes=10-20") .to_request(); - let _response = test::call_service(&mut srv, request); + let _response = test::call_service(&mut srv, request).await; // let contentlength = response // .headers() @@ -919,7 +1019,7 @@ mod tests { .uri("/t%65st/tests/test.binary") .header(header::RANGE, "bytes=10-8") .to_request(); - let response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); // Without range header @@ -927,7 +1027,7 @@ mod tests { .uri("/t%65st/tests/test.binary") // .no_default_headers() .to_request(); - let _response = test::call_service(&mut srv, request); + let _response = test::call_service(&mut srv, request).await; // let contentlength = response // .headers() @@ -941,7 +1041,7 @@ mod tests { let request = TestRequest::get() .uri("/t%65st/tests/test.binary") .to_request(); - let mut response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; // with enabled compression // { @@ -954,28 +1054,24 @@ mod tests { // assert_eq!(te, "chunked"); // } - let bytes = - test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); + let bytes = test::read_body(response).await; let data = Bytes::from(fs::read("tests/test.binary").unwrap()); - assert_eq!(bytes.freeze(), data); + assert_eq!(bytes, data); } - #[test] - fn test_head_content_length_headers() { + #[actix_rt::test] + async fn test_head_content_length_headers() { let mut srv = test::init_service( App::new().service(Files::new("test", ".").index_file("tests/test.binary")), - ); + ) + .await; // Valid range header let request = TestRequest::default() .method(Method::HEAD) .uri("/t%65st/tests/test.binary") .to_request(); - let _response = test::call_service(&mut srv, request); + let _response = test::call_service(&mut srv, request).await; // TODO: fix check // let contentlength = response @@ -987,77 +1083,100 @@ mod tests { // assert_eq!(contentlength, "100"); } - #[test] - fn test_static_files_with_spaces() { + #[actix_rt::test] + async fn test_static_files_with_spaces() { let mut srv = test::init_service( App::new().service(Files::new("/", ".").index_file("Cargo.toml")), - ); + ) + .await; let request = TestRequest::get() .uri("/tests/test%20space.binary") .to_request(); - let mut response = test::call_service(&mut srv, request); + let response = test::call_service(&mut srv, request).await; assert_eq!(response.status(), StatusCode::OK); - let bytes = - test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); - + let bytes = test::read_body(response).await; let data = Bytes::from(fs::read("tests/test space.binary").unwrap()); - assert_eq!(bytes.freeze(), data); + assert_eq!(bytes, data); } - #[test] - fn test_named_file_not_allowed() { - let file = NamedFile::open("Cargo.toml").unwrap(); + #[actix_rt::test] + async fn test_files_not_allowed() { + let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + let req = TestRequest::default() + .uri("/Cargo.toml") .method(Method::POST) - .to_http_request(); - let resp = file.respond_to(&req).unwrap(); + .to_request(); + + let resp = test::call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let file = NamedFile::open("Cargo.toml").unwrap(); - let req = TestRequest::default().method(Method::PUT).to_http_request(); - let resp = file.respond_to(&req).unwrap(); + let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + let req = TestRequest::default() + .method(Method::PUT) + .uri("/Cargo.toml") + .to_request(); + let resp = test::call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); } - #[test] - fn test_named_file_content_encoding() { + #[actix_rt::test] + async fn test_files_guards() { + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").use_guards(guard::Post())), + ) + .await; + + let req = TestRequest::default() + .uri("/Cargo.toml") + .method(Method::POST) + .to_request(); + + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_named_file_content_encoding() { let mut srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| { - NamedFile::open("Cargo.toml") - .unwrap() - .set_content_encoding(header::ContentEncoding::Identity) + async { + NamedFile::open("Cargo.toml") + .unwrap() + .set_content_encoding(header::ContentEncoding::Identity) + } }), - )); + )) + .await; let request = TestRequest::get() .uri("/") .header(header::ACCEPT_ENCODING, "gzip") .to_request(); - let res = test::call_service(&mut srv, request); + let res = test::call_service(&mut srv, request).await; assert_eq!(res.status(), StatusCode::OK); assert!(!res.headers().contains_key(header::CONTENT_ENCODING)); } - #[test] - fn test_named_file_content_encoding_gzip() { + #[actix_rt::test] + async fn test_named_file_content_encoding_gzip() { let mut srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| { - NamedFile::open("Cargo.toml") - .unwrap() - .set_content_encoding(header::ContentEncoding::Gzip) + async { + NamedFile::open("Cargo.toml") + .unwrap() + .set_content_encoding(header::ContentEncoding::Gzip) + } }), - )); + )) + .await; let request = TestRequest::get() .uri("/") .header(header::ACCEPT_ENCODING, "gzip") .to_request(); - let res = test::call_service(&mut srv, request); + let res = test::call_service(&mut srv, request).await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.headers() @@ -1069,81 +1188,101 @@ mod tests { ); } - #[test] - fn test_named_file_allowed_method() { + #[actix_rt::test] + async fn test_named_file_allowed_method() { let req = TestRequest::default().method(Method::GET).to_http_request(); let file = NamedFile::open("Cargo.toml").unwrap(); - let resp = file.respond_to(&req).unwrap(); + let resp = file.respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_static_files() { + #[actix_rt::test] + async fn test_static_files() { let mut srv = test::init_service( App::new().service(Files::new("/", ".").show_files_listing()), - ); + ) + .await; let req = TestRequest::with_uri("/missing").to_request(); - let resp = test::call_service(&mut srv, req); + let resp = test::call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let mut srv = test::init_service(App::new().service(Files::new("/", "."))); + let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; let req = TestRequest::default().to_request(); - let resp = test::call_service(&mut srv, req); + let resp = test::call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::NOT_FOUND); let mut srv = test::init_service( App::new().service(Files::new("/", ".").show_files_listing()), - ); + ) + .await; let req = TestRequest::with_uri("/tests").to_request(); - let mut resp = test::call_service(&mut srv, req); + let resp = test::call_service(&mut srv, req).await; assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/html; charset=utf-8" ); - let bytes = - test::block_on(resp.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); + let bytes = test::read_body(resp).await; assert!(format!("{:?}", bytes).contains("/tests/test.png")); } - #[test] - fn test_static_files_bad_directory() { + #[actix_rt::test] + async fn test_redirect_to_slash_directory() { + // should not redirect if no index + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").redirect_to_slash_directory()), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + // should redirect if index present + let mut srv = test::init_service( + App::new().service( + Files::new("/", ".") + .index_file("test.png") + .redirect_to_slash_directory(), + ), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::FOUND); + + // should not redirect if the path is wrong + let req = TestRequest::with_uri("/not_existing").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_static_files_bad_directory() { let _st: Files = Files::new("/", "missing"); let _st: Files = Files::new("/", "Cargo.toml"); } - #[test] - fn test_default_handler_file_missing() { - let mut st = test::block_on( - Files::new("/", ".") - .default_handler(|req: ServiceRequest| { - Ok(req.into_response(HttpResponse::Ok().body("default content"))) - }) - .new_service(&()), - ) - .unwrap(); + #[actix_rt::test] + async fn test_default_handler_file_missing() { + let mut st = Files::new("/", ".") + .default_handler(|req: ServiceRequest| { + ok(req.into_response(HttpResponse::Ok().body("default content"))) + }) + .new_service(()) + .await + .unwrap(); let req = TestRequest::with_uri("/missing").to_srv_request(); - let mut resp = test::call_service(&mut st, req); + let resp = test::call_service(&mut st, req).await; assert_eq!(resp.status(), StatusCode::OK); - let bytes = - test::block_on(resp.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); - - assert_eq!(bytes.freeze(), Bytes::from_static(b"default content")); + let bytes = test::read_body(resp).await; + assert_eq!(bytes, Bytes::from_static(b"default content")); } - // #[test] - // fn test_serve_index() { + // #[actix_rt::test] + // async fn test_serve_index() { // let st = Files::new(".").index_file("test.binary"); // let req = TestRequest::default().uri("/tests").finish(); @@ -1188,8 +1327,8 @@ mod tests { // assert_eq!(resp.status(), StatusCode::NOT_FOUND); // } - // #[test] - // fn test_serve_index_nested() { + // #[actix_rt::test] + // async fn test_serve_index_nested() { // let st = Files::new(".").index_file("mod.rs"); // let req = TestRequest::default().uri("/src/client").finish(); // let resp = st.handle(&req).respond_to(&req).unwrap(); @@ -1205,7 +1344,7 @@ mod tests { // ); // } - // #[test] + // #[actix_rt::test] // fn integration_serve_index() { // let mut srv = test::TestServer::with_factory(|| { // App::new().handler( @@ -1238,7 +1377,7 @@ mod tests { // assert_eq!(response.status(), StatusCode::NOT_FOUND); // } - // #[test] + // #[actix_rt::test] // fn integration_percent_encoded() { // let mut srv = test::TestServer::with_factory(|| { // App::new().handler( @@ -1256,8 +1395,8 @@ mod tests { // assert_eq!(response.status(), StatusCode::OK); // } - #[test] - fn test_path_buf() { + #[actix_rt::test] + async fn test_path_buf() { assert_eq!( PathBufWrp::get_pathbuf("/test/.tt").map(|t| t.0), Err(UriSegmentError::BadStart('.')) diff --git a/actix-files/src/named.rs b/actix-files/src/named.rs index f548a7a1b..fdb055998 100644 --- a/actix-files/src/named.rs +++ b/actix-files/src/named.rs @@ -12,12 +12,13 @@ use mime; use mime_guess::from_path; use actix_http::body::SizedStream; +use actix_web::dev::BodyEncoding; use actix_web::http::header::{ - self, ContentDisposition, DispositionParam, DispositionType, + self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, }; -use actix_web::http::{ContentEncoding, Method, StatusCode}; -use actix_web::middleware::BodyEncoding; +use actix_web::http::{ContentEncoding, StatusCode}; use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder}; +use futures::future::{ready, Ready}; use crate::range::HttpRange; use crate::ChunkedReadFile; @@ -93,9 +94,18 @@ impl NamedFile { mime::IMAGE | mime::TEXT | mime::VIDEO => DispositionType::Inline, _ => DispositionType::Attachment, }; + let mut parameters = + vec![DispositionParam::Filename(String::from(filename.as_ref()))]; + if !filename.is_ascii() { + parameters.push(DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: filename.into_owned().into_bytes(), + })) + } let cd = ContentDisposition { disposition: disposition_type, - parameters: vec![DispositionParam::Filename(filename.into_owned())], + parameters: parameters, }; (ct, cd) }; @@ -246,62 +256,8 @@ impl NamedFile { pub(crate) fn last_modified(&self) -> Option { self.modified.map(|mtime| mtime.into()) } -} -impl Deref for NamedFile { - type Target = File; - - fn deref(&self) -> &File { - &self.file - } -} - -impl DerefMut for NamedFile { - fn deref_mut(&mut self) -> &mut File { - &mut self.file - } -} - -/// Returns true if `req` has no `If-Match` header or one which matches `etag`. -fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { - match req.get_header::() { - None | Some(header::IfMatch::Any) => true, - Some(header::IfMatch::Items(ref items)) => { - if let Some(some_etag) = etag { - for item in items { - if item.strong_eq(some_etag) { - return true; - } - } - } - false - } - } -} - -/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`. -fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { - match req.get_header::() { - Some(header::IfNoneMatch::Any) => false, - Some(header::IfNoneMatch::Items(ref items)) => { - if let Some(some_etag) = etag { - for item in items { - if item.weak_eq(some_etag) { - return false; - } - } - } - true - } - None => true, - } -} - -impl Responder for NamedFile { - type Error = Error; - type Future = Result; - - fn respond_to(self, req: &HttpRequest) -> Self::Future { + pub fn into_response(self, req: &HttpRequest) -> Result { if self.status_code != StatusCode::OK { let mut resp = HttpResponse::build(self.status_code); resp.set(header::ContentType(self.content_type.clone())) @@ -324,16 +280,6 @@ impl Responder for NamedFile { return Ok(resp.streaming(reader)); } - match *req.method() { - Method::HEAD | Method::GET => (), - _ => { - return Ok(HttpResponse::MethodNotAllowed() - .header(header::CONTENT_TYPE, "text/plain") - .header(header::ALLOW, "GET, HEAD") - .body("This resource only supports GET and HEAD.")); - } - } - let etag = if self.flags.contains(Flags::ETAG) { self.etag() } else { @@ -443,8 +389,67 @@ impl Responder for NamedFile { counter: 0, }; if offset != 0 || length != self.md.len() { - return Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)); - }; - Ok(resp.body(SizedStream::new(length, reader))) + Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)) + } else { + Ok(resp.body(SizedStream::new(length, reader))) + } + } +} + +impl Deref for NamedFile { + type Target = File; + + fn deref(&self) -> &File { + &self.file + } +} + +impl DerefMut for NamedFile { + fn deref_mut(&mut self) -> &mut File { + &mut self.file + } +} + +/// Returns true if `req` has no `If-Match` header or one which matches `etag`. +fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { + match req.get_header::() { + None | Some(header::IfMatch::Any) => true, + Some(header::IfMatch::Items(ref items)) => { + if let Some(some_etag) = etag { + for item in items { + if item.strong_eq(some_etag) { + return true; + } + } + } + false + } + } +} + +/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`. +fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { + match req.get_header::() { + Some(header::IfNoneMatch::Any) => false, + Some(header::IfNoneMatch::Items(ref items)) => { + if let Some(some_etag) = etag { + for item in items { + if item.weak_eq(some_etag) { + return false; + } + } + } + true + } + None => true, + } +} + +impl Responder for NamedFile { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + ready(self.into_response(req)) } } diff --git a/actix-framed/Cargo.toml b/actix-framed/Cargo.toml index 321041c7e..7e322e1d4 100644 --- a/actix-framed/Cargo.toml +++ b/actix-framed/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-framed" -version = "0.2.1" +version = "0.3.0" authors = ["Nikolay Kim "] description = "Actix framed app server" readme = "README.md" @@ -13,26 +13,25 @@ categories = ["network-programming", "asynchronous", "web-programming::websocket"] license = "MIT/Apache-2.0" edition = "2018" -workspace =".." [lib] name = "actix_framed" path = "src/lib.rs" [dependencies] -actix-codec = "0.1.2" -actix-service = "0.4.1" -actix-router = "0.1.2" -actix-rt = "0.2.2" -actix-http = "0.2.7" -actix-server-config = "0.1.2" +actix-codec = "0.2.0" +actix-service = "1.0.1" +actix-router = "0.2.1" +actix-rt = "1.0.0" +actix-http = "1.0.1" -bytes = "0.4" -futures = "0.1.25" +bytes = "0.5.3" +futures = "0.3.1" +pin-project = "0.4.6" log = "0.4" [dev-dependencies] -actix-server = { version = "0.6.0", features=["ssl"] } -actix-connect = { version = "0.2.0", features=["ssl"] } -actix-http-test = { version = "0.2.4", features=["ssl"] } -actix-utils = "0.4.4" +actix-server = "1.0.0" +actix-connect = { version = "1.0.0", features=["openssl"] } +actix-http-test = { version = "1.0.0", features=["openssl"] } +actix-utils = "1.0.3" diff --git a/actix-framed/changes.md b/actix-framed/changes.md index 6e67e00d8..41c7aed0e 100644 --- a/actix-framed/changes.md +++ b/actix-framed/changes.md @@ -1,5 +1,9 @@ # Changes +## [0.3.0] - 2019-12-25 + +* Migrate to actix-http 1.0 + ## [0.2.1] - 2019-07-20 * Remove unneeded actix-utils dependency diff --git a/actix-framed/src/app.rs b/actix-framed/src/app.rs index ad5b1ec26..e4b91e6c4 100644 --- a/actix-framed/src/app.rs +++ b/actix-framed/src/app.rs @@ -1,21 +1,23 @@ +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::h1::{Codec, SendResponse}; use actix_http::{Error, Request, Response}; use actix_router::{Path, Router, Url}; -use actix_server_config::ServerConfig; -use actix_service::{IntoNewService, NewService, Service}; -use futures::{Async, Future, Poll}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture}; use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService}; use crate::request::FramedRequest; use crate::state::State; -type BoxedResponse = Box>; +type BoxedResponse = LocalBoxFuture<'static, Result<(), Error>>; pub trait HttpServiceFactory { - type Factory: NewService; + type Factory: ServiceFactory; fn path(&self) -> &str; @@ -48,19 +50,19 @@ impl FramedApp { pub fn service(mut self, factory: U) -> Self where U: HttpServiceFactory, - U::Factory: NewService< + U::Factory: ServiceFactory< Config = (), Request = FramedRequest, Response = (), Error = Error, InitError = (), > + 'static, - ::Future: 'static, - ::Service: Service< + ::Future: 'static, + ::Service: Service< Request = FramedRequest, Response = (), Error = Error, - Future = Box>, + Future = LocalBoxFuture<'static, Result<(), Error>>, >, { let path = factory.path().to_string(); @@ -70,12 +72,12 @@ impl FramedApp { } } -impl IntoNewService> for FramedApp +impl IntoServiceFactory> for FramedApp where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, S: 'static, { - fn into_new_service(self) -> FramedAppFactory { + fn into_factory(self) -> FramedAppFactory { FramedAppFactory { state: self.state, services: Rc::new(self.services), @@ -89,12 +91,12 @@ pub struct FramedAppFactory { services: Rc>)>>, } -impl NewService for FramedAppFactory +impl ServiceFactory for FramedAppFactory where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, S: 'static, { - type Config = ServerConfig; + type Config = (); type Request = (Request, Framed); type Response = (); type Error = Error; @@ -102,7 +104,7 @@ where type Service = FramedAppService; type Future = CreateService; - fn new_service(&self, _: &ServerConfig) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { CreateService { fut: self .services @@ -110,7 +112,7 @@ where .map(|(path, service)| { CreateServiceItem::Future( Some(path.clone()), - service.new_service(&()), + service.new_service(()), ) }) .collect(), @@ -128,28 +130,30 @@ pub struct CreateService { enum CreateServiceItem { Future( Option, - Box>, Error = ()>>, + LocalBoxFuture<'static, Result>, ()>>, ), Service(String, BoxedHttpService>), } impl Future for CreateService where - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { - type Item = FramedAppService; - type Error = (); + type Output = Result, ()>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut done = true; // poll http services for item in &mut self.fut { let res = match item { CreateServiceItem::Future(ref mut path, ref mut fut) => { - match fut.poll()? { - Async::Ready(service) => Some((path.take().unwrap(), service)), - Async::NotReady => { + match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(service)) => { + Some((path.take().unwrap(), service)) + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { done = false; None } @@ -176,12 +180,12 @@ where } router }); - Ok(Async::Ready(FramedAppService { + Poll::Ready(Ok(FramedAppService { router: router.finish(), state: self.state.clone(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -193,15 +197,15 @@ pub struct FramedAppService { impl Service for FramedAppService where - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { type Request = (Request, Framed); type Response = (); type Error = Error; type Future = BoxedResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { @@ -210,8 +214,8 @@ where if let Some((srv, _info)) = self.router.recognize_mut(&mut path) { return srv.call(FramedRequest::new(req, framed, path, self.state.clone())); } - Box::new( - SendResponse::new(framed, Response::NotFound().finish()).then(|_| Ok(())), - ) + SendResponse::new(framed, Response::NotFound().finish()) + .then(|_| ok(())) + .boxed_local() } } diff --git a/actix-framed/src/helpers.rs b/actix-framed/src/helpers.rs index b343301f3..29492e45b 100644 --- a/actix-framed/src/helpers.rs +++ b/actix-framed/src/helpers.rs @@ -1,36 +1,38 @@ +use std::task::{Context, Poll}; + use actix_http::Error; -use actix_service::{NewService, Service}; -use futures::{Future, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{FutureExt, LocalBoxFuture}; pub(crate) type BoxedHttpService = Box< dyn Service< Request = Req, Response = (), Error = Error, - Future = Box>, + Future = LocalBoxFuture<'static, Result<(), Error>>, >, >; pub(crate) type BoxedHttpNewService = Box< - dyn NewService< + dyn ServiceFactory< Config = (), Request = Req, Response = (), Error = Error, InitError = (), Service = BoxedHttpService, - Future = Box, Error = ()>>, + Future = LocalBoxFuture<'static, Result, ()>>, >, >; -pub(crate) struct HttpNewService(T); +pub(crate) struct HttpNewService(T); impl HttpNewService where - T: NewService, + T: ServiceFactory, T::Response: 'static, T::Future: 'static, - T::Service: Service>> + 'static, + T::Service: Service>> + 'static, ::Future: 'static, { pub fn new(service: T) -> Self { @@ -38,12 +40,12 @@ where } } -impl NewService for HttpNewService +impl ServiceFactory for HttpNewService where - T: NewService, + T: ServiceFactory, T::Request: 'static, T::Future: 'static, - T::Service: Service>> + 'static, + T::Service: Service>> + 'static, ::Future: 'static, { type Config = (); @@ -52,13 +54,19 @@ where type Error = Error; type InitError = (); type Service = BoxedHttpService; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn new_service(&self, _: &()) -> Self::Future { - Box::new(self.0.new_service(&()).map_err(|_| ()).and_then(|service| { - let service: BoxedHttpService<_> = Box::new(HttpServiceWrapper { service }); - Ok(service) - })) + fn new_service(&self, _: ()) -> Self::Future { + let fut = self.0.new_service(()); + + async move { + fut.await.map_err(|_| ()).map(|service| { + let service: BoxedHttpService<_> = + Box::new(HttpServiceWrapper { service }); + service + }) + } + .boxed_local() } } @@ -70,7 +78,7 @@ impl Service for HttpServiceWrapper where T: Service< Response = (), - Future = Box>, + Future = LocalBoxFuture<'static, Result<(), Error>>, Error = Error, >, T::Request: 'static, @@ -78,10 +86,10 @@ where type Request = T::Request; type Response = (); type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result<(), Error>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: Self::Request) -> Self::Future { diff --git a/actix-framed/src/request.rs b/actix-framed/src/request.rs index bdcdd7026..1872dcc25 100644 --- a/actix-framed/src/request.rs +++ b/actix-framed/src/request.rs @@ -123,7 +123,9 @@ impl FramedRequest { #[cfg(test)] mod tests { - use actix_http::http::{HeaderName, HeaderValue, HttpTryFrom}; + use std::convert::TryFrom; + + use actix_http::http::{HeaderName, HeaderValue}; use actix_http::test::{TestBuffer, TestRequest}; use super::*; diff --git a/actix-framed/src/route.rs b/actix-framed/src/route.rs index 5beb24165..793f46273 100644 --- a/actix-framed/src/route.rs +++ b/actix-framed/src/route.rs @@ -1,11 +1,12 @@ use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::{http::Method, Error}; -use actix_service::{NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use log::error; use crate::app::HttpServiceFactory; @@ -15,11 +16,11 @@ use crate::request::FramedRequest; /// /// Route uses builder-like pattern for configuration. /// If handler is not explicitly set, default *404 Not Found* handler is used. -pub struct FramedRoute { +pub struct FramedRoute { handler: F, pattern: String, methods: Vec, - state: PhantomData<(Io, S, R)>, + state: PhantomData<(Io, S, R, E)>, } impl FramedRoute { @@ -53,12 +54,12 @@ impl FramedRoute { self } - pub fn to(self, handler: F) -> FramedRoute + pub fn to(self, handler: F) -> FramedRoute where F: FnMut(FramedRequest) -> R, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Debug, + R: Future> + 'static, + + E: fmt::Debug, { FramedRoute { handler, @@ -69,15 +70,14 @@ impl FramedRoute { } } -impl HttpServiceFactory for FramedRoute +impl HttpServiceFactory for FramedRoute where Io: AsyncRead + AsyncWrite + 'static, F: FnMut(FramedRequest) -> R + Clone, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Display, + R: Future> + 'static, + E: fmt::Display, { - type Factory = FramedRouteFactory; + type Factory = FramedRouteFactory; fn path(&self) -> &str { &self.pattern @@ -92,29 +92,28 @@ where } } -pub struct FramedRouteFactory { +pub struct FramedRouteFactory { handler: F, methods: Vec, - _t: PhantomData<(Io, S, R)>, + _t: PhantomData<(Io, S, R, E)>, } -impl NewService for FramedRouteFactory +impl ServiceFactory for FramedRouteFactory where Io: AsyncRead + AsyncWrite + 'static, F: FnMut(FramedRequest) -> R + Clone, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Display, + R: Future> + 'static, + E: fmt::Display, { type Config = (); type Request = FramedRequest; type Response = (); type Error = Error; type InitError = (); - type Service = FramedRouteService; - type Future = FutureResult; + type Service = FramedRouteService; + type Future = Ready>; - fn new_service(&self, _: &()) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { ok(FramedRouteService { handler: self.handler.clone(), methods: self.methods.clone(), @@ -123,35 +122,38 @@ where } } -pub struct FramedRouteService { +pub struct FramedRouteService { handler: F, methods: Vec, - _t: PhantomData<(Io, S, R)>, + _t: PhantomData<(Io, S, R, E)>, } -impl Service for FramedRouteService +impl Service for FramedRouteService where Io: AsyncRead + AsyncWrite + 'static, F: FnMut(FramedRequest) -> R + Clone, - R: IntoFuture, - R::Future: 'static, - R::Error: fmt::Display, + R: Future> + 'static, + E: fmt::Display, { type Request = FramedRequest; type Response = (); type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result<(), Error>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: FramedRequest) -> Self::Future { - Box::new((self.handler)(req).into_future().then(|res| { + let fut = (self.handler)(req); + + async move { + let res = fut.await; if let Err(e) = res { error!("Error in request handler: {}", e); } Ok(()) - })) + } + .boxed_local() } } diff --git a/actix-framed/src/service.rs b/actix-framed/src/service.rs index fbbc9fbef..92393ca75 100644 --- a/actix-framed/src/service.rs +++ b/actix-framed/src/service.rs @@ -1,4 +1,6 @@ use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::body::BodySize; @@ -6,9 +8,9 @@ use actix_http::error::ResponseError; use actix_http::h1::{Codec, Message}; use actix_http::ws::{verify_handshake, HandshakeError}; use actix_http::{Request, Response}; -use actix_service::{NewService, Service}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll, Sink}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{err, ok, Either, Ready}; +use futures::Future; /// Service that verifies incoming request if it is valid websocket /// upgrade request. In case of error returns `HandshakeError` @@ -22,16 +24,16 @@ impl Default for VerifyWebSockets { } } -impl NewService for VerifyWebSockets { +impl ServiceFactory for VerifyWebSockets { type Config = C; type Request = (Request, Framed); type Response = (Request, Framed); type Error = (HandshakeError, Framed); type InitError = (); type Service = VerifyWebSockets; - type Future = FutureResult; + type Future = Ready>; - fn new_service(&self, _: &C) -> Self::Future { + fn new_service(&self, _: C) -> Self::Future { ok(VerifyWebSockets { _t: PhantomData }) } } @@ -40,16 +42,16 @@ impl Service for VerifyWebSockets { type Request = (Request, Framed); type Response = (Request, Framed); type Error = (HandshakeError, Framed); - type Future = FutureResult; + type Future = Ready>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { match verify_handshake(req.head()) { - Err(e) => Err((e, framed)).into_future(), - Ok(_) => Ok((req, framed)).into_future(), + Err(e) => err((e, framed)), + Ok(_) => ok((req, framed)), } } } @@ -67,9 +69,9 @@ where } } -impl NewService for SendError +impl ServiceFactory for SendError where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, R: 'static, E: ResponseError + 'static, { @@ -79,34 +81,34 @@ where type Error = (E, Framed); type InitError = (); type Service = SendError; - type Future = FutureResult; + type Future = Ready>; - fn new_service(&self, _: &C) -> Self::Future { + fn new_service(&self, _: C) -> Self::Future { ok(SendError(PhantomData)) } } impl Service for SendError where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, R: 'static, E: ResponseError + 'static, { type Request = Result)>; type Response = R; type Error = (E, Framed); - type Future = Either)>, SendErrorFut>; + type Future = Either)>>, SendErrorFut>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: Result)>) -> Self::Future { match req { - Ok(r) => Either::A(ok(r)), + Ok(r) => Either::Left(ok(r)), Err((e, framed)) => { let res = e.error_response().drop_body(); - Either::B(SendErrorFut { + Either::Right(SendErrorFut { framed: Some(framed), res: Some((res, BodySize::Empty).into()), err: Some(e), @@ -117,6 +119,7 @@ where } } +#[pin_project::pin_project] pub struct SendErrorFut { res: Option, BodySize)>>, framed: Option>, @@ -127,23 +130,27 @@ pub struct SendErrorFut { impl Future for SendErrorFut where E: ResponseError, - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { - type Item = R; - type Error = (E, Framed); + type Output = Result)>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if let Some(res) = self.res.take() { - if self.framed.as_mut().unwrap().force_send(res).is_err() { - return Err((self.err.take().unwrap(), self.framed.take().unwrap())); + if self.framed.as_mut().unwrap().write(res).is_err() { + return Poll::Ready(Err(( + self.err.take().unwrap(), + self.framed.take().unwrap(), + ))); } } - match self.framed.as_mut().unwrap().poll_complete() { - Ok(Async::Ready(_)) => { - Err((self.err.take().unwrap(), self.framed.take().unwrap())) + match self.framed.as_mut().unwrap().flush(cx) { + Poll::Ready(Ok(_)) => { + Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())), + Poll::Ready(Err(_)) => { + Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) + } + Poll::Pending => Poll::Pending, } } } diff --git a/actix-framed/src/test.rs b/actix-framed/src/test.rs index 3bc828df4..b8029531e 100644 --- a/actix-framed/src/test.rs +++ b/actix-framed/src/test.rs @@ -1,12 +1,13 @@ //! Various helpers for Actix applications to use during testing. +use std::convert::TryFrom; +use std::future::Future; + use actix_codec::Framed; use actix_http::h1::Codec; use actix_http::http::header::{Header, HeaderName, IntoHeaderValue}; -use actix_http::http::{HttpTryFrom, Method, Uri, Version}; +use actix_http::http::{Error as HttpError, Method, Uri, Version}; use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest}; use actix_router::{Path, Url}; -use actix_rt::Runtime; -use futures::IntoFuture; use crate::{FramedRequest, State}; @@ -41,7 +42,8 @@ impl TestRequest<()> { /// Create TestRequest and set header pub fn with_header(key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { Self::default().header(key, value) @@ -96,7 +98,8 @@ impl TestRequest { /// Set a header pub fn header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { self.req.header(key, value); @@ -118,13 +121,12 @@ impl TestRequest { } /// This method generates `FramedRequest` instance and executes async handler - pub fn run(self, f: F) -> Result + pub async fn run(self, f: F) -> Result where F: FnOnce(FramedRequest) -> R, - R: IntoFuture, + R: Future>, { - let mut rt = Runtime::new().unwrap(); - rt.block_on(f(self.finish()).into_future()) + f(self.finish()).await } } diff --git a/actix-framed/tests/test_server.rs b/actix-framed/tests/test_server.rs index 00f6a97d8..7d6fc08a6 100644 --- a/actix-framed/tests/test_server.rs +++ b/actix-framed/tests/test_server.rs @@ -1,141 +1,159 @@ use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response}; -use actix_http_test::TestServer; -use actix_service::{IntoNewService, NewService}; -use actix_utils::framed::FramedTransport; -use bytes::{Bytes, BytesMut}; -use futures::future::{self, ok}; -use futures::{Future, Sink, Stream}; +use actix_http_test::test_server; +use actix_service::{pipeline_factory, IntoServiceFactory, ServiceFactory}; +use actix_utils::framed::Dispatcher; +use bytes::Bytes; +use futures::{future, SinkExt, StreamExt}; use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets}; -fn ws_service( +async fn ws_service( req: FramedRequest, -) -> impl Future { - let (req, framed, _) = req.into_parts(); +) -> Result<(), Error> { + let (req, mut framed, _) = req.into_parts(); let res = ws::handshake(req.head()).unwrap().message_body(()); framed .send((res, body::BodySize::None).into()) - .map_err(|_| panic!()) - .and_then(|framed| { - FramedTransport::new(framed.into_framed(ws::Codec::new()), service) - .map_err(|_| panic!()) - }) + .await + .unwrap(); + Dispatcher::new(framed.into_framed(ws::Codec::new()), service) + .await + .unwrap(); + + Ok(()) } -fn service(msg: ws::Frame) -> impl Future { +async fn service(msg: ws::Frame) -> Result { let msg = match msg { ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string()) + ws::Message::Text(String::from_utf8_lossy(&text).to_string()) } - ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()), + ws::Frame::Binary(bin) => ws::Message::Binary(bin), ws::Frame::Close(reason) => ws::Message::Close(reason), _ => panic!(), }; - ok(msg) + Ok(msg) } -#[test] -fn test_simple() { - let mut srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_simple() { + let mut srv = test_server(|| { HttpService::build() .upgrade( FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)), ) .finish(|_| future::ok::<_, Error>(Response::NotFound())) + .tcp() }); - assert!(srv.ws_at("/test").is_err()); + assert!(srv.ws_at("/test").await.is_err()); // client service - let framed = srv.ws_at("/index.html").unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) + let mut framed = srv.ws_at("/index.html").await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); - - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + let (item, mut framed) = framed.into_future().await; assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) + item.unwrap().unwrap(), + ws::Frame::Text(Bytes::from_static(b"text")) ); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) + framed + .send(ws::Message::Binary("text".into())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); - - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) - .unwrap(); - - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + let (item, mut framed) = framed.into_future().await; assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + item.unwrap().unwrap(), + ws::Frame::Binary(Bytes::from_static(b"text")) + ); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let (item, _) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) ); } -#[test] -fn test_service() { - let mut srv = TestServer::new(|| { - actix_http::h1::OneRequest::new().map_err(|_| ()).and_then( - VerifyWebSockets::default() - .then(SendError::default()) - .map_err(|_| ()) - .and_then( - FramedApp::new() - .service(FramedRoute::get("/index.html").to(ws_service)) - .into_new_service() - .map_err(|_| ()), - ), +#[actix_rt::test] +async fn test_service() { + let mut srv = test_server(|| { + pipeline_factory(actix_http::h1::OneRequest::new().map_err(|_| ())).and_then( + pipeline_factory( + pipeline_factory(VerifyWebSockets::default()) + .then(SendError::default()) + .map_err(|_| ()), + ) + .and_then( + FramedApp::new() + .service(FramedRoute::get("/index.html").to(ws_service)) + .into_factory() + .map_err(|_| ()), + ), ) }); // non ws request - let res = srv.block_on(srv.get("/index.html").send()).unwrap(); + let res = srv.get("/index.html").send().await.unwrap(); assert_eq!(res.status(), StatusCode::BAD_REQUEST); // not found - assert!(srv.ws_at("/test").is_err()); + assert!(srv.ws_at("/test").await.is_err()); // client service - let framed = srv.ws_at("/index.html").unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) + let mut framed = srv.ws_at("/index.html").await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); - - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + let (item, mut framed) = framed.into_future().await; assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) + item.unwrap().unwrap(), + ws::Frame::Text(Bytes::from_static(b"text")) ); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) + framed + .send(ws::Message::Binary("text".into())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); - - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) - .unwrap(); - - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + let (item, mut framed) = framed.into_future().await; assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + item.unwrap().unwrap(), + ws::Frame::Binary(Bytes::from_static(b"text")) + ); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let (item, _) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) ); } diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 80da691cd..1c8e4f053 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,16 +1,79 @@ # Changes -## +## [1.0.1] - 2019-12-20 + +### Fixed + +* Poll upgrade service's readiness from HTTP service handlers + +* Replace brotli with brotli2 #1224 + +## [1.0.0] - 2019-12-13 + +### Added + +* Add websockets continuation frame support + +### Changed + +* Replace `flate2-xxx` features with `compress` + +## [1.0.0-alpha.5] - 2019-12-09 + +### Fixed + +* Check `Upgrade` service readiness before calling it + +* Fix buffer remaining capacity calcualtion + +### Changed + +* Websockets: Ping and Pong should have binary data #1049 + +## [1.0.0-alpha.4] - 2019-12-08 + +### Added + +* Add impl ResponseBuilder for Error + +### Changed + +* Use rust based brotli compression library + +## [1.0.0-alpha.3] - 2019-12-07 + +### Changed + +* Migrate to tokio 0.2 + +* Migrate to `std::future` + + +## [0.2.11] - 2019-11-06 + +### Added + +* Add support for serde_json::Value to be passed as argument to ResponseBuilder.body() + +* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) + +* Allow to use `std::convert::Infallible` as `actix_http::error::Error` + +### Fixed + +* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header #1118 + + +## [0.2.10] - 2019-09-11 ### Added * Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests with `RequestHead` - -## [0.2.10] - 2019-09-xx - ### Fixed +* h2 will use error response #1080 + * on_connect result isn't added to request extensions for http2 requests #1009 diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 290a8fbae..93aaa756e 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http" -version = "0.2.9" +version = "1.0.1" authors = ["Nikolay Kim "] description = "Actix http primitives" readme = "README.md" @@ -13,10 +13,9 @@ categories = ["network-programming", "asynchronous", "web-programming::websocket"] license = "MIT/Apache-2.0" edition = "2018" -workspace = ".." [package.metadata.docs.rs] -features = ["ssl", "fail", "brotli", "flate2-zlib", "secure-cookies"] +features = ["openssl", "rustls", "failure", "compress", "secure-cookies"] [lib] name = "actix_http" @@ -26,85 +25,77 @@ path = "src/lib.rs" default = [] # openssl -ssl = ["openssl", "actix-connect/ssl"] +openssl = ["actix-tls/openssl", "actix-connect/openssl"] # rustls support -rust-tls = ["rustls", "webpki-roots", "actix-connect/rust-tls"] +rustls = ["actix-tls/rustls", "actix-connect/rustls"] -# brotli encoding, requires c compiler -brotli = ["brotli2"] - -# miniz-sys backend for flate2 crate -flate2-zlib = ["flate2/miniz-sys"] - -# rust backend for flate2 crate -flate2-rust = ["flate2/rust_backend"] +# enable compressison support +compress = ["flate2", "brotli2"] # failure integration. actix does not use failure anymore -fail = ["failure"] +failure = ["fail-ure"] # support for secure cookies secure-cookies = ["ring"] [dependencies] -actix-service = "0.4.1" -actix-codec = "0.1.2" -actix-connect = "0.2.4" -actix-utils = "0.4.4" -actix-server-config = "0.1.2" -actix-threadpool = "0.1.1" +actix-service = "1.0.1" +actix-codec = "0.2.0" +actix-connect = "1.0.1" +actix-utils = "1.0.3" +actix-rt = "1.0.0" +actix-threadpool = "0.3.1" +actix-tls = { version = "1.0.0", optional = true } -base64 = "0.10" -bitflags = "1.0" -bytes = "0.4" +base64 = "0.11" +bitflags = "1.2" +bytes = "0.5.3" copyless = "0.1.4" -derive_more = "0.15.0" -either = "1.5.2" +chrono = "0.4.6" +derive_more = "0.99.2" +either = "1.5.3" encoding_rs = "0.8" -futures = "0.1.25" -hashbrown = "0.5.0" -h2 = "0.1.16" -http = "0.1.17" +futures-core = "0.3.1" +futures-util = "0.3.1" +futures-channel = "0.3.1" +fxhash = "0.2.1" +h2 = "0.2.1" +http = "0.2.0" httparse = "1.3" -indexmap = "1.0" -lazy_static = "1.0" +indexmap = "1.3" +lazy_static = "1.4" language-tags = "0.2" log = "0.4" mime = "0.3" percent-encoding = "2.1" +pin-project = "0.4.6" rand = "0.7" -regex = "1.0" +regex = "1.3" serde = "1.0" serde_json = "1.0" -sha1 = "0.6" +sha-1 = "0.8" slab = "0.4" serde_urlencoded = "0.6.1" time = "0.1.42" -tokio-tcp = "0.1.3" -tokio-timer = "0.2.8" -tokio-current-thread = "0.1" -trust-dns-resolver = { version="0.11.1", default-features = false } # for secure cookie -ring = { version = "0.14.6", optional = true } +ring = { version = "0.16.9", optional = true } # compression brotli2 = { version="0.3.2", optional = true } -flate2 = { version="1.0.7", optional = true, default-features = false } +flate2 = { version = "1.0.13", optional = true } # optional deps -failure = { version = "0.1.5", optional = true } -openssl = { version="0.10", optional = true } -rustls = { version = "0.15.2", optional = true } -webpki-roots = { version = "0.16", optional = true } -chrono = "0.4.6" +fail-ure = { version = "0.1.5", package="failure", optional = true } [dev-dependencies] -actix-rt = "0.2.2" -actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] } -actix-connect = { version = "0.2.0", features=["ssl"] } -actix-http-test = { version = "0.2.4", features=["ssl"] } +actix-server = "1.0.0" +actix-connect = { version = "1.0.0", features=["openssl"] } +actix-http-test = { version = "1.0.0", features=["openssl"] } +actix-tls = { version = "1.0.0", features=["openssl"] } +futures = "0.3.1" env_logger = "0.6" serde_derive = "1.0" -openssl = { version="0.10" } -tokio-tcp = "0.1" +open-ssl = { version="0.10", package = "openssl" } +rust-tls = { version="0.16", package = "rustls" } diff --git a/actix-http/examples/echo.rs b/actix-http/examples/echo.rs index c36292c44..3d57a472a 100644 --- a/actix-http/examples/echo.rs +++ b/actix-http/examples/echo.rs @@ -1,13 +1,14 @@ use std::{env, io}; -use actix_http::{error::PayloadError, HttpService, Request, Response}; +use actix_http::{Error, HttpService, Request, Response}; use actix_server::Server; use bytes::BytesMut; -use futures::{Future, Stream}; +use futures::StreamExt; use http::header::HeaderValue; use log::info; -fn main() -> io::Result<()> { +#[actix_rt::main] +async fn main() -> io::Result<()> { env::set_var("RUST_LOG", "echo=info"); env_logger::init(); @@ -17,21 +18,25 @@ fn main() -> io::Result<()> { .client_timeout(1000) .client_disconnect(1000) .finish(|mut req: Request| { - req.take_payload() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) - .and_then(|bytes| { - info!("request body: {:?}", bytes); - let mut res = Response::Ok(); - res.header( - "x-head", - HeaderValue::from_static("dummy value!"), - ); - Ok(res.body(bytes)) - }) + async move { + let mut body = BytesMut::new(); + while let Some(item) = req.payload().next().await { + body.extend_from_slice(&item?); + } + + info!("request body: {:?}", body); + Ok::<_, Error>( + Response::Ok() + .header( + "x-head", + HeaderValue::from_static("dummy value!"), + ) + .body(body), + ) + } }) + .tcp() })? .run() + .await } diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs index b239796b4..f89ea2dfb 100644 --- a/actix-http/examples/echo2.rs +++ b/actix-http/examples/echo2.rs @@ -1,34 +1,33 @@ use std::{env, io}; use actix_http::http::HeaderValue; -use actix_http::{error::PayloadError, Error, HttpService, Request, Response}; +use actix_http::{Error, HttpService, Request, Response}; use actix_server::Server; use bytes::BytesMut; -use futures::{Future, Stream}; +use futures::StreamExt; use log::info; -fn handle_request(mut req: Request) -> impl Future { - req.take_payload() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) - .from_err() - .and_then(|bytes| { - info!("request body: {:?}", bytes); - let mut res = Response::Ok(); - res.header("x-head", HeaderValue::from_static("dummy value!")); - Ok(res.body(bytes)) - }) +async fn handle_request(mut req: Request) -> Result { + let mut body = BytesMut::new(); + while let Some(item) = req.payload().next().await { + body.extend_from_slice(&item?) + } + + info!("request body: {:?}", body); + Ok(Response::Ok() + .header("x-head", HeaderValue::from_static("dummy value!")) + .body(body)) } -fn main() -> io::Result<()> { +#[actix_rt::main] +async fn main() -> io::Result<()> { env::set_var("RUST_LOG", "echo=info"); env_logger::init(); Server::build() .bind("echo", "127.0.0.1:8080", || { - HttpService::build().finish(|_req: Request| handle_request(_req)) + HttpService::build().finish(handle_request).tcp() })? .run() + .await } diff --git a/actix-http/examples/hello-world.rs b/actix-http/examples/hello-world.rs index 6e3820390..4134ee734 100644 --- a/actix-http/examples/hello-world.rs +++ b/actix-http/examples/hello-world.rs @@ -6,7 +6,8 @@ use futures::future; use http::header::HeaderValue; use log::info; -fn main() -> io::Result<()> { +#[actix_rt::main] +async fn main() -> io::Result<()> { env::set_var("RUST_LOG", "hello_world=info"); env_logger::init(); @@ -21,6 +22,8 @@ fn main() -> io::Result<()> { res.header("x-head", HeaderValue::from_static("dummy value!")); future::ok::<_, ()>(res.body("Hello world!")) }) + .tcp() })? .run() + .await } diff --git a/actix-http/src/body.rs b/actix-http/src/body.rs index e728cdb98..850f97ee4 100644 --- a/actix-http/src/body.rs +++ b/actix-http/src/body.rs @@ -1,8 +1,11 @@ use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, mem}; use bytes::{Bytes, BytesMut}; -use futures::{Async, Poll, Stream}; +use futures_core::Stream; +use pin_project::{pin_project, project}; use crate::error::Error; @@ -32,7 +35,7 @@ impl BodySize { pub trait MessageBody { fn size(&self) -> BodySize; - fn poll_next(&mut self) -> Poll, Error>; + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>>; } impl MessageBody for () { @@ -40,8 +43,8 @@ impl MessageBody for () { BodySize::Empty } - fn poll_next(&mut self) -> Poll, Error> { - Ok(Async::Ready(None)) + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { + Poll::Ready(None) } } @@ -50,11 +53,12 @@ impl MessageBody for Box { self.as_ref().size() } - fn poll_next(&mut self) -> Poll, Error> { - self.as_mut().poll_next() + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + self.as_mut().poll_next(cx) } } +#[pin_project] pub enum ResponseBody { Body(B), Other(Body), @@ -93,20 +97,27 @@ impl MessageBody for ResponseBody { } } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { match self { - ResponseBody::Body(ref mut body) => body.poll_next(), - ResponseBody::Other(ref mut body) => body.poll_next(), + ResponseBody::Body(ref mut body) => body.poll_next(cx), + ResponseBody::Other(ref mut body) => body.poll_next(cx), } } } impl Stream for ResponseBody { - type Item = Bytes; - type Error = Error; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.poll_next() + #[project] + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + #[project] + match self.project() { + ResponseBody::Body(ref mut body) => body.poll_next(cx), + ResponseBody::Other(ref mut body) => body.poll_next(cx), + } } } @@ -125,7 +136,7 @@ pub enum Body { impl Body { /// Create body from slice (copy) pub fn from_slice(s: &[u8]) -> Body { - Body::Bytes(Bytes::from(s)) + Body::Bytes(Bytes::copy_from_slice(s)) } /// Create body from generic message body. @@ -144,19 +155,19 @@ impl MessageBody for Body { } } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { match self { - Body::None => Ok(Async::Ready(None)), - Body::Empty => Ok(Async::Ready(None)), + Body::None => Poll::Ready(None), + Body::Empty => Poll::Ready(None), Body::Bytes(ref mut bin) => { let len = bin.len(); if len == 0 { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(mem::replace(bin, Bytes::new())))) + Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new())))) } } - Body::Message(ref mut body) => body.poll_next(), + Body::Message(ref mut body) => body.poll_next(cx), } } } @@ -182,7 +193,7 @@ impl PartialEq for Body { } impl fmt::Debug for Body { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { Body::None => write!(f, "Body::None"), Body::Empty => write!(f, "Body::Empty"), @@ -218,7 +229,7 @@ impl From for Body { impl<'a> From<&'a String> for Body { fn from(s: &'a String) -> Body { - Body::Bytes(Bytes::from(AsRef::<[u8]>::as_ref(&s))) + Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s))) } } @@ -234,9 +245,15 @@ impl From for Body { } } +impl From for Body { + fn from(v: serde_json::Value) -> Body { + Body::Bytes(v.to_string().into()) + } +} + impl From> for Body where - S: Stream + 'static, + S: Stream> + 'static, { fn from(s: SizedStream) -> Body { Body::from_message(s) @@ -245,7 +262,7 @@ where impl From> for Body where - S: Stream + 'static, + S: Stream> + 'static, E: Into + 'static, { fn from(s: BodyStream) -> Body { @@ -258,11 +275,11 @@ impl MessageBody for Bytes { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(mem::replace(self, Bytes::new())))) + Poll::Ready(Some(Ok(mem::replace(self, Bytes::new())))) } } } @@ -272,13 +289,11 @@ impl MessageBody for BytesMut { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some( - mem::replace(self, BytesMut::new()).freeze(), - ))) + Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze()))) } } } @@ -288,11 +303,11 @@ impl MessageBody for &'static str { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from_static( + Poll::Ready(Some(Ok(Bytes::from_static( mem::replace(self, "").as_ref(), )))) } @@ -304,13 +319,11 @@ impl MessageBody for &'static [u8] { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from_static(mem::replace( - self, b"", - ))))) + Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b""))))) } } } @@ -320,14 +333,11 @@ impl MessageBody for Vec { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from(mem::replace( - self, - Vec::new(), - ))))) + Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new()))))) } } } @@ -337,11 +347,11 @@ impl MessageBody for String { BodySize::Sized(self.len()) } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { if self.is_empty() { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { - Ok(Async::Ready(Some(Bytes::from( + Poll::Ready(Some(Ok(Bytes::from( mem::replace(self, String::new()).into_bytes(), )))) } @@ -350,14 +360,16 @@ impl MessageBody for String { /// Type represent streaming body. /// Response does not contain `content-length` header and appropriate transfer encoding is used. +#[pin_project] pub struct BodyStream { + #[pin] stream: S, _t: PhantomData, } impl BodyStream where - S: Stream, + S: Stream>, E: Into, { pub fn new(stream: S) -> Self { @@ -370,28 +382,34 @@ where impl MessageBody for BodyStream where - S: Stream, + S: Stream>, E: Into, { fn size(&self) -> BodySize { BodySize::Stream } - fn poll_next(&mut self) -> Poll, Error> { - self.stream.poll().map_err(std::convert::Into::into) + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + unsafe { Pin::new_unchecked(self) } + .project() + .stream + .poll_next(cx) + .map(|res| res.map(|res| res.map_err(std::convert::Into::into))) } } /// Type represent streaming body. This body implementation should be used /// if total size of stream is known. Data get sent as is without using transfer encoding. +#[pin_project] pub struct SizedStream { size: u64, + #[pin] stream: S, } impl SizedStream where - S: Stream, + S: Stream>, { pub fn new(size: u64, stream: S) -> Self { SizedStream { size, stream } @@ -400,20 +418,24 @@ where impl MessageBody for SizedStream where - S: Stream, + S: Stream>, { fn size(&self) -> BodySize { BodySize::Sized64(self.size) } - fn poll_next(&mut self) -> Poll, Error> { - self.stream.poll() + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + unsafe { Pin::new_unchecked(self) } + .project() + .stream + .poll_next(cx) } } #[cfg(test)] mod tests { use super::*; + use futures_util::future::poll_fn; impl Body { pub(crate) fn get_ref(&self) -> &[u8] { @@ -433,21 +455,21 @@ mod tests { } } - #[test] - fn test_static_str() { + #[actix_rt::test] + async fn test_static_str() { assert_eq!(Body::from("").size(), BodySize::Sized(0)); assert_eq!(Body::from("test").size(), BodySize::Sized(4)); assert_eq!(Body::from("test").get_ref(), b"test"); assert_eq!("test".size(), BodySize::Sized(4)); assert_eq!( - "test".poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + poll_fn(|cx| "test".poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) ); } - #[test] - fn test_static_bytes() { + #[actix_rt::test] + async fn test_static_bytes() { assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); assert_eq!( @@ -458,51 +480,57 @@ mod tests { assert_eq!((&b"test"[..]).size(), BodySize::Sized(4)); assert_eq!( - (&b"test"[..]).poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + poll_fn(|cx| (&b"test"[..]).poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("test")) ); } - #[test] - fn test_vec() { + #[actix_rt::test] + async fn test_vec() { assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); assert_eq!( - Vec::from("test").poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + poll_fn(|cx| Vec::from("test").poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("test")) ); } - #[test] - fn test_bytes() { + #[actix_rt::test] + async fn test_bytes() { let mut b = Bytes::from("test"); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b.clone()).get_ref(), b"test"); assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!( - b.poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) ); } - #[test] - fn test_bytes_mut() { + #[actix_rt::test] + async fn test_bytes_mut() { let mut b = BytesMut::from("test"); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b.clone()).get_ref(), b"test"); assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!( - b.poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) ); } - #[test] - fn test_string() { + #[actix_rt::test] + async fn test_string() { let mut b = "test".to_owned(); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b.clone()).get_ref(), b"test"); @@ -511,26 +539,26 @@ mod tests { assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!( - b.poll_next().unwrap(), - Async::Ready(Some(Bytes::from("test"))) + poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) ); } - #[test] - fn test_unit() { + #[actix_rt::test] + async fn test_unit() { assert_eq!(().size(), BodySize::Empty); - assert_eq!(().poll_next().unwrap(), Async::Ready(None)); + assert!(poll_fn(|cx| ().poll_next(cx)).await.is_none()); } - #[test] - fn test_box() { + #[actix_rt::test] + async fn test_box() { let mut val = Box::new(()); assert_eq!(val.size(), BodySize::Empty); - assert_eq!(val.poll_next().unwrap(), Async::Ready(None)); + assert!(poll_fn(|cx| val.poll_next(cx)).await.is_none()); } - #[test] - fn test_body_eq() { + #[actix_rt::test] + async fn test_body_eq() { assert!(Body::None == Body::None); assert!(Body::None != Body::Empty); assert!(Body::Empty == Body::Empty); @@ -542,10 +570,23 @@ mod tests { assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); } - #[test] - fn test_body_debug() { + #[actix_rt::test] + async fn test_body_debug() { assert!(format!("{:?}", Body::None).contains("Body::None")); assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains("1")); } + + #[actix_rt::test] + async fn test_serde_json() { + use serde_json::json; + assert_eq!( + Body::from(serde_json::Value::String("test".into())).size(), + BodySize::Sized(6) + ); + assert_eq!( + Body::from(json!({"test-key":"test-value"})).size(), + BodySize::Sized(25) + ); + } } diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index cd23b7265..271abd43f 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -1,10 +1,9 @@ -use std::fmt; use std::marker::PhantomData; use std::rc::Rc; +use std::{fmt, net}; use actix_codec::Framed; -use actix_server_config::ServerConfig as SrvConfig; -use actix_service::{IntoNewService, NewService, Service}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use crate::body::MessageBody; use crate::config::{KeepAlive, ServiceConfig}; @@ -24,6 +23,8 @@ pub struct HttpServiceBuilder> { keep_alive: KeepAlive, client_timeout: u64, client_disconnect: u64, + secure: bool, + local_addr: Option, expect: X, upgrade: Option, on_connect: Option Box>>, @@ -32,9 +33,10 @@ pub struct HttpServiceBuilder> { impl HttpServiceBuilder> where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, + ::Future: 'static, { /// Create instance of `ServiceConfigBuilder` pub fn new() -> Self { @@ -42,6 +44,8 @@ where keep_alive: KeepAlive::Timeout(5), client_timeout: 5000, client_disconnect: 0, + secure: false, + local_addr: None, expect: ExpectHandler, upgrade: None, on_connect: None, @@ -52,19 +56,18 @@ where impl HttpServiceBuilder where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - X: NewService, + ::Future: 'static, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService< - Config = SrvConfig, - Request = (Request, Framed), - Response = (), - >, + ::Future: 'static, + U: ServiceFactory), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, + ::Future: 'static, { /// Set server keep-alive setting. /// @@ -74,6 +77,18 @@ where self } + /// Set connection secure state + pub fn secure(mut self) -> Self { + self.secure = true; + self + } + + /// Set the local address that this service is bound to. + pub fn local_addr(mut self, addr: net::SocketAddr) -> Self { + self.local_addr = Some(addr); + self + } + /// Set server client timeout in milliseconds for first request. /// /// Defines a timeout for reading client request header. If a client does not transmit @@ -108,16 +123,19 @@ where /// request will be forwarded to main service. pub fn expect(self, expect: F) -> HttpServiceBuilder where - F: IntoNewService, - X1: NewService, + F: IntoServiceFactory, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, + ::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, client_timeout: self.client_timeout, client_disconnect: self.client_disconnect, - expect: expect.into_new_service(), + secure: self.secure, + local_addr: self.local_addr, + expect: expect.into_factory(), upgrade: self.upgrade, on_connect: self.on_connect, _t: PhantomData, @@ -130,21 +148,24 @@ where /// and this service get called with original request and framed object. pub fn upgrade(self, upgrade: F) -> HttpServiceBuilder where - F: IntoNewService, - U1: NewService< - Config = SrvConfig, + F: IntoServiceFactory, + U1: ServiceFactory< + Config = (), Request = (Request, Framed), Response = (), >, U1::Error: fmt::Display, U1::InitError: fmt::Debug, + ::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, client_timeout: self.client_timeout, client_disconnect: self.client_disconnect, + secure: self.secure, + local_addr: self.local_addr, expect: self.expect, - upgrade: Some(upgrade.into_new_service()), + upgrade: Some(upgrade.into_factory()), on_connect: self.on_connect, _t: PhantomData, } @@ -164,10 +185,10 @@ where } /// Finish service configuration and create *http service* for HTTP/1 protocol. - pub fn h1(self, service: F) -> H1Service + pub fn h1(self, service: F) -> H1Service where - B: MessageBody + 'static, - F: IntoNewService, + B: MessageBody, + F: IntoServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, @@ -176,48 +197,53 @@ where self.keep_alive, self.client_timeout, self.client_disconnect, + self.secure, + self.local_addr, ); - H1Service::with_config(cfg, service.into_new_service()) + H1Service::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) } /// Finish service configuration and create *http service* for HTTP/2 protocol. - pub fn h2(self, service: F) -> H2Service + pub fn h2(self, service: F) -> H2Service where B: MessageBody + 'static, - F: IntoNewService, - S::Error: Into, + F: IntoServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, { let cfg = ServiceConfig::new( self.keep_alive, self.client_timeout, self.client_disconnect, + self.secure, + self.local_addr, ); - H2Service::with_config(cfg, service.into_new_service()) - .on_connect(self.on_connect) + H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) } /// Finish service configuration and create `HttpService` instance. - pub fn finish(self, service: F) -> HttpService + pub fn finish(self, service: F) -> HttpService where B: MessageBody + 'static, - F: IntoNewService, - S::Error: Into, + F: IntoServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, { let cfg = ServiceConfig::new( self.keep_alive, self.client_timeout, self.client_disconnect, + self.secure, + self.local_addr, ); - HttpService::with_config(cfg, service.into_new_service()) + HttpService::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index d2b94b3e5..0ca788b32 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -1,10 +1,12 @@ -use std::{fmt, io, time}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, io, mem, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{Buf, Bytes}; -use futures::future::{err, Either, Future, FutureResult}; -use futures::Poll; +use futures_util::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready}; use h2::client::SendRequest; +use pin_project::{pin_project, project}; use crate::body::MessageBody; use crate::h1::ClientCodec; @@ -21,8 +23,8 @@ pub(crate) enum ConnectionType { } pub trait Connection { - type Io: AsyncRead + AsyncWrite; - type Future: Future; + type Io: AsyncRead + AsyncWrite + Unpin; + type Future: Future>; fn protocol(&self) -> Protocol; @@ -34,8 +36,7 @@ pub trait Connection { ) -> Self::Future; type TunnelFuture: Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, + Output = Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed @@ -62,7 +63,7 @@ impl fmt::Debug for IoConnection where T: fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.io { Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), @@ -71,7 +72,7 @@ where } } -impl IoConnection { +impl IoConnection { pub(crate) fn new( io: ConnectionType, created: time::Instant, @@ -91,11 +92,11 @@ impl IoConnection { impl Connection for IoConnection where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = T; type Future = - Box>; + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self.io { @@ -111,38 +112,30 @@ where body: B, ) -> Self::Future { match self.io.take().unwrap() { - ConnectionType::H1(io) => Box::new(h1proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), - ConnectionType::H2(io) => Box::new(h2proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), + ConnectionType::H1(io) => { + h1proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } + ConnectionType::H2(io) => { + h2proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } } } type TunnelFuture = Either< - Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >, - FutureResult<(ResponseHead, Framed), SendRequestError>, + Ready), SendRequestError>>, >; /// Send request, returns Response and Framed fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { match self.io.take().unwrap() { ConnectionType::H1(io) => { - Either::A(Box::new(h1proto::open_tunnel(io, head.into()))) + Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) } ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { @@ -152,7 +145,7 @@ where None, )); } - Either::B(err(SendRequestError::TunnelNotSupported)) + Either::Right(err(SendRequestError::TunnelNotSupported)) } } } @@ -166,12 +159,12 @@ pub(crate) enum EitherConnection { impl Connection for EitherConnection where - A: AsyncRead + AsyncWrite + 'static, - B: AsyncRead + AsyncWrite + 'static, + A: AsyncRead + AsyncWrite + Unpin + 'static, + B: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = EitherIo; type Future = - Box>; + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self { @@ -191,44 +184,30 @@ where } } - type TunnelFuture = Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + type TunnelFuture = LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed fn open_tunnel>(self, head: H) -> Self::TunnelFuture { match self { - EitherConnection::A(con) => Box::new( - con.open_tunnel(head) - .map(|(head, framed)| (head, framed.map_io(EitherIo::A))), - ), - EitherConnection::B(con) => Box::new( - con.open_tunnel(head) - .map(|(head, framed)| (head, framed.map_io(EitherIo::B))), - ), + EitherConnection::A(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A)))) + .boxed_local(), + EitherConnection::B(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B)))) + .boxed_local(), } } } +#[pin_project] pub enum EitherIo { - A(A), - B(B), -} - -impl io::Read for EitherIo -where - A: io::Read, - B: io::Read, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - EitherIo::A(ref mut val) => val.read(buf), - EitherIo::B(ref mut val) => val.read(buf), - } - } + A(#[pin] A), + B(#[pin] B), } impl AsyncRead for EitherIo @@ -236,7 +215,23 @@ where A: AsyncRead, B: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + #[project] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_read(cx, buf), + EitherIo::B(val) => val.poll_read(cx, buf), + } + } + + unsafe fn prepare_uninitialized_buffer( + &self, + buf: &mut [mem::MaybeUninit], + ) -> bool { match self { EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf), @@ -244,45 +239,58 @@ where } } -impl io::Write for EitherIo -where - A: io::Write, - B: io::Write, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - EitherIo::A(ref mut val) => val.write(buf), - EitherIo::B(ref mut val) => val.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - EitherIo::A(ref mut val) => val.flush(), - EitherIo::B(ref mut val) => val.flush(), - } - } -} - impl AsyncWrite for EitherIo where A: AsyncWrite, B: AsyncWrite, { - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self { - EitherIo::A(ref mut val) => val.shutdown(), - EitherIo::B(ref mut val) => val.shutdown(), + #[project] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_write(cx, buf), + EitherIo::B(val) => val.poll_write(cx, buf), } } - fn write_buf(&mut self, buf: &mut U) -> Poll + #[project] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_flush(cx), + EitherIo::B(val) => val.poll_flush(cx), + } + } + + #[project] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_shutdown(cx), + EitherIo::B(val) => val.poll_shutdown(cx), + } + } + + #[project] + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut U, + ) -> Poll> where Self: Sized, { - match self { - EitherIo::A(ref mut val) => val.write_buf(buf), - EitherIo::B(ref mut val) => val.write_buf(buf), + #[project] + match self.project() { + EitherIo::A(val) => val.poll_write_buf(cx, buf), + EitherIo::B(val) => val.poll_write_buf(cx, buf), } } } diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs index 98e8618c3..055d4276d 100644 --- a/actix-http/src/client/connector.rs +++ b/actix-http/src/client/connector.rs @@ -6,32 +6,32 @@ use actix_codec::{AsyncRead, AsyncWrite}; use actix_connect::{ default_connector, Connect as TcpConnect, Connection as TcpConnection, }; -use actix_service::{apply_fn, Service, ServiceExt}; +use actix_rt::net::TcpStream; +use actix_service::{apply_fn, Service}; use actix_utils::timeout::{TimeoutError, TimeoutService}; use http::Uri; -use tokio_tcp::TcpStream; use super::connection::Connection; use super::error::ConnectError; use super::pool::{ConnectionPool, Protocol}; use super::Connect; -#[cfg(feature = "ssl")] -use openssl::ssl::SslConnector as OpensslConnector; +#[cfg(feature = "openssl")] +use actix_connect::ssl::openssl::SslConnector as OpensslConnector; -#[cfg(feature = "rust-tls")] -use rustls::ClientConfig; -#[cfg(feature = "rust-tls")] +#[cfg(feature = "rustls")] +use actix_connect::ssl::rustls::ClientConfig; +#[cfg(feature = "rustls")] use std::sync::Arc; -#[cfg(any(feature = "ssl", feature = "rust-tls"))] +#[cfg(any(feature = "openssl", feature = "rustls"))] enum SslConnector { - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] Openssl(OpensslConnector), - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] Rustls(Arc), } -#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] +#[cfg(not(any(feature = "openssl", feature = "rustls")))] type SslConnector = (); /// Manages http client network connectivity @@ -58,11 +58,11 @@ pub struct Connector { _t: PhantomData, } -trait Io: AsyncRead + AsyncWrite {} -impl Io for T {} +trait Io: AsyncRead + AsyncWrite + Unpin {} +impl Io for T {} impl Connector<(), ()> { - #[allow(clippy::new_ret_no_self)] + #[allow(clippy::new_ret_no_self, clippy::let_unit_value)] pub fn new() -> Connector< impl Service< Request = TcpConnect, @@ -72,9 +72,9 @@ impl Connector<(), ()> { TcpStream, > { let ssl = { - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] { - use openssl::ssl::SslMethod; + use actix_connect::ssl::openssl::SslMethod; let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); let _ = ssl @@ -82,17 +82,17 @@ impl Connector<(), ()> { .map_err(|e| error!("Can not set alpn protocol: {:?}", e)); SslConnector::Openssl(ssl.build()) } - #[cfg(all(not(feature = "ssl"), feature = "rust-tls"))] + #[cfg(all(not(feature = "openssl"), feature = "rustls"))] { let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let mut config = ClientConfig::new(); config.set_protocols(&protos); config .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + .add_server_trust_anchors(&actix_tls::rustls::TLS_SERVER_ROOTS); SslConnector::Rustls(Arc::new(config)) } - #[cfg(not(any(feature = "ssl", feature = "rust-tls")))] + #[cfg(not(any(feature = "openssl", feature = "rustls")))] {} }; @@ -113,7 +113,7 @@ impl Connector { /// Use custom connector. pub fn connector(self, connector: T1) -> Connector where - U1: AsyncRead + AsyncWrite + fmt::Debug, + U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, T1: Service< Request = TcpConnect, Response = TcpConnection, @@ -135,7 +135,7 @@ impl Connector { impl Connector where - U: AsyncRead + AsyncWrite + fmt::Debug + 'static, + U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, T: Service< Request = TcpConnect, Response = TcpConnection, @@ -150,14 +150,14 @@ where self } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] /// Use custom `SslConnector` instance. pub fn ssl(mut self, connector: OpensslConnector) -> Self { self.ssl = SslConnector::Openssl(connector); self } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] pub fn rustls(mut self, connector: Arc) -> Self { self.ssl = SslConnector::Rustls(connector); self @@ -212,8 +212,8 @@ where pub fn finish( self, ) -> impl Service - + Clone { - #[cfg(not(any(feature = "ssl", feature = "rust-tls")))] + + Clone { + #[cfg(not(any(feature = "openssl", feature = "rustls")))] { let connector = TimeoutService::new( self.timeout, @@ -238,32 +238,30 @@ where ), } } - #[cfg(any(feature = "ssl", feature = "rust-tls"))] + #[cfg(any(feature = "openssl", feature = "rustls"))] { const H2: &[u8] = b"h2"; - #[cfg(feature = "ssl")] - use actix_connect::ssl::OpensslConnector; - #[cfg(feature = "rust-tls")] - use actix_connect::ssl::RustlsConnector; - use actix_service::boxed::service; - #[cfg(feature = "rust-tls")] - use rustls::Session; + #[cfg(feature = "openssl")] + use actix_connect::ssl::openssl::OpensslConnector; + #[cfg(feature = "rustls")] + use actix_connect::ssl::rustls::{RustlsConnector, Session}; + use actix_service::{boxed::service, pipeline}; let ssl_service = TimeoutService::new( self.timeout, - apply_fn(self.connector.clone(), |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) + pipeline( + apply_fn(self.connector.clone(), |msg: Connect, srv| { + srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + }) + .map_err(ConnectError::from), + ) .and_then(match self.ssl { - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] SslConnector::Openssl(ssl) => service( OpensslConnector::service(ssl) - .map_err(ConnectError::from) .map(|stream| { let sock = stream.into_parts().0; let h2 = sock - .get_ref() .ssl() .selected_alpn_protocol() .map(|protos| protos.windows(2).any(|w| w == H2)) @@ -273,9 +271,10 @@ where } else { (Box::new(sock) as Box, Protocol::Http1) } - }), + }) + .map_err(ConnectError::from), ), - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] SslConnector::Rustls(ssl) => service( RustlsConnector::service(ssl) .map_err(ConnectError::from) @@ -303,7 +302,7 @@ where let tcp_service = TimeoutService::new( self.timeout, - apply_fn(self.connector.clone(), |msg: Connect, srv| { + apply_fn(self.connector, |msg: Connect, srv| { srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) }) .map_err(ConnectError::from) @@ -334,19 +333,19 @@ where } } -#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] +#[cfg(not(any(feature = "openssl", feature = "rustls")))] mod connect_impl { - use futures::future::{err, Either, FutureResult}; - use futures::Poll; + use std::task::{Context, Poll}; + + use futures_util::future::{err, Either, Ready}; use super::*; use crate::client::connection::IoConnection; pub(crate) struct InnerConnector where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { pub(crate) tcp_pool: ConnectionPool, @@ -354,9 +353,8 @@ mod connect_impl { impl Clone for InnerConnector where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { fn clone(&self) -> Self { @@ -368,9 +366,8 @@ mod connect_impl { impl Service for InnerConnector where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { type Request = Connect; @@ -378,38 +375,41 @@ mod connect_impl { type Error = ConnectError; type Future = Either< as Service>::Future, - FutureResult, ConnectError>, + Ready, ConnectError>>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.tcp_pool.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.tcp_pool.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { match req.uri.scheme_str() { Some("https") | Some("wss") => { - Either::B(err(ConnectError::SslIsNotSupported)) + Either::Right(err(ConnectError::SslIsNotSupported)) } - _ => Either::A(self.tcp_pool.call(req)), + _ => Either::Left(self.tcp_pool.call(req)), } } } } -#[cfg(any(feature = "ssl", feature = "rust-tls"))] +#[cfg(any(feature = "openssl", feature = "rustls"))] mod connect_impl { + use std::future::Future; use std::marker::PhantomData; + use std::pin::Pin; + use std::task::{Context, Poll}; - use futures::future::{Either, FutureResult}; - use futures::{Async, Future, Poll}; + use futures_core::ready; + use futures_util::future::Either; use super::*; use crate::client::connection::EitherConnection; pub(crate) struct InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service, T2: Service, { @@ -419,13 +419,11 @@ mod connect_impl { impl Clone for InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service - + Clone + 'static, T2: Service - + Clone + 'static, { fn clone(&self) -> Self { @@ -438,53 +436,47 @@ mod connect_impl { impl Service for InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service - + Clone + 'static, T2: Service - + Clone + 'static, { type Request = Connect; type Response = EitherConnection; type Error = ConnectError; type Future = Either< - FutureResult, - Either< - InnerConnectorResponseA, - InnerConnectorResponseB, - >, + InnerConnectorResponseA, + InnerConnectorResponseB, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.tcp_pool.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.tcp_pool.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { match req.uri.scheme_str() { - Some("https") | Some("wss") => { - Either::B(Either::B(InnerConnectorResponseB { - fut: self.ssl_pool.call(req), - _t: PhantomData, - })) - } - _ => Either::B(Either::A(InnerConnectorResponseA { + Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB { + fut: self.ssl_pool.call(req), + _t: PhantomData, + }), + _ => Either::Left(InnerConnectorResponseA { fut: self.tcp_pool.call(req), _t: PhantomData, - })), + }), } } } + #[pin_project::pin_project] pub(crate) struct InnerConnectorResponseA where - Io1: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { + #[pin] fut: as Service>::Future, _t: PhantomData, } @@ -492,29 +484,28 @@ mod connect_impl { impl Future for InnerConnectorResponseA where T: Service - + Clone + 'static, - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = EitherConnection; - type Error = ConnectError; + type Output = Result, ConnectError>; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(EitherConnection::A), + ) } } + #[pin_project::pin_project] pub(crate) struct InnerConnectorResponseB where - Io2: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { + #[pin] fut: as Service>::Future, _t: PhantomData, } @@ -522,19 +513,17 @@ mod connect_impl { impl Future for InnerConnectorResponseB where T: Service - + Clone + 'static, - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = EitherConnection; - type Error = ConnectError; + type Output = Result, ConnectError>; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(EitherConnection::B), + ) } } } diff --git a/actix-http/src/client/error.rs b/actix-http/src/client/error.rs index 40aef2cce..42ea47ee8 100644 --- a/actix-http/src/client/error.rs +++ b/actix-http/src/client/error.rs @@ -1,14 +1,13 @@ use std::io; +use actix_connect::resolver::ResolveError; use derive_more::{Display, From}; -use trust_dns_resolver::error::ResolveError; -#[cfg(feature = "ssl")] -use openssl::ssl::{Error as SslError, HandshakeError}; +#[cfg(feature = "openssl")] +use actix_connect::ssl::openssl::{HandshakeError, SslError}; use crate::error::{Error, ParseError, ResponseError}; -use crate::http::Error as HttpError; -use crate::response::Response; +use crate::http::{Error as HttpError, StatusCode}; /// A set of errors that can occur while connecting to an HTTP host #[derive(Debug, Display, From)] @@ -18,10 +17,15 @@ pub enum ConnectError { SslIsNotSupported, /// SSL error - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] #[display(fmt = "{}", _0)] SslError(SslError), + /// SSL Handshake error + #[cfg(feature = "openssl")] + #[display(fmt = "{}", _0)] + SslHandshakeError(String), + /// Failed to resolve the hostname #[display(fmt = "Failed resolving hostname: {}", _0)] Resolver(ResolveError), @@ -63,14 +67,10 @@ impl From for ConnectError { } } -#[cfg(feature = "ssl")] -impl From> for ConnectError { +#[cfg(feature = "openssl")] +impl From> for ConnectError { fn from(err: HandshakeError) -> ConnectError { - match err { - HandshakeError::SetupFailure(stack) => SslError::from(stack).into(), - HandshakeError::Failure(stream) => stream.into_error().into(), - HandshakeError::WouldBlock(stream) => stream.into_error().into(), - } + ConnectError::SslHandshakeError(format!("{:?}", err)) } } @@ -117,15 +117,14 @@ pub enum SendRequestError { /// Convert `SendRequestError` to a server `Response` impl ResponseError for SendRequestError { - fn error_response(&self) -> Response { + fn status_code(&self) -> StatusCode { match *self { SendRequestError::Connect(ConnectError::Timeout) => { - Response::GatewayTimeout() + StatusCode::GATEWAY_TIMEOUT } - SendRequestError::Connect(_) => Response::BadGateway(), - _ => Response::InternalServerError(), + SendRequestError::Connect(_) => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, } - .into() } } @@ -147,4 +146,4 @@ impl From for SendRequestError { FreezeRequestError::Http(e) => e.into(), } } -} \ No newline at end of file +} diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index fa920ab92..a0a20edf6 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -1,36 +1,42 @@ use std::io::Write; -use std::{io, time}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, mem, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use bytes::{BufMut, Bytes, BytesMut}; -use futures::future::{ok, Either}; -use futures::{Async, Future, Poll, Sink, Stream}; +use bytes::buf::BufMutExt; +use bytes::{Bytes, BytesMut}; +use futures_core::Stream; +use futures_util::future::poll_fn; +use futures_util::{SinkExt, StreamExt}; use crate::error::PayloadError; use crate::h1; +use crate::header::HeaderMap; use crate::http::header::{IntoHeaderValue, HOST}; use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::{Payload, PayloadStream}; -use crate::header::HeaderMap; use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; use super::error::{ConnectError, SendRequestError}; use super::pool::Acquired; use crate::body::{BodySize, MessageBody}; -pub(crate) fn send_request( +pub(crate) async fn send_request( io: T, mut head: RequestHeadType, body: B, created: time::Instant, pool: Option>, -) -> impl Future +) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { // set request host header - if !head.as_ref().headers.contains_key(HOST) && !head.extra_headers().iter().any(|h| h.contains_key(HOST)) { + if !head.as_ref().headers.contains_key(HOST) + && !head.extra_headers().iter().any(|h| h.contains_key(HOST)) + { if let Some(host) = head.as_ref().uri.host() { let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); @@ -39,21 +45,17 @@ where Some(port) => write!(wrt, "{}:{}", host, port), }; - match wrt.get_mut().take().freeze().try_into() { - Ok(value) => { - match head { - RequestHeadType::Owned(ref mut head) => { - head.headers.insert(HOST, value) - }, - RequestHeadType::Rc(_, ref mut extra_headers) => { - let headers = extra_headers.get_or_insert(HeaderMap::new()); - headers.insert(HOST, value) - }, + match wrt.get_mut().split().freeze().try_into() { + Ok(value) => match head { + RequestHeadType::Owned(ref mut head) => { + head.headers.insert(HOST, value) } - } - Err(e) => { - log::error!("Can not set HOST header {}", e) - } + RequestHeadType::Rc(_, ref mut extra_headers) => { + let headers = extra_headers.get_or_insert(HeaderMap::new()); + headers.insert(HOST, value) + } + }, + Err(e) => log::error!("Can not set HOST header {}", e), } } } @@ -64,68 +66,99 @@ where io: Some(io), }; - let len = body.size(); - // create Framed and send request - Framed::new(io, h1::ClientCodec::default()) - .send((head, len).into()) - .from_err() - // send request body - .and_then(move |framed| match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => { - Either::A(ok(framed)) - } - _ => Either::B(SendBody::new(body, framed)), - }) - // read response and init read body - .and_then(|framed| { - framed - .into_future() - .map_err(|(e, _)| SendRequestError::from(e)) - .and_then(|(item, framed)| { - if let Some(res) = item { - match framed.get_codec().message_type() { - h1::MessageType::None => { - let force_close = !framed.get_codec().keepalive(); - release_connection(framed, force_close); - Ok((res, Payload::None)) - } - _ => { - let pl: PayloadStream = Box::new(PlStream::new(framed)); - Ok((res, pl.into())) - } - } - } else { - Err(ConnectError::Disconnected.into()) - } - }) - }) + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, body.size()).into()).await?; + + // send request body + match body.size() { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), + _ => send_body(body, &mut framed).await?, + }; + + // read response and init read body + let res = framed.into_future().await; + let (head, framed) = if let (Some(result), framed) = res { + let item = result.map_err(SendRequestError::from)?; + (item, framed) + } else { + return Err(SendRequestError::from(ConnectError::Disconnected)); + }; + + match framed.get_codec().message_type() { + h1::MessageType::None => { + let force_close = !framed.get_codec().keepalive(); + release_connection(framed, force_close); + Ok((head, Payload::None)) + } + _ => { + let pl: PayloadStream = PlStream::new(framed).boxed_local(); + Ok((head, pl.into())) + } + } } -pub(crate) fn open_tunnel( +pub(crate) async fn open_tunnel( io: T, head: RequestHeadType, -) -> impl Future), Error = SendRequestError> +) -> Result<(ResponseHead, Framed), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { // create Framed and send request - Framed::new(io, h1::ClientCodec::default()) - .send((head, BodySize::None).into()) - .from_err() - // read response - .and_then(|framed| { - framed - .into_future() - .map_err(|(e, _)| SendRequestError::from(e)) - .and_then(|(head, framed)| { - if let Some(head) = head { - Ok((head, framed)) + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, BodySize::None).into()).await?; + + // read response + if let (Some(result), framed) = framed.into_future().await { + let head = result.map_err(SendRequestError::from)?; + Ok((head, framed)) + } else { + Err(SendRequestError::from(ConnectError::Disconnected)) + } +} + +/// send request body to the peer +pub(crate) async fn send_body( + mut body: B, + framed: &mut Framed, +) -> Result<(), SendRequestError> +where + I: ConnectionLifetime, + B: MessageBody, +{ + let mut eof = false; + while !eof { + while !eof && !framed.is_write_buf_full() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(result) => { + framed.write(h1::Message::Chunk(Some(result?)))?; + } + None => { + eof = true; + framed.write(h1::Message::Chunk(None))?; + } + } + } + + if !framed.is_write_buf_empty() { + poll_fn(|cx| match framed.flush(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => { + if !framed.is_write_buf_full() { + Poll::Ready(Ok(())) } else { - Err(SendRequestError::from(ConnectError::Disconnected)) + Poll::Pending } - }) - }) + } + }) + .await?; + } + } + + SinkExt::flush(framed).await?; + Ok(()) } #[doc(hidden)] @@ -136,7 +169,10 @@ pub struct H1Connection { pool: Option>, } -impl ConnectionLifetime for H1Connection { +impl ConnectionLifetime for H1Connection +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ /// Close connection fn close(&mut self) { if let Some(mut pool) = self.pool.take() { @@ -164,98 +200,44 @@ impl ConnectionLifetime for H1Connection } } -impl io::Read for H1Connection { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.io.as_mut().unwrap().read(buf) +impl AsyncRead for H1Connection { + unsafe fn prepare_uninitialized_buffer( + &self, + buf: &mut [mem::MaybeUninit], + ) -> bool { + self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) } } -impl AsyncRead for H1Connection {} - -impl io::Write for H1Connection { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.io.as_mut().unwrap().write(buf) +impl AsyncWrite for H1Connection { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.io.as_mut().unwrap().flush() + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_flush(cx) } -} -impl AsyncWrite for H1Connection { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.as_mut().unwrap().shutdown() - } -} - -/// Future responsible for sending request body to the peer -pub(crate) struct SendBody { - body: Option, - framed: Option>, - flushed: bool, -} - -impl SendBody -where - I: AsyncRead + AsyncWrite + 'static, - B: MessageBody, -{ - pub(crate) fn new(body: B, framed: Framed) -> Self { - SendBody { - body: Some(body), - framed: Some(framed), - flushed: true, - } - } -} - -impl Future for SendBody -where - I: ConnectionLifetime, - B: MessageBody, -{ - type Item = Framed; - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - let mut body_ready = true; - loop { - while body_ready - && self.body.is_some() - && !self.framed.as_ref().unwrap().is_write_buf_full() - { - match self.body.as_mut().unwrap().poll_next()? { - Async::Ready(item) => { - // check if body is done - if item.is_none() { - let _ = self.body.take(); - } - self.flushed = false; - self.framed - .as_mut() - .unwrap() - .force_send(h1::Message::Chunk(item))?; - break; - } - Async::NotReady => body_ready = false, - } - } - - if !self.flushed { - match self.framed.as_mut().unwrap().poll_complete()? { - Async::Ready(_) => { - self.flushed = true; - continue; - } - Async::NotReady => return Ok(Async::NotReady), - } - } - - if self.body.is_none() { - return Ok(Async::Ready(self.framed.take().unwrap())); - } - return Ok(Async::NotReady); - } + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) } } @@ -272,23 +254,27 @@ impl PlStream { } impl Stream for PlStream { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - match self.framed.as_mut().unwrap().poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(Some(chunk)) => { + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + + match this.framed.as_mut().unwrap().next_item(cx)? { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(chunk)) => { if let Some(chunk) = chunk { - Ok(Async::Ready(Some(chunk))) + Poll::Ready(Some(Ok(chunk))) } else { - let framed = self.framed.take().unwrap(); + let framed = this.framed.take().unwrap(); let force_close = !framed.get_codec().keepalive(); release_connection(framed, force_close); - Ok(Async::Ready(None)) + Poll::Ready(None) } } - Async::Ready(None) => Ok(Async::Ready(None)), + Poll::Ready(None) => Poll::Ready(None), } } } diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 2993d89d8..eabf54e97 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -1,31 +1,31 @@ +use std::convert::TryFrom; use std::time; use actix_codec::{AsyncRead, AsyncWrite}; use bytes::Bytes; -use futures::future::{err, Either}; -use futures::{Async, Future, Poll}; +use futures_util::future::poll_fn; use h2::{client::SendRequest, SendStream}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; -use http::{request::Request, HttpTryFrom, Method, Version}; +use http::{request::Request, Method, Version}; use crate::body::{BodySize, MessageBody}; +use crate::header::HeaderMap; use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; -use crate::header::HeaderMap; use super::connection::{ConnectionType, IoConnection}; use super::error::SendRequestError; use super::pool::Acquired; -pub(crate) fn send_request( - io: SendRequest, +pub(crate) async fn send_request( + mut io: SendRequest, head: RequestHeadType, body: B, created: time::Instant, pool: Option>, -) -> impl Future +) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { trace!("Sending client request: {:?} {:?}", head, body.size()); @@ -36,152 +36,140 @@ where _ => false, }; - io.ready() - .map_err(SendRequestError::from) - .and_then(move |mut io| { - let mut req = Request::new(()); - *req.uri_mut() = head.as_ref().uri.clone(); - *req.method_mut() = head.as_ref().method.clone(); - *req.version_mut() = Version::HTTP_2; + let mut req = Request::new(()); + *req.uri_mut() = head.as_ref().uri.clone(); + *req.method_mut() = head.as_ref().method.clone(); + *req.version_mut() = Version::HTTP_2; - let mut skip_len = true; - // let mut has_date = false; + let mut skip_len = true; + // let mut has_date = false; - // Content length - let _ = match length { - BodySize::None => None, - BodySize::Stream => { - skip_len = false; - None - } - BodySize::Empty => req - .headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), - BodySize::Sized(len) => req.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), - BodySize::Sized64(len) => req.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), - }; + // Content length + let _ = match length { + BodySize::None => None, + BodySize::Stream => { + skip_len = false; + None + } + BodySize::Empty => req + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + BodySize::Sized64(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + }; - // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. - let (head, extra_headers) = match head { - RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), - RequestHeadType::Rc(head, extra_headers) => (RequestHeadType::Rc(head, None), extra_headers.unwrap_or(HeaderMap::new())), - }; + // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. + let (head, extra_headers) = match head { + RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), + RequestHeadType::Rc(head, extra_headers) => ( + RequestHeadType::Rc(head, None), + extra_headers.unwrap_or_else(HeaderMap::new), + ), + }; - // merging headers from head and extra headers. - let headers = head.as_ref().headers.iter() - .filter(|(name, _)| { - !extra_headers.contains_key(*name) - }) - .chain(extra_headers.iter()); + // merging headers from head and extra headers. + let headers = head + .as_ref() + .headers + .iter() + .filter(|(name, _)| !extra_headers.contains_key(*name)) + .chain(extra_headers.iter()); - // copy headers - for (key, value) in headers { - match *key { - CONNECTION | TRANSFER_ENCODING => continue, // http2 specific - CONTENT_LENGTH if skip_len => continue, - // DATE => has_date = true, - _ => (), - } - req.headers_mut().append(key, value.clone()); + // copy headers + for (key, value) in headers { + match *key { + CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + CONTENT_LENGTH if skip_len => continue, + // DATE => has_date = true, + _ => (), + } + req.headers_mut().append(key, value.clone()); + } + + let res = poll_fn(|cx| io.poll_ready(cx)).await; + if let Err(e) = res { + release(io, pool, created, e.is_io()); + return Err(SendRequestError::from(e)); + } + + let resp = match io.send_request(req, eof) { + Ok((fut, send)) => { + release(io, pool, created, false); + + if !eof { + send_body(body, send).await?; } + fut.await.map_err(SendRequestError::from)? + } + Err(e) => { + release(io, pool, created, e.is_io()); + return Err(e.into()); + } + }; - match io.send_request(req, eof) { - Ok((res, send)) => { - release(io, pool, created, false); + let (parts, body) = resp.into_parts(); + let payload = if head_req { Payload::None } else { body.into() }; - if !eof { - Either::A(Either::B( - SendBody { - body, - send, - buf: None, - } - .and_then(move |_| res.map_err(SendRequestError::from)), - )) - } else { - Either::B(res.map_err(SendRequestError::from)) - } - } - Err(e) => { - release(io, pool, created, e.is_io()); - Either::A(Either::A(err(e.into()))) - } - } - }) - .and_then(move |resp| { - let (parts, body) = resp.into_parts(); - let payload = if head_req { Payload::None } else { body.into() }; - - let mut head = ResponseHead::new(parts.status); - head.version = parts.version; - head.headers = parts.headers.into(); - Ok((head, payload)) - }) - .from_err() + let mut head = ResponseHead::new(parts.status); + head.version = parts.version; + head.headers = parts.headers.into(); + Ok((head, payload)) } -struct SendBody { - body: B, - send: SendStream, - buf: Option, -} - -impl Future for SendBody { - type Item = (); - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - loop { - if self.buf.is_none() { - match self.body.poll_next() { - Ok(Async::Ready(Some(buf))) => { - self.send.reserve_capacity(buf.len()); - self.buf = Some(buf); - } - Ok(Async::Ready(None)) => { - if let Err(e) = self.send.send_data(Bytes::new(), true) { - return Err(e.into()); - } - self.send.reserve_capacity(0); - return Ok(Async::Ready(())); - } - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => return Err(e.into()), +async fn send_body( + mut body: B, + mut send: SendStream, +) -> Result<(), SendRequestError> { + let mut buf = None; + loop { + if buf.is_none() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(Ok(b)) => { + send.reserve_capacity(b.len()); + buf = Some(b); } - } - - match self.send.poll_capacity() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(None)) => return Ok(Async::Ready(())), - Ok(Async::Ready(Some(cap))) => { - let mut buf = self.buf.take().unwrap(); - let len = buf.len(); - let bytes = buf.split_to(std::cmp::min(cap, len)); - - if let Err(e) = self.send.send_data(bytes, false) { + Some(Err(e)) => return Err(e.into()), + None => { + if let Err(e) = send.send_data(Bytes::new(), true) { return Err(e.into()); - } else { - if !buf.is_empty() { - self.send.reserve_capacity(buf.len()); - self.buf = Some(buf); - } - continue; } + send.reserve_capacity(0); + return Ok(()); } - Err(e) => return Err(e.into()), } } + + match poll_fn(|cx| send.poll_capacity(cx)).await { + None => return Ok(()), + Some(Ok(cap)) => { + let b = buf.as_mut().unwrap(); + let len = b.len(); + let bytes = b.split_to(std::cmp::min(cap, len)); + + if let Err(e) = send.send_data(bytes, false) { + return Err(e.into()); + } else { + if !b.is_empty() { + send.reserve_capacity(b.len()); + } else { + buf = None; + } + continue; + } + } + Some(Err(e)) => return Err(e.into()), + } } } // release SendRequest object -fn release( +fn release( io: SendRequest, pool: Option>, created: time::Instant, diff --git a/actix-http/src/client/mod.rs b/actix-http/src/client/mod.rs index 04427ce42..a45aebcd5 100644 --- a/actix-http/src/client/mod.rs +++ b/actix-http/src/client/mod.rs @@ -10,7 +10,7 @@ mod pool; pub use self::connection::Connection; pub use self::connector::Connector; -pub use self::error::{ConnectError, InvalidUrl, SendRequestError, FreezeRequestError}; +pub use self::error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; pub use self::pool::Protocol; #[derive(Clone)] diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs index 24a187392..acf76559a 100644 --- a/actix-http/src/client/pool.rs +++ b/actix-http/src/client/pool.rs @@ -1,22 +1,22 @@ use std::cell::RefCell; use std::collections::VecDeque; -use std::io; +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::time::{delay_for, Delay}; use actix_service::Service; +use actix_utils::{oneshot, task::LocalWaker}; use bytes::Bytes; -use futures::future::{err, ok, Either, FutureResult}; -use futures::task::AtomicTask; -use futures::unsync::oneshot; -use futures::{Async, Future, Poll}; -use h2::client::{handshake, Handshake}; -use hashbrown::HashMap; +use futures_util::future::{poll_fn, FutureExt, LocalBoxFuture}; +use fxhash::FxHashMap; +use h2::client::{handshake, Connection, SendRequest}; use http::uri::Authority; use indexmap::IndexSet; use slab::Slab; -use tokio_timer::{sleep, Delay}; use super::connection::{ConnectionType, IoConnection}; use super::error::ConnectError; @@ -41,16 +41,12 @@ impl From for Key { } /// Connections pool -pub(crate) struct ConnectionPool( - T, - Rc>>, -); +pub(crate) struct ConnectionPool(Rc>, Rc>>); impl ConnectionPool where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { pub(crate) fn new( @@ -61,7 +57,7 @@ where limit: usize, ) -> Self { ConnectionPool( - connector, + Rc::new(RefCell::new(connector)), Rc::new(RefCell::new(Inner { conn_lifetime, conn_keep_alive, @@ -70,8 +66,8 @@ where acquired: 0, waiters: Slab::new(), waiters_queue: IndexSet::new(), - available: HashMap::new(), - task: None, + available: FxHashMap::default(), + waker: LocalWaker::new(), })), ) } @@ -79,8 +75,7 @@ where impl Clone for ConnectionPool where - T: Clone, - Io: AsyncRead + AsyncWrite + 'static, + Io: 'static, { fn clone(&self) -> Self { ConnectionPool(self.0.clone(), self.1.clone()) @@ -89,86 +84,116 @@ where impl Service for ConnectionPool where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + 'static, { type Request = Connect; type Response = IoConnection; type Error = ConnectError; - type Future = Either< - FutureResult, - Either, OpenConnection>, - >; + type Future = LocalBoxFuture<'static, Result, ConnectError>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.0.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { - let key = if let Some(authority) = req.uri.authority_part() { - authority.clone().into() - } else { - return Either::A(err(ConnectError::Unresolverd)); + // start support future + actix_rt::spawn(ConnectorPoolSupport { + connector: self.0.clone(), + inner: self.1.clone(), + }); + + let mut connector = self.0.clone(); + let inner = self.1.clone(); + + let fut = async move { + let key = if let Some(authority) = req.uri.authority() { + authority.clone().into() + } else { + return Err(ConnectError::Unresolverd); + }; + + // acquire connection + match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { + Acquire::Acquired(io, created) => { + // use existing connection + return Ok(IoConnection::new( + io, + created, + Some(Acquired(key, Some(inner))), + )); + } + Acquire::Available => { + // open tcp connection + let (io, proto) = connector.call(req).await?; + + let guard = OpenGuard::new(key, inner); + + if proto == Protocol::Http1 { + Ok(IoConnection::new( + ConnectionType::H1(io), + Instant::now(), + Some(guard.consume()), + )) + } else { + let (snd, connection) = handshake(io).await?; + actix_rt::spawn(connection.map(|_| ())); + Ok(IoConnection::new( + ConnectionType::H2(snd), + Instant::now(), + Some(guard.consume()), + )) + } + } + _ => { + // connection is not available, wait + let (rx, token) = inner.borrow_mut().wait_for(req); + + let guard = WaiterGuard::new(key, token, inner); + let res = match rx.await { + Err(_) => Err(ConnectError::Disconnected), + Ok(res) => res, + }; + guard.consume(); + res + } + } }; - // acquire connection - match self.1.as_ref().borrow_mut().acquire(&key) { - Acquire::Acquired(io, created) => { - // use existing connection - return Either::A(ok(IoConnection::new( - io, - created, - Some(Acquired(key, Some(self.1.clone()))), - ))); - } - Acquire::Available => { - // open new connection - return Either::B(Either::B(OpenConnection::new( - key, - self.1.clone(), - self.0.call(req), - ))); - } - _ => (), - } - - // connection is not available, wait - let (rx, token, support) = self.1.as_ref().borrow_mut().wait_for(req); - - // start support future - if !support { - self.1.as_ref().borrow_mut().task = Some(AtomicTask::new()); - tokio_current_thread::spawn(ConnectorPoolSupport { - connector: self.0.clone(), - inner: self.1.clone(), - }) - } - - Either::B(Either::A(WaitForConnection { - rx, - key, - token, - inner: Some(self.1.clone()), - })) + fut.boxed_local() } } -#[doc(hidden)] -pub struct WaitForConnection +struct WaiterGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { key: Key, token: usize, - rx: oneshot::Receiver, ConnectError>>, inner: Option>>>, } -impl Drop for WaitForConnection +impl WaiterGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn new(key: Key, token: usize, inner: Rc>>) -> Self { + Self { + key, + token, + inner: Some(inner), + } + } + + fn consume(mut self) { + let _ = self.inner.take(); + } +} + +impl Drop for WaiterGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { if let Some(i) = self.inner.take() { @@ -179,113 +204,43 @@ where } } -impl Future for WaitForConnection +struct OpenGuard where - Io: AsyncRead + AsyncWrite, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = IoConnection; - type Error = ConnectError; - - fn poll(&mut self) -> Poll { - match self.rx.poll() { - Ok(Async::Ready(item)) => match item { - Err(err) => Err(err), - Ok(conn) => { - let _ = self.inner.take(); - Ok(Async::Ready(conn)) - } - }, - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => { - let _ = self.inner.take(); - Err(ConnectError::Disconnected) - } - } - } -} - -#[doc(hidden)] -pub struct OpenConnection -where - Io: AsyncRead + AsyncWrite + 'static, -{ - fut: F, key: Key, - h2: Option>, inner: Option>>>, } -impl OpenConnection +impl OpenGuard where - F: Future, - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { - fn new(key: Key, inner: Rc>>, fut: F) -> Self { - OpenConnection { + fn new(key: Key, inner: Rc>>) -> Self { + Self { key, - fut, inner: Some(inner), - h2: None, } } + + fn consume(mut self) -> Acquired { + Acquired(self.key.clone(), self.inner.take()) + } } -impl Drop for OpenConnection +impl Drop for OpenGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { - if let Some(inner) = self.inner.take() { - let mut inner = inner.as_ref().borrow_mut(); + if let Some(i) = self.inner.take() { + let mut inner = i.as_ref().borrow_mut(); inner.release(); inner.check_availibility(); } } } -impl Future for OpenConnection -where - F: Future, - Io: AsyncRead + AsyncWrite, -{ - type Item = IoConnection; - type Error = ConnectError; - - fn poll(&mut self) -> Poll { - if let Some(ref mut h2) = self.h2 { - return match h2.poll() { - Ok(Async::Ready((snd, connection))) => { - tokio_current_thread::spawn(connection.map_err(|_| ())); - Ok(Async::Ready(IoConnection::new( - ConnectionType::H2(snd), - Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), - ))) - } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e.into()), - }; - } - - match self.fut.poll() { - Err(err) => Err(err), - Ok(Async::Ready((io, proto))) => { - if proto == Protocol::Http1 { - Ok(Async::Ready(IoConnection::new( - ConnectionType::H1(io), - Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), - ))) - } else { - self.h2 = Some(handshake(io)); - self.poll() - } - } - Ok(Async::NotReady) => Ok(Async::NotReady), - } - } -} - enum Acquire { Acquired(ConnectionType, Instant), Available, @@ -304,7 +259,7 @@ pub(crate) struct Inner { disconnect_timeout: Option, limit: usize, acquired: usize, - available: HashMap>>, + available: FxHashMap>>, waiters: Slab< Option<( Connect, @@ -312,7 +267,7 @@ pub(crate) struct Inner { )>, >, waiters_queue: IndexSet<(Key, usize)>, - task: Option, + waker: LocalWaker, } impl Inner { @@ -326,13 +281,13 @@ impl Inner { fn release_waiter(&mut self, key: &Key, token: usize) { self.waiters.remove(token); - self.waiters_queue.remove(&(key.clone(), token)); + let _ = self.waiters_queue.shift_remove(&(key.clone(), token)); } } impl Inner where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { /// connection is not available, wait fn wait_for( @@ -341,20 +296,19 @@ where ) -> ( oneshot::Receiver, ConnectError>>, usize, - bool, ) { let (tx, rx) = oneshot::channel(); - let key: Key = connect.uri.authority_part().unwrap().clone().into(); + let key: Key = connect.uri.authority().unwrap().clone().into(); let entry = self.waiters.vacant_entry(); let token = entry.key(); entry.insert(Some((connect, tx))); assert!(self.waiters_queue.insert((key, token))); - (rx, token, self.task.is_some()) + (rx, token) } - fn acquire(&mut self, key: &Key) -> Acquire { + fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire { // check limits if self.limit > 0 && self.acquired >= self.limit { return Acquire::NotAvailable; @@ -373,28 +327,26 @@ where { if let Some(timeout) = self.disconnect_timeout { if let ConnectionType::H1(io) = conn.io { - tokio_current_thread::spawn(CloseConnection::new( - io, timeout, - )) + actix_rt::spawn(CloseConnection::new(io, timeout)) } } } else { let mut io = conn.io; let mut buf = [0; 2]; if let ConnectionType::H1(ref mut s) = io { - match s.read(&mut buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Ok(n) if n > 0 => { + match Pin::new(s).poll_read(cx, &mut buf) { + Poll::Pending => (), + Poll::Ready(Ok(n)) if n > 0 => { if let Some(timeout) = self.disconnect_timeout { if let ConnectionType::H1(io) = io { - tokio_current_thread::spawn( - CloseConnection::new(io, timeout), - ) + actix_rt::spawn(CloseConnection::new( + io, timeout, + )) } } continue; } - Ok(_) | Err(_) => continue, + _ => continue, } } return Acquire::Acquired(io, conn.created); @@ -421,7 +373,7 @@ where self.acquired -= 1; if let Some(timeout) = self.disconnect_timeout { if let ConnectionType::H1(io) = io { - tokio_current_thread::spawn(CloseConnection::new(io, timeout)) + actix_rt::spawn(CloseConnection::new(io, timeout)) } } self.check_availibility(); @@ -429,9 +381,7 @@ where fn check_availibility(&self) { if !self.waiters_queue.is_empty() && self.acquired < self.limit { - if let Some(t) = self.task.as_ref() { - t.notify() - } + self.waker.wake(); } } } @@ -443,29 +393,30 @@ struct CloseConnection { impl CloseConnection where - T: AsyncWrite, + T: AsyncWrite + Unpin, { fn new(io: T, timeout: Duration) -> Self { CloseConnection { io, - timeout: sleep(timeout), + timeout: delay_for(timeout), } } } impl Future for CloseConnection where - T: AsyncWrite, + T: AsyncWrite + Unpin, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll<(), ()> { - match self.timeout.poll() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => match self.io.shutdown() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => Ok(Async::NotReady), + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.get_mut(); + + match Pin::new(&mut this.timeout).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, }, } } @@ -473,7 +424,7 @@ where struct ConnectorPoolSupport where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { connector: T, inner: Rc>>, @@ -481,16 +432,17 @@ where impl Future for ConnectorPoolSupport where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service, T::Future: 'static, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - let mut inner = self.inner.as_ref().borrow_mut(); - inner.task.as_ref().unwrap().register(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + + let mut inner = this.inner.as_ref().borrow_mut(); + inner.waker.register(cx.waker()); // check waiters loop { @@ -505,14 +457,14 @@ where continue; } - match inner.acquire(&key) { + match inner.acquire(&key, cx) { Acquire::NotAvailable => break, Acquire::Acquired(io, created) => { let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; if let Err(conn) = tx.send(Ok(IoConnection::new( io, created, - Some(Acquired(key.clone(), Some(self.inner.clone()))), + Some(Acquired(key.clone(), Some(this.inner.clone()))), ))) { let (io, created) = conn.unwrap().into_inner(); inner.release_conn(&key, io, created); @@ -524,33 +476,38 @@ where OpenWaitingConnection::spawn( key.clone(), tx, - self.inner.clone(), - self.connector.call(connect), + this.inner.clone(), + this.connector.call(connect), ); } } let _ = inner.waiters_queue.swap_remove_index(0); } - Ok(Async::NotReady) + Poll::Pending } } struct OpenWaitingConnection where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fut: F, key: Key, - h2: Option>, + h2: Option< + LocalBoxFuture< + 'static, + Result<(SendRequest, Connection), h2::Error>, + >, + >, rx: Option, ConnectError>>>, inner: Option>>>, } impl OpenWaitingConnection where - F: Future + 'static, - Io: AsyncRead + AsyncWrite + 'static, + F: Future> + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn spawn( key: Key, @@ -558,7 +515,7 @@ where inner: Rc>>, fut: F, ) { - tokio_current_thread::spawn(OpenWaitingConnection { + actix_rt::spawn(OpenWaitingConnection { key, fut, h2: None, @@ -570,7 +527,7 @@ where impl Drop for OpenWaitingConnection where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { if let Some(inner) = self.inner.take() { @@ -583,59 +540,60 @@ where impl Future for OpenWaitingConnection where - F: Future, - Io: AsyncRead + AsyncWrite, + F: Future>, + Io: AsyncRead + AsyncWrite + Unpin, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - if let Some(ref mut h2) = self.h2 { - return match h2.poll() { - Ok(Async::Ready((snd, connection))) => { - tokio_current_thread::spawn(connection.map_err(|_| ())); - let rx = self.rx.take().unwrap(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + + if let Some(ref mut h2) = this.h2 { + return match Pin::new(h2).poll(cx) { + Poll::Ready(Ok((snd, connection))) => { + actix_rt::spawn(connection.map(|_| ())); + let rx = this.rx.take().unwrap(); let _ = rx.send(Ok(IoConnection::new( ConnectionType::H2(snd), Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), + Some(Acquired(this.key.clone(), this.inner.take())), ))); - Ok(Async::Ready(())) + Poll::Ready(()) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { - let _ = self.inner.take(); - if let Some(rx) = self.rx.take() { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { let _ = rx.send(Err(ConnectError::H2(err))); } - Err(()) + Poll::Ready(()) } }; } - match self.fut.poll() { - Err(err) => { - let _ = self.inner.take(); - if let Some(rx) = self.rx.take() { + match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) { + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { let _ = rx.send(Err(err)); } - Err(()) + Poll::Ready(()) } - Ok(Async::Ready((io, proto))) => { + Poll::Ready(Ok((io, proto))) => { if proto == Protocol::Http1 { - let rx = self.rx.take().unwrap(); + let rx = this.rx.take().unwrap(); let _ = rx.send(Ok(IoConnection::new( ConnectionType::H1(io), Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), + Some(Acquired(this.key.clone(), this.inner.take())), ))); - Ok(Async::Ready(())) + Poll::Ready(()) } else { - self.h2 = Some(handshake(io)); - self.poll() + this.h2 = Some(handshake(io).boxed_local()); + unsafe { Pin::new_unchecked(this) }.poll(cx) } } - Ok(Async::NotReady) => Ok(Async::NotReady), + Poll::Pending => Poll::Pending, } } } @@ -644,7 +602,7 @@ pub(crate) struct Acquired(Key, Option>>>); impl Acquired where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { pub(crate) fn close(&mut self, conn: IoConnection) { if let Some(inner) = self.1.take() { diff --git a/actix-http/src/cloneable.rs b/actix-http/src/cloneable.rs index ffc1d0611..65c6bec21 100644 --- a/actix-http/src/cloneable.rs +++ b/actix-http/src/cloneable.rs @@ -1,39 +1,33 @@ use std::cell::UnsafeCell; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::Service; -use futures::Poll; #[doc(hidden)] /// Service that allows to turn non-clone service to a service with `Clone` impl -pub(crate) struct CloneableService(Rc>); +pub(crate) struct CloneableService(Rc>); -impl CloneableService { - pub(crate) fn new(service: T) -> Self - where - T: Service, - { +impl CloneableService { + pub(crate) fn new(service: T) -> Self { Self(Rc::new(UnsafeCell::new(service))) } } -impl Clone for CloneableService { +impl Clone for CloneableService { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl Service for CloneableService -where - T: Service, -{ +impl Service for CloneableService { type Request = T::Request; type Response = T::Response; type Error = T::Error; type Future = T::Future; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - unsafe { &mut *self.0.as_ref().get() }.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + unsafe { &mut *self.0.as_ref().get() }.poll_ready(cx) } fn call(&mut self, req: T::Request) -> Self::Future { diff --git a/actix-http/src/config.rs b/actix-http/src/config.rs index bdfecef30..be949aaef 100644 --- a/actix-http/src/config.rs +++ b/actix-http/src/config.rs @@ -1,13 +1,13 @@ use std::cell::UnsafeCell; -use std::fmt; use std::fmt::Write; use std::rc::Rc; -use std::time::{Duration, Instant}; +use std::time::Duration; +use std::{fmt, net}; +use actix_rt::time::{delay_for, delay_until, Delay, Instant}; use bytes::BytesMut; -use futures::{future, Future}; +use futures_util::{future, FutureExt}; use time; -use tokio_timer::{sleep, Delay}; // "Sun, 06 Nov 1994 08:49:37 GMT".len() const DATE_VALUE_LENGTH: usize = 29; @@ -47,6 +47,8 @@ struct Inner { client_timeout: u64, client_disconnect: u64, ka_enabled: bool, + secure: bool, + local_addr: Option, timer: DateService, } @@ -58,7 +60,7 @@ impl Clone for ServiceConfig { impl Default for ServiceConfig { fn default() -> Self { - Self::new(KeepAlive::Timeout(5), 0, 0) + Self::new(KeepAlive::Timeout(5), 0, 0, false, None) } } @@ -68,6 +70,8 @@ impl ServiceConfig { keep_alive: KeepAlive, client_timeout: u64, client_disconnect: u64, + secure: bool, + local_addr: Option, ) -> ServiceConfig { let (keep_alive, ka_enabled) = match keep_alive { KeepAlive::Timeout(val) => (val as u64, true), @@ -85,10 +89,24 @@ impl ServiceConfig { ka_enabled, client_timeout, client_disconnect, + secure, + local_addr, timer: DateService::new(), })) } + #[inline] + /// Returns true if connection is secure(https) + pub fn secure(&self) -> bool { + self.0.secure + } + + #[inline] + /// Returns the local address that this server is bound to. + pub fn local_addr(&self) -> Option { + self.0.local_addr + } + #[inline] /// Keep alive duration if configured. pub fn keep_alive(&self) -> Option { @@ -104,10 +122,10 @@ impl ServiceConfig { #[inline] /// Client timeout for first request. pub fn client_timer(&self) -> Option { - let delay = self.0.client_timeout; - if delay != 0 { - Some(Delay::new( - self.0.timer.now() + Duration::from_millis(delay), + let delay_time = self.0.client_timeout; + if delay_time != 0 { + Some(delay_until( + self.0.timer.now() + Duration::from_millis(delay_time), )) } else { None @@ -138,7 +156,7 @@ impl ServiceConfig { /// Return keep-alive timer delay is configured. pub fn keep_alive_timer(&self) -> Option { if let Some(ka) = self.0.keep_alive { - Some(Delay::new(self.0.timer.now() + ka)) + Some(delay_until(self.0.timer.now() + ka)) } else { None } @@ -242,12 +260,10 @@ impl DateService { // periodic date update let s = self.clone(); - tokio_current_thread::spawn(sleep(Duration::from_millis(500)).then( - move |_| { - s.0.reset(); - future::ok(()) - }, - )); + actix_rt::spawn(delay_for(Duration::from_millis(500)).then(move |_| { + s.0.reset(); + future::ready(()) + })); } } @@ -265,26 +281,19 @@ impl DateService { #[cfg(test)] mod tests { use super::*; - use actix_rt::System; - use futures::future; #[test] fn test_date_len() { assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); } - #[test] - fn test_date() { - let mut rt = System::new("test"); - - let _ = rt.block_on(future::lazy(|| { - let settings = ServiceConfig::new(KeepAlive::Os, 0, 0); - let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); - settings.set_date(&mut buf1); - let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); - settings.set_date(&mut buf2); - assert_eq!(buf1, buf2); - future::ok::<_, ()>(()) - })); + #[actix_rt::test] + async fn test_date() { + let settings = ServiceConfig::new(KeepAlive::Os, 0, 0, false, None); + let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf1); + let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf2); + assert_eq!(buf1, buf2); } } diff --git a/actix-http/src/cookie/builder.rs b/actix-http/src/cookie/builder.rs index efeddbb62..f99d02b02 100644 --- a/actix-http/src/cookie/builder.rs +++ b/actix-http/src/cookie/builder.rs @@ -18,7 +18,6 @@ use super::{Cookie, SameSite}; /// ```rust /// use actix_http::cookie::Cookie; /// -/// # fn main() { /// let cookie: Cookie = Cookie::build("name", "value") /// .domain("www.rust-lang.org") /// .path("/") @@ -26,7 +25,6 @@ use super::{Cookie, SameSite}; /// .http_only(true) /// .max_age(84600) /// .finish(); -/// # } /// ``` #[derive(Debug, Clone)] pub struct CookieBuilder { @@ -65,13 +63,11 @@ impl CookieBuilder { /// ```rust /// use actix_http::cookie::Cookie; /// - /// # fn main() { /// let c = Cookie::build("foo", "bar") /// .expires(time::now()) /// .finish(); /// /// assert!(c.expires().is_some()); - /// # } /// ``` #[inline] pub fn expires(mut self, when: Tm) -> CookieBuilder { @@ -86,13 +82,11 @@ impl CookieBuilder { /// ```rust /// use actix_http::cookie::Cookie; /// - /// # fn main() { /// let c = Cookie::build("foo", "bar") /// .max_age(1800) /// .finish(); /// /// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60))); - /// # } /// ``` #[inline] pub fn max_age(self, seconds: i64) -> CookieBuilder { @@ -106,13 +100,11 @@ impl CookieBuilder { /// ```rust /// use actix_http::cookie::Cookie; /// - /// # fn main() { /// let c = Cookie::build("foo", "bar") /// .max_age_time(time::Duration::minutes(30)) /// .finish(); /// /// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60))); - /// # } /// ``` #[inline] pub fn max_age_time(mut self, value: Duration) -> CookieBuilder { @@ -222,14 +214,12 @@ impl CookieBuilder { /// use actix_http::cookie::Cookie; /// use chrono::Duration; /// - /// # fn main() { /// let c = Cookie::build("foo", "bar") /// .permanent() /// .finish(); /// /// assert_eq!(c.max_age(), Some(Duration::days(365 * 20))); /// # assert!(c.expires().is_some()); - /// # } /// ``` #[inline] pub fn permanent(mut self) -> CookieBuilder { diff --git a/actix-http/src/cookie/draft.rs b/actix-http/src/cookie/draft.rs index 362133946..a2b039121 100644 --- a/actix-http/src/cookie/draft.rs +++ b/actix-http/src/cookie/draft.rs @@ -88,7 +88,7 @@ impl SameSite { } impl fmt::Display for SameSite { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { SameSite::Strict => write!(f, "Strict"), SameSite::Lax => write!(f, "Lax"), diff --git a/actix-http/src/cookie/jar.rs b/actix-http/src/cookie/jar.rs index cc67536c6..dc2de4dfe 100644 --- a/actix-http/src/cookie/jar.rs +++ b/actix-http/src/cookie/jar.rs @@ -190,7 +190,6 @@ impl CookieJar { /// use actix_http::cookie::{CookieJar, Cookie}; /// use chrono::Duration; /// - /// # fn main() { /// let mut jar = CookieJar::new(); /// /// // Assume this cookie originally had a path of "/" and domain of "a.b". @@ -204,7 +203,6 @@ impl CookieJar { /// assert_eq!(delta.len(), 1); /// assert_eq!(delta[0].name(), "name"); /// assert_eq!(delta[0].max_age(), Some(Duration::seconds(0))); - /// # } /// ``` /// /// Removing a new cookie does not result in a _removal_ cookie: @@ -243,7 +241,6 @@ impl CookieJar { /// use actix_http::cookie::{CookieJar, Cookie}; /// use chrono::Duration; /// - /// # fn main() { /// let mut jar = CookieJar::new(); /// /// // Add an original cookie and a new cookie. @@ -261,7 +258,6 @@ impl CookieJar { /// jar.force_remove(Cookie::new("key", "value")); /// assert_eq!(jar.delta().count(), 0); /// assert_eq!(jar.iter().count(), 0); - /// # } /// ``` pub fn force_remove<'a>(&mut self, cookie: Cookie<'a>) { self.original_cookies.remove(cookie.name()); @@ -307,7 +303,7 @@ impl CookieJar { /// // Delta contains two new cookies ("new", "yac") and a removal ("name"). /// assert_eq!(jar.delta().count(), 3); /// ``` - pub fn delta(&self) -> Delta { + pub fn delta(&self) -> Delta<'_> { Delta { iter: self.delta_cookies.iter(), } @@ -343,7 +339,7 @@ impl CookieJar { /// } /// } /// ``` - pub fn iter(&self) -> Iter { + pub fn iter(&self) -> Iter<'_> { Iter { delta_cookies: self .delta_cookies @@ -386,7 +382,7 @@ impl CookieJar { /// assert!(jar.get("private").is_some()); /// ``` #[cfg(feature = "secure-cookies")] - pub fn private(&mut self, key: &Key) -> PrivateJar { + pub fn private(&mut self, key: &Key) -> PrivateJar<'_> { PrivateJar::new(self, key) } @@ -424,7 +420,7 @@ impl CookieJar { /// assert!(jar.get("signed").is_some()); /// ``` #[cfg(feature = "secure-cookies")] - pub fn signed(&mut self, key: &Key) -> SignedJar { + pub fn signed(&mut self, key: &Key) -> SignedJar<'_> { SignedJar::new(self, key) } } diff --git a/actix-http/src/cookie/mod.rs b/actix-http/src/cookie/mod.rs index db8211427..13fd5cf4e 100644 --- a/actix-http/src/cookie/mod.rs +++ b/actix-http/src/cookie/mod.rs @@ -110,7 +110,7 @@ impl CookieStr { /// # Panics /// /// Panics if `self` is an indexed string and `string` is None. - fn to_str<'s>(&'s self, string: Option<&'s Cow>) -> &'s str { + fn to_str<'s>(&'s self, string: Option<&'s Cow<'_, str>>) -> &'s str { match *self { CookieStr::Indexed(i, j) => { let s = string.expect( @@ -647,13 +647,11 @@ impl<'c> Cookie<'c> { /// use actix_http::cookie::Cookie; /// use chrono::Duration; /// - /// # fn main() { /// let mut c = Cookie::new("name", "value"); /// assert_eq!(c.max_age(), None); /// /// c.set_max_age(Duration::hours(10)); /// assert_eq!(c.max_age(), Some(Duration::hours(10))); - /// # } /// ``` #[inline] pub fn set_max_age(&mut self, value: Duration) { @@ -701,7 +699,6 @@ impl<'c> Cookie<'c> { /// ```rust /// use actix_http::cookie::Cookie; /// - /// # fn main() { /// let mut c = Cookie::new("name", "value"); /// assert_eq!(c.expires(), None); /// @@ -710,7 +707,6 @@ impl<'c> Cookie<'c> { /// /// c.set_expires(now); /// assert!(c.expires().is_some()) - /// # } /// ``` #[inline] pub fn set_expires(&mut self, time: Tm) { @@ -726,7 +722,6 @@ impl<'c> Cookie<'c> { /// use actix_http::cookie::Cookie; /// use chrono::Duration; /// - /// # fn main() { /// let mut c = Cookie::new("foo", "bar"); /// assert!(c.expires().is_none()); /// assert!(c.max_age().is_none()); @@ -734,7 +729,6 @@ impl<'c> Cookie<'c> { /// c.make_permanent(); /// assert!(c.expires().is_some()); /// assert_eq!(c.max_age(), Some(Duration::days(365 * 20))); - /// # } /// ``` pub fn make_permanent(&mut self) { let twenty_years = Duration::days(365 * 20); @@ -742,7 +736,7 @@ impl<'c> Cookie<'c> { self.set_expires(time::now() + twenty_years); } - fn fmt_parameters(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt_parameters(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(true) = self.http_only() { write!(f, "; HttpOnly")?; } @@ -924,10 +918,10 @@ impl<'c> Cookie<'c> { /// let mut c = Cookie::new("my name", "this; value?"); /// assert_eq!(&c.encoded().to_string(), "my%20name=this%3B%20value%3F"); /// ``` -pub struct EncodedCookie<'a, 'c: 'a>(&'a Cookie<'c>); +pub struct EncodedCookie<'a, 'c>(&'a Cookie<'c>); impl<'a, 'c: 'a> fmt::Display for EncodedCookie<'a, 'c> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Percent-encode the name and value. let name = percent_encode(self.0.name().as_bytes(), USERINFO); let value = percent_encode(self.0.value().as_bytes(), USERINFO); @@ -952,7 +946,7 @@ impl<'c> fmt::Display for Cookie<'c> { /// /// assert_eq!(&cookie.to_string(), "foo=bar; Path=/"); /// ``` - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}={}", self.name(), self.value())?; self.fmt_parameters(f) } diff --git a/actix-http/src/cookie/parse.rs b/actix-http/src/cookie/parse.rs index 42a2c1fcf..20aee9507 100644 --- a/actix-http/src/cookie/parse.rs +++ b/actix-http/src/cookie/parse.rs @@ -40,7 +40,7 @@ impl ParseError { } impl fmt::Display for ParseError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.as_str()) } } @@ -51,11 +51,7 @@ impl From for ParseError { } } -impl Error for ParseError { - fn description(&self) -> &str { - self.as_str() - } -} +impl Error for ParseError {} fn indexes_of(needle: &str, haystack: &str) -> Option<(usize, usize)> { let haystack_start = haystack.as_ptr() as usize; diff --git a/actix-http/src/cookie/secure/key.rs b/actix-http/src/cookie/secure/key.rs index 39575c93f..779c16b75 100644 --- a/actix-http/src/cookie/secure/key.rs +++ b/actix-http/src/cookie/secure/key.rs @@ -1,13 +1,11 @@ -use ring::digest::{Algorithm, SHA256}; -use ring::hkdf::expand; -use ring::hmac::SigningKey; +use ring::hkdf::{Algorithm, KeyType, Prk, HKDF_SHA256}; use ring::rand::{SecureRandom, SystemRandom}; use super::private::KEY_LEN as PRIVATE_KEY_LEN; use super::signed::KEY_LEN as SIGNED_KEY_LEN; -static HKDF_DIGEST: &Algorithm = &SHA256; -const KEYS_INFO: &str = "COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"; +static HKDF_DIGEST: Algorithm = HKDF_SHA256; +const KEYS_INFO: &[&[u8]] = &[b"COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"]; /// A cryptographic master key for use with `Signed` and/or `Private` jars. /// @@ -25,6 +23,13 @@ pub struct Key { encryption_key: [u8; PRIVATE_KEY_LEN], } +impl KeyType for &Key { + #[inline] + fn len(&self) -> usize { + SIGNED_KEY_LEN + PRIVATE_KEY_LEN + } +} + impl Key { /// Derives new signing/encryption keys from a master key. /// @@ -56,21 +61,26 @@ impl Key { ); } - // Expand the user's key into two. - let prk = SigningKey::new(HKDF_DIGEST, key); + // An empty `Key` structure; will be filled in with HKDF derived keys. + let mut output_key = Key { + signing_key: [0; SIGNED_KEY_LEN], + encryption_key: [0; PRIVATE_KEY_LEN], + }; + + // Expand the master key into two HKDF generated keys. let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN]; - expand(&prk, KEYS_INFO.as_bytes(), &mut both_keys); + let prk = Prk::new_less_safe(HKDF_DIGEST, key); + let okm = prk.expand(KEYS_INFO, &output_key).expect("okm expand"); + okm.fill(&mut both_keys).expect("fill keys"); - // Copy the keys into their respective arrays. - let mut signing_key = [0; SIGNED_KEY_LEN]; - let mut encryption_key = [0; PRIVATE_KEY_LEN]; - signing_key.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); - encryption_key.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); - - Key { - signing_key, - encryption_key, - } + // Copy the key parts into their respective fields. + output_key + .signing_key + .copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); + output_key + .encryption_key + .copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); + output_key } /// Generates signing/encryption keys from a secure, random source. Keys are diff --git a/actix-http/src/cookie/secure/private.rs b/actix-http/src/cookie/secure/private.rs index eb8e9beb1..f05e23800 100644 --- a/actix-http/src/cookie/secure/private.rs +++ b/actix-http/src/cookie/secure/private.rs @@ -1,8 +1,8 @@ use std::str; use log::warn; -use ring::aead::{open_in_place, seal_in_place, Aad, Algorithm, Nonce, AES_256_GCM}; -use ring::aead::{OpeningKey, SealingKey}; +use ring::aead::{Aad, Algorithm, Nonce, AES_256_GCM}; +use ring::aead::{LessSafeKey, UnboundKey}; use ring::rand::{SecureRandom, SystemRandom}; use super::Key; @@ -53,11 +53,14 @@ impl<'a> PrivateJar<'a> { } let ad = Aad::from(name.as_bytes()); - let key = OpeningKey::new(ALGO, &self.key).expect("opening key"); - let (nonce, sealed) = data.split_at_mut(NONCE_LEN); + let key = LessSafeKey::new( + UnboundKey::new(&ALGO, &self.key).expect("matching key length"), + ); + let (nonce, mut sealed) = data.split_at_mut(NONCE_LEN); let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); - let unsealed = open_in_place(&key, nonce, ad, 0, sealed) + let unsealed = key + .open_in_place(nonce, ad, &mut sealed) .map_err(|_| "invalid key/nonce/value: bad seal")?; if let Ok(unsealed_utf8) = str::from_utf8(unsealed) { @@ -156,7 +159,7 @@ Please change it as soon as possible." /// Encrypts the cookie's value with /// authenticated encryption assuring confidentiality, integrity, and authenticity. - fn encrypt_cookie(&self, cookie: &mut Cookie) { + fn encrypt_cookie(&self, cookie: &mut Cookie<'_>) { let name = cookie.name().as_bytes(); let value = cookie.value().as_bytes(); let data = encrypt_name_value(name, value, &self.key); @@ -196,30 +199,33 @@ Please change it as soon as possible." fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec { // Create the `SealingKey` structure. - let key = SealingKey::new(ALGO, key).expect("sealing key creation"); + let unbound = UnboundKey::new(&ALGO, key).expect("matching key length"); + let key = LessSafeKey::new(unbound); // Create a vec to hold the [nonce | cookie value | overhead]. - let overhead = ALGO.tag_len(); - let mut data = vec![0; NONCE_LEN + value.len() + overhead]; + let mut data = vec![0; NONCE_LEN + value.len() + ALGO.tag_len()]; // Randomly generate the nonce, then copy the cookie value as input. let (nonce, in_out) = data.split_at_mut(NONCE_LEN); + let (in_out, tag) = in_out.split_at_mut(value.len()); + in_out.copy_from_slice(value); + + // Randomly generate the nonce into the nonce piece. SystemRandom::new() .fill(nonce) .expect("couldn't random fill nonce"); - in_out[..value.len()].copy_from_slice(value); - let nonce = - Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); + let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid `nonce` length"); // Use cookie's name as associated data to prevent value swapping. let ad = Aad::from(name); + let ad_tag = key + .seal_in_place_separate_tag(nonce, ad, in_out) + .expect("in-place seal"); - // Perform the actual sealing operation and get the output length. - let output_len = - seal_in_place(&key, nonce, ad, in_out, overhead).expect("in-place seal"); + // Copy the tag into the tag piece. + tag.copy_from_slice(ad_tag.as_ref()); // Remove the overhead and return the sealed content. - data.truncate(NONCE_LEN + output_len); data } diff --git a/actix-http/src/cookie/secure/signed.rs b/actix-http/src/cookie/secure/signed.rs index 36a277cd5..64e8d5dda 100644 --- a/actix-http/src/cookie/secure/signed.rs +++ b/actix-http/src/cookie/secure/signed.rs @@ -1,12 +1,11 @@ -use ring::digest::{Algorithm, SHA256}; -use ring::hmac::{sign, verify_with_own_key as verify, SigningKey}; +use ring::hmac::{self, sign, verify}; use super::Key; use crate::cookie::{Cookie, CookieJar}; // Keep these in sync, and keep the key len synced with the `signed` docs as // well as the `KEYS_INFO` const in secure::Key. -static HMAC_DIGEST: &Algorithm = &SHA256; +static HMAC_DIGEST: hmac::Algorithm = hmac::HMAC_SHA256; const BASE64_DIGEST_LEN: usize = 44; pub const KEY_LEN: usize = 32; @@ -21,7 +20,7 @@ pub const KEY_LEN: usize = 32; /// This type is only available when the `secure` feature is enabled. pub struct SignedJar<'a> { parent: &'a mut CookieJar, - key: SigningKey, + key: hmac::Key, } impl<'a> SignedJar<'a> { @@ -32,7 +31,7 @@ impl<'a> SignedJar<'a> { pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> { SignedJar { parent, - key: SigningKey::new(HMAC_DIGEST, key.signing()), + key: hmac::Key::new(HMAC_DIGEST, key.signing()), } } @@ -130,7 +129,7 @@ impl<'a> SignedJar<'a> { } /// Signs the cookie's value assuring integrity and authenticity. - fn sign_cookie(&self, cookie: &mut Cookie) { + fn sign_cookie(&self, cookie: &mut Cookie<'_>) { let digest = sign(&self.key, cookie.value().as_bytes()); let mut new_value = base64::encode(digest.as_ref()); new_value.push_str(cookie.value()); diff --git a/actix-http/src/encoding/decoder.rs b/actix-http/src/encoding/decoder.rs index 4b56a1b62..b60435859 100644 --- a/actix-http/src/encoding/decoder.rs +++ b/actix-http/src/encoding/decoder.rs @@ -1,12 +1,13 @@ +use std::future::Future; use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_threadpool::{run, CpuFuture}; -#[cfg(feature = "brotli")] use brotli2::write::BrotliDecoder; use bytes::Bytes; -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] use flate2::write::{GzDecoder, ZlibDecoder}; -use futures::{try_ready, Async, Future, Poll, Stream}; +use futures_core::{ready, Stream}; use super::Writer; use crate::error::PayloadError; @@ -23,21 +24,18 @@ pub struct Decoder { impl Decoder where - S: Stream, + S: Stream>, { /// Construct a decoder. #[inline] pub fn new(stream: S, encoding: ContentEncoding) -> Decoder { let decoder = match encoding { - #[cfg(feature = "brotli")] ContentEncoding::Br => Some(ContentDecoder::Br(Box::new( BrotliDecoder::new(Writer::new()), ))), - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new( ZlibDecoder::new(Writer::new()), ))), - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new( GzDecoder::new(Writer::new()), ))), @@ -71,34 +69,40 @@ where impl Stream for Decoder where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { loop { if let Some(ref mut fut) = self.fut { - let (chunk, decoder) = try_ready!(fut.poll()); + let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) { + Ok(item) => item, + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; self.decoder = Some(decoder); self.fut.take(); if let Some(chunk) = chunk { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } if self.eof { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } - match self.stream.poll()? { - Async::Ready(Some(chunk)) => { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), + Poll::Ready(Some(Ok(chunk))) => { if let Some(mut decoder) = self.decoder.take() { if chunk.len() < INPLACE { let chunk = decoder.feed_data(chunk)?; self.decoder = Some(decoder); if let Some(chunk) = chunk { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } else { self.fut = Some(run(move || { @@ -108,41 +112,40 @@ where } continue; } else { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } - Async::Ready(None) => { + Poll::Ready(None) => { self.eof = true; return if let Some(mut decoder) = self.decoder.take() { - Ok(Async::Ready(decoder.feed_eof()?)) + match decoder.feed_eof() { + Ok(Some(res)) => Poll::Ready(Some(Ok(res))), + Ok(None) => Poll::Ready(None), + Err(err) => Poll::Ready(Some(Err(err.into()))), + } } else { - Ok(Async::Ready(None)) + Poll::Ready(None) }; } - Async::NotReady => break, + Poll::Pending => break, } } - Ok(Async::NotReady) + Poll::Pending } } enum ContentDecoder { - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] Deflate(Box>), - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] Gzip(Box>), - #[cfg(feature = "brotli")] Br(Box>), } impl ContentDecoder { - #[allow(unreachable_patterns)] fn feed_eof(&mut self) -> io::Result> { match self { - #[cfg(feature = "brotli")] - ContentDecoder::Br(ref mut decoder) => match decoder.finish() { - Ok(mut writer) => { - let b = writer.take(); + ContentDecoder::Br(ref mut decoder) => match decoder.flush() { + Ok(()) => { + let b = decoder.get_mut().take(); if !b.is_empty() { Ok(Some(b)) } else { @@ -151,7 +154,6 @@ impl ContentDecoder { } Err(e) => Err(e), }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() { Ok(_) => { let b = decoder.get_mut().take(); @@ -163,7 +165,6 @@ impl ContentDecoder { } Err(e) => Err(e), }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentDecoder::Deflate(ref mut decoder) => match decoder.try_finish() { Ok(_) => { let b = decoder.get_mut().take(); @@ -175,14 +176,11 @@ impl ContentDecoder { } Err(e) => Err(e), }, - _ => Ok(None), } } - #[allow(unreachable_patterns)] fn feed_data(&mut self, data: Bytes) -> io::Result> { match self { - #[cfg(feature = "brotli")] ContentDecoder::Br(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; @@ -195,7 +193,6 @@ impl ContentDecoder { } Err(e) => Err(e), }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; @@ -208,7 +205,6 @@ impl ContentDecoder { } Err(e) => Err(e), }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentDecoder::Deflate(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; @@ -221,7 +217,6 @@ impl ContentDecoder { } Err(e) => Err(e), }, - _ => Ok(Some(data)), } } } diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index 58d8a2d9e..ca04845ab 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -1,22 +1,23 @@ //! Stream encoder +use std::future::Future; use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_threadpool::{run, CpuFuture}; -#[cfg(feature = "brotli")] use brotli2::write::BrotliEncoder; use bytes::Bytes; -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] use flate2::write::{GzEncoder, ZlibEncoder}; -use futures::{Async, Future, Poll}; +use futures_core::ready; use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; -use crate::http::{HeaderValue, HttpTryFrom, StatusCode}; +use crate::http::{HeaderValue, StatusCode}; use crate::{Error, ResponseHead}; use super::Writer; -const INPLACE: usize = 2049; +const INPLACE: usize = 1024; pub struct Encoder { eof: bool, @@ -94,43 +95,45 @@ impl MessageBody for Encoder { } } - fn poll_next(&mut self) -> Poll, Error> { + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { loop { if self.eof { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } if let Some(ref mut fut) = self.fut { - let mut encoder = futures::try_ready!(fut.poll()); + let mut encoder = match ready!(Pin::new(fut).poll(cx)) { + Ok(item) => item, + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; let chunk = encoder.take(); self.encoder = Some(encoder); self.fut.take(); if !chunk.is_empty() { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } let result = match self.body { EncoderBody::Bytes(ref mut b) => { if b.is_empty() { - Async::Ready(None) + Poll::Ready(None) } else { - Async::Ready(Some(std::mem::replace(b, Bytes::new()))) + Poll::Ready(Some(Ok(std::mem::replace(b, Bytes::new())))) } } - EncoderBody::Stream(ref mut b) => b.poll_next()?, - EncoderBody::BoxedStream(ref mut b) => b.poll_next()?, + EncoderBody::Stream(ref mut b) => b.poll_next(cx), + EncoderBody::BoxedStream(ref mut b) => b.poll_next(cx), }; match result { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(chunk)) => { + Poll::Ready(Some(Ok(chunk))) => { if let Some(mut encoder) = self.encoder.take() { if chunk.len() < INPLACE { encoder.write(&chunk)?; let chunk = encoder.take(); self.encoder = Some(encoder); if !chunk.is_empty() { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } else { self.fut = Some(run(move || { @@ -139,22 +142,23 @@ impl MessageBody for Encoder { })); } } else { - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } - Async::Ready(None) => { + Poll::Ready(None) => { if let Some(encoder) = self.encoder.take() { let chunk = encoder.finish()?; if chunk.is_empty() { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } else { self.eof = true; - return Ok(Async::Ready(Some(chunk))); + return Poll::Ready(Some(Ok(chunk))); } } else { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } } + val => return val, } } } @@ -163,33 +167,27 @@ impl MessageBody for Encoder { fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) { head.headers_mut().insert( CONTENT_ENCODING, - HeaderValue::try_from(Bytes::from_static(encoding.as_str().as_bytes())).unwrap(), + HeaderValue::from_static(encoding.as_str()), ); } enum ContentEncoder { - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] Deflate(ZlibEncoder), - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] Gzip(GzEncoder), - #[cfg(feature = "brotli")] Br(BrotliEncoder), } impl ContentEncoder { fn encoder(encoding: ContentEncoding) -> Option { match encoding { - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoding::Deflate => Some(ContentEncoder::Deflate(ZlibEncoder::new( Writer::new(), flate2::Compression::fast(), ))), - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoding::Gzip => Some(ContentEncoder::Gzip(GzEncoder::new( Writer::new(), flate2::Compression::fast(), ))), - #[cfg(feature = "brotli")] ContentEncoding::Br => { Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3))) } @@ -200,28 +198,22 @@ impl ContentEncoder { #[inline] pub(crate) fn take(&mut self) -> Bytes { match *self { - #[cfg(feature = "brotli")] ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(), - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(), - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(), } } fn finish(self) -> Result { match self { - #[cfg(feature = "brotli")] ContentEncoder::Br(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoder::Gzip(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoder::Deflate(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), @@ -231,7 +223,6 @@ impl ContentEncoder { fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { match *self { - #[cfg(feature = "brotli")] ContentEncoder::Br(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), Err(err) => { @@ -239,7 +230,6 @@ impl ContentEncoder { Err(err) } }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), Err(err) => { @@ -247,7 +237,6 @@ impl ContentEncoder { Err(err) } }, - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), Err(err) => { diff --git a/actix-http/src/encoding/mod.rs b/actix-http/src/encoding/mod.rs index b55a43a7c..9eaf4104e 100644 --- a/actix-http/src/encoding/mod.rs +++ b/actix-http/src/encoding/mod.rs @@ -19,8 +19,9 @@ impl Writer { buf: BytesMut::with_capacity(8192), } } + fn take(&mut self) -> Bytes { - self.buf.take().freeze() + self.buf.split().freeze() } } @@ -29,6 +30,7 @@ impl io::Write for Writer { self.buf.extend_from_slice(buf); Ok(buf.len()) } + fn flush(&mut self) -> io::Result<()> { Ok(()) } diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 2c01c86db..fd0fe927f 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -6,24 +6,25 @@ use std::str::Utf8Error; use std::string::FromUtf8Error; use std::{fmt, io, result}; +use actix_codec::{Decoder, Encoder}; pub use actix_threadpool::BlockingError; +use actix_utils::framed::DispatcherError as FramedDispatcherError; use actix_utils::timeout::TimeoutError; use bytes::BytesMut; use derive_more::{Display, From}; -use futures::Canceled; +pub use futures_channel::oneshot::Canceled; use http::uri::InvalidUri; use http::{header, Error as HttpError, StatusCode}; use httparse; use serde::de::value::Error as DeError; use serde_json::error::Error as JsonError; use serde_urlencoded::ser::Error as FormError; -use tokio_timer::Error as TimerError; // re-export for convinience use crate::body::Body; pub use crate::cookie::ParseError as CookieParseError; use crate::helpers::Writer; -use crate::response::Response; +use crate::response::{Response, ResponseBuilder}; /// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) /// for actix web operations @@ -61,21 +62,23 @@ impl Error { /// Error that can be converted to `Response` pub trait ResponseError: fmt::Debug + fmt::Display { + /// Response's status code + /// + /// Internal server error is generated by default. + fn status_code(&self) -> StatusCode { + StatusCode::INTERNAL_SERVER_ERROR + } + /// Create response for error /// /// Internal server error is generated by default. fn error_response(&self) -> Response { - Response::new(StatusCode::INTERNAL_SERVER_ERROR) - } - - /// Constructs an error response - fn render_response(&self) -> Response { - let mut resp = self.error_response(); + let mut resp = Response::new(self.status_code()); let mut buf = BytesMut::new(); let _ = write!(Writer(&mut buf), "{}", self); resp.headers_mut().insert( header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), + header::HeaderValue::from_static("text/plain; charset=utf-8"), ); resp.set_body(Body::from(buf)) } @@ -101,14 +104,24 @@ impl dyn ResponseError + 'static { } impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.cause, f) } } impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, "{:?}", &self.cause) + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", &self.cause) + } +} + +impl std::error::Error for Error { + fn cause(&self) -> Option<&dyn std::error::Error> { + None + } + + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None } } @@ -118,17 +131,11 @@ impl From<()> for Error { } } -impl std::error::Error for Error { - fn description(&self) -> &str { - "actix-http::Error" - } - - fn cause(&self) -> Option<&dyn std::error::Error> { - None - } - - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - None +impl From for Error { + fn from(_: std::convert::Infallible) -> Self { + // `std::convert::Infallible` indicates an error + // that will never happen + unreachable!() } } @@ -148,12 +155,26 @@ impl From for Error { } } +/// Convert Response to a Error +impl From for Error { + fn from(res: Response) -> Error { + InternalError::from_response("", res).into() + } +} + +/// Convert ResponseBuilder to a Error +impl From for Error { + fn from(mut res: ResponseBuilder) -> Error { + InternalError::from_response("", res.finish()).into() + } +} + /// Return `GATEWAY_TIMEOUT` for `TimeoutError` impl ResponseError for TimeoutError { - fn error_response(&self) -> Response { + fn status_code(&self) -> StatusCode { match self { - TimeoutError::Service(e) => e.error_response(), - TimeoutError::Timeout => Response::new(StatusCode::GATEWAY_TIMEOUT), + TimeoutError::Service(e) => e.status_code(), + TimeoutError::Timeout => StatusCode::GATEWAY_TIMEOUT, } } } @@ -171,31 +192,31 @@ impl ResponseError for JsonError {} /// `InternalServerError` for `FormError` impl ResponseError for FormError {} -/// `InternalServerError` for `TimerError` -impl ResponseError for TimerError {} - -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] /// `InternalServerError` for `openssl::ssl::Error` -impl ResponseError for openssl::ssl::Error {} +impl ResponseError for actix_connect::ssl::openssl::SslError {} -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] /// `InternalServerError` for `openssl::ssl::HandshakeError` -impl ResponseError for openssl::ssl::HandshakeError {} +impl ResponseError for actix_tls::openssl::HandshakeError {} /// Return `BAD_REQUEST` for `de::value::Error` impl ResponseError for DeError { - fn error_response(&self) -> Response { - Response::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } +/// `InternalServerError` for `Canceled` +impl ResponseError for Canceled {} + /// `InternalServerError` for `BlockingError` impl ResponseError for BlockingError {} /// Return `BAD_REQUEST` for `Utf8Error` impl ResponseError for Utf8Error { - fn error_response(&self) -> Response { - Response::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } @@ -205,32 +226,22 @@ impl ResponseError for HttpError {} /// Return `InternalServerError` for `io::Error` impl ResponseError for io::Error { - fn error_response(&self) -> Response { + fn status_code(&self) -> StatusCode { match self.kind() { - io::ErrorKind::NotFound => Response::new(StatusCode::NOT_FOUND), - io::ErrorKind::PermissionDenied => Response::new(StatusCode::FORBIDDEN), - _ => Response::new(StatusCode::INTERNAL_SERVER_ERROR), + io::ErrorKind::NotFound => StatusCode::NOT_FOUND, + io::ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, + _ => StatusCode::INTERNAL_SERVER_ERROR, } } } /// `BadRequest` for `InvalidHeaderValue` impl ResponseError for header::InvalidHeaderValue { - fn error_response(&self) -> Response { - Response::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } -/// `BadRequest` for `InvalidHeaderValue` -impl ResponseError for header::InvalidHeaderValueBytes { - fn error_response(&self) -> Response { - Response::new(StatusCode::BAD_REQUEST) - } -} - -/// `InternalServerError` for `futures::Canceled` -impl ResponseError for Canceled {} - /// A set of errors that can occur during parsing HTTP streams #[derive(Debug, Display)] pub enum ParseError { @@ -270,8 +281,8 @@ pub enum ParseError { /// Return `BadRequest` for `ParseError` impl ResponseError for ParseError { - fn error_response(&self) -> Response { - Response::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } @@ -363,7 +374,7 @@ impl From> for PayloadError { BlockingError::Error(e) => PayloadError::Io(e), BlockingError::Canceled => PayloadError::Io(io::Error::new( io::ErrorKind::Other, - "Thread pool is gone", + "Operation is canceled", )), } } @@ -374,18 +385,18 @@ impl From> for PayloadError { /// - `Overflow` returns `PayloadTooLarge` /// - Other errors returns `BadRequest` impl ResponseError for PayloadError { - fn error_response(&self) -> Response { + fn status_code(&self) -> StatusCode { match *self { - PayloadError::Overflow => Response::new(StatusCode::PAYLOAD_TOO_LARGE), - _ => Response::new(StatusCode::BAD_REQUEST), + PayloadError::Overflow => StatusCode::PAYLOAD_TOO_LARGE, + _ => StatusCode::BAD_REQUEST, } } } /// Return `BadRequest` for `cookie::ParseError` impl ResponseError for crate::cookie::ParseError { - fn error_response(&self) -> Response { - Response::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } @@ -449,11 +460,19 @@ pub enum ContentTypeError { /// Return `BadRequest` for `ContentTypeError` impl ResponseError for ContentTypeError { - fn error_response(&self) -> Response { - Response::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } +impl ResponseError for FramedDispatcherError +where + E: fmt::Debug + fmt::Display, + ::Error: fmt::Debug, + ::Error: fmt::Debug, +{ +} + /// Helper type that can wrap any error and generate custom response. /// /// In following example any `io::Error` will be converted into "BAD REQUEST" @@ -461,14 +480,12 @@ impl ResponseError for ContentTypeError { /// default. /// /// ```rust -/// # extern crate actix_http; /// # use std::io; /// # use actix_http::*; /// /// fn index(req: Request) -> Result<&'static str> { /// Err(error::ErrorBadRequest(io::Error::new(io::ErrorKind::Other, "error"))) /// } -/// # fn main() {} /// ``` pub struct InternalError { cause: T, @@ -502,7 +519,7 @@ impl fmt::Debug for InternalError where T: fmt::Debug + 'static, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&self.cause, f) } } @@ -511,7 +528,7 @@ impl fmt::Display for InternalError where T: fmt::Display + 'static, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.cause, f) } } @@ -520,6 +537,19 @@ impl ResponseError for InternalError where T: fmt::Debug + fmt::Display + 'static, { + fn status_code(&self) -> StatusCode { + match self.status { + InternalErrorType::Status(st) => st, + InternalErrorType::Response(ref resp) => { + if let Some(resp) = resp.borrow().as_ref() { + resp.head().status + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + } + } + } + fn error_response(&self) -> Response { match self.status { InternalErrorType::Status(st) => { @@ -528,7 +558,7 @@ where let _ = write!(Writer(&mut buf), "{}", self); res.headers_mut().insert( header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), + header::HeaderValue::from_static("text/plain; charset=utf-8"), ); res.set_body(Body::from(buf)) } @@ -541,18 +571,6 @@ where } } } - - /// Constructs an error response - fn render_response(&self) -> Response { - self.error_response() - } -} - -/// Convert Response to a Error -impl From for Error { - fn from(res: Response) -> Error { - InternalError::from_response("", res).into() - } } /// Helper function that creates wrapper of any error and generate *BAD @@ -945,24 +963,15 @@ where InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() } -#[cfg(feature = "fail")] -mod failure_integration { - use super::*; - - /// Compatibility for `failure::Error` - impl ResponseError for failure::Error { - fn error_response(&self) -> Response { - Response::new(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} +#[cfg(feature = "failure")] +/// Compatibility for `failure::Error` +impl ResponseError for fail_ure::Error {} #[cfg(test)] mod tests { use super::*; use http::{Error as HttpError, StatusCode}; use httparse; - use std::error::Error as StdError; use std::io; #[test] @@ -991,7 +1000,7 @@ mod tests { #[test] fn test_error_cause() { let orig = io::Error::new(io::ErrorKind::Other, "other"); - let desc = orig.description().to_owned(); + let desc = orig.to_string(); let e = Error::from(orig); assert_eq!(format!("{}", e.as_response_error()), desc); } @@ -999,7 +1008,7 @@ mod tests { #[test] fn test_error_display() { let orig = io::Error::new(io::ErrorKind::Other, "other"); - let desc = orig.description().to_owned(); + let desc = orig.to_string(); let e = Error::from(orig); assert_eq!(format!("{}", e), desc); } @@ -1041,7 +1050,7 @@ mod tests { match ParseError::from($from) { e @ $error => { let desc = format!("{}", e); - assert_eq!(desc, format!("IO error: {}", $from.description())); + assert_eq!(desc, format!("IO error: {}", $from)); } _ => unreachable!("{:?}", $from), } diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs index c6266f56e..d85ca184d 100644 --- a/actix-http/src/extensions.rs +++ b/actix-http/src/extensions.rs @@ -1,12 +1,12 @@ use std::any::{Any, TypeId}; use std::fmt; -use hashbrown::HashMap; +use fxhash::FxHashMap; #[derive(Default)] /// A type map of request extensions. pub struct Extensions { - map: HashMap>, + map: FxHashMap>, } impl Extensions { @@ -14,7 +14,7 @@ impl Extensions { #[inline] pub fn new() -> Extensions { Extensions { - map: HashMap::default(), + map: FxHashMap::default(), } } @@ -65,7 +65,7 @@ impl Extensions { } impl fmt::Debug for Extensions { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Extensions").finish() } } diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs index c0bbcc694..bcfc18cde 100644 --- a/actix-http/src/h1/client.rs +++ b/actix-http/src/h1/client.rs @@ -1,13 +1,8 @@ -#![allow(unused_imports, unused_variables, dead_code)] -use std::io::{self, Write}; -use std::rc::Rc; +use std::io; use actix_codec::{Decoder, Encoder}; use bitflags::bitflags; -use bytes::{BufMut, Bytes, BytesMut}; -use http::header::{ - HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, -}; +use bytes::{Bytes, BytesMut}; use http::{Method, Version}; use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; @@ -16,9 +11,7 @@ use super::{Message, MessageType}; use crate::body::BodySize; use crate::config::ServiceConfig; use crate::error::{ParseError, PayloadError}; -use crate::helpers; -use crate::message::{ConnectionType, Head, MessagePool, RequestHead, RequestHeadType, ResponseHead}; -use crate::header::HeaderMap; +use crate::message::{ConnectionType, RequestHeadType, ResponseHead}; bitflags! { struct Flags: u8 { @@ -28,8 +21,6 @@ bitflags! { } } -const AVERAGE_HEADER_SIZE: usize = 30; - /// HTTP/1 Codec pub struct ClientCodec { inner: ClientCodecInner, @@ -49,7 +40,6 @@ struct ClientCodecInner { // encoder part flags: Flags, - headers_size: u32, encoder: encoder::MessageEncoder, } @@ -78,7 +68,6 @@ impl ClientCodec { ctype: ConnectionType::Close, flags, - headers_size: 0, encoder: encoder::MessageEncoder::default(), }, } @@ -197,7 +186,9 @@ impl Encoder for ClientCodec { Message::Item((mut head, length)) => { let inner = &mut self.inner; inner.version = head.as_ref().version; - inner.flags.set(Flags::HEAD, head.as_ref().method == Method::HEAD); + inner + .flags + .set(Flags::HEAD, head.as_ref().method == Method::HEAD); // connection status inner.ctype = match head.as_ref().connection_type() { diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs index 22c7ed232..de2af9ee7 100644 --- a/actix-http/src/h1/codec.rs +++ b/actix-http/src/h1/codec.rs @@ -1,12 +1,9 @@ -#![allow(unused_imports, unused_variables, dead_code)] -use std::io::Write; -use std::{fmt, io, net}; +use std::{fmt, io}; use actix_codec::{Decoder, Encoder}; use bitflags::bitflags; -use bytes::{BufMut, Bytes, BytesMut}; -use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; -use http::{Method, StatusCode, Version}; +use bytes::BytesMut; +use http::{Method, Version}; use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; use super::{decoder, encoder}; @@ -14,8 +11,7 @@ use super::{Message, MessageType}; use crate::body::BodySize; use crate::config::ServiceConfig; use crate::error::ParseError; -use crate::helpers; -use crate::message::{ConnectionType, Head, ResponseHead}; +use crate::message::ConnectionType; use crate::request::Request; use crate::response::Response; @@ -27,8 +23,6 @@ bitflags! { } } -const AVERAGE_HEADER_SIZE: usize = 30; - /// HTTP/1 Codec pub struct Codec { config: ServiceConfig, @@ -49,7 +43,7 @@ impl Default for Codec { } impl fmt::Debug for Codec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "h1::Codec({:?})", self.flags) } } @@ -176,7 +170,6 @@ impl Encoder for Codec { }; // encode message - let len = dst.len(); self.encoder.encode( dst, &mut res, @@ -202,17 +195,11 @@ impl Encoder for Codec { #[cfg(test)] mod tests { - use std::{cmp, io}; - - use actix_codec::{AsyncRead, AsyncWrite}; - use bytes::{Buf, Bytes, BytesMut}; - use http::{Method, Version}; + use bytes::BytesMut; + use http::Method; use super::*; - use crate::error::ParseError; - use crate::h1::Message; use crate::httpmessage::HttpMessage; - use crate::request::Request; #[test] fn test_http_request_chunked_payload_and_next_message() { diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs index c7ef065a5..e113fd52d 100644 --- a/actix-http/src/h1/decoder.rs +++ b/actix-http/src/h1/decoder.rs @@ -1,11 +1,13 @@ +use std::convert::TryFrom; +use std::io; use std::marker::PhantomData; -use std::{io, mem}; +use std::mem::MaybeUninit; +use std::task::Poll; use actix_codec::Decoder; -use bytes::{Bytes, BytesMut}; -use futures::{Async, Poll}; +use bytes::{Buf, Bytes, BytesMut}; use http::header::{HeaderName, HeaderValue}; -use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version}; +use http::{header, Method, StatusCode, Uri, Version}; use httparse; use log::{debug, error, trace}; @@ -78,8 +80,8 @@ pub(crate) trait MessageType: Sized { // Unsafe: httparse check header value for valid utf-8 let value = unsafe { - HeaderValue::from_shared_unchecked( - slice.slice(idx.value.0, idx.value.1), + HeaderValue::from_maybe_shared_unchecked( + slice.slice(idx.value.0..idx.value.1), ) }; match name { @@ -183,14 +185,16 @@ impl MessageType for Request { &mut self.head_mut().headers } + #[allow(clippy::uninit_assumed_init)] fn decode(src: &mut BytesMut) -> Result, ParseError> { // Unsafe: we read only this data only after httparse parses headers into. // performance bump for pipeline benchmarks. - let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; + let mut headers: [HeaderIndex; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; let (len, method, uri, ver, h_len) = { - let mut parsed: [httparse::Header; MAX_HEADERS] = - unsafe { mem::uninitialized() }; + let mut parsed: [httparse::Header<'_>; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; let mut req = httparse::Request::new(&mut parsed); match req.parse(src)? { @@ -257,14 +261,16 @@ impl MessageType for ResponseHead { &mut self.headers } + #[allow(clippy::uninit_assumed_init)] fn decode(src: &mut BytesMut) -> Result, ParseError> { // Unsafe: we read only this data only after httparse parses headers into. // performance bump for pipeline benchmarks. - let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; + let mut headers: [HeaderIndex; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; let (len, ver, status, h_len) = { - let mut parsed: [httparse::Header; MAX_HEADERS] = - unsafe { mem::uninitialized() }; + let mut parsed: [httparse::Header<'_>; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; let mut res = httparse::Response::new(&mut parsed); match res.parse(src)? { @@ -322,7 +328,7 @@ pub(crate) struct HeaderIndex { impl HeaderIndex { pub(crate) fn record( bytes: &[u8], - headers: &[httparse::Header], + headers: &[httparse::Header<'_>], indices: &mut [HeaderIndex], ) { let bytes_ptr = bytes.as_ptr() as usize; @@ -425,7 +431,7 @@ impl Decoder for PayloadDecoder { let len = src.len() as u64; let buf; if *remaining > len { - buf = src.take().freeze(); + buf = src.split().freeze(); *remaining -= len; } else { buf = src.split_to(*remaining as usize).freeze(); @@ -439,9 +445,10 @@ impl Decoder for PayloadDecoder { loop { let mut buf = None; // advances the chunked state - *state = match state.step(src, size, &mut buf)? { - Async::NotReady => return Ok(None), - Async::Ready(state) => state, + *state = match state.step(src, size, &mut buf) { + Poll::Pending => return Ok(None), + Poll::Ready(Ok(state)) => state, + Poll::Ready(Err(e)) => return Err(e), }; if *state == ChunkedState::End { trace!("End of chunked stream"); @@ -459,7 +466,7 @@ impl Decoder for PayloadDecoder { if src.is_empty() { Ok(None) } else { - Ok(Some(PayloadItem::Chunk(src.take().freeze()))) + Ok(Some(PayloadItem::Chunk(src.split().freeze()))) } } } @@ -470,10 +477,10 @@ macro_rules! byte ( ($rdr:ident) => ({ if $rdr.len() > 0 { let b = $rdr[0]; - $rdr.split_to(1); + $rdr.advance(1); b } else { - return Ok(Async::NotReady) + return Poll::Pending } }) ); @@ -484,7 +491,7 @@ impl ChunkedState { body: &mut BytesMut, size: &mut u64, buf: &mut Option, - ) -> Poll { + ) -> Poll> { use self::ChunkedState::*; match *self { Size => ChunkedState::read_size(body, size), @@ -496,10 +503,14 @@ impl ChunkedState { BodyLf => ChunkedState::read_body_lf(body), EndCr => ChunkedState::read_end_cr(body), EndLf => ChunkedState::read_end_lf(body), - End => Ok(Async::Ready(ChunkedState::End)), + End => Poll::Ready(Ok(ChunkedState::End)), } } - fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll { + + fn read_size( + rdr: &mut BytesMut, + size: &mut u64, + ) -> Poll> { let radix = 16; match byte!(rdr) { b @ b'0'..=b'9' => { @@ -514,48 +525,49 @@ impl ChunkedState { *size *= radix; *size += u64::from(b + 10 - b'A'); } - b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)), - b';' => return Ok(Async::Ready(ChunkedState::Extension)), - b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)), + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), _ => { - return Err(io::Error::new( + return Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk size line: Invalid Size", - )); + ))); } } - Ok(Async::Ready(ChunkedState::Size)) + Poll::Ready(Ok(ChunkedState::Size)) } - fn read_size_lws(rdr: &mut BytesMut) -> Poll { + + fn read_size_lws(rdr: &mut BytesMut) -> Poll> { trace!("read_size_lws"); match byte!(rdr) { // LWS can follow the chunk size, but no more digits can come - b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)), - b';' => Ok(Async::Ready(ChunkedState::Extension)), - b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), - _ => Err(io::Error::new( + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk size linear white space", - )), + ))), } } - fn read_extension(rdr: &mut BytesMut) -> Poll { + fn read_extension(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), - _ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions } } fn read_size_lf( rdr: &mut BytesMut, size: &mut u64, - ) -> Poll { + ) -> Poll> { match byte!(rdr) { - b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)), - b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)), - _ => Err(io::Error::new( + b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), + b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk size LF", - )), + ))), } } @@ -563,16 +575,16 @@ impl ChunkedState { rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option, - ) -> Poll { + ) -> Poll> { trace!("Chunked read, remaining={:?}", rem); let len = rdr.len() as u64; if len == 0 { - Ok(Async::Ready(ChunkedState::Body)) + Poll::Ready(Ok(ChunkedState::Body)) } else { let slice; if *rem > len { - slice = rdr.take().freeze(); + slice = rdr.split().freeze(); *rem -= len; } else { slice = rdr.split_to(*rem as usize).freeze(); @@ -580,47 +592,47 @@ impl ChunkedState { } *buf = Some(slice); if *rem > 0 { - Ok(Async::Ready(ChunkedState::Body)) + Poll::Ready(Ok(ChunkedState::Body)) } else { - Ok(Async::Ready(ChunkedState::BodyCr)) + Poll::Ready(Ok(ChunkedState::BodyCr)) } } } - fn read_body_cr(rdr: &mut BytesMut) -> Poll { + fn read_body_cr(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)), - _ => Err(io::Error::new( + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk body CR", - )), + ))), } } - fn read_body_lf(rdr: &mut BytesMut) -> Poll { + fn read_body_lf(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\n' => Ok(Async::Ready(ChunkedState::Size)), - _ => Err(io::Error::new( + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk body LF", - )), + ))), } } - fn read_end_cr(rdr: &mut BytesMut) -> Poll { + fn read_end_cr(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::EndLf)), - _ => Err(io::Error::new( + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk end CR", - )), + ))), } } - fn read_end_lf(rdr: &mut BytesMut) -> Poll { + fn read_end_lf(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { - b'\n' => Ok(Async::Ready(ChunkedState::End)), - _ => Err(io::Error::new( + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk end LF", - )), + ))), } } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index c82eb4ac8..6f4c09915 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -1,15 +1,15 @@ use std::collections::VecDeque; -use std::time::Instant; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, io, net}; -use actix_codec::{Decoder, Encoder, Framed, FramedParts}; -use actix_server_config::IoStream; +use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; +use actix_rt::time::{delay_until, Delay, Instant}; use actix_service::Service; use bitflags::bitflags; -use bytes::{BufMut, BytesMut}; -use futures::{Async, Future, Poll}; +use bytes::{Buf, BytesMut}; use log::{error, trace}; -use tokio_timer::Delay; use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::cloneable::CloneableService; @@ -166,7 +166,7 @@ impl PartialEq for PollResponse { impl Dispatcher where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, S::Error: Into, S::Response: Into>, @@ -184,6 +184,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + peer_addr: Option, ) -> Self { Dispatcher::with_timeout( stream, @@ -195,6 +196,7 @@ where expect, upgrade, on_connect, + peer_addr, ) } @@ -209,6 +211,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + peer_addr: Option, ) -> Self { let keepalive = config.keep_alive_enabled(); let flags = if keepalive { @@ -232,7 +235,6 @@ where payload: None, state: State::None, error: None, - peer_addr: io.peer_addr(), messages: VecDeque::new(), io, codec, @@ -242,6 +244,7 @@ where upgrade, on_connect, flags, + peer_addr, ka_expire, ka_timer, }), @@ -251,7 +254,7 @@ where impl InnerDispatcher where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, S::Error: Into, S::Response: Into>, @@ -261,14 +264,14 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - fn can_read(&self) -> bool { + fn can_read(&self, cx: &mut Context<'_>) -> bool { if self .flags .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) { false } else if let Some(ref info) = self.payload { - info.need_read() == PayloadStatus::Read + info.need_read(cx) == PayloadStatus::Read } else { true } @@ -287,7 +290,7 @@ where /// /// true - got whouldblock /// false - didnt get whouldblock - fn poll_flush(&mut self) -> Result { + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result { if self.write_buf.is_empty() { return Ok(false); } @@ -295,31 +298,31 @@ where let len = self.write_buf.len(); let mut written = 0; while written < len { - match self.io.write(&self.write_buf[written..]) { - Ok(0) => { + match unsafe { Pin::new_unchecked(&mut self.io) } + .poll_write(cx, &self.write_buf[written..]) + { + Poll::Ready(Ok(0)) => { return Err(DispatchError::Io(io::Error::new( io::ErrorKind::WriteZero, "", ))); } - Ok(n) => { + Poll::Ready(Ok(n)) => { written += n; } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + Poll::Pending => { if written > 0 { - let _ = self.write_buf.split_to(written); + self.write_buf.advance(written); } return Ok(true); } - Err(err) => return Err(DispatchError::Io(err)), + Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), } } - if written > 0 { - if written == self.write_buf.len() { - unsafe { self.write_buf.set_len(0) } - } else { - let _ = self.write_buf.split_to(written); - } + if written == self.write_buf.len() { + unsafe { self.write_buf.set_len(0) } + } else { + self.write_buf.advance(written); } Ok(false) } @@ -350,12 +353,15 @@ where .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); } - fn poll_response(&mut self) -> Result { + fn poll_response( + &mut self, + cx: &mut Context<'_>, + ) -> Result { loop { let state = match self.state { State::None => match self.messages.pop_front() { Some(DispatcherMessage::Item(req)) => { - Some(self.handle_request(req)?) + Some(self.handle_request(req, cx)?) } Some(DispatcherMessage::Error(res)) => { Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) @@ -365,54 +371,58 @@ where } None => None, }, - State::ExpectCall(ref mut fut) => match fut.poll() { - Ok(Async::Ready(req)) => { - self.send_continue(); - self.state = State::ServiceCall(self.service.call(req)); - continue; + State::ExpectCall(ref mut fut) => { + match unsafe { Pin::new_unchecked(fut) }.poll(cx) { + Poll::Ready(Ok(req)) => { + self.send_continue(); + self.state = State::ServiceCall(self.service.call(req)); + continue; + } + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + Poll::Pending => None, } - Ok(Async::NotReady) => None, - Err(e) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - Some(self.send_response(res, body.into_body())?) + } + State::ServiceCall(ref mut fut) => { + match unsafe { Pin::new_unchecked(fut) }.poll(cx) { + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().replace_body(()); + self.state = self.send_response(res, body)?; + continue; + } + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + Poll::Pending => None, } - }, - State::ServiceCall(ref mut fut) => match fut.poll() { - Ok(Async::Ready(res)) => { - let (res, body) = res.into().replace_body(()); - self.state = self.send_response(res, body)?; - continue; - } - Ok(Async::NotReady) => None, - Err(e) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - Some(self.send_response(res, body.into_body())?) - } - }, + } State::SendPayload(ref mut stream) => { loop { if self.write_buf.len() < HW_BUFFER_SIZE { - match stream - .poll_next() - .map_err(|_| DispatchError::Unknown)? - { - Async::Ready(Some(item)) => { + match stream.poll_next(cx) { + Poll::Ready(Some(Ok(item))) => { self.codec.encode( Message::Chunk(Some(item)), &mut self.write_buf, )?; continue; } - Async::Ready(None) => { + Poll::Ready(None) => { self.codec.encode( Message::Chunk(None), &mut self.write_buf, )?; self.state = State::None; } - Async::NotReady => return Ok(PollResponse::DoNothing), + Poll::Ready(Some(Err(_))) => { + return Err(DispatchError::Unknown) + } + Poll::Pending => return Ok(PollResponse::DoNothing), } } else { return Ok(PollResponse::DrainWriteBuf); @@ -433,7 +443,7 @@ where // if read-backpressure is enabled and we consumed some data. // we may read more data and retry if self.state.is_call() { - if self.poll_request()? { + if self.poll_request(cx)? { continue; } } else if !self.messages.is_empty() { @@ -446,17 +456,21 @@ where Ok(PollResponse::DoNothing) } - fn handle_request(&mut self, req: Request) -> Result, DispatchError> { + fn handle_request( + &mut self, + req: Request, + cx: &mut Context<'_>, + ) -> Result, DispatchError> { // Handle `EXPECT: 100-Continue` header let req = if req.head().expect() { let mut task = self.expect.call(req); - match task.poll() { - Ok(Async::Ready(req)) => { + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { + Poll::Ready(Ok(req)) => { self.send_continue(); req } - Ok(Async::NotReady) => return Ok(State::ExpectCall(task)), - Err(e) => { + Poll::Pending => return Ok(State::ExpectCall(task)), + Poll::Ready(Err(e)) => { let e = e.into(); let res: Response = e.into(); let (res, body) = res.replace_body(()); @@ -469,13 +483,13 @@ where // Call service let mut task = self.service.call(req); - match task.poll() { - Ok(Async::Ready(res)) => { + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { + Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); self.send_response(res, body) } - Ok(Async::NotReady) => Ok(State::ServiceCall(task)), - Err(e) => { + Poll::Pending => Ok(State::ServiceCall(task)), + Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); self.send_response(res, body.into_body()) @@ -484,9 +498,12 @@ where } /// Process one incoming requests - pub(self) fn poll_request(&mut self) -> Result { + pub(self) fn poll_request( + &mut self, + cx: &mut Context<'_>, + ) -> Result { // limit a mount of non processed requests - if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read() { + if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) { return Ok(false); } @@ -521,7 +538,7 @@ where // handle request early if self.state.is_empty() { - self.state = self.handle_request(req)?; + self.state = self.handle_request(req, cx)?; } else { self.messages.push_back(DispatcherMessage::Item(req)); } @@ -587,12 +604,12 @@ where } /// keep-alive timer - fn poll_keepalive(&mut self) -> Result<(), DispatchError> { + fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> { if self.ka_timer.is_none() { // shutdown timeout if self.flags.contains(Flags::SHUTDOWN) { if let Some(interval) = self.codec.config().client_disconnect_timer() { - self.ka_timer = Some(Delay::new(interval)); + self.ka_timer = Some(delay_until(interval)); } else { self.flags.insert(Flags::READ_DISCONNECT); if let Some(mut payload) = self.payload.take() { @@ -605,11 +622,8 @@ where } } - match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { - error!("Timer error {:?}", e); - DispatchError::Unknown - })? { - Async::Ready(_) => { + match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { + Poll::Ready(()) => { // if we get timeout during shutdown, drop connection if self.flags.contains(Flags::SHUTDOWN) { return Err(DispatchError::DisconnectTimeout); @@ -624,9 +638,9 @@ where if let Some(deadline) = self.codec.config().client_disconnect_timer() { - if let Some(timer) = self.ka_timer.as_mut() { + if let Some(mut timer) = self.ka_timer.as_mut() { timer.reset(deadline); - let _ = timer.poll(); + let _ = Pin::new(&mut timer).poll(cx); } } else { // no shutdown timeout, drop socket @@ -650,26 +664,26 @@ where } else if let Some(deadline) = self.codec.config().keep_alive_expire() { - if let Some(timer) = self.ka_timer.as_mut() { + if let Some(mut timer) = self.ka_timer.as_mut() { timer.reset(deadline); - let _ = timer.poll(); + let _ = Pin::new(&mut timer).poll(cx); } } - } else if let Some(timer) = self.ka_timer.as_mut() { + } else if let Some(mut timer) = self.ka_timer.as_mut() { timer.reset(self.ka_expire); - let _ = timer.poll(); + let _ = Pin::new(&mut timer).poll(cx); } } - Async::NotReady => (), + Poll::Pending => (), } Ok(()) } } -impl Future for Dispatcher +impl Unpin for Dispatcher where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, S::Error: Into, S::Response: Into>, @@ -679,27 +693,42 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - type Item = (); - type Error = DispatchError; +} + +impl Future for Dispatcher +where + T: AsyncRead + AsyncWrite + Unpin, + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + type Output = Result<(), DispatchError>; #[inline] - fn poll(&mut self) -> Poll { - match self.inner { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().inner { DispatcherState::Normal(ref mut inner) => { - inner.poll_keepalive()?; + inner.poll_keepalive(cx)?; if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::WRITE_DISCONNECT) { - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } else { // flush buffer - inner.poll_flush()?; + inner.poll_flush(cx)?; if !inner.write_buf.is_empty() { - Ok(Async::NotReady) + Poll::Pending } else { - match inner.io.shutdown()? { - Async::Ready(_) => Ok(Async::Ready(())), - Async::NotReady => Ok(Async::NotReady), + match Pin::new(&mut inner.io).poll_shutdown(cx) { + Poll::Ready(res) => { + Poll::Ready(res.map_err(DispatchError::from)) + } + Poll::Pending => Poll::Pending, } } } @@ -707,12 +736,12 @@ where // read socket into a buf let should_disconnect = if !inner.flags.contains(Flags::READ_DISCONNECT) { - read_available(&mut inner.io, &mut inner.read_buf)? + read_available(cx, &mut inner.io, &mut inner.read_buf)? } else { None }; - inner.poll_request()?; + inner.poll_request(cx)?; if let Some(true) = should_disconnect { inner.flags.insert(Flags::READ_DISCONNECT); if let Some(mut payload) = inner.payload.take() { @@ -721,10 +750,12 @@ where }; loop { - if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { - inner.write_buf.reserve(HW_BUFFER_SIZE); + let remaining = + inner.write_buf.capacity() - inner.write_buf.len(); + if remaining < LW_BUFFER_SIZE { + inner.write_buf.reserve(HW_BUFFER_SIZE - remaining); } - let result = inner.poll_response()?; + let result = inner.poll_response(cx)?; let drain = result == PollResponse::DrainWriteBuf; // switch to upgrade handler @@ -742,7 +773,7 @@ where self.inner = DispatcherState::Upgrade( inner.upgrade.unwrap().call((req, framed)), ); - return self.poll(); + return self.poll(cx); } else { panic!() } @@ -751,14 +782,14 @@ where // we didnt get WouldBlock from write operation, // so data get written to kernel completely (OSX) // and we have to write again otherwise response can get stuck - if inner.poll_flush()? || !drain { + if inner.poll_flush(cx)? || !drain { break; } } // client is gone if inner.flags.contains(Flags::WRITE_DISCONNECT) { - return Ok(Async::Ready(())); + return Poll::Ready(Ok(())); } let is_empty = inner.state.is_empty(); @@ -771,58 +802,64 @@ where // keep-alive and stream errors if is_empty && inner.write_buf.is_empty() { if let Some(err) = inner.error.take() { - Err(err) + Poll::Ready(Err(err)) } // disconnect if keep-alive is not enabled else if inner.flags.contains(Flags::STARTED) && !inner.flags.intersects(Flags::KEEPALIVE) { inner.flags.insert(Flags::SHUTDOWN); - self.poll() + self.poll(cx) } // disconnect if shutdown else if inner.flags.contains(Flags::SHUTDOWN) { - self.poll() + self.poll(cx) } else { - Ok(Async::NotReady) + Poll::Pending } } else { - Ok(Async::NotReady) + Poll::Pending } } } - DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| { - error!("Upgrade handler error: {}", e); - DispatchError::Upgrade - }), + DispatcherState::Upgrade(ref mut fut) => { + unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| { + error!("Upgrade handler error: {}", e); + DispatchError::Upgrade + }) + } DispatcherState::None => panic!(), } } } -fn read_available(io: &mut T, buf: &mut BytesMut) -> Result, io::Error> +fn read_available( + cx: &mut Context<'_>, + io: &mut T, + buf: &mut BytesMut, +) -> Result, io::Error> where - T: io::Read, + T: AsyncRead + Unpin, { let mut read_some = false; loop { - if buf.remaining_mut() < LW_BUFFER_SIZE { - buf.reserve(HW_BUFFER_SIZE); + let remaining = buf.capacity() - buf.len(); + if remaining < LW_BUFFER_SIZE { + buf.reserve(HW_BUFFER_SIZE - remaining); } - let read = unsafe { io.read(buf.bytes_mut()) }; - match read { - Ok(n) => { + match read(cx, io, buf) { + Poll::Pending => { + return if read_some { Ok(Some(false)) } else { Ok(None) }; + } + Poll::Ready(Ok(n)) => { if n == 0 { return Ok(Some(true)); } else { read_some = true; - unsafe { - buf.advance_mut(n); - } } } - Err(e) => { + Poll::Ready(Err(e)) => { return if e.kind() == io::ErrorKind::WouldBlock { if read_some { Ok(Some(false)) @@ -833,26 +870,36 @@ where Ok(Some(true)) } else { Err(e) - }; + } } } } } +fn read( + cx: &mut Context<'_>, + io: &mut T, + buf: &mut BytesMut, +) -> Poll> +where + T: AsyncRead + Unpin, +{ + Pin::new(io).poll_read_buf(cx, buf) +} + #[cfg(test)] mod tests { use actix_service::IntoService; - use futures::future::{lazy, ok}; + use futures_util::future::{lazy, ok}; use super::*; use crate::error::Error; use crate::h1::{ExpectHandler, UpgradeHandler}; use crate::test::TestBuffer; - #[test] - fn test_req_parse_err() { - let mut sys = actix_rt::System::new("test"); - let _ = sys.block_on(lazy(|| { + #[actix_rt::test] + async fn test_req_parse_err() { + lazy(|cx| { let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( @@ -864,14 +911,18 @@ mod tests { CloneableService::new(ExpectHandler), None, None, + None, ); - assert!(h1.poll().is_err()); + match Pin::new(&mut h1).poll(cx) { + Poll::Pending => panic!(), + Poll::Ready(res) => assert!(res.is_err()), + } if let DispatcherState::Normal(ref inner) = h1.inner { assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n"); } - ok::<_, ()>(()) - })); + }) + .await; } } diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 380dfe328..4689906b4 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -1,23 +1,18 @@ -#![allow(unused_imports, unused_variables, dead_code)] -use std::fmt::Write as FmtWrite; use std::io::Write; use std::marker::PhantomData; -use std::str::FromStr; -use std::{cmp, fmt, io, mem}; -use std::rc::Rc; +use std::ptr::copy_nonoverlapping; +use std::slice::from_raw_parts_mut; +use std::{cmp, io}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{buf::BufMutExt, BufMut, BytesMut}; use crate::body::BodySize; use crate::config::ServiceConfig; -use crate::header::{map, ContentEncoding}; +use crate::header::map; use crate::helpers; -use crate::http::header::{ - HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, -}; -use crate::http::{HeaderMap, Method, StatusCode, Version}; -use crate::message::{ConnectionType, Head, RequestHead, ResponseHead, RequestHeadType}; -use crate::request::Request; +use crate::http::header::{CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use crate::http::{HeaderMap, StatusCode, Version}; +use crate::message::{ConnectionType, RequestHeadType}; use crate::response::Response; const AVERAGE_HEADER_SIZE: usize = 30; @@ -106,6 +101,7 @@ pub(crate) trait MessageType: Sized { } else { dst.put_slice(b"\r\ncontent-length: "); } + #[allow(clippy::write_with_newline)] write!(dst.writer(), "{}\r\n", len)?; } BodySize::None => dst.put_slice(b"\r\n"), @@ -134,17 +130,18 @@ pub(crate) trait MessageType: Sized { // merging headers from head and extra headers. HeaderMap::new() does not allocate. let empty_headers = HeaderMap::new(); let extra_headers = self.extra_headers().unwrap_or(&empty_headers); - let headers = self.headers().inner.iter() - .filter(|(name, _)| { - !extra_headers.contains_key(*name) - }) + let headers = self + .headers() + .inner + .iter() + .filter(|(name, _)| !extra_headers.contains_key(*name)) .chain(extra_headers.inner.iter()); // write headers let mut pos = 0; let mut has_date = false; - let mut remaining = dst.remaining_mut(); - let mut buf = unsafe { &mut *(dst.bytes_mut() as *mut [u8]) }; + let mut remaining = dst.capacity() - dst.len(); + let mut buf = dst.bytes_mut().as_mut_ptr() as *mut u8; for (key, value) in headers { match *key { CONNECTION => continue, @@ -158,61 +155,67 @@ pub(crate) trait MessageType: Sized { match value { map::Value::One(ref val) => { let v = val.as_ref(); - let len = k.len() + v.len() + 4; + let v_len = v.len(); + let k_len = k.len(); + let len = k_len + v_len + 4; if len > remaining { unsafe { dst.advance_mut(pos); } pos = 0; dst.reserve(len * 2); - remaining = dst.remaining_mut(); - unsafe { - buf = &mut *(dst.bytes_mut() as *mut _); - } + remaining = dst.capacity() - dst.len(); + buf = dst.bytes_mut().as_mut_ptr() as *mut u8; } // use upper Camel-Case - if camel_case { - write_camel_case(k, &mut buf[pos..pos + k.len()]); - } else { - buf[pos..pos + k.len()].copy_from_slice(k); + unsafe { + if camel_case { + write_camel_case(k, from_raw_parts_mut(buf, k_len)) + } else { + write_data(k, buf, k_len) + } + buf = buf.add(k_len); + write_data(b": ", buf, 2); + buf = buf.add(2); + write_data(v, buf, v_len); + buf = buf.add(v_len); + write_data(b"\r\n", buf, 2); + buf = buf.add(2); + pos += len; + remaining -= len; } - pos += k.len(); - buf[pos..pos + 2].copy_from_slice(b": "); - pos += 2; - buf[pos..pos + v.len()].copy_from_slice(v); - pos += v.len(); - buf[pos..pos + 2].copy_from_slice(b"\r\n"); - pos += 2; - remaining -= len; } map::Value::Multi(ref vec) => { for val in vec { let v = val.as_ref(); - let len = k.len() + v.len() + 4; + let v_len = v.len(); + let k_len = k.len(); + let len = k_len + v_len + 4; if len > remaining { unsafe { dst.advance_mut(pos); } pos = 0; dst.reserve(len * 2); - remaining = dst.remaining_mut(); - unsafe { - buf = &mut *(dst.bytes_mut() as *mut _); - } + remaining = dst.capacity() - dst.len(); + buf = dst.bytes_mut().as_mut_ptr() as *mut u8; } // use upper Camel-Case - if camel_case { - write_camel_case(k, &mut buf[pos..pos + k.len()]); - } else { - buf[pos..pos + k.len()].copy_from_slice(k); - } - pos += k.len(); - buf[pos..pos + 2].copy_from_slice(b": "); - pos += 2; - buf[pos..pos + v.len()].copy_from_slice(v); - pos += v.len(); - buf[pos..pos + 2].copy_from_slice(b"\r\n"); - pos += 2; + unsafe { + if camel_case { + write_camel_case(k, from_raw_parts_mut(buf, k_len)); + } else { + write_data(k, buf, k_len); + } + buf = buf.add(k_len); + write_data(b": ", buf, 2); + buf = buf.add(2); + write_data(v, buf, v_len); + buf = buf.add(v_len); + write_data(b"\r\n", buf, 2); + buf = buf.add(2); + }; + pos += len; remaining -= len; } } @@ -297,6 +300,12 @@ impl MessageType for RequestHeadType { Version::HTTP_10 => "HTTP/1.0", Version::HTTP_11 => "HTTP/1.1", Version::HTTP_2 => "HTTP/2.0", + Version::HTTP_3 => "HTTP/3.0", + _ => + return Err(io::Error::new( + io::ErrorKind::Other, + "unsupported version" + )), } ) .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) @@ -478,6 +487,10 @@ impl<'a> io::Write for Writer<'a> { } } +unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) { + copy_nonoverlapping(value.as_ptr(), buf, len); +} + fn write_camel_case(value: &[u8], buffer: &mut [u8]) { let mut index = 0; let key = value; @@ -508,12 +521,14 @@ fn write_camel_case(value: &[u8], buffer: &mut [u8]) { #[cfg(test)] mod tests { + use std::rc::Rc; + use bytes::Bytes; - //use std::rc::Rc; + use http::header::AUTHORIZATION; use super::*; use crate::http::header::{HeaderValue, CONTENT_TYPE}; - use http::header::AUTHORIZATION; + use crate::RequestHead; #[test] fn test_chunked_te() { @@ -524,7 +539,7 @@ mod tests { assert!(enc.encode(b"", &mut bytes).ok().unwrap()); } assert_eq!( - bytes.take().freeze(), + bytes.split().freeze(), Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n") ); } @@ -547,10 +562,12 @@ mod tests { ConnectionType::Close, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\nContent-Length: 0\r\nConnection: close\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") - ); + let data = + String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("Content-Length: 0\r\n")); + assert!(data.contains("Connection: close\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); let _ = head.encode_headers( &mut bytes, @@ -559,10 +576,11 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\nTransfer-Encoding: chunked\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") - ); + let data = + String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("Transfer-Encoding: chunked\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); let _ = head.encode_headers( &mut bytes, @@ -571,10 +589,11 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\nContent-Length: 100\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") - ); + let data = + String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("Content-Length: 100\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); let mut head = RequestHead::default(); head.set_camel_case_headers(false); @@ -585,7 +604,6 @@ mod tests { .append(CONTENT_TYPE, HeaderValue::from_static("xml")); let mut head = RequestHeadType::Owned(head); - let _ = head.encode_headers( &mut bytes, Version::HTTP_11, @@ -593,10 +611,12 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\ntransfer-encoding: chunked\r\ndate: date\r\ncontent-type: xml\r\ncontent-type: plain/text\r\n\r\n") - ); + let data = + String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("transfer-encoding: chunked\r\n")); + assert!(data.contains("content-type: xml\r\n")); + assert!(data.contains("content-type: plain/text\r\n")); + assert!(data.contains("date: date\r\n")); } #[test] @@ -604,10 +624,16 @@ mod tests { let mut bytes = BytesMut::with_capacity(2048); let mut head = RequestHead::default(); - head.headers.insert(AUTHORIZATION, HeaderValue::from_static("some authorization")); + head.headers.insert( + AUTHORIZATION, + HeaderValue::from_static("some authorization"), + ); let mut extra_headers = HeaderMap::new(); - extra_headers.insert(AUTHORIZATION,HeaderValue::from_static("another authorization")); + extra_headers.insert( + AUTHORIZATION, + HeaderValue::from_static("another authorization"), + ); extra_headers.insert(DATE, HeaderValue::from_static("date")); let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers)); @@ -619,9 +645,11 @@ mod tests { ConnectionType::Close, &ServiceConfig::default(), ); - assert_eq!( - bytes.take().freeze(), - Bytes::from_static(b"\r\ncontent-length: 0\r\nconnection: close\r\nauthorization: another authorization\r\ndate: date\r\n\r\n") - ); + let data = + String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("content-length: 0\r\n")); + assert!(data.contains("connection: close\r\n")); + assert!(data.contains("authorization: another authorization\r\n")); + assert!(data.contains("date: date\r\n")); } } diff --git a/actix-http/src/h1/expect.rs b/actix-http/src/h1/expect.rs index 32b6bd9c4..6c08df08e 100644 --- a/actix-http/src/h1/expect.rs +++ b/actix-http/src/h1/expect.rs @@ -1,23 +1,23 @@ -use actix_server_config::ServerConfig; -use actix_service::{NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{Async, Poll}; +use std::task::{Context, Poll}; + +use actix_service::{Service, ServiceFactory}; +use futures_util::future::{ok, Ready}; use crate::error::Error; use crate::request::Request; pub struct ExpectHandler; -impl NewService for ExpectHandler { - type Config = ServerConfig; +impl ServiceFactory for ExpectHandler { + type Config = (); type Request = Request; type Response = Request; type Error = Error; type Service = ExpectHandler; type InitError = Error; - type Future = FutureResult; + type Future = Ready>; - fn new_service(&self, _: &ServerConfig) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { ok(ExpectHandler) } } @@ -26,10 +26,10 @@ impl Service for ExpectHandler { type Request = Request; type Response = Request; type Error = Error; - type Future = FutureResult; + type Future = Ready>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { diff --git a/actix-http/src/h1/payload.rs b/actix-http/src/h1/payload.rs index 28acb64bb..6a348810c 100644 --- a/actix-http/src/h1/payload.rs +++ b/actix-http/src/h1/payload.rs @@ -1,12 +1,13 @@ //! Payload stream use std::cell::RefCell; use std::collections::VecDeque; +use std::pin::Pin; use std::rc::{Rc, Weak}; +use std::task::{Context, Poll}; +use actix_utils::task::LocalWaker; use bytes::Bytes; -use futures::task::current as current_task; -use futures::task::Task; -use futures::{Async, Poll, Stream}; +use futures_core::Stream; use crate::error::PayloadError; @@ -77,15 +78,24 @@ impl Payload { pub fn unread_data(&mut self, data: Bytes) { self.inner.borrow_mut().unread_data(data); } + + #[inline] + pub fn readany( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + self.inner.borrow_mut().readany(cx) + } } impl Stream for Payload { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - #[inline] - fn poll(&mut self) -> Poll, PayloadError> { - self.inner.borrow_mut().readany() + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.inner.borrow_mut().readany(cx) } } @@ -117,19 +127,14 @@ impl PayloadSender { } #[inline] - pub fn need_read(&self) -> PayloadStatus { + pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus { // we check need_read only if Payload (other side) is alive, // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { if shared.borrow().need_read { PayloadStatus::Read } else { - #[cfg(not(test))] - { - if shared.borrow_mut().io_task.is_none() { - shared.borrow_mut().io_task = Some(current_task()); - } - } + shared.borrow_mut().io_task.register(cx.waker()); PayloadStatus::Pause } } else { @@ -145,8 +150,8 @@ struct Inner { err: Option, need_read: bool, items: VecDeque, - task: Option, - io_task: Option, + task: LocalWaker, + io_task: LocalWaker, } impl Inner { @@ -157,8 +162,8 @@ impl Inner { err: None, items: VecDeque::new(), need_read: true, - task: None, - io_task: None, + task: LocalWaker::new(), + io_task: LocalWaker::new(), } } @@ -178,7 +183,7 @@ impl Inner { self.items.push_back(data); self.need_read = self.len < MAX_BUFFER_SIZE; if let Some(task) = self.task.take() { - task.notify() + task.wake() } } @@ -187,34 +192,28 @@ impl Inner { self.len } - fn readany(&mut self) -> Poll, PayloadError> { + fn readany( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); self.need_read = self.len < MAX_BUFFER_SIZE; - if self.need_read && self.task.is_none() && !self.eof { - self.task = Some(current_task()); + if self.need_read && !self.eof { + self.task.register(cx.waker()); } - if let Some(task) = self.io_task.take() { - task.notify() - } - Ok(Async::Ready(Some(data))) + self.io_task.wake(); + Poll::Ready(Some(Ok(data))) } else if let Some(err) = self.err.take() { - Err(err) + Poll::Ready(Some(Err(err))) } else if self.eof { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { self.need_read = true; - #[cfg(not(test))] - { - if self.task.is_none() { - self.task = Some(current_task()); - } - if let Some(task) = self.io_task.take() { - task.notify() - } - } - Ok(Async::NotReady) + self.task.register(cx.waker()); + self.io_task.wake(); + Poll::Pending } } @@ -227,28 +226,19 @@ impl Inner { #[cfg(test)] mod tests { use super::*; - use actix_rt::Runtime; - use futures::future::{lazy, result}; + use futures_util::future::poll_fn; - #[test] - fn test_unread_data() { - Runtime::new() - .unwrap() - .block_on(lazy(|| { - let (_, mut payload) = Payload::create(false); + #[actix_rt::test] + async fn test_unread_data() { + let (_, mut payload) = Payload::create(false); - payload.unread_data(Bytes::from("data")); - assert!(!payload.is_empty()); - assert_eq!(payload.len(), 4); + payload.unread_data(Bytes::from("data")); + assert!(!payload.is_empty()); + assert_eq!(payload.len(), 4); - assert_eq!( - Async::Ready(Some(Bytes::from("data"))), - payload.poll().ok().unwrap() - ); - - let res: Result<(), ()> = Ok(()); - result(res) - })) - .unwrap(); + assert_eq!( + Bytes::from("data"), + poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap() + ); } } diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 89bf08e9b..4d1a1dc1b 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -1,16 +1,19 @@ -use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; +use std::{fmt, net}; -use actix_codec::Framed; -use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; -use actix_service::{IntoNewService, NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_rt::net::TcpStream; +use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; +use futures_core::ready; +use futures_util::future::{ok, Ready}; use crate::body::MessageBody; use crate::cloneable::CloneableService; -use crate::config::{KeepAlive, ServiceConfig}; +use crate::config::ServiceConfig; use crate::error::{DispatchError, Error, ParseError}; use crate::helpers::DataFactory; use crate::request::Request; @@ -20,43 +23,32 @@ use super::codec::Codec; use super::dispatcher::Dispatcher; use super::{ExpectHandler, Message, UpgradeHandler}; -/// `NewService` implementation for HTTP1 transport -pub struct H1Service> { +/// `ServiceFactory` implementation for HTTP1 transport +pub struct H1Service> { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, on_connect: Option Box>>, - _t: PhantomData<(T, P, B)>, + _t: PhantomData<(T, B)>, } -impl H1Service +impl H1Service where - S: NewService, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, { - /// Create new `HttpService` instance with default config. - pub fn new>(service: F) -> Self { - let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); - - H1Service { - cfg, - srv: service.into_new_service(), - expect: ExpectHandler, - upgrade: None, - on_connect: None, - _t: PhantomData, - } - } - /// Create new `HttpService` instance with config. - pub fn with_config>(cfg: ServiceConfig, service: F) -> Self { + pub(crate) fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { H1Service { cfg, - srv: service.into_new_service(), + srv: service.into_factory(), expect: ExpectHandler, upgrade: None, on_connect: None, @@ -65,17 +57,153 @@ where } } -impl H1Service +impl H1Service where - S: NewService, + S: ServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + S::Response: Into>, + B: MessageBody, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + U: ServiceFactory< + Config = (), + Request = (Request, Framed), + Response = (), + >, + U::Error: fmt::Display + Into, + U::InitError: fmt::Debug, +{ + /// Create simple tcp stream service + pub fn tcp( + self, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = DispatchError, + InitError = (), + > { + pipeline_factory(|io: TcpStream| { + let peer_addr = io.peer_addr().ok(); + ok((io, peer_addr)) + }) + .and_then(self) + } +} + +#[cfg(feature = "openssl")] +mod openssl { + use super::*; + + use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream}; + use actix_tls::{openssl::HandshakeError, SslError}; + + impl H1Service, S, B, X, U> + where + S: ServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + S::Response: Into>, + B: MessageBody, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + U: ServiceFactory< + Config = (), + Request = (Request, Framed, Codec>), + Response = (), + >, + U::Error: fmt::Display + Into, + U::InitError: fmt::Debug, + { + /// Create openssl based service + pub fn openssl( + self, + acceptor: SslAcceptor, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = SslError, DispatchError>, + InitError = (), + > { + pipeline_factory( + Acceptor::new(acceptor) + .map_err(SslError::Ssl) + .map_init_err(|_| panic!()), + ) + .and_then(|io: SslStream| { + let peer_addr = io.get_ref().peer_addr().ok(); + ok((io, peer_addr)) + }) + .and_then(self.map_err(SslError::Service)) + } + } +} + +#[cfg(feature = "rustls")] +mod rustls { + use super::*; + use actix_tls::rustls::{Acceptor, ServerConfig, TlsStream}; + use actix_tls::SslError; + use std::{fmt, io}; + + impl H1Service, S, B, X, U> + where + S: ServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + S::Response: Into>, + B: MessageBody, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + U: ServiceFactory< + Config = (), + Request = (Request, Framed, Codec>), + Response = (), + >, + U::Error: fmt::Display + Into, + U::InitError: fmt::Debug, + { + /// Create rustls based service + pub fn rustls( + self, + config: ServerConfig, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = SslError, + InitError = (), + > { + pipeline_factory( + Acceptor::new(config) + .map_err(SslError::Ssl) + .map_init_err(|_| panic!()), + ) + .and_then(|io: TlsStream| { + let peer_addr = io.get_ref().0.peer_addr().ok(); + ok((io, peer_addr)) + }) + .and_then(self.map_err(SslError::Service)) + } + } +} + +impl H1Service +where + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, { - pub fn expect(self, expect: X1) -> H1Service + pub fn expect(self, expect: X1) -> H1Service where - X1: NewService, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, { @@ -89,9 +217,9 @@ where } } - pub fn upgrade(self, upgrade: Option) -> H1Service + pub fn upgrade(self, upgrade: Option) -> H1Service where - U1: NewService), Response = ()>, + U1: ServiceFactory), Response = ()>, U1::Error: fmt::Display, U1::InitError: fmt::Debug, { @@ -115,38 +243,34 @@ where } } -impl NewService for H1Service +impl ServiceFactory for H1Service where - T: IoStream, - S: NewService, + T: AsyncRead + AsyncWrite + Unpin, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService< - Config = SrvConfig, - Request = (Request, Framed), - Response = (), - >, - U::Error: fmt::Display, + U: ServiceFactory), Response = ()>, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, { - type Config = SrvConfig; - type Request = Io; + type Config = (); + type Request = (T, Option); type Response = (); type Error = DispatchError; type InitError = (); - type Service = H1ServiceHandler; - type Future = H1ServiceResponse; + type Service = H1ServiceHandler; + type Future = H1ServiceResponse; - fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { H1ServiceResponse { - fut: self.srv.new_service(cfg).into_future(), - fut_ex: Some(self.expect.new_service(cfg)), - fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), + fut: self.srv.new_service(()), + fut_ex: Some(self.expect.new_service(())), + fut_upg: self.upgrade.as_ref().map(|f| f.new_service(())), expect: None, upgrade: None, on_connect: self.on_connect.clone(), @@ -157,88 +281,99 @@ where } #[doc(hidden)] -pub struct H1ServiceResponse +#[pin_project::pin_project] +pub struct H1ServiceResponse where - S: NewService, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService), Response = ()>, + U: ServiceFactory), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, { + #[pin] fut: S::Future, + #[pin] fut_ex: Option, + #[pin] fut_upg: Option, expect: Option, upgrade: Option, on_connect: Option Box>>, cfg: Option, - _t: PhantomData<(T, P, B)>, + _t: PhantomData<(T, B)>, } -impl Future for H1ServiceResponse +impl Future for H1ServiceResponse where - T: IoStream, - S: NewService, + T: AsyncRead + AsyncWrite + Unpin, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService), Response = ()>, + U: ServiceFactory), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, { - type Item = H1ServiceHandler; - type Error = (); + type Output = Result, ()>; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut_ex { - let expect = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.expect = Some(expect); - self.fut_ex.take(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().project(); + + if let Some(fut) = this.fut_ex.as_pin_mut() { + let expect = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.expect = Some(expect); + this.fut_ex.set(None); } - if let Some(ref mut fut) = self.fut_upg { - let upgrade = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.upgrade = Some(upgrade); - self.fut_ex.take(); + if let Some(fut) = this.fut_upg.as_pin_mut() { + let upgrade = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.upgrade = Some(upgrade); + this.fut_ex.set(None); } - let service = try_ready!(self + let result = ready!(this .fut - .poll() + .poll(cx) .map_err(|e| log::error!("Init http service error: {:?}", e))); - Ok(Async::Ready(H1ServiceHandler::new( - self.cfg.take().unwrap(), - service, - self.expect.take().unwrap(), - self.upgrade.take(), - self.on_connect.clone(), - ))) + + Poll::Ready(result.map(|service| { + let this = self.as_mut().project(); + H1ServiceHandler::new( + this.cfg.take().unwrap(), + service, + this.expect.take().unwrap(), + this.upgrade.take(), + this.on_connect.clone(), + ) + })) } } /// `Service` implementation for HTTP1 transport -pub struct H1ServiceHandler { +pub struct H1ServiceHandler { srv: CloneableService, expect: CloneableService, upgrade: Option>, on_connect: Option Box>>, cfg: ServiceConfig, - _t: PhantomData<(T, P, B)>, + _t: PhantomData<(T, B)>, } -impl H1ServiceHandler +impl H1ServiceHandler where S: Service, S::Error: Into, @@ -255,7 +390,7 @@ where expect: X, upgrade: Option, on_connect: Option Box>>, - ) -> H1ServiceHandler { + ) -> H1ServiceHandler { H1ServiceHandler { srv: CloneableService::new(srv), expect: CloneableService::new(expect), @@ -267,9 +402,9 @@ where } } -impl Service for H1ServiceHandler +impl Service for H1ServiceHandler where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, S::Error: Into, S::Response: Into>, @@ -277,17 +412,17 @@ where X: Service, X::Error: Into, U: Service), Response = ()>, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, { - type Request = Io; + type Request = (T, Option); type Response = (); type Error = DispatchError; type Future = Dispatcher; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { let ready = self .expect - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -297,7 +432,7 @@ where let ready = self .srv - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -306,16 +441,27 @@ where .is_ready() && ready; - if ready { - Ok(Async::Ready(())) + let ready = if let Some(ref mut upg) = self.upgrade { + upg.poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready() + && ready } else { - Ok(Async::NotReady) + ready + }; + + if ready { + Poll::Ready(Ok(())) + } else { + Poll::Pending } } - fn call(&mut self, req: Self::Request) -> Self::Future { - let io = req.into_parts().0; - + fn call(&mut self, (io, addr): Self::Request) -> Self::Future { let on_connect = if let Some(ref on_connect) = self.on_connect { Some(on_connect(&io)) } else { @@ -329,20 +475,21 @@ where self.expect.clone(), self.upgrade.clone(), on_connect, + addr, ) } } -/// `NewService` implementation for `OneRequestService` service +/// `ServiceFactory` implementation for `OneRequestService` service #[derive(Default)] -pub struct OneRequest { +pub struct OneRequest { config: ServiceConfig, - _t: PhantomData<(T, P)>, + _t: PhantomData, } -impl OneRequest +impl OneRequest where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, { /// Create new `H1SimpleService` instance. pub fn new() -> Self { @@ -353,52 +500,49 @@ where } } -impl NewService for OneRequest +impl ServiceFactory for OneRequest where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, { - type Config = SrvConfig; - type Request = Io; + type Config = (); + type Request = T; type Response = (Request, Framed); type Error = ParseError; type InitError = (); - type Service = OneRequestService; - type Future = FutureResult; + type Service = OneRequestService; + type Future = Ready>; - fn new_service(&self, _: &SrvConfig) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { ok(OneRequestService { - config: self.config.clone(), _t: PhantomData, + config: self.config.clone(), }) } } /// `Service` implementation for HTTP1 transport. Reads one request and returns /// request and framed object. -pub struct OneRequestService { +pub struct OneRequestService { + _t: PhantomData, config: ServiceConfig, - _t: PhantomData<(T, P)>, } -impl Service for OneRequestService +impl Service for OneRequestService where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, { - type Request = Io; + type Request = T; type Response = (Request, Framed); type Error = ParseError; type Future = OneRequestServiceResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: Self::Request) -> Self::Future { OneRequestServiceResponse { - framed: Some(Framed::new( - req.into_parts().0, - Codec::new(self.config.clone()), - )), + framed: Some(Framed::new(req, Codec::new(self.config.clone()))), } } } @@ -406,28 +550,28 @@ where #[doc(hidden)] pub struct OneRequestServiceResponse where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, { framed: Option>, } impl Future for OneRequestServiceResponse where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, { - type Item = (Request, Framed); - type Error = ParseError; + type Output = Result<(Request, Framed), ParseError>; - fn poll(&mut self) -> Poll { - match self.framed.as_mut().unwrap().poll()? { - Async::Ready(Some(req)) => match req { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.framed.as_mut().unwrap().next_item(cx) { + Poll::Ready(Some(Ok(req))) => match req { Message::Item(req) => { - Ok(Async::Ready((req, self.framed.take().unwrap()))) + Poll::Ready(Ok((req, self.framed.take().unwrap()))) } Message::Chunk(_) => unreachable!("Something is wrong"), }, - Async::Ready(None) => Err(ParseError::Incomplete), - Async::NotReady => Ok(Async::NotReady), + Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)), + Poll::Ready(None) => Poll::Ready(Err(ParseError::Incomplete)), + Poll::Pending => Poll::Pending, } } } diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs index 0278f23e5..22ba99e26 100644 --- a/actix-http/src/h1/upgrade.rs +++ b/actix-http/src/h1/upgrade.rs @@ -1,10 +1,9 @@ use std::marker::PhantomData; +use std::task::{Context, Poll}; use actix_codec::Framed; -use actix_server_config::ServerConfig; -use actix_service::{NewService, Service}; -use futures::future::FutureResult; -use futures::{Async, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures_util::future::Ready; use crate::error::Error; use crate::h1::Codec; @@ -12,16 +11,16 @@ use crate::request::Request; pub struct UpgradeHandler(PhantomData); -impl NewService for UpgradeHandler { - type Config = ServerConfig; +impl ServiceFactory for UpgradeHandler { + type Config = (); type Request = (Request, Framed); type Response = (); type Error = Error; type Service = UpgradeHandler; type InitError = Error; - type Future = FutureResult; + type Future = Ready>; - fn new_service(&self, _: &ServerConfig) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { unimplemented!() } } @@ -30,10 +29,10 @@ impl Service for UpgradeHandler { type Request = (Request, Framed); type Response = (); type Error = Error; - type Future = FutureResult; + type Future = Ready>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, _: Self::Request) -> Self::Future { diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index fdc4cf0bc..9ba4aa053 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -1,5 +1,8 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use futures::{Async, Future, Poll, Sink}; use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::error::Error; @@ -7,6 +10,7 @@ use crate::h1::{Codec, Message}; use crate::response::Response; /// Send http/1 response +#[pin_project::pin_project] pub struct SendResponse { res: Option, BodySize)>>, body: Option>, @@ -33,60 +37,61 @@ where T: AsyncRead + AsyncWrite, B: MessageBody, { - type Item = Framed; - type Error = Error; + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); - fn poll(&mut self) -> Poll { loop { - let mut body_ready = self.body.is_some(); - let framed = self.framed.as_mut().unwrap(); + let mut body_ready = this.body.is_some(); + let framed = this.framed.as_mut().unwrap(); // send body - if self.res.is_none() && self.body.is_some() { - while body_ready && self.body.is_some() && !framed.is_write_buf_full() { - match self.body.as_mut().unwrap().poll_next()? { - Async::Ready(item) => { + if this.res.is_none() && this.body.is_some() { + while body_ready && this.body.is_some() && !framed.is_write_buf_full() { + match this.body.as_mut().unwrap().poll_next(cx)? { + Poll::Ready(item) => { // body is done if item.is_none() { - let _ = self.body.take(); + let _ = this.body.take(); } - framed.force_send(Message::Chunk(item))?; + framed.write(Message::Chunk(item))?; } - Async::NotReady => body_ready = false, + Poll::Pending => body_ready = false, } } } // flush write buffer if !framed.is_write_buf_empty() { - match framed.poll_complete()? { - Async::Ready(_) => { + match framed.flush(cx)? { + Poll::Ready(_) => { if body_ready { continue; } else { - return Ok(Async::NotReady); + return Poll::Pending; } } - Async::NotReady => return Ok(Async::NotReady), + Poll::Pending => return Poll::Pending, } } // send response - if let Some(res) = self.res.take() { - framed.force_send(res)?; + if let Some(res) = this.res.take() { + framed.write(res)?; continue; } - if self.body.is_some() { + if this.body.is_some() { if body_ready { continue; } else { - return Ok(Async::NotReady); + return Poll::Pending; } } else { break; } } - Ok(Async::Ready(self.framed.take().unwrap())) + Poll::Ready(Ok(this.framed.take().unwrap())) } } diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index 888f9065e..a4ec15fab 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -1,27 +1,23 @@ -use std::collections::VecDeque; +use std::convert::TryFrom; +use std::future::Future; use std::marker::PhantomData; -use std::time::Instant; -use std::{fmt, mem, net}; +use std::net; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite}; -use actix_server_config::IoStream; +use actix_rt::time::{Delay, Instant}; use actix_service::Service; -use bitflags::bitflags; use bytes::{Bytes, BytesMut}; -use futures::{try_ready, Async, Future, Poll, Sink, Stream}; use h2::server::{Connection, SendResponse}; -use h2::{RecvStream, SendStream}; -use http::header::{ - HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, -}; -use http::HttpTryFrom; -use log::{debug, error, trace}; -use tokio_timer::Delay; +use h2::SendStream; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use log::{error, trace}; -use crate::body::{Body, BodySize, MessageBody, ResponseBody}; +use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::cloneable::CloneableService; use crate::config::ServiceConfig; -use crate::error::{DispatchError, Error, ParseError, PayloadError, ResponseError}; +use crate::error::{DispatchError, Error}; use crate::helpers::DataFactory; use crate::httpmessage::HttpMessage; use crate::message::ResponseHead; @@ -32,7 +28,11 @@ use crate::response::Response; const CHUNK_SIZE: usize = 16_384; /// Dispatcher for HTTP/2 protocol -pub struct Dispatcher, B: MessageBody> { +#[pin_project::pin_project] +pub struct Dispatcher, B: MessageBody> +where + T: AsyncRead + AsyncWrite + Unpin, +{ service: CloneableService, connection: Connection, on_connect: Option>, @@ -45,12 +45,12 @@ pub struct Dispatcher, B: MessageBody impl Dispatcher where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, S::Error: Into, - S::Future: 'static, + // S::Future: 'static, S::Response: Into>, - B: MessageBody + 'static, + B: MessageBody, { pub(crate) fn new( service: CloneableService, @@ -91,63 +91,77 @@ where impl Future for Dispatcher where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { - type Item = (); - type Error = DispatchError; + type Output = Result<(), DispatchError>; #[inline] - fn poll(&mut self) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + loop { - match self.connection.poll()? { - Async::Ready(None) => return Ok(Async::Ready(())), - Async::Ready(Some((req, res))) => { + match Pin::new(&mut this.connection).poll_accept(cx) { + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(Some(Ok((req, res)))) => { // update keep-alive expire - if self.ka_timer.is_some() { - if let Some(expire) = self.config.keep_alive_expire() { - self.ka_expire = expire; + if this.ka_timer.is_some() { + if let Some(expire) = this.config.keep_alive_expire() { + this.ka_expire = expire; } } let (parts, body) = req.into_parts(); - let mut req = Request::with_payload(body.into()); + let mut req = Request::with_payload(Payload::< + crate::payload::PayloadStream, + >::H2( + crate::h2::Payload::new(body) + )); let head = &mut req.head_mut(); head.uri = parts.uri; head.method = parts.method; head.version = parts.version; head.headers = parts.headers.into(); - head.peer_addr = self.peer_addr; + head.peer_addr = this.peer_addr; // set on_connect data - if let Some(ref on_connect) = self.on_connect { + if let Some(ref on_connect) = this.on_connect { on_connect.set(&mut req.extensions_mut()); } - tokio_current_thread::spawn(ServiceResponse:: { + actix_rt::spawn(ServiceResponse::< + S::Future, + S::Response, + S::Error, + B, + > { state: ServiceResponseState::ServiceCall( - self.service.call(req), + this.service.call(req), Some(res), ), - config: self.config.clone(), + config: this.config.clone(), buffer: None, - }) + _t: PhantomData, + }); } - Async::NotReady => return Ok(Async::NotReady), + Poll::Pending => return Poll::Pending, } } } } -struct ServiceResponse { +#[pin_project::pin_project] +struct ServiceResponse { state: ServiceResponseState, config: ServiceConfig, buffer: Option, + _t: PhantomData<(I, E)>, } enum ServiceResponseState { @@ -155,12 +169,12 @@ enum ServiceResponseState { SendPayload(SendStream, ResponseBody), } -impl ServiceResponse +impl ServiceResponse where - F: Future, - F::Error: Into, - F::Item: Into>, - B: MessageBody + 'static, + F: Future>, + E: Into, + I: Into>, + B: MessageBody, { fn prepare_response( &self, @@ -215,117 +229,130 @@ where if !has_date { let mut bytes = BytesMut::with_capacity(29); self.config.set_date_header(&mut bytes); - res.headers_mut() - .insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); + res.headers_mut().insert(DATE, unsafe { + HeaderValue::from_maybe_shared_unchecked(bytes.freeze()) + }); } res } } -impl Future for ServiceResponse +impl Future for ServiceResponse where - F: Future, - F::Error: Into, - F::Item: Into>, - B: MessageBody + 'static, + F: Future>, + E: Into, + I: Into>, + B: MessageBody, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - match self.state { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().project(); + + match this.state { ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { - match call.poll() { - Ok(Async::Ready(res)) => { + match unsafe { Pin::new_unchecked(call) }.poll(cx) { + Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); let mut send = send.take().unwrap(); let mut size = body.size(); - let h2_res = self.prepare_response(res.head(), &mut size); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); - let stream = - send.send_response(h2_res, size.is_eof()).map_err(|e| { + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { trace!("Error sending h2 response: {:?}", e); - })?; + return Poll::Ready(()); + } + Ok(stream) => stream, + }; if size.is_eof() { - Ok(Async::Ready(())) + Poll::Ready(()) } else { - self.state = ServiceResponseState::SendPayload(stream, body); - self.poll() + *this.state = + ServiceResponseState::SendPayload(stream, body); + self.poll(cx) } } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); let mut send = send.take().unwrap(); let mut size = body.size(); - let h2_res = self.prepare_response(res.head(), &mut size); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); - let stream = - send.send_response(h2_res, size.is_eof()).map_err(|e| { + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { trace!("Error sending h2 response: {:?}", e); - })?; + return Poll::Ready(()); + } + Ok(stream) => stream, + }; if size.is_eof() { - Ok(Async::Ready(())) + Poll::Ready(()) } else { - self.state = ServiceResponseState::SendPayload( + *this.state = ServiceResponseState::SendPayload( stream, body.into_body(), ); - self.poll() + self.poll(cx) } } } } ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { loop { - if let Some(ref mut buffer) = self.buffer { - match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => return Ok(Async::Ready(())), - Async::Ready(Some(cap)) => { + if let Some(ref mut buffer) = this.buffer { + match stream.poll_capacity(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(()), + Poll::Ready(Some(Ok(cap))) => { let len = buffer.len(); let bytes = buffer.split_to(std::cmp::min(cap, len)); if let Err(e) = stream.send_data(bytes, false) { warn!("{:?}", e); - return Err(()); + return Poll::Ready(()); } else if !buffer.is_empty() { let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); stream.reserve_capacity(cap); } else { - self.buffer.take(); + this.buffer.take(); } } + Poll::Ready(Some(Err(e))) => { + warn!("{:?}", e); + return Poll::Ready(()); + } } } else { - match body.poll_next() { - Ok(Async::NotReady) => { - return Ok(Async::NotReady); - } - Ok(Async::Ready(None)) => { + match body.poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => { if let Err(e) = stream.send_data(Bytes::new(), true) { warn!("{:?}", e); - return Err(()); - } else { - return Ok(Async::Ready(())); } + return Poll::Ready(()); } - Ok(Async::Ready(Some(chunk))) => { + Poll::Ready(Some(Ok(chunk))) => { stream.reserve_capacity(std::cmp::min( chunk.len(), CHUNK_SIZE, )); - self.buffer = Some(chunk); + *this.buffer = Some(chunk); } - Err(e) => { + Poll::Ready(Some(Err(e))) => { error!("Response payload stream error: {:?}", e); - return Err(()); + return Poll::Ready(()); } } } diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs index c5972123f..b00969227 100644 --- a/actix-http/src/h2/mod.rs +++ b/actix-http/src/h2/mod.rs @@ -1,9 +1,9 @@ -#![allow(dead_code, unused_imports)] - -use std::fmt; +//! HTTP/2 implementation +use std::pin::Pin; +use std::task::{Context, Poll}; use bytes::Bytes; -use futures::{Async, Poll, Stream}; +use futures_core::Stream; use h2::RecvStream; mod dispatcher; @@ -25,22 +25,26 @@ impl Payload { } impl Stream for Payload { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - match self.pl.poll() { - Ok(Async::Ready(Some(chunk))) => { + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + + match Pin::new(&mut this.pl).poll_data(cx) { + Poll::Ready(Some(Ok(chunk))) => { let len = chunk.len(); - if let Err(err) = self.pl.release_capacity().release_capacity(len) { - Err(err.into()) + if let Err(err) = this.pl.flow_control().release_capacity(len) { + Poll::Ready(Some(Err(err.into()))) } else { - Ok(Async::Ready(Some(chunk))) + Poll::Ready(Some(Ok(chunk))) } } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => Err(err.into()), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), } } } diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index e894cf660..ff3f69faf 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -1,62 +1,56 @@ -use std::fmt::Debug; +use std::future::Future; use std::marker::PhantomData; -use std::{io, net, rc}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{net, rc}; -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; -use actix_service::{IntoNewService, NewService, Service}; +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::net::TcpStream; +use actix_service::{ + fn_factory, fn_service, pipeline_factory, IntoServiceFactory, Service, + ServiceFactory, +}; use bytes::Bytes; -use futures::future::{ok, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; -use h2::server::{self, Connection, Handshake}; -use h2::RecvStream; +use futures_core::ready; +use futures_util::future::ok; +use h2::server::{self, Handshake}; use log::error; use crate::body::MessageBody; use crate::cloneable::CloneableService; -use crate::config::{KeepAlive, ServiceConfig}; -use crate::error::{DispatchError, Error, ParseError, ResponseError}; +use crate::config::ServiceConfig; +use crate::error::{DispatchError, Error}; use crate::helpers::DataFactory; -use crate::payload::Payload; use crate::request::Request; use crate::response::Response; use super::dispatcher::Dispatcher; -/// `NewService` implementation for HTTP2 transport -pub struct H2Service { +/// `ServiceFactory` implementation for HTTP2 transport +pub struct H2Service { srv: S, cfg: ServiceConfig, on_connect: Option Box>>, - _t: PhantomData<(T, P, B)>, + _t: PhantomData<(T, B)>, } -impl H2Service +impl H2Service where - S: NewService, - S::Error: Into, - S::Response: Into>, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { - /// Create new `HttpService` instance. - pub fn new>(service: F) -> Self { - let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); - - H2Service { - cfg, - on_connect: None, - srv: service.into_new_service(), - _t: PhantomData, - } - } - /// Create new `HttpService` instance with config. - pub fn with_config>(cfg: ServiceConfig, service: F) -> Self { + pub(crate) fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { H2Service { cfg, on_connect: None, - srv: service.into_new_service(), + srv: service.into_factory(), _t: PhantomData, } } @@ -71,26 +65,144 @@ where } } -impl NewService for H2Service +impl H2Service where - T: IoStream, - S: NewService, - S::Error: Into, - S::Response: Into>, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { - type Config = SrvConfig; - type Request = Io; + /// Create simple tcp based service + pub fn tcp( + self, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = DispatchError, + InitError = S::InitError, + > { + pipeline_factory(fn_factory(|| { + async { + Ok::<_, S::InitError>(fn_service(|io: TcpStream| { + let peer_addr = io.peer_addr().ok(); + ok::<_, DispatchError>((io, peer_addr)) + })) + } + })) + .and_then(self) + } +} + +#[cfg(feature = "openssl")] +mod openssl { + use actix_service::{fn_factory, fn_service}; + use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream}; + use actix_tls::{openssl::HandshakeError, SslError}; + + use super::*; + + impl H2Service, S, B> + where + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + { + /// Create ssl based service + pub fn openssl( + self, + acceptor: SslAcceptor, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = SslError, DispatchError>, + InitError = S::InitError, + > { + pipeline_factory( + Acceptor::new(acceptor) + .map_err(SslError::Ssl) + .map_init_err(|_| panic!()), + ) + .and_then(fn_factory(|| { + ok::<_, S::InitError>(fn_service(|io: SslStream| { + let peer_addr = io.get_ref().peer_addr().ok(); + ok((io, peer_addr)) + })) + })) + .and_then(self.map_err(SslError::Service)) + } + } +} + +#[cfg(feature = "rustls")] +mod rustls { + use super::*; + use actix_tls::rustls::{Acceptor, ServerConfig, TlsStream}; + use actix_tls::SslError; + use std::io; + + impl H2Service, S, B> + where + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + { + /// Create openssl based service + pub fn rustls( + self, + mut config: ServerConfig, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = SslError, + InitError = S::InitError, + > { + let protos = vec!["h2".to_string().into()]; + config.set_protocols(&protos); + + pipeline_factory( + Acceptor::new(config) + .map_err(SslError::Ssl) + .map_init_err(|_| panic!()), + ) + .and_then(fn_factory(|| { + ok::<_, S::InitError>(fn_service(|io: TlsStream| { + let peer_addr = io.get_ref().0.peer_addr().ok(); + ok((io, peer_addr)) + })) + })) + .and_then(self.map_err(SslError::Service)) + } + } +} + +impl ServiceFactory for H2Service +where + T: AsyncRead + AsyncWrite + Unpin, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + type Config = (); + type Request = (T, Option); type Response = (); type Error = DispatchError; type InitError = S::InitError; - type Service = H2ServiceHandler; - type Future = H2ServiceResponse; + type Service = H2ServiceHandler; + type Future = H2ServiceResponse; - fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { H2ServiceResponse { - fut: self.srv.new_service(cfg).into_future(), + fut: self.srv.new_service(()), cfg: Some(self.cfg.clone()), on_connect: self.on_connect.clone(), _t: PhantomData, @@ -99,56 +211,61 @@ where } #[doc(hidden)] -pub struct H2ServiceResponse { - fut: ::Future, +#[pin_project::pin_project] +pub struct H2ServiceResponse { + #[pin] + fut: S::Future, cfg: Option, on_connect: Option Box>>, - _t: PhantomData<(T, P, B)>, + _t: PhantomData<(T, B)>, } -impl Future for H2ServiceResponse +impl Future for H2ServiceResponse where - T: IoStream, - S: NewService, - S::Error: Into, - S::Response: Into>, + T: AsyncRead + AsyncWrite + Unpin, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { - type Item = H2ServiceHandler; - type Error = S::InitError; + type Output = Result, S::InitError>; - fn poll(&mut self) -> Poll { - let service = try_ready!(self.fut.poll()); - Ok(Async::Ready(H2ServiceHandler::new( - self.cfg.take().unwrap(), - self.on_connect.clone(), - service, - ))) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); + + Poll::Ready(ready!(this.fut.poll(cx)).map(|service| { + let this = self.as_mut().project(); + H2ServiceHandler::new( + this.cfg.take().unwrap(), + this.on_connect.clone(), + service, + ) + })) } } /// `Service` implementation for http/2 transport -pub struct H2ServiceHandler { +pub struct H2ServiceHandler { srv: CloneableService, cfg: ServiceConfig, on_connect: Option Box>>, - _t: PhantomData<(T, P, B)>, + _t: PhantomData<(T, B)>, } -impl H2ServiceHandler +impl H2ServiceHandler where S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { fn new( cfg: ServiceConfig, on_connect: Option Box>>, srv: S, - ) -> H2ServiceHandler { + ) -> H2ServiceHandler { H2ServiceHandler { cfg, on_connect, @@ -158,31 +275,29 @@ where } } -impl Service for H2ServiceHandler +impl Service for H2ServiceHandler where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { - type Request = Io; + type Request = (T, Option); type Response = (); type Error = DispatchError; type Future = H2ServiceHandlerResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.srv.poll_ready().map_err(|e| { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.srv.poll_ready(cx).map_err(|e| { let e = e.into(); error!("Service readiness error: {:?}", e); DispatchError::Service(e) }) } - fn call(&mut self, req: Self::Request) -> Self::Future { - let io = req.into_parts().0; - let peer_addr = io.peer_addr(); + fn call(&mut self, (io, addr): Self::Request) -> Self::Future { let on_connect = if let Some(ref on_connect) = self.on_connect { Some(on_connect(&io)) } else { @@ -193,7 +308,7 @@ where state: State::Handshake( Some(self.srv.clone()), Some(self.cfg.clone()), - peer_addr, + addr, on_connect, server::handshake(io), ), @@ -201,8 +316,9 @@ where } } -enum State, B: MessageBody> +enum State, B: MessageBody> where + T: AsyncRead + AsyncWrite + Unpin, S::Future: 'static, { Incoming(Dispatcher), @@ -217,11 +333,11 @@ where pub struct H2ServiceHandlerResponse where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, { state: State, @@ -229,27 +345,26 @@ where impl Future for H2ServiceHandlerResponse where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody, { - type Item = (); - type Error = DispatchError; + type Output = Result<(), DispatchError>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.state { - State::Incoming(ref mut disp) => disp.poll(), + State::Incoming(ref mut disp) => Pin::new(disp).poll(cx), State::Handshake( ref mut srv, ref mut config, ref peer_addr, ref mut on_connect, ref mut handshake, - ) => match handshake.poll() { - Ok(Async::Ready(conn)) => { + ) => match Pin::new(handshake).poll(cx) { + Poll::Ready(Ok(conn)) => { self.state = State::Incoming(Dispatcher::new( srv.take().unwrap(), conn, @@ -258,13 +373,13 @@ where None, *peer_addr, )); - self.poll() + self.poll(cx) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { + Poll::Ready(Err(err)) => { trace!("H2 handshake error: {}", err); - Err(err.into()) + Poll::Ready(Err(err.into())) } + Poll::Pending => Poll::Pending, }, } } diff --git a/actix-http/src/header/common/cache_control.rs b/actix-http/src/header/common/cache_control.rs index 55774619b..ec94ce4a9 100644 --- a/actix-http/src/header/common/cache_control.rs +++ b/actix-http/src/header/common/cache_control.rs @@ -74,18 +74,18 @@ impl Header for CacheControl { } impl fmt::Display for CacheControl { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt_comma_delimited(f, &self[..]) } } impl IntoHeaderValue for CacheControl { - type Error = header::InvalidHeaderValueBytes; + type Error = header::InvalidHeaderValue; fn try_into(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); - header::HeaderValue::from_shared(writer.take()) + header::HeaderValue::from_maybe_shared(writer.take()) } } @@ -126,7 +126,7 @@ pub enum CacheDirective { } impl fmt::Display for CacheDirective { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::CacheDirective::*; fmt::Display::fmt( match *self { diff --git a/actix-http/src/header/common/content_disposition.rs b/actix-http/src/header/common/content_disposition.rs index 14fcc3517..d0d5af765 100644 --- a/actix-http/src/header/common/content_disposition.rs +++ b/actix-http/src/header/common/content_disposition.rs @@ -76,6 +76,11 @@ pub enum DispositionParam { /// the form. Name(String), /// A plain file name. + /// + /// It is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any + /// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where + /// [`FilenameExt`](DispositionParam::FilenameExt) with charset UTF-8 may be used instead + /// in case there are Unicode characters in file names. Filename(String), /// An extended file name. It must not exist for `ContentType::Formdata` according to /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). @@ -220,7 +225,16 @@ impl DispositionParam { /// ext-token = /// ``` /// -/// **Note**: filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within +/// # Note +/// +/// filename is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any +/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where +/// filename* with charset UTF-8 may be used instead in case there are Unicode characters in file +/// names. +/// filename is [acceptable](https://tools.ietf.org/html/rfc7578#section-4.2) to be UTF-8 encoded +/// directly in a *Content-Disposition* header for *multipart/form-data*, though. +/// +/// filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within /// *multipart/form-data*. /// /// # Example @@ -251,6 +265,22 @@ impl DispositionParam { /// }; /// assert_eq!(cd2.get_name(), Some("file")); // field name /// assert_eq!(cd2.get_filename(), Some("bill.odt")); +/// +/// // HTTP response header with Unicode characters in file names +/// let cd3 = ContentDisposition { +/// disposition: DispositionType::Attachment, +/// parameters: vec![ +/// DispositionParam::FilenameExt(ExtendedValue { +/// charset: Charset::Ext(String::from("UTF-8")), +/// language_tag: None, +/// value: String::from("\u{1f600}.svg").into_bytes(), +/// }), +/// // fallback for better compatibility +/// DispositionParam::Filename(String::from("Grinning-Face-Emoji.svg")) +/// ], +/// }; +/// assert_eq!(cd3.get_filename_ext().map(|ev| ev.value.as_ref()), +/// Some("\u{1f600}.svg".as_bytes())); /// ``` /// /// # WARN @@ -333,15 +363,17 @@ impl ContentDisposition { // token: won't contains semicolon according to RFC 2616 Section 2.2 let (token, new_left) = split_once_and_trim(left, ';'); left = new_left; + if token.is_empty() { + // quoted-string can be empty, but token cannot be empty + return Err(crate::error::ParseError::Header); + } token.to_owned() }; - if value.is_empty() { - return Err(crate::error::ParseError::Header); - } let param = if param_name.eq_ignore_ascii_case("name") { DispositionParam::Name(value) } else if param_name.eq_ignore_ascii_case("filename") { + // See also comments in test_from_raw_uncessary_percent_decode. DispositionParam::Filename(value) } else { DispositionParam::Unknown(param_name.to_owned(), value) @@ -430,12 +462,12 @@ impl ContentDisposition { } impl IntoHeaderValue for ContentDisposition { - type Error = header::InvalidHeaderValueBytes; + type Error = header::InvalidHeaderValue; fn try_into(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); - header::HeaderValue::from_shared(writer.take()) + header::HeaderValue::from_maybe_shared(writer.take()) } } @@ -454,7 +486,7 @@ impl Header for ContentDisposition { } impl fmt::Display for DispositionType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { DispositionType::Inline => write!(f, "inline"), DispositionType::Attachment => write!(f, "attachment"), @@ -465,12 +497,41 @@ impl fmt::Display for DispositionType { } impl fmt::Display for DispositionParam { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // All ASCII control charaters (0-30, 127) excepting horizontal tab, double quote, and + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // All ASCII control characters (0-30, 127) including horizontal tab, double quote, and // backslash should be escaped in quoted-string (i.e. "foobar"). - // Ref: RFC6266 S4.1 -> RFC2616 S2.2; RFC 7578 S4.2 -> RFC2183 S2 -> ... . + // Ref: RFC6266 S4.1 -> RFC2616 S3.6 + // filename-parm = "filename" "=" value + // value = token | quoted-string + // quoted-string = ( <"> *(qdtext | quoted-pair ) <"> ) + // qdtext = > + // quoted-pair = "\" CHAR + // TEXT = + // LWS = [CRLF] 1*( SP | HT ) + // OCTET = + // CHAR = + // CTL = + // + // Ref: RFC7578 S4.2 -> RFC2183 S2 -> RFC2045 S5.1 + // parameter := attribute "=" value + // attribute := token + // ; Matching of attributes + // ; is ALWAYS case-insensitive. + // value := token / quoted-string + // token := 1* + // tspecials := "(" / ")" / "<" / ">" / "@" / + // "," / ";" / ":" / "\" / <"> + // "/" / "[" / "]" / "?" / "=" + // ; Must be in quoted-string, + // ; to use within parameter values + // + // + // See also comments in test_from_raw_uncessary_percent_decode. lazy_static! { - static ref RE: Regex = Regex::new("[\x01-\x08\x10\x1F\x7F\"\\\\]").unwrap(); + static ref RE: Regex = Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap(); } match self { DispositionParam::Name(ref value) => write!(f, "name={}", value), @@ -494,7 +555,7 @@ impl fmt::Display for DispositionParam { } impl fmt::Display for ContentDisposition { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.disposition)?; self.parameters .iter() @@ -707,9 +768,8 @@ mod tests { Mainstream browsers like Firefox (gecko) and Chrome use UTF-8 directly as above. (And now, only UTF-8 is handled by this implementation.) */ - let a = - HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"") - .unwrap(); + let a = HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"") + .unwrap(); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { disposition: DispositionType::FormData, @@ -774,8 +834,18 @@ mod tests { #[test] fn test_from_raw_uncessary_percent_decode() { + // In fact, RFC7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with + // non-ASCII characters MAY be percent-encoded. + // On the contrary, RFC6266 or other RFCs related to Content-Disposition response header + // do not mention such percent-encoding. + // So, it appears to be undecidable whether to percent-decode or not without + // knowing the usage scenario (multipart/form-data v.s. HTTP response header) and + // inevitable to unnecessarily percent-decode filename with %XX in the former scenario. + // Fortunately, it seems that almost all mainstream browsers just send UTF-8 encoded file + // names in quoted-string format (tested on Edge, IE11, Chrome and Firefox) without + // percent-encoding. So we do not bother to attempt to percent-decode. let a = HeaderValue::from_static( - "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", // Should not be decoded! + "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", ); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { @@ -811,6 +881,13 @@ mod tests { let a = HeaderValue::from_static("inline; filename= "); assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; filename=\"\""); + assert!(ContentDisposition::from_raw(&a) + .expect("parse cd") + .get_filename() + .expect("filename") + .is_empty()); } #[test] diff --git a/actix-http/src/header/common/content_range.rs b/actix-http/src/header/common/content_range.rs index cc7f27548..9a604c641 100644 --- a/actix-http/src/header/common/content_range.rs +++ b/actix-http/src/header/common/content_range.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use crate::error::ParseError; use crate::header::{ - HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer, CONTENT_RANGE, + HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer, CONTENT_RANGE, }; header! { @@ -166,7 +166,7 @@ impl FromStr for ContentRangeSpec { } impl Display for ContentRangeSpec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { ContentRangeSpec::Bytes { range, @@ -198,11 +198,11 @@ impl Display for ContentRangeSpec { } impl IntoHeaderValue for ContentRangeSpec { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; fn try_into(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); - HeaderValue::from_shared(writer.take()) + HeaderValue::from_maybe_shared(writer.take()) } } diff --git a/actix-http/src/header/common/if_range.rs b/actix-http/src/header/common/if_range.rs index e910ebd96..b14ad0391 100644 --- a/actix-http/src/header/common/if_range.rs +++ b/actix-http/src/header/common/if_range.rs @@ -3,7 +3,7 @@ use std::fmt::{self, Display, Write}; use crate::error::ParseError; use crate::header::{ self, from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate, - IntoHeaderValue, InvalidHeaderValueBytes, Writer, + IntoHeaderValue, InvalidHeaderValue, Writer, }; use crate::httpmessage::HttpMessage; @@ -87,7 +87,7 @@ impl Header for IfRange { } impl Display for IfRange { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { IfRange::EntityTag(ref x) => Display::fmt(x, f), IfRange::Date(ref x) => Display::fmt(x, f), @@ -96,12 +96,12 @@ impl Display for IfRange { } impl IntoHeaderValue for IfRange { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; fn try_into(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); - HeaderValue::from_shared(writer.take()) + HeaderValue::from_maybe_shared(writer.take()) } } diff --git a/actix-http/src/header/common/mod.rs b/actix-http/src/header/common/mod.rs index 30dfcaa6d..08950ea8b 100644 --- a/actix-http/src/header/common/mod.rs +++ b/actix-http/src/header/common/mod.rs @@ -159,18 +159,18 @@ macro_rules! header { } impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter) -> ::std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result { $crate::http::header::fmt_comma_delimited(f, &self.0[..]) } } impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValueBytes; + type Error = $crate::http::header::InvalidHeaderValue; fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); - $crate::http::header::HeaderValue::from_shared(writer.take()) + $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) } } }; @@ -195,18 +195,18 @@ macro_rules! header { } impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { $crate::http::header::fmt_comma_delimited(f, &self.0[..]) } } impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValueBytes; + type Error = $crate::http::header::InvalidHeaderValue; fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); - $crate::http::header::HeaderValue::from_shared(writer.take()) + $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) } } }; @@ -231,12 +231,12 @@ macro_rules! header { } impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self.0, f) } } impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValueBytes; + type Error = $crate::http::header::InvalidHeaderValue; fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { self.0.try_into() @@ -276,7 +276,7 @@ macro_rules! header { } impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match *self { $id::Any => f.write_str("*"), $id::Items(ref fields) => $crate::http::header::fmt_comma_delimited( @@ -285,13 +285,13 @@ macro_rules! header { } } impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValueBytes; + type Error = $crate::http::header::InvalidHeaderValue; fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); - $crate::http::header::HeaderValue::from_shared(writer.take()) + $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) } } }; diff --git a/actix-http/src/header/map.rs b/actix-http/src/header/map.rs index f2f1ba51c..132087b9e 100644 --- a/actix-http/src/header/map.rs +++ b/actix-http/src/header/map.rs @@ -1,8 +1,9 @@ +use std::collections::hash_map::{self, Entry}; +use std::convert::TryFrom; + use either::Either; -use hashbrown::hash_map::{self, Entry}; -use hashbrown::HashMap; +use fxhash::FxHashMap; use http::header::{HeaderName, HeaderValue}; -use http::HttpTryFrom; /// A set of HTTP headers /// @@ -11,7 +12,7 @@ use http::HttpTryFrom; /// [`HeaderName`]: struct.HeaderName.html #[derive(Debug, Clone)] pub struct HeaderMap { - pub(crate) inner: HashMap, + pub(crate) inner: FxHashMap, } #[derive(Debug, Clone)] @@ -56,7 +57,7 @@ impl HeaderMap { /// allocate. pub fn new() -> Self { HeaderMap { - inner: HashMap::new(), + inner: FxHashMap::default(), } } @@ -70,7 +71,7 @@ impl HeaderMap { /// More capacity than requested may be allocated. pub fn with_capacity(capacity: usize) -> HeaderMap { HeaderMap { - inner: HashMap::with_capacity(capacity), + inner: FxHashMap::with_capacity_and_hasher(capacity, Default::default()), } } @@ -142,7 +143,7 @@ impl HeaderMap { /// Returns `None` if there are no values associated with the key. /// /// [`GetAll`]: struct.GetAll.html - pub fn get_all(&self, name: N) -> GetAll { + pub fn get_all(&self, name: N) -> GetAll<'_> { GetAll { idx: 0, item: self.get2(name), @@ -186,7 +187,7 @@ impl HeaderMap { /// The iteration order is arbitrary, but consistent across platforms for /// the same crate version. Each key will be yielded once per associated /// value. So, if a key has 3 associated values, it will be yielded 3 times. - pub fn iter(&self) -> Iter { + pub fn iter(&self) -> Iter<'_> { Iter::new(self.inner.iter()) } @@ -195,7 +196,7 @@ impl HeaderMap { /// The iteration order is arbitrary, but consistent across platforms for /// the same crate version. Each key will be yielded only once even if it /// has multiple associated values. - pub fn keys(&self) -> Keys { + pub fn keys(&self) -> Keys<'_> { Keys(self.inner.keys()) } diff --git a/actix-http/src/header/mod.rs b/actix-http/src/header/mod.rs index 37cf94508..0db26ceb0 100644 --- a/actix-http/src/header/mod.rs +++ b/actix-http/src/header/mod.rs @@ -1,6 +1,7 @@ //! Various http headers // This is mostly copy of [hyper](https://github.com/hyperium/hyper/tree/master/src/header) +use std::convert::TryFrom; use std::{fmt, str::FromStr}; use bytes::{Bytes, BytesMut}; @@ -16,7 +17,6 @@ use crate::httpmessage::HttpMessage; mod common; pub(crate) mod map; mod shared; -#[doc(hidden)] pub use self::common::*; #[doc(hidden)] pub use self::shared::*; @@ -74,58 +74,58 @@ impl<'a> IntoHeaderValue for &'a [u8] { } impl IntoHeaderValue for Bytes { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; #[inline] fn try_into(self) -> Result { - HeaderValue::from_shared(self) + HeaderValue::from_maybe_shared(self) } } impl IntoHeaderValue for Vec { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; #[inline] fn try_into(self) -> Result { - HeaderValue::from_shared(Bytes::from(self)) + HeaderValue::try_from(self) } } impl IntoHeaderValue for String { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; #[inline] fn try_into(self) -> Result { - HeaderValue::from_shared(Bytes::from(self)) + HeaderValue::try_from(self) } } impl IntoHeaderValue for usize { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; #[inline] fn try_into(self) -> Result { let s = format!("{}", self); - HeaderValue::from_shared(Bytes::from(s)) + HeaderValue::try_from(s) } } impl IntoHeaderValue for u64 { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; #[inline] fn try_into(self) -> Result { let s = format!("{}", self); - HeaderValue::from_shared(Bytes::from(s)) + HeaderValue::try_from(s) } } impl IntoHeaderValue for Mime { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; #[inline] fn try_into(self) -> Result { - HeaderValue::from_shared(Bytes::from(format!("{}", self))) + HeaderValue::try_from(format!("{}", self)) } } @@ -205,7 +205,7 @@ impl Writer { } } fn take(&mut self) -> Bytes { - self.buf.take().freeze() + self.buf.split().freeze() } } @@ -217,7 +217,7 @@ impl fmt::Write for Writer { } #[inline] - fn write_fmt(&mut self, args: fmt::Arguments) -> fmt::Result { + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { fmt::write(self, args) } } @@ -259,7 +259,7 @@ pub fn from_one_raw_str(val: Option<&HeaderValue>) -> Result(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result +pub fn fmt_comma_delimited(f: &mut fmt::Formatter<'_>, parts: &[T]) -> fmt::Result where T: fmt::Display, { @@ -361,7 +361,7 @@ pub fn parse_extended_value( } impl fmt::Display for ExtendedValue { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let encoded_value = percent_encoding::percent_encode(&self.value[..], HTTP_VALUE); if let Some(ref lang) = self.language_tag { @@ -376,7 +376,7 @@ impl fmt::Display for ExtendedValue { /// [https://tools.ietf.org/html/rfc5987#section-3.2][url] /// /// [url]: https://tools.ietf.org/html/rfc5987#section-3.2 -pub fn http_percent_encode(f: &mut fmt::Formatter, bytes: &[u8]) -> fmt::Result { +pub fn http_percent_encode(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE); fmt::Display::fmt(&encoded, f) } diff --git a/actix-http/src/header/shared/charset.rs b/actix-http/src/header/shared/charset.rs index ec3fe3854..6ddfa03ea 100644 --- a/actix-http/src/header/shared/charset.rs +++ b/actix-http/src/header/shared/charset.rs @@ -98,7 +98,7 @@ impl Charset { } impl Display for Charset { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.label()) } } diff --git a/actix-http/src/header/shared/encoding.rs b/actix-http/src/header/shared/encoding.rs index af7404828..aa49dea45 100644 --- a/actix-http/src/header/shared/encoding.rs +++ b/actix-http/src/header/shared/encoding.rs @@ -27,7 +27,7 @@ pub enum Encoding { } impl fmt::Display for Encoding { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match *self { Chunked => "chunked", Brotli => "br", diff --git a/actix-http/src/header/shared/entity.rs b/actix-http/src/header/shared/entity.rs index da02dc193..3525a19c6 100644 --- a/actix-http/src/header/shared/entity.rs +++ b/actix-http/src/header/shared/entity.rs @@ -1,7 +1,7 @@ use std::fmt::{self, Display, Write}; use std::str::FromStr; -use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer}; +use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer}; /// check that each char in the slice is either: /// 1. `%x21`, or @@ -113,7 +113,7 @@ impl EntityTag { } impl Display for EntityTag { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.weak { write!(f, "W/\"{}\"", self.tag) } else { @@ -157,12 +157,12 @@ impl FromStr for EntityTag { } impl IntoHeaderValue for EntityTag { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; fn try_into(self) -> Result { let mut wrt = Writer::new(); write!(wrt, "{}", self).unwrap(); - HeaderValue::from_shared(wrt.take()) + HeaderValue::from_maybe_shared(wrt.take()) } } diff --git a/actix-http/src/header/shared/httpdate.rs b/actix-http/src/header/shared/httpdate.rs index 350f77bbe..28d6a25ec 100644 --- a/actix-http/src/header/shared/httpdate.rs +++ b/actix-http/src/header/shared/httpdate.rs @@ -3,8 +3,8 @@ use std::io::Write; use std::str::FromStr; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use bytes::{BufMut, BytesMut}; -use http::header::{HeaderValue, InvalidHeaderValueBytes}; +use bytes::{buf::BufMutExt, BytesMut}; +use http::header::{HeaderValue, InvalidHeaderValue}; use crate::error::ParseError; use crate::header::IntoHeaderValue; @@ -28,7 +28,7 @@ impl FromStr for HttpDate { } impl Display for HttpDate { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0.to_utc().rfc822(), f) } } @@ -58,12 +58,12 @@ impl From for HttpDate { } impl IntoHeaderValue for HttpDate { - type Error = InvalidHeaderValueBytes; + type Error = InvalidHeaderValue; fn try_into(self) -> Result { let mut wrt = BytesMut::with_capacity(29).writer(); write!(wrt, "{}", self.0.rfc822()).unwrap(); - HeaderValue::from_shared(wrt.get_mut().take().freeze()) + HeaderValue::from_maybe_shared(wrt.get_mut().split().freeze()) } } diff --git a/actix-http/src/header/shared/quality_item.rs b/actix-http/src/header/shared/quality_item.rs index fc3930c5e..98230dec1 100644 --- a/actix-http/src/header/shared/quality_item.rs +++ b/actix-http/src/header/shared/quality_item.rs @@ -53,7 +53,7 @@ impl cmp::PartialOrd for QualityItem { } impl fmt::Display for QualityItem { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.item, f)?; match self.quality.0 { 1000 => Ok(()), diff --git a/actix-http/src/helpers.rs b/actix-http/src/helpers.rs index e4583ee37..58ebff61f 100644 --- a/actix-http/src/helpers.rs +++ b/actix-http/src/helpers.rs @@ -60,7 +60,7 @@ pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesM bytes.put_slice(&buf); if four { - bytes.put(b' '); + bytes.put_u8(b' '); } } @@ -115,7 +115,7 @@ pub fn write_content_length(mut n: usize, bytes: &mut BytesMut) { pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) { let mut curr: isize = 39; - let mut buf: [u8; 41] = unsafe { mem::uninitialized() }; + let mut buf: [u8; 41] = unsafe { mem::MaybeUninit::uninit().assume_init() }; buf[39] = b'\r'; buf[40] = b'\n'; let buf_ptr = buf.as_mut_ptr(); @@ -203,33 +203,33 @@ mod tests { let mut bytes = BytesMut::new(); bytes.reserve(50); write_content_length(0, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 0\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 0\r\n"[..]); bytes.reserve(50); write_content_length(9, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 9\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 9\r\n"[..]); bytes.reserve(50); write_content_length(10, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 10\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 10\r\n"[..]); bytes.reserve(50); write_content_length(99, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 99\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 99\r\n"[..]); bytes.reserve(50); write_content_length(100, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 100\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 100\r\n"[..]); bytes.reserve(50); write_content_length(101, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 101\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 101\r\n"[..]); bytes.reserve(50); write_content_length(998, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 998\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 998\r\n"[..]); bytes.reserve(50); write_content_length(1000, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1000\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 1000\r\n"[..]); bytes.reserve(50); write_content_length(1001, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1001\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 1001\r\n"[..]); bytes.reserve(50); write_content_length(5909, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 5909\r\n"[..]); + assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 5909\r\n"[..]); } } diff --git a/actix-http/src/httpcodes.rs b/actix-http/src/httpcodes.rs index 3cac35eb7..0c7f23fc8 100644 --- a/actix-http/src/httpcodes.rs +++ b/actix-http/src/httpcodes.rs @@ -29,7 +29,6 @@ impl Response { STATIC_RESP!(AlreadyReported, StatusCode::ALREADY_REPORTED); STATIC_RESP!(MultipleChoices, StatusCode::MULTIPLE_CHOICES); - STATIC_RESP!(MovedPermanenty, StatusCode::MOVED_PERMANENTLY); STATIC_RESP!(MovedPermanently, StatusCode::MOVED_PERMANENTLY); STATIC_RESP!(Found, StatusCode::FOUND); STATIC_RESP!(SeeOther, StatusCode::SEE_OTHER); diff --git a/actix-http/src/httpmessage.rs b/actix-http/src/httpmessage.rs index 05d668c10..e1c4136b0 100644 --- a/actix-http/src/httpmessage.rs +++ b/actix-http/src/httpmessage.rs @@ -25,10 +25,10 @@ pub trait HttpMessage: Sized { fn take_payload(&mut self) -> Payload; /// Request's extensions container - fn extensions(&self) -> Ref; + fn extensions(&self) -> Ref<'_, Extensions>; /// Mutable reference to a the request's extensions container - fn extensions_mut(&self) -> RefMut; + fn extensions_mut(&self) -> RefMut<'_, Extensions>; #[doc(hidden)] /// Get a header @@ -105,7 +105,7 @@ pub trait HttpMessage: Sized { /// Load request cookies. #[inline] - fn cookies(&self) -> Result>>, CookieParseError> { + fn cookies(&self) -> Result>>, CookieParseError> { if self.extensions().get::().is_none() { let mut cookies = Vec::new(); for hdr in self.headers().get_all(header::COOKIE) { @@ -153,12 +153,12 @@ where } /// Request's extensions container - fn extensions(&self) -> Ref { + fn extensions(&self) -> Ref<'_, Extensions> { (**self).extensions() } /// Mutable reference to a the request's extensions container - fn extensions_mut(&self) -> RefMut { + fn extensions_mut(&self) -> RefMut<'_, Extensions> { (**self).extensions_mut() } } diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index b57fdddce..7a47012f8 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -1,10 +1,10 @@ //! Basic http primitives for actix-net framework. +#![deny(rust_2018_idioms, warnings)] #![allow( clippy::type_complexity, clippy::too_many_arguments, clippy::new_without_default, - clippy::borrow_interior_mutable_const, - clippy::write_with_newline + clippy::borrow_interior_mutable_const )] #[macro_use] @@ -15,6 +15,7 @@ mod builder; pub mod client; mod cloneable; mod config; +#[cfg(feature = "compress")] pub mod encoding; mod extensions; mod header; @@ -51,7 +52,7 @@ pub mod http { // re-exports pub use http::header::{HeaderName, HeaderValue}; pub use http::uri::PathAndQuery; - pub use http::{uri, Error, HttpTryFrom, Uri}; + pub use http::{uri, Error, Uri}; pub use http::{Method, StatusCode, Version}; pub use crate::cookie::{Cookie, CookieBuilder}; @@ -64,3 +65,10 @@ pub mod http { pub use crate::header::ContentEncoding; pub use crate::message::ConnectionType; } + +/// Http protocol +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Protocol { + Http1, + Http2, +} diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index 316df2611..d005ad04a 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -78,13 +78,13 @@ impl Head for RequestHead { impl RequestHead { /// Message extensions #[inline] - pub fn extensions(&self) -> Ref { + pub fn extensions(&self) -> Ref<'_, Extensions> { self.extensions.borrow() } /// Mutable reference to a the message's extensions #[inline] - pub fn extensions_mut(&self) -> RefMut { + pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.extensions.borrow_mut() } @@ -237,13 +237,13 @@ impl ResponseHead { /// Message extensions #[inline] - pub fn extensions(&self) -> Ref { + pub fn extensions(&self) -> Ref<'_, Extensions> { self.extensions.borrow() } /// Mutable reference to a the message's extensions #[inline] - pub fn extensions_mut(&self) -> RefMut { + pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.extensions.borrow_mut() } @@ -388,6 +388,12 @@ impl BoxedResponseHead { pub fn new(status: StatusCode) -> Self { RESPONSE_POOL.with(|p| p.get_message(status)) } + + pub(crate) fn take(&mut self) -> Self { + BoxedResponseHead { + head: self.head.take(), + } + } } impl std::ops::Deref for BoxedResponseHead { @@ -406,7 +412,9 @@ impl std::ops::DerefMut for BoxedResponseHead { impl Drop for BoxedResponseHead { fn drop(&mut self) { - RESPONSE_POOL.with(|p| p.release(self.head.take().unwrap())) + if let Some(head) = self.head.take() { + RESPONSE_POOL.with(move |p| p.release(head)) + } } } diff --git a/actix-http/src/payload.rs b/actix-http/src/payload.rs index 0ce209705..54de6ed93 100644 --- a/actix-http/src/payload.rs +++ b/actix-http/src/payload.rs @@ -1,11 +1,14 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + use bytes::Bytes; -use futures::{Async, Poll, Stream}; +use futures_core::Stream; use h2::RecvStream; use crate::error::PayloadError; /// Type represent boxed payload -pub type PayloadStream = Box>; +pub type PayloadStream = Pin>>>; /// Type represent streaming payload pub enum Payload { @@ -48,18 +51,20 @@ impl Payload { impl Stream for Payload where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; #[inline] - fn poll(&mut self) -> Poll, Self::Error> { - match self { - Payload::None => Ok(Async::Ready(None)), - Payload::H1(ref mut pl) => pl.poll(), - Payload::H2(ref mut pl) => pl.poll(), - Payload::Stream(ref mut pl) => pl.poll(), + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + Payload::None => Poll::Ready(None), + Payload::H1(ref mut pl) => pl.readany(cx), + Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx), + Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx), } } } diff --git a/actix-http/src/request.rs b/actix-http/src/request.rs index e9252a829..64e302441 100644 --- a/actix-http/src/request.rs +++ b/actix-http/src/request.rs @@ -25,13 +25,13 @@ impl

HttpMessage for Request

{ /// Request extensions #[inline] - fn extensions(&self) -> Ref { + fn extensions(&self) -> Ref<'_, Extensions> { self.head.extensions() } /// Mutable reference to a the request's extensions #[inline] - fn extensions_mut(&self) -> RefMut { + fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.head.extensions_mut() } @@ -80,6 +80,11 @@ impl

Request

{ ) } + /// Get request's payload + pub fn payload(&mut self) -> &mut Payload

{ + &mut self.payload + } + /// Get request's payload pub fn take_payload(&mut self) -> Payload

{ std::mem::replace(&mut self.payload, Payload::None) @@ -160,7 +165,7 @@ impl

Request

{ } impl

fmt::Debug for Request

{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "\nRequest {:?} {}:{}", @@ -182,7 +187,7 @@ impl

fmt::Debug for Request

{ #[cfg(test)] mod tests { use super::*; - use http::HttpTryFrom; + use std::convert::TryFrom; #[test] fn test_basics() { @@ -199,7 +204,6 @@ mod tests { assert_eq!(req.uri().query(), Some("q=1")); let s = format!("{:?}", req); - println!("T: {:?}", s); assert!(s.contains("Request HTTP/1.1 GET:/index.html")); } } diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs index 124bf5f92..fcdcd7cdf 100644 --- a/actix-http/src/response.rs +++ b/actix-http/src/response.rs @@ -1,11 +1,13 @@ //! Http response use std::cell::{Ref, RefMut}; -use std::io::Write; +use std::convert::TryFrom; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, str}; -use bytes::{BufMut, Bytes, BytesMut}; -use futures::future::{ok, FutureResult, IntoFuture}; -use futures::Stream; +use bytes::{Bytes, BytesMut}; +use futures_core::Stream; use serde::Serialize; use serde_json; @@ -15,7 +17,7 @@ use crate::error::Error; use crate::extensions::Extensions; use crate::header::{Header, IntoHeaderValue}; use crate::http::header::{self, HeaderName, HeaderValue}; -use crate::http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode}; +use crate::http::{Error as HttpError, HeaderMap, StatusCode}; use crate::message::{BoxedResponseHead, ConnectionType, ResponseHead}; /// An HTTP Response @@ -51,7 +53,7 @@ impl Response { /// Constructs an error response #[inline] pub fn from_error(error: Error) -> Response { - let mut resp = error.as_response_error().render_response(); + let mut resp = error.as_response_error().error_response(); if resp.head.status == StatusCode::INTERNAL_SERVER_ERROR { error!("Internal Server Error: {:?}", error); } @@ -128,7 +130,7 @@ impl Response { /// Get an iterator for the cookies set by this response #[inline] - pub fn cookies(&self) -> CookieIter { + pub fn cookies(&self) -> CookieIter<'_> { CookieIter { iter: self.head.headers.get_all(header::SET_COOKIE), } @@ -136,7 +138,7 @@ impl Response { /// Add a cookie to this response #[inline] - pub fn add_cookie(&mut self, cookie: &Cookie) -> Result<(), HttpError> { + pub fn add_cookie(&mut self, cookie: &Cookie<'_>) -> Result<(), HttpError> { let h = &mut self.head.headers; HeaderValue::from_str(&cookie.to_string()) .map(|c| { @@ -184,17 +186,17 @@ impl Response { /// Responses extensions #[inline] - pub fn extensions(&self) -> Ref { + pub fn extensions(&self) -> Ref<'_, Extensions> { self.head.extensions.borrow() } /// Mutable reference to a the response's extensions #[inline] - pub fn extensions_mut(&mut self) -> RefMut { + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { self.head.extensions.borrow_mut() } - /// Get body os this response + /// Get body of this response #[inline] pub fn body(&self) -> &ResponseBody { &self.body @@ -263,7 +265,7 @@ impl Response { } impl fmt::Debug for Response { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let res = writeln!( f, "\nResponse {:?} {}{}", @@ -280,13 +282,15 @@ impl fmt::Debug for Response { } } -impl IntoFuture for Response { - type Item = Response; - type Error = Error; - type Future = FutureResult; +impl Future for Response { + type Output = Result; - fn into_future(self) -> Self::Future { - ok(self) + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Ok(Response { + head: self.head.take(), + body: self.body.take_body(), + error: self.error.take(), + })) } } @@ -350,7 +354,6 @@ impl ResponseBuilder { /// )) /// .finish()) /// } - /// fn main() {} /// ``` #[doc(hidden)] pub fn set(&mut self, hdr: H) -> &mut Self { @@ -376,11 +379,11 @@ impl ResponseBuilder { /// .header(http::header::CONTENT_TYPE, "application/json") /// .finish() /// } - /// fn main() {} /// ``` pub fn header(&mut self, key: K, value: V) -> &mut Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { if let Some(parts) = parts(&mut self.head, &self.err) { @@ -408,11 +411,11 @@ impl ResponseBuilder { /// .set_header(http::header::CONTENT_TYPE, "application/json") /// .finish() /// } - /// fn main() {} /// ``` pub fn set_header(&mut self, key: K, value: V) -> &mut Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { if let Some(parts) = parts(&mut self.head, &self.err) { @@ -481,7 +484,8 @@ impl ResponseBuilder { #[inline] pub fn content_type(&mut self, value: V) -> &mut Self where - HeaderValue: HttpTryFrom, + HeaderValue: TryFrom, + >::Error: Into, { if let Some(parts) = parts(&mut self.head, &self.err) { match HeaderValue::try_from(value) { @@ -497,9 +501,7 @@ impl ResponseBuilder { /// Set content length #[inline] pub fn content_length(&mut self, len: u64) -> &mut Self { - let mut wrt = BytesMut::new().writer(); - let _ = write!(wrt, "{}", len); - self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + self.header(header::CONTENT_LENGTH, len) } /// Set a cookie @@ -583,14 +585,14 @@ impl ResponseBuilder { /// Responses extensions #[inline] - pub fn extensions(&self) -> Ref { + pub fn extensions(&self) -> Ref<'_, Extensions> { let head = self.head.as_ref().expect("cannot reuse response builder"); head.extensions.borrow() } /// Mutable reference to a the response's extensions #[inline] - pub fn extensions_mut(&mut self) -> RefMut { + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { let head = self.head.as_ref().expect("cannot reuse response builder"); head.extensions.borrow_mut() } @@ -635,7 +637,7 @@ impl ResponseBuilder { /// `ResponseBuilder` can not be used after this call. pub fn streaming(&mut self, stream: S) -> Response where - S: Stream + 'static, + S: Stream> + 'static, E: Into + 'static, { self.body(Body::from_message(BodyStream::new(stream))) @@ -757,18 +759,16 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder { } } -impl IntoFuture for ResponseBuilder { - type Item = Response; - type Error = Error; - type Future = FutureResult; +impl Future for ResponseBuilder { + type Output = Result; - fn into_future(mut self) -> Self::Future { - ok(self.finish()) + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Ok(self.finish())) } } impl fmt::Debug for ResponseBuilder { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let head = self.head.as_ref().unwrap(); let res = writeln!( @@ -992,6 +992,14 @@ mod tests { assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); } + #[test] + fn test_serde_json_in_body() { + use serde_json::json; + let resp = + Response::build(StatusCode::OK).body(json!({"test-key":"test-value"})); + assert_eq!(resp.body().get_ref(), br#"{"test-key":"test-value"}"#); + } + #[test] fn test_into_response() { let resp: Response = "test".into(); diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index 09b8077b3..51de95135 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,14 +1,16 @@ use std::marker::PhantomData; -use std::{fmt, io, net, rc}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, net, rc}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_server_config::{ - Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig, -}; -use actix_service::{IntoNewService, NewService, Service}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use actix_rt::net::TcpStream; +use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; +use bytes::Bytes; +use futures_core::{ready, Future}; +use futures_util::future::ok; use h2::server::{self, Handshake}; +use pin_project::{pin_project, project}; use crate::body::MessageBody; use crate::builder::HttpServiceBuilder; @@ -18,24 +20,24 @@ use crate::error::{DispatchError, Error}; use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; -use crate::{h1, h2::Dispatcher}; +use crate::{h1, h2::Dispatcher, Protocol}; -/// `NewService` HTTP1.1/HTTP2 transport implementation -pub struct HttpService> { +/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation +pub struct HttpService> { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, on_connect: Option Box>>, - _t: PhantomData<(T, P, B)>, + _t: PhantomData<(T, B)>, } -impl HttpService +impl HttpService where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { @@ -45,22 +47,22 @@ where } } -impl HttpService +impl HttpService where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, { /// Create new `HttpService` instance. - pub fn new>(service: F) -> Self { - let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0, false, None); HttpService { cfg, - srv: service.into_new_service(), + srv: service.into_factory(), expect: h1::ExpectHandler, upgrade: None, on_connect: None, @@ -69,13 +71,13 @@ where } /// Create new `HttpService` instance with config. - pub(crate) fn with_config>( + pub(crate) fn with_config>( cfg: ServiceConfig, service: F, ) -> Self { HttpService { cfg, - srv: service.into_new_service(), + srv: service.into_factory(), expect: h1::ExpectHandler, upgrade: None, on_connect: None, @@ -84,12 +86,13 @@ where } } -impl HttpService +impl HttpService where - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, + ::Future: 'static, B: MessageBody, { /// Provide service for `EXPECT: 100-Continue` support. @@ -97,11 +100,12 @@ where /// Service get called with request that contains `EXPECT` header. /// Service must return request in case of success, in that case /// request will be forwarded to main service. - pub fn expect(self, expect: X1) -> HttpService + pub fn expect(self, expect: X1) -> HttpService where - X1: NewService, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, + ::Future: 'static, { HttpService { expect, @@ -117,15 +121,16 @@ where /// /// If service is provided then normal requests handling get halted /// and this service get called with original request and framed object. - pub fn upgrade(self, upgrade: Option) -> HttpService + pub fn upgrade(self, upgrade: Option) -> HttpService where - U1: NewService< - Config = SrvConfig, + U1: ServiceFactory< + Config = (), Request = (Request, Framed), Response = (), >, U1::Error: fmt::Display, U1::InitError: fmt::Debug, + ::Future: 'static, { HttpService { upgrade, @@ -147,126 +152,312 @@ where } } -impl NewService for HttpService +impl HttpService where - T: IoStream, - S: NewService, - S::Error: Into, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService< - Config = SrvConfig, + ::Future: 'static, + U: ServiceFactory< + Config = (), + Request = (Request, Framed), + Response = (), + >, + U::Error: fmt::Display + Into, + U::InitError: fmt::Debug, + ::Future: 'static, +{ + /// Create simple tcp stream service + pub fn tcp( + self, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = DispatchError, + InitError = (), + > { + pipeline_factory(|io: TcpStream| { + let peer_addr = io.peer_addr().ok(); + ok((io, Protocol::Http1, peer_addr)) + }) + .and_then(self) + } +} + +#[cfg(feature = "openssl")] +mod openssl { + use super::*; + use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream}; + use actix_tls::{openssl::HandshakeError, SslError}; + + impl HttpService, S, B, X, U> + where + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + ::Future: 'static, + U: ServiceFactory< + Config = (), + Request = (Request, Framed, h1::Codec>), + Response = (), + >, + U::Error: fmt::Display + Into, + U::InitError: fmt::Debug, + ::Future: 'static, + { + /// Create openssl based service + pub fn openssl( + self, + acceptor: SslAcceptor, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = SslError, DispatchError>, + InitError = (), + > { + pipeline_factory( + Acceptor::new(acceptor) + .map_err(SslError::Ssl) + .map_init_err(|_| panic!()), + ) + .and_then(|io: SslStream| { + let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { + if protos.windows(2).any(|window| window == b"h2") { + Protocol::Http2 + } else { + Protocol::Http1 + } + } else { + Protocol::Http1 + }; + let peer_addr = io.get_ref().peer_addr().ok(); + ok((io, proto, peer_addr)) + }) + .and_then(self.map_err(SslError::Service)) + } + } +} + +#[cfg(feature = "rustls")] +mod rustls { + use super::*; + use actix_tls::rustls::{Acceptor, ServerConfig, Session, TlsStream}; + use actix_tls::SslError; + use std::io; + + impl HttpService, S, B, X, U> + where + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + ::Future: 'static, + U: ServiceFactory< + Config = (), + Request = (Request, Framed, h1::Codec>), + Response = (), + >, + U::Error: fmt::Display + Into, + U::InitError: fmt::Debug, + ::Future: 'static, + { + /// Create openssl based service + pub fn rustls( + self, + mut config: ServerConfig, + ) -> impl ServiceFactory< + Config = (), + Request = TcpStream, + Response = (), + Error = SslError, + InitError = (), + > { + let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; + config.set_protocols(&protos); + + pipeline_factory( + Acceptor::new(config) + .map_err(SslError::Ssl) + .map_init_err(|_| panic!()), + ) + .and_then(|io: TlsStream| { + let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() { + if protos.windows(2).any(|window| window == b"h2") { + Protocol::Http2 + } else { + Protocol::Http1 + } + } else { + Protocol::Http1 + }; + let peer_addr = io.get_ref().0.peer_addr().ok(); + ok((io, proto, peer_addr)) + }) + .and_then(self.map_err(SslError::Service)) + } + } +} + +impl ServiceFactory for HttpService +where + T: AsyncRead + AsyncWrite + Unpin, + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + ::Future: 'static, + U: ServiceFactory< + Config = (), Request = (Request, Framed), Response = (), >, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, + ::Future: 'static, { - type Config = SrvConfig; - type Request = ServerIo; + type Config = (); + type Request = (T, Protocol, Option); type Response = (); type Error = DispatchError; type InitError = (); - type Service = HttpServiceHandler; - type Future = HttpServiceResponse; + type Service = HttpServiceHandler; + type Future = HttpServiceResponse; - fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { HttpServiceResponse { - fut: self.srv.new_service(cfg).into_future(), - fut_ex: Some(self.expect.new_service(cfg)), - fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), + fut: self.srv.new_service(()), + fut_ex: Some(self.expect.new_service(())), + fut_upg: self.upgrade.as_ref().map(|f| f.new_service(())), expect: None, upgrade: None, on_connect: self.on_connect.clone(), - cfg: Some(self.cfg.clone()), + cfg: self.cfg.clone(), _t: PhantomData, } } } #[doc(hidden)] -pub struct HttpServiceResponse { +#[pin_project] +pub struct HttpServiceResponse< + T, + S: ServiceFactory, + B, + X: ServiceFactory, + U: ServiceFactory, +> { + #[pin] fut: S::Future, + #[pin] fut_ex: Option, + #[pin] fut_upg: Option, expect: Option, upgrade: Option, on_connect: Option Box>>, - cfg: Option, - _t: PhantomData<(T, P, B)>, + cfg: ServiceConfig, + _t: PhantomData<(T, B)>, } -impl Future for HttpServiceResponse +impl Future for HttpServiceResponse where - T: IoStream, - S: NewService, - S::Error: Into, + T: AsyncRead + AsyncWrite + Unpin, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, + S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, - X: NewService, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: NewService), Response = ()>, + ::Future: 'static, + U: ServiceFactory), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, + ::Future: 'static, { - type Item = HttpServiceHandler; - type Error = (); + type Output = + Result, ()>; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut_ex { - let expect = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.expect = Some(expect); - self.fut_ex.take(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().project(); + + if let Some(fut) = this.fut_ex.as_pin_mut() { + let expect = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.expect = Some(expect); + this.fut_ex.set(None); } - if let Some(ref mut fut) = self.fut_upg { - let upgrade = try_ready!(fut - .poll() - .map_err(|e| log::error!("Init http service error: {:?}", e))); - self.upgrade = Some(upgrade); - self.fut_ex.take(); + if let Some(fut) = this.fut_upg.as_pin_mut() { + let upgrade = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.upgrade = Some(upgrade); + this.fut_ex.set(None); } - let service = try_ready!(self + let result = ready!(this .fut - .poll() + .poll(cx) .map_err(|e| log::error!("Init http service error: {:?}", e))); - Ok(Async::Ready(HttpServiceHandler::new( - self.cfg.take().unwrap(), - service, - self.expect.take().unwrap(), - self.upgrade.take(), - self.on_connect.clone(), - ))) + Poll::Ready(result.map(|service| { + let this = self.as_mut().project(); + HttpServiceHandler::new( + this.cfg.clone(), + service, + this.expect.take().unwrap(), + this.upgrade.take(), + this.on_connect.clone(), + ) + })) } } /// `Service` implementation for http transport -pub struct HttpServiceHandler { +pub struct HttpServiceHandler { srv: CloneableService, expect: CloneableService, upgrade: Option>, cfg: ServiceConfig, on_connect: Option Box>>, - _t: PhantomData<(T, P, B, X)>, + _t: PhantomData<(T, B, X)>, } -impl HttpServiceHandler +impl HttpServiceHandler where S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, X: Service, X::Error: Into, @@ -279,7 +470,7 @@ where expect: X, upgrade: Option, on_connect: Option Box>>, - ) -> HttpServiceHandler { + ) -> HttpServiceHandler { HttpServiceHandler { cfg, on_connect, @@ -291,28 +482,28 @@ where } } -impl Service for HttpServiceHandler +impl Service for HttpServiceHandler where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, X: Service, X::Error: Into, U: Service), Response = ()>, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, { - type Request = ServerIo; + type Request = (T, Protocol, Option); type Response = (); type Error = DispatchError; type Future = HttpServiceHandlerResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { let ready = self .expect - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -322,7 +513,7 @@ where let ready = self .srv - .poll_ready() + .poll_ready(cx) .map_err(|e| { let e = e.into(); log::error!("Http service readiness error: {:?}", e); @@ -331,16 +522,27 @@ where .is_ready() && ready; - if ready { - Ok(Async::Ready(())) + let ready = if let Some(ref mut upg) = self.upgrade { + upg.poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready() + && ready } else { - Ok(Async::NotReady) + ready + }; + + if ready { + Poll::Ready(Ok(())) + } else { + Poll::Pending } } - fn call(&mut self, req: Self::Request) -> Self::Future { - let (io, _, proto) = req.into_parts(); - + fn call(&mut self, (io, proto, peer_addr): Self::Request) -> Self::Future { let on_connect = if let Some(ref on_connect) = self.on_connect { Some(on_connect(&io)) } else { @@ -348,23 +550,16 @@ where }; match proto { - Protocol::Http2 => { - let peer_addr = io.peer_addr(); - let io = Io { - inner: io, - unread: None, - }; - HttpServiceHandlerResponse { - state: State::Handshake(Some(( - server::handshake(io), - self.cfg.clone(), - self.srv.clone(), - peer_addr, - on_connect, - ))), - } - } - Protocol::Http10 | Protocol::Http11 => HttpServiceHandlerResponse { + Protocol::Http2 => HttpServiceHandlerResponse { + state: State::H2Handshake(Some(( + server::handshake(io), + self.cfg.clone(), + self.srv.clone(), + on_connect, + peer_addr, + ))), + }, + Protocol::Http1 => HttpServiceHandlerResponse { state: State::H1(h1::Dispatcher::new( io, self.cfg.clone(), @@ -372,234 +567,117 @@ where self.expect.clone(), self.upgrade.clone(), on_connect, + peer_addr, )), }, - _ => HttpServiceHandlerResponse { - state: State::Unknown(Some(( - io, - BytesMut::with_capacity(14), - self.cfg.clone(), - self.srv.clone(), - self.expect.clone(), - self.upgrade.clone(), - on_connect, - ))), - }, } } } +#[pin_project] enum State where S: Service, S::Future: 'static, S::Error: Into, - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, B: MessageBody, X: Service, X::Error: Into, U: Service), Response = ()>, U::Error: fmt::Display, { - H1(h1::Dispatcher), - H2(Dispatcher, S, B>), - Unknown( + H1(#[pin] h1::Dispatcher), + H2(#[pin] Dispatcher), + H2Handshake( Option<( - T, - BytesMut, + Handshake, ServiceConfig, CloneableService, - CloneableService, - Option>, Option>, - )>, - ), - Handshake( - Option<( - Handshake, Bytes>, - ServiceConfig, - CloneableService, Option, - Option>, )>, ), } +#[pin_project] pub struct HttpServiceHandlerResponse where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody + 'static, X: Service, X::Error: Into, U: Service), Response = ()>, U::Error: fmt::Display, { + #[pin] state: State, } -const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; - impl Future for HttpServiceHandlerResponse where - T: IoStream, + T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into, + S::Error: Into + 'static, S::Future: 'static, - S::Response: Into>, + S::Response: Into> + 'static, B: MessageBody, X: Service, X::Error: Into, U: Service), Response = ()>, U::Error: fmt::Display, { - type Item = (); - type Error = DispatchError; + type Output = Result<(), DispatchError>; - fn poll(&mut self) -> Poll { - match self.state { - State::H1(ref mut disp) => disp.poll(), - State::H2(ref mut disp) => disp.poll(), - State::Unknown(ref mut data) => { - if let Some(ref mut item) = data { - loop { - // Safety - we only write to the returned slice. - let b = unsafe { item.1.bytes_mut() }; - let n = try_ready!(item.0.poll_read(b)); - if n == 0 { - return Ok(Async::Ready(())); - } - // Safety - we know that 'n' bytes have - // been initialized via the contract of - // 'poll_read' - unsafe { item.1.advance_mut(n) }; - if item.1.len() >= HTTP2_PREFACE.len() { - break; - } - } - } else { - panic!() - } - let (io, buf, cfg, srv, expect, upgrade, on_connect) = - data.take().unwrap(); - if buf[..14] == HTTP2_PREFACE[..] { - let peer_addr = io.peer_addr(); - let io = Io { - inner: io, - unread: Some(buf), - }; - self.state = State::Handshake(Some(( - server::handshake(io), - cfg, - srv, - peer_addr, - on_connect, - ))); - } else { - self.state = State::H1(h1::Dispatcher::with_timeout( - io, - h1::Codec::new(cfg.clone()), - cfg, - buf, - None, - srv, - expect, - upgrade, - on_connect, - )) - } - self.poll() - } - State::Handshake(ref mut data) => { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().state.poll(cx) + } +} + +impl State +where + T: AsyncRead + AsyncWrite + Unpin, + S: Service, + S::Error: Into + 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + #[project] + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + #[project] + match self.as_mut().project() { + State::H1(disp) => disp.poll(cx), + State::H2(disp) => disp.poll(cx), + State::H2Handshake(ref mut data) => { let conn = if let Some(ref mut item) = data { - match item.0.poll() { - Ok(Async::Ready(conn)) => conn, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => { + match Pin::new(&mut item.0).poll(cx) { + Poll::Ready(Ok(conn)) => conn, + Poll::Ready(Err(err)) => { trace!("H2 handshake error: {}", err); - return Err(err.into()); + return Poll::Ready(Err(err.into())); } + Poll::Pending => return Poll::Pending, } } else { panic!() }; - let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap(); - self.state = State::H2(Dispatcher::new( + let (_, cfg, srv, on_connect, peer_addr) = data.take().unwrap(); + self.set(State::H2(Dispatcher::new( srv, conn, on_connect, cfg, None, peer_addr, - )); - self.poll() + ))); + self.poll(cx) } } } } - -/// Wrapper for `AsyncRead + AsyncWrite` types -struct Io { - unread: Option, - inner: T, -} - -impl io::Read for Io { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if let Some(mut bytes) = self.unread.take() { - let size = std::cmp::min(buf.len(), bytes.len()); - buf[..size].copy_from_slice(&bytes[..size]); - if bytes.len() > size { - bytes.split_to(size); - self.unread = Some(bytes); - } - Ok(size) - } else { - self.inner.read(buf) - } - } -} - -impl io::Write for Io { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.write(buf) - } - fn flush(&mut self) -> io::Result<()> { - self.inner.flush() - } -} - -impl AsyncRead for Io { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } -} - -impl AsyncWrite for Io { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.inner.shutdown() - } - fn write_buf(&mut self, buf: &mut B) -> Poll { - self.inner.write_buf(buf) - } -} - -impl IoStream for Io { - #[inline] - fn peer_addr(&self) -> Option { - self.inner.peer_addr() - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.inner.set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.inner.set_linger(dur) - } - - #[inline] - fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { - self.inner.set_keepalive(dur) - } -} diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs index ce81a54d5..061ba610f 100644 --- a/actix-http/src/test.rs +++ b/actix-http/src/test.rs @@ -1,14 +1,15 @@ //! Test Various helpers for Actix applications to use during testing. +use std::convert::TryFrom; use std::fmt::Write as FmtWrite; -use std::io; +use std::io::{self, Read, Write}; +use std::pin::Pin; use std::str::FromStr; +use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite}; -use actix_server_config::IoStream; -use bytes::{Buf, Bytes, BytesMut}; -use futures::{Async, Poll}; +use bytes::{Bytes, BytesMut}; use http::header::{self, HeaderName, HeaderValue}; -use http::{HttpTryFrom, Method, Uri, Version}; +use http::{Error as HttpError, Method, Uri, Version}; use percent_encoding::percent_encode; use crate::cookie::{Cookie, CookieJar, USERINFO}; @@ -20,8 +21,6 @@ use crate::Request; /// Test `Request` builder /// /// ```rust,ignore -/// # extern crate http; -/// # extern crate actix_web; /// # use http::{header, StatusCode}; /// # use actix_web::*; /// use actix_web::test::TestRequest; @@ -34,15 +33,13 @@ use crate::Request; /// } /// } /// -/// fn main() { -/// let resp = TestRequest::with_header("content-type", "text/plain") -/// .run(&index) -/// .unwrap(); -/// assert_eq!(resp.status(), StatusCode::OK); +/// let resp = TestRequest::with_header("content-type", "text/plain") +/// .run(&index) +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); /// -/// let resp = TestRequest::default().run(&index).unwrap(); -/// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); -/// } +/// let resp = TestRequest::default().run(&index).unwrap(); +/// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); /// ``` pub struct TestRequest(Option); @@ -82,7 +79,8 @@ impl TestRequest { /// Create TestRequest and set header pub fn with_header(key: K, value: V) -> TestRequest where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { TestRequest::default().header(key, value).take() @@ -118,7 +116,8 @@ impl TestRequest { /// Set a header pub fn header(&mut self, key: K, value: V) -> &mut Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { if let Ok(key) = HeaderName::try_from(key) { @@ -150,7 +149,7 @@ impl TestRequest { /// Complete request creation and generate `Request` instance pub fn finish(&mut self) -> Request { - let inner = self.0.take().expect("cannot reuse test request builder");; + let inner = self.0.take().expect("cannot reuse test request builder"); let mut req = if let Some(pl) = inner.payload { Request::with_payload(pl) @@ -244,27 +243,30 @@ impl io::Write for TestBuffer { } } -impl AsyncRead for TestBuffer {} +impl AsyncRead for TestBuffer { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Poll::Ready(self.get_mut().read(buf)) + } +} impl AsyncWrite for TestBuffer { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(self.get_mut().write(buf)) } - fn write_buf(&mut self, _: &mut B) -> Poll { - Ok(Async::NotReady) - } -} - -impl IoStream for TestBuffer { - fn set_nodelay(&mut self, _nodelay: bool) -> io::Result<()> { - Ok(()) - } - - fn set_linger(&mut self, _dur: Option) -> io::Result<()> { - Ok(()) - } - - fn set_keepalive(&mut self, _dur: Option) -> io::Result<()> { - Ok(()) + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 9891bfa6e..a37208a2b 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -12,10 +12,12 @@ pub enum Message { Text(String), /// Binary message Binary(Bytes), + /// Continuation + Continuation(Item), /// Ping message - Ping(String), + Ping(Bytes), /// Pong message - Pong(String), + Pong(Bytes), /// Close message with optional reason Close(Option), /// No-op. Useful for actix-net services @@ -26,22 +28,41 @@ pub enum Message { #[derive(Debug, PartialEq)] pub enum Frame { /// Text frame, codec does not verify utf8 encoding - Text(Option), + Text(Bytes), /// Binary frame - Binary(Option), + Binary(Bytes), + /// Continuation + Continuation(Item), /// Ping message - Ping(String), + Ping(Bytes), /// Pong message - Pong(String), + Pong(Bytes), /// Close message with optional reason Close(Option), } +/// `WebSocket` continuation item +#[derive(Debug, PartialEq)] +pub enum Item { + FirstText(Bytes), + FirstBinary(Bytes), + Continue(Bytes), + Last(Bytes), +} + #[derive(Debug, Copy, Clone)] /// WebSockets protocol codec pub struct Codec { + flags: Flags, max_size: usize, - server: bool, +} + +bitflags::bitflags! { + struct Flags: u8 { + const SERVER = 0b0000_0001; + const CONTINUATION = 0b0000_0010; + const W_CONTINUATION = 0b0000_0100; + } } impl Codec { @@ -49,7 +70,7 @@ impl Codec { pub fn new() -> Codec { Codec { max_size: 65_536, - server: true, + flags: Flags::SERVER, } } @@ -65,7 +86,7 @@ impl Codec { /// /// By default decoder works in server mode. pub fn client_mode(mut self) -> Self { - self.server = false; + self.flags.remove(Flags::SERVER); self } } @@ -76,19 +97,94 @@ impl Encoder for Codec { fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - Message::Text(txt) => { - Parser::write_message(dst, txt, OpCode::Text, true, !self.server) + Message::Text(txt) => Parser::write_message( + dst, + txt, + OpCode::Text, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Binary(bin) => Parser::write_message( + dst, + bin, + OpCode::Binary, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Ping(txt) => Parser::write_message( + dst, + txt, + OpCode::Ping, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Pong(txt) => Parser::write_message( + dst, + txt, + OpCode::Pong, + true, + !self.flags.contains(Flags::SERVER), + ), + Message::Close(reason) => { + Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER)) } - Message::Binary(bin) => { - Parser::write_message(dst, bin, OpCode::Binary, true, !self.server) - } - Message::Ping(txt) => { - Parser::write_message(dst, txt, OpCode::Ping, true, !self.server) - } - Message::Pong(txt) => { - Parser::write_message(dst, txt, OpCode::Pong, true, !self.server) - } - Message::Close(reason) => Parser::write_close(dst, reason, !self.server), + Message::Continuation(cont) => match cont { + Item::FirstText(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + return Err(ProtocolError::ContinuationStarted); + } else { + self.flags.insert(Flags::W_CONTINUATION); + Parser::write_message( + dst, + &data[..], + OpCode::Binary, + false, + !self.flags.contains(Flags::SERVER), + ) + } + } + Item::FirstBinary(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + return Err(ProtocolError::ContinuationStarted); + } else { + self.flags.insert(Flags::W_CONTINUATION); + Parser::write_message( + dst, + &data[..], + OpCode::Text, + false, + !self.flags.contains(Flags::SERVER), + ) + } + } + Item::Continue(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + Parser::write_message( + dst, + &data[..], + OpCode::Continue, + false, + !self.flags.contains(Flags::SERVER), + ) + } else { + return Err(ProtocolError::ContinuationNotStarted); + } + } + Item::Last(data) => { + if self.flags.contains(Flags::W_CONTINUATION) { + self.flags.remove(Flags::W_CONTINUATION); + Parser::write_message( + dst, + &data[..], + OpCode::Continue, + true, + !self.flags.contains(Flags::SERVER), + ) + } else { + return Err(ProtocolError::ContinuationNotStarted); + } + } + }, Message::Nop => (), } Ok(()) @@ -100,15 +196,64 @@ impl Decoder for Codec { type Error = ProtocolError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - match Parser::parse(src, self.server, self.max_size) { + match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { Ok(Some((finished, opcode, payload))) => { // continuation is not supported if !finished { - return Err(ProtocolError::NoContinuation); + return match opcode { + OpCode::Continue => { + if self.flags.contains(Flags::CONTINUATION) { + Ok(Some(Frame::Continuation(Item::Continue( + payload + .map(|pl| pl.freeze()) + .unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationNotStarted) + } + } + OpCode::Binary => { + if !self.flags.contains(Flags::CONTINUATION) { + self.flags.insert(Flags::CONTINUATION); + Ok(Some(Frame::Continuation(Item::FirstBinary( + payload + .map(|pl| pl.freeze()) + .unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationStarted) + } + } + OpCode::Text => { + if !self.flags.contains(Flags::CONTINUATION) { + self.flags.insert(Flags::CONTINUATION); + Ok(Some(Frame::Continuation(Item::FirstText( + payload + .map(|pl| pl.freeze()) + .unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationStarted) + } + } + _ => { + error!("Unfinished fragment {:?}", opcode); + Err(ProtocolError::ContinuationFragment(opcode)) + } + }; } match opcode { - OpCode::Continue => Err(ProtocolError::NoContinuation), + OpCode::Continue => { + if self.flags.contains(Flags::CONTINUATION) { + self.flags.remove(Flags::CONTINUATION); + Ok(Some(Frame::Continuation(Item::Last( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + )))) + } else { + Err(ProtocolError::ContinuationNotStarted) + } + } OpCode::Bad => Err(ProtocolError::BadOpCode), OpCode::Close => { if let Some(ref pl) = payload { @@ -118,29 +263,18 @@ impl Decoder for Codec { Ok(Some(Frame::Close(None))) } } - OpCode::Ping => { - if let Some(ref pl) = payload { - Ok(Some(Frame::Ping(String::from_utf8_lossy(pl).into()))) - } else { - Ok(Some(Frame::Ping(String::new()))) - } - } - OpCode::Pong => { - if let Some(ref pl) = payload { - Ok(Some(Frame::Pong(String::from_utf8_lossy(pl).into()))) - } else { - Ok(Some(Frame::Pong(String::new()))) - } - } - OpCode::Binary => Ok(Some(Frame::Binary(payload))), - OpCode::Text => { - Ok(Some(Frame::Text(payload))) - //let tmp = Vec::from(payload.as_ref()); - //match String::from_utf8(tmp) { - // Ok(s) => Ok(Some(Message::Text(s))), - // Err(_) => Err(ProtocolError::BadEncoding), - //} - } + OpCode::Ping => Ok(Some(Frame::Ping( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), + OpCode::Pong => Ok(Some(Frame::Pong( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), + OpCode::Binary => Ok(Some(Frame::Binary( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), + OpCode::Text => Ok(Some(Frame::Text( + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + ))), } } Ok(None) => Ok(None), diff --git a/actix-http/src/ws/transport.rs b/actix-http/src/ws/dispatcher.rs similarity index 52% rename from actix-http/src/ws/transport.rs rename to actix-http/src/ws/dispatcher.rs index da7782be5..7a6b11b18 100644 --- a/actix-http/src/ws/transport.rs +++ b/actix-http/src/ws/dispatcher.rs @@ -1,19 +1,22 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::{IntoService, Service}; -use actix_utils::framed::{FramedTransport, FramedTransportError}; -use futures::{Future, Poll}; +use actix_utils::framed; use super::{Codec, Frame, Message}; -pub struct Transport +pub struct Dispatcher where S: Service + 'static, T: AsyncRead + AsyncWrite, { - inner: FramedTransport, + inner: framed::Dispatcher, } -impl Transport +impl Dispatcher where T: AsyncRead + AsyncWrite, S: Service, @@ -21,29 +24,28 @@ where S::Error: 'static, { pub fn new>(io: T, service: F) -> Self { - Transport { - inner: FramedTransport::new(Framed::new(io, Codec::new()), service), + Dispatcher { + inner: framed::Dispatcher::new(Framed::new(io, Codec::new()), service), } } pub fn with>(framed: Framed, service: F) -> Self { - Transport { - inner: FramedTransport::new(framed, service), + Dispatcher { + inner: framed::Dispatcher::new(framed, service), } } } -impl Future for Transport +impl Future for Dispatcher where T: AsyncRead + AsyncWrite, S: Service, S::Future: 'static, S::Error: 'static, { - type Item = (); - type Error = FramedTransportError; + type Output = Result<(), framed::DispatcherError>; - fn poll(&mut self) -> Poll { - self.inner.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx) } } diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 46e9f36db..3c70eb2bd 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -1,6 +1,6 @@ use std::convert::TryFrom; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use log::debug; use rand; @@ -108,7 +108,7 @@ impl Parser { } // remove prefix - src.split_to(idx); + src.advance(idx); // no need for body if length == 0 { @@ -154,14 +154,14 @@ impl Parser { } /// Generate binary representation - pub fn write_message>( + pub fn write_message>( dst: &mut BytesMut, pl: B, op: OpCode, fin: bool, mask: bool, ) { - let payload = pl.into(); + let payload = pl.as_ref(); let one: u8 = if fin { 0x80 | Into::::into(op) } else { @@ -180,11 +180,11 @@ impl Parser { } else if payload_len <= 65_535 { dst.reserve(p_len + 4 + if mask { 4 } else { 0 }); dst.put_slice(&[one, two | 126]); - dst.put_u16_be(payload_len as u16); + dst.put_u16(payload_len as u16); } else { dst.reserve(p_len + 10 + if mask { 4 } else { 0 }); dst.put_slice(&[one, two | 127]); - dst.put_u64_be(payload_len as u64); + dst.put_u64(payload_len as u64); }; if mask { diff --git a/actix-http/src/ws/mask.rs b/actix-http/src/ws/mask.rs index 9f7304039..7eb5d148f 100644 --- a/actix-http/src/ws/mask.rs +++ b/actix-http/src/ws/mask.rs @@ -51,7 +51,7 @@ pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) { // inefficient, it could be done better. The compiler does not understand that // a `ShortSlice` must be smaller than a u64. #[allow(clippy::needless_pass_by_value)] -fn xor_short(buf: ShortSlice, mask: u64) { +fn xor_short(buf: ShortSlice<'_>, mask: u64) { // Unsafe: we know that a `ShortSlice` fits in a u64 unsafe { let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len()); @@ -77,7 +77,7 @@ unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] { #[inline] // Splits a slice into three parts: an unaligned short head and tail, plus an aligned // u64 mid section. -fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) { +fn align_buf(buf: &mut [u8]) -> (ShortSlice<'_>, &mut [u64], ShortSlice<'_>) { let start_ptr = buf.as_ptr() as usize; let end_ptr = start_ptr + buf.len(); diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 891d5110d..ffa397979 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -13,15 +13,15 @@ use crate::message::RequestHead; use crate::response::{Response, ResponseBuilder}; mod codec; +mod dispatcher; mod frame; mod mask; mod proto; -mod transport; -pub use self::codec::{Codec, Frame, Message}; +pub use self::codec::{Codec, Frame, Item, Message}; +pub use self::dispatcher::Dispatcher; pub use self::frame::Parser; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; -pub use self::transport::Transport; /// Websocket protocol errors #[derive(Debug, Display, From)] @@ -44,12 +44,15 @@ pub enum ProtocolError { /// A payload reached size limit. #[display(fmt = "A payload reached size limit.")] Overflow, - /// Continuation is not supported - #[display(fmt = "Continuation is not supported.")] - NoContinuation, - /// Bad utf-8 encoding - #[display(fmt = "Bad utf-8 encoding.")] - BadEncoding, + /// Continuation is not started + #[display(fmt = "Continuation is not started.")] + ContinuationNotStarted, + /// Received new continuation but it is already started + #[display(fmt = "Received new continuation but it is already started")] + ContinuationStarted, + /// Unknown continuation fragment + #[display(fmt = "Unknown continuation fragment.")] + ContinuationFragment(OpCode), /// Io error #[display(fmt = "io error: {}", _0)] Io(io::Error), diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index e14651a56..60af6f08b 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -24,7 +24,7 @@ pub enum OpCode { } impl fmt::Display for OpCode { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { Continue => write!(f, "CONTINUE"), Text => write!(f, "TEXT"), @@ -95,7 +95,7 @@ pub enum CloseCode { Abnormal, /// Indicates that an endpoint is terminating the connection /// because it has received data within a message that was not - /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] /// data within a text message). Invalid, /// Indicates that an endpoint is terminating the connection @@ -207,12 +207,13 @@ static WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // TODO: hash is always same size, we dont need String pub fn hash_key(key: &[u8]) -> String { + use sha1::Digest; let mut hasher = sha1::Sha1::new(); - hasher.update(key); - hasher.update(WS_GUID.as_bytes()); + hasher.input(key); + hasher.input(WS_GUID.as_bytes()); - base64::encode(&hasher.digest().bytes()) + base64::encode(hasher.result().as_ref()) } #[cfg(test)] @@ -277,6 +278,12 @@ mod test { assert_eq!(format!("{}", OpCode::Bad), "BAD"); } + #[test] + fn test_hash_key() { + let hash = hash_key(b"hello actix-web"); + assert_eq!(&hash, "cR1dlyUUJKp0s/Bel25u5TgvC3E="); + } + #[test] fn closecode_from_u16() { assert_eq!(CloseCode::from(1000u16), CloseCode::Normal); diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index a4f1569cc..9da3b04a2 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -1,9 +1,9 @@ -use actix_service::NewService; +use actix_service::ServiceFactory; use bytes::Bytes; use futures::future::{self, ok}; use actix_http::{http, HttpService, Request, Response}; -use actix_http_test::TestServer; +use actix_http_test::test_server; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -27,45 +27,49 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World"; -#[test] -fn test_h1_v2() { - env_logger::init(); - let mut srv = TestServer::new(move || { - HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) +#[actix_rt::test] +async fn test_h1_v2() { + let srv = test_server(move || { + HttpService::build() + .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); let request = srv.get("/").header("x-test", "111").send(); - let response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - let response = srv.block_on(srv.post("/").send()).unwrap(); + let mut response = srv.post("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_connection_close() { - let mut srv = TestServer::new(move || { +#[actix_rt::test] +async fn test_connection_close() { + let srv = test_server(move || { HttpService::build() .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .tcp() .map(|_| ()) }); - let response = srv.block_on(srv.get("/").force_close().send()).unwrap(); + + let response = srv.get("/").force_close().send().await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn test_with_query_parameter() { - let mut srv = TestServer::new(move || { +#[actix_rt::test] +async fn test_with_query_parameter() { + let srv = test_server(move || { HttpService::build() .finish(|req: Request| { if req.uri().query().unwrap().contains("qp=") { @@ -74,10 +78,11 @@ fn test_with_query_parameter() { ok::<_, ()>(Response::BadRequest().finish()) } }) + .tcp() .map(|_| ()) }); - let request = srv.request(http::Method::GET, srv.url("/?qp=5")).send(); - let response = srv.block_on(request).unwrap(); + let request = srv.request(http::Method::GET, srv.url("/?qp=5")); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); } diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs new file mode 100644 index 000000000..b25f05272 --- /dev/null +++ b/actix-http/tests/test_openssl.rs @@ -0,0 +1,416 @@ +#![cfg(feature = "openssl")] +use std::io; + +use actix_http_test::test_server; +use actix_service::{fn_service, ServiceFactory}; + +use bytes::{Bytes, BytesMut}; +use futures::future::{err, ok, ready}; +use futures::stream::{once, Stream, StreamExt}; +use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; + +use actix_http::error::{ErrorBadRequest, PayloadError}; +use actix_http::http::header::{self, HeaderName, HeaderValue}; +use actix_http::http::{Method, StatusCode, Version}; +use actix_http::httpmessage::HttpMessage; +use actix_http::{body, Error, HttpService, Request, Response}; + +async fn load_body(stream: S) -> Result +where + S: Stream>, +{ + let body = stream + .map(|res| match res { + Ok(chunk) => chunk, + Err(_) => panic!(), + }) + .fold(BytesMut::new(), move |mut body, chunk| { + body.extend_from_slice(&chunk); + ready(body) + }) + .await; + + Ok(body) +} + +fn ssl_acceptor() -> SslAcceptor { + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + const H11: &[u8] = b"\x08http/1.1"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else if protos.windows(9).any(|window| window == H11) { + Ok(b"http/1.1") + } else { + Err(AlpnError::NOACK) + } + }); + builder + .set_alpn_protos(b"\x08http/1.1\x02h2") + .expect("Can not contrust SslAcceptor"); + + builder.build() +} + +#[actix_rt::test] +async fn test_h2() -> io::Result<()> { + let srv = test_server(move || { + HttpService::build() + .h2(|_| ok::<_, Error>(Response::Ok().finish())) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_1() -> io::Result<()> { + let srv = test_server(move || { + HttpService::build() + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), Version::HTTP_2); + ok::<_, Error>(Response::Ok().finish()) + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_body() -> io::Result<()> { + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let mut srv = test_server(move || { + HttpService::build() + .h2(|mut req: Request<_>| { + async move { + let body = load_body(req.take_payload()).await?; + Ok::<_, Error>(Response::Ok().body(body)) + } + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); + + let body = srv.load_body(response).await.unwrap(); + assert_eq!(&body, data.as_bytes()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_content_length() { + let srv = test_server(move || { + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + ok::<_, ()>(Response::new(statuses[indx])) + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} + +#[actix_rt::test] +async fn test_h2_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + + let mut srv = test_server(move || { + let data = data.clone(); + HttpService::build().h2(move |_| { + let mut builder = Response::Ok(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + ok::<_, ()>(builder.body(data.clone())) + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[actix_rt::test] +async fn test_h2_body2() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_head_empty() { + let mut srv = test_server(move || { + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| { + ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary2() { + let srv = test_server(move || { + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[actix_rt::test] +async fn test_h2_body_length() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_body_chunked_explicit() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_response_http_error_handling() { + let mut srv = test_server(move || { + HttpService::build() + .h2(fn_service(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(header::CONTENT_TYPE, broken_header) + .body(STR), + ) + })) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); +} + +#[actix_rt::test] +async fn test_h2_service_error() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| err::(ErrorBadRequest("error"))) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); +} + +#[actix_rt::test] +async fn test_h2_on_connect() { + let srv = test_server(move || { + HttpService::build() + .on_connect(|_| 10usize) + .h2(|req: Request| { + assert!(req.extensions().contains::()); + ok::<_, ()>(Response::Ok().finish()) + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs new file mode 100644 index 000000000..bc0c91cc3 --- /dev/null +++ b/actix-http/tests/test_rustls.rs @@ -0,0 +1,421 @@ +#![cfg(feature = "rustls")] +use actix_http::error::PayloadError; +use actix_http::http::header::{self, HeaderName, HeaderValue}; +use actix_http::http::{Method, StatusCode, Version}; +use actix_http::{body, error, Error, HttpService, Request, Response}; +use actix_http_test::test_server; +use actix_service::{fn_factory_with_config, fn_service}; + +use bytes::{Bytes, BytesMut}; +use futures::future::{self, err, ok}; +use futures::stream::{once, Stream, StreamExt}; +use rust_tls::{ + internal::pemfile::{certs, pkcs8_private_keys}, + NoClientAuth, ServerConfig as RustlsServerConfig, +}; + +use std::fs::File; +use std::io::{self, BufReader}; + +async fn load_body(mut stream: S) -> Result +where + S: Stream> + Unpin, +{ + let mut body = BytesMut::new(); + while let Some(item) = stream.next().await { + body.extend_from_slice(&item?) + } + Ok(body) +} + +fn ssl_acceptor() -> RustlsServerConfig { + // load ssl keys + let mut config = RustlsServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + config +} + +#[actix_rt::test] +async fn test_h1() -> io::Result<()> { + let srv = test_server(move || { + HttpService::build() + .h1(|_| future::ok::<_, Error>(Response::Ok().finish())) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2() -> io::Result<()> { + let srv = test_server(move || { + HttpService::build() + .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h1_1() -> io::Result<()> { + let srv = test_server(move || { + HttpService::build() + .h1(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), Version::HTTP_11); + future::ok::<_, Error>(Response::Ok().finish()) + }) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_1() -> io::Result<()> { + let srv = test_server(move || { + HttpService::build() + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), Version::HTTP_2); + future::ok::<_, Error>(Response::Ok().finish()) + }) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_body1() -> io::Result<()> { + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let mut srv = test_server(move || { + HttpService::build() + .h2(|mut req: Request<_>| { + async move { + let body = load_body(req.take_payload()).await?; + Ok::<_, Error>(Response::Ok().body(body)) + } + }) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); + + let body = srv.load_body(response).await.unwrap(); + assert_eq!(&body, data.as_bytes()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_content_length() { + let srv = test_server(move || { + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + .rustls(ssl_acceptor()) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + { + for i in 0..4 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} + +#[actix_rt::test] +async fn test_h2_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + + let mut srv = test_server(move || { + let data = data.clone(); + HttpService::build().h2(move |_| { + let mut config = Response::Ok(); + for idx in 0..90 { + config.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + future::ok::<_, ()>(config.body(data.clone())) + }) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[actix_rt::test] +async fn test_h2_body2() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_head_empty() { + let mut srv = test_server(move || { + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .rustls(ssl_acceptor()) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| { + ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) + }) + .rustls(ssl_acceptor()) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary2() { + let srv = test_server(move || { + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .rustls(ssl_acceptor()) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[actix_rt::test] +async fn test_h2_body_length() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_body_chunked_explicit() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_response_http_error_handling() { + let mut srv = test_server(move || { + HttpService::build() + .h2(fn_factory_with_config(|_: ()| { + ok::<_, ()>(fn_service(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + })) + })) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); +} + +#[actix_rt::test] +async fn test_h2_service_error() { + let mut srv = test_server(move || { + HttpService::build() + .h2(|_| err::(error::ErrorBadRequest("error"))) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); +} + +#[actix_rt::test] +async fn test_h1_service_error() { + let mut srv = test_server(move || { + HttpService::build() + .h1(|_| err::(error::ErrorBadRequest("error"))) + .rustls(ssl_acceptor()) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); +} diff --git a/actix-http/tests/test_rustls_server.rs b/actix-http/tests/test_rustls_server.rs deleted file mode 100644 index b74fd07bf..000000000 --- a/actix-http/tests/test_rustls_server.rs +++ /dev/null @@ -1,462 +0,0 @@ -#![cfg(feature = "rust-tls")] -use actix_codec::{AsyncRead, AsyncWrite}; -use actix_http::error::PayloadError; -use actix_http::http::header::{self, HeaderName, HeaderValue}; -use actix_http::http::{Method, StatusCode, Version}; -use actix_http::{body, error, Error, HttpService, Request, Response}; -use actix_http_test::TestServer; -use actix_server::ssl::RustlsAcceptor; -use actix_server_config::ServerConfig; -use actix_service::{new_service_cfg, NewService}; - -use bytes::{Bytes, BytesMut}; -use futures::future::{self, ok, Future}; -use futures::stream::{once, Stream}; -use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - NoClientAuth, ServerConfig as RustlsServerConfig, -}; - -use std::fs::File; -use std::io::{BufReader, Result}; - -fn load_body(stream: S) -> impl Future -where - S: Stream, -{ - stream.fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) -} - -fn ssl_acceptor() -> Result> { - // load ssl keys - let mut config = RustlsServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - - let protos = vec![b"h2".to_vec()]; - config.set_protocols(&protos); - Ok(RustlsAcceptor::new(config)) -} - -#[test] -fn test_h2() -> Result<()> { - let rustls = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_1() -> Result<()> { - let rustls = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .finish(|req: Request| { - assert!(req.peer_addr().is_some()); - assert_eq!(req.version(), Version::HTTP_2); - future::ok::<_, Error>(Response::Ok().finish()) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_body() -> Result<()> { - let data = "HELLOWORLD".to_owned().repeat(64 * 1024); - let rustls = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|mut req: Request<_>| { - load_body(req.take_payload()) - .and_then(|body| Ok(Response::Ok().body(body))) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap(); - assert!(response.status().is_success()); - - let body = srv.load_body(response).unwrap(); - assert_eq!(&body, data.as_bytes()); - Ok(()) -} - -#[test] -fn test_h2_content_length() { - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); - let statuses = [ - StatusCode::NO_CONTENT, - StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, - StatusCode::OK, - StatusCode::NOT_FOUND, - ]; - future::ok::<_, ()>(Response::new(statuses[indx])) - }) - .map_err(|_| ()), - ) - }); - - let header = HeaderName::from_static("content-length"); - let value = HeaderValue::from_static("0"); - - { - for i in 0..4 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - - let req = srv - .request(Method::HEAD, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - } - - for i in 4..6 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), Some(&value)); - } - } -} - -#[test] -fn test_h2_headers() { - let data = STR.repeat(10); - let data2 = data.clone(); - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - let data = data.clone(); - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build().h2(move |_| { - let mut config = Response::Ok(); - for idx in 0..90 { - config.header( - format!("X-TEST-{}", idx).as_str(), - "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", - ); - } - future::ok::<_, ()>(config.body(data.clone())) - }).map_err(|_| ())) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from(data2)); -} - -const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; - -#[test] -fn test_h2_body2() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_head_empty() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.version(), Version::HTTP_2); - - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - ok::<_, ()>( - Response::Ok().content_length(STR.len() as u64).body(STR), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary2() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response - .headers() - .get(http::header::CONTENT_LENGTH) - .unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } -} - -#[test] -fn test_h2_body_length() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .body(body::SizedStream::new(STR.len() as u64, body)), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_body_chunked_explicit() { - let rustls = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = - once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") - .streaming(body), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); - - // read response - let bytes = srv.load_body(response).unwrap(); - - // decode - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_response_http_error_handling() { - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(new_service_cfg(|_: &ServerConfig| { - Ok::<_, ()>(|_| { - let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() - .header(http::header::CONTENT_TYPE, broken_header) - .body(STR), - ) - }) - })) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); -} - -#[test] -fn test_h2_service_error() { - let rustls = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - rustls - .clone() - .map_err(|e| println!("Rustls error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| Err::(error::ErrorBadRequest("error"))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"error")); -} diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index a31e4ac89..a84692f9d 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -2,23 +2,22 @@ use std::io::{Read, Write}; use std::time::Duration; use std::{net, thread}; -use actix_http_test::TestServer; -use actix_server_config::ServerConfig; -use actix_service::{new_service_cfg, service_fn, NewService}; +use actix_http_test::test_server; +use actix_rt::time::delay_for; +use actix_service::fn_service; use bytes::Bytes; -use futures::future::{self, ok, Future}; -use futures::stream::{once, Stream}; +use futures::future::{self, err, ok, ready, FutureExt}; +use futures::stream::{once, StreamExt}; use regex::Regex; -use tokio_timer::sleep; use actix_http::httpmessage::HttpMessage; use actix_http::{ body, error, http, http::header, Error, HttpService, KeepAlive, Request, Response, }; -#[test] -fn test_h1() { - let mut srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_h1() { + let srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) @@ -27,15 +26,16 @@ fn test_h1() { assert!(req.peer_addr().is_some()); future::ok::<_, ()>(Response::Ok().finish()) }) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn test_h1_2() { - let mut srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_h1_2() { + let srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) @@ -45,25 +45,26 @@ fn test_h1_2() { assert_eq!(req.version(), http::Version::HTTP_11); future::ok::<_, ()>(Response::Ok().finish()) }) - .map(|_| ()) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn test_expect_continue() { - let srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_expect_continue() { + let srv = test_server(|| { HttpService::build() - .expect(service_fn(|req: Request| { + .expect(fn_service(|req: Request| { if req.head().uri.query() == Some("yes=") { - Ok(req) + ok(req) } else { - Err(error::ErrorPreconditionFailed("error")) + err(error::ErrorPreconditionFailed("error")) } })) .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -79,20 +80,21 @@ fn test_expect_continue() { assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); } -#[test] -fn test_expect_continue_h1() { - let srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_expect_continue_h1() { + let srv = test_server(|| { HttpService::build() - .expect(service_fn(|req: Request| { - sleep(Duration::from_millis(20)).then(move |_| { + .expect(fn_service(|req: Request| { + delay_for(Duration::from_millis(20)).then(move |_| { if req.head().uri.query() == Some("yes=") { - Ok(req) + ok(req) } else { - Err(error::ErrorPreconditionFailed("error")) + err(error::ErrorPreconditionFailed("error")) } }) })) - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(fn_service(|_| future::ok::<_, ()>(Response::Ok().finish()))) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -108,19 +110,26 @@ fn test_expect_continue_h1() { assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); } -#[test] -fn test_chunked_payload() { +#[actix_rt::test] +async fn test_chunked_payload() { let chunk_sizes = vec![32768, 32, 32768]; let total_size: usize = chunk_sizes.iter().sum(); - let srv = TestServer::new(|| { - HttpService::build().h1(|mut request: Request| { - request - .take_payload() - .map_err(|e| panic!(format!("Error reading payload: {}", e))) - .fold(0usize, |acc, chunk| future::ok::<_, ()>(acc + chunk.len())) - .map(|req_size| Response::Ok().body(format!("size={}", req_size))) - }) + let srv = test_server(|| { + HttpService::build() + .h1(fn_service(|mut request: Request| { + request + .take_payload() + .map(|res| match res { + Ok(pl) => pl, + Err(e) => panic!(format!("Error reading payload: {}", e)), + }) + .fold(0usize, |acc, chunk| ready(acc + chunk.len())) + .map(|req_size| { + Ok::<_, Error>(Response::Ok().body(format!("size={}", req_size))) + }) + })) + .tcp() }); let returned_size = { @@ -156,12 +165,13 @@ fn test_chunked_payload() { assert_eq!(returned_size, total_size); } -#[test] -fn test_slow_request() { - let srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_slow_request() { + let srv = test_server(|| { HttpService::build() .client_timeout(100) .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -171,10 +181,12 @@ fn test_slow_request() { assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); } -#[test] -fn test_http1_malformed_request() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) +#[actix_rt::test] +async fn test_http1_malformed_request() { + let srv = test_server(|| { + HttpService::build() + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -184,10 +196,12 @@ fn test_http1_malformed_request() { assert!(data.starts_with("HTTP/1.1 400 Bad Request")); } -#[test] -fn test_http1_keepalive() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) +#[actix_rt::test] +async fn test_http1_keepalive() { + let srv = test_server(|| { + HttpService::build() + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -202,12 +216,13 @@ fn test_http1_keepalive() { assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); } -#[test] -fn test_http1_keepalive_timeout() { - let srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_http1_keepalive_timeout() { + let srv = test_server(|| { HttpService::build() .keep_alive(1) .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -222,10 +237,12 @@ fn test_http1_keepalive_timeout() { assert_eq!(res, 0); } -#[test] -fn test_http1_keepalive_close() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) +#[actix_rt::test] +async fn test_http1_keepalive_close() { + let srv = test_server(|| { + HttpService::build() + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -240,10 +257,12 @@ fn test_http1_keepalive_close() { assert_eq!(res, 0); } -#[test] -fn test_http10_keepalive_default_close() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) +#[actix_rt::test] +async fn test_http10_keepalive_default_close() { + let srv = test_server(|| { + HttpService::build() + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -257,10 +276,12 @@ fn test_http10_keepalive_default_close() { assert_eq!(res, 0); } -#[test] -fn test_http10_keepalive() { - let srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) +#[actix_rt::test] +async fn test_http10_keepalive() { + let srv = test_server(|| { + HttpService::build() + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -281,12 +302,13 @@ fn test_http10_keepalive() { assert_eq!(res, 0); } -#[test] -fn test_http1_keepalive_disabled() { - let srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_http1_keepalive_disabled() { + let srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -300,26 +322,28 @@ fn test_http1_keepalive_disabled() { assert_eq!(res, 0); } -#[test] -fn test_content_length() { +#[actix_rt::test] +async fn test_content_length() { use actix_http::http::{ header::{HeaderName, HeaderValue}, StatusCode, }; - let mut srv = TestServer::new(|| { - HttpService::build().h1(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); - let statuses = [ - StatusCode::NO_CONTENT, - StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, - StatusCode::OK, - StatusCode::NOT_FOUND, - ]; - future::ok::<_, ()>(Response::new(statuses[indx])) - }) + let srv = test_server(|| { + HttpService::build() + .h1(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + .tcp() }); let header = HeaderName::from_static("content-length"); @@ -327,35 +351,29 @@ fn test_content_length() { { for i in 0..4 { - let req = srv - .request(http::Method::GET, srv.url(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); + let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); assert_eq!(response.headers().get(&header), None); - let req = srv - .request(http::Method::HEAD, srv.url(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); + let req = srv.request(http::Method::HEAD, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); assert_eq!(response.headers().get(&header), None); } for i in 4..6 { - let req = srv - .request(http::Method::GET, srv.url(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); + let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); assert_eq!(response.headers().get(&header), Some(&value)); } } } -#[test] -fn test_h1_headers() { +#[actix_rt::test] +async fn test_h1_headers() { let data = STR.repeat(10); let data2 = data.clone(); - let mut srv = TestServer::new(move || { + let mut srv = test_server(move || { let data = data.clone(); HttpService::build().h1(move |_| { let mut builder = Response::Ok(); @@ -378,14 +396,14 @@ fn test_h1_headers() { ); } future::ok::<_, ()>(builder.body(data.clone())) - }) + }).tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from(data2)); } @@ -411,27 +429,31 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World"; -#[test] -fn test_h1_body() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().body(STR))) +#[actix_rt::test] +async fn test_h1_body() { + let mut srv = test_server(|| { + HttpService::build() + .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_h1_head_empty() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) +#[actix_rt::test] +async fn test_h1_head_empty() { + let mut srv = test_server(|| { + HttpService::build() + .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .tcp() }); - let response = srv.block_on(srv.head("/").send()).unwrap(); + let response = srv.head("/").send().await.unwrap(); assert!(response.status().is_success()); { @@ -443,19 +465,21 @@ fn test_h1_head_empty() { } // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert!(bytes.is_empty()); } -#[test] -fn test_h1_head_binary() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) - }) +#[actix_rt::test] +async fn test_h1_head_binary() { + let mut srv = test_server(|| { + HttpService::build() + .h1(|_| { + ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) + }) + .tcp() }); - let response = srv.block_on(srv.head("/").send()).unwrap(); + let response = srv.head("/").send().await.unwrap(); assert!(response.status().is_success()); { @@ -467,17 +491,19 @@ fn test_h1_head_binary() { } // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert!(bytes.is_empty()); } -#[test] -fn test_h1_head_binary2() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) +#[actix_rt::test] +async fn test_h1_head_binary2() { + let srv = test_server(|| { + HttpService::build() + .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .tcp() }); - let response = srv.block_on(srv.head("/").send()).unwrap(); + let response = srv.head("/").send().await.unwrap(); assert!(response.status().is_success()); { @@ -489,39 +515,43 @@ fn test_h1_head_binary2() { } } -#[test] -fn test_h1_body_length() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), - ) - }) +#[actix_rt::test] +async fn test_h1_body_length() { + let mut srv = test_server(|| { + HttpService::build() + .h1(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_h1_body_chunked_explicit() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") - .streaming(body), - ) - }) +#[actix_rt::test] +async fn test_h1_body_chunked_explicit() { + let mut srv = test_server(|| { + HttpService::build() + .h1(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); assert_eq!( response @@ -534,22 +564,24 @@ fn test_h1_body_chunked_explicit() { ); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); // decode assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_h1_body_chunked_implicit() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(|_| { - let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>(Response::Ok().streaming(body)) - }) +#[actix_rt::test] +async fn test_h1_body_chunked_implicit() { + let mut srv = test_server(|| { + HttpService::build() + .h1(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>(Response::Ok().streaming(body)) + }) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); assert_eq!( response @@ -562,59 +594,61 @@ fn test_h1_body_chunked_implicit() { ); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_h1_response_http_error_handling() { - let mut srv = TestServer::new(|| { - HttpService::build().h1(new_service_cfg(|_: &ServerConfig| { - Ok::<_, ()>(|_| { +#[actix_rt::test] +async fn test_h1_response_http_error_handling() { + let mut srv = test_server(|| { + HttpService::build() + .h1(fn_service(|_| { let broken_header = Bytes::from_static(b"\0\0\0"); ok::<_, ()>( Response::Ok() .header(http::header::CONTENT_TYPE, broken_header) .body(STR), ) - }) - })) + })) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); } -#[test] -fn test_h1_service_error() { - let mut srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_h1_service_error() { + let mut srv = test_server(|| { HttpService::build() - .h1(|_| Err::(error::ErrorBadRequest("error"))) + .h1(|_| future::err::(error::ErrorBadRequest("error"))) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); // read response - let bytes = srv.load_body(response).unwrap(); + let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(b"error")); } -#[test] -fn test_h1_on_connect() { - let mut srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_h1_on_connect() { + let srv = test_server(|| { HttpService::build() .on_connect(|_| 10usize) .h1(|req: Request| { assert!(req.extensions().contains::()); future::ok::<_, ()>(Response::Ok().finish()) }) + .tcp() }); - let response = srv.block_on(srv.get("/").send()).unwrap(); + let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); } diff --git a/actix-http/tests/test_ssl_server.rs b/actix-http/tests/test_ssl_server.rs deleted file mode 100644 index 897d92b37..000000000 --- a/actix-http/tests/test_ssl_server.rs +++ /dev/null @@ -1,480 +0,0 @@ -#![cfg(feature = "ssl")] -use actix_codec::{AsyncRead, AsyncWrite}; -use actix_http_test::TestServer; -use actix_server::ssl::OpensslAcceptor; -use actix_server_config::ServerConfig; -use actix_service::{new_service_cfg, NewService}; - -use bytes::{Bytes, BytesMut}; -use futures::future::{ok, Future}; -use futures::stream::{once, Stream}; -use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; -use std::io::Result; - -use actix_http::error::{ErrorBadRequest, PayloadError}; -use actix_http::http::header::{self, HeaderName, HeaderValue}; -use actix_http::http::{Method, StatusCode, Version}; -use actix_http::httpmessage::HttpMessage; -use actix_http::{body, Error, HttpService, Request, Response}; - -fn load_body(stream: S) -> impl Future -where - S: Stream, -{ - stream.fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, PayloadError>(body) - }) -} - -fn ssl_acceptor() -> Result> { - // load ssl keys - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("../tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("../tests/cert.pem") - .unwrap(); - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(AlpnError::NOACK) - } - }); - builder.set_alpn_protos(b"\x02h2")?; - Ok(OpensslAcceptor::new(builder.build())) -} - -#[test] -fn test_h2() -> Result<()> { - let openssl = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, Error>(Response::Ok().finish())) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_1() -> Result<()> { - let openssl = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .finish(|req: Request| { - assert!(req.peer_addr().is_some()); - assert_eq!(req.version(), Version::HTTP_2); - ok::<_, Error>(Response::Ok().finish()) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - Ok(()) -} - -#[test] -fn test_h2_body() -> Result<()> { - let data = "HELLOWORLD".to_owned().repeat(64 * 1024); - let openssl = ssl_acceptor()?; - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|mut req: Request<_>| { - load_body(req.take_payload()) - .and_then(|body| Ok(Response::Ok().body(body))) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap(); - assert!(response.status().is_success()); - - let body = srv.load_body(response).unwrap(); - assert_eq!(&body, data.as_bytes()); - Ok(()) -} - -#[test] -fn test_h2_content_length() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); - let statuses = [ - StatusCode::NO_CONTENT, - StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, - StatusCode::OK, - StatusCode::NOT_FOUND, - ]; - ok::<_, ()>(Response::new(statuses[indx])) - }) - .map_err(|_| ()), - ) - }); - - let header = HeaderName::from_static("content-length"); - let value = HeaderValue::from_static("0"); - - { - for i in 0..4 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - - let req = srv - .request(Method::HEAD, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), None); - } - - for i in 4..6 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = srv.block_on(req).unwrap(); - assert_eq!(response.headers().get(&header), Some(&value)); - } - } -} - -#[test] -fn test_h2_headers() { - let data = STR.repeat(10); - let data2 = data.clone(); - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - let data = data.clone(); - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build().h2(move |_| { - let mut builder = Response::Ok(); - for idx in 0..90 { - builder.header( - format!("X-TEST-{}", idx).as_str(), - "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", - ); - } - ok::<_, ()>(builder.body(data.clone())) - }).map_err(|_| ())) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from(data2)); -} - -const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; - -#[test] -fn test_h2_body2() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_head_empty() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - assert_eq!(response.version(), Version::HTTP_2); - - { - let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - ok::<_, ()>( - Response::Ok().content_length(STR.len() as u64).body(STR), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } - - // read response - let bytes = srv.load_body(response).unwrap(); - assert!(bytes.is_empty()); -} - -#[test] -fn test_h2_head_binary2() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.shead("/").send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } -} - -#[test] -fn test_h2_body_length() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .body(body::SizedStream::new(STR.len() as u64, body)), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_body_chunked_explicit() { - let openssl = ssl_acceptor().unwrap(); - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| { - let body = - once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") - .streaming(body), - ) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); - assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); - - // read response - let bytes = srv.load_body(response).unwrap(); - - // decode - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_h2_response_http_error_handling() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(new_service_cfg(|_: &ServerConfig| { - Ok::<_, ()>(|_| { - let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() - .header(header::CONTENT_TYPE, broken_header) - .body(STR), - ) - }) - })) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); -} - -#[test] -fn test_h2_service_error() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .h2(|_| Err::(ErrorBadRequest("error"))) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - - // read response - let bytes = srv.load_body(response).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"error")); -} - -#[test] -fn test_h2_on_connect() { - let openssl = ssl_acceptor().unwrap(); - - let mut srv = TestServer::new(move || { - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)) - .and_then( - HttpService::build() - .on_connect(|_| 10usize) - .h2(|req: Request| { - assert!(req.extensions().contains::()); - ok::<_, ()>(Response::Ok().finish()) - }) - .map_err(|_| ()), - ) - }); - - let response = srv.block_on(srv.sget("/").send()).unwrap(); - assert!(response.status().is_success()); -} diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 65a4d094d..7152fee48 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -1,76 +1,194 @@ +use std::cell::Cell; +use std::marker::PhantomData; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; -use actix_http_test::TestServer; -use actix_utils::framed::FramedTransport; -use bytes::{Bytes, BytesMut}; -use futures::future::{self, ok}; -use futures::{Future, Sink, Stream}; +use actix_http_test::test_server; +use actix_service::{fn_factory, Service}; +use actix_utils::framed::Dispatcher; +use bytes::Bytes; +use futures::future; +use futures::task::{Context, Poll}; +use futures::{Future, SinkExt, StreamExt}; -fn ws_service( - (req, framed): (Request, Framed), -) -> impl Future { - let res = ws::handshake(req.head()).unwrap().message_body(()); +struct WsService(Arc, Cell)>>); - framed - .send((res, body::BodySize::None).into()) - .map_err(|_| panic!()) - .and_then(|framed| { - FramedTransport::new(framed.into_framed(ws::Codec::new()), service) - .map_err(|_| panic!()) - }) +impl WsService { + fn new() -> Self { + WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false))))) + } + + fn set_polled(&mut self) { + *self.0.lock().unwrap().1.get_mut() = true; + } + + fn was_polled(&self) -> bool { + self.0.lock().unwrap().1.get() + } } -fn service(msg: ws::Frame) -> impl Future { +impl Clone for WsService { + fn clone(&self) -> Self { + WsService(self.0.clone()) + } +} + +impl Service for WsService +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Request = (Request, Framed); + type Response = (); + type Error = Error; + type Future = Pin>>>; + + fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll> { + self.set_polled(); + Poll::Ready(Ok(())) + } + + fn call(&mut self, (req, mut framed): Self::Request) -> Self::Future { + let fut = async move { + let res = ws::handshake(req.head()).unwrap().message_body(()); + + framed + .send((res, body::BodySize::None).into()) + .await + .unwrap(); + + Dispatcher::new(framed.into_framed(ws::Codec::new()), service) + .await + .map_err(|_| panic!()) + }; + + Box::pin(fut) + } +} + +async fn service(msg: ws::Frame) -> Result { let msg = match msg { ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string()) + ws::Message::Text(String::from_utf8_lossy(&text).to_string()) } - ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()), + ws::Frame::Binary(bin) => ws::Message::Binary(bin), + ws::Frame::Continuation(item) => ws::Message::Continuation(item), ws::Frame::Close(reason) => ws::Message::Close(reason), _ => panic!(), }; - ok(msg) + Ok(msg) } -#[test] -fn test_simple() { - let mut srv = TestServer::new(|| { - HttpService::build() - .upgrade(ws_service) - .finish(|_| future::ok::<_, ()>(Response::NotFound())) +#[actix_rt::test] +async fn test_simple() { + let ws_service = WsService::new(); + let mut srv = test_server({ + let ws_service = ws_service.clone(); + move || { + let ws_service = ws_service.clone(); + HttpService::build() + .upgrade(fn_factory(move || future::ok::<_, ()>(ws_service.clone()))) + .finish(|_| future::ok::<_, ()>(Response::NotFound())) + .tcp() + } }); // client service - let framed = srv.ws().unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) + let mut framed = srv.ws().await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); - - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + let (item, mut framed) = framed.into_future().await; assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) + item.unwrap().unwrap(), + ws::Frame::Text(Bytes::from_static(b"text")) ); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) + framed + .send(ws::Message::Binary("text".into())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); - - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) - .unwrap(); - - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + let (item, mut framed) = framed.into_future().await; assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + item.unwrap().unwrap(), + ws::Frame::Binary(Bytes::from_static(&b"text"[..])) ); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); + + framed + .send(ws::Message::Continuation(ws::Item::FirstText( + "text".into(), + ))) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Continuation(ws::Item::FirstText(Bytes::from_static(b"text"))) + ); + + assert!(framed + .send(ws::Message::Continuation(ws::Item::FirstText( + "text".into() + ))) + .await + .is_err()); + assert!(framed + .send(ws::Message::Continuation(ws::Item::FirstBinary( + "text".into() + ))) + .await + .is_err()); + + framed + .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Continuation(ws::Item::Continue(Bytes::from_static(b"text"))) + ); + + framed + .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Continuation(ws::Item::Last(Bytes::from_static(b"text"))) + ); + + assert!(framed + .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .await + .is_err()); + + assert!(framed + .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .await + .is_err()); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let (item, _framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) + ); + + assert!(ws_service.was_polled()); } diff --git a/actix-identity/CHANGES.md b/actix-identity/CHANGES.md index 74a204055..594c21388 100644 --- a/actix-identity/CHANGES.md +++ b/actix-identity/CHANGES.md @@ -1,5 +1,13 @@ # Changes +## [0.2.1] - 2020-01-10 + +* Fix panic with already borrowed: BorrowMutError #1263 + +## [0.2.0] - 2019-12-20 + +* Use actix-web 2.0 + ## [0.1.0] - 2019-06-xx * Move identity middleware to separate crate diff --git a/actix-identity/Cargo.toml b/actix-identity/Cargo.toml index e645275a2..8cd6b1271 100644 --- a/actix-identity/Cargo.toml +++ b/actix-identity/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-identity" -version = "0.1.0" +version = "0.2.1" authors = ["Nikolay Kim "] description = "Identity service for actix web framework." readme = "README.md" @@ -10,21 +10,20 @@ repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-identity/" license = "MIT/Apache-2.0" edition = "2018" -workspace = ".." [lib] name = "actix_identity" path = "src/lib.rs" [dependencies] -actix-web = { version = "1.0.0", default-features = false, features = ["secure-cookies"] } -actix-service = "0.4.0" -futures = "0.1.25" +actix-web = { version = "2.0.0", default-features = false, features = ["secure-cookies"] } +actix-service = "1.0.2" +futures = "0.3.1" serde = "1.0" serde_json = "1.0" time = "0.1.42" [dev-dependencies] -actix-rt = "0.2.2" -actix-http = "0.2.3" -bytes = "0.4" \ No newline at end of file +actix-rt = "1.0.0" +actix-http = "1.0.1" +bytes = "0.5.3" \ No newline at end of file diff --git a/actix-identity/src/lib.rs b/actix-identity/src/lib.rs index 7216104eb..3b9626991 100644 --- a/actix-identity/src/lib.rs +++ b/actix-identity/src/lib.rs @@ -16,7 +16,7 @@ //! use actix_web::*; //! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; //! -//! fn index(id: Identity) -> String { +//! async fn index(id: Identity) -> String { //! // access request identity //! if let Some(id) = id.identity() { //! format!("Welcome! {}", id) @@ -25,12 +25,12 @@ //! } //! } //! -//! fn login(id: Identity) -> HttpResponse { +//! async fn login(id: Identity) -> HttpResponse { //! id.remember("User1".to_owned()); // <- remember identity //! HttpResponse::Ok().finish() //! } //! -//! fn logout(id: Identity) -> HttpResponse { +//! async fn logout(id: Identity) -> HttpResponse { //! id.forget(); // <- remove identity //! HttpResponse::Ok().finish() //! } @@ -47,12 +47,13 @@ //! } //! ``` use std::cell::RefCell; +use std::future::Future; use std::rc::Rc; +use std::task::{Context, Poll}; use std::time::SystemTime; use actix_service::{Service, Transform}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Future, IntoFuture, Poll}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use serde::{Deserialize, Serialize}; use time::Duration; @@ -165,21 +166,21 @@ where impl FromRequest for Identity { type Config = (); type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(Identity(req.clone())) + ok(Identity(req.clone())) } } /// Identity policy definition. pub trait IdentityPolicy: Sized + 'static { /// The return type of the middleware - type Future: IntoFuture, Error = Error>; + type Future: Future, Error>>; /// The return type of the middleware - type ResponseFuture: IntoFuture; + type ResponseFuture: Future>; /// Parse the session from request and load data from a service identity. fn from_request(&self, request: &mut ServiceRequest) -> Self::Future; @@ -234,7 +235,7 @@ where type Error = Error; type InitError = (); type Transform = IdentityServiceMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(IdentityServiceMiddleware { @@ -250,6 +251,15 @@ pub struct IdentityServiceMiddleware { service: Rc>, } +impl Clone for IdentityServiceMiddleware { + fn clone(&self) -> Self { + Self { + backend: self.backend.clone(), + service: self.service.clone(), + } + } +} + impl Service for IdentityServiceMiddleware where B: 'static, @@ -261,46 +271,41 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.borrow_mut().poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.borrow_mut().poll_ready(cx) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { let srv = self.service.clone(); let backend = self.backend.clone(); + let fut = self.backend.from_request(&mut req); - Box::new( - self.backend.from_request(&mut req).into_future().then( - move |res| match res { - Ok(id) => { - req.extensions_mut() - .insert(IdentityItem { id, changed: false }); + async move { + match fut.await { + Ok(id) => { + req.extensions_mut() + .insert(IdentityItem { id, changed: false }); - Either::A(srv.borrow_mut().call(req).and_then(move |mut res| { - let id = - res.request().extensions_mut().remove::(); + // https://github.com/actix/actix-web/issues/1263 + let fut = { srv.borrow_mut().call(req) }; + let mut res = fut.await?; + let id = res.request().extensions_mut().remove::(); - if let Some(id) = id { - Either::A( - backend - .to_response(id.id, id.changed, &mut res) - .into_future() - .then(move |t| match t { - Ok(_) => Ok(res), - Err(e) => Ok(res.error_response(e)), - }), - ) - } else { - Either::B(ok(res)) - } - })) + if let Some(id) = id { + match backend.to_response(id.id, id.changed, &mut res).await { + Ok(_) => Ok(res), + Err(e) => Ok(res.error_response(e)), + } + } else { + Ok(res) } - Err(err) => Either::B(ok(req.error_response(err))), - }, - ), - ) + } + Err(err) => Ok(req.error_response(err)), + } + } + .boxed_local() } } @@ -547,11 +552,11 @@ impl CookieIdentityPolicy { } impl IdentityPolicy for CookieIdentityPolicy { - type Future = Result, Error>; - type ResponseFuture = Result<(), Error>; + type Future = Ready, Error>>; + type ResponseFuture = Ready>; fn from_request(&self, req: &mut ServiceRequest) -> Self::Future { - Ok(self.0.load(req).map( + ok(self.0.load(req).map( |CookieValue { identity, login_timestamp, @@ -603,7 +608,7 @@ impl IdentityPolicy for CookieIdentityPolicy { } else { Ok(()) }; - Ok(()) + ok(()) } } @@ -612,16 +617,17 @@ mod tests { use std::borrow::Borrow; use super::*; + use actix_service::into_service; use actix_web::http::StatusCode; use actix_web::test::{self, TestRequest}; - use actix_web::{web, App, Error, HttpResponse}; + use actix_web::{error, web, App, Error, HttpResponse}; const COOKIE_KEY_MASTER: [u8; 32] = [0; 32]; const COOKIE_NAME: &'static str = "actix_auth"; const COOKIE_LOGIN: &'static str = "test"; - #[test] - fn test_identity() { + #[actix_rt::test] + async fn test_identity() { let mut srv = test::init_service( App::new() .wrap(IdentityService::new( @@ -650,13 +656,16 @@ mod tests { HttpResponse::BadRequest() } })), - ); + ) + .await; let resp = - test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()); + test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()) + .await; assert_eq!(resp.status(), StatusCode::OK); let resp = - test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; assert_eq!(resp.status(), StatusCode::OK); let c = resp.response().cookies().next().unwrap().to_owned(); @@ -665,7 +674,8 @@ mod tests { TestRequest::with_uri("/index") .cookie(c.clone()) .to_request(), - ); + ) + .await; assert_eq!(resp.status(), StatusCode::CREATED); let resp = test::call_service( @@ -673,13 +683,14 @@ mod tests { TestRequest::with_uri("/logout") .cookie(c.clone()) .to_request(), - ); + ) + .await; assert_eq!(resp.status(), StatusCode::OK); assert!(resp.headers().contains_key(header::SET_COOKIE)) } - #[test] - fn test_identity_max_age_time() { + #[actix_rt::test] + async fn test_identity_max_age_time() { let duration = Duration::days(1); let mut srv = test::init_service( App::new() @@ -695,17 +706,19 @@ mod tests { id.remember("test".to_string()); HttpResponse::Ok() })), - ); + ) + .await; let resp = - test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; assert_eq!(resp.status(), StatusCode::OK); assert!(resp.headers().contains_key(header::SET_COOKIE)); let c = resp.response().cookies().next().unwrap().to_owned(); assert_eq!(duration, c.max_age().unwrap()); } - #[test] - fn test_identity_max_age() { + #[actix_rt::test] + async fn test_identity_max_age() { let seconds = 60; let mut srv = test::init_service( App::new() @@ -721,16 +734,18 @@ mod tests { id.remember("test".to_string()); HttpResponse::Ok() })), - ); + ) + .await; let resp = - test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; assert_eq!(resp.status(), StatusCode::OK); assert!(resp.headers().contains_key(header::SET_COOKIE)); let c = resp.response().cookies().next().unwrap().to_owned(); assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap()); } - fn create_identity_server< + async fn create_identity_server< F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static, >( f: F, @@ -747,13 +762,16 @@ mod tests { .secure(false) .name(COOKIE_NAME)))) .service(web::resource("/").to(|id: Identity| { - let identity = id.identity(); - if identity.is_none() { - id.remember(COOKIE_LOGIN.to_string()) + async move { + let identity = id.identity(); + if identity.is_none() { + id.remember(COOKIE_LOGIN.to_string()) + } + web::Json(identity) } - web::Json(identity) })), ) + .await } fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { @@ -786,15 +804,8 @@ mod tests { jar.get(COOKIE_NAME).unwrap().clone() } - fn assert_logged_in(response: &mut ServiceResponse, identity: Option<&str>) { - use bytes::BytesMut; - use futures::Stream; - let bytes = - test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| { - b.extend(c); - Ok::<_, Error>(b) - })) - .unwrap(); + async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) { + let bytes = test::read_body(response).await; let resp: Option = serde_json::from_slice(&bytes[..]).unwrap(); assert_eq!(resp.as_ref().map(|s| s.borrow()), identity); } @@ -872,108 +883,118 @@ mod tests { assert!(cookies.get(COOKIE_NAME).is_none()); } - #[test] - fn test_identity_legacy_cookie_is_set() { - let mut srv = create_identity_server(|c| c); + #[actix_rt::test] + async fn test_identity_legacy_cookie_is_set() { + let mut srv = create_identity_server(|c| c).await; let mut resp = - test::call_service(&mut srv, TestRequest::with_uri("/").to_request()); - assert_logged_in(&mut resp, None); + test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await; assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); + assert_logged_in(resp, None).await; } - #[test] - fn test_identity_legacy_cookie_works() { - let mut srv = create_identity_server(|c| c); + #[actix_rt::test] + async fn test_identity_legacy_cookie_works() { + let mut srv = create_identity_server(|c| c).await; let cookie = legacy_login_cookie(COOKIE_LOGIN); let mut resp = test::call_service( &mut srv, TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); + ) + .await; assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; } - #[test] - fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { - let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); + #[actix_rt::test] + async fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; let cookie = legacy_login_cookie(COOKIE_LOGIN); let mut resp = test::call_service( &mut srv, TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, None); + ) + .await; assert_login_cookie( &mut resp, COOKIE_LOGIN, LoginTimestampCheck::NoTimestamp, VisitTimeStampCheck::NewTimestamp, ); + assert_logged_in(resp, None).await; } - #[test] - fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); + #[actix_rt::test] + async fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; let cookie = legacy_login_cookie(COOKIE_LOGIN); let mut resp = test::call_service( &mut srv, TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, None); + ) + .await; assert_login_cookie( &mut resp, COOKIE_LOGIN, LoginTimestampCheck::NewTimestamp, VisitTimeStampCheck::NoTimestamp, ); + assert_logged_in(resp, None).await; } - #[test] - fn test_identity_cookie_rejected_if_login_timestamp_needed() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_login_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); let mut resp = test::call_service( &mut srv, TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, None); + ) + .await; assert_login_cookie( &mut resp, COOKIE_LOGIN, LoginTimestampCheck::NewTimestamp, VisitTimeStampCheck::NoTimestamp, ); + assert_logged_in(resp, None).await; } - #[test] - fn test_identity_cookie_rejected_if_visit_timestamp_needed() { - let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_visit_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let mut resp = test::call_service( &mut srv, TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, None); + ) + .await; assert_login_cookie( &mut resp, COOKIE_LOGIN, LoginTimestampCheck::NoTimestamp, VisitTimeStampCheck::NewTimestamp, ); + assert_logged_in(resp, None).await; } - #[test] - fn test_identity_cookie_rejected_if_login_timestamp_too_old() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_login_timestamp_too_old() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; let cookie = login_cookie( COOKIE_LOGIN, Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), @@ -984,19 +1005,21 @@ mod tests { TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, None); + ) + .await; assert_login_cookie( &mut resp, COOKIE_LOGIN, LoginTimestampCheck::NewTimestamp, VisitTimeStampCheck::NoTimestamp, ); + assert_logged_in(resp, None).await; } - #[test] - fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { - let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; let cookie = login_cookie( COOKIE_LOGIN, None, @@ -1007,36 +1030,41 @@ mod tests { TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, None); + ) + .await; assert_login_cookie( &mut resp, COOKIE_LOGIN, LoginTimestampCheck::NoTimestamp, VisitTimeStampCheck::NewTimestamp, ); + assert_logged_in(resp, None).await; } - #[test] - fn test_identity_cookie_not_updated_on_login_deadline() { - let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); + #[actix_rt::test] + async fn test_identity_cookie_not_updated_on_login_deadline() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let mut resp = test::call_service( &mut srv, TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); + ) + .await; assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; } - #[test] - fn test_identity_cookie_updated_on_visit_deadline() { + // https://github.com/actix/actix-web/issues/1263 + #[actix_rt::test] + async fn test_identity_cookie_updated_on_visit_deadline() { let mut srv = create_identity_server(|c| { c.visit_deadline(Duration::days(90)) .login_deadline(Duration::days(90)) - }); + }) + .await; let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap(); let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); let mut resp = test::call_service( @@ -1044,13 +1072,57 @@ mod tests { TestRequest::with_uri("/") .cookie(cookie.clone()) .to_request(), - ); - assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); + ) + .await; assert_login_cookie( &mut resp, COOKIE_LOGIN, LoginTimestampCheck::OldTimestamp(timestamp), VisitTimeStampCheck::NewTimestamp, ); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + } + + #[actix_rt::test] + async fn test_borrowed_mut_error() { + use futures::future::{lazy, ok, Ready}; + + struct Ident; + impl IdentityPolicy for Ident { + type Future = Ready, Error>>; + type ResponseFuture = Ready>; + + fn from_request(&self, _: &mut ServiceRequest) -> Self::Future { + ok(Some("test".to_string())) + } + + fn to_response( + &self, + _: Option, + _: bool, + _: &mut ServiceResponse, + ) -> Self::ResponseFuture { + ok(()) + } + } + + let mut srv = IdentityServiceMiddleware { + backend: Rc::new(Ident), + service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| { + async move { + actix_rt::time::delay_for(std::time::Duration::from_secs(100)).await; + Err::(error::ErrorBadRequest("error")) + } + }))), + }; + + let mut srv2 = srv.clone(); + let req = TestRequest::default().to_srv_request(); + actix_rt::spawn(async move { + let _ = srv2.call(req).await; + }); + actix_rt::time::delay_for(std::time::Duration::from_millis(50)).await; + + let _ = lazy(|cx| srv.poll_ready(cx)).await; } } diff --git a/actix-multipart/CHANGES.md b/actix-multipart/CHANGES.md index 365dca286..31f326d05 100644 --- a/actix-multipart/CHANGES.md +++ b/actix-multipart/CHANGES.md @@ -1,5 +1,18 @@ # Changes -## [0.1.4] - 2019-xx-xx + +## [0.2.0] - 2019-12-20 + +* Release + +## [0.2.0-alpha.4] - 2019-12-xx + +* Multipart handling now handles Pending during read of boundary #1205 + +## [0.2.0-alpha.2] - 2019-12-03 + +* Migrate to `std::future` + +## [0.1.4] - 2019-09-12 * Multipart handling now parses requests which do not end in CRLF #1038 diff --git a/actix-multipart/Cargo.toml b/actix-multipart/Cargo.toml index b26681e25..6c683cb1a 100644 --- a/actix-multipart/Cargo.toml +++ b/actix-multipart/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-multipart" -version = "0.1.3" +version = "0.2.0" authors = ["Nikolay Kim "] description = "Multipart support for actix web framework." readme = "README.md" @@ -9,8 +9,6 @@ homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-multipart/" license = "MIT/Apache-2.0" -exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] -workspace = ".." edition = "2018" [lib] @@ -18,17 +16,18 @@ name = "actix_multipart" path = "src/lib.rs" [dependencies] -actix-web = { version = "1.0.0", default-features = false } -actix-service = "0.4.1" -bytes = "0.4" -derive_more = "0.15.0" +actix-web = { version = "2.0.0-rc", default-features = false } +actix-service = "1.0.1" +actix-utils = "1.0.3" +bytes = "0.5.3" +derive_more = "0.99.2" httparse = "1.3" -futures = "0.1.25" +futures = "0.3.1" log = "0.4" mime = "0.3" time = "0.1" twoway = "0.2" [dev-dependencies] -actix-rt = "0.2.2" -actix-http = "0.2.4" \ No newline at end of file +actix-rt = "1.0.0" +actix-http = "1.0.0" \ No newline at end of file diff --git a/actix-multipart/README.md b/actix-multipart/README.md index ac0d05640..a453f489e 100644 --- a/actix-multipart/README.md +++ b/actix-multipart/README.md @@ -5,4 +5,4 @@ * [API Documentation](https://docs.rs/actix-multipart/) * [Chat on gitter](https://gitter.im/actix/actix) * Cargo package: [actix-multipart](https://crates.io/crates/actix-multipart) -* Minimum supported Rust version: 1.33 or later +* Minimum supported Rust version: 1.39 or later diff --git a/actix-multipart/src/error.rs b/actix-multipart/src/error.rs index 32c740a1a..6677f69c7 100644 --- a/actix-multipart/src/error.rs +++ b/actix-multipart/src/error.rs @@ -1,7 +1,7 @@ //! Error and Result module use actix_web::error::{ParseError, PayloadError}; use actix_web::http::StatusCode; -use actix_web::{HttpResponse, ResponseError}; +use actix_web::ResponseError; use derive_more::{Display, From}; /// A set of errors that can occur during parsing multipart streams @@ -35,14 +35,15 @@ pub enum MultipartError { /// Return `BadRequest` for `MultipartError` impl ResponseError for MultipartError { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } #[cfg(test)] mod tests { use super::*; + use actix_web::HttpResponse; #[test] fn test_multipart_error() { diff --git a/actix-multipart/src/extractor.rs b/actix-multipart/src/extractor.rs index 7274ed092..71c815227 100644 --- a/actix-multipart/src/extractor.rs +++ b/actix-multipart/src/extractor.rs @@ -1,5 +1,6 @@ //! Multipart payload support use actix_web::{dev::Payload, Error, FromRequest, HttpRequest}; +use futures::future::{ok, Ready}; use crate::server::Multipart; @@ -10,33 +11,31 @@ use crate::server::Multipart; /// ## Server example /// /// ```rust -/// # use futures::{Future, Stream}; -/// # use futures::future::{ok, result, Either}; +/// use futures::{Stream, StreamExt}; /// use actix_web::{web, HttpResponse, Error}; /// use actix_multipart as mp; /// -/// fn index(payload: mp::Multipart) -> impl Future { -/// payload.from_err() // <- get multipart stream for current request -/// .and_then(|field| { // <- iterate over multipart items +/// async fn index(mut payload: mp::Multipart) -> Result { +/// // iterate over multipart stream +/// while let Some(item) = payload.next().await { +/// let mut field = item?; +/// /// // Field in turn is stream of *Bytes* object -/// field.from_err() -/// .fold((), |_, chunk| { -/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk)); -/// Ok::<_, Error>(()) -/// }) -/// }) -/// .fold((), |_, _| Ok::<_, Error>(())) -/// .map(|_| HttpResponse::Ok().into()) +/// while let Some(chunk) = field.next().await { +/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk?)); +/// } +/// } +/// Ok(HttpResponse::Ok().into()) /// } /// # fn main() {} /// ``` impl FromRequest for Multipart { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = (); #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - Ok(Multipart::new(req.headers(), payload.take())) + ok(Multipart::new(req.headers(), payload.take())) } } diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index 3312a580a..2555cb7a3 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -1,20 +1,22 @@ //! Multipart payload support use std::cell::{Cell, RefCell, RefMut}; +use std::convert::TryFrom; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::{cmp, fmt}; use bytes::{Bytes, BytesMut}; -use futures::task::{current as current_task, Task}; -use futures::{Async, Poll, Stream}; +use futures::stream::{LocalBoxStream, Stream, StreamExt}; use httparse; use mime; +use actix_utils::task::LocalWaker; use actix_web::error::{ParseError, PayloadError}; use actix_web::http::header::{ self, ContentDisposition, HeaderMap, HeaderName, HeaderValue, }; -use actix_web::http::HttpTryFrom; use crate::error::MultipartError; @@ -60,7 +62,7 @@ impl Multipart { /// Create multipart instance for boundary. pub fn new(headers: &HeaderMap, stream: S) -> Multipart where - S: Stream + 'static, + S: Stream> + Unpin + 'static, { match Self::boundary(headers) { Ok(boundary) => Multipart { @@ -104,22 +106,25 @@ impl Multipart { } impl Stream for Multipart { - type Item = Field; - type Error = MultipartError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { if let Some(err) = self.error.take() { - Err(err) + Poll::Ready(Some(Err(err))) } else if self.safety.current() { - let mut inner = self.inner.as_mut().unwrap().borrow_mut(); - if let Some(mut payload) = inner.payload.get_mut(&self.safety) { - payload.poll_stream()?; + let this = self.get_mut(); + let mut inner = this.inner.as_mut().unwrap().borrow_mut(); + if let Some(mut payload) = inner.payload.get_mut(&this.safety) { + payload.poll_stream(cx)?; } - inner.poll(&self.safety) + inner.poll(&this.safety, cx) } else if !self.safety.is_clean() { - Err(MultipartError::NotConsumed) + Poll::Ready(Some(Err(MultipartError::NotConsumed))) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -178,13 +183,15 @@ impl InnerMultipart { Some(chunk) => { if chunk.len() < boundary.len() + 4 || &chunk[..2] != b"--" - || &chunk[2..boundary.len() + 2] != boundary.as_bytes() { + || &chunk[2..boundary.len() + 2] != boundary.as_bytes() + { Err(MultipartError::Boundary) } else if &chunk[boundary.len() + 2..] == b"\r\n" { Ok(Some(false)) } else if &chunk[boundary.len() + 2..boundary.len() + 4] == b"--" && (chunk.len() == boundary.len() + 4 - || &chunk[boundary.len() + 4..] == b"\r\n") { + || &chunk[boundary.len() + 4..] == b"\r\n") + { Ok(Some(true)) } else { Err(MultipartError::Boundary) @@ -236,9 +243,13 @@ impl InnerMultipart { Ok(Some(eof)) } - fn poll(&mut self, safety: &Safety) -> Poll, MultipartError> { + fn poll( + &mut self, + safety: &Safety, + cx: &mut Context, + ) -> Poll>> { if self.state == InnerState::Eof { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { // release field loop { @@ -247,10 +258,13 @@ impl InnerMultipart { if safety.current() { let stop = match self.item { InnerMultipartItem::Field(ref mut field) => { - match field.borrow_mut().poll(safety)? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(_)) => continue, - Async::Ready(None) => true, + match field.borrow_mut().poll(safety) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(_))) => continue, + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => true, } } InnerMultipartItem::None => false, @@ -275,12 +289,12 @@ impl InnerMultipart { Some(eof) => { if eof { self.state = InnerState::Eof; - return Ok(Async::Ready(None)); + return Poll::Ready(None); } else { self.state = InnerState::Headers; } } - None => return Ok(Async::NotReady), + None => return Poll::Pending, } } // read boundary @@ -289,11 +303,11 @@ impl InnerMultipart { &mut *payload, &self.boundary, )? { - None => return Ok(Async::NotReady), + None => return Poll::Pending, Some(eof) => { if eof { self.state = InnerState::Eof; - return Ok(Async::Ready(None)); + return Poll::Ready(None); } else { self.state = InnerState::Headers; } @@ -309,14 +323,14 @@ impl InnerMultipart { self.state = InnerState::Boundary; headers } else { - return Ok(Async::NotReady); + return Poll::Pending; } } else { unreachable!() } } else { log::debug!("NotReady: field is in flight"); - return Ok(Async::NotReady); + return Poll::Pending; }; // content type @@ -333,7 +347,7 @@ impl InnerMultipart { // nested multipart stream if mt.type_() == mime::MULTIPART { - Err(MultipartError::Nested) + Poll::Ready(Some(Err(MultipartError::Nested))) } else { let field = Rc::new(RefCell::new(InnerField::new( self.payload.clone(), @@ -342,12 +356,7 @@ impl InnerMultipart { )?)); self.item = InnerMultipartItem::Field(Rc::clone(&field)); - Ok(Async::Ready(Some(Field::new( - safety.clone(), - headers, - mt, - field, - )))) + Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field)))) } } } @@ -407,23 +416,21 @@ impl Field { } impl Stream for Field { - type Item = Bytes; - type Error = MultipartError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.safety.current() { let mut inner = self.inner.borrow_mut(); if let Some(mut payload) = inner.payload.as_ref().unwrap().get_mut(&self.safety) { - payload.poll_stream()?; + payload.poll_stream(cx)?; } - inner.poll(&self.safety) } else if !self.safety.is_clean() { - Err(MultipartError::NotConsumed) + Poll::Ready(Some(Err(MultipartError::NotConsumed))) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -480,9 +487,9 @@ impl InnerField { fn read_len( payload: &mut PayloadBuffer, size: &mut u64, - ) -> Poll, MultipartError> { + ) -> Poll>> { if *size == 0 { - Ok(Async::Ready(None)) + Poll::Ready(None) } else { match payload.read_max(*size)? { Some(mut chunk) => { @@ -492,13 +499,13 @@ impl InnerField { if !chunk.is_empty() { payload.unprocessed(chunk); } - Ok(Async::Ready(Some(ch))) + Poll::Ready(Some(Ok(ch))) } None => { if payload.eof && (*size != 0) { - Err(MultipartError::Incomplete) + Poll::Ready(Some(Err(MultipartError::Incomplete))) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -510,15 +517,15 @@ impl InnerField { fn read_stream( payload: &mut PayloadBuffer, boundary: &str, - ) -> Poll, MultipartError> { + ) -> Poll>> { let mut pos = 0; let len = payload.buf.len(); if len == 0 { return if payload.eof { - Err(MultipartError::Incomplete) + Poll::Ready(Some(Err(MultipartError::Incomplete))) } else { - Ok(Async::NotReady) + Poll::Pending }; } @@ -535,10 +542,10 @@ impl InnerField { if let Some(b_len) = b_len { let b_size = boundary.len() + b_len; if len < b_size { - return Ok(Async::NotReady); + return Poll::Pending; } else if &payload.buf[b_len..b_size] == boundary.as_bytes() { // found boundary - return Ok(Async::Ready(None)); + return Poll::Ready(None); } } } @@ -550,9 +557,9 @@ impl InnerField { // check if we have enough data for boundary detection if cur + 4 > len { if cur > 0 { - Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze()))) + Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) } else { - Ok(Async::NotReady) + Poll::Pending } } else { // check boundary @@ -563,7 +570,7 @@ impl InnerField { { if cur != 0 { // return buffer - Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze()))) + Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) } else { pos = cur + 1; continue; @@ -575,49 +582,51 @@ impl InnerField { } } } else { - Ok(Async::Ready(Some(payload.buf.take().freeze()))) + Poll::Ready(Some(Ok(payload.buf.split().freeze()))) }; } } - fn poll(&mut self, s: &Safety) -> Poll, MultipartError> { + fn poll(&mut self, s: &Safety) -> Poll>> { if self.payload.is_none() { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) { if !self.eof { let res = if let Some(ref mut len) = self.length { - InnerField::read_len(&mut *payload, len)? + InnerField::read_len(&mut *payload, len) } else { - InnerField::read_stream(&mut *payload, &self.boundary)? + InnerField::read_stream(&mut *payload, &self.boundary) }; match res { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(bytes)) => return Ok(Async::Ready(Some(bytes))), - Async::Ready(None) => self.eof = true, + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => self.eof = true, } } - match payload.readline()? { - None => Async::Ready(None), - Some(line) => { + match payload.readline() { + Ok(None) => Poll::Pending, + Ok(Some(line)) => { if line.as_ref() != b"\r\n" { log::warn!("multipart field did not read all the data or it is malformed"); } - Async::Ready(None) + Poll::Ready(None) } + Err(e) => Poll::Ready(Some(Err(e))), } } else { - Async::NotReady + Poll::Pending }; - if Async::Ready(None) == result { + if let Poll::Ready(None) = result { self.payload.take(); } - Ok(result) + result } } @@ -657,7 +666,7 @@ impl Clone for PayloadRef { /// most task. #[derive(Debug)] struct Safety { - task: Option, + task: LocalWaker, level: usize, payload: Rc>, clean: Rc>, @@ -667,7 +676,7 @@ impl Safety { fn new() -> Safety { let payload = Rc::new(PhantomData); Safety { - task: None, + task: LocalWaker::new(), level: Rc::strong_count(&payload), clean: Rc::new(Cell::new(true)), payload, @@ -681,17 +690,17 @@ impl Safety { fn is_clean(&self) -> bool { self.clean.get() } -} -impl Clone for Safety { - fn clone(&self) -> Safety { + fn clone(&self, cx: &mut Context) -> Safety { let payload = Rc::clone(&self.payload); - Safety { - task: Some(current_task()), + let s = Safety { + task: LocalWaker::new(), level: Rc::strong_count(&payload), clean: self.clean.clone(), payload, - } + }; + s.task.register(cx.waker()); + s } } @@ -702,7 +711,7 @@ impl Drop for Safety { self.clean.set(true); } if let Some(task) = self.task.take() { - task.notify() + task.wake() } } } @@ -711,31 +720,32 @@ impl Drop for Safety { struct PayloadBuffer { eof: bool, buf: BytesMut, - stream: Box>, + stream: LocalBoxStream<'static, Result>, } impl PayloadBuffer { /// Create new `PayloadBuffer` instance fn new(stream: S) -> Self where - S: Stream + 'static, + S: Stream> + 'static, { PayloadBuffer { eof: false, buf: BytesMut::new(), - stream: Box::new(stream), + stream: stream.boxed_local(), } } - fn poll_stream(&mut self) -> Result<(), PayloadError> { + fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> { loop { - match self.stream.poll()? { - Async::Ready(Some(data)) => self.buf.extend_from_slice(&data), - Async::Ready(None) => { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), + Poll::Ready(Some(Err(e))) => return Err(e), + Poll::Ready(None) => { self.eof = true; return Ok(()); } - Async::NotReady => return Ok(()), + Poll::Pending => return Ok(()), } } } @@ -781,14 +791,16 @@ impl PayloadBuffer { /// Read bytes until new line delimiter or eof pub fn readline_or_eof(&mut self) -> Result, MultipartError> { match self.readline() { - Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.take().freeze())), - line => line + Err(MultipartError::Incomplete) if self.eof => { + Ok(Some(self.buf.split().freeze())) + } + line => line, } } /// Put unprocessed data back to the buffer pub fn unprocessed(&mut self, data: Bytes) { - let buf = BytesMut::from(data); + let buf = BytesMut::from(data.as_ref()); let buf = std::mem::replace(&mut self.buf, buf); self.buf.extend_from_slice(&buf); } @@ -796,16 +808,16 @@ impl PayloadBuffer { #[cfg(test)] mod tests { - use actix_http::h1::Payload; - use bytes::Bytes; - use futures::unsync::mpsc; - use super::*; - use actix_web::http::header::{DispositionParam, DispositionType}; - use actix_web::test::run_on; - #[test] - fn test_boundary() { + use actix_http::h1::Payload; + use actix_utils::mpsc; + use actix_web::http::header::{DispositionParam, DispositionType}; + use bytes::Bytes; + use futures::future::lazy; + + #[actix_rt::test] + async fn test_boundary() { let headers = HeaderMap::new(); match Multipart::boundary(&headers) { Err(MultipartError::NoContentType) => (), @@ -848,25 +860,64 @@ mod tests { } fn create_stream() -> ( - mpsc::UnboundedSender>, - impl Stream, + mpsc::Sender>, + impl Stream>, ) { - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::channel(); - (tx, rx.map_err(|_| panic!()).and_then(|res| res)) + (tx, rx.map(|res| res.map_err(|_| panic!()))) + } + // Stream that returns from a Bytes, one char at a time and Pending every other poll() + struct SlowStream { + bytes: Bytes, + pos: usize, + ready: bool, + } + + impl SlowStream { + fn new(bytes: Bytes) -> SlowStream { + return SlowStream { + bytes: bytes, + pos: 0, + ready: false, + }; + } + } + + impl Stream for SlowStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + let this = self.get_mut(); + if !this.ready { + this.ready = true; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if this.pos == this.bytes.len() { + return Poll::Ready(None); + } + let res = Poll::Ready(Some(Ok(this.bytes.slice(this.pos..(this.pos + 1))))); + this.pos += 1; + this.ready = false; + res + } } fn create_simple_request_with_header() -> (Bytes, HeaderMap) { let bytes = Bytes::from( "testasdadsad\r\n\ - --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ - Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ - Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ - test\r\n\ - --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ - Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ - data\r\n\ - --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n" + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + test\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + data\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n", ); let mut headers = HeaderMap::new(); headers.insert( @@ -878,246 +929,224 @@ mod tests { (bytes, headers) } - #[test] - fn test_multipart_no_end_crlf() { - run_on(|| { - let (sender, payload) = create_stream(); - let (bytes, headers) = create_simple_request_with_header(); - let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf + #[actix_rt::test] + async fn test_multipart_no_end_crlf() { + let (sender, payload) = create_stream(); + let (mut bytes, headers) = create_simple_request_with_header(); + let bytes_stripped = bytes.split_to(bytes.len()); // strip crlf - sender.unbounded_send(Ok(bytes_stripped)).unwrap(); - drop(sender); // eof + sender.send(Ok(bytes_stripped)).unwrap(); + drop(sender); // eof - let mut multipart = Multipart::new(&headers, payload); + let mut multipart = Multipart::new(&headers, payload); - match multipart.poll().unwrap() { - Async::Ready(Some(_)) => (), - _ => unreachable!(), - } + match multipart.next().await.unwrap() { + Ok(_) => (), + _ => unreachable!(), + } - match multipart.poll().unwrap() { - Async::Ready(Some(_)) => (), - _ => unreachable!(), - } + match multipart.next().await.unwrap() { + Ok(_) => (), + _ => unreachable!(), + } - match multipart.poll().unwrap() { - Async::Ready(None) => (), - _ => unreachable!(), - } - }) + match multipart.next().await { + None => (), + _ => unreachable!(), + } } - #[test] - fn test_multipart() { - run_on(|| { - let (sender, payload) = create_stream(); - let (bytes, headers) = create_simple_request_with_header(); + #[actix_rt::test] + async fn test_multipart() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); - sender.unbounded_send(Ok(bytes)).unwrap(); + sender.send(Ok(bytes)).unwrap(); - let mut multipart = Multipart::new(&headers, payload); - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { - let cd = field.content_disposition().unwrap(); - assert_eq!(cd.disposition, DispositionType::FormData); - assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); + let mut multipart = Multipart::new(&headers, payload); + match multipart.next().await { + Some(Ok(mut field)) => { + let cd = field.content_disposition().unwrap(); + assert_eq!(cd.disposition, DispositionType::FormData); + assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); - assert_eq!(field.content_type().type_(), mime::TEXT); - assert_eq!(field.content_type().subtype(), mime::PLAIN); + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); - match field.poll().unwrap() { - Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"), - _ => unreachable!(), - } - match field.poll().unwrap() { - Async::Ready(None) => (), - _ => unreachable!(), - } + match field.next().await.unwrap() { + Ok(chunk) => assert_eq!(chunk, "test"), + _ => unreachable!(), } - _ => unreachable!(), - } - - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { - assert_eq!(field.content_type().type_(), mime::TEXT); - assert_eq!(field.content_type().subtype(), mime::PLAIN); - - match field.poll() { - Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"), - _ => unreachable!(), - } - match field.poll() { - Ok(Async::Ready(None)) => (), - _ => unreachable!(), - } + match field.next().await { + None => (), + _ => unreachable!(), } - _ => unreachable!(), } + _ => unreachable!(), + } - match multipart.poll().unwrap() { - Async::Ready(None) => (), - _ => unreachable!(), - } - }); - } + match multipart.next().await.unwrap() { + Ok(mut field) => { + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); - #[test] - fn test_stream() { - run_on(|| { - let (sender, payload) = create_stream(); - let (bytes, headers) = create_simple_request_with_header(); - - sender.unbounded_send(Ok(bytes)).unwrap(); - - let mut multipart = Multipart::new(&headers, payload); - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { - let cd = field.content_disposition().unwrap(); - assert_eq!(cd.disposition, DispositionType::FormData); - assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); - - assert_eq!(field.content_type().type_(), mime::TEXT); - assert_eq!(field.content_type().subtype(), mime::PLAIN); - - match field.poll().unwrap() { - Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"), - _ => unreachable!(), - } - match field.poll().unwrap() { - Async::Ready(None) => (), - _ => unreachable!(), - } + match field.next().await { + Some(Ok(chunk)) => assert_eq!(chunk, "data"), + _ => unreachable!(), } - _ => unreachable!(), - } - - match multipart.poll().unwrap() { - Async::Ready(Some(mut field)) => { - assert_eq!(field.content_type().type_(), mime::TEXT); - assert_eq!(field.content_type().subtype(), mime::PLAIN); - - match field.poll() { - Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"), - _ => unreachable!(), - } - match field.poll() { - Ok(Async::Ready(None)) => (), - _ => unreachable!(), - } + match field.next().await { + None => (), + _ => unreachable!(), } + } + _ => unreachable!(), + } + + match multipart.next().await { + None => (), + _ => unreachable!(), + } + } + + // Loops, collecting all bytes until end-of-field + async fn get_whole_field(field: &mut Field) -> BytesMut { + let mut b = BytesMut::new(); + loop { + match field.next().await { + Some(Ok(chunk)) => b.extend_from_slice(&chunk), + None => return b, _ => unreachable!(), } + } + } - match multipart.poll().unwrap() { - Async::Ready(None) => (), - _ => unreachable!(), + #[actix_rt::test] + async fn test_stream() { + let (bytes, headers) = create_simple_request_with_header(); + let payload = SlowStream::new(bytes); + + let mut multipart = Multipart::new(&headers, payload); + match multipart.next().await.unwrap() { + Ok(mut field) => { + let cd = field.content_disposition().unwrap(); + assert_eq!(cd.disposition, DispositionType::FormData); + assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); + + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + assert_eq!(get_whole_field(&mut field).await, "test"); } - }); + _ => unreachable!(), + } + + match multipart.next().await { + Some(Ok(mut field)) => { + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + assert_eq!(get_whole_field(&mut field).await, "data"); + } + _ => unreachable!(), + } + + match multipart.next().await { + None => (), + _ => unreachable!(), + } } - #[test] - fn test_basic() { - run_on(|| { - let (_, payload) = Payload::create(false); - let mut payload = PayloadBuffer::new(payload); + #[actix_rt::test] + async fn test_basic() { + let (_, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); - assert_eq!(payload.buf.len(), 0); - payload.poll_stream().unwrap(); - assert_eq!(None, payload.read_max(1).unwrap()); - }) + assert_eq!(payload.buf.len(), 0); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); + assert_eq!(None, payload.read_max(1).unwrap()); } - #[test] - fn test_eof() { - run_on(|| { - let (mut sender, payload) = Payload::create(false); - let mut payload = PayloadBuffer::new(payload); + #[actix_rt::test] + async fn test_eof() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); - assert_eq!(None, payload.read_max(4).unwrap()); - sender.feed_data(Bytes::from("data")); - sender.feed_eof(); - payload.poll_stream().unwrap(); + assert_eq!(None, payload.read_max(4).unwrap()); + sender.feed_data(Bytes::from("data")); + sender.feed_eof(); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); - assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap()); - assert_eq!(payload.buf.len(), 0); - assert!(payload.read_max(1).is_err()); - assert!(payload.eof); - }) + assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap()); + assert_eq!(payload.buf.len(), 0); + assert!(payload.read_max(1).is_err()); + assert!(payload.eof); } - #[test] - fn test_err() { - run_on(|| { - let (mut sender, payload) = Payload::create(false); - let mut payload = PayloadBuffer::new(payload); - assert_eq!(None, payload.read_max(1).unwrap()); - sender.set_error(PayloadError::Incomplete(None)); - payload.poll_stream().err().unwrap(); - }) + #[actix_rt::test] + async fn test_err() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + assert_eq!(None, payload.read_max(1).unwrap()); + sender.set_error(PayloadError::Incomplete(None)); + lazy(|cx| payload.poll_stream(cx)).await.err().unwrap(); } - #[test] - fn test_readmax() { - run_on(|| { - let (mut sender, payload) = Payload::create(false); - let mut payload = PayloadBuffer::new(payload); + #[actix_rt::test] + async fn test_readmax() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); - sender.feed_data(Bytes::from("line1")); - sender.feed_data(Bytes::from("line2")); - payload.poll_stream().unwrap(); - assert_eq!(payload.buf.len(), 10); + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); + assert_eq!(payload.buf.len(), 10); - assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap()); - assert_eq!(payload.buf.len(), 5); + assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap()); + assert_eq!(payload.buf.len(), 5); - assert_eq!(Some(Bytes::from("line2")), payload.read_max(5).unwrap()); - assert_eq!(payload.buf.len(), 0); - }) + assert_eq!(Some(Bytes::from("line2")), payload.read_max(5).unwrap()); + assert_eq!(payload.buf.len(), 0); } - #[test] - fn test_readexactly() { - run_on(|| { - let (mut sender, payload) = Payload::create(false); - let mut payload = PayloadBuffer::new(payload); + #[actix_rt::test] + async fn test_readexactly() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); - assert_eq!(None, payload.read_exact(2)); + assert_eq!(None, payload.read_exact(2)); - sender.feed_data(Bytes::from("line1")); - sender.feed_data(Bytes::from("line2")); - payload.poll_stream().unwrap(); + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); - assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2)); - assert_eq!(payload.buf.len(), 8); + assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2)); + assert_eq!(payload.buf.len(), 8); - assert_eq!(Some(Bytes::from_static(b"ne1l")), payload.read_exact(4)); - assert_eq!(payload.buf.len(), 4); - }) + assert_eq!(Some(Bytes::from_static(b"ne1l")), payload.read_exact(4)); + assert_eq!(payload.buf.len(), 4); } - #[test] - fn test_readuntil() { - run_on(|| { - let (mut sender, payload) = Payload::create(false); - let mut payload = PayloadBuffer::new(payload); + #[actix_rt::test] + async fn test_readuntil() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); - assert_eq!(None, payload.read_until(b"ne").unwrap()); + assert_eq!(None, payload.read_until(b"ne").unwrap()); - sender.feed_data(Bytes::from("line1")); - sender.feed_data(Bytes::from("line2")); - payload.poll_stream().unwrap(); + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); - assert_eq!( - Some(Bytes::from("line")), - payload.read_until(b"ne").unwrap() - ); - assert_eq!(payload.buf.len(), 6); + assert_eq!( + Some(Bytes::from("line")), + payload.read_until(b"ne").unwrap() + ); + assert_eq!(payload.buf.len(), 6); - assert_eq!( - Some(Bytes::from("1line2")), - payload.read_until(b"2").unwrap() - ); - assert_eq!(payload.buf.len(), 0); - }) + assert_eq!( + Some(Bytes::from("1line2")), + payload.read_until(b"2").unwrap() + ); + assert_eq!(payload.buf.len(), 0); } } diff --git a/actix-session/CHANGES.md b/actix-session/CHANGES.md index d85f6d5f1..e4306fa9d 100644 --- a/actix-session/CHANGES.md +++ b/actix-session/CHANGES.md @@ -1,11 +1,27 @@ # Changes +## [0.3.0] - 2019-12-20 + +* Release + +## [0.3.0-alpha.4] - 2019-12-xx + +* Allow access to sessions also from not mutable references to the request + +## [0.3.0-alpha.3] - 2019-12-xx + +* Add access to the session from RequestHead for use of session from guard methods + +* Migrate to `std::future` + +* Migrate to `actix-web` 2.0 + ## [0.2.0] - 2019-07-08 -* Enhanced ``actix-session`` to facilitate state changes. Use ``Session.renew()`` - at successful login to cycle a session (new key/cookie but keeps state). - Use ``Session.purge()`` at logout to invalid a session cookie (and remove - from redis cache, if applicable). +* Enhanced ``actix-session`` to facilitate state changes. Use ``Session.renew()`` + at successful login to cycle a session (new key/cookie but keeps state). + Use ``Session.purge()`` at logout to invalid a session cookie (and remove + from redis cache, if applicable). ## [0.1.1] - 2019-06-03 diff --git a/actix-session/Cargo.toml b/actix-session/Cargo.toml index d973661ef..5989cc0d6 100644 --- a/actix-session/Cargo.toml +++ b/actix-session/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-session" -version = "0.2.0" +version = "0.3.0" authors = ["Nikolay Kim "] description = "Session for actix web framework." readme = "README.md" @@ -9,8 +9,6 @@ homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-session/" license = "MIT/Apache-2.0" -exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] -workspace = ".." edition = "2018" [lib] @@ -24,15 +22,14 @@ default = ["cookie-session"] cookie-session = ["actix-web/secure-cookies"] [dependencies] -actix-web = "1.0.0" -actix-service = "0.4.1" -bytes = "0.4" -derive_more = "0.15.0" -futures = "0.1.25" -hashbrown = "0.5.0" +actix-web = "2.0.0-rc" +actix-service = "1.0.1" +bytes = "0.5.3" +derive_more = "0.99.2" +futures = "0.3.1" serde = "1.0" serde_json = "1.0" time = "0.1.42" [dev-dependencies] -actix-rt = "0.2.2" +actix-rt = "1.0.0" diff --git a/actix-session/src/cookie.rs b/actix-session/src/cookie.rs index 192737780..75eef0c01 100644 --- a/actix-session/src/cookie.rs +++ b/actix-session/src/cookie.rs @@ -17,6 +17,7 @@ use std::collections::HashMap; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; @@ -24,8 +25,7 @@ use actix_web::dev::{ServiceRequest, ServiceResponse}; use actix_web::http::{header::SET_COOKIE, HeaderValue}; use actix_web::{Error, HttpMessage, ResponseError}; use derive_more::{Display, From}; -use futures::future::{ok, Future, FutureResult}; -use futures::Poll; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use serde_json::error::Error as JsonError; use crate::{Session, SessionStatus}; @@ -284,7 +284,7 @@ where type Error = S::Error; type InitError = (); type Transform = CookieSessionMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CookieSessionMiddleware { @@ -309,10 +309,10 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = S::Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } /// On first request, a new session cookie is returned in response, regardless @@ -325,29 +325,36 @@ where let (is_new, state) = self.inner.load(&req); Session::set_session(state.into_iter(), &mut req); - Box::new(self.service.call(req).map(move |mut res| { - match Session::get_changes(&mut res) { - (SessionStatus::Changed, Some(state)) - | (SessionStatus::Renewed, Some(state)) => { - res.checked_expr(|res| inner.set_cookie(res, state)) - } - (SessionStatus::Unchanged, _) => - // set a new session cookie upon first request (new client) - { - if is_new { - let state: HashMap = HashMap::new(); - res.checked_expr(|res| inner.set_cookie(res, state.into_iter())) - } else { + let fut = self.service.call(req); + + async move { + fut.await.map(|mut res| { + match Session::get_changes(&mut res) { + (SessionStatus::Changed, Some(state)) + | (SessionStatus::Renewed, Some(state)) => { + res.checked_expr(|res| inner.set_cookie(res, state)) + } + (SessionStatus::Unchanged, _) => + // set a new session cookie upon first request (new client) + { + if is_new { + let state: HashMap = HashMap::new(); + res.checked_expr(|res| { + inner.set_cookie(res, state.into_iter()) + }) + } else { + res + } + } + (SessionStatus::Purged, _) => { + let _ = inner.remove_cookie(&mut res); res } + _ => res, } - (SessionStatus::Purged, _) => { - let _ = inner.remove_cookie(&mut res); - res - } - _ => res, - } - })) + }) + } + .boxed_local() } } @@ -357,19 +364,22 @@ mod tests { use actix_web::{test, web, App}; use bytes::Bytes; - #[test] - fn cookie_session() { + #[actix_rt::test] + async fn cookie_session() { let mut app = test::init_service( App::new() .wrap(CookieSession::signed(&[0; 32]).secure(false)) .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" + async move { + let _ = ses.set("counter", 100); + "test" + } })), - ); + ) + .await; let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); + let response = app.call(request).await.unwrap(); assert!(response .response() .cookies() @@ -377,19 +387,22 @@ mod tests { .is_some()); } - #[test] - fn private_cookie() { + #[actix_rt::test] + async fn private_cookie() { let mut app = test::init_service( App::new() .wrap(CookieSession::private(&[0; 32]).secure(false)) .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" + async move { + let _ = ses.set("counter", 100); + "test" + } })), - ); + ) + .await; let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); + let response = app.call(request).await.unwrap(); assert!(response .response() .cookies() @@ -397,19 +410,22 @@ mod tests { .is_some()); } - #[test] - fn cookie_session_extractor() { + #[actix_rt::test] + async fn cookie_session_extractor() { let mut app = test::init_service( App::new() .wrap(CookieSession::signed(&[0; 32]).secure(false)) .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" + async move { + let _ = ses.set("counter", 100); + "test" + } })), - ); + ) + .await; let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); + let response = app.call(request).await.unwrap(); assert!(response .response() .cookies() @@ -417,8 +433,8 @@ mod tests { .is_some()); } - #[test] - fn basics() { + #[actix_rt::test] + async fn basics() { let mut app = test::init_service( App::new() .wrap( @@ -431,17 +447,22 @@ mod tests { .max_age(100), ) .service(web::resource("/").to(|ses: Session| { - let _ = ses.set("counter", 100); - "test" + async move { + let _ = ses.set("counter", 100); + "test" + } })) .service(web::resource("/test/").to(|ses: Session| { - let val: usize = ses.get("counter").unwrap().unwrap(); - format!("counter: {}", val) + async move { + let val: usize = ses.get("counter").unwrap().unwrap(); + format!("counter: {}", val) + } })), - ); + ) + .await; let request = test::TestRequest::get().to_request(); - let response = test::block_on(app.call(request)).unwrap(); + let response = app.call(request).await.unwrap(); let cookie = response .response() .cookies() @@ -453,7 +474,7 @@ mod tests { let request = test::TestRequest::with_uri("/test/") .cookie(cookie) .to_request(); - let body = test::read_response(&mut app, request); + let body = test::read_response(&mut app, request).await; assert_eq!(body, Bytes::from_static(b"counter: 100")); } } diff --git a/actix-session/src/lib.rs b/actix-session/src/lib.rs index 2e9e51714..b6e5dd331 100644 --- a/actix-session/src/lib.rs +++ b/actix-session/src/lib.rs @@ -12,7 +12,7 @@ //! [*Session*](struct.Session.html) extractor must be used. Session //! extractor allows us to get or set session data. //! -//! ```rust +//! ```rust,no_run //! use actix_web::{web, App, HttpServer, HttpResponse, Error}; //! use actix_session::{Session, CookieSession}; //! @@ -28,8 +28,8 @@ //! Ok("Welcome!") //! } //! -//! fn main() -> std::io::Result<()> { -//! # std::thread::spawn(|| +//! #[actix_rt::main] +//! async fn main() -> std::io::Result<()> { //! HttpServer::new( //! || App::new().wrap( //! CookieSession::signed(&[0; 32]) // <- create cookie based session middleware @@ -38,16 +38,18 @@ //! .service(web::resource("/").to(|| HttpResponse::Ok()))) //! .bind("127.0.0.1:59880")? //! .run() -//! # ); -//! # Ok(()) +//! .await //! } //! ``` use std::cell::RefCell; +use std::collections::HashMap; use std::rc::Rc; -use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse}; +use actix_web::dev::{ + Extensions, Payload, RequestHead, ServiceRequest, ServiceResponse, +}; use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; -use hashbrown::HashMap; +use futures::future::{ok, Ready}; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json; @@ -83,17 +85,23 @@ pub struct Session(Rc>); /// Helper trait that allows to get session pub trait UserSession { - fn get_session(&mut self) -> Session; + fn get_session(&self) -> Session; } impl UserSession for HttpRequest { - fn get_session(&mut self) -> Session { + fn get_session(&self) -> Session { Session::get_session(&mut *self.extensions_mut()) } } impl UserSession for ServiceRequest { - fn get_session(&mut self) -> Session { + fn get_session(&self) -> Session { + Session::get_session(&mut *self.extensions_mut()) + } +} + +impl UserSession for RequestHead { + fn get_session(&self) -> Session { Session::get_session(&mut *self.extensions_mut()) } } @@ -230,12 +238,12 @@ impl Session { /// ``` impl FromRequest for Session { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = (); #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(Session::get_session(&mut *req.extensions_mut())) + ok(Session::get_session(&mut *req.extensions_mut())) } } @@ -280,6 +288,20 @@ mod tests { assert_eq!(res, Some("value".to_string())); } + #[test] + fn get_session_from_request_head() { + let mut req = test::TestRequest::default().to_srv_request(); + + Session::set_session( + vec![("key".to_string(), "\"value\"".to_string())].into_iter(), + &mut req, + ); + + let session = req.head_mut().get_session(); + let res = session.get::("key").unwrap(); + assert_eq!(res, Some("value".to_string())); + } + #[test] fn purge_session() { let req = test::TestRequest::default().to_srv_request(); diff --git a/actix-web-actors/CHANGES.md b/actix-web-actors/CHANGES.md index 0d1df7e55..66ff7ed6c 100644 --- a/actix-web-actors/CHANGES.md +++ b/actix-web-actors/CHANGES.md @@ -1,5 +1,21 @@ # Changes +## [2.0.0] - 2019-12-20 + +* Release + +## [2.0.0-alpha.1] - 2019-12-15 + +* Migrate to actix-web 2.0.0 + +## [1.0.4] - 2019-12-07 + +* Allow comma-separated websocket subprotocols without spaces (#1172) + +## [1.0.3] - 2019-11-14 + +* Update actix-web and actix-http dependencies + ## [1.0.2] - 2019-07-20 * Add `ws::start_with_addr()`, returning the address of the created actor, along diff --git a/actix-web-actors/Cargo.toml b/actix-web-actors/Cargo.toml index 356109da5..6f573e442 100644 --- a/actix-web-actors/Cargo.toml +++ b/actix-web-actors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web-actors" -version = "1.0.2" +version = "2.0.0" authors = ["Nikolay Kim "] description = "Actix actors support for actix web framework." readme = "README.md" @@ -9,8 +9,6 @@ homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-web-actors/" license = "MIT/Apache-2.0" -exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] -workspace = ".." edition = "2018" [lib] @@ -18,13 +16,14 @@ name = "actix_web_actors" path = "src/lib.rs" [dependencies] -actix = "0.8.3" -actix-web = "1.0.3" -actix-http = "0.2.5" -actix-codec = "0.1.2" -bytes = "0.4" -futures = "0.1.25" +actix = "0.9.0" +actix-web = "2.0.0-rc" +actix-http = "1.0.1" +actix-codec = "0.2.0" +bytes = "0.5.2" +futures = "0.3.1" +pin-project = "0.4.6" [dev-dependencies] +actix-rt = "1.0.0" env_logger = "0.6" -actix-http-test = { version = "0.2.4", features=["ssl"] } diff --git a/actix-web-actors/src/context.rs b/actix-web-actors/src/context.rs index 31b29500a..6a403de12 100644 --- a/actix-web-actors/src/context.rs +++ b/actix-web-actors/src/context.rs @@ -1,4 +1,6 @@ use std::collections::VecDeque; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix::dev::{ AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, ToEnvelope, @@ -7,10 +9,10 @@ use actix::fut::ActorFuture; use actix::{ Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message, SpawnHandle, }; -use actix_web::error::{Error, ErrorInternalServerError}; +use actix_web::error::Error; use bytes::Bytes; -use futures::sync::oneshot::Sender; -use futures::{Async, Future, Poll, Stream}; +use futures::channel::oneshot::Sender; +use futures::{Future, Stream}; /// Execution context for http actors pub struct HttpContext @@ -43,7 +45,7 @@ where #[inline] fn spawn(&mut self, fut: F) -> SpawnHandle where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.spawn(fut) } @@ -51,7 +53,7 @@ where #[inline] fn wait(&mut self, fut: F) where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.wait(fut) } @@ -81,7 +83,7 @@ where { #[inline] /// Create a new HTTP Context from a request and an actor - pub fn create(actor: A) -> impl Stream { + pub fn create(actor: A) -> impl Stream> { let mb = Mailbox::default(); let ctx = HttpContext { inner: ContextParts::new(mb.sender_producer()), @@ -91,7 +93,7 @@ where } /// Create a new HTTP Context - pub fn with_factory(f: F) -> impl Stream + pub fn with_factory(f: F) -> impl Stream> where F: FnOnce(&mut Self) -> A + 'static, { @@ -160,24 +162,23 @@ impl Stream for HttpContextFut where A: Actor>, { - type Item = Bytes; - type Error = Error; + type Item = Result; - fn poll(&mut self) -> Poll, Error> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { if self.fut.alive() { - match self.fut.poll() { - Ok(Async::NotReady) | Ok(Async::Ready(())) => (), - Err(_) => return Err(ErrorInternalServerError("error")), - } + let _ = Pin::new(&mut self.fut).poll(cx); } // frames if let Some(data) = self.fut.ctx().stream.pop_front() { - Ok(Async::Ready(data)) + Poll::Ready(data.map(|b| Ok(b))) } else if self.fut.alive() { - Ok(Async::NotReady) + Poll::Pending } else { - Ok(Async::Ready(None)) + Poll::Ready(None) } } } @@ -199,9 +200,9 @@ mod tests { use actix::Actor; use actix_web::http::StatusCode; - use actix_web::test::{block_on, call_service, init_service, TestRequest}; + use actix_web::test::{call_service, init_service, read_body, TestRequest}; use actix_web::{web, App, HttpResponse}; - use bytes::{Bytes, BytesMut}; + use bytes::Bytes; use super::*; @@ -223,31 +224,25 @@ mod tests { if self.count > 3 { ctx.write_eof() } else { - ctx.write(Bytes::from(format!("LINE-{}", self.count).as_bytes())); + ctx.write(Bytes::from(format!("LINE-{}", self.count))); ctx.run_later(Duration::from_millis(100), |slf, ctx| slf.write(ctx)); } } } - #[test] - fn test_default_resource() { + #[actix_rt::test] + async fn test_default_resource() { let mut srv = init_service(App::new().service(web::resource("/test").to(|| { HttpResponse::Ok().streaming(HttpContext::create(MyActor { count: 0 })) - }))); + }))) + .await; let req = TestRequest::with_uri("/test").to_request(); - let mut resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); - let body = block_on(resp.take_body().fold( - BytesMut::new(), - move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, Error>(body) - }, - )) - .unwrap(); - assert_eq!(body.freeze(), Bytes::from_static(b"LINE-1LINE-2LINE-3")); + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"LINE-1LINE-2LINE-3")); } } diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index e25a7e6e4..b28aeade4 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -1,6 +1,8 @@ //! Websocket integration use std::collections::VecDeque; use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix::dev::{ AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, @@ -16,20 +18,20 @@ use actix_http::ws::{hash_key, Codec}; pub use actix_http::ws::{ CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError, }; - use actix_web::dev::HttpResponseBuilder; -use actix_web::error::{Error, ErrorInternalServerError, PayloadError}; +use actix_web::error::{Error, PayloadError}; use actix_web::http::{header, Method, StatusCode}; use actix_web::{HttpRequest, HttpResponse}; use bytes::{Bytes, BytesMut}; -use futures::sync::oneshot::Sender; -use futures::{Async, Future, Poll, Stream}; +use futures::channel::oneshot::Sender; +use futures::{Future, Stream}; /// Do websocket handshake and start ws actor. pub fn start(actor: A, req: &HttpRequest, stream: T) -> Result where - A: Actor> + StreamHandler, - T: Stream + 'static, + A: Actor> + + StreamHandler>, + T: Stream> + 'static, { let mut res = handshake(req)?; Ok(res.streaming(WebsocketContext::create(actor, stream))) @@ -52,8 +54,9 @@ pub fn start_with_addr( stream: T, ) -> Result<(Addr, HttpResponse), Error> where - A: Actor> + StreamHandler, - T: Stream + 'static, + A: Actor> + + StreamHandler>, + T: Stream> + 'static, { let mut res = handshake(req)?; let (addr, out_stream) = WebsocketContext::create_with_addr(actor, stream); @@ -70,8 +73,9 @@ pub fn start_with_protocols( stream: T, ) -> Result where - A: Actor> + StreamHandler, - T: Stream + 'static, + A: Actor> + + StreamHandler>, + T: Stream> + 'static, { let mut res = handshake_with_protocols(req, protocols)?; Ok(res.streaming(WebsocketContext::create(actor, stream))) @@ -152,7 +156,8 @@ pub fn handshake_with_protocols( .and_then(|req_protocols| { let req_protocols = req_protocols.to_str().ok()?; req_protocols - .split(", ") + .split(',') + .map(|req_p| req_p.trim()) .find(|req_p| protocols.iter().any(|p| p == req_p)) }); @@ -201,14 +206,14 @@ where { fn spawn(&mut self, fut: F) -> SpawnHandle where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.spawn(fut) } fn wait(&mut self, fut: F) where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.wait(fut) } @@ -237,10 +242,10 @@ where { #[inline] /// Create a new Websocket context from a request and an actor - pub fn create(actor: A, stream: S) -> impl Stream + pub fn create(actor: A, stream: S) -> impl Stream> where - A: StreamHandler, - S: Stream + 'static, + A: StreamHandler>, + S: Stream> + 'static, { let (_, stream) = WebsocketContext::create_with_addr(actor, stream); stream @@ -255,10 +260,10 @@ where pub fn create_with_addr( actor: A, stream: S, - ) -> (Addr, impl Stream) + ) -> (Addr, impl Stream>) where - A: StreamHandler, - S: Stream + 'static, + A: StreamHandler>, + S: Stream> + 'static, { let mb = Mailbox::default(); let mut ctx = WebsocketContext { @@ -278,10 +283,10 @@ where actor: A, stream: S, codec: Codec, - ) -> impl Stream + ) -> impl Stream> where - A: StreamHandler, - S: Stream + 'static, + A: StreamHandler>, + S: Stream> + 'static, { let mb = Mailbox::default(); let mut ctx = WebsocketContext { @@ -297,11 +302,11 @@ where pub fn with_factory( stream: S, f: F, - ) -> impl Stream + ) -> impl Stream> where F: FnOnce(&mut Self) -> A + 'static, - A: StreamHandler, - S: Stream + 'static, + A: StreamHandler>, + S: Stream> + 'static, { let mb = Mailbox::default(); let mut ctx = WebsocketContext { @@ -345,14 +350,14 @@ where /// Send ping frame #[inline] - pub fn ping(&mut self, message: &str) { - self.write_raw(Message::Ping(message.to_string())); + pub fn ping(&mut self, message: &[u8]) { + self.write_raw(Message::Ping(Bytes::copy_from_slice(message))); } /// Send pong frame #[inline] - pub fn pong(&mut self, message: &str) { - self.write_raw(Message::Pong(message.to_string())); + pub fn pong(&mut self, message: &[u8]) { + self.write_raw(Message::Pong(Bytes::copy_from_slice(message))); } /// Send close frame @@ -414,30 +419,34 @@ impl Stream for WebsocketContextFut where A: Actor>, { - type Item = Bytes; - type Error = Error; + type Item = Result; - fn poll(&mut self) -> Poll, Error> { - if self.fut.alive() && self.fut.poll().is_err() { - return Err(ErrorInternalServerError("error")); + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + + if this.fut.alive() { + let _ = Pin::new(&mut this.fut).poll(cx); } // encode messages - while let Some(item) = self.fut.ctx().messages.pop_front() { + while let Some(item) = this.fut.ctx().messages.pop_front() { if let Some(msg) = item { - self.encoder.encode(msg, &mut self.buf)?; + this.encoder.encode(msg, &mut this.buf)?; } else { - self.closed = true; + this.closed = true; break; } } - if !self.buf.is_empty() { - Ok(Async::Ready(Some(self.buf.take().freeze()))) - } else if self.fut.alive() && !self.closed { - Ok(Async::NotReady) + if !this.buf.is_empty() { + Poll::Ready(Some(Ok(this.buf.split().freeze()))) + } else if this.fut.alive() && !this.closed { + Poll::Pending } else { - Ok(Async::Ready(None)) + Poll::Ready(None) } } } @@ -453,7 +462,9 @@ where } } +#[pin_project::pin_project] struct WsStream { + #[pin] stream: S, decoder: Codec, buf: BytesMut, @@ -462,7 +473,7 @@ struct WsStream { impl WsStream where - S: Stream, + S: Stream>, { fn new(stream: S, codec: Codec) -> Self { Self { @@ -476,62 +487,64 @@ where impl Stream for WsStream where - S: Stream, + S: Stream>, { - type Item = Message; - type Error = ProtocolError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - if !self.closed { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut this = self.as_mut().project(); + + if !*this.closed { loop { - match self.stream.poll() { - Ok(Async::Ready(Some(chunk))) => { - self.buf.extend_from_slice(&chunk[..]); + this = self.as_mut().project(); + match Pin::new(&mut this.stream).poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + this.buf.extend_from_slice(&chunk[..]); } - Ok(Async::Ready(None)) => { - self.closed = true; + Poll::Ready(None) => { + *this.closed = true; break; } - Ok(Async::NotReady) => break, - Err(e) => { - return Err(ProtocolError::Io(io::Error::new( - io::ErrorKind::Other, - format!("{}", e), - ))); + Poll::Pending => break, + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(ProtocolError::Io( + io::Error::new(io::ErrorKind::Other, format!("{}", e)), + )))); } } } } - match self.decoder.decode(&mut self.buf)? { + match this.decoder.decode(this.buf)? { None => { - if self.closed { - Ok(Async::Ready(None)) + if *this.closed { + Poll::Ready(None) } else { - Ok(Async::NotReady) + Poll::Pending } } Some(frm) => { let msg = match frm { - Frame::Text(data) => { - if let Some(data) = data { - Message::Text( - std::str::from_utf8(&data) - .map_err(|_| ProtocolError::BadEncoding)? - .to_string(), - ) - } else { - Message::Text(String::new()) - } - } - Frame::Binary(data) => Message::Binary( - data.map(|b| b.freeze()).unwrap_or_else(Bytes::new), + Frame::Text(data) => Message::Text( + std::str::from_utf8(&data) + .map_err(|e| { + ProtocolError::Io(io::Error::new( + io::ErrorKind::Other, + format!("{}", e), + )) + })? + .to_string(), ), + Frame::Binary(data) => Message::Binary(data), Frame::Ping(s) => Message::Ping(s), Frame::Pong(s) => Message::Pong(s), Frame::Close(reason) => Message::Close(reason), + Frame::Continuation(item) => Message::Continuation(item), }; - Ok(Async::Ready(Some(msg))) + Poll::Ready(Some(Ok(msg))) } } } @@ -736,5 +749,46 @@ mod tests { .headers() .get(&header::SEC_WEBSOCKET_PROTOCOL) ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_KEY, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_PROTOCOL, + header::HeaderValue::from_static("p1,p2,p3"), + ) + .to_http_request(); + + let protocols = vec!["p3", "p2"]; + + assert_eq!( + StatusCode::SWITCHING_PROTOCOLS, + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .status() + ); + assert_eq!( + Some(&header::HeaderValue::from_static("p2")), + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) + ); } } diff --git a/actix-web-actors/tests/test_ws.rs b/actix-web-actors/tests/test_ws.rs index 687cf4314..076e375d3 100644 --- a/actix-web-actors/tests/test_ws.rs +++ b/actix-web-actors/tests/test_ws.rs @@ -1,10 +1,8 @@ use actix::prelude::*; -use actix_http::HttpService; -use actix_http_test::TestServer; -use actix_web::{web, App, HttpRequest}; +use actix_web::{test, web, App, HttpRequest}; use actix_web_actors::*; -use bytes::{Bytes, BytesMut}; -use futures::{Sink, Stream}; +use bytes::Bytes; +use futures::{SinkExt, StreamExt}; struct Ws; @@ -12,9 +10,13 @@ impl Actor for Ws { type Context = ws::WebsocketContext; } -impl StreamHandler for Ws { - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - match msg { +impl StreamHandler> for Ws { + fn handle( + &mut self, + msg: Result, + ctx: &mut Self::Context, + ) { + match msg.unwrap() { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => ctx.text(text), ws::Message::Binary(bin) => ctx.binary(bin), @@ -24,45 +26,42 @@ impl StreamHandler for Ws { } } -#[test] -fn test_simple() { - let mut srv = - TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").to( - |req: HttpRequest, stream: web::Payload| ws::start(Ws, &req, stream), - ))) - }); +#[actix_rt::test] +async fn test_simple() { + let mut srv = test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| { + async move { ws::start(Ws, &req, stream) } + }, + )) + }); // client service - let framed = srv.ws().unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); - - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) - ); - - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); - - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) + let mut framed = srv.ws().await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await .unwrap(); - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) - ); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); + + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text").into())); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Pong(Bytes::copy_from_slice(b"text"))); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); } diff --git a/actix-web-codegen/CHANGES.md b/actix-web-codegen/CHANGES.md index 7cc0c164f..de676f688 100644 --- a/actix-web-codegen/CHANGES.md +++ b/actix-web-codegen/CHANGES.md @@ -1,5 +1,15 @@ # Changes +## [0.2.0] - 2019-12-13 + +* Generate code for actix-web 2.0 + +## [0.1.3] - 2019-10-14 + +* Bump up `syn` & `quote` to 1.0 + +* Provide better error message + ## [0.1.2] - 2019-06-04 * Add macros for head, options, trace, connect and patch http methods diff --git a/actix-web-codegen/Cargo.toml b/actix-web-codegen/Cargo.toml index 29abb4897..3fe561deb 100644 --- a/actix-web-codegen/Cargo.toml +++ b/actix-web-codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web-codegen" -version = "0.1.2" +version = "0.2.0" description = "Actix web proc macros" readme = "README.md" authors = ["Nikolay Kim "] @@ -12,11 +12,11 @@ workspace = ".." proc-macro = true [dependencies] -quote = "0.6.12" -syn = { version = "0.15.34", features = ["full", "parsing", "extra-traits"] } +quote = "^1" +syn = { version = "^1", features = ["full", "parsing"] } +proc-macro2 = "^1" [dev-dependencies] -actix-web = { version = "1.0.0" } -actix-http = { version = "0.2.4", features=["ssl"] } -actix-http-test = { version = "0.2.0", features=["ssl"] } -futures = { version = "0.1" } +actix-rt = { version = "1.0.0" } +actix-web = { version = "2.0.0-rc" } +futures = { version = "0.3.1" } diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs index b3ae7dd9b..0a727ed69 100644 --- a/actix-web-codegen/src/lib.rs +++ b/actix-web-codegen/src/lib.rs @@ -35,8 +35,8 @@ //! use futures::{future, Future}; //! //! #[get("/test")] -//! fn async_test() -> impl Future { -//! future::ok(HttpResponse::Ok().finish()) +//! async fn async_test() -> Result { +//! Ok(HttpResponse::Ok().finish()) //! } //! ``` @@ -58,7 +58,10 @@ use syn::parse_macro_input; #[proc_macro_attribute] pub fn get(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Get); + let gen = match route::Route::new(args, input, route::GuardType::Get) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -70,7 +73,10 @@ pub fn get(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn post(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Post); + let gen = match route::Route::new(args, input, route::GuardType::Post) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -82,7 +88,10 @@ pub fn post(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn put(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Put); + let gen = match route::Route::new(args, input, route::GuardType::Put) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -94,7 +103,10 @@ pub fn put(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Delete); + let gen = match route::Route::new(args, input, route::GuardType::Delete) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -106,7 +118,10 @@ pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn head(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Head); + let gen = match route::Route::new(args, input, route::GuardType::Head) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -118,7 +133,10 @@ pub fn head(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn connect(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Connect); + let gen = match route::Route::new(args, input, route::GuardType::Connect) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -130,7 +148,10 @@ pub fn connect(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn options(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Options); + let gen = match route::Route::new(args, input, route::GuardType::Options) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -142,7 +163,10 @@ pub fn options(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn trace(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Trace); + let gen = match route::Route::new(args, input, route::GuardType::Trace) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } @@ -154,6 +178,9 @@ pub fn trace(args: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn patch(args: TokenStream, input: TokenStream) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - let gen = route::Args::new(&args, input, route::GuardType::Patch); + let gen = match route::Route::new(args, input, route::GuardType::Patch) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; gen.generate() } diff --git a/actix-web-codegen/src/route.rs b/actix-web-codegen/src/route.rs index 5215f60c8..16d3e8157 100644 --- a/actix-web-codegen/src/route.rs +++ b/actix-web-codegen/src/route.rs @@ -1,21 +1,23 @@ extern crate proc_macro; -use std::fmt; - use proc_macro::TokenStream; -use quote::quote; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{AttributeArgs, Ident, NestedMeta}; enum ResourceType { Async, Sync, } -impl fmt::Display for ResourceType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - ResourceType::Async => write!(f, "to_async"), - ResourceType::Sync => write!(f, "to"), - } +impl ToTokens for ResourceType { + fn to_tokens(&self, stream: &mut TokenStream2) { + let ident = match self { + ResourceType::Async => "to", + ResourceType::Sync => "to", + }; + let ident = Ident::new(ident, Span::call_site()); + stream.append(ident); } } @@ -32,63 +34,89 @@ pub enum GuardType { Patch, } -impl fmt::Display for GuardType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - GuardType::Get => write!(f, "Get"), - GuardType::Post => write!(f, "Post"), - GuardType::Put => write!(f, "Put"), - GuardType::Delete => write!(f, "Delete"), - GuardType::Head => write!(f, "Head"), - GuardType::Connect => write!(f, "Connect"), - GuardType::Options => write!(f, "Options"), - GuardType::Trace => write!(f, "Trace"), - GuardType::Patch => write!(f, "Patch"), +impl GuardType { + fn as_str(&self) -> &'static str { + match self { + GuardType::Get => "Get", + GuardType::Post => "Post", + GuardType::Put => "Put", + GuardType::Delete => "Delete", + GuardType::Head => "Head", + GuardType::Connect => "Connect", + GuardType::Options => "Options", + GuardType::Trace => "Trace", + GuardType::Patch => "Patch", } } } -pub struct Args { - name: syn::Ident, - path: String, - ast: syn::ItemFn, - resource_type: ResourceType, - pub guard: GuardType, - pub extra_guards: Vec, +impl ToTokens for GuardType { + fn to_tokens(&self, stream: &mut TokenStream2) { + let ident = self.as_str(); + let ident = Ident::new(ident, Span::call_site()); + stream.append(ident); + } } -impl fmt::Display for Args { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let ast = &self.ast; - let guards = format!(".guard(actix_web::guard::{}())", self.guard); - let guards = self.extra_guards.iter().fold(guards, |acc, val| { - format!("{}.guard(actix_web::guard::fn_guard({}))", acc, val) - }); +struct Args { + path: syn::LitStr, + guards: Vec, +} - write!( - f, - " -#[allow(non_camel_case_types)] -pub struct {name}; - -impl actix_web::dev::HttpServiceFactory for {name} {{ - fn register(self, config: &mut actix_web::dev::AppService) {{ - {ast} - - let resource = actix_web::Resource::new(\"{path}\"){guards}.{to}({name}); - - actix_web::dev::HttpServiceFactory::register(resource, config) - }} -}}", - name = self.name, - ast = quote!(#ast), - path = self.path, - guards = guards, - to = self.resource_type - ) +impl Args { + fn new(args: AttributeArgs) -> syn::Result { + let mut path = None; + let mut guards = Vec::new(); + for arg in args { + match arg { + NestedMeta::Lit(syn::Lit::Str(lit)) => match path { + None => { + path = Some(lit); + } + _ => { + return Err(syn::Error::new_spanned( + lit, + "Multiple paths specified! Should be only one!", + )); + } + }, + NestedMeta::Meta(syn::Meta::NameValue(nv)) => { + if nv.path.is_ident("guard") { + if let syn::Lit::Str(lit) = nv.lit { + guards.push(Ident::new(&lit.value(), Span::call_site())); + } else { + return Err(syn::Error::new_spanned( + nv.lit, + "Attribute guard expects literal string!", + )); + } + } else { + return Err(syn::Error::new_spanned( + nv.path, + "Unknown attribute key is specified. Allowed: guard", + )); + } + } + arg => { + return Err(syn::Error::new_spanned(arg, "Unknown attribute")); + } + } + } + Ok(Args { + path: path.unwrap(), + guards, + }) } } +pub struct Route { + name: syn::Ident, + args: Args, + ast: syn::ItemFn, + resource_type: ResourceType, + guard: GuardType, +} + fn guess_resource_type(typ: &syn::Type) -> ResourceType { let mut guess = ResourceType::Sync; @@ -111,75 +139,74 @@ fn guess_resource_type(typ: &syn::Type) -> ResourceType { guess } -impl Args { - pub fn new(args: &[syn::NestedMeta], input: TokenStream, guard: GuardType) -> Self { +impl Route { + pub fn new( + args: AttributeArgs, + input: TokenStream, + guard: GuardType, + ) -> syn::Result { if args.is_empty() { - panic!( - "invalid server definition, expected: #[{}(\"some path\")]", - guard - ); + return Err(syn::Error::new( + Span::call_site(), + format!( + r#"invalid server definition, expected #[{}("")]"#, + guard.as_str().to_ascii_lowercase() + ), + )); } + let ast: syn::ItemFn = syn::parse(input)?; + let name = ast.sig.ident.clone(); - let ast: syn::ItemFn = syn::parse(input).expect("Parse input as function"); - let name = ast.ident.clone(); + let args = Args::new(args)?; - let mut extra_guards = Vec::new(); - let mut path = None; - for arg in args { - match arg { - syn::NestedMeta::Literal(syn::Lit::Str(ref fname)) => { - if path.is_some() { - panic!("Multiple paths specified! Should be only one!") - } - let fname = quote!(#fname).to_string(); - path = Some(fname.as_str()[1..fname.len() - 1].to_owned()) - } - syn::NestedMeta::Meta(syn::Meta::NameValue(ident)) => { - match ident.ident.to_string().to_lowercase().as_str() { - "guard" => match ident.lit { - syn::Lit::Str(ref text) => extra_guards.push(text.value()), - _ => panic!("Attribute guard expects literal string!"), - }, - attr => panic!( - "Unknown attribute key is specified: {}. Allowed: guard", - attr - ), - } - } - attr => panic!("Unknown attribute{:?}", attr), - } - } - - let resource_type = if ast.asyncness.is_some() { + let resource_type = if ast.sig.asyncness.is_some() { ResourceType::Async } else { - match ast.decl.output { - syn::ReturnType::Default => panic!( - "Function {} has no return type. Cannot be used as handler", - name - ), + match ast.sig.output { + syn::ReturnType::Default => { + return Err(syn::Error::new_spanned( + ast, + "Function has no return type. Cannot be used as handler", + )); + } syn::ReturnType::Type(_, ref typ) => guess_resource_type(typ.as_ref()), } }; - let path = path.unwrap(); - - Self { + Ok(Self { name, - path, + args, ast, resource_type, guard, - extra_guards, - } + }) } pub fn generate(&self) -> TokenStream { - let text = self.to_string(); + let name = &self.name; + let resource_name = name.to_string(); + let guard = &self.guard; + let ast = &self.ast; + let path = &self.args.path; + let extra_guards = &self.args.guards; + let resource_type = &self.resource_type; + let stream = quote! { + #[allow(non_camel_case_types)] + pub struct #name; - match text.parse() { - Ok(res) => res, - Err(error) => panic!("Error: {:?}\nGenerated code: {}", error, text), - } + impl actix_web::dev::HttpServiceFactory for #name { + fn register(self, config: &mut actix_web::dev::AppService) { + #ast + let resource = actix_web::Resource::new(#path) + .name(#resource_name) + .guard(actix_web::guard::#guard()) + #(.guard(actix_web::guard::fn_guard(#extra_guards)))* + .#resource_type(#name); + + actix_web::dev::HttpServiceFactory::register(resource, config) + } + } + }; + stream.into() } } diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs index 8ecc81dc1..4ac1a8023 100644 --- a/actix-web-codegen/tests/test_macro.rs +++ b/actix-web-codegen/tests/test_macro.rs @@ -1,157 +1,151 @@ -use actix_http::HttpService; -use actix_http_test::TestServer; -use actix_web::{http, web::Path, App, HttpResponse, Responder}; +use actix_web::{http, test, web::Path, App, HttpResponse, Responder}; use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace}; use futures::{future, Future}; #[get("/test")] -fn test() -> impl Responder { +async fn test_handler() -> impl Responder { HttpResponse::Ok() } #[put("/test")] -fn put_test() -> impl Responder { +async fn put_test() -> impl Responder { HttpResponse::Created() } #[patch("/test")] -fn patch_test() -> impl Responder { +async fn patch_test() -> impl Responder { HttpResponse::Ok() } #[post("/test")] -fn post_test() -> impl Responder { +async fn post_test() -> impl Responder { HttpResponse::NoContent() } #[head("/test")] -fn head_test() -> impl Responder { +async fn head_test() -> impl Responder { HttpResponse::Ok() } #[connect("/test")] -fn connect_test() -> impl Responder { +async fn connect_test() -> impl Responder { HttpResponse::Ok() } #[options("/test")] -fn options_test() -> impl Responder { +async fn options_test() -> impl Responder { HttpResponse::Ok() } #[trace("/test")] -fn trace_test() -> impl Responder { +async fn trace_test() -> impl Responder { HttpResponse::Ok() } #[get("/test")] -fn auto_async() -> impl Future { +fn auto_async() -> impl Future> { future::ok(HttpResponse::Ok().finish()) } #[get("/test")] -fn auto_sync() -> impl Future { +fn auto_sync() -> impl Future> { future::ok(HttpResponse::Ok().finish()) } #[put("/test/{param}")] -fn put_param_test(_: Path) -> impl Responder { +async fn put_param_test(_: Path) -> impl Responder { HttpResponse::Created() } #[delete("/test/{param}")] -fn delete_param_test(_: Path) -> impl Responder { +async fn delete_param_test(_: Path) -> impl Responder { HttpResponse::NoContent() } #[get("/test/{param}")] -fn get_param_test(_: Path) -> impl Responder { +async fn get_param_test(_: Path) -> impl Responder { HttpResponse::Ok() } -#[test] -fn test_params() { - let mut srv = TestServer::new(|| { - HttpService::new( - App::new() - .service(get_param_test) - .service(put_param_test) - .service(delete_param_test), - ) +#[actix_rt::test] +async fn test_params() { + let srv = test::start(|| { + App::new() + .service(get_param_test) + .service(put_param_test) + .service(delete_param_test) }); let request = srv.request(http::Method::GET, srv.url("/test/it")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert_eq!(response.status(), http::StatusCode::OK); let request = srv.request(http::Method::PUT, srv.url("/test/it")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert_eq!(response.status(), http::StatusCode::CREATED); let request = srv.request(http::Method::DELETE, srv.url("/test/it")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert_eq!(response.status(), http::StatusCode::NO_CONTENT); } -#[test] -fn test_body() { - let mut srv = TestServer::new(|| { - HttpService::new( - App::new() - .service(post_test) - .service(put_test) - .service(head_test) - .service(connect_test) - .service(options_test) - .service(trace_test) - .service(patch_test) - .service(test), - ) +#[actix_rt::test] +async fn test_body() { + let srv = test::start(|| { + App::new() + .service(post_test) + .service(put_test) + .service(head_test) + .service(connect_test) + .service(options_test) + .service(trace_test) + .service(patch_test) + .service(test_handler) }); let request = srv.request(http::Method::GET, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); let request = srv.request(http::Method::HEAD, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); let request = srv.request(http::Method::CONNECT, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); let request = srv.request(http::Method::OPTIONS, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); let request = srv.request(http::Method::TRACE, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); let request = srv.request(http::Method::PATCH, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); let request = srv.request(http::Method::PUT, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); assert_eq!(response.status(), http::StatusCode::CREATED); let request = srv.request(http::Method::POST, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); assert_eq!(response.status(), http::StatusCode::NO_CONTENT); let request = srv.request(http::Method::GET, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn test_auto_async() { - let mut srv = TestServer::new(|| HttpService::new(App::new().service(auto_async))); +#[actix_rt::test] +async fn test_auto_async() { + let srv = test::start(|| App::new().service(auto_async)); let request = srv.request(http::Method::GET, srv.url("/test")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); } diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 27962a6f5..d9b26e453 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,14 +1,44 @@ # Changes -## +## [1.0.1] - 2019-12-15 + +* Fix compilation with default features off + + +## [1.0.0] - 2019-12-13 + +* Release + +## [1.0.0-alpha.3] + +* Migrate to `std::future` + + +## [0.2.8] - 2019-11-06 + +* Add support for setting query from Serialize type for client request. + + +## [0.2.7] - 2019-09-25 + +### Added + +* Remaining getter methods for `ClientRequest`'s private `head` field #1101 + + +## [0.2.6] - 2019-09-12 + +### Added + +* Export frozen request related types. + + +## [0.2.5] - 2019-09-11 ### Added * Add `FrozenClientRequest` to support retries for sending HTTP requests - -## [0.2.5] - 2019-09-06 - ### Changed * Ensure that the `Host` header is set when initiating a WebSocket client connection. diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 7f42501e9..67e0a3ee4 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "awc" -version = "0.2.4" +version = "1.0.1" authors = ["Nikolay Kim "] description = "Actix http client." readme = "README.md" @@ -12,43 +12,37 @@ categories = ["network-programming", "asynchronous", "web-programming::http-client", "web-programming::websocket"] license = "MIT/Apache-2.0" -exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] edition = "2018" -workspace = ".." [lib] name = "awc" path = "src/lib.rs" [package.metadata.docs.rs] -features = ["ssl", "brotli", "flate2-zlib"] +features = ["openssl", "rustls", "compress"] [features] -default = ["brotli", "flate2-zlib"] +default = ["compress"] # openssl -ssl = ["openssl", "actix-http/ssl"] +openssl = ["open-ssl", "actix-http/openssl"] # rustls -rust-tls = ["rustls", "actix-http/rust-tls"] +rustls = ["rust-tls", "actix-http/rustls"] -# brotli encoding, requires c compiler -brotli = ["actix-http/brotli"] - -# miniz-sys backend for flate2 crate -flate2-zlib = ["actix-http/flate2-zlib"] - -# rust backend for flate2 crate -flate2-rust = ["actix-http/flate2-rust"] +# content-encoding support +compress = ["actix-http/compress"] [dependencies] -actix-codec = "0.1.2" -actix-service = "0.4.1" -actix-http = "0.2.9" -base64 = "0.10.1" -bytes = "0.4" -derive_more = "0.15.0" -futures = "0.1.25" +actix-codec = "0.2.0" +actix-service = "1.0.1" +actix-http = "1.0.0" +actix-rt = "1.0.0" + +base64 = "0.11" +bytes = "0.5.3" +derive_more = "0.99.2" +futures-core = "0.3.1" log =" 0.4" mime = "0.3" percent-encoding = "2.1" @@ -56,21 +50,19 @@ rand = "0.7" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.6.1" -tokio-timer = "0.2.8" -openssl = { version="0.10", optional = true } -rustls = { version = "0.15.2", optional = true } +open-ssl = { version="0.10", package="openssl", optional = true } +rust-tls = { version = "0.16.0", package="rustls", optional = true, features = ["dangerous_configuration"] } [dev-dependencies] -actix-rt = "0.2.2" -actix-web = { version = "1.0.0", features=["ssl"] } -actix-http = { version = "0.2.4", features=["ssl"] } -actix-http-test = { version = "0.2.0", features=["ssl"] } -actix-utils = "0.4.1" -actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] } -brotli2 = { version="0.3.2" } -flate2 = { version="1.0.2" } +actix-connect = { version = "1.0.1", features=["openssl"] } +actix-web = { version = "2.0.0-rc", features=["openssl"] } +actix-http = { version = "1.0.1", features=["openssl"] } +actix-http-test = { version = "1.0.0", features=["openssl"] } +actix-utils = "1.0.3" +actix-server = "1.0.0" +actix-tls = { version = "1.0.0", features=["openssl", "rustls"] } +brotli2 = "0.3.2" +flate2 = "1.0.13" +futures = "0.3.1" env_logger = "0.6" -rand = "0.7" -tokio-tcp = "0.1" -webpki = "0.19" -rustls = { version = "0.15.2", features = ["dangerous_configuration"] } +webpki = "0.21" diff --git a/awc/src/builder.rs b/awc/src/builder.rs index 463f40303..7bd0171ec 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -1,10 +1,11 @@ use std::cell::RefCell; +use std::convert::TryFrom; use std::fmt; use std::rc::Rc; use std::time::Duration; use actix_http::client::{Connect, ConnectError, Connection, Connector}; -use actix_http::http::{header, HeaderMap, HeaderName, HttpTryFrom}; +use actix_http::http::{header, Error as HttpError, HeaderMap, HeaderName}; use actix_service::Service; use crate::connect::ConnectorWrapper; @@ -97,8 +98,8 @@ impl ClientBuilder { /// get added to every request. pub fn header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, - >::Error: fmt::Debug, + HeaderName: TryFrom, + >::Error: fmt::Debug + Into, V: header::IntoHeaderValue, V::Error: fmt::Debug, { diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 82fd6a759..618d653f5 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,5 +1,8 @@ -use std::{fmt, io, net}; +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; +use std::{fmt, io, mem, net}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::body::Body; @@ -7,10 +10,9 @@ use actix_http::client::{ Connect as ClientConnect, ConnectError, Connection, SendRequestError, }; use actix_http::h1::ClientCodec; -use actix_http::{RequestHead, RequestHeadType, ResponseHead}; use actix_http::http::HeaderMap; +use actix_http::{RequestHead, RequestHeadType, ResponseHead}; use actix_service::Service; -use futures::{Future, Poll}; use crate::response::ClientResponse; @@ -22,7 +24,7 @@ pub(crate) trait Connect { head: RequestHead, body: Body, addr: Option, - ) -> Box>; + ) -> Pin>>>; fn send_request_extra( &mut self, @@ -30,17 +32,21 @@ pub(crate) trait Connect { extra_headers: Option, body: Body, addr: Option, - ) -> Box>; + ) -> Pin>>>; /// Send request, returns Response and Framed fn open_tunnel( &mut self, head: RequestHead, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, + ) -> Pin< + Box< + dyn Future< + Output = Result< + (ResponseHead, Framed), + SendRequestError, + >, + >, >, >; @@ -50,10 +56,14 @@ pub(crate) trait Connect { head: Rc, extra_headers: Option, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, + ) -> Pin< + Box< + dyn Future< + Output = Result< + (ResponseHead, Framed), + SendRequestError, + >, + >, >, >; } @@ -72,19 +82,22 @@ where head: RequestHead, body: Body, addr: Option, - ) -> Box> { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| connection.send_request(RequestHeadType::from(head), body)) - .map(|(head, payload)| ClientResponse::new(head, payload)), - ) + ) -> Pin>>> { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + Box::pin(async move { + let connection = fut.await?; + + // send request + connection + .send_request(RequestHeadType::from(head), body) + .await + .map(|(head, payload)| ClientResponse::new(head, payload)) + }) } fn send_request_extra( @@ -93,46 +106,55 @@ where extra_headers: Option, body: Body, addr: Option, - ) -> Box> { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| connection.send_request(RequestHeadType::Rc(head, extra_headers), body)) - .map(|(head, payload)| ClientResponse::new(head, payload)), - ) + ) -> Pin>>> { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + Box::pin(async move { + let connection = fut.await?; + + // send request + let (head, payload) = connection + .send_request(RequestHeadType::Rc(head, extra_headers), body) + .await?; + + Ok(ClientResponse::new(head, payload)) + }) } fn open_tunnel( &mut self, head: RequestHead, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, + ) -> Pin< + Box< + dyn Future< + Output = Result< + (ResponseHead, Framed), + SendRequestError, + >, + >, >, > { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| connection.open_tunnel(RequestHeadType::from(head))) - .map(|(head, framed)| { - let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); - (head, framed) - }), - ) + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + Box::pin(async move { + let connection = fut.await?; + + // send request + let (head, framed) = + connection.open_tunnel(RequestHeadType::from(head)).await?; + + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok((head, framed)) + }) } fn open_tunnel_extra( @@ -140,46 +162,52 @@ where head: Rc, extra_headers: Option, addr: Option, - ) -> Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, + ) -> Pin< + Box< + dyn Future< + Output = Result< + (ResponseHead, Framed), + SendRequestError, + >, + >, >, > { - Box::new( - self.0 - // connect to the host - .call(ClientConnect { - uri: head.uri.clone(), - addr, - }) - .from_err() - // send request - .and_then(move |connection| connection.open_tunnel(RequestHeadType::Rc(head, extra_headers))) - .map(|(head, framed)| { - let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); - (head, framed) - }), - ) + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + Box::pin(async move { + let connection = fut.await?; + + // send request + let (head, framed) = connection + .open_tunnel(RequestHeadType::Rc(head, extra_headers)) + .await?; + + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok((head, framed)) + }) } } trait AsyncSocket { - fn as_read(&self) -> &dyn AsyncRead; - fn as_read_mut(&mut self) -> &mut dyn AsyncRead; - fn as_write(&mut self) -> &mut dyn AsyncWrite; + fn as_read(&self) -> &(dyn AsyncRead + Unpin); + fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin); + fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin); } -struct Socket(T); +struct Socket(T); -impl AsyncSocket for Socket { - fn as_read(&self) -> &dyn AsyncRead { +impl AsyncSocket for Socket { + fn as_read(&self) -> &(dyn AsyncRead + Unpin) { &self.0 } - fn as_read_mut(&mut self) -> &mut dyn AsyncRead { + fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) { &mut self.0 } - fn as_write(&mut self) -> &mut dyn AsyncWrite { + fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) { &mut self.0 } } @@ -187,35 +215,45 @@ impl AsyncSocket for Socket { pub struct BoxedSocket(Box); impl fmt::Debug for BoxedSocket { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BoxedSocket") } } -impl io::Read for BoxedSocket { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.as_read_mut().read(buf) - } -} - impl AsyncRead for BoxedSocket { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + unsafe fn prepare_uninitialized_buffer( + &self, + buf: &mut [mem::MaybeUninit], + ) -> bool { self.0.as_read().prepare_uninitialized_buffer(buf) } -} -impl io::Write for BoxedSocket { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.as_write().write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.0.as_write().flush() + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) } } impl AsyncWrite for BoxedSocket { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.0.as_write().shutdown() + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx) } } diff --git a/awc/src/error.rs b/awc/src/error.rs index 4eb929007..7fece74ee 100644 --- a/awc/src/error.rs +++ b/awc/src/error.rs @@ -1,13 +1,16 @@ //! Http client errors -pub use actix_http::client::{ConnectError, InvalidUrl, SendRequestError, FreezeRequestError}; +pub use actix_http::client::{ + ConnectError, FreezeRequestError, InvalidUrl, SendRequestError, +}; pub use actix_http::error::PayloadError; +pub use actix_http::http::Error as HttpError; pub use actix_http::ws::HandshakeError as WsHandshakeError; pub use actix_http::ws::ProtocolError as WsProtocolError; -use actix_http::{Response, ResponseError}; +use actix_http::ResponseError; use serde_json::error::Error as JsonError; -use actix_http::http::{header::HeaderValue, Error as HttpError, StatusCode}; +use actix_http::http::{header::HeaderValue, StatusCode}; use derive_more::{Display, From}; /// Websocket client error @@ -66,8 +69,4 @@ pub enum JsonPayloadError { } /// Return `InternalServerError` for `JsonPayloadError` -impl ResponseError for JsonPayloadError { - fn error_response(&self) -> Response { - Response::new(StatusCode::INTERNAL_SERVER_ERROR) - } -} +impl ResponseError for JsonPayloadError {} diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs new file mode 100644 index 000000000..f7098863c --- /dev/null +++ b/awc/src/frozen.rs @@ -0,0 +1,236 @@ +use std::convert::TryFrom; +use std::net; +use std::rc::Rc; +use std::time::Duration; + +use bytes::Bytes; +use futures_core::Stream; +use serde::Serialize; + +use actix_http::body::Body; +use actix_http::http::header::IntoHeaderValue; +use actix_http::http::{Error as HttpError, HeaderMap, HeaderName, Method, Uri}; +use actix_http::{Error, RequestHead}; + +use crate::sender::{RequestSender, SendClientRequest}; +use crate::ClientConfig; + +/// `FrozenClientRequest` struct represents clonable client request. +/// It could be used to send same request multiple times. +#[derive(Clone)] +pub struct FrozenClientRequest { + pub(crate) head: Rc, + pub(crate) addr: Option, + pub(crate) response_decompress: bool, + pub(crate) timeout: Option, + pub(crate) config: Rc, +} + +impl FrozenClientRequest { + /// Get HTTP URI of request + pub fn get_uri(&self) -> &Uri { + &self.head.uri + } + + /// Get HTTP method of this request + pub fn get_method(&self) -> &Method { + &self.head.method + } + + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + /// Send a body. + pub fn send_body(&self, body: B) -> SendClientRequest + where + B: Into, + { + RequestSender::Rc(self.head.clone(), None).send_body( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + body, + ) + } + + /// Send a json body. + pub fn send_json(&self, value: &T) -> SendClientRequest { + RequestSender::Rc(self.head.clone(), None).send_json( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + value, + ) + } + + /// Send an urlencoded body. + pub fn send_form(&self, value: &T) -> SendClientRequest { + RequestSender::Rc(self.head.clone(), None).send_form( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + value, + ) + } + + /// Send a streaming body. + pub fn send_stream(&self, stream: S) -> SendClientRequest + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + RequestSender::Rc(self.head.clone(), None).send_stream( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + stream, + ) + } + + /// Send an empty body. + pub fn send(&self) -> SendClientRequest { + RequestSender::Rc(self.head.clone(), None).send( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + ) + } + + /// Create a `FrozenSendBuilder` with extra headers + pub fn extra_headers(&self, extra_headers: HeaderMap) -> FrozenSendBuilder { + FrozenSendBuilder::new(self.clone(), extra_headers) + } + + /// Create a `FrozenSendBuilder` with an extra header + pub fn extra_header(&self, key: K, value: V) -> FrozenSendBuilder + where + HeaderName: TryFrom, + >::Error: Into, + V: IntoHeaderValue, + { + self.extra_headers(HeaderMap::new()) + .extra_header(key, value) + } +} + +/// Builder that allows to modify extra headers. +pub struct FrozenSendBuilder { + req: FrozenClientRequest, + extra_headers: HeaderMap, + err: Option, +} + +impl FrozenSendBuilder { + pub(crate) fn new(req: FrozenClientRequest, extra_headers: HeaderMap) -> Self { + Self { + req, + extra_headers, + err: None, + } + } + + /// Insert a header, it overrides existing header in `FrozenClientRequest`. + pub fn extra_header(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom, + >::Error: Into, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => self.extra_headers.insert(key, value), + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Complete request construction and send a body. + pub fn send_body(self, body: B) -> SendClientRequest + where + B: Into, + { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_body( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + body, + ) + } + + /// Complete request construction and send a json body. + pub fn send_json(self, value: &T) -> SendClientRequest { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_json( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + value, + ) + } + + /// Complete request construction and send an urlencoded body. + pub fn send_form(self, value: &T) -> SendClientRequest { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_form( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + value, + ) + } + + /// Complete request construction and send a streaming body. + pub fn send_stream(self, stream: S) -> SendClientRequest + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_stream( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + stream, + ) + } + + /// Complete request construction and send an empty body. + pub fn send(self) -> SendClientRequest { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + ) + } +} diff --git a/awc/src/lib.rs b/awc/src/lib.rs index da63bbd93..8944fe229 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -1,4 +1,9 @@ -#![allow(clippy::borrow_interior_mutable_const)] +#![deny(rust_2018_idioms, warnings)] +#![allow( + clippy::type_complexity, + clippy::borrow_interior_mutable_const, + clippy::needless_doctest_main +)] //! An HTTP Client //! //! ```rust @@ -6,65 +11,62 @@ //! use actix_rt::System; //! use awc::Client; //! -//! fn main() { -//! System::new("test").block_on(lazy(|| { -//! let mut client = Client::default(); +//! #[actix_rt::main] +//! async fn main() { +//! let mut client = Client::default(); //! -//! client.get("http://www.rust-lang.org") // <- Create request builder -//! .header("User-Agent", "Actix-web") -//! .send() // <- Send http request -//! .map_err(|_| ()) -//! .and_then(|response| { // <- server http response -//! println!("Response: {:?}", response); -//! Ok(()) -//! }) -//! })); +//! let response = client.get("http://www.rust-lang.org") // <- Create request builder +//! .header("User-Agent", "Actix-web") +//! .send() // <- Send http request +//! .await; +//! +//! println!("Response: {:?}", response); //! } //! ``` use std::cell::RefCell; +use std::convert::TryFrom; use std::rc::Rc; use std::time::Duration; pub use actix_http::{client::Connector, cookie, http}; -use actix_http::http::{HeaderMap, HttpTryFrom, Method, Uri}; +use actix_http::http::{Error as HttpError, HeaderMap, Method, Uri}; use actix_http::RequestHead; mod builder; mod connect; pub mod error; +mod frozen; mod request; mod response; +mod sender; pub mod test; pub mod ws; pub use self::builder::ClientBuilder; pub use self::connect::BoxedSocket; +pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; pub use self::response::{ClientResponse, JsonBody, MessageBody}; +pub use self::sender::SendClientRequest; use self::connect::{Connect, ConnectorWrapper}; /// An HTTP Client /// /// ```rust -/// # use futures::future::{Future, lazy}; -/// use actix_rt::System; /// use awc::Client; /// -/// fn main() { -/// System::new("test").block_on(lazy(|| { -/// let mut client = Client::default(); +/// #[actix_rt::main] +/// async fn main() { +/// let mut client = Client::default(); /// -/// client.get("http://www.rust-lang.org") // <- Create request builder -/// .header("User-Agent", "Actix-web") -/// .send() // <- Send http request -/// .map_err(|_| ()) -/// .and_then(|response| { // <- server http response -/// println!("Response: {:?}", response); -/// Ok(()) -/// }) -/// })); +/// let res = client.get("http://www.rust-lang.org") // <- Create request builder +/// .header("User-Agent", "Actix-web") +/// .send() // <- Send http request +/// .await; // <- send request and wait for response +/// +/// println!("Response: {:?}", res); /// } /// ``` #[derive(Clone)] @@ -102,7 +104,8 @@ impl Client { /// Construct HTTP request. pub fn request(&self, method: Method, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { let mut req = ClientRequest::new(method, url, self.0.clone()); @@ -118,7 +121,8 @@ impl Client { /// copies all headers and the method. pub fn request_from(&self, url: U, head: &RequestHead) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { let mut req = self.request(head.method.clone(), url); for (key, value) in head.headers.iter() { @@ -130,7 +134,8 @@ impl Client { /// Construct HTTP *GET* request. pub fn get(&self, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { self.request(Method::GET, url) } @@ -138,7 +143,8 @@ impl Client { /// Construct HTTP *HEAD* request. pub fn head(&self, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { self.request(Method::HEAD, url) } @@ -146,7 +152,8 @@ impl Client { /// Construct HTTP *PUT* request. pub fn put(&self, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { self.request(Method::PUT, url) } @@ -154,7 +161,8 @@ impl Client { /// Construct HTTP *POST* request. pub fn post(&self, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { self.request(Method::POST, url) } @@ -162,7 +170,8 @@ impl Client { /// Construct HTTP *PATCH* request. pub fn patch(&self, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { self.request(Method::PATCH, url) } @@ -170,7 +179,8 @@ impl Client { /// Construct HTTP *DELETE* request. pub fn delete(&self, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { self.request(Method::DELETE, url) } @@ -178,7 +188,8 @@ impl Client { /// Construct HTTP *OPTIONS* request. pub fn options(&self, url: U) -> ClientRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { self.request(Method::OPTIONS, url) } @@ -186,7 +197,8 @@ impl Client { /// Construct WebSockets request. pub fn ws(&self, url: U) -> ws::WebsocketsRequest where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { let mut req = ws::WebsocketsRequest::new(url, self.0.clone()); for (key, value) in self.0.headers.iter() { diff --git a/awc/src/request.rs b/awc/src/request.rs index 4dd07c5d8..67b063a8e 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -1,38 +1,32 @@ +use std::convert::TryFrom; use std::fmt::Write as FmtWrite; -use std::io::Write; use std::rc::Rc; -use std::time::{Duration, Instant}; +use std::time::Duration; use std::{fmt, net}; -use bytes::{BufMut, Bytes, BytesMut}; -use futures::{Async, Future, Poll, Stream, try_ready}; +use bytes::Bytes; +use futures_core::Stream; use percent_encoding::percent_encode; use serde::Serialize; -use serde_json; -use tokio_timer::Delay; -use derive_more::From; -use actix_http::body::{Body, BodyStream}; +use actix_http::body::Body; use actix_http::cookie::{Cookie, CookieJar, USERINFO}; -use actix_http::encoding::Decoder; -use actix_http::http::header::{self, ContentEncoding, Header, IntoHeaderValue}; +use actix_http::http::header::{self, Header, IntoHeaderValue}; use actix_http::http::{ - uri, ConnectionType, Error as HttpError, HeaderMap, HeaderName, HeaderValue, - HttpTryFrom, Method, Uri, Version, + uri, ConnectionType, Error as HttpError, HeaderMap, HeaderName, HeaderValue, Method, + Uri, Version, }; -use actix_http::{Error, Payload, PayloadStream, RequestHead}; +use actix_http::{Error, RequestHead}; -use crate::error::{InvalidUrl, SendRequestError, FreezeRequestError}; -use crate::response::ClientResponse; +use crate::error::{FreezeRequestError, InvalidUrl}; +use crate::frozen::FrozenClientRequest; +use crate::sender::{PrepForSendingError, RequestSender, SendClientRequest}; use crate::ClientConfig; -#[cfg(any(feature = "brotli", feature = "flate2-zlib", feature = "flate2-rust"))] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] const HTTPS_ENCODING: &str = "br, gzip, deflate"; -#[cfg(all( - any(feature = "flate2-zlib", feature = "flate2-rust"), - not(feature = "brotli") -))] -const HTTPS_ENCODING: &str = "gzip, deflate"; +#[cfg(not(any(feature = "flate2-zlib", feature = "flate2-rust")))] +const HTTPS_ENCODING: &str = "br"; /// An HTTP Client request builder /// @@ -40,21 +34,20 @@ const HTTPS_ENCODING: &str = "gzip, deflate"; /// builder-like pattern. /// /// ```rust -/// use futures::future::{Future, lazy}; /// use actix_rt::System; /// -/// fn main() { -/// System::new("test").block_on(lazy(|| { -/// awc::Client::new() -/// .get("http://www.rust-lang.org") // <- Create request builder -/// .header("User-Agent", "Actix-web") -/// .send() // <- Send http request -/// .map_err(|_| ()) -/// .and_then(|response| { // <- server http response -/// println!("Response: {:?}", response); -/// Ok(()) -/// }) -/// })); +/// #[actix_rt::main] +/// async fn main() { +/// let response = awc::Client::new() +/// .get("http://www.rust-lang.org") // <- Create request builder +/// .header("User-Agent", "Actix-web") +/// .send() // <- Send http request +/// .await; +/// +/// response.and_then(|response| { // <- server http response +/// println!("Response: {:?}", response); +/// Ok(()) +/// }); /// } /// ``` pub struct ClientRequest { @@ -71,7 +64,8 @@ impl ClientRequest { /// Create new client request builder. pub(crate) fn new(method: Method, uri: U, config: Rc) -> Self where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { ClientRequest { config, @@ -90,7 +84,8 @@ impl ClientRequest { #[inline] pub fn uri(mut self, uri: U) -> Self where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { match Uri::try_from(uri) { Ok(uri) => self.head.uri = uri, @@ -99,7 +94,7 @@ impl ClientRequest { self } - /// Get HTTP URI of request + /// Get HTTP URI of request. pub fn get_uri(&self) -> &Uri { &self.head.uri } @@ -135,6 +130,16 @@ impl ClientRequest { self } + /// Get HTTP version of this request. + pub fn get_version(&self) -> &Version { + &self.head.version + } + + /// Get peer address of this request. + pub fn get_peer_addr(&self) -> &Option { + &self.head.peer_addr + } + #[inline] /// Returns request's headers. pub fn headers(&self) -> &HeaderMap { @@ -151,7 +156,7 @@ impl ClientRequest { /// /// ```rust /// fn main() { - /// # actix_rt::System::new("test").block_on(futures::future::lazy(|| { + /// # actix_rt::System::new("test").block_on(futures::future::lazy(|_| { /// let req = awc::Client::new() /// .get("http://www.rust-lang.org") /// .set(awc::http::header::Date::now()) @@ -179,18 +184,19 @@ impl ClientRequest { /// use awc::{http, Client}; /// /// fn main() { - /// # actix_rt::System::new("test").block_on(futures::future::lazy(|| { + /// # actix_rt::System::new("test").block_on(async { /// let req = Client::new() /// .get("http://www.rust-lang.org") /// .header("X-TEST", "value") /// .header(http::header::CONTENT_TYPE, "application/json"); /// # Ok::<_, ()>(()) - /// # })); + /// # }); /// } /// ``` pub fn header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { match HeaderName::try_from(key) { @@ -206,7 +212,8 @@ impl ClientRequest { /// Insert a header, replaces existing header. pub fn set_header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { match HeaderName::try_from(key) { @@ -222,7 +229,8 @@ impl ClientRequest { /// Insert a header only if it is not yet set. pub fn set_header_if_none(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { match HeaderName::try_from(key) { @@ -258,7 +266,8 @@ impl ClientRequest { #[inline] pub fn content_type(mut self, value: V) -> Self where - HeaderValue: HttpTryFrom, + HeaderValue: TryFrom, + >::Error: Into, { match HeaderValue::try_from(value) { Ok(value) => self.head.headers.insert(header::CONTENT_TYPE, value), @@ -270,9 +279,7 @@ impl ClientRequest { /// Set content length #[inline] pub fn content_length(self, len: u64) -> Self { - let mut wrt = BytesMut::new().writer(); - let _ = write!(wrt, "{}", len); - self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + self.header(header::CONTENT_LENGTH, len) } /// Set HTTP basic authorization header @@ -301,26 +308,21 @@ impl ClientRequest { /// Set a cookie /// /// ```rust - /// # use actix_rt::System; - /// # use futures::future::{lazy, Future}; - /// fn main() { - /// System::new("test").block_on(lazy(|| { - /// awc::Client::new().get("https://www.rust-lang.org") - /// .cookie( - /// awc::http::Cookie::build("name", "value") - /// .domain("www.rust-lang.org") - /// .path("/") - /// .secure(true) - /// .http_only(true) - /// .finish(), - /// ) - /// .send() - /// .map_err(|_| ()) - /// .and_then(|response| { - /// println!("Response: {:?}", response); - /// Ok(()) - /// }) - /// })); + /// #[actix_rt::main] + /// async fn main() { + /// let resp = awc::Client::new().get("https://www.rust-lang.org") + /// .cookie( + /// awc::http::Cookie::build("name", "value") + /// .domain("www.rust-lang.org") + /// .path("/") + /// .secure(true) + /// .http_only(true) + /// .finish(), + /// ) + /// .send() + /// .await; + /// + /// println!("Response: {:?}", resp); /// } /// ``` pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { @@ -375,6 +377,29 @@ impl ClientRequest { } } + /// Sets the query part of the request + pub fn query( + mut self, + query: &T, + ) -> Result { + let mut parts = self.head.uri.clone().into_parts(); + + if let Some(path_and_query) = parts.path_and_query { + let query = serde_urlencoded::to_string(query)?; + let path = path_and_query.path(); + parts.path_and_query = format!("{}?{}", path, query).parse().ok(); + + match Uri::from_parts(parts) { + Ok(uri) => self.head.uri = uri, + Err(e) => self.err = Some(e.into()), + } + } + + Ok(self) + } + + /// Freeze request builder and construct `FrozenClientRequest`, + /// which could be used for sending same request multiple times. pub fn freeze(self) -> Result { let slf = match self.prep_for_sending() { Ok(slf) => slf, @@ -393,10 +418,7 @@ impl ClientRequest { } /// Complete request construction and send body. - pub fn send_body( - self, - body: B, - ) -> SendBody + pub fn send_body(self, body: B) -> SendClientRequest where B: Into, { @@ -405,49 +427,53 @@ impl ClientRequest { Err(e) => return e.into(), }; - RequestSender::Owned(slf.head) - .send_body(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), body) + RequestSender::Owned(slf.head).send_body( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + body, + ) } /// Set a JSON body and generate `ClientRequest` - pub fn send_json( - self, - value: &T, - ) -> SendBody - { + pub fn send_json(self, value: &T) -> SendClientRequest { let slf = match self.prep_for_sending() { Ok(slf) => slf, Err(e) => return e.into(), }; - RequestSender::Owned(slf.head) - .send_json(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), value) + RequestSender::Owned(slf.head).send_json( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + value, + ) } /// Set a urlencoded body and generate `ClientRequest` /// /// `ClientRequestBuilder` can not be used after this call. - pub fn send_form( - self, - value: &T, - ) -> SendBody - { + pub fn send_form(self, value: &T) -> SendClientRequest { let slf = match self.prep_for_sending() { Ok(slf) => slf, Err(e) => return e.into(), }; - RequestSender::Owned(slf.head) - .send_form(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), value) + RequestSender::Owned(slf.head).send_form( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + value, + ) } /// Set an streaming body and generate `ClientRequest`. - pub fn send_stream( - self, - stream: S, - ) -> SendBody + pub fn send_stream(self, stream: S) -> SendClientRequest where - S: Stream + 'static, + S: Stream> + Unpin + 'static, E: Into + 'static, { let slf = match self.prep_for_sending() { @@ -455,22 +481,28 @@ impl ClientRequest { Err(e) => return e.into(), }; - RequestSender::Owned(slf.head) - .send_stream(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref(), stream) + RequestSender::Owned(slf.head).send_stream( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + stream, + ) } /// Set an empty body and generate `ClientRequest`. - pub fn send( - self, - ) -> SendBody - { + pub fn send(self) -> SendClientRequest { let slf = match self.prep_for_sending() { Ok(slf) => slf, Err(e) => return e.into(), }; - RequestSender::Owned(slf.head) - .send(slf.addr, slf.response_decompress, slf.timeout, slf.config.as_ref()) + RequestSender::Owned(slf.head).send( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + ) } fn prep_for_sending(mut self) -> Result { @@ -482,9 +514,9 @@ impl ClientRequest { let uri = &self.head.uri; if uri.host().is_none() { return Err(InvalidUrl::MissingHost.into()); - } else if uri.scheme_part().is_none() { + } else if uri.scheme().is_none() { return Err(InvalidUrl::MissingScheme.into()); - } else if let Some(scheme) = uri.scheme_part() { + } else if let Some(scheme) = uri.scheme() { match scheme.as_str() { "http" | "ws" | "https" | "wss" => (), _ => return Err(InvalidUrl::UnknownScheme.into()), @@ -509,31 +541,23 @@ impl ClientRequest { let mut slf = self; - // enable br only for https - #[cfg(any( - feature = "brotli", - feature = "flate2-zlib", - feature = "flate2-rust" - ))] - { - if slf.response_decompress { - let https = slf - .head - .uri - .scheme_part() - .map(|s| s == &uri::Scheme::HTTPS) - .unwrap_or(true); + if slf.response_decompress { + let https = slf + .head + .uri + .scheme() + .map(|s| s == &uri::Scheme::HTTPS) + .unwrap_or(true); - if https { - slf = slf.set_header_if_none(header::ACCEPT_ENCODING, HTTPS_ENCODING) - } else { - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] - { - slf = slf - .set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate") - } - }; - } + if https { + slf = slf.set_header_if_none(header::ACCEPT_ENCODING, HTTPS_ENCODING) + } else { + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + { + slf = + slf.set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate") + } + }; } Ok(slf) @@ -541,7 +565,7 @@ impl ClientRequest { } impl fmt::Debug for ClientRequest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "\nClientRequest {:?} {}:{}", @@ -555,441 +579,6 @@ impl fmt::Debug for ClientRequest { } } -#[derive(Clone)] -pub struct FrozenClientRequest { - pub(crate) head: Rc, - pub(crate) addr: Option, - pub(crate) response_decompress: bool, - pub(crate) timeout: Option, - pub(crate) config: Rc, -} - -impl FrozenClientRequest { - /// Get HTTP URI of request - pub fn get_uri(&self) -> &Uri { - &self.head.uri - } - - /// Get HTTP method of this request - pub fn get_method(&self) -> &Method { - &self.head.method - } - - /// Returns request's headers. - pub fn headers(&self) -> &HeaderMap { - &self.head.headers - } - - /// Send a body. - pub fn send_body( - &self, - body: B, - ) -> SendBody - where - B: Into, - { - RequestSender::Rc(self.head.clone(), None) - .send_body(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), body) - } - - /// Send a json body. - pub fn send_json( - &self, - value: &T, - ) -> SendBody - { - RequestSender::Rc(self.head.clone(), None) - .send_json(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), value) - } - - /// Send an urlencoded body. - pub fn send_form( - &self, - value: &T, - ) -> SendBody - { - RequestSender::Rc(self.head.clone(), None) - .send_form(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), value) - } - - /// Send a streaming body. - pub fn send_stream( - &self, - stream: S, - ) -> SendBody - where - S: Stream + 'static, - E: Into + 'static, - { - RequestSender::Rc(self.head.clone(), None) - .send_stream(self.addr, self.response_decompress, self.timeout, self.config.as_ref(), stream) - } - - /// Send an empty body. - pub fn send( - &self, - ) -> SendBody - { - RequestSender::Rc(self.head.clone(), None) - .send(self.addr, self.response_decompress, self.timeout, self.config.as_ref()) - } - - /// Create a `FrozenSendBuilder` with extra headers - pub fn extra_headers(&self, extra_headers: HeaderMap) -> FrozenSendBuilder { - FrozenSendBuilder::new(self.clone(), extra_headers) - } - - /// Create a `FrozenSendBuilder` with an extra header - pub fn extra_header(&self, key: K, value: V) -> FrozenSendBuilder - where - HeaderName: HttpTryFrom, - V: IntoHeaderValue, - { - self.extra_headers(HeaderMap::new()).extra_header(key, value) - } -} - -pub struct FrozenSendBuilder { - req: FrozenClientRequest, - extra_headers: HeaderMap, - err: Option, -} - -impl FrozenSendBuilder { - pub(crate) fn new(req: FrozenClientRequest, extra_headers: HeaderMap) -> Self { - Self { - req, - extra_headers, - err: None, - } - } - - /// Insert a header, it overrides existing header in `FrozenClientRequest`. - pub fn extra_header(mut self, key: K, value: V) -> Self - where - HeaderName: HttpTryFrom, - V: IntoHeaderValue, - { - match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => self.extra_headers.insert(key, value), - Err(e) => self.err = Some(e.into()), - }, - Err(e) => self.err = Some(e.into()), - } - self - } - - /// Complete request construction and send a body. - pub fn send_body( - self, - body: B, - ) -> SendBody - where - B: Into, - { - if let Some(e) = self.err { - return e.into() - } - - RequestSender::Rc(self.req.head, Some(self.extra_headers)) - .send_body(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), body) - } - - /// Complete request construction and send a json body. - pub fn send_json( - self, - value: &T, - ) -> SendBody - { - if let Some(e) = self.err { - return e.into() - } - - RequestSender::Rc(self.req.head, Some(self.extra_headers)) - .send_json(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), value) - } - - /// Complete request construction and send an urlencoded body. - pub fn send_form( - self, - value: &T, - ) -> SendBody - { - if let Some(e) = self.err { - return e.into() - } - - RequestSender::Rc(self.req.head, Some(self.extra_headers)) - .send_form(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), value) - } - - /// Complete request construction and send a streaming body. - pub fn send_stream( - self, - stream: S, - ) -> SendBody - where - S: Stream + 'static, - E: Into + 'static, - { - if let Some(e) = self.err { - return e.into() - } - - RequestSender::Rc(self.req.head, Some(self.extra_headers)) - .send_stream(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref(), stream) - } - - /// Complete request construction and send an empty body. - pub fn send( - self, - ) -> SendBody - { - if let Some(e) = self.err { - return e.into() - } - - RequestSender::Rc(self.req.head, Some(self.extra_headers)) - .send(self.req.addr, self.req.response_decompress, self.req.timeout, self.req.config.as_ref()) - } -} - -#[derive(Debug, From)] -enum PrepForSendingError { - Url(InvalidUrl), - Http(HttpError), -} - -impl Into for PrepForSendingError { - fn into(self) -> FreezeRequestError { - match self { - PrepForSendingError::Url(e) => FreezeRequestError::Url(e), - PrepForSendingError::Http(e) => FreezeRequestError::Http(e), - } - } -} - -impl Into for PrepForSendingError { - fn into(self) -> SendRequestError { - match self { - PrepForSendingError::Url(e) => SendRequestError::Url(e), - PrepForSendingError::Http(e) => SendRequestError::Http(e), - } - } -} - -pub enum SendBody -{ - Fut(Box>, Option, bool), - Err(Option), -} - -impl SendBody -{ - pub fn new( - send: Box>, - response_decompress: bool, - timeout: Option, - ) -> SendBody - { - let delay = timeout.map(|t| Delay::new(Instant::now() + t)); - SendBody::Fut(send, delay, response_decompress) - } -} - -impl Future for SendBody -{ - type Item = ClientResponse>>; - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - match self { - SendBody::Fut(send, delay, response_decompress) => { - if delay.is_some() { - match delay.poll() { - Ok(Async::NotReady) => (), - _ => return Err(SendRequestError::Timeout), - } - } - - let res = try_ready!(send.poll()) - .map_body(|head, payload| { - if *response_decompress { - Payload::Stream(Decoder::from_headers(payload, &head.headers)) - } else { - Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) - } - }); - - Ok(Async::Ready(res)) - }, - SendBody::Err(ref mut e) => { - match e.take() { - Some(e) => Err(e.into()), - None => panic!("Attempting to call completed future"), - } - } - } - } -} - - -impl From for SendBody -{ - fn from(e: SendRequestError) -> Self { - SendBody::Err(Some(e)) - } -} - -impl From for SendBody -{ - fn from(e: Error) -> Self { - SendBody::Err(Some(e.into())) - } -} - -impl From for SendBody -{ - fn from(e: HttpError) -> Self { - SendBody::Err(Some(e.into())) - } -} - -impl From for SendBody -{ - fn from(e: PrepForSendingError) -> Self { - SendBody::Err(Some(e.into())) - } -} - -#[derive(Debug)] -enum RequestSender { - Owned(RequestHead), - Rc(Rc, Option), -} - -impl RequestSender { - pub fn send_body( - self, - addr: Option, - response_decompress: bool, - timeout: Option, - config: &ClientConfig, - body: B, - ) -> SendBody - where - B: Into, - { - let mut connector = config.connector.borrow_mut(); - - let fut = match self { - RequestSender::Owned(head) => connector.send_request(head, body.into(), addr), - RequestSender::Rc(head, extra_headers) => connector.send_request_extra(head, extra_headers, body.into(), addr), - }; - - SendBody::new(fut, response_decompress, timeout.or_else(|| config.timeout.clone())) - } - - pub fn send_json( - mut self, - addr: Option, - response_decompress: bool, - timeout: Option, - config: &ClientConfig, - value: &T, - ) -> SendBody - { - let body = match serde_json::to_string(value) { - Ok(body) => body, - Err(e) => return Error::from(e).into(), - }; - - if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") { - return e.into(); - } - - self.send_body(addr, response_decompress, timeout, config, Body::Bytes(Bytes::from(body))) - } - - pub fn send_form( - mut self, - addr: Option, - response_decompress: bool, - timeout: Option, - config: &ClientConfig, - value: &T, - ) -> SendBody - { - let body = match serde_urlencoded::to_string(value) { - Ok(body) => body, - Err(e) => return Error::from(e).into(), - }; - - // set content-type - if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/x-www-form-urlencoded") { - return e.into(); - } - - self.send_body(addr, response_decompress, timeout, config, Body::Bytes(Bytes::from(body))) - } - - pub fn send_stream( - self, - addr: Option, - response_decompress: bool, - timeout: Option, - config: &ClientConfig, - stream: S, - ) -> SendBody - where - S: Stream + 'static, - E: Into + 'static, - { - self.send_body(addr, response_decompress, timeout, config, Body::from_message(BodyStream::new(stream))) - } - - pub fn send( - self, - addr: Option, - response_decompress: bool, - timeout: Option, - config: &ClientConfig, - ) -> SendBody - { - self.send_body(addr, response_decompress, timeout, config, Body::Empty) - } - - fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> - where - V: IntoHeaderValue, - { - match self { - RequestSender::Owned(head) => { - if !head.headers.contains_key(&key) { - match value.try_into() { - Ok(value) => head.headers.insert(key, value), - Err(e) => return Err(e.into()), - } - } - }, - RequestSender::Rc(head, extra_headers) => { - if !head.headers.contains_key(&key) && !extra_headers.iter().any(|h| h.contains_key(&key)) { - match value.try_into(){ - Ok(v) => { - let h = extra_headers.get_or_insert(HeaderMap::new()); - h.insert(key, v) - }, - Err(e) => return Err(e.into()), - }; - } - } - } - - Ok(()) - } -} - #[cfg(test)] mod tests { use std::time::SystemTime; @@ -1109,4 +698,13 @@ mod tests { "Bearer someS3cr3tAutht0k3n" ); } + + #[test] + fn client_query() { + let req = Client::new() + .get("/") + .query(&[("key1", "val1"), ("key2", "val2")]) + .unwrap(); + assert_eq!(req.get_uri().query().unwrap(), "key1=val1&key2=val2"); + } } diff --git a/awc/src/response.rs b/awc/src/response.rs index d186526de..20093c72d 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -1,9 +1,11 @@ use std::cell::{Ref, RefMut}; use std::fmt; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use bytes::{Bytes, BytesMut}; -use futures::{Async, Future, Poll, Stream}; +use futures_core::{ready, Future, Stream}; use actix_http::cookie::Cookie; use actix_http::error::{CookieParseError, PayloadError}; @@ -27,11 +29,11 @@ impl HttpMessage for ClientResponse { &self.head.headers } - fn extensions(&self) -> Ref { + fn extensions(&self) -> Ref<'_, Extensions> { self.head.extensions() } - fn extensions_mut(&self) -> RefMut { + fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.head.extensions_mut() } @@ -41,7 +43,7 @@ impl HttpMessage for ClientResponse { /// Load request cookies. #[inline] - fn cookies(&self) -> Result>>, CookieParseError> { + fn cookies(&self) -> Result>>, CookieParseError> { struct Cookies(Vec>); if self.extensions().get::().is_none() { @@ -104,7 +106,7 @@ impl ClientResponse { impl ClientResponse where - S: Stream, + S: Stream>, { /// Loads http response's body. pub fn body(&mut self) -> MessageBody { @@ -125,18 +127,20 @@ where impl Stream for ClientResponse where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.payload.poll() + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().payload).poll_next(cx) } } impl fmt::Debug for ClientResponse { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; writeln!(f, " headers:")?; for (key, val) in self.headers().iter() { @@ -155,7 +159,7 @@ pub struct MessageBody { impl MessageBody where - S: Stream, + S: Stream>, { /// Create `MessageBody` for request. pub fn new(res: &mut ClientResponse) -> MessageBody { @@ -198,23 +202,24 @@ where impl Future for MessageBody where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Output = Result; - fn poll(&mut self) -> Poll { - if let Some(err) = self.err.take() { - return Err(err); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Err(err)); } - if let Some(len) = self.length.take() { - if len > self.fut.as_ref().unwrap().limit { - return Err(PayloadError::Overflow); + if let Some(len) = this.length.take() { + if len > this.fut.as_ref().unwrap().limit { + return Poll::Ready(Err(PayloadError::Overflow)); } } - self.fut.as_mut().unwrap().poll() + Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } } @@ -233,7 +238,7 @@ pub struct JsonBody { impl JsonBody where - S: Stream, + S: Stream>, U: DeserializeOwned, { /// Create `JsonBody` for request. @@ -279,27 +284,35 @@ where } } -impl Future for JsonBody +impl Unpin for JsonBody where - T: Stream, + T: Stream> + Unpin, U: DeserializeOwned, { - type Item = U; - type Error = JsonPayloadError; +} - fn poll(&mut self) -> Poll { +impl Future for JsonBody +where + T: Stream> + Unpin, + U: DeserializeOwned, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } if let Some(len) = self.length.take() { if len > self.fut.as_ref().unwrap().limit { - return Err(JsonPayloadError::Payload(PayloadError::Overflow)); + return Poll::Ready(Err(JsonPayloadError::Payload( + PayloadError::Overflow, + ))); } } - let body = futures::try_ready!(self.fut.as_mut().unwrap().poll()); - Ok(Async::Ready(serde_json::from_slice::(&body)?)) + let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; + Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) } } @@ -321,24 +334,25 @@ impl ReadBody { impl Future for ReadBody where - S: Stream, + S: Stream> + Unpin, { - type Item = Bytes; - type Error = PayloadError; + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); - fn poll(&mut self) -> Poll { loop { - return match self.stream.poll()? { - Async::Ready(Some(chunk)) => { - if (self.buf.len() + chunk.len()) > self.limit { - Err(PayloadError::Overflow) + return match Pin::new(&mut this.stream).poll_next(cx)? { + Poll::Ready(Some(chunk)) => { + if (this.buf.len() + chunk.len()) > this.limit { + Poll::Ready(Err(PayloadError::Overflow)) } else { - self.buf.extend_from_slice(&chunk); + this.buf.extend_from_slice(&chunk); continue; } } - Async::Ready(None) => Ok(Async::Ready(self.buf.take().freeze())), - Async::NotReady => Ok(Async::NotReady), + Poll::Ready(None) => Poll::Ready(Ok(this.buf.split().freeze())), + Poll::Pending => Poll::Pending, }; } } @@ -347,23 +361,21 @@ where #[cfg(test)] mod tests { use super::*; - use actix_http_test::block_on; - use futures::Async; use serde::{Deserialize, Serialize}; use crate::{http::header, test::TestResponse}; - #[test] - fn test_body() { + #[actix_rt::test] + async fn test_body() { let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); - match req.body().poll().err().unwrap() { + match req.body().await.err().unwrap() { PayloadError::UnknownLength => (), _ => unreachable!("error"), } let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); - match req.body().poll().err().unwrap() { + match req.body().await.err().unwrap() { PayloadError::Overflow => (), _ => unreachable!("error"), } @@ -371,15 +383,12 @@ mod tests { let mut req = TestResponse::default() .set_payload(Bytes::from_static(b"test")) .finish(); - match req.body().poll().ok().unwrap() { - Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")), - _ => unreachable!("error"), - } + assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); let mut req = TestResponse::default() .set_payload(Bytes::from_static(b"11111111111111")) .finish(); - match req.body().limit(5).poll().err().unwrap() { + match req.body().limit(5).await.err().unwrap() { PayloadError::Overflow => (), _ => unreachable!("error"), } @@ -404,10 +413,10 @@ mod tests { } } - #[test] - fn test_json_body() { + #[actix_rt::test] + async fn test_json_body() { let mut req = TestResponse::default().finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); + let json = JsonBody::<_, MyObject>::new(&mut req).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let mut req = TestResponse::default() @@ -416,7 +425,7 @@ mod tests { header::HeaderValue::from_static("application/text"), ) .finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); + let json = JsonBody::<_, MyObject>::new(&mut req).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let mut req = TestResponse::default() @@ -430,7 +439,7 @@ mod tests { ) .finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req).limit(100)); + let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; assert!(json_eq( json.err().unwrap(), JsonPayloadError::Payload(PayloadError::Overflow) @@ -448,7 +457,7 @@ mod tests { .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .finish(); - let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); + let json = JsonBody::<_, MyObject>::new(&mut req).await; assert_eq!( json.ok().unwrap(), MyObject { diff --git a/awc/src/sender.rs b/awc/src/sender.rs new file mode 100644 index 000000000..ec18f12e3 --- /dev/null +++ b/awc/src/sender.rs @@ -0,0 +1,325 @@ +use std::net; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use actix_rt::time::{delay_for, Delay}; +use bytes::Bytes; +use derive_more::From; +use futures_core::{Future, Stream}; +use serde::Serialize; +use serde_json; + +use actix_http::body::{Body, BodyStream}; +use actix_http::http::header::{self, IntoHeaderValue}; +use actix_http::http::{Error as HttpError, HeaderMap, HeaderName}; +use actix_http::{Error, RequestHead}; + +#[cfg(feature = "compress")] +use actix_http::encoding::Decoder; +#[cfg(feature = "compress")] +use actix_http::http::header::ContentEncoding; +#[cfg(feature = "compress")] +use actix_http::{Payload, PayloadStream}; + +use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; +use crate::response::ClientResponse; +use crate::ClientConfig; + +#[derive(Debug, From)] +pub(crate) enum PrepForSendingError { + Url(InvalidUrl), + Http(HttpError), +} + +impl Into for PrepForSendingError { + fn into(self) -> FreezeRequestError { + match self { + PrepForSendingError::Url(e) => FreezeRequestError::Url(e), + PrepForSendingError::Http(e) => FreezeRequestError::Http(e), + } + } +} + +impl Into for PrepForSendingError { + fn into(self) -> SendRequestError { + match self { + PrepForSendingError::Url(e) => SendRequestError::Url(e), + PrepForSendingError::Http(e) => SendRequestError::Http(e), + } + } +} + +/// Future that sends request's payload and resolves to a server response. +#[must_use = "futures do nothing unless polled"] +pub enum SendClientRequest { + Fut( + Pin>>>, + Option, + bool, + ), + Err(Option), +} + +impl SendClientRequest { + pub(crate) fn new( + send: Pin>>>, + response_decompress: bool, + timeout: Option, + ) -> SendClientRequest { + let delay = timeout.map(delay_for); + SendClientRequest::Fut(send, delay, response_decompress) + } +} + +#[cfg(feature = "compress")] +impl Future for SendClientRequest { + type Output = + Result>>, SendRequestError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match this { + SendClientRequest::Fut(send, delay, response_decompress) => { + if delay.is_some() { + match Pin::new(delay.as_mut().unwrap()).poll(cx) { + Poll::Pending => (), + _ => return Poll::Ready(Err(SendRequestError::Timeout)), + } + } + + let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { + res.map_body(|head, payload| { + if *response_decompress { + Payload::Stream(Decoder::from_headers( + payload, + &head.headers, + )) + } else { + Payload::Stream(Decoder::new( + payload, + ContentEncoding::Identity, + )) + } + }) + }); + + Poll::Ready(res) + } + SendClientRequest::Err(ref mut e) => match e.take() { + Some(e) => Poll::Ready(Err(e)), + None => panic!("Attempting to call completed future"), + }, + } + } +} + +#[cfg(not(feature = "compress"))] +impl Future for SendClientRequest { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + match this { + SendClientRequest::Fut(send, delay, _) => { + if delay.is_some() { + match Pin::new(delay.as_mut().unwrap()).poll(cx) { + Poll::Pending => (), + _ => return Poll::Ready(Err(SendRequestError::Timeout)), + } + } + Pin::new(send).poll(cx) + } + SendClientRequest::Err(ref mut e) => match e.take() { + Some(e) => Poll::Ready(Err(e)), + None => panic!("Attempting to call completed future"), + }, + } + } +} + +impl From for SendClientRequest { + fn from(e: SendRequestError) -> Self { + SendClientRequest::Err(Some(e)) + } +} + +impl From for SendClientRequest { + fn from(e: Error) -> Self { + SendClientRequest::Err(Some(e.into())) + } +} + +impl From for SendClientRequest { + fn from(e: HttpError) -> Self { + SendClientRequest::Err(Some(e.into())) + } +} + +impl From for SendClientRequest { + fn from(e: PrepForSendingError) -> Self { + SendClientRequest::Err(Some(e.into())) + } +} + +#[derive(Debug)] +pub(crate) enum RequestSender { + Owned(RequestHead), + Rc(Rc, Option), +} + +impl RequestSender { + pub(crate) fn send_body( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + body: B, + ) -> SendClientRequest + where + B: Into, + { + let mut connector = config.connector.borrow_mut(); + + let fut = match self { + RequestSender::Owned(head) => { + connector.send_request(head, body.into(), addr) + } + RequestSender::Rc(head, extra_headers) => { + connector.send_request_extra(head, extra_headers, body.into(), addr) + } + }; + + SendClientRequest::new( + fut, + response_decompress, + timeout.or_else(|| config.timeout), + ) + } + + pub(crate) fn send_json( + mut self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + value: &T, + ) -> SendClientRequest { + let body = match serde_json::to_string(value) { + Ok(body) => body, + Err(e) => return Error::from(e).into(), + }; + + if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") + { + return e.into(); + } + + self.send_body( + addr, + response_decompress, + timeout, + config, + Body::Bytes(Bytes::from(body)), + ) + } + + pub(crate) fn send_form( + mut self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + value: &T, + ) -> SendClientRequest { + let body = match serde_urlencoded::to_string(value) { + Ok(body) => body, + Err(e) => return Error::from(e).into(), + }; + + // set content-type + if let Err(e) = self.set_header_if_none( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) { + return e.into(); + } + + self.send_body( + addr, + response_decompress, + timeout, + config, + Body::Bytes(Bytes::from(body)), + ) + } + + pub(crate) fn send_stream( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + stream: S, + ) -> SendClientRequest + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + self.send_body( + addr, + response_decompress, + timeout, + config, + Body::from_message(BodyStream::new(stream)), + ) + } + + pub(crate) fn send( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + ) -> SendClientRequest { + self.send_body(addr, response_decompress, timeout, config, Body::Empty) + } + + fn set_header_if_none( + &mut self, + key: HeaderName, + value: V, + ) -> Result<(), HttpError> + where + V: IntoHeaderValue, + { + match self { + RequestSender::Owned(head) => { + if !head.headers.contains_key(&key) { + match value.try_into() { + Ok(value) => head.headers.insert(key, value), + Err(e) => return Err(e.into()), + } + } + } + RequestSender::Rc(head, extra_headers) => { + if !head.headers.contains_key(&key) + && !extra_headers.iter().any(|h| h.contains_key(&key)) + { + match value.try_into() { + Ok(v) => { + let h = extra_headers.get_or_insert(HeaderMap::new()); + h.insert(key, v) + } + Err(e) => return Err(e.into()), + }; + } + } + } + + Ok(()) + } +} diff --git a/awc/src/test.rs b/awc/src/test.rs index 641ecaa88..a6cbd03e6 100644 --- a/awc/src/test.rs +++ b/awc/src/test.rs @@ -1,9 +1,10 @@ //! Test helpers for actix http client to use during testing. +use std::convert::TryFrom; use std::fmt::Write as FmtWrite; use actix_http::cookie::{Cookie, CookieJar, USERINFO}; use actix_http::http::header::{self, Header, HeaderValue, IntoHeaderValue}; -use actix_http::http::{HeaderName, HttpTryFrom, StatusCode, Version}; +use actix_http::http::{Error as HttpError, HeaderName, StatusCode, Version}; use actix_http::{h1, Payload, ResponseHead}; use bytes::Bytes; use percent_encoding::percent_encode; @@ -31,7 +32,8 @@ impl TestResponse { /// Create TestResponse and set header pub fn with_header(key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { Self::default().header(key, value) @@ -55,7 +57,8 @@ impl TestResponse { /// Append a header pub fn header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { if let Ok(key) = HeaderName::try_from(key) { diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 67be9e9d8..89ca50b59 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -1,4 +1,5 @@ //! Websockets client +use std::convert::TryFrom; use std::fmt::Write as FmtWrite; use std::net::SocketAddr; use std::rc::Rc; @@ -7,9 +8,8 @@ use std::{fmt, str}; use actix_codec::Framed; use actix_http::cookie::{Cookie, CookieJar}; use actix_http::{ws, Payload, RequestHead}; -use futures::future::{err, Either, Future}; +use actix_rt::time::timeout; use percent_encoding::percent_encode; -use tokio_timer::Timeout; use actix_http::cookie::USERINFO; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; @@ -20,7 +20,7 @@ use crate::http::header::{ self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION, }; use crate::http::{ - ConnectionType, Error as HttpError, HttpTryFrom, Method, StatusCode, Uri, Version, + ConnectionType, Error as HttpError, Method, StatusCode, Uri, Version, }; use crate::response::ClientResponse; use crate::ClientConfig; @@ -42,7 +42,8 @@ impl WebsocketsRequest { /// Create new websocket connection pub(crate) fn new(uri: U, config: Rc) -> Self where - Uri: HttpTryFrom, + Uri: TryFrom, + >::Error: Into, { let mut err = None; let mut head = RequestHead::default(); @@ -103,9 +104,10 @@ impl WebsocketsRequest { } /// Set request Origin - pub fn origin(mut self, origin: V) -> Self + pub fn origin(mut self, origin: V) -> Self where - HeaderValue: HttpTryFrom, + HeaderValue: TryFrom, + HttpError: From, { match HeaderValue::try_from(origin) { Ok(value) => self.origin = Some(value), @@ -134,7 +136,8 @@ impl WebsocketsRequest { /// To override header use `set_header()` method. pub fn header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { match HeaderName::try_from(key) { @@ -152,7 +155,8 @@ impl WebsocketsRequest { /// Insert a header, replaces existing header. pub fn set_header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { match HeaderName::try_from(key) { @@ -170,7 +174,8 @@ impl WebsocketsRequest { /// Insert a header only if it is not yet set. pub fn set_header_if_none(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { match HeaderName::try_from(key) { @@ -210,31 +215,33 @@ impl WebsocketsRequest { } /// Complete request construction and connect to a websockets server. - pub fn connect( + pub async fn connect( mut self, - ) -> impl Future), Error = WsClientError> - { + ) -> Result<(ClientResponse, Framed), WsClientError> { if let Some(e) = self.err.take() { - return Either::A(err(e.into())); + return Err(e.into()); } // validate uri let uri = &self.head.uri; if uri.host().is_none() { - return Either::A(err(InvalidUrl::MissingHost.into())); - } else if uri.scheme_part().is_none() { - return Either::A(err(InvalidUrl::MissingScheme.into())); - } else if let Some(scheme) = uri.scheme_part() { + return Err(InvalidUrl::MissingHost.into()); + } else if uri.scheme().is_none() { + return Err(InvalidUrl::MissingScheme.into()); + } else if let Some(scheme) = uri.scheme() { match scheme.as_str() { "http" | "ws" | "https" | "wss" => (), - _ => return Either::A(err(InvalidUrl::UnknownScheme.into())), + _ => return Err(InvalidUrl::UnknownScheme.into()), } } else { - return Either::A(err(InvalidUrl::UnknownScheme.into())); + return Err(InvalidUrl::UnknownScheme.into()); } if !self.head.headers.contains_key(header::HOST) { - self.head.headers.insert(header::HOST, HeaderValue::from_str(uri.host().unwrap()).unwrap()); + self.head.headers.insert( + header::HOST, + HeaderValue::from_str(uri.host().unwrap()).unwrap(), + ); } // set cookies @@ -291,95 +298,88 @@ impl WebsocketsRequest { .config .connector .borrow_mut() - .open_tunnel(head, self.addr) - .from_err() - .and_then(move |(head, framed)| { - // verify response - if head.status != StatusCode::SWITCHING_PROTOCOLS { - return Err(WsClientError::InvalidResponseStatus(head.status)); - } - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_ascii_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - log::trace!("Invalid upgrade header"); - return Err(WsClientError::InvalidUpgradeHeader); - } - // Check for "CONNECTION" header - if let Some(conn) = head.headers.get(&header::CONNECTION) { - if let Ok(s) = conn.to_str() { - if !s.to_ascii_lowercase().contains("upgrade") { - log::trace!("Invalid connection header: {}", s); - return Err(WsClientError::InvalidConnectionHeader( - conn.clone(), - )); - } - } else { - log::trace!("Invalid connection header: {:?}", conn); - return Err(WsClientError::InvalidConnectionHeader( - conn.clone(), - )); - } - } else { - log::trace!("Missing connection header"); - return Err(WsClientError::MissingConnectionHeader); - } - - if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { - let encoded = ws::hash_key(key.as_ref()); - if hdr_key.as_bytes() != encoded.as_bytes() { - log::trace!( - "Invalid challenge response: expected: {} received: {:?}", - encoded, - key - ); - return Err(WsClientError::InvalidChallengeResponse( - encoded, - hdr_key.clone(), - )); - } - } else { - log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); - return Err(WsClientError::MissingWebSocketAcceptHeader); - }; - - // response and ws framed - Ok(( - ClientResponse::new(head, Payload::None), - framed.map_codec(|_| { - if server_mode { - ws::Codec::new().max_size(max_size) - } else { - ws::Codec::new().max_size(max_size).client_mode() - } - }), - )) - }); + .open_tunnel(head, self.addr); // set request timeout - if let Some(timeout) = self.config.timeout { - Either::B(Either::A(Timeout::new(fut, timeout).map_err(|e| { - if let Some(e) = e.into_inner() { - e - } else { - SendRequestError::Timeout.into() - } - }))) + let (head, framed) = if let Some(to) = self.config.timeout { + timeout(to, fut) + .await + .map_err(|_| SendRequestError::Timeout) + .and_then(|res| res)? } else { - Either::B(Either::B(fut)) + fut.await? + }; + + // verify response + if head.status != StatusCode::SWITCHING_PROTOCOLS { + return Err(WsClientError::InvalidResponseStatus(head.status)); } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + log::trace!("Invalid upgrade header"); + return Err(WsClientError::InvalidUpgradeHeader); + } + + // Check for "CONNECTION" header + if let Some(conn) = head.headers.get(&header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_ascii_lowercase().contains("upgrade") { + log::trace!("Invalid connection header: {}", s); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Invalid connection header: {:?}", conn); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Missing connection header"); + return Err(WsClientError::MissingConnectionHeader); + } + + if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { + let encoded = ws::hash_key(key.as_ref()); + if hdr_key.as_bytes() != encoded.as_bytes() { + log::trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, + key + ); + return Err(WsClientError::InvalidChallengeResponse( + encoded, + hdr_key.clone(), + )); + } + } else { + log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(WsClientError::MissingWebSocketAcceptHeader); + }; + + // response and ws framed + Ok(( + ClientResponse::new(head, Payload::None), + framed.map_codec(|_| { + if server_mode { + ws::Codec::new().max_size(max_size) + } else { + ws::Codec::new().max_size(max_size).client_mode() + } + }), + )) } } impl fmt::Debug for WebsocketsRequest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "\nWebsocketsRequest {}:{}", @@ -398,16 +398,16 @@ mod tests { use super::*; use crate::Client; - #[test] - fn test_debug() { + #[actix_rt::test] + async fn test_debug() { let request = Client::new().ws("/").header("x-test", "111"); let repr = format!("{:?}", request); assert!(repr.contains("WebsocketsRequest")); assert!(repr.contains("x-test")); } - #[test] - fn test_header_override() { + #[actix_rt::test] + async fn test_header_override() { let req = Client::build() .header(header::CONTENT_TYPE, "111") .finish() @@ -425,8 +425,8 @@ mod tests { ); } - #[test] - fn basic_auth() { + #[actix_rt::test] + async fn basic_auth() { let req = Client::new() .ws("/") .basic_auth("username", Some("password")); @@ -452,8 +452,8 @@ mod tests { ); } - #[test] - fn bearer_auth() { + #[actix_rt::test] + async fn bearer_auth() { let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n"); assert_eq!( req.head @@ -467,8 +467,8 @@ mod tests { let _ = req.connect(); } - #[test] - fn basics() { + #[actix_rt::test] + async fn basics() { let req = Client::new() .ws("http://localhost/") .origin("test-origin") @@ -490,14 +490,10 @@ mod tests { header::HeaderValue::from_static("json") ); - let _ = actix_http_test::block_fn(move || req.connect()); + let _ = req.connect().await; - assert!(Client::new().ws("/").connect().poll().is_err()); - assert!(Client::new().ws("http:///test").connect().poll().is_err()); - assert!(Client::new() - .ws("hmm://test.com/") - .connect() - .poll() - .is_err()); + assert!(Client::new().ws("/").connect().await.is_err()); + assert!(Client::new().ws("http:///test").connect().await.is_err()); + assert!(Client::new().ws("hmm://test.com/").connect().await.is_err()); } } diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index cb38c7315..8fb04b005 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -9,15 +9,18 @@ use bytes::Bytes; use flate2::read::GzDecoder; use flate2::write::GzEncoder; use flate2::Compression; -use futures::Future; +use futures::future::ok; use rand::Rng; use actix_http::HttpService; -use actix_http_test::TestServer; -use actix_service::{service_fn, NewService}; +use actix_http_test::test_server; +use actix_service::{map_config, pipeline_factory}; +use actix_web::dev::{AppConfig, BodyEncoding}; use actix_web::http::Cookie; -use actix_web::middleware::{BodyEncoding, Compress}; -use actix_web::{http::header, web, App, Error, HttpMessage, HttpRequest, HttpResponse}; +use actix_web::middleware::Compress; +use actix_web::{ + http::header, test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, +}; use awc::error::SendRequestError; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ @@ -42,99 +45,104 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World"; -#[test] -fn test_simple() { - let mut srv = - TestServer::new(|| { - HttpService::new(App::new().service( - web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), - )) - }); +#[actix_rt::test] +async fn test_simple() { + let srv = test::start(|| { + App::new() + .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) + }); let request = srv.get("/").header("x-test", "111").send(); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - let mut response = srv.block_on(srv.post("/").send()).unwrap(); + let mut response = srv.post("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); // camel case - let response = srv.block_on(srv.post("/").camel_case().send()).unwrap(); + let response = srv.post("/").camel_case().send().await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn test_json() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service( +#[actix_rt::test] +async fn test_json() { + let srv = test::start(|| { + App::new().service( web::resource("/").route(web::to(|_: web::Json| HttpResponse::Ok())), - )) + ) }); let request = srv .get("/") .header("x-test", "111") .send_json(&"TEST".to_string()); - let response = srv.block_on(request).unwrap(); + let response = request.await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn test_form() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to( +#[actix_rt::test] +async fn test_form() { + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to( |_: web::Form>| HttpResponse::Ok(), - )))) + ))) }); let mut data = HashMap::new(); let _ = data.insert("key".to_string(), "TEST".to_string()); let request = srv.get("/").header("x-test", "111").send_form(&data); - let response = srv.block_on(request).unwrap(); + let response = request.await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn test_timeout() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to_async( - || { - tokio_timer::sleep(Duration::from_millis(200)) - .then(|_| Ok::<_, Error>(HttpResponse::Ok().body(STR))) - }, - )))) +#[actix_rt::test] +async fn test_timeout() { + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| { + async { + actix_rt::time::delay_for(Duration::from_millis(200)).await; + Ok::<_, Error>(HttpResponse::Ok().body(STR)) + } + }))) }); - let client = srv.execute(|| { - awc::Client::build() - .timeout(Duration::from_millis(50)) - .finish() - }); + let connector = awc::Connector::new() + .connector(actix_connect::new_connector( + actix_connect::start_default_resolver(), + )) + .timeout(Duration::from_secs(15)) + .finish(); + + let client = awc::Client::build() + .connector(connector) + .timeout(Duration::from_millis(50)) + .finish(); + let request = client.get(srv.url("/")).send(); - match srv.block_on(request) { + match request.await { Err(SendRequestError::Timeout) => (), _ => panic!(), } } -#[test] -fn test_timeout_override() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to_async( - || { - tokio_timer::sleep(Duration::from_millis(200)) - .then(|_| Ok::<_, Error>(HttpResponse::Ok().body(STR))) - }, - )))) +#[actix_rt::test] +async fn test_timeout_override() { + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| { + async { + actix_rt::time::delay_for(Duration::from_millis(200)).await; + Ok::<_, Error>(HttpResponse::Ok().body(STR)) + } + }))) }); let client = awc::Client::build() @@ -144,125 +152,145 @@ fn test_timeout_override() { .get(srv.url("/")) .timeout(Duration::from_millis(50)) .send(); - match srv.block_on(request) { + match request.await { Err(SendRequestError::Timeout) => (), _ => panic!(), } } -#[test] -fn test_connection_reuse() { +#[actix_rt::test] +async fn test_connection_reuse() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let mut srv = TestServer::new(move || { + let srv = test_server(move || { let num2 = num2.clone(); - service_fn(move |io| { + pipeline_factory(move |io| { num2.fetch_add(1, Ordering::Relaxed); - Ok(io) + ok(io) }) - .and_then(HttpService::new( - App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), - )) + .and_then( + HttpService::new(map_config( + App::new() + .service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), + |_| AppConfig::default(), + )) + .tcp(), + ) }); let client = awc::Client::default(); // req 1 let request = client.get(srv.url("/")).send(); - let response = srv.block_on(request).unwrap(); + let response = request.await.unwrap(); assert!(response.status().is_success()); // req 2 let req = client.post(srv.url("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); + let response = req.send().await.unwrap(); assert!(response.status().is_success()); // one connection assert_eq!(num.load(Ordering::Relaxed), 1); } -#[test] -fn test_connection_force_close() { +#[actix_rt::test] +async fn test_connection_force_close() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let mut srv = TestServer::new(move || { + let srv = test_server(move || { let num2 = num2.clone(); - service_fn(move |io| { + pipeline_factory(move |io| { num2.fetch_add(1, Ordering::Relaxed); - Ok(io) + ok(io) }) - .and_then(HttpService::new( - App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), - )) + .and_then( + HttpService::new(map_config( + App::new() + .service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), + |_| AppConfig::default(), + )) + .tcp(), + ) }); let client = awc::Client::default(); // req 1 let request = client.get(srv.url("/")).force_close().send(); - let response = srv.block_on(request).unwrap(); + let response = request.await.unwrap(); assert!(response.status().is_success()); // req 2 let req = client.post(srv.url("/")).force_close(); - let response = srv.block_on_fn(move || req.send()).unwrap(); + let response = req.send().await.unwrap(); assert!(response.status().is_success()); // two connection assert_eq!(num.load(Ordering::Relaxed), 2); } -#[test] -fn test_connection_server_close() { +#[actix_rt::test] +async fn test_connection_server_close() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let mut srv = TestServer::new(move || { + let srv = test_server(move || { let num2 = num2.clone(); - service_fn(move |io| { + pipeline_factory(move |io| { num2.fetch_add(1, Ordering::Relaxed); - Ok(io) + ok(io) }) - .and_then(HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(|| HttpResponse::Ok().force_close().finish())), - ), - )) + .and_then( + HttpService::new(map_config( + App::new().service( + web::resource("/") + .route(web::to(|| HttpResponse::Ok().force_close().finish())), + ), + |_| AppConfig::default(), + )) + .tcp(), + ) }); let client = awc::Client::default(); // req 1 let request = client.get(srv.url("/")).send(); - let response = srv.block_on(request).unwrap(); + let response = request.await.unwrap(); assert!(response.status().is_success()); // req 2 let req = client.post(srv.url("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); + let response = req.send().await.unwrap(); assert!(response.status().is_success()); // two connection assert_eq!(num.load(Ordering::Relaxed), 2); } -#[test] -fn test_connection_wait_queue() { +#[actix_rt::test] +async fn test_connection_wait_queue() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let mut srv = TestServer::new(move || { + let srv = test_server(move || { let num2 = num2.clone(); - service_fn(move |io| { + pipeline_factory(move |io| { num2.fetch_add(1, Ordering::Relaxed); - Ok(io) + ok(io) }) - .and_then(HttpService::new(App::new().service( - web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), - ))) + .and_then( + HttpService::new(map_config( + App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), + ), + |_| AppConfig::default(), + )) + .tcp(), + ) }); let client = awc::Client::build() @@ -271,46 +299,46 @@ fn test_connection_wait_queue() { // req 1 let request = client.get(srv.url("/")).send(); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // req 2 let req2 = client.post(srv.url("/")); - let req2_fut = srv.execute(move || { - let mut fut = req2.send(); - assert!(fut.poll().unwrap().is_not_ready()); - fut - }); + let req2_fut = req2.send(); // read response 1 - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); // req 2 - let response = srv.block_on(req2_fut).unwrap(); + let response = req2_fut.await.unwrap(); assert!(response.status().is_success()); // two connection assert_eq!(num.load(Ordering::Relaxed), 1); } -#[test] -fn test_connection_wait_queue_force_close() { +#[actix_rt::test] +async fn test_connection_wait_queue_force_close() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let mut srv = TestServer::new(move || { + let srv = test_server(move || { let num2 = num2.clone(); - service_fn(move |io| { + pipeline_factory(move |io| { num2.fetch_add(1, Ordering::Relaxed); - Ok(io) + ok(io) }) - .and_then(HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(|| HttpResponse::Ok().force_close().body(STR))), - ), - )) + .and_then( + HttpService::new(map_config( + App::new().service( + web::resource("/") + .route(web::to(|| HttpResponse::Ok().force_close().body(STR))), + ), + |_| AppConfig::default(), + )) + .tcp(), + ) }); let client = awc::Client::build() @@ -319,68 +347,67 @@ fn test_connection_wait_queue_force_close() { // req 1 let request = client.get(srv.url("/")).send(); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // req 2 let req2 = client.post(srv.url("/")); - let req2_fut = srv.execute(move || { - let mut fut = req2.send(); - assert!(fut.poll().unwrap().is_not_ready()); - fut - }); + let req2_fut = req2.send(); // read response 1 - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); // req 2 - let response = srv.block_on(req2_fut).unwrap(); + let response = req2_fut.await.unwrap(); assert!(response.status().is_success()); // two connection assert_eq!(num.load(Ordering::Relaxed), 2); } -#[test] -fn test_with_query_parameter() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").to( - |req: HttpRequest| { - if req.query_string().contains("qp") { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() - } - }, - ))) +#[actix_rt::test] +async fn test_with_query_parameter() { + let srv = test::start(|| { + App::new().service(web::resource("/").to(|req: HttpRequest| { + if req.query_string().contains("qp") { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + })) }); - let res = srv - .block_on(awc::Client::new().get(srv.url("/?qp=5")).send()) + let res = awc::Client::new() + .get(srv.url("/?qp=5")) + .send() + .await .unwrap(); assert!(res.status().is_success()); } -#[test] -fn test_no_decompress() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().wrap(Compress::default()).service( - web::resource("/").route(web::to(|| { +#[actix_rt::test] +async fn test_no_decompress() { + let srv = test::start(|| { + App::new() + .wrap(Compress::default()) + .service(web::resource("/").route(web::to(|| { let mut res = HttpResponse::Ok().body(STR); res.encoding(header::ContentEncoding::Gzip); res - })), - )) + }))) }); - let mut res = srv - .block_on(awc::Client::new().get(srv.url("/")).no_decompress().send()) + let mut res = awc::Client::new() + .get(srv.url("/")) + .no_decompress() + .send() + .await .unwrap(); assert!(res.status().is_success()); // read response - let bytes = srv.block_on(res.body()).unwrap(); + let bytes = res.body().await.unwrap(); let mut e = GzDecoder::new(&bytes[..]); let mut dec = Vec::new(); @@ -388,22 +415,25 @@ fn test_no_decompress() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); // POST - let mut res = srv - .block_on(awc::Client::new().post(srv.url("/")).no_decompress().send()) + let mut res = awc::Client::new() + .post(srv.url("/")) + .no_decompress() + .send() + .await .unwrap(); assert!(res.status().is_success()); - let bytes = srv.block_on(res.body()).unwrap(); + let bytes = res.body().await.unwrap(); let mut e = GzDecoder::new(&bytes[..]); let mut dec = Vec::new(); e.read_to_end(&mut dec).unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -fn test_client_gzip_encoding() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to(|| { +#[actix_rt::test] +async fn test_client_gzip_encoding() { + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| { let mut e = GzEncoder::new(Vec::new(), Compression::default()); e.write_all(STR.as_ref()).unwrap(); let data = e.finish().unwrap(); @@ -411,22 +441,22 @@ fn test_client_gzip_encoding() { HttpResponse::Ok() .header("content-encoding", "gzip") .body(data) - })))) + }))) }); // client request - let mut response = srv.block_on(srv.post("/").send()).unwrap(); + let mut response = srv.post("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_client_gzip_encoding_large() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to(|| { +#[actix_rt::test] +async fn test_client_gzip_encoding_large() { + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| { let mut e = GzEncoder::new(Vec::new(), Compression::default()); e.write_all(STR.repeat(10).as_ref()).unwrap(); let data = e.finish().unwrap(); @@ -434,109 +464,98 @@ fn test_client_gzip_encoding_large() { HttpResponse::Ok() .header("content-encoding", "gzip") .body(data) - })))) + }))) }); // client request - let mut response = srv.block_on(srv.post("/").send()).unwrap(); + let mut response = srv.post("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(STR.repeat(10))); } -#[test] -fn test_client_gzip_encoding_large_random() { +#[actix_rt::test] +async fn test_client_gzip_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&rand::distributions::Alphanumeric) .take(100_000) .collect::(); - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to( - |data: Bytes| { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "gzip") - .body(data) - }, - )))) + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|data: Bytes| { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + }))) }); // client request - let mut response = srv.block_on(srv.post("/").send_body(data.clone())).unwrap(); + let mut response = srv.post("/").send_body(data.clone()).await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -#[test] -fn test_client_brotli_encoding() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().service(web::resource("/").route(web::to( - |data: Bytes| { - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "br") - .body(data) - }, - )))) +#[actix_rt::test] +async fn test_client_brotli_encoding() { + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|data: Bytes| { + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "br") + .body(data) + }))) }); // client request - let mut response = srv.block_on(srv.post("/").send_body(STR)).unwrap(); + let mut response = srv.post("/").send_body(STR).await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -// #[test] -// fn test_client_brotli_encoding_large_random() { -// let data = rand::thread_rng() -// .sample_iter(&rand::distributions::Alphanumeric) -// .take(70_000) -// .collect::(); +#[actix_rt::test] +async fn test_client_brotli_encoding_large_random() { + let data = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(70_000) + .collect::(); -// let mut srv = test::TestServer::new(|app| { -// app.handler(|req: &HttpRequest| { -// req.body() -// .and_then(move |bytes: Bytes| { -// Ok(HttpResponse::Ok() -// .content_encoding(http::ContentEncoding::Gzip) -// .body(bytes)) -// }) -// .responder() -// }) -// }); + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|data: Bytes| { + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "br") + .body(data) + }))) + }); -// // client request -// let request = srv -// .client(http::Method::POST, "/") -// .content_encoding(http::ContentEncoding::Br) -// .body(data.clone()) -// .unwrap(); -// let response = srv.execute(request.send()).unwrap(); -// assert!(response.status().is_success()); + // client request + let mut response = srv.post("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); -// // read response -// let bytes = srv.execute(response.body()).unwrap(); -// assert_eq!(bytes.len(), data.len()); -// assert_eq!(bytes, Bytes::from(data)); -// } + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); +} -// #[cfg(feature = "brotli")] -// #[test] -// fn test_client_deflate_encoding() { -// let mut srv = test::TestServer::new(|app| { +// #[actix_rt::test] +// async fn test_client_deflate_encoding() { +// let srv = test::TestServer::start(|app| { // app.handler(|req: &HttpRequest| { // req.body() // .and_then(|bytes: Bytes| { @@ -562,14 +581,14 @@ fn test_client_brotli_encoding() { // assert_eq!(bytes, Bytes::from_static(STR.as_ref())); // } -// #[test] -// fn test_client_deflate_encoding_large_random() { +// #[actix_rt::test] +// async fn test_client_deflate_encoding_large_random() { // let data = rand::thread_rng() // .sample_iter(&rand::distributions::Alphanumeric) // .take(70_000) // .collect::(); -// let mut srv = test::TestServer::new(|app| { +// let srv = test::TestServer::start(|app| { // app.handler(|req: &HttpRequest| { // req.body() // .and_then(|bytes: Bytes| { @@ -595,9 +614,9 @@ fn test_client_brotli_encoding() { // assert_eq!(bytes, Bytes::from(data)); // } -// #[test] -// fn test_client_streaming_explicit() { -// let mut srv = test::TestServer::new(|app| { +// #[actix_rt::test] +// async fn test_client_streaming_explicit() { +// let srv = test::TestServer::start(|app| { // app.handler(|req: &HttpRequest| { // req.body() // .map_err(Error::from) @@ -622,9 +641,9 @@ fn test_client_brotli_encoding() { // assert_eq!(bytes, Bytes::from_static(STR.as_ref())); // } -// #[test] -// fn test_body_streaming_implicit() { -// let mut srv = test::TestServer::new(|app| { +// #[actix_rt::test] +// async fn test_body_streaming_implicit() { +// let srv = test::TestServer::start(|app| { // app.handler(|_| { // let body = once(Ok(Bytes::from_static(STR.as_ref()))); // HttpResponse::Ok() @@ -642,13 +661,10 @@ fn test_client_brotli_encoding() { // assert_eq!(bytes, Bytes::from_static(STR.as_ref())); // } -#[test] -fn test_client_cookie_handling() { - fn err() -> Error { - use std::io::{Error as IoError, ErrorKind}; - // stub some generic error - Error::from(IoError::from(ErrorKind::NotFound)) - } +#[actix_rt::test] +async fn test_client_cookie_handling() { + use std::io::{Error as IoError, ErrorKind}; + let cookie1 = Cookie::build("cookie1", "value1").finish(); let cookie2 = Cookie::build("cookie2", "value2") .domain("www.example.org") @@ -660,44 +676,53 @@ fn test_client_cookie_handling() { let cookie1b = cookie1.clone(); let cookie2b = cookie2.clone(); - let mut srv = TestServer::new(move || { + let srv = test::start(move || { let cookie1 = cookie1b.clone(); let cookie2 = cookie2b.clone(); - HttpService::new(App::new().route( + App::new().route( "/", web::to(move |req: HttpRequest| { - // Check cookies were sent correctly - req.cookie("cookie1") - .ok_or_else(err) - .and_then(|c1| { - if c1.value() == "value1" { - Ok(()) - } else { - Err(err()) - } - }) - .and_then(|()| req.cookie("cookie2").ok_or_else(err)) - .and_then(|c2| { - if c2.value() == "value2" { - Ok(()) - } else { - Err(err()) - } - }) - // Send some cookies back - .map(|_| { - HttpResponse::Ok() - .cookie(cookie1.clone()) - .cookie(cookie2.clone()) - .finish() - }) + let cookie1 = cookie1.clone(); + let cookie2 = cookie2.clone(); + + async move { + // Check cookies were sent correctly + let res: Result<(), Error> = req + .cookie("cookie1") + .ok_or(()) + .and_then(|c1| { + if c1.value() == "value1" { + Ok(()) + } else { + Err(()) + } + }) + .and_then(|()| req.cookie("cookie2").ok_or(())) + .and_then(|c2| { + if c2.value() == "value2" { + Ok(()) + } else { + Err(()) + } + }) + .map_err(|_| Error::from(IoError::from(ErrorKind::NotFound))); + + if let Err(e) = res { + Err(e) + } else { + // Send some cookies back + Ok::<_, Error>( + HttpResponse::Ok().cookie(cookie1).cookie(cookie2).finish(), + ) + } + } }), - )) + ) }); let request = srv.get("/").cookie(cookie1.clone()).cookie(cookie2.clone()); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); let c1 = response.cookie("cookie1").expect("Missing cookie1"); assert_eq!(c1, cookie1); @@ -705,7 +730,7 @@ fn test_client_cookie_handling() { assert_eq!(c2, cookie2); } -// #[test] +// #[actix_rt::test] // fn client_read_until_eof() { // let addr = test::TestServer::unused_addr(); @@ -727,18 +752,18 @@ fn test_client_cookie_handling() { // let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) // .finish() // .unwrap(); -// let response = sys.block_on(req.send()).unwrap(); +// let response = req.send().await.unwrap(); // assert!(response.status().is_success()); // // read response -// let bytes = sys.block_on(response.body()).unwrap(); +// let bytes = response.body().await.unwrap(); // assert_eq!(bytes, Bytes::from_static(b"welcome!")); // } -#[test] -fn client_basic_auth() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().route( +#[actix_rt::test] +async fn client_basic_auth() { + let srv = test::start(|| { + App::new().route( "/", web::to(|req: HttpRequest| { if req @@ -754,19 +779,19 @@ fn client_basic_auth() { HttpResponse::BadRequest() } }), - )) + ) }); // set authorization header to Basic let request = srv.get("/").basic_auth("username", Some("password")); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); } -#[test] -fn client_bearer_auth() { - let mut srv = TestServer::new(|| { - HttpService::new(App::new().route( +#[actix_rt::test] +async fn client_bearer_auth() { + let srv = test::start(|| { + App::new().route( "/", web::to(|req: HttpRequest| { if req @@ -782,11 +807,11 @@ fn client_bearer_auth() { HttpResponse::BadRequest() } }), - )) + ) }); // set authorization header to Bearer let request = srv.get("/").bearer_auth("someS3cr3tAutht0k3n"); - let response = srv.block_on(request.send()).unwrap(); + let response = request.send().await.unwrap(); assert!(response.status().is_success()); } diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs index e65e4e874..1d7eb7bc5 100644 --- a/awc/tests/test_rustls_client.rs +++ b/awc/tests/test_rustls_client.rs @@ -1,69 +1,74 @@ -#![cfg(feature = "rust-tls")] -use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - ClientConfig, NoClientAuth, -}; - -use std::fs::File; -use std::io::{BufReader, Result}; +#![cfg(feature = "rustls")] use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::HttpService; -use actix_http_test::TestServer; -use actix_server::ssl::RustlsAcceptor; -use actix_service::{service_fn, NewService}; +use actix_http_test::test_server; +use actix_service::{map_config, pipeline_factory, ServiceFactory}; use actix_web::http::Version; -use actix_web::{web, App, HttpResponse}; +use actix_web::{dev::AppConfig, web, App, HttpResponse}; +use futures::future::ok; +use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode}; +use rust_tls::ClientConfig; -fn ssl_acceptor() -> Result> { - use rustls::ServerConfig; +fn ssl_acceptor() -> SslAcceptor { // load ssl keys - let mut config = ServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - let protos = vec![b"h2".to_vec()]; - config.set_protocols(&protos); - Ok(RustlsAcceptor::new(config)) + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder.set_verify_callback(SslVerifyMode::NONE, |_, _| true); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(open_ssl::ssl::AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2").unwrap(); + builder.build() } mod danger { pub struct NoCertificateVerification {} - impl rustls::ServerCertVerifier for NoCertificateVerification { + impl rust_tls::ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, - _roots: &rustls::RootCertStore, - _presented_certs: &[rustls::Certificate], + _roots: &rust_tls::RootCertStore, + _presented_certs: &[rust_tls::Certificate], _dns_name: webpki::DNSNameRef<'_>, _ocsp: &[u8], - ) -> Result { - Ok(rustls::ServerCertVerified::assertion()) + ) -> Result { + Ok(rust_tls::ServerCertVerified::assertion()) } } } -#[test] -fn test_connection_reuse_h2() { - let rustls = ssl_acceptor().unwrap(); +// #[actix_rt::test] +async fn _test_connection_reuse_h2() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let mut srv = TestServer::new(move || { + let srv = test_server(move || { let num2 = num2.clone(); - service_fn(move |io| { + pipeline_factory(move |io| { num2.fetch_add(1, Ordering::Relaxed); - Ok(io) + ok(io) }) - .and_then(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) .and_then( HttpService::build() - .h2(App::new() - .service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) + .h2(map_config( + App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok())), + ), + |_| AppConfig::default(), + )) + .openssl(ssl_acceptor()) .map_err(|_| ()), ) }); @@ -82,12 +87,12 @@ fn test_connection_reuse_h2() { // req 1 let request = client.get(srv.surl("/")).send(); - let response = srv.block_on(request).unwrap(); + let response = request.await.unwrap(); assert!(response.status().is_success()); // req 2 let req = client.post(srv.surl("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); + let response = req.send().await.unwrap(); assert!(response.status().is_success()); assert_eq!(response.version(), Version::HTTP_2); diff --git a/awc/tests/test_ssl_client.rs b/awc/tests/test_ssl_client.rs index e6b0101b2..d3995b4be 100644 --- a/awc/tests/test_ssl_client.rs +++ b/awc/tests/test_ssl_client.rs @@ -1,19 +1,16 @@ -#![cfg(feature = "ssl")] -use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; - -use std::io::Result; +#![cfg(feature = "openssl")] use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use actix_codec::{AsyncRead, AsyncWrite}; use actix_http::HttpService; -use actix_http_test::TestServer; -use actix_server::ssl::OpensslAcceptor; -use actix_service::{service_fn, NewService}; +use actix_http_test::test_server; +use actix_service::{map_config, pipeline_factory, ServiceFactory}; use actix_web::http::Version; -use actix_web::{web, App, HttpResponse}; +use actix_web::{dev::AppConfig, web, App, HttpResponse}; +use futures::future::ok; +use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; -fn ssl_acceptor() -> Result> { +fn ssl_acceptor() -> SslAcceptor { // load ssl keys let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); builder @@ -27,34 +24,33 @@ fn ssl_acceptor() -> Result> { if protos.windows(3).any(|window| window == H2) { Ok(b"h2") } else { - Err(openssl::ssl::AlpnError::NOACK) + Err(open_ssl::ssl::AlpnError::NOACK) } }); - builder.set_alpn_protos(b"\x02h2")?; - Ok(actix_server::ssl::OpensslAcceptor::new(builder.build())) + builder.set_alpn_protos(b"\x02h2").unwrap(); + builder.build() } -#[test] -fn test_connection_reuse_h2() { - let openssl = ssl_acceptor().unwrap(); +#[actix_rt::test] +async fn test_connection_reuse_h2() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let mut srv = TestServer::new(move || { + let srv = test_server(move || { let num2 = num2.clone(); - service_fn(move |io| { + pipeline_factory(move |io| { num2.fetch_add(1, Ordering::Relaxed); - Ok(io) + ok(io) }) - .and_then( - openssl - .clone() - .map_err(|e| println!("Openssl error: {}", e)), - ) .and_then( HttpService::build() - .h2(App::new() - .service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) + .h2(map_config( + App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok())), + ), + |_| AppConfig::default(), + )) + .openssl(ssl_acceptor()) .map_err(|_| ()), ) }); @@ -72,12 +68,12 @@ fn test_connection_reuse_h2() { // req 1 let request = client.get(srv.surl("/")).send(); - let response = srv.block_on(request).unwrap(); + let response = request.await.unwrap(); assert!(response.status().is_success()); // req 2 let req = client.post(srv.surl("/")); - let response = srv.block_on_fn(move || req.send()).unwrap(); + let response = req.send().await.unwrap(); assert!(response.status().is_success()); assert_eq!(response.version(), Version::HTTP_2); diff --git a/awc/tests/test_ws.rs b/awc/tests/test_ws.rs index 5abf96355..ee937e43e 100644 --- a/awc/tests/test_ws.rs +++ b/awc/tests/test_ws.rs @@ -2,81 +2,69 @@ use std::io; use actix_codec::Framed; use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response}; -use actix_http_test::TestServer; -use bytes::{Bytes, BytesMut}; +use actix_http_test::test_server; +use bytes::Bytes; use futures::future::ok; -use futures::{Future, Sink, Stream}; +use futures::{SinkExt, StreamExt}; -fn ws_service(req: ws::Frame) -> impl Future { +async fn ws_service(req: ws::Frame) -> Result { match req { - ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)), - ws::Frame::Text(text) => { - let text = if let Some(pl) = text { - String::from_utf8(Vec::from(pl.as_ref())).unwrap() - } else { - String::new() - }; - ok(ws::Message::Text(text)) - } - ws::Frame::Binary(bin) => ok(ws::Message::Binary( - bin.map(|e| e.freeze()) - .unwrap_or_else(|| Bytes::from("")) - .into(), + ws::Frame::Ping(msg) => Ok(ws::Message::Pong(msg)), + ws::Frame::Text(text) => Ok(ws::Message::Text( + String::from_utf8(Vec::from(text.as_ref())).unwrap(), )), - ws::Frame::Close(reason) => ok(ws::Message::Close(reason)), - _ => ok(ws::Message::Close(None)), + ws::Frame::Binary(bin) => Ok(ws::Message::Binary(bin)), + ws::Frame::Close(reason) => Ok(ws::Message::Close(reason)), + _ => Ok(ws::Message::Close(None)), } } -#[test] -fn test_simple() { - let mut srv = TestServer::new(|| { +#[actix_rt::test] +async fn test_simple() { + let mut srv = test_server(|| { HttpService::build() - .upgrade(|(req, framed): (Request, Framed<_, _>)| { - let res = ws::handshake_response(req.head()).finish(); - // send handshake response - framed - .send(h1::Message::Item((res.drop_body(), BodySize::None))) - .map_err(|e: io::Error| e.into()) - .and_then(|framed| { - // start websocket service - let framed = framed.into_framed(ws::Codec::new()); - ws::Transport::with(framed, ws_service) - }) + .upgrade(|(req, mut framed): (Request, Framed<_, _>)| { + async move { + let res = ws::handshake_response(req.head()).finish(); + // send handshake response + framed + .send(h1::Message::Item((res.drop_body(), BodySize::None))) + .await?; + + // start websocket service + let framed = framed.into_framed(ws::Codec::new()); + ws::Dispatcher::with(framed, ws_service).await + } }) .finish(|_| ok::<_, Error>(Response::NotFound())) + .tcp() }); // client service - let framed = srv.ws().unwrap(); - let framed = srv - .block_on(framed.send(ws::Message::Text("text".to_string()))) + let mut framed = srv.ws().await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); - let framed = srv - .block_on(framed.send(ws::Message::Binary("text".into()))) + framed + .send(ws::Message::Binary("text".into())) + .await .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) - ); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); - let framed = srv - .block_on(framed.send(ws::Message::Ping("text".into()))) - .unwrap(); - let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Pong("text".to_string().into())); - let framed = srv - .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await .unwrap(); - let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); - assert_eq!( - item, - Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) - ); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); } diff --git a/examples/basic.rs b/examples/basic.rs index 46440d706..bd6f8146f 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,26 +1,23 @@ -use futures::IntoFuture; - -use actix_web::{ - get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, -}; +use actix_web::{get, middleware, web, App, HttpRequest, HttpResponse, HttpServer}; #[get("/resource1/{name}/index.html")] -fn index(req: HttpRequest, name: web::Path) -> String { +async fn index(req: HttpRequest, name: web::Path) -> String { println!("REQ: {:?}", req); format!("Hello: {}!\r\n", name) } -fn index_async(req: HttpRequest) -> impl IntoFuture { +async fn index_async(req: HttpRequest) -> &'static str { println!("REQ: {:?}", req); - Ok("Hello world!\r\n") -} - -#[get("/")] -fn no_params() -> &'static str { "Hello world!\r\n" } -fn main() -> std::io::Result<()> { +#[get("/")] +async fn no_params() -> &'static str { + "Hello world!\r\n" +} + +#[actix_rt::main] +async fn main() -> std::io::Result<()> { std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); env_logger::init(); @@ -39,11 +36,12 @@ fn main() -> std::io::Result<()> { .default_service( web::route().to(|| HttpResponse::MethodNotAllowed()), ) - .route(web::get().to_async(index_async)), + .route(web::get().to(index_async)), ) - .service(web::resource("/test1.html").to(|| "Test\r\n")) + .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) }) .bind("127.0.0.1:8080")? .workers(1) .run() + .await } diff --git a/examples/client.rs b/examples/client.rs index 8a75fd306..874e08e1b 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,26 +1,25 @@ use actix_http::Error; -use actix_rt::System; -use futures::{future::lazy, Future}; -fn main() -> Result<(), Error> { +#[actix_rt::main] +async fn main() -> Result<(), Error> { std::env::set_var("RUST_LOG", "actix_http=trace"); env_logger::init(); - System::new("test").block_on(lazy(|| { - awc::Client::new() - .get("https://www.rust-lang.org/") // <- Create request builder - .header("User-Agent", "Actix-web") - .send() // <- Send http request - .from_err() - .and_then(|mut response| { - // <- server http response - println!("Response: {:?}", response); + let client = awc::Client::new(); - // read response body - response - .body() - .from_err() - .map(|body| println!("Downloaded: {:?} bytes", body.len())) - }) - })) + // Create request builder, configure request and send + let mut response = client + .get("https://www.rust-lang.org/") + .header("User-Agent", "Actix-web") + .send() + .await?; + + // server http response + println!("Response: {:?}", response); + + // read response body + let body = response.body().await?; + println!("Downloaded: {:?} bytes", body.len()); + + Ok(()) } diff --git a/examples/uds.rs b/examples/uds.rs index 9dc82903f..77f245d99 100644 --- a/examples/uds.rs +++ b/examples/uds.rs @@ -1,27 +1,26 @@ -use futures::IntoFuture; - use actix_web::{ get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, }; #[get("/resource1/{name}/index.html")] -fn index(req: HttpRequest, name: web::Path) -> String { +async fn index(req: HttpRequest, name: web::Path) -> String { println!("REQ: {:?}", req); format!("Hello: {}!\r\n", name) } -fn index_async(req: HttpRequest) -> impl IntoFuture { +async fn index_async(req: HttpRequest) -> Result<&'static str, Error> { println!("REQ: {:?}", req); Ok("Hello world!\r\n") } #[get("/")] -fn no_params() -> &'static str { +async fn no_params() -> &'static str { "Hello world!\r\n" } -#[cfg(feature = "uds")] -fn main() -> std::io::Result<()> { +#[cfg(unix)] +#[actix_rt::main] +async fn main() -> std::io::Result<()> { std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); env_logger::init(); @@ -40,14 +39,15 @@ fn main() -> std::io::Result<()> { .default_service( web::route().to(|| HttpResponse::MethodNotAllowed()), ) - .route(web::get().to_async(index_async)), + .route(web::get().to(index_async)), ) - .service(web::resource("/test1.html").to(|| "Test\r\n")) + .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) }) .bind_uds("/Users/fafhrd91/uds-test")? .workers(1) .run() + .await } -#[cfg(not(feature = "uds"))] +#[cfg(not(unix))] fn main() {} diff --git a/src/app.rs b/src/app.rs index f93859c7e..a060eb53e 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,43 +1,45 @@ use std::cell::RefCell; use std::fmt; +use std::future::Future; use std::marker::PhantomData; use std::rc::Rc; use actix_http::body::{Body, MessageBody}; -use actix_service::boxed::{self, BoxedNewService}; +use actix_http::Extensions; +use actix_service::boxed::{self, BoxServiceFactory}; use actix_service::{ - apply_transform, IntoNewService, IntoTransform, NewService, Transform, + apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform, }; -use futures::{Future, IntoFuture}; +use futures::future::{FutureExt, LocalBoxFuture}; use crate::app_service::{AppEntry, AppInit, AppRoutingFactory}; -use crate::config::{AppConfig, AppConfigInner, ServiceConfig}; +use crate::config::ServiceConfig; use crate::data::{Data, DataFactory}; use crate::dev::ResourceDef; use crate::error::Error; use crate::resource::Resource; use crate::route::Route; use crate::service::{ - HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest, + AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, }; -type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; type FnDataFactory = - Box Box, Error = ()>>>; + Box LocalBoxFuture<'static, Result, ()>>>; /// Application builder - structure that follows the builder pattern /// for building application instances. pub struct App { endpoint: T, - services: Vec>, + services: Vec>, default: Option>, factory_ref: Rc>>, data: Vec>, data_factories: Vec, - config: AppConfigInner, external: Vec, - _t: PhantomData<(B)>, + extensions: Extensions, + _t: PhantomData, } impl App { @@ -51,8 +53,8 @@ impl App { services: Vec::new(), default: None, factory_ref: fref, - config: AppConfigInner::default(), external: Vec::new(), + extensions: Extensions::new(), _t: PhantomData, } } @@ -61,7 +63,7 @@ impl App { impl App where B: MessageBody, - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -76,28 +78,28 @@ where /// an application instance. Http server constructs an application /// instance for each thread, thus application data must be constructed /// multiple times. If you want to share data between different - /// threads, a shared object should be used, e.g. `Arc`. Application - /// data does not need to be `Send` or `Sync`. + /// threads, a shared object should be used, e.g. `Arc`. Internally `Data` type + /// uses `Arc` so data could be created outside of app factory and clones could + /// be stored via `App::app_data()` method. /// /// ```rust /// use std::cell::Cell; - /// use actix_web::{web, App}; + /// use actix_web::{web, App, HttpResponse, Responder}; /// /// struct MyData { /// counter: Cell, /// } /// - /// fn index(data: web::Data) { + /// async fn index(data: web::Data) -> impl Responder { /// data.counter.set(data.counter.get() + 1); + /// HttpResponse::Ok() /// } /// - /// fn main() { - /// let app = App::new() - /// .data(MyData{ counter: Cell::new(0) }) - /// .service( - /// web::resource("/index.html").route( - /// web::get().to(index))); - /// } + /// let app = App::new() + /// .data(MyData{ counter: Cell::new(0) }) + /// .service( + /// web::resource("/index.html").route( + /// web::get().to(index))); /// ``` pub fn data(mut self, data: U) -> Self { self.data.push(Box::new(Data::new(data))); @@ -107,32 +109,43 @@ where /// Set application data factory. This function is /// similar to `.data()` but it accepts data factory. Data object get /// constructed asynchronously during application initialization. - pub fn data_factory(mut self, data: F) -> Self + pub fn data_factory(mut self, data: F) -> Self where F: Fn() -> Out + 'static, - Out: IntoFuture + 'static, - Out::Error: std::fmt::Debug, + Out: Future> + 'static, + D: 'static, + E: std::fmt::Debug, { self.data_factories.push(Box::new(move || { - Box::new( - data() - .into_future() - .map_err(|e| { - log::error!("Can not construct data instance: {:?}", e); - }) - .map(|data| { - let data: Box = Box::new(Data::new(data)); - data - }), - ) + { + let fut = data(); + async move { + match fut.await { + Err(e) => { + log::error!("Can not construct data instance: {:?}", e); + Err(()) + } + Ok(data) => { + let data: Box = Box::new(Data::new(data)); + Ok(data) + } + } + } + } + .boxed_local() })); self } - /// Set application data. Application data could be accessed - /// by using `Data` extractor where `T` is data type. - pub fn register_data(mut self, data: Data) -> Self { - self.data.push(Box::new(data)); + /// Set application level arbitrary data item. + /// + /// Application data stored with `App::app_data()` method is available + /// via `HttpRequest::app_data()` method at runtime. + /// + /// This method could be used for storing `Data` as well, in that case + /// data could be accessed by using `Data` extractor. + pub fn app_data(mut self, ext: U) -> Self { + self.extensions.insert(ext); self } @@ -183,7 +196,7 @@ where /// ```rust /// use actix_web::{web, App, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -219,18 +232,6 @@ where self } - /// Set server host name. - /// - /// Host name is used by application router as a hostname for url - /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. - /// html#method.host) documentation for more information. - /// - /// By default host name is set to a "localhost" value. - pub fn hostname(mut self, val: &str) -> Self { - self.config.host = val.to_owned(); - self - } - /// Default service to be used if no matching resource could be found. /// /// It is possible to use services like `Resource`, `Route`. @@ -238,7 +239,7 @@ where /// ```rust /// use actix_web::{web, App, HttpResponse}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// @@ -267,8 +268,8 @@ where /// ``` pub fn default_service(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -277,11 +278,9 @@ where U::InitError: fmt::Debug, { // create and configure default resource - self.default = Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|e| { - log::error!("Can not construct default service: {:?}", e) - }), - ))); + self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err( + |e| log::error!("Can not construct default service: {:?}", e), + )))); self } @@ -295,7 +294,7 @@ where /// ```rust /// use actix_web::{web, App, HttpRequest, HttpResponse, Result}; /// - /// fn index(req: HttpRequest) -> Result { + /// async fn index(req: HttpRequest) -> Result { /// let url = req.url_for("youtube", &["asdlkjqme"])?; /// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme"); /// Ok(HttpResponse::Ok().into()) @@ -336,11 +335,10 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{middleware, web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// @@ -350,11 +348,11 @@ where /// .route("/index.html", web::get().to(index)); /// } /// ``` - pub fn wrap( + pub fn wrap( self, - mw: F, + mw: M, ) -> App< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -372,18 +370,16 @@ where InitError = (), >, B1: MessageBody, - F: IntoTransform, { - let endpoint = apply_transform(mw, self.endpoint); App { - endpoint, + endpoint: apply(mw, self.endpoint), data: self.data, data_factories: self.data_factories, services: self.services, default: self.default, factory_ref: self.factory_ref, - config: self.config, external: self.external, + extensions: self.extensions, _t: PhantomData, } } @@ -397,23 +393,25 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// /// fn main() { /// let app = App::new() - /// .wrap_fn(|req, srv| - /// srv.call(req).map(|mut res| { + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; /// res.headers_mut().insert( /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// ); - /// res - /// })) + /// Ok(res) + /// } + /// }) /// .route("/index.html", web::get().to(index)); /// } /// ``` @@ -421,7 +419,7 @@ where self, mw: F, ) -> App< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -433,16 +431,26 @@ where where B1: MessageBody, F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, - R: IntoFuture, Error = Error>, + R: Future, Error>>, { - self.wrap(mw) + App { + endpoint: apply_fn_factory(self.endpoint, mw), + data: self.data, + data_factories: self.data_factories, + services: self.services, + default: self.default, + factory_ref: self.factory_ref, + external: self.external, + extensions: self.extensions, + _t: PhantomData, + } } } -impl IntoNewService> for App +impl IntoServiceFactory> for App where B: MessageBody, - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -450,7 +458,7 @@ where InitError = (), >, { - fn into_new_service(self) -> AppInit { + fn into_factory(self) -> AppInit { AppInit { data: Rc::new(self.data), data_factories: Rc::new(self.data_factories), @@ -459,7 +467,7 @@ where external: RefCell::new(self.external), default: self.default, factory_ref: self.factory_ref, - config: RefCell::new(AppConfig(Rc::new(self.config))), + extensions: RefCell::new(Some(self.extensions)), } } } @@ -468,27 +476,27 @@ where mod tests { use actix_service::Service; use bytes::Bytes; - use futures::{Future, IntoFuture}; + use futures::future::ok; use super::*; use crate::http::{header, HeaderValue, Method, StatusCode}; - use crate::service::{ServiceRequest, ServiceResponse}; - use crate::test::{ - block_fn, block_on, call_service, init_service, read_body, TestRequest, - }; - use crate::{web, Error, HttpRequest, HttpResponse}; + use crate::middleware::DefaultHeaders; + use crate::service::ServiceRequest; + use crate::test::{call_service, init_service, read_body, TestRequest}; + use crate::{web, HttpRequest, HttpResponse}; - #[test] - fn test_default_resource() { + #[actix_rt::test] + async fn test_default_resource() { let mut srv = init_service( App::new().service(web::resource("/test").to(|| HttpResponse::Ok())), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = block_fn(|| srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/blah").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); let mut srv = init_service( @@ -497,76 +505,79 @@ mod tests { .service( web::resource("/test2") .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::Created()) + ok(r.into_response(HttpResponse::Created())) }) .route(web::get().to(|| HttpResponse::Ok())), ) .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::MethodNotAllowed()) + ok(r.into_response(HttpResponse::MethodNotAllowed())) }), - ); + ) + .await; let req = TestRequest::with_uri("/blah").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); let req = TestRequest::with_uri("/test2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test2") .method(Method::POST) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::CREATED); } - #[test] - fn test_data_factory() { + #[actix_rt::test] + async fn test_data_factory() { let mut srv = - init_service(App::new().data_factory(|| Ok::<_, ()>(10usize)).service( + init_service(App::new().data_factory(|| ok::<_, ()>(10usize)).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let mut srv = - init_service(App::new().data_factory(|| Ok::<_, ()>(10u32)).service( + init_service(App::new().data_factory(|| ok::<_, ()>(10u32)).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } - fn md( - req: ServiceRequest, - srv: &mut S, - ) -> impl IntoFuture, Error = Error> - where - S: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - >, - { - srv.call(req).map(|mut res| { - res.headers_mut() - .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); - res - }) + #[actix_rt::test] + async fn test_extension() { + let mut srv = init_service(App::new().app_data(10usize).service( + web::resource("/").to(|req: HttpRequest| { + assert_eq!(*req.app_data::().unwrap(), 10); + HttpResponse::Ok() + }), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_wrap() { + #[actix_rt::test] + async fn test_wrap() { let mut srv = init_service( App::new() - .wrap(md) + .wrap( + DefaultHeaders::new() + .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + ) .route("/test", web::get().to(|| HttpResponse::Ok())), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -574,15 +585,19 @@ mod tests { ); } - #[test] - fn test_router_wrap() { + #[actix_rt::test] + async fn test_router_wrap() { let mut srv = init_service( App::new() .route("/test", web::get().to(|| HttpResponse::Ok())) - .wrap(md), - ); + .wrap( + DefaultHeaders::new() + .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + ), + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -590,23 +605,26 @@ mod tests { ); } - #[test] - fn test_wrap_fn() { + #[actix_rt::test] + async fn test_wrap_fn() { let mut srv = init_service( App::new() .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("0001"), ); - res - }) + Ok(res) + } }) .service(web::resource("/test").to(|| HttpResponse::Ok())), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -614,23 +632,26 @@ mod tests { ); } - #[test] - fn test_router_wrap_fn() { + #[actix_rt::test] + async fn test_router_wrap_fn() { let mut srv = init_service( App::new() .route("/test", web::get().to(|| HttpResponse::Ok())) .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { + let fut = srv.call(req); + async { + let mut res = fut.await?; res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("0001"), ); - res - }) + Ok(res) + } }), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -638,8 +659,8 @@ mod tests { ); } - #[test] - fn test_external_resource() { + #[actix_rt::test] + async fn test_external_resource() { let mut srv = init_service( App::new() .external_resource("youtube", "https://youtube.com/watch/{video_id}") @@ -652,11 +673,12 @@ mod tests { )) }), ), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); + let body = read_body(resp).await; assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); } } diff --git a/src/app_service.rs b/src/app_service.rs index 736c35010..ccfefbc68 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,14 +1,15 @@ use std::cell::RefCell; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{Extensions, Request, Response}; use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url}; -use actix_server_config::ServerConfig; -use actix_service::boxed::{self, BoxedNewService, BoxedService}; -use actix_service::{service_fn, NewService, Service}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, Poll}; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_service::{fn_service, Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture}; use crate::config::{AppConfig, AppService}; use crate::data::DataFactory; @@ -16,23 +17,20 @@ use crate::error::Error; use crate::guard::Guard; use crate::request::{HttpRequest, HttpRequestPool}; use crate::rmap::ResourceMap; -use crate::service::{ServiceFactory, ServiceRequest, ServiceResponse}; +use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse}; type Guards = Vec>; -type HttpService = BoxedService; -type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; -type BoxedResponse = Either< - FutureResult, - Box>, ->; +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; +type BoxResponse = LocalBoxFuture<'static, Result>; type FnDataFactory = - Box Box, Error = ()>>>; + Box LocalBoxFuture<'static, Result, ()>>>; /// Service factory to convert `Request` to a `ServiceRequest`. /// It also executes data factories. pub struct AppInit where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -41,18 +39,18 @@ where >, { pub(crate) endpoint: T, + pub(crate) extensions: RefCell>, pub(crate) data: Rc>>, pub(crate) data_factories: Rc>, - pub(crate) config: RefCell, - pub(crate) services: Rc>>>, + pub(crate) services: Rc>>>, pub(crate) default: Option>, pub(crate) factory_ref: Rc>>, pub(crate) external: RefCell>, } -impl NewService for AppInit +impl ServiceFactory for AppInit where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -60,7 +58,7 @@ where InitError = (), >, { - type Config = ServerConfig; + type Config = AppConfig; type Request = Request; type Response = ServiceResponse; type Error = T::Error; @@ -68,27 +66,16 @@ where type Service = AppInitService; type Future = AppInitResult; - fn new_service(&self, cfg: &ServerConfig) -> Self::Future { + fn new_service(&self, config: AppConfig) -> Self::Future { // update resource default service let default = self.default.clone().unwrap_or_else(|| { - Rc::new(boxed::new_service(service_fn(|req: ServiceRequest| { - Ok(req.into_response(Response::NotFound().finish())) + Rc::new(boxed::factory(fn_service(|req: ServiceRequest| { + ok(req.into_response(Response::NotFound().finish())) }))) }); // App config - { - let mut c = self.config.borrow_mut(); - let loc_cfg = Rc::get_mut(&mut c.0).unwrap(); - loc_cfg.secure = cfg.secure(); - loc_cfg.addr = cfg.local_addr(); - } - - let mut config = AppService::new( - self.config.borrow().clone(), - default.clone(), - self.data.clone(), - ); + let mut config = AppService::new(config, default.clone(), self.data.clone()); // register services std::mem::replace(&mut *self.services.borrow_mut(), Vec::new()) @@ -124,10 +111,16 @@ where AppInitResult { endpoint: None, - endpoint_fut: self.endpoint.new_service(&()), + endpoint_fut: self.endpoint.new_service(()), data: self.data.clone(), data_factories: Vec::new(), data_factories_fut: self.data_factories.iter().map(|f| f()).collect(), + extensions: Some( + self.extensions + .borrow_mut() + .take() + .unwrap_or_else(Extensions::new), + ), config, rmap, _t: PhantomData, @@ -135,23 +128,26 @@ where } } +#[pin_project::pin_project] pub struct AppInitResult where - T: NewService, + T: ServiceFactory, { endpoint: Option, + #[pin] endpoint_fut: T::Future, rmap: Rc, config: AppConfig, data: Rc>>, data_factories: Vec>, - data_factories_fut: Vec, Error = ()>>>, + data_factories_fut: Vec, ()>>>, + extensions: Option, _t: PhantomData, } impl Future for AppInitResult where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -159,48 +155,49 @@ where InitError = (), >, { - type Item = AppInitService; - type Error = (); + type Output = Result, ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); - fn poll(&mut self) -> Poll { // async data factories let mut idx = 0; - while idx < self.data_factories_fut.len() { - match self.data_factories_fut[idx].poll()? { - Async::Ready(f) => { - self.data_factories.push(f); - self.data_factories_fut.remove(idx); + while idx < this.data_factories_fut.len() { + match Pin::new(&mut this.data_factories_fut[idx]).poll(cx)? { + Poll::Ready(f) => { + this.data_factories.push(f); + let _ = this.data_factories_fut.remove(idx); } - Async::NotReady => idx += 1, + Poll::Pending => idx += 1, } } - if self.endpoint.is_none() { - if let Async::Ready(srv) = self.endpoint_fut.poll()? { - self.endpoint = Some(srv); + if this.endpoint.is_none() { + if let Poll::Ready(srv) = this.endpoint_fut.poll(cx)? { + *this.endpoint = Some(srv); } } - if self.endpoint.is_some() && self.data_factories_fut.is_empty() { + if this.endpoint.is_some() && this.data_factories_fut.is_empty() { // create app data container - let mut data = Extensions::new(); - for f in self.data.iter() { + let mut data = this.extensions.take().unwrap(); + for f in this.data.iter() { f.create(&mut data); } - for f in &self.data_factories { + for f in this.data_factories.iter() { f.create(&mut data); } - Ok(Async::Ready(AppInitService { - service: self.endpoint.take().unwrap(), - rmap: self.rmap.clone(), - config: self.config.clone(), + Poll::Ready(Ok(AppInitService { + service: this.endpoint.take().unwrap(), + rmap: this.rmap.clone(), + config: this.config.clone(), data: Rc::new(data), pool: HttpRequestPool::create(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -226,8 +223,8 @@ where type Error = T::Error; type Future = T::Future; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { @@ -270,7 +267,7 @@ pub struct AppRoutingFactory { default: Rc, } -impl NewService for AppRoutingFactory { +impl ServiceFactory for AppRoutingFactory { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -279,7 +276,7 @@ impl NewService for AppRoutingFactory { type Service = AppRouting; type Future = AppRoutingFactoryResponse; - fn new_service(&self, _: &()) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { AppRoutingFactoryResponse { fut: self .services @@ -288,24 +285,24 @@ impl NewService for AppRoutingFactory { CreateAppRoutingItem::Future( Some(path.clone()), guards.borrow_mut().take(), - service.new_service(&()), + service.new_service(()).boxed_local(), ) }) .collect(), default: None, - default_fut: Some(self.default.new_service(&())), + default_fut: Some(self.default.new_service(())), } } } -type HttpServiceFut = Box>; +type HttpServiceFut = LocalBoxFuture<'static, Result>; /// Create app service #[doc(hidden)] pub struct AppRoutingFactoryResponse { fut: Vec, default: Option, - default_fut: Option>>, + default_fut: Option>>, } enum CreateAppRoutingItem { @@ -314,16 +311,15 @@ enum CreateAppRoutingItem { } impl Future for AppRoutingFactoryResponse { - type Item = AppRouting; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut done = true; if let Some(ref mut fut) = self.default_fut { - match fut.poll()? { - Async::Ready(default) => self.default = Some(default), - Async::NotReady => done = false, + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, } } @@ -334,11 +330,12 @@ impl Future for AppRoutingFactoryResponse { ref mut path, ref mut guards, ref mut fut, - ) => match fut.poll()? { - Async::Ready(service) => { + ) => match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(service)) => { Some((path.take().unwrap(), guards.take(), service)) } - Async::NotReady => { + Poll::Ready(Err(_)) => return Poll::Ready(Err(())), + Poll::Pending => { done = false; None } @@ -364,13 +361,13 @@ impl Future for AppRoutingFactoryResponse { } router }); - Ok(Async::Ready(AppRouting { + Poll::Ready(Ok(AppRouting { ready: None, router: router.finish(), default: self.default.take(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -385,13 +382,13 @@ impl Service for AppRouting { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = BoxedResponse; + type Future = BoxResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { if self.ready.is_none() { - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } else { - Ok(Async::NotReady) + Poll::Pending } } @@ -413,7 +410,7 @@ impl Service for AppRouting { default.call(req) } else { let req = req.into_parts().0; - Either::A(ok(ServiceResponse::new(req, Response::NotFound().finish()))) + ok(ServiceResponse::new(req, Response::NotFound().finish())).boxed_local() } } } @@ -429,7 +426,7 @@ impl AppEntry { } } -impl NewService for AppEntry { +impl ServiceFactory for AppEntry { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -438,20 +435,19 @@ impl NewService for AppEntry { type Service = AppRouting; type Future = AppRoutingFactoryResponse; - fn new_service(&self, _: &()) -> Self::Future { - self.factory.borrow_mut().as_mut().unwrap().new_service(&()) + fn new_service(&self, _: ()) -> Self::Future { + self.factory.borrow_mut().as_mut().unwrap().new_service(()) } } #[cfg(test)] mod tests { - use actix_service::Service; - use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; - use crate::{test, web, App, HttpResponse}; + use crate::test::{init_service, TestRequest}; + use crate::{web, App, HttpResponse}; + use actix_service::Service; struct DropData(Arc); @@ -461,17 +457,19 @@ mod tests { } } - #[test] - fn drop_data() { + #[actix_rt::test] + async fn test_drop_data() { let data = Arc::new(AtomicBool::new(false)); + { - let mut app = test::init_service( + let mut app = init_service( App::new() .data(DropData(data.clone())) .service(web::resource("/test").to(|| HttpResponse::Ok())), - ); - let req = test::TestRequest::with_uri("/test").to_request(); - let _ = test::block_on(app.call(req)).unwrap(); + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let _ = app.call(req).await.unwrap(); } assert!(data.load(Ordering::Relaxed)); } diff --git a/src/config.rs b/src/config.rs index 63fd31d27..6ce96767d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use actix_http::Extensions; use actix_router::ResourceDef; -use actix_service::{boxed, IntoNewService, NewService}; +use actix_service::{boxed, IntoServiceFactory, ServiceFactory}; use crate::data::{Data, DataFactory}; use crate::error::Error; @@ -12,13 +12,13 @@ use crate::resource::Resource; use crate::rmap::ResourceMap; use crate::route::Route; use crate::service::{ - HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest, + AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, }; type Guards = Vec>; type HttpNewService = - boxed::BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; + boxed::BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; /// Application configuration pub struct AppService { @@ -102,11 +102,11 @@ impl AppService { &mut self, rdef: ResourceDef, guards: Option>>, - service: F, + factory: F, nested: Option>, ) where - F: IntoNewService, - S: NewService< + F: IntoServiceFactory, + S: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -116,7 +116,7 @@ impl AppService { { self.services.push(( rdef, - boxed::new_service(service.into_new_service()), + boxed::factory(factory.into_factory()), guards, nested, )); @@ -124,18 +124,24 @@ impl AppService { } #[derive(Clone)] -pub struct AppConfig(pub(crate) Rc); +pub struct AppConfig(Rc); + +struct AppConfigInner { + secure: bool, + host: String, + addr: SocketAddr, +} impl AppConfig { - pub(crate) fn new(inner: AppConfigInner) -> Self { - AppConfig(Rc::new(inner)) + pub(crate) fn new(secure: bool, addr: SocketAddr, host: String) -> Self { + AppConfig(Rc::new(AppConfigInner { secure, addr, host })) } - /// Set server host name. + /// Server host name. /// - /// Host name is used by application router as a hostname for url - /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. - /// html#method.host) documentation for more information. + /// Host name is used by application router as a hostname for url generation. + /// Check [ConnectionInfo](./struct.ConnectionInfo.html#method.host) + /// documentation for more information. /// /// By default host name is set to a "localhost" value. pub fn host(&self) -> &str { @@ -153,19 +159,13 @@ impl AppConfig { } } -pub(crate) struct AppConfigInner { - pub(crate) secure: bool, - pub(crate) host: String, - pub(crate) addr: SocketAddr, -} - -impl Default for AppConfigInner { - fn default() -> AppConfigInner { - AppConfigInner { - secure: false, - addr: "127.0.0.1:8080".parse().unwrap(), - host: "localhost:8080".to_owned(), - } +impl Default for AppConfig { + fn default() -> Self { + AppConfig::new( + false, + "127.0.0.1:8080".parse().unwrap(), + "localhost:8080".to_owned(), + ) } } @@ -174,7 +174,7 @@ impl Default for AppConfigInner { /// to set of external methods. This could help with /// modularization of big application configuration. pub struct ServiceConfig { - pub(crate) services: Vec>, + pub(crate) services: Vec>, pub(crate) data: Vec>, pub(crate) external: Vec, } @@ -246,11 +246,11 @@ mod tests { use super::*; use crate::http::{Method, StatusCode}; - use crate::test::{block_on, call_service, init_service, read_body, TestRequest}; + use crate::test::{call_service, init_service, read_body, TestRequest}; use crate::{web, App, HttpRequest, HttpResponse}; - #[test] - fn test_data() { + #[actix_rt::test] + async fn test_data() { let cfg = |cfg: &mut ServiceConfig| { cfg.data(10usize); }; @@ -258,14 +258,15 @@ mod tests { let mut srv = init_service(App::new().configure(cfg).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - // #[test] - // fn test_data_factory() { + // #[actix_rt::test] + // async fn test_data_factory() { // let cfg = |cfg: &mut ServiceConfig| { // cfg.data_factory(|| { // sleep(std::time::Duration::from_millis(50)).then(|_| { @@ -280,7 +281,7 @@ mod tests { // web::resource("/").to(|_: web::Data| HttpResponse::Ok()), // )); // let req = TestRequest::default().to_request(); - // let resp = block_on(srv.call(req)).unwrap(); + // let resp = srv.call(req).await.unwrap(); // assert_eq!(resp.status(), StatusCode::OK); // let cfg2 = |cfg: &mut ServiceConfig| { @@ -292,12 +293,12 @@ mod tests { // .configure(cfg2), // ); // let req = TestRequest::default().to_request(); - // let resp = block_on(srv.call(req)).unwrap(); + // let resp = srv.call(req).await.unwrap(); // assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); // } - #[test] - fn test_external_resource() { + #[actix_rt::test] + async fn test_external_resource() { let mut srv = init_service( App::new() .configure(|cfg| { @@ -315,33 +316,35 @@ mod tests { )) }), ), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); + let body = read_body(resp).await; assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); } - #[test] - fn test_service() { + #[actix_rt::test] + async fn test_service() { let mut srv = init_service(App::new().configure(|cfg| { cfg.service( web::resource("/test").route(web::get().to(|| HttpResponse::Created())), ) .route("/index.html", web::get().to(|| HttpResponse::Ok())); - })); + })) + .await; let req = TestRequest::with_uri("/test") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::CREATED); let req = TestRequest::with_uri("/index.html") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/src/data.rs b/src/data.rs index 14e293bc2..c36418942 100644 --- a/src/data.rs +++ b/src/data.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use actix_http::error::{Error, ErrorInternalServerError}; use actix_http::Extensions; +use futures::future::{err, ok, Ready}; use crate::dev::Payload; use crate::extract::FromRequest; @@ -37,16 +38,17 @@ pub(crate) trait DataFactory { /// /// ```rust /// use std::sync::Mutex; -/// use actix_web::{web, App}; +/// use actix_web::{web, App, HttpResponse, Responder}; /// /// struct MyData { /// counter: usize, /// } /// /// /// Use `Data` extractor to access data in handler. -/// fn index(data: web::Data>) { +/// async fn index(data: web::Data>) -> impl Responder { /// let mut data = data.lock().unwrap(); /// data.counter += 1; +/// HttpResponse::Ok() /// } /// /// fn main() { @@ -54,7 +56,7 @@ pub(crate) trait DataFactory { /// /// let app = App::new() /// // Store `MyData` in application storage. -/// .register_data(data.clone()) +/// .app_data(data.clone()) /// .service( /// web::resource("/index.html").route( /// web::get().to(index))); @@ -85,10 +87,10 @@ impl Data { } impl Deref for Data { - type Target = T; + type Target = Arc; - fn deref(&self) -> &T { - self.0.as_ref() + fn deref(&self) -> &Arc { + &self.0 } } @@ -101,19 +103,19 @@ impl Clone for Data { impl FromRequest for Data { type Config = (); type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - if let Some(st) = req.get_app_data::() { - Ok(st) + if let Some(st) = req.app_data::>() { + ok(st.clone()) } else { log::debug!( "Failed to construct App-level Data extractor. \ Request path: {:?}", req.path() ); - Err(ErrorInternalServerError( + err(ErrorInternalServerError( "App data is not configured, to configure use App::data()", )) } @@ -134,64 +136,72 @@ impl DataFactory for Data { #[cfg(test)] mod tests { use actix_service::Service; + use std::sync::atomic::{AtomicUsize, Ordering}; use super::*; use crate::http::StatusCode; - use crate::test::{block_on, init_service, TestRequest}; + use crate::test::{self, init_service, TestRequest}; use crate::{web, App, HttpResponse}; - #[test] - fn test_data_extractor() { - let mut srv = - init_service(App::new().data(10usize).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + #[actix_rt::test] + async fn test_data_extractor() { + let mut srv = init_service(App::new().data("TEST".to_string()).service( + web::resource("/").to(|data: web::Data| { + assert_eq!(data.to_lowercase(), "test"); + HttpResponse::Ok() + }), + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let mut srv = init_service(App::new().data(10u32).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } - #[test] - fn test_register_data_extractor() { + #[actix_rt::test] + async fn test_app_data_extractor() { let mut srv = - init_service(App::new().register_data(Data::new(10usize)).service( + init_service(App::new().app_data(Data::new(10usize)).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let mut srv = - init_service(App::new().register_data(Data::new(10u32)).service( + init_service(App::new().app_data(Data::new(10u32)).service( web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } - #[test] - fn test_route_data_extractor() { + #[actix_rt::test] + async fn test_route_data_extractor() { let mut srv = init_service(App::new().service(web::resource("/").data(10usize).route( web::get().to(|data: web::Data| { let _ = data.clone(); HttpResponse::Ok() }), - ))); + ))) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); // different type @@ -201,26 +211,71 @@ mod tests { .data(10u32) .route(web::get().to(|_: web::Data| HttpResponse::Ok())), ), - ); + ) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } - #[test] - fn test_override_data() { + #[actix_rt::test] + async fn test_override_data() { let mut srv = init_service(App::new().data(1usize).service( web::resource("/").data(10usize).route(web::get().to( |data: web::Data| { - assert_eq!(*data, 10); + assert_eq!(**data, 10); let _ = data.clone(); HttpResponse::Ok() }, )), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } + + #[actix_rt::test] + async fn test_data_drop() { + struct TestData(Arc); + + impl TestData { + fn new(inner: Arc) -> Self { + let _ = inner.fetch_add(1, Ordering::SeqCst); + Self(inner) + } + } + + impl Clone for TestData { + fn clone(&self) -> Self { + let inner = self.0.clone(); + let _ = inner.fetch_add(1, Ordering::SeqCst); + Self(inner) + } + } + + impl Drop for TestData { + fn drop(&mut self) { + let _ = self.0.fetch_sub(1, Ordering::SeqCst); + } + } + + let num = Arc::new(AtomicUsize::new(0)); + let data = TestData::new(num.clone()); + assert_eq!(num.load(Ordering::SeqCst), 1); + + let srv = test::start(move || { + let data = data.clone(); + + App::new() + .data(data) + .service(web::resource("/").to(|_data: Data| async { "ok" })) + }); + + assert!(srv.get("/").send().await.unwrap().status().is_success()); + srv.stop().await; + + assert_eq!(num.load(Ordering::SeqCst), 0); + } } diff --git a/src/error.rs b/src/error.rs index 9f31582ed..31f6b9c5b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,7 +6,6 @@ use url::ParseError as UrlParseError; use crate::http::StatusCode; use crate::HttpResponse; -use serde_urlencoded::de; /// Errors which can occur when attempting to generate resource uri. #[derive(Debug, PartialEq, Display, From)] @@ -32,8 +31,12 @@ pub enum UrlencodedError { #[display(fmt = "Can not decode chunked transfer encoding")] Chunked, /// Payload size is bigger than allowed. (default: 256kB) - #[display(fmt = "Urlencoded payload size is bigger than allowed (default: 256kB)")] - Overflow, + #[display( + fmt = "Urlencoded payload size is bigger ({} bytes) than allowed (default: {} bytes)", + size, + limit + )] + Overflow { size: usize, limit: usize }, /// Payload size is now known #[display(fmt = "Payload size is now known")] UnknownLength, @@ -50,15 +53,11 @@ pub enum UrlencodedError { /// Return `BadRequest` for `UrlencodedError` impl ResponseError for UrlencodedError { - fn error_response(&self) -> HttpResponse { + fn status_code(&self) -> StatusCode { match *self { - UrlencodedError::Overflow => { - HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE) - } - UrlencodedError::UnknownLength => { - HttpResponse::new(StatusCode::LENGTH_REQUIRED) - } - _ => HttpResponse::new(StatusCode::BAD_REQUEST), + UrlencodedError::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE, + UrlencodedError::UnknownLength => StatusCode::LENGTH_REQUIRED, + _ => StatusCode::BAD_REQUEST, } } } @@ -97,15 +96,13 @@ impl ResponseError for JsonPayloadError { pub enum PathError { /// Deserialize error #[display(fmt = "Path deserialize error: {}", _0)] - Deserialize(de::Error), + Deserialize(serde::de::value::Error), } /// Return `BadRequest` for `PathError` impl ResponseError for PathError { - fn error_response(&self) -> HttpResponse { - match *self { - PathError::Deserialize(_) => HttpResponse::new(StatusCode::BAD_REQUEST), - } + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } @@ -114,17 +111,13 @@ impl ResponseError for PathError { pub enum QueryPayloadError { /// Deserialize error #[display(fmt = "Query deserialize error: {}", _0)] - Deserialize(de::Error), + Deserialize(serde::de::value::Error), } /// Return `BadRequest` for `QueryPayloadError` impl ResponseError for QueryPayloadError { - fn error_response(&self) -> HttpResponse { - match *self { - QueryPayloadError::Deserialize(_) => { - HttpResponse::new(StatusCode::BAD_REQUEST) - } - } + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } @@ -148,12 +141,10 @@ pub enum ReadlinesError { /// Return `BadRequest` for `ReadlinesError` impl ResponseError for ReadlinesError { - fn error_response(&self) -> HttpResponse { + fn status_code(&self) -> StatusCode { match *self { - ReadlinesError::LimitOverflow => { - HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE) - } - _ => HttpResponse::new(StatusCode::BAD_REQUEST), + ReadlinesError::LimitOverflow => StatusCode::PAYLOAD_TOO_LARGE, + _ => StatusCode::BAD_REQUEST, } } } @@ -164,7 +155,8 @@ mod tests { #[test] fn test_urlencoded_error() { - let resp: HttpResponse = UrlencodedError::Overflow.error_response(); + let resp: HttpResponse = + UrlencodedError::Overflow { size: 0, limit: 0 }.error_response(); assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); let resp: HttpResponse = UrlencodedError::UnknownLength.error_response(); assert_eq!(resp.status(), StatusCode::LENGTH_REQUIRED); diff --git a/src/extract.rs b/src/extract.rs index 1687973ac..c189bbf97 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,8 +1,10 @@ //! Request extractors +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_http::error::Error; -use futures::future::ok; -use futures::{future, Async, Future, IntoFuture, Poll}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use crate::dev::Payload; use crate::request::HttpRequest; @@ -15,7 +17,7 @@ pub trait FromRequest: Sized { type Error: Into; /// Future that resolves to a Self - type Future: IntoFuture; + type Future: Future>; /// Configuration for this extractor type Config: Default + 'static; @@ -46,9 +48,10 @@ pub trait FromRequest: Sized { /// ## Example /// /// ```rust -/// # #[macro_use] extern crate serde_derive; /// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest}; /// use actix_web::error::ErrorBadRequest; +/// use futures::future::{ok, err, Ready}; +/// use serde_derive::Deserialize; /// use rand; /// /// #[derive(Debug, Deserialize)] @@ -58,21 +61,21 @@ pub trait FromRequest: Sized { /// /// impl FromRequest for Thing { /// type Error = Error; -/// type Future = Result; +/// type Future = Ready>; /// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { -/// Ok(Thing { name: "thingy".into() }) +/// ok(Thing { name: "thingy".into() }) /// } else { -/// Err(ErrorBadRequest("no luck")) +/// err(ErrorBadRequest("no luck")) /// } /// /// } /// } /// /// /// extract `Thing` from request -/// fn index(supplied_thing: Option) -> String { +/// async fn index(supplied_thing: Option) -> String { /// match supplied_thing { /// // Puns not intended /// Some(thing) => format!("Got something: {:?}", thing), @@ -94,21 +97,19 @@ where { type Config = T::Config; type Error = Error; - type Future = Box, Error = Error>>; + type Future = LocalBoxFuture<'static, Result, Error>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - Box::new( - T::from_request(req, payload) - .into_future() - .then(|r| match r { - Ok(v) => future::ok(Some(v)), - Err(e) => { - log::debug!("Error for Option extractor: {}", e.into()); - future::ok(None) - } - }), - ) + T::from_request(req, payload) + .then(|r| match r { + Ok(v) => ok(Some(v)), + Err(e) => { + log::debug!("Error for Option extractor: {}", e.into()); + ok(None) + } + }) + .boxed_local() } } @@ -119,9 +120,10 @@ where /// ## Example /// /// ```rust -/// # #[macro_use] extern crate serde_derive; /// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest}; /// use actix_web::error::ErrorBadRequest; +/// use futures::future::{ok, err, Ready}; +/// use serde_derive::Deserialize; /// use rand; /// /// #[derive(Debug, Deserialize)] @@ -131,20 +133,20 @@ where /// /// impl FromRequest for Thing { /// type Error = Error; -/// type Future = Result; +/// type Future = Ready>; /// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { -/// Ok(Thing { name: "thingy".into() }) +/// ok(Thing { name: "thingy".into() }) /// } else { -/// Err(ErrorBadRequest("no luck")) +/// err(ErrorBadRequest("no luck")) /// } /// } /// } /// /// /// extract `Thing` from request -/// fn index(supplied_thing: Result) -> String { +/// async fn index(supplied_thing: Result) -> String { /// match supplied_thing { /// Ok(thing) => format!("Got thing: {:?}", thing), /// Err(e) => format!("Error extracting thing: {}", e) @@ -157,26 +159,24 @@ where /// ); /// } /// ``` -impl FromRequest for Result +impl FromRequest for Result where - T: FromRequest, - T::Future: 'static, + T: FromRequest + 'static, T::Error: 'static, + T::Future: 'static, { type Config = T::Config; type Error = Error; - type Future = Box, Error = Error>>; + type Future = LocalBoxFuture<'static, Result, Error>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - Box::new( - T::from_request(req, payload) - .into_future() - .then(|res| match res { - Ok(v) => ok(Ok(v)), - Err(e) => ok(Err(e)), - }), - ) + T::from_request(req, payload) + .then(|res| match res { + Ok(v) => ok(Ok(v)), + Err(e) => ok(Err(e)), + }) + .boxed_local() } } @@ -184,10 +184,10 @@ where impl FromRequest for () { type Config = (); type Error = Error; - type Future = Result<(), Error>; + type Future = Ready>; fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(()) + ok(()) } } @@ -195,6 +195,7 @@ macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => { /// FromRequest implementation for tuple #[doc(hidden)] + #[allow(unused_parens)] impl<$($T: FromRequest + 'static),+> FromRequest for ($($T,)+) { type Error = Error; @@ -204,43 +205,44 @@ macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { $fut_type { items: <($(Option<$T>,)+)>::default(), - futs: ($($T::from_request(req, payload).into_future(),)+), + futs: ($($T::from_request(req, payload),)+), } } } #[doc(hidden)] + #[pin_project::pin_project] pub struct $fut_type<$($T: FromRequest),+> { items: ($(Option<$T>,)+), - futs: ($(<$T::Future as futures::IntoFuture>::Future,)+), + futs: ($($T::Future,)+), } impl<$($T: FromRequest),+> Future for $fut_type<$($T),+> { - type Item = ($($T,)+); - type Error = Error; + type Output = Result<($($T,)+), Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); - fn poll(&mut self) -> Poll { let mut ready = true; - $( - if self.items.$n.is_none() { - match self.futs.$n.poll() { - Ok(Async::Ready(item)) => { - self.items.$n = Some(item); + if this.items.$n.is_none() { + match unsafe { Pin::new_unchecked(&mut this.futs.$n) }.poll(cx) { + Poll::Ready(Ok(item)) => { + this.items.$n = Some(item); } - Ok(Async::NotReady) => ready = false, - Err(e) => return Err(e.into()), + Poll::Pending => ready = false, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), } } )+ if ready { - Ok(Async::Ready( - ($(self.items.$n.take().unwrap(),)+) + Poll::Ready(Ok( + ($(this.items.$n.take().unwrap(),)+) )) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -269,7 +271,7 @@ mod tests { use serde_derive::Deserialize; use super::*; - use crate::test::{block_on, TestRequest}; + use crate::test::TestRequest; use crate::types::{Form, FormConfig}; #[derive(Deserialize, Debug, PartialEq)] @@ -277,8 +279,8 @@ mod tests { hello: String, } - #[test] - fn test_option() { + #[actix_rt::test] + async fn test_option() { let (req, mut pl) = TestRequest::with_header( header::CONTENT_TYPE, "application/x-www-form-urlencoded", @@ -286,7 +288,9 @@ mod tests { .data(FormConfig::default().limit(4096)) .to_http_parts(); - let r = block_on(Option::>::from_request(&req, &mut pl)).unwrap(); + let r = Option::>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!(r, None); let (req, mut pl) = TestRequest::with_header( @@ -297,7 +301,9 @@ mod tests { .set_payload(Bytes::from_static(b"hello=world")) .to_http_parts(); - let r = block_on(Option::>::from_request(&req, &mut pl)).unwrap(); + let r = Option::>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!( r, Some(Form(Info { @@ -313,12 +319,14 @@ mod tests { .set_payload(Bytes::from_static(b"bye=world")) .to_http_parts(); - let r = block_on(Option::>::from_request(&req, &mut pl)).unwrap(); + let r = Option::>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!(r, None); } - #[test] - fn test_result() { + #[actix_rt::test] + async fn test_result() { let (req, mut pl) = TestRequest::with_header( header::CONTENT_TYPE, "application/x-www-form-urlencoded", @@ -327,7 +335,8 @@ mod tests { .set_payload(Bytes::from_static(b"hello=world")) .to_http_parts(); - let r = block_on(Result::, Error>::from_request(&req, &mut pl)) + let r = Result::, Error>::from_request(&req, &mut pl) + .await .unwrap() .unwrap(); assert_eq!( @@ -345,8 +354,9 @@ mod tests { .set_payload(Bytes::from_static(b"bye=world")) .to_http_parts(); - let r = - block_on(Result::, Error>::from_request(&req, &mut pl)).unwrap(); + let r = Result::, Error>::from_request(&req, &mut pl) + .await + .unwrap(); assert!(r.is_err()); } } diff --git a/src/guard.rs b/src/guard.rs index e0b4055ba..e6303e9c7 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,16 +1,16 @@ //! Route match guards. //! -//! Guards are one of the way how actix-web router chooses -//! handler service. In essence it just function that accepts -//! reference to a `RequestHead` instance and returns boolean. +//! Guards are one of the ways how actix-web router chooses a +//! handler service. In essence it is just a function that accepts a +//! reference to a `RequestHead` instance and returns a boolean. //! It is possible to add guards to *scopes*, *resources* //! and *routes*. Actix provide several guards by default, like various //! http methods, header, etc. To become a guard, type must implement `Guard` //! trait. Simple functions coulds guards as well. //! -//! Guard can not modify request object. But it is possible to -//! to store extra attributes on a request by using `Extensions` container. -//! Extensions container available via `RequestHead::extensions()` method. +//! Guards can not modify the request object. But it is possible +//! to store extra attributes on a request by using the `Extensions` container. +//! Extensions containers are available via the `RequestHead::extensions()` method. //! //! ```rust //! use actix_web::{web, http, dev, guard, App, HttpResponse}; @@ -24,16 +24,17 @@ //! ); //! } //! ``` - #![allow(non_snake_case)] -use actix_http::http::{self, header, uri::Uri, HttpTryFrom}; +use std::convert::TryFrom; + +use actix_http::http::{self, header, uri::Uri}; use actix_http::RequestHead; -/// Trait defines resource guards. Guards are used for routes selection. +/// Trait defines resource guards. Guards are used for route selection. /// -/// Guard can not modify request object. But it is possible to -/// to store extra attributes on request by using `Extensions` container, -/// Extensions container available via `RequestHead::extensions()` method. +/// Guards can not modify the request object. But it is possible +/// to store extra attributes on a request by using the `Extensions` container. +/// Extensions containers are available via the `RequestHead::extensions()` method. pub trait Guard { /// Check if request matches predicate fn check(&self, request: &RequestHead) -> bool; @@ -258,16 +259,15 @@ impl Guard for HeaderGuard { /// Return predicate that matches if request contains specified Host name. /// -/// ```rust,ignore -/// # extern crate actix_web; -/// use actix_web::{guard::Host, App, HttpResponse}; +/// ```rust +/// use actix_web::{web, guard::Host, App, HttpResponse}; /// /// fn main() { -/// App::new().resource("/index.html", |r| { -/// r.route() +/// App::new().service( +/// web::resource("/index.html") /// .guard(Host("www.rust-lang.org")) -/// .f(|_| HttpResponse::MethodNotAllowed()) -/// }); +/// .to(|| HttpResponse::MethodNotAllowed()) +/// ); /// } /// ``` pub fn Host>(host: H) -> HostGuard { @@ -276,10 +276,12 @@ pub fn Host>(host: H) -> HostGuard { fn get_host_uri(req: &RequestHead) -> Option { use core::str::FromStr; - let host_value = req.headers.get(header::HOST)?; - let host = host_value.to_str().ok()?; - let uri = Uri::from_str(host).ok()?; - Some(uri) + req.headers + .get(header::HOST) + .and_then(|host_value| host_value.to_str().ok()) + .or_else(|| req.uri.host()) + .map(|host: &str| Uri::from_str(host).ok()) + .and_then(|host_success| host_success) } #[doc(hidden)] @@ -400,6 +402,31 @@ mod tests { assert!(!pred.check(req.head())); } + #[test] + fn test_host_without_header() { + let req = TestRequest::default() + .uri("www.rust-lang.org") + .to_http_request(); + + let pred = Host("www.rust-lang.org"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org").scheme("https"); + assert!(pred.check(req.head())); + + let pred = Host("blog.rust-lang.org"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("crates.io"); + assert!(!pred.check(req.head())); + + let pred = Host("localhost"); + assert!(!pred.check(req.head())); + } + #[test] fn test_methods() { let req = TestRequest::default().to_http_request(); diff --git a/src/handler.rs b/src/handler.rs index 078abbf1d..33cd2408d 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,28 +1,34 @@ use std::convert::Infallible; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_http::{Error, Response}; -use actix_service::{NewService, Service}; -use futures::future::{ok, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, Ready}; +use futures::ready; +use pin_project::pin_project; use crate::extract::FromRequest; use crate::request::HttpRequest; use crate::responder::Responder; use crate::service::{ServiceRequest, ServiceResponse}; -/// Handler converter factory -pub trait Factory: Clone +/// Async handler converter factory +pub trait Factory: Clone + 'static where - R: Responder, + R: Future, + O: Responder, { fn call(&self, param: T) -> R; } -impl Factory<(), R> for F +impl Factory<(), R, O> for F where - F: Fn() -> R + Clone, - R: Responder, + F: Fn() -> R + Clone + 'static, + R: Future, + O: Responder, { fn call(&self, _: ()) -> R { (self)() @@ -30,19 +36,21 @@ where } #[doc(hidden)] -pub struct Handler +pub struct Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { hnd: F, - _t: PhantomData<(T, R)>, + _t: PhantomData<(T, R, O)>, } -impl Handler +impl Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { pub fn new(hnd: F) -> Self { Handler { @@ -52,156 +60,38 @@ where } } -impl Clone for Handler +impl Clone for Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { fn clone(&self) -> Self { - Self { + Handler { hnd: self.hnd.clone(), _t: PhantomData, } } } -impl Service for Handler +impl Service for Handler where - F: Factory, - R: Responder, + F: Factory, + R: Future, + O: Responder, { type Request = (T, HttpRequest); type Response = ServiceResponse; type Error = Infallible; - type Future = HandlerServiceResponse<::Future>; + type Future = HandlerServiceResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { - let fut = self.hnd.call(param).respond_to(&req).into_future(); HandlerServiceResponse { - fut, - req: Some(req), - } - } -} - -pub struct HandlerServiceResponse { - fut: T, - req: Option, -} - -impl Future for HandlerServiceResponse -where - T: Future, - T::Error: Into, -{ - type Item = ServiceResponse; - type Error = Infallible; - - fn poll(&mut self) -> Poll { - match self.fut.poll() { - Ok(Async::Ready(res)) => Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - let res: Response = e.into().into(); - Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))) - } - } - } -} - -/// Async handler converter factory -pub trait AsyncFactory: Clone + 'static -where - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - fn call(&self, param: T) -> R; -} - -impl AsyncFactory<(), R> for F -where - F: Fn() -> R + Clone + 'static, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - fn call(&self, _: ()) -> R { - (self)() - } -} - -#[doc(hidden)] -pub struct AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - hnd: F, - _t: PhantomData<(T, R)>, -} - -impl AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - pub fn new(hnd: F) -> Self { - AsyncHandler { - hnd, - _t: PhantomData, - } - } -} - -impl Clone for AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - fn clone(&self) -> Self { - AsyncHandler { - hnd: self.hnd.clone(), - _t: PhantomData, - } - } -} - -impl Service for AsyncHandler -where - F: AsyncFactory, - R: IntoFuture, - R::Item: Responder, - R::Error: Into, -{ - type Request = (T, HttpRequest); - type Response = ServiceResponse; - type Error = Infallible; - type Future = AsyncHandlerServiceResponse; - - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) - } - - fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { - AsyncHandlerServiceResponse { - fut: self.hnd.call(param).into_future(), + fut: self.hnd.call(param), fut2: None, req: Some(req), } @@ -209,57 +99,49 @@ where } #[doc(hidden)] -pub struct AsyncHandlerServiceResponse +#[pin_project] +pub struct HandlerServiceResponse where - T: Future, - T::Item: Responder, + T: Future, + R: Responder, { + #[pin] fut: T, - fut2: Option<<::Future as IntoFuture>::Future>, + #[pin] + fut2: Option, req: Option, } -impl Future for AsyncHandlerServiceResponse +impl Future for HandlerServiceResponse where - T: Future, - T::Item: Responder, - T::Error: Into, + T: Future, + R: Responder, { - type Item = ServiceResponse; - type Error = Infallible; + type Output = Result; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut2 { - return match fut.poll() { - Ok(Async::Ready(res)) => Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); + + if let Some(fut) = this.fut2.as_pin_mut() { + return match fut.poll(cx) { + Poll::Ready(Ok(res)) => { + Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + } + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { let res: Response = e.into().into(); - Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))) + Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) } }; } - match self.fut.poll() { - Ok(Async::Ready(res)) => { - self.fut2 = - Some(res.respond_to(self.req.as_ref().unwrap()).into_future()); - self.poll() - } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - let res: Response = e.into().into(); - Ok(Async::Ready(ServiceResponse::new( - self.req.take().unwrap(), - res, - ))) + match this.fut.poll(cx) { + Poll::Ready(res) => { + let fut = res.respond_to(this.req.as_ref().unwrap()); + self.as_mut().project().fut2.set(Some(fut)); + self.poll(cx) } + Poll::Pending => Poll::Pending, } } } @@ -279,7 +161,7 @@ impl Extract { } } -impl NewService for Extract +impl ServiceFactory for Extract where S: Service< Request = (T, HttpRequest), @@ -293,9 +175,9 @@ where type Error = (Error, ServiceRequest); type InitError = (); type Service = ExtractService; - type Future = FutureResult; + type Future = Ready>; - fn new_service(&self, _: &()) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { ok(ExtractService { _t: PhantomData, service: self.service.clone(), @@ -321,13 +203,13 @@ where type Error = (Error, ServiceRequest); type Future = ExtractResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let (req, mut payload) = req.into_parts(); - let fut = T::from_request(&req, &mut payload).into_future(); + let fut = T::from_request(&req, &mut payload); ExtractResponse { fut, @@ -338,10 +220,13 @@ where } } +#[pin_project] pub struct ExtractResponse { req: HttpRequest, service: S, - fut: ::Future, + #[pin] + fut: T::Future, + #[pin] fut_s: Option, } @@ -353,40 +238,35 @@ where Error = Infallible, >, { - type Item = ServiceResponse; - type Error = (Error, ServiceRequest); + type Output = Result; - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut_s { - return fut.poll().map_err(|_| panic!()); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); + + if let Some(fut) = this.fut_s.as_pin_mut() { + return fut.poll(cx).map_err(|_| panic!()); } - let item = try_ready!(self.fut.poll().map_err(|e| { - let req = ServiceRequest::new(self.req.clone()); - (e.into(), req) - })); - - self.fut_s = Some(self.service.call((item, self.req.clone()))); - self.poll() + match ready!(this.fut.poll(cx)) { + Err(e) => { + let req = ServiceRequest::new(this.req.clone()); + Poll::Ready(Err((e.into(), req))) + } + Ok(item) => { + let fut = Some(this.service.call((item, this.req.clone()))); + self.as_mut().project().fut_s.set(fut); + self.poll(cx) + } + } } } /// FromRequest trait impl for tuples macro_rules! factory_tuple ({ $(($n:tt, $T:ident)),+} => { - impl Factory<($($T,)+), Res> for Func - where Func: Fn($($T,)+) -> Res + Clone, - Res: Responder, - { - fn call(&self, param: ($($T,)+)) -> Res { - (self)($(param.$n,)+) - } - } - - impl AsyncFactory<($($T,)+), Res> for Func + impl Factory<($($T,)+), Res, O> for Func where Func: Fn($($T,)+) -> Res + Clone + 'static, - Res: IntoFuture, - Res::Item: Responder, - Res::Error: Into, + Res: Future, + O: Responder, { fn call(&self, param: ($($T,)+)) -> Res { (self)($(param.$n,)+) diff --git a/src/info.rs b/src/info.rs index ba59605de..c9a642b36 100644 --- a/src/info.rs +++ b/src/info.rs @@ -76,7 +76,7 @@ impl ConnectionInfo { } } if scheme.is_none() { - scheme = req.uri.scheme_part().map(|a| a.as_str()); + scheme = req.uri.scheme().map(|a| a.as_str()); if scheme.is_none() && cfg.secure() { scheme = Some("https") } @@ -98,7 +98,7 @@ impl ConnectionInfo { host = h.to_str().ok(); } if host.is_none() { - host = req.uri.authority_part().map(|a| a.as_str()); + host = req.uri.authority().map(|a| a.as_str()); if host.is_none() { host = Some(cfg.host()); } @@ -155,13 +155,19 @@ impl ConnectionInfo { &self.host } - /// Remote IP of client initiated HTTP request. + /// Remote socket addr of client initiated HTTP request. /// - /// The IP is resolved through the following headers, in this order: + /// The addr is resolved through the following headers, in this order: /// /// - Forwarded /// - X-Forwarded-For /// - peer name of opened socket + /// + /// # Security + /// Do not use this function for security purposes, unless you can ensure the Forwarded and + /// X-Forwarded-For headers cannot be spoofed by the client. If you want the client's socket + /// address explicitly, use + /// [`HttpRequest::peer_addr()`](../web/struct.HttpRequest.html#method.peer_addr) instead. #[inline] pub fn remote(&self) -> Option<&str> { if let Some(ref r) = self.remote { diff --git a/src/lib.rs b/src/lib.rs index 60c34489e..d51005cfe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,24 +1,27 @@ -#![allow(clippy::borrow_interior_mutable_const)] +#![deny(rust_2018_idioms, warnings)] +#![allow( + clippy::needless_doctest_main, + clippy::type_complexity, + clippy::borrow_interior_mutable_const +)] //! Actix web is a small, pragmatic, and extremely fast web framework //! for Rust. //! -//! ```rust +//! ```rust,no_run //! use actix_web::{web, App, Responder, HttpServer}; -//! # use std::thread; //! -//! fn index(info: web::Path<(String, u32)>) -> impl Responder { +//! async fn index(info: web::Path<(String, u32)>) -> impl Responder { //! format!("Hello {}! id:{}", info.0, info.1) //! } //! -//! fn main() -> std::io::Result<()> { -//! # thread::spawn(|| { +//! #[actix_rt::main] +//! async fn main() -> std::io::Result<()> { //! HttpServer::new(|| App::new().service( //! web::resource("/{name}/{id}/index.html").to(index)) //! ) //! .bind("127.0.0.1:8080")? //! .run() -//! # }); -//! # Ok(()) +//! .await //! } //! ``` //! @@ -44,7 +47,7 @@ //! configure servers. //! //! * [web](web/index.html): This module -//! provide essentials helper functions and types for application registration. +//! provides essential helper functions and types for application registration. //! //! * [HttpRequest](struct.HttpRequest.html) and //! [HttpResponse](struct.HttpResponse.html): These structs @@ -63,23 +66,16 @@ //! * SSL support with OpenSSL or `native-tls` //! * Middlewares (`Logger`, `Session`, `CORS`, `DefaultHeaders`) //! * Supports [Actix actor framework](https://github.com/actix/actix) -//! * Supported Rust version: 1.36 or later +//! * Supported Rust version: 1.39 or later //! //! ## Package feature //! //! * `client` - enables http client (default enabled) -//! * `ssl` - enables ssl support via `openssl` crate, supports `http/2` -//! * `rust-tls` - enables ssl support via `rustls` crate, supports `http/2` +//! * `compress` - enables content encoding compression support (default enabled) +//! * `openssl` - enables ssl support via `openssl` crate, supports `http/2` +//! * `rustls` - enables ssl support via `rustls` crate, supports `http/2` //! * `secure-cookies` - enables secure cookies support, includes `ring` crate as //! dependency -//! * `brotli` - enables `brotli` compression support, requires `c` -//! compiler (default enabled) -//! * `flate2-zlib` - enables `gzip`, `deflate` compression support, requires -//! `c` compiler (default enabled) -//! * `flate2-rust` - experimental rust based implementation for -//! `gzip`, `deflate` compression. -//! * `uds` - Unix domain support, enables `HttpServer::bind_uds()` method. -//! #![allow(clippy::type_complexity, clippy::new_without_default)] mod app; @@ -104,10 +100,6 @@ pub mod test; mod types; pub mod web; -#[allow(unused_imports)] -#[macro_use] -extern crate actix_web_codegen; - #[doc(hidden)] pub use actix_web_codegen::*; @@ -137,17 +129,19 @@ pub mod dev { pub use crate::config::{AppConfig, AppService}; #[doc(hidden)] - pub use crate::handler::{AsyncFactory, Factory}; + pub use crate::handler::Factory; pub use crate::info::ConnectionInfo; pub use crate::rmap::ResourceMap; pub use crate::service::{ HttpServiceFactory, ServiceRequest, ServiceResponse, WebService, }; + pub use crate::types::form::UrlEncoded; pub use crate::types::json::JsonBody; pub use crate::types::readlines::Readlines; pub use actix_http::body::{Body, BodySize, MessageBody, ResponseBody, SizedStream}; + #[cfg(feature = "compress")] pub use actix_http::encoding::Decoder as Decompress; pub use actix_http::ResponseBuilder as HttpResponseBuilder; pub use actix_http::{ @@ -157,37 +151,76 @@ pub mod dev { pub use actix_server::Server; pub use actix_service::{Service, Transform}; - pub(crate) fn insert_slash(path: &str) -> String { - let mut path = path.to_owned(); - if !path.is_empty() && !path.starts_with('/') { - path.insert(0, '/'); - }; - path + pub(crate) fn insert_slash(mut patterns: Vec) -> Vec { + for path in &mut patterns { + if !path.is_empty() && !path.starts_with('/') { + path.insert(0, '/'); + }; + } + patterns + } + + use crate::http::header::ContentEncoding; + use actix_http::{Response, ResponseBuilder}; + + struct Enc(ContentEncoding); + + /// Helper trait that allows to set specific encoding for response. + pub trait BodyEncoding { + /// Get content encoding + fn get_encoding(&self) -> Option; + + /// Set content encoding + fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self; + } + + impl BodyEncoding for ResponseBuilder { + fn get_encoding(&self) -> Option { + if let Some(ref enc) = self.extensions().get::() { + Some(enc.0) + } else { + None + } + } + + fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { + self.extensions_mut().insert(Enc(encoding)); + self + } + } + + impl BodyEncoding for Response { + fn get_encoding(&self) -> Option { + if let Some(ref enc) = self.extensions().get::() { + Some(enc.0) + } else { + None + } + } + + fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { + self.extensions_mut().insert(Enc(encoding)); + self + } } } -#[cfg(feature = "client")] pub mod client { //! An HTTP Client //! //! ```rust - //! # use futures::future::{Future, lazy}; - //! use actix_rt::System; //! use actix_web::client::Client; //! - //! fn main() { - //! System::new("test").block_on(lazy(|| { - //! let mut client = Client::default(); + //! #[actix_rt::main] + //! async fn main() { + //! let mut client = Client::default(); //! - //! client.get("http://www.rust-lang.org") // <- Create request builder - //! .header("User-Agent", "Actix-web") - //! .send() // <- Send http request - //! .map_err(|_| ()) - //! .and_then(|response| { // <- server http response - //! println!("Response: {:?}", response); - //! Ok(()) - //! }) - //! })); + //! // Create request builder and send request + //! let response = client.get("http://www.rust-lang.org") + //! .header("User-Agent", "Actix-web") + //! .send().await; // <- Send http request + //! + //! println!("Response: {:?}", response); //! } //! ``` pub use awc::error::{ diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs index 86665d824..70006ab3c 100644 --- a/src/middleware/compress.rs +++ b/src/middleware/compress.rs @@ -1,39 +1,22 @@ //! `Middleware` for compressing response body. use std::cmp; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::str::FromStr; +use std::task::{Context, Poll}; use actix_http::body::MessageBody; use actix_http::encoding::Encoder; use actix_http::http::header::{ContentEncoding, ACCEPT_ENCODING}; -use actix_http::{Error, Response, ResponseBuilder}; +use actix_http::Error; use actix_service::{Service, Transform}; -use futures::future::{ok, FutureResult}; -use futures::{Async, Future, Poll}; +use futures::future::{ok, Ready}; +use pin_project::pin_project; +use crate::dev::BodyEncoding; use crate::service::{ServiceRequest, ServiceResponse}; -struct Enc(ContentEncoding); - -/// Helper trait that allows to set specific encoding for response. -pub trait BodyEncoding { - fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self; -} - -impl BodyEncoding for ResponseBuilder { - fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { - self.extensions_mut().insert(Enc(encoding)); - self - } -} - -impl BodyEncoding for Response { - fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { - self.extensions_mut().insert(Enc(encoding)); - self - } -} - #[derive(Debug, Clone)] /// `Middleware` for compressing response body. /// @@ -78,7 +61,7 @@ where type Error = Error; type InitError = (); type Transform = CompressMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CompressMiddleware { @@ -103,8 +86,8 @@ where type Error = Error; type Future = CompressResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { @@ -128,14 +111,16 @@ where } #[doc(hidden)] +#[pin_project] pub struct CompressResponse where S: Service, B: MessageBody, { + #[pin] fut: S::Future, encoding: ContentEncoding, - _t: PhantomData<(B)>, + _t: PhantomData, } impl Future for CompressResponse @@ -143,21 +128,25 @@ where B: MessageBody, S: Service, Error = Error>, { - type Item = ServiceResponse>; - type Error = Error; + type Output = Result>, Error>; - fn poll(&mut self) -> Poll { - let resp = futures::try_ready!(self.fut.poll()); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); - let enc = if let Some(enc) = resp.response().extensions().get::() { - enc.0 - } else { - self.encoding - }; + match futures::ready!(this.fut.poll(cx)) { + Ok(resp) => { + let enc = if let Some(enc) = resp.response().get_encoding() { + enc + } else { + *this.encoding + }; - Ok(Async::Ready(resp.map_body(move |head, body| { - Encoder::response(enc, head, body) - }))) + Poll::Ready(Ok( + resp.map_body(move |head, body| Encoder::response(enc, head, body)) + )) + } + Err(e) => Poll::Ready(Err(e)), + } } } @@ -169,6 +158,7 @@ struct AcceptEncoding { impl Eq for AcceptEncoding {} impl Ord for AcceptEncoding { + #[allow(clippy::comparison_chain)] fn cmp(&self, other: &AcceptEncoding) -> cmp::Ordering { if self.quality > other.quality { cmp::Ordering::Less diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index ddc5fdd42..68d06837e 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -1,7 +1,8 @@ //! `Middleware` for conditionally enables another middleware. +use std::task::{Context, Poll}; + use actix_service::{Service, Transform}; -use futures::future::{ok, Either, FutureResult, Map}; -use futures::{Future, Poll}; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture}; /// `Middleware` for conditionally enables another middleware. /// The controled middleware must not change the `Service` interfaces. @@ -13,11 +14,11 @@ use futures::{Future, Poll}; /// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::App; /// -/// fn main() { -/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); -/// let app = App::new() -/// .wrap(Condition::new(enable_normalize, NormalizePath)); -/// } +/// # fn main() { +/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); +/// let app = App::new() +/// .wrap(Condition::new(enable_normalize, NormalizePath)); +/// # } /// ``` pub struct Condition { trans: T, @@ -32,29 +33,31 @@ impl Condition { impl Transform for Condition where - S: Service, + S: Service + 'static, T: Transform, + T::Future: 'static, + T::InitError: 'static, + T::Transform: 'static, { type Request = S::Request; type Response = S::Response; type Error = S::Error; type InitError = T::InitError; type Transform = ConditionMiddleware; - type Future = Either< - Map Self::Transform>, - FutureResult, - >; + type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { if self.enable { - let f = self - .trans - .new_transform(service) - .map(ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform); - Either::A(f) + let f = self.trans.new_transform(service).map(|res| { + res.map( + ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform, + ) + }); + Either::Left(f) } else { - Either::B(ok(ConditionMiddleware::Disable(service))) + Either::Right(ok(ConditionMiddleware::Disable(service))) } + .boxed_local() } } @@ -73,19 +76,19 @@ where type Error = E::Error; type Future = Either; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { use ConditionMiddleware::*; match self { - Enable(service) => service.poll_ready(), - Disable(service) => service.poll_ready(), + Enable(service) => service.poll_ready(cx), + Disable(service) => service.poll_ready(cx), } } fn call(&mut self, req: E::Request) -> Self::Future { use ConditionMiddleware::*; match self { - Enable(service) => Either::A(service.call(req)), - Disable(service) => Either::B(service.call(req)), + Enable(service) => Either::Left(service.call(req)), + Disable(service) => Either::Right(service.call(req)), } } } @@ -109,35 +112,40 @@ mod tests { Ok(ErrorHandlerResponse::Response(res)) } - #[test] - fn test_handler_enabled() { + #[actix_rt::test] + async fn test_handler_enabled() { let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) + ok(req.into_response(HttpResponse::InternalServerError().finish())) }; let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = - test::block_on(Condition::new(true, mw).new_transform(srv.into_service())) - .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); + let mut mw = Condition::new(true, mw) + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } - #[test] - fn test_handler_disabled() { + + #[actix_rt::test] + async fn test_handler_disabled() { let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) + ok(req.into_response(HttpResponse::InternalServerError().finish())) }; let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = - test::block_on(Condition::new(false, mw).new_transform(srv.into_service())) - .unwrap(); + let mut mw = Condition::new(false, mw) + .new_transform(srv.into_service()) + .await + .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE), None); } } diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index ab2d36c2c..ba001c77b 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -1,12 +1,13 @@ //! Middleware for setting default response headers +use std::convert::TryFrom; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures::future::{ok, FutureResult}; -use futures::{Future, Poll}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; -use crate::http::{HeaderMap, HttpTryFrom}; +use crate::http::{Error as HttpError, HeaderMap}; use crate::service::{ServiceRequest, ServiceResponse}; use crate::Error; @@ -58,8 +59,10 @@ impl DefaultHeaders { #[inline] pub fn header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, - HeaderValue: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, { #[allow(clippy::match_wild_err_arm)] match HeaderName::try_from(key) { @@ -96,7 +99,7 @@ where type Error = Error; type InitError = (); type Transform = DefaultHeadersMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(DefaultHeadersMiddleware { @@ -119,16 +122,19 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let inner = self.inner.clone(); + let fut = self.service.call(req); + + async move { + let mut res = fut.await?; - Box::new(self.service.call(req).map(move |mut res| { // set response headers for (key, value) in inner.headers.iter() { if !res.headers().contains_key(key) { @@ -142,61 +148,61 @@ where HeaderValue::from_static("application/octet-stream"), ); } - - res - })) + Ok(res) + } + .boxed_local() } } #[cfg(test)] mod tests { use actix_service::IntoService; + use futures::future::ok; use super::*; use crate::dev::ServiceRequest; use crate::http::header::CONTENT_TYPE; - use crate::test::{block_on, ok_service, TestRequest}; + use crate::test::{ok_service, TestRequest}; use crate::HttpResponse; - #[test] - fn test_default_headers() { - let mut mw = block_on( - DefaultHeaders::new() - .header(CONTENT_TYPE, "0001") - .new_transform(ok_service()), - ) - .unwrap(); + #[actix_rt::test] + async fn test_default_headers() { + let mut mw = DefaultHeaders::new() + .header(CONTENT_TYPE, "0001") + .new_transform(ok_service()) + .await + .unwrap(); let req = TestRequest::default().to_srv_request(); - let resp = block_on(mw.call(req)).unwrap(); + let resp = mw.call(req).await.unwrap(); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); let req = TestRequest::default().to_srv_request(); let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish()) + ok(req + .into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish())) }; - let mut mw = block_on( - DefaultHeaders::new() - .header(CONTENT_TYPE, "0001") - .new_transform(srv.into_service()), - ) - .unwrap(); - let resp = block_on(mw.call(req)).unwrap(); + let mut mw = DefaultHeaders::new() + .header(CONTENT_TYPE, "0001") + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = mw.call(req).await.unwrap(); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); } - #[test] - fn test_content_type() { - let srv = |req: ServiceRequest| req.into_response(HttpResponse::Ok().finish()); - let mut mw = block_on( - DefaultHeaders::new() - .content_type() - .new_transform(srv.into_service()), - ) - .unwrap(); + #[actix_rt::test] + async fn test_content_type() { + let srv = + |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); + let mut mw = DefaultHeaders::new() + .content_type() + .new_transform(srv.into_service()) + .await + .unwrap(); let req = TestRequest::default().to_srv_request(); - let resp = block_on(mw.call(req)).unwrap(); + let resp = mw.call(req).await.unwrap(); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), "application/octet-stream" diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index 5f73d4d7e..71886af0b 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -1,10 +1,10 @@ //! Custom handlers service for responses. use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures::future::{err, ok, Either, Future, FutureResult}; -use futures::Poll; -use hashbrown::HashMap; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use fxhash::FxHashMap; use crate::dev::{ServiceRequest, ServiceResponse}; use crate::error::{Error, Result}; @@ -15,7 +15,7 @@ pub enum ErrorHandlerResponse { /// New http response got generated Response(ServiceResponse), /// Result is a future that resolves to a new http response - Future(Box, Error = Error>>), + Future(LocalBoxFuture<'static, Result, Error>>), } type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; @@ -39,26 +39,26 @@ type ErrorHandler = dyn Fn(ServiceResponse) -> Result { - handlers: Rc>>>, + handlers: Rc>>>, } impl Default for ErrorHandlers { fn default() -> Self { ErrorHandlers { - handlers: Rc::new(HashMap::new()), + handlers: Rc::new(FxHashMap::default()), } } } @@ -92,7 +92,7 @@ where type Error = Error; type InitError = (); type Transform = ErrorHandlersMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(ErrorHandlersMiddleware { @@ -105,7 +105,7 @@ where #[doc(hidden)] pub struct ErrorHandlersMiddleware { service: S, - handlers: Rc>>>, + handlers: Rc>>>, } impl Service for ErrorHandlersMiddleware @@ -117,26 +117,30 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { let handlers = self.handlers.clone(); + let fut = self.service.call(req); + + async move { + let res = fut.await?; - Box::new(self.service.call(req).and_then(move |res| { if let Some(handler) = handlers.get(&res.status()) { match handler(res) { - Ok(ErrorHandlerResponse::Response(res)) => Either::A(ok(res)), - Ok(ErrorHandlerResponse::Future(fut)) => Either::B(fut), - Err(e) => Either::A(err(e)), + Ok(ErrorHandlerResponse::Response(res)) => Ok(res), + Ok(ErrorHandlerResponse::Future(fut)) => fut.await, + Err(e) => Err(e), } } else { - Either::A(ok(res)) + Ok(res) } - })) + } + .boxed_local() } } @@ -157,20 +161,20 @@ mod tests { Ok(ErrorHandlerResponse::Response(res)) } - #[test] - fn test_handler() { + #[actix_rt::test] + async fn test_handler() { let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) + ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mut mw = test::block_on( - ErrorHandlers::new() - .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) - .new_transform(srv.into_service()), - ) - .unwrap(); + let mut mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) + .new_transform(srv.into_service()) + .await + .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } @@ -180,23 +184,23 @@ mod tests { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Future(Box::new(ok(res)))) + Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) } - #[test] - fn test_handler_async() { + #[actix_rt::test] + async fn test_handler_async() { let srv = |req: ServiceRequest| { - req.into_response(HttpResponse::InternalServerError().finish()) + ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mut mw = test::block_on( - ErrorHandlers::new() - .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) - .new_transform(srv.into_service()), - ) - .unwrap(); + let mut mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) + .new_transform(srv.into_service()) + .await + .unwrap(); - let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request()); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } } diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index f450f0481..97fa7463f 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -1,21 +1,24 @@ //! Request logging middleware use std::collections::HashSet; +use std::convert::TryFrom; use std::env; use std::fmt::{self, Display, Formatter}; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_service::{Service, Transform}; use bytes::Bytes; -use futures::future::{ok, FutureResult}; -use futures::{Async, Future, Poll}; +use futures::future::{ok, Ready}; use log::debug; use regex::Regex; use time; use crate::dev::{BodySize, MessageBody, ResponseBody}; use crate::error::{Error, Result}; -use crate::http::{HeaderName, HttpTryFrom, StatusCode}; +use crate::http::{HeaderName, StatusCode}; use crate::service::{ServiceRequest, ServiceResponse}; use crate::HttpResponse; @@ -125,7 +128,7 @@ where type Error = Error; type InitError = (); type Transform = LoggerMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(LoggerMiddleware { @@ -151,8 +154,8 @@ where type Error = Error; type Future = LoggerResponse; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { @@ -181,11 +184,13 @@ where } #[doc(hidden)] +#[pin_project::pin_project] pub struct LoggerResponse where B: MessageBody, S: Service, { + #[pin] fut: S::Future, time: time::Tm, format: Option, @@ -197,11 +202,15 @@ where B: MessageBody, S: Service, Error = Error>, { - type Item = ServiceResponse>; - type Error = Error; + type Output = Result>, Error>; - fn poll(&mut self) -> Poll { - let res = futures::try_ready!(self.fut.poll()); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + let res = match futures::ready!(this.fut.poll(cx)) { + Ok(res) => res, + Err(e) => return Poll::Ready(Err(e)), + }; if let Some(error) = res.response().error() { if res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR { @@ -209,18 +218,21 @@ where } } - if let Some(ref mut format) = self.format { + if let Some(ref mut format) = this.format { for unit in &mut format.0 { unit.render_response(res.response()); } } - Ok(Async::Ready(res.map_body(move |_, body| { + let time = *this.time; + let format = this.format.take(); + + Poll::Ready(Ok(res.map_body(move |_, body| { ResponseBody::Body(StreamLog { body, + time, + format, size: 0, - time: self.time, - format: self.format.take(), }) }))) } @@ -236,7 +248,7 @@ pub struct StreamLog { impl Drop for StreamLog { fn drop(&mut self) { if let Some(ref format) = self.format { - let render = |fmt: &mut Formatter| { + let render = |fmt: &mut Formatter<'_>| { for unit in &format.0 { unit.render(fmt, self.size, self.time)?; } @@ -252,13 +264,13 @@ impl MessageBody for StreamLog { self.body.size() } - fn poll_next(&mut self) -> Poll, Error> { - match self.body.poll_next()? { - Async::Ready(Some(chunk)) => { + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + match self.body.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { self.size += chunk.len(); - Ok(Async::Ready(Some(chunk))) + Poll::Ready(Some(Ok(chunk))) } - val => Ok(val), + val => val, } } } @@ -352,7 +364,7 @@ pub enum FormatText { impl FormatText { fn render( &self, - fmt: &mut Formatter, + fmt: &mut Formatter<'_>, size: usize, entry_time: time::Tm, ) -> Result<(), fmt::Error> { @@ -452,11 +464,11 @@ impl FormatText { } pub(crate) struct FormatDisplay<'a>( - &'a dyn Fn(&mut Formatter) -> Result<(), fmt::Error>, + &'a dyn Fn(&mut Formatter<'_>) -> Result<(), fmt::Error>, ); impl<'a> fmt::Display for FormatDisplay<'a> { - fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { + fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> { (self.0)(fmt) } } @@ -464,34 +476,35 @@ impl<'a> fmt::Display for FormatDisplay<'a> { #[cfg(test)] mod tests { use actix_service::{IntoService, Service, Transform}; + use futures::future::ok; use super::*; use crate::http::{header, StatusCode}; - use crate::test::{block_on, TestRequest}; + use crate::test::TestRequest; - #[test] - fn test_logger() { + #[actix_rt::test] + async fn test_logger() { let srv = |req: ServiceRequest| { - req.into_response( + ok(req.into_response( HttpResponse::build(StatusCode::OK) .header("X-Test", "ttt") .finish(), - ) + )) }; let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); - let mut srv = block_on(logger.new_transform(srv.into_service())).unwrap(); + let mut srv = logger.new_transform(srv.into_service()).await.unwrap(); let req = TestRequest::with_header( header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB"), ) .to_srv_request(); - let _res = block_on(srv.call(req)); + let _res = srv.call(req).await; } - #[test] - fn test_url_path() { + #[actix_rt::test] + async fn test_url_path() { let mut format = Format::new("%T %U"); let req = TestRequest::with_header( header::USER_AGENT, @@ -510,7 +523,7 @@ mod tests { unit.render_response(&resp); } - let render = |fmt: &mut Formatter| { + let render = |fmt: &mut Formatter<'_>| { for unit in &format.0 { unit.render(fmt, 1024, now)?; } @@ -521,8 +534,8 @@ mod tests { assert!(s.contains("/test/route/yeah")); } - #[test] - fn test_default_format() { + #[actix_rt::test] + async fn test_default_format() { let mut format = Format::default(); let req = TestRequest::with_header( @@ -542,7 +555,7 @@ mod tests { } let entry_time = time::now(); - let render = |fmt: &mut Formatter| { + let render = |fmt: &mut Formatter<'_>| { for unit in &format.0 { unit.render(fmt, 1024, entry_time)?; } @@ -554,8 +567,8 @@ mod tests { assert!(s.contains("ACTIX-WEB")); } - #[test] - fn test_request_time_format() { + #[actix_rt::test] + async fn test_request_time_format() { let mut format = Format::new("%t"); let req = TestRequest::default().to_srv_request(); @@ -569,7 +582,7 @@ mod tests { unit.render_response(&resp); } - let render = |fmt: &mut Formatter| { + let render = |fmt: &mut Formatter<'_>| { for unit in &format.0 { unit.render(fmt, 1024, now)?; } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 311d0ee99..f0d42cc2a 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,14 +1,17 @@ //! Middlewares -mod compress; -pub use self::compress::{BodyEncoding, Compress}; +#[cfg(feature = "compress")] +mod compress; +#[cfg(feature = "compress")] +pub use self::compress::Compress; + +mod condition; mod defaultheaders; pub mod errhandlers; mod logger; mod normalize; -mod condition; +pub use self::condition::Condition; pub use self::defaultheaders::DefaultHeaders; pub use self::logger::Logger; pub use self::normalize::NormalizePath; -pub use self::condition::Condition; diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index 9cfbefb30..f6b834bfe 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -1,9 +1,10 @@ //! `Middleware` to normalize request's URI +use std::task::{Context, Poll}; -use actix_http::http::{HttpTryFrom, PathAndQuery, Uri}; +use actix_http::http::{PathAndQuery, Uri}; use actix_service::{Service, Transform}; use bytes::Bytes; -use futures::future::{self, FutureResult}; +use futures::future::{ok, Ready}; use regex::Regex; use crate::service::{ServiceRequest, ServiceResponse}; @@ -19,15 +20,15 @@ use crate::Error; /// ```rust /// use actix_web::{web, http, middleware, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new() -/// .wrap(middleware::NormalizePath) -/// .service( -/// web::resource("/test") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// # fn main() { +/// let app = App::new() +/// .wrap(middleware::NormalizePath) +/// .service( +/// web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// # } /// ``` pub struct NormalizePath; @@ -42,10 +43,10 @@ where type Error = Error; type InitError = (); type Transform = NormalizePathNormalization; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - future::ok(NormalizePathNormalization { + ok(NormalizePathNormalization { service, merge_slash: Regex::new("//+").unwrap(), }) @@ -67,13 +68,12 @@ where type Error = Error; type Future = S::Future; - fn poll_ready(&mut self) -> futures::Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { let head = req.head_mut(); - let path = head.uri.path(); let original_len = path.len(); let path = self.merge_slash.replace_all(path, "/"); @@ -85,9 +85,9 @@ where let path = if let Some(q) = pq.query() { Bytes::from(format!("{}?{}", path, q)) } else { - Bytes::from(path.as_ref()) + Bytes::copy_from_slice(path.as_bytes()) }; - parts.path_and_query = Some(PathAndQuery::try_from(path).unwrap()); + parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap()); let uri = Uri::from_parts(parts).unwrap(); req.match_info_mut().get_mut().update(&uri); @@ -104,51 +104,56 @@ mod tests { use super::*; use crate::dev::ServiceRequest; - use crate::test::{block_on, call_service, init_service, TestRequest}; + use crate::test::{call_service, init_service, TestRequest}; use crate::{web, App, HttpResponse}; - #[test] - fn test_wrap() { + #[actix_rt::test] + async fn test_wrap() { let mut app = init_service( App::new() .wrap(NormalizePath::default()) .service(web::resource("/v1/something/").to(|| HttpResponse::Ok())), - ); + ) + .await; let req = TestRequest::with_uri("/v1//something////").to_request(); - let res = call_service(&mut app, req); + let res = call_service(&mut app, req).await; assert!(res.status().is_success()); } - #[test] - fn test_in_place_normalization() { + #[actix_rt::test] + async fn test_in_place_normalization() { let srv = |req: ServiceRequest| { assert_eq!("/v1/something/", req.path()); - req.into_response(HttpResponse::Ok().finish()) + ok(req.into_response(HttpResponse::Ok().finish())) }; - let mut normalize = - block_on(NormalizePath.new_transform(srv.into_service())).unwrap(); + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); let req = TestRequest::with_uri("/v1//something////").to_srv_request(); - let res = block_on(normalize.call(req)).unwrap(); + let res = normalize.call(req).await.unwrap(); assert!(res.status().is_success()); } - #[test] - fn should_normalize_nothing() { + #[actix_rt::test] + async fn should_normalize_nothing() { const URI: &str = "/v1/something/"; let srv = |req: ServiceRequest| { assert_eq!(URI, req.path()); - req.into_response(HttpResponse::Ok().finish()) + ok(req.into_response(HttpResponse::Ok().finish())) }; - let mut normalize = - block_on(NormalizePath.new_transform(srv.into_service())).unwrap(); + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); let req = TestRequest::with_uri(URI).to_srv_request(); - let res = block_on(normalize.call(req)).unwrap(); + let res = normalize.call(req).await.unwrap(); assert!(res.status().is_success()); } } diff --git a/src/request.rs b/src/request.rs index 6d9d26e8c..cd9c72313 100644 --- a/src/request.rs +++ b/src/request.rs @@ -5,9 +5,9 @@ use std::{fmt, net}; use actix_http::http::{HeaderMap, Method, Uri, Version}; use actix_http::{Error, Extensions, HttpMessage, Message, Payload, RequestHead}; use actix_router::{Path, Url}; +use futures::future::{ok, Ready}; use crate::config::AppConfig; -use crate::data::Data; use crate::error::UrlGenerationError; use crate::extract::FromRequest; use crate::info::ConnectionInfo; @@ -124,13 +124,13 @@ impl HttpRequest { /// Request extensions #[inline] - pub fn extensions(&self) -> Ref { + pub fn extensions(&self) -> Ref<'_, Extensions> { self.head().extensions() } /// Mutable reference to a the request's extensions #[inline] - pub fn extensions_mut(&self) -> RefMut { + pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.head().extensions_mut() } @@ -196,7 +196,7 @@ impl HttpRequest { /// This method panics if request's extensions container is already /// borrowed. #[inline] - pub fn connection_info(&self) -> Ref { + pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> { ConnectionInfo::get(self.head(), &*self.app_config()) } @@ -206,25 +206,21 @@ impl HttpRequest { &self.0.config } - /// Get an application data stored with `App::data()` method during - /// application configuration. + /// Get an application data object stored with `App::data` or `App::app_data` + /// methods during application configuration. + /// + /// If `App::data` was used to store object, use `Data`: + /// + /// ```rust,ignore + /// let opt_t = req.app_data::>(); + /// ``` pub fn app_data(&self) -> Option<&T> { - if let Some(st) = self.0.app_data.get::>() { + if let Some(st) = self.0.app_data.get::() { Some(&st) } else { None } } - - /// Get an application data stored with `App::data()` method during - /// application configuration. - pub fn get_app_data(&self) -> Option> { - if let Some(st) = self.0.app_data.get::>() { - Some(st.clone()) - } else { - None - } - } } impl HttpMessage for HttpRequest { @@ -238,13 +234,13 @@ impl HttpMessage for HttpRequest { /// Request extensions #[inline] - fn extensions(&self) -> Ref { + fn extensions(&self) -> Ref<'_, Extensions> { self.0.head.extensions() } /// Mutable reference to a the request's extensions #[inline] - fn extensions_mut(&self) -> RefMut { + fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.0.head.extensions_mut() } @@ -271,11 +267,11 @@ impl Drop for HttpRequest { /// ## Example /// /// ```rust -/// # #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, HttpRequest}; +/// use serde_derive::Deserialize; /// /// /// extract `Thing` from request -/// fn index(req: HttpRequest) -> String { +/// async fn index(req: HttpRequest) -> String { /// format!("Got thing: {:?}", req) /// } /// @@ -289,16 +285,16 @@ impl Drop for HttpRequest { impl FromRequest for HttpRequest { type Config = (); type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - Ok(req.clone()) + ok(req.clone()) } } impl fmt::Debug for HttpRequest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "\nHttpRequest {:?} {}:{}", @@ -465,9 +461,9 @@ mod tests { ); } - #[test] - fn test_app_data() { - let mut srv = init_service(App::new().data(10usize).service( + #[actix_rt::test] + async fn test_data() { + let mut srv = init_service(App::new().app_data(10usize).service( web::resource("/").to(|req: HttpRequest| { if req.app_data::().is_some() { HttpResponse::Ok() @@ -475,13 +471,14 @@ mod tests { HttpResponse::BadRequest() } }), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); - let mut srv = init_service(App::new().data(10u32).service( + let mut srv = init_service(App::new().app_data(10u32).service( web::resource("/").to(|req: HttpRequest| { if req.app_data::().is_some() { HttpResponse::Ok() @@ -489,15 +486,16 @@ mod tests { HttpResponse::BadRequest() } }), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } - #[test] - fn test_extensions_dropped() { + #[actix_rt::test] + async fn test_extensions_dropped() { struct Tracker { pub dropped: bool, } @@ -520,10 +518,11 @@ mod tests { }); HttpResponse::Ok() }), - )); + )) + .await; let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } diff --git a/src/resource.rs b/src/resource.rs index b872049d0..d03024a07 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,26 +1,29 @@ use std::cell::RefCell; use std::fmt; +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{Error, Extensions, Response}; -use actix_service::boxed::{self, BoxedNewService, BoxedService}; +use actix_router::IntoPattern; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{ - apply_transform, IntoNewService, IntoTransform, NewService, Service, Transform, + apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, }; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll}; +use futures::future::{ok, Either, LocalBoxFuture, Ready}; use crate::data::Data; use crate::dev::{insert_slash, AppService, HttpServiceFactory, ResourceDef}; use crate::extract::FromRequest; use crate::guard::Guard; -use crate::handler::{AsyncFactory, Factory}; +use crate::handler::Factory; use crate::responder::Responder; use crate::route::{CreateRouteService, Route, RouteService}; use crate::service::{ServiceRequest, ServiceResponse}; -type HttpService = BoxedService; -type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; /// *Resource* is an entry in resources table which corresponds to requested URL. /// @@ -46,7 +49,7 @@ type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error /// Default behavior could be overriden with `default_resource()` method. pub struct Resource { endpoint: T, - rdef: String, + rdef: Vec, name: Option, routes: Vec, data: Option, @@ -56,12 +59,12 @@ pub struct Resource { } impl Resource { - pub fn new(path: &str) -> Resource { + pub fn new(path: T) -> Resource { let fref = Rc::new(RefCell::new(None)); Resource { routes: Vec::new(), - rdef: path.to_string(), + rdef: path.patterns(), name: None, endpoint: ResourceEndpoint::new(fref.clone()), factory_ref: fref, @@ -74,7 +77,7 @@ impl Resource { impl Resource where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -95,7 +98,7 @@ where /// ```rust /// use actix_web::{web, guard, App, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -143,7 +146,7 @@ where /// match guards for route selection. /// /// ```rust - /// use actix_web::{web, guard, App, HttpResponse}; + /// use actix_web::{web, guard, App}; /// /// fn main() { /// let app = App::new().service( @@ -153,9 +156,9 @@ where /// .route(web::delete().to(delete_handler)) /// ); /// } - /// # fn get_handler() {} - /// # fn post_handler() {} - /// # fn delete_handler() {} + /// # async fn get_handler() -> impl actix_web::Responder { actix_web::HttpResponse::Ok() } + /// # async fn post_handler() -> impl actix_web::Responder { actix_web::HttpResponse::Ok() } + /// # async fn delete_handler() -> impl actix_web::Responder { actix_web::HttpResponse::Ok() } /// ``` pub fn route(mut self, route: Route) -> Self { self.routes.push(route); @@ -171,7 +174,7 @@ where /// use actix_web::{web, App, FromRequest}; /// /// /// extract text data from request - /// fn index(body: String) -> String { + /// async fn index(body: String) -> String { /// format!("Body {}!", body) /// } /// @@ -189,11 +192,18 @@ where /// )); /// } /// ``` - pub fn data(mut self, data: U) -> Self { + pub fn data(self, data: U) -> Self { + self.app_data(Data::new(data)) + } + + /// Set or override application data. + /// + /// This method overrides data stored with [`App::app_data()`](#method.app_data) + pub fn app_data(mut self, data: U) -> Self { if self.data.is_none() { self.data = Some(Extensions::new()); } - self.data.as_mut().unwrap().insert(Data::new(data)); + self.data.as_mut().unwrap().insert(data); self } @@ -217,52 +227,17 @@ where /// # fn index(req: HttpRequest) -> HttpResponse { unimplemented!() } /// App::new().service(web::resource("/").route(web::route().to(index))); /// ``` - pub fn to(mut self, handler: F) -> Self + pub fn to(mut self, handler: F) -> Self where - F: Factory + 'static, + F: Factory, I: FromRequest + 'static, - R: Responder + 'static, + R: Future + 'static, + U: Responder + 'static, { self.routes.push(Route::new().to(handler)); self } - /// Register a new route and add async handler. - /// - /// ```rust - /// use actix_web::*; - /// use futures::future::{ok, Future}; - /// - /// fn index(req: HttpRequest) -> impl Future { - /// ok(HttpResponse::Ok().finish()) - /// } - /// - /// App::new().service(web::resource("/").to_async(index)); - /// ``` - /// - /// This is shortcut for: - /// - /// ```rust - /// # use actix_web::*; - /// # use futures::future::Future; - /// # fn index(req: HttpRequest) -> Box> { - /// # unimplemented!() - /// # } - /// App::new().service(web::resource("/").route(web::route().to_async(index))); - /// ``` - #[allow(clippy::wrong_self_convention)] - pub fn to_async(mut self, handler: F) -> Self - where - F: AsyncFactory, - I: FromRequest + 'static, - R: IntoFuture + 'static, - R::Item: Responder, - R::Error: Into, - { - self.routes.push(Route::new().to_async(handler)); - self - } - /// Register a resource middleware. /// /// This is similar to `App's` middlewares, but middleware get invoked on resource level. @@ -270,11 +245,11 @@ where /// type (i.e modify response's body). /// /// **Note**: middlewares get called in opposite order of middlewares registration. - pub fn wrap( + pub fn wrap( self, - mw: F, + mw: M, ) -> Resource< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -290,11 +265,9 @@ where Error = Error, InitError = (), >, - F: IntoTransform, { - let endpoint = apply_transform(mw, self.endpoint); Resource { - endpoint, + endpoint: apply(mw, self.endpoint), rdef: self.rdef, name: self.name, guards: self.guards, @@ -316,24 +289,26 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/index.html") - /// .wrap_fn(|req, srv| - /// srv.call(req).map(|mut res| { + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; /// res.headers_mut().insert( /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// ); - /// res - /// })) + /// Ok(res) + /// } + /// }) /// .route(web::get().to(index))); /// } /// ``` @@ -341,7 +316,7 @@ where self, mw: F, ) -> Resource< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -351,9 +326,18 @@ where > where F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, - R: IntoFuture, + R: Future>, { - self.wrap(mw) + Resource { + endpoint: apply_fn_factory(self.endpoint, mw), + rdef: self.rdef, + name: self.name, + guards: self.guards, + routes: self.routes, + default: self.default, + data: self.data, + factory_ref: self.factory_ref, + } } /// Default service to be used if no matching route could be found. @@ -361,8 +345,8 @@ where /// default handler from `App` or `Scope`. pub fn default_service(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -371,8 +355,8 @@ where U::InitError: fmt::Debug, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|e| { + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|e| { log::error!("Can not construct default service: {:?}", e) }), ))))); @@ -383,7 +367,7 @@ where impl HttpServiceFactory for Resource where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -398,9 +382,9 @@ where Some(std::mem::replace(&mut self.guards, Vec::new())) }; let mut rdef = if config.is_root() || !self.rdef.is_empty() { - ResourceDef::new(&insert_slash(&self.rdef)) + ResourceDef::new(insert_slash(self.rdef.clone())) } else { - ResourceDef::new(&self.rdef) + ResourceDef::new(self.rdef.clone()) }; if let Some(ref name) = self.name { *rdef.name_mut() = name.clone(); @@ -413,9 +397,9 @@ where } } -impl IntoNewService for Resource +impl IntoServiceFactory for Resource where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -423,7 +407,7 @@ where InitError = (), >, { - fn into_new_service(self) -> T { + fn into_factory(self) -> T { *self.factory_ref.borrow_mut() = Some(ResourceFactory { routes: self.routes, data: self.data.map(Rc::new), @@ -440,7 +424,7 @@ pub struct ResourceFactory { default: Rc>>>, } -impl NewService for ResourceFactory { +impl ServiceFactory for ResourceFactory { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -449,9 +433,9 @@ impl NewService for ResourceFactory { type Service = ResourceService; type Future = CreateResourceService; - fn new_service(&self, _: &()) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { let default_fut = if let Some(ref default) = *self.default.borrow() { - Some(default.new_service(&())) + Some(default.new_service(())) } else { None }; @@ -460,7 +444,7 @@ impl NewService for ResourceFactory { fut: self .routes .iter() - .map(|route| CreateRouteServiceItem::Future(route.new_service(&()))) + .map(|route| CreateRouteServiceItem::Future(route.new_service(()))) .collect(), data: self.data.clone(), default: None, @@ -478,31 +462,30 @@ pub struct CreateResourceService { fut: Vec, data: Option>, default: Option, - default_fut: Option>>, + default_fut: Option>>, } impl Future for CreateResourceService { - type Item = ResourceService; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut done = true; if let Some(ref mut fut) = self.default_fut { - match fut.poll()? { - Async::Ready(default) => self.default = Some(default), - Async::NotReady => done = false, + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, } } // poll http services for item in &mut self.fut { match item { - CreateRouteServiceItem::Future(ref mut fut) => match fut.poll()? { - Async::Ready(route) => { - *item = CreateRouteServiceItem::Service(route) - } - Async::NotReady => { + CreateRouteServiceItem::Future(ref mut fut) => match Pin::new(fut) + .poll(cx)? + { + Poll::Ready(route) => *item = CreateRouteServiceItem::Service(route), + Poll::Pending => { done = false; } }, @@ -519,13 +502,13 @@ impl Future for CreateResourceService { CreateRouteServiceItem::Future(_) => unreachable!(), }) .collect(); - Ok(Async::Ready(ResourceService { + Poll::Ready(Ok(ResourceService { routes, data: self.data.clone(), default: self.default.take(), })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -541,12 +524,12 @@ impl Service for ResourceService { type Response = ServiceResponse; type Error = Error; type Future = Either< - FutureResult, - Box>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { @@ -555,14 +538,14 @@ impl Service for ResourceService { if let Some(ref data) = self.data { req.set_data_container(data.clone()); } - return route.call(req); + return Either::Right(route.call(req)); } } if let Some(ref mut default) = self.default { - default.call(req) + Either::Right(default.call(req)) } else { let req = req.into_parts().0; - Either::A(ok(ServiceResponse::new( + Either::Left(ok(ServiceResponse::new( req, Response::MethodNotAllowed().finish(), ))) @@ -581,7 +564,7 @@ impl ResourceEndpoint { } } -impl NewService for ResourceEndpoint { +impl ServiceFactory for ResourceEndpoint { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -590,8 +573,8 @@ impl NewService for ResourceEndpoint { type Service = ResourceService; type Future = CreateResourceService; - fn new_service(&self, _: &()) -> Self::Future { - self.factory.borrow_mut().as_mut().unwrap().new_service(&()) + fn new_service(&self, _: ()) -> Self::Future { + self.factory.borrow_mut().as_mut().unwrap().new_service(()) } } @@ -599,45 +582,33 @@ impl NewService for ResourceEndpoint { mod tests { use std::time::Duration; + use actix_rt::time::delay_for; use actix_service::Service; - use futures::{Future, IntoFuture}; - use tokio_timer::sleep; + use futures::future::ok; use crate::http::{header, HeaderValue, Method, StatusCode}; - use crate::service::{ServiceRequest, ServiceResponse}; + use crate::middleware::DefaultHeaders; + use crate::service::ServiceRequest; use crate::test::{call_service, init_service, TestRequest}; use crate::{guard, web, App, Error, HttpResponse}; - fn md( - req: ServiceRequest, - srv: &mut S, - ) -> impl IntoFuture, Error = Error> - where - S: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - >, - { - srv.call(req).map(|mut res| { - res.headers_mut() - .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); - res - }) - } - - #[test] - fn test_middleware() { - let mut srv = init_service( - App::new().service( - web::resource("/test") - .name("test") - .wrap(md) - .route(web::get().to(|| HttpResponse::Ok())), - ), - ); + #[actix_rt::test] + async fn test_middleware() { + let mut srv = + init_service( + App::new().service( + web::resource("/test") + .name("test") + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )) + .route(web::get().to(|| HttpResponse::Ok())), + ), + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -645,25 +616,29 @@ mod tests { ); } - #[test] - fn test_middleware_fn() { + #[actix_rt::test] + async fn test_middleware_fn() { let mut srv = init_service( App::new().service( web::resource("/test") .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); - res - }) + let fut = srv.call(req); + async { + fut.await.map(|mut res| { + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + res + }) + } }) .route(web::get().to(|| HttpResponse::Ok())), ), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -671,36 +646,58 @@ mod tests { ); } - #[test] - fn test_to_async() { + #[actix_rt::test] + async fn test_to() { let mut srv = - init_service(App::new().service(web::resource("/test").to_async(|| { - sleep(Duration::from_millis(100)).then(|_| HttpResponse::Ok()) - }))); + init_service(App::new().service(web::resource("/test").to(|| { + async { + delay_for(Duration::from_millis(100)).await; + Ok::<_, Error>(HttpResponse::Ok()) + } + }))) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_default_resource() { + #[actix_rt::test] + async fn test_pattern() { + let mut srv = init_service( + App::new().service( + web::resource(["/test", "/test2"]) + .to(|| async { Ok::<_, Error>(HttpResponse::Ok()) }), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let req = TestRequest::with_uri("/test2").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_default_resource() { let mut srv = init_service( App::new() .service( web::resource("/test").route(web::get().to(|| HttpResponse::Ok())), ) .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::BadRequest()) + ok(r.into_response(HttpResponse::BadRequest())) }), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test") .method(Method::POST) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); let mut srv = init_service( @@ -708,24 +705,25 @@ mod tests { web::resource("/test") .route(web::get().to(|| HttpResponse::Ok())) .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::BadRequest()) + ok(r.into_response(HttpResponse::BadRequest())) }), ), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test") .method(Method::POST) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } - #[test] - fn test_resource_guards() { + #[actix_rt::test] + async fn test_resource_guards() { let mut srv = init_service( App::new() .service( @@ -743,24 +741,56 @@ mod tests { .guard(guard::Delete()) .to(|| HttpResponse::NoContent()), ), - ); + ) + .await; let req = TestRequest::with_uri("/test/it") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test/it") .method(Method::PUT) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::CREATED); let req = TestRequest::with_uri("/test/it") .method(Method::DELETE) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::NO_CONTENT); } + + #[actix_rt::test] + async fn test_data() { + let mut srv = init_service( + App::new() + .data(1.0f64) + .data(1usize) + .app_data(web::Data::new('-')) + .service( + web::resource("/test") + .data(10usize) + .app_data(web::Data::new('*')) + .guard(guard::Get()) + .to( + |data1: web::Data, + data2: web::Data, + data3: web::Data| { + assert_eq!(**data1, 10); + assert_eq!(**data2, '*'); + assert_eq!(**data3, 1.0); + HttpResponse::Ok() + }, + ), + ), + ) + .await; + + let req = TestRequest::get().uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } } diff --git a/src/responder.rs b/src/responder.rs index 4988ad5bc..7189eecf1 100644 --- a/src/responder.rs +++ b/src/responder.rs @@ -1,12 +1,18 @@ +use std::convert::TryFrom; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_http::error::InternalError; use actix_http::http::{ - header::IntoHeaderValue, Error as HttpError, HeaderMap, HeaderName, HttpTryFrom, - StatusCode, + header::IntoHeaderValue, Error as HttpError, HeaderMap, HeaderName, StatusCode, }; use actix_http::{Error, Response, ResponseBuilder}; use bytes::{Bytes, BytesMut}; -use futures::future::{err, ok, Either as EitherFuture, FutureResult}; -use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use futures::future::{err, ok, Either as EitherFuture, Ready}; +use futures::ready; +use pin_project::{pin_project, project}; use crate::request::HttpRequest; @@ -18,7 +24,7 @@ pub trait Responder { type Error: Into; /// The future response value. - type Future: IntoFuture; + type Future: Future>; /// Convert itself to `AsyncResult` or `Error`. fn respond_to(self, req: &HttpRequest) -> Self::Future; @@ -62,7 +68,8 @@ pub trait Responder { fn with_header(self, key: K, value: V) -> CustomResponder where Self: Sized, - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { CustomResponder::new(self).with_header(key, value) @@ -71,7 +78,7 @@ pub trait Responder { impl Responder for Response { type Error = Error; - type Future = FutureResult; + type Future = Ready>; #[inline] fn respond_to(self, _: &HttpRequest) -> Self::Future { @@ -84,15 +91,14 @@ where T: Responder, { type Error = T::Error; - type Future = EitherFuture< - ::Future, - FutureResult, - >; + type Future = EitherFuture>>; fn respond_to(self, req: &HttpRequest) -> Self::Future { match self { - Some(t) => EitherFuture::A(t.respond_to(req).into_future()), - None => EitherFuture::B(ok(Response::build(StatusCode::NOT_FOUND).finish())), + Some(t) => EitherFuture::Left(t.respond_to(req)), + None => { + EitherFuture::Right(ok(Response::build(StatusCode::NOT_FOUND).finish())) + } } } } @@ -104,23 +110,21 @@ where { type Error = Error; type Future = EitherFuture< - ResponseFuture<::Future>, - FutureResult, + ResponseFuture, + Ready>, >; fn respond_to(self, req: &HttpRequest) -> Self::Future { match self { - Ok(val) => { - EitherFuture::A(ResponseFuture::new(val.respond_to(req).into_future())) - } - Err(e) => EitherFuture::B(err(e.into())), + Ok(val) => EitherFuture::Left(ResponseFuture::new(val.respond_to(req))), + Err(e) => EitherFuture::Right(err(e.into())), } } } impl Responder for ResponseBuilder { type Error = Error; - type Future = FutureResult; + type Future = Ready>; #[inline] fn respond_to(mut self, _: &HttpRequest) -> Self::Future { @@ -128,15 +132,6 @@ impl Responder for ResponseBuilder { } } -impl Responder for () { - type Error = Error; - type Future = FutureResult; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(Response::build(StatusCode::OK).finish()) - } -} - impl Responder for (T, StatusCode) where T: Responder, @@ -146,7 +141,7 @@ where fn respond_to(self, req: &HttpRequest) -> Self::Future { CustomResponderFut { - fut: self.0.respond_to(req).into_future(), + fut: self.0.respond_to(req), status: Some(self.1), headers: None, } @@ -155,7 +150,7 @@ where impl Responder for &'static str { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -166,7 +161,7 @@ impl Responder for &'static str { impl Responder for &'static [u8] { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -177,7 +172,7 @@ impl Responder for &'static [u8] { impl Responder for String { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -188,7 +183,7 @@ impl Responder for String { impl<'a> Responder for &'a String { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -199,7 +194,7 @@ impl<'a> Responder for &'a String { impl Responder for Bytes { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -210,7 +205,7 @@ impl Responder for Bytes { impl Responder for BytesMut { type Error = Error; - type Future = FutureResult; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { ok(Response::build(StatusCode::OK) @@ -273,7 +268,8 @@ impl CustomResponder { /// ``` pub fn with_header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { if self.headers.is_none() { @@ -299,45 +295,49 @@ impl Responder for CustomResponder { fn respond_to(self, req: &HttpRequest) -> Self::Future { CustomResponderFut { - fut: self.responder.respond_to(req).into_future(), + fut: self.responder.respond_to(req), status: self.status, headers: self.headers, } } } +#[pin_project] pub struct CustomResponderFut { - fut: ::Future, + #[pin] + fut: T::Future, status: Option, headers: Option, } impl Future for CustomResponderFut { - type Item = Response; - type Error = T::Error; + type Output = Result; - fn poll(&mut self) -> Poll { - let mut res = try_ready!(self.fut.poll()); - if let Some(status) = self.status { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + let mut res = match ready!(this.fut.poll(cx)) { + Ok(res) => res, + Err(e) => return Poll::Ready(Err(e)), + }; + if let Some(status) = this.status.take() { *res.status_mut() = status; } - if let Some(ref headers) = self.headers { + if let Some(ref headers) = this.headers { for (k, v) in headers { res.headers_mut().insert(k.clone(), v.clone()); } } - Ok(Async::Ready(res)) + Poll::Ready(Ok(res)) } } /// Combines two different responder types into a single type /// /// ```rust -/// # use futures::future::{ok, Future}; /// use actix_web::{Either, Error, HttpResponse}; /// -/// type RegisterResult = -/// Either>>; +/// type RegisterResult = Either>; /// /// fn index() -> RegisterResult { /// if is_a_variant() { @@ -346,9 +346,9 @@ impl Future for CustomResponderFut { /// } else { /// Either::B( /// // <- Right variant -/// Box::new(ok(HttpResponse::Ok() +/// Ok(HttpResponse::Ok() /// .content_type("text/html") -/// .body("Hello!"))) +/// .body("Hello!")) /// ) /// } /// } @@ -369,97 +369,85 @@ where B: Responder, { type Error = Error; - type Future = EitherResponder< - ::Future, - ::Future, - >; + type Future = EitherResponder; fn respond_to(self, req: &HttpRequest) -> Self::Future { match self { - Either::A(a) => EitherResponder::A(a.respond_to(req).into_future()), - Either::B(b) => EitherResponder::B(b.respond_to(req).into_future()), + Either::A(a) => EitherResponder::A(a.respond_to(req)), + Either::B(b) => EitherResponder::B(b.respond_to(req)), } } } +#[pin_project] pub enum EitherResponder where - A: Future, - A::Error: Into, - B: Future, - B::Error: Into, + A: Responder, + B: Responder, { - A(A), - B(B), + A(#[pin] A::Future), + B(#[pin] B::Future), } impl Future for EitherResponder where - A: Future, - A::Error: Into, - B: Future, - B::Error: Into, + A: Responder, + B: Responder, { - type Item = Response; - type Error = Error; + type Output = Result; - fn poll(&mut self) -> Poll { - match self { - EitherResponder::A(ref mut fut) => Ok(fut.poll().map_err(|e| e.into())?), - EitherResponder::B(ref mut fut) => Ok(fut.poll().map_err(|e| e.into())?), + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + #[project] + match self.project() { + EitherResponder::A(fut) => { + Poll::Ready(ready!(fut.poll(cx)).map_err(|e| e.into())) + } + EitherResponder::B(fut) => { + Poll::Ready(ready!(fut.poll(cx).map_err(|e| e.into()))) + } } } } -impl Responder for Box> -where - I: Responder + 'static, - E: Into + 'static, -{ - type Error = Error; - type Future = Box>; - - #[inline] - fn respond_to(self, req: &HttpRequest) -> Self::Future { - let req = req.clone(); - Box::new( - self.map_err(|e| e.into()) - .and_then(move |r| ResponseFuture(r.respond_to(&req).into_future())), - ) - } -} - impl Responder for InternalError where T: std::fmt::Debug + std::fmt::Display + 'static, { type Error = Error; - type Future = Result; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { let err: Error = self.into(); - Ok(err.into()) + ok(err.into()) } } -pub struct ResponseFuture(T); +#[pin_project] +pub struct ResponseFuture { + #[pin] + fut: T, + _t: PhantomData, +} -impl ResponseFuture { +impl ResponseFuture { pub fn new(fut: T) -> Self { - ResponseFuture(fut) + ResponseFuture { + fut, + _t: PhantomData, + } } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - T: Future, - T::Error: Into, + T: Future>, + E: Into, { - type Item = Response; - type Error = Error; + type Output = Result; - fn poll(&mut self) -> Poll { - Ok(self.0.poll().map_err(|e| e.into())?) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Poll::Ready(ready!(self.project().fut.poll(cx)).map_err(|e| e.into())) } } @@ -471,23 +459,26 @@ pub(crate) mod tests { use super::*; use crate::dev::{Body, ResponseBody}; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; - use crate::test::{block_on, init_service, TestRequest}; + use crate::test::{init_service, TestRequest}; use crate::{error, web, App, HttpResponse}; - #[test] - fn test_option_responder() { + #[actix_rt::test] + async fn test_option_responder() { let mut srv = init_service( App::new() - .service(web::resource("/none").to(|| -> Option<&'static str> { None })) - .service(web::resource("/some").to(|| Some("some"))), - ); + .service( + web::resource("/none").to(|| async { Option::<&'static str>::None }), + ) + .service(web::resource("/some").to(|| async { Some("some") })), + ) + .await; let req = TestRequest::with_uri("/none").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); let req = TestRequest::with_uri("/some").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); match resp.response().body() { ResponseBody::Body(Body::Bytes(ref b)) => { @@ -524,15 +515,11 @@ pub(crate) mod tests { } } - #[test] - fn test_responder() { + #[actix_rt::test] + async fn test_responder() { let req = TestRequest::default().to_http_request(); - let resp: HttpResponse = block_on(().respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(*resp.body().body(), Body::Empty); - - let resp: HttpResponse = block_on("test".respond_to(&req)).unwrap(); + let resp: HttpResponse = "test".respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -540,7 +527,7 @@ pub(crate) mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); - let resp: HttpResponse = block_on(b"test".respond_to(&req)).unwrap(); + let resp: HttpResponse = b"test".respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -548,7 +535,15 @@ pub(crate) mod tests { HeaderValue::from_static("application/octet-stream") ); - let resp: HttpResponse = block_on("test".to_string().respond_to(&req)).unwrap(); + let resp: HttpResponse = "test".to_string().respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + + let resp: HttpResponse = (&"test".to_string()).respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -557,16 +552,7 @@ pub(crate) mod tests { ); let resp: HttpResponse = - block_on((&"test".to_string()).respond_to(&req)).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - - let resp: HttpResponse = - block_on(Bytes::from_static(b"test").respond_to(&req)).unwrap(); + Bytes::from_static(b"test").respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -574,8 +560,10 @@ pub(crate) mod tests { HeaderValue::from_static("application/octet-stream") ); - let resp: HttpResponse = - block_on(BytesMut::from(b"test".as_ref()).respond_to(&req)).unwrap(); + let resp: HttpResponse = BytesMut::from(b"test".as_ref()) + .respond_to(&req) + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -587,17 +575,20 @@ pub(crate) mod tests { let resp: HttpResponse = error::InternalError::new("err", StatusCode::BAD_REQUEST) .respond_to(&req) + .await .unwrap(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } - #[test] - fn test_result_responder() { + #[actix_rt::test] + async fn test_result_responder() { let req = TestRequest::default().to_http_request(); // Result - let resp: HttpResponse = - block_on(Ok::<_, Error>("test".to_string()).respond_to(&req)).unwrap(); + let resp: HttpResponse = Ok::<_, Error>("test".to_string()) + .respond_to(&req) + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -605,33 +596,31 @@ pub(crate) mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); - let res = block_on( + let res = Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) - .respond_to(&req), - ); + .respond_to(&req) + .await; assert!(res.is_err()); } - #[test] - fn test_custom_responder() { + #[actix_rt::test] + async fn test_custom_responder() { let req = TestRequest::default().to_http_request(); - let res = block_on( - "test" - .to_string() - .with_status(StatusCode::BAD_REQUEST) - .respond_to(&req), - ) - .unwrap(); + let res = "test" + .to_string() + .with_status(StatusCode::BAD_REQUEST) + .respond_to(&req) + .await + .unwrap(); assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.body().bin_ref(), b"test"); - let res = block_on( - "test" - .to_string() - .with_header("content-type", "json") - .respond_to(&req), - ) - .unwrap(); + let res = "test" + .to_string() + .with_header("content-type", "json") + .respond_to(&req) + .await + .unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.body().bin_ref(), b"test"); @@ -641,22 +630,22 @@ pub(crate) mod tests { ); } - #[test] - fn test_tuple_responder_with_status_code() { + #[actix_rt::test] + async fn test_tuple_responder_with_status_code() { let req = TestRequest::default().to_http_request(); - let res = - block_on(("test".to_string(), StatusCode::BAD_REQUEST).respond_to(&req)) - .unwrap(); + let res = ("test".to_string(), StatusCode::BAD_REQUEST) + .respond_to(&req) + .await + .unwrap(); assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.body().bin_ref(), b"test"); let req = TestRequest::default().to_http_request(); - let res = block_on( - ("test".to_string(), StatusCode::OK) - .with_header("content-type", "json") - .respond_to(&req), - ) - .unwrap(); + let res = ("test".to_string(), StatusCode::OK) + .with_header("content-type", "json") + .respond_to(&req) + .await + .unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.body().bin_ref(), b"test"); assert_eq!( diff --git a/src/rmap.rs b/src/rmap.rs index 42ddb1349..47092608c 100644 --- a/src/rmap.rs +++ b/src/rmap.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::rc::Rc; use actix_router::ResourceDef; -use hashbrown::HashMap; +use fxhash::FxHashMap; use url::Url; use crate::error::UrlGenerationError; @@ -12,7 +12,7 @@ use crate::request::HttpRequest; pub struct ResourceMap { root: ResourceDef, parent: RefCell>>, - named: HashMap, + named: FxHashMap, patterns: Vec<(ResourceDef, Option>)>, } @@ -21,7 +21,7 @@ impl ResourceMap { ResourceMap { root, parent: RefCell::new(None), - named: HashMap::new(), + named: FxHashMap::default(), patterns: Vec::new(), } } diff --git a/src/route.rs b/src/route.rs index f4d303632..f7e391746 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,13 +1,15 @@ +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{http::Method, Error}; -use actix_service::{NewService, Service}; -use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ready, FutureExt, LocalBoxFuture}; use crate::extract::FromRequest; use crate::guard::{self, Guard}; -use crate::handler::{AsyncFactory, AsyncHandler, Extract, Factory, Handler}; +use crate::handler::{Extract, Factory, Handler}; use crate::responder::Responder; use crate::service::{ServiceRequest, ServiceResponse}; use crate::HttpResponse; @@ -17,22 +19,19 @@ type BoxedRouteService = Box< Request = Req, Response = Res, Error = Error, - Future = Either< - FutureResult, - Box>, - >, + Future = LocalBoxFuture<'static, Result>, >, >; type BoxedRouteNewService = Box< - dyn NewService< + dyn ServiceFactory< Config = (), Request = Req, Response = Res, Error = Error, InitError = (), Service = BoxedRouteService, - Future = Box, Error = ()>>, + Future = LocalBoxFuture<'static, Result, ()>>, >, >; @@ -50,7 +49,7 @@ impl Route { pub fn new() -> Route { Route { service: Box::new(RouteNewService::new(Extract::new(Handler::new(|| { - HttpResponse::NotFound() + ready(HttpResponse::NotFound()) })))), guards: Rc::new(Vec::new()), } @@ -61,7 +60,7 @@ impl Route { } } -impl NewService for Route { +impl ServiceFactory for Route { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -70,34 +69,38 @@ impl NewService for Route { type Service = RouteService; type Future = CreateRouteService; - fn new_service(&self, _: &()) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { CreateRouteService { - fut: self.service.new_service(&()), + fut: self.service.new_service(()), guards: self.guards.clone(), } } } -type RouteFuture = Box< - dyn Future, Error = ()>, +type RouteFuture = LocalBoxFuture< + 'static, + Result, ()>, >; +#[pin_project::pin_project] pub struct CreateRouteService { + #[pin] fut: RouteFuture, guards: Rc>>, } impl Future for CreateRouteService { - type Item = RouteService; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::Ready(service) => Ok(Async::Ready(RouteService { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match this.fut.poll(cx)? { + Poll::Ready(service) => Poll::Ready(Ok(RouteService { service, - guards: self.guards.clone(), + guards: this.guards.clone(), })), - Async::NotReady => Ok(Async::NotReady), + Poll::Pending => Poll::Pending, } } } @@ -122,17 +125,14 @@ impl Service for RouteService { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Either< - FutureResult, - Box>, - >; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - self.service.call(req) + self.service.call(req).boxed_local() } } @@ -178,8 +178,8 @@ impl Route { /// Set handler function, use request extractors for parameters. /// /// ```rust - /// #[macro_use] extern crate serde_derive; /// use actix_web::{web, http, App}; + /// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -187,7 +187,7 @@ impl Route { /// } /// /// /// extract path info using serde - /// fn index(info: web::Path) -> String { + /// async fn index(info: web::Path) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -212,7 +212,7 @@ impl Route { /// } /// /// /// extract path info using serde - /// fn index(path: web::Path, query: web::Query>, body: web::Json) -> String { + /// async fn index(path: web::Path, query: web::Query>, body: web::Json) -> String { /// format!("Welcome {}!", path.username) /// } /// @@ -223,69 +223,29 @@ impl Route { /// ); /// } /// ``` - pub fn to(mut self, handler: F) -> Route + pub fn to(mut self, handler: F) -> Self where - F: Factory + 'static, + F: Factory, T: FromRequest + 'static, - R: Responder + 'static, + R: Future + 'static, + U: Responder + 'static, { self.service = Box::new(RouteNewService::new(Extract::new(Handler::new(handler)))); self } - - /// Set async handler function, use request extractors for parameters. - /// This method has to be used if your handler function returns `impl Future<>` - /// - /// ```rust - /// # use futures::future::ok; - /// #[macro_use] extern crate serde_derive; - /// use actix_web::{web, App, Error}; - /// use futures::Future; - /// - /// #[derive(Deserialize)] - /// struct Info { - /// username: String, - /// } - /// - /// /// extract path info using serde - /// fn index(info: web::Path) -> impl Future { - /// ok("Hello World!") - /// } - /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/{username}/index.html") // <- define path parameters - /// .route(web::get().to_async(index)) // <- register async handler - /// ); - /// } - /// ``` - #[allow(clippy::wrong_self_convention)] - pub fn to_async(mut self, handler: F) -> Self - where - F: AsyncFactory, - T: FromRequest + 'static, - R: IntoFuture + 'static, - R::Item: Responder, - R::Error: Into, - { - self.service = Box::new(RouteNewService::new(Extract::new(AsyncHandler::new( - handler, - )))); - self - } } struct RouteNewService where - T: NewService, + T: ServiceFactory, { service: T, } impl RouteNewService where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -300,9 +260,9 @@ where } } -impl NewService for RouteNewService +impl ServiceFactory for RouteNewService where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -318,19 +278,20 @@ where type Error = Error; type InitError = (); type Service = BoxedRouteService; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; - fn new_service(&self, _: &()) -> Self::Future { - Box::new( - self.service - .new_service(&()) - .map_err(|_| ()) - .and_then(|service| { + fn new_service(&self, _: ()) -> Self::Future { + self.service + .new_service(()) + .map(|result| match result { + Ok(service) => { let service: BoxedRouteService<_, _> = Box::new(RouteServiceWrapper { service }); Ok(service) - }), - ) + } + Err(_) => Err(()), + }) + .boxed_local() } } @@ -350,25 +311,30 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Either< - FutureResult, - Box>, - >; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready().map_err(|(e, _)| e) + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx).map_err(|(e, _)| e) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - let mut fut = self.service.call(req); - match fut.poll() { - Ok(Async::Ready(res)) => Either::A(ok(res)), - Err((e, req)) => Either::A(ok(req.error_response(e))), - Ok(Async::NotReady) => Either::B(Box::new(fut.then(|res| match res { + // let mut fut = self.service.call(req); + self.service + .call(req) + .map(|res| match res { Ok(res) => Ok(res), Err((err, req)) => Ok(req.error_response(err)), - }))), - } + }) + .boxed_local() + + // match fut.poll() { + // Poll::Ready(Ok(res)) => Either::Left(ok(res)), + // Poll::Ready(Err((e, req))) => Either::Left(ok(req.error_response(e))), + // Poll::Pending => Either::Right(Box::new(fut.then(|res| match res { + // Ok(res) => Ok(res), + // Err((err, req)) => Ok(req.error_response(err)), + // }))), + // } } } @@ -376,10 +342,9 @@ where mod tests { use std::time::Duration; + use actix_rt::time::delay_for; use bytes::Bytes; - use futures::Future; use serde_derive::Serialize; - use tokio_timer::sleep; use crate::http::{Method, StatusCode}; use crate::test::{call_service, init_service, read_body, TestRequest}; @@ -390,70 +355,77 @@ mod tests { name: String, } - #[test] - fn test_route() { + #[actix_rt::test] + async fn test_route() { let mut srv = init_service( App::new() .service( web::resource("/test") .route(web::get().to(|| HttpResponse::Ok())) .route(web::put().to(|| { - Err::(error::ErrorBadRequest("err")) - })) - .route(web::post().to_async(|| { - sleep(Duration::from_millis(100)) - .then(|_| HttpResponse::Created()) - })) - .route(web::delete().to_async(|| { - sleep(Duration::from_millis(100)).then(|_| { + async { Err::(error::ErrorBadRequest("err")) - }) + } + })) + .route(web::post().to(|| { + async { + delay_for(Duration::from_millis(100)).await; + HttpResponse::Created() + } + })) + .route(web::delete().to(|| { + async { + delay_for(Duration::from_millis(100)).await; + Err::(error::ErrorBadRequest("err")) + } })), ) - .service(web::resource("/json").route(web::get().to_async(|| { - sleep(Duration::from_millis(25)).then(|_| { - Ok::<_, crate::Error>(web::Json(MyObject { + .service(web::resource("/json").route(web::get().to(|| { + async { + delay_for(Duration::from_millis(25)).await; + web::Json(MyObject { name: "test".to_string(), - })) - }) + }) + } }))), - ); + ) + .await; let req = TestRequest::with_uri("/test") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test") .method(Method::POST) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::CREATED); let req = TestRequest::with_uri("/test") .method(Method::PUT) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/test") .method(Method::DELETE) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/test") .method(Method::HEAD) .to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); let req = TestRequest::with_uri("/json").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); + let body = read_body(resp).await; assert_eq!(body, Bytes::from_static(b"{\"name\":\"test\"}")); } } diff --git a/src/scope.rs b/src/scope.rs index 760fee478..18e775e61 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,15 +1,16 @@ use std::cell::RefCell; use std::fmt; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use actix_http::{Extensions, Response}; use actix_router::{ResourceDef, ResourceInfo, Router}; -use actix_service::boxed::{self, BoxedNewService, BoxedService}; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{ - apply_transform, IntoNewService, IntoTransform, NewService, Service, Transform, + apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, }; -use futures::future::{ok, Either, Future, FutureResult}; -use futures::{Async, IntoFuture, Poll}; +use futures::future::{ok, Either, Future, LocalBoxFuture, Ready}; use crate::config::ServiceConfig; use crate::data::Data; @@ -20,16 +21,13 @@ use crate::resource::Resource; use crate::rmap::ResourceMap; use crate::route::Route; use crate::service::{ - ServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, + AppServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, }; type Guards = Vec>; -type HttpService = BoxedService; -type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; -type BoxedResponse = Either< - FutureResult, - Box>, ->; +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; +type BoxedResponse = LocalBoxFuture<'static, Result>; /// Resources scope. /// @@ -48,7 +46,7 @@ type BoxedResponse = Either< /// fn main() { /// let app = App::new().service( /// web::scope("/{project_id}/") -/// .service(web::resource("/path1").to(|| HttpResponse::Ok())) +/// .service(web::resource("/path1").to(|| async { HttpResponse::Ok() })) /// .service(web::resource("/path2").route(web::get().to(|| HttpResponse::Ok()))) /// .service(web::resource("/path3").route(web::head().to(|| HttpResponse::MethodNotAllowed()))) /// ); @@ -64,7 +62,7 @@ pub struct Scope { endpoint: T, rdef: String, data: Option, - services: Vec>, + services: Vec>, guards: Vec>, default: Rc>>>, external: Vec, @@ -90,7 +88,7 @@ impl Scope { impl Scope where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -103,7 +101,7 @@ where /// ```rust /// use actix_web::{web, guard, App, HttpRequest, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -128,14 +126,15 @@ where /// /// ```rust /// use std::cell::Cell; - /// use actix_web::{web, App}; + /// use actix_web::{web, App, HttpResponse, Responder}; /// /// struct MyData { /// counter: Cell, /// } /// - /// fn index(data: web::Data) { + /// async fn index(data: web::Data) -> impl Responder { /// data.counter.set(data.counter.get() + 1); + /// HttpResponse::Ok() /// } /// /// fn main() { @@ -148,11 +147,18 @@ where /// ); /// } /// ``` - pub fn data(mut self, data: U) -> Self { + pub fn data(self, data: U) -> Self { + self.app_data(Data::new(data)) + } + + /// Set or override application data. + /// + /// This method overrides data stored with [`App::app_data()`](#method.app_data) + pub fn app_data(mut self, data: U) -> Self { if self.data.is_none() { self.data = Some(Extensions::new()); } - self.data.as_mut().unwrap().insert(Data::new(data)); + self.data.as_mut().unwrap().insert(data); self } @@ -221,7 +227,7 @@ where /// /// struct AppState; /// - /// fn index(req: HttpRequest) -> &'static str { + /// async fn index(req: HttpRequest) -> &'static str { /// "Welcome!" /// } /// @@ -251,7 +257,7 @@ where /// ```rust /// use actix_web::{web, App, HttpResponse}; /// - /// fn index(data: web::Path<(String, String)>) -> &'static str { + /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// @@ -276,8 +282,8 @@ where /// If default resource is not registered, app's default resource is being used. pub fn default_service(mut self, f: F) -> Self where - F: IntoNewService, - U: NewService< + F: IntoServiceFactory, + U: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -286,8 +292,8 @@ where U::InitError: fmt::Debug, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::new_service( - f.into_new_service().map_init_err(|e| { + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|e| { log::error!("Can not construct default service: {:?}", e) }), ))))); @@ -304,11 +310,11 @@ where /// ServiceResponse. /// /// Use middleware when you need to read or modify *every* request in some way. - pub fn wrap( + pub fn wrap( self, - mw: F, + mw: M, ) -> Scope< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -324,11 +330,9 @@ where Error = Error, InitError = (), >, - F: IntoTransform, { - let endpoint = apply_transform(mw, self.endpoint); Scope { - endpoint, + endpoint: apply(mw, self.endpoint), rdef: self.rdef, data: self.data, guards: self.guards, @@ -348,24 +352,26 @@ where /// /// ```rust /// use actix_service::Service; - /// # use futures::Future; /// use actix_web::{web, App}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// - /// fn index() -> &'static str { + /// async fn index() -> &'static str { /// "Welcome!" /// } /// /// fn main() { /// let app = App::new().service( /// web::scope("/app") - /// .wrap_fn(|req, srv| - /// srv.call(req).map(|mut res| { + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; /// res.headers_mut().insert( /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// ); - /// res - /// })) + /// Ok(res) + /// } + /// }) /// .route("/index.html", web::get().to(index))); /// } /// ``` @@ -373,7 +379,7 @@ where self, mw: F, ) -> Scope< - impl NewService< + impl ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -383,15 +389,24 @@ where > where F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, - R: IntoFuture, + R: Future>, { - self.wrap(mw) + Scope { + endpoint: apply_fn_factory(self.endpoint, mw), + rdef: self.rdef, + data: self.data, + guards: self.guards, + services: self.services, + default: self.default, + external: self.external, + factory_ref: self.factory_ref, + } } } impl HttpServiceFactory for Scope where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -462,7 +477,7 @@ pub struct ScopeFactory { default: Rc>>>, } -impl NewService for ScopeFactory { +impl ServiceFactory for ScopeFactory { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -471,9 +486,9 @@ impl NewService for ScopeFactory { type Service = ScopeService; type Future = ScopeFactoryResponse; - fn new_service(&self, _: &()) -> Self::Future { + fn new_service(&self, _: ()) -> Self::Future { let default_fut = if let Some(ref default) = *self.default.borrow() { - Some(default.new_service(&())) + Some(default.new_service(())) } else { None }; @@ -486,7 +501,7 @@ impl NewService for ScopeFactory { CreateScopeServiceItem::Future( Some(path.clone()), guards.borrow_mut().take(), - service.new_service(&()), + service.new_service(()), ) }) .collect(), @@ -499,14 +514,15 @@ impl NewService for ScopeFactory { /// Create scope service #[doc(hidden)] +#[pin_project::pin_project] pub struct ScopeFactoryResponse { fut: Vec, data: Option>, default: Option, - default_fut: Option>>, + default_fut: Option>>, } -type HttpServiceFut = Box>; +type HttpServiceFut = LocalBoxFuture<'static, Result>; enum CreateScopeServiceItem { Future(Option, Option, HttpServiceFut), @@ -514,16 +530,15 @@ enum CreateScopeServiceItem { } impl Future for ScopeFactoryResponse { - type Item = ScopeService; - type Error = (); + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut done = true; if let Some(ref mut fut) = self.default_fut { - match fut.poll()? { - Async::Ready(default) => self.default = Some(default), - Async::NotReady => done = false, + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, } } @@ -534,11 +549,11 @@ impl Future for ScopeFactoryResponse { ref mut path, ref mut guards, ref mut fut, - ) => match fut.poll()? { - Async::Ready(service) => { + ) => match Pin::new(fut).poll(cx)? { + Poll::Ready(service) => { Some((path.take().unwrap(), guards.take(), service)) } - Async::NotReady => { + Poll::Pending => { done = false; None } @@ -564,14 +579,14 @@ impl Future for ScopeFactoryResponse { } router }); - Ok(Async::Ready(ScopeService { + Poll::Ready(Ok(ScopeService { data: self.data.clone(), router: router.finish(), default: self.default.take(), _ready: None, })) } else { - Ok(Async::NotReady) + Poll::Pending } } } @@ -587,10 +602,10 @@ impl Service for ScopeService { type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = Either>; + type Future = Either>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - Ok(Async::Ready(())) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn call(&mut self, mut req: ServiceRequest) -> Self::Future { @@ -609,12 +624,12 @@ impl Service for ScopeService { if let Some(ref data) = self.data { req.set_data_container(data.clone()); } - Either::A(srv.call(req)) + Either::Left(srv.call(req)) } else if let Some(ref mut default) = self.default { - Either::A(default.call(req)) + Either::Left(default.call(req)) } else { let req = req.into_parts().0; - Either::B(ok(ServiceResponse::new(req, Response::NotFound().finish()))) + Either::Right(ok(ServiceResponse::new(req, Response::NotFound().finish()))) } } } @@ -630,7 +645,7 @@ impl ScopeEndpoint { } } -impl NewService for ScopeEndpoint { +impl ServiceFactory for ScopeEndpoint { type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; @@ -639,8 +654,8 @@ impl NewService for ScopeEndpoint { type Service = ScopeService; type Future = ScopeFactoryResponse; - fn new_service(&self, _: &()) -> Self::Future { - self.factory.borrow_mut().as_mut().unwrap().new_service(&()) + fn new_service(&self, _: ()) -> Self::Future { + self.factory.borrow_mut().as_mut().unwrap().new_service(()) } } @@ -648,106 +663,112 @@ impl NewService for ScopeEndpoint { mod tests { use actix_service::Service; use bytes::Bytes; - use futures::{Future, IntoFuture}; + use futures::future::ok; use crate::dev::{Body, ResponseBody}; use crate::http::{header, HeaderValue, Method, StatusCode}; - use crate::service::{ServiceRequest, ServiceResponse}; - use crate::test::{block_on, call_service, init_service, read_body, TestRequest}; - use crate::{guard, web, App, Error, HttpRequest, HttpResponse}; + use crate::middleware::DefaultHeaders; + use crate::service::ServiceRequest; + use crate::test::{call_service, init_service, read_body, TestRequest}; + use crate::{guard, web, App, HttpRequest, HttpResponse}; - #[test] - fn test_scope() { + #[actix_rt::test] + async fn test_scope() { let mut srv = init_service( App::new().service( web::scope("/app") .service(web::resource("/path1").to(|| HttpResponse::Ok())), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_scope_root() { + #[actix_rt::test] + async fn test_scope_root() { let mut srv = init_service( App::new().service( web::scope("/app") .service(web::resource("").to(|| HttpResponse::Ok())) .service(web::resource("/").to(|| HttpResponse::Created())), ), - ); + ) + .await; let req = TestRequest::with_uri("/app").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/app/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::CREATED); } - #[test] - fn test_scope_root2() { + #[actix_rt::test] + async fn test_scope_root2() { let mut srv = init_service(App::new().service( web::scope("/app/").service(web::resource("").to(|| HttpResponse::Ok())), - )); + )) + .await; let req = TestRequest::with_uri("/app").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); let req = TestRequest::with_uri("/app/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_scope_root3() { + #[actix_rt::test] + async fn test_scope_root3() { let mut srv = init_service(App::new().service( web::scope("/app/").service(web::resource("/").to(|| HttpResponse::Ok())), - )); + )) + .await; let req = TestRequest::with_uri("/app").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); let req = TestRequest::with_uri("/app/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); } - #[test] - fn test_scope_route() { + #[actix_rt::test] + async fn test_scope_route() { let mut srv = init_service( App::new().service( web::scope("app") .route("/path1", web::get().to(|| HttpResponse::Ok())) .route("/path1", web::delete().to(|| HttpResponse::Ok())), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/app/path1") .method(Method::DELETE) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/app/path1") .method(Method::POST) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); } - #[test] - fn test_scope_route_without_leading_slash() { + #[actix_rt::test] + async fn test_scope_route_without_leading_slash() { let mut srv = init_service( App::new().service( web::scope("app").service( @@ -756,60 +777,65 @@ mod tests { .route(web::delete().to(|| HttpResponse::Ok())), ), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/app/path1") .method(Method::DELETE) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/app/path1") .method(Method::POST) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); } - #[test] - fn test_scope_guard() { + #[actix_rt::test] + async fn test_scope_guard() { let mut srv = init_service( App::new().service( web::scope("/app") .guard(guard::Get()) .service(web::resource("/path1").to(|| HttpResponse::Ok())), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/path1") .method(Method::POST) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); let req = TestRequest::with_uri("/app/path1") .method(Method::GET) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_scope_variable_segment() { + #[actix_rt::test] + async fn test_scope_variable_segment() { let mut srv = init_service(App::new().service(web::scope("/ab-{project}").service( web::resource("/path1").to(|r: HttpRequest| { - HttpResponse::Ok() - .body(format!("project: {}", &r.match_info()["project"])) + async move { + HttpResponse::Ok() + .body(format!("project: {}", &r.match_info()["project"])) + } }), - ))); + ))) + .await; let req = TestRequest::with_uri("/ab-project1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); match resp.response().body() { @@ -821,12 +847,12 @@ mod tests { } let req = TestRequest::with_uri("/aa-project1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); } - #[test] - fn test_nested_scope() { + #[actix_rt::test] + async fn test_nested_scope() { let mut srv = init_service( App::new().service( web::scope("/app") @@ -834,15 +860,16 @@ mod tests { web::resource("/path1").to(|| HttpResponse::Created()), )), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/t1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::CREATED); } - #[test] - fn test_nested_scope_no_slash() { + #[actix_rt::test] + async fn test_nested_scope_no_slash() { let mut srv = init_service( App::new().service( web::scope("/app") @@ -850,15 +877,16 @@ mod tests { web::resource("/path1").to(|| HttpResponse::Created()), )), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/t1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::CREATED); } - #[test] - fn test_nested_scope_root() { + #[actix_rt::test] + async fn test_nested_scope_root() { let mut srv = init_service( App::new().service( web::scope("/app").service( @@ -867,19 +895,20 @@ mod tests { .service(web::resource("/").to(|| HttpResponse::Created())), ), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/t1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/app/t1/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::CREATED); } - #[test] - fn test_nested_scope_filter() { + #[actix_rt::test] + async fn test_nested_scope_filter() { let mut srv = init_service( App::new().service( web::scope("/app").service( @@ -888,34 +917,38 @@ mod tests { .service(web::resource("/path1").to(|| HttpResponse::Ok())), ), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/t1/path1") .method(Method::POST) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); let req = TestRequest::with_uri("/app/t1/path1") .method(Method::GET) .to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_nested_scope_with_variable_segment() { + #[actix_rt::test] + async fn test_nested_scope_with_variable_segment() { let mut srv = init_service(App::new().service(web::scope("/app").service( web::scope("/{project_id}").service(web::resource("/path1").to( |r: HttpRequest| { - HttpResponse::Created() - .body(format!("project: {}", &r.match_info()["project_id"])) + async move { + HttpResponse::Created() + .body(format!("project: {}", &r.match_info()["project_id"])) + } }, )), - ))); + ))) + .await; let req = TestRequest::with_uri("/app/project_1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::CREATED); match resp.response().body() { @@ -927,22 +960,25 @@ mod tests { } } - #[test] - fn test_nested2_scope_with_variable_segment() { + #[actix_rt::test] + async fn test_nested2_scope_with_variable_segment() { let mut srv = init_service(App::new().service(web::scope("/app").service( web::scope("/{project}").service(web::scope("/{id}").service( web::resource("/path1").to(|r: HttpRequest| { - HttpResponse::Created().body(format!( - "project: {} - {}", - &r.match_info()["project"], - &r.match_info()["id"], - )) + async move { + HttpResponse::Created().body(format!( + "project: {} - {}", + &r.match_info()["project"], + &r.match_info()["id"], + )) + } }), )), - ))); + ))) + .await; let req = TestRequest::with_uri("/app/test/1/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::CREATED); match resp.response().body() { @@ -954,33 +990,34 @@ mod tests { } let req = TestRequest::with_uri("/app/test/1/path2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); } - #[test] - fn test_default_resource() { + #[actix_rt::test] + async fn test_default_resource() { let mut srv = init_service( App::new().service( web::scope("/app") .service(web::resource("/path1").to(|| HttpResponse::Ok())) .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::BadRequest()) + ok(r.into_response(HttpResponse::BadRequest())) }), ), - ); + ) + .await; let req = TestRequest::with_uri("/app/path2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/path2").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); } - #[test] - fn test_default_resource_propagation() { + #[actix_rt::test] + async fn test_default_resource_propagation() { let mut srv = init_service( App::new() .service(web::scope("/app1").default_service( @@ -988,49 +1025,44 @@ mod tests { )) .service(web::scope("/app2")) .default_service(|r: ServiceRequest| { - r.into_response(HttpResponse::MethodNotAllowed()) + ok(r.into_response(HttpResponse::MethodNotAllowed())) }), - ); + ) + .await; let req = TestRequest::with_uri("/non-exist").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); let req = TestRequest::with_uri("/app1/non-exist").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/app2/non-exist").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); } - fn md( - req: ServiceRequest, - srv: &mut S, - ) -> impl IntoFuture, Error = Error> - where - S: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - >, - { - srv.call(req).map(|mut res| { - res.headers_mut() - .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); - res - }) - } - - #[test] - fn test_middleware() { + #[actix_rt::test] + async fn test_middleware() { let mut srv = - init_service(App::new().service(web::scope("app").wrap(md).service( - web::resource("/test").route(web::get().to(|| HttpResponse::Ok())), - ))); + init_service( + App::new().service( + web::scope("app") + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )) + .service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())), + ), + ), + ) + .await; + let req = TestRequest::with_uri("/app/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -1038,25 +1070,29 @@ mod tests { ); } - #[test] - fn test_middleware_fn() { + #[actix_rt::test] + async fn test_middleware_fn() { let mut srv = init_service( App::new().service( web::scope("app") .wrap_fn(|req, srv| { - srv.call(req).map(|mut res| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("0001"), ); - res - }) + Ok(res) + } }) .route("/test", web::get().to(|| HttpResponse::Ok())), ), - ); + ) + .await; + let req = TestRequest::with_uri("/app/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -1064,52 +1100,74 @@ mod tests { ); } - #[test] - fn test_override_data() { + #[actix_rt::test] + async fn test_override_data() { let mut srv = init_service(App::new().data(1usize).service( web::scope("app").data(10usize).route( "/t", web::get().to(|data: web::Data| { - assert_eq!(*data, 10); + assert_eq!(**data, 10); let _ = data.clone(); HttpResponse::Ok() }), ), - )); + )) + .await; let req = TestRequest::with_uri("/app/t").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_scope_config() { + #[actix_rt::test] + async fn test_override_app_data() { + let mut srv = init_service(App::new().app_data(web::Data::new(1usize)).service( + web::scope("app").app_data(web::Data::new(10usize)).route( + "/t", + web::get().to(|data: web::Data| { + assert_eq!(**data, 10); + let _ = data.clone(); + HttpResponse::Ok() + }), + ), + )) + .await; + + let req = TestRequest::with_uri("/app/t").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_scope_config() { let mut srv = init_service(App::new().service(web::scope("/app").configure(|s| { s.route("/path1", web::get().to(|| HttpResponse::Ok())); - }))); + }))) + .await; let req = TestRequest::with_uri("/app/path1").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_scope_config_2() { + #[actix_rt::test] + async fn test_scope_config_2() { let mut srv = init_service(App::new().service(web::scope("/app").configure(|s| { s.service(web::scope("/v1").configure(|s| { s.route("/", web::get().to(|| HttpResponse::Ok())); })); - }))); + }))) + .await; let req = TestRequest::with_uri("/app/v1/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } - #[test] - fn test_url_for_external() { + #[actix_rt::test] + async fn test_url_for_external() { let mut srv = init_service(App::new().service(web::scope("/app").configure(|s| { s.service(web::scope("/v1").configure(|s| { @@ -1120,36 +1178,45 @@ mod tests { s.route( "/", web::get().to(|req: HttpRequest| { - HttpResponse::Ok().body(format!( - "{}", - req.url_for("youtube", &["xxxxxx"]).unwrap().as_str() - )) + async move { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("youtube", &["xxxxxx"]) + .unwrap() + .as_str() + )) + } }), ); })); - }))); + }))) + .await; let req = TestRequest::with_uri("/app/v1/").to_request(); - let resp = block_on(srv.call(req)).unwrap(); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); + let body = read_body(resp).await; assert_eq!(body, &b"https://youtube.com/watch/xxxxxx"[..]); } - #[test] - fn test_url_for_nested() { + #[actix_rt::test] + async fn test_url_for_nested() { let mut srv = init_service(App::new().service(web::scope("/a").service( web::scope("/b").service(web::resource("/c/{stuff}").name("c").route( web::get().to(|req: HttpRequest| { - HttpResponse::Ok() - .body(format!("{}", req.url_for("c", &["12345"]).unwrap())) + async move { + HttpResponse::Ok() + .body(format!("{}", req.url_for("c", &["12345"]).unwrap())) + } }), )), - ))); + ))) + .await; + let req = TestRequest::with_uri("/a/b/c/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = call_service(&mut srv, req).await; assert_eq!(resp.status(), StatusCode::OK); - let body = read_body(resp); + let body = read_body(resp).await; assert_eq!( body, Bytes::from_static(b"http://localhost:8080/a/b/c/12345") diff --git a/src/server.rs b/src/server.rs index d1a019a19..11cfbb6bc 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,20 +1,26 @@ use std::marker::PhantomData; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::{fmt, io, net}; use actix_http::{body::MessageBody, Error, HttpService, KeepAlive, Request, Response}; -use actix_rt::System; use actix_server::{Server, ServerBuilder}; -use actix_server_config::ServerConfig; -use actix_service::{IntoNewService, NewService}; -use parking_lot::Mutex; +use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory}; use net2::TcpBuilder; -#[cfg(feature = "ssl")] -use openssl::ssl::{SslAcceptor, SslAcceptorBuilder}; -#[cfg(feature = "rust-tls")] -use rustls::ServerConfig as RustlsServerConfig; +#[cfg(unix)] +use actix_http::Protocol; +#[cfg(unix)] +use actix_service::pipeline_factory; +#[cfg(unix)] +use futures::future::ok; + +#[cfg(feature = "openssl")] +use actix_tls::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; +#[cfg(feature = "rustls")] +use actix_tls::rustls::ServerConfig as RustlsServerConfig; + +use crate::config::AppConfig; struct Socket { scheme: &'static str, @@ -22,6 +28,7 @@ struct Socket { } struct Config { + host: Option, keep_alive: KeepAlive, client_timeout: u64, client_shutdown: u64, @@ -31,36 +38,30 @@ struct Config { /// /// Create new http server with application factory. /// -/// ```rust -/// use std::io; +/// ```rust,no_run /// use actix_web::{web, App, HttpResponse, HttpServer}; /// -/// fn main() -> io::Result<()> { -/// let sys = actix_rt::System::new("example"); // <- create Actix runtime -/// +/// #[actix_rt::main] +/// async fn main() -> std::io::Result<()> { /// HttpServer::new( /// || App::new() /// .service(web::resource("/").to(|| HttpResponse::Ok()))) /// .bind("127.0.0.1:59090")? -/// .start(); -/// -/// # actix_rt::System::current().stop(); -/// sys.run() +/// .run() +/// .await /// } /// ``` pub struct HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoNewService, - S: NewService, + I: IntoServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, - S::Service: 'static, B: MessageBody, { pub(super) factory: F, - pub(super) host: Option, config: Arc>, backlog: i32, sockets: Vec, @@ -71,20 +72,20 @@ where impl HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoNewService, - S: NewService, - S::Error: Into, + I: IntoServiceFactory, + S: ServiceFactory, + S::Error: Into + 'static, S::InitError: fmt::Debug, - S::Response: Into>, - S::Service: 'static, + S::Response: Into> + 'static, + ::Future: 'static, B: MessageBody + 'static, { /// Create new http server with application factory pub fn new(factory: F) -> Self { HttpServer { factory, - host: None, config: Arc::new(Mutex::new(Config { + host: None, keep_alive: KeepAlive::Timeout(5), client_timeout: 5000, client_shutdown: 5000, @@ -138,8 +139,8 @@ where /// can be used to limit the global SSL CPU usage. /// /// By default max connections is set to a 256. - pub fn maxconnrate(mut self, num: usize) -> Self { - self.builder = self.builder.maxconnrate(num); + pub fn maxconnrate(self, num: usize) -> Self { + actix_tls::max_concurrent_ssl_connect(num); self } @@ -147,7 +148,7 @@ where /// /// By default keep alive is set to a 5 seconds. pub fn keep_alive>(self, val: T) -> Self { - self.config.lock().keep_alive = val.into(); + self.config.lock().unwrap().keep_alive = val.into(); self } @@ -161,7 +162,7 @@ where /// /// By default client timeout is set to 5000 milliseconds. pub fn client_timeout(self, val: u64) -> Self { - self.config.lock().client_timeout = val; + self.config.lock().unwrap().client_timeout = val; self } @@ -174,17 +175,19 @@ where /// /// By default client timeout is set to 5000 milliseconds. pub fn client_shutdown(self, val: u64) -> Self { - self.config.lock().client_shutdown = val; + self.config.lock().unwrap().client_shutdown = val; self } /// Set server host name. /// - /// Host name is used by application router as a hostname for url - /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. - /// html#method.host) documentation for more information. - pub fn server_hostname>(mut self, val: T) -> Self { - self.host = Some(val.as_ref().to_owned()); + /// Host name is used by application router as a hostname for url generation. + /// Check [ConnectionInfo](./dev/struct.ConnectionInfo.html#method.host) + /// documentation for more information. + /// + /// By default host name is set to a "localhost" value. + pub fn server_hostname>(self, val: T) -> Self { + self.config.lock().unwrap().host = Some(val.as_ref().to_owned()); self } @@ -244,21 +247,29 @@ where format!("actix-web-service-{}", addr), lst, move || { - let c = cfg.lock(); + let c = cfg.lock().unwrap(); + let cfg = AppConfig::new( + false, + addr, + c.host.clone().unwrap_or_else(|| format!("{}", addr)), + ); + HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .finish(factory()) + .local_addr(addr) + .finish(map_config(factory(), move |_| cfg.clone())) + .tcp() }, )?; Ok(self) } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] /// Use listener for accepting incoming tls connection requests /// /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn listen_ssl( + pub fn listen_openssl( self, lst: net::TcpListener, builder: SslAcceptorBuilder, @@ -266,15 +277,12 @@ where self.listen_ssl_inner(lst, openssl_acceptor(builder)?) } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] fn listen_ssl_inner( mut self, lst: net::TcpListener, acceptor: SslAcceptor, ) -> io::Result { - use actix_server::ssl::{OpensslAcceptor, SslError}; - - let acceptor = OpensslAcceptor::new(acceptor); let factory = self.factory.clone(); let cfg = self.config.clone(); let addr = lst.local_addr().unwrap(); @@ -287,22 +295,24 @@ where format!("actix-web-service-{}", addr), lst, move || { - let c = cfg.lock(); - acceptor.clone().map_err(SslError::Ssl).and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) - .finish(factory()) - .map_err(SslError::Service) - .map_init_err(|_| ()), - ) + let c = cfg.lock().unwrap(); + let cfg = AppConfig::new( + true, + addr, + c.host.clone().unwrap_or_else(|| format!("{}", addr)), + ); + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown) + .finish(map_config(factory(), move |_| cfg.clone())) + .openssl(acceptor.clone()) }, )?; Ok(self) } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] /// Use listener for accepting incoming tls connection requests /// /// This method sets alpn protocols to "h2" and "http/1.1" @@ -314,18 +324,12 @@ where self.listen_rustls_inner(lst, config) } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] fn listen_rustls_inner( mut self, lst: net::TcpListener, - mut config: RustlsServerConfig, + config: RustlsServerConfig, ) -> io::Result { - use actix_server::ssl::{RustlsAcceptor, SslError}; - - let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; - config.set_protocols(&protos); - - let acceptor = RustlsAcceptor::new(config); let factory = self.factory.clone(); let cfg = self.config.clone(); let addr = lst.local_addr().unwrap(); @@ -338,16 +342,18 @@ where format!("actix-web-service-{}", addr), lst, move || { - let c = cfg.lock(); - acceptor.clone().map_err(SslError::Ssl).and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) - .finish(factory()) - .map_err(SslError::Service) - .map_init_err(|_| ()), - ) + let c = cfg.lock().unwrap(); + let cfg = AppConfig::new( + true, + addr, + c.host.clone().unwrap_or_else(|| format!("{}", addr)), + ); + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown) + .finish(map_config(factory(), move |_| cfg.clone())) + .rustls(config.clone()) }, )?; Ok(self) @@ -397,11 +403,11 @@ where } } - #[cfg(feature = "ssl")] + #[cfg(feature = "openssl")] /// Start listening for incoming tls connections. /// /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn bind_ssl( + pub fn bind_openssl( mut self, addr: A, builder: SslAcceptorBuilder, @@ -419,7 +425,7 @@ where Ok(self) } - #[cfg(feature = "rust-tls")] + #[cfg(feature = "rustls")] /// Start listening for incoming tls connections. /// /// This method sets alpn protocols to "h2" and "http/1.1" @@ -435,7 +441,47 @@ where Ok(self) } - #[cfg(feature = "uds")] + #[cfg(unix)] + /// Start listening for unix domain connections on existing listener. + /// + /// This method is available with `uds` feature. + pub fn listen_uds( + mut self, + lst: std::os::unix::net::UnixListener, + ) -> io::Result { + use actix_rt::net::UnixStream; + + let cfg = self.config.clone(); + let factory = self.factory.clone(); + let socket_addr = net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 8080, + ); + self.sockets.push(Socket { + scheme: "http", + addr: socket_addr, + }); + + let addr = format!("actix-web-service-{:?}", lst.local_addr()?); + + self.builder = self.builder.listen_uds(addr, lst, move || { + let c = cfg.lock().unwrap(); + let config = AppConfig::new( + false, + socket_addr, + c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), + ); + pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .finish(map_config(factory(), move |_| config.clone())), + ) + })?; + Ok(self) + } + + #[cfg(unix)] /// Start listening for incoming unix domain connections. /// /// This method is available with `uds` feature. @@ -443,25 +489,36 @@ where where A: AsRef, { + use actix_rt::net::UnixStream; + let cfg = self.config.clone(); let factory = self.factory.clone(); + let socket_addr = net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 8080, + ); self.sockets.push(Socket { scheme: "http", - addr: net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 8080, - ), + addr: socket_addr, }); self.builder = self.builder.bind_uds( format!("actix-web-service-{:?}", addr.as_ref()), addr, move || { - let c = cfg.lock(); - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .finish(factory()) + let c = cfg.lock().unwrap(); + let config = AppConfig::new( + false, + socket_addr, + c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), + ); + pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))) + .and_then( + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .finish(map_config(factory(), move |_| config.clone())), + ) }, )?; Ok(self) @@ -471,8 +528,8 @@ where impl HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoNewService, - S: NewService, + I: IntoServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, @@ -488,48 +545,20 @@ where /// This methods panics if no socket address can be bound or an `Actix` system is not yet /// configured. /// - /// ```rust + /// ```rust,no_run /// use std::io; /// use actix_web::{web, App, HttpResponse, HttpServer}; /// - /// fn main() -> io::Result<()> { - /// let sys = actix_rt::System::new("example"); // <- create Actix system - /// - /// HttpServer::new(|| App::new().service(web::resource("/").to(|| HttpResponse::Ok()))) - /// .bind("127.0.0.1:0")? - /// .start(); - /// # actix_rt::System::current().stop(); - /// sys.run() // <- Run actix system, this method starts all async processes - /// } - /// ``` - pub fn start(self) -> Server { - self.builder.start() - } - - /// Spawn new thread and start listening for incoming connections. - /// - /// This method spawns new thread and starts new actix system. Other than - /// that it is similar to `start()` method. This method blocks. - /// - /// This methods panics if no socket addresses get bound. - /// - /// ```rust - /// use std::io; - /// use actix_web::{web, App, HttpResponse, HttpServer}; - /// - /// fn main() -> io::Result<()> { - /// # std::thread::spawn(|| { + /// #[actix_rt::main] + /// async fn main() -> io::Result<()> { /// HttpServer::new(|| App::new().service(web::resource("/").to(|| HttpResponse::Ok()))) /// .bind("127.0.0.1:0")? /// .run() - /// # }); - /// # Ok(()) + /// .await /// } /// ``` - pub fn run(self) -> io::Result<()> { - let sys = System::new("http-server"); - self.start(); - sys.run() + pub fn run(self) -> Server { + self.builder.start() } } @@ -546,15 +575,16 @@ fn create_tcp_listener( Ok(builder.listen(backlog)?) } -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] /// Configure `SslAcceptorBuilder` with custom server flags. fn openssl_acceptor(mut builder: SslAcceptorBuilder) -> io::Result { - use openssl::ssl::AlpnError; - builder.set_alpn_select_callback(|_, protos| { const H2: &[u8] = b"\x02h2"; + const H11: &[u8] = b"\x08http/1.1"; if protos.windows(3).any(|window| window == H2) { Ok(b"h2") + } else if protos.windows(9).any(|window| window == H11) { + Ok(b"http/1.1") } else { Err(AlpnError::NOACK) } diff --git a/src/service.rs b/src/service.rs index 1d475cf15..e51be9964 100644 --- a/src/service.rs +++ b/src/service.rs @@ -8,9 +8,8 @@ use actix_http::{ Error, Extensions, HttpMessage, Payload, PayloadStream, RequestHead, Response, ResponseHead, }; -use actix_router::{Path, Resource, ResourceDef, Url}; -use actix_service::{IntoNewService, NewService}; -use futures::future::{ok, FutureResult, IntoFuture}; +use actix_router::{IntoPattern, Path, Resource, ResourceDef, Url}; +use actix_service::{IntoServiceFactory, ServiceFactory}; use crate::config::{AppConfig, AppService}; use crate::data::Data; @@ -24,7 +23,7 @@ pub trait HttpServiceFactory { fn register(self, config: &mut AppService); } -pub(crate) trait ServiceFactory { +pub(crate) trait AppServiceFactory { fn register(&mut self, config: &mut AppService); } @@ -40,7 +39,7 @@ impl ServiceFactoryWrapper { } } -impl ServiceFactory for ServiceFactoryWrapper +impl AppServiceFactory for ServiceFactoryWrapper where T: HttpServiceFactory, { @@ -68,6 +67,34 @@ impl ServiceRequest { (self.0, pl) } + /// Construct request from parts. + /// + /// `ServiceRequest` can be re-constructed only if `req` hasnt been cloned. + pub fn from_parts( + mut req: HttpRequest, + pl: Payload, + ) -> Result { + if Rc::strong_count(&req.0) == 1 && Rc::weak_count(&req.0) == 0 { + Rc::get_mut(&mut req.0).unwrap().payload = pl; + Ok(ServiceRequest(req)) + } else { + Err((req, pl)) + } + } + + /// Construct request from request. + /// + /// `HttpRequest` implements `Clone` trait via `Rc` type. `ServiceRequest` + /// can be re-constructed only if rc's strong pointers count eq 1 and + /// weak pointers count is 0. + pub fn from_request(req: HttpRequest) -> Result { + if Rc::strong_count(&req.0) == 1 && Rc::weak_count(&req.0) == 0 { + Ok(ServiceRequest(req)) + } else { + Err(req) + } + } + /// Create service response #[inline] pub fn into_response>>(self, res: R) -> ServiceResponse { @@ -154,7 +181,7 @@ impl ServiceRequest { /// Get *ConnectionInfo* for the current request. #[inline] - pub fn connection_info(&self) -> Ref { + pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> { ConnectionInfo::get(self.head(), &*self.app_config()) } @@ -226,13 +253,13 @@ impl HttpMessage for ServiceRequest { /// Request extensions #[inline] - fn extensions(&self) -> Ref { + fn extensions(&self) -> Ref<'_, Extensions> { self.0.extensions() } /// Mutable reference to a the request's extensions #[inline] - fn extensions_mut(&self) -> RefMut { + fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.0.extensions_mut() } @@ -243,7 +270,7 @@ impl HttpMessage for ServiceRequest { } impl fmt::Debug for ServiceRequest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, "\nServiceRequest {:?} {}:{}", @@ -376,18 +403,8 @@ impl Into> for ServiceResponse { } } -impl IntoFuture for ServiceResponse { - type Item = ServiceResponse; - type Error = Error; - type Future = FutureResult, Error>; - - fn into_future(self) -> Self::Future { - ok(self) - } -} - impl fmt::Debug for ServiceResponse { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let res = writeln!( f, "\nServiceResponse {:?} {}{}", @@ -405,16 +422,16 @@ impl fmt::Debug for ServiceResponse { } pub struct WebService { - rdef: String, + rdef: Vec, name: Option, guards: Vec>, } impl WebService { /// Create new `WebService` instance. - pub fn new(path: &str) -> Self { + pub fn new(path: T) -> Self { WebService { - rdef: path.to_string(), + rdef: path.patterns(), name: None, guards: Vec::new(), } @@ -431,10 +448,10 @@ impl WebService { /// Add match guard to a web service. /// /// ```rust - /// use actix_web::{web, guard, dev, App, HttpResponse}; + /// use actix_web::{web, guard, dev, App, Error, HttpResponse}; /// - /// fn index(req: dev::ServiceRequest) -> dev::ServiceResponse { - /// req.into_response(HttpResponse::Ok().finish()) + /// async fn index(req: dev::ServiceRequest) -> Result { + /// Ok(req.into_response(HttpResponse::Ok().finish())) /// } /// /// fn main() { @@ -454,8 +471,8 @@ impl WebService { /// Set a service factory implementation and generate web service. pub fn finish(self, service: F) -> impl HttpServiceFactory where - F: IntoNewService, - T: NewService< + F: IntoServiceFactory, + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -464,7 +481,7 @@ impl WebService { > + 'static, { WebServiceImpl { - srv: service.into_new_service(), + srv: service.into_factory(), rdef: self.rdef, name: self.name, guards: self.guards, @@ -474,14 +491,14 @@ impl WebService { struct WebServiceImpl { srv: T, - rdef: String, + rdef: Vec, name: Option, guards: Vec>, } impl HttpServiceFactory for WebServiceImpl where - T: NewService< + T: ServiceFactory< Config = (), Request = ServiceRequest, Response = ServiceResponse, @@ -497,9 +514,9 @@ where }; let mut rdef = if config.is_root() || !self.rdef.is_empty() { - ResourceDef::new(&insert_slash(&self.rdef)) + ResourceDef::new(insert_slash(self.rdef)) } else { - ResourceDef::new(&self.rdef) + ResourceDef::new(self.rdef) }; if let Some(ref name) = self.name { *rdef.name_mut() = name.clone(); @@ -511,29 +528,54 @@ where #[cfg(test)] mod tests { use super::*; - use crate::test::{call_service, init_service, TestRequest}; + use crate::test::{init_service, TestRequest}; use crate::{guard, http, web, App, HttpResponse}; + use actix_service::Service; + use futures::future::ok; #[test] - fn test_service() { + fn test_service_request() { + let req = TestRequest::default().to_srv_request(); + let (r, pl) = req.into_parts(); + assert!(ServiceRequest::from_parts(r, pl).is_ok()); + + let req = TestRequest::default().to_srv_request(); + let (r, pl) = req.into_parts(); + let _r2 = r.clone(); + assert!(ServiceRequest::from_parts(r, pl).is_err()); + + let req = TestRequest::default().to_srv_request(); + let (r, _pl) = req.into_parts(); + assert!(ServiceRequest::from_request(r).is_ok()); + + let req = TestRequest::default().to_srv_request(); + let (r, _pl) = req.into_parts(); + let _r2 = r.clone(); + assert!(ServiceRequest::from_request(r).is_err()); + } + + #[actix_rt::test] + async fn test_service() { let mut srv = init_service( App::new().service(web::service("/test").name("test").finish( - |req: ServiceRequest| req.into_response(HttpResponse::Ok().finish()), + |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())), )), - ); + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), http::StatusCode::OK); let mut srv = init_service( App::new().service(web::service("/test").guard(guard::Get()).finish( - |req: ServiceRequest| req.into_response(HttpResponse::Ok().finish()), + |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())), )), - ); + ) + .await; let req = TestRequest::with_uri("/test") .method(http::Method::PUT) .to_request(); - let resp = call_service(&mut srv, req); + let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), http::StatusCode::NOT_FOUND); } diff --git a/src/test.rs b/src/test.rs index 903679cad..bb0d05cf7 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,104 +1,40 @@ //! Various helpers for Actix applications to use during testing. -use std::cell::RefCell; +use std::convert::TryFrom; +use std::net::SocketAddr; use std::rc::Rc; +use std::sync::mpsc; +use std::{fmt, net, thread, time}; +use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::http::header::{ContentType, Header, HeaderName, IntoHeaderValue}; -use actix_http::http::{HttpTryFrom, Method, StatusCode, Uri, Version}; +use actix_http::http::{Error as HttpError, Method, StatusCode, Uri, Version}; use actix_http::test::TestRequest as HttpTestRequest; -use actix_http::{cookie::Cookie, Extensions, Request}; +use actix_http::{cookie::Cookie, ws, Extensions, HttpService, Request}; use actix_router::{Path, ResourceDef, Url}; -use actix_rt::{System, SystemRunner}; -use actix_server_config::ServerConfig; -use actix_service::{IntoNewService, IntoService, NewService, Service}; +use actix_rt::{time::delay_for, System}; +use actix_service::{ + map_config, IntoService, IntoServiceFactory, Service, ServiceFactory, +}; +use awc::error::PayloadError; +use awc::{Client, ClientRequest, ClientResponse, Connector}; use bytes::{Bytes, BytesMut}; -use futures::future::{lazy, ok, Future, IntoFuture}; -use futures::Stream; +use futures::future::ok; +use futures::stream::{Stream, StreamExt}; +use net2::TcpBuilder; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json; pub use actix_http::test::TestBuffer; -use crate::config::{AppConfig, AppConfigInner}; +use crate::config::AppConfig; use crate::data::Data; -use crate::dev::{Body, MessageBody, Payload}; +use crate::dev::{Body, MessageBody, Payload, Server}; use crate::request::HttpRequestPool; use crate::rmap::ResourceMap; use crate::service::{ServiceRequest, ServiceResponse}; use crate::{Error, HttpRequest, HttpResponse}; -thread_local! { - static RT: RefCell = { - RefCell::new(Inner(Some(System::builder().build()))) - }; -} - -struct Inner(Option); - -impl Inner { - fn get_mut(&mut self) -> &mut SystemRunner { - self.0.as_mut().unwrap() - } -} - -impl Drop for Inner { - fn drop(&mut self) { - std::mem::forget(self.0.take().unwrap()) - } -} - -/// Runs the provided future, blocking the current thread until the future -/// completes. -/// -/// This function can be used to synchronously block the current thread -/// until the provided `future` has resolved either successfully or with an -/// error. The result of the future is then returned from this function -/// call. -/// -/// Note that this function is intended to be used only for testing purpose. -/// This function panics on nested call. -pub fn block_on(f: F) -> Result -where - F: IntoFuture, -{ - RT.with(move |rt| rt.borrow_mut().get_mut().block_on(f.into_future())) -} - -/// Runs the provided function, blocking the current thread until the result -/// future completes. -/// -/// This function can be used to synchronously block the current thread -/// until the provided `future` has resolved either successfully or with an -/// error. The result of the future is then returned from this function -/// call. -/// -/// Note that this function is intended to be used only for testing purpose. -/// This function panics on nested call. -pub fn block_fn(f: F) -> Result -where - F: FnOnce() -> R, - R: IntoFuture, -{ - RT.with(move |rt| rt.borrow_mut().get_mut().block_on(lazy(f))) -} - -#[doc(hidden)] -/// Runs the provided function, with runtime enabled. -/// -/// Note that this function is intended to be used only for testing purpose. -/// This function panics on nested call. -pub fn run_on(f: F) -> R -where - F: FnOnce() -> R, -{ - RT.with(move |rt| { - rt.borrow_mut() - .get_mut() - .block_on(lazy(|| Ok::<_, ()>(f()))) - }) - .unwrap() -} - /// Create service that always responds with `HttpResponse::Ok()` pub fn ok_service( ) -> impl Service, Error = Error> @@ -112,7 +48,7 @@ pub fn default_service( ) -> impl Service, Error = Error> { (move |req: ServiceRequest| { - req.into_response(HttpResponse::build(status_code).finish()) + ok(req.into_response(HttpResponse::build(status_code).finish())) }) .into_service() } @@ -124,38 +60,36 @@ pub fn default_service( /// use actix_service::Service; /// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; /// -/// #[test] -/// fn test_init_service() { +/// #[actix_rt::test] +/// async fn test_init_service() { /// let mut app = test::init_service( /// App::new() -/// .service(web::resource("/test").to(|| HttpResponse::Ok())) -/// ); +/// .service(web::resource("/test").to(|| async { HttpResponse::Ok() })) +/// ).await; /// /// // Create request object /// let req = test::TestRequest::with_uri("/test").to_request(); /// /// // Execute application -/// let resp = test::block_on(app.call(req)).unwrap(); +/// let resp = app.call(req).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// } /// ``` -pub fn init_service( +pub async fn init_service( app: R, ) -> impl Service, Error = E> where - R: IntoNewService, - S: NewService< - Config = ServerConfig, + R: IntoServiceFactory, + S: ServiceFactory< + Config = AppConfig, Request = Request, Response = ServiceResponse, Error = E, >, S::InitError: std::fmt::Debug, { - let cfg = ServerConfig::new("127.0.0.1:8080".parse().unwrap()); - let srv = app.into_new_service(); - let fut = run_on(move || srv.new_service(&cfg)); - block_on(fut).unwrap() + let srv = app.into_factory(); + srv.new_service(AppConfig::default()).await.unwrap() } /// Calls service and waits for response future completion. @@ -168,109 +102,120 @@ where /// fn test_response() { /// let mut app = test::init_service( /// App::new() -/// .service(web::resource("/test").to(|| HttpResponse::Ok())) -/// ); +/// .service(web::resource("/test").to(|| async { +/// HttpResponse::Ok() +/// })) +/// ).await; /// /// // Create request object /// let req = test::TestRequest::with_uri("/test").to_request(); /// /// // Call application -/// let resp = test::call_service(&mut app, req); +/// let resp = test::call_service(&mut app, req).await; /// assert_eq!(resp.status(), StatusCode::OK); /// } /// ``` -pub fn call_service(app: &mut S, req: R) -> S::Response +pub async fn call_service(app: &mut S, req: R) -> S::Response where S: Service, Error = E>, E: std::fmt::Debug, { - block_on(run_on(move || app.call(req))).unwrap() + app.call(req).await.unwrap() } /// Helper function that returns a response body of a TestRequest -/// This function blocks the current thread until futures complete. /// /// ```rust /// use actix_web::{test, web, App, HttpResponse, http::header}; /// use bytes::Bytes; /// -/// #[test] -/// fn test_index() { +/// #[actix_rt::test] +/// async fn test_index() { /// let mut app = test::init_service( /// App::new().service( /// web::resource("/index.html") -/// .route(web::post().to( -/// || HttpResponse::Ok().body("welcome!"))))); +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; /// /// let req = test::TestRequest::post() /// .uri("/index.html") /// .header(header::CONTENT_TYPE, "application/json") /// .to_request(); /// -/// let result = test::read_response(&mut app, req); +/// let result = test::read_response(&mut app, req).await; /// assert_eq!(result, Bytes::from_static(b"welcome!")); /// } /// ``` -pub fn read_response(app: &mut S, req: Request) -> Bytes +pub async fn read_response(app: &mut S, req: Request) -> Bytes where S: Service, Error = Error>, B: MessageBody, { - block_on(run_on(move || { - app.call(req).and_then(|mut resp: ServiceResponse| { - resp.take_body() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, Error>(body) - }) - .map(|body: BytesMut| body.freeze()) - }) - })) - .unwrap_or_else(|_| panic!("read_response failed at block_on unwrap")) + let mut resp = app + .call(req) + .await + .unwrap_or_else(|_| panic!("read_response failed at application call")); + + let mut body = resp.take_body(); + let mut bytes = BytesMut::new(); + while let Some(item) = body.next().await { + bytes.extend_from_slice(&item.unwrap()); + } + bytes.freeze() } /// Helper function that returns a response body of a ServiceResponse. -/// This function blocks the current thread until futures complete. /// /// ```rust /// use actix_web::{test, web, App, HttpResponse, http::header}; /// use bytes::Bytes; /// -/// #[test] -/// fn test_index() { +/// #[actix_rt::test] +/// async fn test_index() { /// let mut app = test::init_service( /// App::new().service( /// web::resource("/index.html") -/// .route(web::post().to( -/// || HttpResponse::Ok().body("welcome!"))))); +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; /// /// let req = test::TestRequest::post() /// .uri("/index.html") /// .header(header::CONTENT_TYPE, "application/json") /// .to_request(); /// -/// let resp = test::call_service(&mut app, req); +/// let resp = test::call_service(&mut app, req).await; /// let result = test::read_body(resp); /// assert_eq!(result, Bytes::from_static(b"welcome!")); /// } /// ``` -pub fn read_body(mut res: ServiceResponse) -> Bytes +pub async fn read_body(mut res: ServiceResponse) -> Bytes where B: MessageBody, { - block_on(run_on(move || { - res.take_body() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, Error>(body) - }) - .map(|body: BytesMut| body.freeze()) - })) - .unwrap_or_else(|_| panic!("read_response failed at block_on unwrap")) + let mut body = res.take_body(); + let mut bytes = BytesMut::new(); + while let Some(item) = body.next().await { + bytes.extend_from_slice(&item.unwrap()); + } + bytes.freeze() +} + +pub async fn load_stream(mut stream: S) -> Result +where + S: Stream> + Unpin, +{ + let mut data = BytesMut::new(); + while let Some(item) = stream.next().await { + data.extend_from_slice(&item?); + } + Ok(data.freeze()) } /// Helper function that returns a deserialized response body of a TestRequest -/// This function blocks the current thread until futures complete. /// /// ```rust /// use actix_web::{App, test, web, HttpResponse, http::header}; @@ -282,15 +227,16 @@ where /// name: String /// } /// -/// #[test] -/// fn test_add_person() { +/// #[actix_rt::test] +/// async fn test_add_person() { /// let mut app = test::init_service( /// App::new().service( /// web::resource("/people") -/// .route(web::post().to(|person: web::Json| { +/// .route(web::post().to(|person: web::Json| async { /// HttpResponse::Ok() /// .json(person.into_inner())}) -/// ))); +/// )) +/// ).await; /// /// let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); /// @@ -300,30 +246,19 @@ where /// .set_payload(payload) /// .to_request(); /// -/// let result: Person = test::read_response_json(&mut app, req); +/// let result: Person = test::read_response_json(&mut app, req).await; /// } /// ``` -pub fn read_response_json(app: &mut S, req: Request) -> T +pub async fn read_response_json(app: &mut S, req: Request) -> T where S: Service, Error = Error>, B: MessageBody, T: DeserializeOwned, { - block_on(run_on(move || { - app.call(req).and_then(|mut resp: ServiceResponse| { - resp.take_body() - .fold(BytesMut::new(), move |mut body, chunk| { - body.extend_from_slice(&chunk); - Ok::<_, Error>(body) - }) - .and_then(|body: BytesMut| { - ok(serde_json::from_slice(&body).unwrap_or_else(|_| { - panic!("read_response_json failed during deserialization") - })) - }) - }) - })) - .unwrap_or_else(|_| panic!("read_response_json failed at block_on unwrap")) + let body = read_response(app, req).await; + + serde_json::from_slice(&body) + .unwrap_or_else(|_| panic!("read_response_json failed during deserialization")) } /// Test `Request` builder. @@ -339,7 +274,7 @@ where /// use actix_web::{test, HttpRequest, HttpResponse, HttpMessage}; /// use actix_web::http::{header, StatusCode}; /// -/// fn index(req: HttpRequest) -> HttpResponse { +/// async fn index(req: HttpRequest) -> HttpResponse { /// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { /// HttpResponse::Ok().into() /// } else { @@ -352,19 +287,20 @@ where /// let req = test::TestRequest::with_header("content-type", "text/plain") /// .to_http_request(); /// -/// let resp = test::block_on(index(req)).unwrap(); +/// let resp = index(req).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// /// let req = test::TestRequest::default().to_http_request(); -/// let resp = test::block_on(index(req)).unwrap(); +/// let resp = index(req).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); /// } /// ``` pub struct TestRequest { req: HttpTestRequest, rmap: ResourceMap, - config: AppConfigInner, + config: AppConfig, path: Path, + peer_addr: Option, app_data: Extensions, } @@ -373,8 +309,9 @@ impl Default for TestRequest { TestRequest { req: HttpTestRequest::default(), rmap: ResourceMap::new(ResourceDef::new("")), - config: AppConfigInner::default(), + config: AppConfig::default(), path: Path::new(Url::new(Uri::default())), + peer_addr: None, app_data: Extensions::new(), } } @@ -395,7 +332,8 @@ impl TestRequest { /// Create TestRequest and set header pub fn with_header(key: K, value: V) -> TestRequest where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { TestRequest::default().header(key, value) @@ -453,7 +391,8 @@ impl TestRequest { /// Set a header pub fn header(mut self, key: K, value: V) -> Self where - HeaderName: HttpTryFrom, + HeaderName: TryFrom, + >::Error: Into, V: IntoHeaderValue, { self.req.header(key, value); @@ -461,7 +400,7 @@ impl TestRequest { } /// Set cookie for this request - pub fn cookie(mut self, cookie: Cookie) -> Self { + pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { self.req.cookie(cookie); self } @@ -472,6 +411,12 @@ impl TestRequest { self } + /// Set peer addr + pub fn peer_addr(mut self, addr: SocketAddr) -> Self { + self.peer_addr = Some(addr); + self + } + /// Set request payload pub fn set_payload>(mut self, data: B) -> Self { self.req.set_payload(data); @@ -505,6 +450,13 @@ impl TestRequest { self } + /// Set application data. This is equivalent of `App::app_data()` method + /// for testing purpose. + pub fn app_data(mut self, data: T) -> Self { + self.app_data.insert(data); + self + } + #[cfg(test)] /// Set request config pub(crate) fn rmap(mut self, rmap: ResourceMap) -> Self { @@ -514,12 +466,15 @@ impl TestRequest { /// Complete request creation and generate `Request` instance pub fn to_request(mut self) -> Request { - self.req.finish() + let mut req = self.req.finish(); + req.head_mut().peer_addr = self.peer_addr; + req } /// Complete request creation and generate `ServiceRequest` instance pub fn to_srv_request(mut self) -> ServiceRequest { - let (head, payload) = self.req.finish().into_parts(); + let (mut head, payload) = self.req.finish().into_parts(); + head.peer_addr = self.peer_addr; self.path.get_mut().update(&head.uri); ServiceRequest::new(HttpRequest::new( @@ -527,7 +482,7 @@ impl TestRequest { head, payload, Rc::new(self.rmap), - AppConfig::new(self.config), + self.config.clone(), Rc::new(self.app_data), HttpRequestPool::create(), )) @@ -540,7 +495,8 @@ impl TestRequest { /// Complete request creation and generate `HttpRequest` instance pub fn to_http_request(mut self) -> HttpRequest { - let (head, payload) = self.req.finish().into_parts(); + let (mut head, payload) = self.req.finish().into_parts(); + head.peer_addr = self.peer_addr; self.path.get_mut().update(&head.uri); HttpRequest::new( @@ -548,7 +504,7 @@ impl TestRequest { head, payload, Rc::new(self.rmap), - AppConfig::new(self.config), + self.config.clone(), Rc::new(self.app_data), HttpRequestPool::create(), ) @@ -556,7 +512,8 @@ impl TestRequest { /// Complete request creation and generate `HttpRequest` and `Payload` instances pub fn to_http_parts(mut self) -> (HttpRequest, Payload) { - let (head, payload) = self.req.finish().into_parts(); + let (mut head, payload) = self.req.finish().into_parts(); + head.peer_addr = self.peer_addr; self.path.get_mut().update(&head.uri); let req = HttpRequest::new( @@ -564,7 +521,7 @@ impl TestRequest { head, Payload::None, Rc::new(self.rmap), - AppConfig::new(self.config), + self.config.clone(), Rc::new(self.app_data), HttpRequestPool::create(), ); @@ -573,54 +530,487 @@ impl TestRequest { } } +/// Start test server with default configuration +/// +/// Test server is very simple server that simplify process of writing +/// integration tests cases for actix web applications. +/// +/// # Examples +/// +/// ```rust +/// use actix_web::{web, test, App, HttpResponse, Error}; +/// +/// async fn my_handler() -> Result { +/// Ok(HttpResponse::Ok().into()) +/// } +/// +/// #[actix_rt::test] +/// async fn test_example() { +/// let mut srv = test::start( +/// || App::new().service( +/// web::resource("/").to(my_handler)) +/// ); +/// +/// let req = srv.get("/"); +/// let response = req.send().await.unwrap(); +/// assert!(response.status().is_success()); +/// } +/// ``` +pub fn start(factory: F) -> TestServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory + 'static, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + start_with(TestServerConfig::default(), factory) +} + +/// Start test server with custom configuration +/// +/// Test server could be configured in different ways, for details check +/// `TestServerConfig` docs. +/// +/// # Examples +/// +/// ```rust +/// use actix_web::{web, test, App, HttpResponse, Error}; +/// +/// async fn my_handler() -> Result { +/// Ok(HttpResponse::Ok().into()) +/// } +/// +/// #[actix_rt::test] +/// async fn test_example() { +/// let mut srv = test::start_with(test::config().h1(), || +/// App::new().service(web::resource("/").to(my_handler)) +/// ); +/// +/// let req = srv.get("/"); +/// let response = req.send().await.unwrap(); +/// assert!(response.status().is_success()); +/// } +/// ``` +pub fn start_with(cfg: TestServerConfig, factory: F) -> TestServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory + 'static, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + let (tx, rx) = mpsc::channel(); + + let ssl = match cfg.stream { + StreamType::Tcp => false, + #[cfg(feature = "openssl")] + StreamType::Openssl(_) => true, + #[cfg(feature = "rustls")] + StreamType::Rustls(_) => true, + }; + + // run server in separate thread + thread::spawn(move || { + let sys = System::new("actix-test-server"); + let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); + let local_addr = tcp.local_addr().unwrap(); + let factory = factory.clone(); + let cfg = cfg.clone(); + let ctimeout = cfg.client_timeout; + let builder = Server::build().workers(1).disable_signals(); + + let srv = match cfg.stream { + StreamType::Tcp => match cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(false, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .h1(map_config(factory(), move |_| cfg.clone())) + .tcp() + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(false, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .h2(map_config(factory(), move |_| cfg.clone())) + .tcp() + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(false, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .finish(map_config(factory(), move |_| cfg.clone())) + .tcp() + }), + }, + #[cfg(feature = "openssl")] + StreamType::Openssl(acceptor) => match cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(true, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .h1(map_config(factory(), move |_| cfg.clone())) + .openssl(acceptor.clone()) + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(true, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .h2(map_config(factory(), move |_| cfg.clone())) + .openssl(acceptor.clone()) + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(true, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .finish(map_config(factory(), move |_| cfg.clone())) + .openssl(acceptor.clone()) + }), + }, + #[cfg(feature = "rustls")] + StreamType::Rustls(config) => match cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(true, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .h1(map_config(factory(), move |_| cfg.clone())) + .rustls(config.clone()) + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(true, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .h2(map_config(factory(), move |_| cfg.clone())) + .rustls(config.clone()) + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let cfg = + AppConfig::new(true, local_addr, format!("{}", local_addr)); + HttpService::build() + .client_timeout(ctimeout) + .finish(map_config(factory(), move |_| cfg.clone())) + .rustls(config.clone()) + }), + }, + } + .unwrap() + .start(); + + tx.send((System::current(), srv, local_addr)).unwrap(); + sys.run() + }); + + let (system, server, addr) = rx.recv().unwrap(); + + let client = { + let connector = { + #[cfg(feature = "openssl")] + { + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + Connector::new() + .conn_lifetime(time::Duration::from_secs(0)) + .timeout(time::Duration::from_millis(30000)) + .ssl(builder.build()) + .finish() + } + #[cfg(not(feature = "openssl"))] + { + Connector::new() + .conn_lifetime(time::Duration::from_secs(0)) + .timeout(time::Duration::from_millis(30000)) + .finish() + } + }; + + Client::build().connector(connector).finish() + }; + + TestServer { + ssl, + addr, + client, + system, + server, + } +} + +#[derive(Clone)] +pub struct TestServerConfig { + tp: HttpVer, + stream: StreamType, + client_timeout: u64, +} + +#[derive(Clone)] +enum HttpVer { + Http1, + Http2, + Both, +} + +#[derive(Clone)] +enum StreamType { + Tcp, + #[cfg(feature = "openssl")] + Openssl(open_ssl::ssl::SslAcceptor), + #[cfg(feature = "rustls")] + Rustls(rust_tls::ServerConfig), +} + +impl Default for TestServerConfig { + fn default() -> Self { + TestServerConfig::new() + } +} + +/// Create default test server config +pub fn config() -> TestServerConfig { + TestServerConfig::new() +} + +impl TestServerConfig { + /// Create default server configuration + pub(crate) fn new() -> TestServerConfig { + TestServerConfig { + tp: HttpVer::Both, + stream: StreamType::Tcp, + client_timeout: 5000, + } + } + + /// Start http/1.1 server only + pub fn h1(mut self) -> Self { + self.tp = HttpVer::Http1; + self + } + + /// Start http/2 server only + pub fn h2(mut self) -> Self { + self.tp = HttpVer::Http2; + self + } + + /// Start openssl server + #[cfg(feature = "openssl")] + pub fn openssl(mut self, acceptor: open_ssl::ssl::SslAcceptor) -> Self { + self.stream = StreamType::Openssl(acceptor); + self + } + + /// Start rustls server + #[cfg(feature = "rustls")] + pub fn rustls(mut self, config: rust_tls::ServerConfig) -> Self { + self.stream = StreamType::Rustls(config); + self + } + + /// Set server client timeout in milliseconds for first request. + pub fn client_timeout(mut self, val: u64) -> Self { + self.client_timeout = val; + self + } +} + +/// Get first available unused address +pub fn unused_addr() -> net::SocketAddr { + let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + let socket = TcpBuilder::new_v4().unwrap(); + socket.bind(&addr).unwrap(); + socket.reuse_address(true).unwrap(); + let tcp = socket.to_tcp_listener().unwrap(); + tcp.local_addr().unwrap() +} + +/// Test server controller +pub struct TestServer { + addr: net::SocketAddr, + client: awc::Client, + system: actix_rt::System, + ssl: bool, + server: Server, +} + +impl TestServer { + /// Construct test server url + pub fn addr(&self) -> net::SocketAddr { + self.addr + } + + /// Construct test server url + pub fn url(&self, uri: &str) -> String { + let scheme = if self.ssl { "https" } else { "http" }; + + if uri.starts_with('/') { + format!("{}://localhost:{}{}", scheme, self.addr.port(), uri) + } else { + format!("{}://localhost:{}/{}", scheme, self.addr.port(), uri) + } + } + + /// Create `GET` request + pub fn get>(&self, path: S) -> ClientRequest { + self.client.get(self.url(path.as_ref()).as_str()) + } + + /// Create `POST` request + pub fn post>(&self, path: S) -> ClientRequest { + self.client.post(self.url(path.as_ref()).as_str()) + } + + /// Create `HEAD` request + pub fn head>(&self, path: S) -> ClientRequest { + self.client.head(self.url(path.as_ref()).as_str()) + } + + /// Create `PUT` request + pub fn put>(&self, path: S) -> ClientRequest { + self.client.put(self.url(path.as_ref()).as_str()) + } + + /// Create `PATCH` request + pub fn patch>(&self, path: S) -> ClientRequest { + self.client.patch(self.url(path.as_ref()).as_str()) + } + + /// Create `DELETE` request + pub fn delete>(&self, path: S) -> ClientRequest { + self.client.delete(self.url(path.as_ref()).as_str()) + } + + /// Create `OPTIONS` request + pub fn options>(&self, path: S) -> ClientRequest { + self.client.options(self.url(path.as_ref()).as_str()) + } + + /// Connect to test http server + pub fn request>(&self, method: Method, path: S) -> ClientRequest { + self.client.request(method, path.as_ref()) + } + + pub async fn load_body( + &mut self, + mut response: ClientResponse, + ) -> Result + where + S: Stream> + Unpin + 'static, + { + response.body().limit(10_485_760).await + } + + /// Connect to websocket server at a given path + pub async fn ws_at( + &mut self, + path: &str, + ) -> Result, awc::error::WsClientError> + { + let url = self.url(path); + let connect = self.client.ws(url).connect(); + connect.await.map(|(_, framed)| framed) + } + + /// Connect to a websocket server + pub async fn ws( + &mut self, + ) -> Result, awc::error::WsClientError> + { + self.ws_at("/").await + } + + /// Gracefully stop http server + pub async fn stop(self) { + self.server.stop(true).await; + self.system.stop(); + delay_for(time::Duration::from_millis(100)).await; + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.system.stop() + } +} + #[cfg(test)] mod tests { use actix_http::httpmessage::HttpMessage; + use futures::FutureExt; use serde::{Deserialize, Serialize}; use std::time::SystemTime; use super::*; - use crate::{http::header, web, App, HttpResponse}; + use crate::{http::header, web, App, HttpResponse, Responder}; - #[test] - fn test_basics() { + #[actix_rt::test] + async fn test_basics() { let req = TestRequest::with_hdr(header::ContentType::json()) .version(Version::HTTP_2) .set(header::Date(SystemTime::now().into())) .param("test", "123") .data(10u32) + .app_data(20u64) + .peer_addr("127.0.0.1:8081".parse().unwrap()) .to_http_request(); assert!(req.headers().contains_key(header::CONTENT_TYPE)); assert!(req.headers().contains_key(header::DATE)); + assert_eq!( + req.head().peer_addr, + Some("127.0.0.1:8081".parse().unwrap()) + ); assert_eq!(&req.match_info()["test"], "123"); assert_eq!(req.version(), Version::HTTP_2); - let data = req.get_app_data::().unwrap(); - assert!(req.get_app_data::().is_none()); - assert_eq!(*data, 10); + let data = req.app_data::>().unwrap(); + assert!(req.app_data::>().is_none()); assert_eq!(*data.get_ref(), 10); - assert!(req.app_data::().is_none()); - let data = req.app_data::().unwrap(); - assert_eq!(*data, 10); + assert!(req.app_data::().is_none()); + let data = req.app_data::().unwrap(); + assert_eq!(*data, 20); } - #[test] - fn test_request_methods() { + #[actix_rt::test] + async fn test_request_methods() { let mut app = init_service( App::new().service( web::resource("/index.html") - .route(web::put().to(|| HttpResponse::Ok().body("put!"))) - .route(web::patch().to(|| HttpResponse::Ok().body("patch!"))) - .route(web::delete().to(|| HttpResponse::Ok().body("delete!"))), + .route(web::put().to(|| async { HttpResponse::Ok().body("put!") })) + .route( + web::patch().to(|| async { HttpResponse::Ok().body("patch!") }), + ) + .route( + web::delete() + .to(|| async { HttpResponse::Ok().body("delete!") }), + ), ), - ); + ) + .await; let put_req = TestRequest::put() .uri("/index.html") .header(header::CONTENT_TYPE, "application/json") .to_request(); - let result = read_response(&mut app, put_req); + let result = read_response(&mut app, put_req).await; assert_eq!(result, Bytes::from_static(b"put!")); let patch_req = TestRequest::patch() @@ -628,29 +1018,28 @@ mod tests { .header(header::CONTENT_TYPE, "application/json") .to_request(); - let result = read_response(&mut app, patch_req); + let result = read_response(&mut app, patch_req).await; assert_eq!(result, Bytes::from_static(b"patch!")); let delete_req = TestRequest::delete().uri("/index.html").to_request(); - let result = read_response(&mut app, delete_req); + let result = read_response(&mut app, delete_req).await; assert_eq!(result, Bytes::from_static(b"delete!")); } - #[test] - fn test_response() { - let mut app = init_service( - App::new().service( - web::resource("/index.html") - .route(web::post().to(|| HttpResponse::Ok().body("welcome!"))), - ), - ); + #[actix_rt::test] + async fn test_response() { + let mut app = + init_service(App::new().service(web::resource("/index.html").route( + web::post().to(|| async { HttpResponse::Ok().body("welcome!") }), + ))) + .await; let req = TestRequest::post() .uri("/index.html") .header(header::CONTENT_TYPE, "application/json") .to_request(); - let result = read_response(&mut app, req); + let result = read_response(&mut app, req).await; assert_eq!(result, Bytes::from_static(b"welcome!")); } @@ -660,13 +1049,14 @@ mod tests { name: String, } - #[test] - fn test_response_json() { + #[actix_rt::test] + async fn test_response_json() { let mut app = init_service(App::new().service(web::resource("/people").route( web::post().to(|person: web::Json| { - HttpResponse::Ok().json(person.into_inner()) + async { HttpResponse::Ok().json(person.into_inner()) } }), - ))); + ))) + .await; let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); @@ -676,17 +1066,18 @@ mod tests { .set_payload(payload) .to_request(); - let result: Person = read_response_json(&mut app, req); + let result: Person = read_response_json(&mut app, req).await; assert_eq!(&result.id, "12345"); } - #[test] - fn test_request_response_form() { + #[actix_rt::test] + async fn test_request_response_form() { let mut app = init_service(App::new().service(web::resource("/people").route( web::post().to(|person: web::Form| { - HttpResponse::Ok().json(person.into_inner()) + async { HttpResponse::Ok().json(person.into_inner()) } }), - ))); + ))) + .await; let payload = Person { id: "12345".to_string(), @@ -700,18 +1091,19 @@ mod tests { assert_eq!(req.content_type(), "application/x-www-form-urlencoded"); - let result: Person = read_response_json(&mut app, req); + let result: Person = read_response_json(&mut app, req).await; assert_eq!(&result.id, "12345"); assert_eq!(&result.name, "User name"); } - #[test] - fn test_request_response_json() { + #[actix_rt::test] + async fn test_request_response_json() { let mut app = init_service(App::new().service(web::resource("/people").route( web::post().to(|person: web::Json| { - HttpResponse::Ok().json(person.into_inner()) + async { HttpResponse::Ok().json(person.into_inner()) } }), - ))); + ))) + .await; let payload = Person { id: "12345".to_string(), @@ -725,33 +1117,55 @@ mod tests { assert_eq!(req.content_type(), "application/json"); - let result: Person = read_response_json(&mut app, req); + let result: Person = read_response_json(&mut app, req).await; assert_eq!(&result.id, "12345"); assert_eq!(&result.name, "User name"); } - #[test] - fn test_async_with_block() { - fn async_with_block() -> impl Future { - web::block(move || Some(4).ok_or("wrong")).then(|res| match res { - Ok(value) => HttpResponse::Ok() + #[actix_rt::test] + async fn test_async_with_block() { + async fn async_with_block() -> Result { + let res = web::block(move || Some(4usize).ok_or("wrong")).await; + + match res { + Ok(value) => Ok(HttpResponse::Ok() .content_type("text/plain") - .body(format!("Async with block value: {}", value)), + .body(format!("Async with block value: {}", value))), Err(_) => panic!("Unexpected"), - }) + } } let mut app = init_service( - App::new().service(web::resource("/index.html").to_async(async_with_block)), - ); + App::new().service(web::resource("/index.html").to(async_with_block)), + ) + .await; let req = TestRequest::post().uri("/index.html").to_request(); - let res = block_fn(|| app.call(req)).unwrap(); + let res = app.call(req).await.unwrap(); assert!(res.status().is_success()); } - #[test] - fn test_actor() { + #[actix_rt::test] + async fn test_server_data() { + async fn handler(data: web::Data) -> impl Responder { + assert_eq!(**data, 10); + HttpResponse::Ok() + } + + let mut app = init_service( + App::new() + .data(10usize) + .service(web::resource("/index.html").to(handler)), + ) + .await; + + let req = TestRequest::post().uri("/index.html").to_request(); + let res = app.call(req).await.unwrap(); + assert!(res.status().is_success()); + } + + #[actix_rt::test] + async fn test_actor() { use actix::Actor; struct MyActor; @@ -770,21 +1184,26 @@ mod tests { } } - let addr = run_on(|| MyActor.start()); - let mut app = init_service(App::new().service( - web::resource("/index.html").to_async(move || { - addr.send(Num(1)).from_err().and_then(|res| { - if res == 1 { - HttpResponse::Ok() - } else { - HttpResponse::BadRequest() + let addr = MyActor.start(); + + let mut app = init_service(App::new().service(web::resource("/index.html").to( + move || { + addr.send(Num(1)).map(|res| match res { + Ok(res) => { + if res == 1 { + Ok(HttpResponse::Ok()) + } else { + Ok(HttpResponse::BadRequest()) + } } + Err(err) => Err(err), }) - }), - )); + }, + ))) + .await; let req = TestRequest::post().uri("/index.html").to_request(); - let res = block_fn(|| app.call(req)).unwrap(); + let res = app.call(req).await.unwrap(); assert!(res.status().is_success()); } } diff --git a/src/types/form.rs b/src/types/form.rs index 9ab98b17b..d917345e1 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -1,15 +1,20 @@ //! Form extractor +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; +use std::task::{Context, Poll}; use std::{fmt, ops}; use actix_http::{Error, HttpMessage, Payload, Response}; use bytes::BytesMut; use encoding_rs::{Encoding, UTF_8}; -use futures::{Future, Poll, Stream}; +use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; +use futures::StreamExt; use serde::de::DeserializeOwned; use serde::Serialize; +#[cfg(feature = "compress")] use crate::dev::Decompress; use crate::error::UrlencodedError; use crate::extract::FromRequest; @@ -35,9 +40,8 @@ use crate::responder::Responder; /// /// ### Example /// ```rust -/// # extern crate actix_web; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{web, App}; +/// use actix_web::web; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct FormData { @@ -61,9 +65,9 @@ use crate::responder::Responder; /// /// ### Example /// ```rust -/// # #[macro_use] extern crate serde_derive; -/// # use actix_web::*; -/// # +/// use actix_web::*; +/// use serde_derive::Serialize; +/// /// #[derive(Serialize)] /// struct SomeForm { /// name: String, @@ -111,7 +115,7 @@ where { type Config = FormConfig; type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { @@ -121,44 +125,45 @@ where .map(|c| (c.limit, c.ehandler.clone())) .unwrap_or((16384, None)); - Box::new( - UrlEncoded::new(req, payload) - .limit(limit) - .map_err(move |e| { + UrlEncoded::new(req, payload) + .limit(limit) + .map(move |res| match res { + Err(e) => { if let Some(err) = err { - (*err)(e, &req2) + Err((*err)(e, &req2)) } else { - e.into() + Err(e.into()) } - }) - .map(Form), - ) + } + Ok(item) => Ok(Form(item)), + }) + .boxed_local() } } impl fmt::Debug for Form { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl fmt::Display for Form { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl Responder for Form { type Error = Error; - type Future = Result; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { let body = match serde_urlencoded::to_string(&self.0) { Ok(body) => body, - Err(e) => return Err(e.into()), + Err(e) => return err(e.into()), }; - Ok(Response::build(StatusCode::OK) + ok(Response::build(StatusCode::OK) .set(ContentType::form_url_encoded()) .body(body)) } @@ -167,8 +172,8 @@ impl Responder for Form { /// Form extractor configuration /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, FromRequest, Result}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct FormData { @@ -177,7 +182,7 @@ impl Responder for Form { /// /// /// Extract form data using serde. /// /// Custom configuration is used for this handler, max payload size is 4k -/// fn index(form: web::Form) -> Result { +/// async fn index(form: web::Form) -> Result { /// Ok(format!("Welcome {}!", form.username)) /// } /// @@ -185,7 +190,7 @@ impl Responder for Form { /// let app = App::new().service( /// web::resource("/index.html") /// // change `Form` extractor configuration -/// .data( +/// .app_data( /// web::Form::::configure(|cfg| cfg.limit(4097)) /// ) /// .route(web::get().to(index)) @@ -236,12 +241,15 @@ impl Default for FormConfig { /// * content-length is greater than 32k /// pub struct UrlEncoded { + #[cfg(feature = "compress")] stream: Option>, + #[cfg(not(feature = "compress"))] + stream: Option, limit: usize, length: Option, encoding: &'static Encoding, err: Option, - fut: Option>>, + fut: Option>>, } impl UrlEncoded { @@ -269,7 +277,11 @@ impl UrlEncoded { } }; + #[cfg(feature = "compress")] let payload = Decompress::from_headers(payload.take(), req.headers()); + #[cfg(not(feature = "compress"))] + let payload = payload.take(); + UrlEncoded { encoding, stream: Some(payload), @@ -302,42 +314,45 @@ impl Future for UrlEncoded where U: DeserializeOwned + 'static, { - type Item = U; - type Error = UrlencodedError; + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(ref mut fut) = self.fut { - return fut.poll(); + return Pin::new(fut).poll(cx); } if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } // payload size let limit = self.limit; if let Some(len) = self.length.take() { if len > limit { - return Err(UrlencodedError::Overflow); + return Poll::Ready(Err(UrlencodedError::Overflow { size: len, limit })); } } // future let encoding = self.encoding; - let fut = self - .stream - .take() - .unwrap() - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(UrlencodedError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) + let mut stream = self.stream.take().unwrap(); + + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if (body.len() + chunk.len()) > limit { + return Err(UrlencodedError::Overflow { + size: body.len() + chunk.len(), + limit, + }); + } else { + body.extend_from_slice(&chunk); + } } - }) - .and_then(move |body| { + if encoding == UTF_8 { serde_urlencoded::from_bytes::(&body) .map_err(|_| UrlencodedError::Parse) @@ -349,9 +364,10 @@ where serde_urlencoded::from_str::(&body) .map_err(|_| UrlencodedError::Parse) } - }); - self.fut = Some(Box::new(fut)); - self.poll() + } + .boxed_local(), + ); + self.poll(cx) } } @@ -362,7 +378,7 @@ mod tests { use super::*; use crate::http::header::{HeaderValue, CONTENT_TYPE}; - use crate::test::{block_on, TestRequest}; + use crate::test::TestRequest; #[derive(Deserialize, Serialize, Debug, PartialEq)] struct Info { @@ -370,15 +386,15 @@ mod tests { counter: i64, } - #[test] - fn test_form() { + #[actix_rt::test] + async fn test_form() { let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") .header(CONTENT_LENGTH, "11") .set_payload(Bytes::from_static(b"hello=world&counter=123")) .to_http_parts(); - let Form(s) = block_on(Form::::from_request(&req, &mut pl)).unwrap(); + let Form(s) = Form::::from_request(&req, &mut pl).await.unwrap(); assert_eq!( s, Info { @@ -390,8 +406,8 @@ mod tests { fn eq(err: UrlencodedError, other: UrlencodedError) -> bool { match err { - UrlencodedError::Overflow => match other { - UrlencodedError::Overflow => true, + UrlencodedError::Overflow { .. } => match other { + UrlencodedError::Overflow { .. } => true, _ => false, }, UrlencodedError::UnknownLength => match other { @@ -406,38 +422,41 @@ mod tests { } } - #[test] - fn test_urlencoded_error() { + #[actix_rt::test] + async fn test_urlencoded_error() { let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") .header(CONTENT_LENGTH, "xxxx") .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)); + let info = UrlEncoded::::new(&req, &mut pl).await; assert!(eq(info.err().unwrap(), UrlencodedError::UnknownLength)); let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") .header(CONTENT_LENGTH, "1000000") .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)); - assert!(eq(info.err().unwrap(), UrlencodedError::Overflow)); + let info = UrlEncoded::::new(&req, &mut pl).await; + assert!(eq( + info.err().unwrap(), + UrlencodedError::Overflow { size: 0, limit: 0 } + )); let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "text/plain") .header(CONTENT_LENGTH, "10") .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)); + let info = UrlEncoded::::new(&req, &mut pl).await; assert!(eq(info.err().unwrap(), UrlencodedError::ContentType)); } - #[test] - fn test_urlencoded() { + #[actix_rt::test] + async fn test_urlencoded() { let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") .header(CONTENT_LENGTH, "11") .set_payload(Bytes::from_static(b"hello=world&counter=123")) .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)).unwrap(); + let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); assert_eq!( info, Info { @@ -454,7 +473,7 @@ mod tests { .set_payload(Bytes::from_static(b"hello=world&counter=123")) .to_http_parts(); - let info = block_on(UrlEncoded::::new(&req, &mut pl)).unwrap(); + let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); assert_eq!( info, Info { @@ -464,15 +483,15 @@ mod tests { ); } - #[test] - fn test_responder() { + #[actix_rt::test] + async fn test_responder() { let req = TestRequest::default().to_http_request(); let form = Form(Info { hello: "world".to_string(), counter: 123, }); - let resp = form.respond_to(&req).unwrap(); + let resp = form.respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), diff --git a/src/types/json.rs b/src/types/json.rs index f309a3c5a..fb00bf7a6 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -1,10 +1,14 @@ //! Json extractor/responder +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::{fmt, ops}; use bytes::BytesMut; -use futures::{Future, Poll, Stream}; +use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; +use futures::StreamExt; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json; @@ -12,6 +16,7 @@ use serde_json; use actix_http::http::{header::CONTENT_LENGTH, StatusCode}; use actix_http::{HttpMessage, Payload, Response}; +#[cfg(feature = "compress")] use crate::dev::Decompress; use crate::error::{Error, JsonPayloadError}; use crate::extract::FromRequest; @@ -33,8 +38,8 @@ use crate::responder::Responder; /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -42,7 +47,7 @@ use crate::responder::Responder; /// } /// /// /// deserialize `Info` from request's body -/// fn index(info: web::Json) -> String { +/// async fn index(info: web::Json) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -60,9 +65,9 @@ use crate::responder::Responder; /// trait from *serde*. /// /// ```rust -/// # #[macro_use] extern crate serde_derive; -/// # use actix_web::*; -/// # +/// use actix_web::*; +/// use serde_derive::Serialize; +/// /// #[derive(Serialize)] /// struct MyObj { /// name: String, @@ -102,7 +107,7 @@ impl fmt::Debug for Json where T: fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Json: {:?}", self.0) } } @@ -111,22 +116,22 @@ impl fmt::Display for Json where T: fmt::Display, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0, f) } } impl Responder for Json { type Error = Error; - type Future = Result; + type Future = Ready>; fn respond_to(self, _: &HttpRequest) -> Self::Future { let body = match serde_json::to_string(&self.0) { Ok(body) => body, - Err(e) => return Err(e.into()), + Err(e) => return err(e.into()), }; - Ok(Response::build(StatusCode::OK) + ok(Response::build(StatusCode::OK) .content_type("application/json") .body(body)) } @@ -144,8 +149,8 @@ impl Responder for Json { /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -153,7 +158,7 @@ impl Responder for Json { /// } /// /// /// deserialize `Info` from request's body -/// fn index(info: web::Json) -> String { +/// async fn index(info: web::Json) -> String { /// format!("Welcome {}!", info.username) /// } /// @@ -169,7 +174,7 @@ where T: DeserializeOwned + 'static, { type Error = Error; - type Future = Box>; + type Future = LocalBoxFuture<'static, Result>; type Config = JsonConfig; #[inline] @@ -180,31 +185,32 @@ where .map(|c| (c.limit, c.ehandler.clone(), c.content_type.clone())) .unwrap_or((32768, None, None)); - Box::new( - JsonBody::new(req, payload, ctype) - .limit(limit) - .map_err(move |e| { + JsonBody::new(req, payload, ctype) + .limit(limit) + .map(move |res| match res { + Err(e) => { log::debug!( "Failed to deserialize Json from payload. \ Request path: {}", req2.path() ); if let Some(err) = err { - (*err)(e, &req2) + Err((*err)(e, &req2)) } else { - e.into() + Err(e.into()) } - }) - .map(Json), - ) + } + Ok(data) => Ok(Json(data)), + }) + .boxed_local() } } /// Json extractor configuration /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -212,23 +218,24 @@ where /// } /// /// /// deserialize `Info` from request's body, max payload size is 4kb -/// fn index(info: web::Json) -> String { +/// async fn index(info: web::Json) -> String { /// format!("Welcome {}!", info.username) /// } /// /// fn main() { /// let app = App::new().service( -/// web::resource("/index.html").data( -/// // change json extractor configuration -/// web::Json::::configure(|cfg| { -/// cfg.limit(4096) -/// .content_type(|mime| { // <- accept text/plain content type -/// mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN -/// }) -/// .error_handler(|err, req| { // <- create custom error response -/// error::InternalError::from_response( -/// err, HttpResponse::Conflict().finish()).into() -/// }) +/// web::resource("/index.html") +/// .app_data( +/// // change json extractor configuration +/// web::Json::::configure(|cfg| { +/// cfg.limit(4096) +/// .content_type(|mime| { // <- accept text/plain content type +/// mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN +/// }) +/// .error_handler(|err, req| { // <- create custom error response +/// error::InternalError::from_response( +/// err, HttpResponse::Conflict().finish()).into() +/// }) /// })) /// .route(web::post().to(index)) /// ); @@ -288,9 +295,12 @@ impl Default for JsonConfig { pub struct JsonBody { limit: usize, length: Option, + #[cfg(feature = "compress")] stream: Option>, + #[cfg(not(feature = "compress"))] + stream: Option, err: Option, - fut: Option>>, + fut: Option>>, } impl JsonBody @@ -327,7 +337,11 @@ where .get(&CONTENT_LENGTH) .and_then(|l| l.to_str().ok()) .and_then(|s| s.parse::().ok()); + + #[cfg(feature = "compress")] let payload = Decompress::from_headers(payload.take(), req.headers()); + #[cfg(not(feature = "compress"))] + let payload = payload.take(); JsonBody { limit: 262_144, @@ -349,41 +363,43 @@ impl Future for JsonBody where U: DeserializeOwned + 'static, { - type Item = U; - type Error = JsonPayloadError; + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(ref mut fut) = self.fut { - return fut.poll(); + return Pin::new(fut).poll(cx); } if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } let limit = self.limit; if let Some(len) = self.length.take() { if len > limit { - return Err(JsonPayloadError::Overflow); + return Poll::Ready(Err(JsonPayloadError::Overflow)); } } + let mut stream = self.stream.take().unwrap(); - let fut = self - .stream - .take() - .unwrap() - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(JsonPayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if (body.len() + chunk.len()) > limit { + return Err(JsonPayloadError::Overflow); + } else { + body.extend_from_slice(&chunk); + } } - }) - .and_then(|body| Ok(serde_json::from_slice::(&body)?)); - self.fut = Some(Box::new(fut)); - self.poll() + Ok(serde_json::from_slice::(&body)?) + } + .boxed_local(), + ); + + self.poll(cx) } } @@ -395,7 +411,7 @@ mod tests { use super::*; use crate::error::InternalError; use crate::http::header; - use crate::test::{block_on, TestRequest}; + use crate::test::{load_stream, TestRequest}; use crate::HttpResponse; #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -417,14 +433,14 @@ mod tests { } } - #[test] - fn test_responder() { + #[actix_rt::test] + async fn test_responder() { let req = TestRequest::default().to_http_request(); let j = Json(MyObject { name: "test".to_string(), }); - let resp = j.respond_to(&req).unwrap(); + let resp = j.respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -435,8 +451,8 @@ mod tests { assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}"); } - #[test] - fn test_custom_error_responder() { + #[actix_rt::test] + async fn test_custom_error_responder() { let (req, mut pl) = TestRequest::default() .header( header::CONTENT_TYPE, @@ -447,7 +463,7 @@ mod tests { header::HeaderValue::from_static("16"), ) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().limit(10).error_handler(|err, _| { + .app_data(JsonConfig::default().limit(10).error_handler(|err, _| { let msg = MyObject { name: "invalid request".to_string(), }; @@ -457,17 +473,17 @@ mod tests { })) .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); + let s = Json::::from_request(&req, &mut pl).await; let mut resp = Response::from_error(s.err().unwrap().into()); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let body = block_on(resp.take_body().concat2()).unwrap(); + let body = load_stream(resp.take_body()).await.unwrap(); let msg: MyObject = serde_json::from_slice(&body).unwrap(); assert_eq!(msg.name, "invalid request"); } - #[test] - fn test_extract() { + #[actix_rt::test] + async fn test_extract() { let (req, mut pl) = TestRequest::default() .header( header::CONTENT_TYPE, @@ -480,7 +496,7 @@ mod tests { .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)).unwrap(); + let s = Json::::from_request(&req, &mut pl).await.unwrap(); assert_eq!(s.name, "test"); assert_eq!( s.into_inner(), @@ -499,10 +515,10 @@ mod tests { header::HeaderValue::from_static("16"), ) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().limit(10)) + .app_data(JsonConfig::default().limit(10)) .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); + let s = Json::::from_request(&req, &mut pl).await; assert!(format!("{}", s.err().unwrap()) .contains("Json payload size is bigger than allowed")); @@ -516,20 +532,20 @@ mod tests { header::HeaderValue::from_static("16"), ) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data( + .app_data( JsonConfig::default() .limit(10) .error_handler(|_, _| JsonPayloadError::ContentType.into()), ) .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); + let s = Json::::from_request(&req, &mut pl).await; assert!(format!("{}", s.err().unwrap()).contains("Content type error")); } - #[test] - fn test_json_body() { + #[actix_rt::test] + async fn test_json_body() { let (req, mut pl) = TestRequest::default().to_http_parts(); - let json = block_on(JsonBody::::new(&req, &mut pl, None)); + let json = JsonBody::::new(&req, &mut pl, None).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() @@ -538,7 +554,7 @@ mod tests { header::HeaderValue::from_static("application/text"), ) .to_http_parts(); - let json = block_on(JsonBody::::new(&req, &mut pl, None)); + let json = JsonBody::::new(&req, &mut pl, None).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() @@ -552,7 +568,9 @@ mod tests { ) .to_http_parts(); - let json = block_on(JsonBody::::new(&req, &mut pl, None).limit(100)); + let json = JsonBody::::new(&req, &mut pl, None) + .limit(100) + .await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); let (req, mut pl) = TestRequest::default() @@ -567,7 +585,7 @@ mod tests { .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .to_http_parts(); - let json = block_on(JsonBody::::new(&req, &mut pl, None)); + let json = JsonBody::::new(&req, &mut pl, None).await; assert_eq!( json.ok().unwrap(), MyObject { @@ -576,8 +594,8 @@ mod tests { ); } - #[test] - fn test_with_json_and_bad_content_type() { + #[actix_rt::test] + async fn test_with_json_and_bad_content_type() { let (req, mut pl) = TestRequest::with_header( header::CONTENT_TYPE, header::HeaderValue::from_static("text/plain"), @@ -587,15 +605,15 @@ mod tests { header::HeaderValue::from_static("16"), ) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().limit(4096)) + .app_data(JsonConfig::default().limit(4096)) .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); + let s = Json::::from_request(&req, &mut pl).await; assert!(s.is_err()) } - #[test] - fn test_with_json_and_good_custom_content_type() { + #[actix_rt::test] + async fn test_with_json_and_good_custom_content_type() { let (req, mut pl) = TestRequest::with_header( header::CONTENT_TYPE, header::HeaderValue::from_static("text/plain"), @@ -605,17 +623,17 @@ mod tests { header::HeaderValue::from_static("16"), ) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().content_type(|mime: mime::Mime| { + .app_data(JsonConfig::default().content_type(|mime: mime::Mime| { mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN })) .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); + let s = Json::::from_request(&req, &mut pl).await; assert!(s.is_ok()) } - #[test] - fn test_with_json_and_bad_custom_content_type() { + #[actix_rt::test] + async fn test_with_json_and_bad_custom_content_type() { let (req, mut pl) = TestRequest::with_header( header::CONTENT_TYPE, header::HeaderValue::from_static("text/html"), @@ -625,12 +643,12 @@ mod tests { header::HeaderValue::from_static("16"), ) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .data(JsonConfig::default().content_type(|mime: mime::Mime| { + .app_data(JsonConfig::default().content_type(|mime: mime::Mime| { mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN })) .to_http_parts(); - let s = block_on(Json::::from_request(&req, &mut pl)); + let s = Json::::from_request(&req, &mut pl).await; assert!(s.is_err()) } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 43a189e2c..b32711e2a 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -12,3 +12,4 @@ pub use self::json::{Json, JsonConfig}; pub use self::path::{Path, PathConfig}; pub use self::payload::{Payload, PayloadConfig}; pub use self::query::{Query, QueryConfig}; +pub use self::readlines::Readlines; diff --git a/src/types/path.rs b/src/types/path.rs index a46575764..a37cb8f12 100644 --- a/src/types/path.rs +++ b/src/types/path.rs @@ -1,10 +1,10 @@ //! Path extractor - use std::sync::Arc; use std::{fmt, ops}; use actix_http::error::{Error, ErrorNotFound}; use actix_router::PathDeserializer; +use futures::future::{ready, Ready}; use serde::de; use crate::dev::Payload; @@ -15,6 +15,8 @@ use crate::FromRequest; #[derive(PartialEq, Eq, PartialOrd, Ord)] /// Extract typed information from the request's path. /// +/// [**PathConfig**](struct.PathConfig.html) allows to configure extraction process. +/// /// ## Example /// /// ```rust @@ -23,7 +25,7 @@ use crate::FromRequest; /// /// extract path info from "/{username}/{count}/index.html" url /// /// {username} - deserializes to a String /// /// {count} - - deserializes to a u32 -/// fn index(info: web::Path<(String, u32)>) -> String { +/// async fn index(info: web::Path<(String, u32)>) -> String { /// format!("Welcome {}! {}", info.0, info.1) /// } /// @@ -39,8 +41,8 @@ use crate::FromRequest; /// implements `Deserialize` trait from *serde*. /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, Error}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -48,7 +50,7 @@ use crate::FromRequest; /// } /// /// /// extract `Info` from a path using serde -/// fn index(info: web::Path) -> Result { +/// async fn index(info: web::Path) -> Result { /// Ok(format!("Welcome {}!", info.username)) /// } /// @@ -97,13 +99,13 @@ impl From for Path { } impl fmt::Debug for Path { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.inner.fmt(f) } } impl fmt::Display for Path { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.inner.fmt(f) } } @@ -118,7 +120,7 @@ impl fmt::Display for Path { /// /// extract path info from "/{username}/{count}/index.html" url /// /// {username} - deserializes to a String /// /// {count} - - deserializes to a u32 -/// fn index(info: web::Path<(String, u32)>) -> String { +/// async fn index(info: web::Path<(String, u32)>) -> String { /// format!("Welcome {}! {}", info.0, info.1) /// } /// @@ -134,8 +136,8 @@ impl fmt::Display for Path { /// implements `Deserialize` trait from *serde*. /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App, Error}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -143,7 +145,7 @@ impl fmt::Display for Path { /// } /// /// /// extract `Info` from a path using serde -/// fn index(info: web::Path) -> Result { +/// async fn index(info: web::Path) -> Result { /// Ok(format!("Welcome {}!", info.username)) /// } /// @@ -159,7 +161,7 @@ where T: de::DeserializeOwned, { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = PathConfig; #[inline] @@ -169,31 +171,32 @@ where .map(|c| c.ehandler.clone()) .unwrap_or(None); - de::Deserialize::deserialize(PathDeserializer::new(req.match_info())) - .map(|inner| Path { inner }) - .map_err(move |e| { - log::debug!( - "Failed during Path extractor deserialization. \ - Request path: {:?}", - req.path() - ); - if let Some(error_handler) = error_handler { - let e = PathError::Deserialize(e); - (error_handler)(e, req) - } else { - ErrorNotFound(e) - } - }) + ready( + de::Deserialize::deserialize(PathDeserializer::new(req.match_info())) + .map(|inner| Path { inner }) + .map_err(move |e| { + log::debug!( + "Failed during Path extractor deserialization. \ + Request path: {:?}", + req.path() + ); + if let Some(error_handler) = error_handler { + let e = PathError::Deserialize(e); + (error_handler)(e, req) + } else { + ErrorNotFound(e) + } + }), + ) } } /// Path extractor configuration /// /// ```rust -/// # #[macro_use] -/// # extern crate serde_derive; /// use actix_web::web::PathConfig; /// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize, Debug)] /// enum Folder { @@ -204,14 +207,14 @@ where /// } /// /// // deserialize `Info` from request's path -/// fn index(folder: web::Path) -> String { +/// async fn index(folder: web::Path) -> String { /// format!("Selected folder: {:?}!", folder) /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/messages/{folder}") -/// .data(PathConfig::default().error_handler(|err, req| { +/// .app_data(PathConfig::default().error_handler(|err, req| { /// error::InternalError::from_response( /// err, /// HttpResponse::Conflict().finish(), @@ -251,7 +254,7 @@ mod tests { use serde_derive::Deserialize; use super::*; - use crate::test::{block_on, TestRequest}; + use crate::test::TestRequest; use crate::{error, http, HttpResponse}; #[derive(Deserialize, Debug, Display)] @@ -267,54 +270,54 @@ mod tests { value: u32, } - #[test] - fn test_extract_path_single() { + #[actix_rt::test] + async fn test_extract_path_single() { let resource = ResourceDef::new("/{value}/"); let mut req = TestRequest::with_uri("/32/").to_srv_request(); resource.match_path(req.match_info_mut()); let (req, mut pl) = req.into_parts(); - assert_eq!(*Path::::from_request(&req, &mut pl).unwrap(), 32); - assert!(Path::::from_request(&req, &mut pl).is_err()); + assert_eq!(*Path::::from_request(&req, &mut pl).await.unwrap(), 32); + assert!(Path::::from_request(&req, &mut pl).await.is_err()); } - #[test] - fn test_tuple_extract() { + #[actix_rt::test] + async fn test_tuple_extract() { let resource = ResourceDef::new("/{key}/{value}/"); let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); resource.match_path(req.match_info_mut()); let (req, mut pl) = req.into_parts(); - let res = - block_on(<(Path<(String, String)>,)>::from_request(&req, &mut pl)).unwrap(); + let res = <(Path<(String, String)>,)>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!((res.0).0, "name"); assert_eq!((res.0).1, "user1"); - let res = block_on( - <(Path<(String, String)>, Path<(String, String)>)>::from_request( - &req, &mut pl, - ), + let res = <(Path<(String, String)>, Path<(String, String)>)>::from_request( + &req, &mut pl, ) + .await .unwrap(); assert_eq!((res.0).0, "name"); assert_eq!((res.0).1, "user1"); assert_eq!((res.1).0, "name"); assert_eq!((res.1).1, "user1"); - let () = <()>::from_request(&req, &mut pl).unwrap(); + let () = <()>::from_request(&req, &mut pl).await.unwrap(); } - #[test] - fn test_request_extract() { + #[actix_rt::test] + async fn test_request_extract() { let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); let resource = ResourceDef::new("/{key}/{value}/"); resource.match_path(req.match_info_mut()); let (req, mut pl) = req.into_parts(); - let mut s = Path::::from_request(&req, &mut pl).unwrap(); + let mut s = Path::::from_request(&req, &mut pl).await.unwrap(); assert_eq!(s.key, "name"); assert_eq!(s.value, "user1"); s.value = "user2".to_string(); @@ -326,7 +329,9 @@ mod tests { let s = s.into_inner(); assert_eq!(s.value, "user2"); - let s = Path::<(String, String)>::from_request(&req, &mut pl).unwrap(); + let s = Path::<(String, String)>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!(s.0, "name"); assert_eq!(s.1, "user1"); @@ -335,23 +340,27 @@ mod tests { resource.match_path(req.match_info_mut()); let (req, mut pl) = req.into_parts(); - let s = Path::::from_request(&req, &mut pl).unwrap(); + let s = Path::::from_request(&req, &mut pl).await.unwrap(); assert_eq!(s.as_ref().key, "name"); assert_eq!(s.value, 32); - let s = Path::<(String, u8)>::from_request(&req, &mut pl).unwrap(); + let s = Path::<(String, u8)>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!(s.0, "name"); assert_eq!(s.1, 32); - let res = Path::>::from_request(&req, &mut pl).unwrap(); + let res = Path::>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!(res[0], "name".to_owned()); assert_eq!(res[1], "32".to_owned()); } - #[test] - fn test_custom_err_handler() { + #[actix_rt::test] + async fn test_custom_err_handler() { let (req, mut pl) = TestRequest::with_uri("/name/user1/") - .data(PathConfig::default().error_handler(|err, _| { + .app_data(PathConfig::default().error_handler(|err, _| { error::InternalError::from_response( err, HttpResponse::Conflict().finish(), @@ -360,7 +369,9 @@ mod tests { })) .to_http_parts(); - let s = block_on(Path::<(usize,)>::from_request(&req, &mut pl)).unwrap_err(); + let s = Path::<(usize,)>::from_request(&req, &mut pl) + .await + .unwrap_err(); let res: HttpResponse = s.into(); assert_eq!(res.status(), http::StatusCode::CONFLICT); diff --git a/src/types/payload.rs b/src/types/payload.rs index f33e2e5f1..449e6c5b0 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -1,12 +1,15 @@ //! Payload/Bytes/String extractors +use std::future::Future; +use std::pin::Pin; use std::str; +use std::task::{Context, Poll}; use actix_http::error::{Error, ErrorBadRequest, PayloadError}; use actix_http::HttpMessage; use bytes::{Bytes, BytesMut}; use encoding_rs::UTF_8; -use futures::future::{err, Either, FutureResult}; -use futures::{Future, Poll, Stream}; +use futures::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; +use futures::{Stream, StreamExt}; use mime::Mime; use crate::dev; @@ -19,39 +22,46 @@ use crate::request::HttpRequest; /// ## Example /// /// ```rust -/// use futures::{Future, Stream}; +/// use futures::{Future, Stream, StreamExt}; /// use actix_web::{web, error, App, Error, HttpResponse}; /// /// /// extract binary data from request -/// fn index(body: web::Payload) -> impl Future +/// async fn index(mut body: web::Payload) -> Result /// { -/// body.map_err(Error::from) -/// .fold(web::BytesMut::new(), move |mut body, chunk| { -/// body.extend_from_slice(&chunk); -/// Ok::<_, Error>(body) -/// }) -/// .and_then(|body| { -/// format!("Body {:?}!", body); -/// Ok(HttpResponse::Ok().finish()) -/// }) +/// let mut bytes = web::BytesMut::new(); +/// while let Some(item) = body.next().await { +/// bytes.extend_from_slice(&item?); +/// } +/// +/// format!("Body {:?}!", bytes); +/// Ok(HttpResponse::Ok().finish()) /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/index.html").route( -/// web::get().to_async(index)) +/// web::get().to(index)) /// ); /// } /// ``` -pub struct Payload(crate::dev::Payload); +pub struct Payload(pub crate::dev::Payload); + +impl Payload { + /// Deconstruct to a inner value + pub fn into_inner(self) -> crate::dev::Payload { + self.0 + } +} impl Stream for Payload { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; #[inline] - fn poll(&mut self) -> Poll, PayloadError> { - self.0.poll() + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) } } @@ -60,38 +70,36 @@ impl Stream for Payload { /// ## Example /// /// ```rust -/// use futures::{Future, Stream}; +/// use futures::{Future, Stream, StreamExt}; /// use actix_web::{web, error, App, Error, HttpResponse}; /// /// /// extract binary data from request -/// fn index(body: web::Payload) -> impl Future +/// async fn index(mut body: web::Payload) -> Result /// { -/// body.map_err(Error::from) -/// .fold(web::BytesMut::new(), move |mut body, chunk| { -/// body.extend_from_slice(&chunk); -/// Ok::<_, Error>(body) -/// }) -/// .and_then(|body| { -/// format!("Body {:?}!", body); -/// Ok(HttpResponse::Ok().finish()) -/// }) +/// let mut bytes = web::BytesMut::new(); +/// while let Some(item) = body.next().await { +/// bytes.extend_from_slice(&item?); +/// } +/// +/// format!("Body {:?}!", bytes); +/// Ok(HttpResponse::Ok().finish()) /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/index.html").route( -/// web::get().to_async(index)) +/// web::get().to(index)) /// ); /// } /// ``` impl FromRequest for Payload { type Config = PayloadConfig; type Error = Error; - type Future = Result; + type Future = Ready>; #[inline] fn from_request(_: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { - Ok(Payload(payload.take())) + ok(Payload(payload.take())) } } @@ -109,7 +117,7 @@ impl FromRequest for Payload { /// use actix_web::{web, App}; /// /// /// extract binary data from request -/// fn index(body: Bytes) -> String { +/// async fn index(body: Bytes) -> String { /// format!("Body {:?}!", body) /// } /// @@ -123,8 +131,10 @@ impl FromRequest for Payload { impl FromRequest for Bytes { type Config = PayloadConfig; type Error = Error; - type Future = - Either>, FutureResult>; + type Future = Either< + LocalBoxFuture<'static, Result>, + Ready>, + >; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -137,13 +147,12 @@ impl FromRequest for Bytes { }; if let Err(e) = cfg.check_mimetype(req) { - return Either::B(err(e)); + return Either::Right(err(e)); } let limit = cfg.limit; - Either::A(Box::new( - HttpMessageBody::new(req, payload).limit(limit).from_err(), - )) + let fut = HttpMessageBody::new(req, payload).limit(limit); + Either::Left(async move { Ok(fut.await?) }.boxed_local()) } } @@ -160,14 +169,14 @@ impl FromRequest for Bytes { /// use actix_web::{web, App, FromRequest}; /// /// /// extract text data from request -/// fn index(text: String) -> String { +/// async fn index(text: String) -> String { /// format!("Body {}!", text) /// } /// /// fn main() { /// let app = App::new().service( /// web::resource("/index.html") -/// .data(String::configure(|cfg| { // <- limit size of the payload +/// .app_data(String::configure(|cfg| { // <- limit size of the payload /// cfg.limit(4096) /// })) /// .route(web::get().to(index)) // <- register handler with extractor params @@ -178,8 +187,8 @@ impl FromRequest for String { type Config = PayloadConfig; type Error = Error; type Future = Either< - Box>, - FutureResult, + LocalBoxFuture<'static, Result>, + Ready>, >; #[inline] @@ -194,33 +203,34 @@ impl FromRequest for String { // check content-type if let Err(e) = cfg.check_mimetype(req) { - return Either::B(err(e)); + return Either::Right(err(e)); } // check charset let encoding = match req.encoding() { Ok(enc) => enc, - Err(e) => return Either::B(err(e.into())), + Err(e) => return Either::Right(err(e.into())), }; let limit = cfg.limit; + let fut = HttpMessageBody::new(req, payload).limit(limit); - Either::A(Box::new( - HttpMessageBody::new(req, payload) - .limit(limit) - .from_err() - .and_then(move |body| { - if encoding == UTF_8 { - Ok(str::from_utf8(body.as_ref()) - .map_err(|_| ErrorBadRequest("Can not decode body"))? - .to_owned()) - } else { - Ok(encoding - .decode_without_bom_handling_and_without_replacement(&body) - .map(|s| s.into_owned()) - .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) - } - }), - )) + Either::Left( + async move { + let body = fut.await?; + + if encoding == UTF_8 { + Ok(str::from_utf8(body.as_ref()) + .map_err(|_| ErrorBadRequest("Can not decode body"))? + .to_owned()) + } else { + Ok(encoding + .decode_without_bom_handling_and_without_replacement(&body) + .map(|s| s.into_owned()) + .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) + } + } + .boxed_local(), + ) } } /// Payload configuration for request's payload. @@ -291,9 +301,12 @@ impl Default for PayloadConfig { pub struct HttpMessageBody { limit: usize, length: Option, + #[cfg(feature = "compress")] stream: Option>, + #[cfg(not(feature = "compress"))] + stream: Option, err: Option, - fut: Option>>, + fut: Option>>, } impl HttpMessageBody { @@ -312,8 +325,13 @@ impl HttpMessageBody { } } + #[cfg(feature = "compress")] + let stream = Some(dev::Decompress::from_headers(payload.take(), req.headers())); + #[cfg(not(feature = "compress"))] + let stream = Some(payload.take()); + HttpMessageBody { - stream: Some(dev::Decompress::from_headers(payload.take(), req.headers())), + stream, limit: 262_144, length: len, fut: None, @@ -339,42 +357,43 @@ impl HttpMessageBody { } impl Future for HttpMessageBody { - type Item = Bytes; - type Error = PayloadError; + type Output = Result; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(ref mut fut) = self.fut { - return fut.poll(); + return Pin::new(fut).poll(cx); } if let Some(err) = self.err.take() { - return Err(err); + return Poll::Ready(Err(err)); } if let Some(len) = self.length.take() { if len > self.limit { - return Err(PayloadError::Overflow); + return Poll::Ready(Err(PayloadError::Overflow)); } } // future let limit = self.limit; - self.fut = Some(Box::new( - self.stream - .take() - .unwrap() - .from_err() - .fold(BytesMut::with_capacity(8192), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(PayloadError::Overflow) + let mut stream = self.stream.take().unwrap(); + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if body.len() + chunk.len() > limit { + return Err(PayloadError::Overflow); } else { body.extend_from_slice(&chunk); - Ok(body) } - }) - .map(|body| body.freeze()), - )); - self.poll() + } + Ok(body.freeze()) + } + .boxed_local(), + ); + self.poll(cx) } } @@ -384,10 +403,10 @@ mod tests { use super::*; use crate::http::header; - use crate::test::{block_on, TestRequest}; + use crate::test::TestRequest; - #[test] - fn test_payload_config() { + #[actix_rt::test] + async fn test_payload_config() { let req = TestRequest::default().to_http_request(); let cfg = PayloadConfig::default().mimetype(mime::APPLICATION_JSON); assert!(cfg.check_mimetype(&req).is_err()); @@ -404,32 +423,32 @@ mod tests { assert!(cfg.check_mimetype(&req).is_ok()); } - #[test] - fn test_bytes() { + #[actix_rt::test] + async fn test_bytes() { let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "11") .set_payload(Bytes::from_static(b"hello=world")) .to_http_parts(); - let s = block_on(Bytes::from_request(&req, &mut pl)).unwrap(); + let s = Bytes::from_request(&req, &mut pl).await.unwrap(); assert_eq!(s, Bytes::from_static(b"hello=world")); } - #[test] - fn test_string() { + #[actix_rt::test] + async fn test_string() { let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "11") .set_payload(Bytes::from_static(b"hello=world")) .to_http_parts(); - let s = block_on(String::from_request(&req, &mut pl)).unwrap(); + let s = String::from_request(&req, &mut pl).await.unwrap(); assert_eq!(s, "hello=world"); } - #[test] - fn test_message_body() { + #[actix_rt::test] + async fn test_message_body() { let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx") .to_srv_request() .into_parts(); - let res = block_on(HttpMessageBody::new(&req, &mut pl)); + let res = HttpMessageBody::new(&req, &mut pl).await; match res.err().unwrap() { PayloadError::UnknownLength => (), _ => unreachable!("error"), @@ -438,7 +457,7 @@ mod tests { let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "1000000") .to_srv_request() .into_parts(); - let res = block_on(HttpMessageBody::new(&req, &mut pl)); + let res = HttpMessageBody::new(&req, &mut pl).await; match res.err().unwrap() { PayloadError::Overflow => (), _ => unreachable!("error"), @@ -447,13 +466,13 @@ mod tests { let (req, mut pl) = TestRequest::default() .set_payload(Bytes::from_static(b"test")) .to_http_parts(); - let res = block_on(HttpMessageBody::new(&req, &mut pl)); + let res = HttpMessageBody::new(&req, &mut pl).await; assert_eq!(res.ok().unwrap(), Bytes::from_static(b"test")); let (req, mut pl) = TestRequest::default() .set_payload(Bytes::from_static(b"11111111111111")) .to_http_parts(); - let res = block_on(HttpMessageBody::new(&req, &mut pl).limit(5)); + let res = HttpMessageBody::new(&req, &mut pl).limit(5).await; match res.err().unwrap() { PayloadError::Overflow => (), _ => unreachable!("error"), diff --git a/src/types/query.rs b/src/types/query.rs index 60b07085d..a6c18d9be 100644 --- a/src/types/query.rs +++ b/src/types/query.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::{fmt, ops}; use actix_http::error::Error; +use futures::future::{err, ok, Ready}; use serde::de; use serde_urlencoded; @@ -18,11 +19,13 @@ use crate::request::HttpRequest; /// be decoded into any type which depends upon data ordering e.g. tuples or tuple-structs. /// Attempts to do so will *fail at runtime*. /// +/// [**QueryConfig**](struct.QueryConfig.html) allows to configure extraction process. +/// /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Debug, Deserialize)] /// pub enum ResponseType { @@ -39,7 +42,7 @@ use crate::request::HttpRequest; /// // Use `Query` extractor for query information (and destructure it within the signature). /// // This handler gets called only if the request's query string contains a `username` field. /// // The correct request for this handler would be `/index.html?id=64&response_type=Code"`. -/// fn index(web::Query(info): web::Query) -> String { +/// async fn index(web::Query(info): web::Query) -> String { /// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) /// } /// @@ -83,13 +86,13 @@ impl ops::DerefMut for Query { } impl fmt::Debug for Query { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl fmt::Display for Query { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } @@ -99,8 +102,8 @@ impl fmt::Display for Query { /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{web, App}; +/// use serde_derive::Deserialize; /// /// #[derive(Debug, Deserialize)] /// pub enum ResponseType { @@ -117,7 +120,7 @@ impl fmt::Display for Query { /// // Use `Query` extractor for query information. /// // This handler get called only if request's query contains `username` field /// // The correct request for this handler would be `/index.html?id=64&response_type=Code"` -/// fn index(info: web::Query) -> String { +/// async fn index(info: web::Query) -> String { /// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) /// } /// @@ -132,7 +135,7 @@ where T: de::DeserializeOwned, { type Error = Error; - type Future = Result; + type Future = Ready>; type Config = QueryConfig; #[inline] @@ -143,7 +146,7 @@ where .unwrap_or(None); serde_urlencoded::from_str::(req.query_string()) - .map(|val| Ok(Query(val))) + .map(|val| ok(Query(val))) .unwrap_or_else(move |e| { let e = QueryPayloadError::Deserialize(e); @@ -159,7 +162,7 @@ where e.into() }; - Err(e) + err(e) }) } } @@ -169,8 +172,8 @@ where /// ## Example /// /// ```rust -/// #[macro_use] extern crate serde_derive; /// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -178,13 +181,13 @@ where /// } /// /// /// deserialize `Info` from request's querystring -/// fn index(info: web::Query) -> String { +/// async fn index(info: web::Query) -> String { /// format!("Welcome {}!", info.username) /// } /// /// fn main() { /// let app = App::new().service( -/// web::resource("/index.html").data( +/// web::resource("/index.html").app_data( /// // change query extractor configuration /// web::Query::::configure(|cfg| { /// cfg.error_handler(|err, req| { // <- create custom error response @@ -235,8 +238,8 @@ mod tests { id: String, } - #[test] - fn test_service_request_extract() { + #[actix_rt::test] + async fn test_service_request_extract() { let req = TestRequest::with_uri("/name/user1/").to_srv_request(); assert!(Query::::from_query(&req.query_string()).is_err()); @@ -251,16 +254,16 @@ mod tests { assert_eq!(s.id, "test1"); } - #[test] - fn test_request_extract() { + #[actix_rt::test] + async fn test_request_extract() { let req = TestRequest::with_uri("/name/user1/").to_srv_request(); let (req, mut pl) = req.into_parts(); - assert!(Query::::from_request(&req, &mut pl).is_err()); + assert!(Query::::from_request(&req, &mut pl).await.is_err()); let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); let (req, mut pl) = req.into_parts(); - let mut s = Query::::from_request(&req, &mut pl).unwrap(); + let mut s = Query::::from_request(&req, &mut pl).await.unwrap(); assert_eq!(s.id, "test"); assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }"); @@ -269,17 +272,17 @@ mod tests { assert_eq!(s.id, "test1"); } - #[test] - fn test_custom_error_responder() { + #[actix_rt::test] + async fn test_custom_error_responder() { let req = TestRequest::with_uri("/name/user1/") - .data(QueryConfig::default().error_handler(|e, _| { + .app_data(QueryConfig::default().error_handler(|e, _| { let resp = HttpResponse::UnprocessableEntity().finish(); InternalError::from_response(e, resp).into() })) .to_srv_request(); let (req, mut pl) = req.into_parts(); - let query = Query::::from_request(&req, &mut pl); + let query = Query::::from_request(&req, &mut pl).await; assert!(query.is_err()); assert_eq!( diff --git a/src/types/readlines.rs b/src/types/readlines.rs index cea63e43b..82853381b 100644 --- a/src/types/readlines.rs +++ b/src/types/readlines.rs @@ -1,9 +1,11 @@ use std::borrow::Cow; +use std::pin::Pin; use std::str; +use std::task::{Context, Poll}; use bytes::{Bytes, BytesMut}; use encoding_rs::{Encoding, UTF_8}; -use futures::{Async, Poll, Stream}; +use futures::Stream; use crate::dev::Payload; use crate::error::{PayloadError, ReadlinesError}; @@ -22,7 +24,7 @@ pub struct Readlines { impl Readlines where T: HttpMessage, - T::Stream: Stream, + T::Stream: Stream> + Unpin, { /// Create a new stream to read request line by line. pub fn new(req: &mut T) -> Self { @@ -62,20 +64,24 @@ where impl Stream for Readlines where T: HttpMessage, - T::Stream: Stream, + T::Stream: Stream> + Unpin, { - type Item = String; - type Error = ReadlinesError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - if let Some(err) = self.err.take() { - return Err(err); + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Some(Err(err))); } // check if there is a newline in the buffer - if !self.checked_buff { + if !this.checked_buff { let mut found: Option = None; - for (ind, b) in self.buff.iter().enumerate() { + for (ind, b) in this.buff.iter().enumerate() { if *b == b'\n' { found = Some(ind); break; @@ -83,28 +89,28 @@ where } if let Some(ind) = found { // check if line is longer than limit - if ind + 1 > self.limit { - return Err(ReadlinesError::LimitOverflow); + if ind + 1 > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); } - let line = if self.encoding == UTF_8 { - str::from_utf8(&self.buff.split_to(ind + 1)) + let line = if this.encoding == UTF_8 { + str::from_utf8(&this.buff.split_to(ind + 1)) .map_err(|_| ReadlinesError::EncodingError)? .to_owned() } else { - self.encoding + this.encoding .decode_without_bom_handling_and_without_replacement( - &self.buff.split_to(ind + 1), + &this.buff.split_to(ind + 1), ) .map(Cow::into_owned) .ok_or(ReadlinesError::EncodingError)? }; - return Ok(Async::Ready(Some(line))); + return Poll::Ready(Some(Ok(line))); } - self.checked_buff = true; + this.checked_buff = true; } // poll req for more bytes - match self.stream.poll() { - Ok(Async::Ready(Some(mut bytes))) => { + match Pin::new(&mut this.stream).poll_next(cx) { + Poll::Ready(Some(Ok(mut bytes))) => { // check if there is a newline in bytes let mut found: Option = None; for (ind, b) in bytes.iter().enumerate() { @@ -115,15 +121,15 @@ where } if let Some(ind) = found { // check if line is longer than limit - if ind + 1 > self.limit { - return Err(ReadlinesError::LimitOverflow); + if ind + 1 > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); } - let line = if self.encoding == UTF_8 { + let line = if this.encoding == UTF_8 { str::from_utf8(&bytes.split_to(ind + 1)) .map_err(|_| ReadlinesError::EncodingError)? .to_owned() } else { - self.encoding + this.encoding .decode_without_bom_handling_and_without_replacement( &bytes.split_to(ind + 1), ) @@ -131,46 +137,48 @@ where .ok_or(ReadlinesError::EncodingError)? }; // extend buffer with rest of the bytes; - self.buff.extend_from_slice(&bytes); - self.checked_buff = false; - return Ok(Async::Ready(Some(line))); + this.buff.extend_from_slice(&bytes); + this.checked_buff = false; + return Poll::Ready(Some(Ok(line))); } - self.buff.extend_from_slice(&bytes); - Ok(Async::NotReady) + this.buff.extend_from_slice(&bytes); + Poll::Pending } - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(None)) => { - if self.buff.is_empty() { - return Ok(Async::Ready(None)); + Poll::Pending => Poll::Pending, + Poll::Ready(None) => { + if this.buff.is_empty() { + return Poll::Ready(None); } - if self.buff.len() > self.limit { - return Err(ReadlinesError::LimitOverflow); + if this.buff.len() > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); } - let line = if self.encoding == UTF_8 { - str::from_utf8(&self.buff) + let line = if this.encoding == UTF_8 { + str::from_utf8(&this.buff) .map_err(|_| ReadlinesError::EncodingError)? .to_owned() } else { - self.encoding - .decode_without_bom_handling_and_without_replacement(&self.buff) + this.encoding + .decode_without_bom_handling_and_without_replacement(&this.buff) .map(Cow::into_owned) .ok_or(ReadlinesError::EncodingError)? }; - self.buff.clear(); - Ok(Async::Ready(Some(line))) + this.buff.clear(); + Poll::Ready(Some(Ok(line))) } - Err(e) => Err(ReadlinesError::from(e)), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ReadlinesError::from(e)))), } } } #[cfg(test)] mod tests { - use super::*; - use crate::test::{block_on, TestRequest}; + use futures::stream::StreamExt; - #[test] - fn test_readlines() { + use super::*; + use crate::test::TestRequest; + + #[actix_rt::test] + async fn test_readlines() { let mut req = TestRequest::default() .set_payload(Bytes::from_static( b"Lorem Ipsum is simply dummy text of the printing and typesetting\n\ @@ -178,36 +186,21 @@ mod tests { Contrary to popular belief, Lorem Ipsum is not simply random text.", )) .to_request(); - let stream = match block_on(Readlines::new(&mut req).into_future()) { - Ok((Some(s), stream)) => { - assert_eq!( - s, - "Lorem Ipsum is simply dummy text of the printing and typesetting\n" - ); - stream - } - _ => unreachable!("error"), - }; - let stream = match block_on(stream.into_future()) { - Ok((Some(s), stream)) => { - assert_eq!( - s, - "industry. Lorem Ipsum has been the industry's standard dummy\n" - ); - stream - } - _ => unreachable!("error"), - }; + let mut stream = Readlines::new(&mut req); + assert_eq!( + stream.next().await.unwrap().unwrap(), + "Lorem Ipsum is simply dummy text of the printing and typesetting\n" + ); - match block_on(stream.into_future()) { - Ok((Some(s), _)) => { - assert_eq!( - s, - "Contrary to popular belief, Lorem Ipsum is not simply random text." - ); - } - _ => unreachable!("error"), - } + assert_eq!( + stream.next().await.unwrap().unwrap(), + "industry. Lorem Ipsum has been the industry's standard dummy\n" + ); + + assert_eq!( + stream.next().await.unwrap().unwrap(), + "Contrary to popular belief, Lorem Ipsum is not simply random text." + ); } } diff --git a/src/web.rs b/src/web.rs index 5669a1e86..50d99479a 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1,13 +1,15 @@ //! Essentials helper functions and types for application registration. use actix_http::http::Method; -use futures::{Future, IntoFuture}; +use actix_router::IntoPattern; +use futures::Future; pub use actix_http::Response as HttpResponse; pub use bytes::{Bytes, BytesMut}; +pub use futures::channel::oneshot::Canceled; -use crate::error::{BlockingError, Error}; +use crate::error::BlockingError; use crate::extract::FromRequest; -use crate::handler::{AsyncFactory, Factory}; +use crate::handler::Factory; use crate::resource::Resource; use crate::responder::Responder; use crate::route::Route; @@ -43,15 +45,13 @@ pub use crate::types::*; /// # extern crate actix_web; /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/users/{userid}/{friend}") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/users/{userid}/{friend}") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +/// ); /// ``` -pub fn resource(path: &str) -> Resource { +pub fn resource(path: T) -> Resource { Resource::new(path) } @@ -63,14 +63,12 @@ pub fn resource(path: &str) -> Resource { /// ```rust /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::scope("/{project_id}") -/// .service(web::resource("/path1").to(|| HttpResponse::Ok())) -/// .service(web::resource("/path2").to(|| HttpResponse::Ok())) -/// .service(web::resource("/path3").to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// let app = App::new().service( +/// web::scope("/{project_id}") +/// .service(web::resource("/path1").to(|| HttpResponse::Ok())) +/// .service(web::resource("/path2").to(|| HttpResponse::Ok())) +/// .service(web::resource("/path3").to(|| HttpResponse::MethodNotAllowed())) +/// ); /// ``` /// /// In the above example, three routes get added: @@ -92,12 +90,10 @@ pub fn route() -> Route { /// ```rust /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// ); /// ``` /// /// In the above example, one `GET` route get added: @@ -112,12 +108,10 @@ pub fn get() -> Route { /// ```rust /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::post().to(|| HttpResponse::Ok())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::post().to(|| HttpResponse::Ok())) +/// ); /// ``` /// /// In the above example, one `POST` route get added: @@ -132,12 +126,10 @@ pub fn post() -> Route { /// ```rust /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::put().to(|| HttpResponse::Ok())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::put().to(|| HttpResponse::Ok())) +/// ); /// ``` /// /// In the above example, one `PUT` route get added: @@ -152,12 +144,10 @@ pub fn put() -> Route { /// ```rust /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::patch().to(|| HttpResponse::Ok())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::patch().to(|| HttpResponse::Ok())) +/// ); /// ``` /// /// In the above example, one `PATCH` route get added: @@ -172,12 +162,10 @@ pub fn patch() -> Route { /// ```rust /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::delete().to(|| HttpResponse::Ok())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::delete().to(|| HttpResponse::Ok())) +/// ); /// ``` /// /// In the above example, one `DELETE` route get added: @@ -192,12 +180,10 @@ pub fn delete() -> Route { /// ```rust /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::head().to(|| HttpResponse::Ok())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::head().to(|| HttpResponse::Ok())) +/// ); /// ``` /// /// In the above example, one `HEAD` route get added: @@ -212,12 +198,10 @@ pub fn head() -> Route { /// ```rust /// use actix_web::{web, http, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::method(http::Method::GET).to(|| HttpResponse::Ok())) -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::method(http::Method::GET).to(|| HttpResponse::Ok())) +/// ); /// ``` /// /// In the above example, one `GET` route get added: @@ -230,10 +214,10 @@ pub fn method(method: Method) -> Route { /// Create a new route and add handler. /// /// ```rust -/// use actix_web::{web, App, HttpResponse}; +/// use actix_web::{web, App, HttpResponse, Responder}; /// -/// fn index() -> HttpResponse { -/// unimplemented!() +/// async fn index() -> impl Responder { +/// HttpResponse::Ok() /// } /// /// App::new().service( @@ -241,69 +225,42 @@ pub fn method(method: Method) -> Route { /// web::to(index)) /// ); /// ``` -pub fn to(handler: F) -> Route +pub fn to(handler: F) -> Route where - F: Factory + 'static, + F: Factory, I: FromRequest + 'static, - R: Responder + 'static, + R: Future + 'static, + U: Responder + 'static, { Route::new().to(handler) } -/// Create a new route and add async handler. -/// -/// ```rust -/// # use futures::future::{ok, Future}; -/// use actix_web::{web, App, HttpResponse, Error}; -/// -/// fn index() -> impl Future { -/// ok(HttpResponse::Ok().finish()) -/// } -/// -/// App::new().service(web::resource("/").route( -/// web::to_async(index)) -/// ); -/// ``` -pub fn to_async(handler: F) -> Route -where - F: AsyncFactory, - I: FromRequest + 'static, - R: IntoFuture + 'static, - R::Item: Responder, - R::Error: Into, -{ - Route::new().to_async(handler) -} - /// Create raw service for a specific path. /// /// ```rust -/// # extern crate actix_web; -/// use actix_web::{dev, web, guard, App, HttpResponse}; +/// use actix_web::{dev, web, guard, App, Error, HttpResponse}; /// -/// fn my_service(req: dev::ServiceRequest) -> dev::ServiceResponse { -/// req.into_response(HttpResponse::Ok().finish()) +/// async fn my_service(req: dev::ServiceRequest) -> Result { +/// Ok(req.into_response(HttpResponse::Ok().finish())) /// } /// -/// fn main() { -/// let app = App::new().service( -/// web::service("/users/*") -/// .guard(guard::Header("content-type", "text/plain")) -/// .finish(my_service) -/// ); -/// } +/// let app = App::new().service( +/// web::service("/users/*") +/// .guard(guard::Header("content-type", "text/plain")) +/// .finish(my_service) +/// ); /// ``` -pub fn service(path: &str) -> WebService { +pub fn service(path: T) -> WebService { WebService::new(path) } /// Execute blocking function on a thread pool, returns future that resolves /// to result of the function execution. -pub fn block(f: F) -> impl Future> +pub async fn block(f: F) -> Result> where F: FnOnce() -> Result + Send + 'static, I: Send + 'static, E: Send + std::fmt::Debug + 'static, { - actix_threadpool::run(f).from_err() + actix_threadpool::run(f).await } diff --git a/test-server/CHANGES.md b/test-server/CHANGES.md index c3fe5b285..5690afc64 100644 --- a/test-server/CHANGES.md +++ b/test-server/CHANGES.md @@ -1,9 +1,29 @@ # Changes +## [1.0.0] - 2019-12-13 + +### Changed + +* Replaced `TestServer::start()` with `test_server()` + + +## [1.0.0-alpha.3] - 2019-12-07 + +### Changed + +* Migrate to `std::future` + + +## [0.2.5] - 2019-09-17 ### Changed * Update serde_urlencoded to "0.6.1" +* Increase TestServerRuntime timeouts from 500ms to 3000ms + +### Fixed + +* Do not override current `System` ## [0.2.4] - 2019-07-18 diff --git a/test-server/Cargo.toml b/test-server/Cargo.toml index 22809c060..52a2da8da 100644 --- a/test-server/Cargo.toml +++ b/test-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http-test" -version = "0.2.4" +version = "1.0.0" authors = ["Nikolay Kim "] description = "Actix http test server" readme = "README.md" @@ -27,21 +27,22 @@ path = "src/lib.rs" default = [] # openssl -ssl = ["openssl", "actix-server/ssl", "awc/ssl"] +openssl = ["open-ssl", "awc/openssl"] [dependencies] -actix-codec = "0.1.2" -actix-rt = "0.2.2" -actix-service = "0.4.1" -actix-server = "0.6.0" -actix-utils = "0.4.1" -awc = "0.2.2" -actix-connect = "0.2.2" +actix-service = "1.0.1" +actix-codec = "0.2.0" +actix-connect = "1.0.0" +actix-utils = "1.0.3" +actix-rt = "1.0.0" +actix-server = "1.0.0" +actix-testing = "1.0.0" +awc = "1.0.0" -base64 = "0.10" -bytes = "0.4" -futures = "0.1" -http = "0.1.8" +base64 = "0.11" +bytes = "0.5.3" +futures = "0.3.1" +http = "0.2.0" log = "0.4" env_logger = "0.6" net2 = "0.2" @@ -51,10 +52,8 @@ sha1 = "0.6" slab = "0.4" serde_urlencoded = "0.6.1" time = "0.1" -tokio-tcp = "0.1" -tokio-timer = "0.2" -openssl = { version="0.10", optional = true } +open-ssl = { version="0.10", package="openssl", optional = true } [dev-dependencies] -actix-web = "1.0.0" -actix-http = "0.2.4" +actix-web = "2.0.0-rc" +actix-http = "1.0.1" diff --git a/test-server/src/lib.rs b/test-server/src/lib.rs index aba53980c..27326c67a 100644 --- a/test-server/src/lib.rs +++ b/test-server/src/lib.rs @@ -1,75 +1,19 @@ //! Various helpers for Actix applications to use during testing. -use std::cell::RefCell; use std::sync::mpsc; use std::{net, thread, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_rt::{Runtime, System}; -use actix_server::{Server, StreamServiceFactory}; +use actix_rt::{net::TcpStream, System}; +use actix_server::{Server, ServiceFactory}; use awc::{error::PayloadError, ws, Client, ClientRequest, ClientResponse, Connector}; use bytes::Bytes; -use futures::future::lazy; -use futures::{Future, IntoFuture, Stream}; +use futures::Stream; use http::Method; use net2::TcpBuilder; -use tokio_tcp::TcpStream; -thread_local! { - static RT: RefCell = { - RefCell::new(Inner(Some(Runtime::new().unwrap()))) - }; -} +pub use actix_testing::*; -struct Inner(Option); - -impl Inner { - fn get_mut(&mut self) -> &mut Runtime { - self.0.as_mut().unwrap() - } -} - -impl Drop for Inner { - fn drop(&mut self) { - std::mem::forget(self.0.take().unwrap()) - } -} - -/// Runs the provided future, blocking the current thread until the future -/// completes. -/// -/// This function can be used to synchronously block the current thread -/// until the provided `future` has resolved either successfully or with an -/// error. The result of the future is then returned from this function -/// call. -/// -/// Note that this function is intended to be used only for testing purpose. -/// This function panics on nested call. -pub fn block_on(f: F) -> Result -where - F: IntoFuture, -{ - RT.with(move |rt| rt.borrow_mut().get_mut().block_on(f.into_future())) -} - -/// Runs the provided function, blocking the current thread until the resul -/// future completes. -/// -/// This function can be used to synchronously block the current thread -/// until the provided `future` has resolved either successfully or with an -/// error. The result of the future is then returned from this function -/// call. -/// -/// Note that this function is intended to be used only for testing purpose. -/// This function panics on nested call. -pub fn block_fn(f: F) -> Result -where - F: FnOnce() -> R, - R: IntoFuture, -{ - RT.with(move |rt| rt.borrow_mut().get_mut().block_on(lazy(f))) -} - -/// The `TestServer` type. +/// Start test server /// /// `TestServer` is very simple test server that simplify process of writing /// integration tests cases for actix web applications. @@ -79,14 +23,15 @@ where /// ```rust /// use actix_http::HttpService; /// use actix_http_test::TestServer; -/// use actix_web::{web, App, HttpResponse}; +/// use actix_web::{web, App, HttpResponse, Error}; /// -/// fn my_handler() -> HttpResponse { -/// HttpResponse::Ok().into() +/// async fn my_handler() -> Result { +/// Ok(HttpResponse::Ok().into()) /// } /// -/// fn main() { -/// let mut srv = TestServer::new( +/// #[actix_rt::test] +/// async fn test_example() { +/// let mut srv = TestServer::start( /// || HttpService::new( /// App::new().service( /// web::resource("/").to(my_handler)) @@ -94,120 +39,86 @@ where /// ); /// /// let req = srv.get("/"); -/// let response = srv.block_on(req.send()).unwrap(); +/// let response = req.send().await.unwrap(); /// assert!(response.status().is_success()); /// } /// ``` -pub struct TestServer; +pub fn test_server>(factory: F) -> TestServer { + let (tx, rx) = mpsc::channel(); + + // run server in separate thread + thread::spawn(move || { + let sys = System::new("actix-test-server"); + let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); + let local_addr = tcp.local_addr().unwrap(); + + Server::build() + .listen("test", tcp, factory)? + .workers(1) + .disable_signals() + .start(); + + tx.send((System::current(), local_addr)).unwrap(); + sys.run() + }); + + let (system, addr) = rx.recv().unwrap(); + + let client = { + let connector = { + #[cfg(feature = "openssl")] + { + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + Connector::new() + .conn_lifetime(time::Duration::from_secs(0)) + .timeout(time::Duration::from_millis(30000)) + .ssl(builder.build()) + .finish() + } + #[cfg(not(feature = "openssl"))] + { + Connector::new() + .conn_lifetime(time::Duration::from_secs(0)) + .timeout(time::Duration::from_millis(30000)) + .finish() + } + }; + + Client::build().connector(connector).finish() + }; + actix_connect::start_default_resolver(); + + TestServer { + addr, + client, + system, + } +} + +/// Get first available unused address +pub fn unused_addr() -> net::SocketAddr { + let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + let socket = TcpBuilder::new_v4().unwrap(); + socket.bind(&addr).unwrap(); + socket.reuse_address(true).unwrap(); + let tcp = socket.to_tcp_listener().unwrap(); + tcp.local_addr().unwrap() +} /// Test server controller -pub struct TestServerRuntime { +pub struct TestServer { addr: net::SocketAddr, - rt: Runtime, client: Client, + system: System, } impl TestServer { - #[allow(clippy::new_ret_no_self)] - /// Start new test server with application factory - pub fn new>(factory: F) -> TestServerRuntime { - let (tx, rx) = mpsc::channel(); - - // run server in separate thread - thread::spawn(move || { - let sys = System::new("actix-test-server"); - let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); - let local_addr = tcp.local_addr().unwrap(); - - Server::build() - .listen("test", tcp, factory)? - .workers(1) - .disable_signals() - .start(); - - tx.send((System::current(), local_addr)).unwrap(); - sys.run() - }); - - let (system, addr) = rx.recv().unwrap(); - let mut rt = Runtime::new().unwrap(); - - let client = rt - .block_on(lazy(move || { - let connector = { - #[cfg(feature = "ssl")] - { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - - let mut builder = - SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder.set_alpn_protos(b"\x02h2\x08http/1.1").map_err( - |e| log::error!("Can not set alpn protocol: {:?}", e), - ); - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(500)) - .ssl(builder.build()) - .finish() - } - #[cfg(not(feature = "ssl"))] - { - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(500)) - .finish() - } - }; - - Ok::(Client::build().connector(connector).finish()) - })) - .unwrap(); - rt.block_on(lazy( - || Ok::<_, ()>(actix_connect::start_default_resolver()), - )) - .unwrap(); - System::set_current(system); - TestServerRuntime { addr, rt, client } - } - - /// Get first available unused address - pub fn unused_addr() -> net::SocketAddr { - let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - let socket = TcpBuilder::new_v4().unwrap(); - socket.bind(&addr).unwrap(); - socket.reuse_address(true).unwrap(); - let tcp = socket.to_tcp_listener().unwrap(); - tcp.local_addr().unwrap() - } -} - -impl TestServerRuntime { - /// Execute future on current core - pub fn block_on(&mut self, fut: F) -> Result - where - F: Future, - { - self.rt.block_on(fut) - } - - /// Execute future on current core - pub fn block_on_fn(&mut self, f: F) -> Result - where - F: FnOnce() -> R, - R: Future, - { - self.rt.block_on(lazy(f)) - } - - /// Execute function on current core - pub fn execute(&mut self, fut: F) -> R - where - F: FnOnce() -> R, - { - self.rt.block_on(lazy(|| Ok::<_, ()>(fut()))).unwrap() - } - /// Construct test server url pub fn addr(&self) -> net::SocketAddr { self.addr @@ -306,43 +217,42 @@ impl TestServerRuntime { self.client.request(method, path.as_ref()) } - pub fn load_body( + pub async fn load_body( &mut self, mut response: ClientResponse, ) -> Result where - S: Stream + 'static, + S: Stream> + Unpin + 'static, { - self.block_on(response.body().limit(10_485_760)) + response.body().limit(10_485_760).await } /// Connect to websocket server at a given path - pub fn ws_at( + pub async fn ws_at( &mut self, path: &str, ) -> Result, awc::error::WsClientError> { let url = self.url(path); let connect = self.client.ws(url).connect(); - self.rt - .block_on(lazy(move || connect.map(|(_, framed)| framed))) + connect.await.map(|(_, framed)| framed) } /// Connect to a websocket server - pub fn ws( + pub async fn ws( &mut self, ) -> Result, awc::error::WsClientError> { - self.ws_at("/") + self.ws_at("/").await } /// Stop http server fn stop(&mut self) { - System::current().stop(); + self.system.stop(); } } -impl Drop for TestServerRuntime { +impl Drop for TestServer { fn drop(&mut self) { self.stop() } diff --git a/tests/cert.pem b/tests/cert.pem index f9bb05081..0eeb6721d 100644 --- a/tests/cert.pem +++ b/tests/cert.pem @@ -1,32 +1,19 @@ -----BEGIN CERTIFICATE----- -MIIFfjCCA2agAwIBAgIJAOIBvp/w68KrMA0GCSqGSIb3DQEBCwUAMGsxCzAJBgNV -BAYTAlJVMRkwFwYDVQQIDBBTYWludC1QZXRlcnNidXJnMRkwFwYDVQQHDBBTYWlu -dC1QZXRlcnNidXJnMRIwEAYDVQQKDAlLdXBpYmlsZXQxEjAQBgNVBAMMCWxvY2Fs -aG9zdDAgFw0xOTA3MjcxODIzMTJaGA8zMDE5MDcyNzE4MjMxMlowazELMAkGA1UE -BhMCUlUxGTAXBgNVBAgMEFNhaW50LVBldGVyc2J1cmcxGTAXBgNVBAcMEFNhaW50 -LVBldGVyc2J1cmcxEjAQBgNVBAoMCUt1cGliaWxldDESMBAGA1UEAwwJbG9jYWxo -b3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuiQZzTO3gRRPr6ZH -wcmKqkoXig9taCCqx72Qvb9tvCLhQLE1dDPZV8I/r8bx+mM4Yz3r0Hm5LxTIhCM9 -p3/abuiJAZENC/VkxgFzBGg7KGLSFmzU+A8Ft+2mrKmj5MpIPBCxDeVg80TCQOJy -hj+NU3PpBo9nxTgxWNWO6X+ZovZohdp78fYLLtns8rxjug3FVzdPrrLnBvihkGlq -gfImkh+vZxMTj1OgtxyCOhdbO4Ol4jCbn7a5yIw+iixHOEgBQfTQopRP7z1PEUV2 -WIy2VEGzvQDlj2OyzH86T1IOFV5rz5MjdZuW0qNzeS0w3Jzgp/olSbIZLhGAaIk0 -gN7y9XvSHqs7rO0wW+467ico7+uP1ScGgPgJA5fGu7ahp7F7G3ZSoAqAcS60wYsX -kStoA3RWAuqste6aChv1tlgTt+Rhk8qjGhuh0ng2qVnTGyo2T3OCHB/c47Bcsp6L -xiyTCnQIPH3fh2iO/SC7gPw3jihPMCAQZYlkC3MhMk974rn2zs9cKtJ8ubnG2m5F -VFVYmduRqS/YQS/O802jVCFdc8KDmoeAYNuHzgRZjQv9018UUeW3jtWKnopJnWs5 -ae9pbtmYeOtc7OovOxT7J2AaVfUkHRhmlqWZMzEJTcZws0fRPTZDifFJ5LFWbZsC -zW4tCKBKvYM9eAnbb+abiHXlY1MCAwEAAaMjMCEwHwYDVR0RBBgwFoIJbG9jYWxo -b3N0ggkxMjcuMC4wLjEwDQYJKoZIhvcNAQELBQADggIBAC1EU4SVCfUKc7JbhYRf -P87F+8e13bBTLxevJdnTCH3Xw2AN8UPmwQ2vv9Mv2FMulMBQ7yLnQLGtgGUua2OE -XO+EdBBEKnglo9cwXGzU6qHhaiCeXZDM8s53qOOrD42XsDsY0nOoFYqDLW3WixP9 -f1fWbcEf6+ktlvqi/1/3R6QtQR+6LS43enbsYHq8aAP60NrpXxdXxEoUwW6Z/sje -XAQluH8jzledwJcY8bXRskAHZlE4kGlOVuGgnyI3BXyLiwB4g9smFzYIs98iAGmV -7ZBaR5IIiRCtoKBG+SngM7Log0bHphvFPjDDvgqWYiWaOHboYM60Y2Z/gRbcjuMU -WZX64jw29fa8UPFdtGTupt+iuO7iXnHnm0lBBK36rVdOvsZup76p6L4BXmFsRmFK -qJ2Zd8uWNPDq80Am0mYaAqENuIANHHJXX38SesC+QO+G2JZt6vCwkGk/Qune4GIg -1GwhvsDRfTQopSxg1rdPwPM7HWeTfUGHZ34B5p/iILA3o6PfYQU8fNAWIsCDkRX2 -MrgDgCnLZxKb6pjR4DYNAdPwkxyMFACZ2T46z6WvLWFlnkK5nbZoqsOsp+GJHole -llafhrelXEzt3zFR0q4zGcqheJDI+Wy+fBy3XawgAc4eN0T2UCzL/jKxKgzlzSU3 -+xh1SDNjFLRd6sGzZHPMgXN0 +MIIDEDCCAfgCCQCQdmIZc/Ib/jANBgkqhkiG9w0BAQsFADBKMQswCQYDVQQGEwJ1 +czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZIhvcNAQkBFhJmYWZo +cmQ5MUBnbWFpbC5jb20wHhcNMTkxMTE5MTEwNjU1WhcNMjkxMTE2MTEwNjU1WjBK +MQswCQYDVQQGEwJ1czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZI +hvcNAQkBFhJmYWZocmQ5MUBnbWFpbC5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQDcnaz12CKzUL7248V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C ++kLWKjAc2coqDSbGsrsR6KiH2g06Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezy +XRe/olcHFTeCk/Tllz4xGEplhPua6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqc +K2xntIPreumXpiE3QY4+MWyteiJko4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvu +GccHd/ex8cOwotUqd6emZb+0bVE24Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zU +b2GJosbmfGaf+xTfnGGhTLLL7kCtva+NvZr5AgMBAAEwDQYJKoZIhvcNAQELBQAD +ggEBANftoL8zDGrjCwWvct8kOOqset2ukK8vjIGwfm88CKsy0IfSochNz2qeIu9R +ZuO7c0pfjmRkir9ZQdq9vXgG3ccL9UstFsferPH9W3YJ83kgXg3fa0EmCiN/0hwz +6Ij1ZBiN1j3+d6+PJPgyYFNu2nGwox5mJ9+aRAGe0/9c63PEOY8P2TI4HsiPmYSl +fFR8k/03vr6e+rTKW85BgctjvYKe/TnFxeCQ7dZ+na7vlEtch4tNmy6O/vEk2kCt +5jW0DUxhmRsv2wGmfFRI0+LotHjoXQQZi6nN5aGL3odaGF3gYwIVlZNd3AdkwDQz +BzG0ZwXuDDV9bSs3MfWEWcy4xuU= -----END CERTIFICATE----- diff --git a/tests/key.pem b/tests/key.pem index 70153c8ae..a6d308168 100644 --- a/tests/key.pem +++ b/tests/key.pem @@ -1,52 +1,28 @@ -----BEGIN PRIVATE KEY----- -MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQC6JBnNM7eBFE+v -pkfByYqqSheKD21oIKrHvZC9v228IuFAsTV0M9lXwj+vxvH6YzhjPevQebkvFMiE -Iz2nf9pu6IkBkQ0L9WTGAXMEaDsoYtIWbNT4DwW37aasqaPkykg8ELEN5WDzRMJA -4nKGP41Tc+kGj2fFODFY1Y7pf5mi9miF2nvx9gsu2ezyvGO6DcVXN0+usucG+KGQ -aWqB8iaSH69nExOPU6C3HII6F1s7g6XiMJuftrnIjD6KLEc4SAFB9NCilE/vPU8R -RXZYjLZUQbO9AOWPY7LMfzpPUg4VXmvPkyN1m5bSo3N5LTDcnOCn+iVJshkuEYBo -iTSA3vL1e9Ieqzus7TBb7jruJyjv64/VJwaA+AkDl8a7tqGnsXsbdlKgCoBxLrTB -ixeRK2gDdFYC6qy17poKG/W2WBO35GGTyqMaG6HSeDapWdMbKjZPc4IcH9zjsFyy -novGLJMKdAg8fd+HaI79ILuA/DeOKE8wIBBliWQLcyEyT3viufbOz1wq0ny5ucba -bkVUVViZ25GpL9hBL87zTaNUIV1zwoOah4Bg24fOBFmNC/3TXxRR5beO1Yqeikmd -azlp72lu2Zh461zs6i87FPsnYBpV9SQdGGaWpZkzMQlNxnCzR9E9NkOJ8UnksVZt -mwLNbi0IoEq9gz14Cdtv5puIdeVjUwIDAQABAoICAQCZVVezw+BsAjFKPi1qIv2J -HZOadO7pEc/czflHdUON8SWgxtmDqZpmQmt3/ugiHE284qs4hqzXbcVnpCgLrLRh -HEiP887NhQ3IVjVK8hmZQR5SvsAIv0c0ph3gqbWKqF8sq4tOKR/eBUwHawJwODXR -AvB4KPWQbqOny/P3wNbseRLNAJeNT+MSaw5XPnzgLKvdFoEbJeBNy847Sbsk5DaF -tHgm7n30WS1Q6bkU5VyP//hMBUKNJFaSL4TtCWB5qkbu8B5VbtsR9m0FizTb6L3h -VmYbUXvIzJXjAwMjiDJ1w9wHl+tj3BE33tEmhuVzNf+SH+tLc9xuKJighDWt2vpD -eTpZ1qest26ANLOmNXWVCVTGpcWvOu5yhG/P7En10EzjFruMfHAFdwLm1gMx1rlR -9fyNAk/0ROJ+5BUtuWgDiyytS5f2T9KGiOHni7UbBIkv0CV2H6VL39Twxf+3OHnx -JJ7OWZ8DRuLM/EJfN3C1+3eDsXOvcdvbo2TFBmCCl4Pa2pm4k3g2NBfxy/zSYWIh -ccGPZorFKNMUi29U0Ool6fxeVflbll570pWVBLAB31HdkLSESv9h+2j/IiEJcJXj -nzl2RtYB0Uxzk6SjO0z4WXjz/SXg5tQQkm/dx8kM8TvHICFq68AEnw8t9Hagsdxs -v5jNyOEeI1I5gPgZmKuboQKCAQEA7Hw6s8Xc3UtNaywMcK1Eb1O+kwRdztgtm0uE -uqsHWmGqbBxXN4jiVLh3dILIXFuVbvDSsSZtPLhOj1wqxgsTHg93b4BtvowyNBgo -X4tErMu7/6NRael/hfOEdtpfv2gV+0eQa+8KKqYJPbqpMz/r5L/3RaxS3iXkj3QM -6oC4+cRuwy/flPMIpxhDriH5yjfiMOdDsi3ZfMTJu/58DTrKV7WkJxQZmha4EoZn -IiXeRhzo+2ukMDWrr3GGPyDfjd/NB7rmY8QBdmhB5NSk+6B66JCNTIbKka/pichS -36bwSYFNji4NaHUUlYDUjfKoTNuQMEZknMGhc/433ADO7s17iQKCAQEAyYBYVG7/ -LE2IkvQy9Nwly5tRCNlSvSkloz7PUwRbzG5uF5mweWEa8YECJe9/vrFXvyBW+NR8 -XABFn4eG0POTR9nyb4n2nUlqiGugDIPgkrKCkJws5InifITZ/+Viocd4YZL5UwCU -R1/kMf0UjK2iJjWEeTPS6RmwRI2Iu7kym9BzphDyNYBQSbUE/f+4hNP6nUT/h09c -VH4/sUhubSgVKeK4onOci8bKitAkwVBYCYSyhuBCeCu8fTk2hVRWviRaJPVq2PMB -LHw1FCcfJLIPJG6MZpFAPkMQxpiewdggXIgi46ZlZcsNXEJ81ocT4GU2j+ArQXCf -lgEycyD3mx4k+wKCAQBGneohmKoVYtEheavVUcgnvkggOqOQirlDsE9YNo4hjRyI -4AWjTbrYNaVmI0+VVLvQvxULVUA1a4v5/zm+nbv9s/ykTSN4TQEI0VXtAfdl6gif -k7NR/ynXZBpgK2GAFKLLwFj+Agl1JtOHnV+9MA9O5Yv/QDAWqhYQSEU7GWkjHGc+ -3eLT5ablzrcXHooqunlOxSBP6qURPupGuv1sLewSOOllyfjDLJmW3o+ZgNlY8nUX -7tK+mqhD4ZCG9VgMU5I0BrmZfQQ6yXMz09PYV9mb7N5kxbNjwbXpMOqeYolKSdRQ -6quST7Pv2OKf6KAdI0txPvP4Y1HFA1rG1W71nGKRAoIBAHlDU+T8J3Rx9I77hu70 -zYoKnmnE35YW/R+Q3RQIu3X7vyVUyG9DkQNlr/VEfIw2Dahnve9hcLWtNDkdRnTZ -IPlMoCmfzVo6pHIU0uy1MKEX7Js6YYnnsPVevhLR6NmTQU73NDRPVOzfOGUc+RDw -LXTxIBgQqAy/+ORIiNDwUxSSDgcSi7DG14qD9c0l59WH/HpI276Cc/4lPA9kl4/5 -X0MlvheFm+BCcgG34Wa1A0Y3JXkl3NqU94oktDro1or3NYioaPTGyR4MYaUPJh7f -SV2TacsP/ql5ks7xahkeB9un0ddOfBcWa6PqH1a7U6rnPj63mVB4hpGvhrziSiB/ -s6ECggEAOp2P4Yd9Vm9/CptxS50HFF4adyLscAtsDd3S2hIAXhDovcPbvRek4pLQ -idPhHlRAfqrEztnhaVAmCK9HlhgthtiQGQX62YI4CS4QL2IhzDFo3M1a2snjFEdl -QuFk3XI7kQ0Yp8BLLG7T436JUrUkCXc4gQX2uRNut+ff34RIR2CjcQQjChxuHVeG -sP/3xFFj8OSs7ZoSPbmDBLrMOl64YHwezQUNAZiRYiaGbFiY0QUV6dHq8qX/qE1h -a/0Rq+gTqObDST0TqhMzI8V/i7R8SwVcD5ODHaZp5I2N2P/hV5OWY7ghQXhh89WM -o21xtGh0nP2Fq1TC6jFO+9cpbK8jNA== +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDcnaz12CKzUL72 +48V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C+kLWKjAc2coqDSbGsrsR6KiH2g06 +Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezyXRe/olcHFTeCk/Tllz4xGEplhPua +6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqcK2xntIPreumXpiE3QY4+MWyteiJk +o4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvuGccHd/ex8cOwotUqd6emZb+0bVE2 +4Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zUb2GJosbmfGaf+xTfnGGhTLLL7kCt +va+NvZr5AgMBAAECggEBAKoU0UwzVgVCQgca8Jt2dnBvWYDhnxIfYAI/BvaKedMm +1ms87OKfB7oOiksjyI0E2JklH72dzZf2jm4CuZt5UjGC+xwPzlTaJ4s6hQVbBHyC +NRyxU1BCXtW5tThbrhD4OjxqjmLRJEIB9OunLtwAEQoeuFLB8Va7+HFhR+Zd9k3f +7aVA93pC5A50NRbZlke4miJ3Q8n7ZF0+UmxkBfm3fbqLk7aMWkoEKwLLTadjRlu1 +bBp0YDStX66I/p1kujqBOdh6VpPvxFOa1sV9pq0jeiGc9YfSkzRSKzIn8GoyviFB +fHeszQdNlcnrSDSNnMABAw+ZpxUO7SCaftjwejEmKZUCgYEA+TY43VpmV95eY7eo +WKwGepiHE0fwQLuKGELmZdZI80tFi73oZMuiB5WzwmkaKGcJmm7KGE9KEvHQCo9j +xvmktBR0VEZH8pmVfun+4h6+0H7m/NKMBBeOyv/IK8jBgHjkkB6e6nmeR7CqTxCw +tf9tbajl1QN8gNzXZSjBDT/lanMCgYEA4qANOKOSiEARtgwyXQeeSJcM2uPv6zF3 +ffM7vjSedtuEOHUSVeyBP/W8KDt7zyPppO/WNbURHS+HV0maS9yyj6zpVS2HGmbs +3fetswsQ+zYVdokW89x4oc2z4XOGHd1LcSlyhRwPt0u2g1E9L0irwTQLWU0npFmG +PRf7sN9+LeMCgYAGkDUDL2ROoB6gRa/7Vdx90hKMoXJkYgwLA4gJ2pDlR3A3c/Lw +5KQJyxmG3zm/IqeQF6be6QesZA30mT4peV2rGHbP2WH/s6fKReNelSy1VQJEWk8x +tGUgV4gwDwN5nLV4TjYlOrq+bJqvpmLhCC8bmj0jVQosYqSRl3cuICasnQKBgGlV +VO/Xb1su1EyWPK5qxRIeSxZOTYw2sMB01nbgxCqge0M2fvA6/hQ5ZlwY0cIEgits +YlcSMsMq/TAAANxz1vbaupUhlSMbZcsBvNV0Nk9c4vr2Wxm7hsJF9u66IEMvQUp2 +pkjiMxfR9CHzF4orr9EcHI5EQ0Grbq5kwFKEfoRbAoGAcWoFPILeJOlp2yW/Ds3E +g2fQdI9BAamtEZEaslJmZMmsDTg5ACPcDkOSFEQIaJ7wLPXeZy74FVk/NrY5F8Gz +bjX9OD/xzwp852yW5L9r62vYJakAlXef5jI6CFdYKDDCcarU0S7W5k6kq9n+wrBR +i1NklYmUAMr2q59uJA5zsic= -----END PRIVATE KEY----- diff --git a/tests/test_httpserver.rs b/tests/test_httpserver.rs index c0d2e81c4..ecd5c9ffb 100644 --- a/tests/test_httpserver.rs +++ b/tests/test_httpserver.rs @@ -2,11 +2,10 @@ use net2::TcpBuilder; use std::sync::mpsc; use std::{net, thread, time::Duration}; -#[cfg(feature = "ssl")] -use openssl::ssl::SslAcceptorBuilder; +#[cfg(feature = "openssl")] +use open_ssl::ssl::SslAcceptorBuilder; -use actix_http::Response; -use actix_web::{test, web, App, HttpServer}; +use actix_web::{web, App, HttpResponse, HttpServer}; fn unused_addr() -> net::SocketAddr { let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); @@ -17,9 +16,9 @@ fn unused_addr() -> net::SocketAddr { tcp.local_addr().unwrap() } -#[test] #[cfg(unix)] -fn test_start() { +#[actix_rt::test] +async fn test_start() { let addr = unused_addr(); let (tx, rx) = mpsc::channel(); @@ -28,7 +27,7 @@ fn test_start() { let srv = HttpServer::new(|| { App::new().service( - web::resource("/").route(web::to(|| Response::Ok().body("test"))), + web::resource("/").route(web::to(|| HttpResponse::Ok().body("test"))), ) }) .workers(1) @@ -43,7 +42,7 @@ fn test_start() { .disable_signals() .bind(format!("{}", addr)) .unwrap() - .start(); + .run(); let _ = tx.send((srv, actix_rt::System::current())); let _ = sys.run(); @@ -53,23 +52,17 @@ fn test_start() { #[cfg(feature = "client")] { use actix_http::client; - use actix_web::test; - let client = test::run_on(|| { - Ok::<_, ()>( - awc::Client::build() - .connector( - client::Connector::new() - .timeout(Duration::from_millis(100)) - .finish(), - ) + let client = awc::Client::build() + .connector( + client::Connector::new() + .timeout(Duration::from_millis(100)) .finish(), ) - }) - .unwrap(); - let host = format!("http://{}", addr); + .finish(); - let response = test::block_on(client.get(host.clone()).send()).unwrap(); + let host = format!("http://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); assert!(response.status().is_success()); } @@ -80,9 +73,9 @@ fn test_start() { let _ = sys.stop(); } -#[cfg(feature = "ssl")] +#[cfg(feature = "openssl")] fn ssl_acceptor() -> std::io::Result { - use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; + use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; // load ssl keys let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); builder @@ -94,9 +87,11 @@ fn ssl_acceptor() -> std::io::Result { Ok(builder) } -#[test] -#[cfg(feature = "ssl")] -fn test_start_ssl() { +#[actix_rt::test] +#[cfg(feature = "openssl")] +async fn test_start_ssl() { + use actix_web::HttpRequest; + let addr = unused_addr(); let (tx, rx) = mpsc::channel(); @@ -105,46 +100,42 @@ fn test_start_ssl() { let builder = ssl_acceptor().unwrap(); let srv = HttpServer::new(|| { - App::new().service( - web::resource("/").route(web::to(|| Response::Ok().body("test"))), - ) + App::new().service(web::resource("/").route(web::to(|req: HttpRequest| { + assert!(req.app_config().secure()); + HttpResponse::Ok().body("test") + }))) }) .workers(1) .shutdown_timeout(1) .system_exit() .disable_signals() - .bind_ssl(format!("{}", addr), builder) + .bind_openssl(format!("{}", addr), builder) .unwrap() - .start(); + .run(); let _ = tx.send((srv, actix_rt::System::current())); let _ = sys.run(); }); let (srv, sys) = rx.recv().unwrap(); - let client = test::run_on(|| { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder - .set_alpn_protos(b"\x02h2\x08http/1.1") - .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); - Ok::<_, ()>( - awc::Client::build() - .connector( - awc::Connector::new() - .ssl(builder.build()) - .timeout(Duration::from_millis(100)) - .finish(), - ) + let client = awc::Client::build() + .connector( + awc::Connector::new() + .ssl(builder.build()) + .timeout(Duration::from_millis(100)) .finish(), ) - }) - .unwrap(); - let host = format!("https://{}", addr); + .finish(); - let response = test::block_on(client.get(host.clone()).send()).unwrap(); + let host = format!("https://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); assert!(response.status().is_success()); // stop diff --git a/tests/test_server.rs b/tests/test_server.rs index 1623d2ef3..1916b372c 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1,24 +1,22 @@ use std::io::{Read, Write}; -use std::sync::mpsc; -use std::thread; +use std::pin::Pin; +use std::task::{Context, Poll}; use actix_http::http::header::{ ContentEncoding, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING, }; -use actix_http::{h1, Error, HttpService, Response}; -use actix_http_test::TestServer; use brotli2::write::{BrotliDecoder, BrotliEncoder}; use bytes::Bytes; use flate2::read::GzDecoder; use flate2::write::{GzEncoder, ZlibDecoder, ZlibEncoder}; use flate2::Compression; -use futures::stream::once; +use futures::{ready, Future}; use rand::{distributions::Alphanumeric, Rng}; -use actix_connect::start_default_resolver; -use actix_web::middleware::{BodyEncoding, Compress}; -use actix_web::{dev, http, test, web, App, HttpResponse, HttpServer}; +use actix_web::dev::BodyEncoding; +use actix_web::middleware::Compress; +use actix_web::{dev, test, web, App, Error, HttpResponse}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -42,46 +40,76 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World"; -#[test] -fn test_body() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .service(web::resource("/").route(web::to(|| Response::Ok().body(STR)))), - ) +struct TestBody { + data: Bytes, + chunk_size: usize, + delay: actix_rt::time::Delay, +} + +impl TestBody { + fn new(data: Bytes, chunk_size: usize) -> Self { + TestBody { + data, + chunk_size, + delay: actix_rt::time::delay_for(std::time::Duration::from_millis(10)), + } + } +} + +impl futures::Stream for TestBody { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + ready!(Pin::new(&mut self.delay).poll(cx)); + + self.delay = actix_rt::time::delay_for(std::time::Duration::from_millis(10)); + let chunk_size = std::cmp::min(self.chunk_size, self.data.len()); + let chunk = self.data.split_to(chunk_size); + if chunk.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(chunk))) + } + } +} + +#[actix_rt::test] +async fn test_body() { + let srv = test::start(|| { + App::new() + .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }); - let mut response = srv.block_on(srv.get("/").send()).unwrap(); + let mut response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -#[test] -fn test_body_gzip() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| Response::Ok().body(STR)))), - ) +#[actix_rt::test] +async fn test_body_gzip() { + let srv = test::start_with(test::config().h1(), || { + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }); let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -90,31 +118,27 @@ fn test_body_gzip() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -#[test] -fn test_body_gzip2() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| { - Response::Ok().body(STR).into_body::() - }))), - ) +#[actix_rt::test] +async fn test_body_gzip2() { + let srv = test::start_with(test::config().h1(), || { + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| { + HttpResponse::Ok().body(STR).into_body::() + }))) }); let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -123,41 +147,39 @@ fn test_body_gzip2() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -#[test] -fn test_body_encoding_override() { - let mut srv = TestServer::new(|| { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| { - Response::Ok().encoding(ContentEncoding::Deflate).body(STR) - }))) - .service(web::resource("/raw").route(web::to(|| { - let body = actix_web::dev::Body::Bytes(STR.into()); - let mut response = - Response::with_body(actix_web::http::StatusCode::OK, body); +#[actix_rt::test] +async fn test_body_encoding_override() { + let srv = test::start_with(test::config().h1(), || { + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| { + HttpResponse::Ok() + .encoding(ContentEncoding::Deflate) + .body(STR) + }))) + .service(web::resource("/raw").route(web::to(|| { + let body = actix_web::dev::Body::Bytes(STR.into()); + let mut response = + HttpResponse::with_body(actix_web::http::StatusCode::OK, body); - response.encoding(ContentEncoding::Deflate); + response.encoding(ContentEncoding::Deflate); - response - }))), - ) + response + }))) }); // Builder let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "deflate") - .send(), - ) + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "deflate") + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = ZlibDecoder::new(Vec::new()); @@ -167,17 +189,16 @@ fn test_body_encoding_override() { // Raw Response let mut response = srv - .block_on( - srv.request(actix_web::http::Method::GET, srv.url("/raw")) - .no_decompress() - .header(ACCEPT_ENCODING, "deflate") - .send(), - ) + .request(actix_web::http::Method::GET, srv.url("/raw")) + .no_decompress() + .header(ACCEPT_ENCODING, "deflate") + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = ZlibDecoder::new(Vec::new()); @@ -186,36 +207,32 @@ fn test_body_encoding_override() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -#[test] -fn test_body_gzip_large() { +#[actix_rt::test] +async fn test_body_gzip_large() { let data = STR.repeat(10); let srv_data = data.clone(); - let mut srv = TestServer::new(move || { + let srv = test::start_with(test::config().h1(), move || { let data = srv_data.clone(); - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service( - web::resource("/") - .route(web::to(move || Response::Ok().body(data.clone()))), - ), - ) + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service( + web::resource("/") + .route(web::to(move || HttpResponse::Ok().body(data.clone()))), + ) }); let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -224,39 +241,35 @@ fn test_body_gzip_large() { assert_eq!(Bytes::from(dec), Bytes::from(data)); } -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -#[test] -fn test_body_gzip_large_random() { +#[actix_rt::test] +async fn test_body_gzip_large_random() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(70_000) .collect::(); let srv_data = data.clone(); - let mut srv = TestServer::new(move || { + let srv = test::start_with(test::config().h1(), move || { let data = srv_data.clone(); - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service( - web::resource("/") - .route(web::to(move || Response::Ok().body(data.clone()))), - ), - ) + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service( + web::resource("/") + .route(web::to(move || HttpResponse::Ok().body(data.clone()))), + ) }); let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -266,28 +279,23 @@ fn test_body_gzip_large_random() { assert_eq!(Bytes::from(dec), Bytes::from(data)); } -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -#[test] -fn test_body_chunked_implicit() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::get().to(move || { - Response::Ok().streaming(once(Ok::<_, Error>(Bytes::from_static( - STR.as_ref(), - )))) - }))), - ) +#[actix_rt::test] +async fn test_body_chunked_implicit() { + let srv = test::start_with(test::config().h1(), || { + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::get().to(move || { + HttpResponse::Ok() + .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) + }))) }); let mut response = srv - .block_on( - srv.get("/") - .no_decompress() - .header(ACCEPT_ENCODING, "gzip") - .send(), - ) + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await .unwrap(); assert!(response.status().is_success()); assert_eq!( @@ -296,7 +304,7 @@ fn test_body_chunked_implicit() { ); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -305,47 +313,47 @@ fn test_body_chunked_implicit() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -#[cfg(feature = "brotli")] -fn test_body_br_streaming() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().wrap(Compress::new(ContentEncoding::Br)).service( +#[actix_rt::test] +async fn test_body_br_streaming() { + let srv = test::start_with(test::config().h1(), || { + App::new().wrap(Compress::new(ContentEncoding::Br)).service( web::resource("/").route(web::to(move || { - Response::Ok() - .streaming(once(Ok::<_, Error>(Bytes::from_static(STR.as_ref())))) + HttpResponse::Ok() + .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) })), - )) + ) }); let mut response = srv - .block_on( - srv.get("/") - .header(ACCEPT_ENCODING, "br") - .no_decompress() - .send(), - ) + .get("/") + .header(ACCEPT_ENCODING, "br") + .no_decompress() + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); + println!("TEST: {:?}", bytes.len()); // decode br let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); e.write_all(bytes.as_ref()).unwrap(); let dec = e.finish().unwrap(); + println!("T: {:?}", Bytes::copy_from_slice(&dec)); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -fn test_head_binary() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().service(web::resource("/").route( - web::head().to(move || Response::Ok().content_length(100).body(STR)), - ))) +#[actix_rt::test] +async fn test_head_binary() { + let srv = test::start_with(test::config().h1(), || { + App::new().service(web::resource("/").route( + web::head().to(move || HttpResponse::Ok().content_length(100).body(STR)), + )) }); - let mut response = srv.block_on(srv.head("/").send()).unwrap(); + let mut response = srv.head("/").send().await.unwrap(); assert!(response.status().is_success()); { @@ -354,58 +362,52 @@ fn test_head_binary() { } // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert!(bytes.is_empty()); } -#[test] -fn test_no_chunking() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().service(web::resource("/").route(web::to( - move || { - Response::Ok() - .no_chunking() - .content_length(STR.len() as u64) - .streaming(once(Ok::<_, Error>(Bytes::from_static(STR.as_ref())))) - }, - )))) +#[actix_rt::test] +async fn test_no_chunking() { + let srv = test::start_with(test::config().h1(), || { + App::new().service(web::resource("/").route(web::to(move || { + HttpResponse::Ok() + .no_chunking() + .content_length(STR.len() as u64) + .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) + }))) }); - let mut response = srv.block_on(srv.get("/").send()).unwrap(); + let mut response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); assert!(!response.headers().contains_key(TRANSFER_ENCODING)); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_body_deflate() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new() - .wrap(Compress::new(ContentEncoding::Deflate)) - .service( - web::resource("/").route(web::to(move || Response::Ok().body(STR))), - ), - ) +#[actix_rt::test] +async fn test_body_deflate() { + let srv = test::start_with(test::config().h1(), || { + App::new() + .wrap(Compress::new(ContentEncoding::Deflate)) + .service( + web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR))), + ) }); // client request let mut response = srv - .block_on( - srv.get("/") - .header(ACCEPT_ENCODING, "deflate") - .no_decompress() - .send(), - ) + .get("/") + .header(ACCEPT_ENCODING, "deflate") + .no_decompress() + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); let mut e = ZlibDecoder::new(Vec::new()); e.write_all(bytes.as_ref()).unwrap(); @@ -413,28 +415,26 @@ fn test_body_deflate() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -#[cfg(any(feature = "brotli"))] -fn test_body_brotli() { - let mut srv = TestServer::new(move || { - h1::H1Service::new(App::new().wrap(Compress::new(ContentEncoding::Br)).service( - web::resource("/").route(web::to(move || Response::Ok().body(STR))), - )) +#[actix_rt::test] +async fn test_body_brotli() { + let srv = test::start_with(test::config().h1(), || { + App::new().wrap(Compress::new(ContentEncoding::Br)).service( + web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR))), + ) }); // client request let mut response = srv - .block_on( - srv.get("/") - .header(ACCEPT_ENCODING, "br") - .no_decompress() - .send(), - ) + .get("/") + .header(ACCEPT_ENCODING, "br") + .no_decompress() + .send() + .await .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode brotli let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); @@ -443,15 +443,12 @@ fn test_body_brotli() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_encoding() { - let mut srv = TestServer::new(move || { - HttpService::new( - App::new().wrap(Compress::default()).service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), +#[actix_rt::test] +async fn test_encoding() { + let srv = test::start_with(test::config().h1(), || { + App::new().wrap(Compress::default()).service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -464,23 +461,20 @@ fn test_encoding() { .post("/") .header(CONTENT_ENCODING, "gzip") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_gzip_encoding() { - let mut srv = TestServer::new(move || { - HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), +#[actix_rt::test] +async fn test_gzip_encoding() { + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -493,24 +487,21 @@ fn test_gzip_encoding() { .post("/") .header(CONTENT_ENCODING, "gzip") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_gzip_encoding_large() { +#[actix_rt::test] +async fn test_gzip_encoding_large() { let data = STR.repeat(10); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -523,28 +514,25 @@ fn test_gzip_encoding_large() { .post("/") .header(CONTENT_ENCODING, "gzip") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_reading_gzip_encoding_large_random() { +#[actix_rt::test] +async fn test_reading_gzip_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(60_000) .collect::(); - let mut srv = TestServer::new(move || { - HttpService::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -557,24 +545,21 @@ fn test_reading_gzip_encoding_large_random() { .post("/") .header(CONTENT_ENCODING, "gzip") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_reading_deflate_encoding() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), +#[actix_rt::test] +async fn test_reading_deflate_encoding() { + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -587,24 +572,21 @@ fn test_reading_deflate_encoding() { .post("/") .header(CONTENT_ENCODING, "deflate") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_reading_deflate_encoding_large() { +#[actix_rt::test] +async fn test_reading_deflate_encoding_large() { let data = STR.repeat(10); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -617,28 +599,25 @@ fn test_reading_deflate_encoding_large() { .post("/") .header(CONTENT_ENCODING, "deflate") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -#[test] -#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] -fn test_reading_deflate_encoding_large_random() { +#[actix_rt::test] +async fn test_reading_deflate_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(160_000) .collect::(); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -651,24 +630,21 @@ fn test_reading_deflate_encoding_large_random() { .post("/") .header(CONTENT_ENCODING, "deflate") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); } -#[test] -#[cfg(feature = "brotli")] -fn test_brotli_encoding() { - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), +#[actix_rt::test] +async fn test_brotli_encoding() { + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -681,24 +657,28 @@ fn test_brotli_encoding() { .post("/") .header(CONTENT_ENCODING, "br") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[cfg(feature = "brotli")] -#[test] -fn test_brotli_encoding_large() { - let data = STR.repeat(10); - let mut srv = TestServer::new(move || { - h1::H1Service::new( - App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| Response::Ok().body(body))), - ), +#[actix_rt::test] +async fn test_brotli_encoding_large() { + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(320_000) + .collect::(); + + let srv = test::start_with(test::config().h1(), || { + App::new().service( + web::resource("/") + .app_data(web::PayloadConfig::new(320_000)) + .route(web::to(move |body: Bytes| { + HttpResponse::Ok().streaming(TestBody::new(body, 10240)) + })), ) }); @@ -711,133 +691,82 @@ fn test_brotli_encoding_large() { .post("/") .header(CONTENT_ENCODING, "br") .send_body(enc.clone()); - let mut response = srv.block_on(request).unwrap(); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.block_on(response.body()).unwrap(); + let bytes = response.body().limit(320_000).await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -// #[cfg(all(feature = "brotli", feature = "ssl"))] -// #[test] -// fn test_brotli_encoding_large_ssl() { -// use actix::{Actor, System}; -// use openssl::ssl::{ -// SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode, -// }; -// // load ssl keys -// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); -// builder -// .set_private_key_file("tests/key.pem", SslFiletype::PEM) -// .unwrap(); -// builder -// .set_certificate_chain_file("tests/cert.pem") -// .unwrap(); +#[cfg(feature = "openssl")] +#[actix_rt::test] +async fn test_brotli_encoding_large_openssl() { + // load ssl keys + use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); -// let data = STR.repeat(10); -// let srv = test::TestServer::build().ssl(builder).start(|app| { -// app.handler(|req: &HttpRequest| { -// req.body() -// .and_then(|bytes: Bytes| { -// Ok(HttpResponse::Ok() -// .content_encoding(http::ContentEncoding::Identity) -// .body(bytes)) -// }) -// .responder() -// }) -// }); -// let mut rt = System::new("test"); + let data = STR.repeat(10); + let srv = test::start_with(test::config().openssl(builder.build()), move || { + App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { + HttpResponse::Ok() + .encoding(actix_web::http::ContentEncoding::Identity) + .body(bytes) + }))) + }); -// // client connector -// let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); -// builder.set_verify(SslVerifyMode::NONE); -// let conn = client::ClientConnector::with_connector(builder.build()).start(); + // body + let mut e = BrotliEncoder::new(Vec::new(), 3); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); -// // body -// let mut e = BrotliEncoder::new(Vec::new(), 5); -// e.write_all(data.as_ref()).unwrap(); -// let enc = e.finish().unwrap(); + // client request + let mut response = srv + .post("/") + .header(actix_web::http::header::CONTENT_ENCODING, "br") + .send_body(enc) + .await + .unwrap(); + assert!(response.status().is_success()); -// // client request -// let request = client::ClientRequest::build() -// .uri(srv.url("/")) -// .method(http::Method::POST) -// .header(http::header::CONTENT_ENCODING, "br") -// .with_connector(conn) -// .body(enc) -// .unwrap(); -// let response = rt.block_on(request.send()).unwrap(); -// assert!(response.status().is_success()); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(data)); +} -// // read response -// let bytes = rt.block_on(response.body()).unwrap(); -// assert_eq!(bytes, Bytes::from(data)); -// } - -#[cfg(all( - feature = "rust-tls", - feature = "ssl", - any(feature = "flate2-zlib", feature = "flate2-rust") -))] -#[test] -fn test_reading_deflate_encoding_large_random_ssl() { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - use rustls::internal::pemfile::{certs, pkcs8_private_keys}; - use rustls::{NoClientAuth, ServerConfig}; +#[cfg(all(feature = "rustls", feature = "openssl"))] +#[actix_rt::test] +async fn test_reading_deflate_encoding_large_random_rustls() { + use rust_tls::internal::pemfile::{certs, pkcs8_private_keys}; + use rust_tls::{NoClientAuth, ServerConfig}; use std::fs::File; use std::io::BufReader; - let addr = TestServer::unused_addr(); - let (tx, rx) = mpsc::channel(); - let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(160_000) .collect::(); - thread::spawn(move || { - let sys = actix_rt::System::new("test"); + // load ssl keys + let mut config = ServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - // load ssl keys - let mut config = ServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - - let srv = HttpServer::new(|| { - App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { - Ok::<_, Error>( - HttpResponse::Ok() - .encoding(http::ContentEncoding::Identity) - .body(bytes), - ) - }))) - }) - .bind_rustls(addr, config) - .unwrap() - .start(); - - let _ = tx.send((srv, actix_rt::System::current())); - let _ = sys.run(); - }); - let (srv, _sys) = rx.recv().unwrap(); - test::block_on(futures::lazy(|| Ok::<_, ()>(start_default_resolver()))).unwrap(); - let client = test::run_on(|| { - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder.set_alpn_protos(b"\x02h2\x08http/1.1").unwrap(); - - awc::Client::build() - .connector( - awc::Connector::new() - .timeout(std::time::Duration::from_millis(500)) - .ssl(builder.build()) - .finish(), - ) - .finish() + let srv = test::start_with(test::config().rustls(config), || { + App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { + HttpResponse::Ok() + .encoding(actix_web::http::ContentEncoding::Identity) + .body(bytes) + }))) }); // encode data @@ -846,115 +775,25 @@ fn test_reading_deflate_encoding_large_random_ssl() { let enc = e.finish().unwrap(); // client request - let req = client - .post(format!("https://localhost:{}/", addr.port())) - .header(http::header::CONTENT_ENCODING, "deflate") - .send_body(enc); + let req = srv + .post("/") + .header(actix_web::http::header::CONTENT_ENCODING, "deflate") + .send_stream(TestBody::new(Bytes::from(enc), 1024)); - let mut response = test::block_on(req).unwrap(); + let mut response = req.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = test::block_on(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); - - // stop - let _ = srv.stop(false); } -// #[cfg(all(feature = "tls", feature = "ssl"))] -// #[test] -// fn test_reading_deflate_encoding_large_random_tls() { -// use native_tls::{Identity, TlsAcceptor}; -// use openssl::ssl::{ -// SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode, -// }; -// use std::fs::File; -// use std::sync::mpsc; - -// use actix::{Actor, System}; -// let (tx, rx) = mpsc::channel(); - -// // load ssl keys -// let mut file = File::open("tests/identity.pfx").unwrap(); -// let mut identity = vec![]; -// file.read_to_end(&mut identity).unwrap(); -// let identity = Identity::from_pkcs12(&identity, "1").unwrap(); -// let acceptor = TlsAcceptor::new(identity).unwrap(); - -// // load ssl keys -// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); -// builder -// .set_private_key_file("tests/key.pem", SslFiletype::PEM) -// .unwrap(); -// builder -// .set_certificate_chain_file("tests/cert.pem") -// .unwrap(); - -// let data = rand::thread_rng() -// .sample_iter(&Alphanumeric) -// .take(160_000) -// .collect::(); - -// let addr = test::TestServer::unused_addr(); -// thread::spawn(move || { -// System::run(move || { -// server::new(|| { -// App::new().handler("/", |req: &HttpRequest| { -// req.body() -// .and_then(|bytes: Bytes| { -// Ok(HttpResponse::Ok() -// .content_encoding(http::ContentEncoding::Identity) -// .body(bytes)) -// }) -// .responder() -// }) -// }) -// .bind_tls(addr, acceptor) -// .unwrap() -// .start(); -// let _ = tx.send(System::current()); -// }); -// }); -// let sys = rx.recv().unwrap(); - -// let mut rt = System::new("test"); - -// // client connector -// let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); -// builder.set_verify(SslVerifyMode::NONE); -// let conn = client::ClientConnector::with_connector(builder.build()).start(); - -// // encode data -// let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); -// e.write_all(data.as_ref()).unwrap(); -// let enc = e.finish().unwrap(); - -// // client request -// let request = client::ClientRequest::build() -// .uri(format!("https://{}/", addr)) -// .method(http::Method::POST) -// .header(http::header::CONTENT_ENCODING, "deflate") -// .with_connector(conn) -// .body(enc) -// .unwrap(); -// let response = rt.block_on(request.send()).unwrap(); -// assert!(response.status().is_success()); - -// // read response -// let bytes = rt.block_on(response.body()).unwrap(); -// assert_eq!(bytes.len(), data.len()); -// assert_eq!(bytes, Bytes::from(data)); - -// let _ = sys.stop(); -// } - // #[test] // fn test_server_cookies() { // use actix_web::http; -// let mut srv = test::TestServer::with_factory(|| { +// let srv = test::TestServer::with_factory(|| { // App::new().resource("/", |r| { // r.f(|_| { // HttpResponse::Ok() @@ -1006,55 +845,31 @@ fn test_reading_deflate_encoding_large_random_ssl() { // } // } -// #[test] -// fn test_slow_request() { -// use actix::System; +#[actix_rt::test] +async fn test_slow_request() { + use std::net; + + let srv = test::start_with(test::config().client_timeout(200), || { + App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); +} + +// #[cfg(feature = "openssl")] +// #[actix_rt::test] +// async fn test_ssl_handshake_timeout() { +// use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; // use std::net; -// use std::sync::mpsc; -// let (tx, rx) = mpsc::channel(); - -// let addr = test::TestServer::unused_addr(); -// thread::spawn(move || { -// System::run(move || { -// let srv = server::new(|| { -// vec![App::new().resource("/", |r| { -// r.method(http::Method::GET).f(|_| HttpResponse::Ok()) -// })] -// }); - -// let srv = srv.bind(addr).unwrap(); -// srv.client_timeout(200).start(); -// let _ = tx.send(System::current()); -// }); -// }); -// let sys = rx.recv().unwrap(); - -// thread::sleep(time::Duration::from_millis(200)); - -// let mut stream = net::TcpStream::connect(addr).unwrap(); -// let mut data = String::new(); -// let _ = stream.read_to_string(&mut data); -// assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); - -// let mut stream = net::TcpStream::connect(addr).unwrap(); -// let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); -// let mut data = String::new(); -// let _ = stream.read_to_string(&mut data); -// assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); - -// sys.stop(); -// } - -// #[test] -// #[cfg(feature = "ssl")] -// fn test_ssl_handshake_timeout() { -// use actix::System; -// use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; -// use std::net; -// use std::sync::mpsc; - -// let (tx, rx) = mpsc::channel(); -// let addr = test::TestServer::unused_addr(); // // load ssl keys // let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); @@ -1065,28 +880,12 @@ fn test_reading_deflate_encoding_large_random_ssl() { // .set_certificate_chain_file("tests/cert.pem") // .unwrap(); -// thread::spawn(move || { -// System::run(move || { -// let srv = server::new(|| { -// App::new().resource("/", |r| { -// r.method(http::Method::GET).f(|_| HttpResponse::Ok()) -// }) -// }); - -// srv.bind_ssl(addr, builder) -// .unwrap() -// .workers(1) -// .client_timeout(200) -// .start(); -// let _ = tx.send(System::current()); -// }); +// let srv = test::start_with(test::config().openssl(builder.build()), || { +// App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))) // }); -// let sys = rx.recv().unwrap(); -// let mut stream = net::TcpStream::connect(addr).unwrap(); +// let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); // let mut data = String::new(); // let _ = stream.read_to_string(&mut data); // assert!(data.is_empty()); - -// let _ = sys.stop(); // }