diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b779b33fa..42deadf5a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,21 +1,21 @@ - + ## PR Type -INSERT_PR_TYPE +PR_TYPE ## PR Checklist -Check your PR fulfills the following: - + - [ ] Tests for the changes have been added / updated. - [ ] Documentation comments have been added / updated. - [ ] A changelog entry has been made for the appropriate packages. -- [ ] Format code with the latest stable rustfmt +- [ ] Format code with the latest stable rustfmt. +- [ ] (Team) Label with affected crates and semver status. ## Overview diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml index d19471a16..828d62561 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml @@ -1,4 +1,4 @@ -name: Benchmark (Linux) +name: Benchmark on: pull_request: diff --git a/.github/workflows/linux.yml b/.github/workflows/ci.yml similarity index 52% rename from .github/workflows/linux.yml rename to .github/workflows/ci.yml index 53f22df63..7d0520d52 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/ci.yml @@ -1,24 +1,27 @@ -name: CI (Linux) +name: CI on: pull_request: types: [opened, synchronize, reopened] push: - branches: - - master + branches: [master] jobs: build_and_test: strategy: fail-fast: false matrix: + target: + - { name: Linux, os: ubuntu-latest, triple: x86_64-unknown-linux-gnu } + - { name: macOS, os: macos-latest, triple: x86_64-apple-darwin } + - { name: Windows, os: windows-latest, triple: x86_64-pc-windows-msvc } version: - 1.46.0 # MSRV - stable - nightly - name: ${{ matrix.version }} - x86_64-unknown-linux-gnu - runs-on: ubuntu-latest + name: ${{ matrix.target.name }} / ${{ matrix.version }} + runs-on: ${{ matrix.target.os }} steps: - uses: actions/checkout@v2 @@ -26,7 +29,7 @@ jobs: - name: Install ${{ matrix.version }} uses: actions-rs/toolchain@v1 with: - toolchain: ${{ matrix.version }}-x86_64-unknown-linux-gnu + toolchain: ${{ matrix.version }}-${{ matrix.target.triple }} profile: minimal override: true @@ -35,20 +38,33 @@ jobs: with: command: generate-lockfile - name: Cache Dependencies - uses: Swatinem/rust-cache@v1.0.1 + uses: Swatinem/rust-cache@v1.2.0 - - name: check build + - name: Install cargo-hack + uses: actions-rs/cargo@v1 + with: + command: install + args: cargo-hack + + - name: check minimal + uses: actions-rs/cargo@v1 + with: + command: hack + args: --clean-per-run check --workspace --no-default-features --tests + + - name: check full uses: actions-rs/cargo@v1 with: command: check - args: --all --bins --examples --tests + args: --workspace --bins --examples --tests - name: tests uses: actions-rs/cargo@v1 - timeout-minutes: 40 with: command: test - args: --all --all-features --no-fail-fast -- --nocapture + args: -v --workspace --all-features --no-fail-fast -- --nocapture + --skip=test_h2_content_length + --skip=test_reading_deflate_encoding_large_random_rustls - name: tests (actix-http) uses: actions-rs/cargo@v1 @@ -65,12 +81,18 @@ jobs: args: --package=awc --no-default-features --features=rustls -- --nocapture - name: Generate coverage file - if: matrix.version == 'stable' && github.ref == 'refs/heads/master' + if: > + matrix.target.os == 'ubuntu-latest' + && matrix.version == 'stable' + && github.ref == 'refs/heads/master' run: | cargo install cargo-tarpaulin --vers "^0.13" - cargo tarpaulin --out Xml + cargo tarpaulin --out Xml --verbose - name: Upload to Codecov - if: matrix.version == 'stable' && github.ref == 'refs/heads/master' + if: > + matrix.target.os == 'ubuntu-latest' + && matrix.version == 'stable' + && github.ref == 'refs/heads/master' uses: codecov/codecov-action@v1 with: file: cobertura.xml diff --git a/.github/workflows/clippy-fmt.yml b/.github/workflows/clippy-fmt.yml index fb1ed7f32..e966fa4ab 100644 --- a/.github/workflows/clippy-fmt.yml +++ b/.github/workflows/clippy-fmt.yml @@ -1,32 +1,39 @@ +name: Lint + on: pull_request: types: [opened, synchronize, reopened] -name: Clippy and rustfmt Check jobs: - clippy_check: + fmt: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 + - name: Install Rust + uses: actions-rs/toolchain@v1 with: toolchain: stable components: rustfmt - override: true - name: Check with rustfmt uses: actions-rs/cargo@v1 with: command: fmt args: --all -- --check - - uses: actions-rs/toolchain@v1 + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install Rust + uses: actions-rs/toolchain@v1 with: - toolchain: nightly + toolchain: stable components: clippy override: true - name: Check with Clippy uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} - args: --all-features --all --tests + args: --workspace --tests --all-features diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml deleted file mode 100644 index 6b5366faf..000000000 --- a/.github/workflows/macos.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: CI (macOS) - -on: - pull_request: - types: [opened, synchronize, reopened] - push: - branches: - - master - -jobs: - build_and_test: - strategy: - fail-fast: false - matrix: - version: - - stable - - nightly - - name: ${{ matrix.version }} - x86_64-apple-darwin - runs-on: macOS-latest - - steps: - - uses: actions/checkout@v2 - - - name: Install ${{ matrix.version }} - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.version }}-x86_64-apple-darwin - profile: minimal - override: true - - - name: Generate Cargo.lock - uses: actions-rs/cargo@v1 - with: - command: generate-lockfile - - name: Cache Dependencies - uses: Swatinem/rust-cache@v1.0.1 - - - name: check build - uses: actions-rs/cargo@v1 - with: - command: check - args: --all --bins --examples --tests - - - name: tests - uses: actions-rs/cargo@v1 - with: - command: test - args: --all --all-features --no-fail-fast -- --nocapture - --skip=test_h2_content_length - --skip=test_reading_deflate_encoding_large_random_rustls - - - name: Clear the cargo caches - run: | - cargo install cargo-cache --no-default-features --features ci-autoclean - cargo-cache diff --git a/.github/workflows/upload-doc.yml b/.github/workflows/upload-doc.yml index ba87a5637..c080dd8c3 100644 --- a/.github/workflows/upload-doc.yml +++ b/.github/workflows/upload-doc.yml @@ -24,7 +24,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: doc - args: --no-deps --workspace --all-features + args: --workspace --all-features --no-deps - name: Tweak HTML run: echo "" > target/doc/index.html diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml deleted file mode 100644 index d3de72a61..000000000 --- a/.github/workflows/windows.yml +++ /dev/null @@ -1,76 +0,0 @@ -name: CI (Windows) - -on: - pull_request: - types: [opened, synchronize, reopened] - push: - branches: - - master - -env: - VCPKGRS_DYNAMIC: 1 - -jobs: - build_and_test: - strategy: - fail-fast: false - matrix: - version: - - stable - - nightly - - name: ${{ matrix.version }} - x86_64-pc-windows-msvc - runs-on: windows-latest - - steps: - - uses: actions/checkout@v2 - - - name: Install ${{ matrix.version }} - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.version }}-x86_64-pc-windows-msvc - profile: minimal - override: true - - - name: Install OpenSSL - run: | - vcpkg integrate install - vcpkg install openssl:x64-windows - Copy-Item C:\vcpkg\installed\x64-windows\bin\libcrypto-1_1-x64.dll C:\vcpkg\installed\x64-windows\bin\libcrypto.dll - Copy-Item C:\vcpkg\installed\x64-windows\bin\libssl-1_1-x64.dll C:\vcpkg\installed\x64-windows\bin\libssl.dll - Get-ChildItem C:\vcpkg\installed\x64-windows\bin - Get-ChildItem C:\vcpkg\installed\x64-windows\lib - - - name: Generate Cargo.lock - uses: actions-rs/cargo@v1 - with: - command: generate-lockfile - - name: Cache Dependencies - uses: Swatinem/rust-cache@v1.0.1 - - - name: check build - uses: actions-rs/cargo@v1 - with: - command: check - args: --all --bins --examples --tests - - - name: tests - uses: actions-rs/cargo@v1 - with: - command: test - args: --all --all-features --no-fail-fast -- --nocapture - --skip=test_h2_content_length - --skip=test_reading_deflate_encoding_large_random_rustls - --skip=test_params - --skip=test_simple - --skip=test_expect_continue - --skip=test_http10_keepalive - --skip=test_slow_request - --skip=test_connection_force_close - --skip=test_connection_server_close - --skip=test_connection_wait_queue_force_close - - - name: Clear the cargo caches - run: | - cargo install cargo-cache --no-default-features --features ci-autoclean - cargo-cache diff --git a/CHANGES.md b/CHANGES.md index fa56acc17..7cb03c30c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,16 +1,89 @@ # Changes -## Unreleased - 2020-xx-xx +## Unreleased - 2021-xx-xx ### Changed -* Bumped `rand` to `0.8` +* Feature `cookies` is now optional and enabled by default. [#1981] +* `JsonBody::new` returns a default limit of 32kB to be consistent with `JsonConfig` and the + default behaviour of the `web::Json` extractor. [#2010] + +[#1981]: https://github.com/actix/actix-web/pull/1981 +[#2010]: https://github.com/actix/actix-web/pull/2010 + + +## 4.0.0-beta.3 - 2021-02-10 +* Update `actix-web-codegen` to `0.5.0-beta.1`. + + +## 4.0.0-beta.2 - 2021-02-10 +### Added +* The method `Either, web::Form>::into_inner()` which returns the inner type for + whichever variant was created. Also works for `Either, web::Json>`. [#1894] +* Add `services!` macro for helping register multiple services to `App`. [#1933] +* Enable registering a vec of services of the same type to `App` [#1933] + +### Changed +* Rework `Responder` trait to be sync and returns `Response`/`HttpResponse` directly. + Making it simpler and more performant. [#1891] +* `ServiceRequest::into_parts` and `ServiceRequest::from_parts` can no longer fail. [#1893] +* `ServiceRequest::from_request` can no longer fail. [#1893] +* Our `Either` type now uses `Left`/`Right` variants (instead of `A`/`B`) [#1894] +* `test::{call_service, read_response, read_response_json, send_request}` take `&Service` + in argument [#1905] +* `App::wrap_fn`, `Resource::wrap_fn` and `Scope::wrap_fn` provide `&Service` in closure + argument. [#1905] +* `web::block` no longer requires the output is a Result. [#1957] + +### Fixed +* Multiple calls to `App::data` with the same type now keeps the latest call's data. [#1906] + +### Removed +* Public field of `web::Path` has been made private. [#1894] +* Public field of `web::Query` has been made private. [#1894] +* `TestRequest::with_header`; use `TestRequest::default().insert_header()`. [#1869] +* `AppService::set_service_data`; for custom HTTP service factories adding application data, use the + layered data model by calling `ServiceRequest::add_data_container` when handling + requests instead. [#1906] + +[#1891]: https://github.com/actix/actix-web/pull/1891 +[#1893]: https://github.com/actix/actix-web/pull/1893 +[#1894]: https://github.com/actix/actix-web/pull/1894 +[#1869]: https://github.com/actix/actix-web/pull/1869 +[#1905]: https://github.com/actix/actix-web/pull/1905 +[#1906]: https://github.com/actix/actix-web/pull/1906 +[#1933]: https://github.com/actix/actix-web/pull/1933 +[#1957]: https://github.com/actix/actix-web/pull/1957 + + +## 4.0.0-beta.1 - 2021-01-07 +### Added +* `Compat` middleware enabling generic response body/error type of middlewares like `Logger` and + `Compress` to be used in `middleware::Condition` and `Resource`, `Scope` services. [#1865] + +### Changed +* Update `actix-*` dependencies to tokio `1.0` based versions. [#1813] +* Bumped `rand` to `0.8`. +* Update `rust-tls` to `0.19`. [#1813] * Rename `Handler` to `HandlerService` and rename `Factory` to `Handler`. [#1852] +* The default `TrailingSlash` is now `Trim`, in line with existing documentation. See migration + guide for implications. [#1875] +* Rename `DefaultHeaders::{content_type => add_content_type}`. [#1875] * MSRV is now 1.46.0. ### Fixed -* added the actual parsing error to `test::read_body_json` [#1812] +* Added the underlying parse error to `test::read_body_json`'s panic message. [#1812] + +### Removed +* Public modules `middleware::{normalize, err_handlers}`. All necessary middleware structs are now + exposed directly by the `middleware` module. +* Remove `actix-threadpool` as dependency. `actix_threadpool::BlockingError` error type can be imported + from `actix_web::error` module. [#1878] [#1812]: https://github.com/actix/actix-web/pull/1812 +[#1813]: https://github.com/actix/actix-web/pull/1813 [#1852]: https://github.com/actix/actix-web/pull/1852 +[#1865]: https://github.com/actix/actix-web/pull/1865 +[#1875]: https://github.com/actix/actix-web/pull/1875 +[#1878]: https://github.com/actix/actix-web/pull/1878 ## 3.3.2 - 2020-12-01 ### Fixed diff --git a/Cargo.toml b/Cargo.toml index 6ed327f56..1a1b8645c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "3.3.2" +version = "4.0.0-beta.3" authors = ["Nikolay Kim "] description = "Actix Web is a powerful, pragmatic, and extremely fast web framework for Rust" readme = "README.md" @@ -15,6 +15,7 @@ license = "MIT OR Apache-2.0" edition = "2018" [package.metadata.docs.rs] +# features that docs.rs will build with features = ["openssl", "rustls", "compress", "secure-cookies"] [badges] @@ -38,19 +39,22 @@ members = [ ] [features] -default = ["compress"] +default = ["compress", "cookies"] # content-encoding support compress = ["actix-http/compress", "awc/compress"] -# sessions feature +# support for cookies +cookies = ["actix-http/cookies", "awc/cookies"] + +# secure cookies feature secure-cookies = ["actix-http/secure-cookies"] # openssl -openssl = ["actix-tls/openssl", "awc/openssl", "open-ssl"] +openssl = ["tls-openssl", "actix-tls/accept", "actix-tls/openssl", "awc/openssl"] # rustls -rustls = ["actix-tls/rustls", "awc/rustls", "rust-tls"] +rustls = ["tls-rustls", "actix-tls/accept", "actix-tls/rustls", "awc/rustls"] [[example]] name = "basic" @@ -73,51 +77,54 @@ name = "client" required-features = ["rustls"] [dependencies] -actix-codec = "0.3.0" -actix-service = "1.0.6" -actix-utils = "2.0.0" -actix-router = "0.2.4" -actix-rt = "1.1.1" -actix-server = "1.0.0" -actix-testing = "1.0.0" -actix-macros = "0.1.0" -actix-threadpool = "0.3.1" -actix-tls = "2.0.0" +actix-codec = "0.4.0-beta.1" +actix-macros = "0.2.0" +actix-router = "0.2.7" +actix-rt = "2" +actix-server = "2.0.0-beta.3" +actix-service = "2.0.0-beta.4" +actix-utils = "3.0.0-beta.2" +actix-tls = { version = "3.0.0-beta.3", default-features = false, optional = true } -actix-web-codegen = "0.4.0" -actix-http = "2.2.0" -awc = { version = "2.0.3", default-features = false } +actix-web-codegen = "0.5.0-beta.1" +actix-http = "3.0.0-beta.3" +awc = { version = "3.0.0-beta.2", default-features = false } -bytes = "0.5.3" +ahash = "0.7" +bytes = "1" derive_more = "0.99.5" +either = "1.5.3" encoding_rs = "0.8" -futures-channel = { version = "0.3.5", default-features = false } -futures-core = { version = "0.3.5", default-features = false } -futures-util = { version = "0.3.5", default-features = false } -fxhash = "0.2.1" +futures-core = { version = "0.3.7", default-features = false } +futures-util = { version = "0.3.7", default-features = false } log = "0.4" mime = "0.3" -socket2 = "0.3.16" pin-project = "1.0.0" regex = "1.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_urlencoded = "0.7" -time = { version = "0.2.7", default-features = false, features = ["std"] } +smallvec = "1.6" +socket2 = "0.3.16" +time = { version = "0.2.23", default-features = false, features = ["std"] } +tls-openssl = { package = "openssl", version = "0.10.9", optional = true } +tls-rustls = { package = "rustls", version = "0.19.0", optional = true } url = "2.1" -open-ssl = { package = "openssl", version = "0.10", optional = true } -rust-tls = { package = "rustls", version = "0.18.0", optional = true } -tinyvec = { version = "1", features = ["alloc"] } + +[target.'cfg(windows)'.dependencies.tls-openssl] +version = "0.10.9" +package = "openssl" +features = ["vendored"] +optional = true [dev-dependencies] -actix = "0.10.0" -actix-http = { version = "2.1.0", features = ["actors"] } -rand = "0.8" -env_logger = "0.8" -serde_derive = "1.0" brotli2 = "0.3.2" -flate2 = "1.0.13" criterion = "0.3" +env_logger = "0.8" +flate2 = "1.0.13" +rand = "0.8" +rcgen = "0.8" +serde_derive = "1.0" [profile.release] lto = true @@ -128,6 +135,7 @@ codegen-units = 1 actix-web = { path = "." } actix-http = { path = "actix-http" } actix-http-test = { path = "actix-http-test" } +actix-web-actors = { path = "actix-web-actors" } actix-web-codegen = { path = "actix-web-codegen" } actix-multipart = { path = "actix-multipart" } actix-files = { path = "actix-files" } @@ -140,3 +148,7 @@ harness = false [[bench]] name = "service" harness = false + +[[bench]] +name = "responder" +harness = false diff --git a/MIGRATION.md b/MIGRATION.md index 5c4650194..e01702868 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,5 +1,15 @@ ## Unreleased +* The default `NormalizePath` behavior now strips trailing slashes by default. This was + previously documented to be the case in v3 but the behavior now matches. The effect is that + routes defined with trailing slashes will become inaccessible when + using `NormalizePath::default()`. + + Before: `#[get("/test/")` + After: `#[get("/test")` + + Alternatively, explicitly require trailing slashes: `NormalizePath::new(TrailingSlash::Always)`. + ## 3.0.0 diff --git a/README.md b/README.md index 62ee50243..bf68e7961 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,19 @@
-

Actix web

+

Actix Web

Actix Web is a powerful, pragmatic, and extremely fast web framework for Rust

[![crates.io](https://img.shields.io/crates/v/actix-web?label=latest)](https://crates.io/crates/actix-web) -[![Documentation](https://docs.rs/actix-web/badge.svg?version=3.3.2)](https://docs.rs/actix-web/3.3.2) +[![Documentation](https://docs.rs/actix-web/badge.svg?version=4.0.0-beta.2)](https://docs.rs/actix-web/4.0.0-beta.2) [![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) -![License](https://img.shields.io/crates/l/actix-web.svg) -[![Dependency Status](https://deps.rs/crate/actix-web/3.3.2/status.svg)](https://deps.rs/crate/actix-web/3.3.2) +![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-web.svg) +[![Dependency Status](https://deps.rs/crate/actix-web/4.0.0-beta.2/status.svg)](https://deps.rs/crate/actix-web/4.0.0-beta.2)
-[![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) +[![build status](https://github.com/actix/actix-web/workflows/CI%20%28Linux%29/badge.svg?branch=master&event=push)](https://github.com/actix/actix-web/actions) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) -[![Download](https://img.shields.io/crates/d/actix-web.svg)](https://crates.io/crates/actix-web) -[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +![downloads](https://img.shields.io/crates/d/actix-web.svg) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x)

@@ -32,8 +31,7 @@ * Static assets * SSL support using OpenSSL or Rustls * Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) -* Includes an async [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) -* Supports [Actix actor framework](https://github.com/actix/actix) +* Includes an async [HTTP client](https://docs.rs/actix-web/latest/actix_web/client/index.html) * Runs on stable Rust 1.46+ ## Documentation @@ -99,13 +97,13 @@ One of the fastest web frameworks available according to the This project is licensed under either of * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or - [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)) + [http://www.apache.org/licenses/LICENSE-2.0]) * MIT license ([LICENSE-MIT](LICENSE-MIT) or - [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)) + [http://opensource.org/licenses/MIT]) at your option. ## Code of Conduct -Contribution to the actix-web crate is organized under the terms of the Contributor Covenant, the -maintainers of Actix web, promises to intervene to uphold that code of conduct. +Contribution to the actix-web repo is organized under the terms of the Contributor Covenant. +The Actix team promises to intervene to uphold that code of conduct. diff --git a/actix-files/CHANGES.md b/actix-files/CHANGES.md index 49768419b..6e2a241ac 100644 --- a/actix-files/CHANGES.md +++ b/actix-files/CHANGES.md @@ -1,7 +1,21 @@ # Changes -## Unreleased - 2020-xx-xx +## Unreleased - 2021-xx-xx + + +## 0.6.0-beta.2 - 2021-02-10 +* Fix If-Modified-Since and If-Unmodified-Since to not compare using sub-second timestamps. [#1887] +* Replace `v_htmlescape` with `askama_escape`. [#1953] + +[#1887]: https://github.com/actix/actix-web/pull/1887 +[#1953]: https://github.com/actix/actix-web/pull/1953 + + +## 0.6.0-beta.1 - 2021-01-07 * `HttpRange::parse` now has its own error type. +* Update `bytes` to `1.0`. [#1813] + +[#1813]: https://github.com/actix/actix-web/pull/1813 ## 0.5.0 - 2020-12-26 diff --git a/actix-files/Cargo.toml b/actix-files/Cargo.toml index b67abb1c1..08b7b36fc 100644 --- a/actix-files/Cargo.toml +++ b/actix-files/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-files" -version = "0.5.0" +version = "0.6.0-beta.2" authors = ["Nikolay Kim "] description = "Static file serving for Actix Web" readme = "README.md" @@ -17,19 +17,21 @@ name = "actix_files" path = "src/lib.rs" [dependencies] -actix-web = { version = "3.0.0", default-features = false } -actix-service = "1.0.6" +actix-web = { version = "4.0.0-beta.3", default-features = false } +actix-service = "2.0.0-beta.4" + +askama_escape = "0.10" bitflags = "1" -bytes = "0.5.3" +bytes = "1" futures-core = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false } +http-range = "0.1.4" derive_more = "0.99.5" log = "0.4" mime = "0.3" mime_guess = "2.0.1" percent-encoding = "2.1" -v_htmlescape = "0.12" [dev-dependencies] -actix-rt = "1.0.0" -actix-web = "3.0.0" +actix-rt = "2" +actix-web = "4.0.0-beta.3" diff --git a/actix-files/src/chunked.rs b/actix-files/src/chunked.rs index 580b06787..f639848c9 100644 --- a/actix-files/src/chunked.rs +++ b/actix-files/src/chunked.rs @@ -9,26 +9,35 @@ use std::{ use actix_web::{ error::{BlockingError, Error}, - web, + rt::task::{spawn_blocking, JoinHandle}, }; use bytes::Bytes; use futures_core::{ready, Stream}; -use futures_util::future::{FutureExt, LocalBoxFuture}; - -use crate::handle_error; - -type ChunkedBoxFuture = - LocalBoxFuture<'static, Result<(File, Bytes), BlockingError>>; #[doc(hidden)] /// A helper created from a `std::fs::File` which reads the file /// chunk-by-chunk on a `ThreadPool`. pub struct ChunkedReadFile { - pub(crate) size: u64, - pub(crate) offset: u64, - pub(crate) file: Option, - pub(crate) fut: Option, - pub(crate) counter: u64, + size: u64, + offset: u64, + state: ChunkedReadFileState, + counter: u64, +} + +enum ChunkedReadFileState { + File(Option), + Future(JoinHandle>), +} + +impl ChunkedReadFile { + pub(crate) fn new(size: u64, offset: u64, file: File) -> Self { + Self { + size, + offset, + state: ChunkedReadFileState::File(Some(file)), + counter: 0, + } + } } impl fmt::Debug for ChunkedReadFile { @@ -40,55 +49,50 @@ impl fmt::Debug for ChunkedReadFile { impl Stream for ChunkedReadFile { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - if let Some(ref mut fut) = self.fut { - return match ready!(Pin::new(fut).poll(cx)) { - Ok((file, bytes)) => { - self.fut.take(); - self.file = Some(file); + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.as_mut().get_mut(); + match this.state { + ChunkedReadFileState::File(ref mut file) => { + let size = this.size; + let offset = this.offset; + let counter = this.counter; - self.offset += bytes.len() as u64; - self.counter += bytes.len() as u64; + if size == counter { + Poll::Ready(None) + } else { + let mut file = file + .take() + .expect("ChunkedReadFile polled after completion"); - Poll::Ready(Some(Ok(bytes))) + let fut = spawn_blocking(move || { + let max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + + let mut buf = Vec::with_capacity(max_bytes); + file.seek(io::SeekFrom::Start(offset))?; + + let n_bytes = + file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; + + if n_bytes == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + + Ok((file, Bytes::from(buf))) + }); + this.state = ChunkedReadFileState::Future(fut); + self.poll_next(cx) } - Err(e) => Poll::Ready(Some(Err(handle_error(e)))), - }; - } + } + ChunkedReadFileState::Future(ref mut fut) => { + let (file, bytes) = + ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; + this.state = ChunkedReadFileState::File(Some(file)); - let size = self.size; - let offset = self.offset; - let counter = self.counter; + this.offset += bytes.len() as u64; + this.counter += bytes.len() as u64; - if size == counter { - Poll::Ready(None) - } else { - let mut file = self.file.take().expect("Use after completion"); - - self.fut = Some( - web::block(move || { - let max_bytes = - cmp::min(size.saturating_sub(counter), 65_536) as usize; - - let mut buf = Vec::with_capacity(max_bytes); - file.seek(io::SeekFrom::Start(offset))?; - - let n_bytes = - file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; - - if n_bytes == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()); - } - - Ok((file, Bytes::from(buf))) - }) - .boxed_local(), - ); - - self.poll_next(cx) + Poll::Ready(Some(Ok(bytes))) + } } } } diff --git a/actix-files/src/directory.rs b/actix-files/src/directory.rs index 3717985d3..80e0c98d0 100644 --- a/actix-files/src/directory.rs +++ b/actix-files/src/directory.rs @@ -1,8 +1,8 @@ use std::{fmt::Write, fs::DirEntry, io, path::Path, path::PathBuf}; use actix_web::{dev::ServiceResponse, HttpRequest, HttpResponse}; +use askama_escape::{escape as escape_html_entity, Html}; use percent_encoding::{utf8_percent_encode, CONTROLS}; -use v_htmlescape::escape as escape_html_entity; /// A directory; responds with the generated directory listing. #[derive(Debug)] @@ -50,7 +50,7 @@ macro_rules! encode_file_url { // " -- " & -- & ' -- ' < -- < > -- > / -- / macro_rules! encode_file_name { ($entry:ident) => { - escape_html_entity(&$entry.file_name().to_string_lossy()) + escape_html_entity(&$entry.file_name().to_string_lossy(), Html) }; } @@ -66,9 +66,7 @@ pub(crate) fn directory_listing( if dir.is_visible(&entry) { let entry = entry.unwrap(); let p = match entry.path().strip_prefix(&dir.path) { - Ok(p) if cfg!(windows) => { - base.join(p).to_string_lossy().replace("\\", "/") - } + Ok(p) if cfg!(windows) => base.join(p).to_string_lossy().replace("\\", "/"), Ok(p) => base.join(p).to_string_lossy().into_owned(), Err(_) => continue, }; diff --git a/actix-files/src/files.rs b/actix-files/src/files.rs index d0cac6aa4..6f8b28bbf 100644 --- a/actix-files/src/files.rs +++ b/actix-files/src/files.rs @@ -1,10 +1,8 @@ use std::{cell::RefCell, fmt, io, path::PathBuf, rc::Rc}; -use actix_service::{boxed, IntoServiceFactory, ServiceFactory}; +use actix_service::{boxed, IntoServiceFactory, ServiceFactory, ServiceFactoryExt}; use actix_web::{ - dev::{ - AppService, HttpServiceFactory, ResourceDef, ServiceRequest, ServiceResponse, - }, + dev::{AppService, HttpServiceFactory, ResourceDef, ServiceRequest, ServiceResponse}, error::Error, guard::Guard, http::header::DispositionType, @@ -13,8 +11,8 @@ use actix_web::{ use futures_util::future::{ok, FutureExt, LocalBoxFuture}; use crate::{ - directory_listing, named, Directory, DirectoryRenderer, FilesService, - HttpNewService, MimeOverride, + directory_listing, named, Directory, DirectoryRenderer, FilesService, HttpNewService, + MimeOverride, }; /// Static files handling service. @@ -82,8 +80,9 @@ impl Files { /// be inaccessible. Register more specific handlers and services first. /// /// `Files` uses a threadpool for blocking filesystem operations. By default, the pool uses a - /// number of threads equal to 5x the number of available logical CPUs. Pool size can be changed - /// by setting ACTIX_THREADPOOL environment variable. + /// max number of threads equal to `512 * HttpServer::worker`. Real time thread count are + /// adjusted with work load. More threads would spawn when need and threads goes idle for a + /// period of time would be de-spawned. pub fn new>(mount_path: &str, serve_from: T) -> Files { let orig_dir = serve_from.into(); let dir = match orig_dir.canonicalize() { @@ -128,8 +127,8 @@ impl Files { /// Set custom directory renderer pub fn files_listing_renderer(mut self, f: F) -> Self where - for<'r, 's> F: Fn(&'r Directory, &'s HttpRequest) -> Result - + 'static, + for<'r, 's> F: + Fn(&'r Directory, &'s HttpRequest) -> Result + 'static, { self.renderer = Rc::new(f); self @@ -201,10 +200,10 @@ impl Files { /// Sets default handler which is used when no matched file could be found. pub fn default_handler(mut self, f: F) -> Self where - F: IntoServiceFactory, + F: IntoServiceFactory, U: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, > + 'static, @@ -241,8 +240,7 @@ impl HttpServiceFactory for Files { } } -impl ServiceFactory for Files { - type Request = ServiceRequest; +impl ServiceFactory for Files { type Response = ServiceResponse; type Error = Error; type Config = (); diff --git a/actix-files/src/lib.rs b/actix-files/src/lib.rs index 6769d6d38..d26e55899 100644 --- a/actix-files/src/lib.rs +++ b/actix-files/src/lib.rs @@ -15,12 +15,10 @@ #![warn(missing_docs, missing_debug_implementations)] #![clippy::msrv = "1.46"] -use std::io; - use actix_service::boxed::{BoxService, BoxServiceFactory}; use actix_web::{ dev::{ServiceRequest, ServiceResponse}, - error::{BlockingError, Error, ErrorInternalServerError}, + error::Error, http::header::DispositionType, }; use mime_guess::from_ext; @@ -57,13 +55,6 @@ pub fn file_extension_to_mime(ext: &str) -> mime::Mime { from_ext(ext).first_or_octet_stream() } -pub(crate) fn handle_error(err: BlockingError) -> Error { - match err { - BlockingError::Error(err) => err.into(), - BlockingError::Canceled => ErrorInternalServerError("Unexpected error"), - } -} - type MimeOverride = dyn Fn(&mime::Name<'_>) -> DispositionType; #[cfg(test)] @@ -83,7 +74,8 @@ mod tests { }, middleware::Compress, test::{self, TestRequest}, - web, App, HttpResponse, Responder, + web::{self, Bytes}, + App, HttpResponse, Responder, }; use futures_util::future::ok; @@ -107,11 +99,22 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_without_if_none_match() { let file = NamedFile::open("Cargo.toml").unwrap(); - let since = - header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); + let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); let req = TestRequest::default() - .header(header::IF_MODIFIED_SINCE, since) + .insert_header((header::IF_MODIFIED_SINCE, since)) + .to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_MODIFIED); + } + + #[actix_rt::test] + async fn test_if_modified_since_without_if_none_match_same() { + let file = NamedFile::open("Cargo.toml").unwrap(); + let since = file.last_modified().unwrap(); + + let req = TestRequest::default() + .insert_header((header::IF_MODIFIED_SINCE, since)) .to_http_request(); let resp = file.respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_MODIFIED); @@ -120,17 +123,40 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_with_if_none_match() { let file = NamedFile::open("Cargo.toml").unwrap(); - let since = - header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); + let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); let req = TestRequest::default() - .header(header::IF_NONE_MATCH, "miss_etag") - .header(header::IF_MODIFIED_SINCE, since) + .insert_header((header::IF_NONE_MATCH, "miss_etag")) + .insert_header((header::IF_MODIFIED_SINCE, since)) .to_http_request(); let resp = file.respond_to(&req).await.unwrap(); assert_ne!(resp.status(), StatusCode::NOT_MODIFIED); } + #[actix_rt::test] + async fn test_if_unmodified_since() { + let file = NamedFile::open("Cargo.toml").unwrap(); + let since = file.last_modified().unwrap(); + + let req = TestRequest::default() + .insert_header((header::IF_UNMODIFIED_SINCE, since)) + .to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_if_unmodified_since_failed() { + let file = NamedFile::open("Cargo.toml").unwrap(); + let since = header::HttpDate::from(SystemTime::UNIX_EPOCH); + + let req = TestRequest::default() + .insert_header((header::IF_UNMODIFIED_SINCE, since)) + .to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::PRECONDITION_FAILED); + } + #[actix_rt::test] async fn test_named_file_text() { assert!(NamedFile::open("test--").is_err()); @@ -185,8 +211,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_non_ascii_file_name() { let mut file = - NamedFile::from_file(File::open("Cargo.toml").unwrap(), "貨物.toml") - .unwrap(); + NamedFile::from_file(File::open("Cargo.toml").unwrap(), "貨物.toml").unwrap(); { file.file(); let _f: &File = &file; @@ -339,7 +364,7 @@ mod tests { DispositionType::Attachment } - let mut srv = test::init_service( + let srv = test::init_service( App::new().service( Files::new("/", ".") .mime_override(all_attachment) @@ -349,7 +374,7 @@ mod tests { .await; let request = TestRequest::get().uri("/").to_request(); - let response = test::call_service(&mut srv, request).await; + let response = test::call_service(&srv, request).await; assert_eq!(response.status(), StatusCode::OK); let content_disposition = response @@ -364,7 +389,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_ranges_status_code() { - let mut srv = test::init_service( + let srv = test::init_service( App::new().service(Files::new("/test", ".").index_file("Cargo.toml")), ) .await; @@ -372,17 +397,17 @@ mod tests { // Valid range header let request = TestRequest::get() .uri("/t%65st/Cargo.toml") - .header(header::RANGE, "bytes=10-20") + .insert_header((header::RANGE, "bytes=10-20")) .to_request(); - let response = test::call_service(&mut srv, request).await; + let response = test::call_service(&srv, request).await; assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT); // Invalid range header let request = TestRequest::get() .uri("/t%65st/Cargo.toml") - .header(header::RANGE, "bytes=1-0") + .insert_header((header::RANGE, "bytes=1-0")) .to_request(); - let response = test::call_service(&mut srv, request).await; + let response = test::call_service(&srv, request).await; assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); } @@ -394,7 +419,7 @@ mod tests { // Valid range header let response = srv .get("/tests/test.binary") - .header(header::RANGE, "bytes=10-20") + .insert_header((header::RANGE, "bytes=10-20")) .send() .await .unwrap(); @@ -404,7 +429,7 @@ mod tests { // Invalid range header let response = srv .get("/tests/test.binary") - .header(header::RANGE, "bytes=10-5") + .insert_header((header::RANGE, "bytes=10-5")) .send() .await .unwrap(); @@ -419,7 +444,7 @@ mod tests { // Valid range header let response = srv .get("/tests/test.binary") - .header(header::RANGE, "bytes=10-20") + .insert_header((header::RANGE, "bytes=10-20")) .send() .await .unwrap(); @@ -429,7 +454,7 @@ mod tests { // Valid range header, starting from 0 let response = srv .get("/tests/test.binary") - .header(header::RANGE, "bytes=0-20") + .insert_header((header::RANGE, "bytes=0-20")) .send() .await .unwrap(); @@ -469,14 +494,14 @@ mod tests { #[actix_rt::test] async fn test_static_files_with_spaces() { - let mut srv = test::init_service( + let srv = test::init_service( App::new().service(Files::new("/", ".").index_file("Cargo.toml")), ) .await; let request = TestRequest::get() .uri("/tests/test%20space.binary") .to_request(); - let response = test::call_service(&mut srv, request).await; + let response = test::call_service(&srv, request).await; assert_eq!(response.status(), StatusCode::OK); let bytes = test::read_body(response).await; @@ -486,28 +511,28 @@ mod tests { #[actix_rt::test] async fn test_files_not_allowed() { - let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + let srv = test::init_service(App::new().service(Files::new("/", "."))).await; let req = TestRequest::default() .uri("/Cargo.toml") .method(Method::POST) .to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + let srv = test::init_service(App::new().service(Files::new("/", "."))).await; let req = TestRequest::default() .method(Method::PUT) .uri("/Cargo.toml") .to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); } #[actix_rt::test] async fn test_files_guards() { - let mut srv = test::init_service( + let srv = test::init_service( App::new().service(Files::new("/", ".").use_guards(guard::Post())), ) .await; @@ -517,13 +542,13 @@ mod tests { .method(Method::POST) .to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_named_file_content_encoding() { - let mut srv = test::init_service(App::new().wrap(Compress::default()).service( + let srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| async { NamedFile::open("Cargo.toml") .unwrap() @@ -534,16 +559,16 @@ mod tests { let request = TestRequest::get() .uri("/") - .header(header::ACCEPT_ENCODING, "gzip") + .insert_header((header::ACCEPT_ENCODING, "gzip")) .to_request(); - let res = test::call_service(&mut srv, request).await; + let res = test::call_service(&srv, request).await; assert_eq!(res.status(), StatusCode::OK); assert!(!res.headers().contains_key(header::CONTENT_ENCODING)); } #[actix_rt::test] async fn test_named_file_content_encoding_gzip() { - let mut srv = test::init_service(App::new().wrap(Compress::default()).service( + let srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| async { NamedFile::open("Cargo.toml") .unwrap() @@ -554,9 +579,9 @@ mod tests { let request = TestRequest::get() .uri("/") - .header(header::ACCEPT_ENCODING, "gzip") + .insert_header((header::ACCEPT_ENCODING, "gzip")) .to_request(); - let res = test::call_service(&mut srv, request).await; + let res = test::call_service(&srv, request).await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.headers() @@ -578,27 +603,25 @@ mod tests { #[actix_rt::test] async fn test_static_files() { - let mut srv = test::init_service( - App::new().service(Files::new("/", ".").show_files_listing()), - ) - .await; + let srv = + test::init_service(App::new().service(Files::new("/", ".").show_files_listing())) + .await; let req = TestRequest::with_uri("/missing").to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + let srv = test::init_service(App::new().service(Files::new("/", "."))).await; let req = TestRequest::default().to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let mut srv = test::init_service( - App::new().service(Files::new("/", ".").show_files_listing()), - ) - .await; + let srv = + test::init_service(App::new().service(Files::new("/", ".").show_files_listing())) + .await; let req = TestRequest::with_uri("/tests").to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/html; charset=utf-8" @@ -611,16 +634,16 @@ mod tests { #[actix_rt::test] async fn test_redirect_to_slash_directory() { // should not redirect if no index - let mut srv = test::init_service( + let srv = test::init_service( App::new().service(Files::new("/", ".").redirect_to_slash_directory()), ) .await; let req = TestRequest::with_uri("/tests").to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::NOT_FOUND); // should redirect if index present - let mut srv = test::init_service( + let srv = test::init_service( App::new().service( Files::new("/", ".") .index_file("test.png") @@ -629,12 +652,12 @@ mod tests { ) .await; let req = TestRequest::with_uri("/tests").to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::FOUND); // should not redirect if the path is wrong let req = TestRequest::with_uri("/not_existing").to_request(); - let resp = test::call_service(&mut srv, req).await; + let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::NOT_FOUND); } @@ -646,7 +669,7 @@ mod tests { #[actix_rt::test] async fn test_default_handler_file_missing() { - let mut st = Files::new("/", ".") + let st = Files::new("/", ".") .default_handler(|req: ServiceRequest| { ok(req.into_response(HttpResponse::Ok().body("default content"))) }) @@ -655,7 +678,7 @@ mod tests { .unwrap(); let req = TestRequest::with_uri("/missing").to_srv_request(); - let resp = test::call_service(&mut st, req).await; + let resp = test::call_service(&st, req).await; assert_eq!(resp.status(), StatusCode::OK); let bytes = test::read_body(resp).await; assert_eq!(bytes, web::Bytes::from_static(b"default content")); @@ -724,54 +747,49 @@ mod tests { // ); // } - // #[actix_rt::test] - // fn integration_serve_index() { - // let mut srv = test::TestServer::with_factory(|| { - // App::new().handler( - // "test", - // Files::new(".").index_file("Cargo.toml"), - // ) - // }); + #[actix_rt::test] + async fn integration_serve_index() { + let srv = test::init_service( + App::new().service(Files::new("test", ".").index_file("Cargo.toml")), + ) + .await; - // let request = srv.get().uri(srv.url("/test")).finish().unwrap(); - // let response = srv.execute(request.send()).unwrap(); - // assert_eq!(response.status(), StatusCode::OK); - // let bytes = srv.execute(response.body()).unwrap(); - // let data = Bytes::from(fs::read("Cargo.toml").unwrap()); - // assert_eq!(bytes, data); + let req = TestRequest::get().uri("/test").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); - // let request = srv.get().uri(srv.url("/test/")).finish().unwrap(); - // let response = srv.execute(request.send()).unwrap(); - // assert_eq!(response.status(), StatusCode::OK); - // let bytes = srv.execute(response.body()).unwrap(); - // let data = Bytes::from(fs::read("Cargo.toml").unwrap()); - // assert_eq!(bytes, data); + let bytes = test::read_body(res).await; - // // nonexistent index file - // let request = srv.get().uri(srv.url("/test/unknown")).finish().unwrap(); - // let response = srv.execute(request.send()).unwrap(); - // assert_eq!(response.status(), StatusCode::NOT_FOUND); + let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + assert_eq!(bytes, data); - // let request = srv.get().uri(srv.url("/test/unknown/")).finish().unwrap(); - // let response = srv.execute(request.send()).unwrap(); - // assert_eq!(response.status(), StatusCode::NOT_FOUND); - // } + let req = TestRequest::get().uri("/test/").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); - // #[actix_rt::test] - // fn integration_percent_encoded() { - // let mut srv = test::TestServer::with_factory(|| { - // App::new().handler( - // "test", - // Files::new(".").index_file("Cargo.toml"), - // ) - // }); + let bytes = test::read_body(res).await; + let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + assert_eq!(bytes, data); - // let request = srv - // .get() - // .uri(srv.url("/test/%43argo.toml")) - // .finish() - // .unwrap(); - // let response = srv.execute(request.send()).unwrap(); - // assert_eq!(response.status(), StatusCode::OK); - // } + // nonexistent index file + let req = TestRequest::get().uri("/test/unknown").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::get().uri("/test/unknown/").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn integration_percent_encoded() { + let srv = test::init_service( + App::new().service(Files::new("test", ".").index_file("Cargo.toml")), + ) + .await; + + let req = TestRequest::get().uri("/test/%43argo.toml").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + } } diff --git a/actix-files/src/named.rs b/actix-files/src/named.rs index a9b95bad1..a688b2e6c 100644 --- a/actix-files/src/named.rs +++ b/actix-files/src/named.rs @@ -11,15 +11,13 @@ use actix_web::{ dev::{BodyEncoding, SizedStream}, http::{ header::{ - self, Charset, ContentDisposition, DispositionParam, DispositionType, - ExtendedValue, + self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, }, ContentEncoding, StatusCode, }, - Error, HttpMessage, HttpRequest, HttpResponse, Responder, + HttpMessage, HttpRequest, HttpResponse, Responder, }; use bitflags::bitflags; -use futures_util::future::{ready, Ready}; use mime_guess::from_path; use crate::ChunkedReadFile; @@ -277,37 +275,31 @@ impl NamedFile { } /// Creates an `HttpResponse` with file as a streaming body. - pub fn into_response(self, req: &HttpRequest) -> Result { + pub fn into_response(self, req: &HttpRequest) -> HttpResponse { if self.status_code != StatusCode::OK { let mut res = HttpResponse::build(self.status_code); if self.flags.contains(Flags::PREFER_UTF8) { let ct = equiv_utf8_text(self.content_type.clone()); - res.header(header::CONTENT_TYPE, ct.to_string()); + res.insert_header((header::CONTENT_TYPE, ct.to_string())); } else { - res.header(header::CONTENT_TYPE, self.content_type.to_string()); + res.insert_header((header::CONTENT_TYPE, self.content_type.to_string())); } if self.flags.contains(Flags::CONTENT_DISPOSITION) { - res.header( + res.insert_header(( header::CONTENT_DISPOSITION, self.content_disposition.to_string(), - ); + )); } if let Some(current_encoding) = self.encoding { res.encoding(current_encoding); } - let reader = ChunkedReadFile { - size: self.md.len(), - offset: 0, - file: Some(self.file), - fut: None, - counter: 0, - }; + let reader = ChunkedReadFile::new(self.md.len(), 0, self.file); - return Ok(res.streaming(reader)); + return res.streaming(reader); } let etag = if self.flags.contains(Flags::ETAG) { @@ -332,7 +324,7 @@ impl NamedFile { let t2: SystemTime = since.clone().into(); match (t1.duration_since(UNIX_EPOCH), t2.duration_since(UNIX_EPOCH)) { - (Ok(t1), Ok(t2)) => t1 > t2, + (Ok(t1), Ok(t2)) => t1.as_secs() > t2.as_secs(), _ => false, } } else { @@ -351,7 +343,7 @@ impl NamedFile { let t2: SystemTime = since.clone().into(); match (t1.duration_since(UNIX_EPOCH), t2.duration_since(UNIX_EPOCH)) { - (Ok(t1), Ok(t2)) => t1 <= t2, + (Ok(t1), Ok(t2)) => t1.as_secs() <= t2.as_secs(), _ => false, } } else { @@ -362,16 +354,16 @@ impl NamedFile { if self.flags.contains(Flags::PREFER_UTF8) { let ct = equiv_utf8_text(self.content_type.clone()); - resp.header(header::CONTENT_TYPE, ct.to_string()); + resp.insert_header((header::CONTENT_TYPE, ct.to_string())); } else { - resp.header(header::CONTENT_TYPE, self.content_type.to_string()); + resp.insert_header((header::CONTENT_TYPE, self.content_type.to_string())); } if self.flags.contains(Flags::CONTENT_DISPOSITION) { - resp.header( + resp.insert_header(( header::CONTENT_DISPOSITION, self.content_disposition.to_string(), - ); + )); } // default compressing @@ -380,14 +372,14 @@ impl NamedFile { } if let Some(lm) = last_modified { - resp.header(header::LAST_MODIFIED, lm.to_string()); + resp.insert_header((header::LAST_MODIFIED, lm.to_string())); } if let Some(etag) = etag { - resp.header(header::ETAG, etag.to_string()); + resp.insert_header((header::ETAG, etag.to_string())); } - resp.header(header::ACCEPT_RANGES, "bytes"); + resp.insert_header((header::ACCEPT_RANGES, "bytes")); let mut length = self.md.len(); let mut offset = 0; @@ -400,43 +392,32 @@ impl NamedFile { offset = ranges[0].start; resp.encoding(ContentEncoding::Identity); - resp.header( + resp.insert_header(( header::CONTENT_RANGE, - format!( - "bytes {}-{}/{}", - offset, - offset + length - 1, - self.md.len() - ), - ); + format!("bytes {}-{}/{}", offset, offset + length - 1, self.md.len()), + )); } else { - resp.header(header::CONTENT_RANGE, format!("bytes */{}", length)); - return Ok(resp.status(StatusCode::RANGE_NOT_SATISFIABLE).finish()); + resp.insert_header((header::CONTENT_RANGE, format!("bytes */{}", length))); + return resp.status(StatusCode::RANGE_NOT_SATISFIABLE).finish(); }; } else { - return Ok(resp.status(StatusCode::BAD_REQUEST).finish()); + return resp.status(StatusCode::BAD_REQUEST).finish(); }; }; if precondition_failed { - return Ok(resp.status(StatusCode::PRECONDITION_FAILED).finish()); + return resp.status(StatusCode::PRECONDITION_FAILED).finish(); } else if not_modified { - return Ok(resp.status(StatusCode::NOT_MODIFIED).finish()); + return resp.status(StatusCode::NOT_MODIFIED).finish(); } - let reader = ChunkedReadFile { - offset, - size: length, - file: Some(self.file), - fut: None, - counter: 0, - }; + let reader = ChunkedReadFile::new(length, offset, self.file); if offset != 0 || length != self.md.len() { resp.status(StatusCode::PARTIAL_CONTENT); } - Ok(resp.body(SizedStream::new(length, reader))) + resp.body(SizedStream::new(length, reader)) } } @@ -495,10 +476,7 @@ fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { } impl Responder for NamedFile { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, req: &HttpRequest) -> Self::Future { - ready(self.into_response(req)) + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + self.into_response(req) } } diff --git a/actix-files/src/range.rs b/actix-files/src/range.rs index e420ce414..8d9fe9445 100644 --- a/actix-files/src/range.rs +++ b/actix-files/src/range.rs @@ -10,9 +10,6 @@ pub struct HttpRange { pub length: u64, } -const PREFIX: &str = "bytes="; -const PREFIX_LEN: usize = 6; - #[derive(Debug, Clone, Display, Error)] #[display(fmt = "Parse HTTP Range failed")] pub struct ParseRangeErr(#[error(not(source))] ()); @@ -23,84 +20,16 @@ impl HttpRange { /// `header` is HTTP Range header (e.g. `bytes=bytes=0-9`). /// `size` is full size of response (file). pub fn parse(header: &str, size: u64) -> Result, ParseRangeErr> { - if header.is_empty() { - return Ok(Vec::new()); + match http_range::HttpRange::parse(header, size) { + Ok(ranges) => Ok(ranges + .iter() + .map(|range| HttpRange { + start: range.start, + length: range.length, + }) + .collect()), + Err(_) => Err(ParseRangeErr(())), } - if !header.starts_with(PREFIX) { - return Err(ParseRangeErr(())); - } - - let size_sig = size as i64; - let mut no_overlap = false; - - let all_ranges: Vec> = header[PREFIX_LEN..] - .split(',') - .map(|x| x.trim()) - .filter(|x| !x.is_empty()) - .map(|ra| { - let mut start_end_iter = ra.split('-'); - - let start_str = start_end_iter.next().ok_or(ParseRangeErr(()))?.trim(); - let end_str = start_end_iter.next().ok_or(ParseRangeErr(()))?.trim(); - - if start_str.is_empty() { - // If no start is specified, end specifies the - // range start relative to the end of the file. - let mut length: i64 = - end_str.parse().map_err(|_| ParseRangeErr(()))?; - - if length > size_sig { - length = size_sig; - } - - Ok(Some(HttpRange { - start: (size_sig - length) as u64, - length: length as u64, - })) - } else { - let start: i64 = start_str.parse().map_err(|_| ParseRangeErr(()))?; - - if start < 0 { - return Err(ParseRangeErr(())); - } - if start >= size_sig { - no_overlap = true; - return Ok(None); - } - - let length = if end_str.is_empty() { - // If no end is specified, range extends to end of the file. - size_sig - start - } else { - let mut end: i64 = - end_str.parse().map_err(|_| ParseRangeErr(()))?; - - if start > end { - return Err(ParseRangeErr(())); - } - - if end >= size_sig { - end = size_sig - 1; - } - - end - start + 1 - }; - - Ok(Some(HttpRange { - start: start as u64, - length: length as u64, - })) - } - }) - .collect::>()?; - - let ranges: Vec = all_ranges.into_iter().filter_map(|x| x).collect(); - - if no_overlap && ranges.is_empty() { - return Err(ParseRangeErr(())); - } - - Ok(ranges) } } diff --git a/actix-files/src/service.rs b/actix-files/src/service.rs index dc4f2bd2c..14eea6ebc 100644 --- a/actix-files/src/service.rs +++ b/actix-files/src/service.rs @@ -1,9 +1,4 @@ -use std::{ - fmt, io, - path::PathBuf, - rc::Rc, - task::{Context, Poll}, -}; +use std::{fmt, io, path::PathBuf, rc::Rc, task::Poll}; use actix_service::Service; use actix_web::{ @@ -16,8 +11,8 @@ use actix_web::{ use futures_util::future::{ok, Either, LocalBoxFuture, Ready}; use crate::{ - named, Directory, DirectoryRenderer, FilesError, HttpService, MimeOverride, - NamedFile, PathBufWrap, + named, Directory, DirectoryRenderer, FilesError, HttpService, MimeOverride, NamedFile, + PathBufWrap, }; /// Assembled file serving service. @@ -40,10 +35,10 @@ type FilesServiceFuture = Either< >; impl FilesService { - fn handle_err(&mut self, e: io::Error, req: ServiceRequest) -> FilesServiceFuture { + fn handle_err(&self, e: io::Error, req: ServiceRequest) -> FilesServiceFuture { log::debug!("Failed to handle {}: {}", req.path(), e); - if let Some(ref mut default) = self.default { + if let Some(ref default) = self.default { Either::Right(default.call(req)) } else { Either::Left(ok(req.error_response(e))) @@ -57,17 +52,14 @@ impl fmt::Debug for FilesService { } } -impl Service for FilesService { - type Request = ServiceRequest; +impl Service for FilesService { type Response = ServiceResponse; type Error = Error; type Future = FilesServiceFuture; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + actix_service::always_ready!(); - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { let is_method_valid = if let Some(guard) = &self.guards { // execute user defined guards (**guard).check(req.head()) @@ -79,7 +71,7 @@ impl Service for FilesService { if !is_method_valid { return Either::Left(ok(req.into_response( actix_web::HttpResponse::MethodNotAllowed() - .header(header::CONTENT_TYPE, "text/plain") + .insert_header(header::ContentType(mime::TEXT_PLAIN_UTF_8)) .body("Request did not meet this resource's requirements."), ))); } @@ -103,7 +95,7 @@ impl Service for FilesService { return Either::Left(ok(req.into_response( HttpResponse::Found() - .header(header::LOCATION, redirect_to) + .insert_header((header::LOCATION, redirect_to)) .body("") .into_body(), ))); @@ -121,10 +113,8 @@ impl Service for FilesService { named_file.flags = self.file_flags; let (req, _) = req.into_parts(); - Either::Left(ok(match named_file.into_response(&req) { - Ok(item) => ServiceResponse::new(req, item), - Err(e) => ServiceResponse::from_err(e, req), - })) + let res = named_file.into_response(&req); + Either::Left(ok(ServiceResponse::new(req, res))) } Err(e) => self.handle_err(e, req), } @@ -148,19 +138,14 @@ impl Service for FilesService { match NamedFile::open(path) { Ok(mut named_file) => { if let Some(ref mime_override) = self.mime_override { - let new_disposition = - mime_override(&named_file.content_type.type_()); + let new_disposition = mime_override(&named_file.content_type.type_()); named_file.content_disposition.disposition = new_disposition; } named_file.flags = self.file_flags; let (req, _) = req.into_parts(); - match named_file.into_response(&req) { - Ok(item) => { - Either::Left(ok(ServiceResponse::new(req.clone(), item))) - } - Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), - } + let res = named_file.into_response(&req); + Either::Left(ok(ServiceResponse::new(req, res))) } Err(e) => self.handle_err(e, req), } diff --git a/actix-files/tests/encoding.rs b/actix-files/tests/encoding.rs index d7e01b305..d21d4f8fd 100644 --- a/actix-files/tests/encoding.rs +++ b/actix-files/tests/encoding.rs @@ -11,11 +11,10 @@ use actix_web::{ #[actix_rt::test] async fn test_utf8_file_contents() { // use default ISO-8859-1 encoding - let mut srv = - test::init_service(App::new().service(Files::new("/", "./tests"))).await; + let srv = test::init_service(App::new().service(Files::new("/", "./tests"))).await; let req = TestRequest::with_uri("/utf8.txt").to_request(); - let res = test::call_service(&mut srv, req).await; + let res = test::call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( @@ -24,13 +23,12 @@ async fn test_utf8_file_contents() { ); // prefer UTF-8 encoding - let mut srv = test::init_service( - App::new().service(Files::new("/", "./tests").prefer_utf8(true)), - ) - .await; + let srv = + test::init_service(App::new().service(Files::new("/", "./tests").prefer_utf8(true))) + .await; let req = TestRequest::with_uri("/utf8.txt").to_request(); - let res = test::call_service(&mut srv, req).await; + let res = test::call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( diff --git a/actix-http-test/CHANGES.md b/actix-http-test/CHANGES.md index 835b75ddc..2f47d700d 100644 --- a/actix-http-test/CHANGES.md +++ b/actix-http-test/CHANGES.md @@ -1,6 +1,16 @@ # Changes -## Unreleased - 2020-xx-xx +## Unreleased - 2021-xx-xx + + +## 3.0.0-beta.2 - 2021-02-10 +* No notable changes. + + +## 3.0.0-beta.1 - 2021-01-07 +* Update `bytes` to `1.0`. [#1813] + +[#1813]: https://github.com/actix/actix-web/pull/1813 ## 2.1.0 - 2020-11-25 diff --git a/actix-http-test/Cargo.toml b/actix-http-test/Cargo.toml index 8b23bef1c..8ec073661 100644 --- a/actix-http-test/Cargo.toml +++ b/actix-http-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http-test" -version = "2.1.0" +version = "3.0.0-beta.2" authors = ["Nikolay Kim "] description = "Various helpers for Actix applications to use during testing" readme = "README.md" @@ -26,31 +26,36 @@ path = "src/lib.rs" default = [] # openssl -openssl = ["open-ssl", "awc/openssl"] +openssl = ["tls-openssl", "awc/openssl"] [dependencies] -actix-service = "1.0.6" -actix-codec = "0.3.0" -actix-connect = "2.0.0" -actix-utils = "2.0.0" -actix-rt = "1.1.1" -actix-server = "1.0.0" -actix-testing = "1.0.0" -awc = "2.0.0" +actix-service = "2.0.0-beta.4" +actix-codec = "0.4.0-beta.1" +actix-tls = "3.0.0-beta.3" +actix-utils = "3.0.0-beta.2" +actix-rt = "2" +actix-server = "2.0.0-beta.3" +awc = { version = "3.0.0-beta.2", default-features = false } base64 = "0.13" -bytes = "0.5.3" -futures-core = { version = "0.3.5", default-features = false } -http = "0.2.0" +bytes = "1" +futures-core = { version = "0.3.7", default-features = false } +http = "0.2.2" log = "0.4" socket2 = "0.3" serde = "1.0" serde_json = "1.0" slab = "0.4" serde_urlencoded = "0.7" -time = { version = "0.2.7", default-features = false, features = ["std"] } -open-ssl = { version = "0.10", package = "openssl", optional = true } +time = { version = "0.2.23", default-features = false, features = ["std"] } +tls-openssl = { version = "0.10.9", package = "openssl", optional = true } + +[target.'cfg(windows)'.dependencies.tls-openssl] +version = "0.10.9" +package = "openssl" +features = ["vendored"] +optional = true [dev-dependencies] -actix-web = "3.0.0" -actix-http = "2.0.0" +actix-web = { version = "4.0.0-beta.3", default-features = false, features = ["cookies"] } +actix-http = "3.0.0-beta.3" diff --git a/actix-http-test/README.md b/actix-http-test/README.md index bca9a7976..66f15979d 100644 --- a/actix-http-test/README.md +++ b/actix-http-test/README.md @@ -4,7 +4,7 @@ [![crates.io](https://img.shields.io/crates/v/actix-http-test?label=latest)](https://crates.io/crates/actix-http-test) [![Documentation](https://docs.rs/actix-http-test/badge.svg?version=2.1.0)](https://docs.rs/actix-http-test/2.1.0) -![Apache 2.0 or MIT licensed](https://img.shields.io/crates/l/actix-http-test) +![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-http-test) [![Dependency Status](https://deps.rs/crate/actix-http-test/2.1.0/status.svg)](https://deps.rs/crate/actix-http-test/2.1.0) [![Join the chat at https://gitter.im/actix/actix-web](https://badges.gitter.im/actix/actix-web.svg)](https://gitter.im/actix/actix-web?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) diff --git a/actix-http-test/src/lib.rs b/actix-http-test/src/lib.rs index 4d4bd6f2d..336dfe312 100644 --- a/actix-http-test/src/lib.rs +++ b/actix-http-test/src/lib.rs @@ -5,6 +5,9 @@ #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] #![clippy::msrv = "1.46"] +#[cfg(feature = "openssl")] +extern crate tls_openssl as openssl; + use std::sync::mpsc; use std::{net, thread, time}; @@ -17,8 +20,6 @@ use futures_core::stream::Stream; use http::Method; use socket2::{Domain, Protocol, Socket, Type}; -pub use actix_testing::*; - /// Start test server /// /// `TestServer` is very simple test server that simplify process of writing @@ -63,16 +64,19 @@ pub async fn test_server_with_addr>( // run server in separate thread thread::spawn(move || { - let sys = System::new("actix-test-server"); + let sys = System::new(); let local_addr = tcp.local_addr().unwrap(); - Server::build() + let srv = Server::build() .listen("test", tcp, factory)? .workers(1) - .disable_signals() - .start(); + .disable_signals(); + + sys.block_on(async { + srv.run(); + tx.send((System::current(), local_addr)).unwrap(); + }); - tx.send((System::current(), local_addr)).unwrap(); sys.run() }); @@ -82,7 +86,7 @@ pub async fn test_server_with_addr>( let connector = { #[cfg(feature = "openssl")] { - use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); @@ -93,20 +97,17 @@ pub async fn test_server_with_addr>( .conn_lifetime(time::Duration::from_secs(0)) .timeout(time::Duration::from_millis(30000)) .ssl(builder.build()) - .finish() } #[cfg(not(feature = "openssl"))] { Connector::new() .conn_lifetime(time::Duration::from_secs(0)) .timeout(time::Duration::from_millis(30000)) - .finish() } }; Client::builder().connector(connector).finish() }; - actix_connect::start_default_resolver().await.unwrap(); TestServer { addr, @@ -118,8 +119,7 @@ pub async fn test_server_with_addr>( /// Get first available unused address pub fn unused_addr() -> net::SocketAddr { let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - let socket = - Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp())).unwrap(); + let socket = Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp())).unwrap(); socket.bind(&addr.into()).unwrap(); socket.set_reuse_address(true).unwrap(); let tcp = socket.into_tcp_listener(); @@ -148,7 +148,7 @@ impl TestServer { } } - /// Construct test https server url + /// Construct test HTTPS server URL. pub fn surl(&self, uri: &str) -> String { if uri.starts_with('/') { format!("https://localhost:{}{}", self.addr.port(), uri) @@ -162,7 +162,7 @@ impl TestServer { self.client.get(self.url(path.as_ref()).as_str()) } - /// Create https `GET` request + /// Create HTTPS `GET` request pub fn sget>(&self, path: S) -> ClientRequest { self.client.get(self.surl(path.as_ref()).as_str()) } @@ -172,7 +172,7 @@ impl TestServer { self.client.post(self.url(path.as_ref()).as_str()) } - /// Create https `POST` request + /// Create HTTPS `POST` request pub fn spost>(&self, path: S) -> ClientRequest { self.client.post(self.surl(path.as_ref()).as_str()) } @@ -182,7 +182,7 @@ impl TestServer { self.client.head(self.url(path.as_ref()).as_str()) } - /// Create https `HEAD` request + /// Create HTTPS `HEAD` request pub fn shead>(&self, path: S) -> ClientRequest { self.client.head(self.surl(path.as_ref()).as_str()) } @@ -192,7 +192,7 @@ impl TestServer { self.client.put(self.url(path.as_ref()).as_str()) } - /// Create https `PUT` request + /// Create HTTPS `PUT` request pub fn sput>(&self, path: S) -> ClientRequest { self.client.put(self.surl(path.as_ref()).as_str()) } @@ -202,7 +202,7 @@ impl TestServer { self.client.patch(self.url(path.as_ref()).as_str()) } - /// Create https `PATCH` request + /// Create HTTPS `PATCH` request pub fn spatch>(&self, path: S) -> ClientRequest { self.client.patch(self.surl(path.as_ref()).as_str()) } @@ -212,7 +212,7 @@ impl TestServer { self.client.delete(self.url(path.as_ref()).as_str()) } - /// Create https `DELETE` request + /// Create HTTPS `DELETE` request pub fn sdelete>(&self, path: S) -> ClientRequest { self.client.delete(self.surl(path.as_ref()).as_str()) } @@ -222,12 +222,12 @@ impl TestServer { self.client.options(self.url(path.as_ref()).as_str()) } - /// Create https `OPTIONS` request + /// Create HTTPS `OPTIONS` request pub fn soptions>(&self, path: S) -> ClientRequest { self.client.options(self.surl(path.as_ref()).as_str()) } - /// Connect to test http server + /// Connect to test HTTP server pub fn request>(&self, method: Method, path: S) -> ClientRequest { self.client.request(method, path.as_ref()) } @@ -242,26 +242,24 @@ impl TestServer { response.body().limit(10_485_760).await } - /// Connect to websocket server at a given path + /// Connect to WebSocket server at a given path. pub async fn ws_at( &mut self, path: &str, - ) -> Result, awc::error::WsClientError> - { + ) -> Result, awc::error::WsClientError> { let url = self.url(path); let connect = self.client.ws(url).connect(); connect.await.map(|(_, framed)| framed) } - /// Connect to a websocket server + /// Connect to a WebSocket server. pub async fn ws( &mut self, - ) -> Result, awc::error::WsClientError> - { + ) -> Result, awc::error::WsClientError> { self.ws_at("/").await } - /// Stop http server + /// Stop HTTP server fn stop(&mut self) { self.system.stop(); } diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index c43246bb0..6ba111eb3 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,14 +1,99 @@ # Changes -## Unreleased - 2020-xx-xx +## Unreleased - 2021-xx-xx ### Changed -* Bumped `rand` to `0.8` +* Feature `cookies` is now optional and disabled by default. [#1981] + +### Removed +* re-export of `futures_channel::oneshot::Canceled` is removed from `error` mod. [#1994] +* `ResponseError` impl for `futures_channel::oneshot::Canceled` is removed. [#1994] + +[#1981]: https://github.com/actix/actix-web/pull/1981 +[#1994]: https://github.com/actix/actix-web/pull/1994 + + +## 3.0.0-beta.3 - 2021-02-10 +* No notable changes. + + +## 3.0.0-beta.2 - 2021-02-10 +### Added +* `IntoHeaderPair` trait that allows using typed and untyped headers in the same methods. [#1869] +* `ResponseBuilder::insert_header` method which allows using typed headers. [#1869] +* `ResponseBuilder::append_header` method which allows using typed headers. [#1869] +* `TestRequest::insert_header` method which allows using typed headers. [#1869] +* `ContentEncoding` implements all necessary header traits. [#1912] +* `HeaderMap::len_keys` has the behavior of the old `len` method. [#1964] +* `HeaderMap::drain` as an efficient draining iterator. [#1964] +* Implement `IntoIterator` for owned `HeaderMap`. [#1964] +* `trust-dns` optional feature to enable `trust-dns-resolver` as client dns resolver. [#1969] + +### Changed +* `ResponseBuilder::content_type` now takes an `impl IntoHeaderValue` to support using typed + `mime` types. [#1894] +* Renamed `IntoHeaderValue::{try_into => try_into_value}` to avoid ambiguity with std + `TryInto` trait. [#1894] +* `Extensions::insert` returns Option of replaced item. [#1904] +* Remove `HttpResponseBuilder::json2()`. [#1903] +* Enable `HttpResponseBuilder::json()` to receive data by value and reference. [#1903] +* `client::error::ConnectError` Resolver variant contains `Box` type. [#1905] +* `client::ConnectorConfig` default timeout changed to 5 seconds. [#1905] +* Simplify `BlockingError` type to a unit struct. It's now only triggered when blocking thread pool + is dead. [#1957] +* `HeaderMap::len` now returns number of values instead of number of keys. [#1964] +* `HeaderMap::insert` now returns iterator of removed values. [#1964] +* `HeaderMap::remove` now returns iterator of removed values. [#1964] + +### Removed +* `ResponseBuilder::set`; use `ResponseBuilder::insert_header`. [#1869] +* `ResponseBuilder::set_header`; use `ResponseBuilder::insert_header`. [#1869] +* `ResponseBuilder::header`; use `ResponseBuilder::append_header`. [#1869] +* `TestRequest::with_hdr`; use `TestRequest::default().insert_header()`. [#1869] +* `TestRequest::with_header`; use `TestRequest::default().insert_header()`. [#1869] +* `actors` optional feature. [#1969] +* `ResponseError` impl for `actix::MailboxError`. [#1969] + +### Documentation +* Vastly improve docs and add examples for `HeaderMap`. [#1964] + +[#1869]: https://github.com/actix/actix-web/pull/1869 +[#1894]: https://github.com/actix/actix-web/pull/1894 +[#1903]: https://github.com/actix/actix-web/pull/1903 +[#1904]: https://github.com/actix/actix-web/pull/1904 +[#1905]: https://github.com/actix/actix-web/pull/1905 +[#1912]: https://github.com/actix/actix-web/pull/1912 +[#1957]: https://github.com/actix/actix-web/pull/1957 +[#1964]: https://github.com/actix/actix-web/pull/1964 +[#1969]: https://github.com/actix/actix-web/pull/1969 + + +## 3.0.0-beta.1 - 2021-01-07 +### Added +* Add `Http3` to `Protocol` enum for future compatibility and also mark `#[non_exhaustive]`. + +### Changed +* Update `actix-*` dependencies to tokio `1.0` based versions. [#1813] +* Bumped `rand` to `0.8`. +* Update `bytes` to `1.0`. [#1813] +* Update `h2` to `0.3`. [#1813] +* The `ws::Message::Text` enum variant now contains a `bytestring::ByteString`. [#1864] ### Removed * Deprecated `on_connect` methods have been removed. Prefer the new `on_connect_ext` technique. [#1857] +* Remove `ResponseError` impl for `actix::actors::resolver::ResolverError` + due to deprecate of resolver actor. [#1813] +* Remove `ConnectError::SslHandshakeError` and re-export of `HandshakeError`. + due to the removal of this type from `tokio-openssl` crate. openssl handshake + error would return as `ConnectError::SslError`. [#1813] +* Remove `actix-threadpool` dependency. Use `actix_rt::task::spawn_blocking`. + Due to this change `actix_threadpool::BlockingError` type is moved into + `actix_http::error` module. [#1878] +[#1813]: https://github.com/actix/actix-web/pull/1813 [#1857]: https://github.com/actix/actix-web/pull/1857 +[#1864]: https://github.com/actix/actix-web/pull/1864 +[#1878]: https://github.com/actix/actix-web/pull/1878 ## 2.2.0 - 2020-11-25 @@ -55,15 +140,14 @@ * Update actix-connect and actix-tls dependencies. -## [2.0.0-beta.3] - 2020-08-14 - +## 2.0.0-beta.3 - 2020-08-14 ### Fixed * Memory leak of `client::pool::ConnectorPoolSupport`. [#1626] [#1626]: https://github.com/actix/actix-web/pull/1626 -## [2.0.0-beta.2] - 2020-07-21 +## 2.0.0-beta.2 - 2020-07-21 ### Fixed * Potential UB in h1 decoder using uninitialized memory. [#1614] @@ -74,10 +158,8 @@ [#1615]: https://github.com/actix/actix-web/pull/1615 -## [2.0.0-beta.1] - 2020-07-11 - +## 2.0.0-beta.1 - 2020-07-11 ### Changed - * Migrate cookie handling to `cookie` crate. [#1558] * Update `sha-1` to 0.9. [#1586] * Fix leak in client pool. [#1580] @@ -87,33 +169,30 @@ [#1586]: https://github.com/actix/actix-web/pull/1586 [#1580]: https://github.com/actix/actix-web/pull/1580 -## [2.0.0-alpha.4] - 2020-05-21 +## 2.0.0-alpha.4 - 2020-05-21 ### Changed - * Bump minimum supported Rust version to 1.40 -* content_length function is removed, and you can set Content-Length by calling no_chunking function [#1439] +* content_length function is removed, and you can set Content-Length by calling + no_chunking function [#1439] * `BodySize::Sized64` variant has been removed. `BodySize::Sized` now receives a `u64` instead of a `usize`. * Update `base64` dependency to 0.12 ### Fixed - * Support parsing of `SameSite=None` [#1503] [#1439]: https://github.com/actix/actix-web/pull/1439 [#1503]: https://github.com/actix/actix-web/pull/1503 -## [2.0.0-alpha.3] - 2020-05-08 +## 2.0.0-alpha.3 - 2020-05-08 ### Fixed - * Correct spelling of ConnectError::Unresolved [#1487] * Fix a mistake in the encoding of websocket continuation messages wherein Item::FirstText and Item::FirstBinary are each encoded as the other. ### Changed - * Implement `std::error::Error` for our custom errors [#1422] * Remove `failure` support for `ResponseError` since that crate will be deprecated in the near future. @@ -121,338 +200,247 @@ [#1422]: https://github.com/actix/actix-web/pull/1422 [#1487]: https://github.com/actix/actix-web/pull/1487 -## [2.0.0-alpha.2] - 2020-03-07 +## 2.0.0-alpha.2 - 2020-03-07 ### Changed - * Update `actix-connect` and `actix-tls` dependency to 2.0.0-alpha.1. [#1395] - -* Change default initial window size and connection window size for HTTP2 to 2MB and 1MB respectively - to improve download speed for awc when downloading large objects. [#1394] - -* client::Connector accepts initial_window_size and initial_connection_window_size HTTP2 configuration. [#1394] - +* Change default initial window size and connection window size for HTTP2 to 2MB and 1MB + respectively to improve download speed for awc when downloading large objects. [#1394] +* client::Connector accepts initial_window_size and initial_connection_window_size + HTTP2 configuration. [#1394] * client::Connector allowing to set max_http_version to limit HTTP version to be used. [#1394] [#1394]: https://github.com/actix/actix-web/pull/1394 [#1395]: https://github.com/actix/actix-web/pull/1395 -## [2.0.0-alpha.1] - 2020-02-27 +## 2.0.0-alpha.1 - 2020-02-27 ### Changed - * Update the `time` dependency to 0.2.7. * Moved actors messages support from actix crate, enabled with feature `actors`. -* Breaking change: trait MessageBody requires Unpin and accepting Pin<&mut Self> instead of &mut self in the poll_next(). +* Breaking change: trait MessageBody requires Unpin and accepting `Pin<&mut Self>` instead of + `&mut self` in the poll_next(). * MessageBody is not implemented for &'static [u8] anymore. ### Fixed - * Allow `SameSite=None` cookies to be sent in a response. -## [1.0.1] - 2019-12-20 +## 1.0.1 - 2019-12-20 ### Fixed - * Poll upgrade service's readiness from HTTP service handlers - * Replace brotli with brotli2 #1224 -## [1.0.0] - 2019-12-13 +## 1.0.0 - 2019-12-13 ### Added - * Add websockets continuation frame support ### Changed - * Replace `flate2-xxx` features with `compress` -## [1.0.0-alpha.5] - 2019-12-09 +## 1.0.0-alpha.5 - 2019-12-09 ### Fixed - * Check `Upgrade` service readiness before calling it - -* Fix buffer remaining capacity calcualtion +* Fix buffer remaining capacity calculation ### Changed - * Websockets: Ping and Pong should have binary data #1049 -## [1.0.0-alpha.4] - 2019-12-08 +## 1.0.0-alpha.4 - 2019-12-08 ### Added - * Add impl ResponseBuilder for Error ### Changed - * Use rust based brotli compression library -## [1.0.0-alpha.3] - 2019-12-07 - +## 1.0.0-alpha.3 - 2019-12-07 ### Changed - * Migrate to tokio 0.2 - * Migrate to `std::future` -## [0.2.11] - 2019-11-06 - +## 0.2.11 - 2019-11-06 ### Added - * Add support for serde_json::Value to be passed as argument to ResponseBuilder.body() - -* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) - +* Add an additional `filename*` param in the `Content-Disposition` header of + `actix_files::NamedFile` to be more compatible. (#1151) * Allow to use `std::convert::Infallible` as `actix_http::error::Error` ### Fixed +* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; + charset=utf-8` header [#1118] -* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header #1118 +[#1878]: https://github.com/actix/actix-web/pull/1878 -## [0.2.10] - 2019-09-11 - +## 0.2.10 - 2019-09-11 ### Added - -* Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests with `RequestHead` +* Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests + with `RequestHead` ### Fixed - * h2 will use error response #1080 - * on_connect result isn't added to request extensions for http2 requests #1009 -## [0.2.9] - 2019-08-13 - +## 0.2.9 - 2019-08-13 ### Changed - * Dropped the `byteorder`-dependency in favor of `stdlib`-implementation - * Update percent-encoding to 2.1 - * Update serde_urlencoded to 0.6.1 ### Fixed - * Fixed a panic in the HTTP2 handshake in client HTTP requests (#1031) -## [0.2.8] - 2019-08-01 - +## 0.2.8 - 2019-08-01 ### Added - * Add `rustls` support - * Add `Clone` impl for `HeaderMap` ### Fixed - * awc client panic #1016 - -* Invalid response with compression middleware enabled, but compression-related features disabled #997 +* Invalid response with compression middleware enabled, but compression-related features + disabled #997 -## [0.2.7] - 2019-07-18 - +## 0.2.7 - 2019-07-18 ### Added - * Add support for downcasting response errors #986 -## [0.2.6] - 2019-07-17 - +## 0.2.6 - 2019-07-17 ### Changed - * Replace `ClonableService` with local copy - * Upgrade `rand` dependency version to 0.7 -## [0.2.5] - 2019-06-28 - +## 0.2.5 - 2019-06-28 ### Added - * Add `on-connect` callback, `HttpServiceBuilder::on_connect()` #946 ### Changed - * Use `encoding_rs` crate instead of unmaintained `encoding` crate - * Add `Copy` and `Clone` impls for `ws::Codec` -## [0.2.4] - 2019-06-16 - +## 0.2.4 - 2019-06-16 ### Fixed - * Do not compress NoContent (204) responses #918 -## [0.2.3] - 2019-06-02 - +## 0.2.3 - 2019-06-02 ### Added - * Debug impl for ResponseBuilder - * From SizedStream and BodyStream for Body ### Changed - * SizedStream uses u64 -## [0.2.2] - 2019-05-29 - +## 0.2.2 - 2019-05-29 ### Fixed - * Parse incoming stream before closing stream on disconnect #868 -## [0.2.1] - 2019-05-25 - +## 0.2.1 - 2019-05-25 ### Fixed - * Handle socket read disconnect -## [0.2.0] - 2019-05-12 - +## 0.2.0 - 2019-05-12 ### Changed - * Update actix-service to 0.4 - * Expect and upgrade services accept `ServerConfig` config. ### Deleted - * `OneRequest` service -## [0.1.5] - 2019-05-04 - +## 0.1.5 - 2019-05-04 ### Fixed - * Clean up response extensions in response pool #817 -## [0.1.4] - 2019-04-24 - +## 0.1.4 - 2019-04-24 ### Added - * Allow to render h1 request headers in `Camel-Case` ### Fixed - * Read until eof for http/1.0 responses #771 -## [0.1.3] - 2019-04-23 - +## 0.1.3 - 2019-04-23 ### Fixed - * Fix http client pool management - * Fix http client wait queue management #794 -## [0.1.2] - 2019-04-23 - +## 0.1.2 - 2019-04-23 ### Fixed - * Fix BorrowMutError panic in client connector #793 -## [0.1.1] - 2019-04-19 - +## 0.1.1 - 2019-04-19 ### Changed - * Cookie::max_age() accepts value in seconds - * Cookie::max_age_time() accepts value in time::Duration - * Allow to specify server address for client connector -## [0.1.0] - 2019-04-16 - +## 0.1.0 - 2019-04-16 ### Added - * Expose peer addr via `Request::peer_addr()` and `RequestHead::peer_addr` ### Changed - * `actix_http::encoding` always available - * use trust-dns-resolver 0.11.0 -## [0.1.0-alpha.5] - 2019-04-12 - +## 0.1.0-alpha.5 - 2019-04-12 ### Added - * Allow to use custom service for upgrade requests - * Added `h1::SendResponse` future. ### Changed - * MessageBody::length() renamed to MessageBody::size() for consistency - * ws handshake verification functions take RequestHead instead of Request -## [0.1.0-alpha.4] - 2019-04-08 - +## 0.1.0-alpha.4 - 2019-04-08 ### Added - * Allow to use custom `Expect` handler - * Add minimal `std::error::Error` impl for `Error` ### Changed - * Export IntoHeaderValue - * Render error and return as response body - -* Use thread pool for response body comression +* Use thread pool for response body compression ### Deleted - * Removed PayloadBuffer -## [0.1.0-alpha.3] - 2019-04-02 - +## 0.1.0-alpha.3 - 2019-04-02 ### Added - * Warn when an unsealed private cookie isn't valid UTF-8 ### Fixed - * Rust 1.31.0 compatibility - * Preallocate read buffer for h1 codec - * Detect socket disconnection during protocol selection -## [0.1.0-alpha.2] - 2019-03-29 - +## 0.1.0-alpha.2 - 2019-03-29 ### Added - * Added ws::Message::Nop, no-op websockets message ### Changed - -* Do not use thread pool for decomression if chunk size is smaller than 2048. +* Do not use thread pool for decompression if chunk size is smaller than 2048. -## [0.1.0-alpha.1] - 2019-03-28 - +## 0.1.0-alpha.1 - 2019-03-28 * Initial impl diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 7cf344487..78fb55079 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http" -version = "2.2.0" +version = "3.0.0-beta.3" authors = ["Nikolay Kim "] description = "HTTP primitives for the Actix ecosystem" readme = "README.md" @@ -15,7 +15,8 @@ license = "MIT OR Apache-2.0" edition = "2018" [package.metadata.docs.rs] -features = ["openssl", "rustls", "compress", "secure-cookies", "actors"] +# features that docs.rs will build with +features = ["openssl", "rustls", "compress", "cookies", "secure-cookies"] [lib] name = "actix_http" @@ -25,49 +26,47 @@ path = "src/lib.rs" default = [] # openssl -openssl = ["actix-tls/openssl", "actix-connect/openssl"] +openssl = ["actix-tls/openssl"] # rustls support -rustls = ["actix-tls/rustls", "actix-connect/rustls"] +rustls = ["actix-tls/rustls"] -# enable compressison support +# enable compression support compress = ["flate2", "brotli2"] -# support for secure cookies -secure-cookies = ["cookie/secure"] +# support for cookies +cookies = ["cookie"] -# support for actix Actor messages -actors = ["actix"] +# support for secure cookies +secure-cookies = ["cookies", "cookie/secure"] + +# trust-dns as client dns resolver +trust-dns = ["trust-dns-resolver"] [dependencies] -actix-service = "1.0.6" -actix-codec = "0.3.0" -actix-connect = "2.0.0" -actix-utils = "2.0.0" -actix-rt = "1.0.0" -actix-threadpool = "0.3.1" -actix-tls = { version = "2.0.0", optional = true } -actix = { version = "0.10.0", optional = true } +actix-service = "2.0.0-beta.4" +actix-codec = "0.4.0-beta.1" +actix-utils = "3.0.0-beta.2" +actix-rt = "2" +actix-tls = "3.0.0-beta.2" +ahash = "0.7" base64 = "0.13" bitflags = "1.2" -bytes = "0.5.3" -cookie = { version = "0.14.1", features = ["percent-encode"] } -copyless = "0.1.4" -derive_more = "0.99.2" -either = "1.5.3" +bytes = "1" +bytestring = "1" +cfg-if = "1" +cookie = { version = "0.14.1", features = ["percent-encode"], optional = true } +derive_more = "0.99.5" encoding_rs = "0.8" -futures-channel = { version = "0.3.5", default-features = false } -futures-core = { version = "0.3.5", default-features = false } -futures-util = { version = "0.3.5", default-features = false } -fxhash = "0.2.1" -h2 = "0.2.1" -http = "0.2.0" +futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } +futures-util = { version = "0.3.7", default-features = false, features = ["alloc", "sink"] } +h2 = "0.3.0" +http = "0.2.2" httparse = "1.3" -indexmap = "1.3" itoa = "0.4" -lazy_static = "1.4" language-tags = "0.2" +lazy_static = "1.4" log = "0.4" mime = "0.3" percent-encoding = "2.1" @@ -76,28 +75,36 @@ rand = "0.8" regex = "1.3" serde = "1.0" serde_json = "1.0" -sha-1 = "0.9" -slab = "0.4" serde_urlencoded = "0.7" -time = { version = "0.2.7", default-features = false, features = ["std"] } +sha-1 = "0.9" +smallvec = "1.6" +time = { version = "0.2.23", default-features = false, features = ["std"] } +tokio = { version = "1.2", features = ["sync"] } # compression brotli2 = { version="0.3.2", optional = true } flate2 = { version = "1.0.13", optional = true } +trust-dns-resolver = { version = "0.20.0", optional = true } + [dev-dependencies] -actix-server = "1.0.1" -actix-connect = { version = "2.0.0", features = ["openssl"] } -actix-http-test = { version = "2.0.0", features = ["openssl"] } -actix-tls = { version = "2.0.0", features = ["openssl"] } +actix-server = "2.0.0-beta.3" +actix-http-test = { version = "3.0.0-beta.2", features = ["openssl"] } +actix-tls = { version = "3.0.0-beta.2", features = ["openssl"] } criterion = "0.3" -env_logger = "0.7" +env_logger = "0.8" +rcgen = "0.8" serde_derive = "1.0" -open-ssl = { version="0.10", package = "openssl" } -rust-tls = { version="0.18", package = "rustls" } +tls-openssl = { version = "0.10", package = "openssl" } +tls-rustls = { version = "0.19", package = "rustls" } + +[target.'cfg(windows)'.dev-dependencies.tls-openssl] +version = "0.10.9" +package = "openssl" +features = ["vendored"] [[bench]] -name = "content-length" +name = "write-camel-case" harness = false [[bench]] diff --git a/actix-http/README.md b/actix-http/README.md index 9dfb85e24..881fbc8c5 100644 --- a/actix-http/README.md +++ b/actix-http/README.md @@ -3,10 +3,13 @@ > HTTP primitives for the Actix ecosystem. [![crates.io](https://img.shields.io/crates/v/actix-http?label=latest)](https://crates.io/crates/actix-http) -[![Documentation](https://docs.rs/actix-http/badge.svg?version=2.2.0)](https://docs.rs/actix-http/2.2.0) -![Apache 2.0 or MIT licensed](https://img.shields.io/crates/l/actix-http) -[![Dependency Status](https://deps.rs/crate/actix-http/2.2.0/status.svg)](https://deps.rs/crate/actix-http/2.2.0) -[![Join the chat at https://gitter.im/actix/actix-web](https://badges.gitter.im/actix/actix-web.svg)](https://gitter.im/actix/actix-web?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Documentation](https://docs.rs/actix-http/badge.svg?version=3.0.0-beta.3)](https://docs.rs/actix-http/3.0.0-beta.3) +[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-http.svg) +
+[![dependency status](https://deps.rs/crate/actix-http/3.0.0-beta.3/status.svg)](https://deps.rs/crate/actix-http/3.0.0-beta.3) +[![Download](https://img.shields.io/crates/d/actix-http.svg)](https://crates.io/crates/actix-http) +[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ## Documentation & Resources diff --git a/actix-http/benches/content-length.rs b/actix-http/benches/content-length.rs deleted file mode 100644 index 18a55a33b..000000000 --- a/actix-http/benches/content-length.rs +++ /dev/null @@ -1,291 +0,0 @@ -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; - -use bytes::BytesMut; - -// benchmark sending all requests at the same time -fn bench_write_content_length(c: &mut Criterion) { - let mut group = c.benchmark_group("write_content_length"); - - let sizes = [ - 0, 1, 11, 83, 101, 653, 1001, 6323, 10001, 56329, 100001, 123456, 98724245, - 4294967202, - ]; - - for i in sizes.iter() { - group.bench_with_input(BenchmarkId::new("Original (unsafe)", i), i, |b, &i| { - b.iter(|| { - let mut b = BytesMut::with_capacity(35); - _original::write_content_length(i, &mut b) - }) - }); - - group.bench_with_input(BenchmarkId::new("New (safe)", i), i, |b, &i| { - b.iter(|| { - let mut b = BytesMut::with_capacity(35); - _new::write_content_length(i, &mut b) - }) - }); - - group.bench_with_input(BenchmarkId::new("itoa", i), i, |b, &i| { - b.iter(|| { - let mut b = BytesMut::with_capacity(35); - _itoa::write_content_length(i, &mut b) - }) - }); - } - - group.finish(); -} - -criterion_group!(benches, bench_write_content_length); -criterion_main!(benches); - -mod _itoa { - use bytes::{BufMut, BytesMut}; - - pub fn write_content_length(n: usize, bytes: &mut BytesMut) { - if n == 0 { - bytes.put_slice(b"\r\ncontent-length: 0\r\n"); - return; - } - - let mut buf = itoa::Buffer::new(); - - bytes.put_slice(b"\r\ncontent-length: "); - bytes.put_slice(buf.format(n).as_bytes()); - bytes.put_slice(b"\r\n"); - } -} - -mod _new { - use bytes::{BufMut, BytesMut}; - - const DIGITS_START: u8 = b'0'; - - /// NOTE: bytes object has to contain enough space - pub fn write_content_length(n: usize, bytes: &mut BytesMut) { - if n == 0 { - bytes.put_slice(b"\r\ncontent-length: 0\r\n"); - return; - } - - bytes.put_slice(b"\r\ncontent-length: "); - - if n < 10 { - bytes.put_u8(DIGITS_START + (n as u8)); - } else if n < 100 { - let n = n as u8; - - let d10 = n / 10; - let d1 = n % 10; - - bytes.put_u8(DIGITS_START + d10); - bytes.put_u8(DIGITS_START + d1); - } else if n < 1000 { - let n = n as u16; - - let d100 = (n / 100) as u8; - let d10 = ((n / 10) % 10) as u8; - let d1 = (n % 10) as u8; - - bytes.put_u8(DIGITS_START + d100); - bytes.put_u8(DIGITS_START + d10); - bytes.put_u8(DIGITS_START + d1); - } else if n < 10_000 { - let n = n as u16; - - let d1000 = (n / 1000) as u8; - let d100 = ((n / 100) % 10) as u8; - let d10 = ((n / 10) % 10) as u8; - let d1 = (n % 10) as u8; - - bytes.put_u8(DIGITS_START + d1000); - bytes.put_u8(DIGITS_START + d100); - bytes.put_u8(DIGITS_START + d10); - bytes.put_u8(DIGITS_START + d1); - } else if n < 100_000 { - let n = n as u32; - - let d10000 = (n / 10000) as u8; - let d1000 = ((n / 1000) % 10) as u8; - let d100 = ((n / 100) % 10) as u8; - let d10 = ((n / 10) % 10) as u8; - let d1 = (n % 10) as u8; - - bytes.put_u8(DIGITS_START + d10000); - bytes.put_u8(DIGITS_START + d1000); - bytes.put_u8(DIGITS_START + d100); - bytes.put_u8(DIGITS_START + d10); - bytes.put_u8(DIGITS_START + d1); - } else if n < 1_000_000 { - let n = n as u32; - - let d100000 = (n / 100000) as u8; - let d10000 = ((n / 10000) % 10) as u8; - let d1000 = ((n / 1000) % 10) as u8; - let d100 = ((n / 100) % 10) as u8; - let d10 = ((n / 10) % 10) as u8; - let d1 = (n % 10) as u8; - - bytes.put_u8(DIGITS_START + d100000); - bytes.put_u8(DIGITS_START + d10000); - bytes.put_u8(DIGITS_START + d1000); - bytes.put_u8(DIGITS_START + d100); - bytes.put_u8(DIGITS_START + d10); - bytes.put_u8(DIGITS_START + d1); - } else { - write_usize(n, bytes); - } - - bytes.put_slice(b"\r\n"); - } - - fn write_usize(n: usize, bytes: &mut BytesMut) { - let mut n = n; - - // 20 chars is max length of a usize (2^64) - // digits will be added to the buffer from lsd to msd - let mut buf = BytesMut::with_capacity(20); - - while n > 9 { - // "pop" the least-significant digit - let lsd = (n % 10) as u8; - - // remove the lsd from n - n = n / 10; - - buf.put_u8(DIGITS_START + lsd); - } - - // put msd to result buffer - bytes.put_u8(DIGITS_START + (n as u8)); - - // put, in reverse (msd to lsd), remaining digits to buffer - for i in (0..buf.len()).rev() { - bytes.put_u8(buf[i]); - } - } -} - -mod _original { - use std::{mem, ptr, slice}; - - use bytes::{BufMut, BytesMut}; - - const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\ - 2021222324252627282930313233343536373839\ - 4041424344454647484950515253545556575859\ - 6061626364656667686970717273747576777879\ - 8081828384858687888990919293949596979899"; - - /// NOTE: bytes object has to contain enough space - pub fn write_content_length(mut n: usize, bytes: &mut BytesMut) { - if n < 10 { - let mut buf: [u8; 21] = [ - b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', - b'e', b'n', b'g', b't', b'h', b':', b' ', b'0', b'\r', b'\n', - ]; - buf[18] = (n as u8) + b'0'; - bytes.put_slice(&buf); - } else if n < 100 { - let mut buf: [u8; 22] = [ - b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', - b'e', b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'\r', b'\n', - ]; - let d1 = n << 1; - unsafe { - ptr::copy_nonoverlapping( - DEC_DIGITS_LUT.as_ptr().add(d1), - buf.as_mut_ptr().offset(18), - 2, - ); - } - bytes.put_slice(&buf); - } else if n < 1000 { - let mut buf: [u8; 23] = [ - b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', - b'e', b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'0', b'\r', - b'\n', - ]; - // decode 2 more chars, if > 2 chars - let d1 = (n % 100) << 1; - n /= 100; - unsafe { - ptr::copy_nonoverlapping( - DEC_DIGITS_LUT.as_ptr().add(d1), - buf.as_mut_ptr().offset(19), - 2, - ) - }; - - // decode last 1 - buf[18] = (n as u8) + b'0'; - - bytes.put_slice(&buf); - } else { - bytes.put_slice(b"\r\ncontent-length: "); - convert_usize(n, bytes); - } - } - - pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) { - let mut curr: isize = 39; - let mut buf: [u8; 41] = unsafe { mem::MaybeUninit::uninit().assume_init() }; - buf[39] = b'\r'; - buf[40] = b'\n'; - let buf_ptr = buf.as_mut_ptr(); - let lut_ptr = DEC_DIGITS_LUT.as_ptr(); - - // eagerly decode 4 characters at a time - while n >= 10_000 { - let rem = (n % 10_000) as isize; - n /= 10_000; - - let d1 = (rem / 100) << 1; - let d2 = (rem % 100) << 1; - curr -= 4; - unsafe { - ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); - ptr::copy_nonoverlapping( - lut_ptr.offset(d2), - buf_ptr.offset(curr + 2), - 2, - ); - } - } - - // if we reach here numbers are <= 9999, so at most 4 chars long - let mut n = n as isize; // possibly reduce 64bit math - - // decode 2 more chars, if > 2 chars - if n >= 100 { - let d1 = (n % 100) << 1; - n /= 100; - curr -= 2; - unsafe { - ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); - } - } - - // decode last 1 or 2 chars - if n < 10 { - curr -= 1; - unsafe { - *buf_ptr.offset(curr) = (n as u8) + b'0'; - } - } else { - let d1 = n << 1; - curr -= 2; - unsafe { - ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); - } - } - - unsafe { - bytes.extend_from_slice(slice::from_raw_parts( - buf_ptr.offset(curr), - 41 - curr as usize, - )); - } - } -} diff --git a/actix-http/benches/status-line.rs b/actix-http/benches/status-line.rs index 51f840f89..f62d18ed8 100644 --- a/actix-http/benches/status-line.rs +++ b/actix-http/benches/status-line.rs @@ -176,7 +176,7 @@ mod _original { buf[5] = b'0'; buf[7] = b'9'; } - _ => (), + _ => {} } let mut curr: isize = 12; diff --git a/actix-http/benches/write-camel-case.rs b/actix-http/benches/write-camel-case.rs new file mode 100644 index 000000000..fa4930eb9 --- /dev/null +++ b/actix-http/benches/write-camel-case.rs @@ -0,0 +1,89 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; + +fn bench_write_camel_case(c: &mut Criterion) { + let mut group = c.benchmark_group("write_camel_case"); + + let names = ["connection", "Transfer-Encoding", "transfer-encoding"]; + + for &i in &names { + let bts = i.as_bytes(); + + group.bench_with_input(BenchmarkId::new("Original", i), bts, |b, bts| { + b.iter(|| { + let mut buf = black_box([0; 24]); + _original::write_camel_case(black_box(bts), &mut buf) + }); + }); + + group.bench_with_input(BenchmarkId::new("New", i), bts, |b, bts| { + b.iter(|| { + let mut buf = black_box([0; 24]); + _new::write_camel_case(black_box(bts), &mut buf) + }); + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_write_camel_case); +criterion_main!(benches); + +mod _new { + pub fn write_camel_case(value: &[u8], buffer: &mut [u8]) { + // first copy entire (potentially wrong) slice to output + buffer[..value.len()].copy_from_slice(value); + + let mut iter = value.iter(); + + // first character should be uppercase + if let Some(c @ b'a'..=b'z') = iter.next() { + buffer[0] = c & 0b1101_1111; + } + + // track 1 ahead of the current position since that's the location being assigned to + let mut index = 2; + + // remaining characters after hyphens should also be uppercase + while let Some(&c) = iter.next() { + if c == b'-' { + // advance iter by one and uppercase if needed + if let Some(c @ b'a'..=b'z') = iter.next() { + buffer[index] = c & 0b1101_1111; + } + } + + index += 1; + } + } +} + +mod _original { + pub fn write_camel_case(value: &[u8], buffer: &mut [u8]) { + let mut index = 0; + let key = value; + let mut key_iter = key.iter(); + + if let Some(c) = key_iter.next() { + if *c >= b'a' && *c <= b'z' { + buffer[index] = *c ^ b' '; + index += 1; + } + } else { + return; + } + + while let Some(c) = key_iter.next() { + buffer[index] = *c; + index += 1; + if *c == b'-' { + if let Some(c) = key_iter.next() { + if *c >= b'a' && *c <= b'z' { + buffer[index] = *c ^ b' '; + index += 1; + } + } + } + } + } +} diff --git a/actix-http/examples/echo.rs b/actix-http/examples/echo.rs index beb1cce2c..90d768cbe 100644 --- a/actix-http/examples/echo.rs +++ b/actix-http/examples/echo.rs @@ -26,7 +26,10 @@ async fn main() -> io::Result<()> { info!("request body: {:?}", body); Ok::<_, Error>( Response::Ok() - .header("x-head", HeaderValue::from_static("dummy value!")) + .insert_header(( + "x-head", + HeaderValue::from_static("dummy value!"), + )) .body(body), ) }) diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs index 5b7e504d3..bc932ce8f 100644 --- a/actix-http/examples/echo2.rs +++ b/actix-http/examples/echo2.rs @@ -15,7 +15,7 @@ async fn handle_request(mut req: Request) -> Result { info!("request body: {:?}", body); Ok(Response::Ok() - .header("x-head", HeaderValue::from_static("dummy value!")) + .insert_header(("x-head", HeaderValue::from_static("dummy value!"))) .body(body)) } diff --git a/actix-http/examples/hello-world.rs b/actix-http/examples/hello-world.rs index d6477b152..a84e9aac6 100644 --- a/actix-http/examples/hello-world.rs +++ b/actix-http/examples/hello-world.rs @@ -19,7 +19,10 @@ async fn main() -> io::Result<()> { .finish(|_req| { info!("{:?}", _req); let mut res = Response::Ok(); - res.header("x-head", HeaderValue::from_static("dummy value!")); + res.insert_header(( + "x-head", + HeaderValue::from_static("dummy value!"), + )); future::ok::<_, ()>(res.body("Hello world!")) }) .tcp() diff --git a/actix-http/src/body.rs b/actix-http/src/body.rs deleted file mode 100644 index c5d831c45..000000000 --- a/actix-http/src/body.rs +++ /dev/null @@ -1,723 +0,0 @@ -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{fmt, mem}; - -use bytes::{Bytes, BytesMut}; -use futures_core::Stream; -use futures_util::ready; -use pin_project::pin_project; - -use crate::error::Error; - -#[derive(Debug, PartialEq, Copy, Clone)] -/// Body size hint -pub enum BodySize { - None, - Empty, - Sized(u64), - Stream, -} - -impl BodySize { - pub fn is_eof(&self) -> bool { - matches!(self, BodySize::None | BodySize::Empty | BodySize::Sized(0)) - } -} - -/// Type that provides this trait can be streamed to a peer. -pub trait MessageBody { - fn size(&self) -> BodySize; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>>; - - downcast_get_type_id!(); -} - -downcast!(MessageBody); - -impl MessageBody for () { - fn size(&self) -> BodySize { - BodySize::Empty - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - Poll::Ready(None) - } -} - -impl MessageBody for Box { - fn size(&self) -> BodySize { - self.as_ref().size() - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(self.get_mut().as_mut()).poll_next(cx) - } -} - -#[pin_project(project = ResponseBodyProj)] -pub enum ResponseBody { - Body(#[pin] B), - Other(#[pin] Body), -} - -impl ResponseBody { - pub fn into_body(self) -> ResponseBody { - match self { - ResponseBody::Body(b) => ResponseBody::Other(b), - ResponseBody::Other(b) => ResponseBody::Other(b), - } - } -} - -impl ResponseBody { - pub fn take_body(&mut self) -> ResponseBody { - std::mem::replace(self, ResponseBody::Other(Body::None)) - } -} - -impl ResponseBody { - pub fn as_ref(&self) -> Option<&B> { - if let ResponseBody::Body(ref b) = self { - Some(b) - } else { - None - } - } -} - -impl MessageBody for ResponseBody { - fn size(&self) -> BodySize { - match self { - ResponseBody::Body(ref body) => body.size(), - ResponseBody::Other(ref body) => body.size(), - } - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.project() { - ResponseBodyProj::Body(body) => body.poll_next(cx), - ResponseBodyProj::Other(body) => body.poll_next(cx), - } - } -} - -impl Stream for ResponseBody { - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.project() { - ResponseBodyProj::Body(body) => body.poll_next(cx), - ResponseBodyProj::Other(body) => body.poll_next(cx), - } - } -} - -#[pin_project(project = BodyProj)] -/// Represents various types of http message body. -pub enum Body { - /// Empty response. `Content-Length` header is not set. - None, - /// Zero sized response body. `Content-Length` header is set to `0`. - Empty, - /// Specific response body. - Bytes(Bytes), - /// Generic message body. - Message(Box), -} - -impl Body { - /// Create body from slice (copy) - pub fn from_slice(s: &[u8]) -> Body { - Body::Bytes(Bytes::copy_from_slice(s)) - } - - /// Create body from generic message body. - pub fn from_message(body: B) -> Body { - Body::Message(Box::new(body)) - } -} - -impl MessageBody for Body { - fn size(&self) -> BodySize { - match self { - Body::None => BodySize::None, - Body::Empty => BodySize::Empty, - Body::Bytes(ref bin) => BodySize::Sized(bin.len() as u64), - Body::Message(ref body) => body.size(), - } - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.project() { - BodyProj::None => Poll::Ready(None), - BodyProj::Empty => Poll::Ready(None), - BodyProj::Bytes(ref mut bin) => { - let len = bin.len(); - if len == 0 { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(bin)))) - } - } - BodyProj::Message(ref mut body) => Pin::new(body.as_mut()).poll_next(cx), - } - } -} - -impl PartialEq for Body { - fn eq(&self, other: &Body) -> bool { - match *self { - Body::None => matches!(*other, Body::None), - Body::Empty => matches!(*other, Body::Empty), - Body::Bytes(ref b) => match *other { - Body::Bytes(ref b2) => b == b2, - _ => false, - }, - Body::Message(_) => false, - } - } -} - -impl fmt::Debug for Body { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - Body::None => write!(f, "Body::None"), - Body::Empty => write!(f, "Body::Empty"), - Body::Bytes(ref b) => write!(f, "Body::Bytes({:?})", b), - Body::Message(_) => write!(f, "Body::Message(_)"), - } - } -} - -impl From<&'static str> for Body { - fn from(s: &'static str) -> Body { - Body::Bytes(Bytes::from_static(s.as_ref())) - } -} - -impl From<&'static [u8]> for Body { - fn from(s: &'static [u8]) -> Body { - Body::Bytes(Bytes::from_static(s)) - } -} - -impl From> for Body { - fn from(vec: Vec) -> Body { - Body::Bytes(Bytes::from(vec)) - } -} - -impl From for Body { - fn from(s: String) -> Body { - s.into_bytes().into() - } -} - -impl<'a> From<&'a String> for Body { - fn from(s: &'a String) -> Body { - Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s))) - } -} - -impl From for Body { - fn from(s: Bytes) -> Body { - Body::Bytes(s) - } -} - -impl From for Body { - fn from(s: BytesMut) -> Body { - Body::Bytes(s.freeze()) - } -} - -impl From for Body { - fn from(v: serde_json::Value) -> Body { - Body::Bytes(v.to_string().into()) - } -} - -impl From> for Body -where - S: Stream> + Unpin + 'static, -{ - fn from(s: SizedStream) -> Body { - Body::from_message(s) - } -} - -impl From> for Body -where - S: Stream> + Unpin + 'static, - E: Into + 'static, -{ - fn from(s: BodyStream) -> Body { - Body::from_message(s) - } -} - -impl MessageBody for Bytes { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(self.get_mut())))) - } - } -} - -impl MessageBody for BytesMut { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(self.get_mut()).freeze()))) - } - } -} - -impl MessageBody for &'static str { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from_static( - mem::take(self.get_mut()).as_ref(), - )))) - } - } -} - -impl MessageBody for Vec { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from(mem::take(self.get_mut()))))) - } - } -} - -impl MessageBody for String { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from( - mem::take(self.get_mut()).into_bytes(), - )))) - } - } -} - -/// Type represent streaming body. -/// Response does not contain `content-length` header and appropriate transfer encoding is used. -#[pin_project] -pub struct BodyStream { - #[pin] - stream: S, - _t: PhantomData, -} - -impl BodyStream -where - S: Stream> + Unpin, - E: Into, -{ - pub fn new(stream: S) -> Self { - BodyStream { - stream, - _t: PhantomData, - } - } -} - -impl MessageBody for BodyStream -where - S: Stream> + Unpin, - E: Into, -{ - fn size(&self) -> BodySize { - BodySize::Stream - } - - /// Attempts to pull out the next value of the underlying [`Stream`]. - /// - /// Empty values are skipped to prevent [`BodyStream`]'s transmission being - /// ended on a zero-length chunk, but rather proceed until the underlying - /// [`Stream`] ends. - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let mut stream = self.project().stream; - loop { - let stream = stream.as_mut(); - return Poll::Ready(match ready!(stream.poll_next(cx)) { - Some(Ok(ref bytes)) if bytes.is_empty() => continue, - opt => opt.map(|res| res.map_err(Into::into)), - }); - } - } -} - -/// Type represent streaming body. This body implementation should be used -/// if total size of stream is known. Data get sent as is without using transfer encoding. -#[pin_project] -pub struct SizedStream { - size: u64, - #[pin] - stream: S, -} - -impl SizedStream -where - S: Stream> + Unpin, -{ - pub fn new(size: u64, stream: S) -> Self { - SizedStream { size, stream } - } -} - -impl MessageBody for SizedStream -where - S: Stream> + Unpin, -{ - fn size(&self) -> BodySize { - BodySize::Sized(self.size as u64) - } - - /// Attempts to pull out the next value of the underlying [`Stream`]. - /// - /// Empty values are skipped to prevent [`SizedStream`]'s transmission being - /// ended on a zero-length chunk, but rather proceed until the underlying - /// [`Stream`] ends. - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let mut stream: Pin<&mut S> = self.project().stream; - loop { - let stream = stream.as_mut(); - return Poll::Ready(match ready!(stream.poll_next(cx)) { - Some(Ok(ref bytes)) if bytes.is_empty() => continue, - val => val, - }); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use futures_util::future::poll_fn; - use futures_util::pin_mut; - use futures_util::stream; - - impl Body { - pub(crate) fn get_ref(&self) -> &[u8] { - match *self { - Body::Bytes(ref bin) => &bin, - _ => panic!(), - } - } - } - - impl ResponseBody { - pub(crate) fn get_ref(&self) -> &[u8] { - match *self { - ResponseBody::Body(ref b) => b.get_ref(), - ResponseBody::Other(ref b) => b.get_ref(), - } - } - } - - #[actix_rt::test] - async fn test_static_str() { - assert_eq!(Body::from("").size(), BodySize::Sized(0)); - assert_eq!(Body::from("test").size(), BodySize::Sized(4)); - assert_eq!(Body::from("test").get_ref(), b"test"); - - assert_eq!("test".size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| Pin::new(&mut "test").poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_static_bytes() { - assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); - assert_eq!( - Body::from_slice(b"test".as_ref()).size(), - BodySize::Sized(4) - ); - assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test"); - let sb = Bytes::from(&b"test"[..]); - pin_mut!(sb); - - assert_eq!(sb.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| sb.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_vec() { - assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); - assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); - let test_vec = Vec::from("test"); - pin_mut!(test_vec); - - assert_eq!(test_vec.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| test_vec.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_bytes() { - let b = Bytes::from("test"); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - pin_mut!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_bytes_mut() { - let b = BytesMut::from("test"); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - pin_mut!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_string() { - let b = "test".to_owned(); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - assert_eq!(Body::from(&b).size(), BodySize::Sized(4)); - assert_eq!(Body::from(&b).get_ref(), b"test"); - pin_mut!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_unit() { - assert_eq!(().size(), BodySize::Empty); - assert!(poll_fn(|cx| Pin::new(&mut ()).poll_next(cx)) - .await - .is_none()); - } - - #[actix_rt::test] - async fn test_box() { - let val = Box::new(()); - pin_mut!(val); - assert_eq!(val.size(), BodySize::Empty); - assert!(poll_fn(|cx| val.as_mut().poll_next(cx)).await.is_none()); - } - - #[actix_rt::test] - async fn test_body_eq() { - assert!( - Body::Bytes(Bytes::from_static(b"1")) - == Body::Bytes(Bytes::from_static(b"1")) - ); - assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); - } - - #[actix_rt::test] - async fn test_body_debug() { - assert!(format!("{:?}", Body::None).contains("Body::None")); - assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); - assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains('1')); - } - - #[actix_rt::test] - async fn test_serde_json() { - use serde_json::json; - assert_eq!( - Body::from(serde_json::Value::String("test".into())).size(), - BodySize::Sized(6) - ); - assert_eq!( - Body::from(json!({"test-key":"test-value"})).size(), - BodySize::Sized(25) - ); - } - - mod body_stream { - use super::*; - //use futures::task::noop_waker; - //use futures::stream::once; - - #[actix_rt::test] - async fn skips_empty_chunks() { - let body = BodyStream::new(stream::iter( - ["1", "", "2"] - .iter() - .map(|&v| Ok(Bytes::from(v)) as Result), - )); - pin_mut!(body); - - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("1")), - ); - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("2")), - ); - } - - /* Now it does not compile as it should - #[actix_rt::test] - async fn move_pinned_pointer() { - let (sender, receiver) = futures::channel::oneshot::channel(); - let mut body_stream = Ok(BodyStream::new(once(async { - let x = Box::new(0i32); - let y = &x; - receiver.await.unwrap(); - let _z = **y; - Ok::<_, ()>(Bytes::new()) - }))); - - let waker = noop_waker(); - let mut context = Context::from_waker(&waker); - pin_mut!(body_stream); - - let _ = body_stream.as_mut().unwrap().poll_next(&mut context); - sender.send(()).unwrap(); - let _ = std::mem::replace(&mut body_stream, Err([0; 32])).unwrap().poll_next(&mut context); - }*/ - } - - mod sized_stream { - use super::*; - - #[actix_rt::test] - async fn skips_empty_chunks() { - let body = SizedStream::new( - 2, - stream::iter(["1", "", "2"].iter().map(|&v| Ok(Bytes::from(v)))), - ); - pin_mut!(body); - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("1")), - ); - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("2")), - ); - } - } - - #[actix_rt::test] - async fn test_body_casting() { - let mut body = String::from("hello cast"); - let resp_body: &mut dyn MessageBody = &mut body; - let body = resp_body.downcast_ref::().unwrap(); - assert_eq!(body, "hello cast"); - let body = &mut resp_body.downcast_mut::().unwrap(); - body.push('!'); - let body = resp_body.downcast_ref::().unwrap(); - assert_eq!(body, "hello cast!"); - let not_body = resp_body.downcast_ref::<()>(); - assert!(not_body.is_none()); - } -} diff --git a/actix-http/src/body/body.rs b/actix-http/src/body/body.rs new file mode 100644 index 000000000..a3fd7d41c --- /dev/null +++ b/actix-http/src/body/body.rs @@ -0,0 +1,158 @@ +use std::{ + fmt, mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{Bytes, BytesMut}; +use futures_core::Stream; + +use crate::error::Error; + +use super::{BodySize, BodyStream, MessageBody, SizedStream}; + +/// Represents various types of HTTP message body. +pub enum Body { + /// Empty response. `Content-Length` header is not set. + None, + /// Zero sized response body. `Content-Length` header is set to `0`. + Empty, + /// Specific response body. + Bytes(Bytes), + /// Generic message body. + Message(Box), +} + +impl Body { + /// Create body from slice (copy) + pub fn from_slice(s: &[u8]) -> Body { + Body::Bytes(Bytes::copy_from_slice(s)) + } + + /// Create body from generic message body. + pub fn from_message(body: B) -> Body { + Body::Message(Box::new(body)) + } +} + +impl MessageBody for Body { + fn size(&self) -> BodySize { + match self { + Body::None => BodySize::None, + Body::Empty => BodySize::Empty, + Body::Bytes(ref bin) => BodySize::Sized(bin.len() as u64), + Body::Message(ref body) => body.size(), + } + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.get_mut() { + Body::None => Poll::Ready(None), + Body::Empty => Poll::Ready(None), + Body::Bytes(ref mut bin) => { + let len = bin.len(); + if len == 0 { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(bin)))) + } + } + Body::Message(body) => Pin::new(&mut **body).poll_next(cx), + } + } +} + +impl PartialEq for Body { + fn eq(&self, other: &Body) -> bool { + match *self { + Body::None => matches!(*other, Body::None), + Body::Empty => matches!(*other, Body::Empty), + Body::Bytes(ref b) => match *other { + Body::Bytes(ref b2) => b == b2, + _ => false, + }, + Body::Message(_) => false, + } + } +} + +impl fmt::Debug for Body { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Body::None => write!(f, "Body::None"), + Body::Empty => write!(f, "Body::Empty"), + Body::Bytes(ref b) => write!(f, "Body::Bytes({:?})", b), + Body::Message(_) => write!(f, "Body::Message(_)"), + } + } +} + +impl From<&'static str> for Body { + fn from(s: &'static str) -> Body { + Body::Bytes(Bytes::from_static(s.as_ref())) + } +} + +impl From<&'static [u8]> for Body { + fn from(s: &'static [u8]) -> Body { + Body::Bytes(Bytes::from_static(s)) + } +} + +impl From> for Body { + fn from(vec: Vec) -> Body { + Body::Bytes(Bytes::from(vec)) + } +} + +impl From for Body { + fn from(s: String) -> Body { + s.into_bytes().into() + } +} + +impl<'a> From<&'a String> for Body { + fn from(s: &'a String) -> Body { + Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s))) + } +} + +impl From for Body { + fn from(s: Bytes) -> Body { + Body::Bytes(s) + } +} + +impl From for Body { + fn from(s: BytesMut) -> Body { + Body::Bytes(s.freeze()) + } +} + +impl From for Body { + fn from(v: serde_json::Value) -> Body { + Body::Bytes(v.to_string().into()) + } +} + +impl From> for Body +where + S: Stream> + Unpin + 'static, +{ + fn from(s: SizedStream) -> Body { + Body::from_message(s) + } +} + +impl From> for Body +where + S: Stream> + Unpin + 'static, + E: Into + 'static, +{ + fn from(s: BodyStream) -> Body { + Body::from_message(s) + } +} diff --git a/actix-http/src/body/body_stream.rs b/actix-http/src/body/body_stream.rs new file mode 100644 index 000000000..60e33b161 --- /dev/null +++ b/actix-http/src/body/body_stream.rs @@ -0,0 +1,59 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_core::{ready, Stream}; + +use crate::error::Error; + +use super::{BodySize, MessageBody}; + +/// Streaming response wrapper. +/// +/// Response does not contain `Content-Length` header and appropriate transfer encoding is used. +pub struct BodyStream { + stream: S, +} + +impl BodyStream +where + S: Stream> + Unpin, + E: Into, +{ + pub fn new(stream: S) -> Self { + BodyStream { stream } + } +} + +impl MessageBody for BodyStream +where + S: Stream> + Unpin, + E: Into, +{ + fn size(&self) -> BodySize { + BodySize::Stream + } + + /// Attempts to pull out the next value of the underlying [`Stream`]. + /// + /// Empty values are skipped to prevent [`BodyStream`]'s transmission being + /// ended on a zero-length chunk, but rather proceed until the underlying + /// [`Stream`] ends. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + let stream = &mut self.as_mut().stream; + + let chunk = match ready!(Pin::new(stream).poll_next(cx)) { + Some(Ok(ref bytes)) if bytes.is_empty() => continue, + opt => opt.map(|res| res.map_err(Into::into)), + }; + + return Poll::Ready(chunk); + } + } +} diff --git a/actix-http/src/body/message_body.rs b/actix-http/src/body/message_body.rs new file mode 100644 index 000000000..012329146 --- /dev/null +++ b/actix-http/src/body/message_body.rs @@ -0,0 +1,142 @@ +//! [`MessageBody`] trait and foreign implementations. + +use std::{ + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{Bytes, BytesMut}; + +use crate::error::Error; + +use super::BodySize; + +/// Type that implement this trait can be streamed to a peer. +pub trait MessageBody { + fn size(&self) -> BodySize; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>; + + downcast_get_type_id!(); +} + +downcast!(MessageBody); + +impl MessageBody for () { + fn size(&self) -> BodySize { + BodySize::Empty + } + + fn poll_next( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(None) + } +} + +impl MessageBody for Box { + fn size(&self) -> BodySize { + self.as_ref().size() + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + Pin::new(self.get_mut().as_mut()).poll_next(cx) + } +} + +impl MessageBody for Bytes { + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(self.get_mut())))) + } + } +} + +impl MessageBody for BytesMut { + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(self.get_mut()).freeze()))) + } + } +} + +impl MessageBody for &'static str { + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from_static( + mem::take(self.get_mut()).as_ref(), + )))) + } + } +} + +impl MessageBody for Vec { + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from(mem::take(self.get_mut()))))) + } + } +} + +impl MessageBody for String { + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from( + mem::take(self.get_mut()).into_bytes(), + )))) + } + } +} diff --git a/actix-http/src/body/mod.rs b/actix-http/src/body/mod.rs new file mode 100644 index 000000000..a4d6ba2b6 --- /dev/null +++ b/actix-http/src/body/mod.rs @@ -0,0 +1,252 @@ +//! Traits and structures to aid consuming and writing HTTP payloads. + +#[allow(clippy::module_inception)] +mod body; +mod body_stream; +mod message_body; +mod response_body; +mod size; +mod sized_stream; + +pub use self::body::Body; +pub use self::body_stream::BodyStream; +pub use self::message_body::MessageBody; +pub use self::response_body::ResponseBody; +pub use self::size::BodySize; +pub use self::sized_stream::SizedStream; + +#[cfg(test)] +mod tests { + use std::pin::Pin; + + use actix_rt::pin; + use bytes::{Bytes, BytesMut}; + use futures_util::{future::poll_fn, stream}; + + use super::*; + + impl Body { + pub(crate) fn get_ref(&self) -> &[u8] { + match *self { + Body::Bytes(ref bin) => &bin, + _ => panic!(), + } + } + } + + impl ResponseBody { + pub(crate) fn get_ref(&self) -> &[u8] { + match *self { + ResponseBody::Body(ref b) => b.get_ref(), + ResponseBody::Other(ref b) => b.get_ref(), + } + } + } + + #[actix_rt::test] + async fn test_static_str() { + assert_eq!(Body::from("").size(), BodySize::Sized(0)); + assert_eq!(Body::from("test").size(), BodySize::Sized(4)); + assert_eq!(Body::from("test").get_ref(), b"test"); + + assert_eq!("test".size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| Pin::new(&mut "test").poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_static_bytes() { + assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); + assert_eq!( + Body::from_slice(b"test".as_ref()).size(), + BodySize::Sized(4) + ); + assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test"); + let sb = Bytes::from(&b"test"[..]); + pin!(sb); + + assert_eq!(sb.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| sb.as_mut().poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_vec() { + assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); + assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); + let test_vec = Vec::from("test"); + pin!(test_vec); + + assert_eq!(test_vec.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| test_vec.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_bytes() { + let b = Bytes::from("test"); + assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b.clone()).get_ref(), b"test"); + pin!(b); + + assert_eq!(b.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_bytes_mut() { + let b = BytesMut::from("test"); + assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b.clone()).get_ref(), b"test"); + pin!(b); + + assert_eq!(b.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_string() { + let b = "test".to_owned(); + assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b.clone()).get_ref(), b"test"); + assert_eq!(Body::from(&b).size(), BodySize::Sized(4)); + assert_eq!(Body::from(&b).get_ref(), b"test"); + pin!(b); + + assert_eq!(b.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_unit() { + assert_eq!(().size(), BodySize::Empty); + assert!(poll_fn(|cx| Pin::new(&mut ()).poll_next(cx)) + .await + .is_none()); + } + + #[actix_rt::test] + async fn test_box() { + let val = Box::new(()); + pin!(val); + assert_eq!(val.size(), BodySize::Empty); + assert!(poll_fn(|cx| val.as_mut().poll_next(cx)).await.is_none()); + } + + #[actix_rt::test] + async fn test_body_eq() { + assert!( + Body::Bytes(Bytes::from_static(b"1")) + == Body::Bytes(Bytes::from_static(b"1")) + ); + assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); + } + + #[actix_rt::test] + async fn test_body_debug() { + assert!(format!("{:?}", Body::None).contains("Body::None")); + assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); + assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains('1')); + } + + #[actix_rt::test] + async fn test_serde_json() { + use serde_json::json; + assert_eq!( + Body::from(serde_json::Value::String("test".into())).size(), + BodySize::Sized(6) + ); + assert_eq!( + Body::from(json!({"test-key":"test-value"})).size(), + BodySize::Sized(25) + ); + } + + #[actix_rt::test] + async fn body_stream_skips_empty_chunks() { + let body = BodyStream::new(stream::iter( + ["1", "", "2"] + .iter() + .map(|&v| Ok(Bytes::from(v)) as Result), + )); + pin!(body); + + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("1")), + ); + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("2")), + ); + } + + mod sized_stream { + use super::*; + + #[actix_rt::test] + async fn skips_empty_chunks() { + let body = SizedStream::new( + 2, + stream::iter(["1", "", "2"].iter().map(|&v| Ok(Bytes::from(v)))), + ); + pin!(body); + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("1")), + ); + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("2")), + ); + } + } + + #[actix_rt::test] + async fn test_body_casting() { + let mut body = String::from("hello cast"); + let resp_body: &mut dyn MessageBody = &mut body; + let body = resp_body.downcast_ref::().unwrap(); + assert_eq!(body, "hello cast"); + let body = &mut resp_body.downcast_mut::().unwrap(); + body.push('!'); + let body = resp_body.downcast_ref::().unwrap(); + assert_eq!(body, "hello cast!"); + let not_body = resp_body.downcast_ref::<()>(); + assert!(not_body.is_none()); + } +} diff --git a/actix-http/src/body/response_body.rs b/actix-http/src/body/response_body.rs new file mode 100644 index 000000000..97141e11e --- /dev/null +++ b/actix-http/src/body/response_body.rs @@ -0,0 +1,77 @@ +use std::{ + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_core::Stream; +use pin_project::pin_project; + +use crate::error::Error; + +use super::{Body, BodySize, MessageBody}; + +#[pin_project(project = ResponseBodyProj)] +pub enum ResponseBody { + Body(#[pin] B), + Other(Body), +} + +impl ResponseBody { + pub fn into_body(self) -> ResponseBody { + match self { + ResponseBody::Body(b) => ResponseBody::Other(b), + ResponseBody::Other(b) => ResponseBody::Other(b), + } + } +} + +impl ResponseBody { + pub fn take_body(&mut self) -> ResponseBody { + mem::replace(self, ResponseBody::Other(Body::None)) + } +} + +impl ResponseBody { + pub fn as_ref(&self) -> Option<&B> { + if let ResponseBody::Body(ref b) = self { + Some(b) + } else { + None + } + } +} + +impl MessageBody for ResponseBody { + fn size(&self) -> BodySize { + match self { + ResponseBody::Body(ref body) => body.size(), + ResponseBody::Other(ref body) => body.size(), + } + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project() { + ResponseBodyProj::Body(body) => body.poll_next(cx), + ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), + } + } +} + +impl Stream for ResponseBody { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + ResponseBodyProj::Body(body) => body.poll_next(cx), + ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), + } + } +} diff --git a/actix-http/src/body/size.rs b/actix-http/src/body/size.rs new file mode 100644 index 000000000..775d5b8f1 --- /dev/null +++ b/actix-http/src/body/size.rs @@ -0,0 +1,40 @@ +/// Body size hint. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BodySize { + /// Absence of body can be assumed from method or status code. + /// + /// Will skip writing Content-Length header. + None, + + /// Zero size body. + /// + /// Will write `Content-Length: 0` header. + Empty, + + /// Known size body. + /// + /// Will write `Content-Length: N` header. `Sized(0)` is treated the same as `Empty`. + Sized(u64), + + /// Unknown size body. + /// + /// Will not write Content-Length header. Can be used with chunked Transfer-Encoding. + Stream, +} + +impl BodySize { + /// Returns true if size hint indicates no or empty body. + /// + /// ``` + /// # use actix_http::body::BodySize; + /// assert!(BodySize::None.is_eof()); + /// assert!(BodySize::Empty.is_eof()); + /// assert!(BodySize::Sized(0).is_eof()); + /// + /// assert!(!BodySize::Sized(64).is_eof()); + /// assert!(!BodySize::Stream.is_eof()); + /// ``` + pub fn is_eof(&self) -> bool { + matches!(self, BodySize::None | BodySize::Empty | BodySize::Sized(0)) + } +} diff --git a/actix-http/src/body/sized_stream.rs b/actix-http/src/body/sized_stream.rs new file mode 100644 index 000000000..af995a0fb --- /dev/null +++ b/actix-http/src/body/sized_stream.rs @@ -0,0 +1,59 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_core::{ready, Stream}; + +use crate::error::Error; + +use super::{BodySize, MessageBody}; + +/// Known sized streaming response wrapper. +/// +/// This body implementation should be used if total size of stream is known. Data get sent as is +/// without using transfer encoding. +pub struct SizedStream { + size: u64, + stream: S, +} + +impl SizedStream +where + S: Stream> + Unpin, +{ + pub fn new(size: u64, stream: S) -> Self { + SizedStream { size, stream } + } +} + +impl MessageBody for SizedStream +where + S: Stream> + Unpin, +{ + fn size(&self) -> BodySize { + BodySize::Sized(self.size as u64) + } + + /// Attempts to pull out the next value of the underlying [`Stream`]. + /// + /// Empty values are skipped to prevent [`SizedStream`]'s transmission being + /// ended on a zero-length chunk, but rather proceed until the underlying + /// [`Stream`] ends. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + let stream = &mut self.as_mut().stream; + + let chunk = match ready!(Pin::new(stream).poll_next(cx)) { + Some(Ok(ref bytes)) if bytes.is_empty() => continue, + val => val, + }; + + return Poll::Ready(chunk); + } + } +} diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index df1b332c1..fa430c4fe 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -19,7 +19,7 @@ use crate::{ConnectCallback, Extensions}; /// /// This type can be used to construct an instance of [`HttpService`] through a /// builder-like pattern. -pub struct HttpServiceBuilder> { +pub struct HttpServiceBuilder { keep_alive: KeepAlive, client_timeout: u64, client_disconnect: u64, @@ -28,15 +28,15 @@ pub struct HttpServiceBuilder> { expect: X, upgrade: Option, on_connect_ext: Option>>, - _t: PhantomData<(T, S)>, + _phantom: PhantomData, } -impl HttpServiceBuilder> +impl HttpServiceBuilder where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, - ::Future: 'static, + >::Future: 'static, { /// Create instance of `ServiceConfigBuilder` pub fn new() -> Self { @@ -49,25 +49,25 @@ where expect: ExpectHandler, upgrade: None, on_connect_ext: None, - _t: PhantomData, + _phantom: PhantomData, } } } impl HttpServiceBuilder where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, - ::Future: 'static, - X: ServiceFactory, + >::Future: 'static, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - ::Future: 'static, - U: ServiceFactory), Response = ()>, + >::Future: 'static, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, - ::Future: 'static, + )>>::Future: 'static, { /// Set server keep-alive setting. /// @@ -123,11 +123,11 @@ where /// request will be forwarded to main service. pub fn expect(self, expect: F) -> HttpServiceBuilder where - F: IntoServiceFactory, - X1: ServiceFactory, + F: IntoServiceFactory, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, - ::Future: 'static, + >::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, @@ -138,7 +138,7 @@ where expect: expect.into_factory(), upgrade: self.upgrade, on_connect_ext: self.on_connect_ext, - _t: PhantomData, + _phantom: PhantomData, } } @@ -148,15 +148,11 @@ where /// and this service get called with original request and framed object. pub fn upgrade(self, upgrade: F) -> HttpServiceBuilder where - F: IntoServiceFactory, - U1: ServiceFactory< - Config = (), - Request = (Request, Framed), - Response = (), - >, + F: IntoServiceFactory)>, + U1: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U1::Error: fmt::Display, U1::InitError: fmt::Debug, - ::Future: 'static, + )>>::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, @@ -167,7 +163,7 @@ where expect: self.expect, upgrade: Some(upgrade.into_factory()), on_connect_ext: self.on_connect_ext, - _t: PhantomData, + _phantom: PhantomData, } } @@ -188,7 +184,7 @@ where pub fn h1(self, service: F) -> H1Service where B: MessageBody, - F: IntoServiceFactory, + F: IntoServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, @@ -211,11 +207,11 @@ where pub fn h2(self, service: F) -> H2Service where B: MessageBody + 'static, - F: IntoServiceFactory, + F: IntoServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, { let cfg = ServiceConfig::new( self.keep_alive, @@ -233,11 +229,11 @@ where pub fn finish(self, service: F) -> HttpService where B: MessageBody + 'static, - F: IntoServiceFactory, + F: IntoServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, { let cfg = ServiceConfig::new( self.keep_alive, diff --git a/actix-http/src/client/config.rs b/actix-http/src/client/config.rs index c86c697a2..fad902d04 100644 --- a/actix-http/src/client/config.rs +++ b/actix-http/src/client/config.rs @@ -1,8 +1,7 @@ use std::time::Duration; -// These values are taken from hyper/src/proto/h2/client.rs -const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2mb -const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1mb +const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB +const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB /// Connector configuration #[derive(Clone)] @@ -19,7 +18,7 @@ pub(crate) struct ConnectorConfig { impl Default for ConnectorConfig { fn default() -> Self { Self { - timeout: Duration::from_secs(1), + timeout: Duration::from_secs(5), conn_lifetime: Duration::from_secs(75), conn_keep_alive: Duration::from_secs(15), disconnect_timeout: Some(Duration::from_millis(3000)), diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index ec86dabb0..97ecd0515 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -1,11 +1,12 @@ -use std::future::Future; +use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{fmt, io, mem, time}; +use std::{fmt, io, time}; -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use bytes::{Buf, Bytes}; -use futures_util::future::{err, Either, FutureExt, LocalBoxFuture, Ready}; +use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; +use actix_rt::task::JoinHandle; +use bytes::Bytes; +use futures_core::future::LocalBoxFuture; use h2::client::SendRequest; use pin_project::pin_project; @@ -15,33 +16,82 @@ use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; use super::error::SendRequestError; -use super::pool::{Acquired, Protocol}; +use super::pool::Acquired; use super::{h1proto, h2proto}; pub(crate) enum ConnectionType { H1(Io), - H2(SendRequest), + H2(H2Connection), +} + +/// `H2Connection` has two parts: `SendRequest` and `Connection`. +/// +/// `Connection` is spawned as an async task on runtime and `H2Connection` holds a handle for +/// this task. Therefore, it can wake up and quit the task when SendRequest is dropped. +pub(crate) struct H2Connection { + handle: JoinHandle<()>, + sender: SendRequest, +} + +impl H2Connection { + pub(crate) fn new( + sender: SendRequest, + connection: h2::client::Connection, + ) -> Self + where + Io: AsyncRead + AsyncWrite + Unpin + 'static, + { + let handle = actix_rt::spawn(async move { + let _ = connection.await; + }); + + Self { handle, sender } + } +} + +// cancel spawned connection task on drop. +impl Drop for H2Connection { + fn drop(&mut self) { + self.handle.abort(); + } +} + +// only expose sender type to public. +impl Deref for H2Connection { + type Target = SendRequest; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} + +impl DerefMut for H2Connection { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sender + } } pub trait Connection { type Io: AsyncRead + AsyncWrite + Unpin; - type Future: Future>; - - fn protocol(&self) -> Protocol; /// Send request and body - fn send_request>( + fn send_request( self, head: H, body: B, - ) -> Self::Future; - - type TunnelFuture: Future< - Output = Result<(ResponseHead, Framed), SendRequestError>, - >; + ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + where + B: MessageBody + 'static, + H: Into + 'static; /// Send request, returns Response and Framed - fn open_tunnel>(self, head: H) -> Self::TunnelFuture; + fn open_tunnel + 'static>( + self, + head: H, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + >; } pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { @@ -54,7 +104,10 @@ pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { #[doc(hidden)] /// HTTP client connection -pub struct IoConnection { +pub struct IoConnection +where + T: AsyncWrite + Unpin + 'static, +{ io: Option>, created: time::Instant, pool: Option>, @@ -62,7 +115,7 @@ pub struct IoConnection { impl fmt::Debug for IoConnection where - T: fmt::Debug, + T: AsyncWrite + Unpin + fmt::Debug + 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.io { @@ -89,55 +142,36 @@ impl IoConnection { pub(crate) fn into_inner(self) -> (ConnectionType, time::Instant) { (self.io.unwrap(), self.created) } -} -impl Connection for IoConnection -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - type Io = T; - type Future = - LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; - - fn protocol(&self) -> Protocol { - match self.io { - Some(ConnectionType::H1(_)) => Protocol::Http1, - Some(ConnectionType::H2(_)) => Protocol::Http2, - None => Protocol::Http1, - } + #[cfg(test)] + pub(crate) fn into_parts(self) -> (ConnectionType, time::Instant, Acquired) { + (self.io.unwrap(), self.created, self.pool.unwrap()) } - fn send_request>( + async fn send_request>( mut self, head: H, body: B, - ) -> Self::Future { + ) -> Result<(ResponseHead, Payload), SendRequestError> { match self.io.take().unwrap() { ConnectionType::H1(io) => { h1proto::send_request(io, head.into(), body, self.created, self.pool) - .boxed_local() + .await } ConnectionType::H2(io) => { h2proto::send_request(io, head.into(), body, self.created, self.pool) - .boxed_local() + .await } } } - type TunnelFuture = Either< - LocalBoxFuture< - 'static, - Result<(ResponseHead, Framed), SendRequestError>, - >, - Ready), SendRequestError>>, - >; - /// Send request, returns Response and Framed - fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { + async fn open_tunnel>( + mut self, + head: H, + ) -> Result<(ResponseHead, Framed), SendRequestError> { match self.io.take().unwrap() { - ConnectionType::H1(io) => { - Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) - } + ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await, ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { pool.release(IoConnection::new( @@ -146,65 +180,61 @@ where None, )); } - Either::Right(err(SendRequestError::TunnelNotSupported)) + Err(SendRequestError::TunnelNotSupported) } } } } #[allow(dead_code)] -pub(crate) enum EitherConnection { +pub(crate) enum EitherIoConnection +where + A: AsyncRead + AsyncWrite + Unpin + 'static, + B: AsyncRead + AsyncWrite + Unpin + 'static, +{ A(IoConnection), B(IoConnection), } -impl Connection for EitherConnection +impl Connection for EitherIoConnection where A: AsyncRead + AsyncWrite + Unpin + 'static, B: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = EitherIo; - type Future = - LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; - fn protocol(&self) -> Protocol { - match self { - EitherConnection::A(con) => con.protocol(), - EitherConnection::B(con) => con.protocol(), - } - } - - fn send_request>( + fn send_request( self, head: H, body: RB, - ) -> Self::Future { + ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + where + RB: MessageBody + 'static, + H: Into + 'static, + { match self { - EitherConnection::A(con) => con.send_request(head, body), - EitherConnection::B(con) => con.send_request(head, body), + EitherIoConnection::A(con) => Box::pin(con.send_request(head, body)), + EitherIoConnection::B(con) => Box::pin(con.send_request(head, body)), } } - type TunnelFuture = LocalBoxFuture< + /// Send request, returns Response and Framed + fn open_tunnel + 'static>( + self, + head: H, + ) -> LocalBoxFuture< 'static, Result<(ResponseHead, Framed), SendRequestError>, - >; - - /// Send request, returns Response and Framed - fn open_tunnel>(self, head: H) -> Self::TunnelFuture { + > { match self { - EitherConnection::A(con) => con - .open_tunnel(head) - .map(|res| { - res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::A))) - }) - .boxed_local(), - EitherConnection::B(con) => con - .open_tunnel(head) - .map(|res| { - res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::B))) - }) - .boxed_local(), + EitherIoConnection::A(con) => Box::pin(async { + let (head, framed) = con.open_tunnel(head).await?; + Ok((head, framed.into_map_io(EitherIo::A))) + }), + EitherIoConnection::B(con) => Box::pin(async { + let (head, framed) = con.open_tunnel(head).await?; + Ok((head, framed.into_map_io(EitherIo::B))) + }), } } } @@ -223,23 +253,13 @@ where fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_read(cx, buf), EitherIoProj::B(val) => val.poll_read(cx, buf), } } - - unsafe fn prepare_uninitialized_buffer( - &self, - buf: &mut [mem::MaybeUninit], - ) -> bool { - match self { - EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), - EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf), - } - } } impl AsyncWrite for EitherIo @@ -274,18 +294,36 @@ where EitherIoProj::B(val) => val.poll_shutdown(cx), } } +} - fn poll_write_buf( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut U, - ) -> Poll> - where - Self: Sized, - { - match self.project() { - EitherIoProj::A(val) => val.poll_write_buf(cx, buf), - EitherIoProj::B(val) => val.poll_write_buf(cx, buf), - } +#[cfg(test)] +mod test { + use std::net; + + use actix_rt::net::TcpStream; + + use super::*; + + #[actix_rt::test] + async fn test_h2_connection_drop() { + let addr = "127.0.0.1:0".parse::().unwrap(); + let listener = net::TcpListener::bind(addr).unwrap(); + let local = listener.local_addr().unwrap(); + + std::thread::spawn(move || while listener.accept().is_ok() {}); + + let tcp = TcpStream::connect(local).await.unwrap(); + let (sender, connection) = h2::client::handshake(tcp).await.unwrap(); + let conn = H2Connection::new(sender.clone(), connection); + + assert!(sender.clone().ready().await.is_ok()); + assert!(h2::client::SendRequest::clone(&*conn).ready().await.is_ok()); + + drop(conn); + + match sender.ready().await { + Ok(_) => panic!("connection should be gone and can not be ready"), + Err(e) => assert!(e.is_io()), + }; } } diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs index e1aed6382..8aa5b1319 100644 --- a/actix-http/src/client/connector.rs +++ b/actix-http/src/client/connector.rs @@ -1,27 +1,29 @@ use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use actix_codec::{AsyncRead, AsyncWrite}; -use actix_connect::{ - default_connector, Connect as TcpConnect, Connection as TcpConnection, -}; use actix_rt::net::TcpStream; -use actix_service::{apply_fn, Service}; +use actix_service::{apply_fn, Service, ServiceExt}; +use actix_tls::connect::{ + new_connector, Connect as TcpConnect, Connection as TcpConnection, Resolver, +}; use actix_utils::timeout::{TimeoutError, TimeoutService}; use http::Uri; use super::config::ConnectorConfig; -use super::connection::Connection; +use super::connection::{Connection, EitherIoConnection}; use super::error::ConnectError; use super::pool::{ConnectionPool, Protocol}; use super::Connect; #[cfg(feature = "openssl")] -use actix_connect::ssl::openssl::SslConnector as OpensslConnector; - +use actix_tls::connect::ssl::openssl::SslConnector as OpensslConnector; #[cfg(feature = "rustls")] -use actix_connect::ssl::rustls::ClientConfig; +use actix_tls::connect::ssl::rustls::ClientConfig; #[cfg(feature = "rustls")] use std::sync::Arc; @@ -35,7 +37,8 @@ enum SslConnector { #[cfg(not(any(feature = "openssl", feature = "rustls")))] type SslConnector = (); -/// Manages http client network connectivity +/// Manages HTTP client network connectivity. +/// /// The `Connector` type uses a builder-like combinator pattern for service /// construction that finishes by calling the `.finish()` method. /// @@ -52,34 +55,34 @@ pub struct Connector { config: ConnectorConfig, #[allow(dead_code)] ssl: SslConnector, - _t: PhantomData, + _phantom: PhantomData, } -trait Io: AsyncRead + AsyncWrite + Unpin {} +pub trait Io: AsyncRead + AsyncWrite + Unpin {} impl Io for T {} impl Connector<(), ()> { #[allow(clippy::new_ret_no_self, clippy::let_unit_value)] pub fn new() -> Connector< impl Service< - Request = TcpConnect, + TcpConnect, Response = TcpConnection, - Error = actix_connect::ConnectError, + Error = actix_tls::connect::ConnectError, > + Clone, TcpStream, > { Connector { ssl: Self::build_ssl(vec![b"h2".to_vec(), b"http/1.1".to_vec()]), - connector: default_connector(), + connector: new_connector(resolver::resolver()), config: ConnectorConfig::default(), - _t: PhantomData, + _phantom: PhantomData, } } // Build Ssl connector with openssl, based on supplied alpn protocols #[cfg(feature = "openssl")] fn build_ssl(protocols: Vec>) -> SslConnector { - use actix_connect::ssl::openssl::SslMethod; + use actix_tls::connect::ssl::openssl::SslMethod; use bytes::{BufMut, BytesMut}; let mut alpn = BytesMut::with_capacity(20); @@ -100,9 +103,9 @@ impl Connector<(), ()> { fn build_ssl(protocols: Vec>) -> SslConnector { let mut config = ClientConfig::new(); config.set_protocols(&protocols); - config - .root_store - .add_server_trust_anchors(&actix_tls::rustls::TLS_SERVER_ROOTS); + config.root_store.add_server_trust_anchors( + &actix_tls::connect::ssl::rustls::TLS_SERVER_ROOTS, + ); SslConnector::Rustls(Arc::new(config)) } @@ -117,16 +120,16 @@ impl Connector { where U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, T1: Service< - Request = TcpConnect, + TcpConnect, Response = TcpConnection, - Error = actix_connect::ConnectError, + Error = actix_tls::connect::ConnectError, > + Clone, { Connector { connector, config: self.config, ssl: self.ssl, - _t: PhantomData, + _phantom: PhantomData, } } } @@ -135,9 +138,9 @@ impl Connector where U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, T: Service< - Request = TcpConnect, + TcpConnect, Response = TcpConnection, - Error = actix_connect::ConnectError, + Error = actix_tls::connect::ConnectError, > + Clone + 'static, { @@ -161,8 +164,9 @@ where self } - /// Maximum supported http major version - /// Supported versions http/1.1, http/2 + /// Maximum supported HTTP major version. + /// + /// Supported versions are HTTP/1.1 and HTTP/2. pub fn max_http_version(mut self, val: http::Version) -> Self { let versions = match val { http::Version::HTTP_11 => vec![b"http/1.1".to_vec()], @@ -241,38 +245,53 @@ where /// its combinator chain. pub fn finish( self, - ) -> impl Service - + Clone { + ) -> impl Service + Clone + { + let tcp_service = TimeoutService::new( + self.config.timeout, + apply_fn(self.connector.clone(), |msg: Connect, srv| { + srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + }) + .map_err(ConnectError::from) + .map(|stream| (stream.into_parts().0, Protocol::Http1)), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + #[cfg(not(any(feature = "openssl", feature = "rustls")))] { - let connector = TimeoutService::new( - self.config.timeout, - apply_fn(self.connector, |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) - .map(|stream| (stream.into_parts().0, Protocol::Http1)), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => ConnectError::Timeout, - }); + // A dummy service for annotate tls pool's type signature. + pub type DummyService = Box< + dyn Service< + Connect, + Response = (Box, Protocol), + Error = ConnectError, + Future = futures_core::future::LocalBoxFuture< + 'static, + Result<(Box, Protocol), ConnectError>, + >, + >, + >; - connect_impl::InnerConnector { + InnerConnector::<_, DummyService, _, Box> { tcp_pool: ConnectionPool::new( - connector, + tcp_service, self.config.no_disconnect_timeout(), ), + tls_pool: None, } } + #[cfg(any(feature = "openssl", feature = "rustls"))] { const H2: &[u8] = b"h2"; - #[cfg(feature = "openssl")] - use actix_connect::ssl::openssl::OpensslConnector; - #[cfg(feature = "rustls")] - use actix_connect::ssl::rustls::{RustlsConnector, Session}; use actix_service::{boxed::service, pipeline}; + #[cfg(feature = "openssl")] + use actix_tls::connect::ssl::openssl::OpensslConnector; + #[cfg(feature = "rustls")] + use actix_tls::connect::ssl::rustls::{RustlsConnector, Session}; let ssl_service = TimeoutService::new( self.config.timeout, @@ -327,221 +346,177 @@ where TimeoutError::Timeout => ConnectError::Timeout, }); - let tcp_service = TimeoutService::new( - self.config.timeout, - apply_fn(self.connector, |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) - .map(|stream| (stream.into_parts().0, Protocol::Http1)), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => ConnectError::Timeout, - }); - - connect_impl::InnerConnector { + InnerConnector { tcp_pool: ConnectionPool::new( tcp_service, self.config.no_disconnect_timeout(), ), - ssl_pool: ConnectionPool::new(ssl_service, self.config), + tls_pool: Some(ConnectionPool::new(ssl_service, self.config)), } } } } -#[cfg(not(any(feature = "openssl", feature = "rustls")))] -mod connect_impl { - use std::task::{Context, Poll}; +struct InnerConnector +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + tcp_pool: ConnectionPool, + tls_pool: Option>, +} - use futures_util::future::{err, Either, Ready}; +impl Clone for InnerConnector +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + tls_pool: self.tls_pool.as_ref().cloned(), + } + } +} - use super::*; - use crate::client::connection::IoConnection; +impl Service for InnerConnector +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Response = EitherIoConnection; + type Error = ConnectError; + type Future = InnerConnectorResponse; - pub(crate) struct InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, - { - pub(crate) tcp_pool: ConnectionPool, + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.tcp_pool.poll_ready(cx) } - impl Clone for InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, - { - fn clone(&self) -> Self { - InnerConnector { - tcp_pool: self.tcp_pool.clone(), + fn call(&self, req: Connect) -> Self::Future { + match req.uri.scheme_str() { + Some("https") | Some("wss") => match self.tls_pool { + None => InnerConnectorResponse::SslIsNotSupported, + Some(ref pool) => InnerConnectorResponse::Io2(pool.call(req)), + }, + _ => InnerConnectorResponse::Io1(self.tcp_pool.call(req)), + } + } +} + +#[pin_project::pin_project(project = InnerConnectorProj)] +enum InnerConnectorResponse +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + Io1(#[pin] as Service>::Future), + Io2(#[pin] as Service>::Future), + SslIsNotSupported, +} + +impl Future for InnerConnectorResponse +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + InnerConnectorProj::Io1(fut) => fut.poll(cx).map_ok(EitherIoConnection::A), + InnerConnectorProj::Io2(fut) => fut.poll(cx).map_ok(EitherIoConnection::B), + InnerConnectorProj::SslIsNotSupported => { + Poll::Ready(Err(ConnectError::SslIsNotSupported)) } } } +} - impl Service for InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, - { - type Request = Connect; - type Response = IoConnection; - type Error = ConnectError; - type Future = Either< - as Service>::Future, - Ready, ConnectError>>, - >; +#[cfg(not(feature = "trust-dns"))] +mod resolver { + use super::*; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.tcp_pool.poll_ready(cx) + pub(super) fn resolver() -> Resolver { + Resolver::Default + } +} + +#[cfg(feature = "trust-dns")] +mod resolver { + use std::{cell::RefCell, net::SocketAddr}; + + use actix_tls::connect::Resolve; + use futures_core::future::LocalBoxFuture; + use trust_dns_resolver::{ + config::{ResolverConfig, ResolverOpts}, + system_conf::read_system_conf, + TokioAsyncResolver, + }; + + use super::*; + + pub(super) fn resolver() -> Resolver { + // new type for impl Resolve trait for TokioAsyncResolver. + struct TrustDnsResolver(TokioAsyncResolver); + + impl Resolve for TrustDnsResolver { + fn lookup<'a>( + &'a self, + host: &'a str, + port: u16, + ) -> LocalBoxFuture<'a, Result, Box>> + { + Box::pin(async move { + let res = self + .0 + .lookup_ip(host) + .await? + .iter() + .map(|ip| SocketAddr::new(ip, port)) + .collect(); + Ok(res) + }) + } } - fn call(&mut self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => { - Either::Right(err(ConnectError::SslIsNotSupported)) + // dns struct is cached in thread local. + // so new client constructor can reuse the existing dns resolver. + thread_local! { + static TRUST_DNS_RESOLVER: RefCell> = RefCell::new(None); + } + + // get from thread local or construct a new trust-dns resolver. + TRUST_DNS_RESOLVER.with(|local| { + let resolver = local.borrow().as_ref().map(Clone::clone); + match resolver { + Some(resolver) => resolver, + None => { + let (cfg, opts) = match read_system_conf() { + Ok((cfg, opts)) => (cfg, opts), + Err(e) => { + log::error!("TRust-DNS can not load system config: {}", e); + (ResolverConfig::default(), ResolverOpts::default()) + } + }; + + let resolver = TokioAsyncResolver::tokio(cfg, opts).unwrap(); + + // box trust dns resolver and put it in thread local. + let resolver = Resolver::new_custom(TrustDnsResolver(resolver)); + *local.borrow_mut() = Some(resolver.clone()); + resolver } - _ => Either::Left(self.tcp_pool.call(req)), } - } - } -} - -#[cfg(any(feature = "openssl", feature = "rustls"))] -mod connect_impl { - use std::future::Future; - use std::marker::PhantomData; - use std::pin::Pin; - use std::task::{Context, Poll}; - - use futures_core::ready; - use futures_util::future::Either; - - use super::*; - use crate::client::connection::EitherConnection; - - pub(crate) struct InnerConnector - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T1: Service, - T2: Service, - { - pub(crate) tcp_pool: ConnectionPool, - pub(crate) ssl_pool: ConnectionPool, - } - - impl Clone for InnerConnector - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T1: Service - + 'static, - T2: Service - + 'static, - { - fn clone(&self) -> Self { - InnerConnector { - tcp_pool: self.tcp_pool.clone(), - ssl_pool: self.ssl_pool.clone(), - } - } - } - - impl Service for InnerConnector - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T1: Service - + 'static, - T2: Service - + 'static, - { - type Request = Connect; - type Response = EitherConnection; - type Error = ConnectError; - type Future = Either< - InnerConnectorResponseA, - InnerConnectorResponseB, - >; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.tcp_pool.poll_ready(cx) - } - - fn call(&mut self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB { - fut: self.ssl_pool.call(req), - _t: PhantomData, - }), - _ => Either::Left(InnerConnectorResponseA { - fut: self.tcp_pool.call(req), - _t: PhantomData, - }), - } - } - } - - #[pin_project::pin_project] - pub(crate) struct InnerConnectorResponseA - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, - { - #[pin] - fut: as Service>::Future, - _t: PhantomData, - } - - impl Future for InnerConnectorResponseA - where - T: Service - + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - type Output = Result, ConnectError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Poll::Ready( - ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) - .map(EitherConnection::A), - ) - } - } - - #[pin_project::pin_project] - pub(crate) struct InnerConnectorResponseB - where - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, - { - #[pin] - fut: as Service>::Future, - _t: PhantomData, - } - - impl Future for InnerConnectorResponseB - where - T: Service - + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - type Output = Result, ConnectError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Poll::Ready( - ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) - .map(EitherConnection::B), - ) - } + }) } } diff --git a/actix-http/src/client/error.rs b/actix-http/src/client/error.rs index ba697bca4..d27363456 100644 --- a/actix-http/src/client/error.rs +++ b/actix-http/src/client/error.rs @@ -1,10 +1,9 @@ use std::io; -use actix_connect::resolver::ResolveError; use derive_more::{Display, From}; #[cfg(feature = "openssl")] -use actix_connect::ssl::openssl::{HandshakeError, SslError}; +use actix_tls::accept::openssl::SslError; use crate::error::{Error, ParseError, ResponseError}; use crate::http::{Error as HttpError, StatusCode}; @@ -21,17 +20,12 @@ pub enum ConnectError { #[display(fmt = "{}", _0)] SslError(SslError), - /// SSL Handshake error - #[cfg(feature = "openssl")] - #[display(fmt = "{}", _0)] - SslHandshakeError(String), - /// Failed to resolve the hostname #[display(fmt = "Failed resolving hostname: {}", _0)] - Resolver(ResolveError), + Resolver(Box), /// No dns records - #[display(fmt = "No dns records found for the input")] + #[display(fmt = "No DNS records found for the input")] NoRecords, /// Http2 error @@ -57,34 +51,30 @@ pub enum ConnectError { impl std::error::Error for ConnectError {} -impl From for ConnectError { - fn from(err: actix_connect::ConnectError) -> ConnectError { +impl From for ConnectError { + fn from(err: actix_tls::connect::ConnectError) -> ConnectError { match err { - actix_connect::ConnectError::Resolver(e) => ConnectError::Resolver(e), - actix_connect::ConnectError::NoRecords => ConnectError::NoRecords, - actix_connect::ConnectError::InvalidInput => panic!(), - actix_connect::ConnectError::Unresolved => ConnectError::Unresolved, - actix_connect::ConnectError::Io(e) => ConnectError::Io(e), + actix_tls::connect::ConnectError::Resolver(e) => ConnectError::Resolver(e), + actix_tls::connect::ConnectError::NoRecords => ConnectError::NoRecords, + actix_tls::connect::ConnectError::InvalidInput => panic!(), + actix_tls::connect::ConnectError::Unresolved => ConnectError::Unresolved, + actix_tls::connect::ConnectError::Io(e) => ConnectError::Io(e), } } } -#[cfg(feature = "openssl")] -impl From> for ConnectError { - fn from(err: HandshakeError) -> ConnectError { - ConnectError::SslHandshakeError(format!("{:?}", err)) - } -} - #[derive(Debug, Display, From)] pub enum InvalidUrl { - #[display(fmt = "Missing url scheme")] + #[display(fmt = "Missing URL scheme")] MissingScheme, - #[display(fmt = "Unknown url scheme")] + + #[display(fmt = "Unknown URL scheme")] UnknownScheme, + #[display(fmt = "Missing host name")] MissingHost, - #[display(fmt = "Url parse error: {}", _0)] + + #[display(fmt = "URL parse error: {}", _0)] HttpError(http::Error), } @@ -96,25 +86,33 @@ pub enum SendRequestError { /// Invalid URL #[display(fmt = "Invalid URL: {}", _0)] Url(InvalidUrl), + /// Failed to connect to host #[display(fmt = "Failed to connect to host: {}", _0)] Connect(ConnectError), + /// Error sending request Send(io::Error), + /// Error parsing response Response(ParseError), + /// Http error #[display(fmt = "{}", _0)] Http(HttpError), + /// Http2 error #[display(fmt = "{}", _0)] H2(h2::Error), + /// Response took too long #[display(fmt = "Timeout while waiting for response")] Timeout, - /// Tunnels are not supported for http2 connection + + /// Tunnels are not supported for HTTP/2 connection #[display(fmt = "Tunnels are not supported for http2 connection")] TunnelNotSupported, + /// Error sending request body Body(Error), } @@ -140,7 +138,8 @@ pub enum FreezeRequestError { /// Invalid URL #[display(fmt = "Invalid URL: {}", _0)] Url(InvalidUrl), - /// Http error + + /// HTTP error #[display(fmt = "{}", _0)] Http(HttpError), } diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index 06cc05404..082c4b8e2 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -1,14 +1,14 @@ use std::io::Write; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{io, mem, time}; +use std::{io, time}; -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use bytes::buf::BufMutExt; +use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; +use bytes::buf::BufMut; use bytes::{Bytes, BytesMut}; use futures_core::Stream; use futures_util::future::poll_fn; -use futures_util::{pin_mut, SinkExt, StreamExt}; +use futures_util::{SinkExt, StreamExt}; use crate::error::PayloadError; use crate::h1; @@ -45,14 +45,14 @@ where Some(port) => write!(wrt, "{}:{}", host, port), }; - match wrt.get_mut().split().freeze().try_into() { + match wrt.get_mut().split().freeze().try_into_value() { Ok(value) => match head { RequestHeadType::Owned(ref mut head) => { - head.headers.insert(HOST, value) + head.headers.insert(HOST, value); } RequestHeadType::Rc(_, ref mut extra_headers) => { let headers = extra_headers.get_or_insert(HeaderMap::new()); - headers.insert(HOST, value) + headers.insert(HOST, value); } }, Err(e) => log::error!("Can not set HOST header {}", e), @@ -72,7 +72,7 @@ where // send request body match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), + BodySize::None | BodySize::Empty | BodySize::Sized(0) => {} _ => send_body(body, Pin::new(&mut framed_inner)).await?, }; @@ -127,7 +127,7 @@ where T: ConnectionLifetime + Unpin, B: MessageBody, { - pin_mut!(body); + actix_rt::pin!(body); let mut eof = false; while !eof { @@ -165,7 +165,10 @@ where #[doc(hidden)] /// HTTP client connection -pub struct H1Connection { +pub struct H1Connection +where + T: AsyncWrite + Unpin + 'static, +{ /// T should be `Unpin` io: Option, created: time::Instant, @@ -204,18 +207,11 @@ where } impl AsyncRead for H1Connection { - unsafe fn prepare_uninitialized_buffer( - &self, - buf: &mut [mem::MaybeUninit], - ) -> bool { - self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf) - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) } } diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 3f9a981f4..0deb5c014 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -5,7 +5,6 @@ use std::time; use actix_codec::{AsyncRead, AsyncWrite}; use bytes::Bytes; use futures_util::future::poll_fn; -use futures_util::pin_mut; use h2::{ client::{Builder, Connection, SendRequest}, SendStream, @@ -22,9 +21,10 @@ use super::config::ConnectorConfig; use super::connection::{ConnectionType, IoConnection}; use super::error::SendRequestError; use super::pool::Acquired; +use crate::client::connection::H2Connection; pub(crate) async fn send_request( - mut io: SendRequest, + mut io: H2Connection, head: RequestHeadType, body: B, created: time::Instant, @@ -35,6 +35,7 @@ where B: MessageBody, { trace!("Sending client request: {:?} {:?}", head, body.size()); + let head_req = head.as_ref().method == Method::HEAD; let length = body.size(); let eof = matches!( @@ -89,7 +90,7 @@ where CONNECTION | TRANSFER_ENCODING => continue, // http2 specific CONTENT_LENGTH if skip_len => continue, // DATE => has_date = true, - _ => (), + _ => {} } req.headers_mut().append(key, value.clone()); } @@ -129,7 +130,7 @@ async fn send_body( mut send: SendStream, ) -> Result<(), SendRequestError> { let mut buf = None; - pin_mut!(body); + actix_rt::pin!(body); loop { if buf.is_none() { match poll_fn(|cx| body.as_mut().poll_next(cx)).await { @@ -171,9 +172,9 @@ async fn send_body( } } -// release SendRequest object +/// release SendRequest object fn release( - io: SendRequest, + io: H2Connection, pool: Option>, created: time::Instant, close: bool, diff --git a/actix-http/src/client/mod.rs b/actix-http/src/client/mod.rs index dd1e9b25a..5f5e57edb 100644 --- a/actix-http/src/client/mod.rs +++ b/actix-http/src/client/mod.rs @@ -1,4 +1,5 @@ -//! Http client api +//! HTTP client. + use http::Uri; mod config; @@ -9,6 +10,10 @@ mod h1proto; mod h2proto; mod pool; +pub use actix_tls::connect::{ + Connect as TcpConnect, ConnectError as TcpConnectError, Connection as TcpConnection, +}; + pub use self::connection::Connection; pub use self::connector::Connector; pub use self::error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs index a8687dbeb..3800696fa 100644 --- a/actix-http/src/client/pool.rs +++ b/actix-http/src/client/pool.rs @@ -1,27 +1,27 @@ -use std::cell::RefCell; +//! Client connection pooling keyed on the authority part of the connection URI. + use std::collections::VecDeque; use std::future::Future; +use std::ops::Deref; use std::pin::Pin; use std::rc::Rc; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; +use std::{cell::RefCell, io}; use actix_codec::{AsyncRead, AsyncWrite}; -use actix_rt::time::{delay_for, Delay}; +use actix_rt::time::{sleep, Sleep}; use actix_service::Service; -use actix_utils::task::LocalWaker; -use bytes::Bytes; -use futures_channel::oneshot; -use futures_util::future::{poll_fn, FutureExt, LocalBoxFuture}; -use fxhash::FxHashMap; -use h2::client::{Connection, SendRequest}; +use ahash::AHashMap; +use futures_core::future::LocalBoxFuture; use http::uri::Authority; -use indexmap::IndexSet; use pin_project::pin_project; -use slab::Slab; +use tokio::io::ReadBuf; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use super::config::ConnectorConfig; -use super::connection::{ConnectionType, IoConnection}; +use super::connection::{ConnectionType, H2Connection, IoConnection}; use super::error::ConnectError; use super::h2proto::handshake; use super::Connect; @@ -44,602 +44,635 @@ impl From for Key { } } -/// Connections pool -pub(crate) struct ConnectionPool(Rc>, Rc>>); - -impl ConnectionPool +/// Connections pool for reuse Io type for certain [`http::uri::Authority`] as key. +pub(crate) struct ConnectionPool where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, + Io: AsyncWrite + Unpin + 'static, { - pub(crate) fn new(connector: T, config: ConnectorConfig) -> Self { - let connector_rc = Rc::new(RefCell::new(connector)); - let inner_rc = Rc::new(RefCell::new(Inner { - config, - acquired: 0, - waiters: Slab::new(), - waiters_queue: IndexSet::new(), - available: FxHashMap::default(), - waker: LocalWaker::new(), - })); + connector: Rc, + inner: ConnectionPoolInner, +} - // start support future - actix_rt::spawn(ConnectorPoolSupport { - connector: Rc::clone(&connector_rc), - inner: Rc::clone(&inner_rc), - }); +/// wrapper type for check the ref count of Rc. +struct ConnectionPoolInner(Rc>) +where + Io: AsyncWrite + Unpin + 'static; - ConnectionPool(connector_rc, inner_rc) +impl ConnectionPoolInner +where + Io: AsyncWrite + Unpin + 'static, +{ + /// spawn a async for graceful shutdown h1 Io type with a timeout. + fn close(&self, conn: ConnectionType) { + if let Some(timeout) = self.config.disconnect_timeout { + if let ConnectionType::H1(io) = conn { + actix_rt::spawn(CloseConnection::new(io, timeout)); + } + } } } -impl Clone for ConnectionPool +impl Clone for ConnectionPoolInner where - Io: 'static, + Io: AsyncWrite + Unpin + 'static, { fn clone(&self) -> Self { - ConnectionPool(self.0.clone(), self.1.clone()) + Self(Rc::clone(&self.0)) } } -impl Drop for ConnectionPool { - fn drop(&mut self) { - // wake up the ConnectorPoolSupport when dropping so it can exit properly. - self.1.borrow().waker.wake(); - } -} - -impl Service for ConnectionPool +impl Deref for ConnectionPoolInner where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, + Io: AsyncWrite + Unpin + 'static, +{ + type Target = ConnectionPoolInnerPriv; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl Drop for ConnectionPoolInner +where + Io: AsyncWrite + Unpin + 'static, +{ + fn drop(&mut self) { + // When strong count is one it means the pool is dropped + // remove and drop all Io types. + if Rc::strong_count(&self.0) == 1 { + self.permits.close(); + std::mem::take(&mut *self.available.borrow_mut()) + .into_iter() + .for_each(|(_, conns)| { + conns.into_iter().for_each(|pooled| self.close(pooled.conn)) + }); + } + } +} + +struct ConnectionPoolInnerPriv +where + Io: AsyncWrite + Unpin + 'static, +{ + config: ConnectorConfig, + available: RefCell>>>, + permits: Arc, +} + +impl ConnectionPool +where + Io: AsyncWrite + Unpin + 'static, +{ + /// Construct a new connection pool. + /// + /// [`super::config::ConnectorConfig`]'s `limit` is used as the max permits allowed for + /// in-flight connections. + /// + /// The pool can only have equal to `limit` amount of requests spawning/using Io type + /// concurrently. + /// + /// Any requests beyond limit would be wait in fifo order and get notified in async manner + /// by [`tokio::sync::Semaphore`] + pub(crate) fn new(connector: S, config: ConnectorConfig) -> Self { + let permits = Arc::new(Semaphore::new(config.limit)); + let available = RefCell::new(AHashMap::default()); + let connector = Rc::new(connector); + + let inner = ConnectionPoolInner(Rc::new(ConnectionPoolInnerPriv { + config, + available, + permits, + })); + + Self { connector, inner } + } +} + +impl Clone for ConnectionPool +where + Io: AsyncWrite + Unpin + 'static, +{ + fn clone(&self) -> Self { + Self { + connector: self.connector.clone(), + inner: self.inner.clone(), + } + } +} + +impl Service for ConnectionPool +where + S: Service + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { - type Request = Connect; type Response = IoConnection; type Error = ConnectError; type Future = LocalBoxFuture<'static, Result, ConnectError>>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0.poll_ready(cx) - } + actix_service::forward_ready!(connector); - fn call(&mut self, req: Connect) -> Self::Future { - let mut connector = self.0.clone(); - let inner = self.1.clone(); + fn call(&self, req: Connect) -> Self::Future { + let connector = self.connector.clone(); + let inner = self.inner.clone(); - let fut = async move { + Box::pin(async move { let key = if let Some(authority) = req.uri.authority() { authority.clone().into() } else { return Err(ConnectError::Unresolved); }; - // acquire connection - match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { - Acquire::Acquired(io, created) => { - // use existing connection - Ok(IoConnection::new( - io, - created, - Some(Acquired(key, Some(inner))), - )) - } - Acquire::Available => { - // open tcp connection + // acquire an owned permit and carry it with connection + let permit = inner.permits.clone().acquire_owned().await.map_err(|_| { + ConnectError::Io(io::Error::new( + io::ErrorKind::Other, + "failed to acquire semaphore on client connection pool", + )) + })?; + + let conn = { + let mut conn = None; + + // check if there is idle connection for given key. + let mut map = inner.available.borrow_mut(); + + if let Some(conns) = map.get_mut(&key) { + let now = Instant::now(); + + while let Some(mut c) = conns.pop_front() { + let config = &inner.config; + let idle_dur = now - c.used; + let age = now - c.created; + let conn_ineligible = idle_dur > config.conn_keep_alive + || age > config.conn_lifetime; + + if conn_ineligible { + // drop connections that are too old + inner.close(c.conn); + } else { + // check if the connection is still usable + if let ConnectionType::H1(ref mut io) = c.conn { + let check = ConnectionCheckFuture { io }; + match check.await { + ConnectionState::Tainted => { + inner.close(c.conn); + continue; + } + ConnectionState::Skip => continue, + ConnectionState::Live => conn = Some(c), + } + } else { + conn = Some(c); + } + + break; + } + } + }; + + conn + }; + + // construct acquired. It's used to put Io type back to pool/ close the Io type. + // permit is carried with the whole lifecycle of Acquired. + let acquired = Some(Acquired { key, inner, permit }); + + // match the connection and spawn new one if did not get anything. + match conn { + Some(conn) => Ok(IoConnection::new(conn.conn, conn.created, acquired)), + None => { let (io, proto) = connector.call(req).await?; - let config = inner.borrow().config.clone(); - - let guard = OpenGuard::new(key, inner); - if proto == Protocol::Http1 { Ok(IoConnection::new( ConnectionType::H1(io), Instant::now(), - Some(guard.consume()), + acquired, )) } else { - let (snd, connection) = handshake(io, &config).await?; - actix_rt::spawn(connection.map(|_| ())); + let config = &acquired.as_ref().unwrap().inner.config; + let (sender, connection) = handshake(io, config).await?; Ok(IoConnection::new( - ConnectionType::H2(snd), + ConnectionType::H2(H2Connection::new(sender, connection)), Instant::now(), - Some(guard.consume()), + acquired, )) } } - _ => { - // connection is not available, wait - let (rx, token) = inner.borrow_mut().wait_for(req); - - let guard = WaiterGuard::new(key, token, inner); - let res = match rx.await { - Err(_) => Err(ConnectError::Disconnected), - Ok(res) => res, - }; - guard.consume(); - res - } } - }; - - fut.boxed_local() - } -} - -struct WaiterGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - key: Key, - token: usize, - inner: Option>>>, -} - -impl WaiterGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn new(key: Key, token: usize, inner: Rc>>) -> Self { - Self { - key, - token, - inner: Some(inner), - } - } - - fn consume(mut self) { - let _ = self.inner.take(); - } -} - -impl Drop for WaiterGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn drop(&mut self) { - if let Some(i) = self.inner.take() { - let mut inner = i.as_ref().borrow_mut(); - inner.release_waiter(&self.key, self.token); - inner.check_availability(); - } - } -} - -struct OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - key: Key, - inner: Option>>>, -} - -impl OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn new(key: Key, inner: Rc>>) -> Self { - Self { - key, - inner: Some(inner), - } - } - - fn consume(mut self) -> Acquired { - Acquired(self.key.clone(), self.inner.take()) - } -} - -impl Drop for OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn drop(&mut self) { - if let Some(i) = self.inner.take() { - let mut inner = i.as_ref().borrow_mut(); - inner.release(); - inner.check_availability(); - } - } -} - -enum Acquire { - Acquired(ConnectionType, Instant), - Available, - NotAvailable, -} - -struct AvailableConnection { - io: ConnectionType, - used: Instant, - created: Instant, -} - -pub(crate) struct Inner { - config: ConnectorConfig, - acquired: usize, - available: FxHashMap>>, - waiters: Slab< - Option<( - Connect, - oneshot::Sender, ConnectError>>, - )>, - >, - waiters_queue: IndexSet<(Key, usize)>, - waker: LocalWaker, -} - -impl Inner { - fn reserve(&mut self) { - self.acquired += 1; - } - - fn release(&mut self) { - self.acquired -= 1; - } - - fn release_waiter(&mut self, key: &Key, token: usize) { - self.waiters.remove(token); - let _ = self.waiters_queue.shift_remove(&(key.clone(), token)); - } -} - -impl Inner -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - /// connection is not available, wait - fn wait_for( - &mut self, - connect: Connect, - ) -> ( - oneshot::Receiver, ConnectError>>, - usize, - ) { - let (tx, rx) = oneshot::channel(); - - let key: Key = connect.uri.authority().unwrap().clone().into(); - let entry = self.waiters.vacant_entry(); - let token = entry.key(); - entry.insert(Some((connect, tx))); - assert!(self.waiters_queue.insert((key, token))); - - (rx, token) - } - - fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire { - // check limits - if self.config.limit > 0 && self.acquired >= self.config.limit { - return Acquire::NotAvailable; - } - - self.reserve(); - - // check if open connection is available - // cleanup stale connections at the same time - if let Some(ref mut connections) = self.available.get_mut(key) { - let now = Instant::now(); - while let Some(conn) = connections.pop_back() { - // check if it still usable - if (now - conn.used) > self.config.conn_keep_alive - || (now - conn.created) > self.config.conn_lifetime - { - if let Some(timeout) = self.config.disconnect_timeout { - if let ConnectionType::H1(io) = conn.io { - actix_rt::spawn(CloseConnection::new(io, timeout)) - } - } - } else { - let mut io = conn.io; - let mut buf = [0; 2]; - if let ConnectionType::H1(ref mut s) = io { - match Pin::new(s).poll_read(cx, &mut buf) { - Poll::Pending => (), - Poll::Ready(Ok(n)) if n > 0 => { - if let Some(timeout) = self.config.disconnect_timeout { - if let ConnectionType::H1(io) = io { - actix_rt::spawn(CloseConnection::new( - io, timeout, - )) - } - } - continue; - } - _ => continue, - } - } - return Acquire::Acquired(io, conn.created); - } - } - } - Acquire::Available - } - - fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) { - self.acquired -= 1; - self.available - .entry(key.clone()) - .or_insert_with(VecDeque::new) - .push_back(AvailableConnection { - io, - created, - used: Instant::now(), - }); - self.check_availability(); - } - - fn release_close(&mut self, io: ConnectionType) { - self.acquired -= 1; - if let Some(timeout) = self.config.disconnect_timeout { - if let ConnectionType::H1(io) = io { - actix_rt::spawn(CloseConnection::new(io, timeout)) - } - } - self.check_availability(); - } - - fn check_availability(&self) { - if !self.waiters_queue.is_empty() && self.acquired < self.config.limit { - self.waker.wake(); - } - } -} - -struct CloseConnection { - io: T, - timeout: Delay, -} - -impl CloseConnection -where - T: AsyncWrite + Unpin, -{ - fn new(io: T, timeout: Duration) -> Self { - CloseConnection { - io, - timeout: delay_for(timeout), - } - } -} - -impl Future for CloseConnection -where - T: AsyncWrite + Unpin, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let this = self.get_mut(); - - match Pin::new(&mut this.timeout).poll(cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - }, - } - } -} - -#[pin_project] -struct ConnectorPoolSupport -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - connector: T, - inner: Rc>>, -} - -impl Future for ConnectorPoolSupport -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service, - T::Future: 'static, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - if Rc::strong_count(this.inner) == 1 { - // If we are last copy of Inner it means the ConnectionPool is already gone - // and we are safe to exit. - return Poll::Ready(()); - } - - let mut inner = this.inner.borrow_mut(); - inner.waker.register(cx.waker()); - - // check waiters - loop { - let (key, token) = { - if let Some((key, token)) = inner.waiters_queue.get_index(0) { - (key.clone(), *token) - } else { - break; - } - }; - if inner.waiters.get(token).unwrap().is_none() { - continue; - } - - match inner.acquire(&key, cx) { - Acquire::NotAvailable => break, - Acquire::Acquired(io, created) => { - let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; - if let Err(conn) = tx.send(Ok(IoConnection::new( - io, - created, - Some(Acquired(key.clone(), Some(this.inner.clone()))), - ))) { - let (io, created) = conn.unwrap().into_inner(); - inner.release_conn(&key, io, created); - } - } - Acquire::Available => { - let (connect, tx) = - inner.waiters.get_mut(token).unwrap().take().unwrap(); - OpenWaitingConnection::spawn( - key.clone(), - tx, - this.inner.clone(), - this.connector.call(connect), - inner.config.clone(), - ); - } - } - let _ = inner.waiters_queue.swap_remove_index(0); - } - - Poll::Pending - } -} - -#[pin_project::pin_project(PinnedDrop)] -struct OpenWaitingConnection -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - #[pin] - fut: F, - key: Key, - h2: Option< - LocalBoxFuture< - 'static, - Result<(SendRequest, Connection), h2::Error>, - >, - >, - rx: Option, ConnectError>>>, - inner: Option>>>, - config: ConnectorConfig, -} - -impl OpenWaitingConnection -where - F: Future> + 'static, - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn spawn( - key: Key, - rx: oneshot::Sender, ConnectError>>, - inner: Rc>>, - fut: F, - config: ConnectorConfig, - ) { - actix_rt::spawn(OpenWaitingConnection { - key, - fut, - h2: None, - rx: Some(rx), - inner: Some(inner), - config, }) } } -#[pin_project::pinned_drop] -impl PinnedDrop for OpenWaitingConnection +/// Type for check the connection and determine if it's usable. +struct ConnectionCheckFuture<'a, Io> { + io: &'a mut Io, +} + +enum ConnectionState { + /// IO is pending and a new request would wake it. + Live, + + /// IO unexpectedly has unread data and should be dropped. + Tainted, + + /// IO should be skipped but not dropped. + Skip, +} + +impl Future for ConnectionCheckFuture<'_, Io> where - Io: AsyncRead + AsyncWrite + Unpin + 'static, + Io: AsyncRead + Unpin, { - fn drop(self: Pin<&mut Self>) { - if let Some(inner) = self.project().inner.take() { - let mut inner = inner.as_ref().borrow_mut(); - inner.release(); - inner.check_availability(); + type Output = ConnectionState; + + // this future is only used to get access to Context. + // It should never return Poll::Pending. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let mut buf = [0; 2]; + let mut read_buf = ReadBuf::new(&mut buf); + + let state = match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) if !read_buf.filled().is_empty() => { + ConnectionState::Tainted + } + + Poll::Pending => ConnectionState::Live, + _ => ConnectionState::Skip, + }; + + Poll::Ready(state) + } +} + +struct PooledConnection { + conn: ConnectionType, + used: Instant, + created: Instant, +} + +#[pin_project] +struct CloseConnection { + io: Io, + #[pin] + timeout: Sleep, +} + +impl CloseConnection +where + Io: AsyncWrite + Unpin, +{ + fn new(io: Io, timeout: Duration) -> Self { + CloseConnection { + io, + timeout: sleep(timeout), } } } -impl Future for OpenWaitingConnection +impl Future for CloseConnection where - F: Future>, - Io: AsyncRead + AsyncWrite + Unpin, + Io: AsyncWrite + Unpin, { type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.project(); - if let Some(ref mut h2) = this.h2 { - return match Pin::new(h2).poll(cx) { - Poll::Ready(Ok((snd, connection))) => { - actix_rt::spawn(connection.map(|_| ())); - let rx = this.rx.take().unwrap(); - let _ = rx.send(Ok(IoConnection::new( - ConnectionType::H2(snd), - Instant::now(), - Some(Acquired(this.key.clone(), this.inner.take())), - ))); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - Poll::Ready(Err(err)) => { - let _ = this.inner.take(); - if let Some(rx) = this.rx.take() { - let _ = rx.send(Err(ConnectError::H2(err))); - } - Poll::Ready(()) - } - }; - } - - match this.fut.poll(cx) { - Poll::Ready(Err(err)) => { - let _ = this.inner.take(); - if let Some(rx) = this.rx.take() { - let _ = rx.send(Err(err)); - } - Poll::Ready(()) - } - Poll::Ready(Ok((io, proto))) => { - if proto == Protocol::Http1 { - let rx = this.rx.take().unwrap(); - let _ = rx.send(Ok(IoConnection::new( - ConnectionType::H1(io), - Instant::now(), - Some(Acquired(this.key.clone(), this.inner.take())), - ))); - Poll::Ready(()) - } else { - *this.h2 = Some(handshake(io, this.config).boxed_local()); - self.poll(cx) - } - } - Poll::Pending => Poll::Pending, + match this.timeout.poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Pin::new(this.io).poll_shutdown(cx).map(|_| ()), } } } -pub(crate) struct Acquired(Key, Option>>>); - -impl Acquired +pub(crate) struct Acquired where - T: AsyncRead + AsyncWrite + Unpin + 'static, + Io: AsyncWrite + Unpin + 'static, { - pub(crate) fn close(&mut self, conn: IoConnection) { - if let Some(inner) = self.1.take() { - let (io, _) = conn.into_inner(); - inner.as_ref().borrow_mut().release_close(io); - } + key: Key, + inner: ConnectionPoolInner, + permit: OwnedSemaphorePermit, +} + +impl Acquired +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + /// Close the IO. + pub(crate) fn close(&mut self, conn: IoConnection) { + let (conn, _) = conn.into_inner(); + self.inner.close(conn); } - pub(crate) fn release(&mut self, conn: IoConnection) { - if let Some(inner) = self.1.take() { - let (io, created) = conn.into_inner(); - inner - .as_ref() - .borrow_mut() - .release_conn(&self.0, io, created); - } + + /// Release IO back into pool. + pub(crate) fn release(&mut self, conn: IoConnection) { + let (io, created) = conn.into_inner(); + let Acquired { key, inner, .. } = self; + + inner + .available + .borrow_mut() + .entry(key.clone()) + .or_insert_with(VecDeque::new) + .push_back(PooledConnection { + conn: io, + created, + used: Instant::now(), + }); + + let _ = &mut self.permit; } } -impl Drop for Acquired { - fn drop(&mut self) { - if let Some(inner) = self.1.take() { - inner.as_ref().borrow_mut().release(); +#[cfg(test)] +mod test { + use std::{cell::Cell, io}; + + use http::Uri; + + use super::*; + use crate::client::connection::IoConnection; + + /// A stream type that always returns pending on async read. + /// + /// Mocks an idle TCP stream that is ready to be used for client connections. + struct TestStream(Rc>); + + impl Drop for TestStream { + fn drop(&mut self) { + self.0.set(self.0.get() - 1); } } + + impl AsyncRead for TestStream { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } + } + + impl AsyncWrite for TestStream { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &[u8], + ) -> Poll> { + unimplemented!() + } + + fn poll_flush( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + unimplemented!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + struct TestPoolConnector { + generated: Rc>, + } + + impl Service for TestPoolConnector { + type Response = (TestStream, Protocol); + type Error = ConnectError; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::always_ready!(); + + fn call(&self, _: Connect) -> Self::Future { + self.generated.set(self.generated.get() + 1); + let generated = self.generated.clone(); + Box::pin(async { Ok((TestStream(generated), Protocol::Http1)) }) + } + } + + fn release(conn: IoConnection) + where + T: AsyncRead + AsyncWrite + Unpin + 'static, + { + let (conn, created, mut acquired) = conn.into_parts(); + acquired.release(IoConnection::new(conn, created, None)); + } + + #[actix_rt::test] + async fn test_pool_limit() { + let connector = TestPoolConnector { + generated: Rc::new(Cell::new(0)), + }; + + let config = ConnectorConfig { + limit: 1, + ..Default::default() + }; + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("http://localhost"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + + let waiting = Rc::new(Cell::new(true)); + + let waiting_clone = waiting.clone(); + actix_rt::spawn(async move { + actix_rt::time::sleep(Duration::from_millis(100)).await; + waiting_clone.set(false); + drop(conn); + }); + + assert!(waiting.get()); + + let now = Instant::now(); + let conn = pool.call(req).await.unwrap(); + + release(conn); + assert!(!waiting.get()); + assert!(now.elapsed() >= Duration::from_millis(100)); + } + + #[actix_rt::test] + async fn test_pool_keep_alive() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig { + conn_keep_alive: Duration::from_secs(1), + ..Default::default() + }; + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("http://localhost"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + actix_rt::time::sleep(Duration::from_millis(1500)).await; + actix_rt::task::yield_now().await; + + let conn = pool.call(req).await.unwrap(); + // Note: spawned recycle connection is not ran yet. + // This is tokio current thread runtime specific behavior. + assert_eq!(2, generated_clone.get()); + + // yield task so the old connection is properly dropped. + actix_rt::task::yield_now().await; + assert_eq!(1, generated_clone.get()); + + release(conn); + } + + #[actix_rt::test] + async fn test_pool_lifetime() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig { + conn_lifetime: Duration::from_secs(1), + ..Default::default() + }; + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("http://localhost"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + actix_rt::time::sleep(Duration::from_millis(1500)).await; + actix_rt::task::yield_now().await; + + let conn = pool.call(req).await.unwrap(); + // Note: spawned recycle connection is not ran yet. + // This is tokio current thread runtime specific behavior. + assert_eq!(2, generated_clone.get()); + + // yield task so the old connection is properly dropped. + actix_rt::task::yield_now().await; + assert_eq!(1, generated_clone.get()); + + release(conn); + } + + #[actix_rt::test] + async fn test_pool_authority_key() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig::default(); + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("https://crates.io"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let conn = pool.call(req).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let req = Connect { + uri: Uri::from_static("https://google.com"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(2, generated_clone.get()); + release(conn); + let conn = pool.call(req).await.unwrap(); + assert_eq!(2, generated_clone.get()); + release(conn); + } + + #[actix_rt::test] + async fn test_pool_drop() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig::default(); + + let pool = Rc::new(super::ConnectionPool::new(connector, config)); + + let req = Connect { + uri: Uri::from_static("https://crates.io"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let req = Connect { + uri: Uri::from_static("https://google.com"), + addr: None, + }; + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(2, generated_clone.get()); + release(conn); + + let clone1 = pool.clone(); + let clone2 = clone1.clone(); + + drop(clone2); + for _ in 0..2 { + actix_rt::task::yield_now().await; + } + assert_eq!(2, generated_clone.get()); + + drop(clone1); + for _ in 0..2 { + actix_rt::task::yield_now().await; + } + assert_eq!(2, generated_clone.get()); + + drop(pool); + for _ in 0..2 { + actix_rt::task::yield_now().await; + } + assert_eq!(0, generated_clone.get()); + } } diff --git a/actix-http/src/clinu/mod.rs b/actix-http/src/clinu/mod.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/actix-http/src/cloneable.rs b/actix-http/src/cloneable.rs deleted file mode 100644 index 0e77c455c..000000000 --- a/actix-http/src/cloneable.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::cell::RefCell; -use std::rc::Rc; -use std::task::{Context, Poll}; - -use actix_service::Service; - -/// Service that allows to turn non-clone service to a service with `Clone` impl -/// -/// # Panics -/// CloneableService might panic with some creative use of thread local storage. -/// See https://github.com/actix/actix-web/issues/1295 for example -#[doc(hidden)] -pub(crate) struct CloneableService(Rc>); - -impl CloneableService { - pub(crate) fn new(service: T) -> Self { - Self(Rc::new(RefCell::new(service))) - } -} - -impl Clone for CloneableService { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl Service for CloneableService { - type Request = T::Request; - type Response = T::Response; - type Error = T::Error; - type Future = T::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0.borrow_mut().poll_ready(cx) - } - - fn call(&mut self, req: T::Request) -> Self::Future { - self.0.borrow_mut().call(req) - } -} diff --git a/actix-http/src/config.rs b/actix-http/src/config.rs index b314d4c99..9f84b8694 100644 --- a/actix-http/src/config.rs +++ b/actix-http/src/config.rs @@ -4,12 +4,14 @@ use std::rc::Rc; use std::time::Duration; use std::{fmt, net}; -use actix_rt::time::{delay_for, delay_until, Delay, Instant}; +use actix_rt::{ + task::JoinHandle, + time::{interval, sleep_until, Instant, Sleep}, +}; use bytes::BytesMut; -use futures_util::{future, FutureExt}; use time::OffsetDateTime; -// "Sun, 06 Nov 1994 08:49:37 GMT".len() +/// "Sun, 06 Nov 1994 08:49:37 GMT".len() const DATE_VALUE_LENGTH: usize = 29; #[derive(Debug, PartialEq, Clone, Copy)] @@ -49,7 +51,7 @@ struct Inner { ka_enabled: bool, secure: bool, local_addr: Option, - timer: DateService, + date_service: DateService, } impl Clone for ServiceConfig { @@ -91,41 +93,41 @@ impl ServiceConfig { client_disconnect, secure, local_addr, - timer: DateService::new(), + date_service: DateService::new(), })) } + /// Returns true if connection is secure (HTTPS) #[inline] - /// Returns true if connection is secure(https) pub fn secure(&self) -> bool { self.0.secure } - #[inline] /// Returns the local address that this server is bound to. + #[inline] pub fn local_addr(&self) -> Option { self.0.local_addr } - #[inline] /// Keep alive duration if configured. + #[inline] pub fn keep_alive(&self) -> Option { self.0.keep_alive } - #[inline] /// Return state of connection keep-alive functionality + #[inline] pub fn keep_alive_enabled(&self) -> bool { self.0.ka_enabled } - #[inline] /// Client timeout for first request. - pub fn client_timer(&self) -> Option { + #[inline] + pub fn client_timer(&self) -> Option { let delay_time = self.0.client_timeout; if delay_time != 0 { - Some(delay_until( - self.0.timer.now() + Duration::from_millis(delay_time), + Some(sleep_until( + self.0.date_service.now() + Duration::from_millis(delay_time), )) } else { None @@ -136,7 +138,7 @@ impl ServiceConfig { pub fn client_timer_expire(&self) -> Option { let delay = self.0.client_timeout; if delay != 0 { - Some(self.0.timer.now() + Duration::from_millis(delay)) + Some(self.0.date_service.now() + Duration::from_millis(delay)) } else { None } @@ -146,7 +148,7 @@ impl ServiceConfig { pub fn client_disconnect_timer(&self) -> Option { let delay = self.0.client_disconnect; if delay != 0 { - Some(self.0.timer.now() + Duration::from_millis(delay)) + Some(self.0.date_service.now() + Duration::from_millis(delay)) } else { None } @@ -154,9 +156,9 @@ impl ServiceConfig { #[inline] /// Return keep-alive timer delay is configured. - pub fn keep_alive_timer(&self) -> Option { + pub fn keep_alive_timer(&self) -> Option { if let Some(ka) = self.0.keep_alive { - Some(delay_until(self.0.timer.now() + ka)) + Some(sleep_until(self.0.date_service.now() + ka)) } else { None } @@ -165,7 +167,7 @@ impl ServiceConfig { /// Keep-alive expire time pub fn keep_alive_expire(&self) -> Option { if let Some(ka) = self.0.keep_alive { - Some(self.0.timer.now() + ka) + Some(self.0.date_service.now() + ka) } else { None } @@ -173,7 +175,7 @@ impl ServiceConfig { #[inline] pub(crate) fn now(&self) -> Instant { - self.0.timer.now() + self.0.date_service.now() } #[doc(hidden)] @@ -181,7 +183,7 @@ impl ServiceConfig { let mut buf: [u8; 39] = [0; 39]; buf[..6].copy_from_slice(b"date: "); self.0 - .timer + .date_service .set_date(|date| buf[6..35].copy_from_slice(&date.bytes)); buf[35..].copy_from_slice(b"\r\n\r\n"); dst.extend_from_slice(&buf); @@ -189,7 +191,7 @@ impl ServiceConfig { pub(crate) fn set_date_header(&self, dst: &mut BytesMut) { self.0 - .timer + .date_service .set_date(|date| dst.extend_from_slice(&date.bytes)); } } @@ -230,57 +232,103 @@ impl fmt::Write for Date { } } -#[derive(Clone)] -struct DateService(Rc); - -struct DateServiceInner { - current: Cell>, +/// Service for update Date and Instant periodically at 500 millis interval. +struct DateService { + current: Rc>, + handle: JoinHandle<()>, } -impl DateServiceInner { - fn new() -> Self { - DateServiceInner { - current: Cell::new(None), - } - } - - fn reset(&self) { - self.current.take(); - } - - fn update(&self) { - let now = Instant::now(); - let date = Date::new(); - self.current.set(Some((date, now))); +impl Drop for DateService { + fn drop(&mut self) { + // stop the timer update async task on drop. + self.handle.abort(); } } impl DateService { fn new() -> Self { - DateService(Rc::new(DateServiceInner::new())) - } + // shared date and timer for DateService and update async task. + let current = Rc::new(Cell::new((Date::new(), Instant::now()))); + let current_clone = Rc::clone(¤t); + // spawn an async task sleep for 500 milli and update current date/timer in a loop. + // handle is used to stop the task on DateService drop. + let handle = actix_rt::spawn(async move { + #[cfg(test)] + let _notify = notify_on_drop::NotifyOnDrop::new(); - fn check_date(&self) { - if self.0.current.get().is_none() { - self.0.update(); + let mut interval = interval(Duration::from_millis(500)); + loop { + let now = interval.tick().await; + let date = Date::new(); + current_clone.set((date, now)); + } + }); - // periodic date update - let s = self.clone(); - actix_rt::spawn(delay_for(Duration::from_millis(500)).then(move |_| { - s.0.reset(); - future::ready(()) - })); - } + DateService { current, handle } } fn now(&self) -> Instant { - self.check_date(); - self.0.current.get().unwrap().1 + self.current.get().1 } fn set_date(&self, mut f: F) { - self.check_date(); - f(&self.0.current.get().unwrap().0); + f(&self.current.get().0); + } +} + +// TODO: move to a util module for testing all spawn handle drop style tasks. +#[cfg(test)] +/// Test Module for checking the drop state of certain async tasks that are spawned +/// with `actix_rt::spawn` +/// +/// The target task must explicitly generate `NotifyOnDrop` when spawn the task +mod notify_on_drop { + use std::cell::RefCell; + + thread_local! { + static NOTIFY_DROPPED: RefCell> = RefCell::new(None); + } + + /// Check if the spawned task is dropped. + /// + /// # Panic: + /// + /// When there was no `NotifyOnDrop` instance on current thread + pub(crate) fn is_dropped() -> bool { + NOTIFY_DROPPED.with(|bool| { + bool.borrow() + .expect("No NotifyOnDrop existed on current thread") + }) + } + + pub(crate) struct NotifyOnDrop; + + impl NotifyOnDrop { + /// # Panic: + /// + /// When construct multiple instances on any given thread. + pub(crate) fn new() -> Self { + NOTIFY_DROPPED.with(|bool| { + let mut bool = bool.borrow_mut(); + if bool.is_some() { + panic!("NotifyOnDrop existed on current thread"); + } else { + *bool = Some(false); + } + }); + + NotifyOnDrop + } + } + + impl Drop for NotifyOnDrop { + fn drop(&mut self) { + NOTIFY_DROPPED.with(|bool| { + if let Some(b) = bool.borrow_mut().as_mut() { + *b = true; + } + }); + } } } @@ -288,14 +336,53 @@ impl DateService { mod tests { use super::*; - // Test modifying the date from within the closure - // passed to `set_date` - #[test] - fn test_evil_date() { - let service = DateService::new(); - // Make sure that `check_date` doesn't try to spawn a task - service.0.update(); - service.set_date(|_| service.0.reset()); + use actix_rt::task::yield_now; + + #[actix_rt::test] + async fn test_date_service_update() { + let settings = ServiceConfig::new(KeepAlive::Os, 0, 0, false, None); + + yield_now().await; + + let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf1); + let now1 = settings.now(); + + sleep_until(Instant::now() + Duration::from_secs(2)).await; + yield_now().await; + + let now2 = settings.now(); + let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf2); + + assert_ne!(now1, now2); + + assert_ne!(buf1, buf2); + + drop(settings); + assert!(notify_on_drop::is_dropped()); + } + + #[actix_rt::test] + async fn test_date_service_drop() { + let service = Rc::new(DateService::new()); + + // yield so date service have a chance to register the spawned timer update task. + yield_now().await; + + let clone1 = service.clone(); + let clone2 = service.clone(); + let clone3 = service.clone(); + + drop(clone1); + assert_eq!(false, notify_on_drop::is_dropped()); + drop(clone2); + assert_eq!(false, notify_on_drop::is_dropped()); + drop(clone3); + assert_eq!(false, notify_on_drop::is_dropped()); + + drop(service); + assert!(notify_on_drop::is_dropped()); } #[test] diff --git a/actix-http/src/encoding/decoder.rs b/actix-http/src/encoding/decoder.rs index b60435859..f0abae865 100644 --- a/actix-http/src/encoding/decoder.rs +++ b/actix-http/src/encoding/decoder.rs @@ -1,25 +1,31 @@ -use std::future::Future; -use std::io::{self, Write}; -use std::pin::Pin; -use std::task::{Context, Poll}; +//! Stream decoders. -use actix_threadpool::{run, CpuFuture}; +use std::{ + future::Future, + io::{self, Write as _}, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_rt::task::{spawn_blocking, JoinHandle}; use brotli2::write::BrotliDecoder; use bytes::Bytes; use flate2::write::{GzDecoder, ZlibDecoder}; use futures_core::{ready, Stream}; -use super::Writer; -use crate::error::PayloadError; -use crate::http::header::{ContentEncoding, HeaderMap, CONTENT_ENCODING}; +use crate::{ + encoding::Writer, + error::{BlockingError, PayloadError}, + http::header::{ContentEncoding, HeaderMap, CONTENT_ENCODING}, +}; -const INPLACE: usize = 2049; +const MAX_CHUNK_SIZE_DECODE_IN_PLACE: usize = 2049; pub struct Decoder { decoder: Option, stream: S, eof: bool, - fut: Option, ContentDecoder), io::Error>>, + fut: Option, ContentDecoder), io::Error>>>, } impl Decoder @@ -41,6 +47,7 @@ where ))), _ => None, }; + Decoder { decoder, stream, @@ -53,15 +60,11 @@ where #[inline] pub fn from_headers(stream: S, headers: &HeaderMap) -> Decoder { // check content-encoding - let encoding = if let Some(enc) = headers.get(&CONTENT_ENCODING) { - if let Ok(enc) = enc.to_str() { - ContentEncoding::from(enc) - } else { - ContentEncoding::Identity - } - } else { - ContentEncoding::Identity - }; + let encoding = headers + .get(&CONTENT_ENCODING) + .and_then(|val| val.to_str().ok()) + .map(ContentEncoding::from) + .unwrap_or(ContentEncoding::Identity); Self::new(stream, encoding) } @@ -79,12 +82,12 @@ where ) -> Poll> { loop { if let Some(ref mut fut) = self.fut { - let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) { - Ok(item) => item, - Err(e) => return Poll::Ready(Some(Err(e.into()))), - }; + let (chunk, decoder) = + ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; + self.decoder = Some(decoder); self.fut.take(); + if let Some(chunk) = chunk { return Poll::Ready(Some(Ok(chunk))); } @@ -94,29 +97,34 @@ where return Poll::Ready(None); } - match Pin::new(&mut self.stream).poll_next(cx) { - Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), - Poll::Ready(Some(Ok(chunk))) => { + match ready!(Pin::new(&mut self.stream).poll_next(cx)) { + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + + Some(Ok(chunk)) => { if let Some(mut decoder) = self.decoder.take() { - if chunk.len() < INPLACE { + if chunk.len() < MAX_CHUNK_SIZE_DECODE_IN_PLACE { let chunk = decoder.feed_data(chunk)?; self.decoder = Some(decoder); + if let Some(chunk) = chunk { return Poll::Ready(Some(Ok(chunk))); } } else { - self.fut = Some(run(move || { + self.fut = Some(spawn_blocking(move || { let chunk = decoder.feed_data(chunk)?; Ok((chunk, decoder)) })); } + continue; } else { return Poll::Ready(Some(Ok(chunk))); } } - Poll::Ready(None) => { + + None => { self.eof = true; + return if let Some(mut decoder) = self.decoder.take() { match decoder.feed_eof() { Ok(Some(res)) => Poll::Ready(Some(Ok(res))), @@ -127,10 +135,8 @@ where Poll::Ready(None) }; } - Poll::Pending => break, } } - Poll::Pending } } @@ -146,6 +152,7 @@ impl ContentDecoder { ContentDecoder::Br(ref mut decoder) => match decoder.flush() { Ok(()) => { let b = decoder.get_mut().take(); + if !b.is_empty() { Ok(Some(b)) } else { @@ -154,9 +161,11 @@ impl ContentDecoder { } Err(e) => Err(e), }, + ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() { Ok(_) => { let b = decoder.get_mut().take(); + if !b.is_empty() { Ok(Some(b)) } else { @@ -165,6 +174,7 @@ impl ContentDecoder { } Err(e) => Err(e), }, + ContentDecoder::Deflate(ref mut decoder) => match decoder.try_finish() { Ok(_) => { let b = decoder.get_mut().take(); @@ -185,6 +195,7 @@ impl ContentDecoder { Ok(_) => { decoder.flush()?; let b = decoder.get_mut().take(); + if !b.is_empty() { Ok(Some(b)) } else { @@ -193,10 +204,12 @@ impl ContentDecoder { } Err(e) => Err(e), }, + ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; let b = decoder.get_mut().take(); + if !b.is_empty() { Ok(Some(b)) } else { @@ -205,9 +218,11 @@ impl ContentDecoder { } Err(e) => Err(e), }, + ContentDecoder::Deflate(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; + let b = decoder.get_mut().take(); if !b.is_empty() { Ok(Some(b)) diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index eb1821285..ee0587fbd 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -1,24 +1,32 @@ -//! Stream encoder -use std::future::Future; -use std::io::{self, Write}; -use std::pin::Pin; -use std::task::{Context, Poll}; +//! Stream encoders. -use actix_threadpool::{run, CpuFuture}; +use std::{ + future::Future, + io::{self, Write as _}, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_rt::task::{spawn_blocking, JoinHandle}; use brotli2::write::BrotliEncoder; use bytes::Bytes; use flate2::write::{GzEncoder, ZlibEncoder}; use futures_core::ready; use pin_project::pin_project; -use crate::body::{Body, BodySize, MessageBody, ResponseBody}; -use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; -use crate::http::{HeaderValue, StatusCode}; -use crate::{Error, ResponseHead}; +use crate::{ + body::{Body, BodySize, MessageBody, ResponseBody}, + http::{ + header::{ContentEncoding, CONTENT_ENCODING}, + HeaderValue, StatusCode, + }, + Error, ResponseHead, +}; use super::Writer; +use crate::error::BlockingError; -const INPLACE: usize = 1024; +const MAX_CHUNK_SIZE_ENCODE_IN_PLACE: usize = 1024; #[pin_project] pub struct Encoder { @@ -26,7 +34,7 @@ pub struct Encoder { #[pin] body: EncoderBody, encoder: Option, - fut: Option>, + fut: Option>>, } impl Encoder { @@ -70,6 +78,7 @@ impl Encoder { }); } } + ResponseBody::Body(Encoder { body, eof: false, @@ -135,32 +144,35 @@ impl MessageBody for Encoder { } if let Some(ref mut fut) = this.fut { - let mut encoder = match ready!(Pin::new(fut).poll(cx)) { - Ok(item) => item, - Err(e) => return Poll::Ready(Some(Err(e.into()))), - }; + let mut encoder = + ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; + let chunk = encoder.take(); *this.encoder = Some(encoder); this.fut.take(); + if !chunk.is_empty() { return Poll::Ready(Some(Ok(chunk))); } } - let result = this.body.as_mut().poll_next(cx); + let result = ready!(this.body.as_mut().poll_next(cx)); match result { - Poll::Ready(Some(Ok(chunk))) => { + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + + Some(Ok(chunk)) => { if let Some(mut encoder) = this.encoder.take() { - if chunk.len() < INPLACE { + if chunk.len() < MAX_CHUNK_SIZE_ENCODE_IN_PLACE { encoder.write(&chunk)?; let chunk = encoder.take(); *this.encoder = Some(encoder); + if !chunk.is_empty() { return Poll::Ready(Some(Ok(chunk))); } } else { - *this.fut = Some(run(move || { + *this.fut = Some(spawn_blocking(move || { encoder.write(&chunk)?; Ok(encoder) })); @@ -169,7 +181,8 @@ impl MessageBody for Encoder { return Poll::Ready(Some(Ok(chunk))); } } - Poll::Ready(None) => { + + None => { if let Some(encoder) = this.encoder.take() { let chunk = encoder.finish()?; if chunk.is_empty() { @@ -182,7 +195,6 @@ impl MessageBody for Encoder { return Poll::Ready(None); } } - val => return val, } } } diff --git a/actix-http/src/encoding/mod.rs b/actix-http/src/encoding/mod.rs index 9eaf4104e..cb271c638 100644 --- a/actix-http/src/encoding/mod.rs +++ b/actix-http/src/encoding/mod.rs @@ -1,4 +1,5 @@ -//! Content-Encoding support +//! Content-Encoding support. + use std::io; use bytes::{Bytes, BytesMut}; diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 0ebd4c05c..d3095e68d 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -7,24 +7,23 @@ use std::string::FromUtf8Error; use std::{fmt, io, result}; use actix_codec::{Decoder, Encoder}; -pub use actix_threadpool::BlockingError; use actix_utils::dispatcher::DispatcherError as FramedDispatcherError; use actix_utils::timeout::TimeoutError; use bytes::BytesMut; use derive_more::{Display, From}; -pub use futures_channel::oneshot::Canceled; use http::uri::InvalidUri; use http::{header, Error as HttpError, StatusCode}; use serde::de::value::Error as DeError; use serde_json::error::Error as JsonError; use serde_urlencoded::ser::Error as FormError; -// re-export for convenience use crate::body::Body; -pub use crate::cookie::ParseError as CookieParseError; use crate::helpers::Writer; use crate::response::{Response, ResponseBuilder}; +#[cfg(feature = "cookies")] +pub use crate::cookie::ParseError as CookieParseError; + /// A specialized [`std::result::Result`] /// for actix web operations /// @@ -40,7 +39,7 @@ pub type Result = result::Result; /// converting errors with `into()`. /// /// Whenever it is created from an external object a response error is created -/// for it that can be used to create an http response from it this means that +/// for it that can be used to create an HTTP response from it this means that /// if you have access to an actix `Error` you can always get a /// `ResponseError` reference from it. pub struct Error { @@ -100,10 +99,6 @@ impl fmt::Debug for Error { } impl std::error::Error for Error { - fn cause(&self) -> Option<&dyn std::error::Error> { - None - } - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } @@ -153,7 +148,10 @@ impl From for Error { } } -/// Return `GATEWAY_TIMEOUT` for `TimeoutError` +/// Inspects the underlying enum and returns an appropriate status code. +/// +/// If the variant is [`TimeoutError::Service`], the error code of the service is returned. +/// Otherwise, [`StatusCode::GATEWAY_TIMEOUT`] is returned. impl ResponseError for TimeoutError { fn status_code(&self) -> StatusCode { match self { @@ -167,48 +165,41 @@ impl ResponseError for TimeoutError { #[display(fmt = "UnknownError")] struct UnitError; -/// `InternalServerError` for `UnitError` +/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`UnitError`]. impl ResponseError for UnitError {} -/// `InternalServerError` for `JsonError` +/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`JsonError`]. impl ResponseError for JsonError {} -/// `InternalServerError` for `FormError` +/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`FormError`]. impl ResponseError for FormError {} #[cfg(feature = "openssl")] -/// `InternalServerError` for `openssl::ssl::Error` -impl ResponseError for actix_connect::ssl::openssl::SslError {} +/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`actix_tls::accept::openssl::SslError`]. +impl ResponseError for actix_tls::accept::openssl::SslError {} -#[cfg(feature = "openssl")] -/// `InternalServerError` for `openssl::ssl::HandshakeError` -impl ResponseError for actix_tls::openssl::HandshakeError {} - -/// Return `BAD_REQUEST` for `de::value::Error` +/// Returns [`StatusCode::BAD_REQUEST`] for [`DeError`]. impl ResponseError for DeError { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST } } -/// `InternalServerError` for `Canceled` -impl ResponseError for Canceled {} - -/// `InternalServerError` for `BlockingError` -impl ResponseError for BlockingError {} - -/// Return `BAD_REQUEST` for `Utf8Error` +/// Returns [`StatusCode::BAD_REQUEST`] for [`Utf8Error`]. impl ResponseError for Utf8Error { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST } } -/// Return `InternalServerError` for `HttpError`, -/// Response generation can return `HttpError`, so it is internal error +/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`HttpError`]. impl ResponseError for HttpError {} -/// Return `InternalServerError` for `io::Error` +/// Inspects the underlying [`io::ErrorKind`] and returns an appropriate status code. +/// +/// If the error is [`io::ErrorKind::NotFound`], [`StatusCode::NOT_FOUND`] is returned. If the +/// error is [`io::ErrorKind::PermissionDenied`], [`StatusCode::FORBIDDEN`] is returned. Otherwise, +/// [`StatusCode::INTERNAL_SERVER_ERROR`] is returned. impl ResponseError for io::Error { fn status_code(&self) -> StatusCode { match self.kind() { @@ -219,7 +210,7 @@ impl ResponseError for io::Error { } } -/// `BadRequest` for `InvalidHeaderValue` +/// Returns [`StatusCode::BAD_REQUEST`] for [`header::InvalidHeaderValue`]. impl ResponseError for header::InvalidHeaderValue { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST @@ -308,33 +299,60 @@ impl From for ParseError { } } +/// A set of errors that can occur running blocking tasks in thread pool. +#[derive(Debug, Display)] +#[display(fmt = "Blocking thread pool is gone")] +pub struct BlockingError; + +impl std::error::Error for BlockingError {} + +/// `InternalServerError` for `BlockingError` +impl ResponseError for BlockingError {} + #[derive(Display, Debug)] /// A set of errors that can occur during payload parsing pub enum PayloadError { /// A payload reached EOF, but is not complete. #[display( - fmt = "A payload reached EOF, but is not complete. With error: {:?}", + fmt = "A payload reached EOF, but is not complete. Inner error: {:?}", _0 )] Incomplete(Option), - /// Content encoding stream corruption + + /// Content encoding stream corruption. #[display(fmt = "Can not decode content-encoding.")] EncodingCorrupted, - /// A payload reached size limit. - #[display(fmt = "A payload reached size limit.")] + + /// Payload reached size limit. + #[display(fmt = "Payload reached size limit.")] Overflow, - /// A payload length is unknown. - #[display(fmt = "A payload length is unknown.")] + + /// Payload length is unknown. + #[display(fmt = "Payload length is unknown.")] UnknownLength, - /// Http2 payload error + + /// HTTP/2 payload error. #[display(fmt = "{}", _0)] Http2Payload(h2::Error), - /// Io error + + /// Generic I/O error. #[display(fmt = "{}", _0)] Io(io::Error), } -impl std::error::Error for PayloadError {} +impl std::error::Error for PayloadError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + PayloadError::Incomplete(None) => None, + PayloadError::Incomplete(Some(err)) => Some(err as &dyn std::error::Error), + PayloadError::EncodingCorrupted => None, + PayloadError::Overflow => None, + PayloadError::UnknownLength => None, + PayloadError::Http2Payload(err) => Some(err as &dyn std::error::Error), + PayloadError::Io(err) => Some(err as &dyn std::error::Error), + } + } +} impl From for PayloadError { fn from(err: h2::Error) -> Self { @@ -354,15 +372,12 @@ impl From for PayloadError { } } -impl From> for PayloadError { - fn from(err: BlockingError) -> Self { - match err { - BlockingError::Error(e) => PayloadError::Io(e), - BlockingError::Canceled => PayloadError::Io(io::Error::new( - io::ErrorKind::Other, - "Operation is canceled", - )), - } +impl From for PayloadError { + fn from(_: BlockingError) -> Self { + PayloadError::Io(io::Error::new( + io::ErrorKind::Other, + "Operation is canceled", + )) } } @@ -380,6 +395,7 @@ impl ResponseError for PayloadError { } /// Return `BadRequest` for `cookie::ParseError` +#[cfg(feature = "cookies")] impl ResponseError for crate::cookie::ParseError { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST @@ -387,7 +403,7 @@ impl ResponseError for crate::cookie::ParseError { } #[derive(Debug, Display, From)] -/// A set of errors that can occur during dispatching http requests +/// A set of errors that can occur during dispatching HTTP requests pub enum DispatchError { /// Service error Service(Error), @@ -951,16 +967,6 @@ where InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() } -#[cfg(feature = "actors")] -/// `InternalServerError` for `actix::MailboxError` -/// This is supported on feature=`actors` only -impl ResponseError for actix::MailboxError {} - -#[cfg(feature = "actors")] -/// `InternalServerError` for `actix::ResolverError` -/// This is supported on feature=`actors` only -impl ResponseError for actix::actors::resolver::ResolverError {} - #[cfg(test)] mod tests { use super::*; @@ -977,6 +983,7 @@ mod tests { assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } + #[cfg(feature = "cookies")] #[test] fn test_cookie_parse() { let resp: Response = CookieParseError::EmptyName.error_response(); @@ -1018,22 +1025,22 @@ mod tests { fn test_payload_error() { let err: PayloadError = io::Error::new(io::ErrorKind::Other, "ParseError").into(); - assert!(format!("{}", err).contains("ParseError")); + assert!(err.to_string().contains("ParseError")); let err = PayloadError::Incomplete(None); assert_eq!( - format!("{}", err), - "A payload reached EOF, but is not complete. With error: None" + err.to_string(), + "A payload reached EOF, but is not complete. Inner error: None" ); } macro_rules! from { ($from:expr => $error:pat) => { match ParseError::from($from) { - e @ $error => { - assert!(format!("{}", e).len() >= 5); + err @ $error => { + assert!(err.to_string().len() >= 5); } - e => unreachable!("{:?}", e), + err => unreachable!("{:?}", err), } }; } @@ -1076,7 +1083,7 @@ mod tests { let err = PayloadError::Overflow; let resp_err: &dyn ResponseError = &err; let err = resp_err.downcast_ref::().unwrap(); - assert_eq!(err.to_string(), "A payload reached size limit."); + assert_eq!(err.to_string(), "Payload reached size limit."); let not_err = resp_err.downcast_ref::(); assert!(not_err.is_none()); } diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs index b20dfe11d..5fdcefd6d 100644 --- a/actix-http/src/extensions.rs +++ b/actix-http/src/extensions.rs @@ -1,62 +1,119 @@ -use std::any::{Any, TypeId}; -use std::{fmt, mem}; +use std::{ + any::{Any, TypeId}, + fmt, mem, +}; -use fxhash::FxHashMap; +use ahash::AHashMap; -/// A type map of request extensions. +/// A type map for request extensions. +/// +/// All entries into this map must be owned types (or static references). #[derive(Default)] pub struct Extensions { /// Use FxHasher with a std HashMap with for faster /// lookups on the small `TypeId` (u64 equivalent) keys. - map: FxHashMap>, + map: AHashMap>, } impl Extensions { - /// Create an empty `Extensions`. + /// Creates an empty `Extensions`. #[inline] pub fn new() -> Extensions { Extensions { - map: FxHashMap::default(), + map: AHashMap::default(), } } - /// Insert a type into this `Extensions`. + /// Insert an item into the map. /// - /// If a extension of this type already existed, it will - /// be returned. - pub fn insert(&mut self, val: T) { - self.map.insert(TypeId::of::(), Box::new(val)); + /// If an item of this type was already stored, it will be replaced and returned. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// assert_eq!(map.insert(""), None); + /// assert_eq!(map.insert(1u32), None); + /// assert_eq!(map.insert(2u32), Some(1u32)); + /// assert_eq!(*map.get::().unwrap(), 2u32); + /// ``` + pub fn insert(&mut self, val: T) -> Option { + self.map + .insert(TypeId::of::(), Box::new(val)) + .and_then(downcast_owned) } - /// Check if container contains entry + /// Check if map contains an item of a given type. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// assert!(!map.contains::()); + /// + /// assert_eq!(map.insert(1u32), None); + /// assert!(map.contains::()); + /// ``` pub fn contains(&self) -> bool { self.map.contains_key(&TypeId::of::()) } - /// Get a reference to a type previously inserted on this `Extensions`. + /// Get a reference to an item of a given type. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// map.insert(1u32); + /// assert_eq!(map.get::(), Some(&1u32)); + /// ``` pub fn get(&self) -> Option<&T> { self.map .get(&TypeId::of::()) .and_then(|boxed| boxed.downcast_ref()) } - /// Get a mutable reference to a type previously inserted on this `Extensions`. + /// Get a mutable reference to an item of a given type. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// map.insert(1u32); + /// assert_eq!(map.get_mut::(), Some(&mut 1u32)); + /// ``` pub fn get_mut(&mut self) -> Option<&mut T> { self.map .get_mut(&TypeId::of::()) .and_then(|boxed| boxed.downcast_mut()) } - /// Remove a type from this `Extensions`. + /// Remove an item from the map of a given type. /// - /// If a extension of this type existed, it will be returned. + /// If an item of this type was already stored, it will be returned. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// + /// map.insert(1u32); + /// assert_eq!(map.get::(), Some(&1u32)); + /// + /// assert_eq!(map.remove::(), Some(1u32)); + /// assert!(!map.contains::()); + /// ``` pub fn remove(&mut self) -> Option { - self.map - .remove(&TypeId::of::()) - .and_then(|boxed| boxed.downcast().ok().map(|boxed| *boxed)) + self.map.remove(&TypeId::of::()).and_then(downcast_owned) } /// Clear the `Extensions` of all inserted extensions. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// + /// map.insert(1u32); + /// assert!(map.contains::()); + /// + /// map.clear(); + /// assert!(!map.contains::()); + /// ``` #[inline] pub fn clear(&mut self) { self.map.clear(); @@ -79,6 +136,10 @@ impl fmt::Debug for Extensions { } } +fn downcast_owned(boxed: Box) -> Option { + boxed.downcast().ok().map(|boxed| *boxed) +} + #[cfg(test)] mod tests { use super::*; diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs index 2e0103409..4a6104688 100644 --- a/actix-http/src/h1/client.rs +++ b/actix-http/src/h1/client.rs @@ -223,15 +223,3 @@ impl Encoder> for ClientCodec { Ok(()) } } - -pub struct Writer<'a>(pub &'a mut BytesMut); - -impl<'a> io::Write for Writer<'a> { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.extend_from_slice(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs index d5035df26..634ca25e8 100644 --- a/actix-http/src/h1/codec.rs +++ b/actix-http/src/h1/codec.rs @@ -199,10 +199,10 @@ mod tests { use http::Method; use super::*; - use crate::httpmessage::HttpMessage; + use crate::HttpMessage; - #[test] - fn test_http_request_chunked_payload_and_next_message() { + #[actix_rt::test] + async fn test_http_request_chunked_payload_and_next_message() { let mut codec = Codec::default(); let mut buf = BytesMut::from( diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs index 8e891dc5c..93a4b13d2 100644 --- a/actix-http/src/h1/decoder.rs +++ b/actix-http/src/h1/decoder.rs @@ -14,7 +14,7 @@ use crate::header::HeaderMap; use crate::message::{ConnectionType, ResponseHead}; use crate::request::Request; -const MAX_BUFFER_SIZE: usize = 131_072; +pub(crate) const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; /// Incoming message decoder @@ -137,7 +137,7 @@ pub(crate) trait MessageType: Sized { expect = true; } } - _ => (), + _ => {} } headers.append(name, value); @@ -203,7 +203,15 @@ impl MessageType for Request { (len, method, uri, version, req.headers.len()) } - httparse::Status::Partial => return Ok(None), + httparse::Status::Partial => { + return if src.len() >= MAX_BUFFER_SIZE { + trace!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + Err(ParseError::TooLarge) + } else { + // Return None to notify more read are needed for parsing request + Ok(None) + }; + } } }; @@ -216,15 +224,12 @@ impl MessageType for Request { let decoder = match length { PayloadLength::Payload(pl) => pl, PayloadLength::UpgradeWebSocket => { - // upgrade(websocket) + // upgrade (WebSocket) PayloadType::Stream(PayloadDecoder::eof()) } PayloadLength::None => { if method == Method::CONNECT { PayloadType::Stream(PayloadDecoder::eof()) - } else if src.len() >= MAX_BUFFER_SIZE { - trace!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - return Err(ParseError::TooLarge); } else { PayloadType::None } @@ -273,7 +278,14 @@ impl MessageType for ResponseHead { (len, version, status, res.headers.len()) } - httparse::Status::Partial => return Ok(None), + httparse::Status::Partial => { + return if src.len() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + Err(ParseError::TooLarge) + } else { + Ok(None) + } + } } }; @@ -289,9 +301,6 @@ impl MessageType for ResponseHead { } else if status == StatusCode::SWITCHING_PROTOCOLS { // switching protocol or connect PayloadType::Stream(PayloadDecoder::eof()) - } else if src.len() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - return Err(ParseError::TooLarge); } else { // for HTTP/1.0 read to eof and close connection if msg.version == Version::HTTP_10 { @@ -643,7 +652,7 @@ mod tests { use super::*; use crate::error::ParseError; use crate::http::header::{HeaderName, SET_COOKIE}; - use crate::httpmessage::HttpMessage; + use crate::HttpMessage; impl PayloadType { fn unwrap(self) -> PayloadDecoder { @@ -685,7 +694,7 @@ mod tests { match MessageDecoder::::default().decode($e) { Err(err) => match err { ParseError::Io(_) => unreachable!("Parse error expected"), - _ => (), + _ => {} }, _ => unreachable!("Error expected"), } @@ -821,8 +830,8 @@ mod tests { .get_all(SET_COOKIE) .map(|v| v.to_str().unwrap().to_owned()) .collect(); - assert_eq!(val[1], "c1=cookie1"); - assert_eq!(val[0], "c2=cookie2"); + assert_eq!(val[0], "c1=cookie1"); + assert_eq!(val[1], "c2=cookie2"); } #[test] diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 91e208aac..6df579c0a 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -4,59 +4,56 @@ use std::{ future::Future, io, mem, net, pin::Pin, + rc::Rc, task::{Context, Poll}, }; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; -use actix_rt::time::{delay_until, Delay, Instant}; +use actix_rt::time::{sleep_until, Instant, Sleep}; use actix_service::Service; use bitflags::bitflags; use bytes::{Buf, BytesMut}; +use futures_core::ready; use log::{error, trace}; use pin_project::pin_project; -use crate::cloneable::CloneableService; +use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::error::{ParseError, PayloadError}; -use crate::httpmessage::HttpMessage; use crate::request::Request; use crate::response::Response; -use crate::{ - body::{Body, BodySize, MessageBody, ResponseBody}, - Extensions, -}; +use crate::service::HttpFlow; +use crate::OnConnectData; use super::codec::Codec; use super::payload::{Payload, PayloadSender, PayloadStatus}; use super::{Message, MessageType}; -const LW_BUFFER_SIZE: usize = 4096; -const HW_BUFFER_SIZE: usize = 32_768; +const LW_BUFFER_SIZE: usize = 1024; +const HW_BUFFER_SIZE: usize = 1024 * 8; const MAX_PIPELINED_MESSAGES: usize = 16; bitflags! { pub struct Flags: u8 { const STARTED = 0b0000_0001; const KEEPALIVE = 0b0000_0010; - const POLLED = 0b0000_0100; - const SHUTDOWN = 0b0000_1000; - const READ_DISCONNECT = 0b0001_0000; - const WRITE_DISCONNECT = 0b0010_0000; - const UPGRADE = 0b0100_0000; + const SHUTDOWN = 0b0000_0100; + const READ_DISCONNECT = 0b0000_1000; + const WRITE_DISCONNECT = 0b0001_0000; } } -#[pin_project::pin_project] +#[pin_project] /// Dispatcher for HTTP/1.1 protocol pub struct Dispatcher where - S: Service, + S: Service, S::Error: Into, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { #[pin] @@ -69,33 +66,31 @@ where #[pin_project(project = DispatcherStateProj)] enum DispatcherState where - S: Service, + S: Service, S::Error: Into, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { Normal(#[pin] InnerDispatcher), - Upgrade(Pin>), + Upgrade(#[pin] U::Future), } #[pin_project(project = InnerDispatcherProj)] struct InnerDispatcher where - S: Service, + S: Service, S::Error: Into, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { - service: CloneableService, - expect: CloneableService, - upgrade: Option>, - on_connect_data: Extensions, + flow: Rc>, + on_connect_data: OnConnectData, flags: Flags, peer_addr: Option, error: Option, @@ -106,7 +101,8 @@ where messages: VecDeque, ka_expire: Instant, - ka_timer: Option, + #[pin] + ka_timer: Option, io: Option, read_buf: BytesMut, @@ -123,8 +119,8 @@ enum DispatcherMessage { #[pin_project(project = StateProj)] enum State where - S: Service, - X: Service, + S: Service, + X: Service, B: MessageBody, { None, @@ -135,112 +131,64 @@ where impl State where - S: Service, - X: Service, + S: Service, + X: Service, B: MessageBody, { fn is_empty(&self) -> bool { matches!(self, State::None) } - - fn is_call(&self) -> bool { - matches!(self, State::ServiceCall(_)) - } } + enum PollResponse { Upgrade(Request), DoNothing, DrainWriteBuf, } -impl PartialEq for PollResponse { - fn eq(&self, other: &PollResponse) -> bool { - match self { - PollResponse::DrainWriteBuf => matches!(other, PollResponse::DrainWriteBuf), - PollResponse::DoNothing => matches!(other, PollResponse::DoNothing), - _ => false, - } - } -} - impl Dispatcher where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into, S::Response: Into>, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { /// Create HTTP/1 dispatcher. pub(crate) fn new( - stream: T, - config: ServiceConfig, - service: CloneableService, - expect: CloneableService, - upgrade: Option>, - on_connect_data: Extensions, - peer_addr: Option, - ) -> Self { - Dispatcher::with_timeout( - stream, - Codec::new(config.clone()), - config, - BytesMut::with_capacity(HW_BUFFER_SIZE), - None, - service, - expect, - upgrade, - on_connect_data, - peer_addr, - ) - } - - /// Create http/1 dispatcher with slow request timeout. - pub(crate) fn with_timeout( io: T, - codec: Codec, config: ServiceConfig, - read_buf: BytesMut, - timeout: Option, - service: CloneableService, - expect: CloneableService, - upgrade: Option>, - on_connect_data: Extensions, + flow: Rc>, + on_connect_data: OnConnectData, peer_addr: Option, ) -> Self { - let keepalive = config.keep_alive_enabled(); - let flags = if keepalive { + let flags = if config.keep_alive_enabled() { Flags::KEEPALIVE } else { Flags::empty() }; // keep-alive timer - let (ka_expire, ka_timer) = if let Some(delay) = timeout { - (delay.deadline(), Some(delay)) - } else if let Some(delay) = config.keep_alive_timer() { - (delay.deadline(), Some(delay)) - } else { - (config.now(), None) + let (ka_expire, ka_timer) = match config.keep_alive_timer() { + Some(delay) => (delay.deadline(), Some(delay)), + None => (config.now(), None), }; Dispatcher { inner: DispatcherState::Normal(InnerDispatcher { + read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), payload: None, state: State::None, error: None, messages: VecDeque::new(), io: Some(io), - codec, - read_buf, - service, - expect, - upgrade, + codec: Codec::new(config), + flow, on_connect_data, flags, peer_addr, @@ -257,20 +205,17 @@ where impl InnerDispatcher where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into, S::Response: Into>, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { fn can_read(&self, cx: &mut Context<'_>) -> bool { - if self - .flags - .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) - { + if self.flags.contains(Flags::READ_DISCONNECT) { false } else if let Some(ref info) = self.payload { info.need_read(cx) == PayloadStatus::Read @@ -289,52 +234,37 @@ where } } - /// Flush stream - /// - /// true - got WouldBlock - /// false - didn't get WouldBlock fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Result { - if self.write_buf.is_empty() { - return Ok(false); - } - - let len = self.write_buf.len(); - let mut written = 0; + ) -> Poll> { let InnerDispatcherProj { io, write_buf, .. } = self.project(); let mut io = Pin::new(io.as_mut().unwrap()); + + let len = write_buf.len(); + let mut written = 0; + while written < len { - match io.as_mut().poll_write(cx, &write_buf[written..]) { - Poll::Ready(Ok(0)) => { - return Err(DispatchError::Io(io::Error::new( + match io.as_mut().poll_write(cx, &write_buf[written..])? { + Poll::Ready(0) => { + return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, "", - ))); - } - Poll::Ready(Ok(n)) => { - written += n; + ))) } + Poll::Ready(n) => written += n, Poll::Pending => { - if written > 0 { - write_buf.advance(written); - } - return Ok(true); + write_buf.advance(written); + return Poll::Pending; } - Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), } } - if written == write_buf.len() { - // SAFETY: setting length to 0 is safe - // skips one length check vs truncate - unsafe { write_buf.set_len(0) } - } else { - write_buf.advance(written); - } + // everything has written to io. clear buffer. + write_buf.clear(); - Ok(false) + // flush the io and check if get blocked. + io.poll_flush(cx) } fn send_response( @@ -342,9 +272,10 @@ where message: Response<()>, body: ResponseBody, ) -> Result<(), DispatchError> { + let size = body.size(); let mut this = self.project(); this.codec - .encode(Message::Item((message, body.size())), &mut this.write_buf) + .encode(Message::Item((message, size)), &mut this.write_buf) .map_err(|err| { if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); @@ -353,7 +284,7 @@ where })?; this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); - match body.size() { + match size { BodySize::None | BodySize::Empty => this.state.set(State::None), _ => this.state.set(State::SendPayload(body)), }; @@ -370,108 +301,121 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Result { - loop { + 'res: loop { let mut this = self.as_mut().project(); - // state is not changed on Poll::Pending. - // other variant and conditions always trigger a state change(or an error). - let state_change = match this.state.project() { + match this.state.as_mut().project() { + // no future is in InnerDispatcher state. pop next message. StateProj::None => match this.messages.pop_front() { + // handle request message. Some(DispatcherMessage::Item(req)) => { - self.as_mut().handle_request(req, cx)?; - true + // Handle `EXPECT: 100-Continue` header + if req.head().expect() { + // set InnerDispatcher state and continue loop to poll it. + let task = this.flow.expect.call(req); + this.state.set(State::ExpectCall(task)); + } else { + // the same as expect call. + let task = this.flow.service.call(req); + this.state.set(State::ServiceCall(task)); + }; } + + // handle error message. Some(DispatcherMessage::Error(res)) => { + // send_response would update InnerDispatcher state to SendPayload or + // None(If response body is empty). + // continue loop to poll it. self.as_mut() .send_response(res, ResponseBody::Other(Body::Empty))?; - true } + + // return with upgrade request and poll it exclusively. Some(DispatcherMessage::Upgrade(req)) => { return Ok(PollResponse::Upgrade(req)); } - None => false, - }, - StateProj::ExpectCall(fut) => match fut.poll(cx) { - Poll::Ready(Ok(req)) => { - self.as_mut().send_continue(); - this = self.as_mut().project(); - this.state.set(State::ServiceCall(this.service.call(req))); - continue; - } - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - self.as_mut().send_response(res, body.into_body())?; - true - } - Poll::Pending => false, + + // all messages are dealt with. + None => return Ok(PollResponse::DoNothing), }, StateProj::ServiceCall(fut) => match fut.poll(cx) { + // service call resolved. send response. Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); self.as_mut().send_response(res, body)?; - continue; } - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); + + // send service call error as response + Poll::Ready(Err(err)) => { + let res: Response = err.into().into(); let (res, body) = res.replace_body(()); self.as_mut().send_response(res, body.into_body())?; - true } - Poll::Pending => false, - }, - StateProj::SendPayload(mut stream) => { - loop { - if this.write_buf.len() < HW_BUFFER_SIZE { - match stream.as_mut().poll_next(cx) { - Poll::Ready(Some(Ok(item))) => { - this.codec.encode( - Message::Chunk(Some(item)), - &mut this.write_buf, - )?; - continue; - } - Poll::Ready(None) => { - this.codec.encode( - Message::Chunk(None), - &mut this.write_buf, - )?; - this = self.as_mut().project(); - this.state.set(State::None); - } - Poll::Ready(Some(Err(_))) => { - return Err(DispatchError::Unknown) - } - Poll::Pending => return Ok(PollResponse::DoNothing), - } - } else { - return Ok(PollResponse::DrainWriteBuf); + + // service call pending and could be waiting for more chunk messages. + // (pipeline message limit and/or payload can_read limit) + Poll::Pending => { + // no new message is decoded and no new payload is feed. + // nothing to do except waiting for new incoming data from client. + if !self.as_mut().poll_request(cx)? { + return Ok(PollResponse::DoNothing); } - break; + // otherwise keep loop. } - continue; - } - }; + }, - // state is changed and continue when the state is not Empty - if state_change { - if !self.state.is_empty() { - continue; - } - } else { - // if read-backpressure is enabled and we consumed some data. - // we may read more data and retry - if self.state.is_call() { - if self.as_mut().poll_request(cx)? { - continue; + StateProj::SendPayload(mut stream) => { + // keep populate writer buffer until buffer size limit hit, + // get blocked or finished. + while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { + match stream.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(item))) => { + this.codec.encode( + Message::Chunk(Some(item)), + &mut this.write_buf, + )?; + } + + Poll::Ready(None) => { + this.codec + .encode(Message::Chunk(None), &mut this.write_buf)?; + // payload stream finished. + // set state to None and handle next message + this.state.set(State::None); + continue 'res; + } + + Poll::Ready(Some(Err(err))) => { + return Err(DispatchError::Service(err)) + } + + Poll::Pending => return Ok(PollResponse::DoNothing), + } } - } else if !self.messages.is_empty() { - continue; + // buffer is beyond max size. + // return and try to write the whole buffer to io stream. + return Ok(PollResponse::DrainWriteBuf); } + + StateProj::ExpectCall(fut) => match fut.poll(cx) { + // expect resolved. write continue to buffer and set InnerDispatcher state + // to service call. + Poll::Ready(Ok(req)) => { + this.write_buf + .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); + let fut = this.flow.service.call(req); + this.state.set(State::ServiceCall(fut)); + } + // send expect error as response + Poll::Ready(Err(err)) => { + let res: Response = err.into().into(); + let (res, body) = res.replace_body(()); + self.as_mut().send_response(res, body.into_body())?; + } + // expect must be solved before progress can be made. + Poll::Pending => return Ok(PollResponse::DoNothing), + }, } - break; } - - Ok(PollResponse::DoNothing) } fn handle_request( @@ -482,12 +426,14 @@ where // Handle `EXPECT: 100-Continue` header if req.head().expect() { // set dispatcher state so the future is pinned. - let task = self.as_mut().project().expect.call(req); - self.as_mut().project().state.set(State::ExpectCall(task)); + let mut this = self.as_mut().project(); + let task = this.flow.expect.call(req); + this.state.set(State::ExpectCall(task)); } else { // the same as above. - let task = self.as_mut().project().service.call(req); - self.as_mut().project().state.set(State::ServiceCall(task)); + let mut this = self.as_mut().project(); + let task = this.flow.service.call(req); + this.state.set(State::ServiceCall(task)); }; // eagerly poll the future for once(or twice if expect is resolved immediately). @@ -498,8 +444,9 @@ where // expect is resolved. continue loop and poll the service call branch. Poll::Ready(Ok(req)) => { self.as_mut().send_continue(); - let task = self.as_mut().project().service.call(req); - self.as_mut().project().state.set(State::ServiceCall(task)); + let mut this = self.as_mut().project(); + let task = this.flow.service.call(req); + this.state.set(State::ServiceCall(task)); continue; } // future is pending. return Ok(()) to notify that a new state is @@ -508,9 +455,9 @@ where // future is error. send response and return a result. On success // to notify the dispatcher a new state is set and the outer loop // should be continue. - Poll::Ready(Err(e)) => { - let e = e.into(); - let res: Response = e.into(); + Poll::Ready(Err(err)) => { + let err = err.into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); return self.send_response(res, body.into_body()); } @@ -528,9 +475,9 @@ where } // see the comment on ExpectCall state branch's Pending. Poll::Pending => Ok(()), - // see the comment on ExpectCall state branch's Ready(Err(e)). - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); + // see the comment on ExpectCall state branch's Ready(Err(err)). + Poll::Ready(Err(err)) => { + let res: Response = err.into().into(); let (res, body) = res.replace_body(()); self.send_response(res, body.into_body()) } @@ -544,7 +491,7 @@ where } /// Process one incoming request. - pub(self) fn poll_request( + fn poll_request( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Result { @@ -563,25 +510,43 @@ where match msg { Message::Item(mut req) => { - let pl = this.codec.message_type(); req.head_mut().peer_addr = *this.peer_addr; // merge on_connect_ext data into request extensions - req.extensions_mut().drain_from(this.on_connect_data); + this.on_connect_data.merge_into(&mut req); - if pl == MessageType::Stream && this.upgrade.is_some() { - this.messages.push_back(DispatcherMessage::Upgrade(req)); - break; - } - if pl == MessageType::Payload || pl == MessageType::Stream { - let (ps, pl) = Payload::create(false); - let (req1, _) = - req.replace_payload(crate::Payload::H1(pl)); - req = req1; - *this.payload = Some(ps); + match this.codec.message_type() { + // Request is upgradable. add upgrade message and break. + // everything remain in read buffer would be handed to + // upgraded Request. + MessageType::Stream if this.flow.upgrade.is_some() => { + this.messages + .push_back(DispatcherMessage::Upgrade(req)); + break; + } + + // Request is not upgradable. + MessageType::Payload | MessageType::Stream => { + /* + PayloadSender and Payload are smart pointers share the + same state. + PayloadSender is attached to dispatcher and used to sink + new chunked request data to state. + Payload is attached to Request and passed to Service::call + where the state can be collected and consumed. + */ + let (ps, pl) = Payload::create(false); + let (req1, _) = + req.replace_payload(crate::Payload::H1(pl)); + req = req1; + *this.payload = Some(ps); + } + + // Request has no payload. + MessageType::None => {} } - // handle request early + // handle request early when no future in InnerDispatcher state. if this.state.is_empty() { self.as_mut().handle_request(req, cx)?; this = self.as_mut().project(); @@ -619,14 +584,28 @@ where } } } + // decode is partial and buffer is not full yet. + // break and wait for more read. Ok(None) => break, - Err(ParseError::Io(e)) => { + Err(ParseError::Io(err)) => { self.as_mut().client_disconnected(); this = self.as_mut().project(); - *this.error = Some(DispatchError::Io(e)); + *this.error = Some(DispatchError::Io(err)); break; } - Err(e) => { + Err(ParseError::TooLarge) => { + if let Some(mut payload) = this.payload.take() { + payload.set_error(PayloadError::Overflow); + } + // Requests overflow buffer size should be responded with 431 + this.messages.push_back(DispatcherMessage::Error( + Response::RequestHeaderFieldsTooLarge().finish().drop_body(), + )); + this.flags.insert(Flags::READ_DISCONNECT); + *this.error = Some(ParseError::TooLarge.into()); + break; + } + Err(err) => { if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::EncodingCorrupted); } @@ -636,7 +615,7 @@ where Response::BadRequest().finish().drop_body(), )); this.flags.insert(Flags::READ_DISCONNECT); - *this.error = Some(e.into()); + *this.error = Some(err.into()); break; } } @@ -656,93 +635,203 @@ where cx: &mut Context<'_>, ) -> Result<(), DispatchError> { let mut this = self.as_mut().project(); - if this.ka_timer.is_none() { - // shutdown timeout - if this.flags.contains(Flags::SHUTDOWN) { - if let Some(interval) = this.codec.config().client_disconnect_timer() { - *this.ka_timer = Some(delay_until(interval)); - } else { - this.flags.insert(Flags::READ_DISCONNECT); - if let Some(mut payload) = this.payload.take() { - payload.set_error(PayloadError::Incomplete(None)); - } - return Ok(()); - } - } else { - return Ok(()); - } - } - match Pin::new(&mut this.ka_timer.as_mut().unwrap()).poll(cx) { - Poll::Ready(()) => { - // if we get timeout during shutdown, drop connection + // when a branch is not explicit return early it's meant to fall through + // and return as Ok(()) + match this.ka_timer.as_mut().as_pin_mut() { + None => { + // conditionally go into shutdown timeout if this.flags.contains(Flags::SHUTDOWN) { - return Err(DispatchError::DisconnectTimeout); - } else if this.ka_timer.as_mut().unwrap().deadline() >= *this.ka_expire { - // check for any outstanding tasks - if this.state.is_empty() && this.write_buf.is_empty() { - if this.flags.contains(Flags::STARTED) { - trace!("Keep-alive timeout, close connection"); - this.flags.insert(Flags::SHUTDOWN); + if let Some(deadline) = this.codec.config().client_disconnect_timer() + { + // write client disconnect time out and poll again to + // go into Some> branch + this.ka_timer.set(Some(sleep_until(deadline))); + return self.poll_keepalive(cx); + } else { + this.flags.insert(Flags::READ_DISCONNECT); + if let Some(mut payload) = this.payload.take() { + payload.set_error(PayloadError::Incomplete(None)); + } + } + } + } + Some(mut timer) => { + // only operate when keep-alive timer is resolved. + if timer.as_mut().poll(cx).is_ready() { + // got timeout during shutdown, drop connection + if this.flags.contains(Flags::SHUTDOWN) { + return Err(DispatchError::DisconnectTimeout); + // exceed deadline. check for any outstanding tasks + } else if timer.deadline() >= *this.ka_expire { + // have no task at hand. + if this.state.is_empty() && this.write_buf.is_empty() { + if this.flags.contains(Flags::STARTED) { + trace!("Keep-alive timeout, close connection"); + this.flags.insert(Flags::SHUTDOWN); - // start shutdown timer - if let Some(deadline) = - this.codec.config().client_disconnect_timer() - { - if let Some(mut timer) = this.ka_timer.as_mut() { - timer.reset(deadline); - let _ = Pin::new(&mut timer).poll(cx); + // start shutdown timeout + if let Some(deadline) = + this.codec.config().client_disconnect_timer() + { + timer.as_mut().reset(deadline); + let _ = timer.poll(cx); + } else { + // no shutdown timeout, drop socket + this.flags.insert(Flags::WRITE_DISCONNECT); } } else { - // no shutdown timeout, drop socket - this.flags.insert(Flags::WRITE_DISCONNECT); - return Ok(()); + // timeout on first request (slow request) return 408 + if !this.flags.contains(Flags::STARTED) { + trace!("Slow request timeout"); + let _ = self.as_mut().send_response( + Response::RequestTimeout().finish().drop_body(), + ResponseBody::Other(Body::Empty), + ); + this = self.project(); + } else { + trace!("Keep-alive connection timeout"); + } + this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); + this.state.set(State::None); } - } else { - // timeout on first request (slow request) return 408 - if !this.flags.contains(Flags::STARTED) { - trace!("Slow request timeout"); - let _ = self.as_mut().send_response( - Response::RequestTimeout().finish().drop_body(), - ResponseBody::Other(Body::Empty), - ); - this = self.as_mut().project(); - } else { - trace!("Keep-alive connection timeout"); - } - this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); - this.state.set(State::None); - } - } else if let Some(deadline) = - this.codec.config().keep_alive_expire() - { - if let Some(mut timer) = this.ka_timer.as_mut() { - timer.reset(deadline); - let _ = Pin::new(&mut timer).poll(cx); + // still have unfinished task. try to reset and register keep-alive. + } else if let Some(deadline) = + this.codec.config().keep_alive_expire() + { + timer.as_mut().reset(deadline); + let _ = timer.poll(cx); } + // timer resolved but still have not met the keep-alive expire deadline. + // reset and register for later wakeup. + } else { + timer.as_mut().reset(*this.ka_expire); + let _ = timer.poll(cx); } - } else if let Some(mut timer) = this.ka_timer.as_mut() { - timer.reset(*this.ka_expire); - let _ = Pin::new(&mut timer).poll(cx); } } - Poll::Pending => (), } - Ok(()) } + + /// Returns true when io stream can be disconnected after write to it. + /// + /// It covers these conditions: + /// + /// - `std::io::ErrorKind::ConnectionReset` after partial read. + /// - all data read done. + #[inline(always)] + fn read_available( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Result { + let this = self.project(); + + if this.flags.contains(Flags::READ_DISCONNECT) { + return Ok(false); + }; + + let mut io = Pin::new(this.io.as_mut().unwrap()); + + let mut read_some = false; + + loop { + // Return early when read buf exceed decoder's max buffer size. + if this.read_buf.len() >= super::decoder::MAX_BUFFER_SIZE { + /* + At this point it's not known IO stream is still scheduled + to be waked up. so force wake up dispatcher just in case. + + Reason: + AsyncRead mostly would only have guarantee wake up + when the poll_read return Poll::Pending. + + Case: + When read_buf is beyond max buffer size the early return + could be successfully be parsed as a new Request. + This case would not generate ParseError::TooLarge + and at this point IO stream is not fully read to Pending + and would result in dispatcher stuck until timeout (KA) + + Note: + This is a perf choice to reduce branch on + ::decode. + + A Request head too large to parse is only checked on + httparse::Status::Partial condition. + */ + if this.payload.is_none() { + /* + When dispatcher has a payload the responsibility of + wake up it would be shift to h1::payload::Payload. + + Reason: + Self wake up when there is payload would waste poll + and/or result in over read. + + Case: + When payload is (partial) dropped by user there is + no need to do read anymore. + At this case read_buf could always remain beyond + MAX_BUFFER_SIZE and self wake up would be busy poll + dispatcher and waste resource. + + */ + cx.waker().wake_by_ref(); + } + + return Ok(false); + } + + // grow buffer if necessary. + let remaining = this.read_buf.capacity() - this.read_buf.len(); + if remaining < LW_BUFFER_SIZE { + this.read_buf.reserve(HW_BUFFER_SIZE - remaining); + } + + match actix_codec::poll_read_buf(io.as_mut(), cx, this.read_buf) { + Poll::Ready(Ok(n)) => { + if n == 0 { + return Ok(true); + } + read_some = true; + } + Poll::Pending => return Ok(false), + Poll::Ready(Err(err)) => { + return match err.kind() { + io::ErrorKind::WouldBlock => Ok(false), + io::ErrorKind::ConnectionReset if read_some => Ok(true), + _ => Err(DispatchError::Io(err)), + } + } + } + } + } + + /// call upgrade service with request. + fn upgrade(self: Pin<&mut Self>, req: Request) -> U::Future { + let this = self.project(); + let mut parts = FramedParts::with_read_buf( + this.io.take().unwrap(), + mem::take(this.codec), + mem::take(this.read_buf), + ); + parts.write_buf = mem::take(this.write_buf); + let framed = Framed::from_parts(parts); + this.flow.upgrade.as_ref().unwrap().call((req, framed)) + } } impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into, S::Response: Into>, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { type Output = Result<(), DispatchError>; @@ -764,74 +853,43 @@ where if inner.flags.contains(Flags::WRITE_DISCONNECT) { Poll::Ready(Ok(())) } else { - // flush buffer - inner.as_mut().poll_flush(cx)?; - if !inner.write_buf.is_empty() || inner.io.is_none() { - Poll::Pending - } else { - match Pin::new(inner.project().io) - .as_pin_mut() - .unwrap() - .poll_shutdown(cx) - { - Poll::Ready(res) => { - Poll::Ready(res.map_err(DispatchError::from)) - } - Poll::Pending => Poll::Pending, - } - } + // flush buffer and wait on blocked. + ready!(inner.as_mut().poll_flush(cx))?; + Pin::new(inner.project().io.as_mut().unwrap()) + .poll_shutdown(cx) + .map_err(DispatchError::from) } } else { - // read socket into a buf - let should_disconnect = - if !inner.flags.contains(Flags::READ_DISCONNECT) { - let mut inner_p = inner.as_mut().project(); - read_available( - cx, - inner_p.io.as_mut().unwrap(), - &mut inner_p.read_buf, - )? - } else { - None - }; + // read from io stream and fill read buffer. + let should_disconnect = inner.as_mut().read_available(cx)?; inner.as_mut().poll_request(cx)?; - if let Some(true) = should_disconnect { - let inner_p = inner.as_mut().project(); - inner_p.flags.insert(Flags::READ_DISCONNECT); - if let Some(mut payload) = inner_p.payload.take() { + + // io stream should to be closed. + if should_disconnect { + let inner = inner.as_mut().project(); + inner.flags.insert(Flags::READ_DISCONNECT); + if let Some(mut payload) = inner.payload.take() { payload.feed_eof(); } }; loop { - let inner_p = inner.as_mut().project(); - let remaining = - inner_p.write_buf.capacity() - inner_p.write_buf.len(); - if remaining < LW_BUFFER_SIZE { - inner_p.write_buf.reserve(HW_BUFFER_SIZE - remaining); - } - let result = inner.as_mut().poll_response(cx)?; - let drain = result == PollResponse::DrainWriteBuf; - - // switch to upgrade handler - if let PollResponse::Upgrade(req) = result { - let inner_p = inner.as_mut().project(); - let mut parts = FramedParts::with_read_buf( - inner_p.io.take().unwrap(), - mem::take(inner_p.codec), - mem::take(inner_p.read_buf), - ); - parts.write_buf = mem::take(inner_p.write_buf); - let framed = Framed::from_parts(parts); - let upgrade = - inner_p.upgrade.take().unwrap().call((req, framed)); - self.as_mut() - .project() - .inner - .set(DispatcherState::Upgrade(Box::pin(upgrade))); - return self.poll(cx); - } + // poll_response and populate write buffer. + // drain indicate if write buffer should be emptied before next run. + let drain = match inner.as_mut().poll_response(cx)? { + PollResponse::DrainWriteBuf => true, + PollResponse::DoNothing => false, + // upgrade request and goes Upgrade variant of DispatcherState. + PollResponse::Upgrade(req) => { + let upgrade = inner.upgrade(req); + self.as_mut() + .project() + .inner + .set(DispatcherState::Upgrade(upgrade)); + return self.poll(cx); + } + }; // we didn't get WouldBlock from write operation, // so data get written to kernel completely (macOS) @@ -839,7 +897,7 @@ where // // TODO: what? is WouldBlock good or bad? // want to find a reference for this macOS behavior - if inner.as_mut().poll_flush(cx)? || !drain { + if inner.as_mut().poll_flush(cx)?.is_pending() || !drain { break; } } @@ -880,7 +938,7 @@ where } } } - DispatcherStateProj::Upgrade(fut) => fut.as_mut().poll(cx).map_err(|e| { + DispatcherStateProj::Upgrade(fut) => fut.poll(cx).map_err(|e| { error!("Upgrade handler error: {}", e); DispatchError::Upgrade }), @@ -888,75 +946,9 @@ where } } -/// Returns either: -/// - `Ok(Some(true))` - data was read and done reading all data. -/// - `Ok(Some(false))` - data was read but there should be more to read. -/// - `Ok(None)` - no data was read but there should be more to read later. -/// - Unhandled Errors -fn read_available( - cx: &mut Context<'_>, - io: &mut T, - buf: &mut BytesMut, -) -> Result, io::Error> -where - T: AsyncRead + Unpin, -{ - let mut read_some = false; - - loop { - // If buf is full return but do not disconnect since - // there is more reading to be done - if buf.len() >= HW_BUFFER_SIZE { - return Ok(Some(false)); - } - - let remaining = buf.capacity() - buf.len(); - if remaining < LW_BUFFER_SIZE { - buf.reserve(HW_BUFFER_SIZE - remaining); - } - - match read(cx, io, buf) { - Poll::Pending => { - return if read_some { Ok(Some(false)) } else { Ok(None) }; - } - Poll::Ready(Ok(n)) => { - if n == 0 { - return Ok(Some(true)); - } else { - read_some = true; - } - } - Poll::Ready(Err(err)) => { - return if err.kind() == io::ErrorKind::WouldBlock { - if read_some { - Ok(Some(false)) - } else { - Ok(None) - } - } else if err.kind() == io::ErrorKind::ConnectionReset && read_some { - Ok(Some(true)) - } else { - Err(err) - } - } - } - } -} - -fn read( - cx: &mut Context<'_>, - io: &mut T, - buf: &mut BytesMut, -) -> Poll> -where - T: AsyncRead + Unpin, -{ - Pin::new(io).poll_read_buf(cx, buf) -} - #[cfg(test)] mod tests { - use std::{marker::PhantomData, str}; + use std::str; use actix_service::fn_service; use futures_util::future::{lazy, ready}; @@ -985,21 +977,19 @@ mod tests { } } - fn ok_service() -> impl Service - { + fn ok_service() -> impl Service { fn_service(|_req: Request| ready(Ok::<_, Error>(Response::Ok().finish()))) } - fn echo_path_service( - ) -> impl Service { + fn echo_path_service() -> impl Service { fn_service(|req: Request| { let path = req.path().as_bytes(); ready(Ok::<_, Error>(Response::Ok().body(Body::from_slice(path)))) }) } - fn echo_payload_service( - ) -> impl Service { + fn echo_payload_service() -> impl Service + { fn_service(|mut req: Request| { Box::pin(async move { use futures_util::stream::StreamExt as _; @@ -1007,7 +997,7 @@ mod tests { let mut pl = req.take_payload(); let mut body = BytesMut::new(); while let Some(chunk) = pl.next().await { - body.extend_from_slice(chunk.unwrap().bytes()) + body.extend_from_slice(chunk.unwrap().chunk()) } Ok::<_, Error>(Response::Ok().body(body)) @@ -1020,17 +1010,17 @@ mod tests { lazy(|cx| { let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let services = HttpFlow::new(ok_service(), ExpectHandler, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, ServiceConfig::default(), - CloneableService::new(ok_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); match h1.as_mut().poll(cx) { Poll::Pending => panic!(), @@ -1060,17 +1050,17 @@ mod tests { let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let services = HttpFlow::new(echo_path_service(), ExpectHandler, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, cfg, - CloneableService::new(echo_path_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(matches!(&h1.inner, DispatcherState::Normal(_))); @@ -1114,17 +1104,17 @@ mod tests { let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let services = HttpFlow::new(echo_path_service(), ExpectHandler, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, cfg, - CloneableService::new(echo_path_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(matches!(&h1.inner, DispatcherState::Normal(_))); @@ -1163,13 +1153,14 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + + let services = HttpFlow::new(echo_payload_service(), ExpectHandler, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), cfg, - CloneableService::new(echo_payload_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1182,7 +1173,7 @@ mod tests { ", ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(h1.as_mut().poll(cx).is_pending()); assert!(matches!(&h1.inner, DispatcherState::Normal(_))); @@ -1234,13 +1225,14 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + + let services = HttpFlow::new(echo_path_service(), ExpectHandler, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), cfg, - CloneableService::new(echo_path_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1253,7 +1245,7 @@ mod tests { ", ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(h1.as_mut().poll(cx).is_ready()); assert!(matches!(&h1.inner, DispatcherState::Normal(_))); @@ -1293,13 +1285,15 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + + let services = + HttpFlow::new(ok_service(), ExpectHandler, Some(UpgradeHandler)); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), cfg, - CloneableService::new(ok_service()), - CloneableService::new(ExpectHandler), - Some(CloneableService::new(UpgradeHandler(PhantomData))), - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1312,7 +1306,7 @@ mod tests { ", ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(h1.as_mut().poll(cx).is_ready()); assert!(matches!(&h1.inner, DispatcherState::Upgrade(_))); diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index b7648eff1..97916e7db 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -8,7 +8,7 @@ use bytes::{BufMut, BytesMut}; use crate::body::BodySize; use crate::config::ServiceConfig; -use crate::header::map; +use crate::header::{map::Value, HeaderName}; use crate::helpers; use crate::http::header::{CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use crate::http::{HeaderMap, StatusCode, Version}; @@ -21,7 +21,7 @@ const AVERAGE_HEADER_SIZE: usize = 30; pub(crate) struct MessageEncoder { pub length: BodySize, pub te: TransferEncoding, - _t: PhantomData, + _phantom: PhantomData, } impl Default for MessageEncoder { @@ -29,7 +29,7 @@ impl Default for MessageEncoder { MessageEncoder { length: BodySize::None, te: TransferEncoding::empty(), - _t: PhantomData, + _phantom: PhantomData, } } } @@ -118,24 +118,14 @@ pub(crate) trait MessageType: Sized { dst.put_slice(b"connection: close\r\n") } } - _ => (), + _ => {} } - // merging headers from head and extra headers. HeaderMap::new() does not allocate. - let empty_headers = HeaderMap::new(); - let extra_headers = self.extra_headers().unwrap_or(&empty_headers); - let headers = self - .headers() - .inner - .iter() - .filter(|(name, _)| !extra_headers.contains_key(*name)) - .chain(extra_headers.inner.iter()); - // write headers let mut has_date = false; - let mut buf = dst.bytes_mut().as_mut_ptr() as *mut u8; + let mut buf = dst.chunk_mut().as_mut_ptr(); let mut remaining = dst.capacity() - dst.len(); // tracks bytes written since last buffer resize @@ -143,117 +133,67 @@ pub(crate) trait MessageType: Sized { // container's knowledge, this is used to sync the containers cursor after data is written let mut pos = 0; - for (key, value) in headers { + self.write_headers(|key, value| { match *key { - CONNECTION => continue, - TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => continue, + CONNECTION => return, + TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return, DATE => has_date = true, - _ => (), + _ => {} } let k = key.as_str().as_bytes(); let k_len = k.len(); - match value { - map::Value::One(ref val) => { - let v = val.as_ref(); - let v_len = v.len(); + // TODO: drain? + for val in value.iter() { + let v = val.as_ref(); + let v_len = v.len(); - // key length + value length + colon + space + \r\n - let len = k_len + v_len + 4; + // key length + value length + colon + space + \r\n + let len = k_len + v_len + 4; - if len > remaining { - // not enough room in buffer for this header; reserve more space - - // SAFETY: all the bytes written up to position "pos" are initialized - // the written byte count and pointer advancement are kept in sync - unsafe { - dst.advance_mut(pos); - } - - pos = 0; - dst.reserve(len * 2); - remaining = dst.capacity() - dst.len(); - - // re-assign buf raw pointer since it's possible that the buffer was - // reallocated and/or resized - buf = dst.bytes_mut().as_mut_ptr() as *mut u8; - } - - // SAFETY: on each write, it is enough to ensure that the advancement of the - // cursor matches the number of bytes written + if len > remaining { + // SAFETY: all the bytes written up to position "pos" are initialized + // the written byte count and pointer advancement are kept in sync unsafe { - // use upper Camel-Case - if camel_case { - write_camel_case(k, from_raw_parts_mut(buf, k_len)) - } else { - write_data(k, buf, k_len) - } - - buf = buf.add(k_len); - - write_data(b": ", buf, 2); - buf = buf.add(2); - - write_data(v, buf, v_len); - buf = buf.add(v_len); - - write_data(b"\r\n", buf, 2); - buf = buf.add(2); + dst.advance_mut(pos); } - pos += len; - remaining -= len; + pos = 0; + dst.reserve(len * 2); + remaining = dst.capacity() - dst.len(); + + // re-assign buf raw pointer since it's possible that the buffer was + // reallocated and/or resized + buf = dst.chunk_mut().as_mut_ptr(); } - map::Value::Multi(ref vec) => { - for val in vec { - let v = val.as_ref(); - let v_len = v.len(); - let len = k_len + v_len + 4; - - if len > remaining { - // SAFETY: all the bytes written up to position "pos" are initialized - // the written byte count and pointer advancement are kept in sync - unsafe { - dst.advance_mut(pos); - } - pos = 0; - dst.reserve(len * 2); - remaining = dst.capacity() - dst.len(); - - // re-assign buf raw pointer since it's possible that the buffer was - // reallocated and/or resized - buf = dst.bytes_mut().as_mut_ptr() as *mut u8; - } - - // SAFETY: on each write, it is enough to ensure that the advancement of - // the cursor matches the number of bytes written - unsafe { - if camel_case { - write_camel_case(k, from_raw_parts_mut(buf, k_len)); - } else { - write_data(k, buf, k_len); - } - - buf = buf.add(k_len); - - write_data(b": ", buf, 2); - buf = buf.add(2); - - write_data(v, buf, v_len); - buf = buf.add(v_len); - - write_data(b"\r\n", buf, 2); - buf = buf.add(2); - }; - - pos += len; - remaining -= len; + // SAFETY: on each write, it is enough to ensure that the advancement of + // the cursor matches the number of bytes written + unsafe { + if camel_case { + // use Camel-Case headers + write_camel_case(k, from_raw_parts_mut(buf, k_len)); + } else { + write_data(k, buf, k_len); } - } + + buf = buf.add(k_len); + + write_data(b": ", buf, 2); + buf = buf.add(2); + + write_data(v, buf, v_len); + buf = buf.add(v_len); + + write_data(b"\r\n", buf, 2); + buf = buf.add(2); + }; + + pos += len; + remaining -= len; } - } + }); // final cursor synchronization with the bytes container // @@ -273,6 +213,24 @@ pub(crate) trait MessageType: Sized { Ok(()) } + + fn write_headers(&mut self, mut f: F) + where + F: FnMut(&HeaderName, &Value), + { + match self.extra_headers() { + Some(headers) => { + // merging headers from head and extra headers. + self.headers() + .inner + .iter() + .filter(|(name, _)| !headers.contains_key(*name)) + .chain(headers.inner.iter()) + .for_each(|(k, v)| f(k, v)) + } + None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)), + } + } } impl MessageType for Response<()> { @@ -329,7 +287,7 @@ impl MessageType for RequestHeadType { let head = self.as_ref(); dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); write!( - Writer(dst), + helpers::Writer(dst), "{} {} {}", head.method, head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), @@ -462,7 +420,7 @@ impl TransferEncoding { *eof = true; buf.extend_from_slice(b"0\r\n\r\n"); } else { - writeln!(Writer(buf), "{:X}\r", msg.len()) + writeln!(helpers::Writer(buf), "{:X}\r", msg.len()) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; buf.reserve(msg.len() + 2); @@ -512,18 +470,6 @@ impl TransferEncoding { } } -struct Writer<'a>(pub &'a mut BytesMut); - -impl<'a> io::Write for Writer<'a> { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.extend_from_slice(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - /// # Safety /// Callers must ensure that the given length matches given value length. unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) { @@ -532,30 +478,29 @@ unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) { } fn write_camel_case(value: &[u8], buffer: &mut [u8]) { - let mut index = 0; - let key = value; - let mut key_iter = key.iter(); + // first copy entire (potentially wrong) slice to output + buffer[..value.len()].copy_from_slice(value); - if let Some(c) = key_iter.next() { - if *c >= b'a' && *c <= b'z' { - buffer[index] = *c ^ b' '; - index += 1; - } - } else { - return; + let mut iter = value.iter(); + + // first character should be uppercase + if let Some(c @ b'a'..=b'z') = iter.next() { + buffer[0] = c & 0b1101_1111; } - while let Some(c) = key_iter.next() { - buffer[index] = *c; - index += 1; - if *c == b'-' { - if let Some(c) = key_iter.next() { - if *c >= b'a' && *c <= b'z' { - buffer[index] = *c ^ b' '; - index += 1; - } + // track 1 ahead of the current position since that's the location being assigned to + let mut index = 2; + + // remaining characters after hyphens should also be uppercase + while let Some(&c) = iter.next() { + if c == b'-' { + // advance iter by one and uppercase if needed + if let Some(c @ b'a'..=b'z') = iter.next() { + buffer[index] = c & 0b1101_1111; } } + + index += 1; } } @@ -584,8 +529,8 @@ mod tests { ); } - #[test] - fn test_camel_case() { + #[actix_rt::test] + async fn test_camel_case() { let mut bytes = BytesMut::with_capacity(2048); let mut head = RequestHead::default(); head.set_camel_case_headers(true); @@ -604,6 +549,7 @@ mod tests { ); let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("Content-Length: 0\r\n")); assert!(data.contains("Connection: close\r\n")); assert!(data.contains("Content-Type: plain/text\r\n")); @@ -646,8 +592,8 @@ mod tests { assert!(data.contains("date: date\r\n")); } - #[test] - fn test_extra_headers() { + #[actix_rt::test] + async fn test_extra_headers() { let mut bytes = BytesMut::with_capacity(2048); let mut head = RequestHead::default(); @@ -680,8 +626,8 @@ mod tests { assert!(data.contains("date: date\r\n")); } - #[test] - fn test_no_content_length() { + #[actix_rt::test] + async fn test_no_content_length() { let mut bytes = BytesMut::with_capacity(2048); let mut res: Response<()> = diff --git a/actix-http/src/h1/expect.rs b/actix-http/src/h1/expect.rs index b89c7ff74..65856edf6 100644 --- a/actix-http/src/h1/expect.rs +++ b/actix-http/src/h1/expect.rs @@ -1,4 +1,4 @@ -use std::task::{Context, Poll}; +use std::task::Poll; use actix_service::{Service, ServiceFactory}; use futures_util::future::{ready, Ready}; @@ -8,11 +8,10 @@ use crate::request::Request; pub struct ExpectHandler; -impl ServiceFactory for ExpectHandler { - type Config = (); - type Request = Request; +impl ServiceFactory for ExpectHandler { type Response = Request; type Error = Error; + type Config = (); type Service = ExpectHandler; type InitError = Error; type Future = Ready>; @@ -22,17 +21,14 @@ impl ServiceFactory for ExpectHandler { } } -impl Service for ExpectHandler { - type Request = Request; +impl Service for ExpectHandler { type Response = Request; type Error = Error; type Future = Ready>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + actix_service::always_ready!(); - fn call(&mut self, req: Request) -> Self::Future { + fn call(&self, req: Request) -> Self::Future { ready(Ok(req)) // TODO: add some way to trigger error // Err(error::ErrorExpectationFailed("test")) diff --git a/actix-http/src/h1/mod.rs b/actix-http/src/h1/mod.rs index 3d5dea5d6..7e6df6ceb 100644 --- a/actix-http/src/h1/mod.rs +++ b/actix-http/src/h1/mod.rs @@ -1,4 +1,4 @@ -//! HTTP/1 implementation +//! HTTP/1 protocol implementation. use bytes::{Bytes, BytesMut}; mod client; diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 919a5d932..b79453ebd 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -12,37 +12,37 @@ use futures_core::ready; use futures_util::future::ready; use crate::body::MessageBody; -use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::request::Request; use crate::response::Response; -use crate::{ConnectCallback, Extensions}; +use crate::service::HttpFlow; +use crate::{ConnectCallback, OnConnectData}; use super::codec::Codec; use super::dispatcher::Dispatcher; use super::{ExpectHandler, UpgradeHandler}; /// `ServiceFactory` implementation for HTTP1 transport -pub struct H1Service> { +pub struct H1Service { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, on_connect_ext: Option>>, - _t: PhantomData<(T, B)>, + _phantom: PhantomData, } impl H1Service where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, { /// Create new `HttpService` instance with config. - pub(crate) fn with_config>( + pub(crate) fn with_config>( cfg: ServiceConfig, service: F, ) -> Self { @@ -52,26 +52,22 @@ where expect: ExpectHandler, upgrade: None, on_connect_ext: None, - _t: PhantomData, + _phantom: PhantomData, } } } impl H1Service where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: ServiceFactory< - Config = (), - Request = (Request, Framed), - Response = (), - >, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, { @@ -79,8 +75,8 @@ where pub fn tcp( self, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), Error = DispatchError, InitError = (), @@ -97,22 +93,23 @@ where mod openssl { use super::*; - use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream}; - use actix_tls::{openssl::HandshakeError, TlsError}; + use actix_service::ServiceFactoryExt; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::TlsError; impl H1Service, S, B, X, U> where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, U: ServiceFactory< + (Request, Framed, Codec>), Config = (), - Request = (Request, Framed, Codec>), Response = (), >, U::Error: fmt::Display + Into, @@ -123,10 +120,10 @@ mod openssl { self, acceptor: SslAcceptor, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), - Error = TlsError, DispatchError>, + Error = TlsError, InitError = (), > { pipeline_factory( @@ -146,23 +143,24 @@ mod openssl { #[cfg(feature = "rustls")] mod rustls { use super::*; - use actix_tls::rustls::{Acceptor, ServerConfig, TlsStream}; - use actix_tls::TlsError; + use actix_service::ServiceFactoryExt; + use actix_tls::accept::rustls::{Acceptor, ServerConfig, TlsStream}; + use actix_tls::accept::TlsError; use std::{fmt, io}; impl H1Service, S, B, X, U> where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, U: ServiceFactory< + (Request, Framed, Codec>), Config = (), - Request = (Request, Framed, Codec>), Response = (), >, U::Error: fmt::Display + Into, @@ -173,8 +171,8 @@ mod rustls { self, config: ServerConfig, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), Error = TlsError, InitError = (), @@ -195,7 +193,7 @@ mod rustls { impl H1Service where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, @@ -203,7 +201,7 @@ where { pub fn expect(self, expect: X1) -> H1Service where - X1: ServiceFactory, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, { @@ -213,13 +211,13 @@ where srv: self.srv, upgrade: self.upgrade, on_connect_ext: self.on_connect_ext, - _t: PhantomData, + _phantom: PhantomData, } } pub fn upgrade(self, upgrade: Option) -> H1Service where - U1: ServiceFactory), Response = ()>, + U1: ServiceFactory<(Request, Framed), Response = ()>, U1::Error: fmt::Display, U1::InitError: fmt::Debug, { @@ -229,7 +227,7 @@ where srv: self.srv, expect: self.expect, on_connect_ext: self.on_connect_ext, - _t: PhantomData, + _phantom: PhantomData, } } @@ -240,27 +238,27 @@ where } } -impl ServiceFactory for H1Service +impl ServiceFactory<(T, Option)> + for H1Service where T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: ServiceFactory), Response = ()>, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, { - type Config = (); - type Request = (T, Option); type Response = (); type Error = DispatchError; - type InitError = (); + type Config = (); type Service = H1ServiceHandler; + type InitError = (); type Future = H1ServiceResponse; fn new_service(&self, _: ()) -> Self::Future { @@ -272,7 +270,7 @@ where upgrade: None, on_connect_ext: self.on_connect_ext.clone(), cfg: Some(self.cfg.clone()), - _t: PhantomData, + _phantom: PhantomData, } } } @@ -281,13 +279,13 @@ where #[pin_project::pin_project] pub struct H1ServiceResponse where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: ServiceFactory), Response = ()>, + U: ServiceFactory<(Request, Framed), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, { @@ -301,21 +299,21 @@ where upgrade: Option, on_connect_ext: Option>>, cfg: Option, - _t: PhantomData<(T, B)>, + _phantom: PhantomData, } impl Future for H1ServiceResponse where T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, + S: ServiceFactory, S::Error: Into, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - U: ServiceFactory), Response = ()>, + U: ServiceFactory<(Request, Framed), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, { @@ -339,7 +337,7 @@ where .map_err(|e| log::error!("Init http service error: {:?}", e)))?; this = self.as_mut().project(); *this.upgrade = Some(upgrade); - this.fut_ex.set(None); + this.fut_upg.set(None); } let result = ready!(this @@ -362,63 +360,65 @@ where } /// `Service` implementation for HTTP/1 transport -pub struct H1ServiceHandler { - srv: CloneableService, - expect: CloneableService, - upgrade: Option>, +pub struct H1ServiceHandler +where + S: Service, + X: Service, + U: Service<(Request, Framed)>, +{ + flow: Rc>, on_connect_ext: Option>>, cfg: ServiceConfig, - _t: PhantomData<(T, B)>, + _phantom: PhantomData, } impl H1ServiceHandler where - S: Service, + S: Service, S::Error: Into, S::Response: Into>, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { fn new( cfg: ServiceConfig, - srv: S, + service: S, expect: X, upgrade: Option, on_connect_ext: Option>>, ) -> H1ServiceHandler { H1ServiceHandler { - srv: CloneableService::new(srv), - expect: CloneableService::new(expect), - upgrade: upgrade.map(CloneableService::new), + flow: HttpFlow::new(service, expect, upgrade), cfg, on_connect_ext, - _t: PhantomData, + _phantom: PhantomData, } } } -impl Service for H1ServiceHandler +impl Service<(T, Option)> + for H1ServiceHandler where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into, S::Response: Into>, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display + Into, { - type Request = (T, Option); type Response = (); type Error = DispatchError; type Future = Dispatcher; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let ready = self + .flow .expect .poll_ready(cx) .map_err(|e| { @@ -429,7 +429,8 @@ where .is_ready(); let ready = self - .srv + .flow + .service .poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -439,7 +440,7 @@ where .is_ready() && ready; - let ready = if let Some(ref mut upg) = self.upgrade { + let ready = if let Some(ref upg) = self.flow.upgrade { upg.poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -459,20 +460,15 @@ where } } - fn call(&mut self, (io, addr): Self::Request) -> Self::Future { - let mut connect_extensions = Extensions::new(); - if let Some(ref handler) = self.on_connect_ext { - // run on_connect_ext callback, populating connect extensions - handler(&io, &mut connect_extensions); - } + fn call(&self, (io, addr): (T, Option)) -> Self::Future { + let on_connect_data = + OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); Dispatcher::new( io, self.cfg.clone(), - self.srv.clone(), - self.expect.clone(), - self.upgrade.clone(), - connect_extensions, + self.flow.clone(), + on_connect_data, addr, ) } diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs index 8615f27a8..5e24d84e3 100644 --- a/actix-http/src/h1/upgrade.rs +++ b/actix-http/src/h1/upgrade.rs @@ -1,5 +1,4 @@ -use std::marker::PhantomData; -use std::task::{Context, Poll}; +use std::task::Poll; use actix_codec::Framed; use actix_service::{Service, ServiceFactory}; @@ -9,14 +8,13 @@ use crate::error::Error; use crate::h1::Codec; use crate::request::Request; -pub struct UpgradeHandler(pub(crate) PhantomData); +pub struct UpgradeHandler; -impl ServiceFactory for UpgradeHandler { - type Config = (); - type Request = (Request, Framed); +impl ServiceFactory<(Request, Framed)> for UpgradeHandler { type Response = (); type Error = Error; - type Service = UpgradeHandler; + type Config = (); + type Service = UpgradeHandler; type InitError = Error; type Future = Ready>; @@ -25,17 +23,14 @@ impl ServiceFactory for UpgradeHandler { } } -impl Service for UpgradeHandler { - type Request = (Request, Framed); +impl Service<(Request, Framed)> for UpgradeHandler { type Response = (); type Error = Error; type Future = Ready>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + actix_service::always_ready!(); - fn call(&mut self, _: Self::Request) -> Self::Future { + fn call(&self, _: (Request, Framed)) -> Self::Future { ready(Ok(())) } } diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index 7a0be9492..5ccd2a9d1 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -1,63 +1,65 @@ -use std::convert::TryFrom; use std::future::Future; use std::marker::PhantomData; use std::net; use std::pin::Pin; +use std::rc::Rc; use std::task::{Context, Poll}; +use std::{cmp, convert::TryFrom}; use actix_codec::{AsyncRead, AsyncWrite}; -use actix_rt::time::{Delay, Instant}; +use actix_rt::time::{Instant, Sleep}; use actix_service::Service; use bytes::{Bytes, BytesMut}; +use futures_core::ready; use h2::server::{Connection, SendResponse}; use h2::SendStream; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use log::{error, trace}; use crate::body::{BodySize, MessageBody, ResponseBody}; -use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; -use crate::httpmessage::HttpMessage; use crate::message::ResponseHead; use crate::payload::Payload; use crate::request::Request; use crate::response::Response; -use crate::Extensions; +use crate::service::HttpFlow; +use crate::OnConnectData; const CHUNK_SIZE: usize = 16_384; -/// Dispatcher for HTTP/2 protocol +/// Dispatcher for HTTP/2 protocol. #[pin_project::pin_project] -pub struct Dispatcher, B: MessageBody> +pub struct Dispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, + B: MessageBody, { - service: CloneableService, + flow: Rc>, connection: Connection, - on_connect_data: Extensions, + on_connect_data: OnConnectData, config: ServiceConfig, peer_addr: Option, ka_expire: Instant, - ka_timer: Option, - _t: PhantomData, + ka_timer: Option, + _phantom: PhantomData, } -impl Dispatcher +impl Dispatcher where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into, - // S::Future: 'static, S::Response: Into>, B: MessageBody, { pub(crate) fn new( - service: CloneableService, + flow: Rc>, connection: Connection, - on_connect_data: Extensions, + on_connect_data: OnConnectData, config: ServiceConfig, - timeout: Option, + timeout: Option, peer_addr: Option, ) -> Self { // let keepalive = config.keep_alive_enabled(); @@ -77,22 +79,22 @@ where }; Dispatcher { - service, + flow, config, peer_addr, connection, on_connect_data, ka_expire, ka_timer, - _t: PhantomData, + _phantom: PhantomData, } } } -impl Future for Dispatcher +impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, @@ -105,10 +107,12 @@ where let this = self.get_mut(); loop { - match Pin::new(&mut this.connection).poll_accept(cx) { - Poll::Ready(None) => return Poll::Ready(Ok(())), - Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), - Poll::Ready(Some(Ok((req, res)))) => { + match ready!(Pin::new(&mut this.connection).poll_accept(cx)) { + None => return Poll::Ready(Ok(())), + + Some(Err(err)) => return Poll::Ready(Err(err.into())), + + Some(Ok((req, res))) => { // update keep-alive expire if this.ka_timer.is_some() { if let Some(expire) = this.config.keep_alive_expire() { @@ -117,11 +121,9 @@ where } let (parts, body) = req.into_parts(); - let mut req = Request::with_payload(Payload::< - crate::payload::PayloadStream, - >::H2( - crate::h2::Payload::new(body) - )); + let pl = crate::h2::Payload::new(body); + let pl = Payload::::H2(pl); + let mut req = Request::with_payload(pl); let head = &mut req.head_mut(); head.uri = parts.uri; @@ -131,24 +133,20 @@ where head.peer_addr = this.peer_addr; // merge on_connect_ext data into request extensions - req.extensions_mut().drain_from(&mut this.on_connect_data); + this.on_connect_data.merge_into(&mut req); - actix_rt::spawn(ServiceResponse::< - S::Future, - S::Response, - S::Error, - B, - > { + let svc = ServiceResponse:: { state: ServiceResponseState::ServiceCall( - this.service.call(req), + this.flow.service.call(req), Some(res), ), config: this.config.clone(), buffer: None, - _t: PhantomData, - }); + _phantom: PhantomData, + }; + + actix_rt::spawn(svc); } - Poll::Pending => return Poll::Pending, } } } @@ -160,7 +158,7 @@ struct ServiceResponse { state: ServiceResponseState, config: ServiceConfig, buffer: Option, - _t: PhantomData<(I, E)>, + _phantom: PhantomData<(I, E)>, } #[pin_project::pin_project(project = ServiceResponseStateProj)] @@ -197,8 +195,9 @@ where skip_len = true; *size = BodySize::Stream; } - _ => (), + _ => {} } + let _ = match size { BodySize::None | BodySize::Stream => None, BodySize::Empty => res @@ -213,11 +212,13 @@ where // copy headers for (key, value) in head.headers.iter() { match *key { - CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + // omit HTTP/1 only headers + CONNECTION | TRANSFER_ENCODING => continue, CONTENT_LENGTH if skip_len => continue, DATE => has_date = true, - _ => (), + _ => {} } + res.headers_mut().append(key, value.clone()); } @@ -249,109 +250,117 @@ where let mut this = self.as_mut().project(); match this.state.project() { - ServiceResponseStateProj::ServiceCall(call, send) => match call.poll(cx) { - Poll::Ready(Ok(res)) => { - let (res, body) = res.into().replace_body(()); + ServiceResponseStateProj::ServiceCall(call, send) => { + match ready!(call.poll(cx)) { + Ok(res) => { + let (res, body) = res.into().replace_body(()); - let mut send = send.take().unwrap(); - let mut size = body.size(); - let h2_res = self.as_mut().prepare_response(res.head(), &mut size); - this = self.as_mut().project(); + let mut send = send.take().unwrap(); + let mut size = body.size(); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); - let stream = match send.send_response(h2_res, size.is_eof()) { - Err(e) => { - trace!("Error sending h2 response: {:?}", e); - return Poll::Ready(()); + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { + trace!("Error sending HTTP/2 response: {:?}", e); + return Poll::Ready(()); + } + Ok(stream) => stream, + }; + + if size.is_eof() { + Poll::Ready(()) + } else { + this.state + .set(ServiceResponseState::SendPayload(stream, body)); + self.poll(cx) } - Ok(stream) => stream, - }; + } - if size.is_eof() { - Poll::Ready(()) - } else { - this.state - .set(ServiceResponseState::SendPayload(stream, body)); - self.poll(cx) + Err(e) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + + let mut send = send.take().unwrap(); + let mut size = body.size(); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); + + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { + trace!("Error sending HTTP/2 response: {:?}", e); + return Poll::Ready(()); + } + Ok(stream) => stream, + }; + + if size.is_eof() { + Poll::Ready(()) + } else { + this.state.set(ServiceResponseState::SendPayload( + stream, + body.into_body(), + )); + self.poll(cx) + } } } - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); + } - let mut send = send.take().unwrap(); - let mut size = body.size(); - let h2_res = self.as_mut().prepare_response(res.head(), &mut size); - this = self.as_mut().project(); - - let stream = match send.send_response(h2_res, size.is_eof()) { - Err(e) => { - trace!("Error sending h2 response: {:?}", e); - return Poll::Ready(()); - } - Ok(stream) => stream, - }; - - if size.is_eof() { - Poll::Ready(()) - } else { - this.state.set(ServiceResponseState::SendPayload( - stream, - body.into_body(), - )); - self.poll(cx) - } - } - }, ServiceResponseStateProj::SendPayload(ref mut stream, ref mut body) => { loop { loop { - if let Some(ref mut buffer) = this.buffer { - match stream.poll_capacity(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(()), - Poll::Ready(Some(Ok(cap))) => { - let len = buffer.len(); - let bytes = buffer.split_to(std::cmp::min(cap, len)); + match this.buffer { + Some(ref mut buffer) => { + match ready!(stream.poll_capacity(cx)) { + None => return Poll::Ready(()), - if let Err(e) = stream.send_data(bytes, false) { + Some(Ok(cap)) => { + let len = buffer.len(); + let bytes = buffer.split_to(cmp::min(cap, len)); + + if let Err(e) = stream.send_data(bytes, false) { + warn!("{:?}", e); + return Poll::Ready(()); + } else if !buffer.is_empty() { + let cap = cmp::min(buffer.len(), CHUNK_SIZE); + stream.reserve_capacity(cap); + } else { + this.buffer.take(); + } + } + + Some(Err(e)) => { warn!("{:?}", e); return Poll::Ready(()); - } else if !buffer.is_empty() { - let cap = - std::cmp::min(buffer.len(), CHUNK_SIZE); - stream.reserve_capacity(cap); - } else { - this.buffer.take(); } } - Poll::Ready(Some(Err(e))) => { - warn!("{:?}", e); - return Poll::Ready(()); - } } - } else { - match body.as_mut().poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => { + + None => match ready!(body.as_mut().poll_next(cx)) { + None => { if let Err(e) = stream.send_data(Bytes::new(), true) { warn!("{:?}", e); } return Poll::Ready(()); } - Poll::Ready(Some(Ok(chunk))) => { - stream.reserve_capacity(std::cmp::min( + + Some(Ok(chunk)) => { + stream.reserve_capacity(cmp::min( chunk.len(), CHUNK_SIZE, )); *this.buffer = Some(chunk); } - Poll::Ready(Some(Err(e))) => { + + Some(Err(e)) => { error!("Response payload stream error: {:?}", e); return Poll::Ready(()); } - } + }, } } } diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs index b00969227..7eff44ac1 100644 --- a/actix-http/src/h2/mod.rs +++ b/actix-http/src/h2/mod.rs @@ -1,9 +1,12 @@ -//! HTTP/2 implementation -use std::pin::Pin; -use std::task::{Context, Poll}; +//! HTTP/2 protocol. + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; use bytes::Bytes; -use futures_core::Stream; +use futures_core::{ready, Stream}; use h2::RecvStream; mod dispatcher; @@ -13,14 +16,14 @@ pub use self::dispatcher::Dispatcher; pub use self::service::H2Service; use crate::error::PayloadError; -/// H2 receive stream +/// HTTP/2 peer stream. pub struct Payload { - pl: RecvStream, + stream: RecvStream, } impl Payload { - pub(crate) fn new(pl: RecvStream) -> Self { - Self { pl } + pub(crate) fn new(stream: RecvStream) -> Self { + Self { stream } } } @@ -33,18 +36,17 @@ impl Stream for Payload { ) -> Poll> { let this = self.get_mut(); - match Pin::new(&mut this.pl).poll_data(cx) { - Poll::Ready(Some(Ok(chunk))) => { + match ready!(Pin::new(&mut this.stream).poll_data(cx)) { + Some(Ok(chunk)) => { let len = chunk.len(); - if let Err(err) = this.pl.flow_control().release_capacity(len) { - Poll::Ready(Some(Err(err.into()))) - } else { - Poll::Ready(Some(Ok(chunk))) + + match this.stream.flow_control().release_capacity(len) { + Ok(()) => Poll::Ready(Some(Ok(chunk))), + Err(err) => Poll::Ready(Some(Err(err.into()))), } } - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + None => Poll::Ready(None), } } } diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index b1fb9a634..e00c8d968 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -17,33 +17,33 @@ use h2::server::{self, Handshake}; use log::error; use crate::body::MessageBody; -use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::request::Request; use crate::response::Response; -use crate::{ConnectCallback, Extensions}; +use crate::service::HttpFlow; +use crate::{ConnectCallback, OnConnectData}; use super::dispatcher::Dispatcher; -/// `ServiceFactory` implementation for HTTP2 transport +/// `ServiceFactory` implementation for HTTP/2 transport pub struct H2Service { srv: S, cfg: ServiceConfig, on_connect_ext: Option>>, - _t: PhantomData<(T, B)>, + _phantom: PhantomData<(T, B)>, } impl H2Service where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { - /// Create new `HttpService` instance with config. - pub(crate) fn with_config>( + /// Create new `H2Service` instance with config. + pub(crate) fn with_config>( cfg: ServiceConfig, service: F, ) -> Self { @@ -51,7 +51,7 @@ where cfg, on_connect_ext: None, srv: service.into_factory(), - _t: PhantomData, + _phantom: PhantomData, } } @@ -64,18 +64,18 @@ where impl H2Service where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { - /// Create simple tcp based service + /// Create plain TCP based service pub fn tcp( self, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), Error = DispatchError, InitError = S::InitError, @@ -92,29 +92,29 @@ where #[cfg(feature = "openssl")] mod openssl { - use actix_service::{fn_factory, fn_service}; - use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream}; - use actix_tls::{openssl::HandshakeError, TlsError}; + use actix_service::{fn_factory, fn_service, ServiceFactoryExt}; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::TlsError; use super::*; impl H2Service, S, B> where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { - /// Create ssl based service + /// Create OpenSSL based service pub fn openssl( self, acceptor: SslAcceptor, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), - Error = TlsError, DispatchError>, + Error = TlsError, InitError = S::InitError, > { pipeline_factory( @@ -136,25 +136,26 @@ mod openssl { #[cfg(feature = "rustls")] mod rustls { use super::*; - use actix_tls::rustls::{Acceptor, ServerConfig, TlsStream}; - use actix_tls::TlsError; + use actix_service::ServiceFactoryExt; + use actix_tls::accept::rustls::{Acceptor, ServerConfig, TlsStream}; + use actix_tls::accept::TlsError; use std::io; impl H2Service, S, B> where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { - /// Create openssl based service + /// Create Rustls based service pub fn rustls( self, mut config: ServerConfig, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), Error = TlsError, InitError = S::InitError, @@ -178,21 +179,20 @@ mod rustls { } } -impl ServiceFactory for H2Service +impl ServiceFactory<(T, Option)> for H2Service where T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { - type Config = (); - type Request = (T, Option); type Response = (); type Error = DispatchError; - type InitError = S::InitError; + type Config = (); type Service = H2ServiceHandler; + type InitError = S::InitError; type Future = H2ServiceResponse; fn new_service(&self, _: ()) -> Self::Future { @@ -200,28 +200,31 @@ where fut: self.srv.new_service(()), cfg: Some(self.cfg.clone()), on_connect_ext: self.on_connect_ext.clone(), - _t: PhantomData, + _phantom: PhantomData, } } } #[doc(hidden)] #[pin_project::pin_project] -pub struct H2ServiceResponse { +pub struct H2ServiceResponse +where + S: ServiceFactory, +{ #[pin] fut: S::Future, cfg: Option, on_connect_ext: Option>>, - _t: PhantomData<(T, B)>, + _phantom: PhantomData, } impl Future for H2ServiceResponse where T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { type Output = Result, S::InitError>; @@ -229,28 +232,31 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut().project(); - Poll::Ready(ready!(this.fut.poll(cx)).map(|service| { + this.fut.poll(cx).map_ok(|service| { let this = self.as_mut().project(); H2ServiceHandler::new( this.cfg.take().unwrap(), this.on_connect_ext.clone(), service, ) - })) + }) } } -/// `Service` implementation for http/2 transport -pub struct H2ServiceHandler { - srv: CloneableService, +/// `Service` implementation for HTTP/2 transport +pub struct H2ServiceHandler +where + S: Service, +{ + flow: Rc>, cfg: ServiceConfig, on_connect_ext: Option>>, - _t: PhantomData<(T, B)>, + _phantom: PhantomData, } impl H2ServiceHandler where - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, @@ -259,69 +265,65 @@ where fn new( cfg: ServiceConfig, on_connect_ext: Option>>, - srv: S, + service: S, ) -> H2ServiceHandler { H2ServiceHandler { + flow: HttpFlow::new(service, (), None), cfg, on_connect_ext, - srv: CloneableService::new(srv), - _t: PhantomData, + _phantom: PhantomData, } } } -impl Service for H2ServiceHandler +impl Service<(T, Option)> for H2ServiceHandler where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, { - type Request = (T, Option); type Response = (); type Error = DispatchError; type Future = H2ServiceHandlerResponse; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.srv.poll_ready(cx).map_err(|e| { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.flow.service.poll_ready(cx).map_err(|e| { let e = e.into(); error!("Service readiness error: {:?}", e); DispatchError::Service(e) }) } - fn call(&mut self, (io, addr): Self::Request) -> Self::Future { - let mut connect_extensions = Extensions::new(); - if let Some(ref handler) = self.on_connect_ext { - // run on_connect_ext callback, populating connect extensions - handler(&io, &mut connect_extensions); - } + fn call(&self, (io, addr): (T, Option)) -> Self::Future { + let on_connect_data = + OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); H2ServiceHandlerResponse { state: State::Handshake( - Some(self.srv.clone()), + Some(self.flow.clone()), Some(self.cfg.clone()), addr, - Some(connect_extensions), + on_connect_data, server::handshake(io), ), } } } -enum State, B: MessageBody> +enum State, B: MessageBody> where T: AsyncRead + AsyncWrite + Unpin, S::Future: 'static, { - Incoming(Dispatcher), + Incoming(Dispatcher), Handshake( - Option>, + Option>>, Option, Option, - Option, + OnConnectData, Handshake, ), } @@ -329,7 +331,7 @@ where pub struct H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, @@ -341,7 +343,7 @@ where impl Future for H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, @@ -358,23 +360,23 @@ where ref peer_addr, ref mut on_connect_data, ref mut handshake, - ) => match Pin::new(handshake).poll(cx) { - Poll::Ready(Ok(conn)) => { + ) => match ready!(Pin::new(handshake).poll(cx)) { + Ok(conn) => { + let on_connect_data = std::mem::take(on_connect_data); self.state = State::Incoming(Dispatcher::new( srv.take().unwrap(), conn, - on_connect_data.take().unwrap(), + on_connect_data, config.take().unwrap(), None, *peer_addr, )); self.poll(cx) } - Poll::Ready(Err(err)) => { + Err(err) => { trace!("H2 handshake error: {}", err); Poll::Ready(Err(err.into())) } - Poll::Pending => Poll::Pending, }, } } diff --git a/actix-http/src/header/as_name.rs b/actix-http/src/header/as_name.rs new file mode 100644 index 000000000..af81ff7f2 --- /dev/null +++ b/actix-http/src/header/as_name.rs @@ -0,0 +1,48 @@ +//! Helper trait for types that can be effectively borrowed as a [HeaderValue]. +//! +//! [HeaderValue]: crate::http::HeaderValue + +use std::{borrow::Cow, str::FromStr}; + +use http::header::{HeaderName, InvalidHeaderName}; + +pub trait AsHeaderName: Sealed {} + +pub trait Sealed { + fn try_as_name(&self) -> Result, InvalidHeaderName>; +} + +impl Sealed for HeaderName { + fn try_as_name(&self) -> Result, InvalidHeaderName> { + Ok(Cow::Borrowed(self)) + } +} +impl AsHeaderName for HeaderName {} + +impl Sealed for &HeaderName { + fn try_as_name(&self) -> Result, InvalidHeaderName> { + Ok(Cow::Borrowed(*self)) + } +} +impl AsHeaderName for &HeaderName {} + +impl Sealed for &str { + fn try_as_name(&self) -> Result, InvalidHeaderName> { + HeaderName::from_str(self).map(Cow::Owned) + } +} +impl AsHeaderName for &str {} + +impl Sealed for String { + fn try_as_name(&self) -> Result, InvalidHeaderName> { + HeaderName::from_str(self).map(Cow::Owned) + } +} +impl AsHeaderName for String {} + +impl Sealed for &String { + fn try_as_name(&self) -> Result, InvalidHeaderName> { + HeaderName::from_str(self).map(Cow::Owned) + } +} +impl AsHeaderName for &String {} diff --git a/actix-http/src/header/common/accept.rs b/actix-http/src/header/common/accept.rs index da26b0261..775da3394 100644 --- a/actix-http/src/header/common/accept.rs +++ b/actix-http/src/header/common/accept.rs @@ -32,50 +32,36 @@ header! { /// * `text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c` /// /// # Examples - /// ```rust - /// # extern crate actix_http; - /// extern crate mime; + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{Accept, qitem}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// - /// builder.set( + /// builder.insert_header( /// Accept(vec![ /// qitem(mime::TEXT_HTML), /// ]) /// ); - /// # } /// ``` /// - /// ```rust - /// # extern crate actix_http; - /// extern crate mime; + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{Accept, qitem}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// - /// builder.set( + /// builder.insert_header( /// Accept(vec![ /// qitem(mime::APPLICATION_JSON), /// ]) /// ); - /// # } /// ``` /// - /// ```rust - /// # extern crate actix_http; - /// extern crate mime; + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{Accept, QualityItem, q, qitem}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// - /// builder.set( + /// builder.insert_header( /// Accept(vec![ /// qitem(mime::TEXT_HTML), /// qitem("application/xhtml+xml".parse().unwrap()), @@ -90,7 +76,6 @@ header! { /// ), /// ]) /// ); - /// # } /// ``` (Accept, header::ACCEPT) => (QualityItem)+ @@ -132,7 +117,7 @@ header! { #[test] fn test_fuzzing1() { use crate::test::TestRequest; - let req = TestRequest::with_header(crate::header::ACCEPT, "chunk#;e").finish(); + let req = TestRequest::default().insert_header((crate::header::ACCEPT, "chunk#;e")).finish(); let header = Accept::parse(&req); assert!(header.is_ok()); } diff --git a/actix-http/src/header/common/accept_charset.rs b/actix-http/src/header/common/accept_charset.rs index 291ca53b6..db530a8bc 100644 --- a/actix-http/src/header/common/accept_charset.rs +++ b/actix-http/src/header/common/accept_charset.rs @@ -21,44 +21,37 @@ header! { /// * `iso-8859-5, unicode-1-1;q=0.8` /// /// # Examples - /// ```rust - /// # extern crate actix_http; + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// AcceptCharset(vec![qitem(Charset::Us_Ascii)]) /// ); - /// # } /// ``` - /// ```rust - /// # extern crate actix_http; + /// + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{AcceptCharset, Charset, q, QualityItem}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// AcceptCharset(vec![ /// QualityItem::new(Charset::Us_Ascii, q(900)), /// QualityItem::new(Charset::Iso_8859_10, q(200)), /// ]) /// ); - /// # } /// ``` - /// ```rust - /// # extern crate actix_http; + /// + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// AcceptCharset(vec![qitem(Charset::Ext("utf-8".to_owned()))]) /// ); - /// # } /// ``` (AcceptCharset, ACCEPT_CHARSET) => (QualityItem)+ diff --git a/actix-http/src/header/common/accept_language.rs b/actix-http/src/header/common/accept_language.rs index 55879b57f..a7ad00863 100644 --- a/actix-http/src/header/common/accept_language.rs +++ b/actix-http/src/header/common/accept_language.rs @@ -22,41 +22,35 @@ header! { /// /// # Examples /// - /// ```rust - /// # extern crate actix_http; - /// # extern crate language_tags; + /// ``` + /// use language_tags::langtag; /// use actix_http::Response; /// use actix_http::http::header::{AcceptLanguage, LanguageTag, qitem}; /// - /// # fn main() { /// let mut builder = Response::Ok(); /// let mut langtag: LanguageTag = Default::default(); /// langtag.language = Some("en".to_owned()); /// langtag.region = Some("US".to_owned()); - /// builder.set( + /// builder.insert_header( /// AcceptLanguage(vec![ /// qitem(langtag), /// ]) /// ); - /// # } /// ``` /// - /// ```rust - /// # extern crate actix_http; - /// # #[macro_use] extern crate language_tags; + /// ``` + /// use language_tags::langtag; /// use actix_http::Response; /// use actix_http::http::header::{AcceptLanguage, QualityItem, q, qitem}; - /// # - /// # fn main() { + /// /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// AcceptLanguage(vec![ /// qitem(langtag!(da)), /// QualityItem::new(langtag!(en;;;GB), q(800)), /// QualityItem::new(langtag!(en), q(700)), /// ]) /// ); - /// # } /// ``` (AcceptLanguage, ACCEPT_LANGUAGE) => (QualityItem)+ diff --git a/actix-http/src/header/common/allow.rs b/actix-http/src/header/common/allow.rs index 88c21763c..06b1efedc 100644 --- a/actix-http/src/header/common/allow.rs +++ b/actix-http/src/header/common/allow.rs @@ -22,38 +22,28 @@ header! { /// /// # Examples /// - /// ```rust - /// # extern crate http; - /// # extern crate actix_http; + /// ``` /// use actix_http::Response; - /// use actix_http::http::header::Allow; - /// use http::Method; + /// use actix_http::http::{header::Allow, Method}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// Allow(vec![Method::GET]) /// ); - /// # } /// ``` /// - /// ```rust - /// # extern crate http; - /// # extern crate actix_http; + /// ``` /// use actix_http::Response; - /// use actix_http::http::header::Allow; - /// use http::Method; + /// use actix_http::http::{header::Allow, Method}; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// Allow(vec![ /// Method::GET, /// Method::POST, /// Method::PATCH, /// ]) /// ); - /// # } /// ``` (Allow, header::ALLOW) => (Method)* diff --git a/actix-http/src/header/common/cache_control.rs b/actix-http/src/header/common/cache_control.rs index ec94ce4a9..94ce9a750 100644 --- a/actix-http/src/header/common/cache_control.rs +++ b/actix-http/src/header/common/cache_control.rs @@ -28,12 +28,12 @@ use crate::header::{ /// * `max-age=30` /// /// # Examples -/// ```rust +/// ``` /// use actix_http::Response; /// use actix_http::http::header::{CacheControl, CacheDirective}; /// /// let mut builder = Response::Ok(); -/// builder.set(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); +/// builder.insert_header(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); /// ``` /// /// ```rust @@ -41,7 +41,7 @@ use crate::header::{ /// use actix_http::http::header::{CacheControl, CacheDirective}; /// /// let mut builder = Response::Ok(); -/// builder.set(CacheControl(vec![ +/// builder.insert_header(CacheControl(vec![ /// CacheDirective::NoCache, /// CacheDirective::Private, /// CacheDirective::MaxAge(360u32), @@ -82,7 +82,7 @@ impl fmt::Display for CacheControl { impl IntoHeaderValue for CacheControl { type Error = header::InvalidHeaderValue; - fn try_into(self) -> Result { + fn try_into_value(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); header::HeaderValue::from_maybe_shared(writer.take()) @@ -196,7 +196,8 @@ mod tests { #[test] fn test_parse_multiple_headers() { - let req = TestRequest::with_header(header::CACHE_CONTROL, "no-cache, private") + let req = TestRequest::default() + .insert_header((header::CACHE_CONTROL, "no-cache, private")) .finish(); let cache = Header::parse(&req); assert_eq!( @@ -210,9 +211,9 @@ mod tests { #[test] fn test_parse_argument() { - let req = - TestRequest::with_header(header::CACHE_CONTROL, "max-age=100, private") - .finish(); + let req = TestRequest::default() + .insert_header((header::CACHE_CONTROL, "max-age=100, private")) + .finish(); let cache = Header::parse(&req); assert_eq!( cache.ok(), @@ -225,8 +226,9 @@ mod tests { #[test] fn test_parse_quote_form() { - let req = - TestRequest::with_header(header::CACHE_CONTROL, "max-age=\"200\"").finish(); + let req = TestRequest::default() + .insert_header((header::CACHE_CONTROL, "max-age=\"200\"")) + .finish(); let cache = Header::parse(&req); assert_eq!( cache.ok(), @@ -236,8 +238,9 @@ mod tests { #[test] fn test_parse_extension() { - let req = - TestRequest::with_header(header::CACHE_CONTROL, "foo, bar=baz").finish(); + let req = TestRequest::default() + .insert_header((header::CACHE_CONTROL, "foo, bar=baz")) + .finish(); let cache = Header::parse(&req); assert_eq!( cache.ok(), @@ -250,7 +253,9 @@ mod tests { #[test] fn test_parse_bad_syntax() { - let req = TestRequest::with_header(header::CACHE_CONTROL, "foo=").finish(); + let req = TestRequest::default() + .insert_header((header::CACHE_CONTROL, "foo=")) + .finish(); let cache: Result = Header::parse(&req); assert_eq!(cache.ok(), None) } diff --git a/actix-http/src/header/common/content_disposition.rs b/actix-http/src/header/common/content_disposition.rs index 4c512acbe..ecc59aba3 100644 --- a/actix-http/src/header/common/content_disposition.rs +++ b/actix-http/src/header/common/content_disposition.rs @@ -1,10 +1,10 @@ -// # References -// -// "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt -// "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt -// "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc7578.txt -// Browser conformance tests at: http://greenbytes.de/tech/tc2231/ -// IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml +//! # References +//! +//! "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt +//! "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt +//! "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc7578.txt +//! Browser conformance tests at: http://greenbytes.de/tech/tc2231/ +//! IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml use lazy_static::lazy_static; use regex::Regex; @@ -454,7 +454,7 @@ impl ContentDisposition { impl IntoHeaderValue for ContentDisposition { type Error = header::InvalidHeaderValue; - fn try_into(self) -> Result { + fn try_into_value(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); header::HeaderValue::from_maybe_shared(writer.take()) diff --git a/actix-http/src/header/common/content_encoding.rs b/actix-http/src/header/common/content_encoding.rs new file mode 100644 index 000000000..b93d66101 --- /dev/null +++ b/actix-http/src/header/common/content_encoding.rs @@ -0,0 +1,106 @@ +use std::{convert::Infallible, str::FromStr}; + +use http::header::InvalidHeaderValue; + +use crate::{ + error::ParseError, + header::{self, from_one_raw_str, Header, HeaderName, HeaderValue, IntoHeaderValue}, + HttpMessage, +}; + +/// Represents a supported content encoding. +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum ContentEncoding { + /// Automatically select encoding based on encoding negotiation. + Auto, + + /// A format using the Brotli algorithm. + Br, + + /// A format using the zlib structure with deflate algorithm. + Deflate, + + /// Gzip algorithm. + Gzip, + + /// Indicates the identity function (i.e. no compression, nor modification). + Identity, +} + +impl ContentEncoding { + /// Is the content compressed? + #[inline] + pub fn is_compression(self) -> bool { + matches!(self, ContentEncoding::Identity | ContentEncoding::Auto) + } + + /// Convert content encoding to string + #[inline] + pub fn as_str(self) -> &'static str { + match self { + ContentEncoding::Br => "br", + ContentEncoding::Gzip => "gzip", + ContentEncoding::Deflate => "deflate", + ContentEncoding::Identity | ContentEncoding::Auto => "identity", + } + } + + /// Default Q-factor (quality) value. + #[inline] + pub fn quality(self) -> f64 { + match self { + ContentEncoding::Br => 1.1, + ContentEncoding::Gzip => 1.0, + ContentEncoding::Deflate => 0.9, + ContentEncoding::Identity | ContentEncoding::Auto => 0.1, + } + } +} + +impl Default for ContentEncoding { + fn default() -> Self { + Self::Identity + } +} + +impl FromStr for ContentEncoding { + type Err = Infallible; + + fn from_str(val: &str) -> Result { + Ok(Self::from(val)) + } +} + +impl From<&str> for ContentEncoding { + fn from(val: &str) -> ContentEncoding { + let val = val.trim(); + + if val.eq_ignore_ascii_case("br") { + ContentEncoding::Br + } else if val.eq_ignore_ascii_case("gzip") { + ContentEncoding::Gzip + } else if val.eq_ignore_ascii_case("deflate") { + ContentEncoding::Deflate + } else { + ContentEncoding::default() + } + } +} + +impl IntoHeaderValue for ContentEncoding { + type Error = InvalidHeaderValue; + + fn try_into_value(self) -> Result { + Ok(HeaderValue::from_static(self.as_str())) + } +} + +impl Header for ContentEncoding { + fn name() -> HeaderName { + header::CONTENT_ENCODING + } + + fn parse(msg: &T) -> Result { + from_one_raw_str(msg.headers().get(Self::name())) + } +} diff --git a/actix-http/src/header/common/content_language.rs b/actix-http/src/header/common/content_language.rs index 838981a39..e9be67a1b 100644 --- a/actix-http/src/header/common/content_language.rs +++ b/actix-http/src/header/common/content_language.rs @@ -23,38 +23,31 @@ header! { /// /// # Examples /// - /// ```rust - /// # extern crate actix_http; - /// # #[macro_use] extern crate language_tags; + /// ``` + /// use language_tags::langtag; /// use actix_http::Response; - /// # use actix_http::http::header::{ContentLanguage, qitem}; - /// # - /// # fn main() { + /// use actix_http::http::header::{ContentLanguage, qitem}; + /// /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// ContentLanguage(vec![ /// qitem(langtag!(en)), /// ]) /// ); - /// # } /// ``` /// - /// ```rust - /// # extern crate actix_http; - /// # #[macro_use] extern crate language_tags; + /// ``` + /// use language_tags::langtag; /// use actix_http::Response; - /// # use actix_http::http::header::{ContentLanguage, qitem}; - /// # - /// # fn main() { + /// use actix_http::http::header::{ContentLanguage, qitem}; /// /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// ContentLanguage(vec![ /// qitem(langtag!(da)), /// qitem(langtag!(en;;;GB)), /// ]) /// ); - /// # } /// ``` (ContentLanguage, CONTENT_LANGUAGE) => (QualityItem)+ diff --git a/actix-http/src/header/common/content_range.rs b/actix-http/src/header/common/content_range.rs index 9a604c641..8b7552377 100644 --- a/actix-http/src/header/common/content_range.rs +++ b/actix-http/src/header/common/content_range.rs @@ -200,7 +200,7 @@ impl Display for ContentRangeSpec { impl IntoHeaderValue for ContentRangeSpec { type Error = InvalidHeaderValue; - fn try_into(self) -> Result { + fn try_into_value(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); HeaderValue::from_maybe_shared(writer.take()) diff --git a/actix-http/src/header/common/content_type.rs b/actix-http/src/header/common/content_type.rs index a0baa5637..ac5c7e5b8 100644 --- a/actix-http/src/header/common/content_type.rs +++ b/actix-http/src/header/common/content_type.rs @@ -30,31 +30,24 @@ header! { /// /// # Examples /// - /// ```rust + /// ``` /// use actix_http::Response; /// use actix_http::http::header::ContentType; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// ContentType::json() /// ); - /// # } /// ``` /// - /// ```rust - /// # extern crate mime; - /// # extern crate actix_http; - /// use mime::TEXT_HTML; + /// ``` /// use actix_http::Response; /// use actix_http::http::header::ContentType; /// - /// # fn main() { /// let mut builder = Response::Ok(); - /// builder.set( - /// ContentType(TEXT_HTML) + /// builder.insert_header( + /// ContentType(mime::TEXT_HTML) /// ); - /// # } /// ``` (ContentType, CONTENT_TYPE) => [Mime] @@ -99,6 +92,7 @@ impl ContentType { pub fn form_url_encoded() -> ContentType { ContentType(mime::APPLICATION_WWW_FORM_URLENCODED) } + /// A constructor to easily create a `Content-Type: image/jpeg` header. #[inline] pub fn jpeg() -> ContentType { diff --git a/actix-http/src/header/common/date.rs b/actix-http/src/header/common/date.rs index 784100e8d..e5ace95e6 100644 --- a/actix-http/src/header/common/date.rs +++ b/actix-http/src/header/common/date.rs @@ -19,13 +19,15 @@ header! { /// /// # Example /// - /// ```rust + /// ``` + /// use std::time::SystemTime; /// use actix_http::Response; /// use actix_http::http::header::Date; - /// use std::time::SystemTime; /// /// let mut builder = Response::Ok(); - /// builder.set(Date(SystemTime::now().into())); + /// builder.insert_header( + /// Date(SystemTime::now().into()) + /// ); /// ``` (Date, DATE) => [HttpDate] diff --git a/actix-http/src/header/common/etag.rs b/actix-http/src/header/common/etag.rs index 325b91cbf..4c1e8d262 100644 --- a/actix-http/src/header/common/etag.rs +++ b/actix-http/src/header/common/etag.rs @@ -27,20 +27,24 @@ header! { /// /// # Examples /// - /// ```rust + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{ETag, EntityTag}; /// /// let mut builder = Response::Ok(); - /// builder.set(ETag(EntityTag::new(false, "xyzzy".to_owned()))); + /// builder.insert_header( + /// ETag(EntityTag::new(false, "xyzzy".to_owned())) + /// ); /// ``` /// - /// ```rust + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{ETag, EntityTag}; /// /// let mut builder = Response::Ok(); - /// builder.set(ETag(EntityTag::new(true, "xyzzy".to_owned()))); + /// builder.insert_header( + /// ETag(EntityTag::new(true, "xyzzy".to_owned())) + /// ); /// ``` (ETag, ETAG) => [EntityTag] diff --git a/actix-http/src/header/common/expires.rs b/actix-http/src/header/common/expires.rs index 3b9a7873d..79563955d 100644 --- a/actix-http/src/header/common/expires.rs +++ b/actix-http/src/header/common/expires.rs @@ -21,14 +21,16 @@ header! { /// /// # Example /// - /// ```rust + /// ``` + /// use std::time::{SystemTime, Duration}; /// use actix_http::Response; /// use actix_http::http::header::Expires; - /// use std::time::{SystemTime, Duration}; /// /// let mut builder = Response::Ok(); /// let expiration = SystemTime::now() + Duration::from_secs(60 * 60 * 24); - /// builder.set(Expires(expiration.into())); + /// builder.insert_header( + /// Expires(expiration.into()) + /// ); /// ``` (Expires, EXPIRES) => [HttpDate] diff --git a/actix-http/src/header/common/if_match.rs b/actix-http/src/header/common/if_match.rs index 7e0e9a7e0..db255e91a 100644 --- a/actix-http/src/header/common/if_match.rs +++ b/actix-http/src/header/common/if_match.rs @@ -29,20 +29,20 @@ header! { /// /// # Examples /// - /// ```rust + /// ``` /// use actix_http::Response; /// use actix_http::http::header::IfMatch; /// /// let mut builder = Response::Ok(); - /// builder.set(IfMatch::Any); + /// builder.insert_header(IfMatch::Any); /// ``` /// - /// ```rust + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{IfMatch, EntityTag}; /// /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// IfMatch::Items(vec![ /// EntityTag::new(false, "xyzzy".to_owned()), /// EntityTag::new(false, "foobar".to_owned()), diff --git a/actix-http/src/header/common/if_modified_since.rs b/actix-http/src/header/common/if_modified_since.rs index 39aca595d..99c7e441d 100644 --- a/actix-http/src/header/common/if_modified_since.rs +++ b/actix-http/src/header/common/if_modified_since.rs @@ -21,14 +21,16 @@ header! { /// /// # Example /// - /// ```rust + /// ``` + /// use std::time::{SystemTime, Duration}; /// use actix_http::Response; /// use actix_http::http::header::IfModifiedSince; - /// use std::time::{SystemTime, Duration}; /// /// let mut builder = Response::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); - /// builder.set(IfModifiedSince(modified.into())); + /// builder.insert_header( + /// IfModifiedSince(modified.into()) + /// ); /// ``` (IfModifiedSince, IF_MODIFIED_SINCE) => [HttpDate] diff --git a/actix-http/src/header/common/if_none_match.rs b/actix-http/src/header/common/if_none_match.rs index 7f6ccb137..464caf1ae 100644 --- a/actix-http/src/header/common/if_none_match.rs +++ b/actix-http/src/header/common/if_none_match.rs @@ -31,20 +31,20 @@ header! { /// /// # Examples /// - /// ```rust + /// ``` /// use actix_http::Response; /// use actix_http::http::header::IfNoneMatch; /// /// let mut builder = Response::Ok(); - /// builder.set(IfNoneMatch::Any); + /// builder.insert_header(IfNoneMatch::Any); /// ``` /// - /// ```rust + /// ``` /// use actix_http::Response; /// use actix_http::http::header::{IfNoneMatch, EntityTag}; /// /// let mut builder = Response::Ok(); - /// builder.set( + /// builder.insert_header( /// IfNoneMatch::Items(vec![ /// EntityTag::new(false, "xyzzy".to_owned()), /// EntityTag::new(false, "foobar".to_owned()), @@ -73,13 +73,15 @@ mod tests { fn test_if_none_match() { let mut if_none_match: Result; - let req = TestRequest::with_header(IF_NONE_MATCH, "*").finish(); + let req = TestRequest::default() + .insert_header((IF_NONE_MATCH, "*")) + .finish(); if_none_match = Header::parse(&req); assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Any)); - let req = - TestRequest::with_header(IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..]) - .finish(); + let req = TestRequest::default() + .insert_header((IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..])) + .finish(); if_none_match = Header::parse(&req); let mut entities: Vec = Vec::new(); diff --git a/actix-http/src/header/common/if_range.rs b/actix-http/src/header/common/if_range.rs index b14ad0391..0a5749505 100644 --- a/actix-http/src/header/common/if_range.rs +++ b/actix-http/src/header/common/if_range.rs @@ -5,7 +5,7 @@ use crate::header::{ self, from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate, IntoHeaderValue, InvalidHeaderValue, Writer, }; -use crate::httpmessage::HttpMessage; +use crate::HttpMessage; /// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2) /// @@ -35,31 +35,34 @@ use crate::httpmessage::HttpMessage; /// /// # Examples /// -/// ```rust +/// ``` /// use actix_http::Response; /// use actix_http::http::header::{EntityTag, IfRange}; /// /// let mut builder = Response::Ok(); -/// builder.set(IfRange::EntityTag(EntityTag::new( -/// false, -/// "xyzzy".to_owned(), -/// ))); +/// builder.insert_header( +/// IfRange::EntityTag( +/// EntityTag::new(false, "abc".to_owned()) +/// ) +/// ); /// ``` /// -/// ```rust -/// use actix_http::Response; -/// use actix_http::http::header::IfRange; +/// ``` /// use std::time::{Duration, SystemTime}; +/// use actix_http::{http::header::IfRange, Response}; /// /// let mut builder = Response::Ok(); /// let fetched = SystemTime::now() - Duration::from_secs(60 * 60 * 24); -/// builder.set(IfRange::Date(fetched.into())); +/// builder.insert_header( +/// IfRange::Date(fetched.into()) +/// ); /// ``` #[derive(Clone, Debug, PartialEq)] pub enum IfRange { - /// The entity-tag the client has of the resource + /// The entity-tag the client has of the resource. EntityTag(EntityTag), - /// The date when the client retrieved the resource + + /// The date when the client retrieved the resource. Date(HttpDate), } @@ -98,7 +101,7 @@ impl Display for IfRange { impl IntoHeaderValue for IfRange { type Error = InvalidHeaderValue; - fn try_into(self) -> Result { + fn try_into_value(self) -> Result { let mut writer = Writer::new(); let _ = write!(&mut writer, "{}", self); HeaderValue::from_maybe_shared(writer.take()) @@ -110,7 +113,8 @@ mod test_if_range { use super::IfRange as HeaderField; use crate::header::*; use std::str; + test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); - test_header!(test2, vec![b"\"xyzzy\""]); + test_header!(test2, vec![b"\"abc\""]); test_header!(test3, vec![b"this-is-invalid"], None::); } diff --git a/actix-http/src/header/common/if_unmodified_since.rs b/actix-http/src/header/common/if_unmodified_since.rs index d6c099e64..1c2b4af78 100644 --- a/actix-http/src/header/common/if_unmodified_since.rs +++ b/actix-http/src/header/common/if_unmodified_since.rs @@ -22,14 +22,16 @@ header! { /// /// # Example /// - /// ```rust + /// ``` + /// use std::time::{SystemTime, Duration}; /// use actix_http::Response; /// use actix_http::http::header::IfUnmodifiedSince; - /// use std::time::{SystemTime, Duration}; /// /// let mut builder = Response::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); - /// builder.set(IfUnmodifiedSince(modified.into())); + /// builder.insert_header( + /// IfUnmodifiedSince(modified.into()) + /// ); /// ``` (IfUnmodifiedSince, IF_UNMODIFIED_SINCE) => [HttpDate] diff --git a/actix-http/src/header/common/last_modified.rs b/actix-http/src/header/common/last_modified.rs index cc888ccb0..65608d846 100644 --- a/actix-http/src/header/common/last_modified.rs +++ b/actix-http/src/header/common/last_modified.rs @@ -21,14 +21,16 @@ header! { /// /// # Example /// - /// ```rust + /// ``` + /// use std::time::{SystemTime, Duration}; /// use actix_http::Response; /// use actix_http::http::header::LastModified; - /// use std::time::{SystemTime, Duration}; /// /// let mut builder = Response::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); - /// builder.set(LastModified(modified.into())); + /// builder.insert_header( + /// LastModified(modified.into()) + /// ); /// ``` (LastModified, LAST_MODIFIED) => [HttpDate] diff --git a/actix-http/src/header/common/mod.rs b/actix-http/src/header/common/mod.rs index c3d18613c..90e0a855e 100644 --- a/actix-http/src/header/common/mod.rs +++ b/actix-http/src/header/common/mod.rs @@ -18,6 +18,7 @@ pub use self::content_disposition::{ }; pub use self::content_language::ContentLanguage; pub use self::content_range::{ContentRange, ContentRangeSpec}; +pub use self::content_encoding::{ContentEncoding}; pub use self::content_type::ContentType; pub use self::date::Date; pub use self::etag::ETag; @@ -83,7 +84,7 @@ macro_rules! test_header { let a: Vec> = raw.iter().map(|x| x.to_vec()).collect(); let mut req = test::TestRequest::default(); for item in a { - req = req.header(HeaderField::name(), item).take(); + req = req.insert_header((HeaderField::name(), item)).take(); } let req = req.finish(); let value = HeaderField::parse(&req); @@ -110,7 +111,7 @@ macro_rules! test_header { let a: Vec> = $raw.iter().map(|x| x.to_vec()).collect(); let mut req = test::TestRequest::default(); for item in a { - req.header(HeaderField::name(), item); + req.insert_header((HeaderField::name(), item)); } let req = req.finish(); let val = HeaderField::parse(&req); @@ -168,7 +169,7 @@ macro_rules! header { impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; - fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); @@ -204,7 +205,7 @@ macro_rules! header { impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; - fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); @@ -240,8 +241,8 @@ macro_rules! header { impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; - fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - self.0.try_into() + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + self.0.try_into_value() } } }; @@ -289,7 +290,7 @@ macro_rules! header { impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; - fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); @@ -333,13 +334,14 @@ macro_rules! header { } mod accept_charset; -//mod accept_encoding; +// mod accept_encoding; mod accept; mod accept_language; mod allow; mod cache_control; mod content_disposition; mod content_language; +mod content_encoding; mod content_range; mod content_type; mod date; diff --git a/actix-http/src/header/into_pair.rs b/actix-http/src/header/into_pair.rs new file mode 100644 index 000000000..d0d6e7324 --- /dev/null +++ b/actix-http/src/header/into_pair.rs @@ -0,0 +1,117 @@ +use std::convert::TryFrom; + +use http::{ + header::{HeaderName, InvalidHeaderName, InvalidHeaderValue}, + Error as HttpError, HeaderValue, +}; + +use super::{Header, IntoHeaderValue}; + +/// Transforms structures into header K/V pairs for inserting into `HeaderMap`s. +pub trait IntoHeaderPair: Sized { + type Error: Into; + + fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error>; +} + +#[derive(Debug)] +pub enum InvalidHeaderPart { + Name(InvalidHeaderName), + Value(InvalidHeaderValue), +} + +impl From for HttpError { + fn from(part_err: InvalidHeaderPart) -> Self { + match part_err { + InvalidHeaderPart::Name(err) => err.into(), + InvalidHeaderPart::Value(err) => err.into(), + } + } +} + +impl IntoHeaderPair for (HeaderName, V) +where + V: IntoHeaderValue, + V::Error: Into, +{ + type Error = InvalidHeaderPart; + + fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let (name, value) = self; + let value = value + .try_into_value() + .map_err(|err| InvalidHeaderPart::Value(err.into()))?; + Ok((name, value)) + } +} + +impl IntoHeaderPair for (&HeaderName, V) +where + V: IntoHeaderValue, + V::Error: Into, +{ + type Error = InvalidHeaderPart; + + fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let (name, value) = self; + let value = value + .try_into_value() + .map_err(|err| InvalidHeaderPart::Value(err.into()))?; + Ok((name.clone(), value)) + } +} + +impl IntoHeaderPair for (&[u8], V) +where + V: IntoHeaderValue, + V::Error: Into, +{ + type Error = InvalidHeaderPart; + + fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let (name, value) = self; + let name = HeaderName::try_from(name).map_err(InvalidHeaderPart::Name)?; + let value = value + .try_into_value() + .map_err(|err| InvalidHeaderPart::Value(err.into()))?; + Ok((name, value)) + } +} + +impl IntoHeaderPair for (&str, V) +where + V: IntoHeaderValue, + V::Error: Into, +{ + type Error = InvalidHeaderPart; + + fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let (name, value) = self; + let name = HeaderName::try_from(name).map_err(InvalidHeaderPart::Name)?; + let value = value + .try_into_value() + .map_err(|err| InvalidHeaderPart::Value(err.into()))?; + Ok((name, value)) + } +} + +impl IntoHeaderPair for (String, V) +where + V: IntoHeaderValue, + V::Error: Into, +{ + type Error = InvalidHeaderPart; + + fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let (name, value) = self; + (name.as_str(), value).try_into_header_pair() + } +} + +impl IntoHeaderPair for T { + type Error = ::Error; + + fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + Ok((T::name(), self.try_into_value()?)) + } +} diff --git a/actix-http/src/header/into_value.rs b/actix-http/src/header/into_value.rs new file mode 100644 index 000000000..4ba58e726 --- /dev/null +++ b/actix-http/src/header/into_value.rs @@ -0,0 +1,131 @@ +use std::convert::TryFrom; + +use bytes::Bytes; +use http::{header::InvalidHeaderValue, Error as HttpError, HeaderValue}; +use mime::Mime; + +/// A trait for any object that can be Converted to a `HeaderValue` +pub trait IntoHeaderValue: Sized { + /// The type returned in the event of a conversion error. + type Error: Into; + + /// Try to convert value to a HeaderValue. + fn try_into_value(self) -> Result; +} + +impl IntoHeaderValue for HeaderValue { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + Ok(self) + } +} + +impl IntoHeaderValue for &HeaderValue { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + Ok(self.clone()) + } +} + +impl IntoHeaderValue for &str { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + self.parse() + } +} + +impl IntoHeaderValue for &[u8] { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::from_bytes(self) + } +} + +impl IntoHeaderValue for Bytes { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::from_maybe_shared(self) + } +} + +impl IntoHeaderValue for Vec { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::try_from(self) + } +} + +impl IntoHeaderValue for String { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::try_from(self) + } +} + +impl IntoHeaderValue for usize { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::try_from(self.to_string()) + } +} + +impl IntoHeaderValue for i64 { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::try_from(self.to_string()) + } +} + +impl IntoHeaderValue for u64 { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::try_from(self.to_string()) + } +} + +impl IntoHeaderValue for i32 { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::try_from(self.to_string()) + } +} + +impl IntoHeaderValue for u32 { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::try_from(self.to_string()) + } +} + +impl IntoHeaderValue for Mime { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result { + HeaderValue::from_str(self.as_ref()) + } +} diff --git a/actix-http/src/header/map.rs b/actix-http/src/header/map.rs index 6ab3509f7..106e44edb 100644 --- a/actix-http/src/header/map.rs +++ b/actix-http/src/header/map.rs @@ -1,328 +1,563 @@ -use std::collections::hash_map::{self, Entry}; -use std::convert::TryFrom; +//! A multi-value [`HeaderMap`] and its iterators. -use either::Either; -use fxhash::FxHashMap; +use std::{borrow::Cow, collections::hash_map, ops}; + +use ahash::AHashMap; use http::header::{HeaderName, HeaderValue}; +use smallvec::{smallvec, SmallVec}; -/// A set of HTTP headers +use crate::header::AsHeaderName; + +/// A multi-map of HTTP headers. /// -/// `HeaderMap` is an multi-map of [`HeaderName`] to values. -#[derive(Debug, Clone)] +/// `HeaderMap` is a "multi-map" of [`HeaderName`] to one or more [`HeaderValue`]s. +/// +/// # Examples +/// ``` +/// use actix_http::http::{header, HeaderMap, HeaderValue}; +/// +/// let mut map = HeaderMap::new(); +/// +/// map.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain")); +/// map.insert(header::ORIGIN, HeaderValue::from_static("example.com")); +/// +/// assert!(map.contains_key(header::CONTENT_TYPE)); +/// assert!(map.contains_key(header::ORIGIN)); +/// +/// let mut removed = map.remove(header::ORIGIN); +/// assert_eq!(removed.next().unwrap(), "example.com"); +/// +/// assert!(!map.contains_key(header::ORIGIN)); +/// ``` +#[derive(Debug, Clone, Default)] pub struct HeaderMap { - pub(crate) inner: FxHashMap, + pub(crate) inner: AHashMap, } +/// A bespoke non-empty list for HeaderMap values. #[derive(Debug, Clone)] -pub(crate) enum Value { - One(HeaderValue), - Multi(Vec), +pub(crate) struct Value { + inner: SmallVec<[HeaderValue; 4]>, } impl Value { - fn get(&self) -> &HeaderValue { - match self { - Value::One(ref val) => val, - Value::Multi(ref val) => &val[0], + fn one(val: HeaderValue) -> Self { + Self { + inner: smallvec![val], } } - fn get_mut(&mut self) -> &mut HeaderValue { - match self { - Value::One(ref mut val) => val, - Value::Multi(ref mut val) => &mut val[0], - } + fn first(&self) -> &HeaderValue { + &self.inner[0] } - fn append(&mut self, val: HeaderValue) { - match self { - Value::One(_) => { - let data = std::mem::replace(self, Value::Multi(vec![val])); - match data { - Value::One(val) => self.append(val), - Value::Multi(_) => unreachable!(), - } - } - Value::Multi(ref mut vec) => vec.push(val), - } + fn first_mut(&mut self) -> &mut HeaderValue { + &mut self.inner[0] + } + + fn append(&mut self, new_val: HeaderValue) { + self.inner.push(new_val) + } +} + +impl ops::Deref for Value { + type Target = SmallVec<[HeaderValue; 4]>; + + fn deref(&self) -> &Self::Target { + &self.inner } } impl HeaderMap { /// Create an empty `HeaderMap`. /// - /// The map will be created without any capacity. This function will not - /// allocate. + /// The map will be created without any capacity; this function will not allocate. + /// + /// # Examples + /// ``` + /// # use actix_http::http::HeaderMap; + /// let map = HeaderMap::new(); + /// + /// assert!(map.is_empty()); + /// assert_eq!(0, map.capacity()); + /// ``` pub fn new() -> Self { - HeaderMap { - inner: FxHashMap::default(), - } + HeaderMap::default() } /// Create an empty `HeaderMap` with the specified capacity. /// - /// The returned map will allocate internal storage in order to hold about - /// `capacity` elements without reallocating. However, this is a "best - /// effort" as there are usage patterns that could cause additional - /// allocations before `capacity` headers are stored in the map. + /// The map will be able to hold at least `capacity` elements without needing to reallocate. + /// If `capacity` is 0, the map will be created without allocating. /// - /// More capacity than requested may be allocated. - pub fn with_capacity(capacity: usize) -> HeaderMap { + /// # Examples + /// ``` + /// # use actix_http::http::HeaderMap; + /// let map = HeaderMap::with_capacity(16); + /// + /// assert!(map.is_empty()); + /// assert!(map.capacity() >= 16); + /// ``` + pub fn with_capacity(capacity: usize) -> Self { HeaderMap { - inner: FxHashMap::with_capacity_and_hasher(capacity, Default::default()), + inner: AHashMap::with_capacity(capacity), } } - /// Returns the number of keys stored in the map. + /// Create new `HeaderMap` from a `http::HeaderMap`-like drain. + pub(crate) fn from_drain(mut drain: I) -> Self + where + I: Iterator, HeaderValue)>, + { + let (first_name, first_value) = match drain.next() { + None => return HeaderMap::new(), + Some((name, val)) => { + let name = name.expect("drained first item had no name"); + (name, val) + } + }; + + let (lb, ub) = drain.size_hint(); + let capacity = ub.unwrap_or(lb); + + let mut map = HeaderMap::with_capacity(capacity); + map.append(first_name.clone(), first_value); + + let (map, _) = + drain.fold((map, first_name), |(mut map, prev_name), (name, value)| { + let name = name.unwrap_or(prev_name); + map.append(name.clone(), value); + (map, name) + }); + + map + } + + /// Returns the number of values stored in the map. /// - /// This number could be be less than or equal to actual headers stored in - /// the map. + /// See also: [`len_keys`](Self::len_keys). + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// assert_eq!(map.len(), 0); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// assert_eq!(map.len(), 2); + /// + /// map.append(header::SET_COOKIE, HeaderValue::from_static("two=2")); + /// assert_eq!(map.len(), 3); + /// ``` pub fn len(&self) -> usize { + self.inner + .iter() + .fold(0, |acc, (_, values)| acc + values.len()) + } + + /// Returns the number of _keys_ stored in the map. + /// + /// The number of values stored will be at least this number. See also: [`Self::len`]. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// assert_eq!(map.len_keys(), 0); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// assert_eq!(map.len_keys(), 2); + /// + /// map.append(header::SET_COOKIE, HeaderValue::from_static("two=2")); + /// assert_eq!(map.len_keys(), 2); + /// ``` + pub fn len_keys(&self) -> usize { self.inner.len() } /// Returns true if the map contains no elements. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// assert!(map.is_empty()); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// assert!(!map.is_empty()); + /// ``` pub fn is_empty(&self) -> bool { self.inner.len() == 0 } - /// Clears the map, removing all key-value pairs. Keeps the allocated memory - /// for reuse. + /// Clears the map, removing all name-value pairs. + /// + /// Keeps the allocated memory for reuse. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// assert_eq!(map.len(), 2); + /// + /// map.clear(); + /// assert!(map.is_empty()); + /// ``` pub fn clear(&mut self) { self.inner.clear(); } - /// Returns the number of headers the map can hold without reallocating. + fn get_value(&self, key: impl AsHeaderName) -> Option<&Value> { + match key.try_as_name().ok()? { + Cow::Borrowed(name) => self.inner.get(name), + Cow::Owned(name) => self.inner.get(&name), + } + } + + /// Returns a reference to the _first_ value associated with a header name. /// - /// This number is an approximation as certain usage patterns could cause - /// additional allocations before the returned capacity is filled. + /// Returns `None` if there is no value associated with the key. + /// + /// Even when multiple values are associated with the key, the "first" one is returned but is + /// not guaranteed to be chosen with any particular order; though, the returned item will be + /// consistent for each call to `get` if the map has not changed. + /// + /// See also: [`get_all`](Self::get_all). + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// + /// let cookie = map.get(header::SET_COOKIE).unwrap(); + /// assert_eq!(cookie, "one=1"); + /// + /// map.append(header::SET_COOKIE, HeaderValue::from_static("two=2")); + /// assert_eq!(map.get(header::SET_COOKIE).unwrap(), "one=1"); + /// + /// assert_eq!(map.get(header::SET_COOKIE), map.get("set-cookie")); + /// assert_eq!(map.get(header::SET_COOKIE), map.get("Set-Cookie")); + /// + /// assert!(map.get(header::HOST).is_none()); + /// assert!(map.get("INVALID HEADER NAME").is_none()); + /// ``` + pub fn get(&self, key: impl AsHeaderName) -> Option<&HeaderValue> { + self.get_value(key).map(|val| val.first()) + } + + /// Returns a mutable reference to the _first_ value associated a header name. + /// + /// Returns `None` if there is no value associated with the key. + /// + /// Even when multiple values are associated with the key, the "first" one is returned but is + /// not guaranteed to be chosen with any particular order; though, the returned item will be + /// consistent for each call to `get_mut` if the map has not changed. + /// + /// See also: [`get_all`](Self::get_all). + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// + /// let mut cookie = map.get_mut(header::SET_COOKIE).unwrap(); + /// assert_eq!(cookie, "one=1"); + /// + /// *cookie = HeaderValue::from_static("three=3"); + /// assert_eq!(map.get(header::SET_COOKIE).unwrap(), "three=3"); + /// + /// assert!(map.get(header::HOST).is_none()); + /// assert!(map.get("INVALID HEADER NAME").is_none()); + /// ``` + pub fn get_mut(&mut self, key: impl AsHeaderName) -> Option<&mut HeaderValue> { + match key.try_as_name().ok()? { + Cow::Borrowed(name) => self.inner.get_mut(name).map(|v| v.first_mut()), + Cow::Owned(name) => self.inner.get_mut(&name).map(|v| v.first_mut()), + } + } + + /// Returns an iterator over all values associated with a header name. + /// + /// The returned iterator does not incur any allocations and will yield no items if there are no + /// values associated with the key. Iteration order is **not** guaranteed to be the same as + /// insertion order. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// let mut none_iter = map.get_all(header::ORIGIN); + /// assert!(none_iter.next().is_none()); + /// + /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// map.append(header::SET_COOKIE, HeaderValue::from_static("two=2")); + /// + /// let mut set_cookies_iter = map.get_all(header::SET_COOKIE); + /// assert_eq!(set_cookies_iter.next().unwrap(), "one=1"); + /// assert_eq!(set_cookies_iter.next().unwrap(), "two=2"); + /// assert!(set_cookies_iter.next().is_none()); + /// ``` + pub fn get_all(&self, key: impl AsHeaderName) -> GetAll<'_> { + GetAll::new(self.get_value(key)) + } + + // TODO: get_all_mut ? + + /// Returns `true` if the map contains a value for the specified key. + /// + /// Invalid header names will simply return false. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// assert!(!map.contains_key(header::ACCEPT)); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// assert!(map.contains_key(header::ACCEPT)); + /// ``` + pub fn contains_key(&self, key: impl AsHeaderName) -> bool { + match key.try_as_name() { + Ok(Cow::Borrowed(name)) => self.inner.contains_key(name), + Ok(Cow::Owned(name)) => self.inner.contains_key(&name), + Err(_) => false, + } + } + + /// Inserts a name-value pair into the map. + /// + /// If the map already contained this key, the new value is associated with the key and all + /// previous values are removed and returned as a `Removed` iterator. The key is not updated; + /// this matters for types that can be `==` without being identical. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// assert!(map.contains_key(header::ACCEPT)); + /// assert_eq!(map.len(), 1); + /// + /// let mut removed = map.insert(header::ACCEPT, HeaderValue::from_static("text/csv")); + /// assert_eq!(removed.next().unwrap(), "text/plain"); + /// assert!(removed.next().is_none()); + /// + /// assert_eq!(map.len(), 1); + /// ``` + pub fn insert(&mut self, key: HeaderName, val: HeaderValue) -> Removed { + let value = self.inner.insert(key, Value::one(val)); + Removed::new(value) + } + + /// Inserts a name-value pair into the map. + /// + /// If the map already contained this key, the new value is added to the list of values + /// currently associated with the key. The key is not updated; this matters for types that can + /// be `==` without being identical. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// map.append(header::HOST, HeaderValue::from_static("example.com")); + /// assert_eq!(map.len(), 1); + /// + /// map.append(header::ACCEPT, HeaderValue::from_static("text/csv")); + /// assert_eq!(map.len(), 2); + /// + /// map.append(header::ACCEPT, HeaderValue::from_static("text/html")); + /// assert_eq!(map.len(), 3); + /// ``` + pub fn append(&mut self, key: HeaderName, value: HeaderValue) { + match self.inner.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().append(value); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(Value::one(value)); + } + }; + } + + /// Removes all headers for a particular header name from the map. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// map.append(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// map.append(header::SET_COOKIE, HeaderValue::from_static("one=2")); + /// + /// assert_eq!(map.len(), 2); + /// + /// let mut removed = map.remove(header::SET_COOKIE); + /// assert_eq!(removed.next().unwrap(), "one=1"); + /// assert_eq!(removed.next().unwrap(), "one=2"); + /// assert!(removed.next().is_none()); + /// + /// assert!(map.is_empty()); + pub fn remove(&mut self, key: impl AsHeaderName) -> Removed { + let value = match key.try_as_name() { + Ok(Cow::Borrowed(name)) => self.inner.remove(name), + Ok(Cow::Owned(name)) => self.inner.remove(&name), + Err(_) => None, + }; + + Removed::new(value) + } + + /// Returns the number of single-value headers the map can hold without needing to reallocate. + /// + /// Since this is a multi-value map, the actual capacity is much larger when considering + /// each header name can be associated with an arbitrary number of values. The effect is that + /// the size of `len` may be greater than `capacity` since it counts all the values. + /// Conversely, [`len_keys`](Self::len_keys) will never be larger than capacity. + /// + /// # Examples + /// ``` + /// # use actix_http::http::HeaderMap; + /// let map = HeaderMap::with_capacity(16); + /// + /// assert!(map.is_empty()); + /// assert!(map.capacity() >= 16); + /// ``` pub fn capacity(&self) -> usize { self.inner.capacity() } - /// Reserves capacity for at least `additional` more headers to be inserted - /// into the `HeaderMap`. + /// Reserves capacity for at least `additional` more headers to be inserted in the map. /// - /// The header map may reserve more space to avoid frequent reallocations. - /// Like with `with_capacity`, this will be a "best effort" to avoid - /// allocations until `additional` more headers are inserted. Certain usage - /// patterns could cause additional allocations before the number is - /// reached. + /// The header map may reserve more space to avoid frequent reallocations. Additional capacity + /// only considers single-value headers. + /// + /// # Panics + /// Panics if the new allocation size overflows usize. + /// + /// # Examples + /// ``` + /// # use actix_http::http::HeaderMap; + /// let mut map = HeaderMap::with_capacity(2); + /// assert!(map.capacity() >= 2); + /// + /// map.reserve(100); + /// assert!(map.capacity() >= 102); + /// + /// assert!(map.is_empty()); + /// ``` pub fn reserve(&mut self, additional: usize) { self.inner.reserve(additional) } - /// Returns a reference to the value associated with the key. + /// An iterator over all name-value pairs. /// - /// If there are multiple values associated with the key, then the first one - /// is returned. Use `get_all` to get all values associated with a given - /// key. Returns `None` if there are no values associated with the key. - pub fn get(&self, name: N) -> Option<&HeaderValue> { - self.get2(name).map(|v| v.get()) - } - - fn get2(&self, name: N) -> Option<&Value> { - match name.as_name() { - Either::Left(name) => self.inner.get(name), - Either::Right(s) => { - if let Ok(name) = HeaderName::try_from(s) { - self.inner.get(&name) - } else { - None - } - } - } - } - - /// Returns a view of all values associated with a key. + /// Names will be yielded for each associated value. So, if a key has 3 associated values, it + /// will be yielded 3 times. The iteration order should be considered arbitrary. /// - /// The returned view does not incur any allocations and allows iterating - /// the values associated with the key. See [`GetAll`] for more details. - /// Returns `None` if there are no values associated with the key. - pub fn get_all(&self, name: N) -> GetAll<'_> { - GetAll { - idx: 0, - item: self.get2(name), - } - } - - /// Returns a mutable reference to the value associated with the key. + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); /// - /// If there are multiple values associated with the key, then the first one - /// is returned. Use `entry` to get all values associated with a given - /// key. Returns `None` if there are no values associated with the key. - pub fn get_mut(&mut self, name: N) -> Option<&mut HeaderValue> { - match name.as_name() { - Either::Left(name) => self.inner.get_mut(name).map(|v| v.get_mut()), - Either::Right(s) => { - if let Ok(name) = HeaderName::try_from(s) { - self.inner.get_mut(&name).map(|v| v.get_mut()) - } else { - None - } - } - } - } - - /// Returns true if the map contains a value for the specified key. - pub fn contains_key(&self, key: N) -> bool { - match key.as_name() { - Either::Left(name) => self.inner.contains_key(name), - Either::Right(s) => { - if let Ok(name) = HeaderName::try_from(s) { - self.inner.contains_key(&name) - } else { - false - } - } - } - } - - /// An iterator visiting all key-value pairs. + /// let mut iter = map.iter(); + /// assert!(iter.next().is_none()); /// - /// The iteration order is arbitrary, but consistent across platforms for - /// the same crate version. Each key will be yielded once per associated - /// value. So, if a key has 3 associated values, it will be yielded 3 times. + /// map.append(header::HOST, HeaderValue::from_static("duck.com")); + /// map.append(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// map.append(header::SET_COOKIE, HeaderValue::from_static("two=2")); + /// + /// let mut iter = map.iter(); + /// assert!(iter.next().is_some()); + /// assert!(iter.next().is_some()); + /// assert!(iter.next().is_some()); + /// assert!(iter.next().is_none()); + /// + /// let pairs = map.iter().collect::>(); + /// assert!(pairs.contains(&(&header::HOST, &HeaderValue::from_static("duck.com")))); + /// assert!(pairs.contains(&(&header::SET_COOKIE, &HeaderValue::from_static("one=1")))); + /// assert!(pairs.contains(&(&header::SET_COOKIE, &HeaderValue::from_static("two=2")))); + /// ``` pub fn iter(&self) -> Iter<'_> { Iter::new(self.inner.iter()) } - /// An iterator visiting all keys. + /// An iterator over all contained header names. /// - /// The iteration order is arbitrary, but consistent across platforms for - /// the same crate version. Each key will be yielded only once even if it - /// has multiple associated values. + /// Each name will only be yielded once even if it has multiple associated values. The iteration + /// order should be considered arbitrary. + /// + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// let mut iter = map.keys(); + /// assert!(iter.next().is_none()); + /// + /// map.append(header::HOST, HeaderValue::from_static("duck.com")); + /// map.append(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// map.append(header::SET_COOKIE, HeaderValue::from_static("two=2")); + /// + /// let keys = map.keys().cloned().collect::>(); + /// assert_eq!(keys.len(), 2); + /// assert!(keys.contains(&header::HOST)); + /// assert!(keys.contains(&header::SET_COOKIE)); + /// ``` pub fn keys(&self) -> Keys<'_> { Keys(self.inner.keys()) } - /// Inserts a key-value pair into the map. + /// Clears the map, returning all name-value sets as an iterator. /// - /// If the map did not previously have this key present, then `None` is - /// returned. + /// Header names will only be yielded for the first value in each set. All items that are + /// yielded without a name and after an item with a name are associated with that same name. + /// The first item will always contain a name. /// - /// If the map did have this key present, the new value is associated with - /// the key and all previous values are removed. **Note** that only a single - /// one of the previous values is returned. If there are multiple values - /// that have been previously associated with the key, then the first one is - /// returned. See `insert_mult` on `OccupiedEntry` for an API that returns - /// all values. + /// Keeps the allocated memory for reuse. + /// # Examples + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); /// - /// The key is not updated, though; this matters for types that can be `==` - /// without being identical. - pub fn insert(&mut self, key: HeaderName, val: HeaderValue) { - let _ = self.inner.insert(key, Value::One(val)); - } - - /// Inserts a key-value pair into the map. + /// let mut iter = map.drain(); + /// assert!(iter.next().is_none()); + /// drop(iter); /// - /// If the map did not previously have this key present, then `false` is - /// returned. + /// map.append(header::SET_COOKIE, HeaderValue::from_static("one=1")); + /// map.append(header::SET_COOKIE, HeaderValue::from_static("two=2")); /// - /// If the map did have this key present, the new value is pushed to the end - /// of the list of values currently associated with the key. The key is not - /// updated, though; this matters for types that can be `==` without being - /// identical. - pub fn append(&mut self, key: HeaderName, value: HeaderValue) { - match self.inner.entry(key) { - Entry::Occupied(mut entry) => entry.get_mut().append(value), - Entry::Vacant(entry) => { - entry.insert(Value::One(value)); - } - } - } - - /// Removes all headers for a particular header name from the map. - pub fn remove(&mut self, key: N) { - match key.as_name() { - Either::Left(name) => { - let _ = self.inner.remove(name); - } - Either::Right(s) => { - if let Ok(name) = HeaderName::try_from(s) { - let _ = self.inner.remove(&name); - } - } - } + /// let mut iter = map.drain(); + /// assert_eq!(iter.next().unwrap(), (Some(header::SET_COOKIE), HeaderValue::from_static("one=1"))); + /// assert_eq!(iter.next().unwrap(), (None, HeaderValue::from_static("two=2"))); + /// drop(iter); + /// + /// assert!(map.is_empty()); + /// ``` + pub fn drain(&mut self) -> Drain<'_> { + Drain::new(self.inner.drain()) } } -#[doc(hidden)] -pub trait AsName { - fn as_name(&self) -> Either<&HeaderName, &str>; -} - -impl AsName for HeaderName { - fn as_name(&self) -> Either<&HeaderName, &str> { - Either::Left(self) - } -} - -impl<'a> AsName for &'a HeaderName { - fn as_name(&self) -> Either<&HeaderName, &str> { - Either::Left(self) - } -} - -impl<'a> AsName for &'a str { - fn as_name(&self) -> Either<&HeaderName, &str> { - Either::Right(self) - } -} - -impl AsName for String { - fn as_name(&self) -> Either<&HeaderName, &str> { - Either::Right(self.as_str()) - } -} - -impl<'a> AsName for &'a String { - fn as_name(&self) -> Either<&HeaderName, &str> { - Either::Right(self.as_str()) - } -} - -pub struct GetAll<'a> { - idx: usize, - item: Option<&'a Value>, -} - -impl<'a> Iterator for GetAll<'a> { - type Item = &'a HeaderValue; +/// Note that this implementation will clone a [HeaderName] for each value. +impl IntoIterator for HeaderMap { + type Item = (HeaderName, HeaderValue); + type IntoIter = IntoIter; #[inline] - fn next(&mut self) -> Option<&'a HeaderValue> { - if let Some(ref val) = self.item { - match val { - Value::One(ref val) => { - self.item.take(); - Some(val) - } - Value::Multi(ref vec) => { - if self.idx < vec.len() { - let item = Some(&vec[self.idx]); - self.idx += 1; - item - } else { - self.item.take(); - None - } - } - } - } else { - None - } - } -} - -pub struct Keys<'a>(hash_map::Keys<'a, HeaderName, Value>); - -impl<'a> Iterator for Keys<'a> { - type Item = &'a HeaderName; - - #[inline] - fn next(&mut self) -> Option<&'a HeaderName> { - self.0.next() + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.inner.into_iter()) } } @@ -330,23 +565,116 @@ impl<'a> IntoIterator for &'a HeaderMap { type Item = (&'a HeaderName, &'a HeaderValue); type IntoIter = Iter<'a>; + #[inline] fn into_iter(self) -> Self::IntoIter { - self.iter() + Iter::new(self.inner.iter()) } } -pub struct Iter<'a> { +/// Iterator for all values with the same header name. +/// +/// See [`HeaderMap::get_all`]. +#[derive(Debug)] +pub struct GetAll<'a> { idx: usize, - current: Option<(&'a HeaderName, &'a Vec)>, - iter: hash_map::Iter<'a, HeaderName, Value>, + value: Option<&'a Value>, +} + +impl<'a> GetAll<'a> { + fn new(value: Option<&'a Value>) -> Self { + Self { idx: 0, value } + } +} + +impl<'a> Iterator for GetAll<'a> { + type Item = &'a HeaderValue; + + fn next(&mut self) -> Option { + let val = self.value?; + + match val.get(self.idx) { + Some(val) => { + self.idx += 1; + Some(val) + } + None => { + // current index is none; remove value to fast-path future next calls + self.value = None; + None + } + } + } + + fn size_hint(&self) -> (usize, Option) { + match self.value { + Some(val) => (val.len(), Some(val.len())), + None => (0, Some(0)), + } + } +} + +/// Iterator for owned [`HeaderValue`]s with the same associated [`HeaderName`] returned from methods +/// on [`HeaderMap`] that remove or replace items. +#[derive(Debug)] +pub struct Removed { + inner: Option>, +} + +impl<'a> Removed { + fn new(value: Option) -> Self { + let inner = value.map(|value| value.inner.into_iter()); + Self { inner } + } +} + +impl Iterator for Removed { + type Item = HeaderValue; + + #[inline] + fn next(&mut self) -> Option { + self.inner.as_mut()?.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self.inner { + Some(ref iter) => iter.size_hint(), + None => (0, None), + } + } +} + +/// Iterator over all [`HeaderName`]s in the map. +#[derive(Debug)] +pub struct Keys<'a>(hash_map::Keys<'a, HeaderName, Value>); + +impl<'a> Iterator for Keys<'a> { + type Item = &'a HeaderName; + + #[inline] + fn next(&mut self) -> Option { + self.0.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +#[derive(Debug)] +pub struct Iter<'a> { + inner: hash_map::Iter<'a, HeaderName, Value>, + multi_inner: Option<(&'a HeaderName, &'a SmallVec<[HeaderValue; 4]>)>, + multi_idx: usize, } impl<'a> Iter<'a> { fn new(iter: hash_map::Iter<'a, HeaderName, Value>) -> Self { Self { - iter, - idx: 0, - current: None, + inner: iter, + multi_idx: 0, + multi_inner: None, } } } @@ -354,28 +682,272 @@ impl<'a> Iter<'a> { impl<'a> Iterator for Iter<'a> { type Item = (&'a HeaderName, &'a HeaderValue); - #[inline] - fn next(&mut self) -> Option<(&'a HeaderName, &'a HeaderValue)> { - if let Some(ref mut item) = self.current { - if self.idx < item.1.len() { - let item = (item.0, &item.1[self.idx]); - self.idx += 1; - return Some(item); - } else { - self.idx = 0; - self.current.take(); - } - } - if let Some(item) = self.iter.next() { - match item.1 { - Value::One(ref value) => Some((item.0, value)), - Value::Multi(ref vec) => { - self.current = Some((item.0, vec)); - self.next() + fn next(&mut self) -> Option { + // handle in-progress multi value lists first + if let Some((ref name, ref mut vals)) = self.multi_inner { + match vals.get(self.multi_idx) { + Some(val) => { + self.multi_idx += 1; + return Some((name, val)); + } + None => { + // no more items in value list; reset state + self.multi_idx = 0; + self.multi_inner = None; } } - } else { - None + } + + let (name, value) = self.inner.next()?; + + // set up new inner iter and recurse into it + self.multi_inner = Some((name, &value.inner)); + self.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // take inner lower bound + // make no attempt at an upper bound + (self.inner.size_hint().0, None) + } +} + +/// Iterator over drained name-value pairs. +/// +/// Iterator items are `(Option, HeaderValue)` to avoid cloning. +#[derive(Debug)] +pub struct Drain<'a> { + inner: hash_map::Drain<'a, HeaderName, Value>, + multi_inner: Option<(Option, SmallVec<[HeaderValue; 4]>)>, + multi_idx: usize, +} + +impl<'a> Drain<'a> { + fn new(iter: hash_map::Drain<'a, HeaderName, Value>) -> Self { + Self { + inner: iter, + multi_inner: None, + multi_idx: 0, } } } + +impl<'a> Iterator for Drain<'a> { + type Item = (Option, HeaderValue); + + fn next(&mut self) -> Option { + // handle in-progress multi value iterators first + if let Some((ref mut name, ref mut vals)) = self.multi_inner { + if !vals.is_empty() { + // OPTIMIZE: array removals + return Some((name.take(), vals.remove(0))); + } else { + // no more items in value iterator; reset state + self.multi_inner = None; + self.multi_idx = 0; + } + } + + let (name, value) = self.inner.next()?; + + // set up new inner iter and recurse into it + self.multi_inner = Some((Some(name), value.inner)); + self.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // take inner lower bound + // make no attempt at an upper bound + (self.inner.size_hint().0, None) + } +} + +/// Iterator over owned name-value pairs. +/// +/// Implementation necessarily clones header names for each value. +#[derive(Debug)] +pub struct IntoIter { + inner: hash_map::IntoIter, + multi_inner: Option<(HeaderName, smallvec::IntoIter<[HeaderValue; 4]>)>, +} + +impl IntoIter { + fn new(inner: hash_map::IntoIter) -> Self { + Self { + inner, + multi_inner: None, + } + } +} + +impl Iterator for IntoIter { + type Item = (HeaderName, HeaderValue); + + fn next(&mut self) -> Option { + // handle in-progress multi value iterators first + if let Some((ref name, ref mut vals)) = self.multi_inner { + match vals.next() { + Some(val) => { + return Some((name.clone(), val)); + } + None => { + // no more items in value iterator; reset state + self.multi_inner = None; + } + } + } + + let (name, value) = self.inner.next()?; + + // set up new inner iter and recurse into it + self.multi_inner = Some((name, value.inner.into_iter())); + self.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // take inner lower bound + // make no attempt at an upper bound + (self.inner.size_hint().0, None) + } +} + +#[cfg(test)] +mod tests { + use http::header; + + use super::*; + + #[test] + fn create() { + let map = HeaderMap::new(); + assert_eq!(map.len(), 0); + assert_eq!(map.capacity(), 0); + + let map = HeaderMap::with_capacity(16); + assert_eq!(map.len(), 0); + assert!(map.capacity() >= 16); + } + + #[test] + fn insert() { + let mut map = HeaderMap::new(); + + map.insert(header::LOCATION, HeaderValue::from_static("/test")); + assert_eq!(map.len(), 1); + } + + #[test] + fn contains() { + let mut map = HeaderMap::new(); + assert!(!map.contains_key(header::LOCATION)); + + map.insert(header::LOCATION, HeaderValue::from_static("/test")); + assert!(map.contains_key(header::LOCATION)); + assert!(map.contains_key("Location")); + assert!(map.contains_key("Location".to_owned())); + assert!(map.contains_key("location")); + } + + #[test] + fn entries_iter() { + let mut map = HeaderMap::new(); + + map.append(header::HOST, HeaderValue::from_static("duck.com")); + map.append(header::COOKIE, HeaderValue::from_static("one=1")); + map.append(header::COOKIE, HeaderValue::from_static("two=2")); + + let mut iter = map.iter(); + assert!(iter.next().is_some()); + assert!(iter.next().is_some()); + assert!(iter.next().is_some()); + assert!(iter.next().is_none()); + + let pairs = map.iter().collect::>(); + assert!(pairs.contains(&(&header::HOST, &HeaderValue::from_static("duck.com")))); + assert!(pairs.contains(&(&header::COOKIE, &HeaderValue::from_static("one=1")))); + assert!(pairs.contains(&(&header::COOKIE, &HeaderValue::from_static("two=2")))); + } + + #[test] + fn drain_iter() { + let mut map = HeaderMap::new(); + + map.append(header::COOKIE, HeaderValue::from_static("one=1")); + map.append(header::COOKIE, HeaderValue::from_static("two=2")); + + let mut vals = vec![]; + let mut iter = map.drain(); + + let (name, val) = iter.next().unwrap(); + assert_eq!(name, Some(header::COOKIE)); + vals.push(val); + + let (name, val) = iter.next().unwrap(); + assert!(name.is_none()); + vals.push(val); + + assert!(vals.contains(&HeaderValue::from_static("one=1"))); + assert!(vals.contains(&HeaderValue::from_static("two=2"))); + + assert!(iter.next().is_none()); + drop(iter); + + assert!(map.is_empty()); + } + + #[test] + fn entries_into_iter() { + let mut map = HeaderMap::new(); + + map.append(header::HOST, HeaderValue::from_static("duck.com")); + map.append(header::COOKIE, HeaderValue::from_static("one=1")); + map.append(header::COOKIE, HeaderValue::from_static("two=2")); + + let mut iter = map.into_iter(); + assert!(iter.next().is_some()); + assert!(iter.next().is_some()); + assert!(iter.next().is_some()); + assert!(iter.next().is_none()); + } + + #[test] + fn iter_and_into_iter_same_order() { + let mut map = HeaderMap::new(); + + map.append(header::HOST, HeaderValue::from_static("duck.com")); + map.append(header::COOKIE, HeaderValue::from_static("one=1")); + map.append(header::COOKIE, HeaderValue::from_static("two=2")); + + let mut iter = map.iter(); + let mut into_iter = map.clone().into_iter(); + + assert_eq!(iter.next().map(owned_pair), into_iter.next()); + assert_eq!(iter.next().map(owned_pair), into_iter.next()); + assert_eq!(iter.next().map(owned_pair), into_iter.next()); + assert_eq!(iter.next().map(owned_pair), into_iter.next()); + } + + #[test] + fn get_all_and_remove_same_order() { + let mut map = HeaderMap::new(); + + map.append(header::COOKIE, HeaderValue::from_static("one=1")); + map.append(header::COOKIE, HeaderValue::from_static("two=2")); + + let mut vals = map.get_all(header::COOKIE); + let mut removed = map.clone().remove(header::COOKIE); + + assert_eq!(vals.next(), removed.next().as_ref()); + assert_eq!(vals.next(), removed.next().as_ref()); + assert_eq!(vals.next(), removed.next().as_ref()); + } + + fn owned_pair<'a>( + (name, val): (&'a HeaderName, &'a HeaderValue), + ) -> (HeaderName, HeaderValue) { + (name.clone(), val.clone()) + } +} diff --git a/actix-http/src/header/mod.rs b/actix-http/src/header/mod.rs index 0f87516eb..1100a959d 100644 --- a/actix-http/src/header/mod.rs +++ b/actix-http/src/header/mod.rs @@ -1,35 +1,39 @@ -//! Various http headers -// This is mostly copy of [hyper](https://github.com/hyperium/hyper/tree/master/src/header) +//! Typed HTTP headers, pre-defined `HeaderName`s, traits for parsing and conversion, and other +//! header utility methods. -use std::convert::TryFrom; -use std::{fmt, str::FromStr}; +use std::fmt; use bytes::{Bytes, BytesMut}; -use http::Error as HttpError; -use mime::Mime; use percent_encoding::{AsciiSet, CONTROLS}; pub use http::header::*; use crate::error::ParseError; -use crate::httpmessage::HttpMessage; +use crate::HttpMessage; + +mod as_name; +mod into_pair; +mod into_value; +mod utils; mod common; pub(crate) mod map; mod shared; + pub use self::common::*; #[doc(hidden)] pub use self::shared::*; +pub use self::as_name::AsHeaderName; +pub use self::into_pair::IntoHeaderPair; +pub use self::into_value::IntoHeaderValue; #[doc(hidden)] pub use self::map::GetAll; pub use self::map::HeaderMap; +pub use self::utils::*; -/// A trait for any object that will represent a header field and value. -pub trait Header -where - Self: IntoHeaderValue, -{ +/// A trait for any object that already represents a valid header field and value. +pub trait Header: IntoHeaderValue { /// Returns the name of the header field fn name() -> HeaderName; @@ -37,170 +41,16 @@ where fn parse(msg: &T) -> Result; } -/// A trait for any object that can be Converted to a `HeaderValue` -pub trait IntoHeaderValue: Sized { - /// The type returned in the event of a conversion error. - type Error: Into; - - /// Try to convert value to a Header value. - fn try_into(self) -> Result; -} - -impl IntoHeaderValue for HeaderValue { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - Ok(self) - } -} - -impl<'a> IntoHeaderValue for &'a str { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - self.parse() - } -} - -impl<'a> IntoHeaderValue for &'a [u8] { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::from_bytes(self) - } -} - -impl IntoHeaderValue for Bytes { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::from_maybe_shared(self) - } -} - -impl IntoHeaderValue for Vec { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::try_from(self) - } -} - -impl IntoHeaderValue for String { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::try_from(self) - } -} - -impl IntoHeaderValue for usize { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - let s = format!("{}", self); - HeaderValue::try_from(s) - } -} - -impl IntoHeaderValue for u64 { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - let s = format!("{}", self); - HeaderValue::try_from(s) - } -} - -impl IntoHeaderValue for Mime { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::try_from(format!("{}", self)) - } -} - -/// Represents supported types of content encodings -#[derive(Copy, Clone, PartialEq, Debug)] -pub enum ContentEncoding { - /// Automatically select encoding based on encoding negotiation - Auto, - /// A format using the Brotli algorithm - Br, - /// A format using the zlib structure with deflate algorithm - Deflate, - /// Gzip algorithm - Gzip, - /// Indicates the identity function (i.e. no compression, nor modification) - Identity, -} - -impl ContentEncoding { - #[inline] - /// Is the content compressed? - pub fn is_compression(self) -> bool { - matches!(self, ContentEncoding::Identity | ContentEncoding::Auto) - } - - #[inline] - /// Convert content encoding to string - pub fn as_str(self) -> &'static str { - match self { - ContentEncoding::Br => "br", - ContentEncoding::Gzip => "gzip", - ContentEncoding::Deflate => "deflate", - ContentEncoding::Identity | ContentEncoding::Auto => "identity", - } - } - - #[inline] - /// default quality value - pub fn quality(self) -> f64 { - match self { - ContentEncoding::Br => 1.1, - ContentEncoding::Gzip => 1.0, - ContentEncoding::Deflate => 0.9, - ContentEncoding::Identity | ContentEncoding::Auto => 0.1, - } - } -} - -impl<'a> From<&'a str> for ContentEncoding { - fn from(s: &'a str) -> ContentEncoding { - let s = s.trim(); - - if s.eq_ignore_ascii_case("br") { - ContentEncoding::Br - } else if s.eq_ignore_ascii_case("gzip") { - ContentEncoding::Gzip - } else if s.eq_ignore_ascii_case("deflate") { - ContentEncoding::Deflate - } else { - ContentEncoding::Identity - } - } -} - -#[doc(hidden)] +#[derive(Debug, Default)] pub(crate) struct Writer { buf: BytesMut, } impl Writer { fn new() -> Writer { - Writer { - buf: BytesMut::new(), - } + Writer::default() } + fn take(&mut self) -> Bytes { self.buf.split().freeze() } @@ -219,176 +69,15 @@ impl fmt::Write for Writer { } } -#[inline] -#[doc(hidden)] -/// Reads a comma-delimited raw header into a Vec. -pub fn from_comma_delimited<'a, I: Iterator + 'a, T: FromStr>( - all: I, -) -> Result, ParseError> { - let mut result = Vec::new(); - for h in all { - let s = h.to_str().map_err(|_| ParseError::Header)?; - result.extend( - s.split(',') - .filter_map(|x| match x.trim() { - "" => None, - y => Some(y), - }) - .filter_map(|x| x.trim().parse().ok()), - ) - } - Ok(result) -} - -#[inline] -#[doc(hidden)] -/// Reads a single string when parsing a header. -pub fn from_one_raw_str(val: Option<&HeaderValue>) -> Result { - if let Some(line) = val { - let line = line.to_str().map_err(|_| ParseError::Header)?; - if !line.is_empty() { - return T::from_str(line).or(Err(ParseError::Header)); - } - } - Err(ParseError::Header) -} - -#[inline] -#[doc(hidden)] -/// Format an array into a comma-delimited string. -pub fn fmt_comma_delimited(f: &mut fmt::Formatter<'_>, parts: &[T]) -> fmt::Result -where - T: fmt::Display, -{ - let mut iter = parts.iter(); - if let Some(part) = iter.next() { - fmt::Display::fmt(part, f)?; - } - for part in iter { - f.write_str(", ")?; - fmt::Display::fmt(part, f)?; - } - Ok(()) -} - -// From hyper v0.11.27 src/header/parsing.rs - -/// The value part of an extended parameter consisting of three parts: -/// the REQUIRED character set name (`charset`), the OPTIONAL language information (`language_tag`), -/// and a character sequence representing the actual value (`value`), separated by single quote -/// characters. It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). -#[derive(Clone, Debug, PartialEq)] -pub struct ExtendedValue { - /// The character set that is used to encode the `value` to a string. - pub charset: Charset, - /// The human language details of the `value`, if available. - pub language_tag: Option, - /// The parameter value, as expressed in octets. - pub value: Vec, -} - -/// Parses extended header parameter values (`ext-value`), as defined in -/// [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). -/// -/// Extended values are denoted by parameter names that end with `*`. -/// -/// ## ABNF -/// -/// ```text -/// ext-value = charset "'" [ language ] "'" value-chars -/// ; like RFC 2231's -/// ; (see [RFC2231], Section 7) -/// -/// charset = "UTF-8" / "ISO-8859-1" / mime-charset -/// -/// mime-charset = 1*mime-charsetc -/// mime-charsetc = ALPHA / DIGIT -/// / "!" / "#" / "$" / "%" / "&" -/// / "+" / "-" / "^" / "_" / "`" -/// / "{" / "}" / "~" -/// ; as in Section 2.3 of [RFC2978] -/// ; except that the single quote is not included -/// ; SHOULD be registered in the IANA charset registry -/// -/// language = -/// -/// value-chars = *( pct-encoded / attr-char ) -/// -/// pct-encoded = "%" HEXDIG HEXDIG -/// ; see [RFC3986], Section 2.1 -/// -/// attr-char = ALPHA / DIGIT -/// / "!" / "#" / "$" / "&" / "+" / "-" / "." -/// / "^" / "_" / "`" / "|" / "~" -/// ; token except ( "*" / "'" / "%" ) -/// ``` -pub fn parse_extended_value( - val: &str, -) -> Result { - // Break into three pieces separated by the single-quote character - let mut parts = val.splitn(3, '\''); - - // Interpret the first piece as a Charset - let charset: Charset = match parts.next() { - None => return Err(crate::error::ParseError::Header), - Some(n) => FromStr::from_str(n).map_err(|_| crate::error::ParseError::Header)?, - }; - - // Interpret the second piece as a language tag - let language_tag: Option = match parts.next() { - None => return Err(crate::error::ParseError::Header), - Some("") => None, - Some(s) => match s.parse() { - Ok(lt) => Some(lt), - Err(_) => return Err(crate::error::ParseError::Header), - }, - }; - - // Interpret the third piece as a sequence of value characters - let value: Vec = match parts.next() { - None => return Err(crate::error::ParseError::Header), - Some(v) => percent_encoding::percent_decode(v.as_bytes()).collect(), - }; - - Ok(ExtendedValue { - value, - charset, - language_tag, - }) -} - -impl fmt::Display for ExtendedValue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let encoded_value = - percent_encoding::percent_encode(&self.value[..], HTTP_VALUE); - if let Some(ref lang) = self.language_tag { - write!(f, "{}'{}'{}", self.charset, lang, encoded_value) - } else { - write!(f, "{}''{}", self.charset, encoded_value) - } - } -} - -/// Percent encode a sequence of bytes with a character set defined in -/// -pub fn http_percent_encode(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { - let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE); - fmt::Display::fmt(&encoded, f) -} - -/// Convert http::HeaderMap to a HeaderMap +/// Convert `http::HeaderMap` to our `HeaderMap`. impl From for HeaderMap { - fn from(map: http::HeaderMap) -> HeaderMap { - let mut new_map = HeaderMap::with_capacity(map.capacity()); - for (h, v) in map.iter() { - new_map.append(h.clone(), v.clone()); - } - new_map + fn from(mut map: http::HeaderMap) -> HeaderMap { + HeaderMap::from_drain(map.drain()) } } -// This encode set is used for HTTP header values and is defined at -// https://tools.ietf.org/html/rfc5987#section-3.2 +/// This encode set is used for HTTP header values and is defined at +/// https://tools.ietf.org/html/rfc5987#section-3.2. pub(crate) const HTTP_VALUE: &AsciiSet = &CONTROLS .add(b' ') .add(b'"') @@ -410,91 +99,3 @@ pub(crate) const HTTP_VALUE: &AsciiSet = &CONTROLS .add(b']') .add(b'{') .add(b'}'); - -#[cfg(test)] -mod tests { - use super::shared::Charset; - use super::{parse_extended_value, ExtendedValue}; - use language_tags::LanguageTag; - - #[test] - fn test_parse_extended_value_with_encoding_and_language_tag() { - let expected_language_tag = "en".parse::().unwrap(); - // RFC 5987, Section 3.2.2 - // Extended notation, using the Unicode character U+00A3 (POUND SIGN) - let result = parse_extended_value("iso-8859-1'en'%A3%20rates"); - assert!(result.is_ok()); - let extended_value = result.unwrap(); - assert_eq!(Charset::Iso_8859_1, extended_value.charset); - assert!(extended_value.language_tag.is_some()); - assert_eq!(expected_language_tag, extended_value.language_tag.unwrap()); - assert_eq!( - vec![163, b' ', b'r', b'a', b't', b'e', b's'], - extended_value.value - ); - } - - #[test] - fn test_parse_extended_value_with_encoding() { - // RFC 5987, Section 3.2.2 - // Extended notation, using the Unicode characters U+00A3 (POUND SIGN) - // and U+20AC (EURO SIGN) - let result = parse_extended_value("UTF-8''%c2%a3%20and%20%e2%82%ac%20rates"); - assert!(result.is_ok()); - let extended_value = result.unwrap(); - assert_eq!(Charset::Ext("UTF-8".to_string()), extended_value.charset); - assert!(extended_value.language_tag.is_none()); - assert_eq!( - vec![ - 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', - b't', b'e', b's', - ], - extended_value.value - ); - } - - #[test] - fn test_parse_extended_value_missing_language_tag_and_encoding() { - // From: https://greenbytes.de/tech/tc2231/#attwithfn2231quot2 - let result = parse_extended_value("foo%20bar.html"); - assert!(result.is_err()); - } - - #[test] - fn test_parse_extended_value_partially_formatted() { - let result = parse_extended_value("UTF-8'missing third part"); - assert!(result.is_err()); - } - - #[test] - fn test_parse_extended_value_partially_formatted_blank() { - let result = parse_extended_value("blank second part'"); - assert!(result.is_err()); - } - - #[test] - fn test_fmt_extended_value_with_encoding_and_language_tag() { - let extended_value = ExtendedValue { - charset: Charset::Iso_8859_1, - language_tag: Some("en".parse().expect("Could not parse language tag")), - value: vec![163, b' ', b'r', b'a', b't', b'e', b's'], - }; - assert_eq!("ISO-8859-1'en'%A3%20rates", format!("{}", extended_value)); - } - - #[test] - fn test_fmt_extended_value_with_encoding() { - let extended_value = ExtendedValue { - charset: Charset::Ext("UTF-8".to_string()), - language_tag: None, - value: vec![ - 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', - b't', b'e', b's', - ], - }; - assert_eq!( - "UTF-8''%C2%A3%20and%20%E2%82%AC%20rates", - format!("{}", extended_value) - ); - } -} diff --git a/actix-http/src/header/shared/entity.rs b/actix-http/src/header/shared/entity.rs index 344cfb864..eb383cd6f 100644 --- a/actix-http/src/header/shared/entity.rs +++ b/actix-http/src/header/shared/entity.rs @@ -161,7 +161,7 @@ impl FromStr for EntityTag { impl IntoHeaderValue for EntityTag { type Error = InvalidHeaderValue; - fn try_into(self) -> Result { + fn try_into_value(self) -> Result { let mut wrt = Writer::new(); write!(wrt, "{}", self).unwrap(); HeaderValue::from_maybe_shared(wrt.take()) diff --git a/actix-http/src/header/shared/extended.rs b/actix-http/src/header/shared/extended.rs new file mode 100644 index 000000000..6bdcb7922 --- /dev/null +++ b/actix-http/src/header/shared/extended.rs @@ -0,0 +1,193 @@ +use std::{fmt, str::FromStr}; + +use language_tags::LanguageTag; + +use crate::header::{Charset, HTTP_VALUE}; + +// From hyper v0.11.27 src/header/parsing.rs + +/// The value part of an extended parameter consisting of three parts: +/// - The REQUIRED character set name (`charset`). +/// - The OPTIONAL language information (`language_tag`). +/// - A character sequence representing the actual value (`value`), separated by single quotes. +/// +/// It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +#[derive(Clone, Debug, PartialEq)] +pub struct ExtendedValue { + /// The character set that is used to encode the `value` to a string. + pub charset: Charset, + + /// The human language details of the `value`, if available. + pub language_tag: Option, + + /// The parameter value, as expressed in octets. + pub value: Vec, +} + +/// Parses extended header parameter values (`ext-value`), as defined in +/// [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// +/// Extended values are denoted by parameter names that end with `*`. +/// +/// ## ABNF +/// +/// ```text +/// ext-value = charset "'" [ language ] "'" value-chars +/// ; like RFC 2231's +/// ; (see [RFC2231], Section 7) +/// +/// charset = "UTF-8" / "ISO-8859-1" / mime-charset +/// +/// mime-charset = 1*mime-charsetc +/// mime-charsetc = ALPHA / DIGIT +/// / "!" / "#" / "$" / "%" / "&" +/// / "+" / "-" / "^" / "_" / "`" +/// / "{" / "}" / "~" +/// ; as in Section 2.3 of [RFC2978] +/// ; except that the single quote is not included +/// ; SHOULD be registered in the IANA charset registry +/// +/// language = +/// +/// value-chars = *( pct-encoded / attr-char ) +/// +/// pct-encoded = "%" HEXDIG HEXDIG +/// ; see [RFC3986], Section 2.1 +/// +/// attr-char = ALPHA / DIGIT +/// / "!" / "#" / "$" / "&" / "+" / "-" / "." +/// / "^" / "_" / "`" / "|" / "~" +/// ; token except ( "*" / "'" / "%" ) +/// ``` +pub fn parse_extended_value( + val: &str, +) -> Result { + // Break into three pieces separated by the single-quote character + let mut parts = val.splitn(3, '\''); + + // Interpret the first piece as a Charset + let charset: Charset = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some(n) => FromStr::from_str(n).map_err(|_| crate::error::ParseError::Header)?, + }; + + // Interpret the second piece as a language tag + let language_tag: Option = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some("") => None, + Some(s) => match s.parse() { + Ok(lt) => Some(lt), + Err(_) => return Err(crate::error::ParseError::Header), + }, + }; + + // Interpret the third piece as a sequence of value characters + let value: Vec = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some(v) => percent_encoding::percent_decode(v.as_bytes()).collect(), + }; + + Ok(ExtendedValue { + value, + charset, + language_tag, + }) +} + +impl fmt::Display for ExtendedValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let encoded_value = + percent_encoding::percent_encode(&self.value[..], HTTP_VALUE); + if let Some(ref lang) = self.language_tag { + write!(f, "{}'{}'{}", self.charset, lang, encoded_value) + } else { + write!(f, "{}''{}", self.charset, encoded_value) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_extended_value_with_encoding_and_language_tag() { + let expected_language_tag = "en".parse::().unwrap(); + // RFC 5987, Section 3.2.2 + // Extended notation, using the Unicode character U+00A3 (POUND SIGN) + let result = parse_extended_value("iso-8859-1'en'%A3%20rates"); + assert!(result.is_ok()); + let extended_value = result.unwrap(); + assert_eq!(Charset::Iso_8859_1, extended_value.charset); + assert!(extended_value.language_tag.is_some()); + assert_eq!(expected_language_tag, extended_value.language_tag.unwrap()); + assert_eq!( + vec![163, b' ', b'r', b'a', b't', b'e', b's'], + extended_value.value + ); + } + + #[test] + fn test_parse_extended_value_with_encoding() { + // RFC 5987, Section 3.2.2 + // Extended notation, using the Unicode characters U+00A3 (POUND SIGN) + // and U+20AC (EURO SIGN) + let result = parse_extended_value("UTF-8''%c2%a3%20and%20%e2%82%ac%20rates"); + assert!(result.is_ok()); + let extended_value = result.unwrap(); + assert_eq!(Charset::Ext("UTF-8".to_string()), extended_value.charset); + assert!(extended_value.language_tag.is_none()); + assert_eq!( + vec![ + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', + b't', b'e', b's', + ], + extended_value.value + ); + } + + #[test] + fn test_parse_extended_value_missing_language_tag_and_encoding() { + // From: https://greenbytes.de/tech/tc2231/#attwithfn2231quot2 + let result = parse_extended_value("foo%20bar.html"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_extended_value_partially_formatted() { + let result = parse_extended_value("UTF-8'missing third part"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_extended_value_partially_formatted_blank() { + let result = parse_extended_value("blank second part'"); + assert!(result.is_err()); + } + + #[test] + fn test_fmt_extended_value_with_encoding_and_language_tag() { + let extended_value = ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: Some("en".parse().expect("Could not parse language tag")), + value: vec![163, b' ', b'r', b'a', b't', b'e', b's'], + }; + assert_eq!("ISO-8859-1'en'%A3%20rates", format!("{}", extended_value)); + } + + #[test] + fn test_fmt_extended_value_with_encoding() { + let extended_value = ExtendedValue { + charset: Charset::Ext("UTF-8".to_string()), + language_tag: None, + value: vec![ + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', + b't', b'e', b's', + ], + }; + assert_eq!( + "UTF-8''%C2%A3%20and%20%E2%82%AC%20rates", + format!("{}", extended_value) + ); + } +} diff --git a/actix-http/src/header/shared/httpdate.rs b/actix-http/src/header/shared/httpdate.rs index 81caf6d53..72a225589 100644 --- a/actix-http/src/header/shared/httpdate.rs +++ b/actix-http/src/header/shared/httpdate.rs @@ -3,7 +3,8 @@ use std::io::Write; use std::str::FromStr; use std::time::{SystemTime, UNIX_EPOCH}; -use bytes::{buf::BufMutExt, BytesMut}; +use bytes::buf::BufMut; +use bytes::BytesMut; use http::header::{HeaderValue, InvalidHeaderValue}; use time::{offset, OffsetDateTime, PrimitiveDateTime}; @@ -47,7 +48,7 @@ impl From for HttpDate { impl IntoHeaderValue for HttpDate { type Error = InvalidHeaderValue; - fn try_into(self) -> Result { + fn try_into_value(self) -> Result { let mut wrt = BytesMut::with_capacity(29).writer(); write!( wrt, diff --git a/actix-http/src/header/shared/mod.rs b/actix-http/src/header/shared/mod.rs index f2bc91634..72161e46b 100644 --- a/actix-http/src/header/shared/mod.rs +++ b/actix-http/src/header/shared/mod.rs @@ -1,14 +1,16 @@ -//! Copied for `hyper::header::shared`; - -pub use self::charset::Charset; -pub use self::encoding::Encoding; -pub use self::entity::EntityTag; -pub use self::httpdate::HttpDate; -pub use self::quality_item::{q, qitem, Quality, QualityItem}; -pub use language_tags::LanguageTag; +//! Originally taken from `hyper::header::shared`. mod charset; mod encoding; mod entity; +mod extended; mod httpdate; mod quality_item; + +pub use self::charset::Charset; +pub use self::encoding::Encoding; +pub use self::entity::EntityTag; +pub use self::extended::{parse_extended_value, ExtendedValue}; +pub use self::httpdate::HttpDate; +pub use self::quality_item::{q, qitem, Quality, QualityItem}; +pub use language_tags::LanguageTag; diff --git a/actix-http/src/header/utils.rs b/actix-http/src/header/utils.rs new file mode 100644 index 000000000..e232d462f --- /dev/null +++ b/actix-http/src/header/utils.rs @@ -0,0 +1,63 @@ +use std::{fmt, str::FromStr}; + +use http::HeaderValue; + +use crate::{error::ParseError, header::HTTP_VALUE}; + +/// Reads a comma-delimited raw header into a Vec. +#[inline] +pub fn from_comma_delimited<'a, I, T>(all: I) -> Result, ParseError> +where + I: Iterator + 'a, + T: FromStr, +{ + let mut result = Vec::new(); + for h in all { + let s = h.to_str().map_err(|_| ParseError::Header)?; + result.extend( + s.split(',') + .filter_map(|x| match x.trim() { + "" => None, + y => Some(y), + }) + .filter_map(|x| x.trim().parse().ok()), + ) + } + Ok(result) +} + +/// Reads a single string when parsing a header. +#[inline] +pub fn from_one_raw_str(val: Option<&HeaderValue>) -> Result { + if let Some(line) = val { + let line = line.to_str().map_err(|_| ParseError::Header)?; + if !line.is_empty() { + return T::from_str(line).or(Err(ParseError::Header)); + } + } + Err(ParseError::Header) +} + +/// Format an array into a comma-delimited string. +#[inline] +pub fn fmt_comma_delimited(f: &mut fmt::Formatter<'_>, parts: &[T]) -> fmt::Result +where + T: fmt::Display, +{ + let mut iter = parts.iter(); + if let Some(part) = iter.next() { + fmt::Display::fmt(part, f)?; + } + for part in iter { + f.write_str(", ")?; + fmt::Display::fmt(part, f)?; + } + Ok(()) +} + +/// Percent encode a sequence of bytes with a character set defined in +/// +pub fn http_percent_encode(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { + let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE); + fmt::Display::fmt(&encoded, f) +} diff --git a/actix-http/src/httpcodes.rs b/actix-http/src/http_codes.rs similarity index 94% rename from actix-http/src/httpcodes.rs rename to actix-http/src/http_codes.rs index f421d8e23..688a08be5 100644 --- a/actix-http/src/httpcodes.rs +++ b/actix-http/src/http_codes.rs @@ -67,6 +67,14 @@ impl Response { static_resp!(ExpectationFailed, StatusCode::EXPECTATION_FAILED); static_resp!(UnprocessableEntity, StatusCode::UNPROCESSABLE_ENTITY); static_resp!(TooManyRequests, StatusCode::TOO_MANY_REQUESTS); + static_resp!( + RequestHeaderFieldsTooLarge, + StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE + ); + static_resp!( + UnavailableForLegalReasons, + StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS + ); static_resp!(InternalServerError, StatusCode::INTERNAL_SERVER_ERROR); static_resp!(NotImplemented, StatusCode::NOT_IMPLEMENTED); diff --git a/actix-http/src/httpmessage.rs b/actix-http/src/http_message.rs similarity index 77% rename from actix-http/src/httpmessage.rs rename to actix-http/src/http_message.rs index 471fbbcdc..b1f04e50d 100644 --- a/actix-http/src/httpmessage.rs +++ b/actix-http/src/http_message.rs @@ -5,15 +5,17 @@ use encoding_rs::{Encoding, UTF_8}; use http::header; use mime::Mime; -use crate::cookie::Cookie; -use crate::error::{ContentTypeError, CookieParseError, ParseError}; +use crate::error::{ContentTypeError, ParseError}; use crate::extensions::Extensions; use crate::header::{Header, HeaderMap}; use crate::payload::Payload; +#[cfg(feature = "cookies")] +use crate::{cookie::Cookie, error::CookieParseError}; +#[cfg(feature = "cookies")] struct Cookies(Vec>); -/// Trait that implements general purpose operations on http messages +/// Trait that implements general purpose operations on HTTP messages. pub trait HttpMessage: Sized { /// Type of message payload stream type Stream; @@ -30,8 +32,8 @@ pub trait HttpMessage: Sized { /// Mutable reference to a the request's extensions container fn extensions_mut(&self) -> RefMut<'_, Extensions>; + /// Get a header. #[doc(hidden)] - /// Get a header fn get_header(&self) -> Option where Self: Sized, @@ -43,8 +45,8 @@ pub trait HttpMessage: Sized { } } - /// Read the request content type. If request does not contain - /// *Content-Type* header, empty str get returned. + /// Read the request content type. If request did not contain a *Content-Type* header, an empty + /// string is returned. fn content_type(&self) -> &str { if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { if let Ok(content_type) = content_type.to_str() { @@ -90,7 +92,7 @@ pub trait HttpMessage: Sized { Ok(None) } - /// Check if request has chunked transfer encoding + /// Check if request has chunked transfer encoding. fn chunked(&self) -> Result { if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { if let Ok(s) = encodings.to_str() { @@ -104,7 +106,7 @@ pub trait HttpMessage: Sized { } /// Load request cookies. - #[inline] + #[cfg(feature = "cookies")] fn cookies(&self) -> Result>>, CookieParseError> { if self.extensions().get::().is_none() { let mut cookies = Vec::new(); @@ -119,12 +121,14 @@ pub trait HttpMessage: Sized { } self.extensions_mut().insert(Cookies(cookies)); } + Ok(Ref::map(self.extensions(), |ext| { &ext.get::().unwrap().0 })) } /// Return request cookie. + #[cfg(feature = "cookies")] fn cookie(&self, name: &str) -> Option> { if let Ok(cookies) = self.cookies() { for cookie in cookies.iter() { @@ -173,11 +177,13 @@ mod tests { #[test] fn test_content_type() { - let req = TestRequest::with_header("content-type", "text/plain").finish(); + let req = TestRequest::default() + .insert_header(("content-type", "text/plain")) + .finish(); assert_eq!(req.content_type(), "text/plain"); - let req = - TestRequest::with_header("content-type", "application/json; charset=utf=8") - .finish(); + let req = TestRequest::default() + .insert_header(("content-type", "application/json; charset=utf=8")) + .finish(); assert_eq!(req.content_type(), "application/json"); let req = TestRequest::default().finish(); assert_eq!(req.content_type(), ""); @@ -185,13 +191,15 @@ mod tests { #[test] fn test_mime_type() { - let req = TestRequest::with_header("content-type", "application/json").finish(); + let req = TestRequest::default() + .insert_header(("content-type", "application/json")) + .finish(); assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON)); let req = TestRequest::default().finish(); assert_eq!(req.mime_type().unwrap(), None); - let req = - TestRequest::with_header("content-type", "application/json; charset=utf-8") - .finish(); + let req = TestRequest::default() + .insert_header(("content-type", "application/json; charset=utf-8")) + .finish(); let mt = req.mime_type().unwrap().unwrap(); assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8)); assert_eq!(mt.type_(), mime::APPLICATION); @@ -200,11 +208,9 @@ mod tests { #[test] fn test_mime_type_error() { - let req = TestRequest::with_header( - "content-type", - "applicationadfadsfasdflknadsfklnadsfjson", - ) - .finish(); + let req = TestRequest::default() + .insert_header(("content-type", "applicationadfadsfasdflknadsfklnadsfjson")) + .finish(); assert_eq!(Err(ContentTypeError::ParseError), req.mime_type()); } @@ -213,27 +219,27 @@ mod tests { let req = TestRequest::default().finish(); assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); - let req = TestRequest::with_header("content-type", "application/json").finish(); + let req = TestRequest::default() + .insert_header(("content-type", "application/json")) + .finish(); assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); - let req = TestRequest::with_header( - "content-type", - "application/json; charset=ISO-8859-2", - ) - .finish(); + let req = TestRequest::default() + .insert_header(("content-type", "application/json; charset=ISO-8859-2")) + .finish(); assert_eq!(ISO_8859_2, req.encoding().unwrap()); } #[test] fn test_encoding_error() { - let req = TestRequest::with_header("content-type", "applicatjson").finish(); + let req = TestRequest::default() + .insert_header(("content-type", "applicatjson")) + .finish(); assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err()); - let req = TestRequest::with_header( - "content-type", - "application/json; charset=kkkttktk", - ) - .finish(); + let req = TestRequest::default() + .insert_header(("content-type", "application/json; charset=kkkttktk")) + .finish(); assert_eq!( Some(ContentTypeError::UnknownEncoding), req.encoding().err() @@ -245,15 +251,16 @@ mod tests { let req = TestRequest::default().finish(); assert!(!req.chunked().unwrap()); - let req = - TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); + let req = TestRequest::default() + .insert_header((header::TRANSFER_ENCODING, "chunked")) + .finish(); assert!(req.chunked().unwrap()); let req = TestRequest::default() - .header( + .insert_header(( header::TRANSFER_ENCODING, Bytes::from_static(b"some va\xadscc\xacas0xsdasdlue"), - ) + )) .finish(); assert!(req.chunked().is_err()); } diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index edf2b98e9..a5454917a 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -1,6 +1,21 @@ //! HTTP primitives for the Actix ecosystem. +//! +//! ## Crate Features +//! | Feature | Functionality | +//! | ---------------- | ----------------------------------------------------- | +//! | `openssl` | TLS support via [OpenSSL]. | +//! | `rustls` | TLS support via [rustls]. | +//! | `compress` | Payload compression support. (Deflate, Gzip & Brotli) | +//! | `cookies` | Support for cookies backed by the [cookie] crate. | +//! | `secure-cookies` | Adds for secure cookies. Enables `cookies` feature. | +//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. | +//! +//! [OpenSSL]: https://crates.io/crates/openssl +//! [rustls]: https://crates.io/crates/rustls +//! [cookie]: https://crates.io/crates/cookie +//! [trust-dns]: https://crates.io/crates/trust-dns -#![deny(rust_2018_idioms)] +#![deny(rust_2018_idioms, nonstandard_style)] #![allow( clippy::type_complexity, clippy::too_many_arguments, @@ -20,15 +35,14 @@ mod macros; pub mod body; mod builder; pub mod client; -mod cloneable; mod config; #[cfg(feature = "compress")] pub mod encoding; mod extensions; mod header; mod helpers; -mod httpcodes; -pub mod httpmessage; +mod http_codes; +mod http_message; mod message; mod payload; mod request; @@ -36,18 +50,20 @@ mod response; mod service; mod time_parser; -pub use cookie; pub mod error; pub mod h1; pub mod h2; pub mod test; pub mod ws; +#[cfg(feature = "cookies")] +pub use cookie; + pub use self::builder::HttpServiceBuilder; pub use self::config::{KeepAlive, ServiceConfig}; pub use self::error::{Error, ResponseError, Result}; pub use self::extensions::Extensions; -pub use self::httpmessage::HttpMessage; +pub use self::http_message::HttpMessage; pub use self::message::{Message, RequestHead, RequestHeadType, ResponseHead}; pub use self::payload::{Payload, PayloadStream}; pub use self::request::Request; @@ -55,7 +71,7 @@ pub use self::response::{Response, ResponseBuilder}; pub use self::service::HttpService; pub mod http { - //! Various HTTP related types + //! Various HTTP related types. // re-exports pub use http::header::{HeaderName, HeaderValue}; @@ -63,10 +79,11 @@ pub mod http { pub use http::{uri, Error, Uri}; pub use http::{Method, StatusCode, Version}; + #[cfg(feature = "cookies")] pub use crate::cookie::{Cookie, CookieBuilder}; pub use crate::header::HeaderMap; - /// Various http headers + /// A collection of HTTP headers and helpers. pub mod header { pub use crate::header::*; } @@ -74,11 +91,49 @@ pub mod http { pub use crate::message::ConnectionType; } -/// Http protocol +/// A major HTTP protocol version. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[non_exhaustive] pub enum Protocol { Http1, Http2, + Http3, } type ConnectCallback = dyn Fn(&IO, &mut Extensions); + +/// Container for data that extract with ConnectCallback. +/// +/// # Implementation Details +/// Uses Option to reduce necessary allocations when merging with request extensions. +pub(crate) struct OnConnectData(Option); + +impl Default for OnConnectData { + fn default() -> Self { + Self(None) + } +} + +impl OnConnectData { + /// Construct by calling the on-connect callback with the underlying transport I/O. + pub(crate) fn from_io( + io: &T, + on_connect_ext: Option<&ConnectCallback>, + ) -> Self { + let ext = on_connect_ext.map(|handler| { + let mut extensions = Extensions::new(); + handler(io, &mut extensions); + extensions + }); + + Self(ext) + } + + /// Merge self into given request's extensions. + #[inline] + pub(crate) fn merge_into(&mut self, req: &mut Request) { + if let Some(ref mut ext) = self.0 { + req.head.extensions.get_mut().drain_from(ext); + } + } +} diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index 5e53f73b6..6438ccba0 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -3,7 +3,6 @@ use std::net; use std::rc::Rc; use bitflags::bitflags; -use copyless::BoxHelper; use crate::extensions::Extensions; use crate::header::HeaderMap; @@ -14,8 +13,10 @@ use crate::http::{header, Method, StatusCode, Uri, Version}; pub enum ConnectionType { /// Close connection after response Close, + /// Keep connection alive after response KeepAlive, + /// Connection is upgraded to different type Upgrade, } @@ -35,7 +36,9 @@ bitflags! { pub trait Head: Default + 'static { fn clear(&mut self); - fn pool() -> &'static MessagePool; + fn with_pool(f: F) -> R + where + F: FnOnce(&MessagePool) -> R; } #[derive(Debug)] @@ -67,11 +70,14 @@ impl Head for RequestHead { fn clear(&mut self) { self.flags = Flags::empty(); self.headers.clear(); - self.extensions.borrow_mut().clear(); + self.extensions.get_mut().clear(); } - fn pool() -> &'static MessagePool { - REQUEST_POOL.with(|p| *p) + fn with_pool(f: F) -> R + where + F: FnOnce(&MessagePool) -> R, + { + REQUEST_POOL.with(|p| f(p)) } } @@ -339,21 +345,15 @@ impl ResponseHead { } pub struct Message { + // Rc here should not be cloned by anyone. + // It's used to reuse allocation of T and no shared ownership is allowed. head: Rc, } impl Message { /// Get new message from the pool of objects pub fn new() -> Self { - T::pool().get_message() - } -} - -impl Clone for Message { - fn clone(&self) -> Self { - Message { - head: self.head.clone(), - } + T::with_pool(|p| p.get_message()) } } @@ -373,9 +373,7 @@ impl std::ops::DerefMut for Message { impl Drop for Message { fn drop(&mut self) { - if Rc::strong_count(&self.head) == 1 { - T::pool().release(self.head.clone()); - } + T::with_pool(|p| p.release(self.head.clone())) } } @@ -427,22 +425,23 @@ pub struct MessagePool(RefCell>>); /// Request's objects pool pub struct BoxedResponsePool(RefCell>>); -thread_local!(static REQUEST_POOL: &'static MessagePool = MessagePool::::create()); -thread_local!(static RESPONSE_POOL: &'static BoxedResponsePool = BoxedResponsePool::create()); +thread_local!(static REQUEST_POOL: MessagePool = MessagePool::::create()); +thread_local!(static RESPONSE_POOL: BoxedResponsePool = BoxedResponsePool::create()); impl MessagePool { - fn create() -> &'static MessagePool { - let pool = MessagePool(RefCell::new(Vec::with_capacity(128))); - Box::leak(Box::new(pool)) + fn create() -> MessagePool { + MessagePool(RefCell::new(Vec::with_capacity(128))) } /// Get message from the pool #[inline] - fn get_message(&'static self) -> Message { + fn get_message(&self) -> Message { if let Some(mut msg) = self.0.borrow_mut().pop() { - if let Some(r) = Rc::get_mut(&mut msg) { - r.clear(); - } + // Message is put in pool only when it's the last copy. + // which means it's guaranteed to be unique when popped out. + Rc::get_mut(&mut msg) + .expect("Multiple copies exist") + .clear(); Message { head: msg } } else { Message { @@ -462,14 +461,13 @@ impl MessagePool { } impl BoxedResponsePool { - fn create() -> &'static BoxedResponsePool { - let pool = BoxedResponsePool(RefCell::new(Vec::with_capacity(128))); - Box::leak(Box::new(pool)) + fn create() -> BoxedResponsePool { + BoxedResponsePool(RefCell::new(Vec::with_capacity(128))) } /// Get message from the pool #[inline] - fn get_message(&'static self, status: StatusCode) -> BoxedResponseHead { + fn get_message(&self, status: StatusCode) -> BoxedResponseHead { if let Some(mut head) = self.0.borrow_mut().pop() { head.reason = None; head.status = status; @@ -478,17 +476,17 @@ impl BoxedResponsePool { BoxedResponseHead { head: Some(head) } } else { BoxedResponseHead { - head: Some(Box::alloc().init(ResponseHead::new(status))), + head: Some(Box::new(ResponseHead::new(status))), } } } #[inline] /// Release request instance - fn release(&self, msg: Box) { + fn release(&self, mut msg: Box) { let v = &mut self.0.borrow_mut(); if v.len() < 128 { - msg.extensions.borrow_mut().clear(); + msg.extensions.get_mut().clear(); v.push(msg); } } diff --git a/actix-http/src/request.rs b/actix-http/src/request.rs index 64e302441..197ec11c6 100644 --- a/actix-http/src/request.rs +++ b/actix-http/src/request.rs @@ -1,13 +1,17 @@ -use std::cell::{Ref, RefMut}; -use std::{fmt, net}; +//! HTTP requests. + +use std::{ + cell::{Ref, RefMut}, + fmt, net, +}; use http::{header, Method, Uri, Version}; use crate::extensions::Extensions; use crate::header::HeaderMap; -use crate::httpmessage::HttpMessage; use crate::message::{Message, RequestHead}; use crate::payload::{Payload, PayloadStream}; +use crate::HttpMessage; /// Request pub struct Request

{ @@ -23,6 +27,10 @@ impl

HttpMessage for Request

{ &self.head().headers } + fn take_payload(&mut self) -> Payload

{ + std::mem::replace(&mut self.payload, Payload::None) + } + /// Request extensions #[inline] fn extensions(&self) -> Ref<'_, Extensions> { @@ -34,10 +42,6 @@ impl

HttpMessage for Request

{ fn extensions_mut(&self) -> RefMut<'_, Extensions> { self.head.extensions_mut() } - - fn take_payload(&mut self) -> Payload

{ - std::mem::replace(&mut self.payload, Payload::None) - } } impl From> for Request { @@ -103,7 +107,7 @@ impl

Request

{ #[inline] #[doc(hidden)] - /// Mutable reference to a http message part of the request + /// Mutable reference to a HTTP message part of the request pub fn head_mut(&mut self) -> &mut RequestHead { &mut *self.head } @@ -154,10 +158,12 @@ impl

Request

{ self.head().method == Method::CONNECT } - /// Peer socket address + /// Peer socket address. /// - /// Peer address is actual socket address, if proxy is used in front of - /// actix http server, then peer address would be address of this proxy. + /// Peer address is the directly connected peer's socket address. If a proxy is used in front of + /// the Actix Web server, then it would be address of this proxy. + /// + /// Will only return None when called in unit tests. #[inline] pub fn peer_addr(&self) -> Option { self.head().peer_addr @@ -173,13 +179,17 @@ impl

fmt::Debug for Request

Hello!

")) /// ) /// } /// } -/// # fn is_a_variant() -> bool { true } -/// # fn main() {} /// ``` #[derive(Debug, PartialEq)] -pub enum Either { - /// First branch of the type - A(A), - /// Second branch of the type - B(B), +pub enum Either { + /// A value of type `L`. + Left(L), + + /// A value of type `R`. + Right(R), +} + +impl Either, Json> { + pub fn into_inner(self) -> T { + match self { + Either::Left(form) => form.into_inner(), + Either::Right(form) => form.into_inner(), + } + } +} + +impl Either, Form> { + pub fn into_inner(self) -> T { + match self { + Either::Left(form) => form.into_inner(), + Either::Right(form) => form.into_inner(), + } + } +} + +impl From> for Either { + fn from(val: either::Either) -> Self { + match val { + either::Either::Left(l) => Either::Left(l), + either::Either::Right(r) => Either::Right(r), + } + } +} + +impl From> for either::Either { + fn from(val: Either) -> Self { + match val { + Either::Left(l) => either::Either::Left(l), + Either::Right(r) => either::Either::Right(r), + } + } } #[cfg(test)] -impl Either { - pub(self) fn unwrap_left(self) -> A { +impl Either { + pub(self) fn unwrap_left(self) -> L { match self { - Either::A(data) => data, - Either::B(_) => { - panic!("Cannot unwrap left branch. Either contains a right branch.") + Either::Left(data) => data, + Either::Right(_) => { + panic!("Cannot unwrap Left branch. Either contains an `R` type.") } } } - pub(self) fn unwrap_right(self) -> B { + pub(self) fn unwrap_right(self) -> R { match self { - Either::A(_) => { - panic!("Cannot unwrap right branch. Either contains a left branch.") + Either::Left(_) => { + panic!("Cannot unwrap Right branch. Either contains an `L` type.") } - Either::B(data) => data, + Either::Right(data) => data, } } } -impl Responder for Either +/// See [here](#responder) for example of usage as a handler return type. +impl Responder for Either where - A: Responder, - B: Responder, + L: Responder, + R: Responder, { - type Error = Error; - type Future = EitherResponder; - - fn respond_to(self, req: &HttpRequest) -> Self::Future { + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self { - Either::A(a) => EitherResponder::A(a.respond_to(req)), - Either::B(b) => EitherResponder::B(b.respond_to(req)), + Either::Left(a) => a.respond_to(req), + Either::Right(b) => b.respond_to(req), } } } -#[pin_project(project = EitherResponderProj)] -pub enum EitherResponder -where - A: Responder, - B: Responder, -{ - A(#[pin] A::Future), - B(#[pin] B::Future), -} - -impl Future for EitherResponder -where - A: Responder, - B: Responder, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project() { - EitherResponderProj::A(fut) => { - Poll::Ready(ready!(fut.poll(cx)).map_err(|e| e.into())) - } - EitherResponderProj::B(fut) => { - Poll::Ready(ready!(fut.poll(cx).map_err(|e| e.into()))) - } - } - } -} - -/// A composite error resulting from failure to extract an `Either`. +/// A composite error resulting from failure to extract an `Either`. /// /// The implementation of `Into` will return the payload buffering error or the /// error from the primary extractor. To access the fallback error, use a match clause. #[derive(Debug)] -pub enum EitherExtractError { +pub enum EitherExtractError { /// Error from payload buffering, such as exceeding payload max size limit. Bytes(Error), /// Error from primary extractor. - Extract(A, B), + Extract(L, R), } -impl Into for EitherExtractError +impl From> for Error where - A: Into, - B: Into, + L: Into, + R: Into, { - fn into(self) -> Error { - match self { + fn from(err: EitherExtractError) -> Error { + match err { EitherExtractError::Bytes(err) => err, EitherExtractError::Extract(a_err, _b_err) => a_err.into(), } } } -/// Provides a mechanism for trying two extractors, a primary and a fallback. Useful for -/// "polymorphic payloads" where, for example, a form might be JSON or URL encoded. -/// -/// It is important to note that this extractor, by necessity, buffers the entire request payload -/// as part of its implementation. Though, it does respect a `PayloadConfig`'s maximum size limit. -impl FromRequest for Either +/// See [here](#extractor) for example of usage as an extractor. +impl FromRequest for Either where - A: FromRequest + 'static, - B: FromRequest + 'static, + L: FromRequest + 'static, + R: FromRequest + 'static, { - type Error = EitherExtractError; + type Error = EitherExtractError; type Future = LocalBoxFuture<'static, Result>; type Config = (); @@ -153,32 +187,32 @@ where Bytes::from_request(req, payload) .map_err(EitherExtractError::Bytes) - .and_then(|bytes| bytes_to_a_or_b(req2, bytes)) + .and_then(|bytes| bytes_to_l_or_r(req2, bytes)) .boxed_local() } } -async fn bytes_to_a_or_b( +async fn bytes_to_l_or_r( req: HttpRequest, bytes: Bytes, -) -> Result, EitherExtractError> +) -> Result, EitherExtractError> where - A: FromRequest + 'static, - B: FromRequest + 'static, + L: FromRequest + 'static, + R: FromRequest + 'static, { let fallback = bytes.clone(); let a_err; let mut pl = payload_from_bytes(bytes); - match A::from_request(&req, &mut pl).await { - Ok(a_data) => return Ok(Either::A(a_data)), + match L::from_request(&req, &mut pl).await { + Ok(a_data) => return Ok(Either::Left(a_data)), // store A's error for returning if B also fails Err(err) => a_err = err, }; let mut pl = payload_from_bytes(fallback); - match B::from_request(&req, &mut pl).await { - Ok(b_data) => return Ok(Either::B(b_data)), + match R::from_request(&req, &mut pl).await { + Ok(b_data) => return Ok(Either::Right(b_data)), Err(b_err) => Err(EitherExtractError::Extract(a_err, b_err)), } } @@ -242,13 +276,12 @@ mod tests { .set_payload(Bytes::from_static(b"!@$%^&*()")) .to_http_parts(); - let payload = - Either::, Json>, Bytes>::from_request( - &req, &mut pl, - ) - .await - .unwrap() - .unwrap_right(); + let payload = Either::, Json>, Bytes>::from_request( + &req, &mut pl, + ) + .await + .unwrap() + .unwrap_right(); assert_eq!(&payload.as_ref(), &b"!@$%^&*()"); } @@ -260,15 +293,14 @@ mod tests { }) .to_http_parts(); - let form = - Either::, Json>, Bytes>::from_request( - &req, &mut pl, - ) - .await - .unwrap() - .unwrap_left() - .unwrap_right() - .into_inner(); + let form = Either::, Json>, Bytes>::from_request( + &req, &mut pl, + ) + .await + .unwrap() + .unwrap_left() + .unwrap_right() + .into_inner(); assert_eq!(&form.hello, "world"); } } diff --git a/src/types/form.rs b/src/types/form.rs index 82ea73216..0b5c3c1b4 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -1,72 +1,68 @@ -//! Form extractor +//! For URL encoded form helper documentation, see [`Form`]. -use std::future::Future; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::{fmt, ops}; +use std::{ + fmt, + future::Future, + ops, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; -use actix_http::{Error, HttpMessage, Payload, Response}; +use actix_http::Payload; use bytes::BytesMut; use encoding_rs::{Encoding, UTF_8}; -use futures_util::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; -use futures_util::StreamExt; -use serde::de::DeserializeOwned; -use serde::Serialize; +use futures_util::{ + future::{FutureExt, LocalBoxFuture}, + StreamExt, +}; +use serde::{de::DeserializeOwned, Serialize}; #[cfg(feature = "compress")] use crate::dev::Decompress; -use crate::error::UrlencodedError; -use crate::extract::FromRequest; -use crate::http::{ - header::{ContentType, CONTENT_LENGTH}, - StatusCode, +use crate::{ + error::UrlencodedError, extract::FromRequest, http::header::CONTENT_LENGTH, web, Error, + HttpMessage, HttpRequest, HttpResponse, Responder, }; -use crate::request::HttpRequest; -use crate::{responder::Responder, web}; -/// Form data helper (`application/x-www-form-urlencoded`) +/// URL encoded payload extractor and responder. /// -/// Can be use to extract url-encoded data from the request body, -/// or send url-encoded data as the response. +/// `Form` has two uses: URL encoded responses, and extracting typed data from URL request payloads. /// -/// ## Extract +/// # Extractor +/// To extract typed data from a request body, the inner type `T` must implement the +/// [`serde::Deserialize`] trait. /// -/// To extract typed information from request's body, the type `T` must -/// implement the `Deserialize` trait from *serde*. +/// Use [`FormConfig`] to configure extraction process. /// -/// [**FormConfig**](FormConfig) allows to configure extraction -/// process. -/// -/// ### Example -/// ```rust -/// use actix_web::web; -/// use serde_derive::Deserialize; +/// ``` +/// use actix_web::{post, web}; +/// use serde::Deserialize; /// /// #[derive(Deserialize)] -/// struct FormData { -/// username: String, +/// struct Info { +/// name: String, /// } /// -/// /// Extract form data using serde. -/// /// This handler get called only if content type is *x-www-form-urlencoded* -/// /// and content of the request could be deserialized to a `FormData` struct -/// fn index(form: web::Form) -> String { -/// format!("Welcome {}!", form.username) +/// // This handler is only called if: +/// // - request headers declare the content type as `application/x-www-form-urlencoded` +/// // - request payload is deserialized into a `Info` struct from the URL encoded format +/// #[post("/")] +/// async fn index(form: web::Form) -> String { +/// format!("Welcome {}!", form.name) /// } -/// # fn main() {} /// ``` /// -/// ## Respond +/// # Responder +/// The `Form` type also allows you to create URL encoded responses: +/// simply return a value of type Form where T is the type to be URL encoded. +/// The type must implement [`serde::Serialize`]. /// -/// The `Form` type also allows you to respond with well-formed url-encoded data: -/// simply return a value of type Form where T is the type to be url-encoded. -/// The type must implement `serde::Serialize`; +/// Responses use /// -/// ### Example -/// ```rust -/// use actix_web::*; -/// use serde_derive::Serialize; +/// ``` +/// use actix_web::{get, web}; +/// use serde::Serialize; /// /// #[derive(Serialize)] /// struct SomeForm { @@ -74,22 +70,23 @@ use crate::{responder::Responder, web}; /// age: u8 /// } /// -/// // Will return a 200 response with header -/// // `Content-Type: application/x-www-form-urlencoded` -/// // and body "name=actix&age=123" -/// fn index() -> web::Form { +/// // Response will have: +/// // - status: 200 OK +/// // - header: `Content-Type: application/x-www-form-urlencoded` +/// // - body: `name=actix&age=123` +/// #[get("/")] +/// async fn index() -> web::Form { /// web::Form(SomeForm { /// name: "actix".into(), /// age: 123 /// }) /// } -/// # fn main() {} /// ``` #[derive(PartialEq, Eq, PartialOrd, Ord)] pub struct Form(pub T); impl Form { - /// Deconstruct to an inner value + /// Unwrap into inner `T` value. pub fn into_inner(self) -> T { self.0 } @@ -109,6 +106,7 @@ impl ops::DerefMut for Form { } } +/// See [here](#extractor) for example of usage as an extractor. impl FromRequest for Form where T: DeserializeOwned + 'static, @@ -120,7 +118,7 @@ where #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { let req2 = req.clone(); - let (limit, err) = req + let (limit, err_handler) = req .app_data::() .or_else(|| { req.app_data::>() @@ -132,13 +130,10 @@ where UrlEncoded::new(req, payload) .limit(limit) .map(move |res| match res { - Err(e) => { - if let Some(err) = err { - Err((*err)(e, &req2)) - } else { - Err(e.into()) - } - } + Err(err) => match err_handler { + Some(err_handler) => Err((err_handler)(err, &req2)), + None => Err(err.into()), + }, Ok(item) => Ok(Form(item)), }) .boxed_local() @@ -157,49 +152,39 @@ impl fmt::Display for Form { } } +/// See [here](#responder) for example of usage as a handler return type. impl Responder for Form { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - let body = match serde_urlencoded::to_string(&self.0) { - Ok(body) => body, - Err(e) => return err(e.into()), - }; - - ok(Response::build(StatusCode::OK) - .set(ContentType::form_url_encoded()) - .body(body)) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + match serde_urlencoded::to_string(&self.0) { + Ok(body) => HttpResponse::Ok() + .content_type(mime::APPLICATION_WWW_FORM_URLENCODED) + .body(body), + Err(err) => HttpResponse::from_error(err.into()), + } } } -/// Form extractor configuration +/// [`Form`] extractor configuration. /// -/// ```rust -/// use actix_web::{web, App, FromRequest, Result}; -/// use serde_derive::Deserialize; +/// ``` +/// use actix_web::{post, web, App, FromRequest, Result}; +/// use serde::Deserialize; /// /// #[derive(Deserialize)] -/// struct FormData { +/// struct Info { /// username: String, /// } /// -/// /// Extract form data using serde. -/// /// Custom configuration is used for this handler, max payload size is 4k -/// async fn index(form: web::Form) -> Result { +/// // Custom `FormConfig` is applied to App. +/// // Max payload size for URL encoded forms is set to 4kB. +/// #[post("/")] +/// async fn index(form: web::Form) -> Result { /// Ok(format!("Welcome {}!", form.username)) /// } /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html") -/// // change `Form` extractor configuration -/// .app_data( -/// web::FormConfig::default().limit(4097) -/// ) -/// .route(web::get().to(index)) -/// ); -/// } +/// App::new() +/// .app_data(web::FormConfig::default().limit(4096)) +/// .service(index); /// ``` #[derive(Clone)] pub struct FormConfig { @@ -208,7 +193,7 @@ pub struct FormConfig { } impl FormConfig { - /// Change max size of payload. By default max size is 16Kb + /// Set maximum accepted payload size. By default this limit is 16kB. pub fn limit(mut self, limit: usize) -> Self { self.limit = limit; self @@ -233,33 +218,30 @@ impl Default for FormConfig { } } -/// Future that resolves to a parsed urlencoded values. +/// Future that resolves to some `T` when parsed from a URL encoded payload. /// -/// Parse `application/x-www-form-urlencoded` encoded request's body. -/// Return `UrlEncoded` future. Form can be deserialized to any type that -/// implements `Deserialize` trait from *serde*. +/// Form can be deserialized from any type `T` that implements [`serde::Deserialize`]. /// -/// Returns error: -/// -/// * content type is not `application/x-www-form-urlencoded` -/// * content-length is greater than 32k -/// -pub struct UrlEncoded { +/// Returns error if: +/// - content type is not `application/x-www-form-urlencoded` +/// - content length is greater than [limit](UrlEncoded::limit()) +pub struct UrlEncoded { #[cfg(feature = "compress")] stream: Option>, #[cfg(not(feature = "compress"))] stream: Option, + limit: usize, length: Option, encoding: &'static Encoding, err: Option, - fut: Option>>, + fut: Option>>, } #[allow(clippy::borrow_interior_mutable_const)] -impl UrlEncoded { - /// Create a new future to URL encode a request - pub fn new(req: &HttpRequest, payload: &mut Payload) -> UrlEncoded { +impl UrlEncoded { + /// Create a new future to decode a URL encoded request payload. + pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self { // check content type if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" { return Self::err(UrlencodedError::ContentType); @@ -297,29 +279,29 @@ impl UrlEncoded { } } - fn err(e: UrlencodedError) -> Self { + fn err(err: UrlencodedError) -> Self { UrlEncoded { stream: None, limit: 32_768, fut: None, - err: Some(e), + err: Some(err), length: None, encoding: UTF_8, } } - /// Change max size of payload. By default max size is 256Kb + /// Set maximum accepted payload size. The default limit is 256kB. pub fn limit(mut self, limit: usize) -> Self { self.limit = limit; self } } -impl Future for UrlEncoded +impl Future for UrlEncoded where - U: DeserializeOwned + 'static, + T: DeserializeOwned + 'static, { - type Output = Result; + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(ref mut fut) = self.fut { @@ -348,6 +330,7 @@ where while let Some(item) = stream.next().await { let chunk = item?; + if (body.len() + chunk.len()) > limit { return Err(UrlencodedError::Overflow { size: body.len() + chunk.len(), @@ -359,19 +342,19 @@ where } if encoding == UTF_8 { - serde_urlencoded::from_bytes::(&body) - .map_err(|_| UrlencodedError::Parse) + serde_urlencoded::from_bytes::(&body).map_err(|_| UrlencodedError::Parse) } else { let body = encoding .decode_without_bom_handling_and_without_replacement(&body) .map(|s| s.into_owned()) .ok_or(UrlencodedError::Parse)?; - serde_urlencoded::from_str::(&body) - .map_err(|_| UrlencodedError::Parse) + + serde_urlencoded::from_str::(&body).map_err(|_| UrlencodedError::Parse) } } .boxed_local(), ); + self.poll(cx) } } @@ -382,7 +365,10 @@ mod tests { use serde::{Deserialize, Serialize}; use super::*; - use crate::http::header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}; + use crate::http::{ + header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}, + StatusCode, + }; use crate::test::TestRequest; #[derive(Deserialize, Serialize, Debug, PartialEq)] @@ -393,11 +379,11 @@ mod tests { #[actix_rt::test] async fn test_form() { - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world&counter=123")) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((CONTENT_LENGTH, 11)) + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); let Form(s) = Form::::from_request(&req, &mut pl).await.unwrap(); assert_eq!( @@ -426,25 +412,26 @@ mod tests { #[actix_rt::test] async fn test_urlencoded_error() { - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "xxxx") - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((CONTENT_LENGTH, "xxxx")) + .to_http_parts(); let info = UrlEncoded::::new(&req, &mut pl).await; assert!(eq(info.err().unwrap(), UrlencodedError::UnknownLength)); - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "1000000") - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((CONTENT_LENGTH, "1000000")) + .to_http_parts(); let info = UrlEncoded::::new(&req, &mut pl).await; assert!(eq( info.err().unwrap(), UrlencodedError::Overflow { size: 0, limit: 0 } )); - let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "text/plain") - .header(CONTENT_LENGTH, "10") + let (req, mut pl) = TestRequest::default() + .insert_header((CONTENT_TYPE, "text/plain")) + .insert_header((CONTENT_LENGTH, 10)) .to_http_parts(); let info = UrlEncoded::::new(&req, &mut pl).await; assert!(eq(info.err().unwrap(), UrlencodedError::ContentType)); @@ -452,11 +439,11 @@ mod tests { #[actix_rt::test] async fn test_urlencoded() { - let (req, mut pl) = - TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world&counter=123")) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((CONTENT_LENGTH, 11)) + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); assert_eq!( @@ -467,13 +454,14 @@ mod tests { } ); - let (req, mut pl) = TestRequest::with_header( - CONTENT_TYPE, - "application/x-www-form-urlencoded; charset=utf-8", - ) - .header(CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world&counter=123")) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header(( + CONTENT_TYPE, + "application/x-www-form-urlencoded; charset=utf-8", + )) + .insert_header((CONTENT_LENGTH, 11)) + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); assert_eq!( @@ -493,7 +481,7 @@ mod tests { hello: "world".to_string(), counter: 123, }); - let resp = form.respond_to(&req).await.unwrap(); + let resp = form.respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), @@ -509,8 +497,8 @@ mod tests { let ctype = HeaderValue::from_static("application/x-www-form-urlencoded"); let (req, mut pl) = TestRequest::default() - .header(CONTENT_TYPE, ctype) - .header(CONTENT_LENGTH, HeaderValue::from_static("20")) + .insert_header((CONTENT_TYPE, ctype)) + .insert_header((CONTENT_LENGTH, HeaderValue::from_static("20"))) .set_payload(Bytes::from_static(b"hello=test&counter=4")) .app_data(web::Data::new(FormConfig::default().limit(10))) .to_http_parts(); @@ -519,6 +507,6 @@ mod tests { assert!(s.is_err()); let err_str = s.err().unwrap().to_string(); - assert!(err_str.contains("Urlencoded payload size is bigger")); + assert!(err_str.starts_with("URL encoded payload is larger")); } } diff --git a/src/types/json.rs b/src/types/json.rs index dc0870a6e..31ff680f4 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -1,46 +1,44 @@ -//! Json extractor/responder +//! For JSON helper documentation, see [`Json`]. -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{fmt, ops}; +use std::{ + fmt, + future::Future, + marker::PhantomData, + ops, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use bytes::BytesMut; -use futures_util::future::{ready, Ready}; -use futures_util::ready; -use futures_util::stream::Stream; -use serde::de::DeserializeOwned; -use serde::Serialize; +use futures_util::{ready, stream::Stream}; +use serde::{de::DeserializeOwned, Serialize}; -use actix_http::http::{header::CONTENT_LENGTH, StatusCode}; -use actix_http::{HttpMessage, Payload, Response}; +use actix_http::Payload; #[cfg(feature = "compress")] use crate::dev::Decompress; -use crate::error::{Error, JsonPayloadError}; -use crate::extract::FromRequest; -use crate::request::HttpRequest; -use crate::{responder::Responder, web}; +use crate::{ + error::{Error, JsonPayloadError}, + extract::FromRequest, + http::header::CONTENT_LENGTH, + request::HttpRequest, + web, HttpMessage, HttpResponse, Responder, +}; -/// Json helper +/// JSON extractor and responder. /// -/// Json can be used for two different purpose. First is for json response -/// generation and second is for extracting typed information from request's -/// payload. +/// `Json` has two uses: JSON responses, and extracting typed data from JSON request payloads. /// -/// To extract typed information from request's body, the type `T` must -/// implement the `Deserialize` trait from *serde*. +/// # Extractor +/// To extract typed data from a request body, the inner type `T` must implement the +/// [`serde::Deserialize`] trait. /// -/// [**JsonConfig**](JsonConfig) allows to configure extraction -/// process. +/// Use [`JsonConfig`] to configure extraction process. /// -/// ## Example -/// -/// ```rust -/// use actix_web::{web, App}; -/// use serde_derive::Deserialize; +/// ``` +/// use actix_web::{post, web, App}; +/// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -48,43 +46,37 @@ use crate::{responder::Responder, web}; /// } /// /// /// deserialize `Info` from request's body +/// #[post("/")] /// async fn index(info: web::Json) -> String { /// format!("Welcome {}!", info.username) /// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html").route( -/// web::post().to(index)) -/// ); -/// } /// ``` /// -/// The `Json` type allows you to respond with well-formed JSON data: simply -/// return a value of type Json where T is the type of a structure -/// to serialize into *JSON*. The type `T` must implement the `Serialize` -/// trait from *serde*. +/// # Responder +/// The `Json` type JSON formatted responses. A handler may return a value of type +/// `Json` where `T` is the type of a structure to serialize into JSON. The type `T` must +/// implement [`serde::Serialize`]. /// -/// ```rust -/// use actix_web::*; -/// use serde_derive::Serialize; +/// ``` +/// use actix_web::{post, web, HttpRequest}; +/// use serde::Serialize; /// /// #[derive(Serialize)] -/// struct MyObj { +/// struct Info { /// name: String, /// } /// -/// fn index(req: HttpRequest) -> Result> { -/// Ok(web::Json(MyObj { -/// name: req.match_info().get("name").unwrap().to_string(), -/// })) +/// #[post("/{name}")] +/// async fn index(req: HttpRequest) -> web::Json { +/// web::Json(Info { +/// name: req.match_info().get("name").unwrap().to_owned(), +/// }) /// } -/// # fn main() {} /// ``` pub struct Json(pub T); impl Json { - /// Deconstruct to an inner value + /// Unwrap into inner `T` value. pub fn into_inner(self) -> T { self.0 } @@ -122,54 +114,21 @@ where } } +/// Creates response with OK status code, correct content type header, and serialized JSON payload. +/// +/// If serialization failed impl Responder for Json { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - let body = match serde_json::to_string(&self.0) { - Ok(body) => body, - Err(e) => return ready(Err(e.into())), - }; - - ready(Ok(Response::build(StatusCode::OK) - .content_type("application/json") - .body(body))) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + match serde_json::to_string(&self.0) { + Ok(body) => HttpResponse::Ok() + .content_type(mime::APPLICATION_JSON) + .body(body), + Err(err) => HttpResponse::from_error(err.into()), + } } } -/// Json extractor. Allow to extract typed information from request's -/// payload. -/// -/// To extract typed information from request's body, the type `T` must -/// implement the `Deserialize` trait from *serde*. -/// -/// [**JsonConfig**](JsonConfig) allows to configure extraction -/// process. -/// -/// ## Example -/// -/// ```rust -/// use actix_web::{web, App}; -/// use serde_derive::Deserialize; -/// -/// #[derive(Deserialize)] -/// struct Info { -/// username: String, -/// } -/// -/// /// deserialize `Info` from request's body -/// async fn index(info: web::Json) -> String { -/// format!("Welcome {}!", info.username) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html").route( -/// web::post().to(index)) -/// ); -/// } -/// ``` +/// See [here](#extractor) for example of usage as an extractor. impl FromRequest for Json where T: DeserializeOwned + 'static, @@ -215,7 +174,7 @@ where let res = ready!(Pin::new(&mut this.fut).poll(cx)); let res = match res { - Err(e) => { + Err(err) => { let req = this.req.take().unwrap(); log::debug!( "Failed to deserialize Json from payload. \ @@ -223,10 +182,10 @@ where req.path() ); - if let Some(err) = this.err_handler.as_ref() { - Err((*err)(e, &req)) + if let Some(err_handler) = this.err_handler.as_ref() { + Err((*err_handler)(err, &req)) } else { - Err(e.into()) + Err(err.into()) } } Ok(data) => Ok(Json(data)), @@ -236,44 +195,39 @@ where } } -/// Json extractor configuration +/// `Json` extractor configuration. /// -/// # Example -/// -/// ```rust -/// use actix_web::{error, web, App, FromRequest, HttpResponse}; -/// use serde_derive::Deserialize; +/// # Examples +/// ``` +/// use actix_web::{error, post, web, App, FromRequest, HttpResponse}; +/// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { -/// username: String, +/// name: String, /// } /// -/// /// deserialize `Info` from request's body, max payload size is 4kb +/// // `Json` extraction is bound by custom `JsonConfig` applied to App. +/// #[post("/")] /// async fn index(info: web::Json) -> String { -/// format!("Welcome {}!", info.username) +/// format!("Welcome {}!", info.name) /// } /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html") -/// .app_data( -/// // Json extractor configuration for this resource. -/// web::JsonConfig::default() -/// .limit(4096) // Limit request payload size -/// .content_type(|mime| { // <- accept text/plain content type -/// mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN -/// }) -/// .error_handler(|err, req| { // <- create custom error response -/// error::InternalError::from_response( -/// err, HttpResponse::Conflict().finish()).into() -/// }) -/// ) -/// .route(web::post().to(index)) -/// ); -/// } +/// // custom `Json` extractor configuration +/// let json_cfg = web::JsonConfig::default() +/// // limit request payload size +/// .limit(4096) +/// // only accept text/plain content type +/// .content_type(|mime| mime == mime::TEXT_PLAIN) +/// // use custom error handler +/// .error_handler(|err, req| { +/// error::InternalError::from_response(err, HttpResponse::Conflict().finish()).into() +/// }); +/// +/// App::new() +/// .app_data(json_cfg) +/// .service(index); /// ``` -/// #[derive(Clone)] pub struct JsonConfig { limit: usize, @@ -282,13 +236,13 @@ pub struct JsonConfig { } impl JsonConfig { - /// Change max size of payload. By default max size is 32Kb + /// Set maximum accepted payload size. By default this limit is 32kB. pub fn limit(mut self, limit: usize) -> Self { self.limit = limit; self } - /// Set custom error handler + /// Set custom error handler. pub fn error_handler(mut self, f: F) -> Self where F: Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync + 'static, @@ -297,7 +251,7 @@ impl JsonConfig { self } - /// Set predicate for allowed content types + /// Set predicate for allowed content types. pub fn content_type(mut self, predicate: F) -> Self where F: Fn(mime::Mime) -> bool + Send + Sync + 'static, @@ -315,7 +269,7 @@ impl JsonConfig { } } -// Allow shared refs to default. +/// Allow shared refs used as default. const DEFAULT_CONFIG: JsonConfig = JsonConfig { limit: 32_768, // 2^15 bytes, (~32kB) err_handler: None, @@ -328,15 +282,14 @@ impl Default for JsonConfig { } } -/// Request's payload json parser, it resolves to a deserialized `T` value. -/// This future could be used with `ServiceRequest` and `ServiceFromRequest`. +/// Future that resolves to some `T` when parsed from a JSON payload. /// -/// Returns error: +/// Form can be deserialized from any type `T` that implements [`serde::Deserialize`]. /// -/// * content type is not `application/json` -/// (unless specified in [`JsonConfig`]) -/// * content length is greater than 256k -pub enum JsonBody { +/// Returns error if: +/// - content type is not `application/json` +/// - content length is greater than [limit](JsonBody::limit()) +pub enum JsonBody { Error(Option), Body { limit: usize, @@ -346,17 +299,17 @@ pub enum JsonBody { #[cfg(not(feature = "compress"))] payload: Payload, buf: BytesMut, - _res: PhantomData, + _res: PhantomData, }, } -impl Unpin for JsonBody {} +impl Unpin for JsonBody {} -impl JsonBody +impl JsonBody where - U: DeserializeOwned + 'static, + T: DeserializeOwned + 'static, { - /// Create `JsonBody` for request. + /// Create a new future to decode a JSON request payload. #[allow(clippy::borrow_interior_mutable_const)] pub fn new( req: &HttpRequest, @@ -367,7 +320,7 @@ where let json = if let Ok(Some(mime)) = req.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) - || ctype.as_ref().map_or(false, |predicate| predicate(mime)) + || ctype.map_or(false, |predicate| predicate(mime)) } else { false }; @@ -392,7 +345,7 @@ where let payload = payload.take(); JsonBody::Body { - limit: 262_144, + limit: 32_768, length, payload, buf: BytesMut::with_capacity(8192), @@ -400,7 +353,7 @@ where } } - /// Change max size of payload. By default max size is 256Kb + /// Set maximum accepted payload size. The default limit is 32kB. pub fn limit(self, limit: usize) -> Self { match self { JsonBody::Body { @@ -428,11 +381,11 @@ where } } -impl Future for JsonBody +impl Future for JsonBody where - U: DeserializeOwned + 'static, + T: DeserializeOwned + 'static, { - type Output = Result; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); @@ -455,7 +408,7 @@ where } } None => { - let json = serde_json::from_slice::(&buf)?; + let json = serde_json::from_slice::(&buf)?; return Poll::Ready(Ok(json)); } } @@ -468,13 +421,17 @@ where #[cfg(test)] mod tests { use bytes::Bytes; - use serde_derive::{Deserialize, Serialize}; + use serde::{Deserialize, Serialize}; use super::*; - use crate::error::InternalError; - use crate::http::header::{self, HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}; - use crate::test::{load_stream, TestRequest}; - use crate::HttpResponse; + use crate::{ + error::InternalError, + http::{ + header::{self, CONTENT_LENGTH, CONTENT_TYPE}, + StatusCode, + }, + test::{load_stream, TestRequest}, + }; #[derive(Serialize, Deserialize, PartialEq, Debug)] struct MyObject { @@ -498,7 +455,7 @@ mod tests { let j = Json(MyObject { name: "test".to_string(), }); - let resp = j.respond_to(&req).await.unwrap(); + let resp = j.respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -512,27 +469,27 @@ mod tests { #[actix_rt::test] async fn test_custom_error_responder() { let (req, mut pl) = TestRequest::default() - .header( + .insert_header(( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + )) + .insert_header(( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), - ) + )) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .app_data(JsonConfig::default().limit(10).error_handler(|err, _| { let msg = MyObject { name: "invalid request".to_string(), }; - let resp = HttpResponse::BadRequest() - .body(serde_json::to_string(&msg).unwrap()); + let resp = + HttpResponse::BadRequest().body(serde_json::to_string(&msg).unwrap()); InternalError::from_response(err, resp).into() })) .to_http_parts(); let s = Json::::from_request(&req, &mut pl).await; - let mut resp = Response::from_error(s.err().unwrap()); + let mut resp = HttpResponse::from_error(s.err().unwrap()); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let body = load_stream(resp.take_body()).await.unwrap(); @@ -543,14 +500,14 @@ mod tests { #[actix_rt::test] async fn test_extract() { let (req, mut pl) = TestRequest::default() - .header( + .insert_header(( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + )) + .insert_header(( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), - ) + )) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .to_http_parts(); @@ -564,14 +521,14 @@ mod tests { ); let (req, mut pl) = TestRequest::default() - .header( + .insert_header(( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + )) + .insert_header(( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), - ) + )) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .app_data(JsonConfig::default().limit(10)) .to_http_parts(); @@ -581,14 +538,14 @@ mod tests { .contains("Json payload size is bigger than allowed")); let (req, mut pl) = TestRequest::default() - .header( + .insert_header(( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + )) + .insert_header(( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), - ) + )) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .app_data( JsonConfig::default() @@ -607,23 +564,23 @@ mod tests { assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() - .header( + .insert_header(( header::CONTENT_TYPE, header::HeaderValue::from_static("application/text"), - ) + )) .to_http_parts(); let json = JsonBody::::new(&req, &mut pl, None).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() - .header( + .insert_header(( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + )) + .insert_header(( header::CONTENT_LENGTH, header::HeaderValue::from_static("10000"), - ) + )) .to_http_parts(); let json = JsonBody::::new(&req, &mut pl, None) @@ -632,14 +589,14 @@ mod tests { assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); let (req, mut pl) = TestRequest::default() - .header( + .insert_header(( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + )) + .insert_header(( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), - ) + )) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .to_http_parts(); @@ -654,17 +611,18 @@ mod tests { #[actix_rt::test] async fn test_with_json_and_bad_content_type() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .app_data(JsonConfig::default().limit(4096)) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain"), + )) + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + )) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .app_data(JsonConfig::default().limit(4096)) + .to_http_parts(); let s = Json::::from_request(&req, &mut pl).await; assert!(s.is_err()) @@ -672,19 +630,20 @@ mod tests { #[actix_rt::test] async fn test_with_json_and_good_custom_content_type() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .app_data(JsonConfig::default().content_type(|mime: mime::Mime| { - mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN - })) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain"), + )) + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + )) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .app_data(JsonConfig::default().content_type(|mime: mime::Mime| { + mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN + })) + .to_http_parts(); let s = Json::::from_request(&req, &mut pl).await; assert!(s.is_ok()) @@ -692,19 +651,20 @@ mod tests { #[actix_rt::test] async fn test_with_json_and_bad_custom_content_type() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/html"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .app_data(JsonConfig::default().content_type(|mime: mime::Mime| { - mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN - })) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/html"), + )) + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + )) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .app_data(JsonConfig::default().content_type(|mime: mime::Mime| { + mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN + })) + .to_http_parts(); let s = Json::::from_request(&req, &mut pl).await; assert!(s.is_err()) @@ -713,8 +673,8 @@ mod tests { #[actix_rt::test] async fn test_with_config_in_data_wrapper() { let (req, mut pl) = TestRequest::default() - .header(CONTENT_TYPE, HeaderValue::from_static("application/json")) - .header(CONTENT_LENGTH, HeaderValue::from_static("16")) + .insert_header((CONTENT_TYPE, mime::APPLICATION_JSON)) + .insert_header((CONTENT_LENGTH, 16)) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .app_data(web::Data::new(JsonConfig::default().limit(10))) .to_http_parts(); diff --git a/src/types/mod.rs b/src/types/mod.rs index cedf86dd2..a062c351e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,5 +1,6 @@ -//! Helper types +//! Common extractors and responders. +// TODO: review visibility mod either; pub(crate) mod form; pub(crate) mod json; diff --git a/src/types/path.rs b/src/types/path.rs index 640ff4346..4ab124d53 100644 --- a/src/types/path.rs +++ b/src/types/path.rs @@ -1,70 +1,55 @@ -//! Path extractor -use std::sync::Arc; -use std::{fmt, ops}; +//! For path segment extractor documentation, see [`Path`]. + +use std::{fmt, ops, sync::Arc}; use actix_http::error::{Error, ErrorNotFound}; use actix_router::PathDeserializer; use futures_util::future::{ready, Ready}; use serde::de; -use crate::dev::Payload; -use crate::error::PathError; -use crate::request::HttpRequest; -use crate::FromRequest; +use crate::{dev::Payload, error::PathError, FromRequest, HttpRequest}; -#[derive(PartialEq, Eq, PartialOrd, Ord)] -/// Extract typed information from the request's path. +/// Extract typed data from request path segments. /// -/// [**PathConfig**](PathConfig) allows to configure extraction process. +/// Use [`PathConfig`] to configure extraction process. /// -/// ## Example +/// # Examples +/// ``` +/// use actix_web::{get, web}; /// -/// ```rust -/// use actix_web::{web, App}; -/// -/// /// extract path info from "/{username}/{count}/index.html" url -/// /// {username} - deserializes to a String -/// /// {count} - - deserializes to a u32 -/// async fn index(web::Path((username, count)): web::Path<(String, u32)>) -> String { -/// format!("Welcome {}! {}", username, count) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{username}/{count}/index.html") // <- define path parameters -/// .route(web::get().to(index)) // <- register handler with `Path` extractor -/// ); +/// // extract path info from "/{name}/{count}/index.html" into tuple +/// // {name} - deserialize a String +/// // {count} - deserialize a u32 +/// #[get("/")] +/// async fn index(path: web::Path<(String, u32)>) -> String { +/// let (name, count) = path.into_inner(); +/// format!("Welcome {}! {}", name, count) /// } /// ``` /// -/// It is possible to extract path information to a specific type that -/// implements `Deserialize` trait from *serde*. +/// Path segments also can be deserialized into any type that implements [`serde::Deserialize`]. +/// Path segment labels will be matched with struct field names. /// -/// ```rust -/// use actix_web::{web, App, Error}; -/// use serde_derive::Deserialize; +/// ``` +/// use actix_web::{get, web}; +/// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { -/// username: String, +/// name: String, /// } /// -/// /// extract `Info` from a path using serde -/// async fn index(info: web::Path) -> Result { -/// Ok(format!("Welcome {}!", info.username)) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{username}/index.html") // <- define path parameters -/// .route(web::get().to(index)) // <- use handler with Path` extractor -/// ); +/// // extract `Info` from a path using serde +/// #[get("/")] +/// async fn index(info: web::Path) -> String { +/// format!("Welcome {}!", info.name) /// } /// ``` -pub struct Path(pub T); +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub struct Path(T); impl Path { - /// Deconstruct to an inner value + /// Unwrap into inner `T` value. pub fn into_inner(self) -> T { self.0 } @@ -108,52 +93,7 @@ impl fmt::Display for Path { } } -/// Extract typed information from the request's path. -/// -/// ## Example -/// -/// ```rust -/// use actix_web::{web, App}; -/// -/// /// extract path info from "/{username}/{count}/index.html" url -/// /// {username} - deserializes to a String -/// /// {count} - - deserializes to a u32 -/// async fn index(web::Path((username, count)): web::Path<(String, u32)>) -> String { -/// format!("Welcome {}! {}", username, count) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{username}/{count}/index.html") // <- define path parameters -/// .route(web::get().to(index)) // <- register handler with `Path` extractor -/// ); -/// } -/// ``` -/// -/// It is possible to extract path information to a specific type that -/// implements `Deserialize` trait from *serde*. -/// -/// ```rust -/// use actix_web::{web, App, Error}; -/// use serde_derive::Deserialize; -/// -/// #[derive(Deserialize)] -/// struct Info { -/// username: String, -/// } -/// -/// /// extract `Info` from a path using serde -/// async fn index(info: web::Path) -> Result { -/// Ok(format!("Welcome {}!", info.username)) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/{username}/index.html") // <- define path parameters -/// .route(web::get().to(index)) // <- use handler with Path` extractor -/// ); -/// } -/// ``` +/// See [here](#usage) for example of usage as an extractor. impl FromRequest for Path where T: de::DeserializeOwned, @@ -191,10 +131,10 @@ where /// Path extractor configuration /// -/// ```rust +/// ``` /// use actix_web::web::PathConfig; /// use actix_web::{error, web, App, FromRequest, HttpResponse}; -/// use serde_derive::Deserialize; +/// use serde::Deserialize; /// /// #[derive(Deserialize, Debug)] /// enum Folder { @@ -249,7 +189,7 @@ impl Default for PathConfig { mod tests { use actix_router::ResourceDef; use derive_more::Display; - use serde_derive::Deserialize; + use serde::Deserialize; use super::*; use crate::test::TestRequest; @@ -295,11 +235,9 @@ mod tests { assert_eq!(res.1, "user1"); let (Path(a), Path(b)) = - <(Path<(String, String)>, Path<(String, String)>)>::from_request( - &req, &mut pl, - ) - .await - .unwrap(); + <(Path<(String, String)>, Path<(String, String)>)>::from_request(&req, &mut pl) + .await + .unwrap(); assert_eq!(a.0, "name"); assert_eq!(a.1, "user1"); assert_eq!(b.0, "name"); @@ -360,11 +298,8 @@ mod tests { async fn test_custom_err_handler() { let (req, mut pl) = TestRequest::with_uri("/name/user1/") .app_data(PathConfig::default().error_handler(|err, _| { - error::InternalError::from_response( - err, - HttpResponse::Conflict().finish(), - ) - .into() + error::InternalError::from_response(err, HttpResponse::Conflict().finish()) + .into() })) .to_http_parts(); diff --git a/src/types/payload.rs b/src/types/payload.rs index 9228b37aa..781347b84 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -1,57 +1,51 @@ -//! Payload/Bytes/String extractors -use std::future::Future; -use std::pin::Pin; -use std::str; -use std::task::{Context, Poll}; +//! Basic binary and string payload extractors. -use actix_http::error::{Error, ErrorBadRequest, PayloadError}; -use actix_http::HttpMessage; +use std::{ + future::Future, + pin::Pin, + str, + task::{Context, Poll}, +}; + +use actix_http::error::{ErrorBadRequest, PayloadError}; use bytes::{Bytes, BytesMut}; use encoding_rs::{Encoding, UTF_8}; use futures_core::stream::Stream; use futures_util::{ - future::{err, ok, Either, ErrInto, Ready, TryFutureExt as _}, + future::{ready, Either, ErrInto, Ready, TryFutureExt as _}, ready, }; use mime::Mime; -use crate::extract::FromRequest; -use crate::http::header; -use crate::request::HttpRequest; -use crate::{dev, web}; +use crate::{dev, http::header, web, Error, FromRequest, HttpMessage, HttpRequest}; -/// Payload extractor returns request 's payload stream. +/// Extract a request's raw payload stream. /// -/// ## Example +/// See [`PayloadConfig`] for important notes when using this advanced extractor. /// -/// ```rust -/// use actix_web::{web, error, App, Error, HttpResponse}; +/// # Examples +/// ``` /// use std::future::Future; -/// use futures_core::stream::Stream; -/// use futures_util::StreamExt; -/// /// extract binary data from request -/// async fn index(mut body: web::Payload) -> Result -/// { +/// use futures_util::stream::{Stream, StreamExt}; +/// use actix_web::{post, web}; +/// +/// // `body: web::Payload` parameter extracts raw payload stream from request +/// #[post("/")] +/// async fn index(mut body: web::Payload) -> actix_web::Result { +/// // for demonstration only; in a normal case use the `Bytes` extractor +/// // collect payload stream into a bytes object /// let mut bytes = web::BytesMut::new(); /// while let Some(item) = body.next().await { /// bytes.extend_from_slice(&item?); /// } /// -/// format!("Body {:?}!", bytes); -/// Ok(HttpResponse::Ok().finish()) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html").route( -/// web::get().to(index)) -/// ); +/// Ok(format!("Request Body Bytes:\n{:?}", bytes)) /// } /// ``` pub struct Payload(pub crate::dev::Payload); impl Payload { - /// Deconstruct to a inner value + /// Unwrap to inner Payload type. pub fn into_inner(self) -> crate::dev::Payload { self.0 } @@ -61,43 +55,12 @@ impl Stream for Payload { type Item = Result; #[inline] - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_next(cx) } } -/// Get request's payload stream -/// -/// ## Example -/// -/// ```rust -/// use actix_web::{web, error, App, Error, HttpResponse}; -/// use std::future::Future; -/// use futures_core::stream::Stream; -/// use futures_util::StreamExt; -/// -/// /// extract binary data from request -/// async fn index(mut body: web::Payload) -> Result -/// { -/// let mut bytes = web::BytesMut::new(); -/// while let Some(item) = body.next().await { -/// bytes.extend_from_slice(&item?); -/// } -/// -/// format!("Body {:?}!", bytes); -/// Ok(HttpResponse::Ok().finish()) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html").route( -/// web::get().to(index)) -/// ); -/// } -/// ``` +/// See [here](#usage) for example of usage as an extractor. impl FromRequest for Payload { type Config = PayloadConfig; type Error = Error; @@ -105,34 +68,25 @@ impl FromRequest for Payload { #[inline] fn from_request(_: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { - ok(Payload(payload.take())) + ready(Ok(Payload(payload.take()))) } } -/// Request binary data from a request's payload. +/// Extract binary data from a request's payload. /// -/// Loads request's payload and construct Bytes instance. +/// Collects request payload stream into a [Bytes] instance. /// -/// [**PayloadConfig**](PayloadConfig) allows to configure -/// extraction process. +/// Use [`PayloadConfig`] to configure extraction process. /// -/// ## Example -/// -/// ```rust -/// use bytes::Bytes; -/// use actix_web::{web, App}; +/// # Examples +/// ``` +/// use actix_web::{post, web}; /// /// /// extract binary data from request -/// async fn index(body: Bytes) -> String { +/// #[post("/")] +/// async fn index(body: web::Bytes) -> String { /// format!("Body {:?}!", body) /// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html").route( -/// web::get().to(index)) -/// ); -/// } /// ``` impl FromRequest for Bytes { type Config = PayloadConfig; @@ -144,8 +98,8 @@ impl FromRequest for Bytes { // allow both Config and Data let cfg = PayloadConfig::from_req(req); - if let Err(e) = cfg.check_mimetype(req) { - return Either::Right(err(e)); + if let Err(err) = cfg.check_mimetype(req) { + return Either::Right(ready(Err(err))); } let limit = cfg.limit; @@ -161,26 +115,15 @@ impl FromRequest for Bytes { /// [**PayloadConfig**](PayloadConfig) allows to configure /// extraction process. /// -/// ## Example +/// # Examples +/// ``` +/// use actix_web::{post, web, FromRequest}; /// -/// ```rust -/// use actix_web::{web, App, FromRequest}; -/// -/// /// extract text data from request +/// // extract text data from request +/// #[post("/")] /// async fn index(text: String) -> String { /// format!("Body {}!", text) /// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html") -/// .app_data(String::configure(|cfg| { // <- limit size of the payload -/// cfg.limit(4096) -/// })) -/// .route(web::get().to(index)) // <- register handler with extractor params -/// ); -/// } -/// ``` impl FromRequest for String { type Config = PayloadConfig; type Error = Error; @@ -191,14 +134,14 @@ impl FromRequest for String { let cfg = PayloadConfig::from_req(req); // check content-type - if let Err(e) = cfg.check_mimetype(req) { - return Either::Right(err(e)); + if let Err(err) = cfg.check_mimetype(req) { + return Either::Right(ready(Err(err))); } // check charset let encoding = match req.encoding() { Ok(enc) => enc, - Err(e) => return Either::Right(err(e.into())), + Err(err) => return Either::Right(ready(Err(err.into()))), }; let limit = cfg.limit; let body_fut = HttpMessageBody::new(req, payload).limit(limit); @@ -238,11 +181,16 @@ fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result Self { Self { limit, @@ -258,14 +206,13 @@ impl PayloadConfig { } } - /// Change max size of payload. By default max size is 256Kb + /// Set maximum accepted payload size in bytes. The default limit is 256kB. pub fn limit(mut self, limit: usize) -> Self { self.limit = limit; self } - /// Set required mime-type of the request. By default mime type is not - /// enforced. + /// Set required mime type of the request. By default mime type is not enforced. pub fn mimetype(mut self, mt: Mime) -> Self { self.mimetype = Some(mt); self @@ -292,7 +239,7 @@ impl PayloadConfig { } /// Extract payload config from app data. Check both `T` and `Data`, in that order, and fall - /// back to the default payload config. + /// back to the default payload config if neither is found. fn from_req(req: &HttpRequest) -> &Self { req.app_data::() .or_else(|| req.app_data::>().map(|d| d.as_ref())) @@ -300,7 +247,7 @@ impl PayloadConfig { } } -// Allow shared refs to default. +/// Allow shared refs used as defaults. const DEFAULT_CONFIG: PayloadConfig = PayloadConfig { limit: DEFAULT_CONFIG_LIMIT, mimetype: None, @@ -314,13 +261,10 @@ impl Default for PayloadConfig { } } -/// Future that resolves to a complete http message body. +/// Future that resolves to a complete HTTP body payload. /// -/// Load http message body. -/// -/// By default only 256Kb payload reads to a memory, then -/// `PayloadError::Overflow` get returned. Use `MessageBody::limit()` -/// method to change upper limit. +/// By default only 256kB payload is accepted before `PayloadError::Overflow` is returned. +/// Use `MessageBody::limit()` method to change upper limit. pub struct HttpMessageBody { limit: usize, length: Option, @@ -342,10 +286,12 @@ impl HttpMessageBody { if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { match l.to_str() { Ok(s) => match s.parse::() { - Ok(l) if l > DEFAULT_CONFIG_LIMIT => { - err = Some(PayloadError::Overflow) + Ok(l) => { + if l > DEFAULT_CONFIG_LIMIT { + err = Some(PayloadError::Overflow); + } + length = Some(l) } - Ok(l) => length = Some(l), Err(_) => err = Some(PayloadError::UnknownLength), }, Err(_) => err = Some(PayloadError::UnknownLength), @@ -366,12 +312,14 @@ impl HttpMessageBody { } } - /// Change max size of payload. By default max size is 256Kb + /// Change max size of payload. By default max size is 256kB pub fn limit(mut self, limit: usize) -> Self { if let Some(l) = self.length { - if l > limit { - self.err = Some(PayloadError::Overflow); - } + self.err = if l > limit { + Some(PayloadError::Overflow) + } else { + None + }; } self.limit = limit; self @@ -384,8 +332,8 @@ impl Future for HttpMessageBody { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - if let Some(e) = this.err.take() { - return Poll::Ready(Err(e)); + if let Some(err) = this.err.take() { + return Poll::Ready(Err(err)); } loop { @@ -420,14 +368,13 @@ mod tests { let cfg = PayloadConfig::default().mimetype(mime::APPLICATION_JSON); assert!(cfg.check_mimetype(&req).is_err()); - let req = TestRequest::with_header( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) - .to_http_request(); + let req = TestRequest::default() + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .to_http_request(); assert!(cfg.check_mimetype(&req).is_err()); - let req = TestRequest::with_header(header::CONTENT_TYPE, "application/json") + let req = TestRequest::default() + .insert_header((header::CONTENT_TYPE, "application/json")) .to_http_request(); assert!(cfg.check_mimetype(&req).is_ok()); } @@ -442,13 +389,11 @@ mod tests { "payload is probably json string" } - let mut srv = init_service( + let srv = init_service( App::new() .service( web::resource("/bytes-app-data") - .app_data( - PayloadConfig::default().mimetype(mime::APPLICATION_JSON), - ) + .app_data(PayloadConfig::default().mimetype(mime::APPLICATION_JSON)) .route(web::get().to(bytes_handler)), ) .service( @@ -458,9 +403,7 @@ mod tests { ) .service( web::resource("/string-app-data") - .app_data( - PayloadConfig::default().mimetype(mime::APPLICATION_JSON), - ) + .app_data(PayloadConfig::default().mimetype(mime::APPLICATION_JSON)) .route(web::get().to(string_handler)), ) .service( @@ -472,49 +415,50 @@ mod tests { .await; let req = TestRequest::with_uri("/bytes-app-data").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/bytes-data").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/string-app-data").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/string-data").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/bytes-app-data") - .header(header::CONTENT_TYPE, mime::APPLICATION_JSON) + .insert_header(header::ContentType(mime::APPLICATION_JSON)) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/bytes-data") - .header(header::CONTENT_TYPE, mime::APPLICATION_JSON) + .insert_header(header::ContentType(mime::APPLICATION_JSON)) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/string-app-data") - .header(header::CONTENT_TYPE, mime::APPLICATION_JSON) + .insert_header(header::ContentType(mime::APPLICATION_JSON)) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/string-data") - .header(header::CONTENT_TYPE, mime::APPLICATION_JSON) + .insert_header(header::ContentType(mime::APPLICATION_JSON)) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_bytes() { - let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "11") + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_LENGTH, "11")) .set_payload(Bytes::from_static(b"hello=world")) .to_http_parts(); @@ -524,7 +468,8 @@ mod tests { #[actix_rt::test] async fn test_string() { - let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "11") + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_LENGTH, "11")) .set_payload(Bytes::from_static(b"hello=world")) .to_http_parts(); @@ -534,21 +479,23 @@ mod tests { #[actix_rt::test] async fn test_message_body() { - let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx") + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_LENGTH, "xxxx")) .to_srv_request() .into_parts(); let res = HttpMessageBody::new(&req, &mut pl).await; match res.err().unwrap() { - PayloadError::UnknownLength => (), + PayloadError::UnknownLength => {} _ => unreachable!("error"), } - let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "1000000") + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_LENGTH, "1000000")) .to_srv_request() .into_parts(); let res = HttpMessageBody::new(&req, &mut pl).await; match res.err().unwrap() { - PayloadError::Overflow => (), + PayloadError::Overflow => {} _ => unreachable!("error"), } @@ -563,7 +510,7 @@ mod tests { .to_http_parts(); let res = HttpMessageBody::new(&req, &mut pl).limit(5).await; match res.err().unwrap() { - PayloadError::Overflow => (), + PayloadError::Overflow => {} _ => unreachable!("error"), } } diff --git a/src/types/query.rs b/src/types/query.rs index 27df220fc..79af32581 100644 --- a/src/types/query.rs +++ b/src/types/query.rs @@ -1,30 +1,27 @@ -//! Query extractor +//! For query parameter extractor documentation, see [`Query`]. -use std::sync::Arc; -use std::{fmt, ops}; +use std::{fmt, ops, sync::Arc}; -use actix_http::error::Error; use futures_util::future::{err, ok, Ready}; use serde::de; -use crate::dev::Payload; -use crate::error::QueryPayloadError; -use crate::extract::FromRequest; -use crate::request::HttpRequest; +use crate::{dev::Payload, error::QueryPayloadError, Error, FromRequest, HttpRequest}; /// Extract typed information from the request's query. /// -/// **Note**: A query string consists of unordered `key=value` pairs, therefore it cannot -/// be decoded into any type which depends upon data ordering e.g. tuples or tuple-structs. -/// Attempts to do so will *fail at runtime*. +/// To extract typed data from the URL query string, the inner type `T` must implement the +/// [`serde::Deserialize`] trait. /// -/// [**QueryConfig**](QueryConfig) allows to configure extraction process. +/// Use [`QueryConfig`] to configure extraction process. /// -/// ## Example +/// # Panics +/// A query string consists of unordered `key=value` pairs, therefore it cannot be decoded into any +/// type which depends upon data ordering (eg. tuples). Trying to do so will result in a panic. /// -/// ```rust -/// use actix_web::{web, App}; -/// use serde_derive::Deserialize; +/// # Examples +/// ``` +/// use actix_web::{get, web}; +/// use serde::Deserialize; /// /// #[derive(Debug, Deserialize)] /// pub enum ResponseType { @@ -32,41 +29,60 @@ use crate::request::HttpRequest; /// Code /// } /// -/// #[derive(Deserialize)] +/// #[derive(Debug, Deserialize)] /// pub struct AuthRequest { /// id: u64, /// response_type: ResponseType, /// } /// -/// // Use `Query` extractor for query information (and destructure it within the signature). -/// // This handler gets called only if the request's query string contains `id` and `response_type` fields. -/// // The correct request for this handler would be `/index.html?id=64&response_type=Code"`. -/// async fn index(web::Query(info): web::Query) -> String { -/// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) +/// // Deserialize `AuthRequest` struct from query string. +/// // This handler gets called only if the request's query parameters contain both fields. +/// // A valid request path for this handler would be `/?id=64&response_type=Code"`. +/// #[get("/")] +/// async fn index(info: web::Query) -> String { +/// format!("Authorization request for id={} and type={:?}!", info.id, info.response_type) /// } /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html").route(web::get().to(index))); // <- use `Query` extractor +/// // To access the entire underlying query struct, use `.into_inner()`. +/// #[get("/debug1")] +/// async fn debug1(info: web::Query) -> String { +/// dbg!("Authorization object={:?}", info.into_inner()); +/// "OK".to_string() +/// } +/// +/// // Or use `.0`, which is equivalent to `.into_inner()`. +/// #[get("/debug2")] +/// async fn debug2(info: web::Query) -> String { +/// dbg!("Authorization object={:?}", info.0); +/// "OK".to_string() /// } /// ``` -#[derive(PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Query(pub T); impl Query { - /// Deconstruct to a inner value + /// Unwrap into inner `T` value. pub fn into_inner(self) -> T { self.0 } - /// Get query parameters from the path + /// Deserialize `T` from a URL encoded query parameter string. + /// + /// ``` + /// # use std::collections::HashMap; + /// # use actix_web::web::Query; + /// let numbers = Query::>::from_query("one=1&two=2").unwrap(); + /// assert_eq!(numbers.get("one"), Some(&1)); + /// assert_eq!(numbers.get("two"), Some(&2)); + /// assert!(numbers.get("three").is_none()); + /// ``` pub fn from_query(query_str: &str) -> Result where T: de::DeserializeOwned, { serde_urlencoded::from_str::(query_str) - .map(|val| Ok(Query(val))) - .unwrap_or_else(move |e| Err(QueryPayloadError::Deserialize(e))) + .map(Self) + .map_err(QueryPayloadError::Deserialize) } } @@ -96,39 +112,7 @@ impl fmt::Display for Query { } } -/// Extract typed information from the request's query. -/// -/// ## Example -/// -/// ```rust -/// use actix_web::{web, App}; -/// use serde_derive::Deserialize; -/// -/// #[derive(Debug, Deserialize)] -/// pub enum ResponseType { -/// Token, -/// Code -/// } -/// -/// #[derive(Deserialize)] -/// pub struct AuthRequest { -/// id: u64, -/// response_type: ResponseType, -/// } -/// -/// // Use `Query` extractor for query information. -/// // This handler get called only if request's query contains `id` and `response_type` fields. -/// // The correct request for this handler would be `/index.html?id=64&response_type=Code"` -/// async fn index(info: web::Query) -> String { -/// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) -/// } -/// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html") -/// .route(web::get().to(index))); // <- use `Query` extractor -/// } -/// ``` +/// See [here](#usage) for example of usage as an extractor. impl FromRequest for Query where T: de::DeserializeOwned, @@ -141,7 +125,7 @@ where fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let error_handler = req .app_data::() - .map(|c| c.ehandler.clone()) + .map(|c| c.err_handler.clone()) .unwrap_or(None); serde_urlencoded::from_str::(req.query_string()) @@ -166,13 +150,12 @@ where } } -/// Query extractor configuration +/// Query extractor configuration. /// -/// ## Example -/// -/// ```rust -/// use actix_web::{error, web, App, FromRequest, HttpResponse}; -/// use serde_derive::Deserialize; +/// # Examples +/// ``` +/// use actix_web::{error, get, web, App, FromRequest, HttpResponse}; +/// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -180,28 +163,25 @@ where /// } /// /// /// deserialize `Info` from request's querystring +/// #[get("/")] /// async fn index(info: web::Query) -> String { /// format!("Welcome {}!", info.username) /// } /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/index.html").app_data( -/// // change query extractor configuration -/// web::QueryConfig::default() -/// .error_handler(|err, req| { // <- create custom error response -/// error::InternalError::from_response( -/// err, HttpResponse::Conflict().finish()).into() -/// }) -/// ) -/// .route(web::post().to(index)) -/// ); -/// } +/// // custom `Query` extractor configuration +/// let query_cfg = web::QueryConfig::default() +/// // use custom error handler +/// .error_handler(|err, req| { +/// error::InternalError::from_response(err, HttpResponse::Conflict().finish()).into() +/// }); +/// +/// App::new() +/// .app_data(query_cfg) +/// .service(index); /// ``` #[derive(Clone)] pub struct QueryConfig { - ehandler: - Option Error + Send + Sync>>, + err_handler: Option Error + Send + Sync>>, } impl QueryConfig { @@ -210,14 +190,14 @@ impl QueryConfig { where F: Fn(QueryPayloadError, &HttpRequest) -> Error + Send + Sync + 'static, { - self.ehandler = Some(Arc::new(f)); + self.err_handler = Some(Arc::new(f)); self } } impl Default for QueryConfig { fn default() -> Self { - QueryConfig { ehandler: None } + QueryConfig { err_handler: None } } } @@ -225,7 +205,7 @@ impl Default for QueryConfig { mod tests { use actix_http::http::StatusCode; use derive_more::Display; - use serde_derive::Deserialize; + use serde::Deserialize; use super::*; use crate::error::InternalError; @@ -271,6 +251,17 @@ mod tests { assert_eq!(s.id, "test1"); } + #[actix_rt::test] + #[should_panic] + async fn test_tuple_panic() { + let req = TestRequest::with_uri("/?one=1&two=2").to_srv_request(); + let (req, mut pl) = req.into_parts(); + + Query::<(u32, u32)>::from_request(&req, &mut pl) + .await + .unwrap(); + } + #[actix_rt::test] async fn test_custom_error_responder() { let req = TestRequest::with_uri("/name/user1/") diff --git a/src/types/readlines.rs b/src/types/readlines.rs index f03235377..b8bdcc504 100644 --- a/src/types/readlines.rs +++ b/src/types/readlines.rs @@ -1,17 +1,23 @@ -use std::borrow::Cow; -use std::pin::Pin; -use std::str; -use std::task::{Context, Poll}; +//! For request line reader documentation, see [`Readlines`]. + +use std::{ + borrow::Cow, + pin::Pin, + str, + task::{Context, Poll}, +}; use bytes::{Bytes, BytesMut}; use encoding_rs::{Encoding, UTF_8}; -use futures_util::stream::Stream; +use futures_core::{ready, stream::Stream}; -use crate::dev::Payload; -use crate::error::{PayloadError, ReadlinesError}; -use crate::HttpMessage; +use crate::{ + dev::Payload, + error::{PayloadError, ReadlinesError}, + HttpMessage, +}; -/// Stream to read request line by line. +/// Stream that reads request line by line. pub struct Readlines { stream: Payload, buff: BytesMut, @@ -43,7 +49,7 @@ where } } - /// Change max line size. By default max size is 256Kb + /// Set maximum accepted payload size. The default limit is 256kB. pub fn limit(mut self, limit: usize) -> Self { self.limit = limit; self @@ -68,10 +74,7 @@ where { type Item = Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); if let Some(err) = this.err.take() { @@ -108,9 +111,10 @@ where } this.checked_buff = true; } + // poll req for more bytes - match Pin::new(&mut this.stream).poll_next(cx) { - Poll::Ready(Some(Ok(mut bytes))) => { + match ready!(Pin::new(&mut this.stream).poll_next(cx)) { + Some(Ok(mut bytes)) => { // check if there is a newline in bytes let mut found: Option = None; for (ind, b) in bytes.iter().enumerate() { @@ -144,8 +148,8 @@ where this.buff.extend_from_slice(&bytes); Poll::Pending } - Poll::Pending => Poll::Pending, - Poll::Ready(None) => { + + None => { if this.buff.is_empty() { return Poll::Ready(None); } @@ -165,7 +169,8 @@ where this.buff.clear(); Poll::Ready(Some(Ok(line))) } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ReadlinesError::from(e)))), + + Some(Err(err)) => Poll::Ready(Some(Err(ReadlinesError::from(err)))), } } } diff --git a/src/web.rs b/src/web.rs index 85e5f2e7b..1cef37109 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1,11 +1,11 @@ //! Essentials helper functions and types for application registration. + use actix_http::http::Method; use actix_router::IntoPattern; use std::future::Future; pub use actix_http::Response as HttpResponse; pub use bytes::{Buf, BufMut, Bytes, BytesMut}; -pub use futures_channel::oneshot::Canceled; use crate::error::BlockingError; use crate::extract::FromRequest; @@ -275,11 +275,11 @@ pub fn service(path: T) -> WebService { /// Execute blocking function on a thread pool, returns future that resolves /// to result of the function execution. -pub async fn block(f: F) -> Result> +pub fn block(f: F) -> impl Future> where - F: FnOnce() -> Result + Send + 'static, - I: Send + 'static, - E: Send + std::fmt::Debug + 'static, + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, { - actix_threadpool::run(f).await + let fut = actix_rt::task::spawn_blocking(f); + async { fut.await.map_err(|_| BlockingError) } } diff --git a/tests/cert.pem b/tests/cert.pem deleted file mode 100644 index 0eeb6721d..000000000 --- a/tests/cert.pem +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDEDCCAfgCCQCQdmIZc/Ib/jANBgkqhkiG9w0BAQsFADBKMQswCQYDVQQGEwJ1 -czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZIhvcNAQkBFhJmYWZo -cmQ5MUBnbWFpbC5jb20wHhcNMTkxMTE5MTEwNjU1WhcNMjkxMTE2MTEwNjU1WjBK -MQswCQYDVQQGEwJ1czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZI -hvcNAQkBFhJmYWZocmQ5MUBnbWFpbC5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB -DwAwggEKAoIBAQDcnaz12CKzUL7248V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C -+kLWKjAc2coqDSbGsrsR6KiH2g06Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezy -XRe/olcHFTeCk/Tllz4xGEplhPua6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqc -K2xntIPreumXpiE3QY4+MWyteiJko4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvu -GccHd/ex8cOwotUqd6emZb+0bVE24Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zU -b2GJosbmfGaf+xTfnGGhTLLL7kCtva+NvZr5AgMBAAEwDQYJKoZIhvcNAQELBQAD -ggEBANftoL8zDGrjCwWvct8kOOqset2ukK8vjIGwfm88CKsy0IfSochNz2qeIu9R -ZuO7c0pfjmRkir9ZQdq9vXgG3ccL9UstFsferPH9W3YJ83kgXg3fa0EmCiN/0hwz -6Ij1ZBiN1j3+d6+PJPgyYFNu2nGwox5mJ9+aRAGe0/9c63PEOY8P2TI4HsiPmYSl -fFR8k/03vr6e+rTKW85BgctjvYKe/TnFxeCQ7dZ+na7vlEtch4tNmy6O/vEk2kCt -5jW0DUxhmRsv2wGmfFRI0+LotHjoXQQZi6nN5aGL3odaGF3gYwIVlZNd3AdkwDQz -BzG0ZwXuDDV9bSs3MfWEWcy4xuU= ------END CERTIFICATE----- diff --git a/tests/key.pem b/tests/key.pem deleted file mode 100644 index a6d308168..000000000 --- a/tests/key.pem +++ /dev/null @@ -1,28 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDcnaz12CKzUL72 -48V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C+kLWKjAc2coqDSbGsrsR6KiH2g06 -Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezyXRe/olcHFTeCk/Tllz4xGEplhPua -6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqcK2xntIPreumXpiE3QY4+MWyteiJk -o4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvuGccHd/ex8cOwotUqd6emZb+0bVE2 -4Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zUb2GJosbmfGaf+xTfnGGhTLLL7kCt -va+NvZr5AgMBAAECggEBAKoU0UwzVgVCQgca8Jt2dnBvWYDhnxIfYAI/BvaKedMm -1ms87OKfB7oOiksjyI0E2JklH72dzZf2jm4CuZt5UjGC+xwPzlTaJ4s6hQVbBHyC -NRyxU1BCXtW5tThbrhD4OjxqjmLRJEIB9OunLtwAEQoeuFLB8Va7+HFhR+Zd9k3f -7aVA93pC5A50NRbZlke4miJ3Q8n7ZF0+UmxkBfm3fbqLk7aMWkoEKwLLTadjRlu1 -bBp0YDStX66I/p1kujqBOdh6VpPvxFOa1sV9pq0jeiGc9YfSkzRSKzIn8GoyviFB -fHeszQdNlcnrSDSNnMABAw+ZpxUO7SCaftjwejEmKZUCgYEA+TY43VpmV95eY7eo -WKwGepiHE0fwQLuKGELmZdZI80tFi73oZMuiB5WzwmkaKGcJmm7KGE9KEvHQCo9j -xvmktBR0VEZH8pmVfun+4h6+0H7m/NKMBBeOyv/IK8jBgHjkkB6e6nmeR7CqTxCw -tf9tbajl1QN8gNzXZSjBDT/lanMCgYEA4qANOKOSiEARtgwyXQeeSJcM2uPv6zF3 -ffM7vjSedtuEOHUSVeyBP/W8KDt7zyPppO/WNbURHS+HV0maS9yyj6zpVS2HGmbs -3fetswsQ+zYVdokW89x4oc2z4XOGHd1LcSlyhRwPt0u2g1E9L0irwTQLWU0npFmG -PRf7sN9+LeMCgYAGkDUDL2ROoB6gRa/7Vdx90hKMoXJkYgwLA4gJ2pDlR3A3c/Lw -5KQJyxmG3zm/IqeQF6be6QesZA30mT4peV2rGHbP2WH/s6fKReNelSy1VQJEWk8x -tGUgV4gwDwN5nLV4TjYlOrq+bJqvpmLhCC8bmj0jVQosYqSRl3cuICasnQKBgGlV -VO/Xb1su1EyWPK5qxRIeSxZOTYw2sMB01nbgxCqge0M2fvA6/hQ5ZlwY0cIEgits -YlcSMsMq/TAAANxz1vbaupUhlSMbZcsBvNV0Nk9c4vr2Wxm7hsJF9u66IEMvQUp2 -pkjiMxfR9CHzF4orr9EcHI5EQ0Grbq5kwFKEfoRbAoGAcWoFPILeJOlp2yW/Ds3E -g2fQdI9BAamtEZEaslJmZMmsDTg5ACPcDkOSFEQIaJ7wLPXeZy74FVk/NrY5F8Gz -bjX9OD/xzwp852yW5L9r62vYJakAlXef5jI6CFdYKDDCcarU0S7W5k6kq9n+wrBR -i1NklYmUAMr2q59uJA5zsic= ------END PRIVATE KEY----- diff --git a/tests/test_httpserver.rs b/tests/test_httpserver.rs index d164f4445..043159376 100644 --- a/tests/test_httpserver.rs +++ b/tests/test_httpserver.rs @@ -2,7 +2,12 @@ use std::sync::mpsc; use std::{thread, time::Duration}; #[cfg(feature = "openssl")] -use open_ssl::ssl::SslAcceptorBuilder; +extern crate tls_openssl as openssl; +#[cfg(feature = "rustls")] +extern crate tls_rustls as rustls; + +#[cfg(feature = "openssl")] +use openssl::ssl::SslAcceptorBuilder; use actix_web::{test, web, App, HttpResponse, HttpServer}; @@ -13,28 +18,31 @@ async fn test_start() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let sys = actix_rt::System::new("test"); + let sys = actix_rt::System::new(); - let srv = HttpServer::new(|| { - App::new().service( - web::resource("/").route(web::to(|| HttpResponse::Ok().body("test"))), - ) - }) - .workers(1) - .backlog(1) - .max_connections(10) - .max_connection_rate(10) - .keep_alive(10) - .client_timeout(5000) - .client_shutdown(0) - .server_hostname("localhost") - .system_exit() - .disable_signals() - .bind(format!("{}", addr)) - .unwrap() - .run(); + sys.block_on(async { + let srv = HttpServer::new(|| { + App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok().body("test"))), + ) + }) + .workers(1) + .backlog(1) + .max_connections(10) + .max_connection_rate(10) + .keep_alive(10) + .client_timeout(5000) + .client_shutdown(0) + .server_hostname("localhost") + .system_exit() + .disable_signals() + .bind(format!("{}", addr)) + .unwrap() + .run(); + + let _ = tx.send((srv, actix_rt::System::current())); + }); - let _ = tx.send((srv, actix_rt::System::current())); let _ = sys.run(); }); let (srv, sys) = rx.recv().unwrap(); @@ -63,18 +71,24 @@ async fn test_start() { let _ = sys.stop(); } -#[allow(clippy::unnecessary_wraps)] #[cfg(feature = "openssl")] fn ssl_acceptor() -> std::io::Result { - use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; - // load ssl keys + use openssl::{ + pkey::PKey, + ssl::{SslAcceptor, SslMethod}, + x509::X509, + }; + + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + let cert = X509::from_pem(cert_file.as_bytes()).unwrap(); + let key = PKey::private_key_from_pem(key_file.as_bytes()).unwrap(); + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("tests/cert.pem") - .unwrap(); + builder.set_certificate(&cert).unwrap(); + builder.set_private_key(&key).unwrap(); + Ok(builder) } @@ -87,7 +101,7 @@ async fn test_start_ssl() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let sys = actix_rt::System::new("test"); + let sys = actix_rt::System::new(); let builder = ssl_acceptor().unwrap(); let srv = HttpServer::new(|| { @@ -101,15 +115,18 @@ async fn test_start_ssl() { .system_exit() .disable_signals() .bind_openssl(format!("{}", addr), builder) - .unwrap() - .run(); + .unwrap(); + + sys.block_on(async { + let srv = srv.run(); + let _ = tx.send((srv, actix_rt::System::current())); + }); - let _ = tx.send((srv, actix_rt::System::current())); let _ = sys.run(); }); let (srv, sys) = rx.recv().unwrap(); - use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); let _ = builder @@ -120,8 +137,7 @@ async fn test_start_ssl() { .connector( awc::Connector::new() .ssl(builder.build()) - .timeout(Duration::from_millis(100)) - .finish(), + .timeout(Duration::from_millis(100)), ) .finish(); diff --git a/tests/test_server.rs b/tests/test_server.rs index c6c316f0d..2466730f9 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1,23 +1,35 @@ -use std::future::Future; -use std::io::{Read, Write}; -use std::pin::Pin; -use std::task::{Context, Poll}; +#[cfg(feature = "openssl")] +extern crate tls_openssl as openssl; +#[cfg(feature = "rustls")] +extern crate tls_rustls as rustls; + +use std::{ + future::Future, + io::{Read, Write}, + pin::Pin, + task::{Context, Poll}, +}; use actix_http::http::header::{ - ContentEncoding, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, - TRANSFER_ENCODING, + ContentEncoding, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING, }; use brotli2::write::{BrotliDecoder, BrotliEncoder}; use bytes::Bytes; -use flate2::read::GzDecoder; -use flate2::write::{GzEncoder, ZlibDecoder, ZlibEncoder}; -use flate2::Compression; +use flate2::{ + read::GzDecoder, + write::{GzEncoder, ZlibDecoder, ZlibEncoder}, + Compression, +}; use futures_util::ready; +use openssl::{ + pkey::PKey, + ssl::{SslAcceptor, SslMethod}, + x509::X509, +}; use rand::{distributions::Alphanumeric, Rng}; use actix_web::dev::BodyEncoding; -use actix_web::middleware::normalize::TrailingSlash; -use actix_web::middleware::{Compress, NormalizePath}; +use actix_web::middleware::{Compress, NormalizePath, TrailingSlash}; use actix_web::{dev, test, web, App, Error, HttpResponse}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ @@ -42,10 +54,34 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World"; +fn openssl_config() -> SslAcceptor { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + let cert = X509::from_pem(cert_file.as_bytes()).unwrap(); + let key = PKey::private_key_from_pem(key_file.as_bytes()).unwrap(); + + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder.set_certificate(&cert).unwrap(); + builder.set_private_key(&key).unwrap(); + + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(openssl::ssl::AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2").unwrap(); + + builder.build() +} + struct TestBody { data: Bytes, chunk_size: usize, - delay: actix_rt::time::Delay, + delay: Pin>, } impl TestBody { @@ -53,7 +89,7 @@ impl TestBody { TestBody { data, chunk_size, - delay: actix_rt::time::delay_for(std::time::Duration::from_millis(10)), + delay: Box::pin(actix_rt::time::sleep(std::time::Duration::from_millis(10))), } } } @@ -61,13 +97,10 @@ impl TestBody { impl futures_core::stream::Stream for TestBody { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(Pin::new(&mut self.delay).poll(cx)); - self.delay = actix_rt::time::delay_for(std::time::Duration::from_millis(10)); + self.delay = Box::pin(actix_rt::time::sleep(std::time::Duration::from_millis(10))); let chunk_size = std::cmp::min(self.chunk_size, self.data.len()); let chunk = self.data.split_to(chunk_size); if chunk.is_empty() { @@ -81,8 +114,7 @@ impl futures_core::stream::Stream for TestBody { #[actix_rt::test] async fn test_body() { let srv = test::start(|| { - App::new() - .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) + App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }); let mut response = srv.get("/").send().await.unwrap(); @@ -104,7 +136,7 @@ async fn test_body_gzip() { let mut response = srv .get("/") .no_decompress() - .header(ACCEPT_ENCODING, "gzip") + .append_header((ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); @@ -133,7 +165,7 @@ async fn test_body_gzip2() { let mut response = srv .get("/") .no_decompress() - .header(ACCEPT_ENCODING, "gzip") + .append_header((ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); @@ -174,7 +206,7 @@ async fn test_body_encoding_override() { let mut response = srv .get("/") .no_decompress() - .header(ACCEPT_ENCODING, "deflate") + .append_header((ACCEPT_ENCODING, "deflate")) .send() .await .unwrap(); @@ -193,7 +225,7 @@ async fn test_body_encoding_override() { let mut response = srv .request(actix_web::http::Method::GET, srv.url("/raw")) .no_decompress() - .header(ACCEPT_ENCODING, "deflate") + .append_header((ACCEPT_ENCODING, "deflate")) .send() .await .unwrap(); @@ -227,7 +259,7 @@ async fn test_body_gzip_large() { let mut response = srv .get("/") .no_decompress() - .header(ACCEPT_ENCODING, "gzip") + .append_header((ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); @@ -265,7 +297,7 @@ async fn test_body_gzip_large_random() { let mut response = srv .get("/") .no_decompress() - .header(ACCEPT_ENCODING, "gzip") + .append_header((ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); @@ -296,7 +328,7 @@ async fn test_body_chunked_implicit() { let mut response = srv .get("/") .no_decompress() - .header(ACCEPT_ENCODING, "gzip") + .append_header((ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); @@ -319,17 +351,17 @@ async fn test_body_chunked_implicit() { #[actix_rt::test] async fn test_body_br_streaming() { let srv = test::start_with(test::config().h1(), || { - App::new().wrap(Compress::new(ContentEncoding::Br)).service( - web::resource("/").route(web::to(move || { + App::new() + .wrap(Compress::new(ContentEncoding::Br)) + .service(web::resource("/").route(web::to(move || { HttpResponse::Ok() .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) - })), - ) + }))) }); let mut response = srv .get("/") - .header(ACCEPT_ENCODING, "br") + .append_header((ACCEPT_ENCODING, "br")) .no_decompress() .send() .await @@ -352,8 +384,7 @@ async fn test_body_br_streaming() { async fn test_head_binary() { let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::head().to(move || HttpResponse::Ok().body(STR))), + web::resource("/").route(web::head().to(move || HttpResponse::Ok().body(STR))), ) }); @@ -394,15 +425,13 @@ async fn test_body_deflate() { let srv = test::start_with(test::config().h1(), || { App::new() .wrap(Compress::new(ContentEncoding::Deflate)) - .service( - web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR))), - ) + .service(web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR)))) }); // client request let mut response = srv .get("/") - .header(ACCEPT_ENCODING, "deflate") + .append_header((ACCEPT_ENCODING, "deflate")) .no_decompress() .send() .await @@ -421,15 +450,15 @@ async fn test_body_deflate() { #[actix_rt::test] async fn test_body_brotli() { let srv = test::start_with(test::config().h1(), || { - App::new().wrap(Compress::new(ContentEncoding::Br)).service( - web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR))), - ) + App::new() + .wrap(Compress::new(ContentEncoding::Br)) + .service(web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR)))) }); // client request let mut response = srv .get("/") - .header(ACCEPT_ENCODING, "br") + .append_header((ACCEPT_ENCODING, "br")) .no_decompress() .send() .await @@ -450,8 +479,7 @@ async fn test_body_brotli() { async fn test_encoding() { let srv = test::start_with(test::config().h1(), || { App::new().wrap(Compress::default()).service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -462,7 +490,7 @@ async fn test_encoding() { let request = srv .post("/") - .header(CONTENT_ENCODING, "gzip") + .insert_header((CONTENT_ENCODING, "gzip")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -476,8 +504,7 @@ async fn test_encoding() { async fn test_gzip_encoding() { let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -488,7 +515,7 @@ async fn test_gzip_encoding() { let request = srv .post("/") - .header(CONTENT_ENCODING, "gzip") + .append_header((CONTENT_ENCODING, "gzip")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -503,8 +530,7 @@ async fn test_gzip_encoding_large() { let data = STR.repeat(10); let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -515,7 +541,7 @@ async fn test_gzip_encoding_large() { let request = srv .post("/") - .header(CONTENT_ENCODING, "gzip") + .append_header((CONTENT_ENCODING, "gzip")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -535,8 +561,7 @@ async fn test_reading_gzip_encoding_large_random() { let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -547,7 +572,7 @@ async fn test_reading_gzip_encoding_large_random() { let request = srv .post("/") - .header(CONTENT_ENCODING, "gzip") + .append_header((CONTENT_ENCODING, "gzip")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -562,8 +587,7 @@ async fn test_reading_gzip_encoding_large_random() { async fn test_reading_deflate_encoding() { let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -574,7 +598,7 @@ async fn test_reading_deflate_encoding() { // client request let request = srv .post("/") - .header(CONTENT_ENCODING, "deflate") + .append_header((CONTENT_ENCODING, "deflate")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -589,8 +613,7 @@ async fn test_reading_deflate_encoding_large() { let data = STR.repeat(10); let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -601,7 +624,7 @@ async fn test_reading_deflate_encoding_large() { // client request let request = srv .post("/") - .header(CONTENT_ENCODING, "deflate") + .append_header((CONTENT_ENCODING, "deflate")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -621,8 +644,7 @@ async fn test_reading_deflate_encoding_large_random() { let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -633,7 +655,7 @@ async fn test_reading_deflate_encoding_large_random() { // client request let request = srv .post("/") - .header(CONTENT_ENCODING, "deflate") + .append_header((CONTENT_ENCODING, "deflate")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -648,8 +670,7 @@ async fn test_reading_deflate_encoding_large_random() { async fn test_brotli_encoding() { let srv = test::start_with(test::config().h1(), || { App::new().service( - web::resource("/") - .route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); @@ -660,7 +681,7 @@ async fn test_brotli_encoding() { // client request let request = srv .post("/") - .header(CONTENT_ENCODING, "br") + .append_header((CONTENT_ENCODING, "br")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -695,7 +716,7 @@ async fn test_brotli_encoding_large() { // client request let request = srv .post("/") - .header(CONTENT_ENCODING, "br") + .append_header((CONTENT_ENCODING, "br")) .send_body(enc.clone()); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -708,18 +729,8 @@ async fn test_brotli_encoding_large() { #[cfg(feature = "openssl")] #[actix_rt::test] async fn test_brotli_encoding_large_openssl() { - // load ssl keys - use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("tests/cert.pem") - .unwrap(); - let data = STR.repeat(10); - let srv = test::start_with(test::config().openssl(builder.build()), move || { + let srv = test::start_with(test::config().openssl(openssl_config()), move || { App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { HttpResponse::Ok() .encoding(actix_web::http::ContentEncoding::Identity) @@ -735,7 +746,7 @@ async fn test_brotli_encoding_large_openssl() { // client request let mut response = srv .post("/") - .header(actix_web::http::header::CONTENT_ENCODING, "br") + .append_header((actix_web::http::header::CONTENT_ENCODING, "br")) .send_body(enc) .await .unwrap(); @@ -747,110 +758,126 @@ async fn test_brotli_encoding_large_openssl() { } #[cfg(all(feature = "rustls", feature = "openssl"))] -#[actix_rt::test] -async fn test_reading_deflate_encoding_large_random_rustls() { - use rust_tls::internal::pemfile::{certs, pkcs8_private_keys}; - use rust_tls::{NoClientAuth, ServerConfig}; - use std::fs::File; +mod plus_rustls { use std::io::BufReader; - let data = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(160_000) - .map(char::from) - .collect::(); + use rustls::{ + internal::pemfile::{certs, pkcs8_private_keys}, + NoClientAuth, ServerConfig as RustlsServerConfig, + }; - // load ssl keys - let mut config = ServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + use super::*; - let srv = test::start_with(test::config().rustls(config), || { - App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { - HttpResponse::Ok() - .encoding(actix_web::http::ContentEncoding::Identity) - .body(bytes) - }))) - }); + fn rustls_config() -> RustlsServerConfig { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); - // encode data - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut config = RustlsServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(cert_file.as_bytes()); + let key_file = &mut BufReader::new(key_file.as_bytes()); - // client request - let req = srv - .post("/") - .header(actix_web::http::header::CONTENT_ENCODING, "deflate") - .send_stream(TestBody::new(Bytes::from(enc), 1024)); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - let mut response = req.await.unwrap(); - assert!(response.status().is_success()); + config + } - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data)); + #[actix_rt::test] + async fn test_reading_deflate_encoding_large_random_rustls() { + use rustls::internal::pemfile::{certs, pkcs8_private_keys}; + use rustls::{NoClientAuth, ServerConfig}; + use std::fs::File; + use std::io::BufReader; + + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(160_000) + .map(char::from) + .collect::(); + + let srv = test::start_with(test::config().rustls(rustls_config()), || { + App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { + HttpResponse::Ok() + .encoding(actix_web::http::ContentEncoding::Identity) + .body(bytes) + }))) + }); + + // encode data + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + + // client request + let req = srv + .post("/") + .insert_header((actix_web::http::header::CONTENT_ENCODING, "deflate")) + .send_stream(TestBody::new(Bytes::from(enc), 1024)); + + let mut response = req.await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); + } } -// #[test] -// fn test_server_cookies() { -// use actix_web::http; +#[actix_rt::test] +async fn test_server_cookies() { + use actix_web::{http, HttpMessage}; -// let srv = test::TestServer::with_factory(|| { -// App::new().resource("/", |r| { -// r.f(|_| { -// HttpResponse::Ok() -// .cookie( -// http::CookieBuilder::new("first", "first_value") -// .http_only(true) -// .finish(), -// ) -// .cookie(http::Cookie::new("second", "first_value")) -// .cookie(http::Cookie::new("second", "second_value")) -// .finish() -// }) -// }) -// }); + let srv = test::start(|| { + App::new().default_service(web::to(|| { + HttpResponse::Ok() + .cookie( + http::CookieBuilder::new("first", "first_value") + .http_only(true) + .finish(), + ) + .cookie(http::Cookie::new("second", "first_value")) + .cookie(http::Cookie::new("second", "second_value")) + .finish() + })) + }); -// let first_cookie = http::CookieBuilder::new("first", "first_value") -// .http_only(true) -// .finish(); -// let second_cookie = http::Cookie::new("second", "second_value"); + let req = srv.get("/"); + let res = req.send().await.unwrap(); + assert!(res.status().is_success()); -// let request = srv.get("/").finish().unwrap(); -// let response = srv.execute(request.send()).unwrap(); -// assert!(response.status().is_success()); + let first_cookie = http::CookieBuilder::new("first", "first_value") + .http_only(true) + .finish(); + let second_cookie = http::Cookie::new("second", "second_value"); -// let cookies = response.cookies().expect("To have cookies"); -// assert_eq!(cookies.len(), 2); -// if cookies[0] == first_cookie { -// assert_eq!(cookies[1], second_cookie); -// } else { -// assert_eq!(cookies[0], second_cookie); -// assert_eq!(cookies[1], first_cookie); -// } + let cookies = res.cookies().expect("To have cookies"); + assert_eq!(cookies.len(), 2); + if cookies[0] == first_cookie { + assert_eq!(cookies[1], second_cookie); + } else { + assert_eq!(cookies[0], second_cookie); + assert_eq!(cookies[1], first_cookie); + } -// let first_cookie = first_cookie.to_string(); -// let second_cookie = second_cookie.to_string(); -// //Check that we have exactly two instances of raw cookie headers -// let cookies = response -// .headers() -// .get_all(http::header::SET_COOKIE) -// .iter() -// .map(|header| header.to_str().expect("To str").to_string()) -// .collect::>(); -// assert_eq!(cookies.len(), 2); -// if cookies[0] == first_cookie { -// assert_eq!(cookies[1], second_cookie); -// } else { -// assert_eq!(cookies[0], second_cookie); -// assert_eq!(cookies[1], first_cookie); -// } -// } + let first_cookie = first_cookie.to_string(); + let second_cookie = second_cookie.to_string(); + // Check that we have exactly two instances of raw cookie headers + let cookies = res + .headers() + .get_all(http::header::SET_COOKIE) + .map(|header| header.to_str().expect("To str").to_string()) + .collect::>(); + assert_eq!(cookies.len(), 2); + if cookies[0] == first_cookie { + assert_eq!(cookies[1], second_cookie); + } else { + assert_eq!(cookies[0], second_cookie); + assert_eq!(cookies[1], first_cookie); + } +} #[actix_rt::test] async fn test_slow_request() { @@ -877,36 +904,9 @@ async fn test_normalize() { let srv = test::start_with(test::config().h1(), || { App::new() .wrap(NormalizePath::new(TrailingSlash::Trim)) - .service( - web::resource("/one").route(web::to(|| HttpResponse::Ok().finish())), - ) + .service(web::resource("/one").route(web::to(|| HttpResponse::Ok().finish()))) }); let response = srv.get("/one/").send().await.unwrap(); assert!(response.status().is_success()); } - -// #[cfg(feature = "openssl")] -// #[actix_rt::test] -// async fn test_ssl_handshake_timeout() { -// use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; -// use std::net; - -// // load ssl keys -// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); -// builder -// .set_private_key_file("tests/key.pem", SslFiletype::PEM) -// .unwrap(); -// builder -// .set_certificate_chain_file("tests/cert.pem") -// .unwrap(); - -// let srv = test::start_with(test::config().openssl(builder.build()), || { -// App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))) -// }); - -// let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); -// let mut data = String::new(); -// let _ = stream.read_to_string(&mut data); -// assert!(data.is_empty()); -// } diff --git a/tests/test_weird_poll.rs b/tests/test_weird_poll.rs index 7e4300901..5844ea2c2 100644 --- a/tests/test_weird_poll.rs +++ b/tests/test_weird_poll.rs @@ -1,30 +1,30 @@ -// Regression test for #/1321 +//! Regression test for https://github.com/actix/actix-web/issues/1321 -/* -use futures::task::{noop_waker, Context}; -use futures::stream::once; -use actix_http::body::{MessageBody, BodyStream}; -use bytes::Bytes; +// use actix_http::body::{BodyStream, MessageBody}; +// use bytes::Bytes; +// use futures_channel::oneshot; +// use futures_util::{ +// stream::once, +// task::{noop_waker, Context}, +// }; -Disable weird poll until actix-web is based on actix-http 2.0.0 +// #[test] +// fn weird_poll() { +// let (sender, receiver) = oneshot::channel(); +// let mut body_stream = Ok(BodyStream::new(once(async { +// let x = Box::new(0); +// let y = &x; +// receiver.await.unwrap(); +// let _z = **y; +// Ok::<_, ()>(Bytes::new()) +// }))); -#[test] -fn weird_poll() { - let (sender, receiver) = futures::channel::oneshot::channel(); - let mut body_stream = Ok(BodyStream::new(once(async { - let x = Box::new(0); - let y = &x; - receiver.await.unwrap(); - let _z = **y; - Ok::<_, ()>(Bytes::new()) - }))); +// let waker = noop_waker(); +// let mut cx = Context::from_waker(&waker); - let waker = noop_waker(); - let mut context = Context::from_waker(&waker); - - let _ = body_stream.as_mut().unwrap().poll_next(&mut context); - sender.send(()).unwrap(); - let _ = std::mem::replace(&mut body_stream, Err([0; 32])).unwrap().poll_next(&mut context); -} - -*/ +// let _ = body_stream.as_mut().unwrap().poll_next(&mut cx); +// sender.send(()).unwrap(); +// let _ = std::mem::replace(&mut body_stream, Err([0; 32])) +// .unwrap() +// .poll_next(&mut cx); +// }

{ self.method(), self.path() )?; + if let Some(q) = self.uri().query().as_ref() { writeln!(f, " query: ?{:?}", q)?; } + writeln!(f, " headers:")?; - for (key, val) in self.headers() { + + for (key, val) in self.headers().iter() { writeln!(f, " {:?}: {:?}", key, val)?; } + Ok(()) } } diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs index df2f5be50..f96f8f9b6 100644 --- a/actix-http/src/response.rs +++ b/actix-http/src/response.rs @@ -1,23 +1,32 @@ -//! Http response -use std::cell::{Ref, RefMut}; -use std::convert::TryFrom; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{fmt, str}; +//! HTTP responses. + +use std::{ + cell::{Ref, RefMut}, + convert::TryInto, + fmt, + future::Future, + ops, + pin::Pin, + str, + task::{Context, Poll}, +}; use bytes::{Bytes, BytesMut}; use futures_core::Stream; use serde::Serialize; use crate::body::{Body, BodyStream, MessageBody, ResponseBody}; -use crate::cookie::{Cookie, CookieJar}; use crate::error::Error; use crate::extensions::Extensions; -use crate::header::{Header, IntoHeaderValue}; -use crate::http::header::{self, HeaderName, HeaderValue}; +use crate::header::{IntoHeaderPair, IntoHeaderValue}; +use crate::http::header::{self, HeaderName}; use crate::http::{Error as HttpError, HeaderMap, StatusCode}; use crate::message::{BoxedResponseHead, ConnectionType, ResponseHead}; +#[cfg(feature = "cookies")] +use crate::{ + cookie::{Cookie, CookieJar}, + http::header::HeaderValue, +}; /// An HTTP Response pub struct Response { @@ -27,13 +36,13 @@ pub struct Response { } impl Response { - /// Create http response builder with specific status. + /// Create HTTP response builder with specific status. #[inline] pub fn build(status: StatusCode) -> ResponseBuilder { ResponseBuilder::new(status) } - /// Create http response builder + /// Create HTTP response builder #[inline] pub fn build_from>(source: T) -> ResponseBuilder { source.into() @@ -92,7 +101,7 @@ impl Response { } #[inline] - /// Mutable reference to a http message part of the response + /// Mutable reference to a HTTP message part of the response pub fn head_mut(&mut self) -> &mut ResponseHead { &mut *self.head } @@ -128,6 +137,7 @@ impl Response { } /// Get an iterator for the cookies set by this response + #[cfg(feature = "cookies")] #[inline] pub fn cookies(&self) -> CookieIter<'_> { CookieIter { @@ -136,6 +146,7 @@ impl Response { } /// Add a cookie to this response + #[cfg(feature = "cookies")] #[inline] pub fn add_cookie(&mut self, cookie: &Cookie<'_>) -> Result<(), HttpError> { let h = &mut self.head.headers; @@ -148,6 +159,7 @@ impl Response { /// Remove all cookies with the given name from this response. Returns /// the number of cookies removed. + #[cfg(feature = "cookies")] #[inline] pub fn del_cookie(&mut self, name: &str) -> usize { let h = &mut self.head.headers; @@ -293,10 +305,12 @@ impl Future for Response { } } +#[cfg(feature = "cookies")] pub struct CookieIter<'a> { iter: header::GetAll<'a>, } +#[cfg(feature = "cookies")] impl<'a> Iterator for CookieIter<'a> { type Item = Cookie<'a>; @@ -311,13 +325,13 @@ impl<'a> Iterator for CookieIter<'a> { } } -/// An HTTP response builder +/// An HTTP response builder. /// -/// This type can be used to construct an instance of `Response` through a -/// builder-like pattern. +/// This type can be used to construct an instance of `Response` through a builder-like pattern. pub struct ResponseBuilder { head: Option, err: Option, + #[cfg(feature = "cookies")] cookies: Option, } @@ -328,6 +342,7 @@ impl ResponseBuilder { ResponseBuilder { head: Some(BoxedResponseHead::new(status)), err: None, + #[cfg(feature = "cookies")] cookies: None, } } @@ -341,93 +356,98 @@ impl ResponseBuilder { self } - /// Set a header. + /// Insert a header, replacing any that were set with an equivalent field name. /// /// ```rust - /// use actix_http::{http, Request, Response, Result}; + /// # use actix_http::Response; + /// use actix_http::http::header::ContentType; /// - /// fn index(req: Request) -> Result { - /// Ok(Response::Ok() - /// .set(http::header::IfModifiedSince( - /// "Sun, 07 Nov 1994 08:48:37 GMT".parse()?, - /// )) - /// .finish()) - /// } + /// Response::Ok() + /// .insert_header(ContentType(mime::APPLICATION_JSON)) + /// .insert_header(("X-TEST", "value")) + /// .finish(); /// ``` - #[doc(hidden)] - pub fn set(&mut self, hdr: H) -> &mut Self { + pub fn insert_header(&mut self, header: H) -> &mut Self + where + H: IntoHeaderPair, + { if let Some(parts) = parts(&mut self.head, &self.err) { - match hdr.try_into() { - Ok(value) => { - parts.headers.append(H::name(), value); + match header.try_into_header_pair() { + Ok((key, value)) => { + parts.headers.insert(key, value); } Err(e) => self.err = Some(e.into()), - } + }; } + self } - /// Append a header to existing headers. + /// Append a header, keeping any that were set with an equivalent field name. /// /// ```rust - /// use actix_http::{http, Request, Response}; + /// # use actix_http::Response; + /// use actix_http::http::header::ContentType; /// - /// fn index(req: Request) -> Response { - /// Response::Ok() - /// .header("X-TEST", "value") - /// .header(http::header::CONTENT_TYPE, "application/json") - /// .finish() - /// } + /// Response::Ok() + /// .append_header(ContentType(mime::APPLICATION_JSON)) + /// .append_header(("X-TEST", "value1")) + /// .append_header(("X-TEST", "value2")) + /// .finish(); /// ``` - pub fn header(&mut self, key: K, value: V) -> &mut Self + pub fn append_header(&mut self, header: H) -> &mut Self where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, + H: IntoHeaderPair, { if let Some(parts) = parts(&mut self.head, &self.err) { - match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => { - parts.headers.append(key, value); - } - Err(e) => self.err = Some(e.into()), - }, + match header.try_into_header_pair() { + Ok((key, value)) => parts.headers.append(key, value), Err(e) => self.err = Some(e.into()), }; } + self } - /// Set a header. - /// - /// ```rust - /// use actix_http::{http, Request, Response}; - /// - /// fn index(req: Request) -> Response { - /// Response::Ok() - /// .set_header("X-TEST", "value") - /// .set_header(http::header::CONTENT_TYPE, "application/json") - /// .finish() - /// } - /// ``` + /// Replaced with [`Self::insert_header()`]. + #[deprecated = "Replaced with `insert_header((key, value))`."] pub fn set_header(&mut self, key: K, value: V) -> &mut Self where - HeaderName: TryFrom, - >::Error: Into, + K: TryInto, + K::Error: Into, V: IntoHeaderValue, { - if let Some(parts) = parts(&mut self.head, &self.err) { - match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => { - parts.headers.insert(key, value); - } - Err(e) => self.err = Some(e.into()), - }, - Err(e) => self.err = Some(e.into()), - }; + if self.err.is_some() { + return self; } + + match (key.try_into(), value.try_into_value()) { + (Ok(name), Ok(value)) => return self.insert_header((name, value)), + (Err(err), _) => self.err = Some(err.into()), + (_, Err(err)) => self.err = Some(err.into()), + } + + self + } + + /// Replaced with [`Self::append_header()`]. + #[deprecated = "Replaced with `append_header((key, value))`."] + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + K: TryInto, + K::Error: Into, + V: IntoHeaderValue, + { + if self.err.is_some() { + return self; + } + + match (key.try_into(), value.try_into_value()) { + (Ok(name), Ok(value)) => return self.append_header((name, value)), + (Err(err), _) => self.err = Some(err.into()), + (_, Err(err)) => self.err = Some(err.into()), + } + self } @@ -458,7 +478,12 @@ impl ResponseBuilder { if let Some(parts) = parts(&mut self.head, &self.err) { parts.set_connection_type(ConnectionType::Upgrade); } - self.set_header(header::UPGRADE, value) + + if let Ok(value) = value.try_into_value() { + self.insert_header((header::UPGRADE, value)); + } + + self } /// Force close connection, even if it is marked as keep-alive @@ -473,7 +498,7 @@ impl ResponseBuilder { /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. #[inline] pub fn no_chunking(&mut self, len: u64) -> &mut Self { - self.header(header::CONTENT_LENGTH, len); + self.insert_header((header::CONTENT_LENGTH, len)); if let Some(parts) = parts(&mut self.head, &self.err) { parts.no_chunking(true); @@ -481,15 +506,14 @@ impl ResponseBuilder { self } - /// Set response content type + /// Set response content type. #[inline] pub fn content_type(&mut self, value: V) -> &mut Self where - HeaderValue: TryFrom, - >::Error: Into, + V: IntoHeaderValue, { if let Some(parts) = parts(&mut self.head, &self.err) { - match HeaderValue::try_from(value) { + match value.try_into_value() { Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); } @@ -517,6 +541,7 @@ impl ResponseBuilder { /// .finish() /// } /// ``` + #[cfg(feature = "cookies")] pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { if self.cookies.is_none() { let mut jar = CookieJar::new(); @@ -543,6 +568,7 @@ impl ResponseBuilder { /// builder.finish() /// } /// ``` + #[cfg(feature = "cookies")] pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { if self.cookies.is_none() { self.cookies = Some(CookieJar::new()) @@ -610,8 +636,11 @@ impl ResponseBuilder { return Response::from(Error::from(e)).into_body(); } + // allow unused mut when cookies feature is disabled + #[allow(unused_mut)] let mut response = self.head.take().expect("cannot reuse response builder"); + #[cfg(feature = "cookies")] if let Some(ref jar) = self.cookies { for cookie in jar.delta() { match HeaderValue::from_str(&cookie.to_string()) { @@ -640,27 +669,24 @@ impl ResponseBuilder { self.body(Body::from_message(BodyStream::new(stream))) } - #[inline] /// Set a json body and generate `Response` /// /// `ResponseBuilder` can not be used after this call. - pub fn json(&mut self, value: T) -> Response { - self.json2(&value) - } - - /// Set a json body and generate `Response` - /// - /// `ResponseBuilder` can not be used after this call. - pub fn json2(&mut self, value: &T) -> Response { - match serde_json::to_string(value) { + pub fn json(&mut self, value: T) -> Response + where + T: ops::Deref, + T::Target: Serialize, + { + match serde_json::to_string(&*value) { Ok(body) => { let contains = if let Some(parts) = parts(&mut self.head, &self.err) { parts.headers.contains_key(header::CONTENT_TYPE) } else { true }; + if !contains { - self.header(header::CONTENT_TYPE, "application/json"); + self.insert_header(header::ContentType(mime::APPLICATION_JSON)); } self.body(Body::from(body)) @@ -682,6 +708,7 @@ impl ResponseBuilder { ResponseBuilder { head: self.head.take(), err: self.err.take(), + #[cfg(feature = "cookies")] cookies: self.cookies.take(), } } @@ -701,21 +728,28 @@ fn parts<'a>( /// Convert `Response` to a `ResponseBuilder`. Body get dropped. impl From> for ResponseBuilder { fn from(res: Response) -> ResponseBuilder { - // If this response has cookies, load them into a jar - let mut jar: Option = None; - for c in res.cookies() { - if let Some(ref mut j) = jar { - j.add_original(c.into_owned()); - } else { - let mut j = CookieJar::new(); - j.add_original(c.into_owned()); - jar = Some(j); + #[cfg(feature = "cookies")] + let jar = { + // If this response has cookies, load them into a jar + let mut jar: Option = None; + + for c in res.cookies() { + if let Some(ref mut j) = jar { + j.add_original(c.into_owned()); + } else { + let mut j = CookieJar::new(); + j.add_original(c.into_owned()); + jar = Some(j); + } } - } + + jar + }; ResponseBuilder { head: Some(res.head), err: None, + #[cfg(feature = "cookies")] cookies: jar, } } @@ -724,33 +758,42 @@ impl From> for ResponseBuilder { /// Convert `ResponseHead` to a `ResponseBuilder` impl<'a> From<&'a ResponseHead> for ResponseBuilder { fn from(head: &'a ResponseHead) -> ResponseBuilder { - // If this response has cookies, load them into a jar - let mut jar: Option = None; - - let cookies = CookieIter { - iter: head.headers.get_all(header::SET_COOKIE), - }; - for c in cookies { - if let Some(ref mut j) = jar { - j.add_original(c.into_owned()); - } else { - let mut j = CookieJar::new(); - j.add_original(c.into_owned()); - jar = Some(j); - } - } - let mut msg = BoxedResponseHead::new(head.status); msg.version = head.version; msg.reason = head.reason; - for (k, v) in &head.headers { + + for (k, v) in head.headers.iter() { msg.headers.append(k.clone(), v.clone()); } + msg.no_chunking(!head.chunked()); + #[cfg(feature = "cookies")] + let jar = { + // If this response has cookies, load them into a jar + let mut jar: Option = None; + + let cookies = CookieIter { + iter: head.headers.get_all(header::SET_COOKIE), + }; + + for c in cookies { + if let Some(ref mut j) = jar { + j.add_original(c.into_owned()); + } else { + let mut j = CookieJar::new(); + j.add_original(c.into_owned()); + jar = Some(j); + } + } + + jar + }; + ResponseBuilder { head: Some(msg), err: None, + #[cfg(feature = "cookies")] cookies: jar, } } @@ -802,7 +845,7 @@ impl From for Response { impl From<&'static str> for Response { fn from(val: &'static str) -> Self { Response::Ok() - .content_type("text/plain; charset=utf-8") + .content_type(mime::TEXT_PLAIN_UTF_8) .body(val) } } @@ -810,7 +853,7 @@ impl From<&'static str> for Response { impl From<&'static [u8]> for Response { fn from(val: &'static [u8]) -> Self { Response::Ok() - .content_type("application/octet-stream") + .content_type(mime::APPLICATION_OCTET_STREAM) .body(val) } } @@ -818,7 +861,7 @@ impl From<&'static [u8]> for Response { impl From for Response { fn from(val: String) -> Self { Response::Ok() - .content_type("text/plain; charset=utf-8") + .content_type(mime::TEXT_PLAIN_UTF_8) .body(val) } } @@ -826,7 +869,7 @@ impl From for Response { impl<'a> From<&'a String> for Response { fn from(val: &'a String) -> Self { Response::Ok() - .content_type("text/plain; charset=utf-8") + .content_type(mime::TEXT_PLAIN_UTF_8) .body(val) } } @@ -834,7 +877,7 @@ impl<'a> From<&'a String> for Response { impl From for Response { fn from(val: Bytes) -> Self { Response::Ok() - .content_type("application/octet-stream") + .content_type(mime::APPLICATION_OCTET_STREAM) .body(val) } } @@ -842,34 +885,37 @@ impl From for Response { impl From for Response { fn from(val: BytesMut) -> Self { Response::Ok() - .content_type("application/octet-stream") + .content_type(mime::APPLICATION_OCTET_STREAM) .body(val) } } #[cfg(test)] mod tests { + use serde_json::json; + use super::*; use crate::body::Body; - use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE, SET_COOKIE}; + use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE}; + #[cfg(feature = "cookies")] + use crate::{http::header::SET_COOKIE, HttpMessage}; #[test] fn test_debug() { let resp = Response::Ok() - .header(COOKIE, HeaderValue::from_static("cookie1=value1; ")) - .header(COOKIE, HeaderValue::from_static("cookie2=value2; ")) + .append_header((COOKIE, HeaderValue::from_static("cookie1=value1; "))) + .append_header((COOKIE, HeaderValue::from_static("cookie2=value2; "))) .finish(); let dbg = format!("{:?}", resp); assert!(dbg.contains("Response")); } + #[cfg(feature = "cookies")] #[test] fn test_response_cookies() { - use crate::httpmessage::HttpMessage; - let req = crate::test::TestRequest::default() - .header(COOKIE, "cookie1=value1") - .header(COOKIE, "cookie2=value2") + .append_header((COOKIE, "cookie1=value1")) + .append_header((COOKIE, "cookie2=value2")) .finish(); let cookies = req.cookies().unwrap(); @@ -882,22 +928,27 @@ mod tests { .max_age(time::Duration::days(1)) .finish(), ) - .del_cookie(&cookies[1]) + .del_cookie(&cookies[0]) .finish(); - let mut val: Vec<_> = resp + let mut val = resp .headers() .get_all(SET_COOKIE) .map(|v| v.to_str().unwrap().to_owned()) - .collect(); + .collect::>(); val.sort(); + + // the .del_cookie call assert!(val[0].starts_with("cookie1=; Max-Age=0;")); + + // the .cookie call assert_eq!( val[1], "name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400" ); } + #[cfg(feature = "cookies")] #[test] fn test_update_response_cookies() { let mut r = Response::Ok() @@ -916,14 +967,14 @@ mod tests { let mut iter = r.cookies(); let v = iter.next().unwrap(); - assert_eq!((v.name(), v.value()), ("cookie3", "val300")); - let v = iter.next().unwrap(); assert_eq!((v.name(), v.value()), ("original", "val100")); + let v = iter.next().unwrap(); + assert_eq!((v.name(), v.value()), ("cookie3", "val300")); } #[test] fn test_basic_builder() { - let resp = Response::Ok().header("X-TEST", "value").finish(); + let resp = Response::Ok().insert_header(("X-TEST", "value")).finish(); assert_eq!(resp.status(), StatusCode::OK); } @@ -964,26 +1015,8 @@ mod tests { #[test] fn test_json_ct() { let resp = Response::build(StatusCode::OK) - .header(CONTENT_TYPE, "text/json") - .json(vec!["v1", "v2", "v3"]); - let ct = resp.headers().get(CONTENT_TYPE).unwrap(); - assert_eq!(ct, HeaderValue::from_static("text/json")); - assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); - } - - #[test] - fn test_json2() { - let resp = Response::build(StatusCode::OK).json2(&vec!["v1", "v2", "v3"]); - let ct = resp.headers().get(CONTENT_TYPE).unwrap(); - assert_eq!(ct, HeaderValue::from_static("application/json")); - assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); - } - - #[test] - fn test_json2_ct() { - let resp = Response::build(StatusCode::OK) - .header(CONTENT_TYPE, "text/json") - .json2(&vec!["v1", "v2", "v3"]); + .insert_header((CONTENT_TYPE, "text/json")) + .json(&vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("text/json")); assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); @@ -1067,6 +1100,7 @@ mod tests { assert_eq!(resp.body().get_ref(), b"test"); } + #[cfg(feature = "cookies")] #[test] fn test_into_builder() { let mut resp: Response = "test".into(); @@ -1082,4 +1116,54 @@ mod tests { let cookie = resp.cookies().next().unwrap(); assert_eq!((cookie.name(), cookie.value()), ("cookie1", "val100")); } + + #[test] + fn response_builder_header_insert_kv() { + let mut res = Response::Ok(); + res.insert_header(("Content-Type", "application/octet-stream")); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_insert_typed() { + let mut res = Response::Ok(); + res.insert_header(header::ContentType(mime::APPLICATION_OCTET_STREAM)); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_append_kv() { + let mut res = Response::Ok(); + res.append_header(("Content-Type", "application/octet-stream")); + res.append_header(("Content-Type", "application/json")); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } + + #[test] + fn response_builder_header_append_typed() { + let mut res = Response::Ok(); + res.append_header(header::ContentType(mime::APPLICATION_OCTET_STREAM)); + res.append_header(header::ContentType(mime::APPLICATION_JSON)); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } } diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index 527ed3833..fee26dcc3 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -8,36 +8,34 @@ use actix_rt::net::TcpStream; use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use bytes::Bytes; use futures_core::{ready, Future}; -use futures_util::future::ok; use h2::server::{self, Handshake}; use pin_project::pin_project; use crate::body::MessageBody; use crate::builder::HttpServiceBuilder; -use crate::cloneable::CloneableService; use crate::config::{KeepAlive, ServiceConfig}; use crate::error::{DispatchError, Error}; use crate::request::Request; use crate::response::Response; -use crate::{h1, h2::Dispatcher, ConnectCallback, Extensions, Protocol}; +use crate::{h1, h2::Dispatcher, ConnectCallback, OnConnectData, Protocol}; /// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol. -pub struct HttpService> { +pub struct HttpService { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, on_connect_ext: Option>>, - _t: PhantomData<(T, B)>, + _phantom: PhantomData, } impl HttpService where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { /// Create builder for `HttpService` instance. @@ -48,15 +46,15 @@ where impl HttpService where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { /// Create new `HttpService` instance. - pub fn new>(service: F) -> Self { + pub fn new>(service: F) -> Self { let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0, false, None); HttpService { @@ -65,12 +63,12 @@ where expect: h1::ExpectHandler, upgrade: None, on_connect_ext: None, - _t: PhantomData, + _phantom: PhantomData, } } /// Create new `HttpService` instance with config. - pub(crate) fn with_config>( + pub(crate) fn with_config>( cfg: ServiceConfig, service: F, ) -> Self { @@ -80,18 +78,18 @@ where expect: h1::ExpectHandler, upgrade: None, on_connect_ext: None, - _t: PhantomData, + _phantom: PhantomData, } } } impl HttpService where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody, { /// Provide service for `EXPECT: 100-Continue` support. @@ -101,10 +99,10 @@ where /// request will be forwarded to main service. pub fn expect(self, expect: X1) -> HttpService where - X1: ServiceFactory, + X1: ServiceFactory, X1::Error: Into, X1::InitError: fmt::Debug, - ::Future: 'static, + >::Future: 'static, { HttpService { expect, @@ -112,7 +110,7 @@ where srv: self.srv, upgrade: self.upgrade, on_connect_ext: self.on_connect_ext, - _t: PhantomData, + _phantom: PhantomData, } } @@ -122,14 +120,10 @@ where /// and this service get called with original request and framed object. pub fn upgrade(self, upgrade: Option) -> HttpService where - U1: ServiceFactory< - Config = (), - Request = (Request, Framed), - Response = (), - >, + U1: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U1::Error: fmt::Display, U1::InitError: fmt::Debug, - ::Future: 'static, + )>>::Future: 'static, { HttpService { upgrade, @@ -137,7 +131,7 @@ where srv: self.srv, expect: self.expect, on_connect_ext: self.on_connect_ext, - _t: PhantomData, + _phantom: PhantomData, } } @@ -150,38 +144,38 @@ where impl HttpService where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - ::Future: 'static, + >::Future: 'static, U: ServiceFactory< + (Request, Framed), Config = (), - Request = (Request, Framed), Response = (), >, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, - ::Future: 'static, + )>>::Future: 'static, { /// Create simple tcp stream service pub fn tcp( self, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), Error = DispatchError, InitError = (), > { - pipeline_factory(|io: TcpStream| { + pipeline_factory(|io: TcpStream| async { let peer_addr = io.peer_addr().ok(); - ok((io, Protocol::Http1, peer_addr)) + Ok((io, Protocol::Http1, peer_addr)) }) .and_then(self) } @@ -190,39 +184,40 @@ where #[cfg(feature = "openssl")] mod openssl { use super::*; - use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream}; - use actix_tls::{openssl::HandshakeError, TlsError}; + use actix_service::ServiceFactoryExt; + use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; + use actix_tls::accept::TlsError; impl HttpService, S, B, X, U> where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - ::Future: 'static, + >::Future: 'static, U: ServiceFactory< + (Request, Framed, h1::Codec>), Config = (), - Request = (Request, Framed, h1::Codec>), Response = (), >, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, - ::Future: 'static, + , h1::Codec>)>>::Future: 'static, { /// Create openssl based service pub fn openssl( self, acceptor: SslAcceptor, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), - Error = TlsError, DispatchError>, + Error = TlsError, InitError = (), > { pipeline_factory( @@ -230,7 +225,7 @@ mod openssl { .map_err(TlsError::Tls) .map_init_err(|_| panic!()), ) - .and_then(|io: SslStream| { + .and_then(|io: SslStream| async { let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { if protos.windows(2).any(|window| window == b"h2") { Protocol::Http2 @@ -241,7 +236,7 @@ mod openssl { Protocol::Http1 }; let peer_addr = io.get_ref().peer_addr().ok(); - ok((io, proto, peer_addr)) + Ok((io, proto, peer_addr)) }) .and_then(self.map_err(TlsError::Service)) } @@ -250,39 +245,42 @@ mod openssl { #[cfg(feature = "rustls")] mod rustls { - use super::*; - use actix_tls::rustls::{Acceptor, ServerConfig, Session, TlsStream}; - use actix_tls::TlsError; use std::io; + use actix_tls::accept::rustls::{Acceptor, ServerConfig, Session, TlsStream}; + use actix_tls::accept::TlsError; + + use super::*; + use actix_service::ServiceFactoryExt; + impl HttpService, S, B, X, U> where - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - ::Future: 'static, + >::Future: 'static, U: ServiceFactory< + (Request, Framed, h1::Codec>), Config = (), - Request = (Request, Framed, h1::Codec>), Response = (), >, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, - ::Future: 'static, + , h1::Codec>)>>::Future: 'static, { /// Create openssl based service pub fn rustls( self, mut config: ServerConfig, ) -> impl ServiceFactory< + TcpStream, Config = (), - Request = TcpStream, Response = (), Error = TlsError, InitError = (), @@ -295,7 +293,7 @@ mod rustls { .map_err(TlsError::Tls) .map_init_err(|_| panic!()), ) - .and_then(|io: TlsStream| { + .and_then(|io: TlsStream| async { let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() { if protos.windows(2).any(|window| window == b"h2") { Protocol::Http2 @@ -306,41 +304,37 @@ mod rustls { Protocol::Http1 }; let peer_addr = io.get_ref().0.peer_addr().ok(); - ok((io, proto, peer_addr)) + Ok((io, proto, peer_addr)) }) .and_then(self.map_err(TlsError::Service)) } } } -impl ServiceFactory for HttpService +impl ServiceFactory<(T, Protocol, Option)> + for HttpService where T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - ::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, Framed), - Response = (), - >, + >::Future: 'static, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Error: fmt::Display + Into, U::InitError: fmt::Debug, - ::Future: 'static, + )>>::Future: 'static, { - type Config = (); - type Request = (T, Protocol, Option); type Response = (); type Error = DispatchError; - type InitError = (); + type Config = (); type Service = HttpServiceHandler; + type InitError = (); type Future = HttpServiceResponse; fn new_service(&self, _: ()) -> Self::Future { @@ -352,20 +346,19 @@ where upgrade: None, on_connect_ext: self.on_connect_ext.clone(), cfg: self.cfg.clone(), - _t: PhantomData, + _phantom: PhantomData, } } } #[doc(hidden)] #[pin_project] -pub struct HttpServiceResponse< - T, - S: ServiceFactory, - B, - X: ServiceFactory, - U: ServiceFactory, -> { +pub struct HttpServiceResponse +where + S: ServiceFactory, + X: ServiceFactory, + U: ServiceFactory<(Request, Framed)>, +{ #[pin] fut: S::Future, #[pin] @@ -376,26 +369,26 @@ pub struct HttpServiceResponse< upgrade: Option, on_connect_ext: Option>>, cfg: ServiceConfig, - _t: PhantomData<(T, B)>, + _phantom: PhantomData, } impl Future for HttpServiceResponse where T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, + S: ServiceFactory, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, - X: ServiceFactory, + X: ServiceFactory, X::Error: Into, X::InitError: fmt::Debug, - ::Future: 'static, - U: ServiceFactory), Response = ()>, + >::Future: 'static, + U: ServiceFactory<(Request, Framed), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, - ::Future: 'static, + )>>::Future: 'static, { type Output = Result, ()>; @@ -418,7 +411,7 @@ where .map_err(|e| log::error!("Init http service error: {:?}", e)))?; this = self.as_mut().project(); *this.upgrade = Some(upgrade); - this.fut_ex.set(None); + this.fut_upg.set(None); } let result = ready!(this @@ -439,31 +432,51 @@ where } } -/// `Service` implementation for http transport -pub struct HttpServiceHandler { - srv: CloneableService, - expect: CloneableService, - upgrade: Option>, +/// `Service` implementation for HTTP transport +pub struct HttpServiceHandler +where + S: Service, + X: Service, + U: Service<(Request, Framed)>, +{ + flow: Rc>, cfg: ServiceConfig, on_connect_ext: Option>>, - _t: PhantomData<(T, B, X)>, + _phantom: PhantomData, +} + +/// A collection of services that describe an HTTP request flow. +pub(super) struct HttpFlow { + pub(super) service: S, + pub(super) expect: X, + pub(super) upgrade: Option, +} + +impl HttpFlow { + pub(super) fn new(service: S, expect: X, upgrade: Option) -> Rc { + Rc::new(Self { + service, + expect, + upgrade, + }) + } } impl HttpServiceHandler where - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { fn new( cfg: ServiceConfig, - srv: S, + service: S, expect: X, upgrade: Option, on_connect_ext: Option>>, @@ -471,34 +484,33 @@ where HttpServiceHandler { cfg, on_connect_ext, - srv: CloneableService::new(srv), - expect: CloneableService::new(expect), - upgrade: upgrade.map(CloneableService::new), - _t: PhantomData, + flow: HttpFlow::new(service, expect, upgrade), + _phantom: PhantomData, } } } -impl Service for HttpServiceHandler +impl Service<(T, Protocol, Option)> + for HttpServiceHandler where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display + Into, { - type Request = (T, Protocol, Option); type Response = (); type Error = DispatchError; type Future = HttpServiceHandlerResponse; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let ready = self + .flow .expect .poll_ready(cx) .map_err(|e| { @@ -509,7 +521,8 @@ where .is_ready(); let ready = self - .srv + .flow + .service .poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -519,7 +532,7 @@ where .is_ready() && ready; - let ready = if let Some(ref mut upg) = self.upgrade { + let ready = if let Some(ref upg) = self.flow.upgrade { upg.poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -539,20 +552,20 @@ where } } - fn call(&mut self, (io, proto, peer_addr): Self::Request) -> Self::Future { - let mut connect_extensions = Extensions::new(); - - if let Some(ref handler) = self.on_connect_ext { - handler(&io, &mut connect_extensions); - } + fn call( + &self, + (io, proto, peer_addr): (T, Protocol, Option), + ) -> Self::Future { + let on_connect_data = + OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); match proto { Protocol::Http2 => HttpServiceHandlerResponse { state: State::H2Handshake(Some(( server::handshake(io), self.cfg.clone(), - self.srv.clone(), - connect_extensions, + self.flow.clone(), + on_connect_data, peer_addr, ))), }, @@ -561,13 +574,13 @@ where state: State::H1(h1::Dispatcher::new( io, self.cfg.clone(), - self.srv.clone(), - self.expect.clone(), - self.upgrade.clone(), - connect_extensions, + self.flow.clone(), + on_connect_data, peer_addr, )), }, + + proto => unimplemented!("Unsupported HTTP version: {:?}.", proto), } } } @@ -575,24 +588,24 @@ where #[pin_project(project = StateProj)] enum State where - S: Service, + S: Service, S::Future: 'static, S::Error: Into, T: AsyncRead + AsyncWrite + Unpin, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { H1(#[pin] h1::Dispatcher), - H2(#[pin] Dispatcher), + H2(#[pin] Dispatcher), H2Handshake( Option<( Handshake, ServiceConfig, - CloneableService, - Extensions, + Rc>, + OnConnectData, Option, )>, ), @@ -602,14 +615,14 @@ where pub struct HttpServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { #[pin] @@ -619,65 +632,42 @@ where impl Future for HttpServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, - S: Service, + S: Service, S::Error: Into + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody, - X: Service, + X: Service, X::Error: Into, - U: Service), Response = ()>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { type Output = Result<(), DispatchError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().state.poll(cx) - } -} - -impl State -where - T: AsyncRead + AsyncWrite + Unpin, - S: Service, - S::Error: Into + 'static, - S::Response: Into> + 'static, - B: MessageBody + 'static, - X: Service, - X::Error: Into, - U: Service), Response = ()>, - U::Error: fmt::Display, -{ - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.as_mut().project() { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().state.project() { StateProj::H1(disp) => disp.poll(cx), StateProj::H2(disp) => disp.poll(cx), - StateProj::H2Handshake(ref mut data) => { - let conn = if let Some(ref mut item) = data { - match Pin::new(&mut item.0).poll(cx) { - Poll::Ready(Ok(conn)) => conn, - Poll::Ready(Err(err)) => { - trace!("H2 handshake error: {}", err); - return Poll::Ready(Err(err.into())); - } - Poll::Pending => return Poll::Pending, + StateProj::H2Handshake(data) => { + match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) { + Ok(conn) => { + let (_, cfg, srv, on_connect_data, peer_addr) = + data.take().unwrap(); + self.as_mut().project().state.set(State::H2(Dispatcher::new( + srv, + conn, + on_connect_data, + cfg, + None, + peer_addr, + ))); + self.poll(cx) } - } else { - panic!() - }; - let (_, cfg, srv, on_connect_data, peer_addr) = data.take().unwrap(); - self.set(State::H2(Dispatcher::new( - srv, - conn, - on_connect_data, - cfg, - None, - peer_addr, - ))); - self.poll(cx) + Err(err) => { + trace!("H2 handshake error: {}", err); + Poll::Ready(Err(err.into())) + } + } } } } diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs index 4512e72c2..870a656df 100644 --- a/actix-http/src/test.rs +++ b/actix-http/src/test.rs @@ -2,7 +2,6 @@ use std::{ cell::{Ref, RefCell}, - convert::TryFrom, io::{self, Read, Write}, pin::Pin, rc::Rc, @@ -10,16 +9,20 @@ use std::{ task::{Context, Poll}, }; -use actix_codec::{AsyncRead, AsyncWrite}; +use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; use bytes::{Bytes, BytesMut}; -use http::header::{self, HeaderName, HeaderValue}; -use http::{Error as HttpError, Method, Uri, Version}; +use http::{Method, Uri, Version}; -use crate::cookie::{Cookie, CookieJar}; -use crate::header::HeaderMap; -use crate::header::{Header, IntoHeaderValue}; -use crate::payload::Payload; -use crate::Request; +#[cfg(feature = "cookies")] +use crate::{ + cookie::{Cookie, CookieJar}, + header::{self, HeaderValue}, +}; +use crate::{ + header::{HeaderMap, IntoHeaderPair}, + payload::Payload, + Request, +}; /// Test `Request` builder /// @@ -36,7 +39,7 @@ use crate::Request; /// } /// } /// -/// let resp = TestRequest::with_header("content-type", "text/plain") +/// let resp = TestRequest::default().insert_header("content-type", "text/plain") /// .run(&index) /// .unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); @@ -51,6 +54,7 @@ struct Inner { method: Method, uri: Uri, headers: HeaderMap, + #[cfg(feature = "cookies")] cookies: CookieJar, payload: Option, } @@ -62,6 +66,7 @@ impl Default for TestRequest { uri: Uri::from_str("/").unwrap(), version: Version::HTTP_11, headers: HeaderMap::new(), + #[cfg(feature = "cookies")] cookies: CookieJar::new(), payload: None, })) @@ -69,76 +74,74 @@ impl Default for TestRequest { } impl TestRequest { - /// Create TestRequest and set request uri + /// Create a default TestRequest and then set its URI. pub fn with_uri(path: &str) -> TestRequest { TestRequest::default().uri(path).take() } - /// Create TestRequest and set header - pub fn with_hdr(hdr: H) -> TestRequest { - TestRequest::default().set(hdr).take() - } - - /// Create TestRequest and set header - pub fn with_header(key: K, value: V) -> TestRequest - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - TestRequest::default().header(key, value).take() - } - - /// Set HTTP version of this request + /// Set HTTP version of this request. pub fn version(&mut self, ver: Version) -> &mut Self { parts(&mut self.0).version = ver; self } - /// Set HTTP method of this request + /// Set HTTP method of this request. pub fn method(&mut self, meth: Method) -> &mut Self { parts(&mut self.0).method = meth; self } - /// Set HTTP Uri of this request + /// Set URI of this request. + /// + /// # Panics + /// If provided URI is invalid. pub fn uri(&mut self, path: &str) -> &mut Self { parts(&mut self.0).uri = Uri::from_str(path).unwrap(); self } - /// Set a header - pub fn set(&mut self, hdr: H) -> &mut Self { - if let Ok(value) = hdr.try_into() { - parts(&mut self.0).headers.append(H::name(), value); - return self; - } - panic!("Can not set header"); - } - - /// Set a header - pub fn header(&mut self, key: K, value: V) -> &mut Self + /// Insert a header, replacing any that were set with an equivalent field name. + pub fn insert_header(&mut self, header: H) -> &mut Self where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, + H: IntoHeaderPair, { - if let Ok(key) = HeaderName::try_from(key) { - if let Ok(value) = value.try_into() { - parts(&mut self.0).headers.append(key, value); - return self; + match header.try_into_header_pair() { + Ok((key, value)) => { + parts(&mut self.0).headers.insert(key, value); + } + Err(err) => { + panic!("Error inserting test header: {}.", err.into()); } } - panic!("Can not create header"); + + self } - /// Set cookie for this request + /// Append a header, keeping any that were set with an equivalent field name. + pub fn append_header(&mut self, header: H) -> &mut Self + where + H: IntoHeaderPair, + { + match header.try_into_header_pair() { + Ok((key, value)) => { + parts(&mut self.0).headers.append(key, value); + } + Err(err) => { + panic!("Error inserting test header: {}.", err.into()); + } + } + + self + } + + /// Set cookie for this request. + #[cfg(feature = "cookies")] pub fn cookie<'a>(&mut self, cookie: Cookie<'a>) -> &mut Self { parts(&mut self.0).cookies.add(cookie.into_owned()); self } - /// Set request payload + /// Set request payload. pub fn set_payload>(&mut self, data: B) -> &mut Self { let mut payload = crate::h1::Payload::empty(); payload.unread_data(data.into()); @@ -150,7 +153,7 @@ impl TestRequest { TestRequest(self.0.take()) } - /// Complete request creation and generate `Request` instance + /// Complete request creation and generate `Request` instance. pub fn finish(&mut self) -> Request { let inner = self.0.take().expect("cannot reuse test request builder"); @@ -166,17 +169,20 @@ impl TestRequest { head.version = inner.version; head.headers = inner.headers; - let cookie: String = inner - .cookies - .delta() - // ensure only name=value is written to cookie header - .map(|c| Cookie::new(c.name(), c.value()).encoded().to_string()) - .collect::>() - .join("; "); + #[cfg(feature = "cookies")] + { + let cookie: String = inner + .cookies + .delta() + // ensure only name=value is written to cookie header + .map(|c| Cookie::new(c.name(), c.value()).encoded().to_string()) + .collect::>() + .join("; "); - if !cookie.is_empty() { - head.headers - .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap()); + if !cookie.is_empty() { + head.headers + .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap()); + } } req @@ -251,9 +257,11 @@ impl AsyncRead for TestBuffer { fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Poll::Ready(self.get_mut().read(buf)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let dst = buf.initialize_unfilled(); + let res = self.get_mut().read(dst).map(|n| buf.advance(n)); + Poll::Ready(res) } } @@ -356,11 +364,15 @@ impl AsyncRead for TestSeqBuffer { fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let r = self.get_mut().read(buf); + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let dst = buf.initialize_unfilled(); + let r = self.get_mut().read(dst); match r { - Ok(n) => Poll::Ready(Ok(n)), + Ok(n) => { + buf.advance(n); + Poll::Ready(Ok(())) + } Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(err) => Poll::Ready(Err(err)), } diff --git a/actix-http/src/time_parser.rs b/actix-http/src/time_parser.rs index 0d06a5867..46bf73037 100644 --- a/actix-http/src/time_parser.rs +++ b/actix-http/src/time_parser.rs @@ -8,35 +8,65 @@ pub fn parse_http_date(time: &str) -> Option { } /// Attempt to parse a `time` string as a RFC 1123 formatted date time string. +/// +/// Eg: `Fri, 12 Feb 2021 00:14:29 GMT` fn try_parse_rfc_1123(time: &str) -> Option { time::parse(time, "%a, %d %b %Y %H:%M:%S").ok() } /// Attempt to parse a `time` string as a RFC 850 formatted date time string. +/// +/// Eg: `Wednesday, 11-Jan-21 13:37:41 UTC` fn try_parse_rfc_850(time: &str) -> Option { - match PrimitiveDateTime::parse(time, "%A, %d-%b-%y %H:%M:%S") { - Ok(dt) => { - // If the `time` string contains a two-digit year, then as per RFC 2616 § 19.3, - // we consider the year as part of this century if it's within the next 50 years, - // otherwise we consider as part of the previous century. - let now = OffsetDateTime::now_utc(); - let century_start_year = (now.year() / 100) * 100; - let mut expanded_year = century_start_year + dt.year(); + let dt = PrimitiveDateTime::parse(time, "%A, %d-%b-%y %H:%M:%S").ok()?; - if expanded_year > now.year() + 50 { - expanded_year -= 100; - } + // If the `time` string contains a two-digit year, then as per RFC 2616 § 19.3, + // we consider the year as part of this century if it's within the next 50 years, + // otherwise we consider as part of the previous century. - match Date::try_from_ymd(expanded_year, dt.month(), dt.day()) { - Ok(date) => Some(PrimitiveDateTime::new(date, dt.time())), - Err(_) => None, - } - } - Err(_) => None, + let now = OffsetDateTime::now_utc(); + let century_start_year = (now.year() / 100) * 100; + let mut expanded_year = century_start_year + dt.year(); + + if expanded_year > now.year() + 50 { + expanded_year -= 100; } + + let date = Date::try_from_ymd(expanded_year, dt.month(), dt.day()).ok()?; + Some(PrimitiveDateTime::new(date, dt.time())) } /// Attempt to parse a `time` string using ANSI C's `asctime` format. +/// +/// Eg: `Wed Feb 13 15:46:11 2013` fn try_parse_asctime(time: &str) -> Option { time::parse(time, "%a %b %_d %H:%M:%S %Y").ok() } + +#[cfg(test)] +mod tests { + use time::{date, time}; + + use super::*; + + #[test] + fn test_rfc_850_year_shift() { + let date = try_parse_rfc_850("Friday, 19-Nov-82 16:14:55 EST").unwrap(); + assert_eq!(date, date!(1982 - 11 - 19).with_time(time!(16:14:55))); + + let date = try_parse_rfc_850("Wednesday, 11-Jan-62 13:37:41 EST").unwrap(); + assert_eq!(date, date!(2062 - 01 - 11).with_time(time!(13:37:41))); + + let date = try_parse_rfc_850("Wednesday, 11-Jan-21 13:37:41 EST").unwrap(); + assert_eq!(date, date!(2021 - 01 - 11).with_time(time!(13:37:41))); + + let date = try_parse_rfc_850("Wednesday, 11-Jan-23 13:37:41 EST").unwrap(); + assert_eq!(date, date!(2023 - 01 - 11).with_time(time!(13:37:41))); + + let date = try_parse_rfc_850("Wednesday, 11-Jan-99 13:37:41 EST").unwrap(); + assert_eq!(date, date!(1999 - 01 - 11).with_time(time!(13:37:41))); + + let date = try_parse_rfc_850("Wednesday, 11-Jan-00 13:37:41 EST").unwrap(); + assert_eq!(date, date!(2000 - 01 - 11).with_time(time!(13:37:41))); + } +} diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 7c9628b1a..54d850854 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -1,47 +1,60 @@ use actix_codec::{Decoder, Encoder}; +use bitflags::bitflags; use bytes::{Bytes, BytesMut}; +use bytestring::ByteString; use super::frame::Parser; use super::proto::{CloseReason, OpCode}; use super::ProtocolError; -/// `WebSocket` Message +/// A WebSocket message. #[derive(Debug, PartialEq)] pub enum Message { - /// Text message - Text(String), - /// Binary message + /// Text message. + Text(ByteString), + + /// Binary message. Binary(Bytes), - /// Continuation + + /// Continuation. Continuation(Item), - /// Ping message + + /// Ping message. Ping(Bytes), - /// Pong message + + /// Pong message. Pong(Bytes), - /// Close message with optional reason + + /// Close message with optional reason. Close(Option), - /// No-op. Useful for actix-net services + + /// No-op. Useful for low-level services. Nop, } -/// `WebSocket` frame +/// A WebSocket frame. #[derive(Debug, PartialEq)] pub enum Frame { - /// Text frame, codec does not verify utf8 encoding + /// Text frame. Note that the codec does not validate UTF-8 encoding. Text(Bytes), - /// Binary frame + + /// Binary frame. Binary(Bytes), - /// Continuation + + /// Continuation. Continuation(Item), - /// Ping message + + /// Ping message. Ping(Bytes), - /// Pong message + + /// Pong message. Pong(Bytes), - /// Close message with optional reason + + /// Close message with optional reason. Close(Option), } -/// `WebSocket` continuation item +/// A WebSocket continuation item. #[derive(Debug, PartialEq)] pub enum Item { FirstText(Bytes), @@ -51,13 +64,13 @@ pub enum Item { } #[derive(Debug, Copy, Clone)] -/// WebSockets protocol codec +/// WebSocket protocol codec. pub struct Codec { flags: Flags, max_size: usize, } -bitflags::bitflags! { +bitflags! { struct Flags: u8 { const SERVER = 0b0000_0001; const CONTINUATION = 0b0000_0010; @@ -66,7 +79,7 @@ bitflags::bitflags! { } impl Codec { - /// Create new websocket frames decoder + /// Create new WebSocket frames decoder. pub fn new() -> Codec { Codec { max_size: 65_536, @@ -74,9 +87,9 @@ impl Codec { } } - /// Set max frame size + /// Set max frame size. /// - /// By default max size is set to 64kb + /// By default max size is set to 64kB. pub fn max_size(mut self, size: usize) -> Self { self.max_size = size; self @@ -184,7 +197,7 @@ impl Encoder for Codec { } } }, - Message::Nop => (), + Message::Nop => {} } Ok(()) } diff --git a/actix-http/src/ws/dispatcher.rs b/actix-http/src/ws/dispatcher.rs index b114217a0..7be7cf637 100644 --- a/actix-http/src/ws/dispatcher.rs +++ b/actix-http/src/ws/dispatcher.rs @@ -11,7 +11,7 @@ use super::{Codec, Frame, Message}; #[pin_project::pin_project] pub struct Dispatcher where - S: Service + 'static, + S: Service + 'static, T: AsyncRead + AsyncWrite, { #[pin] @@ -21,17 +21,17 @@ where impl Dispatcher where T: AsyncRead + AsyncWrite, - S: Service, + S: Service, S::Future: 'static, S::Error: 'static, { - pub fn new>(io: T, service: F) -> Self { + pub fn new>(io: T, service: F) -> Self { Dispatcher { inner: InnerDispatcher::new(Framed::new(io, Codec::new()), service), } } - pub fn with>(framed: Framed, service: F) -> Self { + pub fn with>(framed: Framed, service: F) -> Self { Dispatcher { inner: InnerDispatcher::new(framed, service), } @@ -41,7 +41,7 @@ where impl Future for Dispatcher where T: AsyncRead + AsyncWrite, - S: Service, + S: Service, S::Future: 'static, S::Error: 'static, { diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 0598a9b4e..46edf5d85 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -7,7 +7,7 @@ use crate::ws::mask::apply_mask; use crate::ws::proto::{CloseCode, CloseReason, OpCode}; use crate::ws::ProtocolError; -/// A struct representing a `WebSocket` frame. +/// A struct representing a WebSocket frame. #[derive(Debug)] pub struct Parser; @@ -16,7 +16,8 @@ impl Parser { src: &[u8], server: bool, max_size: usize, - ) -> Result)>, ProtocolError> { + ) -> Result)>, ProtocolError> + { let chunk_len = src.len(); let mut idx = 2; @@ -77,9 +78,10 @@ impl Parser { return Ok(None); } - let mask = - u32::from_le_bytes(TryFrom::try_from(&src[idx..idx + 4]).unwrap()); + let mask = TryFrom::try_from(&src[idx..idx + 4]).unwrap(); + idx += 4; + Some(mask) } else { None @@ -125,7 +127,7 @@ impl Parser { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); return Ok(Some((true, OpCode::Close, None))); } - _ => (), + _ => {} } // unmask @@ -187,8 +189,8 @@ impl Parser { }; if mask { - let mask = rand::random::(); - dst.put_u32_le(mask); + let mask = rand::random::<[u8; 4]>(); + dst.put_slice(mask.as_ref()); dst.put_slice(payload.as_ref()); let pos = dst.len() - payload_len; apply_mask(&mut dst[pos..], mask); diff --git a/actix-http/src/ws/mask.rs b/actix-http/src/ws/mask.rs index d37d57eb1..276ca4a85 100644 --- a/actix-http/src/ws/mask.rs +++ b/actix-http/src/ws/mask.rs @@ -1,134 +1,57 @@ //! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs) -#![allow(clippy::cast_ptr_alignment)] -use std::ptr::copy_nonoverlapping; -use std::slice; -// Holds a slice guaranteed to be shorter than 8 bytes -struct ShortSlice<'a> { - inner: &'a mut [u8], -} - -impl<'a> ShortSlice<'a> { - /// # Safety - /// Given slice must be shorter than 8 bytes. - unsafe fn new(slice: &'a mut [u8]) -> Self { - // Sanity check for debug builds - debug_assert!(slice.len() < 8); - ShortSlice { inner: slice } - } - - fn len(&self) -> usize { - self.inner.len() - } -} - -/// Faster version of `apply_mask()` which operates on 8-byte blocks. +/// Mask/unmask a frame. #[inline] -#[allow(clippy::cast_lossless)] -pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) { - // Extend the mask to 64 bits - let mut mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64); - // Split the buffer into three segments - let (head, mid, tail) = align_buf(buf); +pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) { + apply_mask_fast32(buf, mask) +} - // Initial unaligned segment - let head_len = head.len(); - if head_len > 0 { - xor_short(head, mask_u64); +/// A safe unoptimized mask application. +#[inline] +fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) { + for (i, byte) in buf.iter_mut().enumerate() { + *byte ^= mask[i & 3]; + } +} + +/// Faster version of `apply_mask()` which operates on 4-byte blocks. +#[inline] +pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { + let mask_u32 = u32::from_ne_bytes(mask); + + // SAFETY: + // + // buf is a valid slice borrowed mutably from bytes::BytesMut. + // + // un aligned prefix and suffix would be mask/unmask per byte. + // proper aligned middle slice goes into fast path and operates on 4-byte blocks. + let (mut prefix, words, mut suffix) = unsafe { buf.align_to_mut::() }; + apply_mask_fallback(&mut prefix, mask); + let head = prefix.len() & 3; + let mask_u32 = if head > 0 { if cfg!(target_endian = "big") { - mask_u64 = mask_u64.rotate_left(8 * head_len as u32); + mask_u32.rotate_left(8 * head as u32) } else { - mask_u64 = mask_u64.rotate_right(8 * head_len as u32); - } - } - // Aligned segment - for v in mid { - *v ^= mask_u64; - } - // Final unaligned segment - if tail.len() > 0 { - xor_short(tail, mask_u64); - } -} - -// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so -// inefficient, it could be done better. The compiler does not understand that -// a `ShortSlice` must be smaller than a u64. -#[inline] -#[allow(clippy::needless_pass_by_value)] -fn xor_short(buf: ShortSlice<'_>, mask: u64) { - // SAFETY: we know that a `ShortSlice` fits in a u64 - unsafe { - let (ptr, len) = (buf.inner.as_mut_ptr(), buf.len()); - let mut b: u64 = 0; - #[allow(trivial_casts)] - copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); - b ^= mask; - #[allow(trivial_casts)] - copy_nonoverlapping(&b as *const _ as *const u8, ptr, len); - } -} - -/// # Safety -/// Caller must ensure the buffer has the correct size and alignment. -#[inline] -unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] { - // Assert correct size and alignment in debug builds - debug_assert!(buf.len().trailing_zeros() >= 3); - debug_assert!((buf.as_ptr() as usize).trailing_zeros() >= 3); - - slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u64, buf.len() >> 3) -} - -// Splits a slice into three parts: an unaligned short head and tail, plus an aligned -// u64 mid section. -#[inline] -fn align_buf(buf: &mut [u8]) -> (ShortSlice<'_>, &mut [u64], ShortSlice<'_>) { - let start_ptr = buf.as_ptr() as usize; - let end_ptr = start_ptr + buf.len(); - - // Round *up* to next aligned boundary for start - let start_aligned = (start_ptr + 7) & !0x7; - // Round *down* to last aligned boundary for end - let end_aligned = end_ptr & !0x7; - - if end_aligned >= start_aligned { - // We have our three segments (head, mid, tail) - let (tmp, tail) = buf.split_at_mut(end_aligned - start_ptr); - let (head, mid) = tmp.split_at_mut(start_aligned - start_ptr); - - // SAFETY: we know the middle section is correctly aligned, and the outer - // sections are smaller than 8 bytes - unsafe { - ( - ShortSlice::new(head), - cast_slice(mid), - ShortSlice::new(tail), - ) + mask_u32.rotate_right(8 * head as u32) } } else { - // We didn't cross even one aligned boundary! - - // SAFETY: The outer sections are smaller than 8 bytes - unsafe { (ShortSlice::new(buf), &mut [], ShortSlice::new(&mut [])) } + mask_u32 + }; + for word in words.iter_mut() { + *word ^= mask_u32; } + apply_mask_fallback(&mut suffix, mask_u32.to_ne_bytes()); } #[cfg(test)] mod tests { - use super::apply_mask; - - /// A safe unoptimized mask application. - fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { - for (i, byte) in buf.iter_mut().enumerate() { - *byte ^= mask[i & 3]; - } - } + use super::*; + // legacy test from old apply mask test. kept for now for back compat test. + // TODO: remove it and favor the other test. #[test] - fn test_apply_mask() { + fn test_apply_mask_legacy() { let mask = [0x6d, 0xb6, 0xb2, 0x80]; - let mask_u32 = u32::from_le_bytes(mask); let unmasked = vec![ 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, @@ -138,10 +61,10 @@ mod tests { // Check masking with proper alignment. { let mut masked = unmasked.clone(); - apply_mask_fallback(&mut masked, &mask); + apply_mask_fallback(&mut masked, mask); let mut masked_fast = unmasked.clone(); - apply_mask(&mut masked_fast, mask_u32); + apply_mask(&mut masked_fast, mask); assert_eq!(masked, masked_fast); } @@ -149,12 +72,38 @@ mod tests { // Check masking without alignment. { let mut masked = unmasked.clone(); - apply_mask_fallback(&mut masked[1..], &mask); + apply_mask_fallback(&mut masked[1..], mask); let mut masked_fast = unmasked; - apply_mask(&mut masked_fast[1..], mask_u32); + apply_mask(&mut masked_fast[1..], mask); assert_eq!(masked, masked_fast); } } + + #[test] + fn test_apply_mask() { + let mask = [0x6d, 0xb6, 0xb2, 0x80]; + let unmasked = vec![ + 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, + 0x74, 0xf9, 0x12, 0x03, + ]; + + for data_len in 0..=unmasked.len() { + let unmasked = &unmasked[0..data_len]; + // Check masking with different alignment. + for off in 0..=3 { + if unmasked.len() < off { + continue; + } + let mut masked = unmasked.to_vec(); + apply_mask_fallback(&mut masked[off..], mask); + + let mut masked_fast = unmasked.to_vec(); + apply_mask_fast32(&mut masked_fast[off..], mask); + + assert_eq!(masked, masked_fast); + } + } + } } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index cd212fb7e..0490163d5 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -1,11 +1,11 @@ -//! WebSocket protocol support. +//! WebSocket protocol. //! -//! To setup a `WebSocket`, first do web socket handshake then on success -//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to -//! communicate with the peer. +//! To setup a WebSocket, first perform the WebSocket handshake then on success convert `Payload` into a +//! `WsStream` stream and then use `WsWriter` to communicate with the peer. + use std::io; -use derive_more::{Display, From}; +use derive_more::{Display, Error, From}; use http::{header, Method, StatusCode}; use crate::error::ResponseError; @@ -23,86 +23,103 @@ pub use self::dispatcher::Dispatcher; pub use self::frame::Parser; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; -/// Websocket protocol errors -#[derive(Debug, Display, From)] +/// WebSocket protocol errors. +#[derive(Debug, Display, From, Error)] pub enum ProtocolError { - /// Received an unmasked frame from client - #[display(fmt = "Received an unmasked frame from client")] + /// Received an unmasked frame from client. + #[display(fmt = "Received an unmasked frame from client.")] UnmaskedFrame, - /// Received a masked frame from server - #[display(fmt = "Received a masked frame from server")] + + /// Received a masked frame from server. + #[display(fmt = "Received a masked frame from server.")] MaskedFrame, - /// Encountered invalid opcode - #[display(fmt = "Invalid opcode: {}", _0)] - InvalidOpcode(u8), + + /// Encountered invalid opcode. + #[display(fmt = "Invalid opcode: {}.", _0)] + InvalidOpcode(#[error(not(source))] u8), + /// Invalid control frame length - #[display(fmt = "Invalid control frame length: {}", _0)] - InvalidLength(usize), - /// Bad web socket op code - #[display(fmt = "Bad web socket op code")] + #[display(fmt = "Invalid control frame length: {}.", _0)] + InvalidLength(#[error(not(source))] usize), + + /// Bad opcode. + #[display(fmt = "Bad opcode.")] BadOpCode, + /// A payload reached size limit. #[display(fmt = "A payload reached size limit.")] Overflow, - /// Continuation is not started + + /// Continuation is not started. #[display(fmt = "Continuation is not started.")] ContinuationNotStarted, - /// Received new continuation but it is already started - #[display(fmt = "Received new continuation but it is already started")] + + /// Received new continuation but it is already started. + #[display(fmt = "Received new continuation but it is already started.")] ContinuationStarted, - /// Unknown continuation fragment - #[display(fmt = "Unknown continuation fragment.")] - ContinuationFragment(OpCode), - /// Io error - #[display(fmt = "io error: {}", _0)] + + /// Unknown continuation fragment. + #[display(fmt = "Unknown continuation fragment: {}.", _0)] + ContinuationFragment(#[error(not(source))] OpCode), + + /// I/O error. + #[display(fmt = "I/O error: {}", _0)] Io(io::Error), } -impl std::error::Error for ProtocolError {} - impl ResponseError for ProtocolError {} -/// Websocket handshake errors +/// WebSocket handshake errors #[derive(PartialEq, Debug, Display)] pub enum HandshakeError { - /// Only get method is allowed - #[display(fmt = "Method not allowed")] + /// Only get method is allowed. + #[display(fmt = "Method not allowed.")] GetMethodRequired, - /// Upgrade header if not set to websocket - #[display(fmt = "Websocket upgrade is expected")] + + /// Upgrade header if not set to WebSocket. + #[display(fmt = "WebSocket upgrade is expected.")] NoWebsocketUpgrade, - /// Connection header is not set to upgrade - #[display(fmt = "Connection upgrade is expected")] + + /// Connection header is not set to upgrade. + #[display(fmt = "Connection upgrade is expected.")] NoConnectionUpgrade, - /// Websocket version header is not set - #[display(fmt = "Websocket version header is required")] + + /// WebSocket version header is not set. + #[display(fmt = "WebSocket version header is required.")] NoVersionHeader, - /// Unsupported websocket version - #[display(fmt = "Unsupported version")] + + /// Unsupported WebSocket version. + #[display(fmt = "Unsupported version.")] UnsupportedVersion, - /// Websocket key is not set or wrong - #[display(fmt = "Unknown websocket key")] + + /// WebSocket key is not set or wrong. + #[display(fmt = "Unknown websocket key.")] BadWebsocketKey, } impl ResponseError for HandshakeError { fn error_response(&self) -> Response { - match *self { + match self { HandshakeError::GetMethodRequired => Response::MethodNotAllowed() - .header(header::ALLOW, "GET") + .insert_header((header::ALLOW, "GET")) .finish(), + HandshakeError::NoWebsocketUpgrade => Response::BadRequest() .reason("No WebSocket UPGRADE header found") .finish(), + HandshakeError::NoConnectionUpgrade => Response::BadRequest() .reason("No CONNECTION upgrade") .finish(), + HandshakeError::NoVersionHeader => Response::BadRequest() .reason("Websocket version header is required") .finish(), + HandshakeError::UnsupportedVersion => Response::BadRequest() .reason("Unsupported version") .finish(), + HandshakeError::BadWebsocketKey => { Response::BadRequest().reason("Handshake error").finish() } @@ -110,26 +127,20 @@ impl ResponseError for HandshakeError { } } -/// Verify `WebSocket` handshake request and create handshake response. -// /// `protocols` is a sequence of known protocols. On successful handshake, -// /// the returned response headers contain the first protocol in this list -// /// which the server also knows. +/// Verify WebSocket handshake request and create handshake response. pub fn handshake(req: &RequestHead) -> Result { verify_handshake(req)?; Ok(handshake_response(req)) } -/// Verify `WebSocket` handshake request. -// /// `protocols` is a sequence of known protocols. On successful handshake, -// /// the returned response headers contain the first protocol in this list -// /// which the server also knows. +/// Verify WebSocket handshake request. pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { // WebSocket accepts only GET if req.method != Method::GET { return Err(HandshakeError::GetMethodRequired); } - // Check for "UPGRADE" to websocket header + // Check for "UPGRADE" to WebSocket header let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { if let Ok(s) = hdr.to_str() { s.to_ascii_lowercase().contains("websocket") @@ -170,7 +181,7 @@ pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { Ok(()) } -/// Create websocket handshake response +/// Create WebSocket handshake response. /// /// This function returns handshake `Response`, ready to send to peer. pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { @@ -181,8 +192,8 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { Response::build(StatusCode::SWITCHING_PROTOCOLS) .upgrade("websocket") - .header(header::TRANSFER_ENCODING, "chunked") - .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) + .insert_header((header::TRANSFER_ENCODING, "chunked")) + .insert_header((header::SEC_WEBSOCKET_ACCEPT, key)) .take() } @@ -207,7 +218,7 @@ mod tests { ); let req = TestRequest::default() - .header(header::UPGRADE, header::HeaderValue::from_static("test")) + .insert_header((header::UPGRADE, header::HeaderValue::from_static("test"))) .finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, @@ -215,10 +226,10 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) + )) .finish(); assert_eq!( HandshakeError::NoConnectionUpgrade, @@ -226,14 +237,14 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) + )) .finish(); assert_eq!( HandshakeError::NoVersionHeader, @@ -241,18 +252,18 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("5"), - ) + )) .finish(); assert_eq!( HandshakeError::UnsupportedVersion, @@ -260,18 +271,18 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) + )) .finish(); assert_eq!( HandshakeError::BadWebsocketKey, @@ -279,22 +290,22 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), - ) + )) .finish(); assert_eq!( StatusCode::SWITCHING_PROTOCOLS, diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index fc271a8f5..1e8bf7af3 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -1,28 +1,34 @@ use std::convert::{From, Into}; use std::fmt; -use self::OpCode::*; -/// Operation codes as part of rfc6455. +/// Operation codes as part of RFC6455. #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum OpCode { /// Indicates a continuation frame of a fragmented message. Continue, + /// Indicates a text data frame. Text, + /// Indicates a binary data frame. Binary, + /// Indicates a close control frame. Close, + /// Indicates a ping control frame. Ping, + /// Indicates a pong control frame. Pong, + /// Indicates an invalid opcode was received. Bad, } impl fmt::Display for OpCode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use self::OpCode::*; match *self { Continue => write!(f, "CONTINUE"), Text => write!(f, "TEXT"), @@ -35,9 +41,10 @@ impl fmt::Display for OpCode { } } -impl Into for OpCode { - fn into(self) -> u8 { - match self { +impl From for u8 { + fn from(op: OpCode) -> u8 { + use self::OpCode::*; + match op { Continue => 0, Text => 1, Binary => 2, @@ -54,6 +61,7 @@ impl Into for OpCode { impl From for OpCode { fn from(byte: u8) -> OpCode { + use self::OpCode::*; match byte { 0 => Continue, 1 => Text, @@ -66,9 +74,7 @@ impl From for OpCode { } } -use self::CloseCode::*; -/// Status code used to indicate why an endpoint is closing the `WebSocket` -/// connection. +/// Status code used to indicate why an endpoint is closing the WebSocket connection. #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum CloseCode { /// Indicates a normal closure, meaning that the purpose for @@ -132,9 +138,10 @@ pub enum CloseCode { Other(u16), } -impl Into for CloseCode { - fn into(self) -> u16 { - match self { +impl From for u16 { + fn from(code: CloseCode) -> u16 { + use self::CloseCode::*; + match code { Normal => 1000, Away => 1001, Protocol => 1002, @@ -155,6 +162,7 @@ impl Into for CloseCode { impl From for CloseCode { fn from(code: u16) -> CloseCode { + use self::CloseCode::*; match code { 1000 => Normal, 1001 => Away, @@ -179,6 +187,7 @@ impl From for CloseCode { pub struct CloseReason { /// Exit code pub code: CloseCode, + /// Optional description of the exit code pub description: Option, } @@ -222,7 +231,7 @@ mod test { macro_rules! opcode_into { ($from:expr => $opcode:pat) => { match OpCode::from($from) { - e @ $opcode => (), + e @ $opcode => {} e => unreachable!("{:?}", e), } }; @@ -232,7 +241,7 @@ mod test { ($from:expr => $opcode:pat) => { let res: u8 = $from.into(); match res { - e @ $opcode => (), + e @ $opcode => {} e => unreachable!("{:?}", e), } }; diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index 07104decc..91b2412f4 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -1,9 +1,8 @@ -use actix_service::ServiceFactory; -use bytes::Bytes; -use futures_util::future::{self, ok}; - use actix_http::{http, HttpService, Request, Response}; use actix_http_test::test_server; +use actix_service::ServiceFactoryExt; +use bytes::Bytes; +use futures_util::future::{self, ok}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -39,7 +38,7 @@ async fn test_h1_v2() { let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); - let request = srv.get("/").header("x-test", "111").send(); + let request = srv.get("/").insert_header(("x-test", "111")).send(); let mut response = request.await.unwrap(); assert!(response.status().is_success()); diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs index 6b80bad0a..f44968baa 100644 --- a/actix-http/tests/test_openssl.rs +++ b/actix-http/tests/test_openssl.rs @@ -1,19 +1,24 @@ #![cfg(feature = "openssl")] + +extern crate tls_openssl as openssl; + use std::io; -use actix_http_test::test_server; -use actix_service::{fn_service, ServiceFactory}; - -use bytes::{Bytes, BytesMut}; -use futures_util::future::{err, ok, ready}; -use futures_util::stream::{once, Stream, StreamExt}; -use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; - use actix_http::error::{ErrorBadRequest, PayloadError}; use actix_http::http::header::{self, HeaderName, HeaderValue}; use actix_http::http::{Method, StatusCode, Version}; -use actix_http::httpmessage::HttpMessage; +use actix_http::HttpMessage; use actix_http::{body, Error, HttpService, Request, Response}; +use actix_http_test::test_server; +use actix_service::{fn_service, ServiceFactoryExt}; +use bytes::{Bytes, BytesMut}; +use futures_util::future::{err, ok, ready}; +use futures_util::stream::{once, Stream, StreamExt}; +use openssl::{ + pkey::PKey, + ssl::{SslAcceptor, SslMethod}, + x509::X509, +}; async fn load_body(stream: S) -> Result where @@ -33,29 +38,26 @@ where Ok(body) } -fn ssl_acceptor() -> SslAcceptor { - // load ssl keys +fn tls_config() -> SslAcceptor { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + let cert = X509::from_pem(cert_file.as_bytes()).unwrap(); + let key = PKey::private_key_from_pem(key_file.as_bytes()).unwrap(); + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("../tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("../tests/cert.pem") - .unwrap(); + builder.set_certificate(&cert).unwrap(); + builder.set_private_key(&key).unwrap(); + builder.set_alpn_select_callback(|_, protos| { const H2: &[u8] = b"\x02h2"; - const H11: &[u8] = b"\x08http/1.1"; if protos.windows(3).any(|window| window == H2) { Ok(b"h2") - } else if protos.windows(9).any(|window| window == H11) { - Ok(b"http/1.1") } else { - Err(AlpnError::NOACK) + Err(openssl::ssl::AlpnError::NOACK) } }); - builder - .set_alpn_protos(b"\x08http/1.1\x02h2") - .expect("Can not contrust SslAcceptor"); + builder.set_alpn_protos(b"\x02h2").unwrap(); builder.build() } @@ -65,7 +67,7 @@ async fn test_h2() -> io::Result<()> { let srv = test_server(move || { HttpService::build() .h2(|_| ok::<_, Error>(Response::Ok().finish())) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -84,7 +86,7 @@ async fn test_h2_1() -> io::Result<()> { assert_eq!(req.version(), Version::HTTP_2); ok::<_, Error>(Response::Ok().finish()) }) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -103,7 +105,7 @@ async fn test_h2_body() -> io::Result<()> { let body = load_body(req.take_payload()).await?; Ok::<_, Error>(Response::Ok().body(body)) }) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -132,7 +134,7 @@ async fn test_h2_content_length() { ]; ok::<_, ()>(Response::new(statuses[indx])) }) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -175,8 +177,8 @@ async fn test_h2_headers() { HttpService::build().h2(move |_| { let mut builder = Response::Ok(); for idx in 0..90 { - builder.header( - format!("X-TEST-{}", idx).as_str(), + builder.insert_header( + (format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ @@ -190,11 +192,11 @@ async fn test_h2_headers() { TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", - ); + )); } ok::<_, ()>(builder.body(data.clone())) }) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }).await; @@ -233,7 +235,7 @@ async fn test_h2_body2() { let mut srv = test_server(move || { HttpService::build() .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -251,7 +253,7 @@ async fn test_h2_head_empty() { let mut srv = test_server(move || { HttpService::build() .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -275,7 +277,7 @@ async fn test_h2_head_binary() { let mut srv = test_server(move || { HttpService::build() .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -298,7 +300,7 @@ async fn test_h2_head_binary2() { let srv = test_server(move || { HttpService::build() .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -322,7 +324,7 @@ async fn test_h2_body_length() { Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), ) }) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -343,11 +345,11 @@ async fn test_h2_body_chunked_explicit() { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") + .insert_header((header::TRANSFER_ENCODING, "chunked")) .streaming(body), ) }) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -371,11 +373,11 @@ async fn test_h2_response_http_error_handling() { let broken_header = Bytes::from_static(b"\0\0\0"); ok::<_, ()>( Response::Ok() - .header(header::CONTENT_TYPE, broken_header) + .insert_header((header::CONTENT_TYPE, broken_header)) .body(STR), ) })) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -393,7 +395,7 @@ async fn test_h2_service_error() { let mut srv = test_server(move || { HttpService::build() .h2(|_| err::(ErrorBadRequest("error"))) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -410,12 +412,14 @@ async fn test_h2_service_error() { async fn test_h2_on_connect() { let srv = test_server(move || { HttpService::build() - .on_connect_ext(|_, data| data.insert(20isize)) + .on_connect_ext(|_, data| { + data.insert(20isize); + }) .h2(|req: Request| { assert!(req.extensions().contains::()); ok::<_, ()>(Response::Ok().finish()) }) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs index beae359d9..a36400910 100644 --- a/actix-http/tests/test_rustls.rs +++ b/actix-http/tests/test_rustls.rs @@ -1,4 +1,7 @@ #![cfg(feature = "rustls")] + +extern crate tls_rustls as rustls; + use actix_http::error::PayloadError; use actix_http::http::header::{self, HeaderName, HeaderValue}; use actix_http::http::{Method, StatusCode, Version}; @@ -9,7 +12,7 @@ use actix_service::{fn_factory_with_config, fn_service}; use bytes::{Bytes, BytesMut}; use futures_util::future::{self, err, ok}; use futures_util::stream::{once, Stream, StreamExt}; -use rust_tls::{ +use rustls::{ internal::pemfile::{certs, pkcs8_private_keys}, NoClientAuth, ServerConfig as RustlsServerConfig, }; @@ -28,14 +31,19 @@ where Ok(body) } -fn ssl_acceptor() -> RustlsServerConfig { - // load ssl keys +fn tls_config() -> RustlsServerConfig { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + let mut config = RustlsServerConfig::new(NoClientAuth::new()); - let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); + let cert_file = &mut BufReader::new(cert_file.as_bytes()); + let key_file = &mut BufReader::new(key_file.as_bytes()); + let cert_chain = certs(cert_file).unwrap(); let mut keys = pkcs8_private_keys(key_file).unwrap(); config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + config } @@ -44,7 +52,7 @@ async fn test_h1() -> io::Result<()> { let srv = test_server(move || { HttpService::build() .h1(|_| future::ok::<_, Error>(Response::Ok().finish())) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -58,7 +66,7 @@ async fn test_h2() -> io::Result<()> { let srv = test_server(move || { HttpService::build() .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -76,7 +84,7 @@ async fn test_h1_1() -> io::Result<()> { assert_eq!(req.version(), Version::HTTP_11); future::ok::<_, Error>(Response::Ok().finish()) }) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -94,7 +102,7 @@ async fn test_h2_1() -> io::Result<()> { assert_eq!(req.version(), Version::HTTP_2); future::ok::<_, Error>(Response::Ok().finish()) }) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -112,7 +120,7 @@ async fn test_h2_body1() -> io::Result<()> { let body = load_body(req.take_payload()).await?; Ok::<_, Error>(Response::Ok().body(body)) }) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -140,7 +148,7 @@ async fn test_h2_content_length() { ]; future::ok::<_, ()>(Response::new(statuses[indx])) }) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -181,7 +189,7 @@ async fn test_h2_headers() { HttpService::build().h2(move |_| { let mut config = Response::Ok(); for idx in 0..90 { - config.header( + config.insert_header(( format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ @@ -196,11 +204,11 @@ async fn test_h2_headers() { TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", - ); + )); } future::ok::<_, ()>(config.body(data.clone())) }) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }).await; let response = srv.sget("/").send().await.unwrap(); @@ -238,7 +246,7 @@ async fn test_h2_body2() { let mut srv = test_server(move || { HttpService::build() .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -255,7 +263,7 @@ async fn test_h2_head_empty() { let mut srv = test_server(move || { HttpService::build() .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -281,7 +289,7 @@ async fn test_h2_head_binary() { let mut srv = test_server(move || { HttpService::build() .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -306,7 +314,7 @@ async fn test_h2_head_binary2() { let srv = test_server(move || { HttpService::build() .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -332,7 +340,7 @@ async fn test_h2_body_length() { Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), ) }) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -352,11 +360,11 @@ async fn test_h2_body_chunked_explicit() { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") + .insert_header((header::TRANSFER_ENCODING, "chunked")) .streaming(body), ) }) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -380,12 +388,12 @@ async fn test_h2_response_http_error_handling() { let broken_header = Bytes::from_static(b"\0\0\0"); ok::<_, ()>( Response::Ok() - .header(http::header::CONTENT_TYPE, broken_header) + .insert_header((http::header::CONTENT_TYPE, broken_header)) .body(STR), ) })) })) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -402,7 +410,7 @@ async fn test_h2_service_error() { let mut srv = test_server(move || { HttpService::build() .h2(|_| err::(error::ErrorBadRequest("error"))) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; @@ -419,7 +427,7 @@ async fn test_h1_service_error() { let mut srv = test_server(move || { HttpService::build() .h1(|_| err::(error::ErrorBadRequest("error"))) - .rustls(ssl_acceptor()) + .rustls(tls_config()) }) .await; diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index 44794e199..910fa81f2 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -3,14 +3,14 @@ use std::time::Duration; use std::{net, thread}; use actix_http_test::test_server; -use actix_rt::time::delay_for; +use actix_rt::time::sleep; use actix_service::fn_service; use bytes::Bytes; use futures_util::future::{self, err, ok, ready, FutureExt}; use futures_util::stream::{once, StreamExt}; use regex::Regex; -use actix_http::httpmessage::HttpMessage; +use actix_http::HttpMessage; use actix_http::{ body, error, http, http::header, Error, HttpService, KeepAlive, Request, Response, }; @@ -88,7 +88,7 @@ async fn test_expect_continue_h1() { let srv = test_server(|| { HttpService::build() .expect(fn_service(|req: Request| { - delay_for(Duration::from_millis(20)).then(move |_| { + sleep(Duration::from_millis(20)).then(move |_| { if req.head().uri.query() == Some("yes=") { ok(req) } else { @@ -392,7 +392,7 @@ async fn test_h1_headers() { HttpService::build().h1(move |_| { let mut builder = Response::Ok(); for idx in 0..90 { - builder.header( + builder.insert_header(( format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ @@ -407,7 +407,7 @@ async fn test_h1_headers() { TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", - ); + )); } future::ok::<_, ()>(builder.body(data.clone())) }).tcp() @@ -561,7 +561,7 @@ async fn test_h1_body_chunked_explicit() { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") + .insert_header((header::TRANSFER_ENCODING, "chunked")) .streaming(body), ) }) @@ -625,7 +625,7 @@ async fn test_h1_response_http_error_handling() { let broken_header = Bytes::from_static(b"\0\0\0"); ok::<_, ()>( Response::Ok() - .header(http::header::CONTENT_TYPE, broken_header) + .insert_header((http::header::CONTENT_TYPE, broken_header)) .body(STR), ) })) @@ -662,7 +662,9 @@ async fn test_h1_service_error() { async fn test_h1_on_connect() { let srv = test_server(|| { HttpService::build() - .on_connect_ext(|_, data| data.insert(20isize)) + .on_connect_ext(|_, data| { + data.insert(20isize); + }) .h1(|req: Request| { assert!(req.extensions().contains::()); future::ok::<_, ()>(Response::Ok().finish()) diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 5d86605f4..7ed9b0df1 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -21,7 +21,7 @@ impl WsService { WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false))))) } - fn set_polled(&mut self) { + fn set_polled(&self) { *self.0.lock().unwrap().1.get_mut() = true; } @@ -36,21 +36,20 @@ impl Clone for WsService { } } -impl Service for WsService +impl Service<(Request, Framed)> for WsService where T: AsyncRead + AsyncWrite + Unpin + 'static, { - type Request = (Request, Framed); type Response = (); type Error = Error; type Future = Pin>>>; - fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, _ctx: &mut Context<'_>) -> Poll> { self.set_polled(); Poll::Ready(Ok(())) } - fn call(&mut self, (req, mut framed): Self::Request) -> Self::Future { + fn call(&self, (req, mut framed): (Request, Framed)) -> Self::Future { let fut = async move { let res = ws::handshake(req.head()).unwrap().message_body(()); @@ -72,7 +71,7 @@ async fn service(msg: ws::Frame) -> Result { let msg = match msg { ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8_lossy(&text).to_string()) + ws::Message::Text(String::from_utf8_lossy(&text).into_owned().into()) } ws::Frame::Binary(bin) => ws::Message::Binary(bin), ws::Frame::Continuation(item) => ws::Message::Continuation(item), @@ -99,10 +98,7 @@ async fn test_simple() { // client service let mut framed = srv.ws().await.unwrap(); - framed - .send(ws::Message::Text("text".to_string())) - .await - .unwrap(); + framed.send(ws::Message::Text("text".into())).await.unwrap(); let (item, mut framed) = framed.into_future().await; assert_eq!( item.unwrap().unwrap(), diff --git a/actix-multipart/CHANGES.md b/actix-multipart/CHANGES.md index 446ca5ad2..2142ebf4b 100644 --- a/actix-multipart/CHANGES.md +++ b/actix-multipart/CHANGES.md @@ -1,14 +1,25 @@ # Changes -## Unreleased - 2020-xx-xx -* Fix multipart consuming payload before header checks #1513 +## Unreleased - 2021-xx-xx -## 3.0.0 - 2020-09-11 -* No significant changes from `3.0.0-beta.2`. +## 0.4.0-beta.2 - 2021-02-10 +* No notable changes. -## 3.0.0-beta.2 - 2020-09-10 +## 0.4.0-beta.1 - 2021-01-07 +* Fix multipart consuming payload before header checks. [#1513] +* Update `bytes` to `1.0`. [#1813] + +[#1813]: https://github.com/actix/actix-web/pull/1813 +[#1513]: https://github.com/actix/actix-web/pull/1513 + + +## 0.3.0 - 2020-09-11 +* No significant changes from `0.3.0-beta.2`. + + +## 0.3.0-beta.2 - 2020-09-10 * Update `actix-*` dependencies to latest versions. diff --git a/actix-multipart/Cargo.toml b/actix-multipart/Cargo.toml index e2e9dbf14..67b2698bd 100644 --- a/actix-multipart/Cargo.toml +++ b/actix-multipart/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "actix-multipart" -version = "0.3.0" +version = "0.4.0-beta.2" authors = ["Nikolay Kim "] -description = "Multipart support for actix web framework." +description = "Multipart form support for Actix Web" readme = "README.md" keywords = ["http", "web", "framework", "async", "futures"] homepage = "https://actix.rs" @@ -16,17 +16,17 @@ name = "actix_multipart" path = "src/lib.rs" [dependencies] -actix-web = { version = "3.0.0", default-features = false } -actix-service = "1.0.6" -actix-utils = "2.0.0" -bytes = "0.5.3" -derive_more = "0.99.2" +actix-web = { version = "4.0.0-beta.3", default-features = false } +actix-utils = "3.0.0-beta.2" + +bytes = "1" +derive_more = "0.99.5" httparse = "1.3" -futures-util = { version = "0.3.5", default-features = false } +futures-util = { version = "0.3.7", default-features = false } log = "0.4" mime = "0.3" twoway = "0.2" [dev-dependencies] -actix-rt = "1.0.0" -actix-http = "2.0.0" +actix-rt = "2" +actix-http = "3.0.0-beta.3" diff --git a/actix-multipart/README.md b/actix-multipart/README.md index edb2e0020..defcf7828 100644 --- a/actix-multipart/README.md +++ b/actix-multipart/README.md @@ -1,8 +1,18 @@ -# Multipart support for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-multipart)](https://crates.io/crates/actix-multipart) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +# actix-multipart -## Documentation & community resources +> Multipart form support for Actix Web. -* [API Documentation](https://docs.rs/actix-multipart/) -* [Chat on gitter](https://gitter.im/actix/actix) -* Cargo package: [actix-multipart](https://crates.io/crates/actix-multipart) -* Minimum supported Rust version: 1.40 or later +[![crates.io](https://img.shields.io/crates/v/actix-multipart?label=latest)](https://crates.io/crates/actix-multipart) +[![Documentation](https://docs.rs/actix-multipart/badge.svg?version=0.4.0-beta.2)](https://docs.rs/actix-multipart/0.4.0-beta.2) +[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-multipart.svg) +
+[![dependency status](https://deps.rs/crate/actix-multipart/0.4.0-beta.2/status.svg)](https://deps.rs/crate/actix-multipart/0.4.0-beta.2) +[![Download](https://img.shields.io/crates/d/actix-multipart.svg)](https://crates.io/crates/actix-multipart) +[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & Resources + +- [API Documentation](https://docs.rs/actix-multipart) +- [Chat on Gitter](https://gitter.im/actix/actix-web) +- Minimum Supported Rust Version (MSRV): 1.46.0 diff --git a/actix-multipart/src/lib.rs b/actix-multipart/src/lib.rs index a3bc36e21..0cf54e70e 100644 --- a/actix-multipart/src/lib.rs +++ b/actix-multipart/src/lib.rs @@ -1,4 +1,4 @@ -//! Multipart form support for Actix web. +//! Multipart form support for Actix Web. #![deny(rust_2018_idioms)] #![allow(clippy::borrow_interior_mutable_const)] diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index b476f1791..8cd1c8e0c 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -13,9 +13,7 @@ use futures_util::stream::{LocalBoxStream, Stream, StreamExt}; use actix_utils::task::LocalWaker; use actix_web::error::{ParseError, PayloadError}; -use actix_web::http::header::{ - self, ContentDisposition, HeaderMap, HeaderName, HeaderValue, -}; +use actix_web::http::header::{self, ContentDisposition, HeaderMap, HeaderName, HeaderValue}; use crate::error::MultipartError; @@ -120,10 +118,7 @@ impl Multipart { impl Stream for Multipart { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if let Some(err) = self.error.take() { Poll::Ready(Some(Err(err))) } else if self.safety.current() { @@ -142,9 +137,7 @@ impl Stream for Multipart { } impl InnerMultipart { - fn read_headers( - payload: &mut PayloadBuffer, - ) -> Result, MultipartError> { + fn read_headers(payload: &mut PayloadBuffer) -> Result, MultipartError> { match payload.read_until(b"\r\n\r\n")? { None => { if payload.eof { @@ -226,8 +219,7 @@ impl InnerMultipart { if chunk.len() < boundary.len() { continue; } - if &chunk[..2] == b"--" - && &chunk[2..chunk.len() - 2] == boundary.as_bytes() + if &chunk[..2] == b"--" && &chunk[2..chunk.len() - 2] == boundary.as_bytes() { break; } else { @@ -273,9 +265,7 @@ impl InnerMultipart { match field.borrow_mut().poll(safety) { Poll::Pending => return Poll::Pending, Poll::Ready(Some(Ok(_))) => continue, - Poll::Ready(Some(Err(e))) => { - return Poll::Ready(Some(Err(e))) - } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(None) => true, } } @@ -311,10 +301,7 @@ impl InnerMultipart { } // read boundary InnerState::Boundary => { - match InnerMultipart::read_boundary( - &mut *payload, - &self.boundary, - )? { + match InnerMultipart::read_boundary(&mut *payload, &self.boundary)? { None => return Poll::Pending, Some(eof) => { if eof { @@ -326,7 +313,7 @@ impl InnerMultipart { } } } - _ => (), + _ => {} } // read field headers for next field @@ -418,8 +405,7 @@ impl Field { pub fn content_disposition(&self) -> Option { // RFC 7578: 'Each part MUST contain a Content-Disposition header field // where the disposition type is "form-data".' - if let Some(content_disposition) = self.headers.get(&header::CONTENT_DISPOSITION) - { + if let Some(content_disposition) = self.headers.get(&header::CONTENT_DISPOSITION) { ContentDisposition::from_raw(content_disposition).ok() } else { None @@ -430,15 +416,10 @@ impl Field { impl Stream for Field { type Item = Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.safety.current() { let mut inner = self.inner.borrow_mut(); - if let Some(mut payload) = - inner.payload.as_ref().unwrap().get_mut(&self.safety) - { + if let Some(mut payload) = inner.payload.as_ref().unwrap().get_mut(&self.safety) { payload.poll_stream(cx)?; } inner.poll(&self.safety) @@ -607,8 +588,7 @@ impl InnerField { return Poll::Ready(None); } - let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) - { + let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) { if !self.eof { let res = if let Some(ref mut len) = self.length { InnerField::read_len(&mut *payload, len) @@ -628,7 +608,9 @@ impl InnerField { Ok(None) => Poll::Pending, Ok(Some(line)) => { if line.as_ref() != b"\r\n" { - log::warn!("multipart field did not read all the data or it is malformed"); + log::warn!( + "multipart field did not read all the data or it is malformed" + ); } Poll::Ready(None) } @@ -804,9 +786,7 @@ impl PayloadBuffer { /// Read bytes until new line delimiter or eof pub fn readline_or_eof(&mut self) -> Result, MultipartError> { match self.readline() { - Err(MultipartError::Incomplete) if self.eof => { - Ok(Some(self.buf.split().freeze())) - } + Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())), line => line, } } @@ -835,7 +815,7 @@ mod tests { async fn test_boundary() { let headers = HeaderMap::new(); match Multipart::boundary(&headers) { - Err(MultipartError::NoContentType) => (), + Err(MultipartError::NoContentType) => {} _ => unreachable!("should not happen"), } @@ -846,7 +826,7 @@ mod tests { ); match Multipart::boundary(&headers) { - Err(MultipartError::ParseContentType) => (), + Err(MultipartError::ParseContentType) => {} _ => unreachable!("should not happen"), } @@ -856,7 +836,7 @@ mod tests { header::HeaderValue::from_static("multipart/mixed"), ); match Multipart::boundary(&headers) { - Err(MultipartError::Boundary) => (), + Err(MultipartError::Boundary) => {} _ => unreachable!("should not happen"), } @@ -902,10 +882,7 @@ mod tests { impl Stream for SlowStream { type Item = Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); if !this.ready { this.ready = true; @@ -956,17 +933,17 @@ mod tests { let mut multipart = Multipart::new(&headers, payload); match multipart.next().await.unwrap() { - Ok(_) => (), + Ok(_) => {} _ => unreachable!(), } match multipart.next().await.unwrap() { - Ok(_) => (), + Ok(_) => {} _ => unreachable!(), } match multipart.next().await { - None => (), + None => {} _ => unreachable!(), } } @@ -993,7 +970,7 @@ mod tests { _ => unreachable!(), } match field.next().await { - None => (), + None => {} _ => unreachable!(), } } @@ -1010,7 +987,7 @@ mod tests { _ => unreachable!(), } match field.next().await { - None => (), + None => {} _ => unreachable!(), } } @@ -1018,7 +995,7 @@ mod tests { } match multipart.next().await { - None => (), + None => {} _ => unreachable!(), } } @@ -1066,7 +1043,7 @@ mod tests { } match multipart.next().await { - None => (), + None => {} _ => unreachable!(), } } diff --git a/actix-web-actors/CHANGES.md b/actix-web-actors/CHANGES.md index 9df0df159..acd9ceada 100644 --- a/actix-web-actors/CHANGES.md +++ b/actix-web-actors/CHANGES.md @@ -1,7 +1,20 @@ # Changes -## Unreleased - 2020-xx-xx -* Upgrade `pin-project` to `1.0`. +## Unreleased - 2021-xx-xx + + +## 4.0.0-beta.2 - 2021-02-10 +* No notable changes. + + +## 4.0.0-beta.1 - 2021-01-07 +* Update `pin-project` to `1.0`. +* Update `bytes` to `1.0`. [#1813] +* `WebsocketContext::text` now takes an `Into`. [#1864] + +[#1813]: https://github.com/actix/actix-web/pull/1813 +[#1864]: https://github.com/actix/actix-web/pull/1864 + ## 3.0.0 - 2020-09-11 * No significant changes from `3.0.0-beta.2`. diff --git a/actix-web-actors/Cargo.toml b/actix-web-actors/Cargo.toml index 920940c40..698ff9420 100644 --- a/actix-web-actors/Cargo.toml +++ b/actix-web-actors/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "actix-web-actors" -version = "3.0.0" +version = "4.0.0-beta.2" authors = ["Nikolay Kim "] -description = "Actix actors support for actix web framework." +description = "Actix actors support for Actix Web" readme = "README.md" keywords = ["actix", "http", "web", "framework", "async"] homepage = "https://actix.rs" @@ -16,16 +16,18 @@ name = "actix_web_actors" path = "src/lib.rs" [dependencies] -actix = "0.10.0" -actix-web = { version = "3.0.0", default-features = false } -actix-http = "2.0.0" -actix-codec = "0.3.0" -bytes = "0.5.2" -futures-channel = { version = "0.3.5", default-features = false } -futures-core = { version = "0.3.5", default-features = false } +actix = { version = "0.11.0-beta.2", default-features = false } +actix-codec = "0.4.0-beta.1" +actix-http = "3.0.0-beta.3" +actix-web = { version = "4.0.0-beta.3", default-features = false } + +bytes = "1" +bytestring = "1" +futures-core = { version = "0.3.7", default-features = false } pin-project = "1.0.0" +tokio = { version = "1", features = ["sync"] } [dev-dependencies] -actix-rt = "1.1.1" -env_logger = "0.7" -futures-util = { version = "0.3.5", default-features = false } +actix-rt = "2" +env_logger = "0.8" +futures-util = { version = "0.3.7", default-features = false } diff --git a/actix-web-actors/README.md b/actix-web-actors/README.md index fb8c3a621..c9b588153 100644 --- a/actix-web-actors/README.md +++ b/actix-web-actors/README.md @@ -1,8 +1,18 @@ -Actix actors support for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-web-actors)](https://crates.io/crates/actix-web-actors) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +# actix-web-actors -## Documentation & community resources +> Actix actors support for Actix Web. -* [API Documentation](https://docs.rs/actix-web-actors/) -* [Chat on gitter](https://gitter.im/actix/actix) -* Cargo package: [actix-web-actors](https://crates.io/crates/actix-web-actors) -* Minimum supported Rust version: 1.40 or later +[![crates.io](https://img.shields.io/crates/v/actix-web-actors?label=latest)](https://crates.io/crates/actix-web-actors) +[![Documentation](https://docs.rs/actix-web-actors/badge.svg?version=0.5.0)](https://docs.rs/actix-web-actors/0.5.0) +[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +![License](https://img.shields.io/crates/l/actix-web-actors.svg) +
+[![dependency status](https://deps.rs/crate/actix-web-actors/0.5.0/status.svg)](https://deps.rs/crate/actix-web-actors/0.5.0) +[![Download](https://img.shields.io/crates/d/actix-web-actors.svg)](https://crates.io/crates/actix-web-actors) +[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & Resources + +- [API Documentation](https://docs.rs/actix-web-actors) +- [Chat on Gitter](https://gitter.im/actix/actix-web) +- Minimum supported Rust version: 1.46 or later diff --git a/actix-web-actors/src/context.rs b/actix-web-actors/src/context.rs index 0839a4288..bd6635cba 100644 --- a/actix-web-actors/src/context.rs +++ b/actix-web-actors/src/context.rs @@ -3,19 +3,17 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use actix::dev::{ - AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, ToEnvelope, -}; +use actix::dev::{AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, ToEnvelope}; use actix::fut::ActorFuture; use actix::{ Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message, SpawnHandle, }; use actix_web::error::Error; use bytes::Bytes; -use futures_channel::oneshot::Sender; use futures_core::Stream; +use tokio::sync::oneshot::Sender; -/// Execution context for http actors +/// Execution context for HTTP actors pub struct HttpContext
where A: Actor>, @@ -165,10 +163,7 @@ where { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.fut.alive() { let _ = Pin::new(&mut self.fut).poll(cx); } @@ -233,14 +228,14 @@ mod tests { #[actix_rt::test] async fn test_default_resource() { - let mut srv = + let srv = init_service(App::new().service(web::resource("/test").to(|| { HttpResponse::Ok().streaming(HttpContext::create(MyActor { count: 0 })) }))) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let body = read_body(resp).await; diff --git a/actix-web-actors/src/lib.rs b/actix-web-actors/src/lib.rs index 1fdd266dc..4310d784e 100644 --- a/actix-web-actors/src/lib.rs +++ b/actix-web-actors/src/lib.rs @@ -1,4 +1,4 @@ -//! Actix actors integration for Actix web framework +//! Actix actors support for Actix Web. #![deny(rust_2018_idioms)] #![allow(clippy::borrow_interior_mutable_const)] diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index 8fd03f6a1..1ab4cfce5 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -1,18 +1,18 @@ -//! Websocket integration -use std::collections::VecDeque; +//! Websocket integration. + use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; +use std::{collections::VecDeque, convert::TryFrom}; use actix::dev::{ - AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, - ToEnvelope, + AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, ToEnvelope, }; use actix::fut::ActorFuture; use actix::{ - Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, - Message as ActixMessage, SpawnHandle, + Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message as ActixMessage, + SpawnHandle, }; use actix_codec::{Decoder, Encoder}; use actix_http::ws::{hash_key, Codec}; @@ -24,21 +24,21 @@ use actix_web::error::{Error, PayloadError}; use actix_web::http::{header, Method, StatusCode}; use actix_web::{HttpRequest, HttpResponse}; use bytes::{Bytes, BytesMut}; -use futures_channel::oneshot::Sender; +use bytestring::ByteString; use futures_core::Stream; +use tokio::sync::oneshot::Sender; -/// Do websocket handshake and start ws actor. +/// Perform WebSocket handshake and start actor. pub fn start(actor: A, req: &HttpRequest, stream: T) -> Result where - A: Actor> - + StreamHandler>, + A: Actor> + StreamHandler>, T: Stream> + 'static, { let mut res = handshake(req)?; Ok(res.streaming(WebsocketContext::create(actor, stream))) } -/// Do websocket handshake and start ws actor. +/// Perform WebSocket handshake and start actor. /// /// `req` is an HTTP Request that should be requesting a websocket protocol /// change. `stream` should be a `Bytes` stream (such as @@ -48,15 +48,14 @@ where /// /// If successful, returns a pair where the first item is an address for the /// created actor and the second item is the response that should be returned -/// from the websocket request. +/// from the WebSocket request. pub fn start_with_addr( actor: A, req: &HttpRequest, stream: T, ) -> Result<(Addr, HttpResponse), Error> where - A: Actor> - + StreamHandler>, + A: Actor> + StreamHandler>, T: Stream> + 'static, { let mut res = handshake(req)?; @@ -64,7 +63,7 @@ where Ok((addr, res.streaming(out_stream))) } -/// Do websocket handshake and start ws actor. +/// Do WebSocket handshake and start ws actor. /// /// `protocols` is a sequence of known protocols. pub fn start_with_protocols( @@ -74,15 +73,14 @@ pub fn start_with_protocols( stream: T, ) -> Result where - A: Actor> - + StreamHandler>, + A: Actor> + StreamHandler>, T: Stream> + 'static, { let mut res = handshake_with_protocols(req, protocols)?; Ok(res.streaming(WebsocketContext::create(actor, stream))) } -/// Prepare `WebSocket` handshake response. +/// Prepare WebSocket handshake response. /// /// This function returns handshake `HttpResponse`, ready to send to peer. /// It does not perform any IO. @@ -90,7 +88,7 @@ pub fn handshake(req: &HttpRequest) -> Result( - stream: S, - f: F, - ) -> impl Stream> + pub fn with_factory(stream: S, f: F) -> impl Stream> where F: FnOnce(&mut Self) -> A + 'static, A: StreamHandler>, @@ -338,13 +333,13 @@ where /// Send text frame #[inline] - pub fn text>(&mut self, text: T) { + pub fn text(&mut self, text: impl Into) { self.write_raw(Message::Text(text.into())); } /// Send binary frame #[inline] - pub fn binary>(&mut self, data: B) { + pub fn binary(&mut self, data: impl Into) { self.write_raw(Message::Binary(data.into())); } @@ -421,10 +416,7 @@ where { type Item = Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); if this.fut.alive() { @@ -491,10 +483,7 @@ where { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.as_mut().project(); if !*this.closed { @@ -510,9 +499,10 @@ where } Poll::Pending => break, Poll::Ready(Some(Err(e))) => { - return Poll::Ready(Some(Err(ProtocolError::Io( - io::Error::new(io::ErrorKind::Other, format!("{}", e)), - )))); + return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new( + io::ErrorKind::Other, + format!("{}", e), + ))))); } } } @@ -528,16 +518,14 @@ where } Some(frm) => { let msg = match frm { - Frame::Text(data) => Message::Text( - std::str::from_utf8(&data) - .map_err(|e| { - ProtocolError::Io(io::Error::new( - io::ErrorKind::Other, - format!("{}", e), - )) - })? - .to_string(), - ), + Frame::Text(data) => { + Message::Text(ByteString::try_from(data).map_err(|e| { + ProtocolError::Io(io::Error::new( + io::ErrorKind::Other, + format!("{}", e), + )) + })?) + } Frame::Binary(data) => Message::Binary(data), Frame::Ping(s) => Message::Ping(s), Frame::Pong(s) => Message::Pong(s), @@ -573,7 +561,7 @@ mod tests { ); let req = TestRequest::default() - .header(header::UPGRADE, header::HeaderValue::from_static("test")) + .insert_header((header::UPGRADE, header::HeaderValue::from_static("test"))) .to_http_request(); assert_eq!( HandshakeError::NoWebsocketUpgrade, @@ -581,10 +569,10 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) + )) .to_http_request(); assert_eq!( HandshakeError::NoConnectionUpgrade, @@ -592,14 +580,14 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) + )) .to_http_request(); assert_eq!( HandshakeError::NoVersionHeader, @@ -607,18 +595,18 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("5"), - ) + )) .to_http_request(); assert_eq!( HandshakeError::UnsupportedVersion, @@ -626,18 +614,18 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) + )) .to_http_request(); assert_eq!( HandshakeError::BadWebsocketKey, @@ -645,22 +633,22 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), - ) + )) .to_http_request(); let resp = handshake(&req).unwrap().finish(); @@ -669,26 +657,26 @@ mod tests { assert_eq!(None, resp.headers().get(&header::TRANSFER_ENCODING)); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_PROTOCOL, header::HeaderValue::from_static("graphql"), - ) + )) .to_http_request(); let protocols = ["graphql"]; @@ -710,26 +698,26 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_PROTOCOL, header::HeaderValue::from_static("p1, p2, p3"), - ) + )) .to_http_request(); let protocols = vec!["p3", "p2"]; @@ -751,26 +739,26 @@ mod tests { ); let req = TestRequest::default() - .header( + .insert_header(( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + )) + .insert_header(( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), - ) - .header( + )) + .insert_header(( header::SEC_WEBSOCKET_PROTOCOL, header::HeaderValue::from_static("p1,p2,p3"), - ) + )) .to_http_request(); let protocols = vec!["p3", "p2"]; diff --git a/actix-web-actors/tests/test_ws.rs b/actix-web-actors/tests/test_ws.rs index dda9f6f0b..912480ae4 100644 --- a/actix-web-actors/tests/test_ws.rs +++ b/actix-web-actors/tests/test_ws.rs @@ -11,17 +11,13 @@ impl Actor for Ws { } impl StreamHandler> for Ws { - fn handle( - &mut self, - msg: Result, - ctx: &mut Self::Context, - ) { + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { match msg.unwrap() { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => ctx.text(text), ws::Message::Binary(bin) => ctx.binary(bin), ws::Message::Close(reason) => ctx.close(reason), - _ => (), + _ => {} } } } @@ -30,18 +26,13 @@ impl StreamHandler> for Ws { async fn test_simple() { let mut srv = test::start(|| { App::new().service(web::resource("/").to( - |req: HttpRequest, stream: web::Payload| async move { - ws::start(Ws, &req, stream) - }, + |req: HttpRequest, stream: web::Payload| async move { ws::start(Ws, &req, stream) }, )) }); // client service let mut framed = srv.ws().await.unwrap(); - framed - .send(ws::Message::Text("text".to_string())) - .await - .unwrap(); + framed.send(ws::Message::Text("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); diff --git a/actix-web-codegen/CHANGES.md b/actix-web-codegen/CHANGES.md index 1ab51f924..2ce728aad 100644 --- a/actix-web-codegen/CHANGES.md +++ b/actix-web-codegen/CHANGES.md @@ -1,6 +1,10 @@ # Changes -## Unreleased - 2020-xx-xx +## Unreleased - 2021-xx-xx + + +## 0.5.0-beta.1 - 2021-02-10 +* Use new call signature for `System::new`. ## 0.4.0 - 2020-09-20 diff --git a/actix-web-codegen/Cargo.toml b/actix-web-codegen/Cargo.toml index fd99a8376..886d9ac3e 100644 --- a/actix-web-codegen/Cargo.toml +++ b/actix-web-codegen/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "actix-web-codegen" -version = "0.4.0" -description = "Actix web proc macros" +version = "0.5.0-beta.1" +description = "Routing and runtime macros for Actix Web" readme = "README.md" homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web" @@ -19,8 +19,8 @@ syn = { version = "1", features = ["full", "parsing"] } proc-macro2 = "1" [dev-dependencies] -actix-rt = "1.1.1" -actix-web = "3.0.0" -futures-util = { version = "0.3.5", default-features = false } +actix-rt = "2" +actix-web = "4.0.0-beta.3" +futures-util = { version = "0.3.7", default-features = false } trybuild = "1" rustversion = "1" diff --git a/actix-web-codegen/README.md b/actix-web-codegen/README.md index 887502075..5820bb443 100644 --- a/actix-web-codegen/README.md +++ b/actix-web-codegen/README.md @@ -1,22 +1,24 @@ # actix-web-codegen -> Helper and convenience macros for Actix Web +> Routing and runtime macros for Actix Web. -[![crates.io](https://meritbadge.herokuapp.com/actix-web-codegen)](https://crates.io/crates/actix-web-codegen) -[![Documentation](https://docs.rs/actix-web-codegen/badge.svg)](https://docs.rs/actix-web-codegen/0.4.0/actix_web_codegen/) +[![crates.io](https://img.shields.io/crates/v/actix-web-codegen?label=latest)](https://crates.io/crates/actix-web-codegen) +[![Documentation](https://docs.rs/actix-web-codegen/badge.svg?version=0.5.0-beta.1)](https://docs.rs/actix-web-codegen/0.5.0-beta.1) [![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) -[![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) -[![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) +![License](https://img.shields.io/crates/l/actix-web-codegen.svg) +
+[![dependency status](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.1/status.svg)](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.1) +[![Download](https://img.shields.io/crates/d/actix-web-codegen.svg)](https://crates.io/crates/actix-web-codegen) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-web-codegen) - [Chat on Gitter](https://gitter.im/actix/actix-web) -- Cargo package: [actix-web-codegen](https://crates.io/crates/actix-web-codegen) - Minimum supported Rust version: 1.46 or later. ## Compile Testing + Uses the [`trybuild`] crate. All compile fail tests should include a stderr file generated by `trybuild`. See the [workflow section](https://github.com/dtolnay/trybuild#workflow) of the trybuild docs for info on how to do this. [`trybuild`]: https://github.com/dtolnay/trybuild diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs index f5bec9773..b6f7b2b59 100644 --- a/actix-web-codegen/src/lib.rs +++ b/actix-web-codegen/src/lib.rs @@ -1,6 +1,6 @@ -//! Macros for reducing boilerplate code in Actix Web applications. +//! Routing and runtime macros for Actix Web. //! -//! ## Actix Web Re-exports +//! # Actix Web Re-exports //! Actix Web re-exports a version of this crate in it's entirety so you usually don't have to //! specify a dependency on this crate explicitly. Sometimes, however, updates are made to this //! crate before the actix-web dependency is updated. Therefore, code examples here will show @@ -10,7 +10,7 @@ //! # Runtime Setup //! Used for setting up the actix async runtime. See [macro@main] macro docs. //! -//! ```rust +//! ``` //! #[actix_web_codegen::main] // or `#[actix_web::main]` in Actix Web apps //! async fn main() { //! async { println!("Hello world"); }.await @@ -23,7 +23,7 @@ //! //! See docs for: [GET], [POST], [PATCH], [PUT], [DELETE], [HEAD], [CONNECT], [OPTIONS], [TRACE] //! -//! ```rust +//! ``` //! # use actix_web::HttpResponse; //! # use actix_web_codegen::get; //! #[get("/test")] @@ -36,7 +36,7 @@ //! Similar to the single method handler macro but takes one or more arguments for the HTTP methods //! it should respond to. See [macro@route] macro docs. //! -//! ```rust +//! ``` //! # use actix_web::HttpResponse; //! # use actix_web_codegen::route; //! #[route("/test", method="GET", method="HEAD")] @@ -160,7 +160,7 @@ method_macro! { /// # Actix Web Re-export /// This macro can be applied with `#[actix_web::main]` when used in Actix Web applications. /// -/// # Usage +/// # Examples /// ```rust /// #[actix_web_codegen::main] /// async fn main() { @@ -176,7 +176,6 @@ pub fn main(_: TokenStream, item: TokenStream) -> TokenStream { let vis = &input.vis; let sig = &mut input.sig; let body = &input.block; - let name = &sig.ident; if sig.asyncness.is_none() { return syn::Error::new_spanned(sig.fn_token, "only async fn is supported") @@ -189,7 +188,7 @@ pub fn main(_: TokenStream, item: TokenStream) -> TokenStream { (quote! { #(#attrs)* #vis #sig { - actix_web::rt::System::new(stringify!(#name)) + actix_web::rt::System::new() .block_on(async move { #body }) } }) diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs index dd2bccd7f..34cdad649 100644 --- a/actix-web-codegen/tests/test_macro.rs +++ b/actix-web-codegen/tests/test_macro.rs @@ -1,14 +1,11 @@ use std::future::Future; -use std::pin::Pin; use std::task::{Context, Poll}; use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::http::header::{HeaderName, HeaderValue}; use actix_web::{http, test, web::Path, App, Error, HttpResponse, Responder}; -use actix_web_codegen::{ - connect, delete, get, head, options, patch, post, put, route, trace, -}; -use futures_util::future; +use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, route, trace}; +use futures_util::future::{self, LocalBoxFuture}; // Make sure that we can name function as 'config' #[get("/config")] @@ -88,17 +85,16 @@ async fn route_test() -> impl Responder { pub struct ChangeStatusCode; -impl Transform for ChangeStatusCode +impl Transform for ChangeStatusCode where - S: Service, Error = Error>, + S: Service, Error = Error>, S::Future: 'static, B: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type InitError = (); type Transform = ChangeStatusCodeMiddleware; + type InitError = (); type Future = future::Ready>; fn new_transform(&self, service: S) -> Self::Future { @@ -110,23 +106,21 @@ pub struct ChangeStatusCodeMiddleware { service: S, } -impl Service for ChangeStatusCodeMiddleware +impl Service for ChangeStatusCodeMiddleware where - S: Service, Error = Error>, + S: Service, Error = Error>, S::Future: 'static, B: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - #[allow(clippy::type_complexity)] - type Future = Pin>>>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { let fut = self.service.call(req); Box::pin(async move { diff --git a/awc/CHANGES.md b/awc/CHANGES.md index e4f801bbe..e6ead2cc8 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,8 +1,48 @@ # Changes -## Unreleased - 2020-xx-xx +## Unreleased - 2021-xx-xx +### Added +* `ClientResponse::timeout` for set the timeout of collecting response body. [#1931] + ### Changed -* Bumped `rand` to `0.8` +* Feature `cookies` is now optional and enabled by default. [#1981] +* `ClientBuilder::connector` method would take `actix_http::client::Connector` type. [#2008] + +### Removed +* `ClientBuilder::default` function [#2008] +* `ClientBuilder::disable_redirects` and `ClientBuilder::max_redirects` method [#2008] + +[#1931]: https://github.com/actix/actix-web/pull/1931 +[#1981]: https://github.com/actix/actix-web/pull/1981 +[#2008]: https://github.com/actix/actix-web/pull/2008 + +## 3.0.0-beta.2 - 2021-02-10 +### Added +* `ClientRequest::insert_header` method which allows using typed headers. [#1869] +* `ClientRequest::append_header` method which allows using typed headers. [#1869] +* `trust-dns` optional feature to enable `trust-dns-resolver` as client dns resolver. [#1969] + +### Changed +* Relax default timeout for `Connector` to 5 seconds(original 1 second). [#1905] + +### Removed +* `ClientRequest::set`; use `ClientRequest::insert_header`. [#1869] +* `ClientRequest::set_header`; use `ClientRequest::insert_header`. [#1869] +* `ClientRequest::set_header_if_none`; use `ClientRequest::insert_header_if_none`. [#1869] +* `ClientRequest::header`; use `ClientRequest::append_header`. [#1869] + +[#1869]: https://github.com/actix/actix-web/pull/1869 +[#1905]: https://github.com/actix/actix-web/pull/1905 +[#1969]: https://github.com/actix/actix-web/pull/1969 + + +## 3.0.0-beta.1 - 2021-01-07 +### Changed +* Update `rand` to `0.8` +* Update `bytes` to `1.0`. [#1813] +* Update `rust-tls` to `0.19`. [#1813] + +[#1813]: https://github.com/actix/actix-web/pull/1813 ## 2.0.3 - 2020-11-29 diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 2e92526d2..9beecc6d4 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "awc" -version = "2.0.3" +version = "3.0.0-beta.2" authors = ["Nikolay Kim "] description = "Async HTTP and WebSocket client library built on the Actix ecosystem" readme = "README.md" @@ -22,31 +22,38 @@ name = "awc" path = "src/lib.rs" [package.metadata.docs.rs] -features = ["openssl", "rustls", "compress"] +# features that docs.rs will build with +features = ["openssl", "rustls", "compress", "cookies"] [features] -default = ["compress"] +default = ["compress", "cookies"] # openssl -openssl = ["open-ssl", "actix-http/openssl"] +openssl = ["tls-openssl", "actix-http/openssl"] # rustls -rustls = ["rust-tls", "actix-http/rustls"] +rustls = ["tls-rustls", "actix-http/rustls"] # content-encoding support compress = ["actix-http/compress"] +# cookie parsing and cookie jar +cookies = ["actix-http/cookies"] + +# trust-dns as dns resolver +trust-dns = ["actix-http/trust-dns"] + [dependencies] -actix-codec = "0.3.0" -actix-service = "1.0.6" -actix-http = "2.2.0" -actix-rt = "1.0.0" +actix-codec = "0.4.0-beta.1" +actix-service = "2.0.0-beta.4" +actix-http = "3.0.0-beta.3" +actix-rt = "2" base64 = "0.13" -bytes = "0.5.3" +bytes = "1" cfg-if = "1.0" -derive_more = "0.99.2" -futures-core = { version = "0.3.5", default-features = false } +derive_more = "0.99.5" +futures-core = { version = "0.3.7", default-features = false } log =" 0.4" mime = "0.3" percent-encoding = "2.1" @@ -54,19 +61,26 @@ rand = "0.8" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.7" -open-ssl = { version = "0.10", package = "openssl", optional = true } -rust-tls = { version = "0.18.0", package = "rustls", optional = true, features = ["dangerous_configuration"] } +tls-openssl = { version = "0.10.9", package = "openssl", optional = true } +tls-rustls = { version = "0.19.0", package = "rustls", optional = true, features = ["dangerous_configuration"] } + +[target.'cfg(windows)'.dependencies.tls-openssl] +version = "0.10.9" +package = "openssl" +features = ["vendored"] +optional = true [dev-dependencies] -actix-connect = { version = "2.0.0", features = ["openssl"] } -actix-web = { version = "3.0.0", features = ["openssl"] } -actix-http = { version = "2.0.0", features = ["openssl"] } -actix-http-test = { version = "2.0.0", features = ["openssl"] } -actix-utils = "2.0.0" -actix-server = "1.0.0" -actix-tls = { version = "2.0.0", features = ["openssl", "rustls"] } +actix-web = { version = "4.0.0-beta.3", features = ["openssl"] } +actix-http = { version = "3.0.0-beta.3", features = ["openssl"] } +actix-http-test = { version = "3.0.0-beta.2", features = ["openssl"] } +actix-utils = "3.0.0-beta.1" +actix-server = "2.0.0-beta.3" +actix-tls = { version = "3.0.0-beta.3", features = ["openssl", "rustls"] } + brotli2 = "0.3.2" flate2 = "1.0.13" -futures-util = { version = "0.3.5", default-features = false } -env_logger = "0.7" +futures-util = { version = "0.3.7", default-features = false } +env_logger = "0.8" +rcgen = "0.8" webpki = "0.21" diff --git a/awc/README.md b/awc/README.md index 972a80140..043ae6a41 100644 --- a/awc/README.md +++ b/awc/README.md @@ -3,9 +3,9 @@ > Async HTTP and WebSocket client library. [![crates.io](https://img.shields.io/crates/v/awc?label=latest)](https://crates.io/crates/awc) -[![Documentation](https://docs.rs/awc/badge.svg?version=2.0.3)](https://docs.rs/awc/2.0.3) -![Apache 2.0 or MIT licensed](https://img.shields.io/crates/l/awc) -[![Dependency Status](https://deps.rs/crate/awc/2.0.3/status.svg)](https://deps.rs/crate/awc/2.0.3) +[![Documentation](https://docs.rs/awc/badge.svg?version=3.0.0-beta.2)](https://docs.rs/awc/3.0.0-beta.2) +![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/awc) +[![Dependency Status](https://deps.rs/crate/awc/3.0.0-beta.2/status.svg)](https://deps.rs/crate/awc/3.0.0-beta.2) [![Join the chat at https://gitter.im/actix/actix-web](https://badges.gitter.im/actix/actix-web.svg)](https://gitter.im/actix/actix-web?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ## Documentation & Resources @@ -21,7 +21,7 @@ use actix_rt::System; use awc::Client; fn main() { - System::new("test").block_on(async { + System::new().block_on(async { let client = Client::default(); let res = client diff --git a/awc/src/builder.rs b/awc/src/builder.rs index 7cd659c38..4495b39fd 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -1,63 +1,82 @@ -use std::cell::RefCell; use std::convert::TryFrom; use std::fmt; use std::rc::Rc; use std::time::Duration; -use actix_http::client::{Connect as HttpConnect, ConnectError, Connection, Connector}; -use actix_http::http::{self, header, Error as HttpError, HeaderMap, HeaderName}; +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http::{ + client::{Connector, TcpConnect, TcpConnectError, TcpConnection}, + http::{self, header, Error as HttpError, HeaderMap, HeaderName, Uri}, +}; +use actix_rt::net::TcpStream; use actix_service::Service; -use crate::connect::{Connect, ConnectorWrapper}; +use crate::connect::ConnectorWrapper; use crate::{Client, ClientConfig}; /// An HTTP Client builder /// /// This type can be used to construct an instance of `Client` through a /// builder-like pattern. -pub struct ClientBuilder { +pub struct ClientBuilder { default_headers: bool, - allow_redirects: bool, - max_redirects: usize, max_http_version: Option, stream_window_size: Option, conn_window_size: Option, headers: HeaderMap, timeout: Option, - connector: Option>>, -} - -impl Default for ClientBuilder { - fn default() -> Self { - Self::new() - } + connector: Connector, } impl ClientBuilder { - pub fn new() -> Self { + #[allow(clippy::new_ret_no_self)] + pub fn new() -> ClientBuilder< + impl Service< + TcpConnect, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone, + TcpStream, + > { ClientBuilder { default_headers: true, - allow_redirects: true, - max_redirects: 10, headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), - connector: None, + connector: Connector::new(), max_http_version: None, stream_window_size: None, conn_window_size: None, } } +} +impl ClientBuilder +where + S: Service, Response = TcpConnection, Error = TcpConnectError> + + Clone + + 'static, + Io: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, +{ /// Use custom connector service. - pub fn connector(mut self, connector: T) -> Self + pub fn connector(self, connector: Connector) -> ClientBuilder where - T: Service + 'static, - T::Response: Connection, - ::Future: 'static, - T::Future: 'static, + S1: Service< + TcpConnect, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, { - self.connector = Some(RefCell::new(Box::new(ConnectorWrapper(connector)))); - self + ClientBuilder { + default_headers: self.default_headers, + headers: self.headers, + timeout: self.timeout, + connector, + max_http_version: self.max_http_version, + stream_window_size: self.stream_window_size, + conn_window_size: self.conn_window_size, + } } /// Set request timeout @@ -75,16 +94,9 @@ impl ClientBuilder { self } - /// Do not follow redirects. + /// Maximum supported HTTP major version. /// - /// Redirects are allowed by default. - pub fn disable_redirects(mut self) -> Self { - self.allow_redirects = false; - self - } - - /// Maximum supported http major version - /// Supported versions http/1.1, http/2 + /// Supported versions are HTTP/1.1 and HTTP/2. pub fn max_http_version(mut self, val: http::Version) -> Self { self.max_http_version = Some(val); self @@ -108,14 +120,6 @@ impl ClientBuilder { self } - /// Set max number of redirects. - /// - /// Max redirects is set to 10 by default. - pub fn max_redirects(mut self, num: usize) -> Self { - self.max_redirects = num; - self - } - /// Do not add default request headers. /// By default `Date` and `User-Agent` headers are set. pub fn no_default_headers(mut self) -> Self { @@ -133,7 +137,7 @@ impl ClientBuilder { V::Error: fmt::Debug, { match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { + Ok(key) => match value.try_into_value() { Ok(value) => { self.headers.append(key, value); } @@ -145,9 +149,9 @@ impl ClientBuilder { } /// Set client wide HTTP basic authorization header - pub fn basic_auth(self, username: U, password: Option<&str>) -> Self + pub fn basic_auth(self, username: N, password: Option<&str>) -> Self where - U: fmt::Display, + N: fmt::Display, { let auth = match password { Some(password) => format!("{}:{}", username, password), @@ -169,28 +173,24 @@ impl ClientBuilder { /// Finish build process and create `Client` instance. pub fn finish(self) -> Client { - let connector = if let Some(connector) = self.connector { - connector - } else { - let mut connector = Connector::new(); - if let Some(val) = self.max_http_version { - connector = connector.max_http_version(val) - }; - if let Some(val) = self.conn_window_size { - connector = connector.initial_connection_window_size(val) - }; - if let Some(val) = self.stream_window_size { - connector = connector.initial_window_size(val) - }; - RefCell::new( - Box::new(ConnectorWrapper(connector.finish())) as Box - ) + let mut connector = self.connector; + + if let Some(val) = self.max_http_version { + connector = connector.max_http_version(val); }; + if let Some(val) = self.conn_window_size { + connector = connector.initial_connection_window_size(val) + }; + if let Some(val) = self.stream_window_size { + connector = connector.initial_window_size(val) + }; + let config = ClientConfig { headers: self.headers, timeout: self.timeout, - connector, + connector: Box::new(ConnectorWrapper::new(connector.finish())) as _, }; + Client(Rc::new(config)) } } diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 7fbe1543a..97af2d1cc 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,193 +1,115 @@ -use std::future::Future; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::{fmt, io, mem, net}; - -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_http::body::Body; -use actix_http::client::{ - Connect as ClientConnect, ConnectError, Connection, SendRequestError, +use std::{ + fmt, io, net, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; +use actix_http::{ + body::Body, + client::{Connect as ClientConnect, ConnectError, Connection, SendRequestError}, + h1::ClientCodec, + RequestHead, RequestHeadType, ResponseHead, }; -use actix_http::h1::ClientCodec; -use actix_http::http::HeaderMap; -use actix_http::{RequestHead, RequestHeadType, ResponseHead}; use actix_service::Service; +use futures_core::future::LocalBoxFuture; use crate::response::ClientResponse; -pub(crate) struct ConnectorWrapper(pub T); - -pub(crate) trait Connect { - fn send_request( - &mut self, - head: RequestHead, - body: Body, - addr: Option, - ) -> Pin>>>; - - fn send_request_extra( - &mut self, - head: Rc, - extra_headers: Option, - body: Body, - addr: Option, - ) -> Pin>>>; - - /// Send request, returns Response and Framed - fn open_tunnel( - &mut self, - head: RequestHead, - addr: Option, - ) -> Pin< - Box< - dyn Future< - Output = Result< - (ResponseHead, Framed), - SendRequestError, - >, - >, - >, - >; - - /// Send request and extra headers, returns Response and Framed - fn open_tunnel_extra( - &mut self, - head: Rc, - extra_headers: Option, - addr: Option, - ) -> Pin< - Box< - dyn Future< - Output = Result< - (ResponseHead, Framed), - SendRequestError, - >, - >, - >, - >; +pub(crate) struct ConnectorWrapper { + connector: T, } -impl Connect for ConnectorWrapper +impl ConnectorWrapper { + pub(crate) fn new(connector: T) -> Self { + Self { connector } + } +} + +pub type ConnectService = Box< + dyn Service< + ConnectRequest, + Response = ConnectResponse, + Error = SendRequestError, + Future = LocalBoxFuture<'static, Result>, + >, +>; + +pub enum ConnectRequest { + Client(RequestHeadType, Body, Option), + Tunnel(RequestHead, Option), +} + +pub enum ConnectResponse { + Client(ClientResponse), + Tunnel(ResponseHead, Framed), +} + +impl ConnectResponse { + pub fn into_client_response(self) -> ClientResponse { + match self { + ConnectResponse::Client(res) => res, + _ => panic!( + "ClientResponse only reachable with ConnectResponse::ClientResponse variant" + ), + } + } + + pub fn into_tunnel_response(self) -> (ResponseHead, Framed) { + match self { + ConnectResponse::Tunnel(head, framed) => (head, framed), + _ => panic!( + "TunnelResponse only reachable with ConnectResponse::TunnelResponse variant" + ), + } + } +} + +impl Service for ConnectorWrapper where - T: Service, + T: Service, T::Response: Connection, ::Io: 'static, - ::Future: 'static, - ::TunnelFuture: 'static, T::Future: 'static, { - fn send_request( - &mut self, - head: RequestHead, - body: Body, - addr: Option, - ) -> Pin>>> { + type Response = ConnectResponse; + type Error = SendRequestError; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::forward_ready!(connector); + + fn call(&self, req: ConnectRequest) -> Self::Future { // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.uri.clone(), - addr, - }); + let fut = match req { + ConnectRequest::Client(ref head, .., addr) => self.connector.call(ClientConnect { + uri: head.as_ref().uri.clone(), + addr, + }), + ConnectRequest::Tunnel(ref head, addr) => self.connector.call(ClientConnect { + uri: head.uri.clone(), + addr, + }), + }; Box::pin(async move { let connection = fut.await?; - // send request - connection - .send_request(RequestHeadType::from(head), body) - .await - .map(|(head, payload)| ClientResponse::new(head, payload)) - }) - } + match req { + ConnectRequest::Client(head, body, ..) => { + // send request + let (head, payload) = connection.send_request(head, body).await?; - fn send_request_extra( - &mut self, - head: Rc, - extra_headers: Option, - body: Body, - addr: Option, - ) -> Pin>>> { - // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.uri.clone(), - addr, - }); + Ok(ConnectResponse::Client(ClientResponse::new(head, payload))) + } + ConnectRequest::Tunnel(head, ..) => { + // send request + let (head, framed) = + connection.open_tunnel(RequestHeadType::from(head)).await?; - Box::pin(async move { - let connection = fut.await?; - - // send request - let (head, payload) = connection - .send_request(RequestHeadType::Rc(head, extra_headers), body) - .await?; - - Ok(ClientResponse::new(head, payload)) - }) - } - - fn open_tunnel( - &mut self, - head: RequestHead, - addr: Option, - ) -> Pin< - Box< - dyn Future< - Output = Result< - (ResponseHead, Framed), - SendRequestError, - >, - >, - >, - > { - // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.uri.clone(), - addr, - }); - - Box::pin(async move { - let connection = fut.await?; - - // send request - let (head, framed) = - connection.open_tunnel(RequestHeadType::from(head)).await?; - - let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); - Ok((head, framed)) - }) - } - - fn open_tunnel_extra( - &mut self, - head: Rc, - extra_headers: Option, - addr: Option, - ) -> Pin< - Box< - dyn Future< - Output = Result< - (ResponseHead, Framed), - SendRequestError, - >, - >, - >, - > { - // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.uri.clone(), - addr, - }); - - Box::pin(async move { - let connection = fut.await?; - - // send request - let (head, framed) = connection - .open_tunnel(RequestHeadType::Rc(head, extra_headers)) - .await?; - - let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); - Ok((head, framed)) + let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok(ConnectResponse::Tunnel(head, framed)) + } + } }) } } @@ -221,18 +143,11 @@ impl fmt::Debug for BoxedSocket { } impl AsyncRead for BoxedSocket { - unsafe fn prepare_uninitialized_buffer( - &self, - buf: &mut [mem::MaybeUninit], - ) -> bool { - self.0.as_read().prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) } } @@ -250,10 +165,7 @@ impl AsyncWrite for BoxedSocket { Pin::new(self.get_mut().0.as_write()).poll_flush(cx) } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx) } } diff --git a/awc/src/error.rs b/awc/src/error.rs index d008166d9..f86224e62 100644 --- a/awc/src/error.rs +++ b/awc/src/error.rs @@ -1,7 +1,6 @@ -//! Http client errors -pub use actix_http::client::{ - ConnectError, FreezeRequestError, InvalidUrl, SendRequestError, -}; +//! HTTP client errors + +pub use actix_http::client::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; pub use actix_http::error::PayloadError; pub use actix_http::http::Error as HttpError; pub use actix_http::ws::HandshakeError as WsHandshakeError; diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index f7098863c..46b4063a0 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -144,8 +144,10 @@ impl FrozenSendBuilder { V: IntoHeaderValue, { match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => self.extra_headers.insert(key, value), + Ok(key) => match value.try_into_value() { + Ok(value) => { + self.extra_headers.insert(key, value); + } Err(e) => self.err = Some(e.into()), }, Err(e) => self.err = Some(e.into()), diff --git a/awc/src/lib.rs b/awc/src/lib.rs index 611a7daa7..f92eabbfb 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -2,12 +2,12 @@ //! //! ## Making a GET request //! -//! ```rust +//! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { //! let mut client = awc::Client::default(); //! let response = client.get("http://www.rust-lang.org") // <- Create request builder -//! .header("User-Agent", "Actix-web") +//! .insert_header(("User-Agent", "Actix-web")) //! .send() // <- Send http request //! .await?; //! @@ -20,7 +20,7 @@ //! //! ### Raw body contents //! -//! ```rust +//! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { //! let mut client = awc::Client::default(); @@ -33,7 +33,7 @@ //! //! ### Forms //! -//! ```rust +//! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { //! let params = [("foo", "bar"), ("baz", "quux")]; @@ -48,7 +48,7 @@ //! //! ### JSON //! -//! ```rust +//! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { //! let request = serde_json::json!({ @@ -66,7 +66,7 @@ //! //! ## WebSocket support //! -//! ``` +//! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), Box> { //! use futures_util::{sink::SinkExt, stream::StreamExt}; @@ -76,7 +76,7 @@ //! .await?; //! //! connection -//! .send(awc::ws::Message::Text("Echo".to_string())) +//! .send(awc::ws::Message::Text("Echo".into())) //! .await?; //! let response = connection.next().await.unwrap()?; //! # assert_eq!(response, awc::ws::Frame::Text("Echo".as_bytes().into())); @@ -94,15 +94,21 @@ #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] #![clippy::msrv = "1.46"] -use std::cell::RefCell; use std::convert::TryFrom; use std::rc::Rc; use std::time::Duration; -pub use actix_http::{client::Connector, cookie, http}; +#[cfg(feature = "cookies")] +pub use actix_http::cookie; +pub use actix_http::{client::Connector, http}; -use actix_http::http::{Error as HttpError, HeaderMap, Method, Uri}; -use actix_http::RequestHead; +use actix_http::{ + client::{TcpConnect, TcpConnectError, TcpConnection}, + http::{Error as HttpError, HeaderMap, Method, Uri}, + RequestHead, +}; +use actix_rt::net::TcpStream; +use actix_service::Service; mod builder; mod connect; @@ -115,13 +121,13 @@ pub mod test; pub mod ws; pub use self::builder::ClientBuilder; -pub use self::connect::BoxedSocket; +pub use self::connect::{BoxedSocket, ConnectRequest, ConnectResponse, ConnectService}; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; pub use self::response::{ClientResponse, JsonBody, MessageBody}; pub use self::sender::SendClientRequest; -use self::connect::{Connect, ConnectorWrapper}; +use self::connect::ConnectorWrapper; /// An asynchronous HTTP and WebSocket client. /// @@ -135,8 +141,8 @@ use self::connect::{Connect, ConnectorWrapper}; /// let mut client = Client::default(); /// /// let res = client.get("http://www.rust-lang.org") // <- Create request builder -/// .header("User-Agent", "Actix-web") -/// .send() // <- Send http request +/// .insert_header(("User-Agent", "Actix-web")) +/// .send() // <- Send HTTP request /// .await; // <- send request and wait for response /// /// println!("Response: {:?}", res); @@ -146,7 +152,7 @@ use self::connect::{Connect, ConnectorWrapper}; pub struct Client(Rc); pub(crate) struct ClientConfig { - pub(crate) connector: RefCell>, + pub(crate) connector: ConnectService, pub(crate) headers: HeaderMap, pub(crate) timeout: Option, } @@ -154,9 +160,7 @@ pub(crate) struct ClientConfig { impl Default for Client { fn default() -> Self { Client(Rc::new(ClientConfig { - connector: RefCell::new(Box::new(ConnectorWrapper( - Connector::new().finish(), - ))), + connector: Box::new(ConnectorWrapper::new(Connector::new().finish())), headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), })) @@ -171,7 +175,14 @@ impl Client { /// Create `Client` builder. /// This function is equivalent of `ClientBuilder::new()`. - pub fn builder() -> ClientBuilder { + pub fn builder() -> ClientBuilder< + impl Service< + TcpConnect, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone, + TcpStream, + > { ClientBuilder::new() } @@ -183,8 +194,8 @@ impl Client { { let mut req = ClientRequest::new(method, url, self.0.clone()); - for (key, value) in self.0.headers.iter() { - req = req.set_header_if_none(key.clone(), value.clone()); + for header in self.0.headers.iter() { + req = req.insert_header_if_none(header); } req } @@ -199,8 +210,8 @@ impl Client { >::Error: Into, { let mut req = self.request(head.method.clone(), url); - for (key, value) in head.headers.iter() { - req = req.set_header_if_none(key.clone(), value.clone()); + for header in head.headers.iter() { + req = req.insert_header_if_none(header); } req } diff --git a/awc/src/request.rs b/awc/src/request.rs index 1e49aae3c..812c0e805 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -8,11 +8,11 @@ use futures_core::Stream; use serde::Serialize; use actix_http::body::Body; +#[cfg(feature = "cookies")] use actix_http::cookie::{Cookie, CookieJar}; -use actix_http::http::header::{self, Header, IntoHeaderValue}; +use actix_http::http::header::{self, IntoHeaderPair}; use actix_http::http::{ - uri, ConnectionType, Error as HttpError, HeaderMap, HeaderName, HeaderValue, Method, - Uri, Version, + uri, ConnectionType, Error as HttpError, HeaderMap, HeaderValue, Method, Uri, Version, }; use actix_http::{Error, RequestHead}; @@ -37,17 +37,15 @@ cfg_if::cfg_if! { /// builder-like pattern. /// /// ```rust -/// use actix_rt::System; -/// /// #[actix_rt::main] /// async fn main() { /// let response = awc::Client::new() /// .get("http://www.rust-lang.org") // <- Create request builder -/// .header("User-Agent", "Actix-web") -/// .send() // <- Send http request +/// .insert_header(("User-Agent", "Actix-web")) +/// .send() // <- Send HTTP request /// .await; /// -/// response.and_then(|response| { // <- server http response +/// response.and_then(|response| { // <- server HTTP response /// println!("Response: {:?}", response); /// Ok(()) /// }); @@ -57,10 +55,12 @@ pub struct ClientRequest { pub(crate) head: RequestHead, err: Option, addr: Option, - cookies: Option, response_decompress: bool, timeout: Option, config: Rc, + + #[cfg(feature = "cookies")] + cookies: Option, } impl ClientRequest { @@ -75,6 +75,7 @@ impl ClientRequest { head: RequestHead::default(), err: None, addr: None, + #[cfg(feature = "cookies")] cookies: None, timeout: None, response_decompress: true, @@ -143,110 +144,73 @@ impl ClientRequest { &self.head.peer_addr } - #[inline] /// Returns request's headers. + #[inline] pub fn headers(&self) -> &HeaderMap { &self.head.headers } - #[inline] /// Returns request's mutable headers. + #[inline] pub fn headers_mut(&mut self) -> &mut HeaderMap { &mut self.head.headers } - /// Set a header. - /// - /// ```rust - /// fn main() { - /// # actix_rt::System::new("test").block_on(futures_util::future::lazy(|_| { - /// let req = awc::Client::new() - /// .get("http://www.rust-lang.org") - /// .set(awc::http::header::Date::now()) - /// .set(awc::http::header::ContentType(mime::TEXT_HTML)); - /// # Ok::<_, ()>(()) - /// # })); - /// } - /// ``` - pub fn set(mut self, hdr: H) -> Self { - match hdr.try_into() { - Ok(value) => { - self.head.headers.insert(H::name(), value); + /// Insert a header, replacing any that were set with an equivalent field name. + pub fn insert_header(mut self, header: H) -> Self + where + H: IntoHeaderPair, + { + match header.try_into_header_pair() { + Ok((key, value)) => { + self.head.headers.insert(key, value); } Err(e) => self.err = Some(e.into()), - } - self - } + }; - /// Append a header. - /// - /// Header gets appended to existing header. - /// To override header use `set_header()` method. - /// - /// ```rust - /// use awc::{http, Client}; - /// - /// fn main() { - /// # actix_rt::System::new("test").block_on(async { - /// let req = Client::new() - /// .get("http://www.rust-lang.org") - /// .header("X-TEST", "value") - /// .header(http::header::CONTENT_TYPE, "application/json"); - /// # Ok::<_, ()>(()) - /// # }); - /// } - /// ``` - pub fn header(mut self, key: K, value: V) -> Self - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => self.head.headers.append(key, value), - Err(e) => self.err = Some(e.into()), - }, - Err(e) => self.err = Some(e.into()), - } - self - } - - /// Insert a header, replaces existing header. - pub fn set_header(mut self, key: K, value: V) -> Self - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => self.head.headers.insert(key, value), - Err(e) => self.err = Some(e.into()), - }, - Err(e) => self.err = Some(e.into()), - } self } /// Insert a header only if it is not yet set. - pub fn set_header_if_none(mut self, key: K, value: V) -> Self + pub fn insert_header_if_none(mut self, header: H) -> Self where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, + H: IntoHeaderPair, { - match HeaderName::try_from(key) { - Ok(key) => { + match header.try_into_header_pair() { + Ok((key, value)) => { if !self.head.headers.contains_key(&key) { - match value.try_into() { - Ok(value) => self.head.headers.insert(key, value), - Err(e) => self.err = Some(e.into()), - } + self.head.headers.insert(key, value); } } Err(e) => self.err = Some(e.into()), - } + }; + + self + } + + /// Append a header, keeping any that were set with an equivalent field name. + /// + /// ```rust + /// # #[actix_rt::main] + /// # async fn main() { + /// # use awc::Client; + /// use awc::http::header::ContentType; + /// + /// Client::new() + /// .get("http://www.rust-lang.org") + /// .insert_header(("X-TEST", "value")) + /// .insert_header(ContentType(mime::APPLICATION_JSON)); + /// # } + /// ``` + pub fn append_header(mut self, header: H) -> Self + where + H: IntoHeaderPair, + { + match header.try_into_header_pair() { + Ok((key, value)) => self.head.headers.append(key, value), + Err(e) => self.err = Some(e.into()), + }; + self } @@ -258,7 +222,7 @@ impl ClientRequest { } /// Force close connection instead of returning it back to connections pool. - /// This setting affect only http/1 connections. + /// This setting affect only HTTP/1 connections. #[inline] pub fn force_close(mut self) -> Self { self.head.set_connection_type(ConnectionType::Close); @@ -273,7 +237,9 @@ impl ClientRequest { >::Error: Into, { match HeaderValue::try_from(value) { - Ok(value) => self.head.headers.insert(header::CONTENT_TYPE, value), + Ok(value) => { + self.head.headers.insert(header::CONTENT_TYPE, value); + } Err(e) => self.err = Some(e.into()), } self @@ -282,7 +248,7 @@ impl ClientRequest { /// Set content length #[inline] pub fn content_length(self, len: u64) -> Self { - self.header(header::CONTENT_LENGTH, len) + self.append_header((header::CONTENT_LENGTH, len)) } /// Set HTTP basic authorization header @@ -294,10 +260,10 @@ impl ClientRequest { Some(password) => format!("{}:{}", username, password), None => format!("{}:", username), }; - self.header( + self.append_header(( header::AUTHORIZATION, format!("Basic {}", base64::encode(&auth)), - ) + )) } /// Set HTTP bearer authentication header @@ -305,7 +271,7 @@ impl ClientRequest { where T: fmt::Display, { - self.header(header::AUTHORIZATION, format!("Bearer {}", token)) + self.append_header((header::AUTHORIZATION, format!("Bearer {}", token))) } /// Set a cookie @@ -328,6 +294,7 @@ impl ClientRequest { /// println!("Response: {:?}", resp); /// } /// ``` + #[cfg(feature = "cookies")] pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { if self.cookies.is_none() { let mut jar = CookieJar::new(); @@ -510,7 +477,8 @@ impl ClientRequest { ) } - fn prep_for_sending(mut self) -> Result { + // allow unused mut when cookies feature is disabled + fn prep_for_sending(#[allow(unused_mut)] mut self) -> Result { if let Some(e) = self.err { return Err(e.into()); } @@ -523,7 +491,7 @@ impl ClientRequest { return Err(InvalidUrl::MissingScheme.into()); } else if let Some(scheme) = uri.scheme() { match scheme.as_str() { - "http" | "ws" | "https" | "wss" => (), + "http" | "ws" | "https" | "wss" => {} _ => return Err(InvalidUrl::UnknownScheme.into()), } } else { @@ -531,6 +499,7 @@ impl ClientRequest { } // set cookies + #[cfg(feature = "cookies")] if let Some(ref mut jar) = self.cookies { let cookie: String = jar .delta() @@ -557,12 +526,11 @@ impl ClientRequest { .unwrap_or(true); if https { - slf = slf.set_header_if_none(header::ACCEPT_ENCODING, HTTPS_ENCODING) + slf = slf.insert_header_if_none((header::ACCEPT_ENCODING, HTTPS_ENCODING)) } else { #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] { - slf = - slf.set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate") + slf = slf.insert_header_if_none((header::ACCEPT_ENCODING, "gzip, deflate")) } }; } @@ -595,7 +563,7 @@ mod tests { #[actix_rt::test] async fn test_debug() { - let request = Client::new().get("/").header("x-test", "111"); + let request = Client::new().get("/").append_header(("x-test", "111")); let repr = format!("{:?}", request); assert!(repr.contains("ClientRequest")); assert!(repr.contains("x-test")); @@ -606,18 +574,18 @@ mod tests { let req = Client::new() .put("/") .version(Version::HTTP_2) - .set(header::Date(SystemTime::now().into())) + .insert_header(header::Date(SystemTime::now().into())) .content_type("plain/text") - .header(header::SERVER, "awc"); + .append_header((header::SERVER, "awc")); let req = if let Some(val) = Some("server") { - req.header(header::USER_AGENT, val) + req.append_header((header::USER_AGENT, val)) } else { req }; let req = if let Some(_val) = Option::<&str>::None { - req.header(header::ALLOW, "1") + req.append_header((header::ALLOW, "1")) } else { req }; @@ -660,7 +628,7 @@ mod tests { .header(header::CONTENT_TYPE, "111") .finish() .get("/") - .set_header(header::CONTENT_TYPE, "222"); + .insert_header((header::CONTENT_TYPE, "222")); assert_eq!( req.head diff --git a/awc/src/response.rs b/awc/src/response.rs index 8364aa556..514b8a90b 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -1,26 +1,65 @@ -use std::cell::{Ref, RefMut}; -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + cell::{Ref, RefMut}, + fmt, + future::Future, + io, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; +use actix_http::{ + error::PayloadError, + http::{header, HeaderMap, StatusCode, Version}, + Extensions, HttpMessage, Payload, PayloadStream, ResponseHead, +}; +use actix_rt::time::{sleep, Sleep}; use bytes::{Bytes, BytesMut}; use futures_core::{ready, Stream}; - -use actix_http::cookie::Cookie; -use actix_http::error::{CookieParseError, PayloadError}; -use actix_http::http::header::{CONTENT_LENGTH, SET_COOKIE}; -use actix_http::http::{HeaderMap, StatusCode, Version}; -use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead}; use serde::de::DeserializeOwned; +#[cfg(feature = "cookies")] +use actix_http::{cookie::Cookie, error::CookieParseError}; + use crate::error::JsonPayloadError; /// Client Response pub struct ClientResponse { pub(crate) head: ResponseHead, pub(crate) payload: Payload, + pub(crate) timeout: ResponseTimeout, +} + +/// helper enum with reusable sleep passed from `SendClientResponse`. +/// See `ClientResponse::_timeout` for reason. +pub(crate) enum ResponseTimeout { + Disabled(Option>>), + Enabled(Pin>), +} + +impl Default for ResponseTimeout { + fn default() -> Self { + Self::Disabled(None) + } +} + +impl ResponseTimeout { + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { + match *self { + Self::Enabled(ref mut timeout) => { + if timeout.as_mut().poll(cx).is_ready() { + Err(PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Response Payload IO timed out", + ))) + } else { + Ok(()) + } + } + Self::Disabled(_) => Ok(()), + } + } } impl HttpMessage for ClientResponse { @@ -30,6 +69,10 @@ impl HttpMessage for ClientResponse { &self.head.headers } + fn take_payload(&mut self) -> Payload { + std::mem::replace(&mut self.payload, Payload::None) + } + fn extensions(&self) -> Ref<'_, Extensions> { self.head.extensions() } @@ -38,20 +81,15 @@ impl HttpMessage for ClientResponse { self.head.extensions_mut() } - fn take_payload(&mut self) -> Payload { - std::mem::replace(&mut self.payload, Payload::None) - } - /// Load request cookies. - #[inline] + #[cfg(feature = "cookies")] fn cookies(&self) -> Result>>, CookieParseError> { struct Cookies(Vec>); if self.extensions().get::().is_none() { let mut cookies = Vec::new(); - for hdr in self.headers().get_all(&SET_COOKIE) { - let s = std::str::from_utf8(hdr.as_bytes()) - .map_err(CookieParseError::from)?; + for hdr in self.headers().get_all(&header::SET_COOKIE) { + let s = std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; cookies.push(Cookie::parse_encoded(s)?.into_owned()); } self.extensions_mut().insert(Cookies(cookies)); @@ -65,7 +103,11 @@ impl HttpMessage for ClientResponse { impl ClientResponse { /// Create new Request instance pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { - ClientResponse { head, payload } + ClientResponse { + head, + payload, + timeout: ResponseTimeout::default(), + } } #[inline] @@ -101,15 +143,50 @@ impl ClientResponse { ClientResponse { payload, head: self.head, + timeout: self.timeout, } } + + /// Set a timeout duration for [`ClientResponse`](self::ClientResponse). + /// + /// This duration covers the duration of processing the response body stream + /// and would end it as timeout error when deadline met. + /// + /// Disabled by default. + pub fn timeout(self, dur: Duration) -> Self { + let timeout = match self.timeout { + ResponseTimeout::Disabled(Some(mut timeout)) + | ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) { + Some(deadline) => { + timeout.as_mut().reset(deadline.into()); + ResponseTimeout::Enabled(timeout) + } + None => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }, + _ => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }; + + Self { + payload: self.payload, + head: self.head, + timeout, + } + } + + /// This method does not enable timeout. It's used to pass the boxed `Sleep` from + /// `SendClientRequest` and reuse it's heap allocation together with it's slot in + /// timer wheel. + pub(crate) fn _timeout(mut self, timeout: Option>>) -> Self { + self.timeout = ResponseTimeout::Disabled(timeout); + self + } } impl ClientResponse where S: Stream>, { - /// Loads http response's body. + /// Loads HTTP response's body. pub fn body(&mut self) -> MessageBody { MessageBody::new(self) } @@ -132,11 +209,11 @@ where { type Item = Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().payload).poll_next(cx) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + this.timeout.poll_timeout(cx)?; + + Pin::new(&mut this.payload).poll_next(cx) } } @@ -151,10 +228,11 @@ impl fmt::Debug for ClientResponse { } } -/// Future that resolves to a complete http message body. +/// Future that resolves to a complete HTTP message body. pub struct MessageBody { length: Option, err: Option, + timeout: ResponseTimeout, fut: Option>, } @@ -165,7 +243,7 @@ where /// Create `MessageBody` for request. pub fn new(res: &mut ClientResponse) -> MessageBody { let mut len = None; - if let Some(l) = res.headers().get(&CONTENT_LENGTH) { + if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { if let Ok(s) = l.to_str() { if let Ok(l) = s.parse::() { len = Some(l) @@ -180,11 +258,12 @@ where MessageBody { length: len, err: None, + timeout: std::mem::take(&mut res.timeout), fut: Some(ReadBody::new(res.take_payload(), 262_144)), } } - /// Change max size of payload. By default max size is 256Kb + /// Change max size of payload. By default max size is 256kB pub fn limit(mut self, limit: usize) -> Self { if let Some(ref mut fut) = self.fut { fut.limit = limit; @@ -197,6 +276,7 @@ where fut: None, err: Some(e), length: None, + timeout: ResponseTimeout::default(), } } } @@ -220,6 +300,8 @@ where } } + this.timeout.poll_timeout(cx)?; + Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } } @@ -233,8 +315,9 @@ where pub struct JsonBody { length: Option, err: Option, + timeout: ResponseTimeout, fut: Option>, - _t: PhantomData, + _phantom: PhantomData, } impl JsonBody @@ -243,9 +326,9 @@ where U: DeserializeOwned, { /// Create `JsonBody` for request. - pub fn new(req: &mut ClientResponse) -> Self { + pub fn new(res: &mut ClientResponse) -> Self { // check content-type - let json = if let Ok(Some(mime)) = req.mime_type() { + let json = if let Ok(Some(mime)) = res.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) } else { false @@ -254,13 +337,15 @@ where return JsonBody { length: None, fut: None, + timeout: ResponseTimeout::default(), err: Some(JsonPayloadError::ContentType), - _t: PhantomData, + _phantom: PhantomData, }; } let mut len = None; - if let Some(l) = req.headers().get(&CONTENT_LENGTH) { + + if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { if let Ok(s) = l.to_str() { if let Ok(l) = s.parse::() { len = Some(l) @@ -271,12 +356,13 @@ where JsonBody { length: len, err: None, - fut: Some(ReadBody::new(req.take_payload(), 65536)), - _t: PhantomData, + timeout: std::mem::take(&mut res.timeout), + fut: Some(ReadBody::new(res.take_payload(), 65536)), + _phantom: PhantomData, } } - /// Change max size of payload. By default max size is 64Kb + /// Change max size of payload. By default max size is 64kB pub fn limit(mut self, limit: usize) -> Self { if let Some(ref mut fut) = self.fut { fut.limit = limit; @@ -306,12 +392,14 @@ where if let Some(len) = self.length.take() { if len > self.fut.as_ref().unwrap().limit { - return Poll::Ready(Err(JsonPayloadError::Payload( - PayloadError::Overflow, - ))); + return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow))); } } + self.timeout + .poll_timeout(cx) + .map_err(JsonPayloadError::Payload)?; + let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) } @@ -370,14 +458,13 @@ mod tests { async fn test_body() { let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); match req.body().await.err().unwrap() { - PayloadError::UnknownLength => (), + PayloadError::UnknownLength => {} _ => unreachable!("error"), } - let mut req = - TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); + let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); match req.body().await.err().unwrap() { - PayloadError::Overflow => (), + PayloadError::Overflow => {} _ => unreachable!("error"), } @@ -390,7 +477,7 @@ mod tests { .set_payload(Bytes::from_static(b"11111111111111")) .finish(); match req.body().limit(5).await.err().unwrap() { - PayloadError::Overflow => (), + PayloadError::Overflow => {} _ => unreachable!("error"), } } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index 0bcdf4307..1170c69a0 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -1,28 +1,30 @@ -use std::future::Future; -use std::net; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::time::Duration; +use std::{ + future::Future, + net, + pin::Pin, + rc::Rc, + task::{Context, Poll}, + time::Duration, +}; -use actix_rt::time::{delay_for, Delay}; +use actix_http::{ + body::{Body, BodyStream}, + http::{ + header::{self, HeaderMap, HeaderName, IntoHeaderValue}, + Error as HttpError, + }, + Error, RequestHead, RequestHeadType, +}; +use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; use derive_more::From; use futures_core::Stream; use serde::Serialize; -use actix_http::body::{Body, BodyStream}; -use actix_http::http::header::{self, IntoHeaderValue}; -use actix_http::http::{Error as HttpError, HeaderMap, HeaderName}; -use actix_http::{Error, RequestHead}; - -#[cfg(feature = "compress")] -use actix_http::encoding::Decoder; -#[cfg(feature = "compress")] -use actix_http::http::header::ContentEncoding; -#[cfg(feature = "compress")] -use actix_http::{Payload, PayloadStream}; +#[cfg(feature = "compress")] +use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream}; +use crate::connect::{ConnectRequest, ConnectResponse}; use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use crate::response::ClientResponse; use crate::ClientConfig; @@ -33,18 +35,18 @@ pub(crate) enum PrepForSendingError { Http(HttpError), } -impl Into for PrepForSendingError { - fn into(self) -> FreezeRequestError { - match self { +impl From for FreezeRequestError { + fn from(err: PrepForSendingError) -> FreezeRequestError { + match err { PrepForSendingError::Url(e) => FreezeRequestError::Url(e), PrepForSendingError::Http(e) => FreezeRequestError::Http(e), } } } -impl Into for PrepForSendingError { - fn into(self) -> SendRequestError { - match self { +impl From for SendRequestError { + fn from(err: PrepForSendingError) -> SendRequestError { + match err { PrepForSendingError::Url(e) => SendRequestError::Url(e), PrepForSendingError::Http(e) => SendRequestError::Http(e), } @@ -55,8 +57,9 @@ impl Into for PrepForSendingError { #[must_use = "futures do nothing unless polled"] pub enum SendClientRequest { Fut( - Pin>>>, - Option, + Pin>>>, + // FIXME: use a pinned Sleep instead of box. + Option>>, bool, ), Err(Option), @@ -64,46 +67,43 @@ pub enum SendClientRequest { impl SendClientRequest { pub(crate) fn new( - send: Pin>>>, + send: Pin>>>, response_decompress: bool, timeout: Option, ) -> SendClientRequest { - let delay = timeout.map(delay_for); + let delay = timeout.map(|d| Box::pin(sleep(d))); SendClientRequest::Fut(send, delay, response_decompress) } } #[cfg(feature = "compress")] impl Future for SendClientRequest { - type Output = - Result>>, SendRequestError>; + type Output = Result>>, SendRequestError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, response_decompress) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => (), - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { - res.map_body(|head, payload| { - if *response_decompress { - Payload::Stream(Decoder::from_headers( - payload, - &head.headers, - )) - } else { - Payload::Stream(Decoder::new( - payload, - ContentEncoding::Identity, - )) - } - }) + let res = futures_core::ready!(send.as_mut().poll(cx)).map(|res| { + res.into_client_response()._timeout(delay.take()).map_body( + |head, payload| { + if *response_decompress { + Payload::Stream(Decoder::from_headers(payload, &head.headers)) + } else { + Payload::Stream(Decoder::new( + payload, + ContentEncoding::Identity, + )) + } + }, + ) }); Poll::Ready(res) @@ -124,13 +124,14 @@ impl Future for SendClientRequest { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, _) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => (), - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - Pin::new(send).poll(cx) + send.as_mut() + .poll(cx) + .map_ok(|res| res.into_client_response()._timeout(delay.take())) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), @@ -182,17 +183,19 @@ impl RequestSender { where B: Into, { - let mut connector = config.connector.borrow_mut(); - - let fut = match self { + let req = match self { RequestSender::Owned(head) => { - connector.send_request(head, body.into(), addr) - } - RequestSender::Rc(head, extra_headers) => { - connector.send_request_extra(head, extra_headers, body.into(), addr) + ConnectRequest::Client(RequestHeadType::Owned(head), body.into(), addr) } + RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( + RequestHeadType::Rc(head, extra_headers), + body.into(), + addr, + ), }; + let fut = config.connector.call(req); + SendClientRequest::new(fut, response_decompress, timeout.or(config.timeout)) } @@ -209,8 +212,7 @@ impl RequestSender { Err(e) => return Error::from(e).into(), }; - if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") - { + if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") { return e.into(); } @@ -237,10 +239,9 @@ impl RequestSender { }; // set content-type - if let Err(e) = self.set_header_if_none( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) { + if let Err(e) = + self.set_header_if_none(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + { return e.into(); } @@ -284,19 +285,17 @@ impl RequestSender { self.send_body(addr, response_decompress, timeout, config, Body::Empty) } - fn set_header_if_none( - &mut self, - key: HeaderName, - value: V, - ) -> Result<(), HttpError> + fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> where V: IntoHeaderValue, { match self { RequestSender::Owned(head) => { if !head.headers.contains_key(&key) { - match value.try_into() { - Ok(value) => head.headers.insert(key, value), + match value.try_into_value() { + Ok(value) => { + head.headers.insert(key, value); + } Err(e) => return Err(e.into()), } } @@ -305,7 +304,7 @@ impl RequestSender { if !head.headers.contains_key(&key) && !extra_headers.iter().any(|h| h.contains_key(&key)) { - match value.try_into() { + match value.try_into_value() { Ok(v) => { let h = extra_headers.get_or_insert(HeaderMap::new()); h.insert(key, v) diff --git a/awc/src/test.rs b/awc/src/test.rs index 68e5c9dc5..97bbb9c3d 100644 --- a/awc/src/test.rs +++ b/awc/src/test.rs @@ -1,9 +1,13 @@ //! Test helpers for actix http client to use during testing. use std::convert::TryFrom; -use actix_http::cookie::{Cookie, CookieJar}; -use actix_http::http::header::{self, Header, HeaderValue, IntoHeaderValue}; +use actix_http::http::header::{Header, IntoHeaderValue}; use actix_http::http::{Error as HttpError, HeaderName, StatusCode, Version}; +#[cfg(feature = "cookies")] +use actix_http::{ + cookie::{Cookie, CookieJar}, + http::header::{self, HeaderValue}, +}; use actix_http::{h1, Payload, ResponseHead}; use bytes::Bytes; @@ -12,6 +16,7 @@ use crate::ClientResponse; /// Test `ClientResponse` builder pub struct TestResponse { head: ResponseHead, + #[cfg(feature = "cookies")] cookies: CookieJar, payload: Option, } @@ -20,6 +25,7 @@ impl Default for TestResponse { fn default() -> TestResponse { TestResponse { head: ResponseHead::new(StatusCode::OK), + #[cfg(feature = "cookies")] cookies: CookieJar::new(), payload: None, } @@ -45,7 +51,7 @@ impl TestResponse { /// Set a header pub fn set(mut self, hdr: H) -> Self { - if let Ok(value) = hdr.try_into() { + if let Ok(value) = hdr.try_into_value() { self.head.headers.append(H::name(), value); return self; } @@ -60,7 +66,7 @@ impl TestResponse { V: IntoHeaderValue, { if let Ok(key) = HeaderName::try_from(key) { - if let Ok(value) = value.try_into() { + if let Ok(value) = value.try_into_value() { self.head.headers.append(key, value); return self; } @@ -69,6 +75,7 @@ impl TestResponse { } /// Set cookie for this response + #[cfg(feature = "cookies")] pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { self.cookies.add(cookie.into_owned()); self @@ -84,8 +91,11 @@ impl TestResponse { /// Complete response creation and generate `ClientResponse` instance pub fn finish(self) -> ClientResponse { + // allow unused mut when cookies feature is disabled + #[allow(unused_mut)] let mut head = self.head; + #[cfg(feature = "cookies")] for cookie in self.cookies.delta() { head.headers.insert( header::SET_COOKIE, diff --git a/awc/src/ws.rs b/awc/src/ws.rs index dd43d08b3..5f4570963 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -4,7 +4,7 @@ //! //! # Example //! -//! ``` +//! ```no_run //! use awc::{Client, ws}; //! use futures_util::{sink::SinkExt, stream::StreamExt}; //! @@ -17,7 +17,7 @@ //! .unwrap(); //! //! connection -//! .send(ws::Message::Text("Echo".to_string())) +//! .send(ws::Message::Text("Echo".into())) //! .await //! .unwrap(); //! let response = connection.next().await.unwrap().unwrap(); @@ -32,24 +32,22 @@ use std::rc::Rc; use std::{fmt, str}; use actix_codec::Framed; +#[cfg(feature = "cookies")] use actix_http::cookie::{Cookie, CookieJar}; use actix_http::{ws, Payload, RequestHead}; use actix_rt::time::timeout; +use actix_service::Service; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; -use crate::connect::BoxedSocket; +use crate::connect::{BoxedSocket, ConnectRequest}; use crate::error::{InvalidUrl, SendRequestError, WsClientError}; -use crate::http::header::{ - self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION, -}; -use crate::http::{ - ConnectionType, Error as HttpError, Method, StatusCode, Uri, Version, -}; +use crate::http::header::{self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION}; +use crate::http::{ConnectionType, Error as HttpError, Method, StatusCode, Uri, Version}; use crate::response::ClientResponse; use crate::ClientConfig; -/// `WebSocket` connection +/// WebSocket connection. pub struct WebsocketsRequest { pub(crate) head: RequestHead, err: Option, @@ -58,12 +56,14 @@ pub struct WebsocketsRequest { addr: Option, max_size: usize, server_mode: bool, - cookies: Option, config: Rc, + + #[cfg(feature = "cookies")] + cookies: Option, } impl WebsocketsRequest { - /// Create new websocket connection + /// Create new WebSocket connection pub(crate) fn new(uri: U, config: Rc) -> Self where Uri: TryFrom, @@ -93,6 +93,7 @@ impl WebsocketsRequest { protocols: None, max_size: 65_536, server_mode: false, + #[cfg(feature = "cookies")] cookies: None, } } @@ -106,7 +107,7 @@ impl WebsocketsRequest { self } - /// Set supported websocket protocols + /// Set supported WebSocket protocols pub fn protocols(mut self, protos: U) -> Self where U: IntoIterator, @@ -121,6 +122,7 @@ impl WebsocketsRequest { } /// Set a cookie + #[cfg(feature = "cookies")] pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { if self.cookies.is_none() { let mut jar = CookieJar::new(); @@ -147,7 +149,7 @@ impl WebsocketsRequest { /// Set max frame size /// - /// By default max size is set to 64kb + /// By default max size is set to 64kB pub fn max_frame_size(mut self, size: usize) -> Self { self.max_size = size; self @@ -170,7 +172,7 @@ impl WebsocketsRequest { V: IntoHeaderValue, { match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { + Ok(key) => match value.try_into_value() { Ok(value) => { self.head.headers.append(key, value); } @@ -189,7 +191,7 @@ impl WebsocketsRequest { V: IntoHeaderValue, { match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { + Ok(key) => match value.try_into_value() { Ok(value) => { self.head.headers.insert(key, value); } @@ -210,7 +212,7 @@ impl WebsocketsRequest { match HeaderName::try_from(key) { Ok(key) => { if !self.head.headers.contains_key(&key) { - match value.try_into() { + match value.try_into_value() { Ok(value) => { self.head.headers.insert(key, value); } @@ -243,7 +245,7 @@ impl WebsocketsRequest { self.header(AUTHORIZATION, format!("Bearer {}", token)) } - /// Complete request construction and connect to a websockets server. + /// Complete request construction and connect to a WebSocket server. pub async fn connect( mut self, ) -> Result<(ClientResponse, Framed), WsClientError> { @@ -259,7 +261,7 @@ impl WebsocketsRequest { return Err(InvalidUrl::MissingScheme.into()); } else if let Some(scheme) = uri.scheme() { match scheme.as_str() { - "http" | "ws" | "https" | "wss" => (), + "http" | "ws" | "https" | "wss" => {} _ => return Err(InvalidUrl::UnknownScheme.into()), } } else { @@ -274,6 +276,7 @@ impl WebsocketsRequest { } // set cookies + #[cfg(feature = "cookies")] if let Some(ref mut jar) = self.cookies { let cookie: String = jar .delta() @@ -325,28 +328,27 @@ impl WebsocketsRequest { let max_size = self.max_size; let server_mode = self.server_mode; - let fut = self - .config - .connector - .borrow_mut() - .open_tunnel(head, self.addr); + let req = ConnectRequest::Tunnel(head, self.addr); + + let fut = self.config.connector.call(req); // set request timeout - let (head, framed) = if let Some(to) = self.config.timeout { + let res = if let Some(to) = self.config.timeout { timeout(to, fut) .await - .map_err(|_| SendRequestError::Timeout) - .and_then(|res| res)? + .map_err(|_| SendRequestError::Timeout)?? } else { fut.await? }; + let (head, framed) = res.into_tunnel_response(); + // verify response if head.status != StatusCode::SWITCHING_PROTOCOLS { return Err(WsClientError::InvalidResponseStatus(head.status)); } - // Check for "UPGRADE" to websocket header + // check for "UPGRADE" to WebSocket header let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) { if let Ok(s) = hdr.to_str() { s.to_ascii_lowercase().contains("websocket") diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index 0024c6652..a41a8dac3 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -9,19 +9,22 @@ use bytes::Bytes; use flate2::read::GzDecoder; use flate2::write::GzEncoder; use flate2::Compression; -use futures_util::future::ok; +use futures_util::{future::ok, stream}; use rand::Rng; -use actix_http::HttpService; +use actix_http::{ + http::{self, StatusCode}, + HttpService, +}; use actix_http_test::test_server; use actix_service::{map_config, pipeline_factory}; -use actix_web::dev::{AppConfig, BodyEncoding}; -use actix_web::http::Cookie; -use actix_web::middleware::Compress; use actix_web::{ - http::header, test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, + dev::{AppConfig, BodyEncoding}, + http::{header, Cookie}, + middleware::Compress, + test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, }; -use awc::error::SendRequestError; +use awc::error::{JsonPayloadError, PayloadError, SendRequestError}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -48,11 +51,10 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ #[actix_rt::test] async fn test_simple() { let srv = test::start(|| { - App::new() - .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) + App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }); - let request = srv.get("/").header("x-test", "111").send(); + let request = srv.get("/").insert_header(("x-test", "111")).send(); let mut response = request.await.unwrap(); assert!(response.status().is_success()); @@ -82,7 +84,7 @@ async fn test_json() { let request = srv .get("/") - .header("x-test", "111") + .insert_header(("x-test", "111")) .send_json(&"TEST".to_string()); let response = request.await.unwrap(); assert!(response.status().is_success()); @@ -99,7 +101,10 @@ async fn test_form() { let mut data = HashMap::new(); let _ = data.insert("key".to_string(), "TEST".to_string()); - let request = srv.get("/").header("x-test", "111").send_form(&data); + let request = srv + .get("/") + .append_header(("x-test", "111")) + .send_form(&data); let response = request.await.unwrap(); assert!(response.status().is_success()); } @@ -108,17 +113,14 @@ async fn test_form() { async fn test_timeout() { let srv = test::start(|| { App::new().service(web::resource("/").route(web::to(|| async { - actix_rt::time::delay_for(Duration::from_millis(200)).await; + actix_rt::time::sleep(Duration::from_millis(200)).await; Ok::<_, Error>(HttpResponse::Ok().body(STR)) }))) }); let connector = awc::Connector::new() - .connector(actix_connect::new_connector( - actix_connect::start_default_resolver().await.unwrap(), - )) - .timeout(Duration::from_secs(15)) - .finish(); + .connector(actix_tls::connect::default_connector()) + .timeout(Duration::from_secs(15)); let client = awc::Client::builder() .connector(connector) @@ -127,7 +129,7 @@ async fn test_timeout() { let request = client.get(srv.url("/")).send(); match request.await { - Err(SendRequestError::Timeout) => (), + Err(SendRequestError::Timeout) => {} _ => panic!(), } } @@ -136,7 +138,7 @@ async fn test_timeout() { async fn test_timeout_override() { let srv = test::start(|| { App::new().service(web::resource("/").route(web::to(|| async { - actix_rt::time::delay_for(Duration::from_millis(200)).await; + actix_rt::time::sleep(Duration::from_millis(200)).await; Ok::<_, Error>(HttpResponse::Ok().body(STR)) }))) }); @@ -149,11 +151,84 @@ async fn test_timeout_override() { .timeout(Duration::from_millis(50)) .send(); match request.await { - Err(SendRequestError::Timeout) => (), + Err(SendRequestError::Timeout) => {} _ => panic!(), } } +#[actix_rt::test] +async fn test_response_timeout() { + use futures_util::stream::{once, StreamExt}; + + let srv = test::start(|| { + App::new().service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/json") + .streaming(Box::pin(once(async { + actix_rt::time::sleep(Duration::from_millis(200)).await; + Ok::<_, Error>(Bytes::from(STR)) + }))), + ) + }))) + }); + + let client = awc::Client::new(); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(500)) + .body() + .await + .unwrap(); + assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .next() + .await + .unwrap(); + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .body() + .await; + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .json::>() + .await; + match res { + Err(JsonPayloadError::Payload(PayloadError::Io(e))) => { + assert_eq!(e.kind(), std::io::ErrorKind::TimedOut) + } + _ => panic!("Response error type is not matched"), + } +} + #[actix_rt::test] async fn test_connection_reuse() { let num = Arc::new(AtomicUsize::new(0)); @@ -292,7 +367,7 @@ async fn test_connection_wait_queue() { .await; let client = awc::Client::builder() - .connector(awc::Connector::new().limit(1).finish()) + .connector(awc::Connector::new().limit(1)) .finish(); // req 1 @@ -341,7 +416,7 @@ async fn test_connection_wait_queue_force_close() { .await; let client = awc::Client::builder() - .connector(awc::Connector::new().limit(1).finish()) + .connector(awc::Connector::new().limit(1)) .finish(); // req 1 @@ -438,7 +513,7 @@ async fn test_client_gzip_encoding() { let data = e.finish().unwrap(); HttpResponse::Ok() - .header("content-encoding", "gzip") + .insert_header(("content-encoding", "gzip")) .body(data) }))) }); @@ -461,7 +536,7 @@ async fn test_client_gzip_encoding_large() { let data = e.finish().unwrap(); HttpResponse::Ok() - .header("content-encoding", "gzip") + .insert_header(("content-encoding", "gzip")) .body(data) }))) }); @@ -489,7 +564,7 @@ async fn test_client_gzip_encoding_large_random() { e.write_all(&data).unwrap(); let data = e.finish().unwrap(); HttpResponse::Ok() - .header("content-encoding", "gzip") + .insert_header(("content-encoding", "gzip")) .body(data) }))) }); @@ -511,7 +586,7 @@ async fn test_client_brotli_encoding() { e.write_all(&data).unwrap(); let data = e.finish().unwrap(); HttpResponse::Ok() - .header("content-encoding", "br") + .insert_header(("content-encoding", "br")) .body(data) }))) }); @@ -539,7 +614,7 @@ async fn test_client_brotli_encoding_large_random() { e.write_all(&data).unwrap(); let data = e.finish().unwrap(); HttpResponse::Ok() - .header("content-encoding", "br") + .insert_header(("content-encoding", "br")) .body(data) }))) }); @@ -554,113 +629,93 @@ async fn test_client_brotli_encoding_large_random() { assert_eq!(bytes, Bytes::from(data)); } -// #[actix_rt::test] -// async fn test_client_deflate_encoding() { -// let srv = test::TestServer::start(|app| { -// app.handler(|req: &HttpRequest| { -// req.body() -// .and_then(|bytes: Bytes| { -// Ok(HttpResponse::Ok() -// .content_encoding(http::ContentEncoding::Br) -// .body(bytes)) -// }) -// .responder() -// }) -// }); +#[actix_rt::test] +async fn test_client_deflate_encoding() { + let srv = test::start(|| { + App::new().default_service(web::to(|body: Bytes| { + HttpResponse::Ok() + .encoding(http::ContentEncoding::Br) + .body(body) + })) + }); -// // client request -// let request = srv -// .post() -// .content_encoding(http::ContentEncoding::Deflate) -// .body(STR) -// .unwrap(); -// let response = srv.execute(request.send()).unwrap(); -// assert!(response.status().is_success()); + let req = srv.post("/").send_body(STR); -// // read response -// let bytes = srv.execute(response.body()).unwrap(); -// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -// } + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); -// #[actix_rt::test] -// async fn test_client_deflate_encoding_large_random() { -// let data = rand::thread_rng() -// .sample_iter(&rand::distributions::Alphanumeric) -// .take(70_000) -// .collect::(); + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} -// let srv = test::TestServer::start(|app| { -// app.handler(|req: &HttpRequest| { -// req.body() -// .and_then(|bytes: Bytes| { -// Ok(HttpResponse::Ok() -// .content_encoding(http::ContentEncoding::Br) -// .body(bytes)) -// }) -// .responder() -// }) -// }); +#[actix_rt::test] +async fn test_client_deflate_encoding_large_random() { + let data = rand::thread_rng() + .sample_iter(rand::distributions::Alphanumeric) + .map(char::from) + .take(70_000) + .collect::(); -// // client request -// let request = srv -// .post() -// .content_encoding(http::ContentEncoding::Deflate) -// .body(data.clone()) -// .unwrap(); -// let response = srv.execute(request.send()).unwrap(); -// assert!(response.status().is_success()); + let srv = test::start(|| { + App::new().default_service(web::to(|body: Bytes| { + HttpResponse::Ok() + .encoding(http::ContentEncoding::Br) + .body(body) + })) + }); -// // read response -// let bytes = srv.execute(response.body()).unwrap(); -// assert_eq!(bytes, Bytes::from(data)); -// } + let req = srv.post("/").send_body(data.clone()); -// #[actix_rt::test] -// async fn test_client_streaming_explicit() { -// let srv = test::TestServer::start(|app| { -// app.handler(|req: &HttpRequest| { -// req.body() -// .map_err(Error::from) -// .and_then(|body| { -// Ok(HttpResponse::Ok() -// .chunked() -// .content_encoding(http::ContentEncoding::Identity) -// .body(body)) -// }) -// .responder() -// }) -// }); + let mut res = req.await.unwrap(); + let bytes = res.body().await.unwrap(); -// let body = once(Ok(Bytes::from_static(STR.as_ref()))); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(bytes, Bytes::from(data)); +} -// let request = srv.get("/").body(Body::Streaming(Box::new(body))).unwrap(); -// let response = srv.execute(request.send()).unwrap(); -// assert!(response.status().is_success()); +#[actix_rt::test] +async fn test_client_streaming_explicit() { + let srv = test::start(|| { + App::new().default_service(web::to(|body: web::Payload| { + HttpResponse::Ok() + .encoding(http::ContentEncoding::Identity) + .streaming(body) + })) + }); -// // read response -// let bytes = srv.execute(response.body()).unwrap(); -// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -// } + let body = + stream::once(async { Ok::<_, actix_http::Error>(Bytes::from_static(STR.as_bytes())) }); + let req = srv.post("/").send_stream(Box::pin(body)); -// #[actix_rt::test] -// async fn test_body_streaming_implicit() { -// let srv = test::TestServer::start(|app| { -// app.handler(|_| { -// let body = once(Ok(Bytes::from_static(STR.as_ref()))); -// HttpResponse::Ok() -// .content_encoding(http::ContentEncoding::Gzip) -// .body(Body::Streaming(Box::new(body))) -// }) -// }); + let mut res = req.await.unwrap(); + assert!(res.status().is_success()); -// let request = srv.get("/").finish().unwrap(); -// let response = srv.execute(request.send()).unwrap(); -// assert!(response.status().is_success()); + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} -// // read response -// let bytes = srv.execute(response.body()).unwrap(); -// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -// } +#[actix_rt::test] +async fn test_body_streaming_implicit() { + let srv = test::start(|| { + App::new().default_service(web::to(|| { + let body = stream::once(async { + Ok::<_, actix_http::Error>(Bytes::from_static(STR.as_bytes())) + }); + + HttpResponse::Ok() + .encoding(http::ContentEncoding::Gzip) + .streaming(Box::pin(body)) + })) + }); + + let req = srv.get("/").send(); + + let mut res = req.await.unwrap(); + assert!(res.status().is_success()); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} #[actix_rt::test] async fn test_client_cookie_handling() { @@ -731,35 +786,35 @@ async fn test_client_cookie_handling() { assert_eq!(c2, cookie2); } -// #[actix_rt::test] -// fn client_read_until_eof() { -// let addr = test::TestServer::unused_addr(); +#[actix_rt::test] +async fn client_unread_response() { + let addr = test::unused_addr(); -// thread::spawn(move || { -// let lst = net::TcpListener::bind(addr).unwrap(); + let lst = std::net::TcpListener::bind(addr).unwrap(); -// for stream in lst.incoming() { -// let mut stream = stream.unwrap(); -// let mut b = [0; 1000]; -// let _ = stream.read(&mut b).unwrap(); -// let _ = stream -// .write_all(b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\nwelcome!"); -// } -// }); + std::thread::spawn(move || { + for stream in lst.incoming() { + let mut stream = stream.unwrap(); + let mut b = [0; 1000]; + let _ = stream.read(&mut b).unwrap(); + let _ = stream.write_all( + b"HTTP/1.1 200 OK\r\n\ + connection: close\r\n\ + \r\n\ + welcome!", + ); + } + }); -// let mut sys = actix::System::new("test"); + // client request + let req = awc::Client::new().get(format!("http://{}/", addr).as_str()); + let mut res = req.send().await.unwrap(); + assert!(res.status().is_success()); -// // client request -// let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) -// .finish() -// .unwrap(); -// let response = req.send().await.unwrap(); -// assert!(response.status().is_success()); - -// // read response -// let bytes = response.body().await.unwrap(); -// assert_eq!(bytes, Bytes::from_static(b"welcome!")); -// } + // awc does not read all bytes unless content-length is specified + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"")); +} #[actix_rt::test] async fn client_basic_auth() { diff --git a/awc/tests/test_connector.rs b/awc/tests/test_connector.rs index 888f7a900..632f68b72 100644 --- a/awc/tests/test_connector.rs +++ b/awc/tests/test_connector.rs @@ -1,29 +1,39 @@ #![cfg(feature = "openssl")] + +extern crate tls_openssl as openssl; + use actix_http::HttpService; use actix_http_test::test_server; -use actix_service::{map_config, ServiceFactory}; +use actix_service::{map_config, ServiceFactoryExt}; use actix_web::http::Version; use actix_web::{dev::AppConfig, web, App, HttpResponse}; -use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; +use openssl::{ + pkey::PKey, + ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode}, + x509::X509, +}; + +fn tls_config() -> SslAcceptor { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + let cert = X509::from_pem(cert_file.as_bytes()).unwrap(); + let key = PKey::private_key_from_pem(key_file.as_bytes()).unwrap(); -fn ssl_acceptor() -> SslAcceptor { - // load ssl keys let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("../tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("../tests/cert.pem") - .unwrap(); + builder.set_certificate(&cert).unwrap(); + builder.set_private_key(&key).unwrap(); + builder.set_alpn_select_callback(|_, protos| { const H2: &[u8] = b"\x02h2"; if protos.windows(3).any(|window| window == H2) { Ok(b"h2") } else { - Err(open_ssl::ssl::AlpnError::NOACK) + Err(openssl::ssl::AlpnError::NOACK) } }); builder.set_alpn_protos(b"\x02h2").unwrap(); + builder.build() } @@ -35,7 +45,7 @@ async fn test_connection_window_size() { App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))), |_| AppConfig::default(), )) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()) }) .await; @@ -48,7 +58,7 @@ async fn test_connection_window_size() { .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); let client = awc::Client::builder() - .connector(awc::Connector::new().ssl(builder.build()).finish()) + .connector(awc::Connector::new().ssl(builder.build())) .initial_window_size(100) .initial_connection_window_size(100) .finish(); diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs index 0df6b154c..464edfe89 100644 --- a/awc/tests/test_rustls_client.rs +++ b/awc/tests/test_rustls_client.rs @@ -1,57 +1,57 @@ #![cfg(feature = "rustls")] -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; + +extern crate tls_rustls as rustls; + +use std::{ + io::BufReader, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; use actix_http::HttpService; use actix_http_test::test_server; -use actix_service::{map_config, pipeline_factory, ServiceFactory}; -use actix_web::http::Version; -use actix_web::{dev::AppConfig, web, App, HttpResponse}; +use actix_service::{map_config, pipeline_factory, ServiceFactoryExt}; +use actix_web::{dev::AppConfig, http::Version, web, App, HttpResponse}; use futures_util::future::ok; -use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode}; -use rust_tls::ClientConfig; +use rustls::internal::pemfile::{certs, pkcs8_private_keys}; +use rustls::{ClientConfig, NoClientAuth, ServerConfig}; -#[allow(unused)] -fn ssl_acceptor() -> SslAcceptor { - // load ssl keys - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder.set_verify_callback(SslVerifyMode::NONE, |_, _| true); - builder - .set_private_key_file("../tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("../tests/cert.pem") - .unwrap(); - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(open_ssl::ssl::AlpnError::NOACK) - } - }); - builder.set_alpn_protos(b"\x02h2").unwrap(); - builder.build() +fn tls_config() -> ServerConfig { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + + let mut config = ServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(cert_file.as_bytes()); + let key_file = &mut BufReader::new(key_file.as_bytes()); + + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + config } mod danger { - pub struct NoCertificateVerification {} + pub struct NoCertificateVerification; - impl rust_tls::ServerCertVerifier for NoCertificateVerification { + impl rustls::ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, - _roots: &rust_tls::RootCertStore, - _presented_certs: &[rust_tls::Certificate], + _roots: &rustls::RootCertStore, + _presented_certs: &[rustls::Certificate], _dns_name: webpki::DNSNameRef<'_>, _ocsp: &[u8], - ) -> Result { - Ok(rust_tls::ServerCertVerified::assertion()) + ) -> Result { + Ok(rustls::ServerCertVerified::assertion()) } } } -// #[actix_rt::test] -async fn _test_connection_reuse_h2() { +#[actix_rt::test] +async fn test_connection_reuse_h2() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); @@ -64,26 +64,25 @@ async fn _test_connection_reuse_h2() { .and_then( HttpService::build() .h2(map_config( - App::new() - .service(web::resource("/").route(web::to(HttpResponse::Ok))), + App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))), |_| AppConfig::default(), )) - .openssl(ssl_acceptor()) + .rustls(tls_config()) .map_err(|_| ()), ) }) .await; - // disable ssl verification + // disable TLS verification let mut config = ClientConfig::new(); let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; config.set_protocols(&protos); config .dangerous() - .set_certificate_verifier(Arc::new(danger::NoCertificateVerification {})); + .set_certificate_verifier(Arc::new(danger::NoCertificateVerification)); let client = awc::Client::builder() - .connector(awc::Connector::new().rustls(Arc::new(config)).finish()) + .connector(awc::Connector::new().rustls(Arc::new(config))) .finish(); // req 1 diff --git a/awc/tests/test_ssl_client.rs b/awc/tests/test_ssl_client.rs index eced5f14b..3079aaf5e 100644 --- a/awc/tests/test_ssl_client.rs +++ b/awc/tests/test_ssl_client.rs @@ -1,33 +1,43 @@ #![cfg(feature = "openssl")] + +extern crate tls_openssl as openssl; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use actix_http::HttpService; use actix_http_test::test_server; -use actix_service::{map_config, pipeline_factory, ServiceFactory}; +use actix_service::{map_config, pipeline_factory, ServiceFactoryExt}; use actix_web::http::Version; use actix_web::{dev::AppConfig, web, App, HttpResponse}; use futures_util::future::ok; -use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; +use openssl::{ + pkey::PKey, + ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode}, + x509::X509, +}; + +fn tls_config() -> SslAcceptor { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + let cert = X509::from_pem(cert_file.as_bytes()).unwrap(); + let key = PKey::private_key_from_pem(key_file.as_bytes()).unwrap(); -fn ssl_acceptor() -> SslAcceptor { - // load ssl keys let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("../tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("../tests/cert.pem") - .unwrap(); + builder.set_certificate(&cert).unwrap(); + builder.set_private_key(&key).unwrap(); + builder.set_alpn_select_callback(|_, protos| { const H2: &[u8] = b"\x02h2"; if protos.windows(3).any(|window| window == H2) { Ok(b"h2") } else { - Err(open_ssl::ssl::AlpnError::NOACK) + Err(openssl::ssl::AlpnError::NOACK) } }); builder.set_alpn_protos(b"\x02h2").unwrap(); + builder.build() } @@ -45,11 +55,10 @@ async fn test_connection_reuse_h2() { .and_then( HttpService::build() .h2(map_config( - App::new() - .service(web::resource("/").route(web::to(HttpResponse::Ok))), + App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))), |_| AppConfig::default(), )) - .openssl(ssl_acceptor()) + .openssl(tls_config()) .map_err(|_| ()), ) }) @@ -63,7 +72,7 @@ async fn test_connection_reuse_h2() { .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); let client = awc::Client::builder() - .connector(awc::Connector::new().ssl(builder.build()).finish()) + .connector(awc::Connector::new().ssl(builder.build())) .finish(); // req 1 diff --git a/awc/tests/test_ws.rs b/awc/tests/test_ws.rs index 1c1068668..1b3f780dc 100644 --- a/awc/tests/test_ws.rs +++ b/awc/tests/test_ws.rs @@ -11,7 +11,7 @@ async fn ws_service(req: ws::Frame) -> Result { match req { ws::Frame::Ping(msg) => Ok(ws::Message::Pong(msg)), ws::Frame::Text(text) => Ok(ws::Message::Text( - String::from_utf8(Vec::from(text.as_ref())).unwrap(), + String::from_utf8(Vec::from(text.as_ref())).unwrap().into(), )), ws::Frame::Binary(bin) => Ok(ws::Message::Binary(bin)), ws::Frame::Close(reason) => Ok(ws::Message::Close(reason)), @@ -31,7 +31,7 @@ async fn test_simple() { .send(h1::Message::Item((res.drop_body(), BodySize::None))) .await?; - // start websocket service + // start WebSocket service let framed = framed.replace_codec(ws::Codec::new()); ws::Dispatcher::with(framed, ws_service).await } @@ -43,10 +43,7 @@ async fn test_simple() { // client service let mut framed = srv.ws().await.unwrap(); - framed - .send(ws::Message::Text("text".to_string())) - .await - .unwrap(); + framed.send(ws::Message::Text("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); diff --git a/benches/responder.rs b/benches/responder.rs new file mode 100644 index 000000000..8cfdbd3ea --- /dev/null +++ b/benches/responder.rs @@ -0,0 +1,113 @@ +use std::future::Future; +use std::time::Instant; + +use actix_http::Response; +use actix_web::http::StatusCode; +use actix_web::test::TestRequest; +use actix_web::{error, Error, HttpRequest, HttpResponse, Responder}; +use criterion::{criterion_group, criterion_main, Criterion}; +use futures_util::future::{ready, Either, Ready}; + +// responder simulate the old responder trait. +trait FutureResponder { + type Error; + type Future: Future>; + + fn future_respond_to(self, req: &HttpRequest) -> Self::Future; +} + +// a simple option responder type. +struct OptionResponder(Option); + +// a simple wrapper type around string +struct StringResponder(String); + +impl FutureResponder for StringResponder { + type Error = Error; + type Future = Ready>; + + fn future_respond_to(self, _: &HttpRequest) -> Self::Future { + // this is default builder for string response in both new and old responder trait. + ready(Ok(Response::build(StatusCode::OK) + .content_type("text/plain; charset=utf-8") + .body(self.0))) + } +} + +impl FutureResponder for OptionResponder +where + T: FutureResponder, + T::Future: Future>, +{ + type Error = Error; + type Future = Either>>; + + fn future_respond_to(self, req: &HttpRequest) -> Self::Future { + match self.0 { + Some(t) => Either::Left(t.future_respond_to(req)), + None => Either::Right(ready(Err(error::ErrorInternalServerError("err")))), + } + } +} + +impl Responder for StringResponder { + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + Response::build(StatusCode::OK) + .content_type("text/plain; charset=utf-8") + .body(self.0) + } +} + +impl Responder for OptionResponder { + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + match self.0 { + Some(t) => t.respond_to(req), + None => Response::from_error(error::ErrorInternalServerError("err")), + } + } +} + +fn future_responder(c: &mut Criterion) { + let rt = actix_rt::System::new(); + let req = TestRequest::default().to_http_request(); + + c.bench_function("future_responder", move |b| { + b.iter_custom(|_| { + let futs = (0..100_000).map(|_| async { + StringResponder(String::from("Hello World!!")) + .future_respond_to(&req) + .await + }); + + let futs = futures_util::future::join_all(futs); + + let start = Instant::now(); + + let _res = rt.block_on(async { futs.await }); + + start.elapsed() + }) + }); +} + +fn responder(c: &mut Criterion) { + let rt = actix_rt::System::new(); + let req = TestRequest::default().to_http_request(); + c.bench_function("responder", move |b| { + b.iter_custom(|_| { + let responders = + (0..100_000).map(|_| StringResponder(String::from("Hello World!!"))); + + let start = Instant::now(); + let _res = rt.block_on(async { + // don't need runtime block on but to be fair. + responders.map(|r| r.respond_to(&req)).collect::>() + }); + + start.elapsed() + }) + }); +} + +criterion_group!(responder_bench, future_responder, responder); +criterion_main!(responder_bench); diff --git a/benches/server.rs b/benches/server.rs index 041d0fa57..9dd540a73 100644 --- a/benches/server.rs +++ b/benches/server.rs @@ -29,34 +29,38 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ fn bench_async_burst(c: &mut Criterion) { // We are using System here, since Runtime requires preinitialized tokio // Maybe add to actix_rt docs - let mut rt = actix_rt::System::new("test"); + let rt = actix_rt::System::new(); - let srv = test::start(|| { - App::new() - .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) + let srv = rt.block_on(async { + test::start(|| { + App::new() + .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) + }) }); let url = srv.url("/"); c.bench_function("get_body_async_burst", move |b| { b.iter_custom(|iters| { - let client = Client::new().get(url.clone()).freeze().unwrap(); + rt.block_on(async { + let client = Client::new().get(url.clone()).freeze().unwrap(); + + let start = std::time::Instant::now(); + // benchmark body - let start = std::time::Instant::now(); - // benchmark body - let resps = rt.block_on(async move { let burst = (0..iters).map(|_| client.send()); - join_all(burst).await - }); - let elapsed = start.elapsed(); + let resps = join_all(burst).await; - // if there are failed requests that might be an issue - let failed = resps.iter().filter(|r| r.is_err()).count(); - if failed > 0 { - eprintln!("failed {} requests (might be bench timeout)", failed); - }; + let elapsed = start.elapsed(); - elapsed + // if there are failed requests that might be an issue + let failed = resps.iter().filter(|r| r.is_err()).count(); + if failed > 0 { + eprintln!("failed {} requests (might be bench timeout)", failed); + }; + + elapsed + }) }) }); } diff --git a/benches/service.rs b/benches/service.rs index 8adbc8a0c..0d3264857 100644 --- a/benches/service.rs +++ b/benches/service.rs @@ -23,10 +23,9 @@ use actix_web::test::{init_service, ok_service, TestRequest}; /// async_service_direct time: [1.0908 us 1.1656 us 1.2613 us] pub fn bench_async_service(c: &mut Criterion, srv: S, name: &str) where - S: Service - + 'static, + S: Service + 'static, { - let mut rt = actix_rt::System::new("test"); + let rt = actix_rt::System::new(); let srv = Rc::new(RefCell::new(srv)); let req = TestRequest::default().to_srv_request(); @@ -41,14 +40,15 @@ where b.iter_custom(|iters| { let srv = srv.clone(); // exclude request generation, it appears it takes significant time vs call (3us vs 1us) - let reqs: Vec<_> = (0..iters) + let futs = (0..iters) .map(|_| TestRequest::default().to_srv_request()) - .collect(); + .map(|req| srv.borrow_mut().call(req)); + let start = std::time::Instant::now(); // benchmark body rt.block_on(async move { - for req in reqs { - srv.borrow_mut().call(req).await.unwrap(); + for fut in futs { + fut.await.unwrap(); } }); let elapsed = start.elapsed(); @@ -67,7 +67,7 @@ async fn index(req: ServiceRequest) -> Result { // Sample results on MacBook Pro '14 // time: [2.0724 us 2.1345 us 2.2074 us] fn async_web_service(c: &mut Criterion) { - let mut rt = actix_rt::System::new("test"); + let rt = actix_rt::System::new(); let srv = Rc::new(RefCell::new(rt.block_on(init_service( App::new().service(web::service("/").finish(index)), )))); @@ -83,13 +83,14 @@ fn async_web_service(c: &mut Criterion) { c.bench_function("async_web_service_direct", move |b| { b.iter_custom(|iters| { let srv = srv.clone(); - let reqs = (0..iters).map(|_| TestRequest::get().uri("/").to_request()); - + let futs = (0..iters) + .map(|_| TestRequest::get().uri("/").to_request()) + .map(|req| srv.borrow_mut().call(req)); let start = std::time::Instant::now(); // benchmark body rt.block_on(async move { - for req in reqs { - srv.borrow_mut().call(req).await.unwrap(); + for fut in futs { + fut.await.unwrap(); } }); let elapsed = start.elapsed(); diff --git a/docs/graphs/web-focus.dot b/docs/graphs/web-focus.dot index 55a82bb41..ec0f7a946 100644 --- a/docs/graphs/web-focus.dot +++ b/docs/graphs/web-focus.dot @@ -1,6 +1,7 @@ digraph { subgraph cluster_web { label="actix/actix-web" + "awc" "actix-web" "actix-files" @@ -16,7 +17,7 @@ digraph { "actix-web-actors" -> { "actix" "actix-web" "actix-http" "actix-codec" } "actix-multipart" -> { "actix-web" "actix-service" "actix-utils" } "actix-http" -> { "actix-service" "actix-codec" "actix-tls" "actix-utils" "actix-rt" "threadpool" } - "actix-http" -> { "actix" "actix-tls" }[color=blue] // optional + "actix-http" -> { "actix-tls" }[color=blue] // optional "actix-files" -> { "actix-web" } "actix-http-test" -> { "actix-service" "actix-codec" "actix-tls" "actix-utils" "actix-rt" "actix-server" "awc" } @@ -27,4 +28,8 @@ digraph { "actix-tls" -> { "actix-service" "actix-codec" "actix-utils" } "actix-server" -> { "actix-service" "actix-rt" "actix-codec" "actix-utils" } "actix-rt" -> { "macros" "threadpool" } + + // actix + + "actix" -> { "actix-rt" } } diff --git a/docs/graphs/web-only.dot b/docs/graphs/web-only.dot index 9e1bb2805..6f8292a3a 100644 --- a/docs/graphs/web-only.dot +++ b/docs/graphs/web-only.dot @@ -15,7 +15,6 @@ digraph { "awc" -> { "actix-http" } "actix-web-actors" -> { "actix" "actix-web" "actix-http" } "actix-multipart" -> { "actix-web" } - "actix-http" -> { "actix" }[color=blue] // optional "actix-files" -> { "actix-web" } "actix-http-test" -> { "awc" } } diff --git a/examples/basic.rs b/examples/basic.rs index 8b2bf2319..99eef3ee1 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -18,8 +18,7 @@ async fn no_params() -> &'static str { #[actix_web::main] async fn main() -> std::io::Result<()> { - std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); - env_logger::init(); + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); HttpServer::new(|| { App::new() @@ -30,12 +29,8 @@ async fn main() -> std::io::Result<()> { .service(no_params) .service( web::resource("/resource2/index.html") - .wrap( - middleware::DefaultHeaders::new().header("X-Version-R2", "0.3"), - ) - .default_service( - web::route().to(|| HttpResponse::MethodNotAllowed()), - ) + .wrap(middleware::DefaultHeaders::new().header("X-Version-R2", "0.3")) + .default_service(web::route().to(|| HttpResponse::MethodNotAllowed())) .route(web::get().to(index_async)), ) .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) diff --git a/examples/client.rs b/examples/client.rs index 15cf24fa8..b9574590d 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -10,7 +10,7 @@ async fn main() -> Result<(), Error> { // Create request builder, configure request and send let mut response = client .get("https://www.rust-lang.org/") - .header("User-Agent", "Actix-web") + .append_header(("User-Agent", "Actix-web")) .send() .await?; diff --git a/examples/on_connect.rs b/examples/on_connect.rs index bdad7e67e..ba5a18f3f 100644 --- a/examples/on_connect.rs +++ b/examples/on_connect.rs @@ -4,7 +4,7 @@ //! For an example of extracting a client TLS certificate, see: //! -use std::{any::Any, env, io, net::SocketAddr}; +use std::{any::Any, io, net::SocketAddr}; use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer}; @@ -36,11 +36,7 @@ fn get_conn_info(connection: &dyn Any, data: &mut Extensions) { #[actix_web::main] async fn main() -> io::Result<()> { - if env::var("RUST_LOG").is_err() { - env::set_var("RUST_LOG", "info"); - } - - env_logger::init(); + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); HttpServer::new(|| App::new().default_service(web::to(route_whoami))) .on_connect(get_conn_info) diff --git a/examples/uds.rs b/examples/uds.rs index e34fa5ac9..096781984 100644 --- a/examples/uds.rs +++ b/examples/uds.rs @@ -22,8 +22,7 @@ async fn no_params() -> &'static str { #[cfg(unix)] #[actix_web::main] async fn main() -> std::io::Result<()> { - std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); - env_logger::init(); + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); HttpServer::new(|| { App::new() @@ -34,12 +33,8 @@ async fn main() -> std::io::Result<()> { .service(no_params) .service( web::resource("/resource2/index.html") - .wrap( - middleware::DefaultHeaders::new().header("X-Version-R2", "0.3"), - ) - .default_service( - web::route().to(|| HttpResponse::MethodNotAllowed()), - ) + .wrap(middleware::DefaultHeaders::new().header("X-Version-R2", "0.3")) + .default_service(web::route().to(|| HttpResponse::MethodNotAllowed())) .route(web::get().to(index_async)), ) .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) diff --git a/rustfmt.toml b/rustfmt.toml index 94bd11d51..973e002c0 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,2 @@ -max_width = 89 +max_width = 96 reorder_imports = true diff --git a/src/app.rs b/src/app.rs index 8dd86f7ec..7a26a3a89 100644 --- a/src/app.rs +++ b/src/app.rs @@ -5,10 +5,10 @@ use std::marker::PhantomData; use std::rc::Rc; use actix_http::body::{Body, MessageBody}; -use actix_http::Extensions; +use actix_http::{Extensions, Request}; use actix_service::boxed::{self, BoxServiceFactory}; use actix_service::{ - apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform, + apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, ServiceFactoryExt, Transform, }; use futures_util::future::FutureExt; @@ -33,11 +33,10 @@ pub struct App { services: Vec>, default: Option>, factory_ref: Rc>>, - data: Vec>, data_factories: Vec, external: Vec, extensions: Extensions, - _t: PhantomData, + _phantom: PhantomData, } impl App { @@ -47,14 +46,13 @@ impl App { let fref = Rc::new(RefCell::new(None)); App { endpoint: AppEntry::new(fref.clone()), - data: Vec::new(), data_factories: Vec::new(), services: Vec::new(), default: None, factory_ref: fref, external: Vec::new(), extensions: Extensions::new(), - _t: PhantomData, + _phantom: PhantomData, } } } @@ -63,8 +61,8 @@ impl App where B: MessageBody, T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -73,7 +71,7 @@ where /// Set application data. Application data could be accessed /// by using `Data` extractor where `T` is data type. /// - /// **Note**: http server accepts an application factory rather than + /// **Note**: HTTP server accepts an application factory rather than /// an application instance. Http server constructs an application /// instance for each thread, thus application data must be constructed /// multiple times. If you want to share data between different @@ -100,9 +98,8 @@ where /// web::resource("/index.html").route( /// web::get().to(index))); /// ``` - pub fn data(mut self, data: U) -> Self { - self.data.push(Box::new(Data::new(data))); - self + pub fn data(self, data: U) -> Self { + self.app_data(Data::new(data)) } /// Set application data factory. This function is @@ -156,8 +153,7 @@ where /// some of the resource's configuration could be moved to different module. /// /// ```rust - /// # extern crate actix_web; - /// use actix_web::{web, middleware, App, HttpResponse}; + /// use actix_web::{web, App, HttpResponse}; /// /// // this function could be located in different module /// fn config(cfg: &mut web::ServiceConfig) { @@ -167,12 +163,9 @@ where /// ); /// } /// - /// fn main() { - /// let app = App::new() - /// .wrap(middleware::Logger::default()) - /// .configure(config) // <- register resources - /// .route("/index.html", web::get().to(|| HttpResponse::Ok())); - /// } + /// App::new() + /// .configure(config) // <- register resources + /// .route("/index.html", web::get().to(|| HttpResponse::Ok())); /// ``` pub fn configure(mut self, f: F) -> Self where @@ -180,10 +173,9 @@ where { let mut cfg = ServiceConfig::new(); f(&mut cfg); - self.data.extend(cfg.data); self.services.extend(cfg.services); self.external.extend(cfg.external); - self.extensions.extend(cfg.extensions); + self.extensions.extend(cfg.app_data); self } @@ -214,11 +206,11 @@ where ) } - /// Register http service. + /// Register HTTP service. /// /// Http service is any type that implements `HttpServiceFactory` trait. /// - /// Actix web provides several services implementations: + /// Actix Web provides several services implementations: /// /// * *Resource* is an entry in resource table which corresponds to requested URL. /// * *Scope* is a set of resources with common root path. @@ -268,10 +260,10 @@ where /// ``` pub fn default_service(mut self, f: F) -> Self where - F: IntoServiceFactory, + F: IntoServiceFactory, U: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, > + 'static, @@ -353,8 +345,8 @@ where mw: M, ) -> App< impl ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -364,7 +356,7 @@ where where M: Transform< T::Service, - Request = ServiceRequest, + ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -373,14 +365,13 @@ where { App { endpoint: apply(mw, self.endpoint), - data: self.data, data_factories: self.data_factories, services: self.services, default: self.default, factory_ref: self.factory_ref, external: self.external, extensions: self.extensions, - _t: PhantomData, + _phantom: PhantomData, } } @@ -420,8 +411,8 @@ where mw: F, ) -> App< impl ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -430,38 +421,37 @@ where > where B1: MessageBody, - F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, + F: Fn(ServiceRequest, &T::Service) -> R + Clone, R: Future, Error>>, { App { endpoint: apply_fn_factory(self.endpoint, mw), - data: self.data, data_factories: self.data_factories, services: self.services, default: self.default, factory_ref: self.factory_ref, external: self.external, extensions: self.extensions, - _t: PhantomData, + _phantom: PhantomData, } } } -impl IntoServiceFactory> for App +impl IntoServiceFactory, Request> for App where B: MessageBody, T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), >, + T::Future: 'static, { fn into_factory(self) -> AppInit { AppInit { - data: self.data.into_boxed_slice().into(), - data_factories: self.data_factories.into_boxed_slice().into(), + async_data_factories: self.data_factories.into_boxed_slice().into(), endpoint: self.endpoint, services: Rc::new(RefCell::new(self.services)), external: RefCell::new(self.external), @@ -482,17 +472,13 @@ mod tests { use crate::http::{header, HeaderValue, Method, StatusCode}; use crate::middleware::DefaultHeaders; use crate::service::ServiceRequest; - use crate::test::{ - call_service, init_service, read_body, try_init_service, TestRequest, - }; + use crate::test::{call_service, init_service, read_body, try_init_service, TestRequest}; use crate::{web, HttpRequest, HttpResponse}; #[actix_rt::test] async fn test_default_resource() { - let mut srv = init_service( - App::new().service(web::resource("/test").to(HttpResponse::Ok)), - ) - .await; + let srv = + init_service(App::new().service(web::resource("/test").to(HttpResponse::Ok))).await; let req = TestRequest::with_uri("/test").to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); @@ -501,7 +487,7 @@ mod tests { let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let mut srv = init_service( + let srv = init_service( App::new() .service(web::resource("/test").to(HttpResponse::Ok)) .service( @@ -534,20 +520,22 @@ mod tests { #[actix_rt::test] async fn test_data_factory() { - let mut srv = - init_service(App::new().data_factory(|| ok::<_, ()>(10usize)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )) - .await; + let srv = init_service( + App::new() + .data_factory(|| ok::<_, ()>(10usize)) + .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())), + ) + .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); - let mut srv = - init_service(App::new().data_factory(|| ok::<_, ()>(10u32)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )) - .await; + let srv = init_service( + App::new() + .data_factory(|| ok::<_, ()>(10u32)) + .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())), + ) + .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); @@ -555,23 +543,24 @@ mod tests { #[actix_rt::test] async fn test_data_factory_errors() { - let srv = - try_init_service(App::new().data_factory(|| err::(())).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )) - .await; + let srv = try_init_service( + App::new() + .data_factory(|| err::(())) + .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())), + ) + .await; assert!(srv.is_err()); } #[actix_rt::test] async fn test_extension() { - let mut srv = init_service(App::new().app_data(10usize).service( - web::resource("/").to(|req: HttpRequest| { + let srv = init_service(App::new().app_data(10usize).service(web::resource("/").to( + |req: HttpRequest| { assert_eq!(*req.app_data::().unwrap(), 10); HttpResponse::Ok() - }), - )) + }, + ))) .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); @@ -580,7 +569,7 @@ mod tests { #[actix_rt::test] async fn test_wrap() { - let mut srv = init_service( + let srv = init_service( App::new() .wrap( DefaultHeaders::new() @@ -590,7 +579,7 @@ mod tests { ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -600,7 +589,7 @@ mod tests { #[actix_rt::test] async fn test_router_wrap() { - let mut srv = init_service( + let srv = init_service( App::new() .route("/test", web::get().to(HttpResponse::Ok)) .wrap( @@ -610,7 +599,7 @@ mod tests { ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -620,16 +609,14 @@ mod tests { #[actix_rt::test] async fn test_wrap_fn() { - let mut srv = init_service( + let srv = init_service( App::new() .wrap_fn(|req, srv| { let fut = srv.call(req); async move { let mut res = fut.await?; - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); + res.headers_mut() + .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); Ok(res) } }) @@ -637,7 +624,7 @@ mod tests { ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -647,24 +634,22 @@ mod tests { #[actix_rt::test] async fn test_router_wrap_fn() { - let mut srv = init_service( + let srv = init_service( App::new() .route("/test", web::get().to(HttpResponse::Ok)) .wrap_fn(|req, srv| { let fut = srv.call(req); async { let mut res = fut.await?; - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); + res.headers_mut() + .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); Ok(res) } }), ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -674,21 +659,20 @@ mod tests { #[actix_rt::test] async fn test_external_resource() { - let mut srv = init_service( + let srv = init_service( App::new() .external_resource("youtube", "https://youtube.com/watch/{video_id}") .route( "/test", web::get().to(|req: HttpRequest| { - HttpResponse::Ok().body( - req.url_for("youtube", &["12345"]).unwrap().to_string(), - ) + HttpResponse::Ok() + .body(req.url_for("youtube", &["12345"]).unwrap().to_string()) }), ), ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let body = read_body(resp).await; assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); diff --git a/src/app_service.rs b/src/app_service.rs index e5f8dd9cf..9b4ae3354 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,19 +1,16 @@ use std::cell::RefCell; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; use std::rc::Rc; -use std::task::{Context, Poll}; +use std::task::Poll; use actix_http::{Extensions, Request, Response}; -use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url}; +use actix_router::{Path, ResourceDef, Router, Url}; use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{fn_service, Service, ServiceFactory}; -use futures_util::future::{join_all, ok, FutureExt, LocalBoxFuture}; -use tinyvec::tiny_vec; +use futures_core::future::LocalBoxFuture; +use futures_util::future::join_all; use crate::config::{AppConfig, AppService}; -use crate::data::{DataFactory, FnDataFactory}; +use crate::data::FnDataFactory; use crate::error::Error; use crate::guard::Guard; use crate::request::{HttpRequest, HttpRequestPool}; @@ -23,15 +20,14 @@ use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse}; type Guards = Vec>; type HttpService = BoxService; type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; -type BoxResponse = LocalBoxFuture<'static, Result>; /// Service factory to convert `Request` to a `ServiceRequest`. /// It also executes data factories. pub struct AppInit where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -39,42 +35,42 @@ where { pub(crate) endpoint: T, pub(crate) extensions: RefCell>, - pub(crate) data: Rc<[Box]>, - pub(crate) data_factories: Rc<[FnDataFactory]>, + pub(crate) async_data_factories: Rc<[FnDataFactory]>, pub(crate) services: Rc>>>, pub(crate) default: Option>, pub(crate) factory_ref: Rc>>, pub(crate) external: RefCell>, } -impl ServiceFactory for AppInit +impl ServiceFactory for AppInit where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), >, + T::Future: 'static, { - type Config = AppConfig; - type Request = Request; type Response = ServiceResponse; type Error = T::Error; - type InitError = T::InitError; + type Config = AppConfig; type Service = AppInitService; - type Future = AppInitResult; + type InitError = T::InitError; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, config: AppConfig) -> Self::Future { - // update resource default service + // set AppService's default service to 404 NotFound + // if no user defined default service exists. let default = self.default.clone().unwrap_or_else(|| { - Rc::new(boxed::factory(fn_service(|req: ServiceRequest| { - ok(req.into_response(Response::NotFound().finish())) + Rc::new(boxed::factory(fn_service(|req: ServiceRequest| async { + Ok(req.into_response(Response::NotFound().finish())) }))) }); // App config - let mut config = AppService::new(config, default.clone(), self.data.clone()); + let mut config = AppService::new(config, default.clone()); // register services std::mem::take(&mut *self.services.borrow_mut()) @@ -85,7 +81,7 @@ where let (config, services) = config.into_services(); - // complete pipeline creation + // complete pipeline creation. *self.factory_ref.borrow_mut() = Some(AppRoutingFactory { default, services: services @@ -108,167 +104,125 @@ where let rmap = Rc::new(rmap); rmap.finish(rmap.clone()); - // start all data factory futures - let factory_futs = join_all(self.data_factories.iter().map(|f| f())); + // construct all async data factory futures + let factory_futs = join_all(self.async_data_factories.iter().map(|f| f())); - AppInitResult { - endpoint: None, - endpoint_fut: self.endpoint.new_service(()), - data: self.data.clone(), - data_factories: None, - data_factories_fut: factory_futs.boxed_local(), - extensions: Some( - self.extensions - .borrow_mut() - .take() - .unwrap_or_else(Extensions::new), - ), - config, - rmap, - _t: PhantomData, - } + // construct app service and middleware service factory future. + let endpoint_fut = self.endpoint.new_service(()); + + // take extensions or create new one as app data container. + let mut app_data = self + .extensions + .borrow_mut() + .take() + .unwrap_or_else(Extensions::new); + + Box::pin(async move { + // async data factories + let async_data_factories = factory_futs + .await + .into_iter() + .collect::, _>>() + .map_err(|_| ())?; + + // app service and middleware + let service = endpoint_fut.await?; + + // populate app data container from (async) data factories. + async_data_factories.iter().for_each(|factory| { + factory.create(&mut app_data); + }); + + Ok(AppInitService { + service, + app_data: Rc::new(app_data), + app_state: AppInitServiceState::new(rmap, config), + }) + }) } } -#[pin_project::pin_project] -pub struct AppInitResult -where - T: ServiceFactory, -{ - #[pin] - endpoint_fut: T::Future, - // a Some signals completion of endpoint creation - endpoint: Option, - - #[pin] - data_factories_fut: LocalBoxFuture<'static, Vec, ()>>>, - // a Some signals completion of factory futures - data_factories: Option>>, - - rmap: Rc, - config: AppConfig, - data: Rc<[Box]>, - extensions: Option, - - _t: PhantomData, -} - -impl Future for AppInitResult -where - T: ServiceFactory< - Config = (), - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - InitError = (), - >, -{ - type Output = Result, ()>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - // async data factories - if let Poll::Ready(factories) = this.data_factories_fut.poll(cx) { - let factories: Result, ()> = factories.into_iter().collect(); - - if let Ok(factories) = factories { - this.data_factories.replace(factories); - } else { - return Poll::Ready(Err(())); - } - } - - // app service and middleware - if this.endpoint.is_none() { - if let Poll::Ready(srv) = this.endpoint_fut.poll(cx)? { - *this.endpoint = Some(srv); - } - } - - // not using if let so condition only needs shared ref - if this.endpoint.is_some() && this.data_factories.is_some() { - // create app data container - let mut data = this.extensions.take().unwrap(); - - for f in this.data.iter() { - f.create(&mut data); - } - - for f in this.data_factories.take().unwrap().iter() { - f.create(&mut data); - } - - return Poll::Ready(Ok(AppInitService { - service: this.endpoint.take().unwrap(), - rmap: this.rmap.clone(), - config: this.config.clone(), - data: Rc::new(data), - pool: HttpRequestPool::create(), - })); - } - - Poll::Pending - } -} - -/// Service to convert `Request` to a `ServiceRequest` +/// Service that takes a [`Request`] and delegates to a service that take a [`ServiceRequest`]. pub struct AppInitService where - T: Service, Error = Error>, + T: Service, Error = Error>, { service: T, - rmap: Rc, - config: AppConfig, - data: Rc, - pool: &'static HttpRequestPool, + app_data: Rc, + app_state: Rc, } -impl Service for AppInitService +/// A collection of [`AppInitService`] state that shared across `HttpRequest`s. +pub(crate) struct AppInitServiceState { + rmap: Rc, + config: AppConfig, + pool: HttpRequestPool, +} + +impl AppInitServiceState { + pub(crate) fn new(rmap: Rc, config: AppConfig) -> Rc { + Rc::new(AppInitServiceState { + rmap, + config, + // TODO: AppConfig can be used to pass user defined HttpRequestPool + // capacity. + pool: HttpRequestPool::default(), + }) + } + + #[inline] + pub(crate) fn rmap(&self) -> &ResourceMap { + &*self.rmap + } + + #[inline] + pub(crate) fn config(&self) -> &AppConfig { + &self.config + } + + #[inline] + pub(crate) fn pool(&self) -> &HttpRequestPool { + &self.pool + } +} + +impl Service for AppInitService where - T: Service, Error = Error>, + T: Service, Error = Error>, { - type Request = Request; type Response = ServiceResponse; type Error = T::Error; type Future = T::Future; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_service::forward_ready!(service); - fn call(&mut self, req: Request) -> Self::Future { + fn call(&self, req: Request) -> Self::Future { let (head, payload) = req.into_parts(); - let req = if let Some(mut req) = self.pool.get_request() { - let inner = Rc::get_mut(&mut req.0).unwrap(); + let req = if let Some(mut req) = self.app_state.pool().pop() { + let inner = Rc::get_mut(&mut req.inner).unwrap(); inner.path.get_mut().update(&head.uri); inner.path.reset(); inner.head = head; - inner.payload = payload; - inner.app_data = tiny_vec![self.data.clone()]; req } else { HttpRequest::new( Path::new(Url::new(head.uri.clone())), head, - payload, - self.rmap.clone(), - self.config.clone(), - self.data.clone(), - self.pool, + self.app_state.clone(), + self.app_data.clone(), ) }; - self.service.call(ServiceRequest::new(req)) + self.service.call(ServiceRequest::new(req, payload)) } } impl Drop for AppInitService where - T: Service, Error = Error>, + T: Service, Error = Error>, { fn drop(&mut self) { - self.pool.clear(); + self.app_state.pool().clear(); } } @@ -277,133 +231,63 @@ pub struct AppRoutingFactory { default: Rc, } -impl ServiceFactory for AppRoutingFactory { - type Config = (); - type Request = ServiceRequest; +impl ServiceFactory for AppRoutingFactory { type Response = ServiceResponse; type Error = Error; - type InitError = (); + type Config = (); type Service = AppRouting; - type Future = AppRoutingFactoryResponse; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - AppRoutingFactoryResponse { - fut: self - .services - .iter() - .map(|(path, service, guards)| { - CreateAppRoutingItem::Future( - Some(path.clone()), - guards.borrow_mut().take(), - service.new_service(()).boxed_local(), - ) - }) - .collect(), - default: None, - default_fut: Some(self.default.new_service(())), - } - } -} - -type HttpServiceFut = LocalBoxFuture<'static, Result>; - -/// Create app service -#[doc(hidden)] -pub struct AppRoutingFactoryResponse { - fut: Vec, - default: Option, - default_fut: Option>>, -} - -enum CreateAppRoutingItem { - Future(Option, Option, HttpServiceFut), - Service(ResourceDef, Option, HttpService), -} - -impl Future for AppRoutingFactoryResponse { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut done = true; - - if let Some(ref mut fut) = self.default_fut { - match Pin::new(fut).poll(cx)? { - Poll::Ready(default) => self.default = Some(default), - Poll::Pending => done = false, + // construct all services factory future with it's resource def and guards. + let factory_fut = join_all(self.services.iter().map(|(path, factory, guards)| { + let path = path.clone(); + let guards = guards.borrow_mut().take(); + let factory_fut = factory.new_service(()); + async move { + let service = factory_fut.await?; + Ok((path, guards, service)) } - } + })); - // poll http services - for item in &mut self.fut { - let res = match item { - CreateAppRoutingItem::Future( - ref mut path, - ref mut guards, - ref mut fut, - ) => match Pin::new(fut).poll(cx) { - Poll::Ready(Ok(service)) => { - Some((path.take().unwrap(), guards.take(), service)) - } - Poll::Ready(Err(_)) => return Poll::Ready(Err(())), - Poll::Pending => { - done = false; - None - } - }, - CreateAppRoutingItem::Service(_, _, _) => continue, - }; + // construct default service factory future + let default_fut = self.default.new_service(()); - if let Some((path, guards, service)) = res { - *item = CreateAppRoutingItem::Service(path, guards, service); - } - } + Box::pin(async move { + let default = default_fut.await?; - if done { - let router = self - .fut + // build router from the factory future result. + let router = factory_fut + .await + .into_iter() + .collect::, _>>()? .drain(..) - .fold(Router::build(), |mut router, item| { - match item { - CreateAppRoutingItem::Service(path, guards, service) => { - router.rdef(path, service).2 = guards; - } - CreateAppRoutingItem::Future(_, _, _) => unreachable!(), - } + .fold(Router::build(), |mut router, (path, guards, service)| { + router.rdef(path, service).2 = guards; router - }); - Poll::Ready(Ok(AppRouting { - ready: None, - router: router.finish(), - default: self.default.take(), - })) - } else { - Poll::Pending - } + }) + .finish(); + + Ok(AppRouting { router, default }) + }) } } pub struct AppRouting { router: Router, - ready: Option<(ServiceRequest, ResourceInfo)>, - default: Option, + default: HttpService, } -impl Service for AppRouting { - type Request = ServiceRequest; +impl Service for AppRouting { type Response = ServiceResponse; type Error = Error; - type Future = BoxResponse; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - if self.ready.is_none() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } + actix_service::always_ready!(); - fn call(&mut self, mut req: ServiceRequest) -> Self::Future { - let res = self.router.recognize_mut_checked(&mut req, |req, guards| { + fn call(&self, mut req: ServiceRequest) -> Self::Future { + let res = self.router.recognize_checked(&mut req, |req, guards| { if let Some(ref guards) = guards { for f in guards { if !f.check(req.head()) { @@ -416,11 +300,8 @@ impl Service for AppRouting { if let Some((srv, _info)) = res { srv.call(req) - } else if let Some(ref mut default) = self.default { - default.call(req) } else { - let req = req.into_parts().0; - ok(ServiceResponse::new(req, Response::NotFound().finish())).boxed_local() + self.default.call(req) } } } @@ -436,14 +317,13 @@ impl AppEntry { } } -impl ServiceFactory for AppEntry { - type Config = (); - type Request = ServiceRequest; +impl ServiceFactory for AppEntry { type Response = ServiceResponse; type Error = Error; - type InitError = (); + type Config = (); type Service = AppRouting; - type Future = AppRoutingFactoryResponse; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { self.factory.borrow_mut().as_mut().unwrap().new_service(()) @@ -455,9 +335,10 @@ mod tests { use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; + use actix_service::Service; + use crate::test::{init_service, TestRequest}; use crate::{web, App, HttpResponse}; - use actix_service::Service; struct DropData(Arc); @@ -472,7 +353,7 @@ mod tests { let data = Arc::new(AtomicBool::new(false)); { - let mut app = init_service( + let app = init_service( App::new() .data(DropData(data.clone())) .service(web::resource("/test").to(HttpResponse::Ok)), diff --git a/src/config.rs b/src/config.rs index 01959daa1..bd9a25c6f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,7 +5,7 @@ use actix_http::Extensions; use actix_router::ResourceDef; use actix_service::{boxed, IntoServiceFactory, ServiceFactory}; -use crate::data::{Data, DataFactory}; +use crate::data::Data; use crate::error::Error; use crate::guard::Guard; use crate::resource::Resource; @@ -17,8 +17,7 @@ use crate::service::{ }; type Guards = Vec>; -type HttpNewService = - boxed::BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; +type HttpNewService = boxed::BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; /// Application configuration pub struct AppService { @@ -31,20 +30,14 @@ pub struct AppService { Option, Option>, )>, - service_data: Rc<[Box]>, } impl AppService { - /// Crate server settings instance - pub(crate) fn new( - config: AppConfig, - default: Rc, - service_data: Rc<[Box]>, - ) -> Self { + /// Crate server settings instance. + pub(crate) fn new(config: AppConfig, default: Rc) -> Self { AppService { config, default, - service_data, root: true, services: Vec::new(), } @@ -75,7 +68,6 @@ impl AppService { default: self.default.clone(), services: Vec::new(), root: false, - service_data: self.service_data.clone(), } } @@ -89,15 +81,7 @@ impl AppService { self.default.clone() } - /// Set global route data - pub fn set_service_data(&self, extensions: &mut Extensions) -> bool { - for f in self.service_data.iter() { - f.create(extensions); - } - !self.service_data.is_empty() - } - - /// Register http service + /// Register HTTP service. pub fn register_service( &mut self, rdef: ResourceDef, @@ -105,29 +89,23 @@ impl AppService { factory: F, nested: Option>, ) where - F: IntoServiceFactory, + F: IntoServiceFactory, S: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), > + 'static, { - self.services.push(( - rdef, - boxed::factory(factory.into_factory()), - guards, - nested, - )); + self.services + .push((rdef, boxed::factory(factory.into_factory()), guards, nested)); } } /// Application connection config #[derive(Clone)] -pub struct AppConfig(Rc); - -struct AppConfigInner { +pub struct AppConfig { secure: bool, host: String, addr: SocketAddr, @@ -135,7 +113,7 @@ struct AppConfigInner { impl AppConfig { pub(crate) fn new(secure: bool, addr: SocketAddr, host: String) -> Self { - AppConfig(Rc::new(AppConfigInner { secure, addr, host })) + AppConfig { secure, addr, host } } /// Server host name. @@ -146,17 +124,17 @@ impl AppConfig { /// /// By default host name is set to a "localhost" value. pub fn host(&self) -> &str { - &self.0.host + &self.host } /// Returns true if connection is secure(https) pub fn secure(&self) -> bool { - self.0.secure + self.secure } /// Returns the socket address of the local half of this TCP connection pub fn local_addr(&self) -> SocketAddr { - self.0.addr + self.addr } } @@ -170,47 +148,60 @@ impl Default for AppConfig { } } -/// Service config is used for external configuration. -/// Part of application configuration could be offloaded -/// to set of external methods. This could help with -/// modularization of big application configuration. +/// Enables parts of app configuration to be declared separately from the app itself. Helpful for +/// modularizing large applications. +/// +/// Merge a `ServiceConfig` into an app using [`App::configure`](crate::App::configure). Scope and +/// resources services have similar methods. +/// +/// ``` +/// use actix_web::{web, App, HttpResponse}; +/// +/// // this function could be located in different module +/// fn config(cfg: &mut web::ServiceConfig) { +/// cfg.service(web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// } +/// +/// // merge `/test` routes from config function to App +/// App::new().configure(config); +/// ``` pub struct ServiceConfig { pub(crate) services: Vec>, - pub(crate) data: Vec>, pub(crate) external: Vec, - pub(crate) extensions: Extensions, + pub(crate) app_data: Extensions, } impl ServiceConfig { pub(crate) fn new() -> Self { Self { services: Vec::new(), - data: Vec::new(), external: Vec::new(), - extensions: Extensions::new(), + app_data: Extensions::new(), } } - /// Set application data. Application data could be accessed - /// by using `Data` extractor where `T` is data type. + /// Add shared app data item. /// - /// This is same as `App::data()` method. - pub fn data(&mut self, data: S) -> &mut Self { - self.data.push(Box::new(Data::new(data))); + /// Counterpart to [`App::data()`](crate::App::data). + pub fn data(&mut self, data: U) -> &mut Self { + self.app_data(Data::new(data)); self } - /// Set arbitrary data item. + /// Add arbitrary app data item. /// - /// This is same as `App::data()` method. + /// Counterpart to [`App::app_data()`](crate::App::app_data). pub fn app_data(&mut self, ext: U) -> &mut Self { - self.extensions.insert(ext); + self.app_data.insert(ext); self } /// Configure route for a specific path. /// - /// This is same as `App::route()` method. + /// Counterpart to [`App::route()`](crate::App::route). pub fn route(&mut self, path: &str, mut route: Route) -> &mut Self { self.service( Resource::new(path) @@ -219,9 +210,9 @@ impl ServiceConfig { ) } - /// Register http service. + /// Register HTTP service factory. /// - /// This is same as `App::service()` method. + /// Counterpart to [`App::service()`](crate::App::service). pub fn service(&mut self, factory: F) -> &mut Self where F: HttpServiceFactory + 'static, @@ -233,11 +224,11 @@ impl ServiceConfig { /// Register an external resource. /// - /// External resources are useful for URL generation purposes only - /// and are never considered for matching at request time. Calls to - /// `HttpRequest::url_for()` will work as expected. + /// External resources are useful for URL generation purposes only and are never considered for + /// matching at request time. Calls to [`HttpRequest::url_for()`](crate::HttpRequest::url_for) + /// will work as expected. /// - /// This is same as `App::external_service()` method. + /// Counterpart to [`App::external_resource()`](crate::App::external_resource). pub fn external_resource(&mut self, name: N, url: U) -> &mut Self where N: AsRef, @@ -267,12 +258,12 @@ mod tests { cfg.app_data(15u8); }; - let mut srv = init_service(App::new().configure(cfg).service( - web::resource("/").to(|_: web::Data, req: HttpRequest| { + let srv = init_service(App::new().configure(cfg).service(web::resource("/").to( + |_: web::Data, req: HttpRequest| { assert_eq!(*req.app_data::().unwrap(), 15u8); HttpResponse::Ok() - }), - )) + }, + ))) .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); @@ -290,7 +281,7 @@ mod tests { // }); // }; - // let mut srv = + // let srv = // init_service(App::new().configure(cfg).service( // web::resource("/").to(|_: web::Data| HttpResponse::Ok()), // )); @@ -301,7 +292,7 @@ mod tests { // let cfg2 = |cfg: &mut ServiceConfig| { // cfg.data_factory(|| Ok::<_, ()>(10u32)); // }; - // let mut srv = init_service( + // let srv = init_service( // App::new() // .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())) // .configure(cfg2), @@ -313,26 +304,22 @@ mod tests { #[actix_rt::test] async fn test_external_resource() { - let mut srv = init_service( + let srv = init_service( App::new() .configure(|cfg| { - cfg.external_resource( - "youtube", - "https://youtube.com/watch/{video_id}", - ); + cfg.external_resource("youtube", "https://youtube.com/watch/{video_id}"); }) .route( "/test", web::get().to(|req: HttpRequest| { - HttpResponse::Ok().body( - req.url_for("youtube", &["12345"]).unwrap().to_string(), - ) + HttpResponse::Ok() + .body(req.url_for("youtube", &["12345"]).unwrap().to_string()) }), ), ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let body = read_body(resp).await; assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); @@ -340,24 +327,22 @@ mod tests { #[actix_rt::test] async fn test_service() { - let mut srv = init_service(App::new().configure(|cfg| { - cfg.service( - web::resource("/test").route(web::get().to(HttpResponse::Created)), - ) - .route("/index.html", web::get().to(HttpResponse::Ok)); + let srv = init_service(App::new().configure(|cfg| { + cfg.service(web::resource("/test").route(web::get().to(HttpResponse::Created))) + .route("/index.html", web::get().to(HttpResponse::Ok)); })) .await; let req = TestRequest::with_uri("/test") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::CREATED); let req = TestRequest::with_uri("/index.html") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/src/data.rs b/src/data.rs index 19c258ff0..133248212 100644 --- a/src/data.rs +++ b/src/data.rs @@ -10,8 +10,9 @@ use crate::dev::Payload; use crate::extract::FromRequest; use crate::request::HttpRequest; -/// Application data factory +/// Data factory. pub(crate) trait DataFactory { + /// Return true if modifications were made to extensions map. fn create(&self, extensions: &mut Extensions) -> bool; } @@ -26,7 +27,7 @@ pub(crate) type FnDataFactory = /// /// Application data can be accessed by using `Data` extractor where `T` is data type. /// -/// **Note**: http server accepts an application factory rather than an application instance. HTTP +/// **Note**: HTTP server accepts an application factory rather than an application instance. HTTP /// server constructs an application instance for each thread, thus application data must be /// constructed multiple times. If you want to share data between different threads, a shareable /// object should be used, e.g. `Send + Sync`. Application data does not need to be `Send` @@ -126,12 +127,8 @@ impl FromRequest for Data { impl DataFactory for Data { fn create(&self, extensions: &mut Extensions) -> bool { - if !extensions.contains::>() { - extensions.insert(Data(self.0.clone())); - true - } else { - false - } + extensions.insert(Data(self.0.clone())); + true } } @@ -147,7 +144,7 @@ mod tests { #[actix_rt::test] async fn test_data_extractor() { - let mut srv = init_service(App::new().data("TEST".to_string()).service( + let srv = init_service(App::new().data("TEST".to_string()).service( web::resource("/").to(|data: web::Data| { assert_eq!(data.to_lowercase(), "test"); HttpResponse::Ok() @@ -159,33 +156,54 @@ mod tests { let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); - let mut srv = - init_service(App::new().data(10u32).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )) - .await; + let srv = init_service( + App::new() + .data(10u32) + .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())), + ) + .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let srv = init_service( + App::new() + .data(10u32) + .data(13u32) + .app_data(12u64) + .app_data(15u64) + .default_service(web::to(|n: web::Data, req: HttpRequest| { + // in each case, the latter insertion should be preserved + assert_eq!(*req.app_data::().unwrap(), 15); + assert_eq!(*n.into_inner(), 13); + HttpResponse::Ok() + })), + ) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_app_data_extractor() { - let mut srv = - init_service(App::new().app_data(Data::new(10usize)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )) - .await; + let srv = init_service( + App::new() + .app_data(Data::new(10usize)) + .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())), + ) + .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); - let mut srv = - init_service(App::new().app_data(Data::new(10u32)).service( - web::resource("/").to(|_: web::Data| HttpResponse::Ok()), - )) - .await; + let srv = init_service( + App::new() + .app_data(Data::new(10u32)) + .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())), + ) + .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); @@ -193,7 +211,7 @@ mod tests { #[actix_rt::test] async fn test_route_data_extractor() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::resource("/") .data(10usize) @@ -207,7 +225,7 @@ mod tests { assert_eq!(resp.status(), StatusCode::OK); // different type - let mut srv = init_service( + let srv = init_service( App::new().service( web::resource("/") .data(10u32) @@ -222,15 +240,16 @@ mod tests { #[actix_rt::test] async fn test_override_data() { - let mut srv = init_service(App::new().data(1usize).service( - web::resource("/").data(10usize).route(web::get().to( - |data: web::Data| { - assert_eq!(**data, 10); - HttpResponse::Ok() - }, - )), - )) - .await; + let srv = + init_service(App::new().data(1usize).service( + web::resource("/").data(10usize).route(web::get().to( + |data: web::Data| { + assert_eq!(**data, 10); + HttpResponse::Ok() + }, + )), + )) + .await; let req = TestRequest::default().to_request(); let resp = srv.call(req).await.unwrap(); diff --git a/src/error.rs b/src/error.rs index 60af8fa11..0865257d3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,11 @@ //! Error and Result module pub use actix_http::error::*; -use derive_more::{Display, From}; +use derive_more::{Display, Error, From}; use serde_json::error::Error as JsonError; use url::ParseError as UrlParseError; -use crate::http::StatusCode; -use crate::HttpResponse; +use crate::{http::StatusCode, HttpResponse}; /// Errors which can occur when attempting to generate resource uri. #[derive(Debug, PartialEq, Display, From)] @@ -14,9 +13,11 @@ pub enum UrlGenerationError { /// Resource not found #[display(fmt = "Resource not found")] ResourceNotFound, + /// Not all path pattern covered #[display(fmt = "Not all path pattern covered")] NotEnoughElements, + /// URL parse error #[display(fmt = "{}", _0)] ParseError(UrlParseError), @@ -28,34 +29,37 @@ impl std::error::Error for UrlGenerationError {} impl ResponseError for UrlGenerationError {} /// A set of errors that can occur during parsing urlencoded payloads -#[derive(Debug, Display, From)] +#[derive(Debug, Display, Error, From)] pub enum UrlencodedError { - /// Can not decode chunked transfer encoding - #[display(fmt = "Can not decode chunked transfer encoding")] + /// Can not decode chunked transfer encoding. + #[display(fmt = "Can not decode chunked transfer encoding.")] Chunked, - /// Payload size is bigger than allowed. (default: 256kB) + + /// Payload size is larger than allowed. (default limit: 256kB). #[display( - fmt = "Urlencoded payload size is bigger ({} bytes) than allowed (default: {} bytes)", + fmt = "URL encoded payload is larger ({} bytes) than allowed (limit: {} bytes).", size, limit )] Overflow { size: usize, limit: usize }, - /// Payload size is now known - #[display(fmt = "Payload size is now known")] + + /// Payload size is now known. + #[display(fmt = "Payload size is now known.")] UnknownLength, - /// Content type error - #[display(fmt = "Content type error")] + + /// Content type error. + #[display(fmt = "Content type error.")] ContentType, - /// Parse error - #[display(fmt = "Parse error")] + + /// Parse error. + #[display(fmt = "Parse error.")] Parse, - /// Payload error - #[display(fmt = "Error that occur during reading payload: {}", _0)] + + /// Payload error. + #[display(fmt = "Error that occur during reading payload: {}.", _0)] Payload(PayloadError), } -impl std::error::Error for UrlencodedError {} - /// Return `BadRequest` for `UrlencodedError` impl ResponseError for UrlencodedError { fn status_code(&self) -> StatusCode { @@ -90,9 +94,7 @@ impl std::error::Error for JsonPayloadError {} impl ResponseError for JsonPayloadError { fn error_response(&self) -> HttpResponse { match *self { - JsonPayloadError::Overflow => { - HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE) - } + JsonPayloadError::Overflow => HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE), _ => HttpResponse::new(StatusCode::BAD_REQUEST), } } @@ -115,16 +117,14 @@ impl ResponseError for PathError { } } -/// A set of errors that can occur during parsing query strings -#[derive(Debug, Display, From)] +/// A set of errors that can occur during parsing query strings. +#[derive(Debug, Display, Error, From)] pub enum QueryPayloadError { - /// Deserialize error + /// Query deserialize error. #[display(fmt = "Query deserialize error: {}", _0)] Deserialize(serde::de::value::Error), } -impl std::error::Error for QueryPayloadError {} - /// Return `BadRequest` for `QueryPayloadError` impl ResponseError for QueryPayloadError { fn status_code(&self) -> StatusCode { diff --git a/src/extract.rs b/src/extract.rs index 5916b1bc5..7a677bca4 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,34 +1,37 @@ //! Request extractors -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use actix_http::error::Error; -use futures_util::future::{ready, Ready}; -use futures_util::ready; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; -use crate::dev::Payload; -use crate::request::HttpRequest; +use futures_util::{ + future::{ready, Ready}, + ready, +}; + +use crate::{dev::Payload, Error, HttpRequest}; /// Trait implemented by types that can be extracted from request. /// /// Types that implement this trait can be used with `Route` handlers. pub trait FromRequest: Sized { + /// Configuration for this extractor. + type Config: Default + 'static; + /// The associated error which can be returned. type Error: Into; - /// Future that resolves to a Self + /// Future that resolves to a Self. type Future: Future>; - /// Configuration for this extractor - type Config: Default + 'static; - - /// Convert request to a Self + /// Create a Self from request parts asynchronously. fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future; - /// Convert request to a Self + /// Create a Self from request head asynchronously. /// - /// This method uses `Payload::None` as payload stream. + /// This method is short for `T::from_request(req, &mut Payload::None)`. fn extract(req: &HttpRequest) -> Self::Future { Self::from_request(req, &mut Payload::None) } @@ -344,25 +347,21 @@ mod tests { #[actix_rt::test] async fn test_option() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) - .data(FormConfig::default().limit(4096)) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .data(FormConfig::default().limit(4096)) + .to_http_parts(); let r = Option::>::from_request(&req, &mut pl) .await .unwrap(); assert_eq!(r, None); - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) - .header(header::CONTENT_LENGTH, "9") - .set_payload(Bytes::from_static(b"hello=world")) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((header::CONTENT_LENGTH, "9")) + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); let r = Option::>::from_request(&req, &mut pl) .await @@ -374,13 +373,11 @@ mod tests { })) ); - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) - .header(header::CONTENT_LENGTH, "9") - .set_payload(Bytes::from_static(b"bye=world")) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((header::CONTENT_LENGTH, "9")) + .set_payload(Bytes::from_static(b"bye=world")) + .to_http_parts(); let r = Option::>::from_request(&req, &mut pl) .await @@ -390,13 +387,11 @@ mod tests { #[actix_rt::test] async fn test_result() { - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) - .header(header::CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world")) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((header::CONTENT_LENGTH, "11")) + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); let r = Result::, Error>::from_request(&req, &mut pl) .await @@ -409,13 +404,11 @@ mod tests { }) ); - let (req, mut pl) = TestRequest::with_header( - header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) - .header(header::CONTENT_LENGTH, "9") - .set_payload(Bytes::from_static(b"bye=world")) - .to_http_parts(); + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((header::CONTENT_LENGTH, 9)) + .set_payload(Bytes::from_static(b"bye=world")) + .to_http_parts(); let r = Result::, Error>::from_request(&req, &mut pl) .await diff --git a/src/guard.rs b/src/guard.rs index 25284236a..5d0de58c2 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -176,7 +176,7 @@ impl Guard for NotGuard { } } -/// Http method guard +/// HTTP method guard. #[doc(hidden)] pub struct MethodGuard(http::Method); @@ -186,52 +186,52 @@ impl Guard for MethodGuard { } } -/// Guard to match *GET* http method +/// Guard to match *GET* HTTP method. pub fn Get() -> MethodGuard { MethodGuard(http::Method::GET) } -/// Predicate to match *POST* http method +/// Predicate to match *POST* HTTP method. pub fn Post() -> MethodGuard { MethodGuard(http::Method::POST) } -/// Predicate to match *PUT* http method +/// Predicate to match *PUT* HTTP method. pub fn Put() -> MethodGuard { MethodGuard(http::Method::PUT) } -/// Predicate to match *DELETE* http method +/// Predicate to match *DELETE* HTTP method. pub fn Delete() -> MethodGuard { MethodGuard(http::Method::DELETE) } -/// Predicate to match *HEAD* http method +/// Predicate to match *HEAD* HTTP method. pub fn Head() -> MethodGuard { MethodGuard(http::Method::HEAD) } -/// Predicate to match *OPTIONS* http method +/// Predicate to match *OPTIONS* HTTP method. pub fn Options() -> MethodGuard { MethodGuard(http::Method::OPTIONS) } -/// Predicate to match *CONNECT* http method +/// Predicate to match *CONNECT* HTTP method. pub fn Connect() -> MethodGuard { MethodGuard(http::Method::CONNECT) } -/// Predicate to match *PATCH* http method +/// Predicate to match *PATCH* HTTP method. pub fn Patch() -> MethodGuard { MethodGuard(http::Method::PATCH) } -/// Predicate to match *TRACE* http method +/// Predicate to match *TRACE* HTTP method. pub fn Trace() -> MethodGuard { MethodGuard(http::Method::TRACE) } -/// Predicate to match specified http method +/// Predicate to match specified HTTP method. pub fn Method(method: http::Method) -> MethodGuard { MethodGuard(method) } @@ -330,7 +330,8 @@ mod tests { #[test] fn test_header() { - let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked") + let req = TestRequest::default() + .insert_header((header::TRANSFER_ENCODING, "chunked")) .to_http_request(); let pred = Header("transfer-encoding", "chunked"); @@ -346,10 +347,10 @@ mod tests { #[test] fn test_host() { let req = TestRequest::default() - .header( + .insert_header(( header::HOST, header::HeaderValue::from_static("www.rust-lang.org"), - ) + )) .to_http_request(); let pred = Host("www.rust-lang.org"); @@ -374,10 +375,10 @@ mod tests { #[test] fn test_host_scheme() { let req = TestRequest::default() - .header( + .insert_header(( header::HOST, header::HeaderValue::from_static("https://www.rust-lang.org"), - ) + )) .to_http_request(); let pred = Host("www.rust-lang.org").scheme("https"); diff --git a/src/handler.rs b/src/handler.rs index d4b755e57..0016b741e 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -49,7 +49,7 @@ where R::Output: Responder, { hnd: F, - _t: PhantomData<(T, R)>, + _phantom: PhantomData<(T, R)>, } impl HandlerService @@ -62,7 +62,7 @@ where pub fn new(hnd: F) -> Self { Self { hnd, - _t: PhantomData, + _phantom: PhantomData, } } } @@ -77,19 +77,18 @@ where fn clone(&self) -> Self { Self { hnd: self.hnd.clone(), - _t: PhantomData, + _phantom: PhantomData, } } } -impl ServiceFactory for HandlerService +impl ServiceFactory for HandlerService where F: Handler, T: FromRequest, R: Future, R::Output: Responder, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Config = (); @@ -102,24 +101,23 @@ where } } -// Handler is both it's ServiceHandler and Service Type. -impl Service for HandlerService +/// HandlerService is both it's ServiceFactory and Service Type. +impl Service for HandlerService where F: Handler, T: FromRequest, R: Future, R::Output: Responder, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Future = HandlerServiceFuture; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: Self::Request) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { let (req, mut payload) = req.into_parts(); let fut = T::from_request(&req, &mut payload); HandlerServiceFuture::Extract(fut, Some(req), self.hnd.clone()) @@ -137,7 +135,6 @@ where { Extract(#[pin] T::Future, Option, F), Handle(#[pin] R, Option), - Respond(#[pin] ::Future, Option), } impl Future for HandlerServiceFuture @@ -170,13 +167,8 @@ where } HandlerProj::Handle(fut, req) => { let res = ready!(fut.poll(cx)); - let fut = res.respond_to(req.as_ref().unwrap()); - let state = HandlerServiceFuture::Respond(fut, req.take()); - self.as_mut().set(state); - } - HandlerProj::Respond(fut, req) => { - let res = ready!(fut.poll(cx)).unwrap_or_else(|e| e.into().into()); let req = req.take().unwrap(); + let res = res.respond_to(&req); return Poll::Ready(Ok(ServiceResponse::new(req, res))); } } diff --git a/src/info.rs b/src/info.rs index 975604041..c9ddf6ec4 100644 --- a/src/info.rs +++ b/src/info.rs @@ -55,7 +55,7 @@ impl ConnectionInfo { host = Some(val.trim()); } } - _ => (), + _ => {} } } } @@ -200,10 +200,10 @@ mod tests { assert_eq!(info.host(), "localhost:8080"); let req = TestRequest::default() - .header( + .insert_header(( header::FORWARDED, "for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org", - ) + )) .to_http_request(); let info = req.connection_info(); @@ -212,7 +212,7 @@ mod tests { assert_eq!(info.realip_remote_addr(), Some("192.0.2.60")); let req = TestRequest::default() - .header(header::HOST, "rust-lang.org") + .insert_header((header::HOST, "rust-lang.org")) .to_http_request(); let info = req.connection_info(); @@ -221,20 +221,20 @@ mod tests { assert_eq!(info.realip_remote_addr(), None); let req = TestRequest::default() - .header(X_FORWARDED_FOR, "192.0.2.60") + .insert_header((X_FORWARDED_FOR, "192.0.2.60")) .to_http_request(); let info = req.connection_info(); assert_eq!(info.realip_remote_addr(), Some("192.0.2.60")); let req = TestRequest::default() - .header(X_FORWARDED_HOST, "192.0.2.60") + .insert_header((X_FORWARDED_HOST, "192.0.2.60")) .to_http_request(); let info = req.connection_info(); assert_eq!(info.host(), "192.0.2.60"); assert_eq!(info.realip_remote_addr(), None); let req = TestRequest::default() - .header(X_FORWARDED_PROTO, "https") + .insert_header((X_FORWARDED_PROTO, "https")) .to_http_request(); let info = req.connection_info(); assert_eq!(info.scheme(), "https"); diff --git a/src/lib.rs b/src/lib.rs index f2233c220..4beb54f98 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,8 @@ //! use actix_web::{get, web, App, HttpServer, Responder}; //! //! #[get("/{id}/{name}/index.html")] -//! async fn index(web::Path((id, name)): web::Path<(u32, String)>) -> impl Responder { +//! async fn index(path: web::Path<(u32, String)>) -> impl Responder { +//! let (id, name) = path.into_inner(); //! format!("Hello {}! id:{}", name, id) //! } //! @@ -29,7 +30,7 @@ //! //! To get started navigating the API docs, you may consider looking at the following pages first: //! -//! * [App]: This struct represents an Actix web application and is used to +//! * [App]: This struct represents an Actix Web application and is used to //! configure routes and other common application settings. //! //! * [HttpServer]: This struct represents an HTTP server instance and is @@ -55,12 +56,12 @@ //! * SSL support using OpenSSL or Rustls //! * Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) //! * Includes an async [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) -//! * Supports [Actix actor framework](https://github.com/actix/actix) //! * Runs on stable Rust 1.46+ //! //! ## Crate Features //! //! * `compress` - content encoding compression support (enabled by default) +//! * `cookies` - cookies support (enabled by default) //! * `openssl` - HTTPS support via `openssl` crate, supports `HTTP/2` //! * `rustls` - HTTPS support via `rustls` crate, supports `HTTP/2` //! * `secure-cookies` - secure cookies support @@ -71,6 +72,11 @@ #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] #![clippy::msrv = "1.46"] +#[cfg(feature = "openssl")] +extern crate tls_openssl as openssl; +#[cfg(feature = "rustls")] +extern crate tls_rustls as rustls; + mod app; mod app_service; mod config; @@ -91,11 +97,13 @@ mod scope; mod server; mod service; pub mod test; -mod types; +pub(crate) mod types; pub mod web; +#[cfg(feature = "cookies")] +pub use actix_http::cookie; pub use actix_http::Response as HttpResponse; -pub use actix_http::{body, cookie, http, Error, HttpMessage, ResponseError, Result}; +pub use actix_http::{body, http, Error, HttpMessage, ResponseError, Result}; pub use actix_rt as rt; pub use actix_web_codegen::*; @@ -107,6 +115,7 @@ pub use crate::responder::Responder; pub use crate::route::Route; pub use crate::scope::Scope; pub use crate::server::HttpServer; +// TODO: is exposing the error directly really needed pub use crate::types::{Either, EitherExtractError}; pub mod dev { @@ -125,9 +134,7 @@ pub mod dev { pub use crate::handler::Handler; pub use crate::info::ConnectionInfo; pub use crate::rmap::ResourceMap; - pub use crate::service::{ - HttpServiceFactory, ServiceRequest, ServiceResponse, WebService, - }; + pub use crate::service::{HttpServiceFactory, ServiceRequest, ServiceResponse, WebService}; pub use crate::types::form::UrlEncoded; pub use crate::types::json::JsonBody; @@ -137,9 +144,7 @@ pub mod dev { #[cfg(feature = "compress")] pub use actix_http::encoding::Decoder as Decompress; pub use actix_http::ResponseBuilder as HttpResponseBuilder; - pub use actix_http::{ - Extensions, Payload, PayloadStream, RequestHead, ResponseHead, - }; + pub use actix_http::{Extensions, Payload, PayloadStream, RequestHead, ResponseHead}; pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; pub use actix_server::Server; pub use actix_service::{Service, Transform}; @@ -197,29 +202,3 @@ pub mod dev { } } } - -pub mod client { - //! Actix web async HTTP client. - //! - //! ```rust - //! use actix_web::client::Client; - //! - //! #[actix_web::main] - //! async fn main() { - //! let mut client = Client::default(); - //! - //! // Create request builder and send request - //! let response = client.get("http://www.rust-lang.org") - //! .header("User-Agent", "actix-web/3.0") - //! .send() // <- Send request - //! .await; // <- Wait for response - //! - //! println!("Response: {:?}", response); - //! } - //! ``` - - pub use awc::error::*; - pub use awc::{ - test, Client, ClientBuilder, ClientRequest, ClientResponse, Connector, - }; -} diff --git a/src/middleware/compat.rs b/src/middleware/compat.rs new file mode 100644 index 000000000..6f60264b1 --- /dev/null +++ b/src/middleware/compat.rs @@ -0,0 +1,202 @@ +//! For middleware documentation, see [`Compat`]. + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::body::{Body, MessageBody, ResponseBody}; +use actix_service::{Service, Transform}; +use futures_core::{future::LocalBoxFuture, ready}; + +use crate::{error::Error, service::ServiceResponse}; + +/// Middleware for enabling any middleware to be used in [`Resource::wrap`](crate::Resource::wrap), +/// [`Scope::wrap`](crate::Scope::wrap) and [`Condition`](super::Condition). +/// +/// # Examples +/// ```rust +/// use actix_web::middleware::{Logger, Compat}; +/// use actix_web::{App, web}; +/// +/// let logger = Logger::default(); +/// +/// // this would not compile because of incompatible body types +/// // let app = App::new() +/// // .service(web::scope("scoped").wrap(logger)); +/// +/// // by using this middleware we can use the logger on a scope +/// let app = App::new() +/// .service(web::scope("scoped").wrap(Compat::new(logger))); +/// ``` +pub struct Compat { + transform: T, +} + +impl Compat { + /// Wrap a middleware to give it broader compatibility. + pub fn new(middleware: T) -> Self { + Self { + transform: middleware, + } + } +} + +impl Transform for Compat +where + S: Service, + T: Transform, + T::Future: 'static, + T::Response: MapServiceResponseBody, + Error: From, +{ + type Response = ServiceResponse; + type Error = Error; + type Transform = CompatMiddleware; + type InitError = T::InitError; + type Future = LocalBoxFuture<'static, Result>; + + fn new_transform(&self, service: S) -> Self::Future { + let fut = self.transform.new_transform(service); + Box::pin(async move { + let service = fut.await?; + Ok(CompatMiddleware { service }) + }) + } +} + +pub struct CompatMiddleware { + service: S, +} + +impl Service for CompatMiddleware +where + S: Service, + S::Response: MapServiceResponseBody, + Error: From, +{ + type Response = ServiceResponse; + type Error = Error; + type Future = CompatMiddlewareFuture; + + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx).map_err(From::from) + } + + fn call(&self, req: Req) -> Self::Future { + let fut = self.service.call(req); + CompatMiddlewareFuture { fut } + } +} + +#[pin_project::pin_project] +pub struct CompatMiddlewareFuture { + #[pin] + fut: Fut, +} + +impl Future for CompatMiddlewareFuture +where + Fut: Future>, + T: MapServiceResponseBody, + Error: From, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = ready!(self.project().fut.poll(cx))?; + Poll::Ready(Ok(res.map_body())) + } +} + +/// Convert `ServiceResponse`'s `ResponseBody` generic type to `ResponseBody`. +pub trait MapServiceResponseBody { + fn map_body(self) -> ServiceResponse; +} + +impl MapServiceResponseBody for ServiceResponse { + fn map_body(self) -> ServiceResponse { + self.map_body(|_, body| ResponseBody::Other(Body::from_message(body))) + } +} + +#[cfg(test)] +mod tests { + // easier to code when cookies feature is disabled + #![allow(unused_imports)] + + use super::*; + + use actix_service::IntoService; + + use crate::dev::ServiceRequest; + use crate::http::StatusCode; + use crate::middleware::{self, Condition, Logger}; + use crate::test::{call_service, init_service, TestRequest}; + use crate::{web, App, HttpResponse}; + + #[actix_rt::test] + #[cfg(feature = "cookies")] + async fn test_scope_middleware() { + use crate::middleware::Compress; + + let logger = Logger::default(); + let compress = Compress::default(); + + let srv = init_service( + App::new().service( + web::scope("app") + .wrap(Compat::new(logger)) + .wrap(Compat::new(compress)) + .service(web::resource("/test").route(web::get().to(HttpResponse::Ok))), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + #[cfg(feature = "cookies")] + async fn test_resource_scope_middleware() { + use crate::middleware::Compress; + + let logger = Logger::default(); + let compress = Compress::default(); + + let srv = init_service( + App::new().service( + web::resource("app/test") + .wrap(Compat::new(logger)) + .wrap(Compat::new(compress)) + .route(web::get().to(HttpResponse::Ok)), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_condition_scope_middleware() { + let srv = |req: ServiceRequest| { + Box::pin(async move { + Ok(req.into_response(HttpResponse::InternalServerError().finish())) + }) + }; + + let logger = Logger::default(); + + let mw = Condition::new(true, Compat::new(logger)) + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs index 7575d7455..698ba768e 100644 --- a/src/middleware/compress.rs +++ b/src/middleware/compress.rs @@ -1,45 +1,48 @@ -//! `Middleware` for compressing response body. -use std::cmp; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::str::FromStr; -use std::task::{Context, Poll}; +//! For middleware documentation, see [`Compress`]. -use actix_http::body::MessageBody; -use actix_http::encoding::Encoder; -use actix_http::http::header::{ContentEncoding, ACCEPT_ENCODING}; -use actix_http::Error; +use std::{ + cmp, + future::Future, + marker::PhantomData, + pin::Pin, + str::FromStr, + task::{Context, Poll}, +}; + +use actix_http::{ + body::MessageBody, + encoding::Encoder, + http::header::{ContentEncoding, ACCEPT_ENCODING}, + Error, +}; use actix_service::{Service, Transform}; +use futures_core::ready; use futures_util::future::{ok, Ready}; use pin_project::pin_project; -use crate::dev::BodyEncoding; -use crate::service::{ServiceRequest, ServiceResponse}; +use crate::{ + dev::BodyEncoding, + service::{ServiceRequest, ServiceResponse}, +}; -#[derive(Debug, Clone)] -/// `Middleware` for compressing response body. +/// Middleware for compressing response payloads. /// -/// Use `BodyEncoding` trait for overriding response compression. -/// To disable compression set encoding to `ContentEncoding::Identity` value. +/// Use `BodyEncoding` trait for overriding response compression. To disable compression set +/// encoding to `ContentEncoding::Identity`. /// +/// # Examples /// ```rust /// use actix_web::{web, middleware, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new() -/// .wrap(middleware::Compress::default()) -/// .service( -/// web::resource("/test") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// let app = App::new() +/// .wrap(middleware::Compress::default()) +/// .default_service(web::to(|| HttpResponse::NotFound())); /// ``` +#[derive(Debug, Clone)] pub struct Compress(ContentEncoding); impl Compress { - /// Create new `Compress` middleware with default encoding. + /// Create new `Compress` middleware with the specified encoding. pub fn new(encoding: ContentEncoding) -> Self { Compress(encoding) } @@ -51,16 +54,15 @@ impl Default for Compress { } } -impl Transform for Compress +impl Transform for Compress where B: MessageBody, - S: Service, Error = Error>, + S: Service, Error = Error>, { - type Request = ServiceRequest; type Response = ServiceResponse>; type Error = Error; - type InitError = (); type Transform = CompressMiddleware; + type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { @@ -76,22 +78,19 @@ pub struct CompressMiddleware { encoding: ContentEncoding, } -impl Service for CompressMiddleware +impl Service for CompressMiddleware where B: MessageBody, - S: Service, Error = Error>, + S: Service, Error = Error>, { - type Request = ServiceRequest; type Response = ServiceResponse>; type Error = Error; type Future = CompressResponse; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_service::forward_ready!(service); #[allow(clippy::borrow_interior_mutable_const)] - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { // negotiate content-encoding let encoding = if let Some(val) = req.headers().get(&ACCEPT_ENCODING) { if let Ok(enc) = val.to_str() { @@ -106,35 +105,34 @@ where CompressResponse { encoding, fut: self.service.call(req), - _t: PhantomData, + _phantom: PhantomData, } } } -#[doc(hidden)] #[pin_project] pub struct CompressResponse where - S: Service, + S: Service, B: MessageBody, { #[pin] fut: S::Future, encoding: ContentEncoding, - _t: PhantomData, + _phantom: PhantomData, } impl Future for CompressResponse where B: MessageBody, - S: Service, Error = Error>, + S: Service, Error = Error>, { type Output = Result>, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - match futures_util::ready!(this.fut.poll(cx)) { + match ready!(this.fut.poll(cx)) { Ok(resp) => { let enc = if let Some(enc) = resp.response().get_encoding() { enc diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index 9061c7458..f0c344062 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -1,63 +1,64 @@ -//! `Middleware` for conditionally enables another middleware. +//! For middleware documentation, see [`Condition`]. + use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures_util::future::{ok, Either, FutureExt, LocalBoxFuture}; +use futures_util::future::{Either, FutureExt, LocalBoxFuture}; -/// `Middleware` for conditionally enables another middleware. -/// The controlled middleware must not change the `Service` interfaces. -/// This means you cannot control such middlewares like `Logger` or `Compress`. +/// Middleware for conditionally enabling other middleware. /// -/// ## Usage +/// The controlled middleware must not change the `Service` interfaces. This means you cannot +/// control such middlewares like `Logger` or `Compress` directly. See the [`Compat`](super::Compat) +/// middleware for a workaround. /// +/// # Examples /// ```rust /// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::App; /// -/// # fn main() { -/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); +/// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok(); /// let app = App::new() /// .wrap(Condition::new(enable_normalize, NormalizePath::default())); -/// # } /// ``` pub struct Condition { - trans: T, + transformer: T, enable: bool, } impl Condition { - pub fn new(enable: bool, trans: T) -> Self { - Self { trans, enable } + pub fn new(enable: bool, transformer: T) -> Self { + Self { + transformer, + enable, + } } } -impl Transform for Condition +impl Transform for Condition where - S: Service + 'static, - T: Transform, + S: Service + 'static, + T: Transform, T::Future: 'static, T::InitError: 'static, T::Transform: 'static, { - type Request = S::Request; type Response = S::Response; type Error = S::Error; - type InitError = T::InitError; type Transform = ConditionMiddleware; + type InitError = T::InitError; type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { if self.enable { - let f = self.trans.new_transform(service).map(|res| { - res.map( - ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform, - ) - }); - Either::Left(f) + let fut = self.transformer.new_transform(service); + async move { + let wrapped_svc = fut.await?; + Ok(ConditionMiddleware::Enable(wrapped_svc)) + } + .boxed_local() } else { - Either::Right(ok(ConditionMiddleware::Disable(service))) + async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local() } - .boxed_local() } } @@ -66,29 +67,26 @@ pub enum ConditionMiddleware { Disable(D), } -impl Service for ConditionMiddleware +impl Service for ConditionMiddleware where - E: Service, - D: Service, + E: Service, + D: Service, { - type Request = E::Request; type Response = E::Response; type Error = E::Error; type Future = Either; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - use ConditionMiddleware::*; + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { match self { - Enable(service) => service.poll_ready(cx), - Disable(service) => service.poll_ready(cx), + ConditionMiddleware::Enable(service) => service.poll_ready(cx), + ConditionMiddleware::Disable(service) => service.poll_ready(cx), } } - fn call(&mut self, req: E::Request) -> Self::Future { - use ConditionMiddleware::*; + fn call(&self, req: Req) -> Self::Future { match self { - Enable(service) => Either::Left(service.call(req)), - Disable(service) => Either::Right(service.call(req)), + ConditionMiddleware::Enable(service) => Either::Left(service.call(req)), + ConditionMiddleware::Disable(service) => Either::Right(service.call(req)), } } } @@ -96,16 +94,18 @@ where #[cfg(test)] mod tests { use actix_service::IntoService; + use futures_util::future::ok; use super::*; - use crate::dev::{ServiceRequest, ServiceResponse}; - use crate::error::Result; - use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; - use crate::middleware::errhandlers::*; - use crate::test::{self, TestRequest}; - use crate::HttpResponse; + use crate::{ + dev::{ServiceRequest, ServiceResponse}, + error::Result, + http::{header::CONTENT_TYPE, HeaderValue, StatusCode}, + middleware::err_handlers::*, + test::{self, TestRequest}, + HttpResponse, + }; - #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() @@ -119,15 +119,13 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mw = - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = Condition::new(true, mw) + let mw = Condition::new(true, mw) .new_transform(srv.into_service()) .await .unwrap(); - let resp = - test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } @@ -137,16 +135,14 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mw = - ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut mw = Condition::new(false, mw) + let mw = Condition::new(false, mw) .new_transform(srv.into_service()) .await .unwrap(); - let resp = - test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE), None); } } diff --git a/src/middleware/defaultheaders.rs b/src/middleware/default_headers.rs similarity index 67% rename from src/middleware/defaultheaders.rs rename to src/middleware/default_headers.rs index a6f1a4336..a36cc2f29 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/default_headers.rs @@ -1,24 +1,34 @@ -//! Middleware for setting default response headers -use std::convert::TryFrom; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; +//! For middleware documentation, see [`DefaultHeaders`]. -use actix_service::{Service, Transform}; -use futures_util::future::{ready, Ready}; -use futures_util::ready; +use std::{ + convert::TryFrom, + future::Future, + marker::PhantomData, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; -use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; -use crate::http::{Error as HttpError, HeaderMap}; -use crate::service::{ServiceRequest, ServiceResponse}; -use crate::Error; +use futures_util::{ + future::{ready, Ready}, + ready, +}; -/// `Middleware` for setting default response headers. +use crate::{ + dev::{Service, Transform}, + http::{ + header::{HeaderName, HeaderValue, CONTENT_TYPE}, + Error as HttpError, HeaderMap, + }, + service::{ServiceRequest, ServiceResponse}, + Error, +}; + +/// Middleware for setting default response headers. /// -/// This middleware does not set header if response headers already contains it. +/// Headers with the same key that are already set in a response will *not* be overwritten. /// +/// # Examples /// ```rust /// use actix_web::{web, http, middleware, App, HttpResponse}; /// @@ -38,7 +48,6 @@ pub struct DefaultHeaders { } struct Inner { - ct: bool, headers: HeaderMap, } @@ -46,7 +55,6 @@ impl Default for DefaultHeaders { fn default() -> Self { DefaultHeaders { inner: Rc::new(Inner { - ct: false, headers: HeaderMap::new(), }), } @@ -54,12 +62,12 @@ impl Default for DefaultHeaders { } impl DefaultHeaders { - /// Construct `DefaultHeaders` middleware. + /// Constructs an empty `DefaultHeaders` middleware. pub fn new() -> DefaultHeaders { DefaultHeaders::default() } - /// Set a header. + /// Adds a header to the default set. #[inline] pub fn header(mut self, key: K, value: V) -> Self where @@ -84,21 +92,27 @@ impl DefaultHeaders { self } - /// Set *CONTENT-TYPE* header if response does not contain this header. - pub fn content_type(mut self) -> Self { + /// Adds a default *Content-Type* header if response does not contain one. + /// + /// Default is `application/octet-stream`. + pub fn add_content_type(mut self) -> Self { Rc::get_mut(&mut self.inner) - .expect("Multiple copies exist") - .ct = true; + .expect("Multiple `Inner` copies exist.") + .headers + .insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + self } } -impl Transform for DefaultHeaders +impl Transform for DefaultHeaders where - S: Service, Error = Error>, + S: Service, Error = Error>, S::Future: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Transform = DefaultHeadersMiddleware; @@ -118,21 +132,18 @@ pub struct DefaultHeadersMiddleware { inner: Rc, } -impl Service for DefaultHeadersMiddleware +impl Service for DefaultHeadersMiddleware where - S: Service, Error = Error>, + S: Service, Error = Error>, S::Future: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Future = DefaultHeaderFuture; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_service::forward_ready!(service); - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { let inner = self.inner.clone(); let fut = self.service.call(req); @@ -145,7 +156,7 @@ where } #[pin_project::pin_project] -pub struct DefaultHeaderFuture { +pub struct DefaultHeaderFuture, B> { #[pin] fut: S::Future, inner: Rc, @@ -154,7 +165,7 @@ pub struct DefaultHeaderFuture { impl Future for DefaultHeaderFuture where - S: Service, Error = Error>, + S: Service, Error = Error>, { type Output = ::Output; @@ -162,19 +173,14 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut res = ready!(this.fut.poll(cx))?; + // set response headers for (key, value) in this.inner.headers.iter() { if !res.headers().contains_key(key) { res.headers_mut().insert(key.clone(), value.clone()); } } - // default content-type - if this.inner.ct && !res.headers().contains_key(&CONTENT_TYPE) { - res.headers_mut().insert( - CONTENT_TYPE, - HeaderValue::from_static("application/octet-stream"), - ); - } + Poll::Ready(Ok(res)) } } @@ -185,14 +191,16 @@ mod tests { use futures_util::future::ok; use super::*; - use crate::dev::ServiceRequest; - use crate::http::header::CONTENT_TYPE; - use crate::test::{ok_service, TestRequest}; - use crate::HttpResponse; + use crate::{ + dev::ServiceRequest, + http::header::CONTENT_TYPE, + test::{ok_service, TestRequest}, + HttpResponse, + }; #[actix_rt::test] async fn test_default_headers() { - let mut mw = DefaultHeaders::new() + let mw = DefaultHeaders::new() .header(CONTENT_TYPE, "0001") .new_transform(ok_service()) .await @@ -204,10 +212,13 @@ mod tests { let req = TestRequest::default().to_srv_request(); let srv = |req: ServiceRequest| { - ok(req - .into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish())) + ok(req.into_response( + HttpResponse::Ok() + .insert_header((CONTENT_TYPE, "0002")) + .finish(), + )) }; - let mut mw = DefaultHeaders::new() + let mw = DefaultHeaders::new() .header(CONTENT_TYPE, "0001") .new_transform(srv.into_service()) .await @@ -218,10 +229,9 @@ mod tests { #[actix_rt::test] async fn test_content_type() { - let srv = - |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); - let mut mw = DefaultHeaders::new() - .content_type() + let srv = |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); + let mw = DefaultHeaders::new() + .add_content_type() .new_transform(srv.into_service()) .await .unwrap(); diff --git a/src/middleware/errhandlers.rs b/src/middleware/err_handlers.rs similarity index 53% rename from src/middleware/errhandlers.rs rename to src/middleware/err_handlers.rs index d2d3b0d8c..4673ed4ce 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/err_handlers.rs @@ -1,35 +1,41 @@ -//! Custom handlers service for responses. -use std::rc::Rc; -use std::task::{Context, Poll}; +//! For middleware documentation, see [`ErrorHandlers`]. + +use std::{ + future::Future, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; use actix_service::{Service, Transform}; -use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; -use fxhash::FxHashMap; +use ahash::AHashMap; +use futures_core::{future::LocalBoxFuture, ready}; -use crate::dev::{ServiceRequest, ServiceResponse}; -use crate::error::{Error, Result}; -use crate::http::StatusCode; +use crate::{ + dev::{ServiceRequest, ServiceResponse}, + error::{Error, Result}, + http::StatusCode, +}; -/// Error handler response +/// Return type for [`ErrorHandlers`] custom handlers. pub enum ErrorHandlerResponse { - /// New http response got generated + /// Immediate HTTP response. Response(ServiceResponse), - /// Result is a future that resolves to a new http response + + /// A future that resolves to an HTTP response. Future(LocalBoxFuture<'static, Result, Error>>), } type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; -/// `Middleware` for allowing custom handlers for responses. +/// Middleware for registering custom status code based error handlers. /// -/// You can use `ErrorHandlers::handler()` method to register a custom error -/// handler for specific status code. You can modify existing response or -/// create completely new one. -/// -/// ## Example +/// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler +/// for a given status code. Handlers can modify existing responses or create completely new ones. /// +/// # Examples /// ```rust -/// use actix_web::middleware::errhandlers::{ErrorHandlers, ErrorHandlerResponse}; +/// use actix_web::middleware::{ErrorHandlers, ErrorHandlerResponse}; /// use actix_web::{web, http, dev, App, HttpRequest, HttpResponse, Result}; /// /// fn render_500(mut res: dev::ServiceResponse) -> Result> { @@ -39,7 +45,6 @@ type ErrorHandler = dyn Fn(ServiceResponse) -> Result = dyn Fn(ServiceResponse) -> Result { - handlers: Rc>>>, + handlers: Handlers, } +type Handlers = Rc>>>; + impl Default for ErrorHandlers { fn default() -> Self { ErrorHandlers { - handlers: Rc::new(FxHashMap::default()), + handlers: Rc::new(AHashMap::default()), } } } impl ErrorHandlers { - /// Construct new `ErrorHandlers` instance + /// Construct new `ErrorHandlers` instance. pub fn new() -> Self { ErrorHandlers::default() } - /// Register error handler for specified status code + /// Register error handler for specified status code. pub fn handler(mut self, status: StatusCode, handler: F) -> Self where F: Fn(ServiceResponse) -> Result> + 'static, @@ -81,80 +87,101 @@ impl ErrorHandlers { } } -impl Transform for ErrorHandlers +impl Transform for ErrorHandlers where - S: Service, Error = Error>, + S: Service, Error = Error> + 'static, S::Future: 'static, B: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type InitError = (); type Transform = ErrorHandlersMiddleware; - type Future = Ready>; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { - ok(ErrorHandlersMiddleware { - service, - handlers: self.handlers.clone(), - }) + let handlers = self.handlers.clone(); + Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) }) } } #[doc(hidden)] pub struct ErrorHandlersMiddleware { service: S, - handlers: Rc>>>, + handlers: Handlers, } -impl Service for ErrorHandlersMiddleware +impl Service for ErrorHandlersMiddleware where - S: Service, Error = Error>, + S: Service, Error = Error>, S::Future: 'static, B: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = ErrorHandlersFuture; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_service::forward_ready!(service); - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { let handlers = self.handlers.clone(); let fut = self.service.call(req); + ErrorHandlersFuture::ServiceFuture { fut, handlers } + } +} - async move { - let res = fut.await?; +#[pin_project::pin_project(project = ErrorHandlersProj)] +pub enum ErrorHandlersFuture +where + Fut: Future, +{ + ServiceFuture { + #[pin] + fut: Fut, + handlers: Handlers, + }, + HandlerFuture { + fut: LocalBoxFuture<'static, Fut::Output>, + }, +} - if let Some(handler) = handlers.get(&res.status()) { - match handler(res) { - Ok(ErrorHandlerResponse::Response(res)) => Ok(res), - Ok(ErrorHandlerResponse::Future(fut)) => fut.await, - Err(e) => Err(e), +impl Future for ErrorHandlersFuture +where + Fut: Future, Error>>, +{ + type Output = Fut::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + ErrorHandlersProj::ServiceFuture { fut, handlers } => { + let res = ready!(fut.poll(cx))?; + match handlers.get(&res.status()) { + Some(handler) => match handler(res)? { + ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)), + ErrorHandlerResponse::Future(fut) => { + self.as_mut() + .set(ErrorHandlersFuture::HandlerFuture { fut }); + self.poll(cx) + } + }, + None => Poll::Ready(Ok(res)), } - } else { - Ok(res) } + ErrorHandlersProj::HandlerFuture { fut } => fut.as_mut().poll(cx), } - .boxed_local() } } #[cfg(test)] mod tests { use actix_service::IntoService; - use futures_util::future::ok; + use futures_util::future::{ok, FutureExt}; use super::*; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; use crate::test::{self, TestRequest}; use crate::HttpResponse; - #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() @@ -168,18 +195,16 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mut mw = ErrorHandlers::new() + let mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) .new_transform(srv.into_service()) .await .unwrap(); - let resp = - test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } - #[allow(clippy::unnecessary_wraps)] fn render_500_async( mut res: ServiceResponse, ) -> Result> { @@ -195,14 +220,13 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mut mw = ErrorHandlers::new() + let mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) .new_transform(srv.into_service()) .await .unwrap(); - let resp = - test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; + let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } } diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index 563cb6c32..5b5b5577c 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -1,93 +1,87 @@ -//! Request logging middleware -use std::collections::HashSet; -use std::convert::TryFrom; -use std::env; -use std::fmt::{self, Display, Formatter}; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; +//! For middleware documentation, see [`Logger`]. + +use std::{ + collections::HashSet, + convert::TryFrom, + env, + fmt::{self, Display as _}, + future::Future, + marker::PhantomData, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; use actix_service::{Service, Transform}; use bytes::Bytes; use futures_util::future::{ok, Ready}; -use log::debug; +use log::{debug, warn}; use regex::{Regex, RegexSet}; use time::OffsetDateTime; -use crate::dev::{BodySize, MessageBody, ResponseBody}; -use crate::error::{Error, Result}; -use crate::http::{HeaderName, StatusCode}; -use crate::service::{ServiceRequest, ServiceResponse}; -use crate::HttpResponse; +use crate::{ + dev::{BodySize, MessageBody, ResponseBody}, + error::{Error, Result}, + http::{HeaderName, StatusCode}, + service::{ServiceRequest, ServiceResponse}, + HttpResponse, +}; -/// `Middleware` for logging request and response info to the terminal. +/// Middleware for logging request and response summaries to the terminal. /// -/// `Logger` middleware uses standard log crate to log information. You should -/// enable logger for `actix_web` package to see access log. -/// ([`env_logger`](https://docs.rs/env_logger/*/env_logger/) or similar) +/// This middleware uses the `log` crate to output information. Enable `log`'s output for the +/// "actix_web" scope using [`env_logger`](https://docs.rs/env_logger) or similar crate. /// -/// ## Usage -/// -/// Create `Logger` middleware with the specified `format`. -/// Default `Logger` could be created with `default` method, it uses the -/// default format: +/// # Default Format +/// The [`default`](Logger::default) Logger uses the following format: /// /// ```plain /// %a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T +/// +/// Example Output: +/// 127.0.0.1:54278 "GET /test HTTP/1.1" 404 20 "-" "HTTPie/2.2.0" 0.001074 /// ``` /// +/// # Examples /// ```rust /// use actix_web::{middleware::Logger, App}; /// -/// std::env::set_var("RUST_LOG", "actix_web=info"); -/// env_logger::init(); +/// // access logs are printed with the INFO level so ensure it is enabled by default +/// env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); /// /// let app = App::new() -/// .wrap(Logger::default()) +/// // .wrap(Logger::default()) /// .wrap(Logger::new("%a %{User-Agent}i")); /// ``` /// -/// ## Format -/// -/// `%%` The percent sign -/// -/// `%a` Remote IP-address (IP-address of proxy if using reverse proxy) -/// -/// `%t` Time when the request was started to process (in rfc3339 format) -/// -/// `%r` First line of request -/// -/// `%s` Response status code -/// -/// `%b` Size of response in bytes, including HTTP headers -/// -/// `%T` Time taken to serve the request, in seconds with floating fraction in -/// .06f format -/// -/// `%D` Time taken to serve the request, in milliseconds -/// -/// `%U` Request URL -/// -/// `%{r}a` Real IP remote address **\*** -/// -/// `%{FOO}i` request.headers['FOO'] -/// -/// `%{FOO}o` response.headers['FOO'] -/// -/// `%{FOO}e` os.environ['FOO'] -/// -/// `%{FOO}xi` [custom request replacement](Logger::custom_request_replace) labelled "FOO" +/// # Format +/// Variable | Description +/// -------- | ----------- +/// `%%` | The percent sign +/// `%a` | Peer IP address (or IP address of reverse proxy if used) +/// `%t` | Time when the request started processing (in RFC 3339 format) +/// `%r` | First line of request (Example: `GET /test HTTP/1.1`) +/// `%s` | Response status code +/// `%b` | Size of response in bytes, including HTTP headers +/// `%T` | Time taken to serve the request, in seconds to 6 decimal places +/// `%D` | Time taken to serve the request, in milliseconds +/// `%U` | Request URL +/// `%{r}a` | "Real IP" remote address **\*** +/// `%{FOO}i` | `request.headers["FOO"]` +/// `%{FOO}o` | `response.headers["FOO"]` +/// `%{FOO}e` | `env_var["FOO"]` +/// `%{FOO}xi` | [Custom request replacement](Logger::custom_request_replace) labelled "FOO" /// /// # Security -/// **\*** It is calculated using -/// [`ConnectionInfo::realip_remote_addr()`](crate::dev::ConnectionInfo::realip_remote_addr()) +/// **\*** "Real IP" remote address is calculated using +/// [`ConnectionInfo::realip_remote_addr()`](crate::dev::ConnectionInfo::realip_remote_addr()) /// -/// If you use this value ensure that all requests come from trusted hosts, since it is trivial -/// for the remote client to simulate being another client. +/// If you use this value, ensure that all requests come from trusted hosts. Otherwise, it is +/// trivial for the remote client to falsify their source IP address. +#[derive(Debug)] pub struct Logger(Rc); +#[derive(Debug, Clone)] struct Inner { format: Format, exclude: HashSet, @@ -113,7 +107,7 @@ impl Logger { self } - /// Ignore and do not log access info for paths that match regex + /// Ignore and do not log access info for paths that match regex. pub fn exclude_regex>(mut self, path: T) -> Self { let inner = Rc::get_mut(&mut self.0).unwrap(); let mut patterns = inner.exclude_regex.patterns().to_vec(); @@ -143,9 +137,9 @@ impl Logger { ) -> Self { let inner = Rc::get_mut(&mut self.0).unwrap(); - let ft = inner.format.0.iter_mut().find(|ft| { - matches!(ft, FormatText::CustomRequest(unit_label, _) if label == unit_label) - }); + let ft = inner.format.0.iter_mut().find( + |ft| matches!(ft, FormatText::CustomRequest(unit_label, _) if label == unit_label), + ); if let Some(FormatText::CustomRequest(_, request_fn)) = ft { // replace into None or previously registered fn using same label @@ -179,12 +173,11 @@ impl Default for Logger { } } -impl Transform for Logger +impl Transform for Logger where - S: Service, Error = Error>, + S: Service, Error = Error>, B: MessageBody, { - type Request = ServiceRequest; type Response = ServiceResponse>; type Error = Error; type InitError = (); @@ -195,9 +188,8 @@ where for unit in &self.0.format.0 { // missing request replacement function diagnostic if let FormatText::CustomRequest(label, None) = unit { - debug!( - "No custom request replacement function was registered for label {} in\ - logger format.", + warn!( + "No custom request replacement function was registered for label \"{}\".", label ); } @@ -210,27 +202,24 @@ where } } -/// Logger middleware +/// Logger middleware service. pub struct LoggerMiddleware { inner: Rc, service: S, } -impl Service for LoggerMiddleware +impl Service for LoggerMiddleware where - S: Service, Error = Error>, + S: Service, Error = Error>, B: MessageBody, { - type Request = ServiceRequest; type Response = ServiceResponse>; type Error = Error; type Future = LoggerResponse; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_service::forward_ready!(service); - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { if self.inner.exclude.contains(req.path()) || self.inner.exclude_regex.is_match(req.path()) { @@ -238,7 +227,7 @@ where fut: self.service.call(req), format: None, time: OffsetDateTime::now_utc(), - _t: PhantomData, + _phantom: PhantomData, } } else { let now = OffsetDateTime::now_utc(); @@ -251,30 +240,29 @@ where fut: self.service.call(req), format: Some(format), time: now, - _t: PhantomData, + _phantom: PhantomData, } } } } -#[doc(hidden)] #[pin_project::pin_project] pub struct LoggerResponse where B: MessageBody, - S: Service, + S: Service, { #[pin] fut: S::Future, time: OffsetDateTime, format: Option, - _t: PhantomData<(B,)>, + _phantom: PhantomData, } impl Future for LoggerResponse where B: MessageBody, - S: Service, Error = Error>, + S: Service, Error = Error>, { type Output = Result>, Error>; @@ -327,7 +315,7 @@ pub struct StreamLog { impl PinnedDrop for StreamLog { fn drop(self: Pin<&mut Self>) { if let Some(ref format) = self.format { - let render = |fmt: &mut Formatter<'_>| { + let render = |fmt: &mut fmt::Formatter<'_>| { for unit in &format.0 { unit.render(fmt, self.size, self.time)?; } @@ -358,9 +346,8 @@ impl MessageBody for StreamLog { } } -/// A formatting style for the `Logger`, consisting of multiple -/// `FormatText`s concatenated into one line. -#[derive(Clone)] +/// A formatting style for the `Logger` consisting of multiple concatenated `FormatText` items. +#[derive(Debug, Clone)] struct Format(Vec); impl Default for Format { @@ -376,8 +363,7 @@ impl Format { /// Returns `None` if the format string syntax is incorrect. pub fn new(s: &str) -> Format { log::trace!("Access log format: {}", s); - let fmt = - Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([aioe]|xi)|[atPrUsbTD]?)").unwrap(); + let fmt = Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([aioe]|xi)|[atPrUsbTD]?)").unwrap(); let mut idx = 0; let mut results = Vec::new(); @@ -398,12 +384,12 @@ impl Format { unreachable!() } } - "i" => FormatText::RequestHeader( - HeaderName::try_from(key.as_str()).unwrap(), - ), - "o" => FormatText::ResponseHeader( - HeaderName::try_from(key.as_str()).unwrap(), - ), + "i" => { + FormatText::RequestHeader(HeaderName::try_from(key.as_str()).unwrap()) + } + "o" => { + FormatText::ResponseHeader(HeaderName::try_from(key.as_str()).unwrap()) + } "e" => FormatText::EnvironHeader(key.as_str().to_owned()), "xi" => FormatText::CustomRequest(key.as_str().to_owned(), None), _ => unreachable!(), @@ -432,13 +418,12 @@ impl Format { } } -/// A string of text to be logged. This is either one of the data -/// fields supported by the `Logger`, or a custom `String`. -#[doc(hidden)] +/// A string of text to be logged. +/// +/// This is either one of the data fields supported by the `Logger`, or a custom `String`. #[non_exhaustive] #[derive(Debug, Clone)] -// TODO: remove pub on next breaking change -pub enum FormatText { +enum FormatText { Str(String), Percent, RequestLine, @@ -456,10 +441,8 @@ pub enum FormatText { CustomRequest(String, Option), } -// TODO: remove pub on next breaking change -#[doc(hidden)] #[derive(Clone)] -pub struct CustomRequestFn { +struct CustomRequestFn { inner_fn: Rc String>, } @@ -478,11 +461,11 @@ impl fmt::Debug for CustomRequestFn { impl FormatText { fn render( &self, - fmt: &mut Formatter<'_>, + fmt: &mut fmt::Formatter<'_>, size: usize, entry_time: OffsetDateTime, ) -> Result<(), fmt::Error> { - match *self { + match self { FormatText::Str(ref string) => fmt.write_str(string), FormatText::Percent => "%".fmt(fmt), FormatText::ResponseSize => size.fmt(fmt), @@ -508,7 +491,7 @@ impl FormatText { } fn render_response(&mut self, res: &HttpResponse) { - match *self { + match self { FormatText::ResponseStatus => { *self = FormatText::Str(format!("{}", res.status().as_u16())) } @@ -524,12 +507,12 @@ impl FormatText { }; *self = FormatText::Str(s.to_string()) } - _ => (), + _ => {} } } fn render_request(&mut self, now: OffsetDateTime, req: &ServiceRequest) { - match &*self { + match self { FormatText::RequestLine => { *self = if req.query_string().is_empty() { FormatText::Str(format!( @@ -549,9 +532,7 @@ impl FormatText { }; } FormatText::UrlPath => *self = FormatText::Str(req.path().to_string()), - FormatText::RequestTime => { - *self = FormatText::Str(now.format("%Y-%m-%dT%H:%M:%S")) - } + FormatText::RequestTime => *self = FormatText::Str(now.format("%Y-%m-%dT%H:%M:%S")), FormatText::RequestHeader(ref name) => { let s = if let Some(val) = req.headers().get(name) { if let Ok(s) = val.to_str() { @@ -573,8 +554,7 @@ impl FormatText { *self = s; } FormatText::RealIPRemoteAddr => { - let s = if let Some(remote) = req.connection_info().realip_remote_addr() - { + let s = if let Some(remote) = req.connection_info().realip_remote_addr() { FormatText::Str(remote.to_string()) } else { FormatText::Str("-".to_string()) @@ -589,18 +569,18 @@ impl FormatText { *self = s; } - _ => (), + _ => {} } } } /// Converter to get a String from something that writes to a Formatter. pub(crate) struct FormatDisplay<'a>( - &'a dyn Fn(&mut Formatter<'_>) -> Result<(), fmt::Error>, + &'a dyn Fn(&mut fmt::Formatter<'_>) -> Result<(), fmt::Error>, ); impl<'a> fmt::Display for FormatDisplay<'a> { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { (self.0)(fmt) } } @@ -619,19 +599,20 @@ mod tests { let srv = |req: ServiceRequest| { ok(req.into_response( HttpResponse::build(StatusCode::OK) - .header("X-Test", "ttt") + .insert_header(("X-Test", "ttt")) .finish(), )) }; let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); - let mut srv = logger.new_transform(srv.into_service()).await.unwrap(); + let srv = logger.new_transform(srv.into_service()).await.unwrap(); - let req = TestRequest::with_header( - header::USER_AGENT, - header::HeaderValue::from_static("ACTIX-WEB"), - ) - .to_srv_request(); + let req = TestRequest::default() + .insert_header(( + header::USER_AGENT, + header::HeaderValue::from_static("ACTIX-WEB"), + )) + .to_srv_request(); let _res = srv.call(req).await; } @@ -640,32 +621,34 @@ mod tests { let srv = |req: ServiceRequest| { ok(req.into_response( HttpResponse::build(StatusCode::OK) - .header("X-Test", "ttt") + .insert_header(("X-Test", "ttt")) .finish(), )) }; - let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test") - .exclude_regex("\\w"); + let logger = + Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test").exclude_regex("\\w"); - let mut srv = logger.new_transform(srv.into_service()).await.unwrap(); + let srv = logger.new_transform(srv.into_service()).await.unwrap(); - let req = TestRequest::with_header( - header::USER_AGENT, - header::HeaderValue::from_static("ACTIX-WEB"), - ) - .to_srv_request(); + let req = TestRequest::default() + .insert_header(( + header::USER_AGENT, + header::HeaderValue::from_static("ACTIX-WEB"), + )) + .to_srv_request(); let _res = srv.call(req).await.unwrap(); } #[actix_rt::test] async fn test_url_path() { let mut format = Format::new("%T %U"); - let req = TestRequest::with_header( - header::USER_AGENT, - header::HeaderValue::from_static("ACTIX-WEB"), - ) - .uri("/test/route/yeah") - .to_srv_request(); + let req = TestRequest::default() + .insert_header(( + header::USER_AGENT, + header::HeaderValue::from_static("ACTIX-WEB"), + )) + .uri("/test/route/yeah") + .to_srv_request(); let now = OffsetDateTime::now_utc(); for unit in &mut format.0 { @@ -677,7 +660,7 @@ mod tests { unit.render_response(&resp); } - let render = |fmt: &mut Formatter<'_>| { + let render = |fmt: &mut fmt::Formatter<'_>| { for unit in &format.0 { unit.render(fmt, 1024, now)?; } @@ -692,12 +675,13 @@ mod tests { async fn test_default_format() { let mut format = Format::default(); - let req = TestRequest::with_header( - header::USER_AGENT, - header::HeaderValue::from_static("ACTIX-WEB"), - ) - .peer_addr("127.0.0.1:8081".parse().unwrap()) - .to_srv_request(); + let req = TestRequest::default() + .insert_header(( + header::USER_AGENT, + header::HeaderValue::from_static("ACTIX-WEB"), + )) + .peer_addr("127.0.0.1:8081".parse().unwrap()) + .to_srv_request(); let now = OffsetDateTime::now_utc(); for unit in &mut format.0 { @@ -710,7 +694,7 @@ mod tests { } let entry_time = OffsetDateTime::now_utc(); - let render = |fmt: &mut Formatter<'_>| { + let render = |fmt: &mut fmt::Formatter<'_>| { for unit in &format.0 { unit.render(fmt, 1024, entry_time)?; } @@ -738,7 +722,7 @@ mod tests { unit.render_response(&resp); } - let render = |fmt: &mut Formatter<'_>| { + let render = |fmt: &mut fmt::Formatter<'_>| { for unit in &format.0 { unit.render(fmt, 1024, now)?; } @@ -752,13 +736,12 @@ mod tests { async fn test_remote_addr_format() { let mut format = Format::new("%{r}a"); - let req = TestRequest::with_header( - header::FORWARDED, - header::HeaderValue::from_static( - "for=192.0.2.60;proto=http;by=203.0.113.43", - ), - ) - .to_srv_request(); + let req = TestRequest::default() + .insert_header(( + header::FORWARDED, + header::HeaderValue::from_static("for=192.0.2.60;proto=http;by=203.0.113.43"), + )) + .to_srv_request(); let now = OffsetDateTime::now_utc(); for unit in &mut format.0 { @@ -771,7 +754,7 @@ mod tests { } let entry_time = OffsetDateTime::now_utc(); - let render = |fmt: &mut Formatter<'_>| { + let render = |fmt: &mut fmt::Formatter<'_>| { for unit in &format.0 { unit.render(fmt, 1024, entry_time)?; } @@ -802,7 +785,7 @@ mod tests { unit.render_request(now, &req); - let render = |fmt: &mut Formatter<'_>| unit.render(fmt, 1024, now); + let render = |fmt: &mut fmt::Formatter<'_>| unit.render(fmt, 1024, now); let log_output = FormatDisplay(&render).to_string(); assert_eq!(log_output, "custom_log"); @@ -817,7 +800,7 @@ mod tests { captured.to_owned() }); - let mut srv = logger.new_transform(test::ok_service()).await.unwrap(); + let srv = logger.new_transform(test::ok_service()).await.unwrap(); let req = TestRequest::default().to_srv_request(); srv.call(req).await.unwrap(); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 12c12a98c..e24782f07 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,17 +1,20 @@ -//! Middlewares +//! Commonly used middleware. + +mod compat; +mod condition; +mod default_headers; +mod err_handlers; +mod logger; +mod normalize; + +pub use self::compat::Compat; +pub use self::condition::Condition; +pub use self::default_headers::DefaultHeaders; +pub use self::err_handlers::{ErrorHandlerResponse, ErrorHandlers}; +pub use self::logger::Logger; +pub use self::normalize::{NormalizePath, TrailingSlash}; #[cfg(feature = "compress")] mod compress; #[cfg(feature = "compress")] pub use self::compress::Compress; - -mod condition; -mod defaultheaders; -pub mod errhandlers; -mod logger; -pub mod normalize; - -pub use self::condition::Condition; -pub use self::defaultheaders::DefaultHeaders; -pub use self::logger::Logger; -pub use self::normalize::NormalizePath; diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index ad9f51079..ea21a7215 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -1,60 +1,63 @@ //! For middleware documentation, see [`NormalizePath`]. -use std::task::{Context, Poll}; - use actix_http::http::{PathAndQuery, Uri}; use actix_service::{Service, Transform}; use bytes::Bytes; use futures_util::future::{ready, Ready}; use regex::Regex; -use crate::service::{ServiceRequest, ServiceResponse}; -use crate::Error; +use crate::{ + service::{ServiceRequest, ServiceResponse}, + Error, +}; -/// To be used when constructing `NormalizePath` to define it's behavior. +/// Determines the behavior of the [`NormalizePath`] middleware. +/// +/// The default is `TrailingSlash::Trim`. #[non_exhaustive] -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy)] pub enum TrailingSlash { - /// Always add a trailing slash to the end of the path. - /// This will require all routes to end in a trailing slash for them to be accessible. - Always, + /// Trim trailing slashes from the end of the path. + /// + /// Using this will require all routes to omit trailing slashes for them to be accessible. + Trim, /// Only merge any present multiple trailing slashes. /// - /// Note: This option provides the best compatibility with the v2 version of this middleware. + /// This option provides the best compatibility with behavior in actix-web v2.0. MergeOnly, - /// Trim trailing slashes from the end of the path. - Trim, + /// Always add a trailing slash to the end of the path. + /// + /// Using this will require all routes have a trailing slash for them to be accessible. + Always, } impl Default for TrailingSlash { fn default() -> Self { - TrailingSlash::Always + TrailingSlash::Trim } } -#[derive(Default, Clone, Copy)] -/// Middleware to normalize a request's path so that routes can be matched less strictly. +/// Middleware for normalizing a request's path so that routes can be matched more flexibly. /// /// # Normalization Steps -/// - Merges multiple consecutive slashes into one. (For example, `/path//one` always -/// becomes `/path/one`.) +/// - Merges consecutive slashes into one. (For example, `/path//one` always becomes `/path/one`.) /// - Appends a trailing slash if one is not present, removes one if present, or keeps trailing /// slashes as-is, depending on which [`TrailingSlash`] variant is supplied /// to [`new`](NormalizePath::new()). /// /// # Default Behavior -/// The default constructor chooses to strip trailing slashes from the end -/// ([`TrailingSlash::Trim`]), the effect is that route definitions should be defined without -/// trailing slashes or else they will be inaccessible. +/// The default constructor chooses to strip trailing slashes from the end of paths with them +/// ([`TrailingSlash::Trim`]). The implication is that route definitions should be defined without +/// trailing slashes or else they will be inaccessible (or vice versa when using the +/// `TrailingSlash::Always` behavior), as shown in the example tests below. /// -/// # Example +/// # Examples /// ```rust /// use actix_web::{web, middleware, App}; /// -/// # #[actix_rt::test] -/// # async fn normalize() { +/// # actix_web::rt::System::new().block_on(async { /// let app = App::new() /// .wrap(middleware::NormalizePath::default()) /// .route("/test", web::get().to(|| async { "test" })) @@ -63,25 +66,26 @@ impl Default for TrailingSlash { /// use actix_web::http::StatusCode; /// use actix_web::test::{call_service, init_service, TestRequest}; /// -/// let mut app = init_service(app).await; +/// let app = init_service(app).await; /// /// let req = TestRequest::with_uri("/test").to_request(); -/// let res = call_service(&mut app, req).await; +/// let res = call_service(&app, req).await; /// assert_eq!(res.status(), StatusCode::OK); /// /// let req = TestRequest::with_uri("/test/").to_request(); -/// let res = call_service(&mut app, req).await; +/// let res = call_service(&app, req).await; /// assert_eq!(res.status(), StatusCode::OK); /// /// let req = TestRequest::with_uri("/unmatchable").to_request(); -/// let res = call_service(&mut app, req).await; +/// let res = call_service(&app, req).await; /// assert_eq!(res.status(), StatusCode::NOT_FOUND); /// /// let req = TestRequest::with_uri("/unmatchable/").to_request(); -/// let res = call_service(&mut app, req).await; +/// let res = call_service(&app, req).await; /// assert_eq!(res.status(), StatusCode::NOT_FOUND); -/// # } +/// # }) /// ``` +#[derive(Debug, Clone, Copy, Default)] pub struct NormalizePath(TrailingSlash); impl NormalizePath { @@ -91,16 +95,15 @@ impl NormalizePath { } } -impl Transform for NormalizePath +impl Transform for NormalizePath where - S: Service, Error = Error>, + S: Service, Error = Error>, S::Future: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type InitError = (); type Transform = NormalizePathNormalization; + type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { @@ -112,28 +115,24 @@ where } } -#[doc(hidden)] pub struct NormalizePathNormalization { service: S, merge_slash: Regex, trailing_slash_behavior: TrailingSlash, } -impl Service for NormalizePathNormalization +impl Service for NormalizePathNormalization where - S: Service, Error = Error>, + S: Service, Error = Error>, S::Future: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Future = S::Future; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_service::forward_ready!(service); - fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + fn call(&self, mut req: ServiceRequest) -> Self::Future { let head = req.head_mut(); let original_path = head.uri.path(); @@ -196,46 +195,46 @@ mod tests { #[actix_rt::test] async fn test_wrap() { - let mut app = init_service( + let app = init_service( App::new() .wrap(NormalizePath::default()) .service(web::resource("/").to(HttpResponse::Ok)) - .service(web::resource("/v1/something/").to(HttpResponse::Ok)), + .service(web::resource("/v1/something").to(HttpResponse::Ok)), ) .await; let req = TestRequest::with_uri("/").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req = TestRequest::with_uri("/?query=test").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req = TestRequest::with_uri("///").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req = TestRequest::with_uri("/v1//something////").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req2 = TestRequest::with_uri("//v1/something").to_request(); - let res2 = call_service(&mut app, req2).await; + let res2 = call_service(&app, req2).await; assert!(res2.status().is_success()); let req3 = TestRequest::with_uri("//v1//////something").to_request(); - let res3 = call_service(&mut app, req3).await; + let res3 = call_service(&app, req3).await; assert!(res3.status().is_success()); let req4 = TestRequest::with_uri("/v1//something").to_request(); - let res4 = call_service(&mut app, req4).await; + let res4 = call_service(&app, req4).await; assert!(res4.status().is_success()); } #[actix_rt::test] async fn trim_trailing_slashes() { - let mut app = init_service( + let app = init_service( App::new() .wrap(NormalizePath(TrailingSlash::Trim)) .service(web::resource("/").to(HttpResponse::Ok)) @@ -245,37 +244,37 @@ mod tests { // root paths should still work let req = TestRequest::with_uri("/").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req = TestRequest::with_uri("/?query=test").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req = TestRequest::with_uri("///").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req = TestRequest::with_uri("/v1/something////").to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert!(res.status().is_success()); let req2 = TestRequest::with_uri("/v1/something/").to_request(); - let res2 = call_service(&mut app, req2).await; + let res2 = call_service(&app, req2).await; assert!(res2.status().is_success()); let req3 = TestRequest::with_uri("//v1//something//").to_request(); - let res3 = call_service(&mut app, req3).await; + let res3 = call_service(&app, req3).await; assert!(res3.status().is_success()); let req4 = TestRequest::with_uri("//v1//something").to_request(); - let res4 = call_service(&mut app, req4).await; + let res4 = call_service(&app, req4).await; assert!(res4.status().is_success()); } #[actix_rt::test] async fn keep_trailing_slash_unchanged() { - let mut app = init_service( + let app = init_service( App::new() .wrap(NormalizePath(TrailingSlash::MergeOnly)) .service(web::resource("/").to(HttpResponse::Ok)) @@ -300,7 +299,7 @@ mod tests { for (path, success) in tests { let req = TestRequest::with_uri(path).to_request(); - let res = call_service(&mut app, req).await; + let res = call_service(&app, req).await; assert_eq!(res.status().is_success(), success); } } @@ -308,11 +307,11 @@ mod tests { #[actix_rt::test] async fn test_in_place_normalization() { let srv = |req: ServiceRequest| { - assert_eq!("/v1/something/", req.path()); + assert_eq!("/v1/something", req.path()); ready(Ok(req.into_response(HttpResponse::Ok().finish()))) }; - let mut normalize = NormalizePath::default() + let normalize = NormalizePath::default() .new_transform(srv.into_service()) .await .unwrap(); @@ -336,14 +335,14 @@ mod tests { #[actix_rt::test] async fn should_normalize_nothing() { - const URI: &str = "/v1/something/"; + const URI: &str = "/v1/something"; let srv = |req: ServiceRequest| { assert_eq!(URI, req.path()); ready(Ok(req.into_response(HttpResponse::Ok().finish()))) }; - let mut normalize = NormalizePath::default() + let normalize = NormalizePath::default() .new_transform(srv.into_service()) .await .unwrap(); @@ -355,19 +354,17 @@ mod tests { #[actix_rt::test] async fn should_normalize_no_trail() { - const URI: &str = "/v1/something"; - let srv = |req: ServiceRequest| { - assert_eq!(URI.to_string() + "/", req.path()); + assert_eq!("/v1/something", req.path()); ready(Ok(req.into_response(HttpResponse::Ok().finish()))) }; - let mut normalize = NormalizePath::default() + let normalize = NormalizePath::default() .new_transform(srv.into_service()) .await .unwrap(); - let req = TestRequest::with_uri(URI).to_srv_request(); + let req = TestRequest::with_uri("/v1/something/").to_srv_request(); let res = normalize.call(req).await.unwrap(); assert!(res.status().is_success()); } diff --git a/src/request.rs b/src/request.rs index bd4bbbf58..514b7466e 100644 --- a/src/request.rs +++ b/src/request.rs @@ -6,8 +6,9 @@ use actix_http::http::{HeaderMap, Method, Uri, Version}; use actix_http::{Error, Extensions, HttpMessage, Message, Payload, RequestHead}; use actix_router::{Path, Url}; use futures_util::future::{ok, Ready}; -use tinyvec::TinyVec; +use smallvec::SmallVec; +use crate::app_service::AppInitServiceState; use crate::config::AppConfig; use crate::error::UrlGenerationError; use crate::extract::FromRequest; @@ -16,16 +17,19 @@ use crate::rmap::ResourceMap; #[derive(Clone)] /// An HTTP Request -pub struct HttpRequest(pub(crate) Rc); +pub struct HttpRequest { + /// # Panics + /// `Rc` is used exclusively and NO `Weak` + /// is allowed anywhere in the code. Weak pointer is purposely ignored when + /// doing `Rc`'s ref counter check. Expect panics if this invariant is violated. + pub(crate) inner: Rc, +} pub(crate) struct HttpRequestInner { pub(crate) head: Message, pub(crate) path: Path, - pub(crate) payload: Payload, - pub(crate) app_data: TinyVec<[Rc; 4]>, - rmap: Rc, - config: AppConfig, - pool: &'static HttpRequestPool, + pub(crate) app_data: SmallVec<[Rc; 4]>, + app_state: Rc, } impl HttpRequest { @@ -33,24 +37,20 @@ impl HttpRequest { pub(crate) fn new( path: Path, head: Message, - payload: Payload, - rmap: Rc, - config: AppConfig, + app_state: Rc, app_data: Rc, - pool: &'static HttpRequestPool, ) -> HttpRequest { - let mut data = TinyVec::<[Rc; 4]>::new(); + let mut data = SmallVec::<[Rc; 4]>::new(); data.push(app_data); - HttpRequest(Rc::new(HttpRequestInner { - head, - path, - payload, - rmap, - config, - app_data: data, - pool, - })) + HttpRequest { + inner: Rc::new(HttpRequestInner { + head, + path, + app_state, + app_data: data, + }), + } } } @@ -58,14 +58,14 @@ impl HttpRequest { /// This method returns reference to the request head #[inline] pub fn head(&self) -> &RequestHead { - &self.0.head + &self.inner.head } /// This method returns mutable reference to the request head. - /// panics if multiple references of http request exists. + /// panics if multiple references of HTTP request exists. #[inline] pub(crate) fn head_mut(&mut self) -> &mut RequestHead { - &mut Rc::get_mut(&mut self.0).unwrap().head + &mut Rc::get_mut(&mut self.inner).unwrap().head } /// Request's uri. @@ -118,12 +118,12 @@ impl HttpRequest { /// access the matched value for that segment. #[inline] pub fn match_info(&self) -> &Path { - &self.0.path + &self.inner.path } #[inline] pub(crate) fn match_info_mut(&mut self) -> &mut Path { - &mut Rc::get_mut(&mut self.0).unwrap().path + &mut Rc::get_mut(&mut self.inner).unwrap().path } /// The resource definition pattern that matched the path. Useful for logging and metrics. @@ -134,7 +134,7 @@ impl HttpRequest { /// Returns a None when no resource is fully matched, including default services. #[inline] pub fn match_pattern(&self) -> Option { - self.0.rmap.match_pattern(self.path()) + self.resource_map().match_pattern(self.path()) } /// The resource name that matched the path. Useful for logging and metrics. @@ -142,7 +142,7 @@ impl HttpRequest { /// Returns a None when no resource is fully matched, including default services. #[inline] pub fn match_name(&self) -> Option<&str> { - self.0.rmap.match_name(self.path()) + self.resource_map().match_name(self.path()) } /// Request extensions @@ -175,16 +175,12 @@ impl HttpRequest { /// ); /// } /// ``` - pub fn url_for( - &self, - name: &str, - elements: U, - ) -> Result + pub fn url_for(&self, name: &str, elements: U) -> Result where U: IntoIterator, I: AsRef, { - self.0.rmap.url_for(&self, name, elements) + self.resource_map().url_for(&self, name, elements) } /// Generate url for named resource @@ -199,15 +195,17 @@ impl HttpRequest { #[inline] /// Get a reference to a `ResourceMap` of current application. pub fn resource_map(&self) -> &ResourceMap { - &self.0.rmap + &self.app_state().rmap() } - /// Peer socket address + /// Peer socket address. /// - /// Peer address is actual socket address, if proxy is used in front of - /// actix http server, then peer address would be address of this proxy. + /// Peer address is the directly connected peer's socket address. If a proxy is used in front of + /// the Actix Web server, then it would be address of this proxy. /// /// To get client connection information `.connection_info()` should be used. + /// + /// Will only return None when called in unit tests. #[inline] pub fn peer_addr(&self) -> Option { self.head().peer_addr @@ -219,13 +217,13 @@ impl HttpRequest { /// borrowed. #[inline] pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> { - ConnectionInfo::get(self.head(), &*self.app_config()) + ConnectionInfo::get(self.head(), self.app_config()) } /// App config #[inline] pub fn app_config(&self) -> &AppConfig { - &self.0.config + self.app_state().config() } /// Get an application data object stored with `App::data` or `App::app_data` @@ -237,7 +235,7 @@ impl HttpRequest { /// let opt_t = req.app_data::>(); /// ``` pub fn app_data(&self) -> Option<&T> { - for container in self.0.app_data.iter().rev() { + for container in self.inner.app_data.iter().rev() { if let Some(data) = container.get::() { return Some(data); } @@ -245,6 +243,11 @@ impl HttpRequest { None } + + #[inline] + fn app_state(&self) -> &AppInitServiceState { + &*self.inner.app_state + } } impl HttpMessage for HttpRequest { @@ -259,13 +262,13 @@ impl HttpMessage for HttpRequest { /// Request extensions #[inline] fn extensions(&self) -> Ref<'_, Extensions> { - self.0.head.extensions() + self.inner.head.extensions() } /// Mutable reference to a the request's extensions #[inline] fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.0.head.extensions_mut() + self.inner.head.extensions_mut() } #[inline] @@ -277,11 +280,19 @@ impl HttpMessage for HttpRequest { impl Drop for HttpRequest { fn drop(&mut self) { // if possible, contribute to current worker's HttpRequest allocation pool - if Rc::strong_count(&self.0) == 1 { - let v = &mut self.0.pool.0.borrow_mut(); - if v.len() < 128 { - self.extensions_mut().clear(); - v.push(self.0.clone()); + + // This relies on no Weak exists anywhere.(There is none) + if let Some(inner) = Rc::get_mut(&mut self.inner) { + if inner.app_state.pool().is_available() { + // clear additional app_data and keep the root one for reuse. + inner.app_data.truncate(1); + // inner is borrowed mut here. get head's Extension mutably + // to reduce borrow check + inner.head.extensions.get_mut().clear(); + + // a re-borrow of pool is necessary here. + let req = self.inner.clone(); + self.app_state().pool().push(req); } } } @@ -323,8 +334,8 @@ impl fmt::Debug for HttpRequest { writeln!( f, "\nHttpRequest {:?} {}:{}", - self.0.head.version, - self.0.head.method, + self.inner.head.version, + self.inner.head.method, self.path() )?; if !self.query_string().is_empty() { @@ -350,25 +361,50 @@ impl fmt::Debug for HttpRequest { /// Request objects are added when they are dropped (see `::drop`) and re-used /// in `::call` when there are available objects in the list. /// -/// The pool's initial capacity is 128 items. -pub(crate) struct HttpRequestPool(RefCell>>); +/// The pool's default capacity is 128 items. +pub(crate) struct HttpRequestPool { + inner: RefCell>>, + cap: usize, +} + +impl Default for HttpRequestPool { + fn default() -> Self { + Self::with_capacity(128) + } +} impl HttpRequestPool { - /// Allocates a slab of memory for pool use. - pub(crate) fn create() -> &'static HttpRequestPool { - let pool = HttpRequestPool(RefCell::new(Vec::with_capacity(128))); - Box::leak(Box::new(pool)) + pub(crate) fn with_capacity(cap: usize) -> Self { + HttpRequestPool { + inner: RefCell::new(Vec::with_capacity(cap)), + cap, + } } /// Re-use a previously allocated (but now completed/discarded) HttpRequest object. #[inline] - pub(crate) fn get_request(&self) -> Option { - self.0.borrow_mut().pop().map(HttpRequest) + pub(crate) fn pop(&self) -> Option { + self.inner + .borrow_mut() + .pop() + .map(|inner| HttpRequest { inner }) + } + + /// Check if the pool still has capacity for request storage. + #[inline] + pub(crate) fn is_available(&self) -> bool { + self.inner.borrow_mut().len() < self.cap + } + + /// Push a request to pool. + #[inline] + pub(crate) fn push(&self, req: Rc) { + self.inner.borrow_mut().push(req); } /// Clears all allocated HttpRequest objects. pub(crate) fn clear(&self) { - self.0.borrow_mut().clear() + self.inner.borrow_mut().clear() } } @@ -385,31 +421,34 @@ mod tests { #[test] fn test_debug() { - let req = - TestRequest::with_header("content-type", "text/plain").to_http_request(); + let req = TestRequest::default() + .insert_header(("content-type", "text/plain")) + .to_http_request(); let dbg = format!("{:?}", req); assert!(dbg.contains("HttpRequest")); } #[test] + #[cfg(feature = "cookies")] fn test_no_request_cookies() { let req = TestRequest::default().to_http_request(); assert!(req.cookies().unwrap().is_empty()); } #[test] + #[cfg(feature = "cookies")] fn test_request_cookies() { let req = TestRequest::default() - .header(header::COOKIE, "cookie1=value1") - .header(header::COOKIE, "cookie2=value2") + .append_header((header::COOKIE, "cookie1=value1")) + .append_header((header::COOKIE, "cookie2=value2")) .to_http_request(); { let cookies = req.cookies().unwrap(); assert_eq!(cookies.len(), 2); - assert_eq!(cookies[0].name(), "cookie2"); - assert_eq!(cookies[0].value(), "value2"); - assert_eq!(cookies[1].name(), "cookie1"); - assert_eq!(cookies[1].value(), "value1"); + assert_eq!(cookies[0].name(), "cookie1"); + assert_eq!(cookies[0].value(), "value1"); + assert_eq!(cookies[1].name(), "cookie2"); + assert_eq!(cookies[1].value(), "value2"); } let cookie = req.cookie("cookie1"); @@ -438,7 +477,8 @@ mod tests { assert!(rmap.has_resource("/user/test.html")); assert!(!rmap.has_resource("/test/unknown")); - let req = TestRequest::with_header(header::HOST, "www.rust-lang.org") + let req = TestRequest::default() + .insert_header((header::HOST, "www.rust-lang.org")) .rmap(rmap) .to_http_request(); @@ -468,7 +508,7 @@ mod tests { assert!(rmap.has_resource("/index.html")); let req = TestRequest::with_uri("/test") - .header(header::HOST, "www.rust-lang.org") + .insert_header((header::HOST, "www.rust-lang.org")) .rmap(rmap) .to_http_request(); let url = req.url_for_static("index"); @@ -514,36 +554,55 @@ mod tests { ); } + #[actix_rt::test] + async fn test_drop_http_request_pool() { + let srv = init_service(App::new().service(web::resource("/").to( + |req: HttpRequest| { + HttpResponse::Ok() + .insert_header(("pool_cap", req.app_state().pool().cap)) + .finish() + }, + ))) + .await; + + let req = TestRequest::default().to_request(); + let resp = call_service(&srv, req).await; + + drop(srv); + + assert_eq!(resp.headers().get("pool_cap").unwrap(), "128"); + } + #[actix_rt::test] async fn test_data() { - let mut srv = init_service(App::new().app_data(10usize).service( - web::resource("/").to(|req: HttpRequest| { + let srv = init_service(App::new().app_data(10usize).service(web::resource("/").to( + |req: HttpRequest| { if req.app_data::().is_some() { HttpResponse::Ok() } else { HttpResponse::BadRequest() } - }), - )) + }, + ))) .await; let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); - let mut srv = init_service(App::new().app_data(10u32).service( - web::resource("/").to(|req: HttpRequest| { + let srv = init_service(App::new().app_data(10u32).service(web::resource("/").to( + |req: HttpRequest| { if req.app_data::().is_some() { HttpResponse::Ok() } else { HttpResponse::BadRequest() } - }), - )) + }, + ))) .await; let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } @@ -555,7 +614,7 @@ mod tests { HttpResponse::Ok().body(num.to_string()) } - let mut srv = init_service( + let srv = init_service( App::new() .app_data(88usize) .service(web::resource("/").route(web::get().to(echo_usize))) @@ -586,7 +645,7 @@ mod tests { HttpResponse::Ok().body(num.to_string()) } - let mut srv = init_service( + let srv = init_service( App::new() .app_data(88usize) .service(web::resource("/").route(web::get().to(echo_usize))) @@ -626,18 +685,18 @@ mod tests { let tracker = Rc::new(RefCell::new(Tracker { dropped: false })); { let tracker2 = Rc::clone(&tracker); - let mut srv = init_service(App::new().data(10u32).service( - web::resource("/").to(move |req: HttpRequest| { + let srv = init_service(App::new().data(10u32).service(web::resource("/").to( + move |req: HttpRequest| { req.extensions_mut().insert(Foo { tracker: Rc::clone(&tracker2), }); HttpResponse::Ok() - }), - )) + }, + ))) .await; let req = TestRequest::default().to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } @@ -646,7 +705,7 @@ mod tests { #[actix_rt::test] async fn extract_path_pattern() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("/user/{id}") .service(web::resource("/profile").route(web::get().to( @@ -668,17 +727,17 @@ mod tests { .await; let req = TestRequest::get().uri("/user/22/profile").to_request(); - let res = call_service(&mut srv, req).await; + let res = call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); let req = TestRequest::get().uri("/user/22/not-exist").to_request(); - let res = call_service(&mut srv, req).await; + let res = call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); } #[actix_rt::test] async fn extract_path_pattern_complex() { - let mut srv = init_service( + let srv = init_service( App::new() .service(web::scope("/user").service(web::scope("/{id}").service( web::resource("").to(move |req: HttpRequest| { @@ -700,15 +759,15 @@ mod tests { .await; let req = TestRequest::get().uri("/user/test").to_request(); - let res = call_service(&mut srv, req).await; + let res = call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); let req = TestRequest::get().uri("/").to_request(); - let res = call_service(&mut srv, req).await; + let res = call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); let req = TestRequest::get().uri("/not-exist").to_request(); - let res = call_service(&mut srv, req).await; + let res = call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); } } diff --git a/src/request_data.rs b/src/request_data.rs index 285154884..beee8ac12 100644 --- a/src/request_data.rs +++ b/src/request_data.rs @@ -102,7 +102,7 @@ mod tests { #[actix_rt::test] async fn req_data_extractor() { - let mut srv = init_service( + let srv = init_service( App::new() .wrap_fn(|req, srv| { if req.method() == Method::POST { @@ -142,7 +142,7 @@ mod tests { #[actix_rt::test] async fn req_data_internal_mutability() { - let mut srv = init_service( + let srv = init_service( App::new() .wrap_fn(|req, srv| { let data_before = Rc::new(RefCell::new(42u32)); diff --git a/src/resource.rs b/src/resource.rs index f6046d652..944beeefa 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,17 +1,18 @@ use std::cell::RefCell; use std::fmt; use std::future::Future; -use std::pin::Pin; use std::rc::Rc; -use std::task::{Context, Poll}; +use std::task::Poll; use actix_http::{Error, Extensions, Response}; use actix_router::IntoPattern; use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{ - apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, + apply, apply_fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, + ServiceFactoryExt, Transform, }; -use futures_util::future::{ok, Either, LocalBoxFuture, Ready}; +use futures_core::future::LocalBoxFuture; +use futures_util::future::join_all; use crate::data::Data; use crate::dev::{insert_slash, AppService, HttpServiceFactory, ResourceDef}; @@ -19,7 +20,7 @@ use crate::extract::FromRequest; use crate::guard::Guard; use crate::handler::Handler; use crate::responder::Responder; -use crate::route::{CreateRouteService, Route, RouteService}; +use crate::route::{Route, RouteService}; use crate::service::{ServiceRequest, ServiceResponse}; type HttpService = BoxService; @@ -52,9 +53,9 @@ pub struct Resource { rdef: Vec, name: Option, routes: Vec, - data: Option, + app_data: Option, guards: Vec>, - default: Rc>>>, + default: HttpNewService, factory_ref: Rc>>, } @@ -69,8 +70,10 @@ impl Resource { endpoint: ResourceEndpoint::new(fref.clone()), factory_ref: fref, guards: Vec::new(), - data: None, - default: Rc::new(RefCell::new(None)), + app_data: None, + default: boxed::factory(fn_service(|req: ServiceRequest| async { + Ok(req.into_response(Response::MethodNotAllowed().finish())) + })), } } } @@ -78,8 +81,8 @@ impl Resource { impl Resource where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -200,10 +203,10 @@ where /// /// Data of different types from parent contexts will still be accessible. pub fn app_data(mut self, data: U) -> Self { - if self.data.is_none() { - self.data = Some(Extensions::new()); - } - self.data.as_mut().unwrap().insert(data); + self.app_data + .get_or_insert_with(Extensions::new) + .insert(data); + self } @@ -250,8 +253,8 @@ where mw: M, ) -> Resource< impl ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -260,7 +263,7 @@ where where M: Transform< T::Service, - Request = ServiceRequest, + ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -273,7 +276,7 @@ where guards: self.guards, routes: self.routes, default: self.default, - data: self.data, + app_data: self.app_data, factory_ref: self.factory_ref, } } @@ -317,15 +320,15 @@ where mw: F, ) -> Resource< impl ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), >, > where - F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, + F: Fn(ServiceRequest, &T::Service) -> R + Clone, R: Future>, { Resource { @@ -335,7 +338,7 @@ where guards: self.guards, routes: self.routes, default: self.default, - data: self.data, + app_data: self.app_data, factory_ref: self.factory_ref, } } @@ -345,21 +348,20 @@ where /// default handler from `App` or `Scope`. pub fn default_service(mut self, f: F) -> Self where - F: IntoServiceFactory, + F: IntoServiceFactory, U: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, > + 'static, U::InitError: fmt::Debug, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( - f.into_factory().map_init_err(|e| { - log::error!("Can not construct default service: {:?}", e) - }), - ))))); + self.default = boxed::factory( + f.into_factory() + .map_init_err(|e| log::error!("Can not construct default service: {:?}", e)), + ); self } @@ -368,8 +370,8 @@ where impl HttpServiceFactory for Resource where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -381,28 +383,26 @@ where } else { Some(std::mem::take(&mut self.guards)) }; + let mut rdef = if config.is_root() || !self.rdef.is_empty() { ResourceDef::new(insert_slash(self.rdef.clone())) } else { ResourceDef::new(self.rdef.clone()) }; + if let Some(ref name) = self.name { *rdef.name_mut() = name.clone(); } - // custom app data storage - if let Some(ref mut ext) = self.data { - config.set_service_data(ext); - } config.register_service(rdef, guards, self, None) } } -impl IntoServiceFactory for Resource +impl IntoServiceFactory for Resource where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -411,7 +411,7 @@ where fn into_factory(self) -> T { *self.factory_ref.borrow_mut() = Some(ResourceFactory { routes: self.routes, - data: self.data.map(Rc::new), + app_data: self.app_data.map(Rc::new), default: self.default, }); @@ -421,139 +421,72 @@ where pub struct ResourceFactory { routes: Vec, - data: Option>, - default: Rc>>>, + app_data: Option>, + default: HttpNewService, } -impl ServiceFactory for ResourceFactory { - type Config = (); - type Request = ServiceRequest; +impl ServiceFactory for ResourceFactory { type Response = ServiceResponse; type Error = Error; - type InitError = (); + type Config = (); type Service = ResourceService; - type Future = CreateResourceService; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - let default_fut = if let Some(ref default) = *self.default.borrow() { - Some(default.new_service(())) - } else { - None - }; + // construct default service factory future. + let default_fut = self.default.new_service(()); - CreateResourceService { - fut: self - .routes - .iter() - .map(|route| CreateRouteServiceItem::Future(route.new_service(()))) - .collect(), - data: self.data.clone(), - default: None, - default_fut, - } - } -} + // construct route service factory futures + let factory_fut = join_all(self.routes.iter().map(|route| route.new_service(()))); -enum CreateRouteServiceItem { - Future(CreateRouteService), - Service(RouteService), -} + let app_data = self.app_data.clone(); -pub struct CreateResourceService { - fut: Vec, - data: Option>, - default: Option, - default_fut: Option>>, -} + Box::pin(async move { + let default = default_fut.await?; + let routes = factory_fut + .await + .into_iter() + .collect::, _>>()?; -impl Future for CreateResourceService { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut done = true; - - if let Some(ref mut fut) = self.default_fut { - match Pin::new(fut).poll(cx)? { - Poll::Ready(default) => self.default = Some(default), - Poll::Pending => done = false, - } - } - - // poll http services - for item in &mut self.fut { - match item { - CreateRouteServiceItem::Future(ref mut fut) => match Pin::new(fut) - .poll(cx)? - { - Poll::Ready(route) => *item = CreateRouteServiceItem::Service(route), - Poll::Pending => { - done = false; - } - }, - CreateRouteServiceItem::Service(_) => continue, - }; - } - - if done { - let routes = self - .fut - .drain(..) - .map(|item| match item { - CreateRouteServiceItem::Service(service) => service, - CreateRouteServiceItem::Future(_) => unreachable!(), - }) - .collect(); - Poll::Ready(Ok(ResourceService { + Ok(ResourceService { + app_data, + default, routes, - data: self.data.clone(), - default: self.default.take(), - })) - } else { - Poll::Pending - } + }) + }) } } pub struct ResourceService { routes: Vec, - data: Option>, - default: Option, + app_data: Option>, + default: HttpService, } -impl Service for ResourceService { - type Request = ServiceRequest; +impl Service for ResourceService { type Response = ServiceResponse; type Error = Error; - type Future = Either< - Ready>, - LocalBoxFuture<'static, Result>, - >; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + actix_service::always_ready!(); - fn call(&mut self, mut req: ServiceRequest) -> Self::Future { - for route in self.routes.iter_mut() { + fn call(&self, mut req: ServiceRequest) -> Self::Future { + for route in self.routes.iter() { if route.check(&mut req) { - if let Some(ref data) = self.data { - req.add_data_container(data.clone()); + if let Some(ref app_data) = self.app_data { + req.add_data_container(app_data.clone()); } - return Either::Right(route.call(req)); + + return route.call(req); } } - if let Some(ref mut default) = self.default { - if let Some(ref data) = self.data { - req.add_data_container(data.clone()); - } - Either::Right(default.call(req)) - } else { - let req = req.into_parts().0; - Either::Left(ok(ServiceResponse::new( - req, - Response::MethodNotAllowed().finish(), - ))) + + if let Some(ref app_data) = self.app_data { + req.add_data_container(app_data.clone()); } + + self.default.call(req) } } @@ -568,17 +501,16 @@ impl ResourceEndpoint { } } -impl ServiceFactory for ResourceEndpoint { - type Config = (); - type Request = ServiceRequest; +impl ServiceFactory for ResourceEndpoint { type Response = ServiceResponse; type Error = Error; - type InitError = (); + type Config = (); type Service = ResourceService; - type Future = CreateResourceService; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - self.factory.borrow_mut().as_mut().unwrap().new_service(()) + self.factory.borrow().as_ref().unwrap().new_service(()) } } @@ -586,7 +518,7 @@ impl ServiceFactory for ResourceEndpoint { mod tests { use std::time::Duration; - use actix_rt::time::delay_for; + use actix_rt::time::sleep; use actix_service::Service; use futures_util::future::ok; @@ -598,21 +530,20 @@ mod tests { #[actix_rt::test] async fn test_middleware() { - let mut srv = - init_service( - App::new().service( - web::resource("/test") - .name("test") - .wrap(DefaultHeaders::new().header( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - )) - .route(web::get().to(HttpResponse::Ok)), - ), - ) - .await; + let srv = init_service( + App::new().service( + web::resource("/test") + .name("test") + .wrap( + DefaultHeaders::new() + .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + ) + .route(web::get().to(HttpResponse::Ok)), + ), + ) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -622,7 +553,7 @@ mod tests { #[actix_rt::test] async fn test_middleware_fn() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::resource("/test") .wrap_fn(|req, srv| { @@ -642,7 +573,7 @@ mod tests { ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -652,20 +583,19 @@ mod tests { #[actix_rt::test] async fn test_to() { - let mut srv = - init_service(App::new().service(web::resource("/test").to(|| async { - delay_for(Duration::from_millis(100)).await; - Ok::<_, Error>(HttpResponse::Ok()) - }))) - .await; + let srv = init_service(App::new().service(web::resource("/test").to(|| async { + sleep(Duration::from_millis(100)).await; + Ok::<_, Error>(HttpResponse::Ok()) + }))) + .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_pattern() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::resource(["/test", "/test2"]) .to(|| async { Ok::<_, Error>(HttpResponse::Ok()) }), @@ -673,16 +603,16 @@ mod tests { ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test2").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_default_resource() { - let mut srv = init_service( + let srv = init_service( App::new() .service(web::resource("/test").route(web::get().to(HttpResponse::Ok))) .default_service(|r: ServiceRequest| { @@ -691,16 +621,16 @@ mod tests { ) .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test") .method(Method::POST) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let mut srv = init_service( + let srv = init_service( App::new().service( web::resource("/test") .route(web::get().to(HttpResponse::Ok)) @@ -712,19 +642,19 @@ mod tests { .await; let req = TestRequest::with_uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test") .method(Method::POST) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } #[actix_rt::test] async fn test_resource_guards() { - let mut srv = init_service( + let srv = init_service( App::new() .service( web::resource("/test/{p}") @@ -747,25 +677,25 @@ mod tests { let req = TestRequest::with_uri("/test/it") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test/it") .method(Method::PUT) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::CREATED); let req = TestRequest::with_uri("/test/it") .method(Method::DELETE) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::NO_CONTENT); } #[actix_rt::test] async fn test_data() { - let mut srv = init_service( + let srv = init_service( App::new() .data(1.0f64) .data(1usize) @@ -791,13 +721,13 @@ mod tests { .await; let req = TestRequest::get().uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_data_default_service() { - let mut srv = init_service( + let srv = init_service( App::new().data(1usize).service( web::resource("/test") .data(10usize) @@ -810,7 +740,7 @@ mod tests { .await; let req = TestRequest::get().uri("/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/src/responder.rs b/src/responder.rs index d1c22323f..92945cdaa 100644 --- a/src/responder.rs +++ b/src/responder.rs @@ -1,43 +1,29 @@ -use std::convert::TryFrom; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::fmt; -use actix_http::error::InternalError; -use actix_http::http::{ - header::IntoHeaderValue, Error as HttpError, HeaderMap, HeaderName, StatusCode, +use actix_http::{ + error::InternalError, + http::{header::IntoHeaderPair, Error as HttpError, HeaderMap, StatusCode}, + ResponseBuilder, }; -use actix_http::{Error, Response, ResponseBuilder}; use bytes::{Bytes, BytesMut}; -use futures_util::future::{err, ok, Either as EitherFuture, Ready}; -use futures_util::ready; -use pin_project::pin_project; -use crate::request::HttpRequest; +use crate::{Error, HttpRequest, HttpResponse}; -/// Trait implemented by types that can be converted to a http response. +/// Trait implemented by types that can be converted to an HTTP response. /// -/// Types that implement this trait can be used as the return type of a handler. +/// Any types that implement this trait can be used in the return type of a handler. pub trait Responder { - /// The associated error which can be returned. - type Error: Into; - - /// The future response value. - type Future: Future>; - - /// Convert itself to `AsyncResult` or `Error`. - fn respond_to(self, req: &HttpRequest) -> Self::Future; + /// Convert self to `HttpResponse`. + fn respond_to(self, req: &HttpRequest) -> HttpResponse; /// Override a status code for a Responder. /// /// ```rust - /// use actix_web::{HttpRequest, Responder, http::StatusCode}; + /// use actix_web::{http::StatusCode, HttpRequest, Responder}; /// /// fn index(req: HttpRequest) -> impl Responder { /// "Welcome!".with_status(StatusCode::OK) /// } - /// # fn main() {} /// ``` fn with_status(self, status: StatusCode) -> CustomResponder where @@ -46,7 +32,9 @@ pub trait Responder { CustomResponder::new(self).with_status(status) } - /// Add header to the Responder's response. + /// Insert header to the final response. + /// + /// Overrides other headers with the same name. /// /// ```rust /// use actix_web::{web, HttpRequest, Responder}; @@ -58,47 +46,31 @@ pub trait Responder { /// } /// /// fn index(req: HttpRequest) -> impl Responder { - /// web::Json( - /// MyObj{name: "Name".to_string()} - /// ) - /// .with_header("x-version", "1.2.3") + /// web::Json(MyObj { name: "Name".to_owned() }) + /// .with_header(("x-version", "1.2.3")) /// } - /// # fn main() {} /// ``` - fn with_header(self, key: K, value: V) -> CustomResponder + fn with_header(self, header: H) -> CustomResponder where Self: Sized, - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, + H: IntoHeaderPair, { - CustomResponder::new(self).with_header(key, value) + CustomResponder::new(self).with_header(header) } } -impl Responder for Response { - type Error = Error; - type Future = Ready>; - +impl Responder for HttpResponse { #[inline] - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(self) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + self } } -impl Responder for Option -where - T: Responder, -{ - type Error = T::Error; - type Future = EitherFuture>>; - - fn respond_to(self, req: &HttpRequest) -> Self::Future { +impl Responder for Option { + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self { - Some(t) => EitherFuture::Left(t.respond_to(req)), - None => { - EitherFuture::Right(ok(Response::build(StatusCode::NOT_FOUND).finish())) - } + Some(t) => t.respond_to(req), + None => HttpResponse::build(StatusCode::NOT_FOUND).finish(), } } } @@ -108,113 +80,78 @@ where T: Responder, E: Into, { - type Error = Error; - type Future = EitherFuture< - ResponseFuture, - Ready>, - >; - - fn respond_to(self, req: &HttpRequest) -> Self::Future { + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self { - Ok(val) => EitherFuture::Left(ResponseFuture::new(val.respond_to(req))), - Err(e) => EitherFuture::Right(err(e.into())), + Ok(val) => val.respond_to(req), + Err(e) => HttpResponse::from_error(e.into()), } } } impl Responder for ResponseBuilder { - type Error = Error; - type Future = Ready>; - #[inline] - fn respond_to(mut self, _: &HttpRequest) -> Self::Future { - ok(self.finish()) + fn respond_to(mut self, _: &HttpRequest) -> HttpResponse { + self.finish() } } -impl Responder for (T, StatusCode) -where - T: Responder, -{ - type Error = T::Error; - type Future = CustomResponderFut; - - fn respond_to(self, req: &HttpRequest) -> Self::Future { - CustomResponderFut { - fut: self.0.respond_to(req), - status: Some(self.1), - headers: None, - } +impl Responder for (T, StatusCode) { + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + let mut res = self.0.respond_to(req); + *res.status_mut() = self.1; + res } } impl Responder for &'static str { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(Response::build(StatusCode::OK) - .content_type("text/plain; charset=utf-8") - .body(self)) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::Ok() + .content_type(mime::TEXT_PLAIN_UTF_8) + .body(self) } } impl Responder for &'static [u8] { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(Response::build(StatusCode::OK) - .content_type("application/octet-stream") - .body(self)) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::Ok() + .content_type(mime::APPLICATION_OCTET_STREAM) + .body(self) } } impl Responder for String { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(Response::build(StatusCode::OK) - .content_type("text/plain; charset=utf-8") - .body(self)) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::Ok() + .content_type(mime::TEXT_PLAIN_UTF_8) + .body(self) } } impl<'a> Responder for &'a String { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(Response::build(StatusCode::OK) - .content_type("text/plain; charset=utf-8") - .body(self)) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::Ok() + .content_type(mime::TEXT_PLAIN_UTF_8) + .body(self) } } impl Responder for Bytes { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(Response::build(StatusCode::OK) - .content_type("application/octet-stream") - .body(self)) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::Ok() + .content_type(mime::APPLICATION_OCTET_STREAM) + .body(self) } } impl Responder for BytesMut { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - ok(Response::build(StatusCode::OK) - .content_type("application/octet-stream") - .body(self)) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::Ok() + .content_type(mime::APPLICATION_OCTET_STREAM) + .body(self) } } -/// Allows to override status code and headers for a responder. +/// Allows overriding status code and headers for a responder. pub struct CustomResponder { responder: T, status: Option, @@ -240,14 +177,15 @@ impl CustomResponder { /// fn index(req: HttpRequest) -> impl Responder { /// "Welcome!".with_status(StatusCode::OK) /// } - /// # fn main() {} /// ``` pub fn with_status(mut self, status: StatusCode) -> Self { self.status = Some(status); self } - /// Add header to the Responder's response. + /// Insert header to the final response. + /// + /// Overrides other headers with the same name. /// /// ```rust /// use actix_web::{web, HttpRequest, Responder}; @@ -259,117 +197,53 @@ impl CustomResponder { /// } /// /// fn index(req: HttpRequest) -> impl Responder { - /// web::Json( - /// MyObj{name: "Name".to_string()} - /// ) - /// .with_header("x-version", "1.2.3") + /// web::Json(MyObj { name: "Name".to_string() }) + /// .with_header(("x-version", "1.2.3")) + /// .with_header(("x-version", "1.2.3")) /// } - /// # fn main() {} /// ``` - pub fn with_header(mut self, key: K, value: V) -> Self + pub fn with_header(mut self, header: H) -> Self where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, + H: IntoHeaderPair, { if self.headers.is_none() { self.headers = Some(HeaderMap::new()); } - match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => { - self.headers.as_mut().unwrap().append(key, value); - } - Err(e) => self.error = Some(e.into()), - }, + match header.try_into_header_pair() { + Ok((key, value)) => self.headers.as_mut().unwrap().append(key, value), Err(e) => self.error = Some(e.into()), }; + self } } impl Responder for CustomResponder { - type Error = T::Error; - type Future = CustomResponderFut; + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + let mut res = self.responder.respond_to(req); - fn respond_to(self, req: &HttpRequest) -> Self::Future { - CustomResponderFut { - fut: self.responder.respond_to(req), - status: self.status, - headers: self.headers, - } - } -} - -#[pin_project] -pub struct CustomResponderFut { - #[pin] - fut: T::Future, - status: Option, - headers: Option, -} - -impl Future for CustomResponderFut { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - let mut res = match ready!(this.fut.poll(cx)) { - Ok(res) => res, - Err(e) => return Poll::Ready(Err(e)), - }; - if let Some(status) = this.status.take() { + if let Some(status) = self.status { *res.status_mut() = status; } - if let Some(ref headers) = this.headers { + + if let Some(ref headers) = self.headers { for (k, v) in headers { + // TODO: before v4, decide if this should be append instead res.headers_mut().insert(k.clone(), v.clone()); } } - Poll::Ready(Ok(res)) + + res } } impl Responder for InternalError where - T: std::fmt::Debug + std::fmt::Display + 'static, + T: fmt::Debug + fmt::Display + 'static, { - type Error = Error; - type Future = Ready>; - - fn respond_to(self, _: &HttpRequest) -> Self::Future { - let err: Error = self.into(); - ok(err.into()) - } -} - -#[pin_project] -pub struct ResponseFuture { - #[pin] - fut: T, - _t: PhantomData, -} - -impl ResponseFuture { - pub fn new(fut: T) -> Self { - ResponseFuture { - fut, - _t: PhantomData, - } - } -} - -impl Future for ResponseFuture -where - T: Future>, - E: Into, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Poll::Ready(ready!(self.project().fut.poll(cx)).map_err(|e| e.into())) + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::from_error(self.into()) } } @@ -382,15 +256,13 @@ pub(crate) mod tests { use crate::dev::{Body, ResponseBody}; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; use crate::test::{init_service, TestRequest}; - use crate::{error, web, App, HttpResponse}; + use crate::{error, web, App}; #[actix_rt::test] async fn test_option_responder() { - let mut srv = init_service( + let srv = init_service( App::new() - .service( - web::resource("/none").to(|| async { Option::<&'static str>::None }), - ) + .service(web::resource("/none").to(|| async { Option::<&'static str>::None })) .service(web::resource("/some").to(|| async { Some("some") })), ) .await; @@ -441,7 +313,7 @@ pub(crate) mod tests { async fn test_responder() { let req = TestRequest::default().to_http_request(); - let resp: HttpResponse = "test".respond_to(&req).await.unwrap(); + let resp = "test".respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -449,7 +321,7 @@ pub(crate) mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); - let resp: HttpResponse = b"test".respond_to(&req).await.unwrap(); + let resp = b"test".respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -457,7 +329,7 @@ pub(crate) mod tests { HeaderValue::from_static("application/octet-stream") ); - let resp: HttpResponse = "test".to_string().respond_to(&req).await.unwrap(); + let resp = "test".to_string().respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -465,7 +337,7 @@ pub(crate) mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); - let resp: HttpResponse = (&"test".to_string()).respond_to(&req).await.unwrap(); + let resp = (&"test".to_string()).respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -473,8 +345,7 @@ pub(crate) mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); - let resp: HttpResponse = - Bytes::from_static(b"test").respond_to(&req).await.unwrap(); + let resp = Bytes::from_static(b"test").respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -482,10 +353,7 @@ pub(crate) mod tests { HeaderValue::from_static("application/octet-stream") ); - let resp: HttpResponse = BytesMut::from(b"test".as_ref()) - .respond_to(&req) - .await - .unwrap(); + let resp = BytesMut::from(b"test".as_ref()).respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -494,11 +362,7 @@ pub(crate) mod tests { ); // InternalError - let resp: HttpResponse = - error::InternalError::new("err", StatusCode::BAD_REQUEST) - .respond_to(&req) - .await - .unwrap(); + let resp = error::InternalError::new("err", StatusCode::BAD_REQUEST).respond_to(&req); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } @@ -507,10 +371,7 @@ pub(crate) mod tests { let req = TestRequest::default().to_http_request(); // Result - let resp: HttpResponse = Ok::<_, Error>("test".to_string()) - .respond_to(&req) - .await - .unwrap(); + let resp = Ok::<_, Error>("test".to_string()).respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( @@ -518,11 +379,10 @@ pub(crate) mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); - let res = - Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) - .respond_to(&req) - .await; - assert!(res.is_err()); + let res = Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) + .respond_to(&req); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); } #[actix_rt::test] @@ -531,18 +391,15 @@ pub(crate) mod tests { let res = "test" .to_string() .with_status(StatusCode::BAD_REQUEST) - .respond_to(&req) - .await - .unwrap(); + .respond_to(&req); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.body().bin_ref(), b"test"); let res = "test" .to_string() - .with_header("content-type", "json") - .respond_to(&req) - .await - .unwrap(); + .with_header(("content-type", "json")) + .respond_to(&req); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.body().bin_ref(), b"test"); @@ -555,24 +412,19 @@ pub(crate) mod tests { #[actix_rt::test] async fn test_tuple_responder_with_status_code() { let req = TestRequest::default().to_http_request(); - let res = ("test".to_string(), StatusCode::BAD_REQUEST) - .respond_to(&req) - .await - .unwrap(); + let res = ("test".to_string(), StatusCode::BAD_REQUEST).respond_to(&req); assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.body().bin_ref(), b"test"); let req = TestRequest::default().to_http_request(); let res = ("test".to_string(), StatusCode::OK) - .with_header("content-type", "json") - .respond_to(&req) - .await - .unwrap(); + .with_header((CONTENT_TYPE, mime::APPLICATION_JSON)) + .respond_to(&req); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.body().bin_ref(), b"test"); assert_eq!( res.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("json") + HeaderValue::from_static("application/json") ); } } diff --git a/src/rmap.rs b/src/rmap.rs index 6827a11b2..3c8805d57 100644 --- a/src/rmap.rs +++ b/src/rmap.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::rc::{Rc, Weak}; use actix_router::ResourceDef; -use fxhash::FxHashMap; +use ahash::AHashMap; use url::Url; use crate::error::UrlGenerationError; @@ -12,7 +12,7 @@ use crate::request::HttpRequest; pub struct ResourceMap { root: ResourceDef, parent: RefCell>, - named: FxHashMap, + named: AHashMap, patterns: Vec<(ResourceDef, Option>)>, } @@ -21,7 +21,7 @@ impl ResourceMap { ResourceMap { root, parent: RefCell::new(Weak::new()), - named: FxHashMap::default(), + named: AHashMap::default(), patterns: Vec::new(), } } diff --git a/src/route.rs b/src/route.rs index 00d93fce9..b6b2482cd 100644 --- a/src/route.rs +++ b/src/route.rs @@ -18,7 +18,7 @@ use crate::HttpResponse; type BoxedRouteService = Box< dyn Service< - Request = ServiceRequest, + ServiceRequest, Response = ServiceResponse, Error = Error, Future = LocalBoxFuture<'static, Result>, @@ -27,8 +27,8 @@ type BoxedRouteService = Box< type BoxedRouteNewService = Box< dyn ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -63,9 +63,8 @@ impl Route { } } -impl ServiceFactory for Route { +impl ServiceFactory for Route { type Config = (); - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type InitError = (); @@ -117,17 +116,16 @@ impl RouteService { } } -impl Service for RouteService { - type Request = ServiceRequest; +impl Service for RouteService { type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { self.service.call(req) } } @@ -233,41 +231,30 @@ impl Route { struct RouteNewService where - T: ServiceFactory, + T: ServiceFactory, { service: T, } impl RouteNewService where - T: ServiceFactory< - Config = (), - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - >, + T: ServiceFactory, T::Future: 'static, T::Service: 'static, - ::Future: 'static, + >::Future: 'static, { pub fn new(service: T) -> Self { RouteNewService { service } } } -impl ServiceFactory for RouteNewService +impl ServiceFactory for RouteNewService where - T: ServiceFactory< - Config = (), - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - >, + T: ServiceFactory, T::Future: 'static, T::Service: 'static, - ::Future: 'static, + >::Future: 'static, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Config = (); @@ -289,25 +276,24 @@ where } } -struct RouteServiceWrapper { +struct RouteServiceWrapper> { service: T, } -impl Service for RouteServiceWrapper +impl Service for RouteServiceWrapper where T::Future: 'static, - T: Service, + T: Service, { - type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn call(&self, req: ServiceRequest) -> Self::Future { Box::pin(self.service.call(req)) } } @@ -316,7 +302,7 @@ where mod tests { use std::time::Duration; - use actix_rt::time::delay_for; + use actix_rt::time::sleep; use bytes::Bytes; use serde_derive::Serialize; @@ -331,7 +317,7 @@ mod tests { #[actix_rt::test] async fn test_route() { - let mut srv = init_service( + let srv = init_service( App::new() .service( web::resource("/test") @@ -340,16 +326,16 @@ mod tests { Err::(error::ErrorBadRequest("err")) })) .route(web::post().to(|| async { - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; Ok::<_, ()>(HttpResponse::Created()) })) .route(web::delete().to(|| async { - delay_for(Duration::from_millis(100)).await; + sleep(Duration::from_millis(100)).await; Err::(error::ErrorBadRequest("err")) })), ) .service(web::resource("/json").route(web::get().to(|| async { - delay_for(Duration::from_millis(25)).await; + sleep(Duration::from_millis(25)).await; web::Json(MyObject { name: "test".to_string(), }) @@ -360,35 +346,35 @@ mod tests { let req = TestRequest::with_uri("/test") .method(Method::GET) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test") .method(Method::POST) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::CREATED); let req = TestRequest::with_uri("/test") .method(Method::PUT) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/test") .method(Method::DELETE) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_uri("/test") .method(Method::HEAD) .to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); let req = TestRequest::with_uri("/json").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let body = read_body(resp).await; diff --git a/src/scope.rs b/src/scope.rs index 681d142be..dd02501b0 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,17 +1,18 @@ use std::cell::RefCell; use std::fmt; use std::future::Future; -use std::pin::Pin; use std::rc::Rc; -use std::task::{Context, Poll}; +use std::task::Poll; -use actix_http::{Extensions, Response}; -use actix_router::{ResourceDef, ResourceInfo, Router}; +use actix_http::Extensions; +use actix_router::{ResourceDef, Router}; use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{ - apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, + apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt, + Transform, }; -use futures_util::future::{ok, Either, LocalBoxFuture, Ready}; +use futures_core::future::LocalBoxFuture; +use futures_util::future::join_all; use crate::config::ServiceConfig; use crate::data::Data; @@ -28,7 +29,6 @@ use crate::service::{ type Guards = Vec>; type HttpService = BoxService; type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; -type BoxedResponse = LocalBoxFuture<'static, Result>; /// Resources scope. /// @@ -61,10 +61,10 @@ type BoxedResponse = LocalBoxFuture<'static, Result>; pub struct Scope { endpoint: T, rdef: String, - data: Option, + app_data: Option, services: Vec>, guards: Vec>, - default: Rc>>>, + default: Option>, external: Vec, factory_ref: Rc>>, } @@ -76,10 +76,10 @@ impl Scope { Scope { endpoint: ScopeEndpoint::new(fref.clone()), rdef: path.to_string(), - data: None, + app_data: None, guards: Vec::new(), services: Vec::new(), - default: Rc::new(RefCell::new(None)), + default: None, external: Vec::new(), factory_ref: fref, } @@ -89,8 +89,8 @@ impl Scope { impl Scope where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -155,10 +155,10 @@ where /// /// Data of different types from parent contexts will still be accessible. pub fn app_data(mut self, data: U) -> Self { - if self.data.is_none() { - self.data = Some(Extensions::new()); - } - self.data.as_mut().unwrap().insert(data); + self.app_data + .get_or_insert_with(Extensions::new) + .insert(data); + self } @@ -200,26 +200,17 @@ where self.services.extend(cfg.services); self.external.extend(cfg.external); - if !cfg.data.is_empty() { - let mut data = self.data.unwrap_or_else(Extensions::new); - - for value in cfg.data.iter() { - value.create(&mut data); - } - - self.data = Some(data); - } - self.data + self.app_data .get_or_insert_with(Extensions::new) - .extend(cfg.extensions); + .extend(cfg.app_data); self } - /// Register http service. + /// Register HTTP service. /// /// This is similar to `App's` service registration. /// - /// Actix web provides several services implementations: + /// Actix Web provides several services implementations: /// /// * *Resource* is an entry in resource table which corresponds to requested URL. /// * *Scope* is a set of resources with common root path. @@ -285,21 +276,19 @@ where /// If default resource is not registered, app's default resource is being used. pub fn default_service(mut self, f: F) -> Self where - F: IntoServiceFactory, + F: IntoServiceFactory, U: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, > + 'static, U::InitError: fmt::Debug, { // create and configure default resource - self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( - f.into_factory().map_init_err(|e| { - log::error!("Can not construct default service: {:?}", e) - }), - ))))); + self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err( + |e| log::error!("Can not construct default service: {:?}", e), + )))); self } @@ -318,8 +307,8 @@ where mw: M, ) -> Scope< impl ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -328,7 +317,7 @@ where where M: Transform< T::Service, - Request = ServiceRequest, + ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -337,7 +326,7 @@ where Scope { endpoint: apply(mw, self.endpoint), rdef: self.rdef, - data: self.data, + app_data: self.app_data, guards: self.guards, services: self.services, default: self.default, @@ -383,21 +372,21 @@ where mw: F, ) -> Scope< impl ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), >, > where - F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, + F: Fn(ServiceRequest, &T::Service) -> R + Clone, R: Future>, { Scope { endpoint: apply_fn_factory(self.endpoint, mw), rdef: self.rdef, - data: self.data, + app_data: self.app_data, guards: self.guards, services: self.services, default: self.default, @@ -410,8 +399,8 @@ where impl HttpServiceFactory for Scope where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -419,9 +408,7 @@ where { fn register(mut self, config: &mut AppService) { // update default resource if needed - if self.default.borrow().is_none() { - *self.default.borrow_mut() = Some(config.default_service()); - } + let default = self.default.unwrap_or_else(|| config.default_service()); // register nested services let mut cfg = config.clone_config(); @@ -436,15 +423,10 @@ where rmap.add(&mut rdef, None); } - // custom app data storage - if let Some(ref mut ext) = self.data { - config.set_service_data(ext); - } - // complete scope pipeline creation *self.factory_ref.borrow_mut() = Some(ScopeFactory { - data: self.data.take().map(Rc::new), - default: self.default.clone(), + app_data: self.app_data.take().map(Rc::new), + default, services: cfg .into_services() .1 @@ -476,144 +458,75 @@ where } pub struct ScopeFactory { - data: Option>, + app_data: Option>, services: Rc<[(ResourceDef, HttpNewService, RefCell>)]>, - default: Rc>>>, + default: Rc, } -impl ServiceFactory for ScopeFactory { - type Config = (); - type Request = ServiceRequest; +impl ServiceFactory for ScopeFactory { type Response = ServiceResponse; type Error = Error; - type InitError = (); + type Config = (); type Service = ScopeService; - type Future = ScopeFactoryResponse; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - let default_fut = if let Some(ref default) = *self.default.borrow() { - Some(default.new_service(())) - } else { - None - }; + // construct default service factory future + let default_fut = self.default.new_service(()); - ScopeFactoryResponse { - fut: self - .services - .iter() - .map(|(path, service, guards)| { - CreateScopeServiceItem::Future( - Some(path.clone()), - guards.borrow_mut().take(), - service.new_service(()), - ) - }) - .collect(), - default: None, - data: self.data.clone(), - default_fut, - } - } -} - -/// Create scope service -#[doc(hidden)] -#[pin_project::pin_project] -pub struct ScopeFactoryResponse { - fut: Vec, - data: Option>, - default: Option, - default_fut: Option>>, -} - -type HttpServiceFut = LocalBoxFuture<'static, Result>; - -enum CreateScopeServiceItem { - Future(Option, Option, HttpServiceFut), - Service(ResourceDef, Option, HttpService), -} - -impl Future for ScopeFactoryResponse { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut done = true; - - if let Some(ref mut fut) = self.default_fut { - match Pin::new(fut).poll(cx)? { - Poll::Ready(default) => self.default = Some(default), - Poll::Pending => done = false, + // construct all services factory future with it's resource def and guards. + let factory_fut = join_all(self.services.iter().map(|(path, factory, guards)| { + let path = path.clone(); + let guards = guards.borrow_mut().take(); + let factory_fut = factory.new_service(()); + async move { + let service = factory_fut.await?; + Ok((path, guards, service)) } - } + })); - // poll http services - for item in &mut self.fut { - let res = match item { - CreateScopeServiceItem::Future( - ref mut path, - ref mut guards, - ref mut fut, - ) => match Pin::new(fut).poll(cx)? { - Poll::Ready(service) => { - Some((path.take().unwrap(), guards.take(), service)) - } - Poll::Pending => { - done = false; - None - } - }, - CreateScopeServiceItem::Service(_, _, _) => continue, - }; + let app_data = self.app_data.clone(); - if let Some((path, guards, service)) = res { - *item = CreateScopeServiceItem::Service(path, guards, service); - } - } + Box::pin(async move { + let default = default_fut.await?; - if done { - let router = self - .fut + // build router from the factory future result. + let router = factory_fut + .await + .into_iter() + .collect::, _>>()? .drain(..) - .fold(Router::build(), |mut router, item| { - match item { - CreateScopeServiceItem::Service(path, guards, service) => { - router.rdef(path, service).2 = guards; - } - CreateScopeServiceItem::Future(_, _, _) => unreachable!(), - } + .fold(Router::build(), |mut router, (path, guards, service)| { + router.rdef(path, service).2 = guards; router - }); - Poll::Ready(Ok(ScopeService { - data: self.data.clone(), - router: router.finish(), - default: self.default.take(), - _ready: None, - })) - } else { - Poll::Pending - } + }) + .finish(); + + Ok(ScopeService { + app_data, + router, + default, + }) + }) } } pub struct ScopeService { - data: Option>, + app_data: Option>, router: Router>>, - default: Option, - _ready: Option<(ServiceRequest, ResourceInfo)>, + default: HttpService, } -impl Service for ScopeService { - type Request = ServiceRequest; +impl Service for ScopeService { type Response = ServiceResponse; type Error = Error; - type Future = Either>>; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + actix_service::always_ready!(); - fn call(&mut self, mut req: ServiceRequest) -> Self::Future { - let res = self.router.recognize_mut_checked(&mut req, |req, guards| { + fn call(&self, mut req: ServiceRequest) -> Self::Future { + let res = self.router.recognize_checked(&mut req, |req, guards| { if let Some(ref guards) = guards { for f in guards { if !f.check(req.head()) { @@ -624,19 +537,14 @@ impl Service for ScopeService { true }); + if let Some(ref app_data) = self.app_data { + req.add_data_container(app_data.clone()); + } + if let Some((srv, _info)) = res { - if let Some(ref data) = self.data { - req.add_data_container(data.clone()); - } - Either::Left(srv.call(req)) - } else if let Some(ref mut default) = self.default { - if let Some(ref data) = self.data { - req.add_data_container(data.clone()); - } - Either::Left(default.call(req)) + srv.call(req) } else { - let req = req.into_parts().0; - Either::Right(ok(ServiceResponse::new(req, Response::NotFound().finish()))) + self.default.call(req) } } } @@ -652,14 +560,13 @@ impl ScopeEndpoint { } } -impl ServiceFactory for ScopeEndpoint { - type Config = (); - type Request = ServiceRequest; +impl ServiceFactory for ScopeEndpoint { type Response = ServiceResponse; type Error = Error; - type InitError = (); + type Config = (); type Service = ScopeService; - type Future = ScopeFactoryResponse; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { self.factory.borrow_mut().as_mut().unwrap().new_service(()) @@ -681,10 +588,11 @@ mod tests { #[actix_rt::test] async fn test_scope() { - let mut srv = init_service(App::new().service( - web::scope("/app").service(web::resource("/path1").to(HttpResponse::Ok)), - )) - .await; + let srv = + init_service(App::new().service( + web::scope("/app").service(web::resource("/path1").to(HttpResponse::Ok)), + )) + .await; let req = TestRequest::with_uri("/app/path1").to_request(); let resp = srv.call(req).await.unwrap(); @@ -693,7 +601,7 @@ mod tests { #[actix_rt::test] async fn test_scope_root() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("/app") .service(web::resource("").to(HttpResponse::Ok)) @@ -713,9 +621,10 @@ mod tests { #[actix_rt::test] async fn test_scope_root2() { - let mut srv = init_service(App::new().service( - web::scope("/app/").service(web::resource("").to(HttpResponse::Ok)), - )) + let srv = init_service( + App::new() + .service(web::scope("/app/").service(web::resource("").to(HttpResponse::Ok))), + ) .await; let req = TestRequest::with_uri("/app").to_request(); @@ -729,9 +638,10 @@ mod tests { #[actix_rt::test] async fn test_scope_root3() { - let mut srv = init_service(App::new().service( - web::scope("/app/").service(web::resource("/").to(HttpResponse::Ok)), - )) + let srv = init_service( + App::new() + .service(web::scope("/app/").service(web::resource("/").to(HttpResponse::Ok))), + ) .await; let req = TestRequest::with_uri("/app").to_request(); @@ -745,7 +655,7 @@ mod tests { #[actix_rt::test] async fn test_scope_route() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("app") .route("/path1", web::get().to(HttpResponse::Ok)) @@ -773,7 +683,7 @@ mod tests { #[actix_rt::test] async fn test_scope_route_without_leading_slash() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("app").service( web::resource("path1") @@ -803,7 +713,7 @@ mod tests { #[actix_rt::test] async fn test_scope_guard() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("/app") .guard(guard::Get()) @@ -827,14 +737,12 @@ mod tests { #[actix_rt::test] async fn test_scope_variable_segment() { - let mut srv = - init_service(App::new().service(web::scope("/ab-{project}").service( - web::resource("/path1").to(|r: HttpRequest| { - HttpResponse::Ok() - .body(format!("project: {}", &r.match_info()["project"])) - }), - ))) - .await; + let srv = init_service(App::new().service(web::scope("/ab-{project}").service( + web::resource("/path1").to(|r: HttpRequest| { + HttpResponse::Ok().body(format!("project: {}", &r.match_info()["project"])) + }), + ))) + .await; let req = TestRequest::with_uri("/ab-project1/path1").to_request(); let resp = srv.call(req).await.unwrap(); @@ -855,7 +763,7 @@ mod tests { #[actix_rt::test] async fn test_nested_scope() { - let mut srv = init_service(App::new().service(web::scope("/app").service( + let srv = init_service(App::new().service(web::scope("/app").service( web::scope("/t1").service(web::resource("/path1").to(HttpResponse::Created)), ))) .await; @@ -867,7 +775,7 @@ mod tests { #[actix_rt::test] async fn test_nested_scope_no_slash() { - let mut srv = init_service(App::new().service(web::scope("/app").service( + let srv = init_service(App::new().service(web::scope("/app").service( web::scope("t1").service(web::resource("/path1").to(HttpResponse::Created)), ))) .await; @@ -879,7 +787,7 @@ mod tests { #[actix_rt::test] async fn test_nested_scope_root() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("/app").service( web::scope("/t1") @@ -901,7 +809,7 @@ mod tests { #[actix_rt::test] async fn test_nested_scope_filter() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("/app").service( web::scope("/t1") @@ -927,7 +835,7 @@ mod tests { #[actix_rt::test] async fn test_nested_scope_with_variable_segment() { - let mut srv = init_service(App::new().service(web::scope("/app").service( + let srv = init_service(App::new().service(web::scope("/app").service( web::scope("/{project_id}").service(web::resource("/path1").to( |r: HttpRequest| { HttpResponse::Created() @@ -952,7 +860,7 @@ mod tests { #[actix_rt::test] async fn test_nested2_scope_with_variable_segment() { - let mut srv = init_service(App::new().service(web::scope("/app").service( + let srv = init_service(App::new().service(web::scope("/app").service( web::scope("/{project}").service(web::scope("/{id}").service( web::resource("/path1").to(|r: HttpRequest| { HttpResponse::Created().body(format!( @@ -984,7 +892,7 @@ mod tests { #[actix_rt::test] async fn test_default_resource() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("/app") .service(web::resource("/path1").to(HttpResponse::Ok)) @@ -1006,7 +914,7 @@ mod tests { #[actix_rt::test] async fn test_default_resource_propagation() { - let mut srv = init_service( + let srv = init_service( App::new() .service( web::scope("/app1") @@ -1034,24 +942,20 @@ mod tests { #[actix_rt::test] async fn test_middleware() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("app") .wrap( - DefaultHeaders::new().header( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ), + DefaultHeaders::new() + .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), ) - .service( - web::resource("/test").route(web::get().to(HttpResponse::Ok)), - ), + .service(web::resource("/test").route(web::get().to(HttpResponse::Ok))), ), ) .await; let req = TestRequest::with_uri("/app/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -1061,17 +965,15 @@ mod tests { #[actix_rt::test] async fn test_middleware_fn() { - let mut srv = init_service( + let srv = init_service( App::new().service( web::scope("app") .wrap_fn(|req, srv| { let fut = srv.call(req); async move { let mut res = fut.await?; - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("0001"), - ); + res.headers_mut() + .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); Ok(res) } }) @@ -1081,7 +983,7 @@ mod tests { .await; let req = TestRequest::with_uri("/app/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); assert_eq!( resp.headers().get(header::CONTENT_TYPE).unwrap(), @@ -1091,7 +993,7 @@ mod tests { #[actix_rt::test] async fn test_override_data() { - let mut srv = init_service(App::new().data(1usize).service( + let srv = init_service(App::new().data(1usize).service( web::scope("app").data(10usize).route( "/t", web::get().to(|data: web::Data| { @@ -1103,13 +1005,13 @@ mod tests { .await; let req = TestRequest::with_uri("/app/t").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_override_data_default_service() { - let mut srv = init_service(App::new().data(1usize).service( + let srv = init_service(App::new().data(1usize).service( web::scope("app").data(10usize).default_service(web::to( |data: web::Data| { assert_eq!(**data, 10); @@ -1120,13 +1022,13 @@ mod tests { .await; let req = TestRequest::with_uri("/app/t").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_override_app_data() { - let mut srv = init_service(App::new().app_data(web::Data::new(1usize)).service( + let srv = init_service(App::new().app_data(web::Data::new(1usize)).service( web::scope("app").app_data(web::Data::new(10usize)).route( "/t", web::get().to(|data: web::Data| { @@ -1138,17 +1040,16 @@ mod tests { .await; let req = TestRequest::with_uri("/app/t").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } #[actix_rt::test] async fn test_scope_config() { - let mut srv = - init_service(App::new().service(web::scope("/app").configure(|s| { - s.route("/path1", web::get().to(HttpResponse::Ok)); - }))) - .await; + let srv = init_service(App::new().service(web::scope("/app").configure(|s| { + s.route("/path1", web::get().to(HttpResponse::Ok)); + }))) + .await; let req = TestRequest::with_uri("/app/path1").to_request(); let resp = srv.call(req).await.unwrap(); @@ -1157,13 +1058,12 @@ mod tests { #[actix_rt::test] async fn test_scope_config_2() { - let mut srv = - init_service(App::new().service(web::scope("/app").configure(|s| { - s.service(web::scope("/v1").configure(|s| { - s.route("/", web::get().to(HttpResponse::Ok)); - })); - }))) - .await; + let srv = init_service(App::new().service(web::scope("/app").configure(|s| { + s.service(web::scope("/v1").configure(|s| { + s.route("/", web::get().to(HttpResponse::Ok)); + })); + }))) + .await; let req = TestRequest::with_uri("/app/v1/").to_request(); let resp = srv.call(req).await.unwrap(); @@ -1172,24 +1072,19 @@ mod tests { #[actix_rt::test] async fn test_url_for_external() { - let mut srv = - init_service(App::new().service(web::scope("/app").configure(|s| { - s.service(web::scope("/v1").configure(|s| { - s.external_resource( - "youtube", - "https://youtube.com/watch/{video_id}", - ); - s.route( - "/", - web::get().to(|req: HttpRequest| { - HttpResponse::Ok().body( - req.url_for("youtube", &["xxxxxx"]).unwrap().to_string(), - ) - }), - ); - })); - }))) - .await; + let srv = init_service(App::new().service(web::scope("/app").configure(|s| { + s.service(web::scope("/v1").configure(|s| { + s.external_resource("youtube", "https://youtube.com/watch/{video_id}"); + s.route( + "/", + web::get().to(|req: HttpRequest| { + HttpResponse::Ok() + .body(req.url_for("youtube", &["xxxxxx"]).unwrap().to_string()) + }), + ); + })); + }))) + .await; let req = TestRequest::with_uri("/app/v1/").to_request(); let resp = srv.call(req).await.unwrap(); @@ -1200,7 +1095,7 @@ mod tests { #[actix_rt::test] async fn test_url_for_nested() { - let mut srv = init_service(App::new().service(web::scope("/a").service( + let srv = init_service(App::new().service(web::scope("/a").service( web::scope("/b").service(web::resource("/c/{stuff}").name("c").route( web::get().to(|req: HttpRequest| { HttpResponse::Ok() @@ -1211,7 +1106,7 @@ mod tests { .await; let req = TestRequest::with_uri("/a/b/c/test").to_request(); - let resp = call_service(&mut srv, req).await; + let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); let body = read_body(resp).await; assert_eq!( diff --git a/src/server.rs b/src/server.rs index be97e8a0d..d69d6570d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,6 @@ use std::{ any::Any, - fmt, io, + cmp, fmt, io, marker::PhantomData, net, sync::{Arc, Mutex}, @@ -20,9 +20,9 @@ use actix_service::pipeline_factory; use futures_util::future::ok; #[cfg(feature = "openssl")] -use actix_tls::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; +use actix_tls::accept::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; #[cfg(feature = "rustls")] -use actix_tls::rustls::ServerConfig as RustlsServerConfig; +use actix_tls::accept::rustls::ServerConfig as RustlsServerConfig; use crate::config::AppConfig; @@ -40,7 +40,7 @@ struct Config { /// An HTTP Server. /// -/// Create new http server with application factory. +/// Create new HTTP server with application factory. /// /// ```rust,no_run /// use actix_web::{web, App, HttpResponse, HttpServer}; @@ -58,8 +58,8 @@ struct Config { pub struct HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoServiceFactory, - S: ServiceFactory, + I: IntoServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, @@ -67,25 +67,26 @@ where { pub(super) factory: F, config: Arc>, - backlog: i32, + backlog: u32, sockets: Vec, builder: ServerBuilder, on_connect_fn: Option>, - _t: PhantomData<(S, B)>, + _phantom: PhantomData<(S, B)>, } impl HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoServiceFactory, - S: ServiceFactory, + I: IntoServiceFactory, + S: ServiceFactory + 'static, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, + S::Service: 'static, B: MessageBody + 'static, { - /// Create new http server with application factory + /// Create new HTTP server with application factory pub fn new(factory: F) -> Self { HttpServer { factory, @@ -99,7 +100,7 @@ where sockets: Vec::new(), builder: ServerBuilder::default(), on_connect_fn: None, - _t: PhantomData, + _phantom: PhantomData, } } @@ -124,14 +125,13 @@ where sockets: self.sockets, builder: self.builder, on_connect_fn: Some(Arc::new(f)), - _t: PhantomData, + _phantom: PhantomData, } } /// Set number of workers to start. /// - /// By default http server uses number of available logical cpu as threads - /// count. + /// By default, server uses number of available logical CPU as thread count. pub fn workers(mut self, num: usize) -> Self { self.builder = self.builder.workers(num); self @@ -147,7 +147,7 @@ where /// Generally set in the 64-2048 range. Default value is 2048. /// /// This method should be called before `bind()` method call. - pub fn backlog(mut self, backlog: i32) -> Self { + pub fn backlog(mut self, backlog: u32) -> Self { self.backlog = backlog; self.builder = self.builder.backlog(backlog); self @@ -170,8 +170,10 @@ where /// limit the global TLS CPU usage. /// /// By default max connections is set to a 256. + #[allow(unused_variables)] pub fn max_connection_rate(self, num: usize) -> Self { - actix_tls::max_concurrent_tls_connect(num); + #[cfg(any(feature = "rustls", feature = "openssl"))] + actix_tls::accept::max_concurrent_tls_connect(num); self } @@ -254,7 +256,7 @@ where /// Get addresses of bound sockets and the scheme for it. /// /// This is useful when the server is bound from different sources - /// with some sockets listening on http and some listening on https + /// with some sockets listening on HTTP and some listening on HTTPS /// and the user should be presented with an enumeration of which /// socket requires which protocol. pub fn addrs_with_scheme(&self) -> Vec<(net::SocketAddr, &str)> { @@ -275,34 +277,28 @@ where }); let on_connect_fn = self.on_connect_fn.clone(); - self.builder = self.builder.listen( - format!("actix-web-service-{}", addr), - lst, - move || { - let c = cfg.lock().unwrap(); - let cfg = AppConfig::new( - false, - addr, - c.host.clone().unwrap_or_else(|| format!("{}", addr)), - ); + self.builder = + self.builder + .listen(format!("actix-web-service-{}", addr), lst, move || { + let c = cfg.lock().unwrap(); + let host = c.host.clone().unwrap_or_else(|| format!("{}", addr)); - let svc = HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .local_addr(addr); + let svc = HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .local_addr(addr); - let svc = if let Some(handler) = on_connect_fn.clone() { - svc.on_connect_ext(move |io: &_, ext: _| { - (handler)(io as &dyn Any, ext) - }) - } else { - svc - }; + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| (handler)(io as &dyn Any, ext)) + } else { + svc + }; - svc.finish(map_config(factory(), move |_| cfg.clone())) + svc.finish(map_config(factory(), move |_| { + AppConfig::new(false, addr, host.clone()) + })) .tcp() - }, - )?; + })?; Ok(self) } @@ -334,34 +330,30 @@ where let on_connect_fn = self.on_connect_fn.clone(); - self.builder = self.builder.listen( - format!("actix-web-service-{}", addr), - lst, - move || { - let c = cfg.lock().unwrap(); - let cfg = AppConfig::new( - true, - addr, - c.host.clone().unwrap_or_else(|| format!("{}", addr)), - ); + self.builder = + self.builder + .listen(format!("actix-web-service-{}", addr), lst, move || { + let c = cfg.lock().unwrap(); + let host = c.host.clone().unwrap_or_else(|| format!("{}", addr)); - let svc = HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown); + let svc = HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown); - let svc = if let Some(handler) = on_connect_fn.clone() { - svc.on_connect_ext(move |io: &_, ext: _| { - (&*handler)(io as &dyn Any, ext) - }) - } else { - svc - }; + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (&*handler)(io as &dyn Any, ext) + }) + } else { + svc + }; - svc.finish(map_config(factory(), move |_| cfg.clone())) + svc.finish(map_config(factory(), move |_| { + AppConfig::new(true, addr, host.clone()) + })) .openssl(acceptor.clone()) - }, - )?; + })?; Ok(self) } @@ -393,34 +385,28 @@ where let on_connect_fn = self.on_connect_fn.clone(); - self.builder = self.builder.listen( - format!("actix-web-service-{}", addr), - lst, - move || { - let c = cfg.lock().unwrap(); - let cfg = AppConfig::new( - true, - addr, - c.host.clone().unwrap_or_else(|| format!("{}", addr)), - ); + self.builder = + self.builder + .listen(format!("actix-web-service-{}", addr), lst, move || { + let c = cfg.lock().unwrap(); + let host = c.host.clone().unwrap_or_else(|| format!("{}", addr)); - let svc = HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown); + let svc = HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown); - let svc = if let Some(handler) = on_connect_fn.clone() { - svc.on_connect_ext(move |io: &_, ext: _| { - (handler)(io as &dyn Any, ext) - }) - } else { - svc - }; + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| (handler)(io as &dyn Any, ext)) + } else { + svc + }; - svc.finish(map_config(factory(), move |_| cfg.clone())) + svc.finish(map_config(factory(), move |_| { + AppConfig::new(true, addr, host.clone()) + })) .rustls(config.clone()) - }, - )?; + })?; Ok(self) } @@ -437,10 +423,7 @@ where Ok(self) } - fn bind2( - &self, - addr: A, - ) -> io::Result> { + fn bind2(&self, addr: A) -> io::Result> { let mut err = None; let mut success = false; let mut sockets = Vec::new(); @@ -473,11 +456,7 @@ where /// Start listening for incoming tls connections. /// /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn bind_openssl
( - mut self, - addr: A, - builder: SslAcceptorBuilder, - ) -> io::Result + pub fn bind_openssl(mut self, addr: A, builder: SslAcceptorBuilder) -> io::Result where A: net::ToSocketAddrs, { @@ -509,18 +488,13 @@ where #[cfg(unix)] /// Start listening for unix domain (UDS) connections on existing listener. - pub fn listen_uds( - mut self, - lst: std::os::unix::net::UnixListener, - ) -> io::Result { + pub fn listen_uds(mut self, lst: std::os::unix::net::UnixListener) -> io::Result { use actix_rt::net::UnixStream; let cfg = self.config.clone(); let factory = self.factory.clone(); - let socket_addr = net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 8080, - ); + let socket_addr = + net::SocketAddr::new(net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), 8080); self.sockets.push(Socket { scheme: "http", addr: socket_addr, @@ -537,23 +511,19 @@ where c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); - pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( - { - let svc = HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout); + pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then({ + let svc = HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout); - let svc = if let Some(handler) = on_connect_fn.clone() { - svc.on_connect_ext(move |io: &_, ext: _| { - (&*handler)(io as &dyn Any, ext) - }) - } else { - svc - }; + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| (&*handler)(io as &dyn Any, ext)) + } else { + svc + }; - svc.finish(map_config(factory(), move |_| config.clone())) - }, - ) + svc.finish(map_config(factory(), move |_| config.clone())) + }) })?; Ok(self) } @@ -568,10 +538,8 @@ where let cfg = self.config.clone(); let factory = self.factory.clone(); - let socket_addr = net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 8080, - ); + let socket_addr = + net::SocketAddr::new(net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), 8080); self.sockets.push(Socket { scheme: "http", addr: socket_addr, @@ -587,13 +555,12 @@ where socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); - pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))) - .and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .finish(map_config(factory(), move |_| config.clone())), - ) + pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .finish(map_config(factory(), move |_| config.clone())), + ) }, )?; Ok(self) @@ -603,8 +570,8 @@ where impl HttpServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoServiceFactory, - S: ServiceFactory, + I: IntoServiceFactory, + S: ServiceFactory, S::Error: Into, S::InitError: fmt::Debug, S::Response: Into>, @@ -613,7 +580,7 @@ where { /// Start listening for incoming connections. /// - /// This method starts number of http workers in separate threads. + /// This method starts number of HTTP workers in separate threads. /// For each address this method starts separate thread which does /// `accept()` in a loop. /// @@ -633,14 +600,11 @@ where /// } /// ``` pub fn run(self) -> Server { - self.builder.start() + self.builder.run() } } -fn create_tcp_listener( - addr: net::SocketAddr, - backlog: i32, -) -> io::Result { +fn create_tcp_listener(addr: net::SocketAddr, backlog: u32) -> io::Result { use socket2::{Domain, Protocol, Socket, Type}; let domain = match addr { net::SocketAddr::V4(_) => Domain::ipv4(), @@ -649,6 +613,8 @@ fn create_tcp_listener( let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?; socket.set_reuse_address(true)?; socket.bind(&addr.into())?; + // clamp backlog to max u32 that fits in i32 range + let backlog = cmp::min(backlog, i32::MAX as u32) as i32; socket.listen(backlog)?; Ok(socket.into_tcp_listener()) } diff --git a/src/service.rs b/src/service.rs index 189ba5554..fcbe61a02 100644 --- a/src/service.rs +++ b/src/service.rs @@ -5,8 +5,7 @@ use std::{fmt, net}; use actix_http::body::{Body, MessageBody, ResponseBody}; use actix_http::http::{HeaderMap, Method, StatusCode, Uri, Version}; use actix_http::{ - Error, Extensions, HttpMessage, Payload, PayloadStream, RequestHead, Response, - ResponseHead, + Error, Extensions, HttpMessage, Payload, PayloadStream, RequestHead, Response, ResponseHead, }; use actix_router::{IntoPattern, Path, Resource, ResourceDef, Url}; use actix_service::{IntoServiceFactory, ServiceFactory}; @@ -22,6 +21,13 @@ pub trait HttpServiceFactory { fn register(self, config: &mut AppService); } +impl HttpServiceFactory for Vec { + fn register(self, config: &mut AppService) { + self.into_iter() + .for_each(|factory| factory.register(config)); + } +} + pub(crate) trait AppServiceFactory { fn register(&mut self, config: &mut AppService); } @@ -52,71 +58,61 @@ where /// An service http request /// /// ServiceRequest allows mutable access to request's internal structures -pub struct ServiceRequest(HttpRequest); +pub struct ServiceRequest { + req: HttpRequest, + payload: Payload, +} impl ServiceRequest { /// Construct service request - pub(crate) fn new(req: HttpRequest) -> Self { - ServiceRequest(req) + pub(crate) fn new(req: HttpRequest, payload: Payload) -> Self { + Self { req, payload } } /// Deconstruct request into parts - pub fn into_parts(mut self) -> (HttpRequest, Payload) { - let pl = Rc::get_mut(&mut (self.0).0).unwrap().payload.take(); - (self.0, pl) + #[inline] + pub fn into_parts(self) -> (HttpRequest, Payload) { + (self.req, self.payload) } /// Construct request from parts. - /// - /// `ServiceRequest` can be re-constructed only if `req` hasn't been cloned. - pub fn from_parts( - mut req: HttpRequest, - pl: Payload, - ) -> Result { - if Rc::strong_count(&req.0) == 1 && Rc::weak_count(&req.0) == 0 { - Rc::get_mut(&mut req.0).unwrap().payload = pl; - Ok(ServiceRequest(req)) - } else { - Err((req, pl)) - } + pub fn from_parts(req: HttpRequest, payload: Payload) -> Self { + Self { req, payload } } /// Construct request from request. /// - /// `HttpRequest` implements `Clone` trait via `Rc` type. `ServiceRequest` - /// can be re-constructed only if rc's strong pointers count eq 1 and - /// weak pointers count is 0. - pub fn from_request(req: HttpRequest) -> Result { - if Rc::strong_count(&req.0) == 1 && Rc::weak_count(&req.0) == 0 { - Ok(ServiceRequest(req)) - } else { - Err(req) + /// The returned `ServiceRequest` would have no payload. + pub fn from_request(req: HttpRequest) -> Self { + ServiceRequest { + req, + payload: Payload::None, } } /// Create service response #[inline] pub fn into_response>>(self, res: R) -> ServiceResponse { - ServiceResponse::new(self.0, res.into()) + ServiceResponse::new(self.req, res.into()) } /// Create service response for error #[inline] pub fn error_response>(self, err: E) -> ServiceResponse { let res: Response = err.into().into(); - ServiceResponse::new(self.0, res.into_body()) + ServiceResponse::new(self.req, res.into_body()) } /// This method returns reference to the request head #[inline] pub fn head(&self) -> &RequestHead { - &self.0.head() + &self.req.head() } /// This method returns reference to the request head #[inline] pub fn head_mut(&mut self) -> &mut RequestHead { - self.0.head_mut() + self.req.head_mut() } /// Request's uri. @@ -167,12 +163,14 @@ impl ServiceRequest { } } - /// Peer socket address + /// Peer socket address. /// - /// Peer address is actual socket address, if proxy is used in front of - /// actix http server, then peer address would be address of this proxy. + /// Peer address is the directly connected peer's socket address. If a proxy is used in front of + /// the Actix Web server, then it would be address of this proxy. /// /// To get client connection information `ConnectionInfo` should be used. + /// + /// Will only return None when called in unit tests. #[inline] pub fn peer_addr(&self) -> Option { self.head().peer_addr @@ -192,42 +190,42 @@ impl ServiceRequest { /// access the matched value for that segment. #[inline] pub fn match_info(&self) -> &Path { - self.0.match_info() + self.req.match_info() } /// Counterpart to [`HttpRequest::match_name`](super::HttpRequest::match_name()). #[inline] pub fn match_name(&self) -> Option<&str> { - self.0.match_name() + self.req.match_name() } /// Counterpart to [`HttpRequest::match_pattern`](super::HttpRequest::match_pattern()). #[inline] pub fn match_pattern(&self) -> Option { - self.0.match_pattern() + self.req.match_pattern() } #[inline] /// Get a mutable reference to the Path parameters. pub fn match_info_mut(&mut self) -> &mut Path { - self.0.match_info_mut() + self.req.match_info_mut() } #[inline] /// Get a reference to a `ResourceMap` of current application. pub fn resource_map(&self) -> &ResourceMap { - self.0.resource_map() + self.req.resource_map() } /// Service configuration #[inline] pub fn app_config(&self) -> &AppConfig { - self.0.app_config() + self.req.app_config() } /// Counterpart to [`HttpRequest::app_data`](super::HttpRequest::app_data()). pub fn app_data(&self) -> Option<&T> { - for container in (self.0).0.app_data.iter().rev() { + for container in self.req.inner.app_data.iter().rev() { if let Some(data) = container.get::() { return Some(data); } @@ -238,13 +236,15 @@ impl ServiceRequest { /// Set request payload. pub fn set_payload(&mut self, payload: Payload) { - Rc::get_mut(&mut (self.0).0).unwrap().payload = payload; + self.payload = payload; } - #[doc(hidden)] - /// Add app data container to request's resolution set. + /// Add data container to request's resolution set. + /// + /// In middleware, prefer [`extensions_mut`](ServiceRequest::extensions_mut) for request-local + /// data since it is assumed that the same app data is presented for every request. pub fn add_data_container(&mut self, extensions: Rc) { - Rc::get_mut(&mut (self.0).0) + Rc::get_mut(&mut (self.req).inner) .unwrap() .app_data .push(extensions); @@ -269,18 +269,18 @@ impl HttpMessage for ServiceRequest { /// Request extensions #[inline] fn extensions(&self) -> Ref<'_, Extensions> { - self.0.extensions() + self.req.extensions() } /// Mutable reference to a the request's extensions #[inline] fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.0.extensions_mut() + self.req.extensions_mut() } #[inline] fn take_payload(&mut self) -> Payload { - Rc::get_mut(&mut (self.0).0).unwrap().payload.take() + self.payload.take() } } @@ -412,9 +412,9 @@ impl ServiceResponse { } } -impl Into> for ServiceResponse { - fn into(self) -> Response { - self.response +impl From> for Response { + fn from(res: ServiceResponse) -> Response { + res.response } } @@ -486,10 +486,10 @@ impl WebService { /// Set a service factory implementation and generate web service. pub fn finish(self, service: F) -> impl HttpServiceFactory where - F: IntoServiceFactory, + F: IntoServiceFactory, T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -514,8 +514,8 @@ struct WebServiceImpl { impl HttpServiceFactory for WebServiceImpl where T: ServiceFactory< + ServiceRequest, Config = (), - Request = ServiceRequest, Response = ServiceResponse, Error = Error, InitError = (), @@ -540,6 +540,65 @@ where } } +/// Macro helping register different types of services at the sametime. +/// +/// The service type must be implementing [`HttpServiceFactory`](self::HttpServiceFactory) trait. +/// +/// The max number of services can be grouped together is 12. +/// +/// # Examples +/// +/// ``` +/// use actix_web::{services, web, App}; +/// +/// let services = services![ +/// web::resource("/test2").to(|| async { "test2" }), +/// web::scope("/test3").route("/", web::get().to(|| async { "test3" })) +/// ]; +/// +/// let app = App::new().service(services); +/// +/// // services macro just convert multiple services to a tuple. +/// // below would also work without importing the macro. +/// let app = App::new().service(( +/// web::resource("/test2").to(|| async { "test2" }), +/// web::scope("/test3").route("/", web::get().to(|| async { "test3" })) +/// )); +/// ``` +#[macro_export] +macro_rules! services { + ($($x:expr),+ $(,)?) => { + ($($x,)+) + } +} + +/// HttpServiceFactory trait impl for tuples +macro_rules! service_tuple ({ $(($n:tt, $T:ident)),+} => { + impl<$($T: HttpServiceFactory),+> HttpServiceFactory for ($($T,)+) { + fn register(self, config: &mut AppService) { + $(self.$n.register(config);)+ + } + } +}); + +#[rustfmt::skip] +mod m { + use super::*; + + service_tuple!((0, A)); + service_tuple!((0, A), (1, B)); + service_tuple!((0, A), (1, B), (2, C)); + service_tuple!((0, A), (1, B), (2, C), (3, D)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J), (10, K)); + service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J), (10, K), (11, L)); +} + #[cfg(test)] mod tests { use super::*; @@ -548,30 +607,9 @@ mod tests { use actix_service::Service; use futures_util::future::ok; - #[test] - fn test_service_request() { - let req = TestRequest::default().to_srv_request(); - let (r, pl) = req.into_parts(); - assert!(ServiceRequest::from_parts(r, pl).is_ok()); - - let req = TestRequest::default().to_srv_request(); - let (r, pl) = req.into_parts(); - let _r2 = r.clone(); - assert!(ServiceRequest::from_parts(r, pl).is_err()); - - let req = TestRequest::default().to_srv_request(); - let (r, _pl) = req.into_parts(); - assert!(ServiceRequest::from_request(r).is_ok()); - - let req = TestRequest::default().to_srv_request(); - let (r, _pl) = req.into_parts(); - let _r2 = r.clone(); - assert!(ServiceRequest::from_request(r).is_err()); - } - #[actix_rt::test] async fn test_service() { - let mut srv = init_service( + let srv = init_service( App::new().service(web::service("/test").name("test").finish( |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())), )), @@ -581,7 +619,7 @@ mod tests { let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), http::StatusCode::OK); - let mut srv = init_service( + let srv = init_service( App::new().service(web::service("/test").guard(guard::Get()).finish( |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())), )), @@ -596,20 +634,18 @@ mod tests { #[actix_rt::test] async fn test_service_data() { - let mut srv = init_service( - App::new() - .data(42u32) - .service(web::service("/test").name("test").finish( - |req: ServiceRequest| { - assert_eq!( - req.app_data::>().unwrap().as_ref(), - &42 - ); - ok(req.into_response(HttpResponse::Ok().finish())) - }, - )), - ) - .await; + let srv = + init_service( + App::new() + .data(42u32) + .service(web::service("/test").name("test").finish( + |req: ServiceRequest| { + assert_eq!(req.app_data::>().unwrap().as_ref(), &42); + ok(req.into_response(HttpResponse::Ok().finish())) + }, + )), + ) + .await; let req = TestRequest::with_uri("/test").to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), http::StatusCode::OK); @@ -619,14 +655,14 @@ mod tests { fn test_fmt_debug() { let req = TestRequest::get() .uri("/index.html?test=1") - .header("x-test", "111") + .insert_header(("x-test", "111")) .to_srv_request(); let s = format!("{:?}", req); assert!(s.contains("ServiceRequest")); assert!(s.contains("test=1")); assert!(s.contains("x-test")); - let res = HttpResponse::Ok().header("x-test", "111").finish(); + let res = HttpResponse::Ok().insert_header(("x-test", "111")).finish(); let res = TestRequest::post() .uri("/index.html?test=1") .to_srv_response(res); @@ -635,4 +671,80 @@ mod tests { assert!(s.contains("ServiceResponse")); assert!(s.contains("x-test")); } + + #[actix_rt::test] + async fn test_services_macro() { + let scoped = services![ + web::service("/scoped_test1").name("scoped_test1").finish( + |req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::Ok().finish())) + } + ), + web::resource("/scoped_test2").to(|| async { "test2" }), + ]; + + let services = services![ + web::service("/test1") + .name("test") + .finish(|req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::Ok().finish())) + }), + web::resource("/test2").to(|| async { "test2" }), + web::scope("/test3").service(scoped) + ]; + + let srv = init_service(App::new().service(services)).await; + + let req = TestRequest::with_uri("/test1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + + let req = TestRequest::with_uri("/test2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + + let req = TestRequest::with_uri("/test3/scoped_test1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + + let req = TestRequest::with_uri("/test3/scoped_test2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + } + + #[actix_rt::test] + async fn test_services_vec() { + let services = vec![ + web::resource("/test1").to(|| async { "test1" }), + web::resource("/test2").to(|| async { "test2" }), + ]; + + let scoped = vec![ + web::resource("/scoped_test1").to(|| async { "test1" }), + web::resource("/scoped_test2").to(|| async { "test2" }), + ]; + + let srv = init_service( + App::new() + .service(services) + .service(web::scope("/test3").service(scoped)), + ) + .await; + + let req = TestRequest::with_uri("/test1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + + let req = TestRequest::with_uri("/test2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + + let req = TestRequest::with_uri("/test3/scoped_test1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + + let req = TestRequest::with_uri("/test3/scoped_test2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + } } diff --git a/src/test.rs b/src/test.rs index cff6c3e51..dd2426fec 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,20 +1,20 @@ //! Various helpers for Actix applications to use during testing. -use std::convert::TryFrom; + use std::net::SocketAddr; use std::rc::Rc; use std::sync::mpsc; use std::{fmt, net, thread, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_http::http::header::{ContentType, Header, HeaderName, IntoHeaderValue}; -use actix_http::http::{Error as HttpError, Method, StatusCode, Uri, Version}; +#[cfg(feature = "cookies")] +use actix_http::cookie::Cookie; +use actix_http::http::header::{ContentType, IntoHeaderPair}; +use actix_http::http::{Method, StatusCode, Uri, Version}; use actix_http::test::TestRequest as HttpTestRequest; -use actix_http::{cookie::Cookie, ws, Extensions, HttpService, Request}; +use actix_http::{ws, Extensions, HttpService, Request}; use actix_router::{Path, ResourceDef, Url}; -use actix_rt::{time::delay_for, System}; -use actix_service::{ - map_config, IntoService, IntoServiceFactory, Service, ServiceFactory, -}; +use actix_rt::{time::sleep, System}; +use actix_service::{map_config, IntoService, IntoServiceFactory, Service, ServiceFactory}; use awc::error::PayloadError; use awc::{Client, ClientRequest, ClientResponse, Connector}; use bytes::{Bytes, BytesMut}; @@ -27,26 +27,24 @@ use socket2::{Domain, Protocol, Socket, Type}; pub use actix_http::test::TestBuffer; +use crate::app_service::AppInitServiceState; use crate::config::AppConfig; use crate::data::Data; use crate::dev::{Body, MessageBody, Payload, Server}; -use crate::request::HttpRequestPool; use crate::rmap::ResourceMap; use crate::service::{ServiceRequest, ServiceResponse}; use crate::{Error, HttpRequest, HttpResponse}; /// Create service that always responds with `HttpResponse::Ok()` pub fn ok_service( -) -> impl Service, Error = Error> -{ +) -> impl Service, Error = Error> { default_service(StatusCode::OK) } /// Create service that responds with response with specified status code pub fn default_service( status_code: StatusCode, -) -> impl Service, Error = Error> -{ +) -> impl Service, Error = Error> { (move |req: ServiceRequest| { ok(req.into_response(HttpResponse::build(status_code).finish())) }) @@ -62,7 +60,7 @@ pub fn default_service( /// /// #[actix_rt::test] /// async fn test_init_service() { -/// let mut app = test::init_service( +/// let app = test::init_service( /// App::new() /// .service(web::resource("/test").to(|| async { HttpResponse::Ok() })) /// ).await; @@ -77,15 +75,10 @@ pub fn default_service( /// ``` pub async fn init_service( app: R, -) -> impl Service, Error = E> +) -> impl Service, Error = E> where - R: IntoServiceFactory, - S: ServiceFactory< - Config = AppConfig, - Request = Request, - Response = ServiceResponse, - Error = E, - >, + R: IntoServiceFactory, + S: ServiceFactory, Error = E>, S::InitError: std::fmt::Debug, { try_init_service(app) @@ -96,18 +89,10 @@ where /// Fallible version of init_service that allows testing data factory errors. pub(crate) async fn try_init_service( app: R, -) -> Result< - impl Service, Error = E>, - S::InitError, -> +) -> Result, Error = E>, S::InitError> where - R: IntoServiceFactory, - S: ServiceFactory< - Config = AppConfig, - Request = Request, - Response = ServiceResponse, - Error = E, - >, + R: IntoServiceFactory, + S: ServiceFactory, Error = E>, S::InitError: std::fmt::Debug, { let srv = app.into_factory(); @@ -121,7 +106,7 @@ where /// /// #[actix_rt::test] /// async fn test_response() { -/// let mut app = test::init_service( +/// let app = test::init_service( /// App::new() /// .service(web::resource("/test").to(|| async { /// HttpResponse::Ok() @@ -132,13 +117,13 @@ where /// let req = test::TestRequest::with_uri("/test").to_request(); /// /// // Call application -/// let resp = test::call_service(&mut app, req).await; +/// let resp = test::call_service(&app, req).await; /// assert_eq!(resp.status(), StatusCode::OK); /// } /// ``` -pub async fn call_service(app: &mut S, req: R) -> S::Response +pub async fn call_service(app: &S, req: R) -> S::Response where - S: Service, Error = E>, + S: Service, Error = E>, E: std::fmt::Debug, { app.call(req).await.unwrap() @@ -152,7 +137,7 @@ where /// /// #[actix_rt::test] /// async fn test_index() { -/// let mut app = test::init_service( +/// let app = test::init_service( /// App::new().service( /// web::resource("/index.html") /// .route(web::post().to(|| async { @@ -165,13 +150,13 @@ where /// .header(header::CONTENT_TYPE, "application/json") /// .to_request(); /// -/// let result = test::read_response(&mut app, req).await; +/// let result = test::read_response(&app, req).await; /// assert_eq!(result, Bytes::from_static(b"welcome!")); /// } /// ``` -pub async fn read_response(app: &mut S, req: Request) -> Bytes +pub async fn read_response(app: &S, req: Request) -> Bytes where - S: Service, Error = Error>, + S: Service, Error = Error>, B: MessageBody + Unpin, { let mut resp = app @@ -195,7 +180,7 @@ where /// /// #[actix_rt::test] /// async fn test_index() { -/// let mut app = test::init_service( +/// let app = test::init_service( /// App::new().service( /// web::resource("/index.html") /// .route(web::post().to(|| async { @@ -208,7 +193,7 @@ where /// .header(header::CONTENT_TYPE, "application/json") /// .to_request(); /// -/// let resp = test::call_service(&mut app, req).await; +/// let resp = test::call_service(&app, req).await; /// let result = test::read_body(resp).await; /// assert_eq!(result, Bytes::from_static(b"welcome!")); /// } @@ -239,12 +224,12 @@ where /// /// #[actix_rt::test] /// async fn test_post_person() { -/// let mut app = test::init_service( +/// let app = test::init_service( /// App::new().service( /// web::resource("/people") /// .route(web::post().to(|person: web::Json| async { /// HttpResponse::Ok() -/// .json(person.into_inner())}) +/// .json(person)}) /// )) /// ).await; /// @@ -269,9 +254,8 @@ where { let body = read_body(res).await; - serde_json::from_slice(&body).unwrap_or_else(|e| { - panic!("read_response_json failed during deserialization: {}", e) - }) + serde_json::from_slice(&body) + .unwrap_or_else(|e| panic!("read_response_json failed during deserialization: {}", e)) } pub async fn load_stream(mut stream: S) -> Result @@ -299,12 +283,12 @@ where /// /// #[actix_rt::test] /// async fn test_add_person() { -/// let mut app = test::init_service( +/// let app = test::init_service( /// App::new().service( /// web::resource("/people") /// .route(web::post().to(|person: web::Json| async { /// HttpResponse::Ok() -/// .json(person.into_inner())}) +/// .json(person)}) /// )) /// ).await; /// @@ -319,9 +303,9 @@ where /// let result: Person = test::read_response_json(&mut app, req).await; /// } /// ``` -pub async fn read_response_json(app: &mut S, req: Request) -> T +pub async fn read_response_json(app: &S, req: Request) -> T where - S: Service, Error = Error>, + S: Service, Error = Error>, B: MessageBody + Unpin, T: DeserializeOwned, { @@ -354,7 +338,7 @@ where /// /// #[test] /// fn test_index() { -/// let req = test::TestRequest::with_header("content-type", "text/plain") +/// let req = test::TestRequest::default().insert_header("content-type", "text/plain") /// .to_http_request(); /// /// let resp = index(req).await.unwrap(); @@ -394,21 +378,6 @@ impl TestRequest { TestRequest::default().uri(path) } - /// Create TestRequest and set header - pub fn with_hdr(hdr: H) -> TestRequest { - TestRequest::default().set(hdr) - } - - /// Create TestRequest and set header - pub fn with_header(key: K, value: V) -> TestRequest - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - TestRequest::default().header(key, value) - } - /// Create TestRequest and set method to `Method::GET` pub fn get() -> TestRequest { TestRequest::default().method(Method::GET) @@ -452,24 +421,26 @@ impl TestRequest { self } - /// Set a header - pub fn set(mut self, hdr: H) -> Self { - self.req.set(hdr); - self - } - - /// Set a header - pub fn header(mut self, key: K, value: V) -> Self + /// Insert a header, replacing any that were set with an equivalent field name. + pub fn insert_header(mut self, header: H) -> Self where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, + H: IntoHeaderPair, { - self.req.header(key, value); + self.req.insert_header(header); self } - /// Set cookie for this request + /// Append a header, keeping any that were set with an equivalent field name. + pub fn append_header(mut self, header: H) -> Self + where + H: IntoHeaderPair, + { + self.req.append_header(header); + self + } + + /// Set cookie for this request. + #[cfg(feature = "cookies")] pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { self.req.cookie(cookie); self @@ -499,17 +470,16 @@ impl TestRequest { let bytes = serde_urlencoded::to_string(data) .expect("Failed to serialize test data as a urlencoded form"); self.req.set_payload(bytes); - self.req.set(ContentType::form_url_encoded()); + self.req.insert_header(ContentType::form_url_encoded()); self } /// Serialize `data` to JSON and set it as the request payload. The `Content-Type` header is /// set to `application/json`. pub fn set_json(mut self, data: &T) -> Self { - let bytes = - serde_json::to_string(data).expect("Failed to serialize test data to json"); + let bytes = serde_json::to_string(data).expect("Failed to serialize test data to json"); self.req.set_payload(bytes); - self.req.set(ContentType::json()); + self.req.insert_header(ContentType::json()); self } @@ -547,15 +517,12 @@ impl TestRequest { head.peer_addr = self.peer_addr; self.path.get_mut().update(&head.uri); - ServiceRequest::new(HttpRequest::new( - self.path, - head, + let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); + + ServiceRequest::new( + HttpRequest::new(self.path, head, app_state, Rc::new(self.app_data)), payload, - Rc::new(self.rmap), - self.config.clone(), - Rc::new(self.app_data), - HttpRequestPool::create(), - )) + ) } /// Complete request creation and generate `ServiceResponse` instance @@ -565,19 +532,13 @@ impl TestRequest { /// Complete request creation and generate `HttpRequest` instance pub fn to_http_request(mut self) -> HttpRequest { - let (mut head, payload) = self.req.finish().into_parts(); + let (mut head, _) = self.req.finish().into_parts(); head.peer_addr = self.peer_addr; self.path.get_mut().update(&head.uri); - HttpRequest::new( - self.path, - head, - payload, - Rc::new(self.rmap), - self.config.clone(), - Rc::new(self.app_data), - HttpRequestPool::create(), - ) + let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); + + HttpRequest::new(self.path, head, app_state, Rc::new(self.app_data)) } /// Complete request creation and generate `HttpRequest` and `Payload` instances @@ -586,23 +547,17 @@ impl TestRequest { head.peer_addr = self.peer_addr; self.path.get_mut().update(&head.uri); - let req = HttpRequest::new( - self.path, - head, - Payload::None, - Rc::new(self.rmap), - self.config.clone(), - Rc::new(self.app_data), - HttpRequestPool::create(), - ); + let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); + + let req = HttpRequest::new(self.path, head, app_state, Rc::new(self.app_data)); (req, payload) } /// Complete request creation, calls service and waits for response future completion. - pub async fn send_request(self, app: &mut S) -> S::Response + pub async fn send_request(self, app: &S) -> S::Response where - S: Service, Error = E>, + S: Service, Error = E>, E: std::fmt::Debug, { let req = self.to_request(); @@ -626,7 +581,7 @@ impl TestRequest { /// /// #[actix_rt::test] /// async fn test_example() { -/// let mut srv = test::start( +/// let srv = test::start( /// || App::new().service( /// web::resource("/").to(my_handler)) /// ); @@ -639,12 +594,12 @@ impl TestRequest { pub fn start(factory: F) -> TestServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoServiceFactory, - S: ServiceFactory + 'static, + I: IntoServiceFactory, + S: ServiceFactory + 'static, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { start_with(TestServerConfig::default(), factory) @@ -666,7 +621,7 @@ where /// /// #[actix_rt::test] /// async fn test_example() { -/// let mut srv = test::start_with(test::config().h1(), || +/// let srv = test::start_with(test::config().h1(), || /// App::new().service(web::resource("/").to(my_handler)) /// ); /// @@ -678,12 +633,12 @@ where pub fn start_with(cfg: TestServerConfig, factory: F) -> TestServer where F: Fn() -> I + Send + Clone + 'static, - I: IntoServiceFactory, - S: ServiceFactory + 'static, + I: IntoServiceFactory, + S: ServiceFactory + 'static, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - ::Future: 'static, + >::Future: 'static, B: MessageBody + 'static, { let (tx, rx) = mpsc::channel(); @@ -698,7 +653,7 @@ where // run server in separate thread thread::spawn(move || { - let sys = System::new("actix-test-server"); + let sys = System::new(); let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); let local_addr = tcp.local_addr().unwrap(); let factory = factory.clone(); @@ -709,24 +664,21 @@ where let srv = match cfg.stream { StreamType::Tcp => match cfg.tp { HttpVer::Http1 => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(false, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(false, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .h1(map_config(factory(), move |_| cfg.clone())) .tcp() }), HttpVer::Http2 => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(false, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(false, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .h2(map_config(factory(), move |_| cfg.clone())) .tcp() }), HttpVer::Both => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(false, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(false, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .finish(map_config(factory(), move |_| cfg.clone())) @@ -736,24 +688,21 @@ where #[cfg(feature = "openssl")] StreamType::Openssl(acceptor) => match cfg.tp { HttpVer::Http1 => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(true, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .h1(map_config(factory(), move |_| cfg.clone())) .openssl(acceptor.clone()) }), HttpVer::Http2 => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(true, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .h2(map_config(factory(), move |_| cfg.clone())) .openssl(acceptor.clone()) }), HttpVer::Both => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(true, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .finish(map_config(factory(), move |_| cfg.clone())) @@ -763,24 +712,21 @@ where #[cfg(feature = "rustls")] StreamType::Rustls(config) => match cfg.tp { HttpVer::Http1 => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(true, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .h1(map_config(factory(), move |_| cfg.clone())) .rustls(config.clone()) }), HttpVer::Http2 => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(true, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .h2(map_config(factory(), move |_| cfg.clone())) .rustls(config.clone()) }), HttpVer::Both => builder.listen("test", tcp, move || { - let cfg = - AppConfig::new(true, local_addr, format!("{}", local_addr)); + let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); HttpService::build() .client_timeout(ctimeout) .finish(map_config(factory(), move |_| cfg.clone())) @@ -788,10 +734,13 @@ where }), }, } - .unwrap() - .start(); + .unwrap(); + + sys.block_on(async { + let srv = srv.run(); + tx.send((System::current(), srv, local_addr)).unwrap(); + }); - tx.send((System::current(), srv, local_addr)).unwrap(); sys.run() }); @@ -801,7 +750,7 @@ where let connector = { #[cfg(feature = "openssl")] { - use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); @@ -812,14 +761,12 @@ where .conn_lifetime(time::Duration::from_secs(0)) .timeout(time::Duration::from_millis(30000)) .ssl(builder.build()) - .finish() } #[cfg(not(feature = "openssl"))] { Connector::new() .conn_lifetime(time::Duration::from_secs(0)) .timeout(time::Duration::from_millis(30000)) - .finish() } }; @@ -853,9 +800,9 @@ enum HttpVer { enum StreamType { Tcp, #[cfg(feature = "openssl")] - Openssl(open_ssl::ssl::SslAcceptor), + Openssl(openssl::ssl::SslAcceptor), #[cfg(feature = "rustls")] - Rustls(rust_tls::ServerConfig), + Rustls(rustls::ServerConfig), } impl Default for TestServerConfig { @@ -879,13 +826,13 @@ impl TestServerConfig { } } - /// Start http/1.1 server only + /// Start HTTP/1.1 server only pub fn h1(mut self) -> Self { self.tp = HttpVer::Http1; self } - /// Start http/2 server only + /// Start HTTP/2 server only pub fn h2(mut self) -> Self { self.tp = HttpVer::Http2; self @@ -893,14 +840,14 @@ impl TestServerConfig { /// Start openssl server #[cfg(feature = "openssl")] - pub fn openssl(mut self, acceptor: open_ssl::ssl::SslAcceptor) -> Self { + pub fn openssl(mut self, acceptor: openssl::ssl::SslAcceptor) -> Self { self.stream = StreamType::Openssl(acceptor); self } /// Start rustls server #[cfg(feature = "rustls")] - pub fn rustls(mut self, config: rust_tls::ServerConfig) -> Self { + pub fn rustls(mut self, config: rustls::ServerConfig) -> Self { self.stream = StreamType::Rustls(config); self } @@ -915,8 +862,7 @@ impl TestServerConfig { /// Get first available unused address pub fn unused_addr() -> net::SocketAddr { let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - let socket = - Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp())).unwrap(); + let socket = Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp())).unwrap(); socket.bind(&addr.into()).unwrap(); socket.set_reuse_address(true).unwrap(); let tcp = socket.into_tcp_listener(); @@ -984,7 +930,7 @@ impl TestServer { self.client.options(self.url(path.as_ref()).as_str()) } - /// Connect to test http server + /// Connect to test HTTP server pub fn request>(&self, method: Method, path: S) -> ClientRequest { self.client.request(method, path.as_ref()) } @@ -999,30 +945,28 @@ impl TestServer { response.body().limit(10_485_760).await } - /// Connect to websocket server at a given path + /// Connect to WebSocket server at a given path. pub async fn ws_at( &mut self, path: &str, - ) -> Result, awc::error::WsClientError> - { + ) -> Result, awc::error::WsClientError> { let url = self.url(path); let connect = self.client.ws(url).connect(); connect.await.map(|(_, framed)| framed) } - /// Connect to a websocket server + /// Connect to a WebSocket server. pub async fn ws( &mut self, - ) -> Result, awc::error::WsClientError> - { + ) -> Result, awc::error::WsClientError> { self.ws_at("/").await } - /// Gracefully stop http server + /// Gracefully stop HTTP server pub async fn stop(self) { self.server.stop(true).await; self.system.stop(); - delay_for(time::Duration::from_millis(100)).await; + sleep(time::Duration::from_millis(100)).await; } } @@ -1034,7 +978,7 @@ impl Drop for TestServer { #[cfg(test)] mod tests { - use actix_http::httpmessage::HttpMessage; + use actix_http::HttpMessage; use serde::{Deserialize, Serialize}; use std::time::SystemTime; @@ -1043,9 +987,10 @@ mod tests { #[actix_rt::test] async fn test_basics() { - let req = TestRequest::with_hdr(header::ContentType::json()) + let req = TestRequest::default() .version(Version::HTTP_2) - .set(header::Date(SystemTime::now().into())) + .insert_header(header::ContentType::json()) + .insert_header(header::Date(SystemTime::now().into())) .param("test", "123") .data(10u32) .app_data(20u64) @@ -1070,7 +1015,7 @@ mod tests { #[actix_rt::test] async fn test_request_methods() { - let mut app = init_service( + let app = init_service( App::new().service( web::resource("/index.html") .route(web::put().to(|| HttpResponse::Ok().body("put!"))) @@ -1082,28 +1027,28 @@ mod tests { let put_req = TestRequest::put() .uri("/index.html") - .header(header::CONTENT_TYPE, "application/json") + .insert_header((header::CONTENT_TYPE, "application/json")) .to_request(); - let result = read_response(&mut app, put_req).await; + let result = read_response(&app, put_req).await; assert_eq!(result, Bytes::from_static(b"put!")); let patch_req = TestRequest::patch() .uri("/index.html") - .header(header::CONTENT_TYPE, "application/json") + .insert_header((header::CONTENT_TYPE, "application/json")) .to_request(); - let result = read_response(&mut app, patch_req).await; + let result = read_response(&app, patch_req).await; assert_eq!(result, Bytes::from_static(b"patch!")); let delete_req = TestRequest::delete().uri("/index.html").to_request(); - let result = read_response(&mut app, delete_req).await; + let result = read_response(&app, delete_req).await; assert_eq!(result, Bytes::from_static(b"delete!")); } #[actix_rt::test] async fn test_response() { - let mut app = init_service( + let app = init_service( App::new().service( web::resource("/index.html") .route(web::post().to(|| HttpResponse::Ok().body("welcome!"))), @@ -1113,16 +1058,16 @@ mod tests { let req = TestRequest::post() .uri("/index.html") - .header(header::CONTENT_TYPE, "application/json") + .insert_header((header::CONTENT_TYPE, "application/json")) .to_request(); - let result = read_response(&mut app, req).await; + let result = read_response(&app, req).await; assert_eq!(result, Bytes::from_static(b"welcome!")); } #[actix_rt::test] async fn test_send_request() { - let mut app = init_service( + let app = init_service( App::new().service( web::resource("/index.html") .route(web::get().to(|| HttpResponse::Ok().body("welcome!"))), @@ -1132,7 +1077,7 @@ mod tests { let resp = TestRequest::get() .uri("/index.html") - .send_request(&mut app) + .send_request(&app) .await; let result = read_body(resp).await; @@ -1147,10 +1092,8 @@ mod tests { #[actix_rt::test] async fn test_response_json() { - let mut app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| { - HttpResponse::Ok().json(person.into_inner()) - }), + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), ))) .await; @@ -1158,20 +1101,18 @@ mod tests { let req = TestRequest::post() .uri("/people") - .header(header::CONTENT_TYPE, "application/json") + .insert_header((header::CONTENT_TYPE, "application/json")) .set_payload(payload) .to_request(); - let result: Person = read_response_json(&mut app, req).await; + let result: Person = read_response_json(&app, req).await; assert_eq!(&result.id, "12345"); } #[actix_rt::test] async fn test_body_json() { - let mut app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| { - HttpResponse::Ok().json(person.into_inner()) - }), + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), ))) .await; @@ -1179,9 +1120,9 @@ mod tests { let resp = TestRequest::post() .uri("/people") - .header(header::CONTENT_TYPE, "application/json") + .insert_header((header::CONTENT_TYPE, "application/json")) .set_payload(payload) - .send_request(&mut app) + .send_request(&app) .await; let result: Person = read_body_json(resp).await; @@ -1190,10 +1131,8 @@ mod tests { #[actix_rt::test] async fn test_request_response_form() { - let mut app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Form| { - HttpResponse::Ok().json(person.into_inner()) - }), + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Form| HttpResponse::Ok().json(person)), ))) .await; @@ -1209,17 +1148,15 @@ mod tests { assert_eq!(req.content_type(), "application/x-www-form-urlencoded"); - let result: Person = read_response_json(&mut app, req).await; + let result: Person = read_response_json(&app, req).await; assert_eq!(&result.id, "12345"); assert_eq!(&result.name, "User name"); } #[actix_rt::test] async fn test_request_response_json() { - let mut app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| { - HttpResponse::Ok().json(person.into_inner()) - }), + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), ))) .await; @@ -1235,7 +1172,7 @@ mod tests { assert_eq!(req.content_type(), "application/json"); - let result: Person = read_response_json(&mut app, req).await; + let result: Person = read_response_json(&app, req).await; assert_eq!(&result.id, "12345"); assert_eq!(&result.name, "User name"); } @@ -1248,15 +1185,14 @@ mod tests { match res { Ok(value) => Ok(HttpResponse::Ok() .content_type("text/plain") - .body(format!("Async with block value: {}", value))), + .body(format!("Async with block value: {:?}", value))), Err(_) => panic!("Unexpected"), } } - let mut app = init_service( - App::new().service(web::resource("/index.html").to(async_with_block)), - ) - .await; + let app = + init_service(App::new().service(web::resource("/index.html").to(async_with_block))) + .await; let req = TestRequest::post().uri("/index.html").to_request(); let res = app.call(req).await.unwrap(); @@ -1270,7 +1206,7 @@ mod tests { HttpResponse::Ok() } - let mut app = init_service( + let app = init_service( App::new() .data(10usize) .service(web::resource("/index.html").to(handler)), @@ -1281,55 +1217,4 @@ mod tests { let res = app.call(req).await.unwrap(); assert!(res.status().is_success()); } - - #[actix_rt::test] - async fn test_actor() { - use crate::Error; - use actix::prelude::*; - - struct MyActor; - - impl Actor for MyActor { - type Context = Context; - } - - struct Num(usize); - - impl Message for Num { - type Result = usize; - } - - impl Handler for MyActor { - type Result = usize; - - fn handle(&mut self, msg: Num, _: &mut Self::Context) -> Self::Result { - msg.0 - } - } - - let addr = MyActor.start(); - - async fn actor_handler( - addr: Data>, - ) -> Result { - // `?` operator tests "actors" feature flag on actix-http - let res = addr.send(Num(1)).await?; - - if res == 1 { - Ok(HttpResponse::Ok()) - } else { - Ok(HttpResponse::BadRequest()) - } - } - - let srv = App::new() - .data(addr.clone()) - .service(web::resource("/").to(actor_handler)); - - let mut app = init_service(srv).await; - - let req = TestRequest::post().uri("/").to_request(); - let res = app.call(req).await.unwrap(); - assert!(res.status().is_success()); - } } diff --git a/src/types/either.rs b/src/types/either.rs index 9f1d81a0b..bbab48dec 100644 --- a/src/types/either.rs +++ b/src/types/either.rs @@ -1,150 +1,184 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, +//! For either helper, see [`Either`]. + +use bytes::Bytes; +use futures_util::{future::LocalBoxFuture, FutureExt, TryFutureExt}; + +use crate::{ + dev, + web::{Form, Json}, + Error, FromRequest, HttpRequest, HttpResponse, Responder, }; -use actix_http::{Error, Response}; -use bytes::Bytes; -use futures_util::{future::LocalBoxFuture, ready, FutureExt, TryFutureExt}; -use pin_project::pin_project; - -use crate::{dev, request::HttpRequest, FromRequest, Responder}; - -/// Combines two different responder types into a single type +/// Combines two extractor or responder types into a single type. /// -/// ```rust -/// use actix_web::{Either, Error, HttpResponse}; +/// Can be converted to and from an [`either::Either`]. /// -/// type RegisterResult = Either>; +/// # Extractor +/// Provides a mechanism for trying two extractors, a primary and a fallback. Useful for +/// "polymorphic payloads" where, for example, a form might be JSON or URL encoded. /// -/// fn index() -> RegisterResult { -/// if is_a_variant() { -/// // <- choose left variant -/// Either::A(HttpResponse::BadRequest().body("Bad data")) +/// It is important to note that this extractor, by necessity, buffers the entire request payload +/// as part of its implementation. Though, it does respect any `PayloadConfig` maximum size limits. +/// +/// ``` +/// use actix_web::{post, web, Either}; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// name: String, +/// } +/// +/// // handler that accepts form as JSON or form-urlencoded. +/// #[post("/")] +/// async fn index(form: Either, web::Form>) -> String { +/// let name: String = match form { +/// Either::Left(json) => json.name.to_owned(), +/// Either::Right(form) => form.name.to_owned(), +/// }; +/// +/// format!("Welcome {}!", name) +/// } +/// ``` +/// +/// # Responder +/// It may be desireable to use a concrete type for a response with multiple branches. As long as +/// both types implement `Responder`, so will the `Either` type, enabling it to be used as a +/// handler's return type. +/// +/// All properties of a response are determined by the Responder branch returned. +/// +/// ``` +/// use actix_web::{get, Either, Error, HttpResponse}; +/// +/// #[get("/")] +/// async fn index() -> Either<&'static str, Result> { +/// if 1 == 2 { +/// // respond with Left variant +/// Either::Left("Bad data") /// } else { -/// Either::B( -/// // <- Right variant +/// // respond with Right variant +/// Either::Right( /// Ok(HttpResponse::Ok() -/// .content_type("text/html") -/// .body("Hello!")) +/// .content_type(mime::TEXT_HTML) +/// .body("